From 9e47c94a7f80680c64562133e84cc2ceb03256f5 Mon Sep 17 00:00:00 2001 From: bcmmbaga Date: Mon, 30 Sep 2024 14:02:55 +0300 Subject: [PATCH] refactor setup keys Signed-off-by: bcmmbaga --- management/server/file_store.go | 8 +++ management/server/setupkey.go | 98 ++++++++++++++++++++------------- management/server/sql_store.go | 12 ++++ management/server/store.go | 2 + 4 files changed, 81 insertions(+), 39 deletions(-) diff --git a/management/server/file_store.go b/management/server/file_store.go index 37ad59291..2f76a40c7 100644 --- a/management/server/file_store.go +++ b/management/server/file_store.go @@ -1051,6 +1051,14 @@ func (s *FileStore) GetSetupKeyByID(_ context.Context, _ LockingStrength, _ stri return nil, status.Errorf(status.Internal, "GetSetupKeyByID is not implemented") } +func (s *FileStore) SaveSetupKey(_ context.Context, _ LockingStrength, _ *SetupKey) error { + return status.Errorf(status.Internal, "GetSetupKeyByID is not implemented") +} + +func (s *FileStore) DeleteSetupKey(_ context.Context, _ LockingStrength, _ string) error { + return status.Errorf(status.Internal, "DeleteSetupKey is not implemented") +} + func (s *FileStore) GetAccountNameServerGroups(_ context.Context, _ LockingStrength, _ string) ([]*dns.NameServerGroup, error) { return nil, status.Errorf(status.Internal, "GetAccountNameServerGroups is not implemented") } diff --git a/management/server/setupkey.go b/management/server/setupkey.go index 9521e22d3..41cf894ec 100644 --- a/management/server/setupkey.go +++ b/management/server/setupkey.go @@ -9,6 +9,7 @@ import ( "unicode/utf8" "github.com/google/uuid" + nbgroup "github.com/netbirdio/netbird/management/server/group" log "github.com/sirupsen/logrus" "github.com/netbirdio/netbird/management/server/activity" @@ -210,39 +211,49 @@ func Hash(s string) uint32 { // and adds it to the specified account. A list of autoGroups IDs can be empty. func (am *DefaultAccountManager) CreateSetupKey(ctx context.Context, accountID string, keyName string, keyType SetupKeyType, expiresIn time.Duration, autoGroups []string, usageLimit int, userID string, ephemeral bool) (*SetupKey, error) { - unlock := am.Store.AcquireWriteLockByUID(ctx, accountID) - defer unlock() + user, err := am.Store.GetUserByUserID(ctx, LockingStrengthShare, userID) + if err != nil { + return nil, err + } + + if user.AccountID != accountID { + return nil, status.Errorf(status.PermissionDenied, "only users with admin power can update setup keys") + } keyDuration := DefaultSetupKeyDuration if expiresIn != 0 { keyDuration = expiresIn } - account, err := am.Store.GetAccount(ctx, accountID) + groups, err := am.Store.GetAccountGroups(ctx, accountID) if err != nil { return nil, err } - if err := validateSetupKeyAutoGroups(account, autoGroups); err != nil { + if err = validateSetupKeyAutoGroups(groups, autoGroups); err != nil { return nil, err } setupKey := GenerateSetupKey(keyName, keyType, keyDuration, autoGroups, usageLimit, ephemeral) - account.SetupKeys[setupKey.Key] = setupKey - err = am.Store.SaveAccount(ctx, account) - if err != nil { - return nil, status.Errorf(status.Internal, "failed adding account key") + setupKey.AccountID = accountID + + if err = am.Store.SaveSetupKey(ctx, LockingStrengthUpdate, setupKey); err != nil { + return nil, err } am.StoreEvent(ctx, userID, setupKey.Id, accountID, activity.SetupKeyCreated, setupKey.EventMeta()) + groupMap := make(map[string]*nbgroup.Group, len(groups)) + for _, g := range groups { + groupMap[g.ID] = g + } for _, g := range setupKey.AutoGroups { - group := account.GetGroup(g) + group := groupMap[g] if group != nil { am.StoreEvent(ctx, userID, setupKey.Id, accountID, activity.GroupAddedToSetupKey, map[string]any{"group": group.Name, "group_id": group.ID, "setupkey": setupKey.Name}) } else { - log.WithContext(ctx).Errorf("group %s not found while saving setup key activity event of account %s", g, account.Id) + log.WithContext(ctx).Errorf("group %s not found while saving setup key activity event of account %s", g, accountID) } } @@ -254,30 +265,30 @@ func (am *DefaultAccountManager) CreateSetupKey(ctx context.Context, accountID s // (e.g. the key itself, creation date, ID, etc). // These properties are overwritten: Name, AutoGroups, Revoked. The rest is copied from the existing key. func (am *DefaultAccountManager) SaveSetupKey(ctx context.Context, accountID string, keyToSave *SetupKey, userID string) (*SetupKey, error) { - unlock := am.Store.AcquireWriteLockByUID(ctx, accountID) - defer unlock() - if keyToSave == nil { return nil, status.Errorf(status.InvalidArgument, "provided setup key to update is nil") } - account, err := am.Store.GetAccount(ctx, accountID) + user, err := am.Store.GetUserByUserID(ctx, LockingStrengthShare, userID) if err != nil { return nil, err } - var oldKey *SetupKey - for _, key := range account.SetupKeys { - if key.Id == keyToSave.Id { - oldKey = key.Copy() - break - } - } - if oldKey == nil { - return nil, status.Errorf(status.NotFound, "setup key not found") + if user.AccountID != accountID { + return nil, status.Errorf(status.PermissionDenied, "only users with admin power can update setup keys") } - if err := validateSetupKeyAutoGroups(account, keyToSave.AutoGroups); err != nil { + groups, err := am.Store.GetAccountGroups(ctx, accountID) + if err != nil { + return nil, err + } + + if err = validateSetupKeyAutoGroups(groups, keyToSave.AutoGroups); err != nil { + return nil, err + } + + oldKey, err := am.Store.GetSetupKeyByID(ctx, LockingStrengthShare, keyToSave.Id, accountID) + if err != nil { return nil, err } @@ -288,9 +299,7 @@ func (am *DefaultAccountManager) SaveSetupKey(ctx context.Context, accountID str newKey.Revoked = keyToSave.Revoked newKey.UpdatedAt = time.Now().UTC() - account.SetupKeys[newKey.Key] = newKey - - if err = am.Store.SaveAccount(ctx, account); err != nil { + if err = am.Store.SaveSetupKey(ctx, LockingStrengthUpdate, newKey); err != nil { return nil, err } @@ -301,30 +310,34 @@ func (am *DefaultAccountManager) SaveSetupKey(ctx context.Context, accountID str defer func() { addedGroups := difference(newKey.AutoGroups, oldKey.AutoGroups) removedGroups := difference(oldKey.AutoGroups, newKey.AutoGroups) + + groupMap := make(map[string]*nbgroup.Group, len(groups)) + for _, g := range groups { + groupMap[g.ID] = g + } + for _, g := range removedGroups { - group := account.GetGroup(g) + group := groupMap[g] if group != nil { am.StoreEvent(ctx, userID, oldKey.Id, accountID, activity.GroupRemovedFromSetupKey, map[string]any{"group": group.Name, "group_id": group.ID, "setupkey": newKey.Name}) } else { - log.WithContext(ctx).Errorf("group %s not found while saving setup key activity event of account %s", g, account.Id) + log.WithContext(ctx).Errorf("group %s not found while saving setup key activity event of account %s", g, accountID) } } for _, g := range addedGroups { - group := account.GetGroup(g) + group := groupMap[g] if group != nil { am.StoreEvent(ctx, userID, oldKey.Id, accountID, activity.GroupAddedToSetupKey, map[string]any{"group": group.Name, "group_id": group.ID, "setupkey": newKey.Name}) } else { - log.WithContext(ctx).Errorf("group %s not found while saving setup key activity event of account %s", g, account.Id) + log.WithContext(ctx).Errorf("group %s not found while saving setup key activity event of account %s", g, accountID) } } }() - am.updateAccountPeers(ctx, account) - return newKey, nil } @@ -386,15 +399,22 @@ func (am *DefaultAccountManager) GetSetupKey(ctx context.Context, accountID, use return setupKey, nil } -func validateSetupKeyAutoGroups(account *Account, autoGroups []string) error { - for _, group := range autoGroups { - g, ok := account.Groups[group] - if !ok { - return status.Errorf(status.NotFound, "group %s doesn't exist", group) +func validateSetupKeyAutoGroups(groups []*nbgroup.Group, autoGroups []string) error { + groupMap := make(map[string]*nbgroup.Group, len(groups)) + for _, g := range groups { + groupMap[g.ID] = g + } + + for _, groupID := range autoGroups { + g, exists := groupMap[groupID] + if !exists { + return status.Errorf(status.NotFound, "group %s doesn't exist", groupID) } + if g.Name == "All" { - return status.Errorf(status.InvalidArgument, "can't add All group to the setup key") + return status.Errorf(status.InvalidArgument, "can't add 'All' group to the setup key") } } + return nil } diff --git a/management/server/sql_store.go b/management/server/sql_store.go index 02d186804..6fa568aaa 100644 --- a/management/server/sql_store.go +++ b/management/server/sql_store.go @@ -1180,6 +1180,18 @@ func (s *SqlStore) GetSetupKeyByID(ctx context.Context, lockStrength LockingStre return getRecordByID[SetupKey](s.db.WithContext(ctx), lockStrength, setupKeyID, accountID) } +// SaveSetupKey saves a setup key to the database. +func (s *SqlStore) SaveSetupKey(ctx context.Context, lockStrength LockingStrength, setupKey *SetupKey) error { + return s.db.WithContext(ctx).Session(&gorm.Session{FullSaveAssociations: true}). + Clauses(clause.Locking{Strength: string(lockStrength)}).Save(&setupKey).Error +} + +// DeleteSetupKey deletes a setup key from the database. +func (s *SqlStore) DeleteSetupKey(ctx context.Context, lockStrength LockingStrength, setupKeyID string) error { + return s.db.WithContext(ctx).Clauses(clause.Locking{Strength: string(lockStrength)}). + Delete(&posture.Checks{}, idQueryCondition, setupKeyID).Error +} + // GetAccountNameServerGroups retrieves name server groups for an account. func (s *SqlStore) GetAccountNameServerGroups(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*nbdns.NameServerGroup, error) { return getRecords[*nbdns.NameServerGroup](s.db.WithContext(ctx), lockStrength, accountID) diff --git a/management/server/store.go b/management/server/store.go index 892bb15fe..26ca56474 100644 --- a/management/server/store.go +++ b/management/server/store.go @@ -93,6 +93,8 @@ type Store interface { IncrementSetupKeyUsage(ctx context.Context, setupKeyID string) error GetAccountSetupKeys(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*SetupKey, error) GetSetupKeyByID(ctx context.Context, lockStrength LockingStrength, setupKeyID string, accountID string) (*SetupKey, error) + SaveSetupKey(ctx context.Context, lockStrength LockingStrength, setupKey *SetupKey) error + DeleteSetupKey(ctx context.Context, lockStrength LockingStrength, setupKeyID string) error GetAccountRoutes(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*route.Route, error) GetRouteByID(ctx context.Context, lockStrength LockingStrength, routeID string, accountID string) (*route.Route, error)