sync user jwt group changes

Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>
This commit is contained in:
bcmmbaga
2024-10-02 19:23:24 +03:00
parent 7d319d5a2e
commit 37b082933c
4 changed files with 139 additions and 80 deletions

View File

@@ -841,55 +841,64 @@ func (a *Account) GetPeer(peerID string) *nbpeer.Peer {
return a.Peers[peerID]
}
// SetJWTGroups updates the user's auto groups by synchronizing JWT groups.
// Returns true if there are changes in the JWT group membership.
func (a *Account) SetJWTGroups(userID string, groupsNames []string) bool {
user, ok := a.Users[userID]
if !ok {
return false
// getJWTGroupsChanges calculates the changes needed to sync a user's JWT groups.
// Returns a bool indicating if there are changes in the JWT group membership, the updated user AutoGroups,
// newly groups to create and an error if any occurred.
func (am *DefaultAccountManager) getJWTGroupsChanges(ctx context.Context, userID, accountID string, groupNames []string) (bool, []string, []*nbgroup.Group, error) {
user, err := am.Store.GetUserByUserID(ctx, LockingStrengthShare, userID)
if err != nil {
return false, nil, nil, err
}
groups, err := am.Store.GetAccountGroups(ctx, accountID)
if err != nil {
return false, nil, nil, err
}
existedGroupsByName := make(map[string]*nbgroup.Group)
for _, group := range a.Groups {
for _, group := range groups {
existedGroupsByName[group.Name] = group
}
newAutoGroups, jwtGroupsMap := separateGroups(user.AutoGroups, a.Groups)
groupsToAdd := difference(groupsNames, maps.Keys(jwtGroupsMap))
groupsToRemove := difference(maps.Keys(jwtGroupsMap), groupsNames)
newUserAutoGroups, jwtGroupsMap := separateGroups(user.AutoGroups, groups)
groupsToAdd := difference(groupNames, maps.Keys(jwtGroupsMap))
groupsToRemove := difference(maps.Keys(jwtGroupsMap), groupNames)
// If no groups are added or removed, we should not sync account
if len(groupsToAdd) == 0 && len(groupsToRemove) == 0 {
return false
return false, nil, nil, nil
}
newGroupsToCreate := make([]*nbgroup.Group, 0)
var modified bool
for _, name := range groupsToAdd {
group, exists := existedGroupsByName[name]
if !exists {
group = &nbgroup.Group{
ID: xid.New().String(),
Name: name,
Issued: nbgroup.GroupIssuedJWT,
ID: xid.New().String(),
AccountID: accountID,
Name: name,
Issued: nbgroup.GroupIssuedJWT,
}
a.Groups[group.ID] = group
newGroupsToCreate = append(newGroupsToCreate, group)
}
if group.Issued == nbgroup.GroupIssuedJWT {
newAutoGroups = append(newAutoGroups, group.ID)
newUserAutoGroups = append(newUserAutoGroups, group.ID)
modified = true
}
}
for name, id := range jwtGroupsMap {
if !slices.Contains(groupsToRemove, name) {
newAutoGroups = append(newAutoGroups, id)
newUserAutoGroups = append(newUserAutoGroups, id)
continue
}
modified = true
}
user.AutoGroups = newAutoGroups
return modified
return modified, newUserAutoGroups, newGroupsToCreate, nil
}
// UserGroupsAddToPeers adds groups to all peers of user
@@ -1819,66 +1828,84 @@ func (am *DefaultAccountManager) syncJWTGroups(ctx context.Context, accountID st
}
if settings.JWTGroupsClaimName == "" {
log.WithContext(ctx).Errorf("JWT groups are enabled but no claim name is set")
log.WithContext(ctx).Debugf("JWT groups are enabled but no claim name is set")
return nil
}
// TODO: Remove GetAccount after refactoring account peer's update
unlock := am.Store.AcquireWriteLockByUID(ctx, accountID)
defer unlock()
account, err := am.Store.GetAccount(ctx, accountID)
jwtGroupsNames := extractJWTGroups(ctx, settings.JWTGroupsClaimName, claims)
hasChanges, updatedAutoGroups, newGroupsToCreate, err := am.getJWTGroupsChanges(ctx, claims.UserId, accountID, jwtGroupsNames)
if err != nil {
//log.WithContext(ctx).Debugf("skipping JWT groups sync for user %s: %v", claims.UserId, err)
return err
}
jwtGroupsNames := extractJWTGroups(ctx, settings.JWTGroupsClaimName, claims)
oldGroups := make([]string, len(user.AutoGroups))
copy(oldGroups, user.AutoGroups)
// Update the account if group membership changes
if account.SetJWTGroups(claims.UserId, jwtGroupsNames) {
addNewGroups := difference(user.AutoGroups, oldGroups)
removeOldGroups := difference(oldGroups, user.AutoGroups)
if hasChanges {
err = am.Store.ExecuteInTransaction(ctx, func(transaction Store) error {
oldGroups := make([]string, len(user.AutoGroups))
copy(oldGroups, user.AutoGroups)
if settings.GroupsPropagationEnabled {
account.UserGroupsAddToPeers(claims.UserId, addNewGroups...)
account.UserGroupsRemoveFromPeers(claims.UserId, removeOldGroups...)
account.Network.IncSerial()
}
addNewGroups := difference(user.AutoGroups, oldGroups)
removeOldGroups := difference(oldGroups, user.AutoGroups)
if err := am.Store.SaveAccount(ctx, account); err != nil {
log.WithContext(ctx).Errorf("failed to save account: %v", err)
// Propagate changes to peers if group propagation is enabled
if settings.GroupsPropagationEnabled {
// TODO propagate users groups changes to peers
//account.UserGroupsAddToPeers(claims.UserId, addNewGroups...)
//account.UserGroupsRemoveFromPeers(claims.UserId, removeOldGroups...)
if err = transaction.IncrementNetworkSerial(ctx, accountID); err != nil {
return fmt.Errorf("error incrementing network serial: %w", err)
}
account, err := am.requestBuffer.GetAccountWithBackpressure(ctx, accountID)
if err != nil {
return fmt.Errorf("error getting account: %w", err)
}
log.WithContext(ctx).Tracef("user %s: JWT group membership changed, updating account peers", claims.UserId)
am.updateAccountPeers(ctx, account)
}
user.AutoGroups = updatedAutoGroups
if err = transaction.SaveUser(ctx, LockingStrengthUpdate, user); err != nil {
return fmt.Errorf("error saving user: %w", err)
}
if len(newGroupsToCreate) > 0 {
if err = transaction.SaveGroups(ctx, LockingStrengthUpdate, newGroupsToCreate); err != nil {
return fmt.Errorf("error saving groups: %w", err)
}
}
for _, g := range addNewGroups {
group, err := transaction.GetGroupByID(ctx, LockingStrengthShare, g, accountID)
if err != nil {
log.WithContext(ctx).Debugf("group %s not found while saving user activity event of account %s", g, accountID)
} else {
meta := map[string]any{
"group": group.Name, "group_id": group.ID,
"is_service_user": user.IsServiceUser, "user_name": user.ServiceUserName,
}
am.StoreEvent(ctx, user.Id, user.Id, accountID, activity.GroupAddedToUser, meta)
}
}
for _, g := range removeOldGroups {
group, err := transaction.GetGroupByID(ctx, LockingStrengthShare, g, accountID)
if err != nil {
log.WithContext(ctx).Debugf("group %s not found while saving user activity event of account %s", g, accountID)
} else {
meta := map[string]any{
"group": group.Name, "group_id": group.ID,
"is_service_user": user.IsServiceUser, "user_name": user.ServiceUserName,
}
am.StoreEvent(ctx, user.Id, user.Id, accountID, activity.GroupRemovedFromUser, meta)
}
}
return nil
}
// Propagate changes to peers if group propagation is enabled
if settings.GroupsPropagationEnabled {
log.WithContext(ctx).Tracef("user %s: JWT group membership changed, updating account peers", claims.UserId)
am.updateAccountPeers(ctx, account)
}
for _, g := range addNewGroups {
if group := account.GetGroup(g); group != nil {
am.StoreEvent(ctx, user.Id, user.Id, account.Id, activity.GroupAddedToUser,
map[string]any{
"group": group.Name,
"group_id": group.ID,
"is_service_user": user.IsServiceUser,
"user_name": user.ServiceUserName})
}
}
for _, g := range removeOldGroups {
if group := account.GetGroup(g); group != nil {
am.StoreEvent(ctx, user.Id, user.Id, account.Id, activity.GroupRemovedFromUser,
map[string]any{
"group": group.Name,
"group_id": group.ID,
"is_service_user": user.IsServiceUser,
"user_name": user.ServiceUserName})
}
})
if err != nil {
return err
}
}
@@ -2307,12 +2334,17 @@ func userHasAllowedGroup(allowedGroups []string, userGroups []string) bool {
// separateGroups separates user's auto groups into non-JWT and JWT groups.
// Returns the list of standard auto groups and a map of JWT auto groups,
// where the keys are the group names and the values are the group IDs.
func separateGroups(autoGroups []string, allGroups map[string]*nbgroup.Group) ([]string, map[string]string) {
func separateGroups(autoGroups []string, allGroups []*nbgroup.Group) ([]string, map[string]string) {
newAutoGroups := make([]string, 0)
jwtAutoGroups := make(map[string]string) // map of group name to group ID
allGroupsMap := make(map[string]*nbgroup.Group, len(allGroups))
for _, group := range allGroups {
allGroupsMap[group.ID] = group
}
for _, id := range autoGroups {
if group, ok := allGroups[id]; ok {
if group, ok := allGroupsMap[id]; ok {
if group.Issued == nbgroup.GroupIssuedJWT {
jwtAutoGroups[group.Name] = id
} else {
@@ -2320,5 +2352,6 @@ func separateGroups(autoGroups []string, allGroups map[string]*nbgroup.Group) ([
}
}
}
return newAutoGroups, jwtAutoGroups
}

View File

@@ -964,7 +964,7 @@ func (s *FileStore) SaveUsers(_ string, _ map[string]*User) error {
return status.Errorf(status.Internal, "SaveUsers is not implemented")
}
func (s *FileStore) SaveGroups(_ string, _ map[string]*nbgroup.Group) error {
func (s *FileStore) SaveGroups(_ context.Context, _ LockingStrength, _ []*nbgroup.Group) error {
return status.Errorf(status.Internal, "SaveGroups is not implemented")
}
@@ -1042,3 +1042,11 @@ func (s *FileStore) GetAccountNameServerGroups(_ context.Context, _ LockingStren
func (s *FileStore) GetNameServerGroupByID(_ context.Context, _ LockingStrength, _ string, _ string) (*dns.NameServerGroup, error) {
return nil, status.Errorf(status.Internal, "GetNameServerGroupByID is not implemented")
}
func (s *FileStore) SaveUser(_ context.Context, _ LockingStrength, _ *User) error {
return status.Errorf(status.Internal, "SaveUser is not implemented")
}
func (s *FileStore) SaveGroup(_ context.Context, _ LockingStrength, _ *nbgroup.Group) error {
return status.Errorf(status.Internal, "SaveGroup is not implemented")
}

View File

@@ -378,15 +378,22 @@ func (s *SqlStore) SaveUsers(accountID string, users map[string]*User) error {
Create(&usersToSave).Error
}
// SaveGroups saves the given list of groups to the database.
// It updates existing groups if a conflict occurs.
func (s *SqlStore) SaveGroups(accountID string, groups map[string]*nbgroup.Group) error {
groupsToSave := make([]nbgroup.Group, 0, len(groups))
for _, group := range groups {
group.AccountID = accountID
groupsToSave = append(groupsToSave, *group)
// SaveUser saves the given user to the database.
func (s *SqlStore) SaveUser(ctx context.Context, lockStrength LockingStrength, user *User) error {
result := s.db.WithContext(ctx).Clauses(clause.Locking{Strength: string(lockStrength)}).Save(user)
if result.Error != nil {
return status.Errorf(status.Internal, "failed to save user to store: %v", result.Error)
}
return s.db.Clauses(clause.OnConflict{UpdateAll: true}).Create(&groupsToSave).Error
return nil
}
// SaveGroups saves the given list of groups to the database.
func (s *SqlStore) SaveGroups(ctx context.Context, lockStrength LockingStrength, groups []*nbgroup.Group) error {
result := s.db.WithContext(ctx).Clauses(clause.Locking{Strength: string(lockStrength)}).Save(&groups)
if result.Error != nil {
return status.Errorf(status.Internal, "failed to save groups to store: %v", result.Error)
}
return nil
}
// DeleteHashedPAT2TokenIDIndex is noop in SqlStore
@@ -1105,6 +1112,15 @@ func (s *SqlStore) GetGroupByName(ctx context.Context, lockStrength LockingStren
return &group, nil
}
// SaveGroup saves a group to the store.
func (s *SqlStore) SaveGroup(ctx context.Context, lockStrength LockingStrength, group *nbgroup.Group) error {
result := s.db.WithContext(ctx).Clauses(clause.Locking{Strength: string(lockStrength)}).Save(group)
if result.Error != nil {
return status.Errorf(status.Internal, "failed to save group to store: %v", result.Error)
}
return nil
}
// GetAccountPolicies retrieves policies for an account.
func (s *SqlStore) GetAccountPolicies(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*Policy, error) {
return getRecords[*Policy](s.db.WithContext(ctx).Preload(clause.Associations), lockStrength, accountID)

View File

@@ -59,6 +59,7 @@ type Store interface {
GetUserByTokenID(ctx context.Context, tokenID string) (*User, error)
GetUserByUserID(ctx context.Context, lockStrength LockingStrength, userID string) (*User, error)
SaveUsers(accountID string, users map[string]*User) error
SaveUser(ctx context.Context, lockStrength LockingStrength, user *User) error
SaveUserLastLogin(ctx context.Context, accountID, userID string, lastLogin time.Time) error
GetTokenIDByHashedToken(ctx context.Context, secret string) (string, error)
DeleteHashedPAT2TokenIDIndex(hashedToken string) error
@@ -67,7 +68,8 @@ type Store interface {
GetAccountGroups(ctx context.Context, accountID string) ([]*nbgroup.Group, error)
GetGroupByID(ctx context.Context, lockStrength LockingStrength, groupID, accountID string) (*nbgroup.Group, error)
GetGroupByName(ctx context.Context, lockStrength LockingStrength, groupName, accountID string) (*nbgroup.Group, error)
SaveGroups(accountID string, groups map[string]*nbgroup.Group) error
SaveGroups(ctx context.Context, lockStrength LockingStrength, groups []*nbgroup.Group) error
SaveGroup(ctx context.Context, lockStrength LockingStrength, group *nbgroup.Group) error
GetAccountPolicies(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*Policy, error)
GetPolicyByID(ctx context.Context, lockStrength LockingStrength, policyID string, accountID string) (*Policy, error)