mirror of
https://github.com/netbirdio/netbird.git
synced 2025-08-10 15:48:29 +02:00
propagate jwt group changes to peers
Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>
This commit is contained in:
@ -1835,81 +1835,80 @@ func (am *DefaultAccountManager) syncJWTGroups(ctx context.Context, accountID st
|
|||||||
jwtGroupsNames := extractJWTGroups(ctx, settings.JWTGroupsClaimName, claims)
|
jwtGroupsNames := extractJWTGroups(ctx, settings.JWTGroupsClaimName, claims)
|
||||||
hasChanges, updatedAutoGroups, newGroupsToCreate, err := am.getJWTGroupsChanges(ctx, claims.UserId, accountID, jwtGroupsNames)
|
hasChanges, updatedAutoGroups, newGroupsToCreate, err := am.getJWTGroupsChanges(ctx, claims.UserId, accountID, jwtGroupsNames)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
//log.WithContext(ctx).Debugf("skipping JWT groups sync for user %s: %v", claims.UserId, err)
|
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
// Update the account if group membership changes
|
// skip update if no changes
|
||||||
if hasChanges {
|
if !hasChanges {
|
||||||
err = am.Store.ExecuteInTransaction(ctx, func(transaction Store) error {
|
log.WithContext(ctx).Debugf("no changes in JWT group membership")
|
||||||
oldGroups := make([]string, len(user.AutoGroups))
|
return nil
|
||||||
copy(oldGroups, user.AutoGroups)
|
|
||||||
|
|
||||||
addNewGroups := difference(user.AutoGroups, oldGroups)
|
|
||||||
removeOldGroups := difference(oldGroups, user.AutoGroups)
|
|
||||||
|
|
||||||
// Propagate changes to peers if group propagation is enabled
|
|
||||||
if settings.GroupsPropagationEnabled {
|
|
||||||
// TODO propagate users groups changes to peers
|
|
||||||
//account.UserGroupsAddToPeers(claims.UserId, addNewGroups...)
|
|
||||||
//account.UserGroupsRemoveFromPeers(claims.UserId, removeOldGroups...)
|
|
||||||
|
|
||||||
if err = transaction.IncrementNetworkSerial(ctx, accountID); err != nil {
|
|
||||||
return fmt.Errorf("error incrementing network serial: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
account, err := am.requestBuffer.GetAccountWithBackpressure(ctx, accountID)
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("error getting account: %w", err)
|
|
||||||
}
|
|
||||||
log.WithContext(ctx).Tracef("user %s: JWT group membership changed, updating account peers", claims.UserId)
|
|
||||||
am.updateAccountPeers(ctx, account)
|
|
||||||
}
|
|
||||||
|
|
||||||
user.AutoGroups = updatedAutoGroups
|
|
||||||
if err = transaction.SaveUser(ctx, LockingStrengthUpdate, user); err != nil {
|
|
||||||
return fmt.Errorf("error saving user: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
if len(newGroupsToCreate) > 0 {
|
|
||||||
if err = transaction.SaveGroups(ctx, LockingStrengthUpdate, newGroupsToCreate); err != nil {
|
|
||||||
return fmt.Errorf("error saving groups: %w", err)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, g := range addNewGroups {
|
|
||||||
group, err := transaction.GetGroupByID(ctx, LockingStrengthShare, g, accountID)
|
|
||||||
if err != nil {
|
|
||||||
log.WithContext(ctx).Debugf("group %s not found while saving user activity event of account %s", g, accountID)
|
|
||||||
} else {
|
|
||||||
meta := map[string]any{
|
|
||||||
"group": group.Name, "group_id": group.ID,
|
|
||||||
"is_service_user": user.IsServiceUser, "user_name": user.ServiceUserName,
|
|
||||||
}
|
|
||||||
am.StoreEvent(ctx, user.Id, user.Id, accountID, activity.GroupAddedToUser, meta)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, g := range removeOldGroups {
|
|
||||||
group, err := transaction.GetGroupByID(ctx, LockingStrengthShare, g, accountID)
|
|
||||||
if err != nil {
|
|
||||||
log.WithContext(ctx).Debugf("group %s not found while saving user activity event of account %s", g, accountID)
|
|
||||||
} else {
|
|
||||||
meta := map[string]any{
|
|
||||||
"group": group.Name, "group_id": group.ID,
|
|
||||||
"is_service_user": user.IsServiceUser, "user_name": user.ServiceUserName,
|
|
||||||
}
|
|
||||||
am.StoreEvent(ctx, user.Id, user.Id, accountID, activity.GroupRemovedFromUser, meta)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
})
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
return nil
|
return am.Store.ExecuteInTransaction(ctx, func(transaction Store) error {
|
||||||
|
oldGroups := make([]string, len(user.AutoGroups))
|
||||||
|
copy(oldGroups, user.AutoGroups)
|
||||||
|
|
||||||
|
addNewGroups := difference(user.AutoGroups, oldGroups)
|
||||||
|
removeOldGroups := difference(oldGroups, user.AutoGroups)
|
||||||
|
|
||||||
|
user.AutoGroups = updatedAutoGroups
|
||||||
|
if err = transaction.SaveUser(ctx, LockingStrengthUpdate, user); err != nil {
|
||||||
|
return fmt.Errorf("error saving user: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if err = transaction.SaveGroups(ctx, LockingStrengthUpdate, newGroupsToCreate); err != nil {
|
||||||
|
return fmt.Errorf("error saving groups: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Propagate changes to peers if group propagation is enabled
|
||||||
|
if settings.GroupsPropagationEnabled {
|
||||||
|
if err = transaction.AddUserPeersToGroups(ctx, accountID, claims.UserId, addNewGroups); err != nil {
|
||||||
|
return fmt.Errorf("error adding user peers to groups: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if err = transaction.RemoveUserPeersFromGroups(ctx, accountID, claims.UserId, removeOldGroups); err != nil {
|
||||||
|
return fmt.Errorf("error removing user peers from groups: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if err = transaction.IncrementNetworkSerial(ctx, accountID); err != nil {
|
||||||
|
return fmt.Errorf("error incrementing network serial: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
account, err := am.requestBuffer.GetAccountWithBackpressure(ctx, accountID)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("error getting account: %w", err)
|
||||||
|
}
|
||||||
|
log.WithContext(ctx).Tracef("user %s: JWT group membership changed, updating account peers", claims.UserId)
|
||||||
|
am.updateAccountPeers(ctx, account)
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, g := range addNewGroups {
|
||||||
|
group, err := transaction.GetGroupByID(ctx, LockingStrengthShare, g, accountID)
|
||||||
|
if err != nil {
|
||||||
|
log.WithContext(ctx).Debugf("group %s not found while saving user activity event of account %s", g, accountID)
|
||||||
|
} else {
|
||||||
|
meta := map[string]any{
|
||||||
|
"group": group.Name, "group_id": group.ID,
|
||||||
|
"is_service_user": user.IsServiceUser, "user_name": user.ServiceUserName,
|
||||||
|
}
|
||||||
|
am.StoreEvent(ctx, user.Id, user.Id, accountID, activity.GroupAddedToUser, meta)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, g := range removeOldGroups {
|
||||||
|
group, err := transaction.GetGroupByID(ctx, LockingStrengthShare, g, accountID)
|
||||||
|
if err != nil {
|
||||||
|
log.WithContext(ctx).Debugf("group %s not found while saving user activity event of account %s", g, accountID)
|
||||||
|
} else {
|
||||||
|
meta := map[string]any{
|
||||||
|
"group": group.Name, "group_id": group.ID,
|
||||||
|
"is_service_user": user.IsServiceUser, "user_name": user.ServiceUserName,
|
||||||
|
}
|
||||||
|
am.StoreEvent(ctx, user.Id, user.Id, accountID, activity.GroupRemovedFromUser, meta)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
// getAccountIDWithAuthorizationClaims retrieves an account ID using JWT Claims.
|
// getAccountIDWithAuthorizationClaims retrieves an account ID using JWT Claims.
|
||||||
|
@ -1050,3 +1050,11 @@ func (s *FileStore) SaveUser(_ context.Context, _ LockingStrength, _ *User) erro
|
|||||||
func (s *FileStore) SaveGroup(_ context.Context, _ LockingStrength, _ *nbgroup.Group) error {
|
func (s *FileStore) SaveGroup(_ context.Context, _ LockingStrength, _ *nbgroup.Group) error {
|
||||||
return status.Errorf(status.Internal, "SaveGroup is not implemented")
|
return status.Errorf(status.Internal, "SaveGroup is not implemented")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (s *FileStore) AddUserPeersToGroups(_ context.Context, _ string, _ string, _ []string) error {
|
||||||
|
return status.Errorf(status.Internal, "AddUserPeersToGroups is not implemented")
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *FileStore) RemoveUserPeersFromGroups(_ context.Context, _ string, _ string, _ []string) error {
|
||||||
|
return status.Errorf(status.Internal, "RemoveUserPeersFromGroups is not implemented")
|
||||||
|
}
|
||||||
|
@ -27,7 +27,7 @@ type MockAccountManager struct {
|
|||||||
CreateSetupKeyFunc func(ctx context.Context, accountId string, keyName string, keyType server.SetupKeyType,
|
CreateSetupKeyFunc func(ctx context.Context, accountId string, keyName string, keyType server.SetupKeyType,
|
||||||
expiresIn time.Duration, autoGroups []string, usageLimit int, userID string, ephemeral bool) (*server.SetupKey, error)
|
expiresIn time.Duration, autoGroups []string, usageLimit int, userID string, ephemeral bool) (*server.SetupKey, error)
|
||||||
GetSetupKeyFunc func(ctx context.Context, accountID, userID, keyID string) (*server.SetupKey, error)
|
GetSetupKeyFunc func(ctx context.Context, accountID, userID, keyID string) (*server.SetupKey, error)
|
||||||
GetAccountIDByUserOrAccountIdFunc func(ctx context.Context, userId, accountId, domain string) (string, error)
|
GetAccountIDByUserIdFunc func(ctx context.Context, userId, domain string) (string, error)
|
||||||
GetUserFunc func(ctx context.Context, claims jwtclaims.AuthorizationClaims) (*server.User, error)
|
GetUserFunc func(ctx context.Context, claims jwtclaims.AuthorizationClaims) (*server.User, error)
|
||||||
ListUsersFunc func(ctx context.Context, accountID string) ([]*server.User, error)
|
ListUsersFunc func(ctx context.Context, accountID string) ([]*server.User, error)
|
||||||
GetPeersFunc func(ctx context.Context, accountID, userID string) ([]*nbpeer.Peer, error)
|
GetPeersFunc func(ctx context.Context, accountID, userID string) ([]*nbpeer.Peer, error)
|
||||||
@ -194,10 +194,10 @@ func (am *MockAccountManager) CreateSetupKey(
|
|||||||
return nil, status.Errorf(codes.Unimplemented, "method CreateSetupKey is not implemented")
|
return nil, status.Errorf(codes.Unimplemented, "method CreateSetupKey is not implemented")
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetAccountIDByUserOrAccountID mock implementation of GetAccountIDByUserOrAccountID from server.AccountManager interface
|
// GetAccountIDByUserID mock implementation of GetAccountIDByUserID from server.AccountManager interface
|
||||||
func (am *MockAccountManager) GetAccountIDByUserOrAccountID(ctx context.Context, userId, accountId, domain string) (string, error) {
|
func (am *MockAccountManager) GetAccountIDByUserID(ctx context.Context, userId, domain string) (string, error) {
|
||||||
if am.GetAccountIDByUserOrAccountIdFunc != nil {
|
if am.GetAccountIDByUserIdFunc != nil {
|
||||||
return am.GetAccountIDByUserOrAccountIdFunc(ctx, userId, accountId, domain)
|
return am.GetAccountIDByUserIdFunc(ctx, userId, domain)
|
||||||
}
|
}
|
||||||
return "", status.Errorf(
|
return "", status.Errorf(
|
||||||
codes.Unimplemented,
|
codes.Unimplemented,
|
||||||
|
@ -10,6 +10,7 @@ import (
|
|||||||
"path/filepath"
|
"path/filepath"
|
||||||
"runtime"
|
"runtime"
|
||||||
"runtime/debug"
|
"runtime/debug"
|
||||||
|
"slices"
|
||||||
"strings"
|
"strings"
|
||||||
"sync"
|
"sync"
|
||||||
"time"
|
"time"
|
||||||
@ -389,6 +390,10 @@ func (s *SqlStore) SaveUser(ctx context.Context, lockStrength LockingStrength, u
|
|||||||
|
|
||||||
// SaveGroups saves the given list of groups to the database.
|
// SaveGroups saves the given list of groups to the database.
|
||||||
func (s *SqlStore) SaveGroups(ctx context.Context, lockStrength LockingStrength, groups []*nbgroup.Group) error {
|
func (s *SqlStore) SaveGroups(ctx context.Context, lockStrength LockingStrength, groups []*nbgroup.Group) error {
|
||||||
|
if len(groups) == 0 {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
result := s.db.WithContext(ctx).Clauses(clause.Locking{Strength: string(lockStrength)}).Save(&groups)
|
result := s.db.WithContext(ctx).Clauses(clause.Locking{Strength: string(lockStrength)}).Save(&groups)
|
||||||
if result.Error != nil {
|
if result.Error != nil {
|
||||||
return status.Errorf(status.Internal, "failed to save groups to store: %v", result.Error)
|
return status.Errorf(status.Internal, "failed to save groups to store: %v", result.Error)
|
||||||
@ -1006,6 +1011,84 @@ func (s *SqlStore) AddPeerToGroup(ctx context.Context, accountId string, peerId
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// AddUserPeersToGroups adds the user's peers to specified groups in database.
|
||||||
|
func (s *SqlStore) AddUserPeersToGroups(ctx context.Context, accountID string, userID string, groupIDs []string) error {
|
||||||
|
if len(groupIDs) == 0 {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
var userPeerIDs []string
|
||||||
|
result := s.db.WithContext(ctx).Clauses(clause.Locking{Strength: string(LockingStrengthShare)}).Select("id").
|
||||||
|
Where("account_id = ? AND user_id = ?", accountID, userID).Model(&nbpeer.Peer{}).Find(&userPeerIDs)
|
||||||
|
if result.Error != nil {
|
||||||
|
return status.Errorf(status.Internal, "issue finding user peers")
|
||||||
|
}
|
||||||
|
|
||||||
|
groupsToUpdate := make([]*nbgroup.Group, 0, len(groupIDs))
|
||||||
|
for _, gid := range groupIDs {
|
||||||
|
group, err := s.GetGroupByID(ctx, LockingStrengthShare, gid, accountID)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
groupPeers := make(map[string]struct{})
|
||||||
|
for _, pid := range group.Peers {
|
||||||
|
groupPeers[pid] = struct{}{}
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, pid := range userPeerIDs {
|
||||||
|
groupPeers[pid] = struct{}{}
|
||||||
|
}
|
||||||
|
|
||||||
|
group.Peers = group.Peers[:0]
|
||||||
|
for pid := range groupPeers {
|
||||||
|
group.Peers = append(group.Peers, pid)
|
||||||
|
}
|
||||||
|
|
||||||
|
groupsToUpdate = append(groupsToUpdate, group)
|
||||||
|
}
|
||||||
|
|
||||||
|
return s.SaveGroups(ctx, LockingStrengthUpdate, groupsToUpdate)
|
||||||
|
}
|
||||||
|
|
||||||
|
// RemoveUserPeersFromGroups removes the user's peers from specified groups in database.
|
||||||
|
func (s *SqlStore) RemoveUserPeersFromGroups(ctx context.Context, accountID string, userID string, groupIDs []string) error {
|
||||||
|
if len(groupIDs) == 0 {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
var userPeerIDs []string
|
||||||
|
result := s.db.WithContext(ctx).Clauses(clause.Locking{Strength: string(LockingStrengthShare)}).Select("id").
|
||||||
|
Where("account_id = ? AND user_id = ?", accountID, userID).Model(&nbpeer.Peer{}).Find(&userPeerIDs)
|
||||||
|
if result.Error != nil {
|
||||||
|
return status.Errorf(status.Internal, "issue finding user peers")
|
||||||
|
}
|
||||||
|
|
||||||
|
groupsToUpdate := make([]*nbgroup.Group, 0, len(groupIDs))
|
||||||
|
for _, gid := range groupIDs {
|
||||||
|
group, err := s.GetGroupByID(ctx, LockingStrengthShare, gid, accountID)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
if group.Name == "All" {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
update := make([]string, 0, len(group.Peers))
|
||||||
|
for _, pid := range group.Peers {
|
||||||
|
if !slices.Contains(userPeerIDs, pid) {
|
||||||
|
update = append(update, pid)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
group.Peers = update
|
||||||
|
groupsToUpdate = append(groupsToUpdate, group)
|
||||||
|
}
|
||||||
|
|
||||||
|
return s.SaveGroups(ctx, LockingStrengthUpdate, groupsToUpdate)
|
||||||
|
}
|
||||||
|
|
||||||
func (s *SqlStore) AddPeerToAccount(ctx context.Context, peer *nbpeer.Peer) error {
|
func (s *SqlStore) AddPeerToAccount(ctx context.Context, peer *nbpeer.Peer) error {
|
||||||
if err := s.db.WithContext(ctx).Create(peer).Error; err != nil {
|
if err := s.db.WithContext(ctx).Create(peer).Error; err != nil {
|
||||||
return status.Errorf(status.Internal, "issue adding peer to account")
|
return status.Errorf(status.Internal, "issue adding peer to account")
|
||||||
|
@ -81,6 +81,8 @@ type Store interface {
|
|||||||
GetPeerLabelsInAccount(ctx context.Context, lockStrength LockingStrength, accountId string) ([]string, error)
|
GetPeerLabelsInAccount(ctx context.Context, lockStrength LockingStrength, accountId string) ([]string, error)
|
||||||
AddPeerToAllGroup(ctx context.Context, accountID string, peerID string) error
|
AddPeerToAllGroup(ctx context.Context, accountID string, peerID string) error
|
||||||
AddPeerToGroup(ctx context.Context, accountId string, peerId string, groupID string) error
|
AddPeerToGroup(ctx context.Context, accountId string, peerId string, groupID string) error
|
||||||
|
AddUserPeersToGroups(ctx context.Context, accountID string, userID string, groupIDs []string) error
|
||||||
|
RemoveUserPeersFromGroups(ctx context.Context, accountID string, userID string, groupIDs []string) error
|
||||||
AddPeerToAccount(ctx context.Context, peer *nbpeer.Peer) error
|
AddPeerToAccount(ctx context.Context, peer *nbpeer.Peer) error
|
||||||
GetPeerByPeerPubKey(ctx context.Context, lockStrength LockingStrength, peerKey string) (*nbpeer.Peer, error)
|
GetPeerByPeerPubKey(ctx context.Context, lockStrength LockingStrength, peerKey string) (*nbpeer.Peer, error)
|
||||||
SavePeer(ctx context.Context, accountID string, peer *nbpeer.Peer) error
|
SavePeer(ctx context.Context, accountID string, peer *nbpeer.Peer) error
|
||||||
|
Reference in New Issue
Block a user