diff --git a/management/server/account.go b/management/server/account.go index 3b7359502..8a80aefb6 100644 --- a/management/server/account.go +++ b/management/server/account.go @@ -106,6 +106,18 @@ type DefaultAccountManager struct { disableDefaultPolicy bool } +func isUniqueConstraintError(err error) bool { + switch { + case strings.Contains(err.Error(), "(SQLSTATE 23505)"), + strings.Contains(err.Error(), "Error 1062 (23000)"), + strings.Contains(err.Error(), "UNIQUE constraint failed"): + return true + + default: + return false + } +} + // getJWTGroupsChanges calculates the changes needed to sync a user's JWT groups. // Returns a bool indicating if there are changes in the JWT group membership, the updated user AutoGroups, // newly groups to create and an error if any occurred. @@ -1661,25 +1673,6 @@ func (am *DefaultAccountManager) handleUserPeer(ctx context.Context, transaction return false, nil } -func (am *DefaultAccountManager) getFreeDNSLabel(ctx context.Context, s store.Store, accountID string, peerHostName string) (string, error) { - existingLabels, err := s.GetPeerLabelsInAccount(ctx, store.LockingStrengthShare, accountID) - if err != nil { - return "", fmt.Errorf("failed to get peer dns labels: %w", err) - } - - labelMap := ConvertSliceToMap(existingLabels) - newLabel, err := types.GetPeerHostLabel(peerHostName, labelMap) - if err != nil { - return "", fmt.Errorf("failed to get new host label: %w", err) - } - - if newLabel == "" { - return "", fmt.Errorf("failed to get new host label: %w", err) - } - - return newLabel, nil -} - func (am *DefaultAccountManager) GetAccountSettings(ctx context.Context, accountID string, userID string) (*types.Settings, error) { allowed, err := am.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Settings, operations.Read) if err != nil { diff --git a/management/server/account_test.go b/management/server/account_test.go index 7f319b81e..60353389f 100644 --- a/management/server/account_test.go +++ b/management/server/account_test.go @@ -2623,11 +2623,11 @@ func TestAccount_SetJWTGroups(t *testing.T) { account := &types.Account{ Id: "accountID", Peers: map[string]*nbpeer.Peer{ - "peer1": {ID: "peer1", Key: "key1", UserID: "user1"}, - "peer2": {ID: "peer2", Key: "key2", UserID: "user1"}, - "peer3": {ID: "peer3", Key: "key3", UserID: "user1"}, - "peer4": {ID: "peer4", Key: "key4", UserID: "user2"}, - "peer5": {ID: "peer5", Key: "key5", UserID: "user2"}, + "peer1": {ID: "peer1", Key: "key1", UserID: "user1", IP: net.IP{1, 1, 1, 1}, DNSLabel: "peer1.domain.test"}, + "peer2": {ID: "peer2", Key: "key2", UserID: "user1", IP: net.IP{2, 2, 2, 2}, DNSLabel: "peer2.domain.test"}, + "peer3": {ID: "peer3", Key: "key3", UserID: "user1", IP: net.IP{3, 3, 3, 3}, DNSLabel: "peer3.domain.test"}, + "peer4": {ID: "peer4", Key: "key4", UserID: "user2", IP: net.IP{4, 4, 4, 4}, DNSLabel: "peer4.domain.test"}, + "peer5": {ID: "peer5", Key: "key5", UserID: "user2", IP: net.IP{5, 5, 5, 5}, DNSLabel: "peer5.domain.test"}, }, Groups: map[string]*types.Group{ "group1": {ID: "group1", Name: "group1", Issued: types.GroupIssuedAPI, Peers: []string{}}, @@ -3147,11 +3147,11 @@ func BenchmarkLoginPeer_NewPeer(b *testing.B) { minMsPerOpCICD float64 maxMsPerOpCICD float64 }{ - {"Small", 50, 5, 7, 20, 10, 80}, + {"Small", 50, 5, 7, 20, 5, 80}, {"Medium", 500, 100, 5, 40, 30, 140}, {"Large", 5000, 200, 80, 120, 140, 390}, - {"Small single", 50, 10, 7, 20, 10, 80}, - {"Medium single", 500, 10, 5, 40, 20, 85}, + {"Small single", 50, 10, 7, 20, 6, 80}, + {"Medium single", 500, 10, 5, 40, 15, 85}, {"Large 5", 5000, 15, 80, 120, 80, 200}, } @@ -3343,11 +3343,11 @@ func TestPropagateUserGroupMemberships(t *testing.T) { account, err := manager.GetOrCreateAccountByUser(ctx, initiatorId, domain) require.NoError(t, err) - peer1 := &nbpeer.Peer{ID: "peer1", AccountID: account.Id, UserID: initiatorId} + peer1 := &nbpeer.Peer{ID: "peer1", AccountID: account.Id, UserID: initiatorId, IP: net.IP{1, 1, 1, 1}, DNSLabel: "peer1.domain.test"} err = manager.Store.AddPeerToAccount(ctx, store.LockingStrengthUpdate, peer1) require.NoError(t, err) - peer2 := &nbpeer.Peer{ID: "peer2", AccountID: account.Id, UserID: initiatorId} + peer2 := &nbpeer.Peer{ID: "peer2", AccountID: account.Id, UserID: initiatorId, IP: net.IP{2, 2, 2, 2}, DNSLabel: "peer2.domain.test"} err = manager.Store.AddPeerToAccount(ctx, store.LockingStrengthUpdate, peer2) require.NoError(t, err) diff --git a/management/server/migration/migration.go b/management/server/migration/migration.go index c8a852e0a..ab11be731 100644 --- a/management/server/migration/migration.go +++ b/management/server/migration/migration.go @@ -373,3 +373,42 @@ func DropIndex[T any](ctx context.Context, db *gorm.DB, indexName string) error log.WithContext(ctx).Infof("dropped index %s from table %T", indexName, model) return nil } + +func CreateIndexIfNotExists[T any](ctx context.Context, db *gorm.DB, indexName string, columns ...string) error { + var model T + + stmt := &gorm.Statement{DB: db} + if err := stmt.Parse(&model); err != nil { + return fmt.Errorf("failed to parse model schema: %w", err) + } + tableName := stmt.Schema.Table + dialect := db.Dialector.Name() + + var columnClause string + if dialect == "mysql" { + var withLength []string + for _, col := range columns { + if col == "ip" || col == "dns_label" { + withLength = append(withLength, fmt.Sprintf("%s(64)", col)) + } else { + withLength = append(withLength, col) + } + } + columnClause = strings.Join(withLength, ", ") + } else { + columnClause = strings.Join(columns, ", ") + } + + createStmt := fmt.Sprintf("CREATE UNIQUE INDEX %s ON %s (%s)", indexName, tableName, columnClause) + if dialect == "postgres" || dialect == "sqlite" { + createStmt = strings.Replace(createStmt, "CREATE UNIQUE INDEX", "CREATE UNIQUE INDEX IF NOT EXISTS", 1) + } + + log.WithContext(ctx).Infof("executing index creation: %s", createStmt) + if err := db.Exec(createStmt).Error; err != nil { + return fmt.Errorf("failed to create index %s: %w", indexName, err) + } + + log.WithContext(ctx).Infof("successfully created index %s on table %s", indexName, tableName) + return nil +} diff --git a/management/server/peer.go b/management/server/peer.go index 254048a96..2c1d8f64c 100644 --- a/management/server/peer.go +++ b/management/server/peer.go @@ -15,13 +15,14 @@ import ( log "github.com/sirupsen/logrus" "golang.org/x/exp/maps" + nbdns "github.com/netbirdio/netbird/dns" "github.com/netbirdio/netbird/management/domain" "github.com/netbirdio/netbird/management/server/geolocation" + "github.com/netbirdio/netbird/management/server/idp" routerTypes "github.com/netbirdio/netbird/management/server/networks/routers/types" "github.com/netbirdio/netbird/management/server/permissions/modules" "github.com/netbirdio/netbird/management/server/permissions/operations" - "github.com/netbirdio/netbird/management/server/idp" "github.com/netbirdio/netbird/management/server/posture" "github.com/netbirdio/netbird/management/server/store" "github.com/netbirdio/netbird/management/server/types" @@ -234,14 +235,10 @@ func (am *DefaultAccountManager) UpdatePeer(ctx context.Context, accountID, user } if peer.Name != update.Name { - existingLabels, err := getPeerDNSLabels(ctx, transaction, accountID) + var newLabel string + newLabel, err = getPeerIPDNSLabel(ctx, transaction, peer.IP, accountID, update.Name) if err != nil { - return err - } - - newLabel, err := types.GetPeerHostLabel(update.Name, existingLabels) - if err != nil { - return err + return fmt.Errorf("failed to get free DNS label: %w", err) } peer.Name = update.Name @@ -463,208 +460,232 @@ func (am *DefaultAccountManager) AddPeer(ctx context.Context, setupKey, userID s upperKey := strings.ToUpper(setupKey) hashedKey := sha256.Sum256([]byte(upperKey)) encodedHashedKey := b64.StdEncoding.EncodeToString(hashedKey[:]) - var accountID string - var err error - addedByUser := false - if len(userID) > 0 { - addedByUser = true - accountID, err = am.Store.GetAccountIDByUserID(ctx, store.LockingStrengthShare, userID) - } else { - accountID, err = am.Store.GetAccountIDBySetupKey(ctx, encodedHashedKey) - } - if err != nil { - return nil, nil, nil, status.Errorf(status.NotFound, "failed adding new peer: account not found") - } - - unlock := am.Store.AcquireWriteLockByUID(ctx, accountID) - defer func() { - if unlock != nil { - unlock() - } - }() + addedByUser := len(userID) > 0 // This is a handling for the case when the same machine (with the same WireGuard pub key) tries to register twice. // Such case is possible when AddPeer function takes long time to finish after AcquireWriteLockByUID (e.g., database is slow) // and the peer disconnects with a timeout and tries to register again. // We just check if this machine has been registered before and reject the second registration. // The connecting peer should be able to recover with a retry. - _, err = am.Store.GetPeerByPeerPubKey(ctx, store.LockingStrengthShare, peer.Key) + _, err := am.Store.GetPeerByPeerPubKey(ctx, store.LockingStrengthNone, peer.Key) if err == nil { return nil, nil, nil, status.Errorf(status.PreconditionFailed, "peer has been already registered") } opEvent := &activity.Event{ Timestamp: time.Now().UTC(), - AccountID: accountID, } var newPeer *nbpeer.Peer var updateAccountPeers bool - err = am.Store.ExecuteInTransaction(ctx, func(transaction store.Store) error { - var setupKeyID string - var setupKeyName string - var ephemeral bool - var groupsToAdd []string - var allowExtraDNSLabels bool - if addedByUser { - user, err := transaction.GetUserByUserID(ctx, store.LockingStrengthUpdate, userID) - if err != nil { - return fmt.Errorf("failed to get user groups: %w", err) - } - groupsToAdd = user.AutoGroups - opEvent.InitiatorID = userID - opEvent.Activity = activity.PeerAddedByUser - } else { - // Validate the setup key - sk, err := transaction.GetSetupKeyBySecret(ctx, store.LockingStrengthUpdate, encodedHashedKey) - if err != nil { - return fmt.Errorf("failed to get setup key: %w", err) - } - - if !sk.IsValid() { - return status.Errorf(status.PreconditionFailed, "couldn't add peer: setup key is invalid") - } - - opEvent.InitiatorID = sk.Id - opEvent.Activity = activity.PeerAddedWithSetupKey - groupsToAdd = sk.AutoGroups - ephemeral = sk.Ephemeral - setupKeyID = sk.Id - setupKeyName = sk.Name - allowExtraDNSLabels = sk.AllowExtraDNSLabels - - if !sk.AllowExtraDNSLabels && len(peer.ExtraDNSLabels) > 0 { - return status.Errorf(status.PreconditionFailed, "couldn't add peer: setup key doesn't allow extra DNS labels") - } - } - - if (strings.ToLower(peer.Meta.Hostname) == "iphone" || strings.ToLower(peer.Meta.Hostname) == "ipad") && userID != "" { - if am.idpManager != nil { - userdata, err := am.idpManager.GetUserDataByID(ctx, userID, idp.AppMetadata{WTAccountID: accountID}) - if err == nil && userdata != nil { - peer.Meta.Hostname = fmt.Sprintf("%s-%s", peer.Meta.Hostname, strings.Split(userdata.Email, "@")[0]) - } - } - } - - freeLabel, err := am.getFreeDNSLabel(ctx, transaction, accountID, peer.Meta.Hostname) + var setupKeyID string + var setupKeyName string + var ephemeral bool + var groupsToAdd []string + var allowExtraDNSLabels bool + var accountID string + if addedByUser { + user, err := am.Store.GetUserByUserID(ctx, store.LockingStrengthNone, userID) if err != nil { - return fmt.Errorf("failed to get free DNS label: %w", err) + return nil, nil, nil, fmt.Errorf("failed to get user groups: %w", err) } - - freeIP, err := getFreeIP(ctx, transaction, accountID) + groupsToAdd = user.AutoGroups + opEvent.InitiatorID = userID + opEvent.Activity = activity.PeerAddedByUser + accountID = user.AccountID + } else { + // Validate the setup key + sk, err := am.Store.GetSetupKeyBySecret(ctx, store.LockingStrengthNone, encodedHashedKey) if err != nil { - return fmt.Errorf("failed to get free IP: %w", err) + return nil, nil, nil, fmt.Errorf("failed to get setup key: %w", err) } - if err := domain.ValidateDomainsList(peer.ExtraDNSLabels); err != nil { - return status.Errorf(status.InvalidArgument, "invalid extra DNS labels: %v", err) + // we will check key twice for early return + if !sk.IsValid() { + return nil, nil, nil, status.Errorf(status.PreconditionFailed, "couldn't add peer: setup key is invalid") } - registrationTime := time.Now().UTC() - newPeer = &nbpeer.Peer{ - ID: xid.New().String(), - AccountID: accountID, - Key: peer.Key, - IP: freeIP, - Meta: peer.Meta, - Name: peer.Meta.Hostname, - DNSLabel: freeLabel, - UserID: userID, - Status: &nbpeer.PeerStatus{Connected: false, LastSeen: registrationTime}, - SSHEnabled: false, - SSHKey: peer.SSHKey, - LastLogin: ®istrationTime, - CreatedAt: registrationTime, - LoginExpirationEnabled: addedByUser, - Ephemeral: ephemeral, - Location: peer.Location, - InactivityExpirationEnabled: addedByUser, - ExtraDNSLabels: peer.ExtraDNSLabels, - AllowExtraDNSLabels: allowExtraDNSLabels, - } - settings, err := transaction.GetAccountSettings(ctx, store.LockingStrengthShare, accountID) - if err != nil { - return fmt.Errorf("failed to get account settings: %w", err) - } + opEvent.InitiatorID = sk.Id + opEvent.Activity = activity.PeerAddedWithSetupKey + groupsToAdd = sk.AutoGroups + ephemeral = sk.Ephemeral + setupKeyID = sk.Id + setupKeyName = sk.Name + allowExtraDNSLabels = sk.AllowExtraDNSLabels + accountID = sk.AccountID - opEvent.TargetID = newPeer.ID - opEvent.Meta = newPeer.EventMeta(am.GetDNSDomain(settings)) - if !addedByUser { - opEvent.Meta["setup_key_name"] = setupKeyName + if !sk.AllowExtraDNSLabels && len(peer.ExtraDNSLabels) > 0 { + return nil, nil, nil, status.Errorf(status.PreconditionFailed, "couldn't add peer: setup key doesn't allow extra DNS labels") } + } + opEvent.AccountID = accountID - if am.geo != nil && newPeer.Location.ConnectionIP != nil { - location, err := am.geo.Lookup(newPeer.Location.ConnectionIP) - if err != nil { - log.WithContext(ctx).Warnf("failed to get location for new peer realip: [%s]: %v", newPeer.Location.ConnectionIP.String(), err) - } else { - newPeer.Location.CountryCode = location.Country.ISOCode - newPeer.Location.CityName = location.City.Names.En - newPeer.Location.GeoNameID = location.City.GeonameID + if (strings.ToLower(peer.Meta.Hostname) == "iphone" || strings.ToLower(peer.Meta.Hostname) == "ipad") && userID != "" { + if am.idpManager != nil { + userdata, err := am.idpManager.GetUserDataByID(ctx, userID, idp.AppMetadata{WTAccountID: accountID}) + if err == nil && userdata != nil { + peer.Meta.Hostname = fmt.Sprintf("%s-%s", peer.Meta.Hostname, strings.Split(userdata.Email, "@")[0]) } } + } - newPeer = am.integratedPeerValidator.PreparePeer(ctx, accountID, newPeer, groupsToAdd, settings.Extra) - - err = transaction.AddPeerToAllGroup(ctx, store.LockingStrengthUpdate, accountID, newPeer.ID) - if err != nil { - return fmt.Errorf("failed adding peer to All group: %w", err) - } - - if len(groupsToAdd) > 0 { - for _, g := range groupsToAdd { - err = transaction.AddPeerToGroup(ctx, store.LockingStrengthUpdate, accountID, newPeer.ID, g) - if err != nil { - return err - } - } - } - - err = transaction.AddPeerToAccount(ctx, store.LockingStrengthUpdate, newPeer) - if err != nil { - return fmt.Errorf("failed to add peer to account: %w", err) - } - - err = transaction.IncrementNetworkSerial(ctx, store.LockingStrengthUpdate, accountID) - if err != nil { - return fmt.Errorf("failed to increment network serial: %w", err) - } - - if addedByUser { - err := transaction.SaveUserLastLogin(ctx, accountID, userID, newPeer.GetLastLogin()) - if err != nil { - log.WithContext(ctx).Debugf("failed to update user last login: %v", err) - } - } else { - err = transaction.IncrementSetupKeyUsage(ctx, setupKeyID) - if err != nil { - return fmt.Errorf("failed to increment setup key usage: %w", err) - } - } - - updateAccountPeers, err = isPeerInActiveGroup(ctx, transaction, accountID, newPeer.ID) - if err != nil { - return err - } - - log.WithContext(ctx).Debugf("Peer %s added to account %s", newPeer.ID, accountID) - return nil - }) + if err := domain.ValidateDomainsList(peer.ExtraDNSLabels); err != nil { + return nil, nil, nil, status.Errorf(status.InvalidArgument, "invalid extra DNS labels: %v", err) + } + registrationTime := time.Now().UTC() + newPeer = &nbpeer.Peer{ + ID: xid.New().String(), + AccountID: accountID, + Key: peer.Key, + Meta: peer.Meta, + Name: peer.Meta.Hostname, + UserID: userID, + Status: &nbpeer.PeerStatus{Connected: false, LastSeen: registrationTime}, + SSHEnabled: false, + SSHKey: peer.SSHKey, + LastLogin: ®istrationTime, + CreatedAt: registrationTime, + LoginExpirationEnabled: addedByUser, + Ephemeral: ephemeral, + Location: peer.Location, + InactivityExpirationEnabled: addedByUser, + ExtraDNSLabels: peer.ExtraDNSLabels, + AllowExtraDNSLabels: allowExtraDNSLabels, + } + settings, err := am.Store.GetAccountSettings(ctx, store.LockingStrengthNone, accountID) if err != nil { + return nil, nil, nil, fmt.Errorf("failed to get account settings: %w", err) + } + + if am.geo != nil && newPeer.Location.ConnectionIP != nil { + location, err := am.geo.Lookup(newPeer.Location.ConnectionIP) + if err != nil { + log.WithContext(ctx).Warnf("failed to get location for new peer realip: [%s]: %v", newPeer.Location.ConnectionIP.String(), err) + } else { + newPeer.Location.CountryCode = location.Country.ISOCode + newPeer.Location.CityName = location.City.Names.En + newPeer.Location.GeoNameID = location.City.GeonameID + } + } + + newPeer = am.integratedPeerValidator.PreparePeer(ctx, accountID, newPeer, groupsToAdd, settings.Extra) + + network, err := am.Store.GetAccountNetwork(ctx, store.LockingStrengthNone, accountID) + if err != nil { + return nil, nil, nil, fmt.Errorf("failed getting network: %w", err) + } + + maxAttempts := 10 + for attempt := 1; attempt <= maxAttempts; attempt++ { + var freeIP net.IP + freeIP, err = types.AllocateRandomPeerIP(network.Net) + if err != nil { + return nil, nil, nil, fmt.Errorf("failed to get free IP: %w", err) + } + + var freeLabel string + freeLabel, err = getPeerIPDNSLabel(ctx, am.Store, freeIP, accountID, peer.Meta.Hostname) + if err != nil { + return nil, nil, nil, fmt.Errorf("failed to get free DNS label: %w", err) + } + + newPeer.DNSLabel = freeLabel + newPeer.IP = freeIP + + unlock := am.Store.AcquireWriteLockByUID(ctx, accountID) + defer func() { + if unlock != nil { + unlock() + } + }() + + err = am.Store.ExecuteInTransaction(ctx, func(transaction store.Store) error { + err = transaction.AddPeerToAccount(ctx, store.LockingStrengthUpdate, newPeer) + if err != nil { + return err + } + + err = transaction.AddPeerToAllGroup(ctx, store.LockingStrengthUpdate, accountID, newPeer.ID) + if err != nil { + return fmt.Errorf("failed adding peer to All group: %w", err) + } + + if len(groupsToAdd) > 0 { + for _, g := range groupsToAdd { + err = transaction.AddPeerToGroup(ctx, store.LockingStrengthUpdate, accountID, newPeer.ID, g) + if err != nil { + return err + } + } + } + + if addedByUser { + err := transaction.SaveUserLastLogin(ctx, accountID, userID, newPeer.GetLastLogin()) + if err != nil { + log.WithContext(ctx).Debugf("failed to update user last login: %v", err) + } + } else { + sk, err := transaction.GetSetupKeyBySecret(ctx, store.LockingStrengthUpdate, encodedHashedKey) + if err != nil { + return fmt.Errorf("failed to get setup key: %w", err) + } + + // we validate at the end to not block the setup key for too long + if !sk.IsValid() { + return status.Errorf(status.PreconditionFailed, "couldn't add peer: setup key is invalid") + } + + err = transaction.IncrementSetupKeyUsage(ctx, setupKeyID) + if err != nil { + return fmt.Errorf("failed to increment setup key usage: %w", err) + } + } + + err = transaction.IncrementNetworkSerial(ctx, store.LockingStrengthUpdate, accountID) + if err != nil { + return fmt.Errorf("failed to increment network serial: %w", err) + } + + log.WithContext(ctx).Debugf("Peer %s added to account %s", newPeer.ID, accountID) + return nil + }) + if err == nil { + unlock() + unlock = nil + break + } + + if isUniqueConstraintError(err) { + unlock() + unlock = nil + log.WithContext(ctx).Debugf("Failed to add peer in attempt %d, retrying: %v", attempt, err) + continue + } + return nil, nil, nil, fmt.Errorf("failed to add peer to database: %w", err) } + if err != nil { + return nil, nil, nil, fmt.Errorf("failed to add peer to database after %d attempts: %w", maxAttempts, err) + } + + updateAccountPeers, err = isPeerInActiveGroup(ctx, am.Store, accountID, newPeer.ID) + if err != nil { + updateAccountPeers = true + } if newPeer == nil { return nil, nil, nil, fmt.Errorf("new peer is nil") } - am.StoreEvent(ctx, opEvent.InitiatorID, opEvent.TargetID, opEvent.AccountID, opEvent.Activity, opEvent.Meta) + opEvent.TargetID = newPeer.ID + opEvent.Meta = newPeer.EventMeta(am.GetDNSDomain(settings)) + if !addedByUser { + opEvent.Meta["setup_key_name"] = setupKeyName + } - unlock() - unlock = nil + am.StoreEvent(ctx, opEvent.InitiatorID, opEvent.TargetID, opEvent.AccountID, opEvent.Activity, opEvent.Meta) if updateAccountPeers { am.BufferUpdateAccountPeers(ctx, accountID) @@ -673,23 +694,21 @@ func (am *DefaultAccountManager) AddPeer(ctx context.Context, setupKey, userID s return am.getValidatedPeerWithMap(ctx, false, accountID, newPeer) } -func getFreeIP(ctx context.Context, transaction store.Store, accountID string) (net.IP, error) { - takenIps, err := transaction.GetTakenIPs(ctx, store.LockingStrengthShare, accountID) +func getPeerIPDNSLabel(ctx context.Context, tx store.Store, ip net.IP, accountID, peerHostName string) (string, error) { + ip = ip.To4() + + dnsName, err := nbdns.GetParsedDomainLabel(peerHostName) if err != nil { - return nil, fmt.Errorf("failed to get taken IPs: %w", err) + return "", fmt.Errorf("failed to parse peer host name %s: %w", peerHostName, err) } - network, err := transaction.GetAccountNetwork(ctx, store.LockingStrengthUpdate, accountID) + _, err = tx.GetPeerIdByLabel(ctx, store.LockingStrengthNone, accountID, dnsName) if err != nil { - return nil, fmt.Errorf("failed getting network: %w", err) + //nolint:nilerr + return dnsName, nil } - nextIp, err := types.AllocatePeerIP(network.Net, takenIps) - if err != nil { - return nil, fmt.Errorf("failed to allocate new peer ip: %w", err) - } - - return nextIp, nil + return fmt.Sprintf("%s-%d-%d", dnsName, ip[2], ip[3]), nil } // SyncPeer checks whether peer is eligible for receiving NetworkMap (authenticated) and returns its NetworkMap if eligible @@ -1477,19 +1496,6 @@ func getPeerGroupIDs(ctx context.Context, transaction store.Store, accountID str return groupIDs, err } -func getPeerDNSLabels(ctx context.Context, transaction store.Store, accountID string) (types.LookupMap, error) { - dnsLabels, err := transaction.GetPeerLabelsInAccount(ctx, store.LockingStrengthShare, accountID) - if err != nil { - return nil, err - } - - existingLabels := make(types.LookupMap) - for _, label := range dnsLabels { - existingLabels[label] = struct{}{} - } - return existingLabels, nil -} - // IsPeerInActiveGroup checks if the given peer is part of a group that is used // in an active DNS, route, or ACL configuration. func isPeerInActiveGroup(ctx context.Context, transaction store.Store, accountID, peerID string) (bool, error) { diff --git a/management/server/peer/peer.go b/management/server/peer/peer.go index 8ce1dfb4e..f7140e254 100644 --- a/management/server/peer/peer.go +++ b/management/server/peer/peer.go @@ -20,14 +20,14 @@ type Peer struct { // WireGuard public key Key string `gorm:"index"` // IP address of the Peer - IP net.IP `gorm:"serializer:json"` + IP net.IP `gorm:"serializer:json"` // uniqueness index per accountID (check migrations) // Meta is a Peer system meta data Meta PeerSystemMeta `gorm:"embedded;embeddedPrefix:meta_"` // Name is peer's name (machine name) Name string // DNSLabel is the parsed peer name for domain resolution. It is used to form an FQDN by appending the account's // domain to the peer label. e.g. peer-dns-label.netbird.cloud - DNSLabel string + DNSLabel string // uniqueness index per accountID (check migrations) // Status peer's management connection status Status *PeerStatus `gorm:"embedded;embeddedPrefix:peer_status_"` // The user ID that registered the peer diff --git a/management/server/peer_test.go b/management/server/peer_test.go index 775385a29..3edf7e82c 100644 --- a/management/server/peer_test.go +++ b/management/server/peer_test.go @@ -10,7 +10,9 @@ import ( "net/netip" "os" "runtime" + "strconv" "strings" + "sync" "testing" "time" @@ -19,6 +21,7 @@ import ( log "github.com/sirupsen/logrus" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" + "golang.org/x/exp/maps" "golang.zx2c4.com/wireguard/wgctrl/wgtypes" "github.com/netbirdio/netbird/management/server/integrations/port_forwarding" @@ -1391,7 +1394,7 @@ func Test_RegisterPeerBySetupKey(t *testing.T) { name: "Absent setup key", existingSetupKeyID: "AAAAAAAA-38F5-4553-B31E-DD66C696CEBB", expectAddPeerError: true, - expectedErrorMsgSubstring: "failed adding new peer: account not found", + expectedErrorMsgSubstring: "failed to get setup key: setup key not found", }, } @@ -2057,10 +2060,14 @@ func Test_DeletePeer(t *testing.T) { "peer1": { ID: "peer1", AccountID: accountID, + IP: net.IP{1, 1, 1, 1}, + DNSLabel: "peer1.test", }, "peer2": { ID: "peer2", AccountID: accountID, + IP: net.IP{2, 2, 2, 2}, + DNSLabel: "peer2.test", }, } account.Groups = map[string]*types.Group{ @@ -2090,3 +2097,138 @@ func Test_DeletePeer(t *testing.T) { assert.NotContains(t, group.Peers, "peer1") } + +func Test_IsUniqueConstraintError(t *testing.T) { + tests := []struct { + name string + engine types.Engine + }{ + { + name: "PostgreSQL uniqueness error", + engine: types.PostgresStoreEngine, + }, + { + name: "MySQL uniqueness error", + engine: types.MysqlStoreEngine, + }, + { + name: "SQLite uniqueness error", + engine: types.SqliteStoreEngine, + }, + } + + peer := &nbpeer.Peer{ + ID: "test-peer-id", + AccountID: "bf1c8084-ba50-4ce7-9439-34653001fc3b", + DNSLabel: "test-peer-dns-label", + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Setenv("NETBIRD_STORE_ENGINE", string(tt.engine)) + s, cleanup, err := store.NewTestStoreFromSQL(context.Background(), "testdata/extended-store.sql", t.TempDir()) + if err != nil { + t.Fatalf("Error when creating store: %s", err) + } + t.Cleanup(cleanup) + + err = s.AddPeerToAccount(context.Background(), store.LockingStrengthUpdate, peer) + assert.NoError(t, err) + + err = s.AddPeerToAccount(context.Background(), store.LockingStrengthUpdate, peer) + result := isUniqueConstraintError(err) + assert.True(t, result) + }) + } +} + +func Test_AddPeer(t *testing.T) { + t.Setenv("NETBIRD_STORE_ENGINE", string(types.PostgresStoreEngine)) + manager, err := createManager(t) + if err != nil { + t.Fatal(err) + return + } + + accountID := "testaccount" + userID := "testuser" + + _, err = createAccount(manager, accountID, userID, "domain.com") + if err != nil { + t.Fatal("error creating account") + return + } + + setupKey, err := manager.CreateSetupKey(context.Background(), accountID, "test-key", types.SetupKeyReusable, time.Hour, nil, 10000, userID, false, false) + if err != nil { + t.Fatal("error creating setup key") + return + } + + const totalPeers = 300 // totalPeers / differentHostnames should be less than 10 (due to concurrent retries) + const differentHostnames = 50 + + var wg sync.WaitGroup + errs := make(chan error, totalPeers+differentHostnames) + start := make(chan struct{}) + for i := 0; i < totalPeers; i++ { + wg.Add(1) + hostNameID := i % differentHostnames + + go func(i int) { + defer wg.Done() + + newPeer := &nbpeer.Peer{ + Key: "key" + strconv.Itoa(i), + Meta: nbpeer.PeerSystemMeta{Hostname: "peer" + strconv.Itoa(hostNameID), GoOS: "linux"}, + } + + <-start + + _, _, _, err := manager.AddPeer(context.Background(), setupKey.Key, "", newPeer) + if err != nil { + errs <- fmt.Errorf("AddPeer failed for peer %d: %w", i, err) + return + } + + }(i) + } + startTime := time.Now() + + close(start) + wg.Wait() + close(errs) + + t.Logf("time since start: %s", time.Since(startTime)) + + for err := range errs { + t.Fatal(err) + } + + account, err := manager.Store.GetAccount(context.Background(), accountID) + if err != nil { + t.Fatalf("Failed to get account %s: %v", accountID, err) + } + + assert.Equal(t, totalPeers, len(account.Peers), "Expected %d peers in account %s, got %d", totalPeers, accountID, len(account.Peers)) + + seenIP := make(map[string]bool) + for _, p := range account.Peers { + ipStr := p.IP.String() + if seenIP[ipStr] { + t.Fatalf("Duplicate IP found in account %s: %s", accountID, ipStr) + } + seenIP[ipStr] = true + } + + seenLabel := make(map[string]bool) + for _, p := range account.Peers { + if seenLabel[p.DNSLabel] { + t.Fatalf("Duplicate Label found in account %s: %s", accountID, p.DNSLabel) + } + seenLabel[p.DNSLabel] = true + } + + assert.Equal(t, totalPeers, maps.Values(account.SetupKeys)[0].UsedTimes) + assert.Equal(t, uint64(totalPeers), account.Network.Serial) +} diff --git a/management/server/store/file_store.go b/management/server/store/file_store.go index 3b95164f5..d5d9337ca 100644 --- a/management/server/store/file_store.go +++ b/management/server/store/file_store.go @@ -156,7 +156,7 @@ func restore(ctx context.Context, file string) (*FileStore, error) { allGroup, err := account.GetGroupAll() if err != nil { - log.WithContext(ctx).Errorf("unable to find the All group, this should happen only when migrate from a version that didn't support groups. Error: %v", err) + log.WithContext(ctx).Errorf("unable to find the All group, this should happen only when migratePreAuto from a version that didn't support groups. Error: %v", err) // if the All group didn't exist we probably don't have routes to update continue } diff --git a/management/server/store/sql_store.go b/management/server/store/sql_store.go index 72a73a57a..197255ab6 100644 --- a/management/server/store/sql_store.go +++ b/management/server/store/sql_store.go @@ -92,8 +92,8 @@ func NewSqlStore(ctx context.Context, db *gorm.DB, storeEngine types.Engine, met return &SqlStore{db: db, storeEngine: storeEngine, metrics: metrics, installationPK: 1}, nil } - if err := migrate(ctx, db); err != nil { - return nil, fmt.Errorf("migrate: %w", err) + if err := migratePreAuto(ctx, db); err != nil { + return nil, fmt.Errorf("migratePreAuto: %w", err) } err = db.AutoMigrate( &types.SetupKey{}, &nbpeer.Peer{}, &types.User{}, &types.PersonalAccessToken{}, &types.Group{}, @@ -102,7 +102,10 @@ func NewSqlStore(ctx context.Context, db *gorm.DB, storeEngine types.Engine, met &networkTypes.Network{}, &routerTypes.NetworkRouter{}, &resourceTypes.NetworkResource{}, ) if err != nil { - return nil, fmt.Errorf("auto migrate: %w", err) + return nil, fmt.Errorf("auto migratePreAuto: %w", err) + } + if err := migratePostAuto(ctx, db); err != nil { + return nil, fmt.Errorf("migratePostAuto: %w", err) } return &SqlStore{db: db, storeEngine: storeEngine, metrics: metrics, installationPK: 1}, nil @@ -967,7 +970,7 @@ func (s *SqlStore) GetTakenIPs(ctx context.Context, lockStrength LockingStrength return ips, nil } -func (s *SqlStore) GetPeerLabelsInAccount(ctx context.Context, lockStrength LockingStrength, accountID string) ([]string, error) { +func (s *SqlStore) GetPeerLabelsInAccount(ctx context.Context, lockStrength LockingStrength, accountID string, dnsLabel string) ([]string, error) { tx := s.db if lockStrength != LockingStrengthNone { tx = tx.Clauses(clause.Locking{Strength: string(lockStrength)}) @@ -975,7 +978,7 @@ func (s *SqlStore) GetPeerLabelsInAccount(ctx context.Context, lockStrength Lock var labels []string result := tx.Model(&nbpeer.Peer{}). - Where("account_id = ?", accountID). + Where("account_id = ? AND dns_label LIKE ?", accountID, dnsLabel+"%"). Pluck("dns_label", &labels) if result.Error != nil { @@ -1254,7 +1257,7 @@ func (s *SqlStore) GetSetupKeyBySecret(ctx context.Context, lockStrength Locking if result.Error != nil { if errors.Is(result.Error, gorm.ErrRecordNotFound) { - return nil, status.NewSetupKeyNotFoundError(key) + return nil, status.Errorf(status.PreconditionFailed, "setup key not found") } log.WithContext(ctx).Errorf("failed to get setup key by secret from store: %v", result.Error) return nil, status.Errorf(status.Internal, "failed to get setup key by secret from store") @@ -2546,6 +2549,27 @@ func (s *SqlStore) GetPeerByIP(ctx context.Context, lockStrength LockingStrength return &peer, nil } +func (s *SqlStore) GetPeerIdByLabel(ctx context.Context, lockStrength LockingStrength, accountID string, hostname string) (string, error) { + tx := s.db.WithContext(ctx) + if lockStrength != LockingStrengthNone { + tx = tx.Clauses(clause.Locking{Strength: string(lockStrength)}) + } + + var peerID string + result := tx.Model(&nbpeer.Peer{}). + Select("id"). + // Where(" = ?", hostname). + Where("account_id = ? AND dns_label = ?", accountID, hostname). + Limit(1). + Scan(&peerID) + + if peerID == "" { + return "", gorm.ErrRecordNotFound + } + + return peerID, result.Error +} + func (s *SqlStore) CountAccountsByPrivateDomain(ctx context.Context, domain string) (int64, error) { var count int64 result := s.db.Model(&types.Account{}). diff --git a/management/server/store/sql_store_test.go b/management/server/store/sql_store_test.go index f187be8c7..928486ab4 100644 --- a/management/server/store/sql_store_test.go +++ b/management/server/store/sql_store_test.go @@ -10,6 +10,7 @@ import ( "net/netip" "os" "runtime" + "sort" "sync" "testing" "time" @@ -630,7 +631,7 @@ func TestMigrate(t *testing.T) { t.Cleanup(cleanUp) assert.NoError(t, err) - err = migrate(context.Background(), store.(*SqlStore).db) + err = migratePreAuto(context.Background(), store.(*SqlStore).db) require.NoError(t, err, "Migration should not fail on empty db") _, ipnet, err := net.ParseCIDR("10.0.0.0/24") @@ -685,10 +686,10 @@ func TestMigrate(t *testing.T) { err = store.(*SqlStore).db.Save(rt).Error require.NoError(t, err, "Failed to insert Gob data") - err = migrate(context.Background(), store.(*SqlStore).db) + err = migratePreAuto(context.Background(), store.(*SqlStore).db) require.NoError(t, err, "Migration should not fail on gob populated db") - err = migrate(context.Background(), store.(*SqlStore).db) + err = migratePreAuto(context.Background(), store.(*SqlStore).db) require.NoError(t, err, "Migration should not fail on migrated db") err = store.(*SqlStore).db.Delete(rt).Where("id = ?", "route1").Error @@ -704,10 +705,10 @@ func TestMigrate(t *testing.T) { err = store.(*SqlStore).db.Save(nRT).Error require.NoError(t, err, "Failed to insert json nil slice data") - err = migrate(context.Background(), store.(*SqlStore).db) + err = migratePreAuto(context.Background(), store.(*SqlStore).db) require.NoError(t, err, "Migration should not fail on json nil slice populated db") - err = migrate(context.Background(), store.(*SqlStore).db) + err = migratePreAuto(context.Background(), store.(*SqlStore).db) require.NoError(t, err, "Migration should not fail on migrated db") } @@ -950,6 +951,7 @@ func TestSqlite_GetTakenIPs(t *testing.T) { peer1 := &nbpeer.Peer{ ID: "peer1", AccountID: existingAccountID, + DNSLabel: "peer1", IP: net.IP{1, 1, 1, 1}, } err = store.AddPeerToAccount(context.Background(), LockingStrengthUpdate, peer1) @@ -961,8 +963,9 @@ func TestSqlite_GetTakenIPs(t *testing.T) { assert.Equal(t, []net.IP{ip1}, takenIPs) peer2 := &nbpeer.Peer{ - ID: "peer2", + ID: "peer1second", AccountID: existingAccountID, + DNSLabel: "peer1-1", IP: net.IP{2, 2, 2, 2}, } err = store.AddPeerToAccount(context.Background(), LockingStrengthUpdate, peer2) @@ -972,49 +975,100 @@ func TestSqlite_GetTakenIPs(t *testing.T) { require.NoError(t, err) ip2 := net.IP{2, 2, 2, 2}.To16() assert.Equal(t, []net.IP{ip1, ip2}, takenIPs) - } func TestSqlite_GetPeerLabelsInAccount(t *testing.T) { - t.Setenv("NETBIRD_STORE_ENGINE", string(types.SqliteStoreEngine)) - store, cleanup, err := NewTestStoreFromSQL(context.Background(), "../testdata/extended-store.sql", t.TempDir()) - if err != nil { - return - } - t.Cleanup(cleanup) + runTestForAllEngines(t, "../testdata/extended-store.sql", func(t *testing.T, store Store) { + existingAccountID := "bf1c8084-ba50-4ce7-9439-34653001fc3b" + peerHostname := "peer1" - existingAccountID := "bf1c8084-ba50-4ce7-9439-34653001fc3b" + _, err := store.GetAccount(context.Background(), existingAccountID) + require.NoError(t, err) - _, err = store.GetAccount(context.Background(), existingAccountID) - require.NoError(t, err) + labels, err := store.GetPeerLabelsInAccount(context.Background(), LockingStrengthShare, existingAccountID, peerHostname) + require.NoError(t, err) + assert.Equal(t, []string{}, labels) - labels, err := store.GetPeerLabelsInAccount(context.Background(), LockingStrengthShare, existingAccountID) - require.NoError(t, err) - assert.Equal(t, []string{}, labels) + peer1 := &nbpeer.Peer{ + ID: "peer1", + AccountID: existingAccountID, + DNSLabel: "peer1", + IP: net.IP{1, 1, 1, 1}, + } + err = store.AddPeerToAccount(context.Background(), LockingStrengthUpdate, peer1) + require.NoError(t, err) - peer1 := &nbpeer.Peer{ - ID: "peer1", - AccountID: existingAccountID, - DNSLabel: "peer1.domain.test", - } - err = store.AddPeerToAccount(context.Background(), LockingStrengthUpdate, peer1) - require.NoError(t, err) + labels, err = store.GetPeerLabelsInAccount(context.Background(), LockingStrengthShare, existingAccountID, peerHostname) + require.NoError(t, err) + assert.Equal(t, []string{"peer1"}, labels) - labels, err = store.GetPeerLabelsInAccount(context.Background(), LockingStrengthShare, existingAccountID) - require.NoError(t, err) - assert.Equal(t, []string{"peer1.domain.test"}, labels) + peer2 := &nbpeer.Peer{ + ID: "peer1second", + AccountID: existingAccountID, + DNSLabel: "peer1-1", + IP: net.IP{2, 2, 2, 2}, + } + err = store.AddPeerToAccount(context.Background(), LockingStrengthUpdate, peer2) + require.NoError(t, err) - peer2 := &nbpeer.Peer{ - ID: "peer2", - AccountID: existingAccountID, - DNSLabel: "peer2.domain.test", - } - err = store.AddPeerToAccount(context.Background(), LockingStrengthUpdate, peer2) - require.NoError(t, err) + labels, err = store.GetPeerLabelsInAccount(context.Background(), LockingStrengthShare, existingAccountID, peerHostname) + require.NoError(t, err) - labels, err = store.GetPeerLabelsInAccount(context.Background(), LockingStrengthShare, existingAccountID) - require.NoError(t, err) - assert.Equal(t, []string{"peer1.domain.test", "peer2.domain.test"}, labels) + expected := []string{"peer1", "peer1-1"} + sort.Strings(expected) + sort.Strings(labels) + assert.Equal(t, expected, labels) + }) +} + +func Test_AddPeerWithSameDnsLabel(t *testing.T) { + runTestForAllEngines(t, "../testdata/extended-store.sql", func(t *testing.T, store Store) { + existingAccountID := "bf1c8084-ba50-4ce7-9439-34653001fc3b" + + _, err := store.GetAccount(context.Background(), existingAccountID) + require.NoError(t, err) + + peer1 := &nbpeer.Peer{ + ID: "peer1", + AccountID: existingAccountID, + DNSLabel: "peer1.domain.test", + } + err = store.AddPeerToAccount(context.Background(), LockingStrengthUpdate, peer1) + require.NoError(t, err) + + peer2 := &nbpeer.Peer{ + ID: "peer1second", + AccountID: existingAccountID, + DNSLabel: "peer1.domain.test", + } + err = store.AddPeerToAccount(context.Background(), LockingStrengthUpdate, peer2) + require.Error(t, err) + }) +} + +func Test_AddPeerWithSameIP(t *testing.T) { + runTestForAllEngines(t, "../testdata/extended-store.sql", func(t *testing.T, store Store) { + existingAccountID := "bf1c8084-ba50-4ce7-9439-34653001fc3b" + + _, err := store.GetAccount(context.Background(), existingAccountID) + require.NoError(t, err) + + peer1 := &nbpeer.Peer{ + ID: "peer1", + AccountID: existingAccountID, + IP: net.IP{1, 1, 1, 1}, + } + err = store.AddPeerToAccount(context.Background(), LockingStrengthUpdate, peer1) + require.NoError(t, err) + + peer2 := &nbpeer.Peer{ + ID: "peer1second", + AccountID: existingAccountID, + IP: net.IP{1, 1, 1, 1}, + } + err = store.AddPeerToAccount(context.Background(), LockingStrengthUpdate, peer2) + require.Error(t, err) + }) } func TestSqlite_GetAccountNetwork(t *testing.T) { diff --git a/management/server/store/store.go b/management/server/store/store.go index f66130ad3..30ff1549d 100644 --- a/management/server/store/store.go +++ b/management/server/store/store.go @@ -117,7 +117,7 @@ type Store interface { SavePostureChecks(ctx context.Context, lockStrength LockingStrength, postureCheck *posture.Checks) error DeletePostureChecks(ctx context.Context, lockStrength LockingStrength, accountID, postureChecksID string) error - GetPeerLabelsInAccount(ctx context.Context, lockStrength LockingStrength, accountId string) ([]string, error) + GetPeerLabelsInAccount(ctx context.Context, lockStrength LockingStrength, accountId string, hostname string) ([]string, error) AddPeerToAllGroup(ctx context.Context, lockStrength LockingStrength, accountID string, peerID string) error AddPeerToGroup(ctx context.Context, lockStrength LockingStrength, accountId string, peerId string, groupID string) error GetPeerGroups(ctx context.Context, lockStrength LockingStrength, accountId string, peerId string) ([]*types.Group, error) @@ -193,6 +193,7 @@ type Store interface { SaveNetworkResource(ctx context.Context, lockStrength LockingStrength, resource *resourceTypes.NetworkResource) error DeleteNetworkResource(ctx context.Context, lockStrength LockingStrength, accountID, resourceID string) error GetPeerByIP(ctx context.Context, lockStrength LockingStrength, accountID string, ip net.IP) (*nbpeer.Peer, error) + GetPeerIdByLabel(ctx context.Context, lockStrength LockingStrength, accountID string, hostname string) (string, error) } const ( @@ -234,9 +235,9 @@ func getStoreEngine(ctx context.Context, dataDir string, kind types.Engine) type if util.FileExists(jsonStoreFile) && !util.FileExists(sqliteStoreFile) { log.WithContext(ctx).Warnf("unsupported store engine specified, but found %s. Automatically migrating to SQLite.", jsonStoreFile) - // Attempt to migrate from JSON store to SQLite + // Attempt to migratePreAuto from JSON store to SQLite if err := MigrateFileStoreToSqlite(ctx, dataDir); err != nil { - log.WithContext(ctx).Errorf("failed to migrate filestore to SQLite: %v", err) + log.WithContext(ctx).Errorf("failed to migratePreAuto filestore to SQLite: %v", err) kind = types.FileStoreEngine } } @@ -280,9 +281,9 @@ func checkFileStoreEngine(kind types.Engine, dataDir string) error { return nil } -// migrate migrates the SQLite database to the latest schema -func migrate(ctx context.Context, db *gorm.DB) error { - migrations := getMigrations(ctx) +// migratePreAuto migrates the SQLite database to the latest schema +func migratePreAuto(ctx context.Context, db *gorm.DB) error { + migrations := getMigrationsPreAuto(ctx) for _, m := range migrations { if err := m(db); err != nil { @@ -293,7 +294,7 @@ func migrate(ctx context.Context, db *gorm.DB) error { return nil } -func getMigrations(ctx context.Context) []migrationFunc { +func getMigrationsPreAuto(ctx context.Context) []migrationFunc { return []migrationFunc{ func(db *gorm.DB) error { return migration.MigrateFieldFromGobToJSON[types.Account, net.IPNet](ctx, db, "network_net") @@ -329,6 +330,28 @@ func getMigrations(ctx context.Context) []migrationFunc { return migration.DropIndex[routerTypes.NetworkRouter](ctx, db, "idx_network_routers_id") }, } +} // migratePostAuto migrates the SQLite database to the latest schema +func migratePostAuto(ctx context.Context, db *gorm.DB) error { + migrations := getMigrationsPostAuto(ctx) + + for _, m := range migrations { + if err := m(db); err != nil { + return err + } + } + + return nil +} + +func getMigrationsPostAuto(ctx context.Context) []migrationFunc { + return []migrationFunc{ + func(db *gorm.DB) error { + return migration.CreateIndexIfNotExists[nbpeer.Peer](ctx, db, "idx_account_ip", "account_id", "ip") + }, + func(db *gorm.DB) error { + return migration.CreateIndexIfNotExists[nbpeer.Peer](ctx, db, "idx_account_dnslabel", "account_id", "dns_label") + }, + } } // NewTestStoreFromSQL is only used in tests. It will create a test database base of the store engine set in env. @@ -577,7 +600,7 @@ func MigrateFileStoreToSqlite(ctx context.Context, dataDir string) error { sqliteStoreAccounts := len(store.GetAllAccounts(ctx)) if fsStoreAccounts != sqliteStoreAccounts { - return fmt.Errorf("failed to migrate accounts from file to sqlite. Expected accounts: %d, got: %d", + return fmt.Errorf("failed to migratePreAuto accounts from file to sqlite. Expected accounts: %d, got: %d", fsStoreAccounts, sqliteStoreAccounts) } diff --git a/management/server/testdata/store.sql b/management/server/testdata/store.sql index 41b8fa2f7..4b126c618 100644 --- a/management/server/testdata/store.sql +++ b/management/server/testdata/store.sql @@ -52,4 +52,4 @@ INSERT INTO policy_rules VALUES('cs387mkv2d4bgq41b6n0','cs1tnh0hhcjnqoiuebf0','D INSERT INTO network_routers VALUES('ctc20ji7qv9ck2sebc80','ct286bi7qv930dsrrug0','bf1c8084-ba50-4ce7-9439-34653001fc3b','cs1tnh0hhcjnqoiuebeg',NULL,0,0); INSERT INTO network_resources VALUES ('ctc4nci7qv9061u6ilfg','ct286bi7qv930dsrrug0','bf1c8084-ba50-4ce7-9439-34653001fc3b','Host','192.168.1.1'); INSERT INTO networks VALUES('ct286bi7qv930dsrrug0','bf1c8084-ba50-4ce7-9439-34653001fc3b','Test Network','Test Network'); -INSERT INTO peers VALUES('ct286bi7qv930dsrrug0','bf1c8084-ba50-4ce7-9439-34653001fc3b','','','192.168.0.0','','','','','','','','','','','','','','','','','test','test','2023-01-01 00:00:00+00:00',0,0,0,'a23efe53-63fb-11ec-90d6-0242ac120003','',0,0,'2023-01-01 00:00:00+00:00','2023-01-01 00:00:00+00:00',0,'','','',0); +INSERT INTO peers VALUES('ct286bi7qv930dsrrug0','bf1c8084-ba50-4ce7-9439-34653001fc3b','','','"192.168.0.0"','','','','','','','','','','','','','','','','','test','test','2023-01-01 00:00:00+00:00',0,0,0,'a23efe53-63fb-11ec-90d6-0242ac120003','',0,0,'2023-01-01 00:00:00+00:00','2023-01-01 00:00:00+00:00',0,'','','',0); diff --git a/management/server/testdata/store_with_expired_peers.sql b/management/server/testdata/store_with_expired_peers.sql index 5990a0625..f2ef56a23 100644 --- a/management/server/testdata/store_with_expired_peers.sql +++ b/management/server/testdata/store_with_expired_peers.sql @@ -30,7 +30,7 @@ INSERT INTO setup_keys VALUES('','bf1c8084-ba50-4ce7-9439-34653001fc3b','A2C8E62 INSERT INTO peers VALUES('cfvprsrlo1hqoo49ohog','bf1c8084-ba50-4ce7-9439-34653001fc3b','5rvhvriKJZ3S9oxYToVj5TzDM9u9y8cxg7htIMWlYAg=','72546A29-6BC8-4311-BCFC-9CDBF33F1A48','"100.64.114.31"','f2a34f6a4731','linux','Linux','11','unknown','Debian GNU/Linux','','0.12.0','','',NULL,'','','','{"Cloud":"","Platform":""}',NULL,'f2a34f6a4731','f2a34f6a4731','2023-03-02 09:21:02.189035775+01:00',0,0,0,'','ssh-ed25519 AAAAC3NzaC1lZDI1NTE5AAAAILzUUSYG/LGnV8zarb2SGN+tib/PZ+M7cL4WtTzUrTpk',0,1,0,'2023-03-01 19:48:19.817799698+01:00','2024-10-02 17:00:32.527947+02:00',0,'""','','',0); INSERT INTO peers VALUES('cg05lnblo1hkg2j514p0','bf1c8084-ba50-4ce7-9439-34653001fc3b','RlSy2vzoG2HyMBTUImXOiVhCBiiBa5qD5xzMxkiFDW4=','','"100.64.39.54"','expiredhost','linux','Linux','22.04','x86_64','Ubuntu','','development','','',NULL,'','','','{"Cloud":"","Platform":""}',NULL,'expiredhost','expiredhost','2023-03-02 09:19:57.276717255+01:00',0,1,0,'edafee4e-63fb-11ec-90d6-0242ac120003','ssh-ed25519 AAAAC3NzaC1lZDI1NTE5AAAAIMbK5ZXJsGOOWoBT4OmkPtgdPZe2Q7bDuS/zjn2CZxhK',0,1,0,'2023-03-02 09:14:21.791679181+01:00','2024-10-02 17:00:32.527947+02:00',0,'""','','',0); INSERT INTO peers VALUES('cg3161rlo1hs9cq94gdg','bf1c8084-ba50-4ce7-9439-34653001fc3b','mVABSKj28gv+JRsf7e0NEGKgSOGTfU/nPB2cpuG56HU=','','"100.64.117.96"','testhost','linux','Linux','22.04','x86_64','Ubuntu','','development','','',NULL,'','','','{"Cloud":"","Platform":""}',NULL,'testhost','testhost','2023-03-06 18:21:27.252010027+01:00',0,0,0,'edafee4e-63fb-11ec-90d6-0242ac120003','ssh-ed25519 AAAAC3NzaC1lZDI1NTE5AAAAINWvvUkFFcrj48CWTkNUb/do/n52i1L5dH4DhGu+4ZuM',0,0,0,'2023-03-07 09:02:47.442857106+01:00','2024-10-02 17:00:32.527947+02:00',0,'""','','',0); -INSERT INTO peers VALUES('csrnkiq7qv9d8aitqd50','bf1c8084-ba50-4ce7-9439-34653001fc3b','mVABSKj28gv+JRsf7e0NEGKgSOGTfU/nPB2cpuG56HU=','','"100.64.117.96"','testhost','linux','Linux','22.04','x86_64','Ubuntu','','development','','',NULL,'','','','{"Cloud":"","Platform":""}',NULL,'testhost','testhost','2023-03-06 18:21:27.252010027+01:00',0,0,0,'f4f6d672-63fb-11ec-90d6-0242ac120003','ssh-ed25519 AAAAC3NzaC1lZDI1NTE5AAAAINWvvUkFFcrj48CWTkNUb/do/n52i1L5dH4DhGu+4ZuM',0,0,1,'2023-03-07 09:02:47.442857106+01:00','2024-10-02 17:00:32.527947+02:00',0,'""','','',0); +INSERT INTO peers VALUES('csrnkiq7qv9d8aitqd50','bf1c8084-ba50-4ce7-9439-34653001fc3b','mVABSKj28gv+JRsf7e0NEGKgSOGTfU/nPB2cpuG56HU=','','"100.64.117.97"','testhost','linux','Linux','22.04','x86_64','Ubuntu','','development','','',NULL,'','','','{"Cloud":"","Platform":""}',NULL,'testhost','testhost-1','2023-03-06 18:21:27.252010027+01:00',0,0,0,'f4f6d672-63fb-11ec-90d6-0242ac120003','ssh-ed25519 AAAAC3NzaC1lZDI1NTE5AAAAINWvvUkFFcrj48CWTkNUb/do/n52i1L5dH4DhGu+4ZuM',0,0,1,'2023-03-07 09:02:47.442857106+01:00','2024-10-02 17:00:32.527947+02:00',0,'""','','',0); INSERT INTO users VALUES('f4f6d672-63fb-11ec-90d6-0242ac120003','bf1c8084-ba50-4ce7-9439-34653001fc3b','user',0,0,'','[]',0,NULL,'2024-10-02 17:00:32.528196+02:00','api',0,''); INSERT INTO users VALUES('edafee4e-63fb-11ec-90d6-0242ac120003','bf1c8084-ba50-4ce7-9439-34653001fc3b','admin',0,0,'','[]',0,NULL,'2024-10-02 17:00:32.528196+02:00','api',0,''); INSERT INTO installations VALUES(1,''); diff --git a/management/server/types/network.go b/management/server/types/network.go index 00082bb41..eb8415264 100644 --- a/management/server/types/network.go +++ b/management/server/types/network.go @@ -1,6 +1,7 @@ package types import ( + "encoding/binary" "math/rand" "net" "sync" @@ -161,24 +162,65 @@ func (n *Network) Copy() *Network { // This method considers already taken IPs and reuses IPs if there are gaps in takenIps // E.g. if ipNet=100.30.0.0/16 and takenIps=[100.30.0.1, 100.30.0.4] then the result would be 100.30.0.2 or 100.30.0.3 func AllocatePeerIP(ipNet net.IPNet, takenIps []net.IP) (net.IP, error) { - takenIPMap := make(map[string]struct{}) - takenIPMap[ipNet.IP.String()] = struct{}{} + baseIP := ipToUint32(ipNet.IP.Mask(ipNet.Mask)) + totalIPs := uint32(1 << SubnetSize) + + taken := make(map[uint32]struct{}, len(takenIps)+1) + taken[baseIP] = struct{}{} // reserve network IP + taken[baseIP+totalIPs-1] = struct{}{} // reserve broadcast IP + for _, ip := range takenIps { - takenIPMap[ip.String()] = struct{}{} + taken[ipToUint32(ip)] = struct{}{} } - ips, _ := generateIPs(&ipNet, takenIPMap) + rng := rand.New(rand.NewSource(time.Now().UnixNano())) + maxAttempts := (int(totalIPs) - len(taken)) / 100 - if len(ips) == 0 { - return nil, status.Errorf(status.PreconditionFailed, "failed allocating new IP for the ipNet %s - network is out of IPs", ipNet.String()) + for i := 0; i < maxAttempts; i++ { + offset := uint32(rng.Intn(int(totalIPs-2))) + 1 + candidate := baseIP + offset + if _, exists := taken[candidate]; !exists { + return uint32ToIP(candidate), nil + } } - // pick a random IP - s := rand.NewSource(time.Now().Unix()) - r := rand.New(s) - intn := r.Intn(len(ips)) + for offset := uint32(1); offset < totalIPs-1; offset++ { + candidate := baseIP + offset + if _, exists := taken[candidate]; !exists { + return uint32ToIP(candidate), nil + } + } - return ips[intn], nil + return nil, status.Errorf(status.PreconditionFailed, "network %s is out of IPs", ipNet.String()) +} + +func AllocateRandomPeerIP(ipNet net.IPNet) (net.IP, error) { + baseIP := ipToUint32(ipNet.IP.Mask(ipNet.Mask)) + + ones, bits := ipNet.Mask.Size() + hostBits := bits - ones + + totalIPs := uint32(1 << hostBits) + + rng := rand.New(rand.NewSource(time.Now().UnixNano())) + offset := uint32(rng.Intn(int(totalIPs-2))) + 1 + + candidate := baseIP + offset + return uint32ToIP(candidate), nil +} + +func ipToUint32(ip net.IP) uint32 { + ip = ip.To4() + if len(ip) < 4 { + return 0 + } + return binary.BigEndian.Uint32(ip) +} + +func uint32ToIP(n uint32) net.IP { + ip := make(net.IP, 4) + binary.BigEndian.PutUint32(ip, n) + return ip } // generateIPs generates a list of all possible IPs of the given network excluding IPs specified in the exclusion list