propagate jwt group changes to peers

Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>
This commit is contained in:
bcmmbaga
2024-10-02 21:05:59 +03:00
parent 37b082933c
commit fceaa2c03d
5 changed files with 167 additions and 75 deletions

View File

@ -1835,81 +1835,80 @@ func (am *DefaultAccountManager) syncJWTGroups(ctx context.Context, accountID st
jwtGroupsNames := extractJWTGroups(ctx, settings.JWTGroupsClaimName, claims) jwtGroupsNames := extractJWTGroups(ctx, settings.JWTGroupsClaimName, claims)
hasChanges, updatedAutoGroups, newGroupsToCreate, err := am.getJWTGroupsChanges(ctx, claims.UserId, accountID, jwtGroupsNames) hasChanges, updatedAutoGroups, newGroupsToCreate, err := am.getJWTGroupsChanges(ctx, claims.UserId, accountID, jwtGroupsNames)
if err != nil { if err != nil {
//log.WithContext(ctx).Debugf("skipping JWT groups sync for user %s: %v", claims.UserId, err)
return err return err
} }
// Update the account if group membership changes // skip update if no changes
if hasChanges { if !hasChanges {
err = am.Store.ExecuteInTransaction(ctx, func(transaction Store) error { log.WithContext(ctx).Debugf("no changes in JWT group membership")
oldGroups := make([]string, len(user.AutoGroups)) return nil
copy(oldGroups, user.AutoGroups)
addNewGroups := difference(user.AutoGroups, oldGroups)
removeOldGroups := difference(oldGroups, user.AutoGroups)
// Propagate changes to peers if group propagation is enabled
if settings.GroupsPropagationEnabled {
// TODO propagate users groups changes to peers
//account.UserGroupsAddToPeers(claims.UserId, addNewGroups...)
//account.UserGroupsRemoveFromPeers(claims.UserId, removeOldGroups...)
if err = transaction.IncrementNetworkSerial(ctx, accountID); err != nil {
return fmt.Errorf("error incrementing network serial: %w", err)
}
account, err := am.requestBuffer.GetAccountWithBackpressure(ctx, accountID)
if err != nil {
return fmt.Errorf("error getting account: %w", err)
}
log.WithContext(ctx).Tracef("user %s: JWT group membership changed, updating account peers", claims.UserId)
am.updateAccountPeers(ctx, account)
}
user.AutoGroups = updatedAutoGroups
if err = transaction.SaveUser(ctx, LockingStrengthUpdate, user); err != nil {
return fmt.Errorf("error saving user: %w", err)
}
if len(newGroupsToCreate) > 0 {
if err = transaction.SaveGroups(ctx, LockingStrengthUpdate, newGroupsToCreate); err != nil {
return fmt.Errorf("error saving groups: %w", err)
}
}
for _, g := range addNewGroups {
group, err := transaction.GetGroupByID(ctx, LockingStrengthShare, g, accountID)
if err != nil {
log.WithContext(ctx).Debugf("group %s not found while saving user activity event of account %s", g, accountID)
} else {
meta := map[string]any{
"group": group.Name, "group_id": group.ID,
"is_service_user": user.IsServiceUser, "user_name": user.ServiceUserName,
}
am.StoreEvent(ctx, user.Id, user.Id, accountID, activity.GroupAddedToUser, meta)
}
}
for _, g := range removeOldGroups {
group, err := transaction.GetGroupByID(ctx, LockingStrengthShare, g, accountID)
if err != nil {
log.WithContext(ctx).Debugf("group %s not found while saving user activity event of account %s", g, accountID)
} else {
meta := map[string]any{
"group": group.Name, "group_id": group.ID,
"is_service_user": user.IsServiceUser, "user_name": user.ServiceUserName,
}
am.StoreEvent(ctx, user.Id, user.Id, accountID, activity.GroupRemovedFromUser, meta)
}
}
return nil
})
if err != nil {
return err
}
} }
return nil return am.Store.ExecuteInTransaction(ctx, func(transaction Store) error {
oldGroups := make([]string, len(user.AutoGroups))
copy(oldGroups, user.AutoGroups)
addNewGroups := difference(user.AutoGroups, oldGroups)
removeOldGroups := difference(oldGroups, user.AutoGroups)
user.AutoGroups = updatedAutoGroups
if err = transaction.SaveUser(ctx, LockingStrengthUpdate, user); err != nil {
return fmt.Errorf("error saving user: %w", err)
}
if err = transaction.SaveGroups(ctx, LockingStrengthUpdate, newGroupsToCreate); err != nil {
return fmt.Errorf("error saving groups: %w", err)
}
// Propagate changes to peers if group propagation is enabled
if settings.GroupsPropagationEnabled {
if err = transaction.AddUserPeersToGroups(ctx, accountID, claims.UserId, addNewGroups); err != nil {
return fmt.Errorf("error adding user peers to groups: %w", err)
}
if err = transaction.RemoveUserPeersFromGroups(ctx, accountID, claims.UserId, removeOldGroups); err != nil {
return fmt.Errorf("error removing user peers from groups: %w", err)
}
if err = transaction.IncrementNetworkSerial(ctx, accountID); err != nil {
return fmt.Errorf("error incrementing network serial: %w", err)
}
account, err := am.requestBuffer.GetAccountWithBackpressure(ctx, accountID)
if err != nil {
return fmt.Errorf("error getting account: %w", err)
}
log.WithContext(ctx).Tracef("user %s: JWT group membership changed, updating account peers", claims.UserId)
am.updateAccountPeers(ctx, account)
}
for _, g := range addNewGroups {
group, err := transaction.GetGroupByID(ctx, LockingStrengthShare, g, accountID)
if err != nil {
log.WithContext(ctx).Debugf("group %s not found while saving user activity event of account %s", g, accountID)
} else {
meta := map[string]any{
"group": group.Name, "group_id": group.ID,
"is_service_user": user.IsServiceUser, "user_name": user.ServiceUserName,
}
am.StoreEvent(ctx, user.Id, user.Id, accountID, activity.GroupAddedToUser, meta)
}
}
for _, g := range removeOldGroups {
group, err := transaction.GetGroupByID(ctx, LockingStrengthShare, g, accountID)
if err != nil {
log.WithContext(ctx).Debugf("group %s not found while saving user activity event of account %s", g, accountID)
} else {
meta := map[string]any{
"group": group.Name, "group_id": group.ID,
"is_service_user": user.IsServiceUser, "user_name": user.ServiceUserName,
}
am.StoreEvent(ctx, user.Id, user.Id, accountID, activity.GroupRemovedFromUser, meta)
}
}
return nil
})
} }
// getAccountIDWithAuthorizationClaims retrieves an account ID using JWT Claims. // getAccountIDWithAuthorizationClaims retrieves an account ID using JWT Claims.

