mirror of
https://github.com/netbirdio/netbird.git
synced 2025-06-19 17:31:39 +02:00
Add batch save/update for groups and users (#2245)
* Add functionality to update multiple users * Remove SaveUsers from DefaultAccountManager * Add SaveGroups method to AccountManager interface * Refactoring * Add SaveUsers and SaveGroups methods to store interface * Refactor method SaveAccount to SaveUsers and SaveGroups The method SaveAccount in user.go and group.go files was split into two separate methods. Now, user-specific data is handled by SaveUsers and group-specific data is handled by SaveGroups method. This provides a cleaner and more efficient way to save user and group data. * Add account ID to user and group in SqlStore * Refactor SaveUsers and SaveGroups in store * Remove unnecessary ID assignment in SaveUsers and SaveGroups
This commit is contained in:
parent
2577100096
commit
1537b0f5e7
@ -69,6 +69,7 @@ type AccountManager interface {
|
|||||||
ListSetupKeys(ctx context.Context, accountID, userID string) ([]*SetupKey, error)
|
ListSetupKeys(ctx context.Context, accountID, userID string) ([]*SetupKey, error)
|
||||||
SaveUser(ctx context.Context, accountID, initiatorUserID string, update *User) (*UserInfo, error)
|
SaveUser(ctx context.Context, accountID, initiatorUserID string, update *User) (*UserInfo, error)
|
||||||
SaveOrAddUser(ctx context.Context, accountID, initiatorUserID string, update *User, addIfNotExists bool) (*UserInfo, error)
|
SaveOrAddUser(ctx context.Context, accountID, initiatorUserID string, update *User, addIfNotExists bool) (*UserInfo, error)
|
||||||
|
SaveOrAddUsers(ctx context.Context, accountID, initiatorUserID string, updates []*User, addIfNotExists bool) ([]*UserInfo, error)
|
||||||
GetSetupKey(ctx context.Context, accountID, userID, keyID string) (*SetupKey, error)
|
GetSetupKey(ctx context.Context, accountID, userID, keyID string) (*SetupKey, error)
|
||||||
GetAccountByUserOrAccountID(ctx context.Context, userID, accountID, domain string) (*Account, error)
|
GetAccountByUserOrAccountID(ctx context.Context, userID, accountID, domain string) (*Account, error)
|
||||||
GetAccountFromToken(ctx context.Context, claims jwtclaims.AuthorizationClaims) (*Account, *User, error)
|
GetAccountFromToken(ctx context.Context, claims jwtclaims.AuthorizationClaims) (*Account, *User, error)
|
||||||
@ -95,6 +96,7 @@ type AccountManager interface {
|
|||||||
GetAllGroups(ctx context.Context, accountID, userID string) ([]*nbgroup.Group, error)
|
GetAllGroups(ctx context.Context, accountID, userID string) ([]*nbgroup.Group, error)
|
||||||
GetGroupByName(ctx context.Context, groupName, accountID string) (*nbgroup.Group, error)
|
GetGroupByName(ctx context.Context, groupName, accountID string) (*nbgroup.Group, error)
|
||||||
SaveGroup(ctx context.Context, accountID, userID string, group *nbgroup.Group) error
|
SaveGroup(ctx context.Context, accountID, userID string, group *nbgroup.Group) error
|
||||||
|
SaveGroups(ctx context.Context, accountID, userID string, newGroups []*nbgroup.Group) error
|
||||||
DeleteGroup(ctx context.Context, accountId, userId, groupID string) error
|
DeleteGroup(ctx context.Context, accountId, userId, groupID string) error
|
||||||
ListGroups(ctx context.Context, accountId string) ([]*nbgroup.Group, error)
|
ListGroups(ctx context.Context, accountId string) ([]*nbgroup.Group, error)
|
||||||
GroupAddPeer(ctx context.Context, accountId, groupID, peerID string) error
|
GroupAddPeer(ctx context.Context, accountId, groupID, peerID string) error
|
||||||
|
@ -746,3 +746,11 @@ func (s *FileStore) Close(ctx context.Context) error {
|
|||||||
func (s *FileStore) GetStoreEngine() StoreEngine {
|
func (s *FileStore) GetStoreEngine() StoreEngine {
|
||||||
return FileStoreEngine
|
return FileStoreEngine
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (s *FileStore) SaveUsers(accountID string, users map[string]*User) error {
|
||||||
|
return status.Errorf(status.Internal, "SaveUsers is not implemented")
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *FileStore) SaveGroups(accountID string, groups map[string]*nbgroup.Group) error {
|
||||||
|
return status.Errorf(status.Internal, "SaveGroups is not implemented")
|
||||||
|
}
|
||||||
|
@ -112,61 +112,85 @@ func (am *DefaultAccountManager) GetGroupByName(ctx context.Context, groupName,
|
|||||||
func (am *DefaultAccountManager) SaveGroup(ctx context.Context, accountID, userID string, newGroup *nbgroup.Group) error {
|
func (am *DefaultAccountManager) SaveGroup(ctx context.Context, accountID, userID string, newGroup *nbgroup.Group) error {
|
||||||
unlock := am.Store.AcquireAccountWriteLock(ctx, accountID)
|
unlock := am.Store.AcquireAccountWriteLock(ctx, accountID)
|
||||||
defer unlock()
|
defer unlock()
|
||||||
|
return am.SaveGroups(ctx, accountID, userID, []*nbgroup.Group{newGroup})
|
||||||
|
}
|
||||||
|
|
||||||
|
// SaveGroups adds new groups to the account.
|
||||||
|
// Note: This function does not acquire the global lock.
|
||||||
|
// It is the caller's responsibility to ensure proper locking is in place before invoking this method.
|
||||||
|
func (am *DefaultAccountManager) SaveGroups(ctx context.Context, accountID, userID string, newGroups []*nbgroup.Group) error {
|
||||||
account, err := am.Store.GetAccount(ctx, accountID)
|
account, err := am.Store.GetAccount(ctx, accountID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
if newGroup.ID == "" && newGroup.Issued != nbgroup.GroupIssuedAPI {
|
var eventsToStore []func()
|
||||||
return status.Errorf(status.InvalidArgument, "%s group without ID set", newGroup.Issued)
|
|
||||||
}
|
|
||||||
|
|
||||||
if newGroup.ID == "" && newGroup.Issued == nbgroup.GroupIssuedAPI {
|
for _, newGroup := range newGroups {
|
||||||
|
if newGroup.ID == "" && newGroup.Issued != nbgroup.GroupIssuedAPI {
|
||||||
|
return status.Errorf(status.InvalidArgument, "%s group without ID set", newGroup.Issued)
|
||||||
|
}
|
||||||
|
|
||||||
existingGroup, err := account.FindGroupByName(newGroup.Name)
|
if newGroup.ID == "" && newGroup.Issued == nbgroup.GroupIssuedAPI {
|
||||||
if err != nil {
|
existingGroup, err := account.FindGroupByName(newGroup.Name)
|
||||||
s, ok := status.FromError(err)
|
if err != nil {
|
||||||
if !ok || s.ErrorType != status.NotFound {
|
s, ok := status.FromError(err)
|
||||||
return err
|
if !ok || s.ErrorType != status.NotFound {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Avoid duplicate groups only for the API issued groups.
|
||||||
|
// Integration or JWT groups can be duplicated as they are coming from the IdP that we don't have control of.
|
||||||
|
if existingGroup != nil {
|
||||||
|
return status.Errorf(status.AlreadyExists, "group with name %s already exists", newGroup.Name)
|
||||||
|
}
|
||||||
|
|
||||||
|
newGroup.ID = xid.New().String()
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, peerID := range newGroup.Peers {
|
||||||
|
if account.Peers[peerID] == nil {
|
||||||
|
return status.Errorf(status.InvalidArgument, "peer with ID \"%s\" not found", peerID)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// avoid duplicate groups only for the API issued groups. Integration or JWT groups can be duplicated as they are
|
oldGroup := account.Groups[newGroup.ID]
|
||||||
// coming from the IdP that we don't have control of.
|
account.Groups[newGroup.ID] = newGroup
|
||||||
if existingGroup != nil {
|
|
||||||
return status.Errorf(status.AlreadyExists, "group with name %s already exists", newGroup.Name)
|
|
||||||
}
|
|
||||||
|
|
||||||
newGroup.ID = xid.New().String()
|
events := am.prepareGroupEvents(ctx, userID, accountID, newGroup, oldGroup, account)
|
||||||
|
eventsToStore = append(eventsToStore, events...)
|
||||||
}
|
}
|
||||||
|
|
||||||
for _, peerID := range newGroup.Peers {
|
|
||||||
if account.Peers[peerID] == nil {
|
|
||||||
return status.Errorf(status.InvalidArgument, "peer with ID \"%s\" not found", peerID)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
oldGroup, exists := account.Groups[newGroup.ID]
|
|
||||||
account.Groups[newGroup.ID] = newGroup
|
|
||||||
|
|
||||||
account.Network.IncSerial()
|
account.Network.IncSerial()
|
||||||
if err = am.Store.SaveAccount(ctx, account); err != nil {
|
if err = am.Store.SaveGroups(account.Id, account.Groups); err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
am.updateAccountPeers(ctx, account)
|
am.updateAccountPeers(ctx, account)
|
||||||
|
|
||||||
// the following snippet tracks the activity and stores the group events in the event store.
|
for _, storeEvent := range eventsToStore {
|
||||||
// It has to happen after all the operations have been successfully performed.
|
storeEvent()
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// prepareGroupEvents prepares a list of event functions to be stored.
|
||||||
|
func (am *DefaultAccountManager) prepareGroupEvents(ctx context.Context, userID string, accountID string, newGroup, oldGroup *nbgroup.Group, account *Account) []func() {
|
||||||
|
var eventsToStore []func()
|
||||||
|
|
||||||
addedPeers := make([]string, 0)
|
addedPeers := make([]string, 0)
|
||||||
removedPeers := make([]string, 0)
|
removedPeers := make([]string, 0)
|
||||||
if exists {
|
|
||||||
|
if oldGroup != nil {
|
||||||
addedPeers = difference(newGroup.Peers, oldGroup.Peers)
|
addedPeers = difference(newGroup.Peers, oldGroup.Peers)
|
||||||
removedPeers = difference(oldGroup.Peers, newGroup.Peers)
|
removedPeers = difference(oldGroup.Peers, newGroup.Peers)
|
||||||
} else {
|
} else {
|
||||||
addedPeers = append(addedPeers, newGroup.Peers...)
|
addedPeers = append(addedPeers, newGroup.Peers...)
|
||||||
am.StoreEvent(ctx, userID, newGroup.ID, accountID, activity.GroupCreated, newGroup.EventMeta())
|
eventsToStore = append(eventsToStore, func() {
|
||||||
|
am.StoreEvent(ctx, userID, newGroup.ID, accountID, activity.GroupCreated, newGroup.EventMeta())
|
||||||
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
for _, p := range addedPeers {
|
for _, p := range addedPeers {
|
||||||
@ -175,11 +199,14 @@ func (am *DefaultAccountManager) SaveGroup(ctx context.Context, accountID, userI
|
|||||||
log.WithContext(ctx).Errorf("peer %s not found under account %s while saving group", p, accountID)
|
log.WithContext(ctx).Errorf("peer %s not found under account %s while saving group", p, accountID)
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
am.StoreEvent(ctx, userID, peer.ID, accountID, activity.GroupAddedToPeer,
|
peerCopy := peer // copy to avoid closure issues
|
||||||
map[string]any{
|
eventsToStore = append(eventsToStore, func() {
|
||||||
"group": newGroup.Name, "group_id": newGroup.ID, "peer_ip": peer.IP.String(),
|
am.StoreEvent(ctx, userID, peerCopy.ID, accountID, activity.GroupAddedToPeer,
|
||||||
"peer_fqdn": peer.FQDN(am.GetDNSDomain()),
|
map[string]any{
|
||||||
})
|
"group": newGroup.Name, "group_id": newGroup.ID, "peer_ip": peerCopy.IP.String(),
|
||||||
|
"peer_fqdn": peerCopy.FQDN(am.GetDNSDomain()),
|
||||||
|
})
|
||||||
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
for _, p := range removedPeers {
|
for _, p := range removedPeers {
|
||||||
@ -188,14 +215,17 @@ func (am *DefaultAccountManager) SaveGroup(ctx context.Context, accountID, userI
|
|||||||
log.WithContext(ctx).Errorf("peer %s not found under account %s while saving group", p, accountID)
|
log.WithContext(ctx).Errorf("peer %s not found under account %s while saving group", p, accountID)
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
am.StoreEvent(ctx, userID, peer.ID, accountID, activity.GroupRemovedFromPeer,
|
peerCopy := peer // copy to avoid closure issues
|
||||||
map[string]any{
|
eventsToStore = append(eventsToStore, func() {
|
||||||
"group": newGroup.Name, "group_id": newGroup.ID, "peer_ip": peer.IP.String(),
|
am.StoreEvent(ctx, userID, peerCopy.ID, accountID, activity.GroupRemovedFromPeer,
|
||||||
"peer_fqdn": peer.FQDN(am.GetDNSDomain()),
|
map[string]any{
|
||||||
})
|
"group": newGroup.Name, "group_id": newGroup.ID, "peer_ip": peerCopy.IP.String(),
|
||||||
|
"peer_fqdn": peerCopy.FQDN(am.GetDNSDomain()),
|
||||||
|
})
|
||||||
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
return nil
|
return eventsToStore
|
||||||
}
|
}
|
||||||
|
|
||||||
// difference returns the elements in `a` that aren't in `b`.
|
// difference returns the elements in `a` that aren't in `b`.
|
||||||
|
@ -40,6 +40,7 @@ type MockAccountManager struct {
|
|||||||
GetAllGroupsFunc func(ctx context.Context, accountID, userID string) ([]*group.Group, error)
|
GetAllGroupsFunc func(ctx context.Context, accountID, userID string) ([]*group.Group, error)
|
||||||
GetGroupByNameFunc func(ctx context.Context, accountID, groupName string) (*group.Group, error)
|
GetGroupByNameFunc func(ctx context.Context, accountID, groupName string) (*group.Group, error)
|
||||||
SaveGroupFunc func(ctx context.Context, accountID, userID string, group *group.Group) error
|
SaveGroupFunc func(ctx context.Context, accountID, userID string, group *group.Group) error
|
||||||
|
SaveGroupsFunc func(ctx context.Context, accountID, userID string, groups []*group.Group) error
|
||||||
DeleteGroupFunc func(ctx context.Context, accountID, userId, groupID string) error
|
DeleteGroupFunc func(ctx context.Context, accountID, userId, groupID string) error
|
||||||
ListGroupsFunc func(ctx context.Context, accountID string) ([]*group.Group, error)
|
ListGroupsFunc func(ctx context.Context, accountID string) ([]*group.Group, error)
|
||||||
GroupAddPeerFunc func(ctx context.Context, accountID, groupID, peerID string) error
|
GroupAddPeerFunc func(ctx context.Context, accountID, groupID, peerID string) error
|
||||||
@ -64,6 +65,7 @@ type MockAccountManager struct {
|
|||||||
ListSetupKeysFunc func(ctx context.Context, accountID, userID string) ([]*server.SetupKey, error)
|
ListSetupKeysFunc func(ctx context.Context, accountID, userID string) ([]*server.SetupKey, error)
|
||||||
SaveUserFunc func(ctx context.Context, accountID, userID string, user *server.User) (*server.UserInfo, error)
|
SaveUserFunc func(ctx context.Context, accountID, userID string, user *server.User) (*server.UserInfo, error)
|
||||||
SaveOrAddUserFunc func(ctx context.Context, accountID, userID string, user *server.User, addIfNotExists bool) (*server.UserInfo, error)
|
SaveOrAddUserFunc func(ctx context.Context, accountID, userID string, user *server.User, addIfNotExists bool) (*server.UserInfo, error)
|
||||||
|
SaveOrAddUsersFunc func(ctx context.Context, accountID, initiatorUserID string, update []*server.User, addIfNotExists bool) ([]*server.UserInfo, error)
|
||||||
DeleteUserFunc func(ctx context.Context, accountID string, initiatorUserID string, targetUserID string) error
|
DeleteUserFunc func(ctx context.Context, accountID string, initiatorUserID string, targetUserID string) error
|
||||||
CreatePATFunc func(ctx context.Context, accountID string, initiatorUserID string, targetUserId string, tokenName string, expiresIn int) (*server.PersonalAccessTokenGenerated, error)
|
CreatePATFunc func(ctx context.Context, accountID string, initiatorUserID string, targetUserId string, tokenName string, expiresIn int) (*server.PersonalAccessTokenGenerated, error)
|
||||||
DeletePATFunc func(ctx context.Context, accountID string, initiatorUserID string, targetUserId string, tokenID string) error
|
DeletePATFunc func(ctx context.Context, accountID string, initiatorUserID string, targetUserId string, tokenID string) error
|
||||||
@ -308,6 +310,14 @@ func (am *MockAccountManager) SaveGroup(ctx context.Context, accountID, userID s
|
|||||||
return status.Errorf(codes.Unimplemented, "method SaveGroup is not implemented")
|
return status.Errorf(codes.Unimplemented, "method SaveGroup is not implemented")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// SaveGroups mock implementation of SaveGroups from server.AccountManager interface
|
||||||
|
func (am *MockAccountManager) SaveGroups(ctx context.Context, accountID, userID string, groups []*group.Group) error {
|
||||||
|
if am.SaveGroupsFunc != nil {
|
||||||
|
return am.SaveGroupsFunc(ctx, accountID, userID, groups)
|
||||||
|
}
|
||||||
|
return status.Errorf(codes.Unimplemented, "method SaveGroups is not implemented")
|
||||||
|
}
|
||||||
|
|
||||||
// DeleteGroup mock implementation of DeleteGroup from server.AccountManager interface
|
// DeleteGroup mock implementation of DeleteGroup from server.AccountManager interface
|
||||||
func (am *MockAccountManager) DeleteGroup(ctx context.Context, accountId, userId, groupID string) error {
|
func (am *MockAccountManager) DeleteGroup(ctx context.Context, accountId, userId, groupID string) error {
|
||||||
if am.DeleteGroupFunc != nil {
|
if am.DeleteGroupFunc != nil {
|
||||||
@ -502,6 +512,14 @@ func (am *MockAccountManager) SaveOrAddUser(ctx context.Context, accountID, user
|
|||||||
return nil, status.Errorf(codes.Unimplemented, "method SaveOrAddUser is not implemented")
|
return nil, status.Errorf(codes.Unimplemented, "method SaveOrAddUser is not implemented")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// SaveOrAddUsers mocks SaveOrAddUsers of the AccountManager interface
|
||||||
|
func (am *MockAccountManager) SaveOrAddUsers(ctx context.Context, accountID, userID string, users []*server.User, addIfNotExists bool) ([]*server.UserInfo, error) {
|
||||||
|
if am.SaveOrAddUsersFunc != nil {
|
||||||
|
return am.SaveOrAddUsersFunc(ctx, accountID, userID, users, addIfNotExists)
|
||||||
|
}
|
||||||
|
return nil, status.Errorf(codes.Unimplemented, "method SaveOrAddUsers is not implemented")
|
||||||
|
}
|
||||||
|
|
||||||
// DeleteUser mocks DeleteUser of the AccountManager interface
|
// DeleteUser mocks DeleteUser of the AccountManager interface
|
||||||
func (am *MockAccountManager) DeleteUser(ctx context.Context, accountID string, initiatorUserID string, targetUserID string) error {
|
func (am *MockAccountManager) DeleteUser(ctx context.Context, accountID string, initiatorUserID string, targetUserID string) error {
|
||||||
if am.DeleteUserFunc != nil {
|
if am.DeleteUserFunc != nil {
|
||||||
|
@ -311,6 +311,34 @@ func (s *SqlStore) SavePeerLocation(accountID string, peerWithLocation *nbpeer.P
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// SaveUsers saves the given list of users to the database.
|
||||||
|
// It updates existing users if a conflict occurs.
|
||||||
|
func (s *SqlStore) SaveUsers(accountID string, users map[string]*User) error {
|
||||||
|
usersToSave := make([]User, 0, len(users))
|
||||||
|
for _, user := range users {
|
||||||
|
user.AccountID = accountID
|
||||||
|
for id, pat := range user.PATs {
|
||||||
|
pat.ID = id
|
||||||
|
user.PATsG = append(user.PATsG, *pat)
|
||||||
|
}
|
||||||
|
usersToSave = append(usersToSave, *user)
|
||||||
|
}
|
||||||
|
return s.db.Session(&gorm.Session{FullSaveAssociations: true}).
|
||||||
|
Clauses(clause.OnConflict{UpdateAll: true}).
|
||||||
|
Create(&usersToSave).Error
|
||||||
|
}
|
||||||
|
|
||||||
|
// SaveGroups saves the given list of groups to the database.
|
||||||
|
// It updates existing groups if a conflict occurs.
|
||||||
|
func (s *SqlStore) SaveGroups(accountID string, groups map[string]*nbgroup.Group) error {
|
||||||
|
groupsToSave := make([]nbgroup.Group, 0, len(groups))
|
||||||
|
for _, group := range groups {
|
||||||
|
group.AccountID = accountID
|
||||||
|
groupsToSave = append(groupsToSave, *group)
|
||||||
|
}
|
||||||
|
return s.db.Clauses(clause.OnConflict{UpdateAll: true}).Create(&groupsToSave).Error
|
||||||
|
}
|
||||||
|
|
||||||
// DeleteHashedPAT2TokenIDIndex is noop in SqlStore
|
// DeleteHashedPAT2TokenIDIndex is noop in SqlStore
|
||||||
func (s *SqlStore) DeleteHashedPAT2TokenIDIndex(hashedToken string) error {
|
func (s *SqlStore) DeleteHashedPAT2TokenIDIndex(hashedToken string) error {
|
||||||
return nil
|
return nil
|
||||||
|
@ -12,6 +12,7 @@ import (
|
|||||||
"strings"
|
"strings"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
nbgroup "github.com/netbirdio/netbird/management/server/group"
|
||||||
log "github.com/sirupsen/logrus"
|
log "github.com/sirupsen/logrus"
|
||||||
"gorm.io/gorm"
|
"gorm.io/gorm"
|
||||||
|
|
||||||
@ -41,6 +42,8 @@ type Store interface {
|
|||||||
GetUserByTokenID(ctx context.Context, tokenID string) (*User, error)
|
GetUserByTokenID(ctx context.Context, tokenID string) (*User, error)
|
||||||
GetPostureCheckByChecksDefinition(accountID string, checks *posture.ChecksDefinition) (*posture.Checks, error)
|
GetPostureCheckByChecksDefinition(accountID string, checks *posture.ChecksDefinition) (*posture.Checks, error)
|
||||||
SaveAccount(ctx context.Context, account *Account) error
|
SaveAccount(ctx context.Context, account *Account) error
|
||||||
|
SaveUsers(accountID string, users map[string]*User) error
|
||||||
|
SaveGroups(accountID string, groups map[string]*nbgroup.Group) error
|
||||||
DeleteHashedPAT2TokenIDIndex(hashedToken string) error
|
DeleteHashedPAT2TokenIDIndex(hashedToken string) error
|
||||||
DeleteTokenID2UserIDIndex(tokenID string) error
|
DeleteTokenID2UserIDIndex(tokenID string) error
|
||||||
GetInstallationID() string
|
GetInstallationID() string
|
||||||
|
@ -740,7 +740,7 @@ func (am *DefaultAccountManager) GetAllPATs(ctx context.Context, accountID strin
|
|||||||
return pats, nil
|
return pats, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// SaveUser saves updates to the given user. If the user doesn't exit it will throw status.NotFound error.
|
// SaveUser saves updates to the given user. If the user doesn't exist, it will throw status.NotFound error.
|
||||||
func (am *DefaultAccountManager) SaveUser(ctx context.Context, accountID, initiatorUserID string, update *User) (*UserInfo, error) {
|
func (am *DefaultAccountManager) SaveUser(ctx context.Context, accountID, initiatorUserID string, update *User) (*UserInfo, error) {
|
||||||
return am.SaveOrAddUser(ctx, accountID, initiatorUserID, update, false) // false means do not create user and throw status.NotFound
|
return am.SaveOrAddUser(ctx, accountID, initiatorUserID, update, false) // false means do not create user and throw status.NotFound
|
||||||
}
|
}
|
||||||
@ -748,11 +748,31 @@ func (am *DefaultAccountManager) SaveUser(ctx context.Context, accountID, initia
|
|||||||
// SaveOrAddUser updates the given user. If addIfNotExists is set to true it will add user when no exist
|
// SaveOrAddUser updates the given user. If addIfNotExists is set to true it will add user when no exist
|
||||||
// Only User.AutoGroups, User.Role, and User.Blocked fields are allowed to be updated for now.
|
// Only User.AutoGroups, User.Role, and User.Blocked fields are allowed to be updated for now.
|
||||||
func (am *DefaultAccountManager) SaveOrAddUser(ctx context.Context, accountID, initiatorUserID string, update *User, addIfNotExists bool) (*UserInfo, error) {
|
func (am *DefaultAccountManager) SaveOrAddUser(ctx context.Context, accountID, initiatorUserID string, update *User, addIfNotExists bool) (*UserInfo, error) {
|
||||||
|
if update == nil {
|
||||||
|
return nil, status.Errorf(status.InvalidArgument, "provided user update is nil")
|
||||||
|
}
|
||||||
|
|
||||||
unlock := am.Store.AcquireAccountWriteLock(ctx, accountID)
|
unlock := am.Store.AcquireAccountWriteLock(ctx, accountID)
|
||||||
defer unlock()
|
defer unlock()
|
||||||
|
|
||||||
if update == nil {
|
updatedUsers, err := am.SaveOrAddUsers(ctx, accountID, initiatorUserID, []*User{update}, addIfNotExists)
|
||||||
return nil, status.Errorf(status.InvalidArgument, "provided user update is nil")
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(updatedUsers) == 0 {
|
||||||
|
return nil, status.Errorf(status.Internal, "user was not updated")
|
||||||
|
}
|
||||||
|
|
||||||
|
return updatedUsers[0], nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// SaveOrAddUsers updates existing users or adds new users to the account.
|
||||||
|
// Note: This function does not acquire the global lock.
|
||||||
|
// It is the caller's responsibility to ensure proper locking is in place before invoking this method.
|
||||||
|
func (am *DefaultAccountManager) SaveOrAddUsers(ctx context.Context, accountID, initiatorUserID string, updates []*User, addIfNotExists bool) ([]*UserInfo, error) {
|
||||||
|
if len(updates) == 0 {
|
||||||
|
return nil, nil //nolint:nilnil
|
||||||
}
|
}
|
||||||
|
|
||||||
account, err := am.Store.GetAccount(ctx, accountID)
|
account, err := am.Store.GetAccount(ctx, accountID)
|
||||||
@ -769,144 +789,200 @@ func (am *DefaultAccountManager) SaveOrAddUser(ctx context.Context, accountID, i
|
|||||||
return nil, status.Errorf(status.PermissionDenied, "only users with admin power are authorized to perform user update operations")
|
return nil, status.Errorf(status.PermissionDenied, "only users with admin power are authorized to perform user update operations")
|
||||||
}
|
}
|
||||||
|
|
||||||
oldUser := account.Users[update.Id]
|
updatedUsers := make([]*UserInfo, 0, len(updates))
|
||||||
if oldUser == nil {
|
var (
|
||||||
if !addIfNotExists {
|
expiredPeers []*nbpeer.Peer
|
||||||
return nil, status.Errorf(status.NotFound, "user to update doesn't exist")
|
eventsToStore []func()
|
||||||
|
)
|
||||||
|
|
||||||
|
for _, update := range updates {
|
||||||
|
if update == nil {
|
||||||
|
return nil, status.Errorf(status.InvalidArgument, "provided user update is nil")
|
||||||
}
|
}
|
||||||
// when addIfNotExists is set to true the newUser will use all fields from the update input
|
|
||||||
oldUser = update
|
|
||||||
}
|
|
||||||
|
|
||||||
if initiatorUser.HasAdminPower() && initiatorUserID == update.Id && oldUser.Blocked != update.Blocked {
|
oldUser := account.Users[update.Id]
|
||||||
return nil, status.Errorf(status.PermissionDenied, "admins can't block or unblock themselves")
|
if oldUser == nil {
|
||||||
}
|
if !addIfNotExists {
|
||||||
|
return nil, status.Errorf(status.NotFound, "user to update doesn't exist: %s", update.Id)
|
||||||
if initiatorUser.HasAdminPower() && initiatorUserID == update.Id && update.Role != initiatorUser.Role {
|
}
|
||||||
return nil, status.Errorf(status.PermissionDenied, "admins can't change their role")
|
// when addIfNotExists is set to true, the newUser will use all fields from the update input
|
||||||
}
|
oldUser = update
|
||||||
|
|
||||||
if initiatorUser.Role == UserRoleAdmin && oldUser.Role == UserRoleOwner && update.Role != oldUser.Role {
|
|
||||||
return nil, status.Errorf(status.PermissionDenied, "only owners can remove owner role from their user")
|
|
||||||
}
|
|
||||||
|
|
||||||
if initiatorUser.Role == UserRoleAdmin && oldUser.Role == UserRoleOwner && update.IsBlocked() && !oldUser.IsBlocked() {
|
|
||||||
return nil, status.Errorf(status.PermissionDenied, "unable to block owner user")
|
|
||||||
}
|
|
||||||
|
|
||||||
if initiatorUser.Role == UserRoleAdmin && update.Role == UserRoleOwner && update.Role != oldUser.Role {
|
|
||||||
return nil, status.Errorf(status.PermissionDenied, "only owners can add owner role to other users")
|
|
||||||
}
|
|
||||||
|
|
||||||
if oldUser.IsServiceUser && update.Role == UserRoleOwner {
|
|
||||||
return nil, status.Errorf(status.PermissionDenied, "can't update a service user with owner role")
|
|
||||||
}
|
|
||||||
|
|
||||||
transferedOwnerRole := false
|
|
||||||
if initiatorUser.Role == UserRoleOwner && initiatorUserID != update.Id && update.Role == UserRoleOwner {
|
|
||||||
newInitiatorUser := initiatorUser.Copy()
|
|
||||||
newInitiatorUser.Role = UserRoleAdmin
|
|
||||||
account.Users[initiatorUserID] = newInitiatorUser
|
|
||||||
transferedOwnerRole = true
|
|
||||||
}
|
|
||||||
|
|
||||||
// only auto groups, revoked status, and integration reference can be updated for now
|
|
||||||
newUser := oldUser.Copy()
|
|
||||||
newUser.Role = update.Role
|
|
||||||
newUser.Blocked = update.Blocked
|
|
||||||
// these two fields can't be set via API, only via direct call to the method
|
|
||||||
newUser.Issued = update.Issued
|
|
||||||
newUser.IntegrationReference = update.IntegrationReference
|
|
||||||
|
|
||||||
for _, newGroupID := range update.AutoGroups {
|
|
||||||
if _, ok := account.Groups[newGroupID]; !ok {
|
|
||||||
return nil, status.Errorf(status.InvalidArgument, "provided group ID %s in the user %s update doesn't exist",
|
|
||||||
newGroupID, update.Id)
|
|
||||||
}
|
}
|
||||||
}
|
|
||||||
newUser.AutoGroups = update.AutoGroups
|
|
||||||
|
|
||||||
account.Users[newUser.Id] = newUser
|
if err := validateUserUpdate(account, initiatorUser, oldUser, update); err != nil {
|
||||||
|
|
||||||
if !oldUser.IsBlocked() && update.IsBlocked() {
|
|
||||||
// expire peers that belong to the user who's getting blocked
|
|
||||||
blockedPeers, err := account.FindUserPeers(update.Id)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := am.expireAndUpdatePeers(ctx, account, blockedPeers); err != nil {
|
// only auto groups, revoked status, and integration reference can be updated for now
|
||||||
|
newUser := oldUser.Copy()
|
||||||
|
newUser.Role = update.Role
|
||||||
|
newUser.Blocked = update.Blocked
|
||||||
|
newUser.AutoGroups = update.AutoGroups
|
||||||
|
// these two fields can't be set via API, only via direct call to the method
|
||||||
|
newUser.Issued = update.Issued
|
||||||
|
newUser.IntegrationReference = update.IntegrationReference
|
||||||
|
|
||||||
|
transferredOwnerRole := handleOwnerRoleTransfer(account, initiatorUser, update)
|
||||||
|
account.Users[newUser.Id] = newUser
|
||||||
|
|
||||||
|
if !oldUser.IsBlocked() && update.IsBlocked() {
|
||||||
|
// expire peers that belong to the user who's getting blocked
|
||||||
|
blockedPeers, err := account.FindUserPeers(update.Id)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
expiredPeers = append(expiredPeers, blockedPeers...)
|
||||||
|
}
|
||||||
|
|
||||||
|
if update.AutoGroups != nil && account.Settings.GroupsPropagationEnabled {
|
||||||
|
removedGroups := difference(oldUser.AutoGroups, update.AutoGroups)
|
||||||
|
// need force update all auto groups in any case they will not be duplicated
|
||||||
|
account.UserGroupsAddToPeers(oldUser.Id, update.AutoGroups...)
|
||||||
|
account.UserGroupsRemoveFromPeers(oldUser.Id, removedGroups...)
|
||||||
|
}
|
||||||
|
|
||||||
|
events := am.prepareUserUpdateEvents(ctx, initiatorUser.Id, oldUser, newUser, account, transferredOwnerRole)
|
||||||
|
eventsToStore = append(eventsToStore, events...)
|
||||||
|
|
||||||
|
updatedUserInfo, err := getUserInfo(ctx, am, newUser, account)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
updatedUsers = append(updatedUsers, updatedUserInfo)
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(expiredPeers) > 0 {
|
||||||
|
if err := am.expireAndUpdatePeers(ctx, account, expiredPeers); err != nil {
|
||||||
log.WithContext(ctx).Errorf("failed update expired peers: %s", err)
|
log.WithContext(ctx).Errorf("failed update expired peers: %s", err)
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if update.AutoGroups != nil && account.Settings.GroupsPropagationEnabled {
|
account.Network.IncSerial()
|
||||||
removedGroups := difference(oldUser.AutoGroups, update.AutoGroups)
|
if err = am.Store.SaveUsers(account.Id, account.Users); err != nil {
|
||||||
// need force update all auto groups in any case they will not be duplicated
|
return nil, err
|
||||||
account.UserGroupsAddToPeers(oldUser.Id, update.AutoGroups...)
|
}
|
||||||
account.UserGroupsRemoveFromPeers(oldUser.Id, removedGroups...)
|
|
||||||
|
|
||||||
account.Network.IncSerial()
|
|
||||||
if err = am.Store.SaveAccount(ctx, account); err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
|
if account.Settings.GroupsPropagationEnabled {
|
||||||
am.updateAccountPeers(ctx, account)
|
am.updateAccountPeers(ctx, account)
|
||||||
} else {
|
}
|
||||||
if err = am.Store.SaveAccount(ctx, account); err != nil {
|
|
||||||
return nil, err
|
for _, storeEvent := range eventsToStore {
|
||||||
|
storeEvent()
|
||||||
|
}
|
||||||
|
|
||||||
|
return updatedUsers, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// prepareUserUpdateEvents prepares a list user update events based on the changes between the old and new user data.
|
||||||
|
func (am *DefaultAccountManager) prepareUserUpdateEvents(ctx context.Context, initiatorUserID string, oldUser, newUser *User, account *Account, transferredOwnerRole bool) []func() {
|
||||||
|
var eventsToStore []func()
|
||||||
|
|
||||||
|
if oldUser.IsBlocked() != newUser.IsBlocked() {
|
||||||
|
if newUser.IsBlocked() {
|
||||||
|
eventsToStore = append(eventsToStore, func() {
|
||||||
|
am.StoreEvent(ctx, initiatorUserID, oldUser.Id, account.Id, activity.UserBlocked, nil)
|
||||||
|
})
|
||||||
|
} else {
|
||||||
|
eventsToStore = append(eventsToStore, func() {
|
||||||
|
am.StoreEvent(ctx, initiatorUserID, oldUser.Id, account.Id, activity.UserUnblocked, nil)
|
||||||
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
defer func() {
|
switch {
|
||||||
if oldUser.IsBlocked() != update.IsBlocked() {
|
case transferredOwnerRole:
|
||||||
if update.IsBlocked() {
|
eventsToStore = append(eventsToStore, func() {
|
||||||
am.StoreEvent(ctx, initiatorUserID, oldUser.Id, accountID, activity.UserBlocked, nil)
|
am.StoreEvent(ctx, initiatorUserID, oldUser.Id, account.Id, activity.TransferredOwnerRole, nil)
|
||||||
|
})
|
||||||
|
case oldUser.Role != newUser.Role:
|
||||||
|
eventsToStore = append(eventsToStore, func() {
|
||||||
|
am.StoreEvent(ctx, initiatorUserID, oldUser.Id, account.Id, activity.UserRoleUpdated, map[string]any{"role": newUser.Role})
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
if newUser.AutoGroups != nil {
|
||||||
|
removedGroups := difference(oldUser.AutoGroups, newUser.AutoGroups)
|
||||||
|
addedGroups := difference(newUser.AutoGroups, oldUser.AutoGroups)
|
||||||
|
for _, g := range removedGroups {
|
||||||
|
group := account.GetGroup(g)
|
||||||
|
if group != nil {
|
||||||
|
eventsToStore = append(eventsToStore, func() {
|
||||||
|
am.StoreEvent(ctx, initiatorUserID, oldUser.Id, account.Id, activity.GroupRemovedFromUser,
|
||||||
|
map[string]any{"group": group.Name, "group_id": group.ID, "is_service_user": newUser.IsServiceUser, "user_name": newUser.ServiceUserName})
|
||||||
|
})
|
||||||
|
|
||||||
} else {
|
} else {
|
||||||
am.StoreEvent(ctx, initiatorUserID, oldUser.Id, accountID, activity.UserUnblocked, nil)
|
log.WithContext(ctx).Errorf("group %s not found while saving user activity event of account %s", g, account.Id)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
for _, g := range addedGroups {
|
||||||
switch {
|
group := account.GetGroup(g)
|
||||||
case transferedOwnerRole:
|
if group != nil {
|
||||||
am.StoreEvent(ctx, initiatorUserID, oldUser.Id, accountID, activity.TransferredOwnerRole, nil)
|
eventsToStore = append(eventsToStore, func() {
|
||||||
case oldUser.Role != newUser.Role:
|
am.StoreEvent(ctx, initiatorUserID, oldUser.Id, account.Id, activity.GroupAddedToUser,
|
||||||
am.StoreEvent(ctx, initiatorUserID, oldUser.Id, accountID, activity.UserRoleUpdated, map[string]any{"role": newUser.Role})
|
|
||||||
default:
|
|
||||||
}
|
|
||||||
|
|
||||||
if update.AutoGroups != nil {
|
|
||||||
removedGroups := difference(oldUser.AutoGroups, update.AutoGroups)
|
|
||||||
addedGroups := difference(newUser.AutoGroups, oldUser.AutoGroups)
|
|
||||||
for _, g := range removedGroups {
|
|
||||||
group := account.GetGroup(g)
|
|
||||||
if group != nil {
|
|
||||||
am.StoreEvent(ctx, initiatorUserID, oldUser.Id, accountID, activity.GroupRemovedFromUser,
|
|
||||||
map[string]any{"group": group.Name, "group_id": group.ID, "is_service_user": newUser.IsServiceUser, "user_name": newUser.ServiceUserName})
|
map[string]any{"group": group.Name, "group_id": group.ID, "is_service_user": newUser.IsServiceUser, "user_name": newUser.ServiceUserName})
|
||||||
} else {
|
})
|
||||||
log.WithContext(ctx).Errorf("group %s not found while saving user activity event of account %s", g, account.Id)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, g := range addedGroups {
|
|
||||||
group := account.GetGroup(g)
|
|
||||||
if group != nil {
|
|
||||||
am.StoreEvent(ctx, initiatorUserID, oldUser.Id, accountID, activity.GroupAddedToUser,
|
|
||||||
map[string]any{"group": group.Name, "group_id": group.ID, "is_service_user": newUser.IsServiceUser, "user_name": newUser.ServiceUserName})
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}()
|
}
|
||||||
|
|
||||||
if !isNil(am.idpManager) && !newUser.IsServiceUser {
|
return eventsToStore
|
||||||
userData, err := am.lookupUserInCache(ctx, newUser.Id, account)
|
}
|
||||||
|
|
||||||
|
func handleOwnerRoleTransfer(account *Account, initiatorUser, update *User) bool {
|
||||||
|
if initiatorUser.Role == UserRoleOwner && initiatorUser.Id != update.Id && update.Role == UserRoleOwner {
|
||||||
|
newInitiatorUser := initiatorUser.Copy()
|
||||||
|
newInitiatorUser.Role = UserRoleAdmin
|
||||||
|
account.Users[initiatorUser.Id] = newInitiatorUser
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
// getUserInfo retrieves the UserInfo for a given User and Account.
|
||||||
|
// If the AccountManager has a non-nil idpManager and the User is not a service user,
|
||||||
|
// it will attempt to look up the UserData from the cache.
|
||||||
|
func getUserInfo(ctx context.Context, am *DefaultAccountManager, user *User, account *Account) (*UserInfo, error) {
|
||||||
|
if !isNil(am.idpManager) && !user.IsServiceUser {
|
||||||
|
userData, err := am.lookupUserInCache(ctx, user.Id, account)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
return newUser.ToUserInfo(userData, account.Settings)
|
return user.ToUserInfo(userData, account.Settings)
|
||||||
}
|
}
|
||||||
return newUser.ToUserInfo(nil, account.Settings)
|
return user.ToUserInfo(nil, account.Settings)
|
||||||
|
}
|
||||||
|
|
||||||
|
// validateUserUpdate validates the update operation for a user.
|
||||||
|
func validateUserUpdate(account *Account, initiatorUser, oldUser, update *User) error {
|
||||||
|
if initiatorUser.HasAdminPower() && initiatorUser.Id == update.Id && oldUser.Blocked != update.Blocked {
|
||||||
|
return status.Errorf(status.PermissionDenied, "admins can't block or unblock themselves")
|
||||||
|
}
|
||||||
|
if initiatorUser.HasAdminPower() && initiatorUser.Id == update.Id && update.Role != initiatorUser.Role {
|
||||||
|
return status.Errorf(status.PermissionDenied, "admins can't change their role")
|
||||||
|
}
|
||||||
|
if initiatorUser.Role == UserRoleAdmin && oldUser.Role == UserRoleOwner && update.Role != oldUser.Role {
|
||||||
|
return status.Errorf(status.PermissionDenied, "only owners can remove owner role from their user")
|
||||||
|
}
|
||||||
|
if initiatorUser.Role == UserRoleAdmin && oldUser.Role == UserRoleOwner && update.IsBlocked() && !oldUser.IsBlocked() {
|
||||||
|
return status.Errorf(status.PermissionDenied, "unable to block owner user")
|
||||||
|
}
|
||||||
|
if initiatorUser.Role == UserRoleAdmin && update.Role == UserRoleOwner && update.Role != oldUser.Role {
|
||||||
|
return status.Errorf(status.PermissionDenied, "only owners can add owner role to other users")
|
||||||
|
}
|
||||||
|
if oldUser.IsServiceUser && update.Role == UserRoleOwner {
|
||||||
|
return status.Errorf(status.PermissionDenied, "can't update a service user with owner role")
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, newGroupID := range update.AutoGroups {
|
||||||
|
if _, ok := account.Groups[newGroupID]; !ok {
|
||||||
|
return status.Errorf(status.InvalidArgument, "provided group ID %s in the user %s update doesn't exist",
|
||||||
|
newGroupID, update.Id)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetOrCreateAccountByUser returns an existing account for a given user id or creates a new one if doesn't exist
|
// GetOrCreateAccountByUser returns an existing account for a given user id or creates a new one if doesn't exist
|
||||||
|
Loading…
x
Reference in New Issue
Block a user