[management] Refactor User JWT group sync (#2690)

* Refactor GetAccountIDByUserOrAccountID

Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>

* sync user jwt group changes

Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>

* propagate jwt group changes to peers

Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>

* fix no jwt groups synced

Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>

* fix tests and lint

Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>

* Move the account peer update outside the transaction

Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>

* move updateUserPeersInGroups to account manager

Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>

* move event store outside of transaction

Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>

* get user with update lock

Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>

* Run jwt sync in transaction

Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>

---------

Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>
This commit is contained in:
Bethuel Mmbaga 2024-10-04 17:17:01 +03:00 committed by GitHub
parent 158936fb15
commit 7f09b39769
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
8 changed files with 520 additions and 177 deletions

View File

@ -76,7 +76,7 @@ type AccountManager interface {
SaveOrAddUsers(ctx context.Context, accountID, initiatorUserID string, updates []*User, addIfNotExists bool) ([]*UserInfo, error) SaveOrAddUsers(ctx context.Context, accountID, initiatorUserID string, updates []*User, addIfNotExists bool) ([]*UserInfo, error)
GetSetupKey(ctx context.Context, accountID, userID, keyID string) (*SetupKey, error) GetSetupKey(ctx context.Context, accountID, userID, keyID string) (*SetupKey, error)
GetAccountByID(ctx context.Context, accountID string, userID string) (*Account, error) GetAccountByID(ctx context.Context, accountID string, userID string) (*Account, error)
GetAccountIDByUserOrAccountID(ctx context.Context, userID, accountID, domain string) (string, error) GetAccountIDByUserID(ctx context.Context, userID, domain string) (string, error)
GetAccountIDFromToken(ctx context.Context, claims jwtclaims.AuthorizationClaims) (string, string, error) GetAccountIDFromToken(ctx context.Context, claims jwtclaims.AuthorizationClaims) (string, string, error)
CheckUserAccessByJWTGroups(ctx context.Context, claims jwtclaims.AuthorizationClaims) error CheckUserAccessByJWTGroups(ctx context.Context, claims jwtclaims.AuthorizationClaims) error
GetAccountFromPAT(ctx context.Context, pat string) (*Account, *User, *PersonalAccessToken, error) GetAccountFromPAT(ctx context.Context, pat string) (*Account, *User, *PersonalAccessToken, error)
@ -843,55 +843,54 @@ func (a *Account) GetPeer(peerID string) *nbpeer.Peer {
return a.Peers[peerID] return a.Peers[peerID]
} }
// SetJWTGroups updates the user's auto groups by synchronizing JWT groups. // getJWTGroupsChanges calculates the changes needed to sync a user's JWT groups.
// Returns true if there are changes in the JWT group membership. // Returns a bool indicating if there are changes in the JWT group membership, the updated user AutoGroups,
func (a *Account) SetJWTGroups(userID string, groupsNames []string) bool { // newly groups to create and an error if any occurred.
user, ok := a.Users[userID] func (am *DefaultAccountManager) getJWTGroupsChanges(user *User, groups []*nbgroup.Group, groupNames []string) (bool, []string, []*nbgroup.Group, error) {
if !ok {
return false
}
existedGroupsByName := make(map[string]*nbgroup.Group) existedGroupsByName := make(map[string]*nbgroup.Group)
for _, group := range a.Groups { for _, group := range groups {
existedGroupsByName[group.Name] = group existedGroupsByName[group.Name] = group
} }
newAutoGroups, jwtGroupsMap := separateGroups(user.AutoGroups, a.Groups) newUserAutoGroups, jwtGroupsMap := separateGroups(user.AutoGroups, groups)
groupsToAdd := difference(groupsNames, maps.Keys(jwtGroupsMap))
groupsToRemove := difference(maps.Keys(jwtGroupsMap), groupsNames) groupsToAdd := difference(groupNames, maps.Keys(jwtGroupsMap))
groupsToRemove := difference(maps.Keys(jwtGroupsMap), groupNames)
// If no groups are added or removed, we should not sync account // If no groups are added or removed, we should not sync account
if len(groupsToAdd) == 0 && len(groupsToRemove) == 0 { if len(groupsToAdd) == 0 && len(groupsToRemove) == 0 {
return false return false, nil, nil, nil
} }
newGroupsToCreate := make([]*nbgroup.Group, 0)
var modified bool var modified bool
for _, name := range groupsToAdd { for _, name := range groupsToAdd {
group, exists := existedGroupsByName[name] group, exists := existedGroupsByName[name]
if !exists { if !exists {
group = &nbgroup.Group{ group = &nbgroup.Group{
ID: xid.New().String(), ID: xid.New().String(),
AccountID: user.AccountID,
Name: name, Name: name,
Issued: nbgroup.GroupIssuedJWT, Issued: nbgroup.GroupIssuedJWT,
} }
a.Groups[group.ID] = group newGroupsToCreate = append(newGroupsToCreate, group)
} }
if group.Issued == nbgroup.GroupIssuedJWT { if group.Issued == nbgroup.GroupIssuedJWT {
newAutoGroups = append(newAutoGroups, group.ID) newUserAutoGroups = append(newUserAutoGroups, group.ID)
modified = true modified = true
} }
} }
for name, id := range jwtGroupsMap { for name, id := range jwtGroupsMap {
if !slices.Contains(groupsToRemove, name) { if !slices.Contains(groupsToRemove, name) {
newAutoGroups = append(newAutoGroups, id) newUserAutoGroups = append(newUserAutoGroups, id)
continue continue
} }
modified = true modified = true
} }
user.AutoGroups = newAutoGroups
return modified return modified, newUserAutoGroups, newGroupsToCreate, nil
} }
// UserGroupsAddToPeers adds groups to all peers of user // UserGroupsAddToPeers adds groups to all peers of user
@ -1262,24 +1261,18 @@ func (am *DefaultAccountManager) DeleteAccount(ctx context.Context, accountID, u
return nil return nil
} }
// GetAccountIDByUserOrAccountID retrieves the account ID based on either the userID or accountID provided. // GetAccountIDByUserID retrieves the account ID based on the userID provided.
// If an accountID is provided, it checks if the account exists and returns it. // If user does have an account, it returns the user's account ID.
// If no accountID is provided, but a userID is given, it tries to retrieve the account by userID.
// If the user doesn't have an account, it creates one using the provided domain. // If the user doesn't have an account, it creates one using the provided domain.
// Returns the account ID or an error if none is found or created. // Returns the account ID or an error if none is found or created.
func (am *DefaultAccountManager) GetAccountIDByUserOrAccountID(ctx context.Context, userID, accountID, domain string) (string, error) { func (am *DefaultAccountManager) GetAccountIDByUserID(ctx context.Context, userID, domain string) (string, error) {
if accountID != "" { if userID == "" {
exists, err := am.Store.AccountExists(ctx, LockingStrengthShare, accountID) return "", status.Errorf(status.NotFound, "no valid userID provided")
if err != nil {
return "", err
}
if !exists {
return "", status.Errorf(status.NotFound, "account %s does not exist", accountID)
}
return accountID, nil
} }
if userID != "" { accountID, err := am.Store.GetAccountIDByUserID(userID)
if err != nil {
if s, ok := status.FromError(err); ok && s.Type() == status.NotFound {
account, err := am.GetOrCreateAccountByUser(ctx, userID, domain) account, err := am.GetOrCreateAccountByUser(ctx, userID, domain)
if err != nil { if err != nil {
return "", status.Errorf(status.NotFound, "account not found or created for user id: %s", userID) return "", status.Errorf(status.NotFound, "account not found or created for user id: %s", userID)
@ -1288,11 +1281,11 @@ func (am *DefaultAccountManager) GetAccountIDByUserOrAccountID(ctx context.Conte
if err = am.addAccountIDToIDPAppMeta(ctx, userID, account); err != nil { if err = am.addAccountIDToIDPAppMeta(ctx, userID, account); err != nil {
return "", err return "", err
} }
return account.Id, nil return account.Id, nil
} }
return "", err
return "", status.Errorf(status.NotFound, "no valid userID or accountID provided") }
return accountID, nil
} }
func isNil(i idp.Manager) bool { func isNil(i idp.Manager) bool {
@ -1796,6 +1789,10 @@ func (am *DefaultAccountManager) GetAccountIDFromToken(ctx context.Context, clai
return "", "", status.Errorf(status.NotFound, "user %s not found", claims.UserId) return "", "", status.Errorf(status.NotFound, "user %s not found", claims.UserId)
} }
if user.AccountID != accountID {
return "", "", status.Errorf(status.PermissionDenied, "user %s is not part of the account %s", claims.UserId, accountID)
}
if !user.IsServiceUser && claims.Invited { if !user.IsServiceUser && claims.Invited {
err = am.redeemInvite(ctx, accountID, user.Id) err = am.redeemInvite(ctx, accountID, user.Id)
if err != nil { if err != nil {
@ -1803,7 +1800,7 @@ func (am *DefaultAccountManager) GetAccountIDFromToken(ctx context.Context, clai
} }
} }
if err = am.syncJWTGroups(ctx, accountID, user, claims); err != nil { if err = am.syncJWTGroups(ctx, accountID, claims); err != nil {
return "", "", err return "", "", err
} }
@ -1812,7 +1809,7 @@ func (am *DefaultAccountManager) GetAccountIDFromToken(ctx context.Context, clai
// syncJWTGroups processes the JWT groups for a user, updates the account based on the groups, // syncJWTGroups processes the JWT groups for a user, updates the account based on the groups,
// and propagates changes to peers if group propagation is enabled. // and propagates changes to peers if group propagation is enabled.
func (am *DefaultAccountManager) syncJWTGroups(ctx context.Context, accountID string, user *User, claims jwtclaims.AuthorizationClaims) error { func (am *DefaultAccountManager) syncJWTGroups(ctx context.Context, accountID string, claims jwtclaims.AuthorizationClaims) error {
settings, err := am.Store.GetAccountSettings(ctx, LockingStrengthShare, accountID) settings, err := am.Store.GetAccountSettings(ctx, LockingStrengthShare, accountID)
if err != nil { if err != nil {
return err return err
@ -1823,67 +1820,134 @@ func (am *DefaultAccountManager) syncJWTGroups(ctx context.Context, accountID st
} }
if settings.JWTGroupsClaimName == "" { if settings.JWTGroupsClaimName == "" {
log.WithContext(ctx).Errorf("JWT groups are enabled but no claim name is set") log.WithContext(ctx).Debugf("JWT groups are enabled but no claim name is set")
return nil return nil
} }
// TODO: Remove GetAccount after refactoring account peer's update
unlock := am.Store.AcquireWriteLockByUID(ctx, accountID)
defer unlock()
account, err := am.Store.GetAccount(ctx, accountID)
if err != nil {
return err
}
jwtGroupsNames := extractJWTGroups(ctx, settings.JWTGroupsClaimName, claims) jwtGroupsNames := extractJWTGroups(ctx, settings.JWTGroupsClaimName, claims)
oldGroups := make([]string, len(user.AutoGroups)) unlockPeer := am.Store.AcquireWriteLockByUID(ctx, accountID)
copy(oldGroups, user.AutoGroups) defer func() {
if unlockPeer != nil {
unlockPeer()
}
}()
// Update the account if group membership changes var addNewGroups []string
if account.SetJWTGroups(claims.UserId, jwtGroupsNames) { var removeOldGroups []string
addNewGroups := difference(user.AutoGroups, oldGroups) var hasChanges bool
removeOldGroups := difference(oldGroups, user.AutoGroups) var user *User
err = am.Store.ExecuteInTransaction(ctx, func(transaction Store) error {
if settings.GroupsPropagationEnabled { user, err = am.Store.GetUserByUserID(ctx, LockingStrengthShare, claims.UserId)
account.UserGroupsAddToPeers(claims.UserId, addNewGroups...) if err != nil {
account.UserGroupsRemoveFromPeers(claims.UserId, removeOldGroups...) return fmt.Errorf("error getting user: %w", err)
account.Network.IncSerial()
} }
if err := am.Store.SaveAccount(ctx, account); err != nil { groups, err := am.Store.GetAccountGroups(ctx, accountID)
log.WithContext(ctx).Errorf("failed to save account: %v", err) if err != nil {
return fmt.Errorf("error getting account groups: %w", err)
}
changed, updatedAutoGroups, newGroupsToCreate, err := am.getJWTGroupsChanges(user, groups, jwtGroupsNames)
if err != nil {
return fmt.Errorf("error getting JWT groups changes: %w", err)
}
hasChanges = changed
// skip update if no changes
if !changed {
return nil return nil
} }
if err = transaction.SaveGroups(ctx, LockingStrengthUpdate, newGroupsToCreate); err != nil {
return fmt.Errorf("error saving groups: %w", err)
}
addNewGroups = difference(updatedAutoGroups, user.AutoGroups)
removeOldGroups = difference(user.AutoGroups, updatedAutoGroups)
user.AutoGroups = updatedAutoGroups
if err = transaction.SaveUser(ctx, LockingStrengthUpdate, user); err != nil {
return fmt.Errorf("error saving user: %w", err)
}
// Propagate changes to peers if group propagation is enabled // Propagate changes to peers if group propagation is enabled
if settings.GroupsPropagationEnabled { if settings.GroupsPropagationEnabled {
log.WithContext(ctx).Tracef("user %s: JWT group membership changed, updating account peers", claims.UserId) groups, err = transaction.GetAccountGroups(ctx, accountID)
am.updateAccountPeers(ctx, account) if err != nil {
return fmt.Errorf("error getting account groups: %w", err)
}
groupsMap := make(map[string]*nbgroup.Group, len(groups))
for _, group := range groups {
groupsMap[group.ID] = group
}
peers, err := transaction.GetUserPeers(ctx, LockingStrengthShare, accountID, claims.UserId)
if err != nil {
return fmt.Errorf("error getting user peers: %w", err)
}
updatedGroups, err := am.updateUserPeersInGroups(groupsMap, peers, addNewGroups, removeOldGroups)
if err != nil {
return fmt.Errorf("error modifying user peers in groups: %w", err)
}
if err = transaction.SaveGroups(ctx, LockingStrengthUpdate, updatedGroups); err != nil {
return fmt.Errorf("error saving groups: %w", err)
}
if err = transaction.IncrementNetworkSerial(ctx, accountID); err != nil {
return fmt.Errorf("error incrementing network serial: %w", err)
}
}
unlockPeer()
unlockPeer = nil
return nil
})
if err != nil {
return err
}
if !hasChanges {
return nil
} }
for _, g := range addNewGroups { for _, g := range addNewGroups {
if group := account.GetGroup(g); group != nil { group, err := am.Store.GetGroupByID(ctx, LockingStrengthShare, g, accountID)
am.StoreEvent(ctx, user.Id, user.Id, account.Id, activity.GroupAddedToUser, if err != nil {
map[string]any{ log.WithContext(ctx).Debugf("group %s not found while saving user activity event of account %s", g, accountID)
"group": group.Name, } else {
"group_id": group.ID, meta := map[string]any{
"is_service_user": user.IsServiceUser, "group": group.Name, "group_id": group.ID,
"user_name": user.ServiceUserName}) "is_service_user": user.IsServiceUser, "user_name": user.ServiceUserName,
}
am.StoreEvent(ctx, user.Id, user.Id, accountID, activity.GroupAddedToUser, meta)
} }
} }
for _, g := range removeOldGroups { for _, g := range removeOldGroups {
if group := account.GetGroup(g); group != nil { group, err := am.Store.GetGroupByID(ctx, LockingStrengthShare, g, accountID)
am.StoreEvent(ctx, user.Id, user.Id, account.Id, activity.GroupRemovedFromUser, if err != nil {
map[string]any{ log.WithContext(ctx).Debugf("group %s not found while saving user activity event of account %s", g, accountID)
"group": group.Name, } else {
"group_id": group.ID, meta := map[string]any{
"is_service_user": user.IsServiceUser, "group": group.Name, "group_id": group.ID,
"user_name": user.ServiceUserName}) "is_service_user": user.IsServiceUser, "user_name": user.ServiceUserName,
}
am.StoreEvent(ctx, user.Id, user.Id, accountID, activity.GroupRemovedFromUser, meta)
} }
} }
if settings.GroupsPropagationEnabled {
account, err := am.requestBuffer.GetAccountWithBackpressure(ctx, accountID)
if err != nil {
return fmt.Errorf("error getting account: %w", err)
}
log.WithContext(ctx).Tracef("user %s: JWT group membership changed, updating account peers", claims.UserId)
am.updateAccountPeers(ctx, account)
} }
return nil return nil
@ -1916,7 +1980,17 @@ func (am *DefaultAccountManager) getAccountIDWithAuthorizationClaims(ctx context
// if Account ID is part of the claims // if Account ID is part of the claims
// it means that we've already classified the domain and user has an account // it means that we've already classified the domain and user has an account
if claims.DomainCategory != PrivateCategory || !isDomainValid(claims.Domain) { if claims.DomainCategory != PrivateCategory || !isDomainValid(claims.Domain) {
return am.GetAccountIDByUserOrAccountID(ctx, claims.UserId, claims.AccountId, claims.Domain) if claims.AccountId != "" {
exists, err := am.Store.AccountExists(ctx, LockingStrengthShare, claims.AccountId)
if err != nil {
return "", err
}
if !exists {
return "", status.Errorf(status.NotFound, "account %s does not exist", claims.AccountId)
}
return claims.AccountId, nil
}
return am.GetAccountIDByUserID(ctx, claims.UserId, claims.Domain)
} else if claims.AccountId != "" { } else if claims.AccountId != "" {
userAccountID, err := am.Store.GetAccountIDByUserID(claims.UserId) userAccountID, err := am.Store.GetAccountIDByUserID(claims.UserId)
if err != nil { if err != nil {
@ -2229,7 +2303,11 @@ func newAccountWithId(ctx context.Context, accountID, userID, domain string) *Ac
routes := make(map[route.ID]*route.Route) routes := make(map[route.ID]*route.Route)
setupKeys := map[string]*SetupKey{} setupKeys := map[string]*SetupKey{}
nameServersGroups := make(map[string]*nbdns.NameServerGroup) nameServersGroups := make(map[string]*nbdns.NameServerGroup)
users[userID] = NewOwnerUser(userID)
owner := NewOwnerUser(userID)
owner.AccountID = accountID
users[userID] = owner
dnsSettings := DNSSettings{ dnsSettings := DNSSettings{
DisabledManagementGroups: make([]string, 0), DisabledManagementGroups: make([]string, 0),
} }
@ -2297,12 +2375,17 @@ func userHasAllowedGroup(allowedGroups []string, userGroups []string) bool {
// separateGroups separates user's auto groups into non-JWT and JWT groups. // separateGroups separates user's auto groups into non-JWT and JWT groups.
// Returns the list of standard auto groups and a map of JWT auto groups, // Returns the list of standard auto groups and a map of JWT auto groups,
// where the keys are the group names and the values are the group IDs. // where the keys are the group names and the values are the group IDs.
func separateGroups(autoGroups []string, allGroups map[string]*nbgroup.Group) ([]string, map[string]string) { func separateGroups(autoGroups []string, allGroups []*nbgroup.Group) ([]string, map[string]string) {
newAutoGroups := make([]string, 0) newAutoGroups := make([]string, 0)
jwtAutoGroups := make(map[string]string) // map of group name to group ID jwtAutoGroups := make(map[string]string) // map of group name to group ID
allGroupsMap := make(map[string]*nbgroup.Group, len(allGroups))
for _, group := range allGroups {
allGroupsMap[group.ID] = group
}
for _, id := range autoGroups { for _, id := range autoGroups {
if group, ok := allGroups[id]; ok { if group, ok := allGroupsMap[id]; ok {
if group.Issued == nbgroup.GroupIssuedJWT { if group.Issued == nbgroup.GroupIssuedJWT {
jwtAutoGroups[group.Name] = id jwtAutoGroups[group.Name] = id
} else { } else {
@ -2310,5 +2393,6 @@ func separateGroups(autoGroups []string, allGroups map[string]*nbgroup.Group) ([
} }
} }
} }
return newAutoGroups, jwtAutoGroups return newAutoGroups, jwtAutoGroups
} }

View File

@ -633,7 +633,7 @@ func TestDefaultAccountManager_GetAccountIDFromToken(t *testing.T) {
manager, err := createManager(t) manager, err := createManager(t)
require.NoError(t, err, "unable to create account manager") require.NoError(t, err, "unable to create account manager")
accountID, err := manager.GetAccountIDByUserOrAccountID(context.Background(), testCase.inputInitUserParams.UserId, testCase.inputInitUserParams.AccountId, testCase.inputInitUserParams.Domain) accountID, err := manager.GetAccountIDByUserID(context.Background(), testCase.inputInitUserParams.UserId, testCase.inputInitUserParams.Domain)
require.NoError(t, err, "create init user failed") require.NoError(t, err, "create init user failed")
initAccount, err := manager.Store.GetAccount(context.Background(), accountID) initAccount, err := manager.Store.GetAccount(context.Background(), accountID)
@ -671,17 +671,16 @@ func TestDefaultAccountManager_GetGroupsFromTheToken(t *testing.T) {
userId := "user-id" userId := "user-id"
domain := "test.domain" domain := "test.domain"
initAccount := newAccountWithId(context.Background(), "", userId, domain) _ = newAccountWithId(context.Background(), "", userId, domain)
manager, err := createManager(t) manager, err := createManager(t)
require.NoError(t, err, "unable to create account manager") require.NoError(t, err, "unable to create account manager")
accountID := initAccount.Id accountID, err := manager.GetAccountIDByUserID(context.Background(), userId, domain)
accountID, err = manager.GetAccountIDByUserOrAccountID(context.Background(), userId, accountID, domain)
require.NoError(t, err, "create init user failed") require.NoError(t, err, "create init user failed")
// as initAccount was created without account id we have to take the id after account initialization // as initAccount was created without account id we have to take the id after account initialization
// that happens inside the GetAccountIDByUserOrAccountID where the id is getting generated // that happens inside the GetAccountIDByUserID where the id is getting generated
// it is important to set the id as it help to avoid creating additional account with empty Id and re-pointing indices to it // it is important to set the id as it help to avoid creating additional account with empty Id and re-pointing indices to it
initAccount, err = manager.Store.GetAccount(context.Background(), accountID) initAccount, err := manager.Store.GetAccount(context.Background(), accountID)
require.NoError(t, err, "get init account failed") require.NoError(t, err, "get init account failed")
claims := jwtclaims.AuthorizationClaims{ claims := jwtclaims.AuthorizationClaims{
@ -885,7 +884,7 @@ func TestAccountManager_SetOrUpdateDomain(t *testing.T) {
} }
} }
func TestAccountManager_GetAccountByUserOrAccountId(t *testing.T) { func TestAccountManager_GetAccountByUserID(t *testing.T) {
manager, err := createManager(t) manager, err := createManager(t)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
@ -894,7 +893,7 @@ func TestAccountManager_GetAccountByUserOrAccountId(t *testing.T) {
userId := "test_user" userId := "test_user"
accountID, err := manager.GetAccountIDByUserOrAccountID(context.Background(), userId, "", "") accountID, err := manager.GetAccountIDByUserID(context.Background(), userId, "")
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
@ -903,14 +902,13 @@ func TestAccountManager_GetAccountByUserOrAccountId(t *testing.T) {
return return
} }
_, err = manager.GetAccountIDByUserOrAccountID(context.Background(), "", accountID, "") exists, err := manager.Store.AccountExists(context.Background(), LockingStrengthShare, accountID)
if err != nil { assert.NoError(t, err)
t.Errorf("expected to get existing account after creation using userid, no account was found for a account %s", accountID) assert.True(t, exists, "expected to get existing account after creation using userid")
}
_, err = manager.GetAccountIDByUserOrAccountID(context.Background(), "", "", "") _, err = manager.GetAccountIDByUserID(context.Background(), "", "")
if err == nil { if err == nil {
t.Errorf("expected an error when user and account IDs are empty") t.Errorf("expected an error when user ID is empty")
} }
} }
@ -1669,7 +1667,7 @@ func TestDefaultAccountManager_DefaultAccountSettings(t *testing.T) {
manager, err := createManager(t) manager, err := createManager(t)
require.NoError(t, err, "unable to create account manager") require.NoError(t, err, "unable to create account manager")
accountID, err := manager.GetAccountIDByUserOrAccountID(context.Background(), userID, "", "") accountID, err := manager.GetAccountIDByUserID(context.Background(), userID, "")
require.NoError(t, err, "unable to create an account") require.NoError(t, err, "unable to create an account")
settings, err := manager.Store.GetAccountSettings(context.Background(), LockingStrengthShare, accountID) settings, err := manager.Store.GetAccountSettings(context.Background(), LockingStrengthShare, accountID)
@ -1684,7 +1682,7 @@ func TestDefaultAccountManager_UpdatePeer_PeerLoginExpiration(t *testing.T) {
manager, err := createManager(t) manager, err := createManager(t)
require.NoError(t, err, "unable to create account manager") require.NoError(t, err, "unable to create account manager")
_, err = manager.GetAccountIDByUserOrAccountID(context.Background(), userID, "", "") _, err = manager.GetAccountIDByUserID(context.Background(), userID, "")
require.NoError(t, err, "unable to create an account") require.NoError(t, err, "unable to create an account")
key, err := wgtypes.GenerateKey() key, err := wgtypes.GenerateKey()
@ -1696,7 +1694,7 @@ func TestDefaultAccountManager_UpdatePeer_PeerLoginExpiration(t *testing.T) {
}) })
require.NoError(t, err, "unable to add peer") require.NoError(t, err, "unable to add peer")
accountID, err := manager.GetAccountIDByUserOrAccountID(context.Background(), userID, "", "") accountID, err := manager.GetAccountIDByUserID(context.Background(), userID, "")
require.NoError(t, err, "unable to get the account") require.NoError(t, err, "unable to get the account")
account, err := manager.Store.GetAccount(context.Background(), accountID) account, err := manager.Store.GetAccount(context.Background(), accountID)
@ -1742,7 +1740,7 @@ func TestDefaultAccountManager_MarkPeerConnected_PeerLoginExpiration(t *testing.
manager, err := createManager(t) manager, err := createManager(t)
require.NoError(t, err, "unable to create account manager") require.NoError(t, err, "unable to create account manager")
accountID, err := manager.GetAccountIDByUserOrAccountID(context.Background(), userID, "", "") accountID, err := manager.GetAccountIDByUserID(context.Background(), userID, "")
require.NoError(t, err, "unable to create an account") require.NoError(t, err, "unable to create an account")
key, err := wgtypes.GenerateKey() key, err := wgtypes.GenerateKey()
@ -1770,7 +1768,7 @@ func TestDefaultAccountManager_MarkPeerConnected_PeerLoginExpiration(t *testing.
}, },
} }
accountID, err = manager.GetAccountIDByUserOrAccountID(context.Background(), userID, "", "") accountID, err = manager.GetAccountIDByUserID(context.Background(), userID, "")
require.NoError(t, err, "unable to get the account") require.NoError(t, err, "unable to get the account")
account, err := manager.Store.GetAccount(context.Background(), accountID) account, err := manager.Store.GetAccount(context.Background(), accountID)
@ -1790,7 +1788,7 @@ func TestDefaultAccountManager_UpdateAccountSettings_PeerLoginExpiration(t *test
manager, err := createManager(t) manager, err := createManager(t)
require.NoError(t, err, "unable to create account manager") require.NoError(t, err, "unable to create account manager")
_, err = manager.GetAccountIDByUserOrAccountID(context.Background(), userID, "", "") _, err = manager.GetAccountIDByUserID(context.Background(), userID, "")
require.NoError(t, err, "unable to create an account") require.NoError(t, err, "unable to create an account")
key, err := wgtypes.GenerateKey() key, err := wgtypes.GenerateKey()
@ -1802,7 +1800,7 @@ func TestDefaultAccountManager_UpdateAccountSettings_PeerLoginExpiration(t *test
}) })
require.NoError(t, err, "unable to add peer") require.NoError(t, err, "unable to add peer")
accountID, err := manager.GetAccountIDByUserOrAccountID(context.Background(), userID, "", "") accountID, err := manager.GetAccountIDByUserID(context.Background(), userID, "")
require.NoError(t, err, "unable to get the account") require.NoError(t, err, "unable to get the account")
account, err := manager.Store.GetAccount(context.Background(), accountID) account, err := manager.Store.GetAccount(context.Background(), accountID)
@ -1850,7 +1848,7 @@ func TestDefaultAccountManager_UpdateAccountSettings(t *testing.T) {
manager, err := createManager(t) manager, err := createManager(t)
require.NoError(t, err, "unable to create account manager") require.NoError(t, err, "unable to create account manager")
accountID, err := manager.GetAccountIDByUserOrAccountID(context.Background(), userID, "", "") accountID, err := manager.GetAccountIDByUserID(context.Background(), userID, "")
require.NoError(t, err, "unable to create an account") require.NoError(t, err, "unable to create an account")
updated, err := manager.UpdateAccountSettings(context.Background(), accountID, userID, &Settings{ updated, err := manager.UpdateAccountSettings(context.Background(), accountID, userID, &Settings{
@ -1861,9 +1859,6 @@ func TestDefaultAccountManager_UpdateAccountSettings(t *testing.T) {
assert.False(t, updated.Settings.PeerLoginExpirationEnabled) assert.False(t, updated.Settings.PeerLoginExpirationEnabled)
assert.Equal(t, updated.Settings.PeerLoginExpiration, time.Hour) assert.Equal(t, updated.Settings.PeerLoginExpiration, time.Hour)
accountID, err = manager.GetAccountIDByUserOrAccountID(context.Background(), "", accountID, "")
require.NoError(t, err, "unable to get account by ID")
settings, err := manager.Store.GetAccountSettings(context.Background(), LockingStrengthShare, accountID) settings, err := manager.Store.GetAccountSettings(context.Background(), LockingStrengthShare, accountID)
require.NoError(t, err, "unable to get account settings") require.NoError(t, err, "unable to get account settings")
@ -2199,8 +2194,12 @@ func TestAccount_GetNextPeerExpiration(t *testing.T) {
} }
func TestAccount_SetJWTGroups(t *testing.T) { func TestAccount_SetJWTGroups(t *testing.T) {
manager, err := createManager(t)
require.NoError(t, err, "unable to create account manager")
// create a new account // create a new account
account := &Account{ account := &Account{
Id: "accountID",
Peers: map[string]*nbpeer.Peer{ Peers: map[string]*nbpeer.Peer{
"peer1": {ID: "peer1", Key: "key1", UserID: "user1"}, "peer1": {ID: "peer1", Key: "key1", UserID: "user1"},
"peer2": {ID: "peer2", Key: "key2", UserID: "user1"}, "peer2": {ID: "peer2", Key: "key2", UserID: "user1"},
@ -2211,62 +2210,120 @@ func TestAccount_SetJWTGroups(t *testing.T) {
Groups: map[string]*group.Group{ Groups: map[string]*group.Group{
"group1": {ID: "group1", Name: "group1", Issued: group.GroupIssuedAPI, Peers: []string{}}, "group1": {ID: "group1", Name: "group1", Issued: group.GroupIssuedAPI, Peers: []string{}},
}, },
Settings: &Settings{GroupsPropagationEnabled: true}, Settings: &Settings{GroupsPropagationEnabled: true, JWTGroupsEnabled: true, JWTGroupsClaimName: "groups"},
Users: map[string]*User{ Users: map[string]*User{
"user1": {Id: "user1"}, "user1": {Id: "user1", AccountID: "accountID"},
"user2": {Id: "user2"}, "user2": {Id: "user2", AccountID: "accountID"},
}, },
} }
assert.NoError(t, manager.Store.SaveAccount(context.Background(), account), "unable to save account")
t.Run("empty jwt groups", func(t *testing.T) { t.Run("empty jwt groups", func(t *testing.T) {
updated := account.SetJWTGroups("user1", []string{}) claims := jwtclaims.AuthorizationClaims{
assert.False(t, updated, "account should not be updated") UserId: "user1",
assert.Empty(t, account.Users["user1"].AutoGroups, "auto groups must be empty") Raw: jwt.MapClaims{"groups": []interface{}{}},
}
err := manager.syncJWTGroups(context.Background(), "accountID", claims)
assert.NoError(t, err, "unable to sync jwt groups")
user, err := manager.Store.GetUserByUserID(context.Background(), LockingStrengthShare, "user1")
assert.NoError(t, err, "unable to get user")
assert.Empty(t, user.AutoGroups, "auto groups must be empty")
}) })
t.Run("jwt match existing api group", func(t *testing.T) { t.Run("jwt match existing api group", func(t *testing.T) {
updated := account.SetJWTGroups("user1", []string{"group1"}) claims := jwtclaims.AuthorizationClaims{
assert.False(t, updated, "account should not be updated") UserId: "user1",
assert.Equal(t, 0, len(account.Users["user1"].AutoGroups)) Raw: jwt.MapClaims{"groups": []interface{}{"group1"}},
assert.Equal(t, account.Groups["group1"].Issued, group.GroupIssuedAPI, "group should be api issued") }
err := manager.syncJWTGroups(context.Background(), "accountID", claims)
assert.NoError(t, err, "unable to sync jwt groups")
user, err := manager.Store.GetUserByUserID(context.Background(), LockingStrengthShare, "user1")
assert.NoError(t, err, "unable to get user")
assert.Len(t, user.AutoGroups, 0)
group1, err := manager.Store.GetGroupByID(context.Background(), LockingStrengthShare, "group1", "accountID")
assert.NoError(t, err, "unable to get group")
assert.Equal(t, group1.Issued, group.GroupIssuedAPI, "group should be api issued")
}) })
t.Run("jwt match existing api group in user auto groups", func(t *testing.T) { t.Run("jwt match existing api group in user auto groups", func(t *testing.T) {
account.Users["user1"].AutoGroups = []string{"group1"} account.Users["user1"].AutoGroups = []string{"group1"}
assert.NoError(t, manager.Store.SaveUser(context.Background(), LockingStrengthUpdate, account.Users["user1"]))
updated := account.SetJWTGroups("user1", []string{"group1"}) claims := jwtclaims.AuthorizationClaims{
assert.False(t, updated, "account should not be updated") UserId: "user1",
assert.Equal(t, 1, len(account.Users["user1"].AutoGroups)) Raw: jwt.MapClaims{"groups": []interface{}{"group1"}},
assert.Equal(t, account.Groups["group1"].Issued, group.GroupIssuedAPI, "group should be api issued") }
err = manager.syncJWTGroups(context.Background(), "accountID", claims)
assert.NoError(t, err, "unable to sync jwt groups")
user, err := manager.Store.GetUserByUserID(context.Background(), LockingStrengthShare, "user1")
assert.NoError(t, err, "unable to get user")
assert.Len(t, user.AutoGroups, 1)
group1, err := manager.Store.GetGroupByID(context.Background(), LockingStrengthShare, "group1", "accountID")
assert.NoError(t, err, "unable to get group")
assert.Equal(t, group1.Issued, group.GroupIssuedAPI, "group should be api issued")
}) })
t.Run("add jwt group", func(t *testing.T) { t.Run("add jwt group", func(t *testing.T) {
updated := account.SetJWTGroups("user1", []string{"group1", "group2"}) claims := jwtclaims.AuthorizationClaims{
assert.True(t, updated, "account should be updated") UserId: "user1",
assert.Len(t, account.Groups, 2, "new group should be added") Raw: jwt.MapClaims{"groups": []interface{}{"group1", "group2"}},
assert.Len(t, account.Users["user1"].AutoGroups, 2, "new group should be added") }
assert.Contains(t, account.Groups, account.Users["user1"].AutoGroups[0], "groups must contain group2 from user groups") err = manager.syncJWTGroups(context.Background(), "accountID", claims)
assert.NoError(t, err, "unable to sync jwt groups")
user, err := manager.Store.GetUserByUserID(context.Background(), LockingStrengthShare, "user1")
assert.NoError(t, err, "unable to get user")
assert.Len(t, user.AutoGroups, 2, "groups count should not be change")
}) })
t.Run("existed group not update", func(t *testing.T) { t.Run("existed group not update", func(t *testing.T) {
updated := account.SetJWTGroups("user1", []string{"group2"}) claims := jwtclaims.AuthorizationClaims{
assert.False(t, updated, "account should not be updated") UserId: "user1",
assert.Len(t, account.Groups, 2, "groups count should not be changed") Raw: jwt.MapClaims{"groups": []interface{}{"group2"}},
}
err = manager.syncJWTGroups(context.Background(), "accountID", claims)
assert.NoError(t, err, "unable to sync jwt groups")
user, err := manager.Store.GetUserByUserID(context.Background(), LockingStrengthShare, "user1")
assert.NoError(t, err, "unable to get user")
assert.Len(t, user.AutoGroups, 2, "groups count should not be change")
}) })
t.Run("add new group", func(t *testing.T) { t.Run("add new group", func(t *testing.T) {
updated := account.SetJWTGroups("user2", []string{"group1", "group3"}) claims := jwtclaims.AuthorizationClaims{
assert.True(t, updated, "account should be updated") UserId: "user2",
assert.Len(t, account.Groups, 3, "new group should be added") Raw: jwt.MapClaims{"groups": []interface{}{"group1", "group3"}},
assert.Len(t, account.Users["user2"].AutoGroups, 1, "new group should be added") }
assert.Contains(t, account.Groups, account.Users["user2"].AutoGroups[0], "groups must contain group3 from user groups") err = manager.syncJWTGroups(context.Background(), "accountID", claims)
assert.NoError(t, err, "unable to sync jwt groups")
groups, err := manager.Store.GetAccountGroups(context.Background(), "accountID")
assert.NoError(t, err)
assert.Len(t, groups, 3, "new group3 should be added")
user, err := manager.Store.GetUserByUserID(context.Background(), LockingStrengthShare, "user2")
assert.NoError(t, err, "unable to get user")
assert.Len(t, user.AutoGroups, 1, "new group should be added")
}) })
t.Run("remove all JWT groups", func(t *testing.T) { t.Run("remove all JWT groups", func(t *testing.T) {
updated := account.SetJWTGroups("user1", []string{}) claims := jwtclaims.AuthorizationClaims{
assert.True(t, updated, "account should be updated") UserId: "user1",
assert.Len(t, account.Users["user1"].AutoGroups, 1, "only non-JWT groups should remain") Raw: jwt.MapClaims{"groups": []interface{}{}},
assert.Contains(t, account.Users["user1"].AutoGroups, "group1", " group1 should still be present") }
err = manager.syncJWTGroups(context.Background(), "accountID", claims)
assert.NoError(t, err, "unable to sync jwt groups")
user, err := manager.Store.GetUserByUserID(context.Background(), LockingStrengthShare, "user1")
assert.NoError(t, err, "unable to get user")
assert.Len(t, user.AutoGroups, 1, "only non-JWT groups should remain")
assert.Contains(t, user.AutoGroups, "group1", " group1 should still be present")
}) })
} }

View File

@ -27,7 +27,7 @@ type MockAccountManager struct {
CreateSetupKeyFunc func(ctx context.Context, accountId string, keyName string, keyType server.SetupKeyType, CreateSetupKeyFunc func(ctx context.Context, accountId string, keyName string, keyType server.SetupKeyType,
expiresIn time.Duration, autoGroups []string, usageLimit int, userID string, ephemeral bool) (*server.SetupKey, error) expiresIn time.Duration, autoGroups []string, usageLimit int, userID string, ephemeral bool) (*server.SetupKey, error)
GetSetupKeyFunc func(ctx context.Context, accountID, userID, keyID string) (*server.SetupKey, error) GetSetupKeyFunc func(ctx context.Context, accountID, userID, keyID string) (*server.SetupKey, error)
GetAccountIDByUserOrAccountIdFunc func(ctx context.Context, userId, accountId, domain string) (string, error) GetAccountIDByUserIdFunc func(ctx context.Context, userId, domain string) (string, error)
GetUserFunc func(ctx context.Context, claims jwtclaims.AuthorizationClaims) (*server.User, error) GetUserFunc func(ctx context.Context, claims jwtclaims.AuthorizationClaims) (*server.User, error)
ListUsersFunc func(ctx context.Context, accountID string) ([]*server.User, error) ListUsersFunc func(ctx context.Context, accountID string) ([]*server.User, error)
GetPeersFunc func(ctx context.Context, accountID, userID string) ([]*nbpeer.Peer, error) GetPeersFunc func(ctx context.Context, accountID, userID string) ([]*nbpeer.Peer, error)
@ -194,14 +194,14 @@ func (am *MockAccountManager) CreateSetupKey(
return nil, status.Errorf(codes.Unimplemented, "method CreateSetupKey is not implemented") return nil, status.Errorf(codes.Unimplemented, "method CreateSetupKey is not implemented")
} }
// GetAccountIDByUserOrAccountID mock implementation of GetAccountIDByUserOrAccountID from server.AccountManager interface // GetAccountIDByUserID mock implementation of GetAccountIDByUserID from server.AccountManager interface
func (am *MockAccountManager) GetAccountIDByUserOrAccountID(ctx context.Context, userId, accountId, domain string) (string, error) { func (am *MockAccountManager) GetAccountIDByUserID(ctx context.Context, userId, domain string) (string, error) {
if am.GetAccountIDByUserOrAccountIdFunc != nil { if am.GetAccountIDByUserIdFunc != nil {
return am.GetAccountIDByUserOrAccountIdFunc(ctx, userId, accountId, domain) return am.GetAccountIDByUserIdFunc(ctx, userId, domain)
} }
return "", status.Errorf( return "", status.Errorf(
codes.Unimplemented, codes.Unimplemented,
"method GetAccountIDByUserOrAccountID is not implemented", "method GetAccountIDByUserID is not implemented",
) )
} }

View File

@ -10,6 +10,7 @@ import (
"path/filepath" "path/filepath"
"runtime" "runtime"
"runtime/debug" "runtime/debug"
"slices"
"strings" "strings"
"sync" "sync"
"time" "time"
@ -378,15 +379,26 @@ func (s *SqlStore) SaveUsers(accountID string, users map[string]*User) error {
Create(&usersToSave).Error Create(&usersToSave).Error
} }
// SaveGroups saves the given list of groups to the database. // SaveUser saves the given user to the database.
// It updates existing groups if a conflict occurs. func (s *SqlStore) SaveUser(ctx context.Context, lockStrength LockingStrength, user *User) error {
func (s *SqlStore) SaveGroups(accountID string, groups map[string]*nbgroup.Group) error { result := s.db.WithContext(ctx).Clauses(clause.Locking{Strength: string(lockStrength)}).Save(user)
groupsToSave := make([]nbgroup.Group, 0, len(groups)) if result.Error != nil {
for _, group := range groups { return status.Errorf(status.Internal, "failed to save user to store: %v", result.Error)
group.AccountID = accountID
groupsToSave = append(groupsToSave, *group)
} }
return s.db.Clauses(clause.OnConflict{UpdateAll: true}).Create(&groupsToSave).Error return nil
}
// SaveGroups saves the given list of groups to the database.
func (s *SqlStore) SaveGroups(ctx context.Context, lockStrength LockingStrength, groups []*nbgroup.Group) error {
if len(groups) == 0 {
return nil
}
result := s.db.WithContext(ctx).Clauses(clause.Locking{Strength: string(lockStrength)}).Save(&groups)
if result.Error != nil {
return status.Errorf(status.Internal, "failed to save groups to store: %v", result.Error)
}
return nil
} }
// DeleteHashedPAT2TokenIDIndex is noop in SqlStore // DeleteHashedPAT2TokenIDIndex is noop in SqlStore
@ -1021,6 +1033,89 @@ func (s *SqlStore) AddPeerToGroup(ctx context.Context, accountId string, peerId
return nil return nil
} }
// AddUserPeersToGroups adds the user's peers to specified groups in database.
func (s *SqlStore) AddUserPeersToGroups(ctx context.Context, accountID string, userID string, groupIDs []string) error {
if len(groupIDs) == 0 {
return nil
}
var userPeerIDs []string
result := s.db.WithContext(ctx).Clauses(clause.Locking{Strength: string(LockingStrengthShare)}).Select("id").
Where("account_id = ? AND user_id = ?", accountID, userID).Model(&nbpeer.Peer{}).Find(&userPeerIDs)
if result.Error != nil {
return status.Errorf(status.Internal, "issue finding user peers")
}
groupsToUpdate := make([]*nbgroup.Group, 0, len(groupIDs))
for _, gid := range groupIDs {
group, err := s.GetGroupByID(ctx, LockingStrengthShare, gid, accountID)
if err != nil {
return err
}
groupPeers := make(map[string]struct{})
for _, pid := range group.Peers {
groupPeers[pid] = struct{}{}
}
for _, pid := range userPeerIDs {
groupPeers[pid] = struct{}{}
}
group.Peers = group.Peers[:0]
for pid := range groupPeers {
group.Peers = append(group.Peers, pid)
}
groupsToUpdate = append(groupsToUpdate, group)
}
return s.SaveGroups(ctx, LockingStrengthUpdate, groupsToUpdate)
}
// RemoveUserPeersFromGroups removes the user's peers from specified groups in database.
func (s *SqlStore) RemoveUserPeersFromGroups(ctx context.Context, accountID string, userID string, groupIDs []string) error {
if len(groupIDs) == 0 {
return nil
}
var userPeerIDs []string
result := s.db.WithContext(ctx).Clauses(clause.Locking{Strength: string(LockingStrengthShare)}).Select("id").
Where("account_id = ? AND user_id = ?", accountID, userID).Model(&nbpeer.Peer{}).Find(&userPeerIDs)
if result.Error != nil {
return status.Errorf(status.Internal, "issue finding user peers")
}
groupsToUpdate := make([]*nbgroup.Group, 0, len(groupIDs))
for _, gid := range groupIDs {
group, err := s.GetGroupByID(ctx, LockingStrengthShare, gid, accountID)
if err != nil {
return err
}
if group.Name == "All" {
continue
}
update := make([]string, 0, len(group.Peers))
for _, pid := range group.Peers {
if !slices.Contains(userPeerIDs, pid) {
update = append(update, pid)
}
}
group.Peers = update
groupsToUpdate = append(groupsToUpdate, group)
}
return s.SaveGroups(ctx, LockingStrengthUpdate, groupsToUpdate)
}
// GetUserPeers retrieves peers for a user.
func (s *SqlStore) GetUserPeers(ctx context.Context, lockStrength LockingStrength, accountID, userID string) ([]*nbpeer.Peer, error) {
return getRecords[*nbpeer.Peer](s.db.WithContext(ctx).Where("user_id = ?", userID), lockStrength, accountID)
}
func (s *SqlStore) AddPeerToAccount(ctx context.Context, peer *nbpeer.Peer) error { func (s *SqlStore) AddPeerToAccount(ctx context.Context, peer *nbpeer.Peer) error {
if err := s.db.WithContext(ctx).Create(peer).Error; err != nil { if err := s.db.WithContext(ctx).Create(peer).Error; err != nil {
return status.Errorf(status.Internal, "issue adding peer to account") return status.Errorf(status.Internal, "issue adding peer to account")
@ -1127,6 +1222,15 @@ func (s *SqlStore) GetGroupByName(ctx context.Context, lockStrength LockingStren
return &group, nil return &group, nil
} }
// SaveGroup saves a group to the store.
func (s *SqlStore) SaveGroup(ctx context.Context, lockStrength LockingStrength, group *nbgroup.Group) error {
result := s.db.WithContext(ctx).Clauses(clause.Locking{Strength: string(lockStrength)}).Save(group)
if result.Error != nil {
return status.Errorf(status.Internal, "failed to save group to store: %v", result.Error)
}
return nil
}
// GetAccountPolicies retrieves policies for an account. // GetAccountPolicies retrieves policies for an account.
func (s *SqlStore) GetAccountPolicies(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*Policy, error) { func (s *SqlStore) GetAccountPolicies(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*Policy, error) {
return getRecords[*Policy](s.db.WithContext(ctx).Preload(clause.Associations), lockStrength, accountID) return getRecords[*Policy](s.db.WithContext(ctx).Preload(clause.Associations), lockStrength, accountID)

View File

@ -1185,3 +1185,33 @@ func TestSqlite_incrementSetupKeyUsage(t *testing.T) {
require.NoError(t, err) require.NoError(t, err)
assert.Equal(t, 2, setupKey.UsedTimes) assert.Equal(t, 2, setupKey.UsedTimes)
} }
func TestSqlite_CreateAndGetObjectInTransaction(t *testing.T) {
store, cleanup, err := NewSqliteTestStore(context.Background(), t.TempDir(), "testdata/extended-store.sqlite")
t.Cleanup(cleanup)
if err != nil {
t.Fatal(err)
}
group := &nbgroup.Group{
ID: "group-id",
AccountID: "account-id",
Name: "group-name",
Issued: "api",
Peers: nil,
}
err = store.ExecuteInTransaction(context.Background(), func(transaction Store) error {
err := transaction.SaveGroup(context.Background(), LockingStrengthUpdate, group)
if err != nil {
t.Fatal("failed to save group")
return err
}
group, err = transaction.GetGroupByID(context.Background(), LockingStrengthUpdate, group.ID, group.AccountID)
if err != nil {
t.Fatal("failed to get group")
return err
}
t.Logf("group: %v", group)
return nil
})
assert.NoError(t, err)
}

View File

@ -60,6 +60,7 @@ type Store interface {
GetUserByTokenID(ctx context.Context, tokenID string) (*User, error) GetUserByTokenID(ctx context.Context, tokenID string) (*User, error)
GetUserByUserID(ctx context.Context, lockStrength LockingStrength, userID string) (*User, error) GetUserByUserID(ctx context.Context, lockStrength LockingStrength, userID string) (*User, error)
SaveUsers(accountID string, users map[string]*User) error SaveUsers(accountID string, users map[string]*User) error
SaveUser(ctx context.Context, lockStrength LockingStrength, user *User) error
SaveUserLastLogin(ctx context.Context, accountID, userID string, lastLogin time.Time) error SaveUserLastLogin(ctx context.Context, accountID, userID string, lastLogin time.Time) error
GetTokenIDByHashedToken(ctx context.Context, secret string) (string, error) GetTokenIDByHashedToken(ctx context.Context, secret string) (string, error)
DeleteHashedPAT2TokenIDIndex(hashedToken string) error DeleteHashedPAT2TokenIDIndex(hashedToken string) error
@ -68,7 +69,8 @@ type Store interface {
GetAccountGroups(ctx context.Context, accountID string) ([]*nbgroup.Group, error) GetAccountGroups(ctx context.Context, accountID string) ([]*nbgroup.Group, error)
GetGroupByID(ctx context.Context, lockStrength LockingStrength, groupID, accountID string) (*nbgroup.Group, error) GetGroupByID(ctx context.Context, lockStrength LockingStrength, groupID, accountID string) (*nbgroup.Group, error)
GetGroupByName(ctx context.Context, lockStrength LockingStrength, groupName, accountID string) (*nbgroup.Group, error) GetGroupByName(ctx context.Context, lockStrength LockingStrength, groupName, accountID string) (*nbgroup.Group, error)
SaveGroups(accountID string, groups map[string]*nbgroup.Group) error SaveGroups(ctx context.Context, lockStrength LockingStrength, groups []*nbgroup.Group) error
SaveGroup(ctx context.Context, lockStrength LockingStrength, group *nbgroup.Group) error
GetAccountPolicies(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*Policy, error) GetAccountPolicies(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*Policy, error)
GetPolicyByID(ctx context.Context, lockStrength LockingStrength, policyID string, accountID string) (*Policy, error) GetPolicyByID(ctx context.Context, lockStrength LockingStrength, policyID string, accountID string) (*Policy, error)
@ -82,6 +84,7 @@ type Store interface {
AddPeerToGroup(ctx context.Context, accountId string, peerId string, groupID string) error AddPeerToGroup(ctx context.Context, accountId string, peerId string, groupID string) error
AddPeerToAccount(ctx context.Context, peer *nbpeer.Peer) error AddPeerToAccount(ctx context.Context, peer *nbpeer.Peer) error
GetPeerByPeerPubKey(ctx context.Context, lockStrength LockingStrength, peerKey string) (*nbpeer.Peer, error) GetPeerByPeerPubKey(ctx context.Context, lockStrength LockingStrength, peerKey string) (*nbpeer.Peer, error)
GetUserPeers(ctx context.Context, lockStrength LockingStrength, accountID, userID string) ([]*nbpeer.Peer, error)
SavePeer(ctx context.Context, accountID string, peer *nbpeer.Peer) error SavePeer(ctx context.Context, accountID string, peer *nbpeer.Peer) error
SavePeerStatus(accountID, peerID string, status nbpeer.PeerStatus) error SavePeerStatus(accountID, peerID string, status nbpeer.PeerStatus) error
SavePeerLocation(accountID string, peer *nbpeer.Peer) error SavePeerLocation(accountID string, peer *nbpeer.Peer) error

View File

@ -8,14 +8,14 @@ import (
"time" "time"
"github.com/google/uuid" "github.com/google/uuid"
log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/management/server/activity" "github.com/netbirdio/netbird/management/server/activity"
nbgroup "github.com/netbirdio/netbird/management/server/group"
"github.com/netbirdio/netbird/management/server/idp" "github.com/netbirdio/netbird/management/server/idp"
"github.com/netbirdio/netbird/management/server/integration_reference" "github.com/netbirdio/netbird/management/server/integration_reference"
"github.com/netbirdio/netbird/management/server/jwtclaims" "github.com/netbirdio/netbird/management/server/jwtclaims"
nbpeer "github.com/netbirdio/netbird/management/server/peer" nbpeer "github.com/netbirdio/netbird/management/server/peer"
"github.com/netbirdio/netbird/management/server/status" "github.com/netbirdio/netbird/management/server/status"
log "github.com/sirupsen/logrus"
) )
const ( const (
@ -1254,6 +1254,74 @@ func (am *DefaultAccountManager) prepareUserDeletion(ctx context.Context, accoun
return map[string]any{"name": tuName, "email": tuEmail, "created_at": tuCreatedAt}, nil return map[string]any{"name": tuName, "email": tuEmail, "created_at": tuCreatedAt}, nil
} }
// updateUserPeersInGroups updates the user's peers in the specified groups by adding or removing them.
func (am *DefaultAccountManager) updateUserPeersInGroups(accountGroups map[string]*nbgroup.Group, peers []*nbpeer.Peer, groupsToAdd,
groupsToRemove []string) (groupsToUpdate []*nbgroup.Group, err error) {
if len(groupsToAdd) == 0 && len(groupsToRemove) == 0 {
return
}
userPeerIDMap := make(map[string]struct{}, len(peers))
for _, peer := range peers {
userPeerIDMap[peer.ID] = struct{}{}
}
for _, gid := range groupsToAdd {
group, ok := accountGroups[gid]
if !ok {
return nil, errors.New("group not found")
}
addUserPeersToGroup(userPeerIDMap, group)
groupsToUpdate = append(groupsToUpdate, group)
}
for _, gid := range groupsToRemove {
group, ok := accountGroups[gid]
if !ok {
return nil, errors.New("group not found")
}
removeUserPeersFromGroup(userPeerIDMap, group)
groupsToUpdate = append(groupsToUpdate, group)
}
return groupsToUpdate, nil
}
// addUserPeersToGroup adds the user's peers to the group.
func addUserPeersToGroup(userPeerIDs map[string]struct{}, group *nbgroup.Group) {
groupPeers := make(map[string]struct{}, len(group.Peers))
for _, pid := range group.Peers {
groupPeers[pid] = struct{}{}
}
for pid := range userPeerIDs {
groupPeers[pid] = struct{}{}
}
group.Peers = make([]string, 0, len(groupPeers))
for pid := range groupPeers {
group.Peers = append(group.Peers, pid)
}
}
// removeUserPeersFromGroup removes user's peers from the group.
func removeUserPeersFromGroup(userPeerIDs map[string]struct{}, group *nbgroup.Group) {
// skip removing peers from group All
if group.Name == "All" {
return
}
updatedPeers := make([]string, 0, len(group.Peers))
for _, pid := range group.Peers {
if _, found := userPeerIDs[pid]; !found {
updatedPeers = append(updatedPeers, pid)
}
}
group.Peers = updatedPeers
}
func findUserInIDPUserdata(userID string, userData []*idp.UserData) (*idp.UserData, bool) { func findUserInIDPUserdata(userID string, userData []*idp.UserData) (*idp.UserData, bool) {
for _, user := range userData { for _, user := range userData {
if user.ID == userID { if user.ID == userID {

View File

@ -813,10 +813,7 @@ func TestUser_DeleteUser_RegularUsers(t *testing.T) {
assert.NoError(t, err) assert.NoError(t, err)
} }
accID, err := am.GetAccountIDByUserOrAccountID(context.Background(), "", account.Id, "") acc, err := am.Store.GetAccount(context.Background(), account.Id)
assert.NoError(t, err)
acc, err := am.Store.GetAccount(context.Background(), accID)
assert.NoError(t, err) assert.NoError(t, err)
for _, id := range tc.expectedDeleted { for _, id := range tc.expectedDeleted {