View File

@ -1050,3 +1050,11 @@ func (s *FileStore) SaveUser(_ context.Context, _ LockingStrength, _ *User) erro
func (s *FileStore) SaveGroup(_ context.Context, _ LockingStrength, _ *nbgroup.Group) error { func (s *FileStore) SaveGroup(_ context.Context, _ LockingStrength, _ *nbgroup.Group) error {
return status.Errorf(status.Internal, "SaveGroup is not implemented") return status.Errorf(status.Internal, "SaveGroup is not implemented")
} }
func (s *FileStore) AddUserPeersToGroups(_ context.Context, _ string, _ string, _ []string) error {
return status.Errorf(status.Internal, "AddUserPeersToGroups is not implemented")
}
func (s *FileStore) RemoveUserPeersFromGroups(_ context.Context, _ string, _ string, _ []string) error {
return status.Errorf(status.Internal, "RemoveUserPeersFromGroups is not implemented")
}

View File

@ -27,7 +27,7 @@ type MockAccountManager struct {
CreateSetupKeyFunc func(ctx context.Context, accountId string, keyName string, keyType server.SetupKeyType, CreateSetupKeyFunc func(ctx context.Context, accountId string, keyName string, keyType server.SetupKeyType,
expiresIn time.Duration, autoGroups []string, usageLimit int, userID string, ephemeral bool) (*server.SetupKey, error) expiresIn time.Duration, autoGroups []string, usageLimit int, userID string, ephemeral bool) (*server.SetupKey, error)
GetSetupKeyFunc func(ctx context.Context, accountID, userID, keyID string) (*server.SetupKey, error) GetSetupKeyFunc func(ctx context.Context, accountID, userID, keyID string) (*server.SetupKey, error)
GetAccountIDByUserOrAccountIdFunc func(ctx context.Context, userId, accountId, domain string) (string, error) GetAccountIDByUserIdFunc func(ctx context.Context, userId, domain string) (string, error)
GetUserFunc func(ctx context.Context, claims jwtclaims.AuthorizationClaims) (*server.User, error) GetUserFunc func(ctx context.Context, claims jwtclaims.AuthorizationClaims) (*server.User, error)
ListUsersFunc func(ctx context.Context, accountID string) ([]*server.User, error) ListUsersFunc func(ctx context.Context, accountID string) ([]*server.User, error)
GetPeersFunc func(ctx context.Context, accountID, userID string) ([]*nbpeer.Peer, error) GetPeersFunc func(ctx context.Context, accountID, userID string) ([]*nbpeer.Peer, error)
@ -194,10 +194,10 @@ func (am *MockAccountManager) CreateSetupKey(
return nil, status.Errorf(codes.Unimplemented, "method CreateSetupKey is not implemented") return nil, status.Errorf(codes.Unimplemented, "method CreateSetupKey is not implemented")
} }
// GetAccountIDByUserOrAccountID mock implementation of GetAccountIDByUserOrAccountID from server.AccountManager interface // GetAccountIDByUserID mock implementation of GetAccountIDByUserID from server.AccountManager interface
func (am *MockAccountManager) GetAccountIDByUserOrAccountID(ctx context.Context, userId, accountId, domain string) (string, error) { func (am *MockAccountManager) GetAccountIDByUserID(ctx context.Context, userId, domain string) (string, error) {
if am.GetAccountIDByUserOrAccountIdFunc != nil { if am.GetAccountIDByUserIdFunc != nil {
return am.GetAccountIDByUserOrAccountIdFunc(ctx, userId, accountId, domain) return am.GetAccountIDByUserIdFunc(ctx, userId, domain)
} }
return "", status.Errorf( return "", status.Errorf(
codes.Unimplemented, codes.Unimplemented,

View File

@ -10,6 +10,7 @@ import (
"path/filepath" "path/filepath"
"runtime" "runtime"
"runtime/debug" "runtime/debug"
"slices"
"strings" "strings"
"sync" "sync"
"time" "time"
@ -389,6 +390,10 @@ func (s *SqlStore) SaveUser(ctx context.Context, lockStrength LockingStrength, u
// SaveGroups saves the given list of groups to the database. // SaveGroups saves the given list of groups to the database.
func (s *SqlStore) SaveGroups(ctx context.Context, lockStrength LockingStrength, groups []*nbgroup.Group) error { func (s *SqlStore) SaveGroups(ctx context.Context, lockStrength LockingStrength, groups []*nbgroup.Group) error {
if len(groups) == 0 {
return nil
}
result := s.db.WithContext(ctx).Clauses(clause.Locking{Strength: string(lockStrength)}).Save(&groups) result := s.db.WithContext(ctx).Clauses(clause.Locking{Strength: string(lockStrength)}).Save(&groups)
if result.Error != nil { if result.Error != nil {
return status.Errorf(status.Internal, "failed to save groups to store: %v", result.Error) return status.Errorf(status.Internal, "failed to save groups to store: %v", result.Error)
@ -1006,6 +1011,84 @@ func (s *SqlStore) AddPeerToGroup(ctx context.Context, accountId string, peerId
return nil return nil
} }
// AddUserPeersToGroups adds the user's peers to specified groups in database.
func (s *SqlStore) AddUserPeersToGroups(ctx context.Context, accountID string, userID string, groupIDs []string) error {
if len(groupIDs) == 0 {
return nil
}
var userPeerIDs []string
result := s.db.WithContext(ctx).Clauses(clause.Locking{Strength: string(LockingStrengthShare)}).Select("id").
Where("account_id = ? AND user_id = ?", accountID, userID).Model(&nbpeer.Peer{}).Find(&userPeerIDs)
if result.Error != nil {
return status.Errorf(status.Internal, "issue finding user peers")
}
groupsToUpdate := make([]*nbgroup.Group, 0, len(groupIDs))
for _, gid := range groupIDs {
group, err := s.GetGroupByID(ctx, LockingStrengthShare, gid, accountID)
if err != nil {
return err
}
groupPeers := make(map[string]struct{})
for _, pid := range group.Peers {
groupPeers[pid] = struct{}{}
}
for _, pid := range userPeerIDs {
groupPeers[pid] = struct{}{}
}
group.Peers = group.Peers[:0]
for pid := range groupPeers {
group.Peers = append(group.Peers, pid)
}
groupsToUpdate = append(groupsToUpdate, group)
}
return s.SaveGroups(ctx, LockingStrengthUpdate, groupsToUpdate)
}
// RemoveUserPeersFromGroups removes the user's peers from specified groups in database.
func (s *SqlStore) RemoveUserPeersFromGroups(ctx context.Context, accountID string, userID string, groupIDs []string) error {
if len(groupIDs) == 0 {
return nil
}
var userPeerIDs []string
result := s.db.WithContext(ctx).Clauses(clause.Locking{Strength: string(LockingStrengthShare)}).Select("id").
Where("account_id = ? AND user_id = ?", accountID, userID).Model(&nbpeer.Peer{}).Find(&userPeerIDs)
if result.Error != nil {
return status.Errorf(status.Internal, "issue finding user peers")
}
groupsToUpdate := make([]*nbgroup.Group, 0, len(groupIDs))
for _, gid := range groupIDs {
group, err := s.GetGroupByID(ctx, LockingStrengthShare, gid, accountID)
if err != nil {
return err
}
if group.Name == "All" {
continue
}
update := make([]string, 0, len(group.Peers))
for _, pid := range group.Peers {
if !slices.Contains(userPeerIDs, pid) {
update = append(update, pid)
}
}
group.Peers = update
groupsToUpdate = append(groupsToUpdate, group)
}
return s.SaveGroups(ctx, LockingStrengthUpdate, groupsToUpdate)
}
func (s *SqlStore) AddPeerToAccount(ctx context.Context, peer *nbpeer.Peer) error { func (s *SqlStore) AddPeerToAccount(ctx context.Context, peer *nbpeer.Peer) error {
if err := s.db.WithContext(ctx).Create(peer).Error; err != nil { if err := s.db.WithContext(ctx).Create(peer).Error; err != nil {
return status.Errorf(status.Internal, "issue adding peer to account") return status.Errorf(status.Internal, "issue adding peer to account")

View File

@ -81,6 +81,8 @@ type Store interface {
GetPeerLabelsInAccount(ctx context.Context, lockStrength LockingStrength, accountId string) ([]string, error) GetPeerLabelsInAccount(ctx context.Context, lockStrength LockingStrength, accountId string) ([]string, error)
AddPeerToAllGroup(ctx context.Context, accountID string, peerID string) error AddPeerToAllGroup(ctx context.Context, accountID string, peerID string) error
AddPeerToGroup(ctx context.Context, accountId string, peerId string, groupID string) error AddPeerToGroup(ctx context.Context, accountId string, peerId string, groupID string) error
AddUserPeersToGroups(ctx context.Context, accountID string, userID string, groupIDs []string) error
RemoveUserPeersFromGroups(ctx context.Context, accountID string, userID string, groupIDs []string) error
AddPeerToAccount(ctx context.Context, peer *nbpeer.Peer) error AddPeerToAccount(ctx context.Context, peer *nbpeer.Peer) error
GetPeerByPeerPubKey(ctx context.Context, lockStrength LockingStrength, peerKey string) (*nbpeer.Peer, error) GetPeerByPeerPubKey(ctx context.Context, lockStrength LockingStrength, peerKey string) (*nbpeer.Peer, error)
SavePeer(ctx context.Context, accountID string, peer *nbpeer.Peer) error SavePeer(ctx context.Context, accountID string, peer *nbpeer.Peer) error