Add batch save/update for groups and users (#2245)

* Add functionality to update multiple users

* Remove SaveUsers from DefaultAccountManager

* Add SaveGroups method to AccountManager interface

* Refactoring

* Add SaveUsers and SaveGroups methods to store interface

* Refactor method SaveAccount to SaveUsers and SaveGroups

The method SaveAccount in user.go and group.go files was split into two separate methods. Now, user-specific data is handled by SaveUsers and group-specific data is handled by SaveGroups method. This provides a cleaner and more efficient way to save user and group data.

* Add account ID to user and group in SqlStore

* Refactor SaveUsers and SaveGroups in store

* Remove unnecessary ID assignment in SaveUsers and SaveGroups
This commit is contained in:
Bethuel Mmbaga 2024-07-15 17:04:06 +03:00 committed by GitHub
parent 2577100096
commit 1537b0f5e7
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
7 changed files with 318 additions and 153 deletions

View File

@ -69,6 +69,7 @@ type AccountManager interface {
ListSetupKeys(ctx context.Context, accountID, userID string) ([]*SetupKey, error) ListSetupKeys(ctx context.Context, accountID, userID string) ([]*SetupKey, error)
SaveUser(ctx context.Context, accountID, initiatorUserID string, update *User) (*UserInfo, error) SaveUser(ctx context.Context, accountID, initiatorUserID string, update *User) (*UserInfo, error)
SaveOrAddUser(ctx context.Context, accountID, initiatorUserID string, update *User, addIfNotExists bool) (*UserInfo, error) SaveOrAddUser(ctx context.Context, accountID, initiatorUserID string, update *User, addIfNotExists bool) (*UserInfo, error)
SaveOrAddUsers(ctx context.Context, accountID, initiatorUserID string, updates []*User, addIfNotExists bool) ([]*UserInfo, error)
GetSetupKey(ctx context.Context, accountID, userID, keyID string) (*SetupKey, error) GetSetupKey(ctx context.Context, accountID, userID, keyID string) (*SetupKey, error)
GetAccountByUserOrAccountID(ctx context.Context, userID, accountID, domain string) (*Account, error) GetAccountByUserOrAccountID(ctx context.Context, userID, accountID, domain string) (*Account, error)
GetAccountFromToken(ctx context.Context, claims jwtclaims.AuthorizationClaims) (*Account, *User, error) GetAccountFromToken(ctx context.Context, claims jwtclaims.AuthorizationClaims) (*Account, *User, error)
@ -95,6 +96,7 @@ type AccountManager interface {
GetAllGroups(ctx context.Context, accountID, userID string) ([]*nbgroup.Group, error) GetAllGroups(ctx context.Context, accountID, userID string) ([]*nbgroup.Group, error)
GetGroupByName(ctx context.Context, groupName, accountID string) (*nbgroup.Group, error) GetGroupByName(ctx context.Context, groupName, accountID string) (*nbgroup.Group, error)
SaveGroup(ctx context.Context, accountID, userID string, group *nbgroup.Group) error SaveGroup(ctx context.Context, accountID, userID string, group *nbgroup.Group) error
SaveGroups(ctx context.Context, accountID, userID string, newGroups []*nbgroup.Group) error
DeleteGroup(ctx context.Context, accountId, userId, groupID string) error DeleteGroup(ctx context.Context, accountId, userId, groupID string) error
ListGroups(ctx context.Context, accountId string) ([]*nbgroup.Group, error) ListGroups(ctx context.Context, accountId string) ([]*nbgroup.Group, error)
GroupAddPeer(ctx context.Context, accountId, groupID, peerID string) error GroupAddPeer(ctx context.Context, accountId, groupID, peerID string) error

View File

@ -746,3 +746,11 @@ func (s *FileStore) Close(ctx context.Context) error {
func (s *FileStore) GetStoreEngine() StoreEngine { func (s *FileStore) GetStoreEngine() StoreEngine {
return FileStoreEngine return FileStoreEngine
} }
func (s *FileStore) SaveUsers(accountID string, users map[string]*User) error {
return status.Errorf(status.Internal, "SaveUsers is not implemented")
}
func (s *FileStore) SaveGroups(accountID string, groups map[string]*nbgroup.Group) error {
return status.Errorf(status.Internal, "SaveGroups is not implemented")
}

View File

@ -112,61 +112,85 @@ func (am *DefaultAccountManager) GetGroupByName(ctx context.Context, groupName,
func (am *DefaultAccountManager) SaveGroup(ctx context.Context, accountID, userID string, newGroup *nbgroup.Group) error { func (am *DefaultAccountManager) SaveGroup(ctx context.Context, accountID, userID string, newGroup *nbgroup.Group) error {
unlock := am.Store.AcquireAccountWriteLock(ctx, accountID) unlock := am.Store.AcquireAccountWriteLock(ctx, accountID)
defer unlock() defer unlock()
return am.SaveGroups(ctx, accountID, userID, []*nbgroup.Group{newGroup})
}
// SaveGroups adds new groups to the account.
// Note: This function does not acquire the global lock.
// It is the caller's responsibility to ensure proper locking is in place before invoking this method.
func (am *DefaultAccountManager) SaveGroups(ctx context.Context, accountID, userID string, newGroups []*nbgroup.Group) error {
account, err := am.Store.GetAccount(ctx, accountID) account, err := am.Store.GetAccount(ctx, accountID)
if err != nil { if err != nil {
return err return err
} }
if newGroup.ID == "" && newGroup.Issued != nbgroup.GroupIssuedAPI { var eventsToStore []func()
return status.Errorf(status.InvalidArgument, "%s group without ID set", newGroup.Issued)
}
if newGroup.ID == "" && newGroup.Issued == nbgroup.GroupIssuedAPI { for _, newGroup := range newGroups {
if newGroup.ID == "" && newGroup.Issued != nbgroup.GroupIssuedAPI {
return status.Errorf(status.InvalidArgument, "%s group without ID set", newGroup.Issued)
}
existingGroup, err := account.FindGroupByName(newGroup.Name) if newGroup.ID == "" && newGroup.Issued == nbgroup.GroupIssuedAPI {
if err != nil { existingGroup, err := account.FindGroupByName(newGroup.Name)
s, ok := status.FromError(err) if err != nil {
if !ok || s.ErrorType != status.NotFound { s, ok := status.FromError(err)
return err if !ok || s.ErrorType != status.NotFound {
return err
}
}
// Avoid duplicate groups only for the API issued groups.
// Integration or JWT groups can be duplicated as they are coming from the IdP that we don't have control of.
if existingGroup != nil {
return status.Errorf(status.AlreadyExists, "group with name %s already exists", newGroup.Name)
}
newGroup.ID = xid.New().String()
}
for _, peerID := range newGroup.Peers {
if account.Peers[peerID] == nil {
return status.Errorf(status.InvalidArgument, "peer with ID \"%s\" not found", peerID)
} }
} }
// avoid duplicate groups only for the API issued groups. Integration or JWT groups can be duplicated as they are oldGroup := account.Groups[newGroup.ID]
// coming from the IdP that we don't have control of. account.Groups[newGroup.ID] = newGroup
if existingGroup != nil {
return status.Errorf(status.AlreadyExists, "group with name %s already exists", newGroup.Name)
}
newGroup.ID = xid.New().String() events := am.prepareGroupEvents(ctx, userID, accountID, newGroup, oldGroup, account)
eventsToStore = append(eventsToStore, events...)
} }
for _, peerID := range newGroup.Peers {
if account.Peers[peerID] == nil {
return status.Errorf(status.InvalidArgument, "peer with ID \"%s\" not found", peerID)
}
}
oldGroup, exists := account.Groups[newGroup.ID]
account.Groups[newGroup.ID] = newGroup
account.Network.IncSerial() account.Network.IncSerial()
if err = am.Store.SaveAccount(ctx, account); err != nil { if err = am.Store.SaveGroups(account.Id, account.Groups); err != nil {
return err return err
} }
am.updateAccountPeers(ctx, account) am.updateAccountPeers(ctx, account)
// the following snippet tracks the activity and stores the group events in the event store. for _, storeEvent := range eventsToStore {
// It has to happen after all the operations have been successfully performed. storeEvent()
}
return nil
}
// prepareGroupEvents prepares a list of event functions to be stored.
func (am *DefaultAccountManager) prepareGroupEvents(ctx context.Context, userID string, accountID string, newGroup, oldGroup *nbgroup.Group, account *Account) []func() {
var eventsToStore []func()
addedPeers := make([]string, 0) addedPeers := make([]string, 0)
removedPeers := make([]string, 0) removedPeers := make([]string, 0)
if exists {
if oldGroup != nil {
addedPeers = difference(newGroup.Peers, oldGroup.Peers) addedPeers = difference(newGroup.Peers, oldGroup.Peers)
removedPeers = difference(oldGroup.Peers, newGroup.Peers) removedPeers = difference(oldGroup.Peers, newGroup.Peers)
} else { } else {
addedPeers = append(addedPeers, newGroup.Peers...) addedPeers = append(addedPeers, newGroup.Peers...)
am.StoreEvent(ctx, userID, newGroup.ID, accountID, activity.GroupCreated, newGroup.EventMeta()) eventsToStore = append(eventsToStore, func() {
am.StoreEvent(ctx, userID, newGroup.ID, accountID, activity.GroupCreated, newGroup.EventMeta())
})
} }
for _, p := range addedPeers { for _, p := range addedPeers {
@ -175,11 +199,14 @@ func (am *DefaultAccountManager) SaveGroup(ctx context.Context, accountID, userI
log.WithContext(ctx).Errorf("peer %s not found under account %s while saving group", p, accountID) log.WithContext(ctx).Errorf("peer %s not found under account %s while saving group", p, accountID)
continue continue
} }
am.StoreEvent(ctx, userID, peer.ID, accountID, activity.GroupAddedToPeer, peerCopy := peer // copy to avoid closure issues
map[string]any{ eventsToStore = append(eventsToStore, func() {
"group": newGroup.Name, "group_id": newGroup.ID, "peer_ip": peer.IP.String(), am.StoreEvent(ctx, userID, peerCopy.ID, accountID, activity.GroupAddedToPeer,
"peer_fqdn": peer.FQDN(am.GetDNSDomain()), map[string]any{
}) "group": newGroup.Name, "group_id": newGroup.ID, "peer_ip": peerCopy.IP.String(),
"peer_fqdn": peerCopy.FQDN(am.GetDNSDomain()),
})
})
} }
for _, p := range removedPeers { for _, p := range removedPeers {
@ -188,14 +215,17 @@ func (am *DefaultAccountManager) SaveGroup(ctx context.Context, accountID, userI
log.WithContext(ctx).Errorf("peer %s not found under account %s while saving group", p, accountID) log.WithContext(ctx).Errorf("peer %s not found under account %s while saving group", p, accountID)
continue continue
} }
am.StoreEvent(ctx, userID, peer.ID, accountID, activity.GroupRemovedFromPeer, peerCopy := peer // copy to avoid closure issues
map[string]any{ eventsToStore = append(eventsToStore, func() {
"group": newGroup.Name, "group_id": newGroup.ID, "peer_ip": peer.IP.String(), am.StoreEvent(ctx, userID, peerCopy.ID, accountID, activity.GroupRemovedFromPeer,
"peer_fqdn": peer.FQDN(am.GetDNSDomain()), map[string]any{
}) "group": newGroup.Name, "group_id": newGroup.ID, "peer_ip": peerCopy.IP.String(),
"peer_fqdn": peerCopy.FQDN(am.GetDNSDomain()),
})
})
} }
return nil return eventsToStore
} }
// difference returns the elements in `a` that aren't in `b`. // difference returns the elements in `a` that aren't in `b`.

View File

@ -40,6 +40,7 @@ type MockAccountManager struct {
GetAllGroupsFunc func(ctx context.Context, accountID, userID string) ([]*group.Group, error) GetAllGroupsFunc func(ctx context.Context, accountID, userID string) ([]*group.Group, error)
GetGroupByNameFunc func(ctx context.Context, accountID, groupName string) (*group.Group, error) GetGroupByNameFunc func(ctx context.Context, accountID, groupName string) (*group.Group, error)
SaveGroupFunc func(ctx context.Context, accountID, userID string, group *group.Group) error SaveGroupFunc func(ctx context.Context, accountID, userID string, group *group.Group) error
SaveGroupsFunc func(ctx context.Context, accountID, userID string, groups []*group.Group) error
DeleteGroupFunc func(ctx context.Context, accountID, userId, groupID string) error DeleteGroupFunc func(ctx context.Context, accountID, userId, groupID string) error
ListGroupsFunc func(ctx context.Context, accountID string) ([]*group.Group, error) ListGroupsFunc func(ctx context.Context, accountID string) ([]*group.Group, error)
GroupAddPeerFunc func(ctx context.Context, accountID, groupID, peerID string) error GroupAddPeerFunc func(ctx context.Context, accountID, groupID, peerID string) error
@ -64,6 +65,7 @@ type MockAccountManager struct {
ListSetupKeysFunc func(ctx context.Context, accountID, userID string) ([]*server.SetupKey, error) ListSetupKeysFunc func(ctx context.Context, accountID, userID string) ([]*server.SetupKey, error)
SaveUserFunc func(ctx context.Context, accountID, userID string, user *server.User) (*server.UserInfo, error) SaveUserFunc func(ctx context.Context, accountID, userID string, user *server.User) (*server.UserInfo, error)
SaveOrAddUserFunc func(ctx context.Context, accountID, userID string, user *server.User, addIfNotExists bool) (*server.UserInfo, error) SaveOrAddUserFunc func(ctx context.Context, accountID, userID string, user *server.User, addIfNotExists bool) (*server.UserInfo, error)
SaveOrAddUsersFunc func(ctx context.Context, accountID, initiatorUserID string, update []*server.User, addIfNotExists bool) ([]*server.UserInfo, error)
DeleteUserFunc func(ctx context.Context, accountID string, initiatorUserID string, targetUserID string) error DeleteUserFunc func(ctx context.Context, accountID string, initiatorUserID string, targetUserID string) error
CreatePATFunc func(ctx context.Context, accountID string, initiatorUserID string, targetUserId string, tokenName string, expiresIn int) (*server.PersonalAccessTokenGenerated, error) CreatePATFunc func(ctx context.Context, accountID string, initiatorUserID string, targetUserId string, tokenName string, expiresIn int) (*server.PersonalAccessTokenGenerated, error)
DeletePATFunc func(ctx context.Context, accountID string, initiatorUserID string, targetUserId string, tokenID string) error DeletePATFunc func(ctx context.Context, accountID string, initiatorUserID string, targetUserId string, tokenID string) error
@ -308,6 +310,14 @@ func (am *MockAccountManager) SaveGroup(ctx context.Context, accountID, userID s
return status.Errorf(codes.Unimplemented, "method SaveGroup is not implemented") return status.Errorf(codes.Unimplemented, "method SaveGroup is not implemented")
} }
// SaveGroups mock implementation of SaveGroups from server.AccountManager interface
func (am *MockAccountManager) SaveGroups(ctx context.Context, accountID, userID string, groups []*group.Group) error {
if am.SaveGroupsFunc != nil {
return am.SaveGroupsFunc(ctx, accountID, userID, groups)
}
return status.Errorf(codes.Unimplemented, "method SaveGroups is not implemented")
}
// DeleteGroup mock implementation of DeleteGroup from server.AccountManager interface // DeleteGroup mock implementation of DeleteGroup from server.AccountManager interface
func (am *MockAccountManager) DeleteGroup(ctx context.Context, accountId, userId, groupID string) error { func (am *MockAccountManager) DeleteGroup(ctx context.Context, accountId, userId, groupID string) error {
if am.DeleteGroupFunc != nil { if am.DeleteGroupFunc != nil {
@ -502,6 +512,14 @@ func (am *MockAccountManager) SaveOrAddUser(ctx context.Context, accountID, user
return nil, status.Errorf(codes.Unimplemented, "method SaveOrAddUser is not implemented") return nil, status.Errorf(codes.Unimplemented, "method SaveOrAddUser is not implemented")
} }
// SaveOrAddUsers mocks SaveOrAddUsers of the AccountManager interface
func (am *MockAccountManager) SaveOrAddUsers(ctx context.Context, accountID, userID string, users []*server.User, addIfNotExists bool) ([]*server.UserInfo, error) {
if am.SaveOrAddUsersFunc != nil {
return am.SaveOrAddUsersFunc(ctx, accountID, userID, users, addIfNotExists)
}
return nil, status.Errorf(codes.Unimplemented, "method SaveOrAddUsers is not implemented")
}
// DeleteUser mocks DeleteUser of the AccountManager interface // DeleteUser mocks DeleteUser of the AccountManager interface
func (am *MockAccountManager) DeleteUser(ctx context.Context, accountID string, initiatorUserID string, targetUserID string) error { func (am *MockAccountManager) DeleteUser(ctx context.Context, accountID string, initiatorUserID string, targetUserID string) error {
if am.DeleteUserFunc != nil { if am.DeleteUserFunc != nil {

View File

@ -311,6 +311,34 @@ func (s *SqlStore) SavePeerLocation(accountID string, peerWithLocation *nbpeer.P
return nil return nil
} }
// SaveUsers saves the given list of users to the database.
// It updates existing users if a conflict occurs.
func (s *SqlStore) SaveUsers(accountID string, users map[string]*User) error {
usersToSave := make([]User, 0, len(users))
for _, user := range users {
user.AccountID = accountID
for id, pat := range user.PATs {
pat.ID = id
user.PATsG = append(user.PATsG, *pat)
}
usersToSave = append(usersToSave, *user)
}
return s.db.Session(&gorm.Session{FullSaveAssociations: true}).
Clauses(clause.OnConflict{UpdateAll: true}).
Create(&usersToSave).Error
}
// SaveGroups saves the given list of groups to the database.
// It updates existing groups if a conflict occurs.
func (s *SqlStore) SaveGroups(accountID string, groups map[string]*nbgroup.Group) error {
groupsToSave := make([]nbgroup.Group, 0, len(groups))
for _, group := range groups {
group.AccountID = accountID
groupsToSave = append(groupsToSave, *group)
}
return s.db.Clauses(clause.OnConflict{UpdateAll: true}).Create(&groupsToSave).Error
}
// DeleteHashedPAT2TokenIDIndex is noop in SqlStore // DeleteHashedPAT2TokenIDIndex is noop in SqlStore
func (s *SqlStore) DeleteHashedPAT2TokenIDIndex(hashedToken string) error { func (s *SqlStore) DeleteHashedPAT2TokenIDIndex(hashedToken string) error {
return nil return nil

View File

@ -12,6 +12,7 @@ import (
"strings" "strings"
"time" "time"
nbgroup "github.com/netbirdio/netbird/management/server/group"
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
"gorm.io/gorm" "gorm.io/gorm"
@ -41,6 +42,8 @@ type Store interface {
GetUserByTokenID(ctx context.Context, tokenID string) (*User, error) GetUserByTokenID(ctx context.Context, tokenID string) (*User, error)
GetPostureCheckByChecksDefinition(accountID string, checks *posture.ChecksDefinition) (*posture.Checks, error) GetPostureCheckByChecksDefinition(accountID string, checks *posture.ChecksDefinition) (*posture.Checks, error)
SaveAccount(ctx context.Context, account *Account) error SaveAccount(ctx context.Context, account *Account) error
SaveUsers(accountID string, users map[string]*User) error
SaveGroups(accountID string, groups map[string]*nbgroup.Group) error
DeleteHashedPAT2TokenIDIndex(hashedToken string) error DeleteHashedPAT2TokenIDIndex(hashedToken string) error
DeleteTokenID2UserIDIndex(tokenID string) error DeleteTokenID2UserIDIndex(tokenID string) error
GetInstallationID() string GetInstallationID() string

View File

@ -740,7 +740,7 @@ func (am *DefaultAccountManager) GetAllPATs(ctx context.Context, accountID strin
return pats, nil return pats, nil
} }
// SaveUser saves updates to the given user. If the user doesn't exit it will throw status.NotFound error. // SaveUser saves updates to the given user. If the user doesn't exist, it will throw status.NotFound error.
func (am *DefaultAccountManager) SaveUser(ctx context.Context, accountID, initiatorUserID string, update *User) (*UserInfo, error) { func (am *DefaultAccountManager) SaveUser(ctx context.Context, accountID, initiatorUserID string, update *User) (*UserInfo, error) {
return am.SaveOrAddUser(ctx, accountID, initiatorUserID, update, false) // false means do not create user and throw status.NotFound return am.SaveOrAddUser(ctx, accountID, initiatorUserID, update, false) // false means do not create user and throw status.NotFound
} }
@ -748,11 +748,31 @@ func (am *DefaultAccountManager) SaveUser(ctx context.Context, accountID, initia
// SaveOrAddUser updates the given user. If addIfNotExists is set to true it will add user when no exist // SaveOrAddUser updates the given user. If addIfNotExists is set to true it will add user when no exist
// Only User.AutoGroups, User.Role, and User.Blocked fields are allowed to be updated for now. // Only User.AutoGroups, User.Role, and User.Blocked fields are allowed to be updated for now.
func (am *DefaultAccountManager) SaveOrAddUser(ctx context.Context, accountID, initiatorUserID string, update *User, addIfNotExists bool) (*UserInfo, error) { func (am *DefaultAccountManager) SaveOrAddUser(ctx context.Context, accountID, initiatorUserID string, update *User, addIfNotExists bool) (*UserInfo, error) {
if update == nil {
return nil, status.Errorf(status.InvalidArgument, "provided user update is nil")
}
unlock := am.Store.AcquireAccountWriteLock(ctx, accountID) unlock := am.Store.AcquireAccountWriteLock(ctx, accountID)
defer unlock() defer unlock()
if update == nil { updatedUsers, err := am.SaveOrAddUsers(ctx, accountID, initiatorUserID, []*User{update}, addIfNotExists)
return nil, status.Errorf(status.InvalidArgument, "provided user update is nil") if err != nil {
return nil, err
}
if len(updatedUsers) == 0 {
return nil, status.Errorf(status.Internal, "user was not updated")
}
return updatedUsers[0], nil
}
// SaveOrAddUsers updates existing users or adds new users to the account.
// Note: This function does not acquire the global lock.
// It is the caller's responsibility to ensure proper locking is in place before invoking this method.
func (am *DefaultAccountManager) SaveOrAddUsers(ctx context.Context, accountID, initiatorUserID string, updates []*User, addIfNotExists bool) ([]*UserInfo, error) {
if len(updates) == 0 {
return nil, nil //nolint:nilnil
} }
account, err := am.Store.GetAccount(ctx, accountID) account, err := am.Store.GetAccount(ctx, accountID)
@ -769,144 +789,200 @@ func (am *DefaultAccountManager) SaveOrAddUser(ctx context.Context, accountID, i
return nil, status.Errorf(status.PermissionDenied, "only users with admin power are authorized to perform user update operations") return nil, status.Errorf(status.PermissionDenied, "only users with admin power are authorized to perform user update operations")
} }
oldUser := account.Users[update.Id] updatedUsers := make([]*UserInfo, 0, len(updates))
if oldUser == nil { var (
if !addIfNotExists { expiredPeers []*nbpeer.Peer
return nil, status.Errorf(status.NotFound, "user to update doesn't exist") eventsToStore []func()
)
for _, update := range updates {
if update == nil {
return nil, status.Errorf(status.InvalidArgument, "provided user update is nil")
} }
// when addIfNotExists is set to true the newUser will use all fields from the update input
oldUser = update
}
if initiatorUser.HasAdminPower() && initiatorUserID == update.Id && oldUser.Blocked != update.Blocked { oldUser := account.Users[update.Id]
return nil, status.Errorf(status.PermissionDenied, "admins can't block or unblock themselves") if oldUser == nil {
} if !addIfNotExists {
return nil, status.Errorf(status.NotFound, "user to update doesn't exist: %s", update.Id)
if initiatorUser.HasAdminPower() && initiatorUserID == update.Id && update.Role != initiatorUser.Role { }
return nil, status.Errorf(status.PermissionDenied, "admins can't change their role") // when addIfNotExists is set to true, the newUser will use all fields from the update input
} oldUser = update
if initiatorUser.Role == UserRoleAdmin && oldUser.Role == UserRoleOwner && update.Role != oldUser.Role {
return nil, status.Errorf(status.PermissionDenied, "only owners can remove owner role from their user")
}
if initiatorUser.Role == UserRoleAdmin && oldUser.Role == UserRoleOwner && update.IsBlocked() && !oldUser.IsBlocked() {
return nil, status.Errorf(status.PermissionDenied, "unable to block owner user")
}
if initiatorUser.Role == UserRoleAdmin && update.Role == UserRoleOwner && update.Role != oldUser.Role {
return nil, status.Errorf(status.PermissionDenied, "only owners can add owner role to other users")
}
if oldUser.IsServiceUser && update.Role == UserRoleOwner {
return nil, status.Errorf(status.PermissionDenied, "can't update a service user with owner role")
}
transferedOwnerRole := false
if initiatorUser.Role == UserRoleOwner && initiatorUserID != update.Id && update.Role == UserRoleOwner {
newInitiatorUser := initiatorUser.Copy()
newInitiatorUser.Role = UserRoleAdmin
account.Users[initiatorUserID] = newInitiatorUser
transferedOwnerRole = true
}
// only auto groups, revoked status, and integration reference can be updated for now
newUser := oldUser.Copy()
newUser.Role = update.Role
newUser.Blocked = update.Blocked
// these two fields can't be set via API, only via direct call to the method
newUser.Issued = update.Issued
newUser.IntegrationReference = update.IntegrationReference
for _, newGroupID := range update.AutoGroups {
if _, ok := account.Groups[newGroupID]; !ok {
return nil, status.Errorf(status.InvalidArgument, "provided group ID %s in the user %s update doesn't exist",
newGroupID, update.Id)
} }
}
newUser.AutoGroups = update.AutoGroups
account.Users[newUser.Id] = newUser if err := validateUserUpdate(account, initiatorUser, oldUser, update); err != nil {
if !oldUser.IsBlocked() && update.IsBlocked() {
// expire peers that belong to the user who's getting blocked
blockedPeers, err := account.FindUserPeers(update.Id)
if err != nil {
return nil, err return nil, err
} }
if err := am.expireAndUpdatePeers(ctx, account, blockedPeers); err != nil { // only auto groups, revoked status, and integration reference can be updated for now
newUser := oldUser.Copy()
newUser.Role = update.Role
newUser.Blocked = update.Blocked
newUser.AutoGroups = update.AutoGroups
// these two fields can't be set via API, only via direct call to the method
newUser.Issued = update.Issued
newUser.IntegrationReference = update.IntegrationReference
transferredOwnerRole := handleOwnerRoleTransfer(account, initiatorUser, update)
account.Users[newUser.Id] = newUser
if !oldUser.IsBlocked() && update.IsBlocked() {
// expire peers that belong to the user who's getting blocked
blockedPeers, err := account.FindUserPeers(update.Id)
if err != nil {
return nil, err
}
expiredPeers = append(expiredPeers, blockedPeers...)
}
if update.AutoGroups != nil && account.Settings.GroupsPropagationEnabled {
removedGroups := difference(oldUser.AutoGroups, update.AutoGroups)
// need force update all auto groups in any case they will not be duplicated
account.UserGroupsAddToPeers(oldUser.Id, update.AutoGroups...)
account.UserGroupsRemoveFromPeers(oldUser.Id, removedGroups...)
}
events := am.prepareUserUpdateEvents(ctx, initiatorUser.Id, oldUser, newUser, account, transferredOwnerRole)
eventsToStore = append(eventsToStore, events...)
updatedUserInfo, err := getUserInfo(ctx, am, newUser, account)
if err != nil {
return nil, err
}
updatedUsers = append(updatedUsers, updatedUserInfo)
}
if len(expiredPeers) > 0 {
if err := am.expireAndUpdatePeers(ctx, account, expiredPeers); err != nil {
log.WithContext(ctx).Errorf("failed update expired peers: %s", err) log.WithContext(ctx).Errorf("failed update expired peers: %s", err)
return nil, err return nil, err
} }
} }
if update.AutoGroups != nil && account.Settings.GroupsPropagationEnabled { account.Network.IncSerial()
removedGroups := difference(oldUser.AutoGroups, update.AutoGroups) if err = am.Store.SaveUsers(account.Id, account.Users); err != nil {
// need force update all auto groups in any case they will not be duplicated return nil, err
account.UserGroupsAddToPeers(oldUser.Id, update.AutoGroups...) }
account.UserGroupsRemoveFromPeers(oldUser.Id, removedGroups...)
account.Network.IncSerial()
if err = am.Store.SaveAccount(ctx, account); err != nil {
return nil, err
}
if account.Settings.GroupsPropagationEnabled {
am.updateAccountPeers(ctx, account) am.updateAccountPeers(ctx, account)
} else { }
if err = am.Store.SaveAccount(ctx, account); err != nil {
return nil, err for _, storeEvent := range eventsToStore {
storeEvent()
}
return updatedUsers, nil
}
// prepareUserUpdateEvents prepares a list user update events based on the changes between the old and new user data.
func (am *DefaultAccountManager) prepareUserUpdateEvents(ctx context.Context, initiatorUserID string, oldUser, newUser *User, account *Account, transferredOwnerRole bool) []func() {
var eventsToStore []func()
if oldUser.IsBlocked() != newUser.IsBlocked() {
if newUser.IsBlocked() {
eventsToStore = append(eventsToStore, func() {
am.StoreEvent(ctx, initiatorUserID, oldUser.Id, account.Id, activity.UserBlocked, nil)
})
} else {
eventsToStore = append(eventsToStore, func() {
am.StoreEvent(ctx, initiatorUserID, oldUser.Id, account.Id, activity.UserUnblocked, nil)
})
} }
} }
defer func() { switch {
if oldUser.IsBlocked() != update.IsBlocked() { case transferredOwnerRole:
if update.IsBlocked() { eventsToStore = append(eventsToStore, func() {
am.StoreEvent(ctx, initiatorUserID, oldUser.Id, accountID, activity.UserBlocked, nil) am.StoreEvent(ctx, initiatorUserID, oldUser.Id, account.Id, activity.TransferredOwnerRole, nil)
})
case oldUser.Role != newUser.Role:
eventsToStore = append(eventsToStore, func() {
am.StoreEvent(ctx, initiatorUserID, oldUser.Id, account.Id, activity.UserRoleUpdated, map[string]any{"role": newUser.Role})
})
}
if newUser.AutoGroups != nil {
removedGroups := difference(oldUser.AutoGroups, newUser.AutoGroups)
addedGroups := difference(newUser.AutoGroups, oldUser.AutoGroups)
for _, g := range removedGroups {
group := account.GetGroup(g)
if group != nil {
eventsToStore = append(eventsToStore, func() {
am.StoreEvent(ctx, initiatorUserID, oldUser.Id, account.Id, activity.GroupRemovedFromUser,
map[string]any{"group": group.Name, "group_id": group.ID, "is_service_user": newUser.IsServiceUser, "user_name": newUser.ServiceUserName})
})
} else { } else {
am.StoreEvent(ctx, initiatorUserID, oldUser.Id, accountID, activity.UserUnblocked, nil) log.WithContext(ctx).Errorf("group %s not found while saving user activity event of account %s", g, account.Id)
} }
} }
for _, g := range addedGroups {
switch { group := account.GetGroup(g)
case transferedOwnerRole: if group != nil {
am.StoreEvent(ctx, initiatorUserID, oldUser.Id, accountID, activity.TransferredOwnerRole, nil) eventsToStore = append(eventsToStore, func() {
case oldUser.Role != newUser.Role: am.StoreEvent(ctx, initiatorUserID, oldUser.Id, account.Id, activity.GroupAddedToUser,
am.StoreEvent(ctx, initiatorUserID, oldUser.Id, accountID, activity.UserRoleUpdated, map[string]any{"role": newUser.Role})
default:
}
if update.AutoGroups != nil {
removedGroups := difference(oldUser.AutoGroups, update.AutoGroups)
addedGroups := difference(newUser.AutoGroups, oldUser.AutoGroups)
for _, g := range removedGroups {
group := account.GetGroup(g)
if group != nil {
am.StoreEvent(ctx, initiatorUserID, oldUser.Id, accountID, activity.GroupRemovedFromUser,
map[string]any{"group": group.Name, "group_id": group.ID, "is_service_user": newUser.IsServiceUser, "user_name": newUser.ServiceUserName}) map[string]any{"group": group.Name, "group_id": group.ID, "is_service_user": newUser.IsServiceUser, "user_name": newUser.ServiceUserName})
} else { })
log.WithContext(ctx).Errorf("group %s not found while saving user activity event of account %s", g, account.Id)
}
}
for _, g := range addedGroups {
group := account.GetGroup(g)
if group != nil {
am.StoreEvent(ctx, initiatorUserID, oldUser.Id, accountID, activity.GroupAddedToUser,
map[string]any{"group": group.Name, "group_id": group.ID, "is_service_user": newUser.IsServiceUser, "user_name": newUser.ServiceUserName})
}
} }
} }
}() }
if !isNil(am.idpManager) && !newUser.IsServiceUser { return eventsToStore
userData, err := am.lookupUserInCache(ctx, newUser.Id, account) }
func handleOwnerRoleTransfer(account *Account, initiatorUser, update *User) bool {
if initiatorUser.Role == UserRoleOwner && initiatorUser.Id != update.Id && update.Role == UserRoleOwner {
newInitiatorUser := initiatorUser.Copy()
newInitiatorUser.Role = UserRoleAdmin
account.Users[initiatorUser.Id] = newInitiatorUser
return true
}
return false
}
// getUserInfo retrieves the UserInfo for a given User and Account.
// If the AccountManager has a non-nil idpManager and the User is not a service user,
// it will attempt to look up the UserData from the cache.
func getUserInfo(ctx context.Context, am *DefaultAccountManager, user *User, account *Account) (*UserInfo, error) {
if !isNil(am.idpManager) && !user.IsServiceUser {
userData, err := am.lookupUserInCache(ctx, user.Id, account)
if err != nil { if err != nil {
return nil, err return nil, err
} }
return newUser.ToUserInfo(userData, account.Settings) return user.ToUserInfo(userData, account.Settings)
} }
return newUser.ToUserInfo(nil, account.Settings) return user.ToUserInfo(nil, account.Settings)
}
// validateUserUpdate validates the update operation for a user.
func validateUserUpdate(account *Account, initiatorUser, oldUser, update *User) error {
if initiatorUser.HasAdminPower() && initiatorUser.Id == update.Id && oldUser.Blocked != update.Blocked {
return status.Errorf(status.PermissionDenied, "admins can't block or unblock themselves")
}
if initiatorUser.HasAdminPower() && initiatorUser.Id == update.Id && update.Role != initiatorUser.Role {
return status.Errorf(status.PermissionDenied, "admins can't change their role")
}
if initiatorUser.Role == UserRoleAdmin && oldUser.Role == UserRoleOwner && update.Role != oldUser.Role {
return status.Errorf(status.PermissionDenied, "only owners can remove owner role from their user")
}
if initiatorUser.Role == UserRoleAdmin && oldUser.Role == UserRoleOwner && update.IsBlocked() && !oldUser.IsBlocked() {
return status.Errorf(status.PermissionDenied, "unable to block owner user")
}
if initiatorUser.Role == UserRoleAdmin && update.Role == UserRoleOwner && update.Role != oldUser.Role {
return status.Errorf(status.PermissionDenied, "only owners can add owner role to other users")
}
if oldUser.IsServiceUser && update.Role == UserRoleOwner {
return status.Errorf(status.PermissionDenied, "can't update a service user with owner role")
}
for _, newGroupID := range update.AutoGroups {
if _, ok := account.Groups[newGroupID]; !ok {
return status.Errorf(status.InvalidArgument, "provided group ID %s in the user %s update doesn't exist",
newGroupID, update.Id)
}
}
return nil
} }
// GetOrCreateAccountByUser returns an existing account for a given user id or creates a new one if doesn't exist // GetOrCreateAccountByUser returns an existing account for a given user id or creates a new one if doesn't exist