diff --git a/management/server/account.go b/management/server/account.go index 208315643..710b6f62f 100644 --- a/management/server/account.go +++ b/management/server/account.go @@ -20,11 +20,6 @@ import ( cacheStore "github.com/eko/gocache/v3/store" "github.com/hashicorp/go-multierror" "github.com/miekg/dns" - gocache "github.com/patrickmn/go-cache" - "github.com/rs/xid" - log "github.com/sirupsen/logrus" - "golang.org/x/exp/maps" - "github.com/netbirdio/netbird/base62" nbdns "github.com/netbirdio/netbird/dns" "github.com/netbirdio/netbird/management/domain" @@ -41,6 +36,10 @@ import ( "github.com/netbirdio/netbird/management/server/status" "github.com/netbirdio/netbird/management/server/telemetry" "github.com/netbirdio/netbird/route" + gocache "github.com/patrickmn/go-cache" + "github.com/rs/xid" + log "github.com/sirupsen/logrus" + "golang.org/x/exp/maps" ) const ( @@ -63,6 +62,7 @@ func cacheEntryExpiration() time.Duration { type AccountManager interface { GetOrCreateAccountByUser(ctx context.Context, userId, domain string) (*Account, error) + GetAccount(ctx context.Context, accountID string) (*Account, error) CreateSetupKey(ctx context.Context, accountID string, keyName string, keyType SetupKeyType, expiresIn time.Duration, autoGroups []string, usageLimit int, userID string, ephemeral bool) (*SetupKey, error) SaveSetupKey(ctx context.Context, accountID string, key *SetupKey, userID string) (*SetupKey, error) @@ -75,12 +75,14 @@ type AccountManager interface { SaveOrAddUser(ctx context.Context, accountID, initiatorUserID string, update *User, addIfNotExists bool) (*UserInfo, error) SaveOrAddUsers(ctx context.Context, accountID, initiatorUserID string, updates []*User, addIfNotExists bool) ([]*UserInfo, error) GetSetupKey(ctx context.Context, accountID, userID, keyID string) (*SetupKey, error) - GetAccountByUserOrAccountID(ctx context.Context, userID, accountID, domain string) (*Account, error) - GetAccountFromToken(ctx context.Context, claims jwtclaims.AuthorizationClaims) (*Account, *User, error) + GetAccountByID(ctx context.Context, accountID string, userID string) (*Account, error) + GetAccountIDByUserOrAccountID(ctx context.Context, userID, accountID, domain string) (string, error) + GetAccountIDFromToken(ctx context.Context, claims jwtclaims.AuthorizationClaims) (string, string, error) CheckUserAccessByJWTGroups(ctx context.Context, claims jwtclaims.AuthorizationClaims) error GetAccountFromPAT(ctx context.Context, pat string) (*Account, *User, *PersonalAccessToken, error) DeleteAccount(ctx context.Context, accountID, userID string) error MarkPATUsed(ctx context.Context, tokenID string) error + GetUserByID(ctx context.Context, id string) (*User, error) GetUser(ctx context.Context, claims jwtclaims.AuthorizationClaims) (*User, error) ListUsers(ctx context.Context, accountID string) ([]*User, error) GetPeers(ctx context.Context, accountID, userID string) ([]*nbpeer.Peer, error) @@ -107,7 +109,7 @@ type AccountManager interface { GroupAddPeer(ctx context.Context, accountId, groupID, peerID string) error GroupDeletePeer(ctx context.Context, accountId, groupID, peerID string) error GetPolicy(ctx context.Context, accountID, policyID, userID string) (*Policy, error) - SavePolicy(ctx context.Context, accountID, userID string, policy *Policy) error + SavePolicy(ctx context.Context, accountID, userID string, policy *Policy, isUpdate bool) error DeletePolicy(ctx context.Context, accountID, policyID, userID string) error ListPolicies(ctx context.Context, accountID, userID string) ([]*Policy, error) GetRoute(ctx context.Context, accountID string, routeID route.ID, userID string) (*route.Route, error) @@ -145,6 +147,7 @@ type AccountManager interface { SyncPeerMeta(ctx context.Context, peerPubKey string, meta nbpeer.PeerSystemMeta) error FindExistingPostureCheck(accountID string, checks *posture.ChecksDefinition) (*posture.Checks, error) GetAccountIDForPeerKey(ctx context.Context, peerKey string) (string, error) + GetAccountSettings(ctx context.Context, accountID string, userID string) (*Settings, error) } type DefaultAccountManager struct { @@ -268,6 +271,11 @@ type AccountNetwork struct { Network *Network `gorm:"embedded;embeddedPrefix:network_"` } +// AccountDNSSettings used in gorm to only load dns settings and not whole account +type AccountDNSSettings struct { + DNSSettings DNSSettings `gorm:"embedded;embeddedPrefix:dns_settings_"` +} + type UserPermissions struct { DashboardView string `json:"dashboard_view"` } @@ -1252,25 +1260,37 @@ func (am *DefaultAccountManager) DeleteAccount(ctx context.Context, accountID, u return nil } -// GetAccountByUserOrAccountID looks for an account by user or accountID, if no account is provided and -// userID doesn't have an account associated with it, one account is created -// domain is used to create a new account if no account is found -func (am *DefaultAccountManager) GetAccountByUserOrAccountID(ctx context.Context, userID, accountID, domain string) (*Account, error) { +// GetAccountIDByUserOrAccountID retrieves the account ID based on either the userID or accountID provided. +// If an accountID is provided, it checks if the account exists and returns it. +// If no accountID is provided, but a userID is given, it tries to retrieve the account by userID. +// If the user doesn't have an account, it creates one using the provided domain. +// Returns the account ID or an error if none is found or created. +func (am *DefaultAccountManager) GetAccountIDByUserOrAccountID(ctx context.Context, userID, accountID, domain string) (string, error) { if accountID != "" { - return am.Store.GetAccount(ctx, accountID) - } else if userID != "" { - account, err := am.GetOrCreateAccountByUser(ctx, userID, domain) + exists, err := am.Store.AccountExists(ctx, LockingStrengthShare, accountID) if err != nil { - return nil, status.Errorf(status.NotFound, "account not found using user id: %s", userID) + return "", err } - err = am.addAccountIDToIDPAppMeta(ctx, userID, account) - if err != nil { - return nil, err + if !exists { + return "", status.Errorf(status.NotFound, "account %s does not exist", accountID) } - return account, nil + return accountID, nil } - return nil, status.Errorf(status.NotFound, "no valid user or account Id provided") + if userID != "" { + account, err := am.GetOrCreateAccountByUser(ctx, userID, domain) + if err != nil { + return "", status.Errorf(status.NotFound, "account not found or created for user id: %s", userID) + } + + if err = am.addAccountIDToIDPAppMeta(ctx, userID, account); err != nil { + return "", err + } + + return account.Id, nil + } + + return "", status.Errorf(status.NotFound, "no valid userID or accountID provided") } func isNil(i idp.Manager) bool { @@ -1613,13 +1633,18 @@ func (am *DefaultAccountManager) handleNewUserAccount(ctx context.Context, domai } // redeemInvite checks whether user has been invited and redeems the invite -func (am *DefaultAccountManager) redeemInvite(ctx context.Context, account *Account, userID string) error { +func (am *DefaultAccountManager) redeemInvite(ctx context.Context, accountID string, userID string) error { // only possible with the enabled IdP manager if am.idpManager == nil { log.WithContext(ctx).Warnf("invites only work with enabled IdP manager") return nil } + account, err := am.Store.GetAccount(ctx, accountID) + if err != nil { + return err + } + user, err := am.lookupUserInCache(ctx, userID, account) if err != nil { return err @@ -1678,6 +1703,11 @@ func (am *DefaultAccountManager) MarkPATUsed(ctx context.Context, tokenID string return am.Store.SaveAccount(ctx, account) } +// GetAccount returns an account associated with this account ID. +func (am *DefaultAccountManager) GetAccount(ctx context.Context, accountID string) (*Account, error) { + return am.Store.GetAccount(ctx, accountID) +} + // GetAccountFromPAT returns Account and User associated with a personal access token func (am *DefaultAccountManager) GetAccountFromPAT(ctx context.Context, token string) (*Account, *User, *PersonalAccessToken, error) { if len(token) != PATLength { @@ -1726,10 +1756,24 @@ func (am *DefaultAccountManager) GetAccountFromPAT(ctx context.Context, token st return account, user, pat, nil } -// GetAccountFromToken returns an account associated with this token -func (am *DefaultAccountManager) GetAccountFromToken(ctx context.Context, claims jwtclaims.AuthorizationClaims) (*Account, *User, error) { +// GetAccountByID returns an account associated with this account ID. +func (am *DefaultAccountManager) GetAccountByID(ctx context.Context, accountID string, userID string) (*Account, error) { + user, err := am.Store.GetUserByUserID(ctx, LockingStrengthShare, userID) + if err != nil { + return nil, err + } + + if user.AccountID != accountID || (!user.HasAdminPower() && !user.IsServiceUser) { + return nil, status.Errorf(status.PermissionDenied, "the user has no permission to access account data") + } + + return am.Store.GetAccount(ctx, accountID) +} + +// GetAccountIDFromToken returns an account ID associated with this token. +func (am *DefaultAccountManager) GetAccountIDFromToken(ctx context.Context, claims jwtclaims.AuthorizationClaims) (string, string, error) { if claims.UserId == "" { - return nil, nil, fmt.Errorf("user ID is empty") + return "", "", fmt.Errorf("user ID is empty") } if am.singleAccountMode && am.singleAccountModeDomain != "" { // This section is mostly related to self-hosted installations. @@ -1739,110 +1783,111 @@ func (am *DefaultAccountManager) GetAccountFromToken(ctx context.Context, claims log.WithContext(ctx).Debugf("overriding JWT Domain and DomainCategory claims since single account mode is enabled") } - newAcc, err := am.getAccountWithAuthorizationClaims(ctx, claims) + accountID, err := am.getAccountIDWithAuthorizationClaims(ctx, claims) if err != nil { - return nil, nil, err - } - unlock := am.Store.AcquireWriteLockByUID(ctx, newAcc.Id) - alreadyUnlocked := false - defer func() { - if !alreadyUnlocked { - unlock() - } - }() - - account, err := am.Store.GetAccount(ctx, newAcc.Id) - if err != nil { - return nil, nil, err + return "", "", err } - user := account.Users[claims.UserId] - if user == nil { + user, err := am.Store.GetUserByUserID(ctx, LockingStrengthShare, claims.UserId) + if err != nil { // this is not really possible because we got an account by user ID - return nil, nil, status.Errorf(status.NotFound, "user %s not found", claims.UserId) + return "", "", status.Errorf(status.NotFound, "user %s not found", claims.UserId) } if !user.IsServiceUser && claims.Invited { - err = am.redeemInvite(ctx, account, claims.UserId) + err = am.redeemInvite(ctx, accountID, user.Id) if err != nil { - return nil, nil, err + return "", "", err } } - if account.Settings.JWTGroupsEnabled { - if account.Settings.JWTGroupsClaimName == "" { - log.WithContext(ctx).Errorf("JWT groups are enabled but no claim name is set") - return account, user, nil - } - if claim, ok := claims.Raw[account.Settings.JWTGroupsClaimName]; ok { - if slice, ok := claim.([]interface{}); ok { - var groupsNames []string - for _, item := range slice { - if g, ok := item.(string); ok { - groupsNames = append(groupsNames, g) - } else { - log.WithContext(ctx).Errorf("JWT claim %q is not a string: %v", account.Settings.JWTGroupsClaimName, item) - } - } - - oldGroups := make([]string, len(user.AutoGroups)) - copy(oldGroups, user.AutoGroups) - // if groups were added or modified, save the account - if account.SetJWTGroups(claims.UserId, groupsNames) { - if account.Settings.GroupsPropagationEnabled { - if user, err := account.FindUser(claims.UserId); err == nil { - addNewGroups := difference(user.AutoGroups, oldGroups) - removeOldGroups := difference(oldGroups, user.AutoGroups) - account.UserGroupsAddToPeers(claims.UserId, addNewGroups...) - account.UserGroupsRemoveFromPeers(claims.UserId, removeOldGroups...) - account.Network.IncSerial() - if err := am.Store.SaveAccount(ctx, account); err != nil { - log.WithContext(ctx).Errorf("failed to save account: %v", err) - } else { - log.WithContext(ctx).Tracef("user %s: JWT group membership changed, updating account peers", claims.UserId) - am.updateAccountPeers(ctx, account) - unlock() - alreadyUnlocked = true - for _, g := range addNewGroups { - if group := account.GetGroup(g); group != nil { - am.StoreEvent(ctx, user.Id, user.Id, account.Id, activity.GroupAddedToUser, - map[string]any{ - "group": group.Name, - "group_id": group.ID, - "is_service_user": user.IsServiceUser, - "user_name": user.ServiceUserName}) - } - } - for _, g := range removeOldGroups { - if group := account.GetGroup(g); group != nil { - am.StoreEvent(ctx, user.Id, user.Id, account.Id, activity.GroupRemovedFromUser, - map[string]any{ - "group": group.Name, - "group_id": group.ID, - "is_service_user": user.IsServiceUser, - "user_name": user.ServiceUserName}) - } - } - } - } - } else { - if err := am.Store.SaveAccount(ctx, account); err != nil { - log.WithContext(ctx).Errorf("failed to save account: %v", err) - } - } - } - } else { - log.WithContext(ctx).Debugf("JWT claim %q is not a string array", account.Settings.JWTGroupsClaimName) - } - } else { - log.WithContext(ctx).Debugf("JWT claim %q not found", account.Settings.JWTGroupsClaimName) - } + if err = am.syncJWTGroups(ctx, accountID, user, claims); err != nil { + return "", "", err } - return account, user, nil + return accountID, user.Id, nil } -// getAccountWithAuthorizationClaims retrievs an account using JWT Claims. +// syncJWTGroups processes the JWT groups for a user, updates the account based on the groups, +// and propagates changes to peers if group propagation is enabled. +func (am *DefaultAccountManager) syncJWTGroups(ctx context.Context, accountID string, user *User, claims jwtclaims.AuthorizationClaims) error { + settings, err := am.Store.GetAccountSettings(ctx, LockingStrengthShare, accountID) + if err != nil { + return err + } + + if settings == nil || !settings.JWTGroupsEnabled { + return nil + } + + if settings.JWTGroupsClaimName == "" { + log.WithContext(ctx).Errorf("JWT groups are enabled but no claim name is set") + return nil + } + + // TODO: Remove GetAccount after refactoring account peer's update + unlock := am.Store.AcquireWriteLockByUID(ctx, accountID) + defer unlock() + + account, err := am.Store.GetAccount(ctx, accountID) + if err != nil { + return err + } + + jwtGroupsNames := extractJWTGroups(ctx, settings.JWTGroupsClaimName, claims) + + oldGroups := make([]string, len(user.AutoGroups)) + copy(oldGroups, user.AutoGroups) + + // Update the account if group membership changes + if account.SetJWTGroups(claims.UserId, jwtGroupsNames) { + addNewGroups := difference(user.AutoGroups, oldGroups) + removeOldGroups := difference(oldGroups, user.AutoGroups) + + if settings.GroupsPropagationEnabled { + account.UserGroupsAddToPeers(claims.UserId, addNewGroups...) + account.UserGroupsRemoveFromPeers(claims.UserId, removeOldGroups...) + account.Network.IncSerial() + } + + if err := am.Store.SaveAccount(ctx, account); err != nil { + log.WithContext(ctx).Errorf("failed to save account: %v", err) + return nil + } + + // Propagate changes to peers if group propagation is enabled + if settings.GroupsPropagationEnabled { + log.WithContext(ctx).Tracef("user %s: JWT group membership changed, updating account peers", claims.UserId) + am.updateAccountPeers(ctx, account) + } + + for _, g := range addNewGroups { + if group := account.GetGroup(g); group != nil { + am.StoreEvent(ctx, user.Id, user.Id, account.Id, activity.GroupAddedToUser, + map[string]any{ + "group": group.Name, + "group_id": group.ID, + "is_service_user": user.IsServiceUser, + "user_name": user.ServiceUserName}) + } + } + + for _, g := range removeOldGroups { + if group := account.GetGroup(g); group != nil { + am.StoreEvent(ctx, user.Id, user.Id, account.Id, activity.GroupRemovedFromUser, + map[string]any{ + "group": group.Name, + "group_id": group.ID, + "is_service_user": user.IsServiceUser, + "user_name": user.ServiceUserName}) + } + } + } + + return nil +} + +// getAccountIDWithAuthorizationClaims retrieves an account ID using JWT Claims. // if domain is of the PrivateCategory category, it will evaluate // if account is new, existing or if there is another account with the same domain // @@ -1859,26 +1904,34 @@ func (am *DefaultAccountManager) GetAccountFromToken(ctx context.Context, claims // Existing user + Existing account + Existing Indexed Domain -> Nothing changes // // Existing user + Existing account + Existing domain reclassified Domain as private -> Nothing changes (index domain) -func (am *DefaultAccountManager) getAccountWithAuthorizationClaims(ctx context.Context, claims jwtclaims.AuthorizationClaims) (*Account, error) { +func (am *DefaultAccountManager) getAccountIDWithAuthorizationClaims(ctx context.Context, claims jwtclaims.AuthorizationClaims) (string, error) { log.WithContext(ctx).Tracef("getting account with authorization claims. User ID: \"%s\", Account ID: \"%s\", Domain: \"%s\", Domain Category: \"%s\"", claims.UserId, claims.AccountId, claims.Domain, claims.DomainCategory) if claims.UserId == "" { - return nil, fmt.Errorf("user ID is empty") + return "", fmt.Errorf("user ID is empty") } + // if Account ID is part of the claims // it means that we've already classified the domain and user has an account if claims.DomainCategory != PrivateCategory || !isDomainValid(claims.Domain) { - return am.GetAccountByUserOrAccountID(ctx, claims.UserId, claims.AccountId, claims.Domain) + return am.GetAccountIDByUserOrAccountID(ctx, claims.UserId, claims.AccountId, claims.Domain) } else if claims.AccountId != "" { - accountFromID, err := am.Store.GetAccount(ctx, claims.AccountId) + userAccountID, err := am.Store.GetAccountIDByUserID(claims.UserId) if err != nil { - return nil, err + return "", err } - if _, ok := accountFromID.Users[claims.UserId]; !ok { - return nil, fmt.Errorf("user %s is not part of the account id %s", claims.UserId, claims.AccountId) + + if userAccountID != claims.AccountId { + return "", fmt.Errorf("user %s is not part of the account id %s", claims.UserId, claims.AccountId) } - if accountFromID.DomainCategory == PrivateCategory || claims.DomainCategory != PrivateCategory || accountFromID.Domain != claims.Domain { - return accountFromID, nil + + domain, domainCategory, err := am.Store.GetAccountDomainAndCategory(ctx, LockingStrengthShare, claims.AccountId) + if err != nil { + return "", err + } + + if domainCategory == PrivateCategory || claims.DomainCategory != PrivateCategory || domain != claims.Domain { + return userAccountID, nil } } @@ -1888,48 +1941,53 @@ func (am *DefaultAccountManager) getAccountWithAuthorizationClaims(ctx context.C log.WithContext(ctx).Debugf("Acquired global lock in %s for user %s", time.Since(start), claims.UserId) // We checked if the domain has a primary account already - domainAccount, err := am.Store.GetAccountByPrivateDomain(ctx, claims.Domain) + domainAccountID, err := am.Store.GetAccountIDByPrivateDomain(ctx, LockingStrengthShare, claims.Domain) if err != nil { // if NotFound we are good to continue, otherwise return error e, ok := status.FromError(err) if !ok || e.Type() != status.NotFound { - return nil, err + return "", err } } - account, err := am.Store.GetAccountByUser(ctx, claims.UserId) + userAccountID, err := am.Store.GetAccountIDByUserID(claims.UserId) if err == nil { - unlockAccount := am.Store.AcquireWriteLockByUID(ctx, account.Id) + unlockAccount := am.Store.AcquireWriteLockByUID(ctx, userAccountID) defer unlockAccount() - account, err = am.Store.GetAccountByUser(ctx, claims.UserId) + account, err := am.Store.GetAccountByUser(ctx, claims.UserId) if err != nil { - return nil, err + return "", err } // If there is no primary domain account yet, we set the account as primary for the domain. Otherwise, // we compare the account's ID with the domain account ID, and if they don't match, we set the account as // non-primary account for the domain. We don't merge accounts at this stage, because of cases when a domain // was previously unclassified or classified as public so N users that logged int that time, has they own account // and peers that shouldn't be lost. - primaryDomain := domainAccount == nil || account.Id == domainAccount.Id - - err = am.handleExistingUserAccount(ctx, account, primaryDomain, claims) - if err != nil { - return nil, err + primaryDomain := domainAccountID == "" || account.Id == domainAccountID + if err = am.handleExistingUserAccount(ctx, account, primaryDomain, claims); err != nil { + return "", err } - return account, nil + + return account.Id, nil } else if s, ok := status.FromError(err); ok && s.Type() == status.NotFound { - if domainAccount != nil { - unlockAccount := am.Store.AcquireWriteLockByUID(ctx, domainAccount.Id) + var domainAccount *Account + if domainAccountID != "" { + unlockAccount := am.Store.AcquireWriteLockByUID(ctx, domainAccountID) defer unlockAccount() domainAccount, err = am.Store.GetAccountByPrivateDomain(ctx, claims.Domain) if err != nil { - return nil, err + return "", err } } - return am.handleNewUserAccount(ctx, domainAccount, claims) + + account, err := am.handleNewUserAccount(ctx, domainAccount, claims) + if err != nil { + return "", err + } + return account.Id, nil } else { // other error - return nil, err + return "", err } } @@ -2022,26 +2080,21 @@ func (am *DefaultAccountManager) GetDNSDomain() string { // CheckUserAccessByJWTGroups checks if the user has access, particularly in cases where the admin enabled JWT // group propagation and set the list of groups with access permissions. func (am *DefaultAccountManager) CheckUserAccessByJWTGroups(ctx context.Context, claims jwtclaims.AuthorizationClaims) error { - account, _, err := am.GetAccountFromToken(ctx, claims) + accountID, _, err := am.GetAccountIDFromToken(ctx, claims) + if err != nil { + return err + } + + settings, err := am.Store.GetAccountSettings(ctx, LockingStrengthShare, accountID) if err != nil { return err } // Ensures JWT group synchronization to the management is enabled before, // filtering access based on the allowed groups. - if account.Settings != nil && account.Settings.JWTGroupsEnabled { - if allowedGroups := account.Settings.JWTAllowGroups; len(allowedGroups) > 0 { - userJWTGroups := make([]string, 0) - - if claim, ok := claims.Raw[account.Settings.JWTGroupsClaimName]; ok { - if claimGroups, ok := claim.([]interface{}); ok { - for _, g := range claimGroups { - if group, ok := g.(string); ok { - userJWTGroups = append(userJWTGroups, group) - } - } - } - } + if settings != nil && settings.JWTGroupsEnabled { + if allowedGroups := settings.JWTAllowGroups; len(allowedGroups) > 0 { + userJWTGroups := extractJWTGroups(ctx, settings.JWTGroupsClaimName, claims) if !userHasAllowedGroup(allowedGroups, userJWTGroups) { return fmt.Errorf("user does not belong to any of the allowed JWT groups") @@ -2111,6 +2164,19 @@ func (am *DefaultAccountManager) getFreeDNSLabel(ctx context.Context, store Stor return newLabel, nil } +func (am *DefaultAccountManager) GetAccountSettings(ctx context.Context, accountID string, userID string) (*Settings, error) { + user, err := am.Store.GetUserByUserID(ctx, LockingStrengthShare, userID) + if err != nil { + return nil, err + } + + if user.AccountID != accountID || (!user.HasAdminPower() && !user.IsServiceUser) { + return nil, status.Errorf(status.PermissionDenied, "the user has no permission to access account data") + } + + return am.Store.GetAccountSettings(ctx, LockingStrengthShare, accountID) +} + // addAllGroup to account object if it doesn't exist func addAllGroup(account *Account) error { if len(account.Groups) == 0 { @@ -2193,6 +2259,27 @@ func newAccountWithId(ctx context.Context, accountID, userID, domain string) *Ac return acc } +// extractJWTGroups extracts the group names from a JWT token's claims. +func extractJWTGroups(ctx context.Context, claimName string, claims jwtclaims.AuthorizationClaims) []string { + userJWTGroups := make([]string, 0) + + if claim, ok := claims.Raw[claimName]; ok { + if claimGroups, ok := claim.([]interface{}); ok { + for _, g := range claimGroups { + if group, ok := g.(string); ok { + userJWTGroups = append(userJWTGroups, group) + } else { + log.WithContext(ctx).Debugf("JWT claim %q contains a non-string group (type: %T): %v", claimName, g, g) + } + } + } + } else { + log.WithContext(ctx).Debugf("JWT claim %q is not a string array", claimName) + } + + return userJWTGroups +} + // userHasAllowedGroup checks if a user belongs to any of the allowed groups. func userHasAllowedGroup(allowedGroups []string, userGroups []string) bool { for _, userGroup := range userGroups { diff --git a/management/server/account_test.go b/management/server/account_test.go index 03b5fa83e..303261bea 100644 --- a/management/server/account_test.go +++ b/management/server/account_test.go @@ -462,7 +462,7 @@ func TestAccountManager_GetOrCreateAccountByUser(t *testing.T) { assert.Equal(t, account.Id, ev.TargetID) } -func TestDefaultAccountManager_GetAccountFromToken(t *testing.T) { +func TestDefaultAccountManager_GetAccountIDFromToken(t *testing.T) { type initUserParams jwtclaims.AuthorizationClaims type test struct { @@ -633,9 +633,12 @@ func TestDefaultAccountManager_GetAccountFromToken(t *testing.T) { manager, err := createManager(t) require.NoError(t, err, "unable to create account manager") - initAccount, err := manager.GetAccountByUserOrAccountID(context.Background(), testCase.inputInitUserParams.UserId, testCase.inputInitUserParams.AccountId, testCase.inputInitUserParams.Domain) + accountID, err := manager.GetAccountIDByUserOrAccountID(context.Background(), testCase.inputInitUserParams.UserId, testCase.inputInitUserParams.AccountId, testCase.inputInitUserParams.Domain) require.NoError(t, err, "create init user failed") + initAccount, err := manager.Store.GetAccount(context.Background(), accountID) + require.NoError(t, err, "get init account failed") + if testCase.inputUpdateAttrs { err = manager.updateAccountDomainAttributes(context.Background(), initAccount, jwtclaims.AuthorizationClaims{UserId: testCase.inputInitUserParams.UserId, Domain: testCase.inputInitUserParams.Domain, DomainCategory: testCase.inputInitUserParams.DomainCategory}, true) require.NoError(t, err, "update init user failed") @@ -645,8 +648,12 @@ func TestDefaultAccountManager_GetAccountFromToken(t *testing.T) { testCase.inputClaims.AccountId = initAccount.Id } - account, _, err := manager.GetAccountFromToken(context.Background(), testCase.inputClaims) + accountID, _, err = manager.GetAccountIDFromToken(context.Background(), testCase.inputClaims) require.NoError(t, err, "support function failed") + + account, err := manager.Store.GetAccount(context.Background(), accountID) + require.NoError(t, err, "get account failed") + verifyNewAccountHasDefaultFields(t, account, testCase.expectedCreatedBy, testCase.inputClaims.Domain, testCase.expectedUsers) verifyCanAddPeerToAccount(t, manager, account, testCase.expectedCreatedBy) @@ -669,12 +676,13 @@ func TestDefaultAccountManager_GetGroupsFromTheToken(t *testing.T) { require.NoError(t, err, "unable to create account manager") accountID := initAccount.Id - acc, err := manager.GetAccountByUserOrAccountID(context.Background(), userId, accountID, domain) + accountID, err = manager.GetAccountIDByUserOrAccountID(context.Background(), userId, accountID, domain) require.NoError(t, err, "create init user failed") // as initAccount was created without account id we have to take the id after account initialization - // that happens inside the GetAccountByUserOrAccountID where the id is getting generated + // that happens inside the GetAccountIDByUserOrAccountID where the id is getting generated // it is important to set the id as it help to avoid creating additional account with empty Id and re-pointing indices to it - initAccount = acc + initAccount, err = manager.Store.GetAccount(context.Background(), accountID) + require.NoError(t, err, "get init account failed") claims := jwtclaims.AuthorizationClaims{ AccountId: accountID, // is empty as it is based on accountID right after initialization of initAccount @@ -685,8 +693,12 @@ func TestDefaultAccountManager_GetGroupsFromTheToken(t *testing.T) { } t.Run("JWT groups disabled", func(t *testing.T) { - account, _, err := manager.GetAccountFromToken(context.Background(), claims) + accountID, _, err := manager.GetAccountIDFromToken(context.Background(), claims) require.NoError(t, err, "get account by token failed") + + account, err := manager.Store.GetAccount(context.Background(), accountID) + require.NoError(t, err, "get account failed") + require.Len(t, account.Groups, 1, "only ALL group should exists") }) @@ -696,8 +708,12 @@ func TestDefaultAccountManager_GetGroupsFromTheToken(t *testing.T) { require.NoError(t, err, "save account failed") require.Len(t, manager.Store.GetAllAccounts(context.Background()), 1, "only one account should exist") - account, _, err := manager.GetAccountFromToken(context.Background(), claims) + accountID, _, err := manager.GetAccountIDFromToken(context.Background(), claims) require.NoError(t, err, "get account by token failed") + + account, err := manager.Store.GetAccount(context.Background(), accountID) + require.NoError(t, err, "get account failed") + require.Len(t, account.Groups, 1, "if group claim is not set no group added from JWT") }) @@ -708,8 +724,12 @@ func TestDefaultAccountManager_GetGroupsFromTheToken(t *testing.T) { require.NoError(t, err, "save account failed") require.Len(t, manager.Store.GetAllAccounts(context.Background()), 1, "only one account should exist") - account, _, err := manager.GetAccountFromToken(context.Background(), claims) + accountID, _, err := manager.GetAccountIDFromToken(context.Background(), claims) require.NoError(t, err, "get account by token failed") + + account, err := manager.Store.GetAccount(context.Background(), accountID) + require.NoError(t, err, "get account failed") + require.Len(t, account.Groups, 3, "groups should be added to the account") groupsByNames := map[string]*group.Group{} @@ -874,21 +894,21 @@ func TestAccountManager_GetAccountByUserOrAccountId(t *testing.T) { userId := "test_user" - account, err := manager.GetAccountByUserOrAccountID(context.Background(), userId, "", "") + accountID, err := manager.GetAccountIDByUserOrAccountID(context.Background(), userId, "", "") if err != nil { t.Fatal(err) } - if account == nil { + if accountID == "" { t.Fatalf("expected to create an account for a user %s", userId) return } - _, err = manager.GetAccountByUserOrAccountID(context.Background(), "", account.Id, "") + _, err = manager.GetAccountIDByUserOrAccountID(context.Background(), "", accountID, "") if err != nil { - t.Errorf("expected to get existing account after creation using userid, no account was found for a account %s", account.Id) + t.Errorf("expected to get existing account after creation using userid, no account was found for a account %s", accountID) } - _, err = manager.GetAccountByUserOrAccountID(context.Background(), "", "", "") + _, err = manager.GetAccountIDByUserOrAccountID(context.Background(), "", "", "") if err == nil { t.Errorf("expected an error when user and account IDs are empty") } @@ -1240,7 +1260,7 @@ func TestAccountManager_NetworkUpdates(t *testing.T) { } }() - if err := manager.SavePolicy(context.Background(), account.Id, userID, &policy); err != nil { + if err := manager.SavePolicy(context.Background(), account.Id, userID, &policy, false); err != nil { t.Errorf("delete default rule: %v", err) return } @@ -1648,19 +1668,22 @@ func TestDefaultAccountManager_DefaultAccountSettings(t *testing.T) { manager, err := createManager(t) require.NoError(t, err, "unable to create account manager") - account, err := manager.GetAccountByUserOrAccountID(context.Background(), userID, "", "") + accountID, err := manager.GetAccountIDByUserOrAccountID(context.Background(), userID, "", "") require.NoError(t, err, "unable to create an account") - assert.NotNil(t, account.Settings) - assert.Equal(t, account.Settings.PeerLoginExpirationEnabled, true) - assert.Equal(t, account.Settings.PeerLoginExpiration, 24*time.Hour) + settings, err := manager.Store.GetAccountSettings(context.Background(), LockingStrengthShare, accountID) + require.NoError(t, err, "unable to get account settings") + + assert.NotNil(t, settings) + assert.Equal(t, settings.PeerLoginExpirationEnabled, true) + assert.Equal(t, settings.PeerLoginExpiration, 24*time.Hour) } func TestDefaultAccountManager_UpdatePeer_PeerLoginExpiration(t *testing.T) { manager, err := createManager(t) require.NoError(t, err, "unable to create account manager") - _, err = manager.GetAccountByUserOrAccountID(context.Background(), userID, "", "") + _, err = manager.GetAccountIDByUserOrAccountID(context.Background(), userID, "", "") require.NoError(t, err, "unable to create an account") key, err := wgtypes.GenerateKey() @@ -1672,11 +1695,16 @@ func TestDefaultAccountManager_UpdatePeer_PeerLoginExpiration(t *testing.T) { }) require.NoError(t, err, "unable to add peer") - account, err := manager.GetAccountByUserOrAccountID(context.Background(), userID, "", "") + accountID, err := manager.GetAccountIDByUserOrAccountID(context.Background(), userID, "", "") require.NoError(t, err, "unable to get the account") + + account, err := manager.Store.GetAccount(context.Background(), accountID) + require.NoError(t, err, "unable to get the account") + err = manager.MarkPeerConnected(context.Background(), key.PublicKey().String(), true, nil, account) require.NoError(t, err, "unable to mark peer connected") - account, err = manager.UpdateAccountSettings(context.Background(), account.Id, userID, &Settings{ + + account, err = manager.UpdateAccountSettings(context.Background(), accountID, userID, &Settings{ PeerLoginExpiration: time.Hour, PeerLoginExpirationEnabled: true, }) @@ -1713,7 +1741,7 @@ func TestDefaultAccountManager_MarkPeerConnected_PeerLoginExpiration(t *testing. manager, err := createManager(t) require.NoError(t, err, "unable to create account manager") - account, err := manager.GetAccountByUserOrAccountID(context.Background(), userID, "", "") + accountID, err := manager.GetAccountIDByUserOrAccountID(context.Background(), userID, "", "") require.NoError(t, err, "unable to create an account") key, err := wgtypes.GenerateKey() @@ -1724,7 +1752,7 @@ func TestDefaultAccountManager_MarkPeerConnected_PeerLoginExpiration(t *testing. LoginExpirationEnabled: true, }) require.NoError(t, err, "unable to add peer") - _, err = manager.UpdateAccountSettings(context.Background(), account.Id, userID, &Settings{ + _, err = manager.UpdateAccountSettings(context.Background(), accountID, userID, &Settings{ PeerLoginExpiration: time.Hour, PeerLoginExpirationEnabled: true, }) @@ -1741,8 +1769,12 @@ func TestDefaultAccountManager_MarkPeerConnected_PeerLoginExpiration(t *testing. }, } - account, err = manager.GetAccountByUserOrAccountID(context.Background(), userID, "", "") + accountID, err = manager.GetAccountIDByUserOrAccountID(context.Background(), userID, "", "") require.NoError(t, err, "unable to get the account") + + account, err := manager.Store.GetAccount(context.Background(), accountID) + require.NoError(t, err, "unable to get the account") + // when we mark peer as connected, the peer login expiration routine should trigger err = manager.MarkPeerConnected(context.Background(), key.PublicKey().String(), true, nil, account) require.NoError(t, err, "unable to mark peer connected") @@ -1757,7 +1789,7 @@ func TestDefaultAccountManager_UpdateAccountSettings_PeerLoginExpiration(t *test manager, err := createManager(t) require.NoError(t, err, "unable to create account manager") - _, err = manager.GetAccountByUserOrAccountID(context.Background(), userID, "", "") + _, err = manager.GetAccountIDByUserOrAccountID(context.Background(), userID, "", "") require.NoError(t, err, "unable to create an account") key, err := wgtypes.GenerateKey() @@ -1769,8 +1801,12 @@ func TestDefaultAccountManager_UpdateAccountSettings_PeerLoginExpiration(t *test }) require.NoError(t, err, "unable to add peer") - account, err := manager.GetAccountByUserOrAccountID(context.Background(), userID, "", "") + accountID, err := manager.GetAccountIDByUserOrAccountID(context.Background(), userID, "", "") require.NoError(t, err, "unable to get the account") + + account, err := manager.Store.GetAccount(context.Background(), accountID) + require.NoError(t, err, "unable to get the account") + err = manager.MarkPeerConnected(context.Background(), key.PublicKey().String(), true, nil, account) require.NoError(t, err, "unable to mark peer connected") @@ -1813,10 +1849,10 @@ func TestDefaultAccountManager_UpdateAccountSettings(t *testing.T) { manager, err := createManager(t) require.NoError(t, err, "unable to create account manager") - account, err := manager.GetAccountByUserOrAccountID(context.Background(), userID, "", "") + accountID, err := manager.GetAccountIDByUserOrAccountID(context.Background(), userID, "", "") require.NoError(t, err, "unable to create an account") - updated, err := manager.UpdateAccountSettings(context.Background(), account.Id, userID, &Settings{ + updated, err := manager.UpdateAccountSettings(context.Background(), accountID, userID, &Settings{ PeerLoginExpiration: time.Hour, PeerLoginExpirationEnabled: false, }) @@ -1824,19 +1860,22 @@ func TestDefaultAccountManager_UpdateAccountSettings(t *testing.T) { assert.False(t, updated.Settings.PeerLoginExpirationEnabled) assert.Equal(t, updated.Settings.PeerLoginExpiration, time.Hour) - account, err = manager.GetAccountByUserOrAccountID(context.Background(), "", account.Id, "") + accountID, err = manager.GetAccountIDByUserOrAccountID(context.Background(), "", accountID, "") require.NoError(t, err, "unable to get account by ID") - assert.False(t, account.Settings.PeerLoginExpirationEnabled) - assert.Equal(t, account.Settings.PeerLoginExpiration, time.Hour) + settings, err := manager.Store.GetAccountSettings(context.Background(), LockingStrengthShare, accountID) + require.NoError(t, err, "unable to get account settings") - _, err = manager.UpdateAccountSettings(context.Background(), account.Id, userID, &Settings{ + assert.False(t, settings.PeerLoginExpirationEnabled) + assert.Equal(t, settings.PeerLoginExpiration, time.Hour) + + _, err = manager.UpdateAccountSettings(context.Background(), accountID, userID, &Settings{ PeerLoginExpiration: time.Second, PeerLoginExpirationEnabled: false, }) require.Error(t, err, "expecting to fail when providing PeerLoginExpiration less than one hour") - _, err = manager.UpdateAccountSettings(context.Background(), account.Id, userID, &Settings{ + _, err = manager.UpdateAccountSettings(context.Background(), accountID, userID, &Settings{ PeerLoginExpiration: time.Hour * 24 * 181, PeerLoginExpirationEnabled: false, }) diff --git a/management/server/dns.go b/management/server/dns.go index 1d156c90a..7410aaa15 100644 --- a/management/server/dns.go +++ b/management/server/dns.go @@ -80,24 +80,16 @@ func (d DNSSettings) Copy() DNSSettings { // GetDNSSettings validates a user role and returns the DNS settings for the provided account ID func (am *DefaultAccountManager) GetDNSSettings(ctx context.Context, accountID string, userID string) (*DNSSettings, error) { - unlock := am.Store.AcquireWriteLockByUID(ctx, accountID) - defer unlock() - - account, err := am.Store.GetAccount(ctx, accountID) + user, err := am.Store.GetUserByUserID(ctx, LockingStrengthShare, userID) if err != nil { return nil, err } - user, err := account.FindUser(userID) - if err != nil { - return nil, err - } - - if !(user.HasAdminPower() || user.IsServiceUser) { + if !user.IsAdminOrServiceUser() || user.AccountID != accountID { return nil, status.Errorf(status.PermissionDenied, "only users with admin power are allowed to view DNS settings") } - dnsSettings := account.DNSSettings.Copy() - return &dnsSettings, nil + + return am.Store.GetAccountDNSSettings(ctx, LockingStrengthShare, accountID) } // SaveDNSSettings validates a user role and updates the account's DNS settings diff --git a/management/server/file_store.go b/management/server/file_store.go index 95d5b4e6e..994a4b1ee 100644 --- a/management/server/file_store.go +++ b/management/server/file_store.go @@ -10,14 +10,15 @@ import ( "sync" "time" - "github.com/rs/xid" - log "github.com/sirupsen/logrus" - + "github.com/netbirdio/netbird/dns" nbgroup "github.com/netbirdio/netbird/management/server/group" nbpeer "github.com/netbirdio/netbird/management/server/peer" "github.com/netbirdio/netbird/management/server/posture" "github.com/netbirdio/netbird/management/server/status" "github.com/netbirdio/netbird/management/server/telemetry" + "github.com/netbirdio/netbird/route" + "github.com/rs/xid" + log "github.com/sirupsen/logrus" "github.com/netbirdio/netbird/util" ) @@ -634,10 +635,19 @@ func (s *FileStore) GetUserByUserID(_ context.Context, _ LockingStrength, userID return nil, err } - return account.Users[userID].Copy(), nil + user := account.Users[userID].Copy() + pat := make([]PersonalAccessToken, 0, len(user.PATs)) + for _, token := range user.PATs { + if token != nil { + pat = append(pat, *token) + } + } + user.PATsG = pat + + return user, nil } -func (s *FileStore) GetAccountGroups(ctx context.Context, accountID string) ([]*nbgroup.Group, error) { +func (s *FileStore) GetAccountGroups(_ context.Context, accountID string) ([]*nbgroup.Group, error) { account, err := s.getAccount(accountID) if err != nil { return nil, err @@ -931,7 +941,7 @@ func (s *FileStore) SaveUserLastLogin(_ context.Context, accountID, userID strin return nil } -func (s *FileStore) GetPostureCheckByChecksDefinition(accountID string, checks *posture.ChecksDefinition) (*posture.Checks, error) { +func (s *FileStore) GetPostureCheckByChecksDefinition(_ string, _ *posture.ChecksDefinition) (*posture.Checks, error) { return nil, status.Errorf(status.Internal, "GetPostureCheckByChecksDefinition is not implemented") } @@ -950,10 +960,85 @@ func (s *FileStore) GetStoreEngine() StoreEngine { return FileStoreEngine } -func (s *FileStore) SaveUsers(accountID string, users map[string]*User) error { +func (s *FileStore) SaveUsers(_ string, _ map[string]*User) error { return status.Errorf(status.Internal, "SaveUsers is not implemented") } -func (s *FileStore) SaveGroups(accountID string, groups map[string]*nbgroup.Group) error { +func (s *FileStore) SaveGroups(_ string, _ map[string]*nbgroup.Group) error { return status.Errorf(status.Internal, "SaveGroups is not implemented") } + +func (s *FileStore) GetAccountIDByPrivateDomain(_ context.Context, _ LockingStrength, _ string) (string, error) { + return "", status.Errorf(status.Internal, "GetAccountIDByPrivateDomain is not implemented") +} + +func (s *FileStore) GetAccountDomainAndCategory(_ context.Context, _ LockingStrength, accountID string) (string, string, error) { + s.mux.Lock() + defer s.mux.Unlock() + + account, err := s.getAccount(accountID) + if err != nil { + return "", "", err + } + + return account.Domain, account.DomainCategory, nil +} + +// AccountExists checks whether an account exists by the given ID. +func (s *FileStore) AccountExists(_ context.Context, _ LockingStrength, id string) (bool, error) { + _, exists := s.Accounts[id] + return exists, nil +} + +func (s *FileStore) GetAccountDNSSettings(_ context.Context, _ LockingStrength, _ string) (*DNSSettings, error) { + return nil, status.Errorf(status.Internal, "GetAccountDNSSettings is not implemented") +} + +func (s *FileStore) GetGroupByID(_ context.Context, _ LockingStrength, _, _ string) (*nbgroup.Group, error) { + return nil, status.Errorf(status.Internal, "GetGroupByID is not implemented") +} + +func (s *FileStore) GetGroupByName(_ context.Context, _ LockingStrength, _, _ string) (*nbgroup.Group, error) { + return nil, status.Errorf(status.Internal, "GetGroupByName is not implemented") +} + +func (s *FileStore) GetAccountPolicies(_ context.Context, _ LockingStrength, _ string) ([]*Policy, error) { + return nil, status.Errorf(status.Internal, "GetPolicyByID is not implemented") +} + +func (s *FileStore) GetPolicyByID(_ context.Context, _ LockingStrength, _ string, _ string) (*Policy, error) { + return nil, status.Errorf(status.Internal, "GetPolicyByID is not implemented") + +} + +func (s *FileStore) GetAccountPostureChecks(_ context.Context, _ LockingStrength, _ string) ([]*posture.Checks, error) { + return nil, status.Errorf(status.Internal, "GetAccountPostureChecks is not implemented") +} + +func (s *FileStore) GetPostureChecksByID(_ context.Context, _ LockingStrength, _ string, _ string) (*posture.Checks, error) { + return nil, status.Errorf(status.Internal, "GetPostureChecksByID is not implemented") +} + +func (s *FileStore) GetAccountRoutes(_ context.Context, _ LockingStrength, _ string) ([]*route.Route, error) { + return nil, status.Errorf(status.Internal, "GetAccountRoutes is not implemented") +} + +func (s *FileStore) GetRouteByID(_ context.Context, _ LockingStrength, _ string, _ string) (*route.Route, error) { + return nil, status.Errorf(status.Internal, "GetRouteByID is not implemented") +} + +func (s *FileStore) GetAccountSetupKeys(_ context.Context, _ LockingStrength, _ string) ([]*SetupKey, error) { + return nil, status.Errorf(status.Internal, "GetAccountSetupKeys is not implemented") +} + +func (s *FileStore) GetSetupKeyByID(_ context.Context, _ LockingStrength, _ string, _ string) (*SetupKey, error) { + return nil, status.Errorf(status.Internal, "GetSetupKeyByID is not implemented") +} + +func (s *FileStore) GetAccountNameServerGroups(_ context.Context, _ LockingStrength, _ string) ([]*dns.NameServerGroup, error) { + return nil, status.Errorf(status.Internal, "GetAccountNameServerGroups is not implemented") +} + +func (s *FileStore) GetNameServerGroupByID(_ context.Context, _ LockingStrength, _ string, _ string) (*dns.NameServerGroup, error) { + return nil, status.Errorf(status.Internal, "GetNameServerGroupByID is not implemented") +} diff --git a/management/server/group.go b/management/server/group.go index 49720f347..aa387c058 100644 --- a/management/server/group.go +++ b/management/server/group.go @@ -25,91 +25,46 @@ func (e *GroupLinkError) Error() string { return fmt.Sprintf("group has been linked to %s: %s", e.Resource, e.Name) } -// GetGroup object of the peers +// CheckGroupPermissions validates if a user has the necessary permissions to view groups +func (am *DefaultAccountManager) CheckGroupPermissions(ctx context.Context, accountID, userID string) error { + settings, err := am.Store.GetAccountSettings(ctx, LockingStrengthShare, accountID) + if err != nil { + return err + } + + user, err := am.Store.GetUserByUserID(ctx, LockingStrengthShare, userID) + if err != nil { + return err + } + + if (!user.IsAdminOrServiceUser() && settings.RegularUsersViewBlocked) || user.AccountID != accountID { + return status.Errorf(status.PermissionDenied, "groups are blocked for users") + } + + return nil +} + +// GetGroup returns a specific group by groupID in an account func (am *DefaultAccountManager) GetGroup(ctx context.Context, accountID, groupID, userID string) (*nbgroup.Group, error) { - unlock := am.Store.AcquireWriteLockByUID(ctx, accountID) - defer unlock() - - account, err := am.Store.GetAccount(ctx, accountID) - if err != nil { + if err := am.CheckGroupPermissions(ctx, accountID, userID); err != nil { return nil, err } - user, err := account.FindUser(userID) - if err != nil { - return nil, err - } - - if !user.HasAdminPower() && !user.IsServiceUser && account.Settings.RegularUsersViewBlocked { - return nil, status.Errorf(status.PermissionDenied, "groups are blocked for users") - } - - group, ok := account.Groups[groupID] - if ok { - return group, nil - } - - return nil, status.Errorf(status.NotFound, "group with ID %s not found", groupID) + return am.Store.GetGroupByID(ctx, LockingStrengthShare, groupID, accountID) } // GetAllGroups returns all groups in an account -func (am *DefaultAccountManager) GetAllGroups(ctx context.Context, accountID string, userID string) ([]*nbgroup.Group, error) { - unlock := am.Store.AcquireWriteLockByUID(ctx, accountID) - defer unlock() - - account, err := am.Store.GetAccount(ctx, accountID) - if err != nil { +func (am *DefaultAccountManager) GetAllGroups(ctx context.Context, accountID, userID string) ([]*nbgroup.Group, error) { + if err := am.CheckGroupPermissions(ctx, accountID, userID); err != nil { return nil, err } - user, err := account.FindUser(userID) - if err != nil { - return nil, err - } - - if !user.HasAdminPower() && !user.IsServiceUser && account.Settings.RegularUsersViewBlocked { - return nil, status.Errorf(status.PermissionDenied, "groups are blocked for users") - } - - groups := make([]*nbgroup.Group, 0, len(account.Groups)) - for _, item := range account.Groups { - groups = append(groups, item) - } - - return groups, nil + return am.Store.GetAccountGroups(ctx, accountID) } // GetGroupByName filters all groups in an account by name and returns the one with the most peers func (am *DefaultAccountManager) GetGroupByName(ctx context.Context, groupName, accountID string) (*nbgroup.Group, error) { - unlock := am.Store.AcquireWriteLockByUID(ctx, accountID) - defer unlock() - - account, err := am.Store.GetAccount(ctx, accountID) - if err != nil { - return nil, err - } - - matchingGroups := make([]*nbgroup.Group, 0) - for _, group := range account.Groups { - if group.Name == groupName { - matchingGroups = append(matchingGroups, group) - } - } - - if len(matchingGroups) == 0 { - return nil, status.Errorf(status.NotFound, "group with name %s not found", groupName) - } - - maxPeers := -1 - var groupWithMostPeers *nbgroup.Group - for i, group := range matchingGroups { - if len(group.Peers) > maxPeers { - maxPeers = len(group.Peers) - groupWithMostPeers = matchingGroups[i] - } - } - - return groupWithMostPeers, nil + return am.Store.GetGroupByName(ctx, LockingStrengthShare, groupName, accountID) } // SaveGroup object of the peers @@ -262,6 +217,15 @@ func (am *DefaultAccountManager) DeleteGroup(ctx context.Context, accountId, use return nil } + allGroup, err := account.GetGroupAll() + if err != nil { + return err + } + + if allGroup.ID == groupID { + return status.Errorf(status.InvalidArgument, "deleting group ALL is not allowed") + } + if err = validateDeleteGroup(account, group, userId); err != nil { return err } diff --git a/management/server/grpcserver.go b/management/server/grpcserver.go index 5d7094b6a..cda3bc748 100644 --- a/management/server/grpcserver.go +++ b/management/server/grpcserver.go @@ -262,7 +262,7 @@ func (s *GRPCServer) validateToken(ctx context.Context, jwtToken string) (string } claims := s.jwtClaimsExtractor.FromToken(token) // we need to call this method because if user is new, we will automatically add it to existing or create a new account - _, _, err = s.accountManager.GetAccountFromToken(ctx, claims) + _, _, err = s.accountManager.GetAccountIDFromToken(ctx, claims) if err != nil { return "", status.Errorf(codes.Internal, "unable to fetch account with claims, err: %v", err) } diff --git a/management/server/http/accounts_handler.go b/management/server/http/accounts_handler.go index ffa5b9a28..91caa1512 100644 --- a/management/server/http/accounts_handler.go +++ b/management/server/http/accounts_handler.go @@ -35,25 +35,26 @@ func NewAccountsHandler(accountManager server.AccountManager, authCfg AuthCfg) * // GetAllAccounts is HTTP GET handler that returns a list of accounts. Effectively returns just a single account. func (h *AccountsHandler) GetAllAccounts(w http.ResponseWriter, r *http.Request) { claims := h.claimsExtractor.FromRequestContext(r) - account, user, err := h.accountManager.GetAccountFromToken(r.Context(), claims) + accountID, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims) if err != nil { util.WriteError(r.Context(), err, w) return } - if !(user.HasAdminPower() || user.IsServiceUser) { - util.WriteError(r.Context(), status.Errorf(status.PermissionDenied, "the user has no permission to access account data"), w) + settings, err := h.accountManager.GetAccountSettings(r.Context(), accountID, userID) + if err != nil { + util.WriteError(r.Context(), err, w) return } - resp := toAccountResponse(account) + resp := toAccountResponse(accountID, settings) util.WriteJSONObject(r.Context(), w, []*api.Account{resp}) } // UpdateAccount is HTTP PUT handler that updates the provided account. Updates only account settings (server.Settings) func (h *AccountsHandler) UpdateAccount(w http.ResponseWriter, r *http.Request) { claims := h.claimsExtractor.FromRequestContext(r) - _, user, err := h.accountManager.GetAccountFromToken(r.Context(), claims) + _, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims) if err != nil { util.WriteError(r.Context(), err, w) return @@ -96,24 +97,19 @@ func (h *AccountsHandler) UpdateAccount(w http.ResponseWriter, r *http.Request) settings.JWTAllowGroups = *req.Settings.JwtAllowGroups } - updatedAccount, err := h.accountManager.UpdateAccountSettings(r.Context(), accountID, user.Id, settings) + updatedAccount, err := h.accountManager.UpdateAccountSettings(r.Context(), accountID, userID, settings) if err != nil { util.WriteError(r.Context(), err, w) return } - resp := toAccountResponse(updatedAccount) + resp := toAccountResponse(updatedAccount.Id, updatedAccount.Settings) util.WriteJSONObject(r.Context(), w, &resp) } // DeleteAccount is a HTTP DELETE handler to delete an account func (h *AccountsHandler) DeleteAccount(w http.ResponseWriter, r *http.Request) { - if r.Method != http.MethodDelete { - util.WriteErrorResponse("wrong HTTP method", http.StatusMethodNotAllowed, w) - return - } - claims := h.claimsExtractor.FromRequestContext(r) vars := mux.Vars(r) targetAccountID := vars["accountId"] @@ -131,28 +127,28 @@ func (h *AccountsHandler) DeleteAccount(w http.ResponseWriter, r *http.Request) util.WriteJSONObject(r.Context(), w, emptyObject{}) } -func toAccountResponse(account *server.Account) *api.Account { - jwtAllowGroups := account.Settings.JWTAllowGroups +func toAccountResponse(accountID string, settings *server.Settings) *api.Account { + jwtAllowGroups := settings.JWTAllowGroups if jwtAllowGroups == nil { jwtAllowGroups = []string{} } - settings := api.AccountSettings{ - PeerLoginExpiration: int(account.Settings.PeerLoginExpiration.Seconds()), - PeerLoginExpirationEnabled: account.Settings.PeerLoginExpirationEnabled, - GroupsPropagationEnabled: &account.Settings.GroupsPropagationEnabled, - JwtGroupsEnabled: &account.Settings.JWTGroupsEnabled, - JwtGroupsClaimName: &account.Settings.JWTGroupsClaimName, + apiSettings := api.AccountSettings{ + PeerLoginExpiration: int(settings.PeerLoginExpiration.Seconds()), + PeerLoginExpirationEnabled: settings.PeerLoginExpirationEnabled, + GroupsPropagationEnabled: &settings.GroupsPropagationEnabled, + JwtGroupsEnabled: &settings.JWTGroupsEnabled, + JwtGroupsClaimName: &settings.JWTGroupsClaimName, JwtAllowGroups: &jwtAllowGroups, - RegularUsersViewBlocked: account.Settings.RegularUsersViewBlocked, + RegularUsersViewBlocked: settings.RegularUsersViewBlocked, } - if account.Settings.Extra != nil { - settings.Extra = &api.AccountExtraSettings{PeerApprovalEnabled: &account.Settings.Extra.PeerApprovalEnabled} + if settings.Extra != nil { + apiSettings.Extra = &api.AccountExtraSettings{PeerApprovalEnabled: &settings.Extra.PeerApprovalEnabled} } return &api.Account{ - Id: account.Id, - Settings: settings, + Id: accountID, + Settings: apiSettings, } } diff --git a/management/server/http/accounts_handler_test.go b/management/server/http/accounts_handler_test.go index 45c7679e5..cacb3d430 100644 --- a/management/server/http/accounts_handler_test.go +++ b/management/server/http/accounts_handler_test.go @@ -23,8 +23,11 @@ import ( func initAccountsTestData(account *server.Account, admin *server.User) *AccountsHandler { return &AccountsHandler{ accountManager: &mock_server.MockAccountManager{ - GetAccountFromTokenFunc: func(ctx context.Context, claims jwtclaims.AuthorizationClaims) (*server.Account, *server.User, error) { - return account, admin, nil + GetAccountIDFromTokenFunc: func(ctx context.Context, claims jwtclaims.AuthorizationClaims) (string, string, error) { + return account.Id, admin.Id, nil + }, + GetAccountSettingsFunc: func(ctx context.Context, accountID string, userID string) (*server.Settings, error) { + return account.Settings, nil }, UpdateAccountSettingsFunc: func(ctx context.Context, accountID, userID string, newSettings *server.Settings) (*server.Account, error) { halfYearLimit := 180 * 24 * time.Hour diff --git a/management/server/http/dns_settings_handler.go b/management/server/http/dns_settings_handler.go index 74b0e1a55..13c2101a7 100644 --- a/management/server/http/dns_settings_handler.go +++ b/management/server/http/dns_settings_handler.go @@ -32,14 +32,14 @@ func NewDNSSettingsHandler(accountManager server.AccountManager, authCfg AuthCfg // GetDNSSettings returns the DNS settings for the account func (h *DNSSettingsHandler) GetDNSSettings(w http.ResponseWriter, r *http.Request) { claims := h.claimsExtractor.FromRequestContext(r) - account, user, err := h.accountManager.GetAccountFromToken(r.Context(), claims) + accountID, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims) if err != nil { log.WithContext(r.Context()).Error(err) http.Redirect(w, r, "/", http.StatusInternalServerError) return } - dnsSettings, err := h.accountManager.GetDNSSettings(r.Context(), account.Id, user.Id) + dnsSettings, err := h.accountManager.GetDNSSettings(r.Context(), accountID, userID) if err != nil { util.WriteError(r.Context(), err, w) return @@ -55,7 +55,7 @@ func (h *DNSSettingsHandler) GetDNSSettings(w http.ResponseWriter, r *http.Reque // UpdateDNSSettings handles update to DNS settings of an account func (h *DNSSettingsHandler) UpdateDNSSettings(w http.ResponseWriter, r *http.Request) { claims := h.claimsExtractor.FromRequestContext(r) - account, user, err := h.accountManager.GetAccountFromToken(r.Context(), claims) + accountID, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims) if err != nil { util.WriteError(r.Context(), err, w) return @@ -72,7 +72,7 @@ func (h *DNSSettingsHandler) UpdateDNSSettings(w http.ResponseWriter, r *http.Re DisabledManagementGroups: req.DisabledManagementGroups, } - err = h.accountManager.SaveDNSSettings(r.Context(), account.Id, user.Id, updateDNSSettings) + err = h.accountManager.SaveDNSSettings(r.Context(), accountID, userID, updateDNSSettings) if err != nil { util.WriteError(r.Context(), err, w) return diff --git a/management/server/http/dns_settings_handler_test.go b/management/server/http/dns_settings_handler_test.go index 897ae63dc..8baea7b15 100644 --- a/management/server/http/dns_settings_handler_test.go +++ b/management/server/http/dns_settings_handler_test.go @@ -52,8 +52,8 @@ func initDNSSettingsTestData() *DNSSettingsHandler { } return status.Errorf(status.InvalidArgument, "the dns settings provided are nil") }, - GetAccountFromTokenFunc: func(ctx context.Context, _ jwtclaims.AuthorizationClaims) (*server.Account, *server.User, error) { - return testingDNSSettingsAccount, testingDNSSettingsAccount.Users[testDNSSettingsUserID], nil + GetAccountIDFromTokenFunc: func(ctx context.Context, _ jwtclaims.AuthorizationClaims) (string, string, error) { + return testingDNSSettingsAccount.Id, testingDNSSettingsAccount.Users[testDNSSettingsUserID].Id, nil }, }, claimsExtractor: jwtclaims.NewClaimsExtractor( diff --git a/management/server/http/events_handler.go b/management/server/http/events_handler.go index 428b4c164..ee0c63f28 100644 --- a/management/server/http/events_handler.go +++ b/management/server/http/events_handler.go @@ -34,14 +34,14 @@ func NewEventsHandler(accountManager server.AccountManager, authCfg AuthCfg) *Ev // GetAllEvents list of the given account func (h *EventsHandler) GetAllEvents(w http.ResponseWriter, r *http.Request) { claims := h.claimsExtractor.FromRequestContext(r) - account, user, err := h.accountManager.GetAccountFromToken(r.Context(), claims) + accountID, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims) if err != nil { log.WithContext(r.Context()).Error(err) http.Redirect(w, r, "/", http.StatusInternalServerError) return } - accountEvents, err := h.accountManager.GetEvents(r.Context(), account.Id, user.Id) + accountEvents, err := h.accountManager.GetEvents(r.Context(), accountID, userID) if err != nil { util.WriteError(r.Context(), err, w) return @@ -51,7 +51,7 @@ func (h *EventsHandler) GetAllEvents(w http.ResponseWriter, r *http.Request) { events[i] = toEventResponse(e) } - err = h.fillEventsWithUserInfo(r.Context(), events, account.Id, user.Id) + err = h.fillEventsWithUserInfo(r.Context(), events, accountID, userID) if err != nil { util.WriteError(r.Context(), err, w) return diff --git a/management/server/http/events_handler_test.go b/management/server/http/events_handler_test.go index 8bdd508bf..e525cf2ee 100644 --- a/management/server/http/events_handler_test.go +++ b/management/server/http/events_handler_test.go @@ -20,7 +20,7 @@ import ( "github.com/netbirdio/netbird/management/server/mock_server" ) -func initEventsTestData(account string, user *server.User, events ...*activity.Event) *EventsHandler { +func initEventsTestData(account string, events ...*activity.Event) *EventsHandler { return &EventsHandler{ accountManager: &mock_server.MockAccountManager{ GetEventsFunc: func(_ context.Context, accountID, userID string) ([]*activity.Event, error) { @@ -29,14 +29,8 @@ func initEventsTestData(account string, user *server.User, events ...*activity.E } return []*activity.Event{}, nil }, - GetAccountFromTokenFunc: func(_ context.Context, claims jwtclaims.AuthorizationClaims) (*server.Account, *server.User, error) { - return &server.Account{ - Id: claims.AccountId, - Domain: "hotmail.com", - Users: map[string]*server.User{ - user.Id: user, - }, - }, user, nil + GetAccountIDFromTokenFunc: func(_ context.Context, claims jwtclaims.AuthorizationClaims) (string, string, error) { + return claims.AccountId, claims.UserId, nil }, GetUsersFromAccountFunc: func(_ context.Context, accountID, userID string) ([]*server.UserInfo, error) { return make([]*server.UserInfo, 0), nil @@ -199,7 +193,7 @@ func TestEvents_GetEvents(t *testing.T) { accountID := "test_account" adminUser := server.NewAdminUser("test_user") events := generateEvents(accountID, adminUser.Id) - handler := initEventsTestData(accountID, adminUser, events...) + handler := initEventsTestData(accountID, events...) for _, tc := range tt { t.Run(tc.name, func(t *testing.T) { diff --git a/management/server/http/geolocation_handler_test.go b/management/server/http/geolocation_handler_test.go index 7f4d6dc7c..19c916dd2 100644 --- a/management/server/http/geolocation_handler_test.go +++ b/management/server/http/geolocation_handler_test.go @@ -11,9 +11,9 @@ import ( "testing" "github.com/gorilla/mux" + "github.com/netbirdio/netbird/management/server" "github.com/stretchr/testify/assert" - "github.com/netbirdio/netbird/management/server" "github.com/netbirdio/netbird/management/server/geolocation" "github.com/netbirdio/netbird/management/server/http/api" "github.com/netbirdio/netbird/management/server/jwtclaims" @@ -43,14 +43,11 @@ func initGeolocationTestData(t *testing.T) *GeolocationsHandler { return &GeolocationsHandler{ accountManager: &mock_server.MockAccountManager{ - GetAccountFromTokenFunc: func(_ context.Context, claims jwtclaims.AuthorizationClaims) (*server.Account, *server.User, error) { - user := server.NewAdminUser("test_user") - return &server.Account{ - Id: claims.AccountId, - Users: map[string]*server.User{ - "test_user": user, - }, - }, user, nil + GetAccountIDFromTokenFunc: func(_ context.Context, claims jwtclaims.AuthorizationClaims) (string, string, error) { + return claims.AccountId, claims.UserId, nil + }, + GetUserByIDFunc: func(ctx context.Context, id string) (*server.User, error) { + return server.NewAdminUser(id), nil }, }, geolocationManager: geo, diff --git a/management/server/http/geolocations_handler.go b/management/server/http/geolocations_handler.go index af4d3116f..418228abf 100644 --- a/management/server/http/geolocations_handler.go +++ b/management/server/http/geolocations_handler.go @@ -98,7 +98,12 @@ func (l *GeolocationsHandler) GetCitiesByCountry(w http.ResponseWriter, r *http. func (l *GeolocationsHandler) authenticateUser(r *http.Request) error { claims := l.claimsExtractor.FromRequestContext(r) - _, user, err := l.accountManager.GetAccountFromToken(r.Context(), claims) + _, userID, err := l.accountManager.GetAccountIDFromToken(r.Context(), claims) + if err != nil { + return err + } + + user, err := l.accountManager.GetUserByID(r.Context(), userID) if err != nil { return err } diff --git a/management/server/http/groups_handler.go b/management/server/http/groups_handler.go index c622d873a..f369d1a00 100644 --- a/management/server/http/groups_handler.go +++ b/management/server/http/groups_handler.go @@ -5,6 +5,7 @@ import ( "net/http" "github.com/gorilla/mux" + nbpeer "github.com/netbirdio/netbird/management/server/peer" log "github.com/sirupsen/logrus" "github.com/netbirdio/netbird/management/server" @@ -35,14 +36,20 @@ func NewGroupsHandler(accountManager server.AccountManager, authCfg AuthCfg) *Gr // GetAllGroups list for the account func (h *GroupsHandler) GetAllGroups(w http.ResponseWriter, r *http.Request) { claims := h.claimsExtractor.FromRequestContext(r) - account, user, err := h.accountManager.GetAccountFromToken(r.Context(), claims) + accountID, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims) if err != nil { log.WithContext(r.Context()).Error(err) http.Redirect(w, r, "/", http.StatusInternalServerError) return } - groups, err := h.accountManager.GetAllGroups(r.Context(), account.Id, user.Id) + groups, err := h.accountManager.GetAllGroups(r.Context(), accountID, userID) + if err != nil { + util.WriteError(r.Context(), err, w) + return + } + + accountPeers, err := h.accountManager.GetPeers(r.Context(), accountID, userID) if err != nil { util.WriteError(r.Context(), err, w) return @@ -50,7 +57,7 @@ func (h *GroupsHandler) GetAllGroups(w http.ResponseWriter, r *http.Request) { groupsResponse := make([]*api.Group, 0, len(groups)) for _, group := range groups { - groupsResponse = append(groupsResponse, toGroupResponse(account, group)) + groupsResponse = append(groupsResponse, toGroupResponse(accountPeers, group)) } util.WriteJSONObject(r.Context(), w, groupsResponse) @@ -59,7 +66,7 @@ func (h *GroupsHandler) GetAllGroups(w http.ResponseWriter, r *http.Request) { // UpdateGroup handles update to a group identified by a given ID func (h *GroupsHandler) UpdateGroup(w http.ResponseWriter, r *http.Request) { claims := h.claimsExtractor.FromRequestContext(r) - account, user, err := h.accountManager.GetAccountFromToken(r.Context(), claims) + accountID, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims) if err != nil { util.WriteError(r.Context(), err, w) return @@ -76,17 +83,18 @@ func (h *GroupsHandler) UpdateGroup(w http.ResponseWriter, r *http.Request) { return } - eg, ok := account.Groups[groupID] - if !ok { - util.WriteError(r.Context(), status.Errorf(status.NotFound, "couldn't find group with ID %s", groupID), w) - return - } - - allGroup, err := account.GetGroupAll() + existingGroup, err := h.accountManager.GetGroup(r.Context(), accountID, groupID, userID) if err != nil { util.WriteError(r.Context(), err, w) return } + + allGroup, err := h.accountManager.GetGroupByName(r.Context(), "All", accountID) + if err != nil { + util.WriteError(r.Context(), err, w) + return + } + if allGroup.ID == groupID { util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "updating group ALL is not allowed"), w) return @@ -114,23 +122,29 @@ func (h *GroupsHandler) UpdateGroup(w http.ResponseWriter, r *http.Request) { ID: groupID, Name: req.Name, Peers: peers, - Issued: eg.Issued, - IntegrationReference: eg.IntegrationReference, + Issued: existingGroup.Issued, + IntegrationReference: existingGroup.IntegrationReference, } - if err := h.accountManager.SaveGroup(r.Context(), account.Id, user.Id, &group); err != nil { - log.WithContext(r.Context()).Errorf("failed updating group %s under account %s %v", groupID, account.Id, err) + if err := h.accountManager.SaveGroup(r.Context(), accountID, userID, &group); err != nil { + log.WithContext(r.Context()).Errorf("failed updating group %s under account %s %v", groupID, accountID, err) util.WriteError(r.Context(), err, w) return } - util.WriteJSONObject(r.Context(), w, toGroupResponse(account, &group)) + accountPeers, err := h.accountManager.GetPeers(r.Context(), accountID, userID) + if err != nil { + util.WriteError(r.Context(), err, w) + return + } + + util.WriteJSONObject(r.Context(), w, toGroupResponse(accountPeers, &group)) } // CreateGroup handles group creation request func (h *GroupsHandler) CreateGroup(w http.ResponseWriter, r *http.Request) { claims := h.claimsExtractor.FromRequestContext(r) - account, user, err := h.accountManager.GetAccountFromToken(r.Context(), claims) + accountID, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims) if err != nil { util.WriteError(r.Context(), err, w) return @@ -160,24 +174,29 @@ func (h *GroupsHandler) CreateGroup(w http.ResponseWriter, r *http.Request) { Issued: nbgroup.GroupIssuedAPI, } - err = h.accountManager.SaveGroup(r.Context(), account.Id, user.Id, &group) + err = h.accountManager.SaveGroup(r.Context(), accountID, userID, &group) if err != nil { util.WriteError(r.Context(), err, w) return } - util.WriteJSONObject(r.Context(), w, toGroupResponse(account, &group)) + accountPeers, err := h.accountManager.GetPeers(r.Context(), accountID, userID) + if err != nil { + util.WriteError(r.Context(), err, w) + return + } + + util.WriteJSONObject(r.Context(), w, toGroupResponse(accountPeers, &group)) } // DeleteGroup handles group deletion request func (h *GroupsHandler) DeleteGroup(w http.ResponseWriter, r *http.Request) { claims := h.claimsExtractor.FromRequestContext(r) - account, user, err := h.accountManager.GetAccountFromToken(r.Context(), claims) + accountID, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims) if err != nil { util.WriteError(r.Context(), err, w) return } - aID := account.Id groupID := mux.Vars(r)["groupId"] if len(groupID) == 0 { @@ -185,18 +204,7 @@ func (h *GroupsHandler) DeleteGroup(w http.ResponseWriter, r *http.Request) { return } - allGroup, err := account.GetGroupAll() - if err != nil { - util.WriteError(r.Context(), err, w) - return - } - - if allGroup.ID == groupID { - util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "deleting group ALL is not allowed"), w) - return - } - - err = h.accountManager.DeleteGroup(r.Context(), aID, user.Id, groupID) + err = h.accountManager.DeleteGroup(r.Context(), accountID, userID, groupID) if err != nil { _, ok := err.(*server.GroupLinkError) if ok { @@ -213,34 +221,39 @@ func (h *GroupsHandler) DeleteGroup(w http.ResponseWriter, r *http.Request) { // GetGroup returns a group func (h *GroupsHandler) GetGroup(w http.ResponseWriter, r *http.Request) { claims := h.claimsExtractor.FromRequestContext(r) - account, user, err := h.accountManager.GetAccountFromToken(r.Context(), claims) + accountID, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims) + if err != nil { + util.WriteError(r.Context(), err, w) + return + } + groupID := mux.Vars(r)["groupId"] + if len(groupID) == 0 { + util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "invalid group ID"), w) + return + } + + group, err := h.accountManager.GetGroup(r.Context(), accountID, groupID, userID) if err != nil { util.WriteError(r.Context(), err, w) return } - switch r.Method { - case http.MethodGet: - groupID := mux.Vars(r)["groupId"] - if len(groupID) == 0 { - util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "invalid group ID"), w) - return - } - - group, err := h.accountManager.GetGroup(r.Context(), account.Id, groupID, user.Id) - if err != nil { - util.WriteError(r.Context(), err, w) - return - } - - util.WriteJSONObject(r.Context(), w, toGroupResponse(account, group)) - default: - util.WriteError(r.Context(), status.Errorf(status.NotFound, "HTTP method not found"), w) + accountPeers, err := h.accountManager.GetPeers(r.Context(), accountID, userID) + if err != nil { + util.WriteError(r.Context(), err, w) return } + + util.WriteJSONObject(r.Context(), w, toGroupResponse(accountPeers, group)) + } -func toGroupResponse(account *server.Account, group *nbgroup.Group) *api.Group { +func toGroupResponse(peers []*nbpeer.Peer, group *nbgroup.Group) *api.Group { + peersMap := make(map[string]*nbpeer.Peer, len(peers)) + for _, peer := range peers { + peersMap[peer.ID] = peer + } + cache := make(map[string]api.PeerMinimum) gr := api.Group{ Id: group.ID, @@ -251,7 +264,7 @@ func toGroupResponse(account *server.Account, group *nbgroup.Group) *api.Group { for _, pid := range group.Peers { _, ok := cache[pid] if !ok { - peer, ok := account.Peers[pid] + peer, ok := peersMap[pid] if !ok { continue } diff --git a/management/server/http/groups_handler_test.go b/management/server/http/groups_handler_test.go index d5ed07c9e..7f3c81f18 100644 --- a/management/server/http/groups_handler_test.go +++ b/management/server/http/groups_handler_test.go @@ -14,6 +14,7 @@ import ( "github.com/gorilla/mux" "github.com/magiconair/properties/assert" + "golang.org/x/exp/maps" "github.com/netbirdio/netbird/management/server" nbgroup "github.com/netbirdio/netbird/management/server/group" @@ -30,7 +31,7 @@ var TestPeers = map[string]*nbpeer.Peer{ "B": {Key: "B", ID: "peer-B-ID", IP: net.ParseIP("200.200.200.200")}, } -func initGroupTestData(user *server.User, _ ...*nbgroup.Group) *GroupsHandler { +func initGroupTestData(initGroups ...*nbgroup.Group) *GroupsHandler { return &GroupsHandler{ accountManager: &mock_server.MockAccountManager{ SaveGroupFunc: func(_ context.Context, accountID, userID string, group *nbgroup.Group) error { @@ -40,36 +41,35 @@ func initGroupTestData(user *server.User, _ ...*nbgroup.Group) *GroupsHandler { return nil }, GetGroupFunc: func(_ context.Context, _, groupID, _ string) (*nbgroup.Group, error) { - if groupID != "idofthegroup" { + groups := map[string]*nbgroup.Group{ + "id-jwt-group": {ID: "id-jwt-group", Name: "From JWT", Issued: nbgroup.GroupIssuedJWT}, + "id-existed": {ID: "id-existed", Peers: []string{"A", "B"}, Issued: nbgroup.GroupIssuedAPI}, + "id-all": {ID: "id-all", Name: "All", Issued: nbgroup.GroupIssuedAPI}, + } + + for _, group := range initGroups { + groups[group.ID] = group + } + + group, ok := groups[groupID] + if !ok { return nil, status.Errorf(status.NotFound, "not found") } - if groupID == "id-jwt-group" { - return &nbgroup.Group{ - ID: "id-jwt-group", - Name: "Default Group", - Issued: nbgroup.GroupIssuedJWT, - }, nil - } - return &nbgroup.Group{ - ID: "idofthegroup", - Name: "Group", - Issued: nbgroup.GroupIssuedAPI, - }, nil + + return group, nil }, - GetAccountFromTokenFunc: func(_ context.Context, claims jwtclaims.AuthorizationClaims) (*server.Account, *server.User, error) { - return &server.Account{ - Id: claims.AccountId, - Domain: "hotmail.com", - Peers: TestPeers, - Users: map[string]*server.User{ - user.Id: user, - }, - Groups: map[string]*nbgroup.Group{ - "id-jwt-group": {ID: "id-jwt-group", Name: "From JWT", Issued: nbgroup.GroupIssuedJWT}, - "id-existed": {ID: "id-existed", Peers: []string{"A", "B"}, Issued: nbgroup.GroupIssuedAPI}, - "id-all": {ID: "id-all", Name: "All", Issued: nbgroup.GroupIssuedAPI}, - }, - }, user, nil + GetAccountIDFromTokenFunc: func(_ context.Context, claims jwtclaims.AuthorizationClaims) (string, string, error) { + return claims.AccountId, claims.UserId, nil + }, + GetGroupByNameFunc: func(ctx context.Context, groupName, _ string) (*nbgroup.Group, error) { + if groupName == "All" { + return &nbgroup.Group{ID: "id-all", Name: "All", Issued: nbgroup.GroupIssuedAPI}, nil + } + + return nil, fmt.Errorf("unknown group name") + }, + GetPeersFunc: func(ctx context.Context, accountID, userID string) ([]*nbpeer.Peer, error) { + return maps.Values(TestPeers), nil }, DeleteGroupFunc: func(_ context.Context, accountID, userId, groupID string) error { if groupID == "linked-grp" { @@ -125,8 +125,7 @@ func TestGetGroup(t *testing.T) { Name: "Group", } - adminUser := server.NewAdminUser("test_user") - p := initGroupTestData(adminUser, group) + p := initGroupTestData(group) for _, tc := range tt { t.Run(tc.name, func(t *testing.T) { @@ -247,8 +246,7 @@ func TestWriteGroup(t *testing.T) { }, } - adminUser := server.NewAdminUser("test_user") - p := initGroupTestData(adminUser) + p := initGroupTestData() for _, tc := range tt { t.Run(tc.name, func(t *testing.T) { @@ -325,8 +323,7 @@ func TestDeleteGroup(t *testing.T) { }, } - adminUser := server.NewAdminUser("test_user") - p := initGroupTestData(adminUser) + p := initGroupTestData() for _, tc := range tt { t.Run(tc.name, func(t *testing.T) { diff --git a/management/server/http/nameservers_handler.go b/management/server/http/nameservers_handler.go index c6e00bb2d..e7a2bc2ae 100644 --- a/management/server/http/nameservers_handler.go +++ b/management/server/http/nameservers_handler.go @@ -36,14 +36,14 @@ func NewNameserversHandler(accountManager server.AccountManager, authCfg AuthCfg // GetAllNameservers returns the list of nameserver groups for the account func (h *NameserversHandler) GetAllNameservers(w http.ResponseWriter, r *http.Request) { claims := h.claimsExtractor.FromRequestContext(r) - account, user, err := h.accountManager.GetAccountFromToken(r.Context(), claims) + accountID, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims) if err != nil { log.WithContext(r.Context()).Error(err) http.Redirect(w, r, "/", http.StatusInternalServerError) return } - nsGroups, err := h.accountManager.ListNameServerGroups(r.Context(), account.Id, user.Id) + nsGroups, err := h.accountManager.ListNameServerGroups(r.Context(), accountID, userID) if err != nil { util.WriteError(r.Context(), err, w) return @@ -60,7 +60,7 @@ func (h *NameserversHandler) GetAllNameservers(w http.ResponseWriter, r *http.Re // CreateNameserverGroup handles nameserver group creation request func (h *NameserversHandler) CreateNameserverGroup(w http.ResponseWriter, r *http.Request) { claims := h.claimsExtractor.FromRequestContext(r) - account, user, err := h.accountManager.GetAccountFromToken(r.Context(), claims) + accountID, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims) if err != nil { util.WriteError(r.Context(), err, w) return @@ -79,7 +79,7 @@ func (h *NameserversHandler) CreateNameserverGroup(w http.ResponseWriter, r *htt return } - nsGroup, err := h.accountManager.CreateNameServerGroup(r.Context(), account.Id, req.Name, req.Description, nsList, req.Groups, req.Primary, req.Domains, req.Enabled, user.Id, req.SearchDomainsEnabled) + nsGroup, err := h.accountManager.CreateNameServerGroup(r.Context(), accountID, req.Name, req.Description, nsList, req.Groups, req.Primary, req.Domains, req.Enabled, userID, req.SearchDomainsEnabled) if err != nil { util.WriteError(r.Context(), err, w) return @@ -93,7 +93,7 @@ func (h *NameserversHandler) CreateNameserverGroup(w http.ResponseWriter, r *htt // UpdateNameserverGroup handles update to a nameserver group identified by a given ID func (h *NameserversHandler) UpdateNameserverGroup(w http.ResponseWriter, r *http.Request) { claims := h.claimsExtractor.FromRequestContext(r) - account, user, err := h.accountManager.GetAccountFromToken(r.Context(), claims) + accountID, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims) if err != nil { util.WriteError(r.Context(), err, w) return @@ -130,7 +130,7 @@ func (h *NameserversHandler) UpdateNameserverGroup(w http.ResponseWriter, r *htt SearchDomainsEnabled: req.SearchDomainsEnabled, } - err = h.accountManager.SaveNameServerGroup(r.Context(), account.Id, user.Id, updatedNSGroup) + err = h.accountManager.SaveNameServerGroup(r.Context(), accountID, userID, updatedNSGroup) if err != nil { util.WriteError(r.Context(), err, w) return @@ -144,7 +144,7 @@ func (h *NameserversHandler) UpdateNameserverGroup(w http.ResponseWriter, r *htt // DeleteNameserverGroup handles nameserver group deletion request func (h *NameserversHandler) DeleteNameserverGroup(w http.ResponseWriter, r *http.Request) { claims := h.claimsExtractor.FromRequestContext(r) - account, user, err := h.accountManager.GetAccountFromToken(r.Context(), claims) + accountID, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims) if err != nil { util.WriteError(r.Context(), err, w) return @@ -156,7 +156,7 @@ func (h *NameserversHandler) DeleteNameserverGroup(w http.ResponseWriter, r *htt return } - err = h.accountManager.DeleteNameServerGroup(r.Context(), account.Id, nsGroupID, user.Id) + err = h.accountManager.DeleteNameServerGroup(r.Context(), accountID, nsGroupID, userID) if err != nil { util.WriteError(r.Context(), err, w) return @@ -168,7 +168,7 @@ func (h *NameserversHandler) DeleteNameserverGroup(w http.ResponseWriter, r *htt // GetNameserverGroup handles a nameserver group Get request identified by ID func (h *NameserversHandler) GetNameserverGroup(w http.ResponseWriter, r *http.Request) { claims := h.claimsExtractor.FromRequestContext(r) - account, user, err := h.accountManager.GetAccountFromToken(r.Context(), claims) + accountID, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims) if err != nil { log.WithContext(r.Context()).Error(err) http.Redirect(w, r, "/", http.StatusInternalServerError) @@ -181,7 +181,7 @@ func (h *NameserversHandler) GetNameserverGroup(w http.ResponseWriter, r *http.R return } - nsGroup, err := h.accountManager.GetNameServerGroup(r.Context(), account.Id, user.Id, nsGroupID) + nsGroup, err := h.accountManager.GetNameServerGroup(r.Context(), accountID, userID, nsGroupID) if err != nil { util.WriteError(r.Context(), err, w) return diff --git a/management/server/http/nameservers_handler_test.go b/management/server/http/nameservers_handler_test.go index 28b080571..98c2e402d 100644 --- a/management/server/http/nameservers_handler_test.go +++ b/management/server/http/nameservers_handler_test.go @@ -18,7 +18,6 @@ import ( "github.com/gorilla/mux" - "github.com/netbirdio/netbird/management/server" "github.com/netbirdio/netbird/management/server/jwtclaims" "github.com/netbirdio/netbird/management/server/mock_server" ) @@ -29,14 +28,6 @@ const ( testNSGroupAccountID = "test_id" ) -var testingNSAccount = &server.Account{ - Id: testNSGroupAccountID, - Domain: "hotmail.com", - Users: map[string]*server.User{ - "test_user": server.NewAdminUser("test_user"), - }, -} - var baseExistingNSGroup = &nbdns.NameServerGroup{ ID: existingNSGroupID, Name: "super", @@ -90,8 +81,8 @@ func initNameserversTestData() *NameserversHandler { } return status.Errorf(status.NotFound, "nameserver group with ID %s was not found", nsGroupToSave.ID) }, - GetAccountFromTokenFunc: func(_ context.Context, _ jwtclaims.AuthorizationClaims) (*server.Account, *server.User, error) { - return testingNSAccount, testingAccount.Users["test_user"], nil + GetAccountIDFromTokenFunc: func(_ context.Context, claims jwtclaims.AuthorizationClaims) (string, string, error) { + return claims.AccountId, claims.UserId, nil }, }, claimsExtractor: jwtclaims.NewClaimsExtractor( diff --git a/management/server/http/pat_handler.go b/management/server/http/pat_handler.go index 9d8448d3d..dfa9563e3 100644 --- a/management/server/http/pat_handler.go +++ b/management/server/http/pat_handler.go @@ -34,20 +34,20 @@ func NewPATsHandler(accountManager server.AccountManager, authCfg AuthCfg) *PATH // GetAllTokens is HTTP GET handler that returns a list of all personal access tokens for the given user func (h *PATHandler) GetAllTokens(w http.ResponseWriter, r *http.Request) { claims := h.claimsExtractor.FromRequestContext(r) - account, user, err := h.accountManager.GetAccountFromToken(r.Context(), claims) + accountID, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims) if err != nil { util.WriteError(r.Context(), err, w) return } vars := mux.Vars(r) - userID := vars["userId"] + targetUserID := vars["userId"] if len(userID) == 0 { util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "invalid user ID"), w) return } - pats, err := h.accountManager.GetAllPATs(r.Context(), account.Id, user.Id, userID) + pats, err := h.accountManager.GetAllPATs(r.Context(), accountID, userID, targetUserID) if err != nil { util.WriteError(r.Context(), err, w) return @@ -64,7 +64,7 @@ func (h *PATHandler) GetAllTokens(w http.ResponseWriter, r *http.Request) { // GetToken is HTTP GET handler that returns a personal access token for the given user func (h *PATHandler) GetToken(w http.ResponseWriter, r *http.Request) { claims := h.claimsExtractor.FromRequestContext(r) - account, user, err := h.accountManager.GetAccountFromToken(r.Context(), claims) + accountID, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims) if err != nil { util.WriteError(r.Context(), err, w) return @@ -83,7 +83,7 @@ func (h *PATHandler) GetToken(w http.ResponseWriter, r *http.Request) { return } - pat, err := h.accountManager.GetPAT(r.Context(), account.Id, user.Id, targetUserID, tokenID) + pat, err := h.accountManager.GetPAT(r.Context(), accountID, userID, targetUserID, tokenID) if err != nil { util.WriteError(r.Context(), err, w) return @@ -95,7 +95,7 @@ func (h *PATHandler) GetToken(w http.ResponseWriter, r *http.Request) { // CreateToken is HTTP POST handler that creates a personal access token for the given user func (h *PATHandler) CreateToken(w http.ResponseWriter, r *http.Request) { claims := h.claimsExtractor.FromRequestContext(r) - account, user, err := h.accountManager.GetAccountFromToken(r.Context(), claims) + accountID, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims) if err != nil { util.WriteError(r.Context(), err, w) return @@ -115,7 +115,7 @@ func (h *PATHandler) CreateToken(w http.ResponseWriter, r *http.Request) { return } - pat, err := h.accountManager.CreatePAT(r.Context(), account.Id, user.Id, targetUserID, req.Name, req.ExpiresIn) + pat, err := h.accountManager.CreatePAT(r.Context(), accountID, userID, targetUserID, req.Name, req.ExpiresIn) if err != nil { util.WriteError(r.Context(), err, w) return @@ -127,7 +127,7 @@ func (h *PATHandler) CreateToken(w http.ResponseWriter, r *http.Request) { // DeleteToken is HTTP DELETE handler that deletes a personal access token for the given user func (h *PATHandler) DeleteToken(w http.ResponseWriter, r *http.Request) { claims := h.claimsExtractor.FromRequestContext(r) - account, user, err := h.accountManager.GetAccountFromToken(r.Context(), claims) + accountID, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims) if err != nil { util.WriteError(r.Context(), err, w) return @@ -146,7 +146,7 @@ func (h *PATHandler) DeleteToken(w http.ResponseWriter, r *http.Request) { return } - err = h.accountManager.DeletePAT(r.Context(), account.Id, user.Id, targetUserID, tokenID) + err = h.accountManager.DeletePAT(r.Context(), accountID, userID, targetUserID, tokenID) if err != nil { util.WriteError(r.Context(), err, w) return diff --git a/management/server/http/pat_handler_test.go b/management/server/http/pat_handler_test.go index b72f71468..c28228a50 100644 --- a/management/server/http/pat_handler_test.go +++ b/management/server/http/pat_handler_test.go @@ -77,8 +77,8 @@ func initPATTestData() *PATHandler { }, nil }, - GetAccountFromTokenFunc: func(_ context.Context, _ jwtclaims.AuthorizationClaims) (*server.Account, *server.User, error) { - return testAccount, testAccount.Users[existingUserID], nil + GetAccountIDFromTokenFunc: func(_ context.Context, claims jwtclaims.AuthorizationClaims) (string, string, error) { + return claims.AccountId, claims.UserId, nil }, DeletePATFunc: func(_ context.Context, accountID string, initiatorUserID string, targetUserID string, tokenID string) error { if accountID != existingAccountID { @@ -119,7 +119,7 @@ func initPATTestData() *PATHandler { return jwtclaims.AuthorizationClaims{ UserId: existingUserID, Domain: testDomain, - AccountId: testNSGroupAccountID, + AccountId: existingAccountID, } }), ), diff --git a/management/server/http/peers_handler.go b/management/server/http/peers_handler.go index 5a2190d83..4fbbc3106 100644 --- a/management/server/http/peers_handler.go +++ b/management/server/http/peers_handler.go @@ -74,7 +74,7 @@ func (h *PeersHandler) getPeer(ctx context.Context, account *server.Account, pee util.WriteJSONObject(ctx, w, toSinglePeerResponse(peerToReturn, groupsInfo, dnsDomain, valid)) } -func (h *PeersHandler) updatePeer(ctx context.Context, account *server.Account, user *server.User, peerID string, w http.ResponseWriter, r *http.Request) { +func (h *PeersHandler) updatePeer(ctx context.Context, account *server.Account, userID, peerID string, w http.ResponseWriter, r *http.Request) { req := &api.PeerRequest{} err := json.NewDecoder(r.Body).Decode(&req) if err != nil { @@ -96,7 +96,7 @@ func (h *PeersHandler) updatePeer(ctx context.Context, account *server.Account, } } - peer, err := h.accountManager.UpdatePeer(ctx, account.Id, user.Id, update) + peer, err := h.accountManager.UpdatePeer(ctx, account.Id, userID, update) if err != nil { util.WriteError(ctx, err, w) return @@ -130,7 +130,7 @@ func (h *PeersHandler) deletePeer(ctx context.Context, accountID, userID string, // HandlePeer handles all peer requests for GET, PUT and DELETE operations func (h *PeersHandler) HandlePeer(w http.ResponseWriter, r *http.Request) { claims := h.claimsExtractor.FromRequestContext(r) - account, user, err := h.accountManager.GetAccountFromToken(r.Context(), claims) + accountID, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims) if err != nil { util.WriteError(r.Context(), err, w) return @@ -144,13 +144,20 @@ func (h *PeersHandler) HandlePeer(w http.ResponseWriter, r *http.Request) { switch r.Method { case http.MethodDelete: - h.deletePeer(r.Context(), account.Id, user.Id, peerID, w) + h.deletePeer(r.Context(), accountID, userID, peerID, w) return - case http.MethodPut: - h.updatePeer(r.Context(), account, user, peerID, w, r) - return - case http.MethodGet: - h.getPeer(r.Context(), account, peerID, user.Id, w) + case http.MethodGet, http.MethodPut: + account, err := h.accountManager.GetAccountByID(r.Context(), accountID, userID) + if err != nil { + util.WriteError(r.Context(), err, w) + return + } + + if r.Method == http.MethodGet { + h.getPeer(r.Context(), account, peerID, userID, w) + } else { + h.updatePeer(r.Context(), account, userID, peerID, w, r) + } return default: util.WriteError(r.Context(), status.Errorf(status.NotFound, "unknown METHOD"), w) @@ -159,19 +166,14 @@ func (h *PeersHandler) HandlePeer(w http.ResponseWriter, r *http.Request) { // GetAllPeers returns a list of all peers associated with a provided account func (h *PeersHandler) GetAllPeers(w http.ResponseWriter, r *http.Request) { - if r.Method != http.MethodGet { - util.WriteError(r.Context(), status.Errorf(status.NotFound, "unknown METHOD"), w) - return - } - claims := h.claimsExtractor.FromRequestContext(r) - account, user, err := h.accountManager.GetAccountFromToken(r.Context(), claims) + accountID, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims) if err != nil { util.WriteError(r.Context(), err, w) return } - peers, err := h.accountManager.GetPeers(r.Context(), account.Id, user.Id) + account, err := h.accountManager.GetAccountByID(r.Context(), accountID, userID) if err != nil { util.WriteError(r.Context(), err, w) return @@ -179,8 +181,8 @@ func (h *PeersHandler) GetAllPeers(w http.ResponseWriter, r *http.Request) { dnsDomain := h.accountManager.GetDNSDomain() - respBody := make([]*api.PeerBatch, 0, len(peers)) - for _, peer := range peers { + respBody := make([]*api.PeerBatch, 0, len(account.Peers)) + for _, peer := range account.Peers { peerToReturn, err := h.checkPeerStatus(peer) if err != nil { util.WriteError(r.Context(), err, w) @@ -214,7 +216,7 @@ func (h *PeersHandler) setApprovalRequiredFlag(respBody []*api.PeerBatch, approv // GetAccessiblePeers returns a list of all peers that the specified peer can connect to within the network. func (h *PeersHandler) GetAccessiblePeers(w http.ResponseWriter, r *http.Request) { claims := h.claimsExtractor.FromRequestContext(r) - account, user, err := h.accountManager.GetAccountFromToken(r.Context(), claims) + accountID, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims) if err != nil { util.WriteError(r.Context(), err, w) return @@ -227,6 +229,18 @@ func (h *PeersHandler) GetAccessiblePeers(w http.ResponseWriter, r *http.Request return } + account, err := h.accountManager.GetAccountByID(r.Context(), accountID, userID) + if err != nil { + util.WriteError(r.Context(), err, w) + return + } + + user, err := account.FindUser(userID) + if err != nil { + util.WriteError(r.Context(), err, w) + return + } + // If the user is regular user and does not own the peer // with the given peerID return an empty list if !user.HasAdminPower() && !user.IsServiceUser { diff --git a/management/server/http/peers_handler_test.go b/management/server/http/peers_handler_test.go index dae264fff..f933eee14 100644 --- a/management/server/http/peers_handler_test.go +++ b/management/server/http/peers_handler_test.go @@ -13,16 +13,15 @@ import ( "time" "github.com/gorilla/mux" + "github.com/netbirdio/netbird/management/server" nbgroup "github.com/netbirdio/netbird/management/server/group" "github.com/netbirdio/netbird/management/server/http/api" + "github.com/netbirdio/netbird/management/server/jwtclaims" nbpeer "github.com/netbirdio/netbird/management/server/peer" "golang.org/x/exp/maps" - "github.com/netbirdio/netbird/management/server/jwtclaims" - "github.com/stretchr/testify/assert" - "github.com/netbirdio/netbird/management/server" "github.com/netbirdio/netbird/management/server/mock_server" ) @@ -70,7 +69,10 @@ func initTestMetaData(peers ...*nbpeer.Peer) *PeersHandler { GetDNSDomainFunc: func() string { return "netbird.selfhosted" }, - GetAccountFromTokenFunc: func(_ context.Context, claims jwtclaims.AuthorizationClaims) (*server.Account, *server.User, error) { + GetAccountIDFromTokenFunc: func(_ context.Context, claims jwtclaims.AuthorizationClaims) (string, string, error) { + return claims.AccountId, claims.UserId, nil + }, + GetAccountByIDFunc: func(ctx context.Context, accountID string, userID string) (*server.Account, error) { peersMap := make(map[string]*nbpeer.Peer) for _, peer := range peers { peersMap[peer.ID] = peer.Copy() @@ -78,7 +80,7 @@ func initTestMetaData(peers ...*nbpeer.Peer) *PeersHandler { policy := &server.Policy{ ID: "policy", - AccountID: claims.AccountId, + AccountID: accountID, Name: "policy", Enabled: true, Rules: []*server.PolicyRule{ @@ -100,7 +102,7 @@ func initTestMetaData(peers ...*nbpeer.Peer) *PeersHandler { srvUser.IsServiceUser = true account := &server.Account{ - Id: claims.AccountId, + Id: accountID, Domain: "hotmail.com", Peers: peersMap, Users: map[string]*server.User{ @@ -111,7 +113,7 @@ func initTestMetaData(peers ...*nbpeer.Peer) *PeersHandler { Groups: map[string]*nbgroup.Group{ "group1": { ID: "group1", - AccountID: claims.AccountId, + AccountID: accountID, Name: "group1", Issued: "api", Peers: maps.Keys(peersMap), @@ -132,7 +134,7 @@ func initTestMetaData(peers ...*nbpeer.Peer) *PeersHandler { }, } - return account, account.Users[claims.UserId], nil + return account, nil }, HasConnectedChannelFunc: func(peerID string) bool { statuses := make(map[string]struct{}) @@ -279,9 +281,15 @@ func TestGetPeers(t *testing.T) { // hardcode this check for now as we only have two peers in this suite assert.Equal(t, len(respBody), 2) - assert.Equal(t, respBody[1].Connected, false) - got = respBody[0] + for _, peer := range respBody { + if peer.Id == testPeerID { + got = peer + } else { + assert.Equal(t, peer.Connected, false) + } + } + } else { got = &api.Peer{} err = json.Unmarshal(content, got) diff --git a/management/server/http/policies_handler.go b/management/server/http/policies_handler.go index 9622668f4..225d7e1f3 100644 --- a/management/server/http/policies_handler.go +++ b/management/server/http/policies_handler.go @@ -6,6 +6,7 @@ import ( "strconv" "github.com/gorilla/mux" + nbgroup "github.com/netbirdio/netbird/management/server/group" "github.com/rs/xid" "github.com/netbirdio/netbird/management/server" @@ -35,21 +36,27 @@ func NewPoliciesHandler(accountManager server.AccountManager, authCfg AuthCfg) * // GetAllPolicies list for the account func (h *Policies) GetAllPolicies(w http.ResponseWriter, r *http.Request) { claims := h.claimsExtractor.FromRequestContext(r) - account, user, err := h.accountManager.GetAccountFromToken(r.Context(), claims) + accountID, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims) if err != nil { util.WriteError(r.Context(), err, w) return } - accountPolicies, err := h.accountManager.ListPolicies(r.Context(), account.Id, user.Id) + listPolicies, err := h.accountManager.ListPolicies(r.Context(), accountID, userID) if err != nil { util.WriteError(r.Context(), err, w) return } - policies := []*api.Policy{} - for _, policy := range accountPolicies { - resp := toPolicyResponse(account, policy) + allGroups, err := h.accountManager.GetAllGroups(r.Context(), accountID, userID) + if err != nil { + util.WriteError(r.Context(), err, w) + return + } + + policies := make([]*api.Policy, 0, len(listPolicies)) + for _, policy := range listPolicies { + resp := toPolicyResponse(allGroups, policy) if len(resp.Rules) == 0 { util.WriteError(r.Context(), status.Errorf(status.Internal, "no rules in the policy"), w) return @@ -63,7 +70,7 @@ func (h *Policies) GetAllPolicies(w http.ResponseWriter, r *http.Request) { // UpdatePolicy handles update to a policy identified by a given ID func (h *Policies) UpdatePolicy(w http.ResponseWriter, r *http.Request) { claims := h.claimsExtractor.FromRequestContext(r) - account, user, err := h.accountManager.GetAccountFromToken(r.Context(), claims) + accountID, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims) if err != nil { util.WriteError(r.Context(), err, w) return @@ -76,41 +83,29 @@ func (h *Policies) UpdatePolicy(w http.ResponseWriter, r *http.Request) { return } - policyIdx := -1 - for i, policy := range account.Policies { - if policy.ID == policyID { - policyIdx = i - break - } - } - if policyIdx < 0 { - util.WriteError(r.Context(), status.Errorf(status.NotFound, "couldn't find policy id %s", policyID), w) - return - } - - h.savePolicy(w, r, account, user, policyID) -} - -// CreatePolicy handles policy creation request -func (h *Policies) CreatePolicy(w http.ResponseWriter, r *http.Request) { - claims := h.claimsExtractor.FromRequestContext(r) - account, user, err := h.accountManager.GetAccountFromToken(r.Context(), claims) + _, err = h.accountManager.GetPolicy(r.Context(), accountID, policyID, userID) if err != nil { util.WriteError(r.Context(), err, w) return } - h.savePolicy(w, r, account, user, "") + h.savePolicy(w, r, accountID, userID, policyID) +} + +// CreatePolicy handles policy creation request +func (h *Policies) CreatePolicy(w http.ResponseWriter, r *http.Request) { + claims := h.claimsExtractor.FromRequestContext(r) + accountID, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims) + if err != nil { + util.WriteError(r.Context(), err, w) + return + } + + h.savePolicy(w, r, accountID, userID, "") } // savePolicy handles policy creation and update -func (h *Policies) savePolicy( - w http.ResponseWriter, - r *http.Request, - account *server.Account, - user *server.User, - policyID string, -) { +func (h *Policies) savePolicy(w http.ResponseWriter, r *http.Request, accountID string, userID string, policyID string) { var req api.PutApiPoliciesPolicyIdJSONRequestBody if err := json.NewDecoder(r.Body).Decode(&req); err != nil { util.WriteErrorResponse("couldn't parse JSON request", http.StatusBadRequest, w) @@ -127,6 +122,8 @@ func (h *Policies) savePolicy( return } + isUpdate := policyID != "" + if policyID == "" { policyID = xid.New().String() } @@ -141,8 +138,8 @@ func (h *Policies) savePolicy( pr := server.PolicyRule{ ID: policyID, // TODO: when policy can contain multiple rules, need refactor Name: rule.Name, - Destinations: groupMinimumsToStrings(account, rule.Destinations), - Sources: groupMinimumsToStrings(account, rule.Sources), + Destinations: rule.Destinations, + Sources: rule.Sources, Bidirectional: rule.Bidirectional, } @@ -207,15 +204,21 @@ func (h *Policies) savePolicy( } if req.SourcePostureChecks != nil { - policy.SourcePostureChecks = sourcePostureChecksToStrings(account, *req.SourcePostureChecks) + policy.SourcePostureChecks = *req.SourcePostureChecks } - if err := h.accountManager.SavePolicy(r.Context(), account.Id, user.Id, &policy); err != nil { + if err := h.accountManager.SavePolicy(r.Context(), accountID, userID, &policy, isUpdate); err != nil { util.WriteError(r.Context(), err, w) return } - resp := toPolicyResponse(account, &policy) + allGroups, err := h.accountManager.GetAllGroups(r.Context(), accountID, userID) + if err != nil { + util.WriteError(r.Context(), err, w) + return + } + + resp := toPolicyResponse(allGroups, &policy) if len(resp.Rules) == 0 { util.WriteError(r.Context(), status.Errorf(status.Internal, "no rules in the policy"), w) return @@ -227,12 +230,11 @@ func (h *Policies) savePolicy( // DeletePolicy handles policy deletion request func (h *Policies) DeletePolicy(w http.ResponseWriter, r *http.Request) { claims := h.claimsExtractor.FromRequestContext(r) - account, user, err := h.accountManager.GetAccountFromToken(r.Context(), claims) + accountID, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims) if err != nil { util.WriteError(r.Context(), err, w) return } - aID := account.Id vars := mux.Vars(r) policyID := vars["policyId"] @@ -241,7 +243,7 @@ func (h *Policies) DeletePolicy(w http.ResponseWriter, r *http.Request) { return } - if err = h.accountManager.DeletePolicy(r.Context(), aID, policyID, user.Id); err != nil { + if err = h.accountManager.DeletePolicy(r.Context(), accountID, policyID, userID); err != nil { util.WriteError(r.Context(), err, w) return } @@ -252,40 +254,46 @@ func (h *Policies) DeletePolicy(w http.ResponseWriter, r *http.Request) { // GetPolicy handles a group Get request identified by ID func (h *Policies) GetPolicy(w http.ResponseWriter, r *http.Request) { claims := h.claimsExtractor.FromRequestContext(r) - account, user, err := h.accountManager.GetAccountFromToken(r.Context(), claims) + accountID, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims) if err != nil { util.WriteError(r.Context(), err, w) return } - switch r.Method { - case http.MethodGet: - vars := mux.Vars(r) - policyID := vars["policyId"] - if len(policyID) == 0 { - util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "invalid policy ID"), w) - return - } - - policy, err := h.accountManager.GetPolicy(r.Context(), account.Id, policyID, user.Id) - if err != nil { - util.WriteError(r.Context(), err, w) - return - } - - resp := toPolicyResponse(account, policy) - if len(resp.Rules) == 0 { - util.WriteError(r.Context(), status.Errorf(status.Internal, "no rules in the policy"), w) - return - } - - util.WriteJSONObject(r.Context(), w, resp) - default: - util.WriteError(r.Context(), status.Errorf(status.NotFound, "method not found"), w) + vars := mux.Vars(r) + policyID := vars["policyId"] + if len(policyID) == 0 { + util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "invalid policy ID"), w) + return } + + policy, err := h.accountManager.GetPolicy(r.Context(), accountID, policyID, userID) + if err != nil { + util.WriteError(r.Context(), err, w) + return + } + + allGroups, err := h.accountManager.GetAllGroups(r.Context(), accountID, userID) + if err != nil { + util.WriteError(r.Context(), err, w) + return + } + + resp := toPolicyResponse(allGroups, policy) + if len(resp.Rules) == 0 { + util.WriteError(r.Context(), status.Errorf(status.Internal, "no rules in the policy"), w) + return + } + + util.WriteJSONObject(r.Context(), w, resp) } -func toPolicyResponse(account *server.Account, policy *server.Policy) *api.Policy { +func toPolicyResponse(groups []*nbgroup.Group, policy *server.Policy) *api.Policy { + groupsMap := make(map[string]*nbgroup.Group) + for _, group := range groups { + groupsMap[group.ID] = group + } + cache := make(map[string]api.GroupMinimum) ap := &api.Policy{ Id: &policy.ID, @@ -306,16 +314,18 @@ func toPolicyResponse(account *server.Account, policy *server.Policy) *api.Polic Protocol: api.PolicyRuleProtocol(r.Protocol), Action: api.PolicyRuleAction(r.Action), } + if len(r.Ports) != 0 { portsCopy := r.Ports rule.Ports = &portsCopy } + for _, gid := range r.Sources { _, ok := cache[gid] if ok { continue } - if group, ok := account.Groups[gid]; ok { + if group, ok := groupsMap[gid]; ok { minimum := api.GroupMinimum{ Id: group.ID, Name: group.Name, @@ -325,13 +335,14 @@ func toPolicyResponse(account *server.Account, policy *server.Policy) *api.Polic cache[gid] = minimum } } + for _, gid := range r.Destinations { cachedMinimum, ok := cache[gid] if ok { rule.Destinations = append(rule.Destinations, cachedMinimum) continue } - if group, ok := account.Groups[gid]; ok { + if group, ok := groupsMap[gid]; ok { minimum := api.GroupMinimum{ Id: group.ID, Name: group.Name, @@ -345,28 +356,3 @@ func toPolicyResponse(account *server.Account, policy *server.Policy) *api.Polic } return ap } - -func groupMinimumsToStrings(account *server.Account, gm []string) []string { - result := make([]string, 0, len(gm)) - for _, g := range gm { - if _, ok := account.Groups[g]; !ok { - continue - } - result = append(result, g) - } - return result -} - -func sourcePostureChecksToStrings(account *server.Account, postureChecksIds []string) []string { - result := make([]string, 0, len(postureChecksIds)) - for _, id := range postureChecksIds { - for _, postureCheck := range account.PostureChecks { - if id == postureCheck.ID { - result = append(result, id) - continue - } - } - - } - return result -} diff --git a/management/server/http/policies_handler_test.go b/management/server/http/policies_handler_test.go index 06274fb07..228ebcbce 100644 --- a/management/server/http/policies_handler_test.go +++ b/management/server/http/policies_handler_test.go @@ -38,17 +38,23 @@ func initPoliciesTestData(policies ...*server.Policy) *Policies { } return policy, nil }, - SavePolicyFunc: func(_ context.Context, _, _ string, policy *server.Policy) error { + SavePolicyFunc: func(_ context.Context, _, _ string, policy *server.Policy, _ bool) error { if !strings.HasPrefix(policy.ID, "id-") { policy.ID = "id-was-set" policy.Rules[0].ID = "id-was-set" } return nil }, - GetAccountFromTokenFunc: func(_ context.Context, claims jwtclaims.AuthorizationClaims) (*server.Account, *server.User, error) { - user := server.NewAdminUser("test_user") + GetAllGroupsFunc: func(ctx context.Context, accountID, userID string) ([]*nbgroup.Group, error) { + return []*nbgroup.Group{{ID: "F"}, {ID: "G"}}, nil + }, + GetAccountIDFromTokenFunc: func(_ context.Context, claims jwtclaims.AuthorizationClaims) (string, string, error) { + return claims.AccountId, claims.UserId, nil + }, + GetAccountByIDFunc: func(ctx context.Context, accountID string, userID string) (*server.Account, error) { + user := server.NewAdminUser(userID) return &server.Account{ - Id: claims.AccountId, + Id: accountID, Domain: "hotmail.com", Policies: []*server.Policy{ {ID: "id-existed"}, @@ -60,7 +66,7 @@ func initPoliciesTestData(policies ...*server.Policy) *Policies { Users: map[string]*server.User{ "test_user": user, }, - }, user, nil + }, nil }, }, claimsExtractor: jwtclaims.NewClaimsExtractor( diff --git a/management/server/http/posture_checks_handler.go b/management/server/http/posture_checks_handler.go index 059cb3b80..1d020e9bc 100644 --- a/management/server/http/posture_checks_handler.go +++ b/management/server/http/posture_checks_handler.go @@ -37,20 +37,20 @@ func NewPostureChecksHandler(accountManager server.AccountManager, geolocationMa // GetAllPostureChecks list for the account func (p *PostureChecksHandler) GetAllPostureChecks(w http.ResponseWriter, r *http.Request) { claims := p.claimsExtractor.FromRequestContext(r) - account, user, err := p.accountManager.GetAccountFromToken(r.Context(), claims) + accountID, userID, err := p.accountManager.GetAccountIDFromToken(r.Context(), claims) if err != nil { util.WriteError(r.Context(), err, w) return } - accountPostureChecks, err := p.accountManager.ListPostureChecks(r.Context(), account.Id, user.Id) + listPostureChecks, err := p.accountManager.ListPostureChecks(r.Context(), accountID, userID) if err != nil { util.WriteError(r.Context(), err, w) return } - postureChecks := []*api.PostureCheck{} - for _, postureCheck := range accountPostureChecks { + postureChecks := make([]*api.PostureCheck, 0, len(listPostureChecks)) + for _, postureCheck := range listPostureChecks { postureChecks = append(postureChecks, postureCheck.ToAPIResponse()) } @@ -60,7 +60,7 @@ func (p *PostureChecksHandler) GetAllPostureChecks(w http.ResponseWriter, r *htt // UpdatePostureCheck handles update to a posture check identified by a given ID func (p *PostureChecksHandler) UpdatePostureCheck(w http.ResponseWriter, r *http.Request) { claims := p.claimsExtractor.FromRequestContext(r) - account, user, err := p.accountManager.GetAccountFromToken(r.Context(), claims) + accountID, userID, err := p.accountManager.GetAccountIDFromToken(r.Context(), claims) if err != nil { util.WriteError(r.Context(), err, w) return @@ -73,37 +73,31 @@ func (p *PostureChecksHandler) UpdatePostureCheck(w http.ResponseWriter, r *http return } - postureChecksIdx := -1 - for i, postureCheck := range account.PostureChecks { - if postureCheck.ID == postureChecksID { - postureChecksIdx = i - break - } - } - if postureChecksIdx < 0 { - util.WriteError(r.Context(), status.Errorf(status.NotFound, "couldn't find posture checks id %s", postureChecksID), w) + _, err = p.accountManager.GetPostureChecks(r.Context(), accountID, postureChecksID, userID) + if err != nil { + util.WriteError(r.Context(), err, w) return } - p.savePostureChecks(w, r, account, user, postureChecksID) + p.savePostureChecks(w, r, accountID, userID, postureChecksID) } // CreatePostureCheck handles posture check creation request func (p *PostureChecksHandler) CreatePostureCheck(w http.ResponseWriter, r *http.Request) { claims := p.claimsExtractor.FromRequestContext(r) - account, user, err := p.accountManager.GetAccountFromToken(r.Context(), claims) + accountID, userID, err := p.accountManager.GetAccountIDFromToken(r.Context(), claims) if err != nil { util.WriteError(r.Context(), err, w) return } - p.savePostureChecks(w, r, account, user, "") + p.savePostureChecks(w, r, accountID, userID, "") } // GetPostureCheck handles a posture check Get request identified by ID func (p *PostureChecksHandler) GetPostureCheck(w http.ResponseWriter, r *http.Request) { claims := p.claimsExtractor.FromRequestContext(r) - account, user, err := p.accountManager.GetAccountFromToken(r.Context(), claims) + accountID, userID, err := p.accountManager.GetAccountIDFromToken(r.Context(), claims) if err != nil { util.WriteError(r.Context(), err, w) return @@ -116,7 +110,7 @@ func (p *PostureChecksHandler) GetPostureCheck(w http.ResponseWriter, r *http.Re return } - postureChecks, err := p.accountManager.GetPostureChecks(r.Context(), account.Id, postureChecksID, user.Id) + postureChecks, err := p.accountManager.GetPostureChecks(r.Context(), accountID, postureChecksID, userID) if err != nil { util.WriteError(r.Context(), err, w) return @@ -128,7 +122,7 @@ func (p *PostureChecksHandler) GetPostureCheck(w http.ResponseWriter, r *http.Re // DeletePostureCheck handles posture check deletion request func (p *PostureChecksHandler) DeletePostureCheck(w http.ResponseWriter, r *http.Request) { claims := p.claimsExtractor.FromRequestContext(r) - account, user, err := p.accountManager.GetAccountFromToken(r.Context(), claims) + accountID, userID, err := p.accountManager.GetAccountIDFromToken(r.Context(), claims) if err != nil { util.WriteError(r.Context(), err, w) return @@ -141,7 +135,7 @@ func (p *PostureChecksHandler) DeletePostureCheck(w http.ResponseWriter, r *http return } - if err = p.accountManager.DeletePostureChecks(r.Context(), account.Id, postureChecksID, user.Id); err != nil { + if err = p.accountManager.DeletePostureChecks(r.Context(), accountID, postureChecksID, userID); err != nil { util.WriteError(r.Context(), err, w) return } @@ -150,13 +144,7 @@ func (p *PostureChecksHandler) DeletePostureCheck(w http.ResponseWriter, r *http } // savePostureChecks handles posture checks create and update -func (p *PostureChecksHandler) savePostureChecks( - w http.ResponseWriter, - r *http.Request, - account *server.Account, - user *server.User, - postureChecksID string, -) { +func (p *PostureChecksHandler) savePostureChecks(w http.ResponseWriter, r *http.Request, accountID, userID, postureChecksID string) { var ( err error req api.PostureCheckUpdate @@ -181,7 +169,7 @@ func (p *PostureChecksHandler) savePostureChecks( return } - if err := p.accountManager.SavePostureChecks(r.Context(), account.Id, user.Id, postureChecks); err != nil { + if err := p.accountManager.SavePostureChecks(r.Context(), accountID, userID, postureChecks); err != nil { util.WriteError(r.Context(), err, w) return } diff --git a/management/server/http/posture_checks_handler_test.go b/management/server/http/posture_checks_handler_test.go index 974edafde..02f0f0d83 100644 --- a/management/server/http/posture_checks_handler_test.go +++ b/management/server/http/posture_checks_handler_test.go @@ -14,7 +14,6 @@ import ( "github.com/gorilla/mux" "github.com/stretchr/testify/assert" - "github.com/netbirdio/netbird/management/server" "github.com/netbirdio/netbird/management/server/geolocation" "github.com/netbirdio/netbird/management/server/http/api" "github.com/netbirdio/netbird/management/server/jwtclaims" @@ -67,15 +66,8 @@ func initPostureChecksTestData(postureChecks ...*posture.Checks) *PostureChecksH } return accountPostureChecks, nil }, - GetAccountFromTokenFunc: func(_ context.Context, claims jwtclaims.AuthorizationClaims) (*server.Account, *server.User, error) { - user := server.NewAdminUser("test_user") - return &server.Account{ - Id: claims.AccountId, - Users: map[string]*server.User{ - "test_user": user, - }, - PostureChecks: postureChecks, - }, user, nil + GetAccountIDFromTokenFunc: func(_ context.Context, claims jwtclaims.AuthorizationClaims) (string, string, error) { + return claims.AccountId, claims.UserId, nil }, }, geolocationManager: &geolocation.Geolocation{}, diff --git a/management/server/http/routes_handler.go b/management/server/http/routes_handler.go index 18c347334..0932e6445 100644 --- a/management/server/http/routes_handler.go +++ b/management/server/http/routes_handler.go @@ -43,13 +43,13 @@ func NewRoutesHandler(accountManager server.AccountManager, authCfg AuthCfg) *Ro // GetAllRoutes returns the list of routes for the account func (h *RoutesHandler) GetAllRoutes(w http.ResponseWriter, r *http.Request) { claims := h.claimsExtractor.FromRequestContext(r) - account, user, err := h.accountManager.GetAccountFromToken(r.Context(), claims) + accountID, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims) if err != nil { util.WriteError(r.Context(), err, w) return } - routes, err := h.accountManager.ListRoutes(r.Context(), account.Id, user.Id) + routes, err := h.accountManager.ListRoutes(r.Context(), accountID, userID) if err != nil { util.WriteError(r.Context(), err, w) return @@ -70,7 +70,7 @@ func (h *RoutesHandler) GetAllRoutes(w http.ResponseWriter, r *http.Request) { // CreateRoute handles route creation request func (h *RoutesHandler) CreateRoute(w http.ResponseWriter, r *http.Request) { claims := h.claimsExtractor.FromRequestContext(r) - account, user, err := h.accountManager.GetAccountFromToken(r.Context(), claims) + accountID, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims) if err != nil { util.WriteError(r.Context(), err, w) return @@ -117,15 +117,9 @@ func (h *RoutesHandler) CreateRoute(w http.ResponseWriter, r *http.Request) { peerGroupIds = *req.PeerGroups } - // Do not allow non-Linux peers - if peer := account.GetPeer(peerId); peer != nil { - if peer.Meta.GoOS != "linux" { - util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "non-linux peers are not supported as network routes"), w) - return - } - } - - newRoute, err := h.accountManager.CreateRoute(r.Context(), account.Id, newPrefix, networkType, domains, peerId, peerGroupIds, req.Description, route.NetID(req.NetworkId), req.Masquerade, req.Metric, req.Groups, req.Enabled, user.Id, req.KeepRoute) + newRoute, err := h.accountManager.CreateRoute(r.Context(), accountID, newPrefix, networkType, domains, peerId, peerGroupIds, + req.Description, route.NetID(req.NetworkId), req.Masquerade, req.Metric, req.Groups, req.Enabled, userID, req.KeepRoute, + ) if err != nil { util.WriteError(r.Context(), err, w) return @@ -168,7 +162,7 @@ func (h *RoutesHandler) validateRoute(req api.PostApiRoutesJSONRequestBody) erro // UpdateRoute handles update to a route identified by a given ID func (h *RoutesHandler) UpdateRoute(w http.ResponseWriter, r *http.Request) { claims := h.claimsExtractor.FromRequestContext(r) - account, user, err := h.accountManager.GetAccountFromToken(r.Context(), claims) + accountID, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims) if err != nil { util.WriteError(r.Context(), err, w) return @@ -181,7 +175,7 @@ func (h *RoutesHandler) UpdateRoute(w http.ResponseWriter, r *http.Request) { return } - _, err = h.accountManager.GetRoute(r.Context(), account.Id, route.ID(routeID), user.Id) + _, err = h.accountManager.GetRoute(r.Context(), accountID, route.ID(routeID), userID) if err != nil { util.WriteError(r.Context(), err, w) return @@ -204,14 +198,6 @@ func (h *RoutesHandler) UpdateRoute(w http.ResponseWriter, r *http.Request) { peerID = *req.Peer } - // do not allow non Linux peers - if peer := account.GetPeer(peerID); peer != nil { - if peer.Meta.GoOS != "linux" { - util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "non-linux peers are non supported as network routes"), w) - return - } - } - newRoute := &route.Route{ ID: route.ID(routeID), NetID: route.NetID(req.NetworkId), @@ -247,7 +233,7 @@ func (h *RoutesHandler) UpdateRoute(w http.ResponseWriter, r *http.Request) { newRoute.PeerGroups = *req.PeerGroups } - err = h.accountManager.SaveRoute(r.Context(), account.Id, user.Id, newRoute) + err = h.accountManager.SaveRoute(r.Context(), accountID, userID, newRoute) if err != nil { util.WriteError(r.Context(), err, w) return @@ -265,7 +251,7 @@ func (h *RoutesHandler) UpdateRoute(w http.ResponseWriter, r *http.Request) { // DeleteRoute handles route deletion request func (h *RoutesHandler) DeleteRoute(w http.ResponseWriter, r *http.Request) { claims := h.claimsExtractor.FromRequestContext(r) - account, user, err := h.accountManager.GetAccountFromToken(r.Context(), claims) + accountID, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims) if err != nil { util.WriteError(r.Context(), err, w) return @@ -277,7 +263,7 @@ func (h *RoutesHandler) DeleteRoute(w http.ResponseWriter, r *http.Request) { return } - err = h.accountManager.DeleteRoute(r.Context(), account.Id, route.ID(routeID), user.Id) + err = h.accountManager.DeleteRoute(r.Context(), accountID, route.ID(routeID), userID) if err != nil { util.WriteError(r.Context(), err, w) return @@ -289,7 +275,7 @@ func (h *RoutesHandler) DeleteRoute(w http.ResponseWriter, r *http.Request) { // GetRoute handles a route Get request identified by ID func (h *RoutesHandler) GetRoute(w http.ResponseWriter, r *http.Request) { claims := h.claimsExtractor.FromRequestContext(r) - account, user, err := h.accountManager.GetAccountFromToken(r.Context(), claims) + accountID, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims) if err != nil { util.WriteError(r.Context(), err, w) return @@ -301,7 +287,7 @@ func (h *RoutesHandler) GetRoute(w http.ResponseWriter, r *http.Request) { return } - foundRoute, err := h.accountManager.GetRoute(r.Context(), account.Id, route.ID(routeID), user.Id) + foundRoute, err := h.accountManager.GetRoute(r.Context(), accountID, route.ID(routeID), userID) if err != nil { util.WriteError(r.Context(), status.Errorf(status.NotFound, "route not found"), w) return diff --git a/management/server/http/routes_handler_test.go b/management/server/http/routes_handler_test.go index 40075eb9d..2c367cac3 100644 --- a/management/server/http/routes_handler_test.go +++ b/management/server/http/routes_handler_test.go @@ -112,6 +112,12 @@ func initRoutesTestData() *RoutesHandler { if len(peerGroups) > 0 && peerGroups[0] == notFoundGroupID { return nil, status.Errorf(status.InvalidArgument, "peer groups with ID %s not found", peerGroups[0]) } + if peerID != "" { + if peerID == nonLinuxExistingPeerID { + return nil, status.Errorf(status.InvalidArgument, "non-linux peers are not supported as network routes") + } + } + return &route.Route{ ID: existingRouteID, NetID: netID, @@ -131,6 +137,11 @@ func initRoutesTestData() *RoutesHandler { if r.Peer == notFoundPeerID { return status.Errorf(status.InvalidArgument, "peer with ID %s not found", r.Peer) } + + if r.Peer == nonLinuxExistingPeerID { + return status.Errorf(status.InvalidArgument, "non-linux peers are not supported as network routes") + } + return nil }, DeleteRouteFunc: func(_ context.Context, _ string, routeID route.ID, _ string) error { @@ -139,8 +150,9 @@ func initRoutesTestData() *RoutesHandler { } return nil }, - GetAccountFromTokenFunc: func(_ context.Context, _ jwtclaims.AuthorizationClaims) (*server.Account, *server.User, error) { - return testingAccount, testingAccount.Users["test_user"], nil + GetAccountIDFromTokenFunc: func(_ context.Context, _ jwtclaims.AuthorizationClaims) (string, string, error) { + //return testingAccount, testingAccount.Users["test_user"], nil + return testingAccount.Id, testingAccount.Users["test_user"].Id, nil }, }, claimsExtractor: jwtclaims.NewClaimsExtractor( diff --git a/management/server/http/setupkeys_handler.go b/management/server/http/setupkeys_handler.go index 8ee7dfaba..8514f0b55 100644 --- a/management/server/http/setupkeys_handler.go +++ b/management/server/http/setupkeys_handler.go @@ -35,7 +35,7 @@ func NewSetupKeysHandler(accountManager server.AccountManager, authCfg AuthCfg) // CreateSetupKey is a POST requests that creates a new SetupKey func (h *SetupKeysHandler) CreateSetupKey(w http.ResponseWriter, r *http.Request) { claims := h.claimsExtractor.FromRequestContext(r) - account, user, err := h.accountManager.GetAccountFromToken(r.Context(), claims) + accountID, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims) if err != nil { util.WriteError(r.Context(), err, w) return @@ -76,8 +76,8 @@ func (h *SetupKeysHandler) CreateSetupKey(w http.ResponseWriter, r *http.Request if req.Ephemeral != nil { ephemeral = *req.Ephemeral } - setupKey, err := h.accountManager.CreateSetupKey(r.Context(), account.Id, req.Name, server.SetupKeyType(req.Type), expiresIn, - req.AutoGroups, req.UsageLimit, user.Id, ephemeral) + setupKey, err := h.accountManager.CreateSetupKey(r.Context(), accountID, req.Name, server.SetupKeyType(req.Type), expiresIn, + req.AutoGroups, req.UsageLimit, userID, ephemeral) if err != nil { util.WriteError(r.Context(), err, w) return @@ -89,7 +89,7 @@ func (h *SetupKeysHandler) CreateSetupKey(w http.ResponseWriter, r *http.Request // GetSetupKey is a GET request to get a SetupKey by ID func (h *SetupKeysHandler) GetSetupKey(w http.ResponseWriter, r *http.Request) { claims := h.claimsExtractor.FromRequestContext(r) - account, user, err := h.accountManager.GetAccountFromToken(r.Context(), claims) + accountID, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims) if err != nil { util.WriteError(r.Context(), err, w) return @@ -102,7 +102,7 @@ func (h *SetupKeysHandler) GetSetupKey(w http.ResponseWriter, r *http.Request) { return } - key, err := h.accountManager.GetSetupKey(r.Context(), account.Id, user.Id, keyID) + key, err := h.accountManager.GetSetupKey(r.Context(), accountID, userID, keyID) if err != nil { util.WriteError(r.Context(), err, w) return @@ -114,7 +114,7 @@ func (h *SetupKeysHandler) GetSetupKey(w http.ResponseWriter, r *http.Request) { // UpdateSetupKey is a PUT request to update server.SetupKey func (h *SetupKeysHandler) UpdateSetupKey(w http.ResponseWriter, r *http.Request) { claims := h.claimsExtractor.FromRequestContext(r) - account, user, err := h.accountManager.GetAccountFromToken(r.Context(), claims) + accountID, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims) if err != nil { util.WriteError(r.Context(), err, w) return @@ -150,7 +150,7 @@ func (h *SetupKeysHandler) UpdateSetupKey(w http.ResponseWriter, r *http.Request newKey.Name = req.Name newKey.Id = keyID - newKey, err = h.accountManager.SaveSetupKey(r.Context(), account.Id, newKey, user.Id) + newKey, err = h.accountManager.SaveSetupKey(r.Context(), accountID, newKey, userID) if err != nil { util.WriteError(r.Context(), err, w) return @@ -161,13 +161,13 @@ func (h *SetupKeysHandler) UpdateSetupKey(w http.ResponseWriter, r *http.Request // GetAllSetupKeys is a GET request that returns a list of SetupKey func (h *SetupKeysHandler) GetAllSetupKeys(w http.ResponseWriter, r *http.Request) { claims := h.claimsExtractor.FromRequestContext(r) - account, user, err := h.accountManager.GetAccountFromToken(r.Context(), claims) + accountID, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims) if err != nil { util.WriteError(r.Context(), err, w) return } - setupKeys, err := h.accountManager.ListSetupKeys(r.Context(), account.Id, user.Id) + setupKeys, err := h.accountManager.ListSetupKeys(r.Context(), accountID, userID) if err != nil { util.WriteError(r.Context(), err, w) return diff --git a/management/server/http/setupkeys_handler_test.go b/management/server/http/setupkeys_handler_test.go index bfa0ec008..2d15287af 100644 --- a/management/server/http/setupkeys_handler_test.go +++ b/management/server/http/setupkeys_handler_test.go @@ -15,7 +15,6 @@ import ( "github.com/stretchr/testify/assert" "github.com/netbirdio/netbird/management/server" - nbgroup "github.com/netbirdio/netbird/management/server/group" "github.com/netbirdio/netbird/management/server/http/api" "github.com/netbirdio/netbird/management/server/jwtclaims" "github.com/netbirdio/netbird/management/server/mock_server" @@ -34,21 +33,8 @@ func initSetupKeysTestMetaData(defaultKey *server.SetupKey, newKey *server.Setup ) *SetupKeysHandler { return &SetupKeysHandler{ accountManager: &mock_server.MockAccountManager{ - GetAccountFromTokenFunc: func(_ context.Context, claims jwtclaims.AuthorizationClaims) (*server.Account, *server.User, error) { - return &server.Account{ - Id: testAccountID, - Domain: "hotmail.com", - Users: map[string]*server.User{ - user.Id: user, - }, - SetupKeys: map[string]*server.SetupKey{ - defaultKey.Key: defaultKey, - }, - Groups: map[string]*nbgroup.Group{ - "group-1": {ID: "group-1", Peers: []string{"A", "B"}}, - "id-all": {ID: "id-all", Name: "All"}, - }, - }, user, nil + GetAccountIDFromTokenFunc: func(_ context.Context, claims jwtclaims.AuthorizationClaims) (string, string, error) { + return claims.AccountId, claims.UserId, nil }, CreateSetupKeyFunc: func(_ context.Context, _ string, keyName string, typ server.SetupKeyType, _ time.Duration, _ []string, _ int, _ string, ephemeral bool, diff --git a/management/server/http/users_handler.go b/management/server/http/users_handler.go index 2c2aed842..6e151a0da 100644 --- a/management/server/http/users_handler.go +++ b/management/server/http/users_handler.go @@ -41,22 +41,22 @@ func (h *UsersHandler) UpdateUser(w http.ResponseWriter, r *http.Request) { } claims := h.claimsExtractor.FromRequestContext(r) - account, user, err := h.accountManager.GetAccountFromToken(r.Context(), claims) + accountID, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims) if err != nil { util.WriteError(r.Context(), err, w) return } vars := mux.Vars(r) - userID := vars["userId"] - if len(userID) == 0 { + targetUserID := vars["userId"] + if len(targetUserID) == 0 { util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "invalid user ID"), w) return } - existingUser, ok := account.Users[userID] - if !ok { - util.WriteError(r.Context(), status.Errorf(status.NotFound, "couldn't find user with ID %s", userID), w) + existingUser, err := h.accountManager.GetUserByID(r.Context(), targetUserID) + if err != nil { + util.WriteError(r.Context(), err, w) return } @@ -78,8 +78,8 @@ func (h *UsersHandler) UpdateUser(w http.ResponseWriter, r *http.Request) { return } - newUser, err := h.accountManager.SaveUser(r.Context(), account.Id, user.Id, &server.User{ - Id: userID, + newUser, err := h.accountManager.SaveUser(r.Context(), accountID, userID, &server.User{ + Id: targetUserID, Role: userRole, AutoGroups: req.AutoGroups, Blocked: req.IsBlocked, @@ -102,7 +102,7 @@ func (h *UsersHandler) DeleteUser(w http.ResponseWriter, r *http.Request) { } claims := h.claimsExtractor.FromRequestContext(r) - account, user, err := h.accountManager.GetAccountFromToken(r.Context(), claims) + accountID, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims) if err != nil { util.WriteError(r.Context(), err, w) return @@ -115,7 +115,7 @@ func (h *UsersHandler) DeleteUser(w http.ResponseWriter, r *http.Request) { return } - err = h.accountManager.DeleteUser(r.Context(), account.Id, user.Id, targetUserID) + err = h.accountManager.DeleteUser(r.Context(), accountID, userID, targetUserID) if err != nil { util.WriteError(r.Context(), err, w) return @@ -132,7 +132,7 @@ func (h *UsersHandler) CreateUser(w http.ResponseWriter, r *http.Request) { } claims := h.claimsExtractor.FromRequestContext(r) - account, user, err := h.accountManager.GetAccountFromToken(r.Context(), claims) + accountID, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims) if err != nil { util.WriteError(r.Context(), err, w) return @@ -160,7 +160,7 @@ func (h *UsersHandler) CreateUser(w http.ResponseWriter, r *http.Request) { name = *req.Name } - newUser, err := h.accountManager.CreateUser(r.Context(), account.Id, user.Id, &server.UserInfo{ + newUser, err := h.accountManager.CreateUser(r.Context(), accountID, userID, &server.UserInfo{ Email: email, Name: name, Role: req.Role, @@ -184,13 +184,13 @@ func (h *UsersHandler) GetAllUsers(w http.ResponseWriter, r *http.Request) { } claims := h.claimsExtractor.FromRequestContext(r) - account, user, err := h.accountManager.GetAccountFromToken(r.Context(), claims) + accountID, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims) if err != nil { util.WriteError(r.Context(), err, w) return } - data, err := h.accountManager.GetUsersFromAccount(r.Context(), account.Id, user.Id) + data, err := h.accountManager.GetUsersFromAccount(r.Context(), accountID, userID) if err != nil { util.WriteError(r.Context(), err, w) return @@ -231,7 +231,7 @@ func (h *UsersHandler) InviteUser(w http.ResponseWriter, r *http.Request) { } claims := h.claimsExtractor.FromRequestContext(r) - account, user, err := h.accountManager.GetAccountFromToken(r.Context(), claims) + accountID, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims) if err != nil { util.WriteError(r.Context(), err, w) return @@ -244,7 +244,7 @@ func (h *UsersHandler) InviteUser(w http.ResponseWriter, r *http.Request) { return } - err = h.accountManager.InviteUser(r.Context(), account.Id, user.Id, targetUserID) + err = h.accountManager.InviteUser(r.Context(), accountID, userID, targetUserID) if err != nil { util.WriteError(r.Context(), err, w) return diff --git a/management/server/http/users_handler_test.go b/management/server/http/users_handler_test.go index a78ac3a4e..f3d989da1 100644 --- a/management/server/http/users_handler_test.go +++ b/management/server/http/users_handler_test.go @@ -64,8 +64,11 @@ var usersTestAccount = &server.Account{ func initUsersTestData() *UsersHandler { return &UsersHandler{ accountManager: &mock_server.MockAccountManager{ - GetAccountFromTokenFunc: func(_ context.Context, claims jwtclaims.AuthorizationClaims) (*server.Account, *server.User, error) { - return usersTestAccount, usersTestAccount.Users[claims.UserId], nil + GetAccountIDFromTokenFunc: func(_ context.Context, claims jwtclaims.AuthorizationClaims) (string, string, error) { + return usersTestAccount.Id, claims.UserId, nil + }, + GetUserByIDFunc: func(ctx context.Context, id string) (*server.User, error) { + return usersTestAccount.Users[id], nil }, GetUsersFromAccountFunc: func(_ context.Context, accountID, userID string) ([]*server.UserInfo, error) { users := make([]*server.UserInfo, 0) diff --git a/management/server/mock_server/account_mock.go b/management/server/mock_server/account_mock.go index 495325252..df12ec1c4 100644 --- a/management/server/mock_server/account_mock.go +++ b/management/server/mock_server/account_mock.go @@ -23,10 +23,11 @@ import ( type MockAccountManager struct { GetOrCreateAccountByUserFunc func(ctx context.Context, userId, domain string) (*server.Account, error) + GetAccountFunc func(ctx context.Context, accountID string) (*server.Account, error) CreateSetupKeyFunc func(ctx context.Context, accountId string, keyName string, keyType server.SetupKeyType, expiresIn time.Duration, autoGroups []string, usageLimit int, userID string, ephemeral bool) (*server.SetupKey, error) GetSetupKeyFunc func(ctx context.Context, accountID, userID, keyID string) (*server.SetupKey, error) - GetAccountByUserOrAccountIdFunc func(ctx context.Context, userId, accountId, domain string) (*server.Account, error) + GetAccountIDByUserOrAccountIdFunc func(ctx context.Context, userId, accountId, domain string) (string, error) GetUserFunc func(ctx context.Context, claims jwtclaims.AuthorizationClaims) (*server.User, error) ListUsersFunc func(ctx context.Context, accountID string) ([]*server.User, error) GetPeersFunc func(ctx context.Context, accountID, userID string) ([]*nbpeer.Peer, error) @@ -48,7 +49,7 @@ type MockAccountManager struct { GroupDeletePeerFunc func(ctx context.Context, accountID, groupID, peerID string) error DeleteRuleFunc func(ctx context.Context, accountID, ruleID, userID string) error GetPolicyFunc func(ctx context.Context, accountID, policyID, userID string) (*server.Policy, error) - SavePolicyFunc func(ctx context.Context, accountID, userID string, policy *server.Policy) error + SavePolicyFunc func(ctx context.Context, accountID, userID string, policy *server.Policy, isUpdate bool) error DeletePolicyFunc func(ctx context.Context, accountID, policyID, userID string) error ListPoliciesFunc func(ctx context.Context, accountID, userID string) ([]*server.Policy, error) GetUsersFromAccountFunc func(ctx context.Context, accountID, userID string) ([]*server.UserInfo, error) @@ -79,7 +80,7 @@ type MockAccountManager struct { DeleteNameServerGroupFunc func(ctx context.Context, accountID, nsGroupID, userID string) error ListNameServerGroupsFunc func(ctx context.Context, accountID string, userID string) ([]*nbdns.NameServerGroup, error) CreateUserFunc func(ctx context.Context, accountID, userID string, key *server.UserInfo) (*server.UserInfo, error) - GetAccountFromTokenFunc func(ctx context.Context, claims jwtclaims.AuthorizationClaims) (*server.Account, *server.User, error) + GetAccountIDFromTokenFunc func(ctx context.Context, claims jwtclaims.AuthorizationClaims) (string, string, error) CheckUserAccessByJWTGroupsFunc func(ctx context.Context, claims jwtclaims.AuthorizationClaims) error DeleteAccountFunc func(ctx context.Context, accountID, userID string) error GetDNSDomainFunc func() string @@ -105,6 +106,9 @@ type MockAccountManager struct { SyncPeerMetaFunc func(ctx context.Context, peerPubKey string, meta nbpeer.PeerSystemMeta) error FindExistingPostureCheckFunc func(accountID string, checks *posture.ChecksDefinition) (*posture.Checks, error) GetAccountIDForPeerKeyFunc func(ctx context.Context, peerKey string) (string, error) + GetAccountByIDFunc func(ctx context.Context, accountID string, userID string) (*server.Account, error) + GetUserByIDFunc func(ctx context.Context, id string) (*server.User, error) + GetAccountSettingsFunc func(ctx context.Context, accountID string, userID string) (*server.Settings, error) } func (am *MockAccountManager) SyncAndMarkPeer(ctx context.Context, accountID string, peerPubKey string, meta nbpeer.PeerSystemMeta, realIP net.IP) (*nbpeer.Peer, *server.NetworkMap, []*posture.Checks, error) { @@ -190,16 +194,14 @@ func (am *MockAccountManager) CreateSetupKey( return nil, status.Errorf(codes.Unimplemented, "method CreateSetupKey is not implemented") } -// GetAccountByUserOrAccountID mock implementation of GetAccountByUserOrAccountID from server.AccountManager interface -func (am *MockAccountManager) GetAccountByUserOrAccountID( - ctx context.Context, userId, accountId, domain string, -) (*server.Account, error) { - if am.GetAccountByUserOrAccountIdFunc != nil { - return am.GetAccountByUserOrAccountIdFunc(ctx, userId, accountId, domain) +// GetAccountIDByUserOrAccountID mock implementation of GetAccountIDByUserOrAccountID from server.AccountManager interface +func (am *MockAccountManager) GetAccountIDByUserOrAccountID(ctx context.Context, userId, accountId, domain string) (string, error) { + if am.GetAccountIDByUserOrAccountIdFunc != nil { + return am.GetAccountIDByUserOrAccountIdFunc(ctx, userId, accountId, domain) } - return nil, status.Errorf( + return "", status.Errorf( codes.Unimplemented, - "method GetAccountByUserOrAccountID is not implemented", + "method GetAccountIDByUserOrAccountID is not implemented", ) } @@ -377,9 +379,9 @@ func (am *MockAccountManager) GetPolicy(ctx context.Context, accountID, policyID } // SavePolicy mock implementation of SavePolicy from server.AccountManager interface -func (am *MockAccountManager) SavePolicy(ctx context.Context, accountID, userID string, policy *server.Policy) error { +func (am *MockAccountManager) SavePolicy(ctx context.Context, accountID, userID string, policy *server.Policy, isUpdate bool) error { if am.SavePolicyFunc != nil { - return am.SavePolicyFunc(ctx, accountID, userID, policy) + return am.SavePolicyFunc(ctx, accountID, userID, policy, isUpdate) } return status.Errorf(codes.Unimplemented, "method SavePolicy is not implemented") } @@ -601,14 +603,12 @@ func (am *MockAccountManager) CreateUser(ctx context.Context, accountID, userID return nil, status.Errorf(codes.Unimplemented, "method CreateUser is not implemented") } -// GetAccountFromToken mocks GetAccountFromToken of the AccountManager interface -func (am *MockAccountManager) GetAccountFromToken(ctx context.Context, claims jwtclaims.AuthorizationClaims) (*server.Account, *server.User, - error, -) { - if am.GetAccountFromTokenFunc != nil { - return am.GetAccountFromTokenFunc(ctx, claims) +// GetAccountIDFromToken mocks GetAccountIDFromToken of the AccountManager interface +func (am *MockAccountManager) GetAccountIDFromToken(ctx context.Context, claims jwtclaims.AuthorizationClaims) (string, string, error) { + if am.GetAccountIDFromTokenFunc != nil { + return am.GetAccountIDFromTokenFunc(ctx, claims) } - return nil, nil, status.Errorf(codes.Unimplemented, "method GetAccountFromToken is not implemented") + return "", "", status.Errorf(codes.Unimplemented, "method GetAccountIDFromToken is not implemented") } func (am *MockAccountManager) CheckUserAccessByJWTGroups(ctx context.Context, claims jwtclaims.AuthorizationClaims) error { @@ -802,3 +802,33 @@ func (am *MockAccountManager) GetAccountIDForPeerKey(ctx context.Context, peerKe } return "", status.Errorf(codes.Unimplemented, "method GetAccountIDForPeerKey is not implemented") } + +// GetAccountByID mocks GetAccountByID of the AccountManager interface +func (am *MockAccountManager) GetAccountByID(ctx context.Context, accountID string, userID string) (*server.Account, error) { + if am.GetAccountByIDFunc != nil { + return am.GetAccountByIDFunc(ctx, accountID, userID) + } + return nil, status.Errorf(codes.Unimplemented, "method GetAccountByID is not implemented") +} + +// GetUserByID mocks GetUserByID of the AccountManager interface +func (am *MockAccountManager) GetUserByID(ctx context.Context, id string) (*server.User, error) { + if am.GetUserByIDFunc != nil { + return am.GetUserByIDFunc(ctx, id) + } + return nil, status.Errorf(codes.Unimplemented, "method GetUserByID is not implemented") +} + +func (am *MockAccountManager) GetAccountSettings(ctx context.Context, accountID string, userID string) (*server.Settings, error) { + if am.GetAccountSettingsFunc != nil { + return am.GetAccountSettingsFunc(ctx, accountID, userID) + } + return nil, status.Errorf(codes.Unimplemented, "method GetAccountSettings is not implemented") +} + +func (am *MockAccountManager) GetAccount(ctx context.Context, accountID string) (*server.Account, error) { + if am.GetAccountFunc != nil { + return am.GetAccountFunc(ctx, accountID) + } + return nil, status.Errorf(codes.Unimplemented, "method GetAccount is not implemented") +} diff --git a/management/server/nameserver.go b/management/server/nameserver.go index 636f7cfee..0eb5d9ae4 100644 --- a/management/server/nameserver.go +++ b/management/server/nameserver.go @@ -19,30 +19,16 @@ const domainPattern = `^(?i)[a-z0-9]+([\-\.]{1}[a-z0-9]+)*\.[a-z]{2,}$` // GetNameServerGroup gets a nameserver group object from account and nameserver group IDs func (am *DefaultAccountManager) GetNameServerGroup(ctx context.Context, accountID, userID, nsGroupID string) (*nbdns.NameServerGroup, error) { - - unlock := am.Store.AcquireWriteLockByUID(ctx, accountID) - defer unlock() - - account, err := am.Store.GetAccount(ctx, accountID) + user, err := am.Store.GetUserByUserID(ctx, LockingStrengthShare, userID) if err != nil { return nil, err } - user, err := account.FindUser(userID) - if err != nil { - return nil, err + if !user.IsAdminOrServiceUser() || user.AccountID != accountID { + return nil, status.Errorf(status.PermissionDenied, "only users with admin power can view name server groups") } - if !(user.HasAdminPower() || user.IsServiceUser) { - return nil, status.Errorf(status.PermissionDenied, "only users with admin power can view nameserver groups") - } - - nsGroup, found := account.NameServerGroups[nsGroupID] - if found { - return nsGroup.Copy(), nil - } - - return nil, status.Errorf(status.NotFound, "nameserver group with ID %s not found", nsGroupID) + return am.Store.GetNameServerGroupByID(ctx, LockingStrengthShare, nsGroupID, accountID) } // CreateNameServerGroup creates and saves a new nameserver group @@ -159,30 +145,16 @@ func (am *DefaultAccountManager) DeleteNameServerGroup(ctx context.Context, acco // ListNameServerGroups returns a list of nameserver groups from account func (am *DefaultAccountManager) ListNameServerGroups(ctx context.Context, accountID string, userID string) ([]*nbdns.NameServerGroup, error) { - - unlock := am.Store.AcquireWriteLockByUID(ctx, accountID) - defer unlock() - - account, err := am.Store.GetAccount(ctx, accountID) + user, err := am.Store.GetUserByUserID(ctx, LockingStrengthShare, userID) if err != nil { return nil, err } - user, err := account.FindUser(userID) - if err != nil { - return nil, err - } - - if !(user.HasAdminPower() || user.IsServiceUser) { + if !user.IsAdminOrServiceUser() || user.AccountID != accountID { return nil, status.Errorf(status.PermissionDenied, "only users with admin power can view name server groups") } - nsGroups := make([]*nbdns.NameServerGroup, 0, len(account.NameServerGroups)) - for _, item := range account.NameServerGroups { - nsGroups = append(nsGroups, item.Copy()) - } - - return nsGroups, nil + return am.Store.GetAccountNameServerGroups(ctx, LockingStrengthShare, accountID) } func validateNameServerGroup(existingGroup bool, nameserverGroup *nbdns.NameServerGroup, account *Account) error { diff --git a/management/server/peer_test.go b/management/server/peer_test.go index 4b2ec66c6..d329e04bc 100644 --- a/management/server/peer_test.go +++ b/management/server/peer_test.go @@ -251,7 +251,7 @@ func TestAccountManager_GetNetworkMapWithPolicy(t *testing.T) { Action: PolicyTrafficActionAccept, }, } - err = manager.SavePolicy(context.Background(), account.Id, userID, &policy) + err = manager.SavePolicy(context.Background(), account.Id, userID, &policy, false) if err != nil { t.Errorf("expecting rule to be added, got failure %v", err) return @@ -299,7 +299,7 @@ func TestAccountManager_GetNetworkMapWithPolicy(t *testing.T) { } policy.Enabled = false - err = manager.SavePolicy(context.Background(), account.Id, userID, &policy) + err = manager.SavePolicy(context.Background(), account.Id, userID, &policy, true) if err != nil { t.Errorf("expecting rule to be added, got failure %v", err) return diff --git a/management/server/policy.go b/management/server/policy.go index aaf9b6e72..5d07ba8f8 100644 --- a/management/server/policy.go +++ b/management/server/policy.go @@ -3,6 +3,7 @@ package server import ( "context" _ "embed" + "slices" "strconv" "strings" @@ -314,34 +315,20 @@ func (a *Account) connResourcesGenerator(ctx context.Context) (func(*PolicyRule, // GetPolicy from the store func (am *DefaultAccountManager) GetPolicy(ctx context.Context, accountID, policyID, userID string) (*Policy, error) { - unlock := am.Store.AcquireWriteLockByUID(ctx, accountID) - defer unlock() - - account, err := am.Store.GetAccount(ctx, accountID) + user, err := am.Store.GetUserByUserID(ctx, LockingStrengthShare, userID) if err != nil { return nil, err } - user, err := account.FindUser(userID) - if err != nil { - return nil, err - } - - if !(user.HasAdminPower() || user.IsServiceUser) { + if !user.IsAdminOrServiceUser() || user.AccountID != accountID { return nil, status.Errorf(status.PermissionDenied, "only users with admin power are allowed to view policies") } - for _, policy := range account.Policies { - if policy.ID == policyID { - return policy, nil - } - } - - return nil, status.Errorf(status.NotFound, "policy with ID %s not found", policyID) + return am.Store.GetPolicyByID(ctx, LockingStrengthShare, policyID, accountID) } // SavePolicy in the store -func (am *DefaultAccountManager) SavePolicy(ctx context.Context, accountID, userID string, policy *Policy) error { +func (am *DefaultAccountManager) SavePolicy(ctx context.Context, accountID, userID string, policy *Policy, isUpdate bool) error { unlock := am.Store.AcquireWriteLockByUID(ctx, accountID) defer unlock() @@ -350,7 +337,9 @@ func (am *DefaultAccountManager) SavePolicy(ctx context.Context, accountID, user return err } - exists := am.savePolicy(account, policy) + if err = am.savePolicy(account, policy, isUpdate); err != nil { + return err + } account.Network.IncSerial() if err = am.Store.SaveAccount(ctx, account); err != nil { @@ -358,7 +347,7 @@ func (am *DefaultAccountManager) SavePolicy(ctx context.Context, accountID, user } action := activity.PolicyAdded - if exists { + if isUpdate { action = activity.PolicyUpdated } am.StoreEvent(ctx, userID, policy.ID, accountID, action, policy.EventMeta()) @@ -397,24 +386,16 @@ func (am *DefaultAccountManager) DeletePolicy(ctx context.Context, accountID, po // ListPolicies from the store func (am *DefaultAccountManager) ListPolicies(ctx context.Context, accountID, userID string) ([]*Policy, error) { - unlock := am.Store.AcquireWriteLockByUID(ctx, accountID) - defer unlock() - - account, err := am.Store.GetAccount(ctx, accountID) + user, err := am.Store.GetUserByUserID(ctx, LockingStrengthShare, userID) if err != nil { return nil, err } - user, err := account.FindUser(userID) - if err != nil { - return nil, err + if !user.IsAdminOrServiceUser() || user.AccountID != accountID { + return nil, status.Errorf(status.PermissionDenied, "only users with admin power are allowed to view policies") } - if !(user.HasAdminPower() || user.IsServiceUser) { - return nil, status.Errorf(status.PermissionDenied, "only users with admin power can view policies") - } - - return account.Policies, nil + return am.Store.GetAccountPolicies(ctx, LockingStrengthShare, accountID) } func (am *DefaultAccountManager) deletePolicy(account *Account, policyID string) (*Policy, error) { @@ -434,18 +415,34 @@ func (am *DefaultAccountManager) deletePolicy(account *Account, policyID string) return policy, nil } -func (am *DefaultAccountManager) savePolicy(account *Account, policy *Policy) (exists bool) { - for i, p := range account.Policies { - if p.ID == policy.ID { - account.Policies[i] = policy - exists = true - break +// savePolicy saves or updates a policy in the given account. +// If isUpdate is true, the function updates the existing policy; otherwise, it adds a new policy. +func (am *DefaultAccountManager) savePolicy(account *Account, policyToSave *Policy, isUpdate bool) error { + for index, rule := range policyToSave.Rules { + rule.Sources = filterValidGroupIDs(account, rule.Sources) + rule.Destinations = filterValidGroupIDs(account, rule.Destinations) + policyToSave.Rules[index] = rule + } + + if policyToSave.SourcePostureChecks != nil { + policyToSave.SourcePostureChecks = filterValidPostureChecks(account, policyToSave.SourcePostureChecks) + } + + if isUpdate { + policyIdx := slices.IndexFunc(account.Policies, func(policy *Policy) bool { return policy.ID == policyToSave.ID }) + if policyIdx < 0 { + return status.Errorf(status.NotFound, "couldn't find policy id %s", policyToSave.ID) } + + // Update the existing policy + account.Policies[policyIdx] = policyToSave + return nil } - if !exists { - account.Policies = append(account.Policies, policy) - } - return + + // Add the new policy to the account + account.Policies = append(account.Policies, policyToSave) + + return nil } func toProtocolFirewallRules(update []*FirewallRule) []*proto.FirewallRule { @@ -560,3 +557,29 @@ func (a *Account) getPostureChecks(postureChecksID string) *posture.Checks { } return nil } + +// filterValidPostureChecks filters and returns the posture check IDs from the given list +// that are valid within the provided account. +func filterValidPostureChecks(account *Account, postureChecksIds []string) []string { + result := make([]string, 0, len(postureChecksIds)) + for _, id := range postureChecksIds { + for _, postureCheck := range account.PostureChecks { + if id == postureCheck.ID { + result = append(result, id) + continue + } + } + } + return result +} + +// filterValidGroupIDs filters a list of group IDs and returns only the ones present in the account's group map. +func filterValidGroupIDs(account *Account, groupIDs []string) []string { + result := make([]string, 0, len(groupIDs)) + for _, groupID := range groupIDs { + if _, exists := account.Groups[groupID]; exists { + result = append(result, groupID) + } + } + return result +} diff --git a/management/server/posture_checks.go b/management/server/posture_checks.go index 4180550e6..9a4b679ce 100644 --- a/management/server/posture_checks.go +++ b/management/server/posture_checks.go @@ -15,30 +15,16 @@ const ( ) func (am *DefaultAccountManager) GetPostureChecks(ctx context.Context, accountID, postureChecksID, userID string) (*posture.Checks, error) { - unlock := am.Store.AcquireWriteLockByUID(ctx, accountID) - defer unlock() - - account, err := am.Store.GetAccount(ctx, accountID) + user, err := am.Store.GetUserByUserID(ctx, LockingStrengthShare, userID) if err != nil { return nil, err } - user, err := account.FindUser(userID) - if err != nil { - return nil, err - } - - if !user.HasAdminPower() { + if !user.HasAdminPower() || user.AccountID != accountID { return nil, status.Errorf(status.PermissionDenied, errMsgPostureAdminOnly) } - for _, postureChecks := range account.PostureChecks { - if postureChecks.ID == postureChecksID { - return postureChecks, nil - } - } - - return nil, status.Errorf(status.NotFound, "posture checks with ID %s not found", postureChecksID) + return am.Store.GetPostureChecksByID(ctx, LockingStrengthShare, postureChecksID, accountID) } func (am *DefaultAccountManager) SavePostureChecks(ctx context.Context, accountID, userID string, postureChecks *posture.Checks) error { @@ -121,24 +107,16 @@ func (am *DefaultAccountManager) DeletePostureChecks(ctx context.Context, accoun } func (am *DefaultAccountManager) ListPostureChecks(ctx context.Context, accountID, userID string) ([]*posture.Checks, error) { - unlock := am.Store.AcquireWriteLockByUID(ctx, accountID) - defer unlock() - - account, err := am.Store.GetAccount(ctx, accountID) + user, err := am.Store.GetUserByUserID(ctx, LockingStrengthShare, userID) if err != nil { return nil, err } - user, err := account.FindUser(userID) - if err != nil { - return nil, err - } - - if !user.HasAdminPower() { + if !user.HasAdminPower() || user.AccountID != accountID { return nil, status.Errorf(status.PermissionDenied, errMsgPostureAdminOnly) } - return account.PostureChecks, nil + return am.Store.GetAccountPostureChecks(ctx, LockingStrengthShare, accountID) } func (am *DefaultAccountManager) savePostureChecks(account *Account, postureChecks *posture.Checks) (exists, uniqName bool) { diff --git a/management/server/route.go b/management/server/route.go index 064f3c105..6c1c8b1b3 100644 --- a/management/server/route.go +++ b/management/server/route.go @@ -17,29 +17,16 @@ import ( // GetRoute gets a route object from account and route IDs func (am *DefaultAccountManager) GetRoute(ctx context.Context, accountID string, routeID route.ID, userID string) (*route.Route, error) { - unlock := am.Store.AcquireWriteLockByUID(ctx, accountID) - defer unlock() - - account, err := am.Store.GetAccount(ctx, accountID) + user, err := am.Store.GetUserByUserID(ctx, LockingStrengthShare, userID) if err != nil { return nil, err } - user, err := account.FindUser(userID) - if err != nil { - return nil, err - } - - if !(user.HasAdminPower() || user.IsServiceUser) { + if !user.IsAdminOrServiceUser() || user.AccountID != accountID { return nil, status.Errorf(status.PermissionDenied, "only users with admin power can view Network Routes") } - wantedRoute, found := account.Routes[routeID] - if found { - return wantedRoute, nil - } - - return nil, status.Errorf(status.NotFound, "route with ID %s not found", routeID) + return am.Store.GetRouteByID(ctx, LockingStrengthShare, string(routeID), accountID) } // checkRoutePrefixOrDomainsExistForPeers checks if a route with a given prefix exists for a single peer or multiple peer groups. @@ -134,6 +121,13 @@ func (am *DefaultAccountManager) CreateRoute(ctx context.Context, accountID stri return nil, err } + // Do not allow non-Linux peers + if peer := account.GetPeer(peerID); peer != nil { + if peer.Meta.GoOS != "linux" { + return nil, status.Errorf(status.InvalidArgument, "non-linux peers are not supported as network routes") + } + } + if len(domains) > 0 && prefix.IsValid() { return nil, status.Errorf(status.InvalidArgument, "domains and network should not be provided at the same time") } @@ -234,6 +228,13 @@ func (am *DefaultAccountManager) SaveRoute(ctx context.Context, accountID, userI return err } + // Do not allow non-Linux peers + if peer := account.GetPeer(routeToSave.Peer); peer != nil { + if peer.Meta.GoOS != "linux" { + return status.Errorf(status.InvalidArgument, "non-linux peers are not supported as network routes") + } + } + if len(routeToSave.Domains) > 0 && routeToSave.Network.IsValid() { return status.Errorf(status.InvalidArgument, "domains and network should not be provided at the same time") } @@ -311,29 +312,16 @@ func (am *DefaultAccountManager) DeleteRoute(ctx context.Context, accountID stri // ListRoutes returns a list of routes from account func (am *DefaultAccountManager) ListRoutes(ctx context.Context, accountID, userID string) ([]*route.Route, error) { - unlock := am.Store.AcquireWriteLockByUID(ctx, accountID) - defer unlock() - - account, err := am.Store.GetAccount(ctx, accountID) + user, err := am.Store.GetUserByUserID(ctx, LockingStrengthShare, userID) if err != nil { return nil, err } - user, err := account.FindUser(userID) - if err != nil { - return nil, err - } - - if !(user.HasAdminPower() || user.IsServiceUser) { + if !user.IsAdminOrServiceUser() || user.AccountID != accountID { return nil, status.Errorf(status.PermissionDenied, "only users with admin power can view Network Routes") } - routes := make([]*route.Route, 0, len(account.Routes)) - for _, item := range account.Routes { - routes = append(routes, item) - } - - return routes, nil + return am.Store.GetAccountRoutes(ctx, LockingStrengthShare, accountID) } func toProtocolRoute(route *route.Route) *proto.Route { diff --git a/management/server/route_test.go b/management/server/route_test.go index 506bfb0a8..4533c6b7e 100644 --- a/management/server/route_test.go +++ b/management/server/route_test.go @@ -1205,7 +1205,7 @@ func TestGetNetworkMap_RouteSync(t *testing.T) { newPolicy.Rules[0].Sources = []string{newGroup.ID} newPolicy.Rules[0].Destinations = []string{newGroup.ID} - err = am.SavePolicy(context.Background(), account.Id, userID, newPolicy) + err = am.SavePolicy(context.Background(), account.Id, userID, newPolicy, false) require.NoError(t, err) err = am.DeletePolicy(context.Background(), account.Id, defaultRule.ID, userID) diff --git a/management/server/setupkey.go b/management/server/setupkey.go index 859f1b0b9..9521e22d3 100644 --- a/management/server/setupkey.go +++ b/management/server/setupkey.go @@ -330,26 +330,24 @@ func (am *DefaultAccountManager) SaveSetupKey(ctx context.Context, accountID str // ListSetupKeys returns a list of all setup keys of the account func (am *DefaultAccountManager) ListSetupKeys(ctx context.Context, accountID, userID string) ([]*SetupKey, error) { - unlock := am.Store.AcquireWriteLockByUID(ctx, accountID) - defer unlock() - account, err := am.Store.GetAccount(ctx, accountID) + user, err := am.Store.GetUserByUserID(ctx, LockingStrengthShare, userID) if err != nil { return nil, err } - user, err := account.FindUser(userID) + if !user.IsAdminOrServiceUser() || user.AccountID != accountID { + return nil, status.Errorf(status.Unauthorized, "only users with admin power can view setup keys") + } + + setupKeys, err := am.Store.GetAccountSetupKeys(ctx, LockingStrengthShare, accountID) if err != nil { return nil, err } - if !user.HasAdminPower() && !user.IsServiceUser { - return nil, status.Errorf(status.Unauthorized, "only users with admin power can view policies") - } - - keys := make([]*SetupKey, 0, len(account.SetupKeys)) - for _, key := range account.SetupKeys { + keys := make([]*SetupKey, 0, len(setupKeys)) + for _, key := range setupKeys { var k *SetupKey - if !(user.HasAdminPower() || user.IsServiceUser) { + if !user.IsAdminOrServiceUser() { k = key.HiddenCopy(999) } else { k = key.Copy() @@ -362,44 +360,30 @@ func (am *DefaultAccountManager) ListSetupKeys(ctx context.Context, accountID, u // GetSetupKey looks up a SetupKey by KeyID, returns NotFound error if not found. func (am *DefaultAccountManager) GetSetupKey(ctx context.Context, accountID, userID, keyID string) (*SetupKey, error) { - unlock := am.Store.AcquireWriteLockByUID(ctx, accountID) - defer unlock() - - account, err := am.Store.GetAccount(ctx, accountID) + user, err := am.Store.GetUserByUserID(ctx, LockingStrengthShare, userID) if err != nil { return nil, err } - user, err := account.FindUser(userID) + if !user.IsAdminOrServiceUser() || user.AccountID != accountID { + return nil, status.Errorf(status.Unauthorized, "only users with admin power can view setup keys") + } + + setupKey, err := am.Store.GetSetupKeyByID(ctx, LockingStrengthShare, keyID, accountID) if err != nil { return nil, err } - if !user.HasAdminPower() && !user.IsServiceUser { - return nil, status.Errorf(status.Unauthorized, "only users with admin power can view policies") - } - - var foundKey *SetupKey - for _, key := range account.SetupKeys { - if key.Id == keyID { - foundKey = key.Copy() - break - } - } - if foundKey == nil { - return nil, status.Errorf(status.NotFound, "setup key not found") - } - // the UpdatedAt field was introduced later, so there might be that some keys have a Zero value (e.g, null in the store file) - if foundKey.UpdatedAt.IsZero() { - foundKey.UpdatedAt = foundKey.CreatedAt + if setupKey.UpdatedAt.IsZero() { + setupKey.UpdatedAt = setupKey.CreatedAt } - if !(user.HasAdminPower() || user.IsServiceUser) { - foundKey = foundKey.HiddenCopy(999) + if !user.IsAdminOrServiceUser() { + setupKey = setupKey.HiddenCopy(999) } - return foundKey, nil + return setupKey, nil } func validateSetupKeyAutoGroups(account *Account, autoGroups []string) error { diff --git a/management/server/sql_store.go b/management/server/sql_store.go index 8fa5f9d05..85c68ef44 100644 --- a/management/server/sql_store.go +++ b/management/server/sql_store.go @@ -36,6 +36,7 @@ const ( idQueryCondition = "id = ?" keyQueryCondition = "key = ?" accountAndIDQueryCondition = "account_id = ? and id = ?" + accountIDCondition = "account_id = ?" peerNotFoundFMT = "peer %s not found" ) @@ -399,20 +400,30 @@ func (s *SqlStore) DeleteTokenID2UserIDIndex(tokenID string) error { } func (s *SqlStore) GetAccountByPrivateDomain(ctx context.Context, domain string) (*Account, error) { - var account Account - - result := s.db.First(&account, "domain = ? and is_domain_primary_account = ? and domain_category = ?", - strings.ToLower(domain), true, PrivateCategory) - if result.Error != nil { - if errors.Is(result.Error, gorm.ErrRecordNotFound) { - return nil, status.Errorf(status.NotFound, "account not found: provided domain is not registered or is not private") - } - log.WithContext(ctx).Errorf("error when getting account from the store: %s", result.Error) - return nil, status.Errorf(status.Internal, "issue getting account from store") + accountID, err := s.GetAccountIDByPrivateDomain(ctx, LockingStrengthShare, domain) + if err != nil { + return nil, err } // TODO: rework to not call GetAccount - return s.GetAccount(ctx, account.Id) + return s.GetAccount(ctx, accountID) +} + +func (s *SqlStore) GetAccountIDByPrivateDomain(ctx context.Context, lockStrength LockingStrength, domain string) (string, error) { + var accountID string + result := s.db.WithContext(ctx).Clauses(clause.Locking{Strength: string(lockStrength)}).Model(&Account{}).Select("id"). + Where("domain = ? and is_domain_primary_account = ? and domain_category = ?", + strings.ToLower(domain), true, PrivateCategory, + ).First(&accountID) + if result.Error != nil { + if errors.Is(result.Error, gorm.ErrRecordNotFound) { + return "", status.Errorf(status.NotFound, "account not found: provided domain is not registered or is not private") + } + log.WithContext(ctx).Errorf("error when getting account from the store: %s", result.Error) + return "", status.Errorf(status.Internal, "issue getting account from store") + } + + return accountID, nil } func (s *SqlStore) GetAccountBySetupKey(ctx context.Context, setupKey string) (*Account, error) { @@ -478,7 +489,7 @@ func (s *SqlStore) GetUserByTokenID(ctx context.Context, tokenID string) (*User, func (s *SqlStore) GetUserByUserID(ctx context.Context, lockStrength LockingStrength, userID string) (*User, error) { var user User result := s.db.WithContext(ctx).Clauses(clause.Locking{Strength: string(lockStrength)}). - First(&user, idQueryCondition, userID) + Preload(clause.Associations).First(&user, idQueryCondition, userID) if result.Error != nil { if errors.Is(result.Error, gorm.ErrRecordNotFound) { return nil, status.NewUserNotFoundError(userID) @@ -491,7 +502,7 @@ func (s *SqlStore) GetUserByUserID(ctx context.Context, lockStrength LockingStre func (s *SqlStore) GetAccountGroups(ctx context.Context, accountID string) ([]*nbgroup.Group, error) { var groups []*nbgroup.Group - result := s.db.Find(&groups, idQueryCondition, accountID) + result := s.db.Find(&groups, accountIDCondition, accountID) if result.Error != nil { if errors.Is(result.Error, gorm.ErrRecordNotFound) { return nil, status.Errorf(status.NotFound, "accountID not found: index lookup failed") @@ -661,9 +672,8 @@ func (s *SqlStore) GetAccountIDByPeerPubKey(ctx context.Context, peerKey string) } func (s *SqlStore) GetAccountIDByUserID(userID string) (string, error) { - var user User var accountID string - result := s.db.Model(&user).Select("account_id").Where(idQueryCondition, userID).First(&accountID) + result := s.db.Model(&User{}).Select("account_id").Where(idQueryCondition, userID).First(&accountID) if result.Error != nil { if errors.Is(result.Error, gorm.ErrRecordNotFound) { return "", status.Errorf(status.NotFound, "account not found: index lookup failed") @@ -1028,3 +1038,152 @@ func (s *SqlStore) withTx(tx *gorm.DB) Store { func (s *SqlStore) GetDB() *gorm.DB { return s.db } + +func (s *SqlStore) GetAccountDNSSettings(ctx context.Context, lockStrength LockingStrength, accountID string) (*DNSSettings, error) { + var accountDNSSettings AccountDNSSettings + + result := s.db.WithContext(ctx).Clauses(clause.Locking{Strength: string(lockStrength)}).Model(&Account{}). + First(&accountDNSSettings, idQueryCondition, accountID) + if result.Error != nil { + if errors.Is(result.Error, gorm.ErrRecordNotFound) { + return nil, status.Errorf(status.NotFound, "dns settings not found") + } + return nil, status.Errorf(status.Internal, "failed to get dns settings from store: %v", result.Error) + } + return &accountDNSSettings.DNSSettings, nil +} + +// AccountExists checks whether an account exists by the given ID. +func (s *SqlStore) AccountExists(ctx context.Context, lockStrength LockingStrength, id string) (bool, error) { + var accountID string + + result := s.db.WithContext(ctx).Clauses(clause.Locking{Strength: string(lockStrength)}).Model(&Account{}). + Select("id").First(&accountID, idQueryCondition, id) + if result.Error != nil { + if errors.Is(result.Error, gorm.ErrRecordNotFound) { + return false, nil + } + return false, result.Error + } + + return accountID != "", nil +} + +// GetAccountDomainAndCategory retrieves the Domain and DomainCategory fields for an account based on the given accountID. +func (s *SqlStore) GetAccountDomainAndCategory(ctx context.Context, lockStrength LockingStrength, accountID string) (string, string, error) { + var account Account + + result := s.db.WithContext(ctx).Clauses(clause.Locking{Strength: string(lockStrength)}).Model(&Account{}).Select("domain", "domain_category"). + Where(idQueryCondition, accountID).First(&account) + if result.Error != nil { + if errors.Is(result.Error, gorm.ErrRecordNotFound) { + return "", "", status.Errorf(status.NotFound, "account not found") + } + return "", "", status.Errorf(status.Internal, "failed to get domain category from store: %v", result.Error) + } + + return account.Domain, account.DomainCategory, nil +} + +// GetGroupByID retrieves a group by ID and account ID. +func (s *SqlStore) GetGroupByID(ctx context.Context, lockStrength LockingStrength, groupID, accountID string) (*nbgroup.Group, error) { + return getRecordByID[nbgroup.Group](s.db.WithContext(ctx).Preload(clause.Associations), lockStrength, groupID, accountID) +} + +// GetGroupByName retrieves a group by name and account ID. +func (s *SqlStore) GetGroupByName(ctx context.Context, lockStrength LockingStrength, groupName, accountID string) (*nbgroup.Group, error) { + var group nbgroup.Group + + result := s.db.WithContext(ctx).Clauses(clause.Locking{Strength: string(lockStrength)}).Preload(clause.Associations). + Order("json_array_length(peers) DESC").First(&group, "name = ? and account_id = ?", groupName, accountID) + if err := result.Error; err != nil { + if errors.Is(result.Error, gorm.ErrRecordNotFound) { + return nil, status.Errorf(status.NotFound, "group not found") + } + return nil, status.Errorf(status.Internal, "failed to get group from store: %s", result.Error) + } + return &group, nil +} + +// GetAccountPolicies retrieves policies for an account. +func (s *SqlStore) GetAccountPolicies(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*Policy, error) { + return getRecords[*Policy](s.db.WithContext(ctx).Preload(clause.Associations), lockStrength, accountID) +} + +// GetPolicyByID retrieves a policy by its ID and account ID. +func (s *SqlStore) GetPolicyByID(ctx context.Context, lockStrength LockingStrength, policyID string, accountID string) (*Policy, error) { + return getRecordByID[Policy](s.db.WithContext(ctx).Preload(clause.Associations), lockStrength, policyID, accountID) +} + +// GetAccountPostureChecks retrieves posture checks for an account. +func (s *SqlStore) GetAccountPostureChecks(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*posture.Checks, error) { + return getRecords[*posture.Checks](s.db.WithContext(ctx), lockStrength, accountID) +} + +// GetPostureChecksByID retrieves posture checks by their ID and account ID. +func (s *SqlStore) GetPostureChecksByID(ctx context.Context, lockStrength LockingStrength, postureCheckID string, accountID string) (*posture.Checks, error) { + return getRecordByID[posture.Checks](s.db.WithContext(ctx), lockStrength, postureCheckID, accountID) +} + +// GetAccountRoutes retrieves network routes for an account. +func (s *SqlStore) GetAccountRoutes(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*route.Route, error) { + return getRecords[*route.Route](s.db.WithContext(ctx), lockStrength, accountID) +} + +// GetRouteByID retrieves a route by its ID and account ID. +func (s *SqlStore) GetRouteByID(ctx context.Context, lockStrength LockingStrength, routeID string, accountID string) (*route.Route, error) { + return getRecordByID[route.Route](s.db.WithContext(ctx), lockStrength, routeID, accountID) +} + +// GetAccountSetupKeys retrieves setup keys for an account. +func (s *SqlStore) GetAccountSetupKeys(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*SetupKey, error) { + return getRecords[*SetupKey](s.db.WithContext(ctx), lockStrength, accountID) +} + +// GetSetupKeyByID retrieves a setup key by its ID and account ID. +func (s *SqlStore) GetSetupKeyByID(ctx context.Context, lockStrength LockingStrength, setupKeyID string, accountID string) (*SetupKey, error) { + return getRecordByID[SetupKey](s.db.WithContext(ctx), lockStrength, setupKeyID, accountID) +} + +// GetAccountNameServerGroups retrieves name server groups for an account. +func (s *SqlStore) GetAccountNameServerGroups(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*nbdns.NameServerGroup, error) { + return getRecords[*nbdns.NameServerGroup](s.db.WithContext(ctx), lockStrength, accountID) +} + +// GetNameServerGroupByID retrieves a name server group by its ID and account ID. +func (s *SqlStore) GetNameServerGroupByID(ctx context.Context, lockStrength LockingStrength, nsGroupID string, accountID string) (*nbdns.NameServerGroup, error) { + return getRecordByID[nbdns.NameServerGroup](s.db.WithContext(ctx), lockStrength, nsGroupID, accountID) +} + +// getRecords retrieves records from the database based on the account ID. +func getRecords[T any](db *gorm.DB, lockStrength LockingStrength, accountID string) ([]T, error) { + var record []T + + result := db.Clauses(clause.Locking{Strength: string(lockStrength)}).Find(&record, accountIDCondition, accountID) + if err := result.Error; err != nil { + parts := strings.Split(fmt.Sprintf("%T", record), ".") + recordType := parts[len(parts)-1] + + return nil, status.Errorf(status.Internal, "failed to get account %ss from store: %v", recordType, err) + } + + return record, nil +} + +// getRecordByID retrieves a record by its ID and account ID from the database. +func getRecordByID[T any](db *gorm.DB, lockStrength LockingStrength, recordID, accountID string) (*T, error) { + var record T + + result := db.Clauses(clause.Locking{Strength: string(lockStrength)}). + First(&record, accountAndIDQueryCondition, accountID, recordID) + if err := result.Error; err != nil { + parts := strings.Split(fmt.Sprintf("%T", record), ".") + recordType := parts[len(parts)-1] + + if errors.Is(result.Error, gorm.ErrRecordNotFound) { + return nil, status.Errorf(status.NotFound, "%s not found", recordType) + } + return nil, status.Errorf(status.Internal, "failed to get %s from store: %v", recordType, err) + } + return &record, nil +} diff --git a/management/server/store.go b/management/server/store.go index 84b3b140c..f34a73c2d 100644 --- a/management/server/store.go +++ b/management/server/store.go @@ -12,6 +12,7 @@ import ( "strings" "time" + "github.com/netbirdio/netbird/dns" log "github.com/sirupsen/logrus" "gorm.io/gorm" @@ -39,53 +40,81 @@ const ( type Store interface { GetAllAccounts(ctx context.Context) []*Account GetAccount(ctx context.Context, accountID string) (*Account, error) - DeleteAccount(ctx context.Context, account *Account) error + AccountExists(ctx context.Context, lockStrength LockingStrength, id string) (bool, error) + GetAccountDomainAndCategory(ctx context.Context, lockStrength LockingStrength, accountID string) (string, string, error) GetAccountByUser(ctx context.Context, userID string) (*Account, error) GetAccountByPeerPubKey(ctx context.Context, peerKey string) (*Account, error) GetAccountIDByPeerPubKey(ctx context.Context, peerKey string) (string, error) - GetAccountIDByUserID(peerKey string) (string, error) + GetAccountIDByUserID(userID string) (string, error) GetAccountIDBySetupKey(ctx context.Context, peerKey string) (string, error) GetAccountByPeerID(ctx context.Context, peerID string) (*Account, error) GetAccountBySetupKey(ctx context.Context, setupKey string) (*Account, error) // todo use key hash later GetAccountByPrivateDomain(ctx context.Context, domain string) (*Account, error) - GetTokenIDByHashedToken(ctx context.Context, secret string) (string, error) + GetAccountIDByPrivateDomain(ctx context.Context, lockStrength LockingStrength, domain string) (string, error) + GetAccountSettings(ctx context.Context, lockStrength LockingStrength, accountID string) (*Settings, error) + GetAccountDNSSettings(ctx context.Context, lockStrength LockingStrength, accountID string) (*DNSSettings, error) + SaveAccount(ctx context.Context, account *Account) error + DeleteAccount(ctx context.Context, account *Account) error + GetUserByTokenID(ctx context.Context, tokenID string) (*User, error) GetUserByUserID(ctx context.Context, lockStrength LockingStrength, userID string) (*User, error) - GetAccountGroups(ctx context.Context, accountID string) ([]*nbgroup.Group, error) - GetPostureCheckByChecksDefinition(accountID string, checks *posture.ChecksDefinition) (*posture.Checks, error) - SaveAccount(ctx context.Context, account *Account) error SaveUsers(accountID string, users map[string]*User) error - SaveGroups(accountID string, groups map[string]*nbgroup.Group) error + SaveUserLastLogin(ctx context.Context, accountID, userID string, lastLogin time.Time) error + GetTokenIDByHashedToken(ctx context.Context, secret string) (string, error) DeleteHashedPAT2TokenIDIndex(hashedToken string) error DeleteTokenID2UserIDIndex(tokenID string) error + + GetAccountGroups(ctx context.Context, accountID string) ([]*nbgroup.Group, error) + GetGroupByID(ctx context.Context, lockStrength LockingStrength, groupID, accountID string) (*nbgroup.Group, error) + GetGroupByName(ctx context.Context, lockStrength LockingStrength, groupName, accountID string) (*nbgroup.Group, error) + SaveGroups(accountID string, groups map[string]*nbgroup.Group) error + + GetAccountPolicies(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*Policy, error) + GetPolicyByID(ctx context.Context, lockStrength LockingStrength, policyID string, accountID string) (*Policy, error) + + GetPostureCheckByChecksDefinition(accountID string, checks *posture.ChecksDefinition) (*posture.Checks, error) + GetAccountPostureChecks(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*posture.Checks, error) + GetPostureChecksByID(ctx context.Context, lockStrength LockingStrength, postureCheckID string, accountID string) (*posture.Checks, error) + + GetPeerLabelsInAccount(ctx context.Context, lockStrength LockingStrength, accountId string) ([]string, error) + AddPeerToAllGroup(ctx context.Context, accountID string, peerID string) error + AddPeerToGroup(ctx context.Context, accountId string, peerId string, groupID string) error + AddPeerToAccount(ctx context.Context, peer *nbpeer.Peer) error + GetPeerByPeerPubKey(ctx context.Context, lockStrength LockingStrength, peerKey string) (*nbpeer.Peer, error) + SavePeer(ctx context.Context, accountID string, peer *nbpeer.Peer) error + SavePeerStatus(accountID, peerID string, status nbpeer.PeerStatus) error + SavePeerLocation(accountID string, peer *nbpeer.Peer) error + + GetSetupKeyBySecret(ctx context.Context, lockStrength LockingStrength, key string) (*SetupKey, error) + IncrementSetupKeyUsage(ctx context.Context, setupKeyID string) error + GetAccountSetupKeys(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*SetupKey, error) + GetSetupKeyByID(ctx context.Context, lockStrength LockingStrength, setupKeyID string, accountID string) (*SetupKey, error) + + GetAccountRoutes(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*route.Route, error) + GetRouteByID(ctx context.Context, lockStrength LockingStrength, routeID string, accountID string) (*route.Route, error) + + GetAccountNameServerGroups(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*dns.NameServerGroup, error) + GetNameServerGroupByID(ctx context.Context, lockStrength LockingStrength, nameServerGroupID string, accountID string) (*dns.NameServerGroup, error) + + GetTakenIPs(ctx context.Context, lockStrength LockingStrength, accountId string) ([]net.IP, error) + IncrementNetworkSerial(ctx context.Context, accountId string) error + GetAccountNetwork(ctx context.Context, lockStrength LockingStrength, accountId string) (*Network, error) + GetInstallationID() string SaveInstallationID(ctx context.Context, ID string) error + // AcquireWriteLockByUID should attempt to acquire a lock for write purposes and return a function that releases the lock AcquireWriteLockByUID(ctx context.Context, uniqueID string) func() // AcquireReadLockByUID should attempt to acquire lock for read purposes and return a function that releases the lock AcquireReadLockByUID(ctx context.Context, uniqueID string) func() // AcquireGlobalLock should attempt to acquire a global lock and return a function that releases the lock AcquireGlobalLock(ctx context.Context) func() - SavePeer(ctx context.Context, accountID string, peer *nbpeer.Peer) error - SavePeerStatus(accountID, peerID string, status nbpeer.PeerStatus) error - SavePeerLocation(accountID string, peer *nbpeer.Peer) error - SaveUserLastLogin(ctx context.Context, accountID, userID string, lastLogin time.Time) error + // Close should close the store persisting all unsaved data. Close(ctx context.Context) error // GetStoreEngine should return StoreEngine of the current store implementation. // This is also a method of metrics.DataSource interface. GetStoreEngine() StoreEngine - GetPeerByPeerPubKey(ctx context.Context, lockStrength LockingStrength, peerKey string) (*nbpeer.Peer, error) - GetAccountSettings(ctx context.Context, lockStrength LockingStrength, accountID string) (*Settings, error) - GetSetupKeyBySecret(ctx context.Context, lockStrength LockingStrength, key string) (*SetupKey, error) - GetTakenIPs(ctx context.Context, lockStrength LockingStrength, accountId string) ([]net.IP, error) - IncrementSetupKeyUsage(ctx context.Context, setupKeyID string) error - AddPeerToAllGroup(ctx context.Context, accountID string, peerID string) error - GetPeerLabelsInAccount(ctx context.Context, lockStrength LockingStrength, accountId string) ([]string, error) - AddPeerToGroup(ctx context.Context, accountId string, peerId string, groupID string) error - AddPeerToAccount(ctx context.Context, peer *nbpeer.Peer) error - IncrementNetworkSerial(ctx context.Context, accountId string) error - GetAccountNetwork(ctx context.Context, lockStrength LockingStrength, accountId string) (*Network, error) ExecuteInTransaction(ctx context.Context, f func(store Store) error) error } diff --git a/management/server/user.go b/management/server/user.go index 9e60bb94b..6d01561c6 100644 --- a/management/server/user.go +++ b/management/server/user.go @@ -94,6 +94,11 @@ func (u *User) HasAdminPower() bool { return u.Role == UserRoleAdmin || u.Role == UserRoleOwner } +// IsAdminOrServiceUser checks if the user has admin power or is a service user. +func (u *User) IsAdminOrServiceUser() bool { + return u.HasAdminPower() || u.IsServiceUser +} + // ToUserInfo converts a User object to a UserInfo object. func (u *User) ToUserInfo(userData *idp.UserData, settings *Settings) (*UserInfo, error) { autoGroups := u.AutoGroups @@ -357,39 +362,35 @@ func (am *DefaultAccountManager) inviteNewUser(ctx context.Context, accountID, u return newUser.ToUserInfo(idpUser, account.Settings) } +func (am *DefaultAccountManager) GetUserByID(ctx context.Context, id string) (*User, error) { + return am.Store.GetUserByUserID(ctx, LockingStrengthShare, id) +} + // GetUser looks up a user by provided authorization claims. // It will also create an account if didn't exist for this user before. func (am *DefaultAccountManager) GetUser(ctx context.Context, claims jwtclaims.AuthorizationClaims) (*User, error) { - account, _, err := am.GetAccountFromToken(ctx, claims) + accountID, userID, err := am.GetAccountIDFromToken(ctx, claims) if err != nil { return nil, fmt.Errorf("failed to get account with token claims %v", err) } - unlock := am.Store.AcquireWriteLockByUID(ctx, account.Id) - defer unlock() - - account, err = am.Store.GetAccount(ctx, account.Id) + user, err := am.Store.GetUserByUserID(ctx, LockingStrengthShare, userID) if err != nil { - return nil, fmt.Errorf("failed to get an account from store %v", err) + return nil, err } - user, ok := account.Users[claims.UserId] - if !ok { - return nil, status.Errorf(status.NotFound, "user not found") - } - - // this code should be outside of the am.GetAccountFromToken(claims) because this method is called also by the gRPC + // this code should be outside of the am.GetAccountIDFromToken(claims) because this method is called also by the gRPC // server when user authenticates a device. And we need to separate the Dashboard login event from the Device login event. newLogin := user.LastDashboardLoginChanged(claims.LastLogin) - err = am.Store.SaveUserLastLogin(ctx, account.Id, claims.UserId, claims.LastLogin) + err = am.Store.SaveUserLastLogin(ctx, accountID, userID, claims.LastLogin) if err != nil { log.WithContext(ctx).Errorf("failed saving user last login: %v", err) } if newLogin { meta := map[string]any{"timestamp": claims.LastLogin} - am.StoreEvent(ctx, claims.UserId, claims.UserId, account.Id, activity.DashboardLogin, meta) + am.StoreEvent(ctx, claims.UserId, claims.UserId, accountID, activity.DashboardLogin, meta) } return user, nil @@ -642,63 +643,48 @@ func (am *DefaultAccountManager) DeletePAT(ctx context.Context, accountID string // GetPAT returns a specific PAT from a user func (am *DefaultAccountManager) GetPAT(ctx context.Context, accountID string, initiatorUserID string, targetUserID string, tokenID string) (*PersonalAccessToken, error) { - unlock := am.Store.AcquireWriteLockByUID(ctx, accountID) - defer unlock() - - account, err := am.Store.GetAccount(ctx, accountID) + initiatorUser, err := am.Store.GetUserByUserID(ctx, LockingStrengthShare, initiatorUserID) if err != nil { - return nil, status.Errorf(status.NotFound, "account not found: %s", err) + return nil, err } - targetUser, ok := account.Users[targetUserID] - if !ok { - return nil, status.Errorf(status.NotFound, "user not found") + targetUser, err := am.Store.GetUserByUserID(ctx, LockingStrengthShare, targetUserID) + if err != nil { + return nil, err } - executingUser, ok := account.Users[initiatorUserID] - if !ok { - return nil, status.Errorf(status.NotFound, "user not found") + if (initiatorUserID != targetUserID && !initiatorUser.IsAdminOrServiceUser()) || initiatorUser.AccountID != accountID { + return nil, status.Errorf(status.PermissionDenied, "no permission to get PAT for this user") } - if !(initiatorUserID == targetUserID || (executingUser.HasAdminPower() && targetUser.IsServiceUser)) { - return nil, status.Errorf(status.PermissionDenied, "no permission to get PAT for this userser") + for _, pat := range targetUser.PATsG { + if pat.ID == tokenID { + return pat.Copy(), nil + } } - pat := targetUser.PATs[tokenID] - if pat == nil { - return nil, status.Errorf(status.NotFound, "PAT not found") - } - - return pat, nil + return nil, status.Errorf(status.NotFound, "PAT not found") } // GetAllPATs returns all PATs for a user func (am *DefaultAccountManager) GetAllPATs(ctx context.Context, accountID string, initiatorUserID string, targetUserID string) ([]*PersonalAccessToken, error) { - unlock := am.Store.AcquireWriteLockByUID(ctx, accountID) - defer unlock() - - account, err := am.Store.GetAccount(ctx, accountID) + initiatorUser, err := am.Store.GetUserByUserID(ctx, LockingStrengthShare, initiatorUserID) if err != nil { - return nil, status.Errorf(status.NotFound, "account not found: %s", err) + return nil, err } - targetUser, ok := account.Users[targetUserID] - if !ok { - return nil, status.Errorf(status.NotFound, "user not found") + targetUser, err := am.Store.GetUserByUserID(ctx, LockingStrengthShare, targetUserID) + if err != nil { + return nil, err } - executingUser, ok := account.Users[initiatorUserID] - if !ok { - return nil, status.Errorf(status.NotFound, "user not found") - } - - if !(initiatorUserID == targetUserID || (executingUser.HasAdminPower() && targetUser.IsServiceUser)) { + if (initiatorUserID != targetUserID && !initiatorUser.IsAdminOrServiceUser()) || initiatorUser.AccountID != accountID { return nil, status.Errorf(status.PermissionDenied, "no permission to get PAT for this user") } - var pats []*PersonalAccessToken - for _, pat := range targetUser.PATs { - pats = append(pats, pat) + pats := make([]*PersonalAccessToken, 0, len(targetUser.PATsG)) + for _, pat := range targetUser.PATsG { + pats = append(pats, pat.Copy()) } return pats, nil diff --git a/management/server/user_test.go b/management/server/user_test.go index 272060276..e394ef840 100644 --- a/management/server/user_test.go +++ b/management/server/user_test.go @@ -199,7 +199,8 @@ func TestUser_GetPAT(t *testing.T) { defer store.Close(context.Background()) account := newAccountWithId(context.Background(), mockAccountID, mockUserID, "") account.Users[mockUserID] = &User{ - Id: mockUserID, + Id: mockUserID, + AccountID: mockAccountID, PATs: map[string]*PersonalAccessToken{ mockTokenID1: { ID: mockTokenID1, @@ -231,7 +232,8 @@ func TestUser_GetAllPATs(t *testing.T) { defer store.Close(context.Background()) account := newAccountWithId(context.Background(), mockAccountID, mockUserID, "") account.Users[mockUserID] = &User{ - Id: mockUserID, + Id: mockUserID, + AccountID: mockAccountID, PATs: map[string]*PersonalAccessToken{ mockTokenID1: { ID: mockTokenID1, @@ -796,7 +798,10 @@ func TestUser_DeleteUser_RegularUsers(t *testing.T) { assert.NoError(t, err) } - acc, err := am.GetAccountByUserOrAccountID(context.Background(), "", account.Id, "") + accID, err := am.GetAccountIDByUserOrAccountID(context.Background(), "", account.Id, "") + assert.NoError(t, err) + + acc, err := am.Store.GetAccount(context.Background(), accID) assert.NoError(t, err) for _, id := range tc.expectedDeleted {