diff --git a/management/server/account.go b/management/server/account.go index 59e9b74cd..178aabd55 100644 --- a/management/server/account.go +++ b/management/server/account.go @@ -1835,81 +1835,80 @@ func (am *DefaultAccountManager) syncJWTGroups(ctx context.Context, accountID st jwtGroupsNames := extractJWTGroups(ctx, settings.JWTGroupsClaimName, claims) hasChanges, updatedAutoGroups, newGroupsToCreate, err := am.getJWTGroupsChanges(ctx, claims.UserId, accountID, jwtGroupsNames) if err != nil { - //log.WithContext(ctx).Debugf("skipping JWT groups sync for user %s: %v", claims.UserId, err) return err } - // Update the account if group membership changes - if hasChanges { - err = am.Store.ExecuteInTransaction(ctx, func(transaction Store) error { - oldGroups := make([]string, len(user.AutoGroups)) - copy(oldGroups, user.AutoGroups) - - addNewGroups := difference(user.AutoGroups, oldGroups) - removeOldGroups := difference(oldGroups, user.AutoGroups) - - // Propagate changes to peers if group propagation is enabled - if settings.GroupsPropagationEnabled { - // TODO propagate users groups changes to peers - //account.UserGroupsAddToPeers(claims.UserId, addNewGroups...) - //account.UserGroupsRemoveFromPeers(claims.UserId, removeOldGroups...) - - if err = transaction.IncrementNetworkSerial(ctx, accountID); err != nil { - return fmt.Errorf("error incrementing network serial: %w", err) - } - - account, err := am.requestBuffer.GetAccountWithBackpressure(ctx, accountID) - if err != nil { - return fmt.Errorf("error getting account: %w", err) - } - log.WithContext(ctx).Tracef("user %s: JWT group membership changed, updating account peers", claims.UserId) - am.updateAccountPeers(ctx, account) - } - - user.AutoGroups = updatedAutoGroups - if err = transaction.SaveUser(ctx, LockingStrengthUpdate, user); err != nil { - return fmt.Errorf("error saving user: %w", err) - } - - if len(newGroupsToCreate) > 0 { - if err = transaction.SaveGroups(ctx, LockingStrengthUpdate, newGroupsToCreate); err != nil { - return fmt.Errorf("error saving groups: %w", err) - } - } - - for _, g := range addNewGroups { - group, err := transaction.GetGroupByID(ctx, LockingStrengthShare, g, accountID) - if err != nil { - log.WithContext(ctx).Debugf("group %s not found while saving user activity event of account %s", g, accountID) - } else { - meta := map[string]any{ - "group": group.Name, "group_id": group.ID, - "is_service_user": user.IsServiceUser, "user_name": user.ServiceUserName, - } - am.StoreEvent(ctx, user.Id, user.Id, accountID, activity.GroupAddedToUser, meta) - } - } - - for _, g := range removeOldGroups { - group, err := transaction.GetGroupByID(ctx, LockingStrengthShare, g, accountID) - if err != nil { - log.WithContext(ctx).Debugf("group %s not found while saving user activity event of account %s", g, accountID) - } else { - meta := map[string]any{ - "group": group.Name, "group_id": group.ID, - "is_service_user": user.IsServiceUser, "user_name": user.ServiceUserName, - } - am.StoreEvent(ctx, user.Id, user.Id, accountID, activity.GroupRemovedFromUser, meta) - } - } - return nil - }) - if err != nil { - return err - } + // skip update if no changes + if !hasChanges { + log.WithContext(ctx).Debugf("no changes in JWT group membership") + return nil } - return nil + return am.Store.ExecuteInTransaction(ctx, func(transaction Store) error { + oldGroups := make([]string, len(user.AutoGroups)) + copy(oldGroups, user.AutoGroups) + + addNewGroups := difference(user.AutoGroups, oldGroups) + removeOldGroups := difference(oldGroups, user.AutoGroups) + + user.AutoGroups = updatedAutoGroups + if err = transaction.SaveUser(ctx, LockingStrengthUpdate, user); err != nil { + return fmt.Errorf("error saving user: %w", err) + } + + if err = transaction.SaveGroups(ctx, LockingStrengthUpdate, newGroupsToCreate); err != nil { + return fmt.Errorf("error saving groups: %w", err) + } + + // Propagate changes to peers if group propagation is enabled + if settings.GroupsPropagationEnabled { + if err = transaction.AddUserPeersToGroups(ctx, accountID, claims.UserId, addNewGroups); err != nil { + return fmt.Errorf("error adding user peers to groups: %w", err) + } + + if err = transaction.RemoveUserPeersFromGroups(ctx, accountID, claims.UserId, removeOldGroups); err != nil { + return fmt.Errorf("error removing user peers from groups: %w", err) + } + + if err = transaction.IncrementNetworkSerial(ctx, accountID); err != nil { + return fmt.Errorf("error incrementing network serial: %w", err) + } + + account, err := am.requestBuffer.GetAccountWithBackpressure(ctx, accountID) + if err != nil { + return fmt.Errorf("error getting account: %w", err) + } + log.WithContext(ctx).Tracef("user %s: JWT group membership changed, updating account peers", claims.UserId) + am.updateAccountPeers(ctx, account) + } + + for _, g := range addNewGroups { + group, err := transaction.GetGroupByID(ctx, LockingStrengthShare, g, accountID) + if err != nil { + log.WithContext(ctx).Debugf("group %s not found while saving user activity event of account %s", g, accountID) + } else { + meta := map[string]any{ + "group": group.Name, "group_id": group.ID, + "is_service_user": user.IsServiceUser, "user_name": user.ServiceUserName, + } + am.StoreEvent(ctx, user.Id, user.Id, accountID, activity.GroupAddedToUser, meta) + } + } + + for _, g := range removeOldGroups { + group, err := transaction.GetGroupByID(ctx, LockingStrengthShare, g, accountID) + if err != nil { + log.WithContext(ctx).Debugf("group %s not found while saving user activity event of account %s", g, accountID) + } else { + meta := map[string]any{ + "group": group.Name, "group_id": group.ID, + "is_service_user": user.IsServiceUser, "user_name": user.ServiceUserName, + } + am.StoreEvent(ctx, user.Id, user.Id, accountID, activity.GroupRemovedFromUser, meta) + } + } + return nil + }) } // getAccountIDWithAuthorizationClaims retrieves an account ID using JWT Claims. diff --git a/management/server/file_store.go b/management/server/file_store.go index 13205ca6e..b3375ee11 100644 --- a/management/server/file_store.go +++ b/management/server/file_store.go @@ -1050,3 +1050,11 @@ func (s *FileStore) SaveUser(_ context.Context, _ LockingStrength, _ *User) erro func (s *FileStore) SaveGroup(_ context.Context, _ LockingStrength, _ *nbgroup.Group) error { return status.Errorf(status.Internal, "SaveGroup is not implemented") } + +func (s *FileStore) AddUserPeersToGroups(_ context.Context, _ string, _ string, _ []string) error { + return status.Errorf(status.Internal, "AddUserPeersToGroups is not implemented") +} + +func (s *FileStore) RemoveUserPeersFromGroups(_ context.Context, _ string, _ string, _ []string) error { + return status.Errorf(status.Internal, "RemoveUserPeersFromGroups is not implemented") +} diff --git a/management/server/mock_server/account_mock.go b/management/server/mock_server/account_mock.go index abf4cc185..fc75ea438 100644 --- a/management/server/mock_server/account_mock.go +++ b/management/server/mock_server/account_mock.go @@ -27,7 +27,7 @@ type MockAccountManager struct { CreateSetupKeyFunc func(ctx context.Context, accountId string, keyName string, keyType server.SetupKeyType, expiresIn time.Duration, autoGroups []string, usageLimit int, userID string, ephemeral bool) (*server.SetupKey, error) GetSetupKeyFunc func(ctx context.Context, accountID, userID, keyID string) (*server.SetupKey, error) - GetAccountIDByUserOrAccountIdFunc func(ctx context.Context, userId, accountId, domain string) (string, error) + GetAccountIDByUserIdFunc func(ctx context.Context, userId, domain string) (string, error) GetUserFunc func(ctx context.Context, claims jwtclaims.AuthorizationClaims) (*server.User, error) ListUsersFunc func(ctx context.Context, accountID string) ([]*server.User, error) GetPeersFunc func(ctx context.Context, accountID, userID string) ([]*nbpeer.Peer, error) @@ -194,10 +194,10 @@ func (am *MockAccountManager) CreateSetupKey( return nil, status.Errorf(codes.Unimplemented, "method CreateSetupKey is not implemented") } -// GetAccountIDByUserOrAccountID mock implementation of GetAccountIDByUserOrAccountID from server.AccountManager interface -func (am *MockAccountManager) GetAccountIDByUserOrAccountID(ctx context.Context, userId, accountId, domain string) (string, error) { - if am.GetAccountIDByUserOrAccountIdFunc != nil { - return am.GetAccountIDByUserOrAccountIdFunc(ctx, userId, accountId, domain) +// GetAccountIDByUserID mock implementation of GetAccountIDByUserID from server.AccountManager interface +func (am *MockAccountManager) GetAccountIDByUserID(ctx context.Context, userId, domain string) (string, error) { + if am.GetAccountIDByUserIdFunc != nil { + return am.GetAccountIDByUserIdFunc(ctx, userId, domain) } return "", status.Errorf( codes.Unimplemented, diff --git a/management/server/sql_store.go b/management/server/sql_store.go index b6db1193f..13f5c5e9e 100644 --- a/management/server/sql_store.go +++ b/management/server/sql_store.go @@ -10,6 +10,7 @@ import ( "path/filepath" "runtime" "runtime/debug" + "slices" "strings" "sync" "time" @@ -389,6 +390,10 @@ func (s *SqlStore) SaveUser(ctx context.Context, lockStrength LockingStrength, u // SaveGroups saves the given list of groups to the database. func (s *SqlStore) SaveGroups(ctx context.Context, lockStrength LockingStrength, groups []*nbgroup.Group) error { + if len(groups) == 0 { + return nil + } + result := s.db.WithContext(ctx).Clauses(clause.Locking{Strength: string(lockStrength)}).Save(&groups) if result.Error != nil { return status.Errorf(status.Internal, "failed to save groups to store: %v", result.Error) @@ -1006,6 +1011,84 @@ func (s *SqlStore) AddPeerToGroup(ctx context.Context, accountId string, peerId return nil } +// AddUserPeersToGroups adds the user's peers to specified groups in database. +func (s *SqlStore) AddUserPeersToGroups(ctx context.Context, accountID string, userID string, groupIDs []string) error { + if len(groupIDs) == 0 { + return nil + } + + var userPeerIDs []string + result := s.db.WithContext(ctx).Clauses(clause.Locking{Strength: string(LockingStrengthShare)}).Select("id"). + Where("account_id = ? AND user_id = ?", accountID, userID).Model(&nbpeer.Peer{}).Find(&userPeerIDs) + if result.Error != nil { + return status.Errorf(status.Internal, "issue finding user peers") + } + + groupsToUpdate := make([]*nbgroup.Group, 0, len(groupIDs)) + for _, gid := range groupIDs { + group, err := s.GetGroupByID(ctx, LockingStrengthShare, gid, accountID) + if err != nil { + return err + } + + groupPeers := make(map[string]struct{}) + for _, pid := range group.Peers { + groupPeers[pid] = struct{}{} + } + + for _, pid := range userPeerIDs { + groupPeers[pid] = struct{}{} + } + + group.Peers = group.Peers[:0] + for pid := range groupPeers { + group.Peers = append(group.Peers, pid) + } + + groupsToUpdate = append(groupsToUpdate, group) + } + + return s.SaveGroups(ctx, LockingStrengthUpdate, groupsToUpdate) +} + +// RemoveUserPeersFromGroups removes the user's peers from specified groups in database. +func (s *SqlStore) RemoveUserPeersFromGroups(ctx context.Context, accountID string, userID string, groupIDs []string) error { + if len(groupIDs) == 0 { + return nil + } + + var userPeerIDs []string + result := s.db.WithContext(ctx).Clauses(clause.Locking{Strength: string(LockingStrengthShare)}).Select("id"). + Where("account_id = ? AND user_id = ?", accountID, userID).Model(&nbpeer.Peer{}).Find(&userPeerIDs) + if result.Error != nil { + return status.Errorf(status.Internal, "issue finding user peers") + } + + groupsToUpdate := make([]*nbgroup.Group, 0, len(groupIDs)) + for _, gid := range groupIDs { + group, err := s.GetGroupByID(ctx, LockingStrengthShare, gid, accountID) + if err != nil { + return err + } + + if group.Name == "All" { + continue + } + + update := make([]string, 0, len(group.Peers)) + for _, pid := range group.Peers { + if !slices.Contains(userPeerIDs, pid) { + update = append(update, pid) + } + } + + group.Peers = update + groupsToUpdate = append(groupsToUpdate, group) + } + + return s.SaveGroups(ctx, LockingStrengthUpdate, groupsToUpdate) +} + func (s *SqlStore) AddPeerToAccount(ctx context.Context, peer *nbpeer.Peer) error { if err := s.db.WithContext(ctx).Create(peer).Error; err != nil { return status.Errorf(status.Internal, "issue adding peer to account") diff --git a/management/server/store.go b/management/server/store.go index fe1f8c222..9bc6eafce 100644 --- a/management/server/store.go +++ b/management/server/store.go @@ -81,6 +81,8 @@ type Store interface { GetPeerLabelsInAccount(ctx context.Context, lockStrength LockingStrength, accountId string) ([]string, error) AddPeerToAllGroup(ctx context.Context, accountID string, peerID string) error AddPeerToGroup(ctx context.Context, accountId string, peerId string, groupID string) error + AddUserPeersToGroups(ctx context.Context, accountID string, userID string, groupIDs []string) error + RemoveUserPeersFromGroups(ctx context.Context, accountID string, userID string, groupIDs []string) error AddPeerToAccount(ctx context.Context, peer *nbpeer.Peer) error GetPeerByPeerPubKey(ctx context.Context, lockStrength LockingStrength, peerKey string) (*nbpeer.Peer, error) SavePeer(ctx context.Context, accountID string, peer *nbpeer.Peer) error