sync user jwt group changes

Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>
This commit is contained in:
bcmmbaga
2024-10-02 19:23:24 +03:00
parent 7d319d5a2e
commit 37b082933c
4 changed files with 139 additions and 80 deletions

View File

@@ -841,55 +841,64 @@ func (a *Account) GetPeer(peerID string) *nbpeer.Peer {
return a.Peers[peerID] return a.Peers[peerID]
} }
// SetJWTGroups updates the user's auto groups by synchronizing JWT groups. // getJWTGroupsChanges calculates the changes needed to sync a user's JWT groups.
// Returns true if there are changes in the JWT group membership. // Returns a bool indicating if there are changes in the JWT group membership, the updated user AutoGroups,
func (a *Account) SetJWTGroups(userID string, groupsNames []string) bool { // newly groups to create and an error if any occurred.
user, ok := a.Users[userID] func (am *DefaultAccountManager) getJWTGroupsChanges(ctx context.Context, userID, accountID string, groupNames []string) (bool, []string, []*nbgroup.Group, error) {
if !ok { user, err := am.Store.GetUserByUserID(ctx, LockingStrengthShare, userID)
return false if err != nil {
return false, nil, nil, err
}
groups, err := am.Store.GetAccountGroups(ctx, accountID)
if err != nil {
return false, nil, nil, err
} }
existedGroupsByName := make(map[string]*nbgroup.Group) existedGroupsByName := make(map[string]*nbgroup.Group)
for _, group := range a.Groups { for _, group := range groups {
existedGroupsByName[group.Name] = group existedGroupsByName[group.Name] = group
} }
newAutoGroups, jwtGroupsMap := separateGroups(user.AutoGroups, a.Groups) newUserAutoGroups, jwtGroupsMap := separateGroups(user.AutoGroups, groups)
groupsToAdd := difference(groupsNames, maps.Keys(jwtGroupsMap))
groupsToRemove := difference(maps.Keys(jwtGroupsMap), groupsNames) groupsToAdd := difference(groupNames, maps.Keys(jwtGroupsMap))
groupsToRemove := difference(maps.Keys(jwtGroupsMap), groupNames)
// If no groups are added or removed, we should not sync account // If no groups are added or removed, we should not sync account
if len(groupsToAdd) == 0 && len(groupsToRemove) == 0 { if len(groupsToAdd) == 0 && len(groupsToRemove) == 0 {
return false return false, nil, nil, nil
} }
newGroupsToCreate := make([]*nbgroup.Group, 0)
var modified bool var modified bool
for _, name := range groupsToAdd { for _, name := range groupsToAdd {
group, exists := existedGroupsByName[name] group, exists := existedGroupsByName[name]
if !exists { if !exists {
group = &nbgroup.Group{ group = &nbgroup.Group{
ID: xid.New().String(), ID: xid.New().String(),
Name: name, AccountID: accountID,
Issued: nbgroup.GroupIssuedJWT, Name: name,
Issued: nbgroup.GroupIssuedJWT,
} }
a.Groups[group.ID] = group newGroupsToCreate = append(newGroupsToCreate, group)
} }
if group.Issued == nbgroup.GroupIssuedJWT { if group.Issued == nbgroup.GroupIssuedJWT {
newAutoGroups = append(newAutoGroups, group.ID) newUserAutoGroups = append(newUserAutoGroups, group.ID)
modified = true modified = true
} }
} }
for name, id := range jwtGroupsMap { for name, id := range jwtGroupsMap {
if !slices.Contains(groupsToRemove, name) { if !slices.Contains(groupsToRemove, name) {
newAutoGroups = append(newAutoGroups, id) newUserAutoGroups = append(newUserAutoGroups, id)
continue continue
} }
modified = true modified = true
} }
user.AutoGroups = newAutoGroups
return modified return modified, newUserAutoGroups, newGroupsToCreate, nil
} }
// UserGroupsAddToPeers adds groups to all peers of user // UserGroupsAddToPeers adds groups to all peers of user
@@ -1819,66 +1828,84 @@ func (am *DefaultAccountManager) syncJWTGroups(ctx context.Context, accountID st
} }
if settings.JWTGroupsClaimName == "" { if settings.JWTGroupsClaimName == "" {
log.WithContext(ctx).Errorf("JWT groups are enabled but no claim name is set") log.WithContext(ctx).Debugf("JWT groups are enabled but no claim name is set")
return nil return nil
} }
// TODO: Remove GetAccount after refactoring account peer's update jwtGroupsNames := extractJWTGroups(ctx, settings.JWTGroupsClaimName, claims)
unlock := am.Store.AcquireWriteLockByUID(ctx, accountID) hasChanges, updatedAutoGroups, newGroupsToCreate, err := am.getJWTGroupsChanges(ctx, claims.UserId, accountID, jwtGroupsNames)
defer unlock()
account, err := am.Store.GetAccount(ctx, accountID)
if err != nil { if err != nil {
//log.WithContext(ctx).Debugf("skipping JWT groups sync for user %s: %v", claims.UserId, err)
return err return err
} }
jwtGroupsNames := extractJWTGroups(ctx, settings.JWTGroupsClaimName, claims)
oldGroups := make([]string, len(user.AutoGroups))
copy(oldGroups, user.AutoGroups)
// Update the account if group membership changes // Update the account if group membership changes
if account.SetJWTGroups(claims.UserId, jwtGroupsNames) { if hasChanges {
addNewGroups := difference(user.AutoGroups, oldGroups) err = am.Store.ExecuteInTransaction(ctx, func(transaction Store) error {
removeOldGroups := difference(oldGroups, user.AutoGroups) oldGroups := make([]string, len(user.AutoGroups))
copy(oldGroups, user.AutoGroups)
if settings.GroupsPropagationEnabled { addNewGroups := difference(user.AutoGroups, oldGroups)
account.UserGroupsAddToPeers(claims.UserId, addNewGroups...) removeOldGroups := difference(oldGroups, user.AutoGroups)
account.UserGroupsRemoveFromPeers(claims.UserId, removeOldGroups...)
account.Network.IncSerial()
}
if err := am.Store.SaveAccount(ctx, account); err != nil { // Propagate changes to peers if group propagation is enabled
log.WithContext(ctx).Errorf("failed to save account: %v", err) if settings.GroupsPropagationEnabled {
// TODO propagate users groups changes to peers
//account.UserGroupsAddToPeers(claims.UserId, addNewGroups...)
//account.UserGroupsRemoveFromPeers(claims.UserId, removeOldGroups...)
if err = transaction.IncrementNetworkSerial(ctx, accountID); err != nil {
return fmt.Errorf("error incrementing network serial: %w", err)
}
account, err := am.requestBuffer.GetAccountWithBackpressure(ctx, accountID)
if err != nil {
return fmt.Errorf("error getting account: %w", err)
}
log.WithContext(ctx).Tracef("user %s: JWT group membership changed, updating account peers", claims.UserId)
am.updateAccountPeers(ctx, account)
}
user.AutoGroups = updatedAutoGroups
if err = transaction.SaveUser(ctx, LockingStrengthUpdate, user); err != nil {
return fmt.Errorf("error saving user: %w", err)
}
if len(newGroupsToCreate) > 0 {
if err = transaction.SaveGroups(ctx, LockingStrengthUpdate, newGroupsToCreate); err != nil {
return fmt.Errorf("error saving groups: %w", err)
}
}
for _, g := range addNewGroups {
group, err := transaction.GetGroupByID(ctx, LockingStrengthShare, g, accountID)
if err != nil {
log.WithContext(ctx).Debugf("group %s not found while saving user activity event of account %s", g, accountID)
} else {
meta := map[string]any{
"group": group.Name, "group_id": group.ID,
"is_service_user": user.IsServiceUser, "user_name": user.ServiceUserName,
}
am.StoreEvent(ctx, user.Id, user.Id, accountID, activity.GroupAddedToUser, meta)
}
}
for _, g := range removeOldGroups {
group, err := transaction.GetGroupByID(ctx, LockingStrengthShare, g, accountID)
if err != nil {
log.WithContext(ctx).Debugf("group %s not found while saving user activity event of account %s", g, accountID)
} else {
meta := map[string]any{
"group": group.Name, "group_id": group.ID,
"is_service_user": user.IsServiceUser, "user_name": user.ServiceUserName,
}
am.StoreEvent(ctx, user.Id, user.Id, accountID, activity.GroupRemovedFromUser, meta)
}
}
return nil return nil
} })
if err != nil {
// Propagate changes to peers if group propagation is enabled return err
if settings.GroupsPropagationEnabled {
log.WithContext(ctx).Tracef("user %s: JWT group membership changed, updating account peers", claims.UserId)
am.updateAccountPeers(ctx, account)
}
for _, g := range addNewGroups {
if group := account.GetGroup(g); group != nil {
am.StoreEvent(ctx, user.Id, user.Id, account.Id, activity.GroupAddedToUser,
map[string]any{
"group": group.Name,
"group_id": group.ID,
"is_service_user": user.IsServiceUser,
"user_name": user.ServiceUserName})
}
}
for _, g := range removeOldGroups {
if group := account.GetGroup(g); group != nil {
am.StoreEvent(ctx, user.Id, user.Id, account.Id, activity.GroupRemovedFromUser,
map[string]any{
"group": group.Name,
"group_id": group.ID,
"is_service_user": user.IsServiceUser,
"user_name": user.ServiceUserName})
}
} }
} }
@@ -2307,12 +2334,17 @@ func userHasAllowedGroup(allowedGroups []string, userGroups []string) bool {
// separateGroups separates user's auto groups into non-JWT and JWT groups. // separateGroups separates user's auto groups into non-JWT and JWT groups.
// Returns the list of standard auto groups and a map of JWT auto groups, // Returns the list of standard auto groups and a map of JWT auto groups,
// where the keys are the group names and the values are the group IDs. // where the keys are the group names and the values are the group IDs.
func separateGroups(autoGroups []string, allGroups map[string]*nbgroup.Group) ([]string, map[string]string) { func separateGroups(autoGroups []string, allGroups []*nbgroup.Group) ([]string, map[string]string) {
newAutoGroups := make([]string, 0) newAutoGroups := make([]string, 0)
jwtAutoGroups := make(map[string]string) // map of group name to group ID jwtAutoGroups := make(map[string]string) // map of group name to group ID
allGroupsMap := make(map[string]*nbgroup.Group, len(allGroups))
for _, group := range allGroups {
allGroupsMap[group.ID] = group
}
for _, id := range autoGroups { for _, id := range autoGroups {
if group, ok := allGroups[id]; ok { if group, ok := allGroupsMap[id]; ok {
if group.Issued == nbgroup.GroupIssuedJWT { if group.Issued == nbgroup.GroupIssuedJWT {
jwtAutoGroups[group.Name] = id jwtAutoGroups[group.Name] = id
} else { } else {
@@ -2320,5 +2352,6 @@ func separateGroups(autoGroups []string, allGroups map[string]*nbgroup.Group) ([
} }
} }
} }
return newAutoGroups, jwtAutoGroups return newAutoGroups, jwtAutoGroups
} }

View File

@@ -964,7 +964,7 @@ func (s *FileStore) SaveUsers(_ string, _ map[string]*User) error {
return status.Errorf(status.Internal, "SaveUsers is not implemented") return status.Errorf(status.Internal, "SaveUsers is not implemented")
} }
func (s *FileStore) SaveGroups(_ string, _ map[string]*nbgroup.Group) error { func (s *FileStore) SaveGroups(_ context.Context, _ LockingStrength, _ []*nbgroup.Group) error {
return status.Errorf(status.Internal, "SaveGroups is not implemented") return status.Errorf(status.Internal, "SaveGroups is not implemented")
} }
@@ -1042,3 +1042,11 @@ func (s *FileStore) GetAccountNameServerGroups(_ context.Context, _ LockingStren
func (s *FileStore) GetNameServerGroupByID(_ context.Context, _ LockingStrength, _ string, _ string) (*dns.NameServerGroup, error) { func (s *FileStore) GetNameServerGroupByID(_ context.Context, _ LockingStrength, _ string, _ string) (*dns.NameServerGroup, error) {
return nil, status.Errorf(status.Internal, "GetNameServerGroupByID is not implemented") return nil, status.Errorf(status.Internal, "GetNameServerGroupByID is not implemented")
} }
func (s *FileStore) SaveUser(_ context.Context, _ LockingStrength, _ *User) error {
return status.Errorf(status.Internal, "SaveUser is not implemented")
}
func (s *FileStore) SaveGroup(_ context.Context, _ LockingStrength, _ *nbgroup.Group) error {
return status.Errorf(status.Internal, "SaveGroup is not implemented")
}

View File

@@ -378,15 +378,22 @@ func (s *SqlStore) SaveUsers(accountID string, users map[string]*User) error {
Create(&usersToSave).Error Create(&usersToSave).Error
} }
// SaveGroups saves the given list of groups to the database. // SaveUser saves the given user to the database.
// It updates existing groups if a conflict occurs. func (s *SqlStore) SaveUser(ctx context.Context, lockStrength LockingStrength, user *User) error {
func (s *SqlStore) SaveGroups(accountID string, groups map[string]*nbgroup.Group) error { result := s.db.WithContext(ctx).Clauses(clause.Locking{Strength: string(lockStrength)}).Save(user)
groupsToSave := make([]nbgroup.Group, 0, len(groups)) if result.Error != nil {
for _, group := range groups { return status.Errorf(status.Internal, "failed to save user to store: %v", result.Error)
group.AccountID = accountID
groupsToSave = append(groupsToSave, *group)
} }
return s.db.Clauses(clause.OnConflict{UpdateAll: true}).Create(&groupsToSave).Error return nil
}
// SaveGroups saves the given list of groups to the database.
func (s *SqlStore) SaveGroups(ctx context.Context, lockStrength LockingStrength, groups []*nbgroup.Group) error {
result := s.db.WithContext(ctx).Clauses(clause.Locking{Strength: string(lockStrength)}).Save(&groups)
if result.Error != nil {
return status.Errorf(status.Internal, "failed to save groups to store: %v", result.Error)
}
return nil
} }
// DeleteHashedPAT2TokenIDIndex is noop in SqlStore // DeleteHashedPAT2TokenIDIndex is noop in SqlStore
@@ -1105,6 +1112,15 @@ func (s *SqlStore) GetGroupByName(ctx context.Context, lockStrength LockingStren
return &group, nil return &group, nil
} }
// SaveGroup saves a group to the store.
func (s *SqlStore) SaveGroup(ctx context.Context, lockStrength LockingStrength, group *nbgroup.Group) error {
result := s.db.WithContext(ctx).Clauses(clause.Locking{Strength: string(lockStrength)}).Save(group)
if result.Error != nil {
return status.Errorf(status.Internal, "failed to save group to store: %v", result.Error)
}
return nil
}
// GetAccountPolicies retrieves policies for an account. // GetAccountPolicies retrieves policies for an account.
func (s *SqlStore) GetAccountPolicies(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*Policy, error) { func (s *SqlStore) GetAccountPolicies(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*Policy, error) {
return getRecords[*Policy](s.db.WithContext(ctx).Preload(clause.Associations), lockStrength, accountID) return getRecords[*Policy](s.db.WithContext(ctx).Preload(clause.Associations), lockStrength, accountID)

View File

@@ -59,6 +59,7 @@ type Store interface {
GetUserByTokenID(ctx context.Context, tokenID string) (*User, error) GetUserByTokenID(ctx context.Context, tokenID string) (*User, error)
GetUserByUserID(ctx context.Context, lockStrength LockingStrength, userID string) (*User, error) GetUserByUserID(ctx context.Context, lockStrength LockingStrength, userID string) (*User, error)
SaveUsers(accountID string, users map[string]*User) error SaveUsers(accountID string, users map[string]*User) error
SaveUser(ctx context.Context, lockStrength LockingStrength, user *User) error
SaveUserLastLogin(ctx context.Context, accountID, userID string, lastLogin time.Time) error SaveUserLastLogin(ctx context.Context, accountID, userID string, lastLogin time.Time) error
GetTokenIDByHashedToken(ctx context.Context, secret string) (string, error) GetTokenIDByHashedToken(ctx context.Context, secret string) (string, error)
DeleteHashedPAT2TokenIDIndex(hashedToken string) error DeleteHashedPAT2TokenIDIndex(hashedToken string) error
@@ -67,7 +68,8 @@ type Store interface {
GetAccountGroups(ctx context.Context, accountID string) ([]*nbgroup.Group, error) GetAccountGroups(ctx context.Context, accountID string) ([]*nbgroup.Group, error)
GetGroupByID(ctx context.Context, lockStrength LockingStrength, groupID, accountID string) (*nbgroup.Group, error) GetGroupByID(ctx context.Context, lockStrength LockingStrength, groupID, accountID string) (*nbgroup.Group, error)
GetGroupByName(ctx context.Context, lockStrength LockingStrength, groupName, accountID string) (*nbgroup.Group, error) GetGroupByName(ctx context.Context, lockStrength LockingStrength, groupName, accountID string) (*nbgroup.Group, error)
SaveGroups(accountID string, groups map[string]*nbgroup.Group) error SaveGroups(ctx context.Context, lockStrength LockingStrength, groups []*nbgroup.Group) error
SaveGroup(ctx context.Context, lockStrength LockingStrength, group *nbgroup.Group) error
GetAccountPolicies(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*Policy, error) GetAccountPolicies(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*Policy, error)
GetPolicyByID(ctx context.Context, lockStrength LockingStrength, policyID string, accountID string) (*Policy, error) GetPolicyByID(ctx context.Context, lockStrength LockingStrength, policyID string, accountID string) (*Policy, error)