mirror of
https://github.com/netbirdio/netbird.git
synced 2025-08-16 10:08:12 +02:00
[management] add uniqueness constraint for peer ip and label and optimize generation (#4042)
This commit is contained in:
@ -106,6 +106,18 @@ type DefaultAccountManager struct {
|
|||||||
disableDefaultPolicy bool
|
disableDefaultPolicy bool
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func isUniqueConstraintError(err error) bool {
|
||||||
|
switch {
|
||||||
|
case strings.Contains(err.Error(), "(SQLSTATE 23505)"),
|
||||||
|
strings.Contains(err.Error(), "Error 1062 (23000)"),
|
||||||
|
strings.Contains(err.Error(), "UNIQUE constraint failed"):
|
||||||
|
return true
|
||||||
|
|
||||||
|
default:
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// getJWTGroupsChanges calculates the changes needed to sync a user's JWT groups.
|
// getJWTGroupsChanges calculates the changes needed to sync a user's JWT groups.
|
||||||
// Returns a bool indicating if there are changes in the JWT group membership, the updated user AutoGroups,
|
// Returns a bool indicating if there are changes in the JWT group membership, the updated user AutoGroups,
|
||||||
// newly groups to create and an error if any occurred.
|
// newly groups to create and an error if any occurred.
|
||||||
@ -1661,25 +1673,6 @@ func (am *DefaultAccountManager) handleUserPeer(ctx context.Context, transaction
|
|||||||
return false, nil
|
return false, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (am *DefaultAccountManager) getFreeDNSLabel(ctx context.Context, s store.Store, accountID string, peerHostName string) (string, error) {
|
|
||||||
existingLabels, err := s.GetPeerLabelsInAccount(ctx, store.LockingStrengthShare, accountID)
|
|
||||||
if err != nil {
|
|
||||||
return "", fmt.Errorf("failed to get peer dns labels: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
labelMap := ConvertSliceToMap(existingLabels)
|
|
||||||
newLabel, err := types.GetPeerHostLabel(peerHostName, labelMap)
|
|
||||||
if err != nil {
|
|
||||||
return "", fmt.Errorf("failed to get new host label: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
if newLabel == "" {
|
|
||||||
return "", fmt.Errorf("failed to get new host label: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
return newLabel, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (am *DefaultAccountManager) GetAccountSettings(ctx context.Context, accountID string, userID string) (*types.Settings, error) {
|
func (am *DefaultAccountManager) GetAccountSettings(ctx context.Context, accountID string, userID string) (*types.Settings, error) {
|
||||||
allowed, err := am.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Settings, operations.Read)
|
allowed, err := am.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Settings, operations.Read)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
@ -2623,11 +2623,11 @@ func TestAccount_SetJWTGroups(t *testing.T) {
|
|||||||
account := &types.Account{
|
account := &types.Account{
|
||||||
Id: "accountID",
|
Id: "accountID",
|
||||||
Peers: map[string]*nbpeer.Peer{
|
Peers: map[string]*nbpeer.Peer{
|
||||||
"peer1": {ID: "peer1", Key: "key1", UserID: "user1"},
|
"peer1": {ID: "peer1", Key: "key1", UserID: "user1", IP: net.IP{1, 1, 1, 1}, DNSLabel: "peer1.domain.test"},
|
||||||
"peer2": {ID: "peer2", Key: "key2", UserID: "user1"},
|
"peer2": {ID: "peer2", Key: "key2", UserID: "user1", IP: net.IP{2, 2, 2, 2}, DNSLabel: "peer2.domain.test"},
|
||||||
"peer3": {ID: "peer3", Key: "key3", UserID: "user1"},
|
"peer3": {ID: "peer3", Key: "key3", UserID: "user1", IP: net.IP{3, 3, 3, 3}, DNSLabel: "peer3.domain.test"},
|
||||||
"peer4": {ID: "peer4", Key: "key4", UserID: "user2"},
|
"peer4": {ID: "peer4", Key: "key4", UserID: "user2", IP: net.IP{4, 4, 4, 4}, DNSLabel: "peer4.domain.test"},
|
||||||
"peer5": {ID: "peer5", Key: "key5", UserID: "user2"},
|
"peer5": {ID: "peer5", Key: "key5", UserID: "user2", IP: net.IP{5, 5, 5, 5}, DNSLabel: "peer5.domain.test"},
|
||||||
},
|
},
|
||||||
Groups: map[string]*types.Group{
|
Groups: map[string]*types.Group{
|
||||||
"group1": {ID: "group1", Name: "group1", Issued: types.GroupIssuedAPI, Peers: []string{}},
|
"group1": {ID: "group1", Name: "group1", Issued: types.GroupIssuedAPI, Peers: []string{}},
|
||||||
@ -3147,11 +3147,11 @@ func BenchmarkLoginPeer_NewPeer(b *testing.B) {
|
|||||||
minMsPerOpCICD float64
|
minMsPerOpCICD float64
|
||||||
maxMsPerOpCICD float64
|
maxMsPerOpCICD float64
|
||||||
}{
|
}{
|
||||||
{"Small", 50, 5, 7, 20, 10, 80},
|
{"Small", 50, 5, 7, 20, 5, 80},
|
||||||
{"Medium", 500, 100, 5, 40, 30, 140},
|
{"Medium", 500, 100, 5, 40, 30, 140},
|
||||||
{"Large", 5000, 200, 80, 120, 140, 390},
|
{"Large", 5000, 200, 80, 120, 140, 390},
|
||||||
{"Small single", 50, 10, 7, 20, 10, 80},
|
{"Small single", 50, 10, 7, 20, 6, 80},
|
||||||
{"Medium single", 500, 10, 5, 40, 20, 85},
|
{"Medium single", 500, 10, 5, 40, 15, 85},
|
||||||
{"Large 5", 5000, 15, 80, 120, 80, 200},
|
{"Large 5", 5000, 15, 80, 120, 80, 200},
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -3343,11 +3343,11 @@ func TestPropagateUserGroupMemberships(t *testing.T) {
|
|||||||
account, err := manager.GetOrCreateAccountByUser(ctx, initiatorId, domain)
|
account, err := manager.GetOrCreateAccountByUser(ctx, initiatorId, domain)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
peer1 := &nbpeer.Peer{ID: "peer1", AccountID: account.Id, UserID: initiatorId}
|
peer1 := &nbpeer.Peer{ID: "peer1", AccountID: account.Id, UserID: initiatorId, IP: net.IP{1, 1, 1, 1}, DNSLabel: "peer1.domain.test"}
|
||||||
err = manager.Store.AddPeerToAccount(ctx, store.LockingStrengthUpdate, peer1)
|
err = manager.Store.AddPeerToAccount(ctx, store.LockingStrengthUpdate, peer1)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
peer2 := &nbpeer.Peer{ID: "peer2", AccountID: account.Id, UserID: initiatorId}
|
peer2 := &nbpeer.Peer{ID: "peer2", AccountID: account.Id, UserID: initiatorId, IP: net.IP{2, 2, 2, 2}, DNSLabel: "peer2.domain.test"}
|
||||||
err = manager.Store.AddPeerToAccount(ctx, store.LockingStrengthUpdate, peer2)
|
err = manager.Store.AddPeerToAccount(ctx, store.LockingStrengthUpdate, peer2)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
@ -373,3 +373,42 @@ func DropIndex[T any](ctx context.Context, db *gorm.DB, indexName string) error
|
|||||||
log.WithContext(ctx).Infof("dropped index %s from table %T", indexName, model)
|
log.WithContext(ctx).Infof("dropped index %s from table %T", indexName, model)
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func CreateIndexIfNotExists[T any](ctx context.Context, db *gorm.DB, indexName string, columns ...string) error {
|
||||||
|
var model T
|
||||||
|
|
||||||
|
stmt := &gorm.Statement{DB: db}
|
||||||
|
if err := stmt.Parse(&model); err != nil {
|
||||||
|
return fmt.Errorf("failed to parse model schema: %w", err)
|
||||||
|
}
|
||||||
|
tableName := stmt.Schema.Table
|
||||||
|
dialect := db.Dialector.Name()
|
||||||
|
|
||||||
|
var columnClause string
|
||||||
|
if dialect == "mysql" {
|
||||||
|
var withLength []string
|
||||||
|
for _, col := range columns {
|
||||||
|
if col == "ip" || col == "dns_label" {
|
||||||
|
withLength = append(withLength, fmt.Sprintf("%s(64)", col))
|
||||||
|
} else {
|
||||||
|
withLength = append(withLength, col)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
columnClause = strings.Join(withLength, ", ")
|
||||||
|
} else {
|
||||||
|
columnClause = strings.Join(columns, ", ")
|
||||||
|
}
|
||||||
|
|
||||||
|
createStmt := fmt.Sprintf("CREATE UNIQUE INDEX %s ON %s (%s)", indexName, tableName, columnClause)
|
||||||
|
if dialect == "postgres" || dialect == "sqlite" {
|
||||||
|
createStmt = strings.Replace(createStmt, "CREATE UNIQUE INDEX", "CREATE UNIQUE INDEX IF NOT EXISTS", 1)
|
||||||
|
}
|
||||||
|
|
||||||
|
log.WithContext(ctx).Infof("executing index creation: %s", createStmt)
|
||||||
|
if err := db.Exec(createStmt).Error; err != nil {
|
||||||
|
return fmt.Errorf("failed to create index %s: %w", indexName, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
log.WithContext(ctx).Infof("successfully created index %s on table %s", indexName, tableName)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
@ -15,13 +15,14 @@ import (
|
|||||||
log "github.com/sirupsen/logrus"
|
log "github.com/sirupsen/logrus"
|
||||||
"golang.org/x/exp/maps"
|
"golang.org/x/exp/maps"
|
||||||
|
|
||||||
|
nbdns "github.com/netbirdio/netbird/dns"
|
||||||
"github.com/netbirdio/netbird/management/domain"
|
"github.com/netbirdio/netbird/management/domain"
|
||||||
"github.com/netbirdio/netbird/management/server/geolocation"
|
"github.com/netbirdio/netbird/management/server/geolocation"
|
||||||
|
"github.com/netbirdio/netbird/management/server/idp"
|
||||||
routerTypes "github.com/netbirdio/netbird/management/server/networks/routers/types"
|
routerTypes "github.com/netbirdio/netbird/management/server/networks/routers/types"
|
||||||
"github.com/netbirdio/netbird/management/server/permissions/modules"
|
"github.com/netbirdio/netbird/management/server/permissions/modules"
|
||||||
"github.com/netbirdio/netbird/management/server/permissions/operations"
|
"github.com/netbirdio/netbird/management/server/permissions/operations"
|
||||||
|
|
||||||
"github.com/netbirdio/netbird/management/server/idp"
|
|
||||||
"github.com/netbirdio/netbird/management/server/posture"
|
"github.com/netbirdio/netbird/management/server/posture"
|
||||||
"github.com/netbirdio/netbird/management/server/store"
|
"github.com/netbirdio/netbird/management/server/store"
|
||||||
"github.com/netbirdio/netbird/management/server/types"
|
"github.com/netbirdio/netbird/management/server/types"
|
||||||
@ -234,14 +235,10 @@ func (am *DefaultAccountManager) UpdatePeer(ctx context.Context, accountID, user
|
|||||||
}
|
}
|
||||||
|
|
||||||
if peer.Name != update.Name {
|
if peer.Name != update.Name {
|
||||||
existingLabels, err := getPeerDNSLabels(ctx, transaction, accountID)
|
var newLabel string
|
||||||
|
newLabel, err = getPeerIPDNSLabel(ctx, transaction, peer.IP, accountID, update.Name)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return fmt.Errorf("failed to get free DNS label: %w", err)
|
||||||
}
|
|
||||||
|
|
||||||
newLabel, err := types.GetPeerHostLabel(update.Name, existingLabels)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
}
|
||||||
|
|
||||||
peer.Name = update.Name
|
peer.Name = update.Name
|
||||||
@ -463,67 +460,50 @@ func (am *DefaultAccountManager) AddPeer(ctx context.Context, setupKey, userID s
|
|||||||
upperKey := strings.ToUpper(setupKey)
|
upperKey := strings.ToUpper(setupKey)
|
||||||
hashedKey := sha256.Sum256([]byte(upperKey))
|
hashedKey := sha256.Sum256([]byte(upperKey))
|
||||||
encodedHashedKey := b64.StdEncoding.EncodeToString(hashedKey[:])
|
encodedHashedKey := b64.StdEncoding.EncodeToString(hashedKey[:])
|
||||||
var accountID string
|
addedByUser := len(userID) > 0
|
||||||
var err error
|
|
||||||
addedByUser := false
|
|
||||||
if len(userID) > 0 {
|
|
||||||
addedByUser = true
|
|
||||||
accountID, err = am.Store.GetAccountIDByUserID(ctx, store.LockingStrengthShare, userID)
|
|
||||||
} else {
|
|
||||||
accountID, err = am.Store.GetAccountIDBySetupKey(ctx, encodedHashedKey)
|
|
||||||
}
|
|
||||||
if err != nil {
|
|
||||||
return nil, nil, nil, status.Errorf(status.NotFound, "failed adding new peer: account not found")
|
|
||||||
}
|
|
||||||
|
|
||||||
unlock := am.Store.AcquireWriteLockByUID(ctx, accountID)
|
|
||||||
defer func() {
|
|
||||||
if unlock != nil {
|
|
||||||
unlock()
|
|
||||||
}
|
|
||||||
}()
|
|
||||||
|
|
||||||
// This is a handling for the case when the same machine (with the same WireGuard pub key) tries to register twice.
|
// This is a handling for the case when the same machine (with the same WireGuard pub key) tries to register twice.
|
||||||
// Such case is possible when AddPeer function takes long time to finish after AcquireWriteLockByUID (e.g., database is slow)
|
// Such case is possible when AddPeer function takes long time to finish after AcquireWriteLockByUID (e.g., database is slow)
|
||||||
// and the peer disconnects with a timeout and tries to register again.
|
// and the peer disconnects with a timeout and tries to register again.
|
||||||
// We just check if this machine has been registered before and reject the second registration.
|
// We just check if this machine has been registered before and reject the second registration.
|
||||||
// The connecting peer should be able to recover with a retry.
|
// The connecting peer should be able to recover with a retry.
|
||||||
_, err = am.Store.GetPeerByPeerPubKey(ctx, store.LockingStrengthShare, peer.Key)
|
_, err := am.Store.GetPeerByPeerPubKey(ctx, store.LockingStrengthNone, peer.Key)
|
||||||
if err == nil {
|
if err == nil {
|
||||||
return nil, nil, nil, status.Errorf(status.PreconditionFailed, "peer has been already registered")
|
return nil, nil, nil, status.Errorf(status.PreconditionFailed, "peer has been already registered")
|
||||||
}
|
}
|
||||||
|
|
||||||
opEvent := &activity.Event{
|
opEvent := &activity.Event{
|
||||||
Timestamp: time.Now().UTC(),
|
Timestamp: time.Now().UTC(),
|
||||||
AccountID: accountID,
|
|
||||||
}
|
}
|
||||||
|
|
||||||
var newPeer *nbpeer.Peer
|
var newPeer *nbpeer.Peer
|
||||||
var updateAccountPeers bool
|
var updateAccountPeers bool
|
||||||
|
|
||||||
err = am.Store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
|
|
||||||
var setupKeyID string
|
var setupKeyID string
|
||||||
var setupKeyName string
|
var setupKeyName string
|
||||||
var ephemeral bool
|
var ephemeral bool
|
||||||
var groupsToAdd []string
|
var groupsToAdd []string
|
||||||
var allowExtraDNSLabels bool
|
var allowExtraDNSLabels bool
|
||||||
|
var accountID string
|
||||||
if addedByUser {
|
if addedByUser {
|
||||||
user, err := transaction.GetUserByUserID(ctx, store.LockingStrengthUpdate, userID)
|
user, err := am.Store.GetUserByUserID(ctx, store.LockingStrengthNone, userID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("failed to get user groups: %w", err)
|
return nil, nil, nil, fmt.Errorf("failed to get user groups: %w", err)
|
||||||
}
|
}
|
||||||
groupsToAdd = user.AutoGroups
|
groupsToAdd = user.AutoGroups
|
||||||
opEvent.InitiatorID = userID
|
opEvent.InitiatorID = userID
|
||||||
opEvent.Activity = activity.PeerAddedByUser
|
opEvent.Activity = activity.PeerAddedByUser
|
||||||
|
accountID = user.AccountID
|
||||||
} else {
|
} else {
|
||||||
// Validate the setup key
|
// Validate the setup key
|
||||||
sk, err := transaction.GetSetupKeyBySecret(ctx, store.LockingStrengthUpdate, encodedHashedKey)
|
sk, err := am.Store.GetSetupKeyBySecret(ctx, store.LockingStrengthNone, encodedHashedKey)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("failed to get setup key: %w", err)
|
return nil, nil, nil, fmt.Errorf("failed to get setup key: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// we will check key twice for early return
|
||||||
if !sk.IsValid() {
|
if !sk.IsValid() {
|
||||||
return status.Errorf(status.PreconditionFailed, "couldn't add peer: setup key is invalid")
|
return nil, nil, nil, status.Errorf(status.PreconditionFailed, "couldn't add peer: setup key is invalid")
|
||||||
}
|
}
|
||||||
|
|
||||||
opEvent.InitiatorID = sk.Id
|
opEvent.InitiatorID = sk.Id
|
||||||
@ -533,11 +513,13 @@ func (am *DefaultAccountManager) AddPeer(ctx context.Context, setupKey, userID s
|
|||||||
setupKeyID = sk.Id
|
setupKeyID = sk.Id
|
||||||
setupKeyName = sk.Name
|
setupKeyName = sk.Name
|
||||||
allowExtraDNSLabels = sk.AllowExtraDNSLabels
|
allowExtraDNSLabels = sk.AllowExtraDNSLabels
|
||||||
|
accountID = sk.AccountID
|
||||||
|
|
||||||
if !sk.AllowExtraDNSLabels && len(peer.ExtraDNSLabels) > 0 {
|
if !sk.AllowExtraDNSLabels && len(peer.ExtraDNSLabels) > 0 {
|
||||||
return status.Errorf(status.PreconditionFailed, "couldn't add peer: setup key doesn't allow extra DNS labels")
|
return nil, nil, nil, status.Errorf(status.PreconditionFailed, "couldn't add peer: setup key doesn't allow extra DNS labels")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
opEvent.AccountID = accountID
|
||||||
|
|
||||||
if (strings.ToLower(peer.Meta.Hostname) == "iphone" || strings.ToLower(peer.Meta.Hostname) == "ipad") && userID != "" {
|
if (strings.ToLower(peer.Meta.Hostname) == "iphone" || strings.ToLower(peer.Meta.Hostname) == "ipad") && userID != "" {
|
||||||
if am.idpManager != nil {
|
if am.idpManager != nil {
|
||||||
@ -548,18 +530,8 @@ func (am *DefaultAccountManager) AddPeer(ctx context.Context, setupKey, userID s
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
freeLabel, err := am.getFreeDNSLabel(ctx, transaction, accountID, peer.Meta.Hostname)
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("failed to get free DNS label: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
freeIP, err := getFreeIP(ctx, transaction, accountID)
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("failed to get free IP: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
if err := domain.ValidateDomainsList(peer.ExtraDNSLabels); err != nil {
|
if err := domain.ValidateDomainsList(peer.ExtraDNSLabels); err != nil {
|
||||||
return status.Errorf(status.InvalidArgument, "invalid extra DNS labels: %v", err)
|
return nil, nil, nil, status.Errorf(status.InvalidArgument, "invalid extra DNS labels: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
registrationTime := time.Now().UTC()
|
registrationTime := time.Now().UTC()
|
||||||
@ -567,10 +539,8 @@ func (am *DefaultAccountManager) AddPeer(ctx context.Context, setupKey, userID s
|
|||||||
ID: xid.New().String(),
|
ID: xid.New().String(),
|
||||||
AccountID: accountID,
|
AccountID: accountID,
|
||||||
Key: peer.Key,
|
Key: peer.Key,
|
||||||
IP: freeIP,
|
|
||||||
Meta: peer.Meta,
|
Meta: peer.Meta,
|
||||||
Name: peer.Meta.Hostname,
|
Name: peer.Meta.Hostname,
|
||||||
DNSLabel: freeLabel,
|
|
||||||
UserID: userID,
|
UserID: userID,
|
||||||
Status: &nbpeer.PeerStatus{Connected: false, LastSeen: registrationTime},
|
Status: &nbpeer.PeerStatus{Connected: false, LastSeen: registrationTime},
|
||||||
SSHEnabled: false,
|
SSHEnabled: false,
|
||||||
@ -584,15 +554,9 @@ func (am *DefaultAccountManager) AddPeer(ctx context.Context, setupKey, userID s
|
|||||||
ExtraDNSLabels: peer.ExtraDNSLabels,
|
ExtraDNSLabels: peer.ExtraDNSLabels,
|
||||||
AllowExtraDNSLabels: allowExtraDNSLabels,
|
AllowExtraDNSLabels: allowExtraDNSLabels,
|
||||||
}
|
}
|
||||||
settings, err := transaction.GetAccountSettings(ctx, store.LockingStrengthShare, accountID)
|
settings, err := am.Store.GetAccountSettings(ctx, store.LockingStrengthNone, accountID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("failed to get account settings: %w", err)
|
return nil, nil, nil, fmt.Errorf("failed to get account settings: %w", err)
|
||||||
}
|
|
||||||
|
|
||||||
opEvent.TargetID = newPeer.ID
|
|
||||||
opEvent.Meta = newPeer.EventMeta(am.GetDNSDomain(settings))
|
|
||||||
if !addedByUser {
|
|
||||||
opEvent.Meta["setup_key_name"] = setupKeyName
|
|
||||||
}
|
}
|
||||||
|
|
||||||
if am.geo != nil && newPeer.Location.ConnectionIP != nil {
|
if am.geo != nil && newPeer.Location.ConnectionIP != nil {
|
||||||
@ -608,6 +572,41 @@ func (am *DefaultAccountManager) AddPeer(ctx context.Context, setupKey, userID s
|
|||||||
|
|
||||||
newPeer = am.integratedPeerValidator.PreparePeer(ctx, accountID, newPeer, groupsToAdd, settings.Extra)
|
newPeer = am.integratedPeerValidator.PreparePeer(ctx, accountID, newPeer, groupsToAdd, settings.Extra)
|
||||||
|
|
||||||
|
network, err := am.Store.GetAccountNetwork(ctx, store.LockingStrengthNone, accountID)
|
||||||
|
if err != nil {
|
||||||
|
return nil, nil, nil, fmt.Errorf("failed getting network: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
maxAttempts := 10
|
||||||
|
for attempt := 1; attempt <= maxAttempts; attempt++ {
|
||||||
|
var freeIP net.IP
|
||||||
|
freeIP, err = types.AllocateRandomPeerIP(network.Net)
|
||||||
|
if err != nil {
|
||||||
|
return nil, nil, nil, fmt.Errorf("failed to get free IP: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
var freeLabel string
|
||||||
|
freeLabel, err = getPeerIPDNSLabel(ctx, am.Store, freeIP, accountID, peer.Meta.Hostname)
|
||||||
|
if err != nil {
|
||||||
|
return nil, nil, nil, fmt.Errorf("failed to get free DNS label: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
newPeer.DNSLabel = freeLabel
|
||||||
|
newPeer.IP = freeIP
|
||||||
|
|
||||||
|
unlock := am.Store.AcquireWriteLockByUID(ctx, accountID)
|
||||||
|
defer func() {
|
||||||
|
if unlock != nil {
|
||||||
|
unlock()
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
err = am.Store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
|
||||||
|
err = transaction.AddPeerToAccount(ctx, store.LockingStrengthUpdate, newPeer)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
err = transaction.AddPeerToAllGroup(ctx, store.LockingStrengthUpdate, accountID, newPeer.ID)
|
err = transaction.AddPeerToAllGroup(ctx, store.LockingStrengthUpdate, accountID, newPeer.ID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("failed adding peer to All group: %w", err)
|
return fmt.Errorf("failed adding peer to All group: %w", err)
|
||||||
@ -622,9 +621,26 @@ func (am *DefaultAccountManager) AddPeer(ctx context.Context, setupKey, userID s
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
err = transaction.AddPeerToAccount(ctx, store.LockingStrengthUpdate, newPeer)
|
if addedByUser {
|
||||||
|
err := transaction.SaveUserLastLogin(ctx, accountID, userID, newPeer.GetLastLogin())
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("failed to add peer to account: %w", err)
|
log.WithContext(ctx).Debugf("failed to update user last login: %v", err)
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
sk, err := transaction.GetSetupKeyBySecret(ctx, store.LockingStrengthUpdate, encodedHashedKey)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to get setup key: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// we validate at the end to not block the setup key for too long
|
||||||
|
if !sk.IsValid() {
|
||||||
|
return status.Errorf(status.PreconditionFailed, "couldn't add peer: setup key is invalid")
|
||||||
|
}
|
||||||
|
|
||||||
|
err = transaction.IncrementSetupKeyUsage(ctx, setupKeyID)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to increment setup key usage: %w", err)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
err = transaction.IncrementNetworkSerial(ctx, store.LockingStrengthUpdate, accountID)
|
err = transaction.IncrementNetworkSerial(ctx, store.LockingStrengthUpdate, accountID)
|
||||||
@ -632,39 +648,44 @@ func (am *DefaultAccountManager) AddPeer(ctx context.Context, setupKey, userID s
|
|||||||
return fmt.Errorf("failed to increment network serial: %w", err)
|
return fmt.Errorf("failed to increment network serial: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
if addedByUser {
|
|
||||||
err := transaction.SaveUserLastLogin(ctx, accountID, userID, newPeer.GetLastLogin())
|
|
||||||
if err != nil {
|
|
||||||
log.WithContext(ctx).Debugf("failed to update user last login: %v", err)
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
err = transaction.IncrementSetupKeyUsage(ctx, setupKeyID)
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("failed to increment setup key usage: %w", err)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
updateAccountPeers, err = isPeerInActiveGroup(ctx, transaction, accountID, newPeer.ID)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
log.WithContext(ctx).Debugf("Peer %s added to account %s", newPeer.ID, accountID)
|
log.WithContext(ctx).Debugf("Peer %s added to account %s", newPeer.ID, accountID)
|
||||||
return nil
|
return nil
|
||||||
})
|
})
|
||||||
|
if err == nil {
|
||||||
|
unlock()
|
||||||
|
unlock = nil
|
||||||
|
break
|
||||||
|
}
|
||||||
|
|
||||||
|
if isUniqueConstraintError(err) {
|
||||||
|
unlock()
|
||||||
|
unlock = nil
|
||||||
|
log.WithContext(ctx).Debugf("Failed to add peer in attempt %d, retrying: %v", attempt, err)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
if err != nil {
|
|
||||||
return nil, nil, nil, fmt.Errorf("failed to add peer to database: %w", err)
|
return nil, nil, nil, fmt.Errorf("failed to add peer to database: %w", err)
|
||||||
}
|
}
|
||||||
|
if err != nil {
|
||||||
|
return nil, nil, nil, fmt.Errorf("failed to add peer to database after %d attempts: %w", maxAttempts, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
updateAccountPeers, err = isPeerInActiveGroup(ctx, am.Store, accountID, newPeer.ID)
|
||||||
|
if err != nil {
|
||||||
|
updateAccountPeers = true
|
||||||
|
}
|
||||||
|
|
||||||
if newPeer == nil {
|
if newPeer == nil {
|
||||||
return nil, nil, nil, fmt.Errorf("new peer is nil")
|
return nil, nil, nil, fmt.Errorf("new peer is nil")
|
||||||
}
|
}
|
||||||
|
|
||||||
am.StoreEvent(ctx, opEvent.InitiatorID, opEvent.TargetID, opEvent.AccountID, opEvent.Activity, opEvent.Meta)
|
opEvent.TargetID = newPeer.ID
|
||||||
|
opEvent.Meta = newPeer.EventMeta(am.GetDNSDomain(settings))
|
||||||
|
if !addedByUser {
|
||||||
|
opEvent.Meta["setup_key_name"] = setupKeyName
|
||||||
|
}
|
||||||
|
|
||||||
unlock()
|
am.StoreEvent(ctx, opEvent.InitiatorID, opEvent.TargetID, opEvent.AccountID, opEvent.Activity, opEvent.Meta)
|
||||||
unlock = nil
|
|
||||||
|
|
||||||
if updateAccountPeers {
|
if updateAccountPeers {
|
||||||
am.BufferUpdateAccountPeers(ctx, accountID)
|
am.BufferUpdateAccountPeers(ctx, accountID)
|
||||||
@ -673,23 +694,21 @@ func (am *DefaultAccountManager) AddPeer(ctx context.Context, setupKey, userID s
|
|||||||
return am.getValidatedPeerWithMap(ctx, false, accountID, newPeer)
|
return am.getValidatedPeerWithMap(ctx, false, accountID, newPeer)
|
||||||
}
|
}
|
||||||
|
|
||||||
func getFreeIP(ctx context.Context, transaction store.Store, accountID string) (net.IP, error) {
|
func getPeerIPDNSLabel(ctx context.Context, tx store.Store, ip net.IP, accountID, peerHostName string) (string, error) {
|
||||||
takenIps, err := transaction.GetTakenIPs(ctx, store.LockingStrengthShare, accountID)
|
ip = ip.To4()
|
||||||
|
|
||||||
|
dnsName, err := nbdns.GetParsedDomainLabel(peerHostName)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("failed to get taken IPs: %w", err)
|
return "", fmt.Errorf("failed to parse peer host name %s: %w", peerHostName, err)
|
||||||
}
|
}
|
||||||
|
|
||||||
network, err := transaction.GetAccountNetwork(ctx, store.LockingStrengthUpdate, accountID)
|
_, err = tx.GetPeerIdByLabel(ctx, store.LockingStrengthNone, accountID, dnsName)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("failed getting network: %w", err)
|
//nolint:nilerr
|
||||||
|
return dnsName, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
nextIp, err := types.AllocatePeerIP(network.Net, takenIps)
|
return fmt.Sprintf("%s-%d-%d", dnsName, ip[2], ip[3]), nil
|
||||||
if err != nil {
|
|
||||||
return nil, fmt.Errorf("failed to allocate new peer ip: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
return nextIp, nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// SyncPeer checks whether peer is eligible for receiving NetworkMap (authenticated) and returns its NetworkMap if eligible
|
// SyncPeer checks whether peer is eligible for receiving NetworkMap (authenticated) and returns its NetworkMap if eligible
|
||||||
@ -1477,19 +1496,6 @@ func getPeerGroupIDs(ctx context.Context, transaction store.Store, accountID str
|
|||||||
return groupIDs, err
|
return groupIDs, err
|
||||||
}
|
}
|
||||||
|
|
||||||
func getPeerDNSLabels(ctx context.Context, transaction store.Store, accountID string) (types.LookupMap, error) {
|
|
||||||
dnsLabels, err := transaction.GetPeerLabelsInAccount(ctx, store.LockingStrengthShare, accountID)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
existingLabels := make(types.LookupMap)
|
|
||||||
for _, label := range dnsLabels {
|
|
||||||
existingLabels[label] = struct{}{}
|
|
||||||
}
|
|
||||||
return existingLabels, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// IsPeerInActiveGroup checks if the given peer is part of a group that is used
|
// IsPeerInActiveGroup checks if the given peer is part of a group that is used
|
||||||
// in an active DNS, route, or ACL configuration.
|
// in an active DNS, route, or ACL configuration.
|
||||||
func isPeerInActiveGroup(ctx context.Context, transaction store.Store, accountID, peerID string) (bool, error) {
|
func isPeerInActiveGroup(ctx context.Context, transaction store.Store, accountID, peerID string) (bool, error) {
|
||||||
|
@ -20,14 +20,14 @@ type Peer struct {
|
|||||||
// WireGuard public key
|
// WireGuard public key
|
||||||
Key string `gorm:"index"`
|
Key string `gorm:"index"`
|
||||||
// IP address of the Peer
|
// IP address of the Peer
|
||||||
IP net.IP `gorm:"serializer:json"`
|
IP net.IP `gorm:"serializer:json"` // uniqueness index per accountID (check migrations)
|
||||||
// Meta is a Peer system meta data
|
// Meta is a Peer system meta data
|
||||||
Meta PeerSystemMeta `gorm:"embedded;embeddedPrefix:meta_"`
|
Meta PeerSystemMeta `gorm:"embedded;embeddedPrefix:meta_"`
|
||||||
// Name is peer's name (machine name)
|
// Name is peer's name (machine name)
|
||||||
Name string
|
Name string
|
||||||
// DNSLabel is the parsed peer name for domain resolution. It is used to form an FQDN by appending the account's
|
// DNSLabel is the parsed peer name for domain resolution. It is used to form an FQDN by appending the account's
|
||||||
// domain to the peer label. e.g. peer-dns-label.netbird.cloud
|
// domain to the peer label. e.g. peer-dns-label.netbird.cloud
|
||||||
DNSLabel string
|
DNSLabel string // uniqueness index per accountID (check migrations)
|
||||||
// Status peer's management connection status
|
// Status peer's management connection status
|
||||||
Status *PeerStatus `gorm:"embedded;embeddedPrefix:peer_status_"`
|
Status *PeerStatus `gorm:"embedded;embeddedPrefix:peer_status_"`
|
||||||
// The user ID that registered the peer
|
// The user ID that registered the peer
|
||||||
|
@ -10,7 +10,9 @@ import (
|
|||||||
"net/netip"
|
"net/netip"
|
||||||
"os"
|
"os"
|
||||||
"runtime"
|
"runtime"
|
||||||
|
"strconv"
|
||||||
"strings"
|
"strings"
|
||||||
|
"sync"
|
||||||
"testing"
|
"testing"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
@ -19,6 +21,7 @@ import (
|
|||||||
log "github.com/sirupsen/logrus"
|
log "github.com/sirupsen/logrus"
|
||||||
"github.com/stretchr/testify/assert"
|
"github.com/stretchr/testify/assert"
|
||||||
"github.com/stretchr/testify/require"
|
"github.com/stretchr/testify/require"
|
||||||
|
"golang.org/x/exp/maps"
|
||||||
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
|
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
|
||||||
|
|
||||||
"github.com/netbirdio/netbird/management/server/integrations/port_forwarding"
|
"github.com/netbirdio/netbird/management/server/integrations/port_forwarding"
|
||||||
@ -1391,7 +1394,7 @@ func Test_RegisterPeerBySetupKey(t *testing.T) {
|
|||||||
name: "Absent setup key",
|
name: "Absent setup key",
|
||||||
existingSetupKeyID: "AAAAAAAA-38F5-4553-B31E-DD66C696CEBB",
|
existingSetupKeyID: "AAAAAAAA-38F5-4553-B31E-DD66C696CEBB",
|
||||||
expectAddPeerError: true,
|
expectAddPeerError: true,
|
||||||
expectedErrorMsgSubstring: "failed adding new peer: account not found",
|
expectedErrorMsgSubstring: "failed to get setup key: setup key not found",
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -2057,10 +2060,14 @@ func Test_DeletePeer(t *testing.T) {
|
|||||||
"peer1": {
|
"peer1": {
|
||||||
ID: "peer1",
|
ID: "peer1",
|
||||||
AccountID: accountID,
|
AccountID: accountID,
|
||||||
|
IP: net.IP{1, 1, 1, 1},
|
||||||
|
DNSLabel: "peer1.test",
|
||||||
},
|
},
|
||||||
"peer2": {
|
"peer2": {
|
||||||
ID: "peer2",
|
ID: "peer2",
|
||||||
AccountID: accountID,
|
AccountID: accountID,
|
||||||
|
IP: net.IP{2, 2, 2, 2},
|
||||||
|
DNSLabel: "peer2.test",
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
account.Groups = map[string]*types.Group{
|
account.Groups = map[string]*types.Group{
|
||||||
@ -2090,3 +2097,138 @@ func Test_DeletePeer(t *testing.T) {
|
|||||||
assert.NotContains(t, group.Peers, "peer1")
|
assert.NotContains(t, group.Peers, "peer1")
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func Test_IsUniqueConstraintError(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
engine types.Engine
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "PostgreSQL uniqueness error",
|
||||||
|
engine: types.PostgresStoreEngine,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "MySQL uniqueness error",
|
||||||
|
engine: types.MysqlStoreEngine,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "SQLite uniqueness error",
|
||||||
|
engine: types.SqliteStoreEngine,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
peer := &nbpeer.Peer{
|
||||||
|
ID: "test-peer-id",
|
||||||
|
AccountID: "bf1c8084-ba50-4ce7-9439-34653001fc3b",
|
||||||
|
DNSLabel: "test-peer-dns-label",
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
t.Setenv("NETBIRD_STORE_ENGINE", string(tt.engine))
|
||||||
|
s, cleanup, err := store.NewTestStoreFromSQL(context.Background(), "testdata/extended-store.sql", t.TempDir())
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Error when creating store: %s", err)
|
||||||
|
}
|
||||||
|
t.Cleanup(cleanup)
|
||||||
|
|
||||||
|
err = s.AddPeerToAccount(context.Background(), store.LockingStrengthUpdate, peer)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
|
||||||
|
err = s.AddPeerToAccount(context.Background(), store.LockingStrengthUpdate, peer)
|
||||||
|
result := isUniqueConstraintError(err)
|
||||||
|
assert.True(t, result)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func Test_AddPeer(t *testing.T) {
|
||||||
|
t.Setenv("NETBIRD_STORE_ENGINE", string(types.PostgresStoreEngine))
|
||||||
|
manager, err := createManager(t)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
accountID := "testaccount"
|
||||||
|
userID := "testuser"
|
||||||
|
|
||||||
|
_, err = createAccount(manager, accountID, userID, "domain.com")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal("error creating account")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
setupKey, err := manager.CreateSetupKey(context.Background(), accountID, "test-key", types.SetupKeyReusable, time.Hour, nil, 10000, userID, false, false)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal("error creating setup key")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
const totalPeers = 300 // totalPeers / differentHostnames should be less than 10 (due to concurrent retries)
|
||||||
|
const differentHostnames = 50
|
||||||
|
|
||||||
|
var wg sync.WaitGroup
|
||||||
|
errs := make(chan error, totalPeers+differentHostnames)
|
||||||
|
start := make(chan struct{})
|
||||||
|
for i := 0; i < totalPeers; i++ {
|
||||||
|
wg.Add(1)
|
||||||
|
hostNameID := i % differentHostnames
|
||||||
|
|
||||||
|
go func(i int) {
|
||||||
|
defer wg.Done()
|
||||||
|
|
||||||
|
newPeer := &nbpeer.Peer{
|
||||||
|
Key: "key" + strconv.Itoa(i),
|
||||||
|
Meta: nbpeer.PeerSystemMeta{Hostname: "peer" + strconv.Itoa(hostNameID), GoOS: "linux"},
|
||||||
|
}
|
||||||
|
|
||||||
|
<-start
|
||||||
|
|
||||||
|
_, _, _, err := manager.AddPeer(context.Background(), setupKey.Key, "", newPeer)
|
||||||
|
if err != nil {
|
||||||
|
errs <- fmt.Errorf("AddPeer failed for peer %d: %w", i, err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
}(i)
|
||||||
|
}
|
||||||
|
startTime := time.Now()
|
||||||
|
|
||||||
|
close(start)
|
||||||
|
wg.Wait()
|
||||||
|
close(errs)
|
||||||
|
|
||||||
|
t.Logf("time since start: %s", time.Since(startTime))
|
||||||
|
|
||||||
|
for err := range errs {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
account, err := manager.Store.GetAccount(context.Background(), accountID)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Failed to get account %s: %v", accountID, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
assert.Equal(t, totalPeers, len(account.Peers), "Expected %d peers in account %s, got %d", totalPeers, accountID, len(account.Peers))
|
||||||
|
|
||||||
|
seenIP := make(map[string]bool)
|
||||||
|
for _, p := range account.Peers {
|
||||||
|
ipStr := p.IP.String()
|
||||||
|
if seenIP[ipStr] {
|
||||||
|
t.Fatalf("Duplicate IP found in account %s: %s", accountID, ipStr)
|
||||||
|
}
|
||||||
|
seenIP[ipStr] = true
|
||||||
|
}
|
||||||
|
|
||||||
|
seenLabel := make(map[string]bool)
|
||||||
|
for _, p := range account.Peers {
|
||||||
|
if seenLabel[p.DNSLabel] {
|
||||||
|
t.Fatalf("Duplicate Label found in account %s: %s", accountID, p.DNSLabel)
|
||||||
|
}
|
||||||
|
seenLabel[p.DNSLabel] = true
|
||||||
|
}
|
||||||
|
|
||||||
|
assert.Equal(t, totalPeers, maps.Values(account.SetupKeys)[0].UsedTimes)
|
||||||
|
assert.Equal(t, uint64(totalPeers), account.Network.Serial)
|
||||||
|
}
|
||||||
|
@ -156,7 +156,7 @@ func restore(ctx context.Context, file string) (*FileStore, error) {
|
|||||||
|
|
||||||
allGroup, err := account.GetGroupAll()
|
allGroup, err := account.GetGroupAll()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.WithContext(ctx).Errorf("unable to find the All group, this should happen only when migrate from a version that didn't support groups. Error: %v", err)
|
log.WithContext(ctx).Errorf("unable to find the All group, this should happen only when migratePreAuto from a version that didn't support groups. Error: %v", err)
|
||||||
// if the All group didn't exist we probably don't have routes to update
|
// if the All group didn't exist we probably don't have routes to update
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
@ -92,8 +92,8 @@ func NewSqlStore(ctx context.Context, db *gorm.DB, storeEngine types.Engine, met
|
|||||||
return &SqlStore{db: db, storeEngine: storeEngine, metrics: metrics, installationPK: 1}, nil
|
return &SqlStore{db: db, storeEngine: storeEngine, metrics: metrics, installationPK: 1}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := migrate(ctx, db); err != nil {
|
if err := migratePreAuto(ctx, db); err != nil {
|
||||||
return nil, fmt.Errorf("migrate: %w", err)
|
return nil, fmt.Errorf("migratePreAuto: %w", err)
|
||||||
}
|
}
|
||||||
err = db.AutoMigrate(
|
err = db.AutoMigrate(
|
||||||
&types.SetupKey{}, &nbpeer.Peer{}, &types.User{}, &types.PersonalAccessToken{}, &types.Group{},
|
&types.SetupKey{}, &nbpeer.Peer{}, &types.User{}, &types.PersonalAccessToken{}, &types.Group{},
|
||||||
@ -102,7 +102,10 @@ func NewSqlStore(ctx context.Context, db *gorm.DB, storeEngine types.Engine, met
|
|||||||
&networkTypes.Network{}, &routerTypes.NetworkRouter{}, &resourceTypes.NetworkResource{},
|
&networkTypes.Network{}, &routerTypes.NetworkRouter{}, &resourceTypes.NetworkResource{},
|
||||||
)
|
)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("auto migrate: %w", err)
|
return nil, fmt.Errorf("auto migratePreAuto: %w", err)
|
||||||
|
}
|
||||||
|
if err := migratePostAuto(ctx, db); err != nil {
|
||||||
|
return nil, fmt.Errorf("migratePostAuto: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
return &SqlStore{db: db, storeEngine: storeEngine, metrics: metrics, installationPK: 1}, nil
|
return &SqlStore{db: db, storeEngine: storeEngine, metrics: metrics, installationPK: 1}, nil
|
||||||
@ -967,7 +970,7 @@ func (s *SqlStore) GetTakenIPs(ctx context.Context, lockStrength LockingStrength
|
|||||||
return ips, nil
|
return ips, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *SqlStore) GetPeerLabelsInAccount(ctx context.Context, lockStrength LockingStrength, accountID string) ([]string, error) {
|
func (s *SqlStore) GetPeerLabelsInAccount(ctx context.Context, lockStrength LockingStrength, accountID string, dnsLabel string) ([]string, error) {
|
||||||
tx := s.db
|
tx := s.db
|
||||||
if lockStrength != LockingStrengthNone {
|
if lockStrength != LockingStrengthNone {
|
||||||
tx = tx.Clauses(clause.Locking{Strength: string(lockStrength)})
|
tx = tx.Clauses(clause.Locking{Strength: string(lockStrength)})
|
||||||
@ -975,7 +978,7 @@ func (s *SqlStore) GetPeerLabelsInAccount(ctx context.Context, lockStrength Lock
|
|||||||
|
|
||||||
var labels []string
|
var labels []string
|
||||||
result := tx.Model(&nbpeer.Peer{}).
|
result := tx.Model(&nbpeer.Peer{}).
|
||||||
Where("account_id = ?", accountID).
|
Where("account_id = ? AND dns_label LIKE ?", accountID, dnsLabel+"%").
|
||||||
Pluck("dns_label", &labels)
|
Pluck("dns_label", &labels)
|
||||||
|
|
||||||
if result.Error != nil {
|
if result.Error != nil {
|
||||||
@ -1254,7 +1257,7 @@ func (s *SqlStore) GetSetupKeyBySecret(ctx context.Context, lockStrength Locking
|
|||||||
|
|
||||||
if result.Error != nil {
|
if result.Error != nil {
|
||||||
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
|
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
|
||||||
return nil, status.NewSetupKeyNotFoundError(key)
|
return nil, status.Errorf(status.PreconditionFailed, "setup key not found")
|
||||||
}
|
}
|
||||||
log.WithContext(ctx).Errorf("failed to get setup key by secret from store: %v", result.Error)
|
log.WithContext(ctx).Errorf("failed to get setup key by secret from store: %v", result.Error)
|
||||||
return nil, status.Errorf(status.Internal, "failed to get setup key by secret from store")
|
return nil, status.Errorf(status.Internal, "failed to get setup key by secret from store")
|
||||||
@ -2546,6 +2549,27 @@ func (s *SqlStore) GetPeerByIP(ctx context.Context, lockStrength LockingStrength
|
|||||||
return &peer, nil
|
return &peer, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (s *SqlStore) GetPeerIdByLabel(ctx context.Context, lockStrength LockingStrength, accountID string, hostname string) (string, error) {
|
||||||
|
tx := s.db.WithContext(ctx)
|
||||||
|
if lockStrength != LockingStrengthNone {
|
||||||
|
tx = tx.Clauses(clause.Locking{Strength: string(lockStrength)})
|
||||||
|
}
|
||||||
|
|
||||||
|
var peerID string
|
||||||
|
result := tx.Model(&nbpeer.Peer{}).
|
||||||
|
Select("id").
|
||||||
|
// Where(" = ?", hostname).
|
||||||
|
Where("account_id = ? AND dns_label = ?", accountID, hostname).
|
||||||
|
Limit(1).
|
||||||
|
Scan(&peerID)
|
||||||
|
|
||||||
|
if peerID == "" {
|
||||||
|
return "", gorm.ErrRecordNotFound
|
||||||
|
}
|
||||||
|
|
||||||
|
return peerID, result.Error
|
||||||
|
}
|
||||||
|
|
||||||
func (s *SqlStore) CountAccountsByPrivateDomain(ctx context.Context, domain string) (int64, error) {
|
func (s *SqlStore) CountAccountsByPrivateDomain(ctx context.Context, domain string) (int64, error) {
|
||||||
var count int64
|
var count int64
|
||||||
result := s.db.Model(&types.Account{}).
|
result := s.db.Model(&types.Account{}).
|
||||||
|
@ -10,6 +10,7 @@ import (
|
|||||||
"net/netip"
|
"net/netip"
|
||||||
"os"
|
"os"
|
||||||
"runtime"
|
"runtime"
|
||||||
|
"sort"
|
||||||
"sync"
|
"sync"
|
||||||
"testing"
|
"testing"
|
||||||
"time"
|
"time"
|
||||||
@ -630,7 +631,7 @@ func TestMigrate(t *testing.T) {
|
|||||||
t.Cleanup(cleanUp)
|
t.Cleanup(cleanUp)
|
||||||
assert.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
|
|
||||||
err = migrate(context.Background(), store.(*SqlStore).db)
|
err = migratePreAuto(context.Background(), store.(*SqlStore).db)
|
||||||
require.NoError(t, err, "Migration should not fail on empty db")
|
require.NoError(t, err, "Migration should not fail on empty db")
|
||||||
|
|
||||||
_, ipnet, err := net.ParseCIDR("10.0.0.0/24")
|
_, ipnet, err := net.ParseCIDR("10.0.0.0/24")
|
||||||
@ -685,10 +686,10 @@ func TestMigrate(t *testing.T) {
|
|||||||
err = store.(*SqlStore).db.Save(rt).Error
|
err = store.(*SqlStore).db.Save(rt).Error
|
||||||
require.NoError(t, err, "Failed to insert Gob data")
|
require.NoError(t, err, "Failed to insert Gob data")
|
||||||
|
|
||||||
err = migrate(context.Background(), store.(*SqlStore).db)
|
err = migratePreAuto(context.Background(), store.(*SqlStore).db)
|
||||||
require.NoError(t, err, "Migration should not fail on gob populated db")
|
require.NoError(t, err, "Migration should not fail on gob populated db")
|
||||||
|
|
||||||
err = migrate(context.Background(), store.(*SqlStore).db)
|
err = migratePreAuto(context.Background(), store.(*SqlStore).db)
|
||||||
require.NoError(t, err, "Migration should not fail on migrated db")
|
require.NoError(t, err, "Migration should not fail on migrated db")
|
||||||
|
|
||||||
err = store.(*SqlStore).db.Delete(rt).Where("id = ?", "route1").Error
|
err = store.(*SqlStore).db.Delete(rt).Where("id = ?", "route1").Error
|
||||||
@ -704,10 +705,10 @@ func TestMigrate(t *testing.T) {
|
|||||||
err = store.(*SqlStore).db.Save(nRT).Error
|
err = store.(*SqlStore).db.Save(nRT).Error
|
||||||
require.NoError(t, err, "Failed to insert json nil slice data")
|
require.NoError(t, err, "Failed to insert json nil slice data")
|
||||||
|
|
||||||
err = migrate(context.Background(), store.(*SqlStore).db)
|
err = migratePreAuto(context.Background(), store.(*SqlStore).db)
|
||||||
require.NoError(t, err, "Migration should not fail on json nil slice populated db")
|
require.NoError(t, err, "Migration should not fail on json nil slice populated db")
|
||||||
|
|
||||||
err = migrate(context.Background(), store.(*SqlStore).db)
|
err = migratePreAuto(context.Background(), store.(*SqlStore).db)
|
||||||
require.NoError(t, err, "Migration should not fail on migrated db")
|
require.NoError(t, err, "Migration should not fail on migrated db")
|
||||||
|
|
||||||
}
|
}
|
||||||
@ -950,6 +951,7 @@ func TestSqlite_GetTakenIPs(t *testing.T) {
|
|||||||
peer1 := &nbpeer.Peer{
|
peer1 := &nbpeer.Peer{
|
||||||
ID: "peer1",
|
ID: "peer1",
|
||||||
AccountID: existingAccountID,
|
AccountID: existingAccountID,
|
||||||
|
DNSLabel: "peer1",
|
||||||
IP: net.IP{1, 1, 1, 1},
|
IP: net.IP{1, 1, 1, 1},
|
||||||
}
|
}
|
||||||
err = store.AddPeerToAccount(context.Background(), LockingStrengthUpdate, peer1)
|
err = store.AddPeerToAccount(context.Background(), LockingStrengthUpdate, peer1)
|
||||||
@ -961,8 +963,9 @@ func TestSqlite_GetTakenIPs(t *testing.T) {
|
|||||||
assert.Equal(t, []net.IP{ip1}, takenIPs)
|
assert.Equal(t, []net.IP{ip1}, takenIPs)
|
||||||
|
|
||||||
peer2 := &nbpeer.Peer{
|
peer2 := &nbpeer.Peer{
|
||||||
ID: "peer2",
|
ID: "peer1second",
|
||||||
AccountID: existingAccountID,
|
AccountID: existingAccountID,
|
||||||
|
DNSLabel: "peer1-1",
|
||||||
IP: net.IP{2, 2, 2, 2},
|
IP: net.IP{2, 2, 2, 2},
|
||||||
}
|
}
|
||||||
err = store.AddPeerToAccount(context.Background(), LockingStrengthUpdate, peer2)
|
err = store.AddPeerToAccount(context.Background(), LockingStrengthUpdate, peer2)
|
||||||
@ -972,26 +975,59 @@ func TestSqlite_GetTakenIPs(t *testing.T) {
|
|||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
ip2 := net.IP{2, 2, 2, 2}.To16()
|
ip2 := net.IP{2, 2, 2, 2}.To16()
|
||||||
assert.Equal(t, []net.IP{ip1, ip2}, takenIPs)
|
assert.Equal(t, []net.IP{ip1, ip2}, takenIPs)
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestSqlite_GetPeerLabelsInAccount(t *testing.T) {
|
func TestSqlite_GetPeerLabelsInAccount(t *testing.T) {
|
||||||
t.Setenv("NETBIRD_STORE_ENGINE", string(types.SqliteStoreEngine))
|
runTestForAllEngines(t, "../testdata/extended-store.sql", func(t *testing.T, store Store) {
|
||||||
store, cleanup, err := NewTestStoreFromSQL(context.Background(), "../testdata/extended-store.sql", t.TempDir())
|
|
||||||
if err != nil {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
t.Cleanup(cleanup)
|
|
||||||
|
|
||||||
existingAccountID := "bf1c8084-ba50-4ce7-9439-34653001fc3b"
|
existingAccountID := "bf1c8084-ba50-4ce7-9439-34653001fc3b"
|
||||||
|
peerHostname := "peer1"
|
||||||
|
|
||||||
_, err = store.GetAccount(context.Background(), existingAccountID)
|
_, err := store.GetAccount(context.Background(), existingAccountID)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
labels, err := store.GetPeerLabelsInAccount(context.Background(), LockingStrengthShare, existingAccountID)
|
labels, err := store.GetPeerLabelsInAccount(context.Background(), LockingStrengthShare, existingAccountID, peerHostname)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
assert.Equal(t, []string{}, labels)
|
assert.Equal(t, []string{}, labels)
|
||||||
|
|
||||||
|
peer1 := &nbpeer.Peer{
|
||||||
|
ID: "peer1",
|
||||||
|
AccountID: existingAccountID,
|
||||||
|
DNSLabel: "peer1",
|
||||||
|
IP: net.IP{1, 1, 1, 1},
|
||||||
|
}
|
||||||
|
err = store.AddPeerToAccount(context.Background(), LockingStrengthUpdate, peer1)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
labels, err = store.GetPeerLabelsInAccount(context.Background(), LockingStrengthShare, existingAccountID, peerHostname)
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.Equal(t, []string{"peer1"}, labels)
|
||||||
|
|
||||||
|
peer2 := &nbpeer.Peer{
|
||||||
|
ID: "peer1second",
|
||||||
|
AccountID: existingAccountID,
|
||||||
|
DNSLabel: "peer1-1",
|
||||||
|
IP: net.IP{2, 2, 2, 2},
|
||||||
|
}
|
||||||
|
err = store.AddPeerToAccount(context.Background(), LockingStrengthUpdate, peer2)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
labels, err = store.GetPeerLabelsInAccount(context.Background(), LockingStrengthShare, existingAccountID, peerHostname)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
expected := []string{"peer1", "peer1-1"}
|
||||||
|
sort.Strings(expected)
|
||||||
|
sort.Strings(labels)
|
||||||
|
assert.Equal(t, expected, labels)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func Test_AddPeerWithSameDnsLabel(t *testing.T) {
|
||||||
|
runTestForAllEngines(t, "../testdata/extended-store.sql", func(t *testing.T, store Store) {
|
||||||
|
existingAccountID := "bf1c8084-ba50-4ce7-9439-34653001fc3b"
|
||||||
|
|
||||||
|
_, err := store.GetAccount(context.Background(), existingAccountID)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
peer1 := &nbpeer.Peer{
|
peer1 := &nbpeer.Peer{
|
||||||
ID: "peer1",
|
ID: "peer1",
|
||||||
AccountID: existingAccountID,
|
AccountID: existingAccountID,
|
||||||
@ -1000,21 +1036,39 @@ func TestSqlite_GetPeerLabelsInAccount(t *testing.T) {
|
|||||||
err = store.AddPeerToAccount(context.Background(), LockingStrengthUpdate, peer1)
|
err = store.AddPeerToAccount(context.Background(), LockingStrengthUpdate, peer1)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
labels, err = store.GetPeerLabelsInAccount(context.Background(), LockingStrengthShare, existingAccountID)
|
|
||||||
require.NoError(t, err)
|
|
||||||
assert.Equal(t, []string{"peer1.domain.test"}, labels)
|
|
||||||
|
|
||||||
peer2 := &nbpeer.Peer{
|
peer2 := &nbpeer.Peer{
|
||||||
ID: "peer2",
|
ID: "peer1second",
|
||||||
AccountID: existingAccountID,
|
AccountID: existingAccountID,
|
||||||
DNSLabel: "peer2.domain.test",
|
DNSLabel: "peer1.domain.test",
|
||||||
}
|
}
|
||||||
err = store.AddPeerToAccount(context.Background(), LockingStrengthUpdate, peer2)
|
err = store.AddPeerToAccount(context.Background(), LockingStrengthUpdate, peer2)
|
||||||
|
require.Error(t, err)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func Test_AddPeerWithSameIP(t *testing.T) {
|
||||||
|
runTestForAllEngines(t, "../testdata/extended-store.sql", func(t *testing.T, store Store) {
|
||||||
|
existingAccountID := "bf1c8084-ba50-4ce7-9439-34653001fc3b"
|
||||||
|
|
||||||
|
_, err := store.GetAccount(context.Background(), existingAccountID)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
labels, err = store.GetPeerLabelsInAccount(context.Background(), LockingStrengthShare, existingAccountID)
|
peer1 := &nbpeer.Peer{
|
||||||
|
ID: "peer1",
|
||||||
|
AccountID: existingAccountID,
|
||||||
|
IP: net.IP{1, 1, 1, 1},
|
||||||
|
}
|
||||||
|
err = store.AddPeerToAccount(context.Background(), LockingStrengthUpdate, peer1)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
assert.Equal(t, []string{"peer1.domain.test", "peer2.domain.test"}, labels)
|
|
||||||
|
peer2 := &nbpeer.Peer{
|
||||||
|
ID: "peer1second",
|
||||||
|
AccountID: existingAccountID,
|
||||||
|
IP: net.IP{1, 1, 1, 1},
|
||||||
|
}
|
||||||
|
err = store.AddPeerToAccount(context.Background(), LockingStrengthUpdate, peer2)
|
||||||
|
require.Error(t, err)
|
||||||
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestSqlite_GetAccountNetwork(t *testing.T) {
|
func TestSqlite_GetAccountNetwork(t *testing.T) {
|
||||||
|
@ -117,7 +117,7 @@ type Store interface {
|
|||||||
SavePostureChecks(ctx context.Context, lockStrength LockingStrength, postureCheck *posture.Checks) error
|
SavePostureChecks(ctx context.Context, lockStrength LockingStrength, postureCheck *posture.Checks) error
|
||||||
DeletePostureChecks(ctx context.Context, lockStrength LockingStrength, accountID, postureChecksID string) error
|
DeletePostureChecks(ctx context.Context, lockStrength LockingStrength, accountID, postureChecksID string) error
|
||||||
|
|
||||||
GetPeerLabelsInAccount(ctx context.Context, lockStrength LockingStrength, accountId string) ([]string, error)
|
GetPeerLabelsInAccount(ctx context.Context, lockStrength LockingStrength, accountId string, hostname string) ([]string, error)
|
||||||
AddPeerToAllGroup(ctx context.Context, lockStrength LockingStrength, accountID string, peerID string) error
|
AddPeerToAllGroup(ctx context.Context, lockStrength LockingStrength, accountID string, peerID string) error
|
||||||
AddPeerToGroup(ctx context.Context, lockStrength LockingStrength, accountId string, peerId string, groupID string) error
|
AddPeerToGroup(ctx context.Context, lockStrength LockingStrength, accountId string, peerId string, groupID string) error
|
||||||
GetPeerGroups(ctx context.Context, lockStrength LockingStrength, accountId string, peerId string) ([]*types.Group, error)
|
GetPeerGroups(ctx context.Context, lockStrength LockingStrength, accountId string, peerId string) ([]*types.Group, error)
|
||||||
@ -193,6 +193,7 @@ type Store interface {
|
|||||||
SaveNetworkResource(ctx context.Context, lockStrength LockingStrength, resource *resourceTypes.NetworkResource) error
|
SaveNetworkResource(ctx context.Context, lockStrength LockingStrength, resource *resourceTypes.NetworkResource) error
|
||||||
DeleteNetworkResource(ctx context.Context, lockStrength LockingStrength, accountID, resourceID string) error
|
DeleteNetworkResource(ctx context.Context, lockStrength LockingStrength, accountID, resourceID string) error
|
||||||
GetPeerByIP(ctx context.Context, lockStrength LockingStrength, accountID string, ip net.IP) (*nbpeer.Peer, error)
|
GetPeerByIP(ctx context.Context, lockStrength LockingStrength, accountID string, ip net.IP) (*nbpeer.Peer, error)
|
||||||
|
GetPeerIdByLabel(ctx context.Context, lockStrength LockingStrength, accountID string, hostname string) (string, error)
|
||||||
}
|
}
|
||||||
|
|
||||||
const (
|
const (
|
||||||
@ -234,9 +235,9 @@ func getStoreEngine(ctx context.Context, dataDir string, kind types.Engine) type
|
|||||||
if util.FileExists(jsonStoreFile) && !util.FileExists(sqliteStoreFile) {
|
if util.FileExists(jsonStoreFile) && !util.FileExists(sqliteStoreFile) {
|
||||||
log.WithContext(ctx).Warnf("unsupported store engine specified, but found %s. Automatically migrating to SQLite.", jsonStoreFile)
|
log.WithContext(ctx).Warnf("unsupported store engine specified, but found %s. Automatically migrating to SQLite.", jsonStoreFile)
|
||||||
|
|
||||||
// Attempt to migrate from JSON store to SQLite
|
// Attempt to migratePreAuto from JSON store to SQLite
|
||||||
if err := MigrateFileStoreToSqlite(ctx, dataDir); err != nil {
|
if err := MigrateFileStoreToSqlite(ctx, dataDir); err != nil {
|
||||||
log.WithContext(ctx).Errorf("failed to migrate filestore to SQLite: %v", err)
|
log.WithContext(ctx).Errorf("failed to migratePreAuto filestore to SQLite: %v", err)
|
||||||
kind = types.FileStoreEngine
|
kind = types.FileStoreEngine
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -280,9 +281,9 @@ func checkFileStoreEngine(kind types.Engine, dataDir string) error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// migrate migrates the SQLite database to the latest schema
|
// migratePreAuto migrates the SQLite database to the latest schema
|
||||||
func migrate(ctx context.Context, db *gorm.DB) error {
|
func migratePreAuto(ctx context.Context, db *gorm.DB) error {
|
||||||
migrations := getMigrations(ctx)
|
migrations := getMigrationsPreAuto(ctx)
|
||||||
|
|
||||||
for _, m := range migrations {
|
for _, m := range migrations {
|
||||||
if err := m(db); err != nil {
|
if err := m(db); err != nil {
|
||||||
@ -293,7 +294,7 @@ func migrate(ctx context.Context, db *gorm.DB) error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func getMigrations(ctx context.Context) []migrationFunc {
|
func getMigrationsPreAuto(ctx context.Context) []migrationFunc {
|
||||||
return []migrationFunc{
|
return []migrationFunc{
|
||||||
func(db *gorm.DB) error {
|
func(db *gorm.DB) error {
|
||||||
return migration.MigrateFieldFromGobToJSON[types.Account, net.IPNet](ctx, db, "network_net")
|
return migration.MigrateFieldFromGobToJSON[types.Account, net.IPNet](ctx, db, "network_net")
|
||||||
@ -329,6 +330,28 @@ func getMigrations(ctx context.Context) []migrationFunc {
|
|||||||
return migration.DropIndex[routerTypes.NetworkRouter](ctx, db, "idx_network_routers_id")
|
return migration.DropIndex[routerTypes.NetworkRouter](ctx, db, "idx_network_routers_id")
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
} // migratePostAuto migrates the SQLite database to the latest schema
|
||||||
|
func migratePostAuto(ctx context.Context, db *gorm.DB) error {
|
||||||
|
migrations := getMigrationsPostAuto(ctx)
|
||||||
|
|
||||||
|
for _, m := range migrations {
|
||||||
|
if err := m(db); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func getMigrationsPostAuto(ctx context.Context) []migrationFunc {
|
||||||
|
return []migrationFunc{
|
||||||
|
func(db *gorm.DB) error {
|
||||||
|
return migration.CreateIndexIfNotExists[nbpeer.Peer](ctx, db, "idx_account_ip", "account_id", "ip")
|
||||||
|
},
|
||||||
|
func(db *gorm.DB) error {
|
||||||
|
return migration.CreateIndexIfNotExists[nbpeer.Peer](ctx, db, "idx_account_dnslabel", "account_id", "dns_label")
|
||||||
|
},
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewTestStoreFromSQL is only used in tests. It will create a test database base of the store engine set in env.
|
// NewTestStoreFromSQL is only used in tests. It will create a test database base of the store engine set in env.
|
||||||
@ -577,7 +600,7 @@ func MigrateFileStoreToSqlite(ctx context.Context, dataDir string) error {
|
|||||||
|
|
||||||
sqliteStoreAccounts := len(store.GetAllAccounts(ctx))
|
sqliteStoreAccounts := len(store.GetAllAccounts(ctx))
|
||||||
if fsStoreAccounts != sqliteStoreAccounts {
|
if fsStoreAccounts != sqliteStoreAccounts {
|
||||||
return fmt.Errorf("failed to migrate accounts from file to sqlite. Expected accounts: %d, got: %d",
|
return fmt.Errorf("failed to migratePreAuto accounts from file to sqlite. Expected accounts: %d, got: %d",
|
||||||
fsStoreAccounts, sqliteStoreAccounts)
|
fsStoreAccounts, sqliteStoreAccounts)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
2
management/server/testdata/store.sql
vendored
2
management/server/testdata/store.sql
vendored
@ -52,4 +52,4 @@ INSERT INTO policy_rules VALUES('cs387mkv2d4bgq41b6n0','cs1tnh0hhcjnqoiuebf0','D
|
|||||||
INSERT INTO network_routers VALUES('ctc20ji7qv9ck2sebc80','ct286bi7qv930dsrrug0','bf1c8084-ba50-4ce7-9439-34653001fc3b','cs1tnh0hhcjnqoiuebeg',NULL,0,0);
|
INSERT INTO network_routers VALUES('ctc20ji7qv9ck2sebc80','ct286bi7qv930dsrrug0','bf1c8084-ba50-4ce7-9439-34653001fc3b','cs1tnh0hhcjnqoiuebeg',NULL,0,0);
|
||||||
INSERT INTO network_resources VALUES ('ctc4nci7qv9061u6ilfg','ct286bi7qv930dsrrug0','bf1c8084-ba50-4ce7-9439-34653001fc3b','Host','192.168.1.1');
|
INSERT INTO network_resources VALUES ('ctc4nci7qv9061u6ilfg','ct286bi7qv930dsrrug0','bf1c8084-ba50-4ce7-9439-34653001fc3b','Host','192.168.1.1');
|
||||||
INSERT INTO networks VALUES('ct286bi7qv930dsrrug0','bf1c8084-ba50-4ce7-9439-34653001fc3b','Test Network','Test Network');
|
INSERT INTO networks VALUES('ct286bi7qv930dsrrug0','bf1c8084-ba50-4ce7-9439-34653001fc3b','Test Network','Test Network');
|
||||||
INSERT INTO peers VALUES('ct286bi7qv930dsrrug0','bf1c8084-ba50-4ce7-9439-34653001fc3b','','','192.168.0.0','','','','','','','','','','','','','','','','','test','test','2023-01-01 00:00:00+00:00',0,0,0,'a23efe53-63fb-11ec-90d6-0242ac120003','',0,0,'2023-01-01 00:00:00+00:00','2023-01-01 00:00:00+00:00',0,'','','',0);
|
INSERT INTO peers VALUES('ct286bi7qv930dsrrug0','bf1c8084-ba50-4ce7-9439-34653001fc3b','','','"192.168.0.0"','','','','','','','','','','','','','','','','','test','test','2023-01-01 00:00:00+00:00',0,0,0,'a23efe53-63fb-11ec-90d6-0242ac120003','',0,0,'2023-01-01 00:00:00+00:00','2023-01-01 00:00:00+00:00',0,'','','',0);
|
||||||
|
@ -30,7 +30,7 @@ INSERT INTO setup_keys VALUES('','bf1c8084-ba50-4ce7-9439-34653001fc3b','A2C8E62
|
|||||||
INSERT INTO peers VALUES('cfvprsrlo1hqoo49ohog','bf1c8084-ba50-4ce7-9439-34653001fc3b','5rvhvriKJZ3S9oxYToVj5TzDM9u9y8cxg7htIMWlYAg=','72546A29-6BC8-4311-BCFC-9CDBF33F1A48','"100.64.114.31"','f2a34f6a4731','linux','Linux','11','unknown','Debian GNU/Linux','','0.12.0','','',NULL,'','','','{"Cloud":"","Platform":""}',NULL,'f2a34f6a4731','f2a34f6a4731','2023-03-02 09:21:02.189035775+01:00',0,0,0,'','ssh-ed25519 AAAAC3NzaC1lZDI1NTE5AAAAILzUUSYG/LGnV8zarb2SGN+tib/PZ+M7cL4WtTzUrTpk',0,1,0,'2023-03-01 19:48:19.817799698+01:00','2024-10-02 17:00:32.527947+02:00',0,'""','','',0);
|
INSERT INTO peers VALUES('cfvprsrlo1hqoo49ohog','bf1c8084-ba50-4ce7-9439-34653001fc3b','5rvhvriKJZ3S9oxYToVj5TzDM9u9y8cxg7htIMWlYAg=','72546A29-6BC8-4311-BCFC-9CDBF33F1A48','"100.64.114.31"','f2a34f6a4731','linux','Linux','11','unknown','Debian GNU/Linux','','0.12.0','','',NULL,'','','','{"Cloud":"","Platform":""}',NULL,'f2a34f6a4731','f2a34f6a4731','2023-03-02 09:21:02.189035775+01:00',0,0,0,'','ssh-ed25519 AAAAC3NzaC1lZDI1NTE5AAAAILzUUSYG/LGnV8zarb2SGN+tib/PZ+M7cL4WtTzUrTpk',0,1,0,'2023-03-01 19:48:19.817799698+01:00','2024-10-02 17:00:32.527947+02:00',0,'""','','',0);
|
||||||
INSERT INTO peers VALUES('cg05lnblo1hkg2j514p0','bf1c8084-ba50-4ce7-9439-34653001fc3b','RlSy2vzoG2HyMBTUImXOiVhCBiiBa5qD5xzMxkiFDW4=','','"100.64.39.54"','expiredhost','linux','Linux','22.04','x86_64','Ubuntu','','development','','',NULL,'','','','{"Cloud":"","Platform":""}',NULL,'expiredhost','expiredhost','2023-03-02 09:19:57.276717255+01:00',0,1,0,'edafee4e-63fb-11ec-90d6-0242ac120003','ssh-ed25519 AAAAC3NzaC1lZDI1NTE5AAAAIMbK5ZXJsGOOWoBT4OmkPtgdPZe2Q7bDuS/zjn2CZxhK',0,1,0,'2023-03-02 09:14:21.791679181+01:00','2024-10-02 17:00:32.527947+02:00',0,'""','','',0);
|
INSERT INTO peers VALUES('cg05lnblo1hkg2j514p0','bf1c8084-ba50-4ce7-9439-34653001fc3b','RlSy2vzoG2HyMBTUImXOiVhCBiiBa5qD5xzMxkiFDW4=','','"100.64.39.54"','expiredhost','linux','Linux','22.04','x86_64','Ubuntu','','development','','',NULL,'','','','{"Cloud":"","Platform":""}',NULL,'expiredhost','expiredhost','2023-03-02 09:19:57.276717255+01:00',0,1,0,'edafee4e-63fb-11ec-90d6-0242ac120003','ssh-ed25519 AAAAC3NzaC1lZDI1NTE5AAAAIMbK5ZXJsGOOWoBT4OmkPtgdPZe2Q7bDuS/zjn2CZxhK',0,1,0,'2023-03-02 09:14:21.791679181+01:00','2024-10-02 17:00:32.527947+02:00',0,'""','','',0);
|
||||||
INSERT INTO peers VALUES('cg3161rlo1hs9cq94gdg','bf1c8084-ba50-4ce7-9439-34653001fc3b','mVABSKj28gv+JRsf7e0NEGKgSOGTfU/nPB2cpuG56HU=','','"100.64.117.96"','testhost','linux','Linux','22.04','x86_64','Ubuntu','','development','','',NULL,'','','','{"Cloud":"","Platform":""}',NULL,'testhost','testhost','2023-03-06 18:21:27.252010027+01:00',0,0,0,'edafee4e-63fb-11ec-90d6-0242ac120003','ssh-ed25519 AAAAC3NzaC1lZDI1NTE5AAAAINWvvUkFFcrj48CWTkNUb/do/n52i1L5dH4DhGu+4ZuM',0,0,0,'2023-03-07 09:02:47.442857106+01:00','2024-10-02 17:00:32.527947+02:00',0,'""','','',0);
|
INSERT INTO peers VALUES('cg3161rlo1hs9cq94gdg','bf1c8084-ba50-4ce7-9439-34653001fc3b','mVABSKj28gv+JRsf7e0NEGKgSOGTfU/nPB2cpuG56HU=','','"100.64.117.96"','testhost','linux','Linux','22.04','x86_64','Ubuntu','','development','','',NULL,'','','','{"Cloud":"","Platform":""}',NULL,'testhost','testhost','2023-03-06 18:21:27.252010027+01:00',0,0,0,'edafee4e-63fb-11ec-90d6-0242ac120003','ssh-ed25519 AAAAC3NzaC1lZDI1NTE5AAAAINWvvUkFFcrj48CWTkNUb/do/n52i1L5dH4DhGu+4ZuM',0,0,0,'2023-03-07 09:02:47.442857106+01:00','2024-10-02 17:00:32.527947+02:00',0,'""','','',0);
|
||||||
INSERT INTO peers VALUES('csrnkiq7qv9d8aitqd50','bf1c8084-ba50-4ce7-9439-34653001fc3b','mVABSKj28gv+JRsf7e0NEGKgSOGTfU/nPB2cpuG56HU=','','"100.64.117.96"','testhost','linux','Linux','22.04','x86_64','Ubuntu','','development','','',NULL,'','','','{"Cloud":"","Platform":""}',NULL,'testhost','testhost','2023-03-06 18:21:27.252010027+01:00',0,0,0,'f4f6d672-63fb-11ec-90d6-0242ac120003','ssh-ed25519 AAAAC3NzaC1lZDI1NTE5AAAAINWvvUkFFcrj48CWTkNUb/do/n52i1L5dH4DhGu+4ZuM',0,0,1,'2023-03-07 09:02:47.442857106+01:00','2024-10-02 17:00:32.527947+02:00',0,'""','','',0);
|
INSERT INTO peers VALUES('csrnkiq7qv9d8aitqd50','bf1c8084-ba50-4ce7-9439-34653001fc3b','mVABSKj28gv+JRsf7e0NEGKgSOGTfU/nPB2cpuG56HU=','','"100.64.117.97"','testhost','linux','Linux','22.04','x86_64','Ubuntu','','development','','',NULL,'','','','{"Cloud":"","Platform":""}',NULL,'testhost','testhost-1','2023-03-06 18:21:27.252010027+01:00',0,0,0,'f4f6d672-63fb-11ec-90d6-0242ac120003','ssh-ed25519 AAAAC3NzaC1lZDI1NTE5AAAAINWvvUkFFcrj48CWTkNUb/do/n52i1L5dH4DhGu+4ZuM',0,0,1,'2023-03-07 09:02:47.442857106+01:00','2024-10-02 17:00:32.527947+02:00',0,'""','','',0);
|
||||||
INSERT INTO users VALUES('f4f6d672-63fb-11ec-90d6-0242ac120003','bf1c8084-ba50-4ce7-9439-34653001fc3b','user',0,0,'','[]',0,NULL,'2024-10-02 17:00:32.528196+02:00','api',0,'');
|
INSERT INTO users VALUES('f4f6d672-63fb-11ec-90d6-0242ac120003','bf1c8084-ba50-4ce7-9439-34653001fc3b','user',0,0,'','[]',0,NULL,'2024-10-02 17:00:32.528196+02:00','api',0,'');
|
||||||
INSERT INTO users VALUES('edafee4e-63fb-11ec-90d6-0242ac120003','bf1c8084-ba50-4ce7-9439-34653001fc3b','admin',0,0,'','[]',0,NULL,'2024-10-02 17:00:32.528196+02:00','api',0,'');
|
INSERT INTO users VALUES('edafee4e-63fb-11ec-90d6-0242ac120003','bf1c8084-ba50-4ce7-9439-34653001fc3b','admin',0,0,'','[]',0,NULL,'2024-10-02 17:00:32.528196+02:00','api',0,'');
|
||||||
INSERT INTO installations VALUES(1,'');
|
INSERT INTO installations VALUES(1,'');
|
||||||
|
@ -1,6 +1,7 @@
|
|||||||
package types
|
package types
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"encoding/binary"
|
||||||
"math/rand"
|
"math/rand"
|
||||||
"net"
|
"net"
|
||||||
"sync"
|
"sync"
|
||||||
@ -161,24 +162,65 @@ func (n *Network) Copy() *Network {
|
|||||||
// This method considers already taken IPs and reuses IPs if there are gaps in takenIps
|
// This method considers already taken IPs and reuses IPs if there are gaps in takenIps
|
||||||
// E.g. if ipNet=100.30.0.0/16 and takenIps=[100.30.0.1, 100.30.0.4] then the result would be 100.30.0.2 or 100.30.0.3
|
// E.g. if ipNet=100.30.0.0/16 and takenIps=[100.30.0.1, 100.30.0.4] then the result would be 100.30.0.2 or 100.30.0.3
|
||||||
func AllocatePeerIP(ipNet net.IPNet, takenIps []net.IP) (net.IP, error) {
|
func AllocatePeerIP(ipNet net.IPNet, takenIps []net.IP) (net.IP, error) {
|
||||||
takenIPMap := make(map[string]struct{})
|
baseIP := ipToUint32(ipNet.IP.Mask(ipNet.Mask))
|
||||||
takenIPMap[ipNet.IP.String()] = struct{}{}
|
totalIPs := uint32(1 << SubnetSize)
|
||||||
|
|
||||||
|
taken := make(map[uint32]struct{}, len(takenIps)+1)
|
||||||
|
taken[baseIP] = struct{}{} // reserve network IP
|
||||||
|
taken[baseIP+totalIPs-1] = struct{}{} // reserve broadcast IP
|
||||||
|
|
||||||
for _, ip := range takenIps {
|
for _, ip := range takenIps {
|
||||||
takenIPMap[ip.String()] = struct{}{}
|
taken[ipToUint32(ip)] = struct{}{}
|
||||||
}
|
}
|
||||||
|
|
||||||
ips, _ := generateIPs(&ipNet, takenIPMap)
|
rng := rand.New(rand.NewSource(time.Now().UnixNano()))
|
||||||
|
maxAttempts := (int(totalIPs) - len(taken)) / 100
|
||||||
|
|
||||||
if len(ips) == 0 {
|
for i := 0; i < maxAttempts; i++ {
|
||||||
return nil, status.Errorf(status.PreconditionFailed, "failed allocating new IP for the ipNet %s - network is out of IPs", ipNet.String())
|
offset := uint32(rng.Intn(int(totalIPs-2))) + 1
|
||||||
|
candidate := baseIP + offset
|
||||||
|
if _, exists := taken[candidate]; !exists {
|
||||||
|
return uint32ToIP(candidate), nil
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// pick a random IP
|
for offset := uint32(1); offset < totalIPs-1; offset++ {
|
||||||
s := rand.NewSource(time.Now().Unix())
|
candidate := baseIP + offset
|
||||||
r := rand.New(s)
|
if _, exists := taken[candidate]; !exists {
|
||||||
intn := r.Intn(len(ips))
|
return uint32ToIP(candidate), nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
return ips[intn], nil
|
return nil, status.Errorf(status.PreconditionFailed, "network %s is out of IPs", ipNet.String())
|
||||||
|
}
|
||||||
|
|
||||||
|
func AllocateRandomPeerIP(ipNet net.IPNet) (net.IP, error) {
|
||||||
|
baseIP := ipToUint32(ipNet.IP.Mask(ipNet.Mask))
|
||||||
|
|
||||||
|
ones, bits := ipNet.Mask.Size()
|
||||||
|
hostBits := bits - ones
|
||||||
|
|
||||||
|
totalIPs := uint32(1 << hostBits)
|
||||||
|
|
||||||
|
rng := rand.New(rand.NewSource(time.Now().UnixNano()))
|
||||||
|
offset := uint32(rng.Intn(int(totalIPs-2))) + 1
|
||||||
|
|
||||||
|
candidate := baseIP + offset
|
||||||
|
return uint32ToIP(candidate), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func ipToUint32(ip net.IP) uint32 {
|
||||||
|
ip = ip.To4()
|
||||||
|
if len(ip) < 4 {
|
||||||
|
return 0
|
||||||
|
}
|
||||||
|
return binary.BigEndian.Uint32(ip)
|
||||||
|
}
|
||||||
|
|
||||||
|
func uint32ToIP(n uint32) net.IP {
|
||||||
|
ip := make(net.IP, 4)
|
||||||
|
binary.BigEndian.PutUint32(ip, n)
|
||||||
|
return ip
|
||||||
}
|
}
|
||||||
|
|
||||||
// generateIPs generates a list of all possible IPs of the given network excluding IPs specified in the exclusion list
|
// generateIPs generates a list of all possible IPs of the given network excluding IPs specified in the exclusion list
|
||||||
|
Reference in New Issue
Block a user