From e3f3d2c1bdb8eb44d638581924df506188da7fa7 Mon Sep 17 00:00:00 2001 From: Pascal Fischer Date: Fri, 4 Oct 2024 15:17:28 +0200 Subject: [PATCH] update transaction logic --- management/server/account.go | 98 +++++++++++++++++------------ management/server/sql_store_test.go | 34 ++++++++++ management/server/user.go | 22 +++---- 3 files changed, 102 insertions(+), 52 deletions(-) diff --git a/management/server/account.go b/management/server/account.go index 8ad9bb536..bb3950629 100644 --- a/management/server/account.go +++ b/management/server/account.go @@ -20,6 +20,11 @@ import ( cacheStore "github.com/eko/gocache/v3/store" "github.com/hashicorp/go-multierror" "github.com/miekg/dns" + gocache "github.com/patrickmn/go-cache" + "github.com/rs/xid" + log "github.com/sirupsen/logrus" + "golang.org/x/exp/maps" + "github.com/netbirdio/netbird/base62" nbdns "github.com/netbirdio/netbird/dns" "github.com/netbirdio/netbird/management/domain" @@ -36,10 +41,6 @@ import ( "github.com/netbirdio/netbird/management/server/status" "github.com/netbirdio/netbird/management/server/telemetry" "github.com/netbirdio/netbird/route" - gocache "github.com/patrickmn/go-cache" - "github.com/rs/xid" - log "github.com/sirupsen/logrus" - "golang.org/x/exp/maps" ) const ( @@ -846,17 +847,7 @@ func (a *Account) GetPeer(peerID string) *nbpeer.Peer { // getJWTGroupsChanges calculates the changes needed to sync a user's JWT groups. // Returns a bool indicating if there are changes in the JWT group membership, the updated user AutoGroups, // newly groups to create and an error if any occurred. -func (am *DefaultAccountManager) getJWTGroupsChanges(ctx context.Context, userID, accountID string, groupNames []string) (bool, []string, []*nbgroup.Group, error) { - user, err := am.Store.GetUserByUserID(ctx, LockingStrengthShare, userID) - if err != nil { - return false, nil, nil, err - } - - groups, err := am.Store.GetAccountGroups(ctx, accountID) - if err != nil { - return false, nil, nil, err - } - +func (am *DefaultAccountManager) getJWTGroupsChanges(ctx context.Context, user *User, groups []*nbgroup.Group, groupNames []string) (bool, []string, []*nbgroup.Group, error) { existedGroupsByName := make(map[string]*nbgroup.Group) for _, group := range groups { existedGroupsByName[group.Name] = group @@ -880,7 +871,7 @@ func (am *DefaultAccountManager) getJWTGroupsChanges(ctx context.Context, userID if !exists { group = &nbgroup.Group{ ID: xid.New().String(), - AccountID: accountID, + AccountID: user.AccountID, Name: name, Issued: nbgroup.GroupIssuedJWT, } @@ -1836,16 +1827,6 @@ func (am *DefaultAccountManager) syncJWTGroups(ctx context.Context, accountID st jwtGroupsNames := extractJWTGroups(ctx, settings.JWTGroupsClaimName, claims) - hasChanges, updatedAutoGroups, newGroupsToCreate, err := am.getJWTGroupsChanges(ctx, claims.UserId, accountID, jwtGroupsNames) - if err != nil { - return err - } - - // skip update if no changes - if !hasChanges { - return nil - } - unlockPeer := am.Store.AcquireWriteLockByUID(ctx, accountID) defer func() { if unlockPeer != nil { @@ -1853,19 +1834,39 @@ func (am *DefaultAccountManager) syncJWTGroups(ctx context.Context, accountID st } }() - if err = am.Store.SaveGroups(ctx, LockingStrengthUpdate, newGroupsToCreate); err != nil { - return fmt.Errorf("error saving groups: %w", err) - } - - user, err := am.Store.GetUserByUserID(ctx, LockingStrengthUpdate, claims.UserId) - if err != nil { - return fmt.Errorf("error getting user: %w", err) - } - - addNewGroups := difference(updatedAutoGroups, user.AutoGroups) - removeOldGroups := difference(user.AutoGroups, updatedAutoGroups) - + var addNewGroups []string + var removeOldGroups []string + var hasChanges bool + var user *User err = am.Store.ExecuteInTransaction(ctx, func(transaction Store) error { + user, err = am.Store.GetUserByUserID(ctx, LockingStrengthShare, claims.UserId) + if err != nil { + return fmt.Errorf("error getting user: %w", err) + } + + groups, err := am.Store.GetAccountGroups(ctx, accountID) + if err != nil { + return fmt.Errorf("error getting account groups: %w", err) + } + + changed, updatedAutoGroups, newGroupsToCreate, err := am.getJWTGroupsChanges(ctx, user, groups, jwtGroupsNames) + if err != nil { + return fmt.Errorf("error getting JWT groups changes: %w", err) + } + + hasChanges = changed + // skip update if no changes + if !changed { + return nil + } + + if err = transaction.SaveGroups(ctx, LockingStrengthUpdate, newGroupsToCreate); err != nil { + return fmt.Errorf("error saving groups: %w", err) + } + + addNewGroups = difference(updatedAutoGroups, user.AutoGroups) + removeOldGroups = difference(user.AutoGroups, updatedAutoGroups) + user.AutoGroups = updatedAutoGroups if err = transaction.SaveUser(ctx, LockingStrengthUpdate, user); err != nil { return fmt.Errorf("error saving user: %w", err) @@ -1873,7 +1874,22 @@ func (am *DefaultAccountManager) syncJWTGroups(ctx context.Context, accountID st // Propagate changes to peers if group propagation is enabled if settings.GroupsPropagationEnabled { - updatedGroups, err := am.updateUserPeersInGroups(ctx, accountID, claims.UserId, addNewGroups, removeOldGroups) + groups, err = transaction.GetAccountGroups(ctx, accountID) + if err != nil { + return fmt.Errorf("error getting account groups: %w", err) + } + + groupsMap := make(map[string]*nbgroup.Group, len(groups)) + for _, group := range groups { + groupsMap[group.ID] = group + } + + peers, err := transaction.GetUserPeers(ctx, LockingStrengthShare, accountID, claims.UserId) + if err != nil { + return fmt.Errorf("error getting user peers: %w", err) + } + + updatedGroups, err := am.updateUserPeersInGroups(ctx, groupsMap, peers, addNewGroups, removeOldGroups) if err != nil { return fmt.Errorf("error modifying user peers in groups: %w", err) } @@ -1895,6 +1911,10 @@ func (am *DefaultAccountManager) syncJWTGroups(ctx context.Context, accountID st return err } + if !hasChanges { + return nil + } + for _, g := range addNewGroups { group, err := am.Store.GetGroupByID(ctx, LockingStrengthShare, g, accountID) if err != nil { diff --git a/management/server/sql_store_test.go b/management/server/sql_store_test.go index dc07849d9..5cabe89aa 100644 --- a/management/server/sql_store_test.go +++ b/management/server/sql_store_test.go @@ -1185,3 +1185,37 @@ func TestSqlite_incrementSetupKeyUsage(t *testing.T) { require.NoError(t, err) assert.Equal(t, 2, setupKey.UsedTimes) } + +func TestSqlite_CreateAndGetObjcetInTransaction(t *testing.T) { + store, cleanup, err := NewSqliteTestStore(context.Background(), t.TempDir(), "testdata/extended-store.sqlite") + t.Cleanup(cleanup) + if err != nil { + t.Fatal(err) + } + + group := &nbgroup.Group{ + ID: "group-id", + AccountID: "account-id", + Name: "group-name", + Issued: "api", + Peers: nil, + } + store.ExecuteInTransaction(context.Background(), func(transaction Store) error { + err := transaction.SaveGroup(context.Background(), LockingStrengthUpdate, group) + if err != nil { + t.Fatal("failed to save group") + return err + } + + group, err = transaction.GetGroupByID(context.Background(), LockingStrengthUpdate, group.ID, group.AccountID) + if err != nil { + t.Fatal("failed to get group") + return err + } + + t.Logf("group: %v", group) + + return nil + }) + +} diff --git a/management/server/user.go b/management/server/user.go index 8c3ad846d..ed1c5453f 100644 --- a/management/server/user.go +++ b/management/server/user.go @@ -8,6 +8,8 @@ import ( "time" "github.com/google/uuid" + log "github.com/sirupsen/logrus" + "github.com/netbirdio/netbird/management/server/activity" nbgroup "github.com/netbirdio/netbird/management/server/group" "github.com/netbirdio/netbird/management/server/idp" @@ -15,7 +17,6 @@ import ( "github.com/netbirdio/netbird/management/server/jwtclaims" nbpeer "github.com/netbirdio/netbird/management/server/peer" "github.com/netbirdio/netbird/management/server/status" - log "github.com/sirupsen/logrus" ) const ( @@ -1255,36 +1256,31 @@ func (am *DefaultAccountManager) prepareUserDeletion(ctx context.Context, accoun } // updateUserPeersInGroups updates the user's peers in the specified groups by adding or removing them. -func (am *DefaultAccountManager) updateUserPeersInGroups(ctx context.Context, accountID, userID string, groupsToAdd, +func (am *DefaultAccountManager) updateUserPeersInGroups(ctx context.Context, accountGroups map[string]*nbgroup.Group, peers []*nbpeer.Peer, groupsToAdd, groupsToRemove []string) (groupsToUpdate []*nbgroup.Group, err error) { if len(groupsToAdd) == 0 && len(groupsToRemove) == 0 { return } - peers, err := am.Store.GetUserPeers(ctx, LockingStrengthShare, accountID, userID) - if err != nil { - return nil, err - } - userPeerIDMap := make(map[string]struct{}, len(peers)) for _, peer := range peers { userPeerIDMap[peer.ID] = struct{}{} } for _, gid := range groupsToAdd { - group, err := am.Store.GetGroupByID(ctx, LockingStrengthShare, gid, accountID) - if err != nil { - return nil, err + group, ok := accountGroups[gid] + if !ok { + return nil, errors.New("group not found") } addUserPeersToGroup(userPeerIDMap, group) groupsToUpdate = append(groupsToUpdate, group) } for _, gid := range groupsToRemove { - group, err := am.Store.GetGroupByID(ctx, LockingStrengthShare, gid, accountID) - if err != nil { - return nil, err + group, ok := accountGroups[gid] + if !ok { + return nil, errors.New("group not found") } removeUserPeersFromGroup(userPeerIDMap, group) groupsToUpdate = append(groupsToUpdate, group)