mirror of
https://github.com/netbirdio/netbird.git
synced 2025-08-18 02:50:43 +02:00
sync user jwt group changes
Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>
This commit is contained in:
@@ -841,55 +841,64 @@ func (a *Account) GetPeer(peerID string) *nbpeer.Peer {
|
|||||||
return a.Peers[peerID]
|
return a.Peers[peerID]
|
||||||
}
|
}
|
||||||
|
|
||||||
// SetJWTGroups updates the user's auto groups by synchronizing JWT groups.
|
// getJWTGroupsChanges calculates the changes needed to sync a user's JWT groups.
|
||||||
// Returns true if there are changes in the JWT group membership.
|
// Returns a bool indicating if there are changes in the JWT group membership, the updated user AutoGroups,
|
||||||
func (a *Account) SetJWTGroups(userID string, groupsNames []string) bool {
|
// newly groups to create and an error if any occurred.
|
||||||
user, ok := a.Users[userID]
|
func (am *DefaultAccountManager) getJWTGroupsChanges(ctx context.Context, userID, accountID string, groupNames []string) (bool, []string, []*nbgroup.Group, error) {
|
||||||
if !ok {
|
user, err := am.Store.GetUserByUserID(ctx, LockingStrengthShare, userID)
|
||||||
return false
|
if err != nil {
|
||||||
|
return false, nil, nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
groups, err := am.Store.GetAccountGroups(ctx, accountID)
|
||||||
|
if err != nil {
|
||||||
|
return false, nil, nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
existedGroupsByName := make(map[string]*nbgroup.Group)
|
existedGroupsByName := make(map[string]*nbgroup.Group)
|
||||||
for _, group := range a.Groups {
|
for _, group := range groups {
|
||||||
existedGroupsByName[group.Name] = group
|
existedGroupsByName[group.Name] = group
|
||||||
}
|
}
|
||||||
|
|
||||||
newAutoGroups, jwtGroupsMap := separateGroups(user.AutoGroups, a.Groups)
|
newUserAutoGroups, jwtGroupsMap := separateGroups(user.AutoGroups, groups)
|
||||||
groupsToAdd := difference(groupsNames, maps.Keys(jwtGroupsMap))
|
|
||||||
groupsToRemove := difference(maps.Keys(jwtGroupsMap), groupsNames)
|
groupsToAdd := difference(groupNames, maps.Keys(jwtGroupsMap))
|
||||||
|
groupsToRemove := difference(maps.Keys(jwtGroupsMap), groupNames)
|
||||||
|
|
||||||
// If no groups are added or removed, we should not sync account
|
// If no groups are added or removed, we should not sync account
|
||||||
if len(groupsToAdd) == 0 && len(groupsToRemove) == 0 {
|
if len(groupsToAdd) == 0 && len(groupsToRemove) == 0 {
|
||||||
return false
|
return false, nil, nil, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
newGroupsToCreate := make([]*nbgroup.Group, 0)
|
||||||
|
|
||||||
var modified bool
|
var modified bool
|
||||||
for _, name := range groupsToAdd {
|
for _, name := range groupsToAdd {
|
||||||
group, exists := existedGroupsByName[name]
|
group, exists := existedGroupsByName[name]
|
||||||
if !exists {
|
if !exists {
|
||||||
group = &nbgroup.Group{
|
group = &nbgroup.Group{
|
||||||
ID: xid.New().String(),
|
ID: xid.New().String(),
|
||||||
Name: name,
|
AccountID: accountID,
|
||||||
Issued: nbgroup.GroupIssuedJWT,
|
Name: name,
|
||||||
|
Issued: nbgroup.GroupIssuedJWT,
|
||||||
}
|
}
|
||||||
a.Groups[group.ID] = group
|
newGroupsToCreate = append(newGroupsToCreate, group)
|
||||||
}
|
}
|
||||||
if group.Issued == nbgroup.GroupIssuedJWT {
|
if group.Issued == nbgroup.GroupIssuedJWT {
|
||||||
newAutoGroups = append(newAutoGroups, group.ID)
|
newUserAutoGroups = append(newUserAutoGroups, group.ID)
|
||||||
modified = true
|
modified = true
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
for name, id := range jwtGroupsMap {
|
for name, id := range jwtGroupsMap {
|
||||||
if !slices.Contains(groupsToRemove, name) {
|
if !slices.Contains(groupsToRemove, name) {
|
||||||
newAutoGroups = append(newAutoGroups, id)
|
newUserAutoGroups = append(newUserAutoGroups, id)
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
modified = true
|
modified = true
|
||||||
}
|
}
|
||||||
user.AutoGroups = newAutoGroups
|
|
||||||
|
|
||||||
return modified
|
return modified, newUserAutoGroups, newGroupsToCreate, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// UserGroupsAddToPeers adds groups to all peers of user
|
// UserGroupsAddToPeers adds groups to all peers of user
|
||||||
@@ -1819,66 +1828,84 @@ func (am *DefaultAccountManager) syncJWTGroups(ctx context.Context, accountID st
|
|||||||
}
|
}
|
||||||
|
|
||||||
if settings.JWTGroupsClaimName == "" {
|
if settings.JWTGroupsClaimName == "" {
|
||||||
log.WithContext(ctx).Errorf("JWT groups are enabled but no claim name is set")
|
log.WithContext(ctx).Debugf("JWT groups are enabled but no claim name is set")
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// TODO: Remove GetAccount after refactoring account peer's update
|
jwtGroupsNames := extractJWTGroups(ctx, settings.JWTGroupsClaimName, claims)
|
||||||
unlock := am.Store.AcquireWriteLockByUID(ctx, accountID)
|
hasChanges, updatedAutoGroups, newGroupsToCreate, err := am.getJWTGroupsChanges(ctx, claims.UserId, accountID, jwtGroupsNames)
|
||||||
defer unlock()
|
|
||||||
|
|
||||||
account, err := am.Store.GetAccount(ctx, accountID)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
//log.WithContext(ctx).Debugf("skipping JWT groups sync for user %s: %v", claims.UserId, err)
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
jwtGroupsNames := extractJWTGroups(ctx, settings.JWTGroupsClaimName, claims)
|
|
||||||
|
|
||||||
oldGroups := make([]string, len(user.AutoGroups))
|
|
||||||
copy(oldGroups, user.AutoGroups)
|
|
||||||
|
|
||||||
// Update the account if group membership changes
|
// Update the account if group membership changes
|
||||||
if account.SetJWTGroups(claims.UserId, jwtGroupsNames) {
|
if hasChanges {
|
||||||
addNewGroups := difference(user.AutoGroups, oldGroups)
|
err = am.Store.ExecuteInTransaction(ctx, func(transaction Store) error {
|
||||||
removeOldGroups := difference(oldGroups, user.AutoGroups)
|
oldGroups := make([]string, len(user.AutoGroups))
|
||||||
|
copy(oldGroups, user.AutoGroups)
|
||||||
|
|
||||||
if settings.GroupsPropagationEnabled {
|
addNewGroups := difference(user.AutoGroups, oldGroups)
|
||||||
account.UserGroupsAddToPeers(claims.UserId, addNewGroups...)
|
removeOldGroups := difference(oldGroups, user.AutoGroups)
|
||||||
account.UserGroupsRemoveFromPeers(claims.UserId, removeOldGroups...)
|
|
||||||
account.Network.IncSerial()
|
|
||||||
}
|
|
||||||
|
|
||||||
if err := am.Store.SaveAccount(ctx, account); err != nil {
|
// Propagate changes to peers if group propagation is enabled
|
||||||
log.WithContext(ctx).Errorf("failed to save account: %v", err)
|
if settings.GroupsPropagationEnabled {
|
||||||
|
// TODO propagate users groups changes to peers
|
||||||
|
//account.UserGroupsAddToPeers(claims.UserId, addNewGroups...)
|
||||||
|
//account.UserGroupsRemoveFromPeers(claims.UserId, removeOldGroups...)
|
||||||
|
|
||||||
|
if err = transaction.IncrementNetworkSerial(ctx, accountID); err != nil {
|
||||||
|
return fmt.Errorf("error incrementing network serial: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
account, err := am.requestBuffer.GetAccountWithBackpressure(ctx, accountID)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("error getting account: %w", err)
|
||||||
|
}
|
||||||
|
log.WithContext(ctx).Tracef("user %s: JWT group membership changed, updating account peers", claims.UserId)
|
||||||
|
am.updateAccountPeers(ctx, account)
|
||||||
|
}
|
||||||
|
|
||||||
|
user.AutoGroups = updatedAutoGroups
|
||||||
|
if err = transaction.SaveUser(ctx, LockingStrengthUpdate, user); err != nil {
|
||||||
|
return fmt.Errorf("error saving user: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(newGroupsToCreate) > 0 {
|
||||||
|
if err = transaction.SaveGroups(ctx, LockingStrengthUpdate, newGroupsToCreate); err != nil {
|
||||||
|
return fmt.Errorf("error saving groups: %w", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, g := range addNewGroups {
|
||||||
|
group, err := transaction.GetGroupByID(ctx, LockingStrengthShare, g, accountID)
|
||||||
|
if err != nil {
|
||||||
|
log.WithContext(ctx).Debugf("group %s not found while saving user activity event of account %s", g, accountID)
|
||||||
|
} else {
|
||||||
|
meta := map[string]any{
|
||||||
|
"group": group.Name, "group_id": group.ID,
|
||||||
|
"is_service_user": user.IsServiceUser, "user_name": user.ServiceUserName,
|
||||||
|
}
|
||||||
|
am.StoreEvent(ctx, user.Id, user.Id, accountID, activity.GroupAddedToUser, meta)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, g := range removeOldGroups {
|
||||||
|
group, err := transaction.GetGroupByID(ctx, LockingStrengthShare, g, accountID)
|
||||||
|
if err != nil {
|
||||||
|
log.WithContext(ctx).Debugf("group %s not found while saving user activity event of account %s", g, accountID)
|
||||||
|
} else {
|
||||||
|
meta := map[string]any{
|
||||||
|
"group": group.Name, "group_id": group.ID,
|
||||||
|
"is_service_user": user.IsServiceUser, "user_name": user.ServiceUserName,
|
||||||
|
}
|
||||||
|
am.StoreEvent(ctx, user.Id, user.Id, accountID, activity.GroupRemovedFromUser, meta)
|
||||||
|
}
|
||||||
|
}
|
||||||
return nil
|
return nil
|
||||||
}
|
})
|
||||||
|
if err != nil {
|
||||||
// Propagate changes to peers if group propagation is enabled
|
return err
|
||||||
if settings.GroupsPropagationEnabled {
|
|
||||||
log.WithContext(ctx).Tracef("user %s: JWT group membership changed, updating account peers", claims.UserId)
|
|
||||||
am.updateAccountPeers(ctx, account)
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, g := range addNewGroups {
|
|
||||||
if group := account.GetGroup(g); group != nil {
|
|
||||||
am.StoreEvent(ctx, user.Id, user.Id, account.Id, activity.GroupAddedToUser,
|
|
||||||
map[string]any{
|
|
||||||
"group": group.Name,
|
|
||||||
"group_id": group.ID,
|
|
||||||
"is_service_user": user.IsServiceUser,
|
|
||||||
"user_name": user.ServiceUserName})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, g := range removeOldGroups {
|
|
||||||
if group := account.GetGroup(g); group != nil {
|
|
||||||
am.StoreEvent(ctx, user.Id, user.Id, account.Id, activity.GroupRemovedFromUser,
|
|
||||||
map[string]any{
|
|
||||||
"group": group.Name,
|
|
||||||
"group_id": group.ID,
|
|
||||||
"is_service_user": user.IsServiceUser,
|
|
||||||
"user_name": user.ServiceUserName})
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -2307,12 +2334,17 @@ func userHasAllowedGroup(allowedGroups []string, userGroups []string) bool {
|
|||||||
// separateGroups separates user's auto groups into non-JWT and JWT groups.
|
// separateGroups separates user's auto groups into non-JWT and JWT groups.
|
||||||
// Returns the list of standard auto groups and a map of JWT auto groups,
|
// Returns the list of standard auto groups and a map of JWT auto groups,
|
||||||
// where the keys are the group names and the values are the group IDs.
|
// where the keys are the group names and the values are the group IDs.
|
||||||
func separateGroups(autoGroups []string, allGroups map[string]*nbgroup.Group) ([]string, map[string]string) {
|
func separateGroups(autoGroups []string, allGroups []*nbgroup.Group) ([]string, map[string]string) {
|
||||||
newAutoGroups := make([]string, 0)
|
newAutoGroups := make([]string, 0)
|
||||||
jwtAutoGroups := make(map[string]string) // map of group name to group ID
|
jwtAutoGroups := make(map[string]string) // map of group name to group ID
|
||||||
|
|
||||||
|
allGroupsMap := make(map[string]*nbgroup.Group, len(allGroups))
|
||||||
|
for _, group := range allGroups {
|
||||||
|
allGroupsMap[group.ID] = group
|
||||||
|
}
|
||||||
|
|
||||||
for _, id := range autoGroups {
|
for _, id := range autoGroups {
|
||||||
if group, ok := allGroups[id]; ok {
|
if group, ok := allGroupsMap[id]; ok {
|
||||||
if group.Issued == nbgroup.GroupIssuedJWT {
|
if group.Issued == nbgroup.GroupIssuedJWT {
|
||||||
jwtAutoGroups[group.Name] = id
|
jwtAutoGroups[group.Name] = id
|
||||||
} else {
|
} else {
|
||||||
@@ -2320,5 +2352,6 @@ func separateGroups(autoGroups []string, allGroups map[string]*nbgroup.Group) ([
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
return newAutoGroups, jwtAutoGroups
|
return newAutoGroups, jwtAutoGroups
|
||||||
}
|
}
|
||||||
|
@@ -964,7 +964,7 @@ func (s *FileStore) SaveUsers(_ string, _ map[string]*User) error {
|
|||||||
return status.Errorf(status.Internal, "SaveUsers is not implemented")
|
return status.Errorf(status.Internal, "SaveUsers is not implemented")
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *FileStore) SaveGroups(_ string, _ map[string]*nbgroup.Group) error {
|
func (s *FileStore) SaveGroups(_ context.Context, _ LockingStrength, _ []*nbgroup.Group) error {
|
||||||
return status.Errorf(status.Internal, "SaveGroups is not implemented")
|
return status.Errorf(status.Internal, "SaveGroups is not implemented")
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -1042,3 +1042,11 @@ func (s *FileStore) GetAccountNameServerGroups(_ context.Context, _ LockingStren
|
|||||||
func (s *FileStore) GetNameServerGroupByID(_ context.Context, _ LockingStrength, _ string, _ string) (*dns.NameServerGroup, error) {
|
func (s *FileStore) GetNameServerGroupByID(_ context.Context, _ LockingStrength, _ string, _ string) (*dns.NameServerGroup, error) {
|
||||||
return nil, status.Errorf(status.Internal, "GetNameServerGroupByID is not implemented")
|
return nil, status.Errorf(status.Internal, "GetNameServerGroupByID is not implemented")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (s *FileStore) SaveUser(_ context.Context, _ LockingStrength, _ *User) error {
|
||||||
|
return status.Errorf(status.Internal, "SaveUser is not implemented")
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *FileStore) SaveGroup(_ context.Context, _ LockingStrength, _ *nbgroup.Group) error {
|
||||||
|
return status.Errorf(status.Internal, "SaveGroup is not implemented")
|
||||||
|
}
|
||||||
|
@@ -378,15 +378,22 @@ func (s *SqlStore) SaveUsers(accountID string, users map[string]*User) error {
|
|||||||
Create(&usersToSave).Error
|
Create(&usersToSave).Error
|
||||||
}
|
}
|
||||||
|
|
||||||
// SaveGroups saves the given list of groups to the database.
|
// SaveUser saves the given user to the database.
|
||||||
// It updates existing groups if a conflict occurs.
|
func (s *SqlStore) SaveUser(ctx context.Context, lockStrength LockingStrength, user *User) error {
|
||||||
func (s *SqlStore) SaveGroups(accountID string, groups map[string]*nbgroup.Group) error {
|
result := s.db.WithContext(ctx).Clauses(clause.Locking{Strength: string(lockStrength)}).Save(user)
|
||||||
groupsToSave := make([]nbgroup.Group, 0, len(groups))
|
if result.Error != nil {
|
||||||
for _, group := range groups {
|
return status.Errorf(status.Internal, "failed to save user to store: %v", result.Error)
|
||||||
group.AccountID = accountID
|
|
||||||
groupsToSave = append(groupsToSave, *group)
|
|
||||||
}
|
}
|
||||||
return s.db.Clauses(clause.OnConflict{UpdateAll: true}).Create(&groupsToSave).Error
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// SaveGroups saves the given list of groups to the database.
|
||||||
|
func (s *SqlStore) SaveGroups(ctx context.Context, lockStrength LockingStrength, groups []*nbgroup.Group) error {
|
||||||
|
result := s.db.WithContext(ctx).Clauses(clause.Locking{Strength: string(lockStrength)}).Save(&groups)
|
||||||
|
if result.Error != nil {
|
||||||
|
return status.Errorf(status.Internal, "failed to save groups to store: %v", result.Error)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// DeleteHashedPAT2TokenIDIndex is noop in SqlStore
|
// DeleteHashedPAT2TokenIDIndex is noop in SqlStore
|
||||||
@@ -1105,6 +1112,15 @@ func (s *SqlStore) GetGroupByName(ctx context.Context, lockStrength LockingStren
|
|||||||
return &group, nil
|
return &group, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// SaveGroup saves a group to the store.
|
||||||
|
func (s *SqlStore) SaveGroup(ctx context.Context, lockStrength LockingStrength, group *nbgroup.Group) error {
|
||||||
|
result := s.db.WithContext(ctx).Clauses(clause.Locking{Strength: string(lockStrength)}).Save(group)
|
||||||
|
if result.Error != nil {
|
||||||
|
return status.Errorf(status.Internal, "failed to save group to store: %v", result.Error)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
// GetAccountPolicies retrieves policies for an account.
|
// GetAccountPolicies retrieves policies for an account.
|
||||||
func (s *SqlStore) GetAccountPolicies(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*Policy, error) {
|
func (s *SqlStore) GetAccountPolicies(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*Policy, error) {
|
||||||
return getRecords[*Policy](s.db.WithContext(ctx).Preload(clause.Associations), lockStrength, accountID)
|
return getRecords[*Policy](s.db.WithContext(ctx).Preload(clause.Associations), lockStrength, accountID)
|
||||||
|
@@ -59,6 +59,7 @@ type Store interface {
|
|||||||
GetUserByTokenID(ctx context.Context, tokenID string) (*User, error)
|
GetUserByTokenID(ctx context.Context, tokenID string) (*User, error)
|
||||||
GetUserByUserID(ctx context.Context, lockStrength LockingStrength, userID string) (*User, error)
|
GetUserByUserID(ctx context.Context, lockStrength LockingStrength, userID string) (*User, error)
|
||||||
SaveUsers(accountID string, users map[string]*User) error
|
SaveUsers(accountID string, users map[string]*User) error
|
||||||
|
SaveUser(ctx context.Context, lockStrength LockingStrength, user *User) error
|
||||||
SaveUserLastLogin(ctx context.Context, accountID, userID string, lastLogin time.Time) error
|
SaveUserLastLogin(ctx context.Context, accountID, userID string, lastLogin time.Time) error
|
||||||
GetTokenIDByHashedToken(ctx context.Context, secret string) (string, error)
|
GetTokenIDByHashedToken(ctx context.Context, secret string) (string, error)
|
||||||
DeleteHashedPAT2TokenIDIndex(hashedToken string) error
|
DeleteHashedPAT2TokenIDIndex(hashedToken string) error
|
||||||
@@ -67,7 +68,8 @@ type Store interface {
|
|||||||
GetAccountGroups(ctx context.Context, accountID string) ([]*nbgroup.Group, error)
|
GetAccountGroups(ctx context.Context, accountID string) ([]*nbgroup.Group, error)
|
||||||
GetGroupByID(ctx context.Context, lockStrength LockingStrength, groupID, accountID string) (*nbgroup.Group, error)
|
GetGroupByID(ctx context.Context, lockStrength LockingStrength, groupID, accountID string) (*nbgroup.Group, error)
|
||||||
GetGroupByName(ctx context.Context, lockStrength LockingStrength, groupName, accountID string) (*nbgroup.Group, error)
|
GetGroupByName(ctx context.Context, lockStrength LockingStrength, groupName, accountID string) (*nbgroup.Group, error)
|
||||||
SaveGroups(accountID string, groups map[string]*nbgroup.Group) error
|
SaveGroups(ctx context.Context, lockStrength LockingStrength, groups []*nbgroup.Group) error
|
||||||
|
SaveGroup(ctx context.Context, lockStrength LockingStrength, group *nbgroup.Group) error
|
||||||
|
|
||||||
GetAccountPolicies(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*Policy, error)
|
GetAccountPolicies(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*Policy, error)
|
||||||
GetPolicyByID(ctx context.Context, lockStrength LockingStrength, policyID string, accountID string) (*Policy, error)
|
GetPolicyByID(ctx context.Context, lockStrength LockingStrength, policyID string, accountID string) (*Policy, error)
|
||||||
|
Reference in New Issue
Block a user