From 37b082933cb9b60424dbdc9d4cd11bb723166812 Mon Sep 17 00:00:00 2001 From: bcmmbaga Date: Wed, 2 Oct 2024 19:23:24 +0300 Subject: [PATCH] sync user jwt group changes Signed-off-by: bcmmbaga --- management/server/account.go | 173 +++++++++++++++++++------------- management/server/file_store.go | 10 +- management/server/sql_store.go | 32 ++++-- management/server/store.go | 4 +- 4 files changed, 139 insertions(+), 80 deletions(-) diff --git a/management/server/account.go b/management/server/account.go index aee386211..59e9b74cd 100644 --- a/management/server/account.go +++ b/management/server/account.go @@ -841,55 +841,64 @@ func (a *Account) GetPeer(peerID string) *nbpeer.Peer { return a.Peers[peerID] } -// SetJWTGroups updates the user's auto groups by synchronizing JWT groups. -// Returns true if there are changes in the JWT group membership. -func (a *Account) SetJWTGroups(userID string, groupsNames []string) bool { - user, ok := a.Users[userID] - if !ok { - return false +// getJWTGroupsChanges calculates the changes needed to sync a user's JWT groups. +// Returns a bool indicating if there are changes in the JWT group membership, the updated user AutoGroups, +// newly groups to create and an error if any occurred. +func (am *DefaultAccountManager) getJWTGroupsChanges(ctx context.Context, userID, accountID string, groupNames []string) (bool, []string, []*nbgroup.Group, error) { + user, err := am.Store.GetUserByUserID(ctx, LockingStrengthShare, userID) + if err != nil { + return false, nil, nil, err + } + + groups, err := am.Store.GetAccountGroups(ctx, accountID) + if err != nil { + return false, nil, nil, err } existedGroupsByName := make(map[string]*nbgroup.Group) - for _, group := range a.Groups { + for _, group := range groups { existedGroupsByName[group.Name] = group } - newAutoGroups, jwtGroupsMap := separateGroups(user.AutoGroups, a.Groups) - groupsToAdd := difference(groupsNames, maps.Keys(jwtGroupsMap)) - groupsToRemove := difference(maps.Keys(jwtGroupsMap), groupsNames) + newUserAutoGroups, jwtGroupsMap := separateGroups(user.AutoGroups, groups) + + groupsToAdd := difference(groupNames, maps.Keys(jwtGroupsMap)) + groupsToRemove := difference(maps.Keys(jwtGroupsMap), groupNames) // If no groups are added or removed, we should not sync account if len(groupsToAdd) == 0 && len(groupsToRemove) == 0 { - return false + return false, nil, nil, nil } + newGroupsToCreate := make([]*nbgroup.Group, 0) + var modified bool for _, name := range groupsToAdd { group, exists := existedGroupsByName[name] if !exists { group = &nbgroup.Group{ - ID: xid.New().String(), - Name: name, - Issued: nbgroup.GroupIssuedJWT, + ID: xid.New().String(), + AccountID: accountID, + Name: name, + Issued: nbgroup.GroupIssuedJWT, } - a.Groups[group.ID] = group + newGroupsToCreate = append(newGroupsToCreate, group) } if group.Issued == nbgroup.GroupIssuedJWT { - newAutoGroups = append(newAutoGroups, group.ID) + newUserAutoGroups = append(newUserAutoGroups, group.ID) modified = true } } for name, id := range jwtGroupsMap { if !slices.Contains(groupsToRemove, name) { - newAutoGroups = append(newAutoGroups, id) + newUserAutoGroups = append(newUserAutoGroups, id) continue } modified = true } - user.AutoGroups = newAutoGroups - return modified + return modified, newUserAutoGroups, newGroupsToCreate, nil } // UserGroupsAddToPeers adds groups to all peers of user @@ -1819,66 +1828,84 @@ func (am *DefaultAccountManager) syncJWTGroups(ctx context.Context, accountID st } if settings.JWTGroupsClaimName == "" { - log.WithContext(ctx).Errorf("JWT groups are enabled but no claim name is set") + log.WithContext(ctx).Debugf("JWT groups are enabled but no claim name is set") return nil } - // TODO: Remove GetAccount after refactoring account peer's update - unlock := am.Store.AcquireWriteLockByUID(ctx, accountID) - defer unlock() - - account, err := am.Store.GetAccount(ctx, accountID) + jwtGroupsNames := extractJWTGroups(ctx, settings.JWTGroupsClaimName, claims) + hasChanges, updatedAutoGroups, newGroupsToCreate, err := am.getJWTGroupsChanges(ctx, claims.UserId, accountID, jwtGroupsNames) if err != nil { + //log.WithContext(ctx).Debugf("skipping JWT groups sync for user %s: %v", claims.UserId, err) return err } - jwtGroupsNames := extractJWTGroups(ctx, settings.JWTGroupsClaimName, claims) - - oldGroups := make([]string, len(user.AutoGroups)) - copy(oldGroups, user.AutoGroups) - // Update the account if group membership changes - if account.SetJWTGroups(claims.UserId, jwtGroupsNames) { - addNewGroups := difference(user.AutoGroups, oldGroups) - removeOldGroups := difference(oldGroups, user.AutoGroups) + if hasChanges { + err = am.Store.ExecuteInTransaction(ctx, func(transaction Store) error { + oldGroups := make([]string, len(user.AutoGroups)) + copy(oldGroups, user.AutoGroups) - if settings.GroupsPropagationEnabled { - account.UserGroupsAddToPeers(claims.UserId, addNewGroups...) - account.UserGroupsRemoveFromPeers(claims.UserId, removeOldGroups...) - account.Network.IncSerial() - } + addNewGroups := difference(user.AutoGroups, oldGroups) + removeOldGroups := difference(oldGroups, user.AutoGroups) - if err := am.Store.SaveAccount(ctx, account); err != nil { - log.WithContext(ctx).Errorf("failed to save account: %v", err) + // Propagate changes to peers if group propagation is enabled + if settings.GroupsPropagationEnabled { + // TODO propagate users groups changes to peers + //account.UserGroupsAddToPeers(claims.UserId, addNewGroups...) + //account.UserGroupsRemoveFromPeers(claims.UserId, removeOldGroups...) + + if err = transaction.IncrementNetworkSerial(ctx, accountID); err != nil { + return fmt.Errorf("error incrementing network serial: %w", err) + } + + account, err := am.requestBuffer.GetAccountWithBackpressure(ctx, accountID) + if err != nil { + return fmt.Errorf("error getting account: %w", err) + } + log.WithContext(ctx).Tracef("user %s: JWT group membership changed, updating account peers", claims.UserId) + am.updateAccountPeers(ctx, account) + } + + user.AutoGroups = updatedAutoGroups + if err = transaction.SaveUser(ctx, LockingStrengthUpdate, user); err != nil { + return fmt.Errorf("error saving user: %w", err) + } + + if len(newGroupsToCreate) > 0 { + if err = transaction.SaveGroups(ctx, LockingStrengthUpdate, newGroupsToCreate); err != nil { + return fmt.Errorf("error saving groups: %w", err) + } + } + + for _, g := range addNewGroups { + group, err := transaction.GetGroupByID(ctx, LockingStrengthShare, g, accountID) + if err != nil { + log.WithContext(ctx).Debugf("group %s not found while saving user activity event of account %s", g, accountID) + } else { + meta := map[string]any{ + "group": group.Name, "group_id": group.ID, + "is_service_user": user.IsServiceUser, "user_name": user.ServiceUserName, + } + am.StoreEvent(ctx, user.Id, user.Id, accountID, activity.GroupAddedToUser, meta) + } + } + + for _, g := range removeOldGroups { + group, err := transaction.GetGroupByID(ctx, LockingStrengthShare, g, accountID) + if err != nil { + log.WithContext(ctx).Debugf("group %s not found while saving user activity event of account %s", g, accountID) + } else { + meta := map[string]any{ + "group": group.Name, "group_id": group.ID, + "is_service_user": user.IsServiceUser, "user_name": user.ServiceUserName, + } + am.StoreEvent(ctx, user.Id, user.Id, accountID, activity.GroupRemovedFromUser, meta) + } + } return nil - } - - // Propagate changes to peers if group propagation is enabled - if settings.GroupsPropagationEnabled { - log.WithContext(ctx).Tracef("user %s: JWT group membership changed, updating account peers", claims.UserId) - am.updateAccountPeers(ctx, account) - } - - for _, g := range addNewGroups { - if group := account.GetGroup(g); group != nil { - am.StoreEvent(ctx, user.Id, user.Id, account.Id, activity.GroupAddedToUser, - map[string]any{ - "group": group.Name, - "group_id": group.ID, - "is_service_user": user.IsServiceUser, - "user_name": user.ServiceUserName}) - } - } - - for _, g := range removeOldGroups { - if group := account.GetGroup(g); group != nil { - am.StoreEvent(ctx, user.Id, user.Id, account.Id, activity.GroupRemovedFromUser, - map[string]any{ - "group": group.Name, - "group_id": group.ID, - "is_service_user": user.IsServiceUser, - "user_name": user.ServiceUserName}) - } + }) + if err != nil { + return err } } @@ -2307,12 +2334,17 @@ func userHasAllowedGroup(allowedGroups []string, userGroups []string) bool { // separateGroups separates user's auto groups into non-JWT and JWT groups. // Returns the list of standard auto groups and a map of JWT auto groups, // where the keys are the group names and the values are the group IDs. -func separateGroups(autoGroups []string, allGroups map[string]*nbgroup.Group) ([]string, map[string]string) { +func separateGroups(autoGroups []string, allGroups []*nbgroup.Group) ([]string, map[string]string) { newAutoGroups := make([]string, 0) jwtAutoGroups := make(map[string]string) // map of group name to group ID + allGroupsMap := make(map[string]*nbgroup.Group, len(allGroups)) + for _, group := range allGroups { + allGroupsMap[group.ID] = group + } + for _, id := range autoGroups { - if group, ok := allGroups[id]; ok { + if group, ok := allGroupsMap[id]; ok { if group.Issued == nbgroup.GroupIssuedJWT { jwtAutoGroups[group.Name] = id } else { @@ -2320,5 +2352,6 @@ func separateGroups(autoGroups []string, allGroups map[string]*nbgroup.Group) ([ } } } + return newAutoGroups, jwtAutoGroups } diff --git a/management/server/file_store.go b/management/server/file_store.go index 994a4b1ee..13205ca6e 100644 --- a/management/server/file_store.go +++ b/management/server/file_store.go @@ -964,7 +964,7 @@ func (s *FileStore) SaveUsers(_ string, _ map[string]*User) error { return status.Errorf(status.Internal, "SaveUsers is not implemented") } -func (s *FileStore) SaveGroups(_ string, _ map[string]*nbgroup.Group) error { +func (s *FileStore) SaveGroups(_ context.Context, _ LockingStrength, _ []*nbgroup.Group) error { return status.Errorf(status.Internal, "SaveGroups is not implemented") } @@ -1042,3 +1042,11 @@ func (s *FileStore) GetAccountNameServerGroups(_ context.Context, _ LockingStren func (s *FileStore) GetNameServerGroupByID(_ context.Context, _ LockingStrength, _ string, _ string) (*dns.NameServerGroup, error) { return nil, status.Errorf(status.Internal, "GetNameServerGroupByID is not implemented") } + +func (s *FileStore) SaveUser(_ context.Context, _ LockingStrength, _ *User) error { + return status.Errorf(status.Internal, "SaveUser is not implemented") +} + +func (s *FileStore) SaveGroup(_ context.Context, _ LockingStrength, _ *nbgroup.Group) error { + return status.Errorf(status.Internal, "SaveGroup is not implemented") +} diff --git a/management/server/sql_store.go b/management/server/sql_store.go index 85c68ef44..b6db1193f 100644 --- a/management/server/sql_store.go +++ b/management/server/sql_store.go @@ -378,15 +378,22 @@ func (s *SqlStore) SaveUsers(accountID string, users map[string]*User) error { Create(&usersToSave).Error } -// SaveGroups saves the given list of groups to the database. -// It updates existing groups if a conflict occurs. -func (s *SqlStore) SaveGroups(accountID string, groups map[string]*nbgroup.Group) error { - groupsToSave := make([]nbgroup.Group, 0, len(groups)) - for _, group := range groups { - group.AccountID = accountID - groupsToSave = append(groupsToSave, *group) +// SaveUser saves the given user to the database. +func (s *SqlStore) SaveUser(ctx context.Context, lockStrength LockingStrength, user *User) error { + result := s.db.WithContext(ctx).Clauses(clause.Locking{Strength: string(lockStrength)}).Save(user) + if result.Error != nil { + return status.Errorf(status.Internal, "failed to save user to store: %v", result.Error) } - return s.db.Clauses(clause.OnConflict{UpdateAll: true}).Create(&groupsToSave).Error + return nil +} + +// SaveGroups saves the given list of groups to the database. +func (s *SqlStore) SaveGroups(ctx context.Context, lockStrength LockingStrength, groups []*nbgroup.Group) error { + result := s.db.WithContext(ctx).Clauses(clause.Locking{Strength: string(lockStrength)}).Save(&groups) + if result.Error != nil { + return status.Errorf(status.Internal, "failed to save groups to store: %v", result.Error) + } + return nil } // DeleteHashedPAT2TokenIDIndex is noop in SqlStore @@ -1105,6 +1112,15 @@ func (s *SqlStore) GetGroupByName(ctx context.Context, lockStrength LockingStren return &group, nil } +// SaveGroup saves a group to the store. +func (s *SqlStore) SaveGroup(ctx context.Context, lockStrength LockingStrength, group *nbgroup.Group) error { + result := s.db.WithContext(ctx).Clauses(clause.Locking{Strength: string(lockStrength)}).Save(group) + if result.Error != nil { + return status.Errorf(status.Internal, "failed to save group to store: %v", result.Error) + } + return nil +} + // GetAccountPolicies retrieves policies for an account. func (s *SqlStore) GetAccountPolicies(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*Policy, error) { return getRecords[*Policy](s.db.WithContext(ctx).Preload(clause.Associations), lockStrength, accountID) diff --git a/management/server/store.go b/management/server/store.go index f34a73c2d..fe1f8c222 100644 --- a/management/server/store.go +++ b/management/server/store.go @@ -59,6 +59,7 @@ type Store interface { GetUserByTokenID(ctx context.Context, tokenID string) (*User, error) GetUserByUserID(ctx context.Context, lockStrength LockingStrength, userID string) (*User, error) SaveUsers(accountID string, users map[string]*User) error + SaveUser(ctx context.Context, lockStrength LockingStrength, user *User) error SaveUserLastLogin(ctx context.Context, accountID, userID string, lastLogin time.Time) error GetTokenIDByHashedToken(ctx context.Context, secret string) (string, error) DeleteHashedPAT2TokenIDIndex(hashedToken string) error @@ -67,7 +68,8 @@ type Store interface { GetAccountGroups(ctx context.Context, accountID string) ([]*nbgroup.Group, error) GetGroupByID(ctx context.Context, lockStrength LockingStrength, groupID, accountID string) (*nbgroup.Group, error) GetGroupByName(ctx context.Context, lockStrength LockingStrength, groupName, accountID string) (*nbgroup.Group, error) - SaveGroups(accountID string, groups map[string]*nbgroup.Group) error + SaveGroups(ctx context.Context, lockStrength LockingStrength, groups []*nbgroup.Group) error + SaveGroup(ctx context.Context, lockStrength LockingStrength, group *nbgroup.Group) error GetAccountPolicies(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*Policy, error) GetPolicyByID(ctx context.Context, lockStrength LockingStrength, policyID string, accountID string) (*Policy, error)