[management] add uniqueness constraint for peer ip and label and optimize generation (#4042)

This commit is contained in:
Pascal Fischer
2025-07-02 18:13:10 +02:00
committed by GitHub
parent 6c633497bc
commit 22678bce7f
13 changed files with 616 additions and 293 deletions

View File

@ -106,6 +106,18 @@ type DefaultAccountManager struct {
disableDefaultPolicy bool disableDefaultPolicy bool
} }
func isUniqueConstraintError(err error) bool {
switch {
case strings.Contains(err.Error(), "(SQLSTATE 23505)"),
strings.Contains(err.Error(), "Error 1062 (23000)"),
strings.Contains(err.Error(), "UNIQUE constraint failed"):
return true
default:
return false
}
}
// getJWTGroupsChanges calculates the changes needed to sync a user's JWT groups. // getJWTGroupsChanges calculates the changes needed to sync a user's JWT groups.
// Returns a bool indicating if there are changes in the JWT group membership, the updated user AutoGroups, // Returns a bool indicating if there are changes in the JWT group membership, the updated user AutoGroups,
// newly groups to create and an error if any occurred. // newly groups to create and an error if any occurred.
@ -1661,25 +1673,6 @@ func (am *DefaultAccountManager) handleUserPeer(ctx context.Context, transaction
return false, nil return false, nil
} }
func (am *DefaultAccountManager) getFreeDNSLabel(ctx context.Context, s store.Store, accountID string, peerHostName string) (string, error) {
existingLabels, err := s.GetPeerLabelsInAccount(ctx, store.LockingStrengthShare, accountID)
if err != nil {
return "", fmt.Errorf("failed to get peer dns labels: %w", err)
}
labelMap := ConvertSliceToMap(existingLabels)
newLabel, err := types.GetPeerHostLabel(peerHostName, labelMap)
if err != nil {
return "", fmt.Errorf("failed to get new host label: %w", err)
}
if newLabel == "" {
return "", fmt.Errorf("failed to get new host label: %w", err)
}
return newLabel, nil
}
func (am *DefaultAccountManager) GetAccountSettings(ctx context.Context, accountID string, userID string) (*types.Settings, error) { func (am *DefaultAccountManager) GetAccountSettings(ctx context.Context, accountID string, userID string) (*types.Settings, error) {
allowed, err := am.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Settings, operations.Read) allowed, err := am.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Settings, operations.Read)
if err != nil { if err != nil {

View File

@ -2623,11 +2623,11 @@ func TestAccount_SetJWTGroups(t *testing.T) {
account := &types.Account{ account := &types.Account{
Id: "accountID", Id: "accountID",
Peers: map[string]*nbpeer.Peer{ Peers: map[string]*nbpeer.Peer{
"peer1": {ID: "peer1", Key: "key1", UserID: "user1"}, "peer1": {ID: "peer1", Key: "key1", UserID: "user1", IP: net.IP{1, 1, 1, 1}, DNSLabel: "peer1.domain.test"},
"peer2": {ID: "peer2", Key: "key2", UserID: "user1"}, "peer2": {ID: "peer2", Key: "key2", UserID: "user1", IP: net.IP{2, 2, 2, 2}, DNSLabel: "peer2.domain.test"},
"peer3": {ID: "peer3", Key: "key3", UserID: "user1"}, "peer3": {ID: "peer3", Key: "key3", UserID: "user1", IP: net.IP{3, 3, 3, 3}, DNSLabel: "peer3.domain.test"},
"peer4": {ID: "peer4", Key: "key4", UserID: "user2"}, "peer4": {ID: "peer4", Key: "key4", UserID: "user2", IP: net.IP{4, 4, 4, 4}, DNSLabel: "peer4.domain.test"},
"peer5": {ID: "peer5", Key: "key5", UserID: "user2"}, "peer5": {ID: "peer5", Key: "key5", UserID: "user2", IP: net.IP{5, 5, 5, 5}, DNSLabel: "peer5.domain.test"},
}, },
Groups: map[string]*types.Group{ Groups: map[string]*types.Group{
"group1": {ID: "group1", Name: "group1", Issued: types.GroupIssuedAPI, Peers: []string{}}, "group1": {ID: "group1", Name: "group1", Issued: types.GroupIssuedAPI, Peers: []string{}},
@ -3147,11 +3147,11 @@ func BenchmarkLoginPeer_NewPeer(b *testing.B) {
minMsPerOpCICD float64 minMsPerOpCICD float64
maxMsPerOpCICD float64 maxMsPerOpCICD float64
}{ }{
{"Small", 50, 5, 7, 20, 10, 80}, {"Small", 50, 5, 7, 20, 5, 80},
{"Medium", 500, 100, 5, 40, 30, 140}, {"Medium", 500, 100, 5, 40, 30, 140},
{"Large", 5000, 200, 80, 120, 140, 390}, {"Large", 5000, 200, 80, 120, 140, 390},
{"Small single", 50, 10, 7, 20, 10, 80}, {"Small single", 50, 10, 7, 20, 6, 80},
{"Medium single", 500, 10, 5, 40, 20, 85}, {"Medium single", 500, 10, 5, 40, 15, 85},
{"Large 5", 5000, 15, 80, 120, 80, 200}, {"Large 5", 5000, 15, 80, 120, 80, 200},
} }
@ -3343,11 +3343,11 @@ func TestPropagateUserGroupMemberships(t *testing.T) {
account, err := manager.GetOrCreateAccountByUser(ctx, initiatorId, domain) account, err := manager.GetOrCreateAccountByUser(ctx, initiatorId, domain)
require.NoError(t, err) require.NoError(t, err)
peer1 := &nbpeer.Peer{ID: "peer1", AccountID: account.Id, UserID: initiatorId} peer1 := &nbpeer.Peer{ID: "peer1", AccountID: account.Id, UserID: initiatorId, IP: net.IP{1, 1, 1, 1}, DNSLabel: "peer1.domain.test"}
err = manager.Store.AddPeerToAccount(ctx, store.LockingStrengthUpdate, peer1) err = manager.Store.AddPeerToAccount(ctx, store.LockingStrengthUpdate, peer1)
require.NoError(t, err) require.NoError(t, err)
peer2 := &nbpeer.Peer{ID: "peer2", AccountID: account.Id, UserID: initiatorId} peer2 := &nbpeer.Peer{ID: "peer2", AccountID: account.Id, UserID: initiatorId, IP: net.IP{2, 2, 2, 2}, DNSLabel: "peer2.domain.test"}
err = manager.Store.AddPeerToAccount(ctx, store.LockingStrengthUpdate, peer2) err = manager.Store.AddPeerToAccount(ctx, store.LockingStrengthUpdate, peer2)
require.NoError(t, err) require.NoError(t, err)

View File

@ -373,3 +373,42 @@ func DropIndex[T any](ctx context.Context, db *gorm.DB, indexName string) error
log.WithContext(ctx).Infof("dropped index %s from table %T", indexName, model) log.WithContext(ctx).Infof("dropped index %s from table %T", indexName, model)
return nil return nil
} }
func CreateIndexIfNotExists[T any](ctx context.Context, db *gorm.DB, indexName string, columns ...string) error {
var model T
stmt := &gorm.Statement{DB: db}
if err := stmt.Parse(&model); err != nil {
return fmt.Errorf("failed to parse model schema: %w", err)
}
tableName := stmt.Schema.Table
dialect := db.Dialector.Name()
var columnClause string
if dialect == "mysql" {
var withLength []string
for _, col := range columns {
if col == "ip" || col == "dns_label" {
withLength = append(withLength, fmt.Sprintf("%s(64)", col))
} else {
withLength = append(withLength, col)
}
}
columnClause = strings.Join(withLength, ", ")
} else {
columnClause = strings.Join(columns, ", ")
}
createStmt := fmt.Sprintf("CREATE UNIQUE INDEX %s ON %s (%s)", indexName, tableName, columnClause)
if dialect == "postgres" || dialect == "sqlite" {
createStmt = strings.Replace(createStmt, "CREATE UNIQUE INDEX", "CREATE UNIQUE INDEX IF NOT EXISTS", 1)
}
log.WithContext(ctx).Infof("executing index creation: %s", createStmt)
if err := db.Exec(createStmt).Error; err != nil {
return fmt.Errorf("failed to create index %s: %w", indexName, err)
}
log.WithContext(ctx).Infof("successfully created index %s on table %s", indexName, tableName)
return nil
}

View File

@ -15,13 +15,14 @@ import (
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
"golang.org/x/exp/maps" "golang.org/x/exp/maps"
nbdns "github.com/netbirdio/netbird/dns"
"github.com/netbirdio/netbird/management/domain" "github.com/netbirdio/netbird/management/domain"
"github.com/netbirdio/netbird/management/server/geolocation" "github.com/netbirdio/netbird/management/server/geolocation"
"github.com/netbirdio/netbird/management/server/idp"
routerTypes "github.com/netbirdio/netbird/management/server/networks/routers/types" routerTypes "github.com/netbirdio/netbird/management/server/networks/routers/types"
"github.com/netbirdio/netbird/management/server/permissions/modules" "github.com/netbirdio/netbird/management/server/permissions/modules"
"github.com/netbirdio/netbird/management/server/permissions/operations" "github.com/netbirdio/netbird/management/server/permissions/operations"
"github.com/netbirdio/netbird/management/server/idp"
"github.com/netbirdio/netbird/management/server/posture" "github.com/netbirdio/netbird/management/server/posture"
"github.com/netbirdio/netbird/management/server/store" "github.com/netbirdio/netbird/management/server/store"
"github.com/netbirdio/netbird/management/server/types" "github.com/netbirdio/netbird/management/server/types"
@ -234,14 +235,10 @@ func (am *DefaultAccountManager) UpdatePeer(ctx context.Context, accountID, user
} }
if peer.Name != update.Name { if peer.Name != update.Name {
existingLabels, err := getPeerDNSLabels(ctx, transaction, accountID) var newLabel string
newLabel, err = getPeerIPDNSLabel(ctx, transaction, peer.IP, accountID, update.Name)
if err != nil { if err != nil {
return err return fmt.Errorf("failed to get free DNS label: %w", err)
}
newLabel, err := types.GetPeerHostLabel(update.Name, existingLabels)
if err != nil {
return err
} }
peer.Name = update.Name peer.Name = update.Name
@ -463,67 +460,50 @@ func (am *DefaultAccountManager) AddPeer(ctx context.Context, setupKey, userID s
upperKey := strings.ToUpper(setupKey) upperKey := strings.ToUpper(setupKey)
hashedKey := sha256.Sum256([]byte(upperKey)) hashedKey := sha256.Sum256([]byte(upperKey))
encodedHashedKey := b64.StdEncoding.EncodeToString(hashedKey[:]) encodedHashedKey := b64.StdEncoding.EncodeToString(hashedKey[:])
var accountID string addedByUser := len(userID) > 0
var err error
addedByUser := false
if len(userID) > 0 {
addedByUser = true
accountID, err = am.Store.GetAccountIDByUserID(ctx, store.LockingStrengthShare, userID)
} else {
accountID, err = am.Store.GetAccountIDBySetupKey(ctx, encodedHashedKey)
}
if err != nil {
return nil, nil, nil, status.Errorf(status.NotFound, "failed adding new peer: account not found")
}
unlock := am.Store.AcquireWriteLockByUID(ctx, accountID)
defer func() {
if unlock != nil {
unlock()
}
}()
// This is a handling for the case when the same machine (with the same WireGuard pub key) tries to register twice. // This is a handling for the case when the same machine (with the same WireGuard pub key) tries to register twice.
// Such case is possible when AddPeer function takes long time to finish after AcquireWriteLockByUID (e.g., database is slow) // Such case is possible when AddPeer function takes long time to finish after AcquireWriteLockByUID (e.g., database is slow)
// and the peer disconnects with a timeout and tries to register again. // and the peer disconnects with a timeout and tries to register again.
// We just check if this machine has been registered before and reject the second registration. // We just check if this machine has been registered before and reject the second registration.
// The connecting peer should be able to recover with a retry. // The connecting peer should be able to recover with a retry.
_, err = am.Store.GetPeerByPeerPubKey(ctx, store.LockingStrengthShare, peer.Key) _, err := am.Store.GetPeerByPeerPubKey(ctx, store.LockingStrengthNone, peer.Key)
if err == nil { if err == nil {
return nil, nil, nil, status.Errorf(status.PreconditionFailed, "peer has been already registered") return nil, nil, nil, status.Errorf(status.PreconditionFailed, "peer has been already registered")
} }
opEvent := &activity.Event{ opEvent := &activity.Event{
Timestamp: time.Now().UTC(), Timestamp: time.Now().UTC(),
AccountID: accountID,
} }
var newPeer *nbpeer.Peer var newPeer *nbpeer.Peer
var updateAccountPeers bool var updateAccountPeers bool
err = am.Store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
var setupKeyID string var setupKeyID string
var setupKeyName string var setupKeyName string
var ephemeral bool var ephemeral bool
var groupsToAdd []string var groupsToAdd []string
var allowExtraDNSLabels bool var allowExtraDNSLabels bool
var accountID string
if addedByUser { if addedByUser {
user, err := transaction.GetUserByUserID(ctx, store.LockingStrengthUpdate, userID) user, err := am.Store.GetUserByUserID(ctx, store.LockingStrengthNone, userID)
if err != nil { if err != nil {
return fmt.Errorf("failed to get user groups: %w", err) return nil, nil, nil, fmt.Errorf("failed to get user groups: %w", err)
} }
groupsToAdd = user.AutoGroups groupsToAdd = user.AutoGroups
opEvent.InitiatorID = userID opEvent.InitiatorID = userID
opEvent.Activity = activity.PeerAddedByUser opEvent.Activity = activity.PeerAddedByUser
accountID = user.AccountID
} else { } else {
// Validate the setup key // Validate the setup key
sk, err := transaction.GetSetupKeyBySecret(ctx, store.LockingStrengthUpdate, encodedHashedKey) sk, err := am.Store.GetSetupKeyBySecret(ctx, store.LockingStrengthNone, encodedHashedKey)
if err != nil { if err != nil {
return fmt.Errorf("failed to get setup key: %w", err) return nil, nil, nil, fmt.Errorf("failed to get setup key: %w", err)
} }
// we will check key twice for early return
if !sk.IsValid() { if !sk.IsValid() {
return status.Errorf(status.PreconditionFailed, "couldn't add peer: setup key is invalid") return nil, nil, nil, status.Errorf(status.PreconditionFailed, "couldn't add peer: setup key is invalid")
} }
opEvent.InitiatorID = sk.Id opEvent.InitiatorID = sk.Id
@ -533,11 +513,13 @@ func (am *DefaultAccountManager) AddPeer(ctx context.Context, setupKey, userID s
setupKeyID = sk.Id setupKeyID = sk.Id
setupKeyName = sk.Name setupKeyName = sk.Name
allowExtraDNSLabels = sk.AllowExtraDNSLabels allowExtraDNSLabels = sk.AllowExtraDNSLabels
accountID = sk.AccountID
if !sk.AllowExtraDNSLabels && len(peer.ExtraDNSLabels) > 0 { if !sk.AllowExtraDNSLabels && len(peer.ExtraDNSLabels) > 0 {
return status.Errorf(status.PreconditionFailed, "couldn't add peer: setup key doesn't allow extra DNS labels") return nil, nil, nil, status.Errorf(status.PreconditionFailed, "couldn't add peer: setup key doesn't allow extra DNS labels")
} }
} }
opEvent.AccountID = accountID
if (strings.ToLower(peer.Meta.Hostname) == "iphone" || strings.ToLower(peer.Meta.Hostname) == "ipad") && userID != "" { if (strings.ToLower(peer.Meta.Hostname) == "iphone" || strings.ToLower(peer.Meta.Hostname) == "ipad") && userID != "" {
if am.idpManager != nil { if am.idpManager != nil {
@ -548,18 +530,8 @@ func (am *DefaultAccountManager) AddPeer(ctx context.Context, setupKey, userID s
} }
} }
freeLabel, err := am.getFreeDNSLabel(ctx, transaction, accountID, peer.Meta.Hostname)
if err != nil {
return fmt.Errorf("failed to get free DNS label: %w", err)
}
freeIP, err := getFreeIP(ctx, transaction, accountID)
if err != nil {
return fmt.Errorf("failed to get free IP: %w", err)
}
if err := domain.ValidateDomainsList(peer.ExtraDNSLabels); err != nil { if err := domain.ValidateDomainsList(peer.ExtraDNSLabels); err != nil {
return status.Errorf(status.InvalidArgument, "invalid extra DNS labels: %v", err) return nil, nil, nil, status.Errorf(status.InvalidArgument, "invalid extra DNS labels: %v", err)
} }
registrationTime := time.Now().UTC() registrationTime := time.Now().UTC()
@ -567,10 +539,8 @@ func (am *DefaultAccountManager) AddPeer(ctx context.Context, setupKey, userID s
ID: xid.New().String(), ID: xid.New().String(),
AccountID: accountID, AccountID: accountID,
Key: peer.Key, Key: peer.Key,
IP: freeIP,
Meta: peer.Meta, Meta: peer.Meta,
Name: peer.Meta.Hostname, Name: peer.Meta.Hostname,
DNSLabel: freeLabel,
UserID: userID, UserID: userID,
Status: &nbpeer.PeerStatus{Connected: false, LastSeen: registrationTime}, Status: &nbpeer.PeerStatus{Connected: false, LastSeen: registrationTime},
SSHEnabled: false, SSHEnabled: false,
@ -584,15 +554,9 @@ func (am *DefaultAccountManager) AddPeer(ctx context.Context, setupKey, userID s
ExtraDNSLabels: peer.ExtraDNSLabels, ExtraDNSLabels: peer.ExtraDNSLabels,
AllowExtraDNSLabels: allowExtraDNSLabels, AllowExtraDNSLabels: allowExtraDNSLabels,
} }
settings, err := transaction.GetAccountSettings(ctx, store.LockingStrengthShare, accountID) settings, err := am.Store.GetAccountSettings(ctx, store.LockingStrengthNone, accountID)
if err != nil { if err != nil {
return fmt.Errorf("failed to get account settings: %w", err) return nil, nil, nil, fmt.Errorf("failed to get account settings: %w", err)
}
opEvent.TargetID = newPeer.ID
opEvent.Meta = newPeer.EventMeta(am.GetDNSDomain(settings))
if !addedByUser {
opEvent.Meta["setup_key_name"] = setupKeyName
} }
if am.geo != nil && newPeer.Location.ConnectionIP != nil { if am.geo != nil && newPeer.Location.ConnectionIP != nil {
@ -608,6 +572,41 @@ func (am *DefaultAccountManager) AddPeer(ctx context.Context, setupKey, userID s
newPeer = am.integratedPeerValidator.PreparePeer(ctx, accountID, newPeer, groupsToAdd, settings.Extra) newPeer = am.integratedPeerValidator.PreparePeer(ctx, accountID, newPeer, groupsToAdd, settings.Extra)
network, err := am.Store.GetAccountNetwork(ctx, store.LockingStrengthNone, accountID)
if err != nil {
return nil, nil, nil, fmt.Errorf("failed getting network: %w", err)
}
maxAttempts := 10
for attempt := 1; attempt <= maxAttempts; attempt++ {
var freeIP net.IP
freeIP, err = types.AllocateRandomPeerIP(network.Net)
if err != nil {
return nil, nil, nil, fmt.Errorf("failed to get free IP: %w", err)
}
var freeLabel string
freeLabel, err = getPeerIPDNSLabel(ctx, am.Store, freeIP, accountID, peer.Meta.Hostname)
if err != nil {
return nil, nil, nil, fmt.Errorf("failed to get free DNS label: %w", err)
}
newPeer.DNSLabel = freeLabel
newPeer.IP = freeIP
unlock := am.Store.AcquireWriteLockByUID(ctx, accountID)
defer func() {
if unlock != nil {
unlock()
}
}()
err = am.Store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
err = transaction.AddPeerToAccount(ctx, store.LockingStrengthUpdate, newPeer)
if err != nil {
return err
}
err = transaction.AddPeerToAllGroup(ctx, store.LockingStrengthUpdate, accountID, newPeer.ID) err = transaction.AddPeerToAllGroup(ctx, store.LockingStrengthUpdate, accountID, newPeer.ID)
if err != nil { if err != nil {
return fmt.Errorf("failed adding peer to All group: %w", err) return fmt.Errorf("failed adding peer to All group: %w", err)
@ -622,9 +621,26 @@ func (am *DefaultAccountManager) AddPeer(ctx context.Context, setupKey, userID s
} }
} }
err = transaction.AddPeerToAccount(ctx, store.LockingStrengthUpdate, newPeer) if addedByUser {
err := transaction.SaveUserLastLogin(ctx, accountID, userID, newPeer.GetLastLogin())
if err != nil { if err != nil {
return fmt.Errorf("failed to add peer to account: %w", err) log.WithContext(ctx).Debugf("failed to update user last login: %v", err)
}
} else {
sk, err := transaction.GetSetupKeyBySecret(ctx, store.LockingStrengthUpdate, encodedHashedKey)
if err != nil {
return fmt.Errorf("failed to get setup key: %w", err)
}
// we validate at the end to not block the setup key for too long
if !sk.IsValid() {
return status.Errorf(status.PreconditionFailed, "couldn't add peer: setup key is invalid")
}
err = transaction.IncrementSetupKeyUsage(ctx, setupKeyID)
if err != nil {
return fmt.Errorf("failed to increment setup key usage: %w", err)
}
} }
err = transaction.IncrementNetworkSerial(ctx, store.LockingStrengthUpdate, accountID) err = transaction.IncrementNetworkSerial(ctx, store.LockingStrengthUpdate, accountID)
@ -632,39 +648,44 @@ func (am *DefaultAccountManager) AddPeer(ctx context.Context, setupKey, userID s
return fmt.Errorf("failed to increment network serial: %w", err) return fmt.Errorf("failed to increment network serial: %w", err)
} }
if addedByUser {
err := transaction.SaveUserLastLogin(ctx, accountID, userID, newPeer.GetLastLogin())
if err != nil {
log.WithContext(ctx).Debugf("failed to update user last login: %v", err)
}
} else {
err = transaction.IncrementSetupKeyUsage(ctx, setupKeyID)
if err != nil {
return fmt.Errorf("failed to increment setup key usage: %w", err)
}
}
updateAccountPeers, err = isPeerInActiveGroup(ctx, transaction, accountID, newPeer.ID)
if err != nil {
return err
}
log.WithContext(ctx).Debugf("Peer %s added to account %s", newPeer.ID, accountID) log.WithContext(ctx).Debugf("Peer %s added to account %s", newPeer.ID, accountID)
return nil return nil
}) })
if err == nil {
unlock()
unlock = nil
break
}
if isUniqueConstraintError(err) {
unlock()
unlock = nil
log.WithContext(ctx).Debugf("Failed to add peer in attempt %d, retrying: %v", attempt, err)
continue
}
if err != nil {
return nil, nil, nil, fmt.Errorf("failed to add peer to database: %w", err) return nil, nil, nil, fmt.Errorf("failed to add peer to database: %w", err)
} }
if err != nil {
return nil, nil, nil, fmt.Errorf("failed to add peer to database after %d attempts: %w", maxAttempts, err)
}
updateAccountPeers, err = isPeerInActiveGroup(ctx, am.Store, accountID, newPeer.ID)
if err != nil {
updateAccountPeers = true
}
if newPeer == nil { if newPeer == nil {
return nil, nil, nil, fmt.Errorf("new peer is nil") return nil, nil, nil, fmt.Errorf("new peer is nil")
} }
am.StoreEvent(ctx, opEvent.InitiatorID, opEvent.TargetID, opEvent.AccountID, opEvent.Activity, opEvent.Meta) opEvent.TargetID = newPeer.ID
opEvent.Meta = newPeer.EventMeta(am.GetDNSDomain(settings))
if !addedByUser {
opEvent.Meta["setup_key_name"] = setupKeyName
}
unlock() am.StoreEvent(ctx, opEvent.InitiatorID, opEvent.TargetID, opEvent.AccountID, opEvent.Activity, opEvent.Meta)
unlock = nil
if updateAccountPeers { if updateAccountPeers {
am.BufferUpdateAccountPeers(ctx, accountID) am.BufferUpdateAccountPeers(ctx, accountID)
@ -673,23 +694,21 @@ func (am *DefaultAccountManager) AddPeer(ctx context.Context, setupKey, userID s
return am.getValidatedPeerWithMap(ctx, false, accountID, newPeer) return am.getValidatedPeerWithMap(ctx, false, accountID, newPeer)
} }
func getFreeIP(ctx context.Context, transaction store.Store, accountID string) (net.IP, error) { func getPeerIPDNSLabel(ctx context.Context, tx store.Store, ip net.IP, accountID, peerHostName string) (string, error) {
takenIps, err := transaction.GetTakenIPs(ctx, store.LockingStrengthShare, accountID) ip = ip.To4()
dnsName, err := nbdns.GetParsedDomainLabel(peerHostName)
if err != nil { if err != nil {
return nil, fmt.Errorf("failed to get taken IPs: %w", err) return "", fmt.Errorf("failed to parse peer host name %s: %w", peerHostName, err)
} }
network, err := transaction.GetAccountNetwork(ctx, store.LockingStrengthUpdate, accountID) _, err = tx.GetPeerIdByLabel(ctx, store.LockingStrengthNone, accountID, dnsName)
if err != nil { if err != nil {
return nil, fmt.Errorf("failed getting network: %w", err) //nolint:nilerr
return dnsName, nil
} }
nextIp, err := types.AllocatePeerIP(network.Net, takenIps) return fmt.Sprintf("%s-%d-%d", dnsName, ip[2], ip[3]), nil
if err != nil {
return nil, fmt.Errorf("failed to allocate new peer ip: %w", err)
}
return nextIp, nil
} }
// SyncPeer checks whether peer is eligible for receiving NetworkMap (authenticated) and returns its NetworkMap if eligible // SyncPeer checks whether peer is eligible for receiving NetworkMap (authenticated) and returns its NetworkMap if eligible
@ -1477,19 +1496,6 @@ func getPeerGroupIDs(ctx context.Context, transaction store.Store, accountID str
return groupIDs, err return groupIDs, err
} }
func getPeerDNSLabels(ctx context.Context, transaction store.Store, accountID string) (types.LookupMap, error) {
dnsLabels, err := transaction.GetPeerLabelsInAccount(ctx, store.LockingStrengthShare, accountID)
if err != nil {
return nil, err
}
existingLabels := make(types.LookupMap)
for _, label := range dnsLabels {
existingLabels[label] = struct{}{}
}
return existingLabels, nil
}
// IsPeerInActiveGroup checks if the given peer is part of a group that is used // IsPeerInActiveGroup checks if the given peer is part of a group that is used
// in an active DNS, route, or ACL configuration. // in an active DNS, route, or ACL configuration.
func isPeerInActiveGroup(ctx context.Context, transaction store.Store, accountID, peerID string) (bool, error) { func isPeerInActiveGroup(ctx context.Context, transaction store.Store, accountID, peerID string) (bool, error) {

View File

@ -20,14 +20,14 @@ type Peer struct {
// WireGuard public key // WireGuard public key
Key string `gorm:"index"` Key string `gorm:"index"`
// IP address of the Peer // IP address of the Peer
IP net.IP `gorm:"serializer:json"` IP net.IP `gorm:"serializer:json"` // uniqueness index per accountID (check migrations)
// Meta is a Peer system meta data // Meta is a Peer system meta data
Meta PeerSystemMeta `gorm:"embedded;embeddedPrefix:meta_"` Meta PeerSystemMeta `gorm:"embedded;embeddedPrefix:meta_"`
// Name is peer's name (machine name) // Name is peer's name (machine name)
Name string Name string
// DNSLabel is the parsed peer name for domain resolution. It is used to form an FQDN by appending the account's // DNSLabel is the parsed peer name for domain resolution. It is used to form an FQDN by appending the account's
// domain to the peer label. e.g. peer-dns-label.netbird.cloud // domain to the peer label. e.g. peer-dns-label.netbird.cloud
DNSLabel string DNSLabel string // uniqueness index per accountID (check migrations)
// Status peer's management connection status // Status peer's management connection status
Status *PeerStatus `gorm:"embedded;embeddedPrefix:peer_status_"` Status *PeerStatus `gorm:"embedded;embeddedPrefix:peer_status_"`
// The user ID that registered the peer // The user ID that registered the peer

View File

@ -10,7 +10,9 @@ import (
"net/netip" "net/netip"
"os" "os"
"runtime" "runtime"
"strconv"
"strings" "strings"
"sync"
"testing" "testing"
"time" "time"
@ -19,6 +21,7 @@ import (
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
"golang.org/x/exp/maps"
"golang.zx2c4.com/wireguard/wgctrl/wgtypes" "golang.zx2c4.com/wireguard/wgctrl/wgtypes"
"github.com/netbirdio/netbird/management/server/integrations/port_forwarding" "github.com/netbirdio/netbird/management/server/integrations/port_forwarding"
@ -1391,7 +1394,7 @@ func Test_RegisterPeerBySetupKey(t *testing.T) {
name: "Absent setup key", name: "Absent setup key",
existingSetupKeyID: "AAAAAAAA-38F5-4553-B31E-DD66C696CEBB", existingSetupKeyID: "AAAAAAAA-38F5-4553-B31E-DD66C696CEBB",
expectAddPeerError: true, expectAddPeerError: true,
expectedErrorMsgSubstring: "failed adding new peer: account not found", expectedErrorMsgSubstring: "failed to get setup key: setup key not found",
}, },
} }
@ -2057,10 +2060,14 @@ func Test_DeletePeer(t *testing.T) {
"peer1": { "peer1": {
ID: "peer1", ID: "peer1",
AccountID: accountID, AccountID: accountID,
IP: net.IP{1, 1, 1, 1},
DNSLabel: "peer1.test",
}, },
"peer2": { "peer2": {
ID: "peer2", ID: "peer2",
AccountID: accountID, AccountID: accountID,
IP: net.IP{2, 2, 2, 2},
DNSLabel: "peer2.test",
}, },
} }
account.Groups = map[string]*types.Group{ account.Groups = map[string]*types.Group{
@ -2090,3 +2097,138 @@ func Test_DeletePeer(t *testing.T) {
assert.NotContains(t, group.Peers, "peer1") assert.NotContains(t, group.Peers, "peer1")
} }
func Test_IsUniqueConstraintError(t *testing.T) {
tests := []struct {
name string
engine types.Engine
}{
{
name: "PostgreSQL uniqueness error",
engine: types.PostgresStoreEngine,
},
{
name: "MySQL uniqueness error",
engine: types.MysqlStoreEngine,
},
{
name: "SQLite uniqueness error",
engine: types.SqliteStoreEngine,
},
}
peer := &nbpeer.Peer{
ID: "test-peer-id",
AccountID: "bf1c8084-ba50-4ce7-9439-34653001fc3b",
DNSLabel: "test-peer-dns-label",
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
t.Setenv("NETBIRD_STORE_ENGINE", string(tt.engine))
s, cleanup, err := store.NewTestStoreFromSQL(context.Background(), "testdata/extended-store.sql", t.TempDir())
if err != nil {
t.Fatalf("Error when creating store: %s", err)
}
t.Cleanup(cleanup)
err = s.AddPeerToAccount(context.Background(), store.LockingStrengthUpdate, peer)
assert.NoError(t, err)
err = s.AddPeerToAccount(context.Background(), store.LockingStrengthUpdate, peer)
result := isUniqueConstraintError(err)
assert.True(t, result)
})
}
}
func Test_AddPeer(t *testing.T) {
t.Setenv("NETBIRD_STORE_ENGINE", string(types.PostgresStoreEngine))
manager, err := createManager(t)
if err != nil {
t.Fatal(err)
return
}
accountID := "testaccount"
userID := "testuser"
_, err = createAccount(manager, accountID, userID, "domain.com")
if err != nil {
t.Fatal("error creating account")
return
}
setupKey, err := manager.CreateSetupKey(context.Background(), accountID, "test-key", types.SetupKeyReusable, time.Hour, nil, 10000, userID, false, false)
if err != nil {
t.Fatal("error creating setup key")
return
}
const totalPeers = 300 // totalPeers / differentHostnames should be less than 10 (due to concurrent retries)
const differentHostnames = 50
var wg sync.WaitGroup
errs := make(chan error, totalPeers+differentHostnames)
start := make(chan struct{})
for i := 0; i < totalPeers; i++ {
wg.Add(1)
hostNameID := i % differentHostnames
go func(i int) {
defer wg.Done()
newPeer := &nbpeer.Peer{
Key: "key" + strconv.Itoa(i),
Meta: nbpeer.PeerSystemMeta{Hostname: "peer" + strconv.Itoa(hostNameID), GoOS: "linux"},
}
<-start
_, _, _, err := manager.AddPeer(context.Background(), setupKey.Key, "", newPeer)
if err != nil {
errs <- fmt.Errorf("AddPeer failed for peer %d: %w", i, err)
return
}
}(i)
}
startTime := time.Now()
close(start)
wg.Wait()
close(errs)
t.Logf("time since start: %s", time.Since(startTime))
for err := range errs {
t.Fatal(err)
}
account, err := manager.Store.GetAccount(context.Background(), accountID)
if err != nil {
t.Fatalf("Failed to get account %s: %v", accountID, err)
}
assert.Equal(t, totalPeers, len(account.Peers), "Expected %d peers in account %s, got %d", totalPeers, accountID, len(account.Peers))
seenIP := make(map[string]bool)
for _, p := range account.Peers {
ipStr := p.IP.String()
if seenIP[ipStr] {
t.Fatalf("Duplicate IP found in account %s: %s", accountID, ipStr)
}
seenIP[ipStr] = true
}
seenLabel := make(map[string]bool)
for _, p := range account.Peers {
if seenLabel[p.DNSLabel] {
t.Fatalf("Duplicate Label found in account %s: %s", accountID, p.DNSLabel)
}
seenLabel[p.DNSLabel] = true
}
assert.Equal(t, totalPeers, maps.Values(account.SetupKeys)[0].UsedTimes)
assert.Equal(t, uint64(totalPeers), account.Network.Serial)
}

View File

@ -156,7 +156,7 @@ func restore(ctx context.Context, file string) (*FileStore, error) {
allGroup, err := account.GetGroupAll() allGroup, err := account.GetGroupAll()
if err != nil { if err != nil {
log.WithContext(ctx).Errorf("unable to find the All group, this should happen only when migrate from a version that didn't support groups. Error: %v", err) log.WithContext(ctx).Errorf("unable to find the All group, this should happen only when migratePreAuto from a version that didn't support groups. Error: %v", err)
// if the All group didn't exist we probably don't have routes to update // if the All group didn't exist we probably don't have routes to update
continue continue
} }

View File

@ -92,8 +92,8 @@ func NewSqlStore(ctx context.Context, db *gorm.DB, storeEngine types.Engine, met
return &SqlStore{db: db, storeEngine: storeEngine, metrics: metrics, installationPK: 1}, nil return &SqlStore{db: db, storeEngine: storeEngine, metrics: metrics, installationPK: 1}, nil
} }
if err := migrate(ctx, db); err != nil { if err := migratePreAuto(ctx, db); err != nil {
return nil, fmt.Errorf("migrate: %w", err) return nil, fmt.Errorf("migratePreAuto: %w", err)
} }
err = db.AutoMigrate( err = db.AutoMigrate(
&types.SetupKey{}, &nbpeer.Peer{}, &types.User{}, &types.PersonalAccessToken{}, &types.Group{}, &types.SetupKey{}, &nbpeer.Peer{}, &types.User{}, &types.PersonalAccessToken{}, &types.Group{},
@ -102,7 +102,10 @@ func NewSqlStore(ctx context.Context, db *gorm.DB, storeEngine types.Engine, met
&networkTypes.Network{}, &routerTypes.NetworkRouter{}, &resourceTypes.NetworkResource{}, &networkTypes.Network{}, &routerTypes.NetworkRouter{}, &resourceTypes.NetworkResource{},
) )
if err != nil { if err != nil {
return nil, fmt.Errorf("auto migrate: %w", err) return nil, fmt.Errorf("auto migratePreAuto: %w", err)
}
if err := migratePostAuto(ctx, db); err != nil {
return nil, fmt.Errorf("migratePostAuto: %w", err)
} }
return &SqlStore{db: db, storeEngine: storeEngine, metrics: metrics, installationPK: 1}, nil return &SqlStore{db: db, storeEngine: storeEngine, metrics: metrics, installationPK: 1}, nil
@ -967,7 +970,7 @@ func (s *SqlStore) GetTakenIPs(ctx context.Context, lockStrength LockingStrength
return ips, nil return ips, nil
} }
func (s *SqlStore) GetPeerLabelsInAccount(ctx context.Context, lockStrength LockingStrength, accountID string) ([]string, error) { func (s *SqlStore) GetPeerLabelsInAccount(ctx context.Context, lockStrength LockingStrength, accountID string, dnsLabel string) ([]string, error) {
tx := s.db tx := s.db
if lockStrength != LockingStrengthNone { if lockStrength != LockingStrengthNone {
tx = tx.Clauses(clause.Locking{Strength: string(lockStrength)}) tx = tx.Clauses(clause.Locking{Strength: string(lockStrength)})
@ -975,7 +978,7 @@ func (s *SqlStore) GetPeerLabelsInAccount(ctx context.Context, lockStrength Lock
var labels []string var labels []string
result := tx.Model(&nbpeer.Peer{}). result := tx.Model(&nbpeer.Peer{}).
Where("account_id = ?", accountID). Where("account_id = ? AND dns_label LIKE ?", accountID, dnsLabel+"%").
Pluck("dns_label", &labels) Pluck("dns_label", &labels)
if result.Error != nil { if result.Error != nil {
@ -1254,7 +1257,7 @@ func (s *SqlStore) GetSetupKeyBySecret(ctx context.Context, lockStrength Locking
if result.Error != nil { if result.Error != nil {
if errors.Is(result.Error, gorm.ErrRecordNotFound) { if errors.Is(result.Error, gorm.ErrRecordNotFound) {
return nil, status.NewSetupKeyNotFoundError(key) return nil, status.Errorf(status.PreconditionFailed, "setup key not found")
} }
log.WithContext(ctx).Errorf("failed to get setup key by secret from store: %v", result.Error) log.WithContext(ctx).Errorf("failed to get setup key by secret from store: %v", result.Error)
return nil, status.Errorf(status.Internal, "failed to get setup key by secret from store") return nil, status.Errorf(status.Internal, "failed to get setup key by secret from store")
@ -2546,6 +2549,27 @@ func (s *SqlStore) GetPeerByIP(ctx context.Context, lockStrength LockingStrength
return &peer, nil return &peer, nil
} }
func (s *SqlStore) GetPeerIdByLabel(ctx context.Context, lockStrength LockingStrength, accountID string, hostname string) (string, error) {
tx := s.db.WithContext(ctx)
if lockStrength != LockingStrengthNone {
tx = tx.Clauses(clause.Locking{Strength: string(lockStrength)})
}
var peerID string
result := tx.Model(&nbpeer.Peer{}).
Select("id").
// Where(" = ?", hostname).
Where("account_id = ? AND dns_label = ?", accountID, hostname).
Limit(1).
Scan(&peerID)
if peerID == "" {
return "", gorm.ErrRecordNotFound
}
return peerID, result.Error
}
func (s *SqlStore) CountAccountsByPrivateDomain(ctx context.Context, domain string) (int64, error) { func (s *SqlStore) CountAccountsByPrivateDomain(ctx context.Context, domain string) (int64, error) {
var count int64 var count int64
result := s.db.Model(&types.Account{}). result := s.db.Model(&types.Account{}).

View File

@ -10,6 +10,7 @@ import (
"net/netip" "net/netip"
"os" "os"
"runtime" "runtime"
"sort"
"sync" "sync"
"testing" "testing"
"time" "time"
@ -630,7 +631,7 @@ func TestMigrate(t *testing.T) {
t.Cleanup(cleanUp) t.Cleanup(cleanUp)
assert.NoError(t, err) assert.NoError(t, err)
err = migrate(context.Background(), store.(*SqlStore).db) err = migratePreAuto(context.Background(), store.(*SqlStore).db)
require.NoError(t, err, "Migration should not fail on empty db") require.NoError(t, err, "Migration should not fail on empty db")
_, ipnet, err := net.ParseCIDR("10.0.0.0/24") _, ipnet, err := net.ParseCIDR("10.0.0.0/24")
@ -685,10 +686,10 @@ func TestMigrate(t *testing.T) {
err = store.(*SqlStore).db.Save(rt).Error err = store.(*SqlStore).db.Save(rt).Error
require.NoError(t, err, "Failed to insert Gob data") require.NoError(t, err, "Failed to insert Gob data")
err = migrate(context.Background(), store.(*SqlStore).db) err = migratePreAuto(context.Background(), store.(*SqlStore).db)
require.NoError(t, err, "Migration should not fail on gob populated db") require.NoError(t, err, "Migration should not fail on gob populated db")
err = migrate(context.Background(), store.(*SqlStore).db) err = migratePreAuto(context.Background(), store.(*SqlStore).db)
require.NoError(t, err, "Migration should not fail on migrated db") require.NoError(t, err, "Migration should not fail on migrated db")
err = store.(*SqlStore).db.Delete(rt).Where("id = ?", "route1").Error err = store.(*SqlStore).db.Delete(rt).Where("id = ?", "route1").Error
@ -704,10 +705,10 @@ func TestMigrate(t *testing.T) {
err = store.(*SqlStore).db.Save(nRT).Error err = store.(*SqlStore).db.Save(nRT).Error
require.NoError(t, err, "Failed to insert json nil slice data") require.NoError(t, err, "Failed to insert json nil slice data")
err = migrate(context.Background(), store.(*SqlStore).db) err = migratePreAuto(context.Background(), store.(*SqlStore).db)
require.NoError(t, err, "Migration should not fail on json nil slice populated db") require.NoError(t, err, "Migration should not fail on json nil slice populated db")
err = migrate(context.Background(), store.(*SqlStore).db) err = migratePreAuto(context.Background(), store.(*SqlStore).db)
require.NoError(t, err, "Migration should not fail on migrated db") require.NoError(t, err, "Migration should not fail on migrated db")
} }
@ -950,6 +951,7 @@ func TestSqlite_GetTakenIPs(t *testing.T) {
peer1 := &nbpeer.Peer{ peer1 := &nbpeer.Peer{
ID: "peer1", ID: "peer1",
AccountID: existingAccountID, AccountID: existingAccountID,
DNSLabel: "peer1",
IP: net.IP{1, 1, 1, 1}, IP: net.IP{1, 1, 1, 1},
} }
err = store.AddPeerToAccount(context.Background(), LockingStrengthUpdate, peer1) err = store.AddPeerToAccount(context.Background(), LockingStrengthUpdate, peer1)
@ -961,8 +963,9 @@ func TestSqlite_GetTakenIPs(t *testing.T) {
assert.Equal(t, []net.IP{ip1}, takenIPs) assert.Equal(t, []net.IP{ip1}, takenIPs)
peer2 := &nbpeer.Peer{ peer2 := &nbpeer.Peer{
ID: "peer2", ID: "peer1second",
AccountID: existingAccountID, AccountID: existingAccountID,
DNSLabel: "peer1-1",
IP: net.IP{2, 2, 2, 2}, IP: net.IP{2, 2, 2, 2},
} }
err = store.AddPeerToAccount(context.Background(), LockingStrengthUpdate, peer2) err = store.AddPeerToAccount(context.Background(), LockingStrengthUpdate, peer2)
@ -972,26 +975,59 @@ func TestSqlite_GetTakenIPs(t *testing.T) {
require.NoError(t, err) require.NoError(t, err)
ip2 := net.IP{2, 2, 2, 2}.To16() ip2 := net.IP{2, 2, 2, 2}.To16()
assert.Equal(t, []net.IP{ip1, ip2}, takenIPs) assert.Equal(t, []net.IP{ip1, ip2}, takenIPs)
} }
func TestSqlite_GetPeerLabelsInAccount(t *testing.T) { func TestSqlite_GetPeerLabelsInAccount(t *testing.T) {
t.Setenv("NETBIRD_STORE_ENGINE", string(types.SqliteStoreEngine)) runTestForAllEngines(t, "../testdata/extended-store.sql", func(t *testing.T, store Store) {
store, cleanup, err := NewTestStoreFromSQL(context.Background(), "../testdata/extended-store.sql", t.TempDir())
if err != nil {
return
}
t.Cleanup(cleanup)
existingAccountID := "bf1c8084-ba50-4ce7-9439-34653001fc3b" existingAccountID := "bf1c8084-ba50-4ce7-9439-34653001fc3b"
peerHostname := "peer1"
_, err = store.GetAccount(context.Background(), existingAccountID) _, err := store.GetAccount(context.Background(), existingAccountID)
require.NoError(t, err) require.NoError(t, err)
labels, err := store.GetPeerLabelsInAccount(context.Background(), LockingStrengthShare, existingAccountID) labels, err := store.GetPeerLabelsInAccount(context.Background(), LockingStrengthShare, existingAccountID, peerHostname)
require.NoError(t, err) require.NoError(t, err)
assert.Equal(t, []string{}, labels) assert.Equal(t, []string{}, labels)
peer1 := &nbpeer.Peer{
ID: "peer1",
AccountID: existingAccountID,
DNSLabel: "peer1",
IP: net.IP{1, 1, 1, 1},
}
err = store.AddPeerToAccount(context.Background(), LockingStrengthUpdate, peer1)
require.NoError(t, err)
labels, err = store.GetPeerLabelsInAccount(context.Background(), LockingStrengthShare, existingAccountID, peerHostname)
require.NoError(t, err)
assert.Equal(t, []string{"peer1"}, labels)
peer2 := &nbpeer.Peer{
ID: "peer1second",
AccountID: existingAccountID,
DNSLabel: "peer1-1",
IP: net.IP{2, 2, 2, 2},
}
err = store.AddPeerToAccount(context.Background(), LockingStrengthUpdate, peer2)
require.NoError(t, err)
labels, err = store.GetPeerLabelsInAccount(context.Background(), LockingStrengthShare, existingAccountID, peerHostname)
require.NoError(t, err)
expected := []string{"peer1", "peer1-1"}
sort.Strings(expected)
sort.Strings(labels)
assert.Equal(t, expected, labels)
})
}
func Test_AddPeerWithSameDnsLabel(t *testing.T) {
runTestForAllEngines(t, "../testdata/extended-store.sql", func(t *testing.T, store Store) {
existingAccountID := "bf1c8084-ba50-4ce7-9439-34653001fc3b"
_, err := store.GetAccount(context.Background(), existingAccountID)
require.NoError(t, err)
peer1 := &nbpeer.Peer{ peer1 := &nbpeer.Peer{
ID: "peer1", ID: "peer1",
AccountID: existingAccountID, AccountID: existingAccountID,
@ -1000,21 +1036,39 @@ func TestSqlite_GetPeerLabelsInAccount(t *testing.T) {
err = store.AddPeerToAccount(context.Background(), LockingStrengthUpdate, peer1) err = store.AddPeerToAccount(context.Background(), LockingStrengthUpdate, peer1)
require.NoError(t, err) require.NoError(t, err)
labels, err = store.GetPeerLabelsInAccount(context.Background(), LockingStrengthShare, existingAccountID)
require.NoError(t, err)
assert.Equal(t, []string{"peer1.domain.test"}, labels)
peer2 := &nbpeer.Peer{ peer2 := &nbpeer.Peer{
ID: "peer2", ID: "peer1second",
AccountID: existingAccountID, AccountID: existingAccountID,
DNSLabel: "peer2.domain.test", DNSLabel: "peer1.domain.test",
} }
err = store.AddPeerToAccount(context.Background(), LockingStrengthUpdate, peer2) err = store.AddPeerToAccount(context.Background(), LockingStrengthUpdate, peer2)
require.Error(t, err)
})
}
func Test_AddPeerWithSameIP(t *testing.T) {
runTestForAllEngines(t, "../testdata/extended-store.sql", func(t *testing.T, store Store) {
existingAccountID := "bf1c8084-ba50-4ce7-9439-34653001fc3b"
_, err := store.GetAccount(context.Background(), existingAccountID)
require.NoError(t, err) require.NoError(t, err)
labels, err = store.GetPeerLabelsInAccount(context.Background(), LockingStrengthShare, existingAccountID) peer1 := &nbpeer.Peer{
ID: "peer1",
AccountID: existingAccountID,
IP: net.IP{1, 1, 1, 1},
}
err = store.AddPeerToAccount(context.Background(), LockingStrengthUpdate, peer1)
require.NoError(t, err) require.NoError(t, err)
assert.Equal(t, []string{"peer1.domain.test", "peer2.domain.test"}, labels)
peer2 := &nbpeer.Peer{
ID: "peer1second",
AccountID: existingAccountID,
IP: net.IP{1, 1, 1, 1},
}
err = store.AddPeerToAccount(context.Background(), LockingStrengthUpdate, peer2)
require.Error(t, err)
})
} }
func TestSqlite_GetAccountNetwork(t *testing.T) { func TestSqlite_GetAccountNetwork(t *testing.T) {

View File

@ -117,7 +117,7 @@ type Store interface {
SavePostureChecks(ctx context.Context, lockStrength LockingStrength, postureCheck *posture.Checks) error SavePostureChecks(ctx context.Context, lockStrength LockingStrength, postureCheck *posture.Checks) error
DeletePostureChecks(ctx context.Context, lockStrength LockingStrength, accountID, postureChecksID string) error DeletePostureChecks(ctx context.Context, lockStrength LockingStrength, accountID, postureChecksID string) error
GetPeerLabelsInAccount(ctx context.Context, lockStrength LockingStrength, accountId string) ([]string, error) GetPeerLabelsInAccount(ctx context.Context, lockStrength LockingStrength, accountId string, hostname string) ([]string, error)
AddPeerToAllGroup(ctx context.Context, lockStrength LockingStrength, accountID string, peerID string) error AddPeerToAllGroup(ctx context.Context, lockStrength LockingStrength, accountID string, peerID string) error
AddPeerToGroup(ctx context.Context, lockStrength LockingStrength, accountId string, peerId string, groupID string) error AddPeerToGroup(ctx context.Context, lockStrength LockingStrength, accountId string, peerId string, groupID string) error
GetPeerGroups(ctx context.Context, lockStrength LockingStrength, accountId string, peerId string) ([]*types.Group, error) GetPeerGroups(ctx context.Context, lockStrength LockingStrength, accountId string, peerId string) ([]*types.Group, error)
@ -193,6 +193,7 @@ type Store interface {
SaveNetworkResource(ctx context.Context, lockStrength LockingStrength, resource *resourceTypes.NetworkResource) error SaveNetworkResource(ctx context.Context, lockStrength LockingStrength, resource *resourceTypes.NetworkResource) error
DeleteNetworkResource(ctx context.Context, lockStrength LockingStrength, accountID, resourceID string) error DeleteNetworkResource(ctx context.Context, lockStrength LockingStrength, accountID, resourceID string) error
GetPeerByIP(ctx context.Context, lockStrength LockingStrength, accountID string, ip net.IP) (*nbpeer.Peer, error) GetPeerByIP(ctx context.Context, lockStrength LockingStrength, accountID string, ip net.IP) (*nbpeer.Peer, error)
GetPeerIdByLabel(ctx context.Context, lockStrength LockingStrength, accountID string, hostname string) (string, error)
} }
const ( const (
@ -234,9 +235,9 @@ func getStoreEngine(ctx context.Context, dataDir string, kind types.Engine) type
if util.FileExists(jsonStoreFile) && !util.FileExists(sqliteStoreFile) { if util.FileExists(jsonStoreFile) && !util.FileExists(sqliteStoreFile) {
log.WithContext(ctx).Warnf("unsupported store engine specified, but found %s. Automatically migrating to SQLite.", jsonStoreFile) log.WithContext(ctx).Warnf("unsupported store engine specified, but found %s. Automatically migrating to SQLite.", jsonStoreFile)
// Attempt to migrate from JSON store to SQLite // Attempt to migratePreAuto from JSON store to SQLite
if err := MigrateFileStoreToSqlite(ctx, dataDir); err != nil { if err := MigrateFileStoreToSqlite(ctx, dataDir); err != nil {
log.WithContext(ctx).Errorf("failed to migrate filestore to SQLite: %v", err) log.WithContext(ctx).Errorf("failed to migratePreAuto filestore to SQLite: %v", err)
kind = types.FileStoreEngine kind = types.FileStoreEngine
} }
} }
@ -280,9 +281,9 @@ func checkFileStoreEngine(kind types.Engine, dataDir string) error {
return nil return nil
} }
// migrate migrates the SQLite database to the latest schema // migratePreAuto migrates the SQLite database to the latest schema
func migrate(ctx context.Context, db *gorm.DB) error { func migratePreAuto(ctx context.Context, db *gorm.DB) error {
migrations := getMigrations(ctx) migrations := getMigrationsPreAuto(ctx)
for _, m := range migrations { for _, m := range migrations {
if err := m(db); err != nil { if err := m(db); err != nil {
@ -293,7 +294,7 @@ func migrate(ctx context.Context, db *gorm.DB) error {
return nil return nil
} }
func getMigrations(ctx context.Context) []migrationFunc { func getMigrationsPreAuto(ctx context.Context) []migrationFunc {
return []migrationFunc{ return []migrationFunc{
func(db *gorm.DB) error { func(db *gorm.DB) error {
return migration.MigrateFieldFromGobToJSON[types.Account, net.IPNet](ctx, db, "network_net") return migration.MigrateFieldFromGobToJSON[types.Account, net.IPNet](ctx, db, "network_net")
@ -329,6 +330,28 @@ func getMigrations(ctx context.Context) []migrationFunc {
return migration.DropIndex[routerTypes.NetworkRouter](ctx, db, "idx_network_routers_id") return migration.DropIndex[routerTypes.NetworkRouter](ctx, db, "idx_network_routers_id")
}, },
} }
} // migratePostAuto migrates the SQLite database to the latest schema
func migratePostAuto(ctx context.Context, db *gorm.DB) error {
migrations := getMigrationsPostAuto(ctx)
for _, m := range migrations {
if err := m(db); err != nil {
return err
}
}
return nil
}
func getMigrationsPostAuto(ctx context.Context) []migrationFunc {
return []migrationFunc{
func(db *gorm.DB) error {
return migration.CreateIndexIfNotExists[nbpeer.Peer](ctx, db, "idx_account_ip", "account_id", "ip")
},
func(db *gorm.DB) error {
return migration.CreateIndexIfNotExists[nbpeer.Peer](ctx, db, "idx_account_dnslabel", "account_id", "dns_label")
},
}
} }
// NewTestStoreFromSQL is only used in tests. It will create a test database base of the store engine set in env. // NewTestStoreFromSQL is only used in tests. It will create a test database base of the store engine set in env.
@ -577,7 +600,7 @@ func MigrateFileStoreToSqlite(ctx context.Context, dataDir string) error {
sqliteStoreAccounts := len(store.GetAllAccounts(ctx)) sqliteStoreAccounts := len(store.GetAllAccounts(ctx))
if fsStoreAccounts != sqliteStoreAccounts { if fsStoreAccounts != sqliteStoreAccounts {
return fmt.Errorf("failed to migrate accounts from file to sqlite. Expected accounts: %d, got: %d", return fmt.Errorf("failed to migratePreAuto accounts from file to sqlite. Expected accounts: %d, got: %d",
fsStoreAccounts, sqliteStoreAccounts) fsStoreAccounts, sqliteStoreAccounts)
} }

View File

@ -52,4 +52,4 @@ INSERT INTO policy_rules VALUES('cs387mkv2d4bgq41b6n0','cs1tnh0hhcjnqoiuebf0','D
INSERT INTO network_routers VALUES('ctc20ji7qv9ck2sebc80','ct286bi7qv930dsrrug0','bf1c8084-ba50-4ce7-9439-34653001fc3b','cs1tnh0hhcjnqoiuebeg',NULL,0,0); INSERT INTO network_routers VALUES('ctc20ji7qv9ck2sebc80','ct286bi7qv930dsrrug0','bf1c8084-ba50-4ce7-9439-34653001fc3b','cs1tnh0hhcjnqoiuebeg',NULL,0,0);
INSERT INTO network_resources VALUES ('ctc4nci7qv9061u6ilfg','ct286bi7qv930dsrrug0','bf1c8084-ba50-4ce7-9439-34653001fc3b','Host','192.168.1.1'); INSERT INTO network_resources VALUES ('ctc4nci7qv9061u6ilfg','ct286bi7qv930dsrrug0','bf1c8084-ba50-4ce7-9439-34653001fc3b','Host','192.168.1.1');
INSERT INTO networks VALUES('ct286bi7qv930dsrrug0','bf1c8084-ba50-4ce7-9439-34653001fc3b','Test Network','Test Network'); INSERT INTO networks VALUES('ct286bi7qv930dsrrug0','bf1c8084-ba50-4ce7-9439-34653001fc3b','Test Network','Test Network');
INSERT INTO peers VALUES('ct286bi7qv930dsrrug0','bf1c8084-ba50-4ce7-9439-34653001fc3b','','','192.168.0.0','','','','','','','','','','','','','','','','','test','test','2023-01-01 00:00:00+00:00',0,0,0,'a23efe53-63fb-11ec-90d6-0242ac120003','',0,0,'2023-01-01 00:00:00+00:00','2023-01-01 00:00:00+00:00',0,'','','',0); INSERT INTO peers VALUES('ct286bi7qv930dsrrug0','bf1c8084-ba50-4ce7-9439-34653001fc3b','','','"192.168.0.0"','','','','','','','','','','','','','','','','','test','test','2023-01-01 00:00:00+00:00',0,0,0,'a23efe53-63fb-11ec-90d6-0242ac120003','',0,0,'2023-01-01 00:00:00+00:00','2023-01-01 00:00:00+00:00',0,'','','',0);

View File

@ -30,7 +30,7 @@ INSERT INTO setup_keys VALUES('','bf1c8084-ba50-4ce7-9439-34653001fc3b','A2C8E62
INSERT INTO peers VALUES('cfvprsrlo1hqoo49ohog','bf1c8084-ba50-4ce7-9439-34653001fc3b','5rvhvriKJZ3S9oxYToVj5TzDM9u9y8cxg7htIMWlYAg=','72546A29-6BC8-4311-BCFC-9CDBF33F1A48','"100.64.114.31"','f2a34f6a4731','linux','Linux','11','unknown','Debian GNU/Linux','','0.12.0','','',NULL,'','','','{"Cloud":"","Platform":""}',NULL,'f2a34f6a4731','f2a34f6a4731','2023-03-02 09:21:02.189035775+01:00',0,0,0,'','ssh-ed25519 AAAAC3NzaC1lZDI1NTE5AAAAILzUUSYG/LGnV8zarb2SGN+tib/PZ+M7cL4WtTzUrTpk',0,1,0,'2023-03-01 19:48:19.817799698+01:00','2024-10-02 17:00:32.527947+02:00',0,'""','','',0); INSERT INTO peers VALUES('cfvprsrlo1hqoo49ohog','bf1c8084-ba50-4ce7-9439-34653001fc3b','5rvhvriKJZ3S9oxYToVj5TzDM9u9y8cxg7htIMWlYAg=','72546A29-6BC8-4311-BCFC-9CDBF33F1A48','"100.64.114.31"','f2a34f6a4731','linux','Linux','11','unknown','Debian GNU/Linux','','0.12.0','','',NULL,'','','','{"Cloud":"","Platform":""}',NULL,'f2a34f6a4731','f2a34f6a4731','2023-03-02 09:21:02.189035775+01:00',0,0,0,'','ssh-ed25519 AAAAC3NzaC1lZDI1NTE5AAAAILzUUSYG/LGnV8zarb2SGN+tib/PZ+M7cL4WtTzUrTpk',0,1,0,'2023-03-01 19:48:19.817799698+01:00','2024-10-02 17:00:32.527947+02:00',0,'""','','',0);
INSERT INTO peers VALUES('cg05lnblo1hkg2j514p0','bf1c8084-ba50-4ce7-9439-34653001fc3b','RlSy2vzoG2HyMBTUImXOiVhCBiiBa5qD5xzMxkiFDW4=','','"100.64.39.54"','expiredhost','linux','Linux','22.04','x86_64','Ubuntu','','development','','',NULL,'','','','{"Cloud":"","Platform":""}',NULL,'expiredhost','expiredhost','2023-03-02 09:19:57.276717255+01:00',0,1,0,'edafee4e-63fb-11ec-90d6-0242ac120003','ssh-ed25519 AAAAC3NzaC1lZDI1NTE5AAAAIMbK5ZXJsGOOWoBT4OmkPtgdPZe2Q7bDuS/zjn2CZxhK',0,1,0,'2023-03-02 09:14:21.791679181+01:00','2024-10-02 17:00:32.527947+02:00',0,'""','','',0); INSERT INTO peers VALUES('cg05lnblo1hkg2j514p0','bf1c8084-ba50-4ce7-9439-34653001fc3b','RlSy2vzoG2HyMBTUImXOiVhCBiiBa5qD5xzMxkiFDW4=','','"100.64.39.54"','expiredhost','linux','Linux','22.04','x86_64','Ubuntu','','development','','',NULL,'','','','{"Cloud":"","Platform":""}',NULL,'expiredhost','expiredhost','2023-03-02 09:19:57.276717255+01:00',0,1,0,'edafee4e-63fb-11ec-90d6-0242ac120003','ssh-ed25519 AAAAC3NzaC1lZDI1NTE5AAAAIMbK5ZXJsGOOWoBT4OmkPtgdPZe2Q7bDuS/zjn2CZxhK',0,1,0,'2023-03-02 09:14:21.791679181+01:00','2024-10-02 17:00:32.527947+02:00',0,'""','','',0);
INSERT INTO peers VALUES('cg3161rlo1hs9cq94gdg','bf1c8084-ba50-4ce7-9439-34653001fc3b','mVABSKj28gv+JRsf7e0NEGKgSOGTfU/nPB2cpuG56HU=','','"100.64.117.96"','testhost','linux','Linux','22.04','x86_64','Ubuntu','','development','','',NULL,'','','','{"Cloud":"","Platform":""}',NULL,'testhost','testhost','2023-03-06 18:21:27.252010027+01:00',0,0,0,'edafee4e-63fb-11ec-90d6-0242ac120003','ssh-ed25519 AAAAC3NzaC1lZDI1NTE5AAAAINWvvUkFFcrj48CWTkNUb/do/n52i1L5dH4DhGu+4ZuM',0,0,0,'2023-03-07 09:02:47.442857106+01:00','2024-10-02 17:00:32.527947+02:00',0,'""','','',0); INSERT INTO peers VALUES('cg3161rlo1hs9cq94gdg','bf1c8084-ba50-4ce7-9439-34653001fc3b','mVABSKj28gv+JRsf7e0NEGKgSOGTfU/nPB2cpuG56HU=','','"100.64.117.96"','testhost','linux','Linux','22.04','x86_64','Ubuntu','','development','','',NULL,'','','','{"Cloud":"","Platform":""}',NULL,'testhost','testhost','2023-03-06 18:21:27.252010027+01:00',0,0,0,'edafee4e-63fb-11ec-90d6-0242ac120003','ssh-ed25519 AAAAC3NzaC1lZDI1NTE5AAAAINWvvUkFFcrj48CWTkNUb/do/n52i1L5dH4DhGu+4ZuM',0,0,0,'2023-03-07 09:02:47.442857106+01:00','2024-10-02 17:00:32.527947+02:00',0,'""','','',0);
INSERT INTO peers VALUES('csrnkiq7qv9d8aitqd50','bf1c8084-ba50-4ce7-9439-34653001fc3b','mVABSKj28gv+JRsf7e0NEGKgSOGTfU/nPB2cpuG56HU=','','"100.64.117.96"','testhost','linux','Linux','22.04','x86_64','Ubuntu','','development','','',NULL,'','','','{"Cloud":"","Platform":""}',NULL,'testhost','testhost','2023-03-06 18:21:27.252010027+01:00',0,0,0,'f4f6d672-63fb-11ec-90d6-0242ac120003','ssh-ed25519 AAAAC3NzaC1lZDI1NTE5AAAAINWvvUkFFcrj48CWTkNUb/do/n52i1L5dH4DhGu+4ZuM',0,0,1,'2023-03-07 09:02:47.442857106+01:00','2024-10-02 17:00:32.527947+02:00',0,'""','','',0); INSERT INTO peers VALUES('csrnkiq7qv9d8aitqd50','bf1c8084-ba50-4ce7-9439-34653001fc3b','mVABSKj28gv+JRsf7e0NEGKgSOGTfU/nPB2cpuG56HU=','','"100.64.117.97"','testhost','linux','Linux','22.04','x86_64','Ubuntu','','development','','',NULL,'','','','{"Cloud":"","Platform":""}',NULL,'testhost','testhost-1','2023-03-06 18:21:27.252010027+01:00',0,0,0,'f4f6d672-63fb-11ec-90d6-0242ac120003','ssh-ed25519 AAAAC3NzaC1lZDI1NTE5AAAAINWvvUkFFcrj48CWTkNUb/do/n52i1L5dH4DhGu+4ZuM',0,0,1,'2023-03-07 09:02:47.442857106+01:00','2024-10-02 17:00:32.527947+02:00',0,'""','','',0);
INSERT INTO users VALUES('f4f6d672-63fb-11ec-90d6-0242ac120003','bf1c8084-ba50-4ce7-9439-34653001fc3b','user',0,0,'','[]',0,NULL,'2024-10-02 17:00:32.528196+02:00','api',0,''); INSERT INTO users VALUES('f4f6d672-63fb-11ec-90d6-0242ac120003','bf1c8084-ba50-4ce7-9439-34653001fc3b','user',0,0,'','[]',0,NULL,'2024-10-02 17:00:32.528196+02:00','api',0,'');
INSERT INTO users VALUES('edafee4e-63fb-11ec-90d6-0242ac120003','bf1c8084-ba50-4ce7-9439-34653001fc3b','admin',0,0,'','[]',0,NULL,'2024-10-02 17:00:32.528196+02:00','api',0,''); INSERT INTO users VALUES('edafee4e-63fb-11ec-90d6-0242ac120003','bf1c8084-ba50-4ce7-9439-34653001fc3b','admin',0,0,'','[]',0,NULL,'2024-10-02 17:00:32.528196+02:00','api',0,'');
INSERT INTO installations VALUES(1,''); INSERT INTO installations VALUES(1,'');

View File

@ -1,6 +1,7 @@
package types package types
import ( import (
"encoding/binary"
"math/rand" "math/rand"
"net" "net"
"sync" "sync"
@ -161,24 +162,65 @@ func (n *Network) Copy() *Network {
// This method considers already taken IPs and reuses IPs if there are gaps in takenIps // This method considers already taken IPs and reuses IPs if there are gaps in takenIps
// E.g. if ipNet=100.30.0.0/16 and takenIps=[100.30.0.1, 100.30.0.4] then the result would be 100.30.0.2 or 100.30.0.3 // E.g. if ipNet=100.30.0.0/16 and takenIps=[100.30.0.1, 100.30.0.4] then the result would be 100.30.0.2 or 100.30.0.3
func AllocatePeerIP(ipNet net.IPNet, takenIps []net.IP) (net.IP, error) { func AllocatePeerIP(ipNet net.IPNet, takenIps []net.IP) (net.IP, error) {
takenIPMap := make(map[string]struct{}) baseIP := ipToUint32(ipNet.IP.Mask(ipNet.Mask))
takenIPMap[ipNet.IP.String()] = struct{}{} totalIPs := uint32(1 << SubnetSize)
taken := make(map[uint32]struct{}, len(takenIps)+1)
taken[baseIP] = struct{}{} // reserve network IP
taken[baseIP+totalIPs-1] = struct{}{} // reserve broadcast IP
for _, ip := range takenIps { for _, ip := range takenIps {
takenIPMap[ip.String()] = struct{}{} taken[ipToUint32(ip)] = struct{}{}
} }
ips, _ := generateIPs(&ipNet, takenIPMap) rng := rand.New(rand.NewSource(time.Now().UnixNano()))
maxAttempts := (int(totalIPs) - len(taken)) / 100
if len(ips) == 0 { for i := 0; i < maxAttempts; i++ {
return nil, status.Errorf(status.PreconditionFailed, "failed allocating new IP for the ipNet %s - network is out of IPs", ipNet.String()) offset := uint32(rng.Intn(int(totalIPs-2))) + 1
candidate := baseIP + offset
if _, exists := taken[candidate]; !exists {
return uint32ToIP(candidate), nil
}
} }
// pick a random IP for offset := uint32(1); offset < totalIPs-1; offset++ {
s := rand.NewSource(time.Now().Unix()) candidate := baseIP + offset
r := rand.New(s) if _, exists := taken[candidate]; !exists {
intn := r.Intn(len(ips)) return uint32ToIP(candidate), nil
}
}
return ips[intn], nil return nil, status.Errorf(status.PreconditionFailed, "network %s is out of IPs", ipNet.String())
}
func AllocateRandomPeerIP(ipNet net.IPNet) (net.IP, error) {
baseIP := ipToUint32(ipNet.IP.Mask(ipNet.Mask))
ones, bits := ipNet.Mask.Size()
hostBits := bits - ones
totalIPs := uint32(1 << hostBits)
rng := rand.New(rand.NewSource(time.Now().UnixNano()))
offset := uint32(rng.Intn(int(totalIPs-2))) + 1
candidate := baseIP + offset
return uint32ToIP(candidate), nil
}
func ipToUint32(ip net.IP) uint32 {
ip = ip.To4()
if len(ip) < 4 {
return 0
}
return binary.BigEndian.Uint32(ip)
}
func uint32ToIP(n uint32) net.IP {
ip := make(net.IP, 4)
binary.BigEndian.PutUint32(ip, n)
return ip
} }
// generateIPs generates a list of all possible IPs of the given network excluding IPs specified in the exclusion list // generateIPs generates a list of all possible IPs of the given network excluding IPs specified in the exclusion list