[management] Remove redundant get account calls in GetAccountFromToken (#2615)

* refactor access control middleware and user access by JWT groups

Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>

* refactor jwt groups extractor

Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>

* refactor handlers to get account when necessary

Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>

* refactor getAccountFromToken

Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>

* refactor getAccountWithAuthorizationClaims

Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>

* fix merge

Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>

* revert handles change

Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>

* remove GetUserByID from account manager

Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>

* fix tests

Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>

* refactor getAccountWithAuthorizationClaims to return account id

Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>

* refactor handlers to use GetAccountIDFromToken

Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>

* fix tests

Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>

* remove locks

Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>

* refactor

Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>

* add GetGroupByName from store

Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>

* add GetGroupByID from store and refactor

Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>

* Refactor retrieval of policy and posture checks

Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>

* Refactor user permissions and retrieves PAT

Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>

* Refactor route, setupkey, nameserver and dns to get record(s) from store

Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>

* Refactor store

Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>

* fix lint

Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>

* fix tests

Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>

* fix add missing policy source posture checks

Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>

* add store lock

Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>

* fix tests

Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>

* add get account

Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>

---------

Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>
This commit is contained in:
Bethuel Mmbaga 2024-09-27 17:10:50 +03:00 committed by GitHub
parent 4ebf6e1c4c
commit acb73bd64a
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
44 changed files with 1279 additions and 981 deletions

View File

@ -20,11 +20,6 @@ import (
cacheStore "github.com/eko/gocache/v3/store" cacheStore "github.com/eko/gocache/v3/store"
"github.com/hashicorp/go-multierror" "github.com/hashicorp/go-multierror"
"github.com/miekg/dns" "github.com/miekg/dns"
gocache "github.com/patrickmn/go-cache"
"github.com/rs/xid"
log "github.com/sirupsen/logrus"
"golang.org/x/exp/maps"
"github.com/netbirdio/netbird/base62" "github.com/netbirdio/netbird/base62"
nbdns "github.com/netbirdio/netbird/dns" nbdns "github.com/netbirdio/netbird/dns"
"github.com/netbirdio/netbird/management/domain" "github.com/netbirdio/netbird/management/domain"
@ -41,6 +36,10 @@ import (
"github.com/netbirdio/netbird/management/server/status" "github.com/netbirdio/netbird/management/server/status"
"github.com/netbirdio/netbird/management/server/telemetry" "github.com/netbirdio/netbird/management/server/telemetry"
"github.com/netbirdio/netbird/route" "github.com/netbirdio/netbird/route"
gocache "github.com/patrickmn/go-cache"
"github.com/rs/xid"
log "github.com/sirupsen/logrus"
"golang.org/x/exp/maps"
) )
const ( const (
@ -63,6 +62,7 @@ func cacheEntryExpiration() time.Duration {
type AccountManager interface { type AccountManager interface {
GetOrCreateAccountByUser(ctx context.Context, userId, domain string) (*Account, error) GetOrCreateAccountByUser(ctx context.Context, userId, domain string) (*Account, error)
GetAccount(ctx context.Context, accountID string) (*Account, error)
CreateSetupKey(ctx context.Context, accountID string, keyName string, keyType SetupKeyType, expiresIn time.Duration, CreateSetupKey(ctx context.Context, accountID string, keyName string, keyType SetupKeyType, expiresIn time.Duration,
autoGroups []string, usageLimit int, userID string, ephemeral bool) (*SetupKey, error) autoGroups []string, usageLimit int, userID string, ephemeral bool) (*SetupKey, error)
SaveSetupKey(ctx context.Context, accountID string, key *SetupKey, userID string) (*SetupKey, error) SaveSetupKey(ctx context.Context, accountID string, key *SetupKey, userID string) (*SetupKey, error)
@ -75,12 +75,14 @@ type AccountManager interface {
SaveOrAddUser(ctx context.Context, accountID, initiatorUserID string, update *User, addIfNotExists bool) (*UserInfo, error) SaveOrAddUser(ctx context.Context, accountID, initiatorUserID string, update *User, addIfNotExists bool) (*UserInfo, error)
SaveOrAddUsers(ctx context.Context, accountID, initiatorUserID string, updates []*User, addIfNotExists bool) ([]*UserInfo, error) SaveOrAddUsers(ctx context.Context, accountID, initiatorUserID string, updates []*User, addIfNotExists bool) ([]*UserInfo, error)
GetSetupKey(ctx context.Context, accountID, userID, keyID string) (*SetupKey, error) GetSetupKey(ctx context.Context, accountID, userID, keyID string) (*SetupKey, error)
GetAccountByUserOrAccountID(ctx context.Context, userID, accountID, domain string) (*Account, error) GetAccountByID(ctx context.Context, accountID string, userID string) (*Account, error)
GetAccountFromToken(ctx context.Context, claims jwtclaims.AuthorizationClaims) (*Account, *User, error) GetAccountIDByUserOrAccountID(ctx context.Context, userID, accountID, domain string) (string, error)
GetAccountIDFromToken(ctx context.Context, claims jwtclaims.AuthorizationClaims) (string, string, error)
CheckUserAccessByJWTGroups(ctx context.Context, claims jwtclaims.AuthorizationClaims) error CheckUserAccessByJWTGroups(ctx context.Context, claims jwtclaims.AuthorizationClaims) error
GetAccountFromPAT(ctx context.Context, pat string) (*Account, *User, *PersonalAccessToken, error) GetAccountFromPAT(ctx context.Context, pat string) (*Account, *User, *PersonalAccessToken, error)
DeleteAccount(ctx context.Context, accountID, userID string) error DeleteAccount(ctx context.Context, accountID, userID string) error
MarkPATUsed(ctx context.Context, tokenID string) error MarkPATUsed(ctx context.Context, tokenID string) error
GetUserByID(ctx context.Context, id string) (*User, error)
GetUser(ctx context.Context, claims jwtclaims.AuthorizationClaims) (*User, error) GetUser(ctx context.Context, claims jwtclaims.AuthorizationClaims) (*User, error)
ListUsers(ctx context.Context, accountID string) ([]*User, error) ListUsers(ctx context.Context, accountID string) ([]*User, error)
GetPeers(ctx context.Context, accountID, userID string) ([]*nbpeer.Peer, error) GetPeers(ctx context.Context, accountID, userID string) ([]*nbpeer.Peer, error)
@ -107,7 +109,7 @@ type AccountManager interface {
GroupAddPeer(ctx context.Context, accountId, groupID, peerID string) error GroupAddPeer(ctx context.Context, accountId, groupID, peerID string) error
GroupDeletePeer(ctx context.Context, accountId, groupID, peerID string) error GroupDeletePeer(ctx context.Context, accountId, groupID, peerID string) error
GetPolicy(ctx context.Context, accountID, policyID, userID string) (*Policy, error) GetPolicy(ctx context.Context, accountID, policyID, userID string) (*Policy, error)
SavePolicy(ctx context.Context, accountID, userID string, policy *Policy) error SavePolicy(ctx context.Context, accountID, userID string, policy *Policy, isUpdate bool) error
DeletePolicy(ctx context.Context, accountID, policyID, userID string) error DeletePolicy(ctx context.Context, accountID, policyID, userID string) error
ListPolicies(ctx context.Context, accountID, userID string) ([]*Policy, error) ListPolicies(ctx context.Context, accountID, userID string) ([]*Policy, error)
GetRoute(ctx context.Context, accountID string, routeID route.ID, userID string) (*route.Route, error) GetRoute(ctx context.Context, accountID string, routeID route.ID, userID string) (*route.Route, error)
@ -145,6 +147,7 @@ type AccountManager interface {
SyncPeerMeta(ctx context.Context, peerPubKey string, meta nbpeer.PeerSystemMeta) error SyncPeerMeta(ctx context.Context, peerPubKey string, meta nbpeer.PeerSystemMeta) error
FindExistingPostureCheck(accountID string, checks *posture.ChecksDefinition) (*posture.Checks, error) FindExistingPostureCheck(accountID string, checks *posture.ChecksDefinition) (*posture.Checks, error)
GetAccountIDForPeerKey(ctx context.Context, peerKey string) (string, error) GetAccountIDForPeerKey(ctx context.Context, peerKey string) (string, error)
GetAccountSettings(ctx context.Context, accountID string, userID string) (*Settings, error)
} }
type DefaultAccountManager struct { type DefaultAccountManager struct {
@ -268,6 +271,11 @@ type AccountNetwork struct {
Network *Network `gorm:"embedded;embeddedPrefix:network_"` Network *Network `gorm:"embedded;embeddedPrefix:network_"`
} }
// AccountDNSSettings used in gorm to only load dns settings and not whole account
type AccountDNSSettings struct {
DNSSettings DNSSettings `gorm:"embedded;embeddedPrefix:dns_settings_"`
}
type UserPermissions struct { type UserPermissions struct {
DashboardView string `json:"dashboard_view"` DashboardView string `json:"dashboard_view"`
} }
@ -1252,25 +1260,37 @@ func (am *DefaultAccountManager) DeleteAccount(ctx context.Context, accountID, u
return nil return nil
} }
// GetAccountByUserOrAccountID looks for an account by user or accountID, if no account is provided and // GetAccountIDByUserOrAccountID retrieves the account ID based on either the userID or accountID provided.
// userID doesn't have an account associated with it, one account is created // If an accountID is provided, it checks if the account exists and returns it.
// domain is used to create a new account if no account is found // If no accountID is provided, but a userID is given, it tries to retrieve the account by userID.
func (am *DefaultAccountManager) GetAccountByUserOrAccountID(ctx context.Context, userID, accountID, domain string) (*Account, error) { // If the user doesn't have an account, it creates one using the provided domain.
// Returns the account ID or an error if none is found or created.
func (am *DefaultAccountManager) GetAccountIDByUserOrAccountID(ctx context.Context, userID, accountID, domain string) (string, error) {
if accountID != "" { if accountID != "" {
return am.Store.GetAccount(ctx, accountID) exists, err := am.Store.AccountExists(ctx, LockingStrengthShare, accountID)
} else if userID != "" {
account, err := am.GetOrCreateAccountByUser(ctx, userID, domain)
if err != nil { if err != nil {
return nil, status.Errorf(status.NotFound, "account not found using user id: %s", userID) return "", err
} }
err = am.addAccountIDToIDPAppMeta(ctx, userID, account) if !exists {
if err != nil { return "", status.Errorf(status.NotFound, "account %s does not exist", accountID)
return nil, err
} }
return account, nil return accountID, nil
} }
return nil, status.Errorf(status.NotFound, "no valid user or account Id provided") if userID != "" {
account, err := am.GetOrCreateAccountByUser(ctx, userID, domain)
if err != nil {
return "", status.Errorf(status.NotFound, "account not found or created for user id: %s", userID)
}
if err = am.addAccountIDToIDPAppMeta(ctx, userID, account); err != nil {
return "", err
}
return account.Id, nil
}
return "", status.Errorf(status.NotFound, "no valid userID or accountID provided")
} }
func isNil(i idp.Manager) bool { func isNil(i idp.Manager) bool {
@ -1613,13 +1633,18 @@ func (am *DefaultAccountManager) handleNewUserAccount(ctx context.Context, domai
} }
// redeemInvite checks whether user has been invited and redeems the invite // redeemInvite checks whether user has been invited and redeems the invite
func (am *DefaultAccountManager) redeemInvite(ctx context.Context, account *Account, userID string) error { func (am *DefaultAccountManager) redeemInvite(ctx context.Context, accountID string, userID string) error {
// only possible with the enabled IdP manager // only possible with the enabled IdP manager
if am.idpManager == nil { if am.idpManager == nil {
log.WithContext(ctx).Warnf("invites only work with enabled IdP manager") log.WithContext(ctx).Warnf("invites only work with enabled IdP manager")
return nil return nil
} }
account, err := am.Store.GetAccount(ctx, accountID)
if err != nil {
return err
}
user, err := am.lookupUserInCache(ctx, userID, account) user, err := am.lookupUserInCache(ctx, userID, account)
if err != nil { if err != nil {
return err return err
@ -1678,6 +1703,11 @@ func (am *DefaultAccountManager) MarkPATUsed(ctx context.Context, tokenID string
return am.Store.SaveAccount(ctx, account) return am.Store.SaveAccount(ctx, account)
} }
// GetAccount returns an account associated with this account ID.
func (am *DefaultAccountManager) GetAccount(ctx context.Context, accountID string) (*Account, error) {
return am.Store.GetAccount(ctx, accountID)
}
// GetAccountFromPAT returns Account and User associated with a personal access token // GetAccountFromPAT returns Account and User associated with a personal access token
func (am *DefaultAccountManager) GetAccountFromPAT(ctx context.Context, token string) (*Account, *User, *PersonalAccessToken, error) { func (am *DefaultAccountManager) GetAccountFromPAT(ctx context.Context, token string) (*Account, *User, *PersonalAccessToken, error) {
if len(token) != PATLength { if len(token) != PATLength {
@ -1726,10 +1756,24 @@ func (am *DefaultAccountManager) GetAccountFromPAT(ctx context.Context, token st
return account, user, pat, nil return account, user, pat, nil
} }
// GetAccountFromToken returns an account associated with this token // GetAccountByID returns an account associated with this account ID.
func (am *DefaultAccountManager) GetAccountFromToken(ctx context.Context, claims jwtclaims.AuthorizationClaims) (*Account, *User, error) { func (am *DefaultAccountManager) GetAccountByID(ctx context.Context, accountID string, userID string) (*Account, error) {
user, err := am.Store.GetUserByUserID(ctx, LockingStrengthShare, userID)
if err != nil {
return nil, err
}
if user.AccountID != accountID || (!user.HasAdminPower() && !user.IsServiceUser) {
return nil, status.Errorf(status.PermissionDenied, "the user has no permission to access account data")
}
return am.Store.GetAccount(ctx, accountID)
}
// GetAccountIDFromToken returns an account ID associated with this token.
func (am *DefaultAccountManager) GetAccountIDFromToken(ctx context.Context, claims jwtclaims.AuthorizationClaims) (string, string, error) {
if claims.UserId == "" { if claims.UserId == "" {
return nil, nil, fmt.Errorf("user ID is empty") return "", "", fmt.Errorf("user ID is empty")
} }
if am.singleAccountMode && am.singleAccountModeDomain != "" { if am.singleAccountMode && am.singleAccountModeDomain != "" {
// This section is mostly related to self-hosted installations. // This section is mostly related to self-hosted installations.
@ -1739,110 +1783,111 @@ func (am *DefaultAccountManager) GetAccountFromToken(ctx context.Context, claims
log.WithContext(ctx).Debugf("overriding JWT Domain and DomainCategory claims since single account mode is enabled") log.WithContext(ctx).Debugf("overriding JWT Domain and DomainCategory claims since single account mode is enabled")
} }
newAcc, err := am.getAccountWithAuthorizationClaims(ctx, claims) accountID, err := am.getAccountIDWithAuthorizationClaims(ctx, claims)
if err != nil { if err != nil {
return nil, nil, err return "", "", err
}
unlock := am.Store.AcquireWriteLockByUID(ctx, newAcc.Id)
alreadyUnlocked := false
defer func() {
if !alreadyUnlocked {
unlock()
}
}()
account, err := am.Store.GetAccount(ctx, newAcc.Id)
if err != nil {
return nil, nil, err
} }
user := account.Users[claims.UserId] user, err := am.Store.GetUserByUserID(ctx, LockingStrengthShare, claims.UserId)
if user == nil { if err != nil {
// this is not really possible because we got an account by user ID // this is not really possible because we got an account by user ID
return nil, nil, status.Errorf(status.NotFound, "user %s not found", claims.UserId) return "", "", status.Errorf(status.NotFound, "user %s not found", claims.UserId)
} }
if !user.IsServiceUser && claims.Invited { if !user.IsServiceUser && claims.Invited {
err = am.redeemInvite(ctx, account, claims.UserId) err = am.redeemInvite(ctx, accountID, user.Id)
if err != nil { if err != nil {
return nil, nil, err return "", "", err
} }
} }
if account.Settings.JWTGroupsEnabled { if err = am.syncJWTGroups(ctx, accountID, user, claims); err != nil {
if account.Settings.JWTGroupsClaimName == "" { return "", "", err
log.WithContext(ctx).Errorf("JWT groups are enabled but no claim name is set")
return account, user, nil
}
if claim, ok := claims.Raw[account.Settings.JWTGroupsClaimName]; ok {
if slice, ok := claim.([]interface{}); ok {
var groupsNames []string
for _, item := range slice {
if g, ok := item.(string); ok {
groupsNames = append(groupsNames, g)
} else {
log.WithContext(ctx).Errorf("JWT claim %q is not a string: %v", account.Settings.JWTGroupsClaimName, item)
}
}
oldGroups := make([]string, len(user.AutoGroups))
copy(oldGroups, user.AutoGroups)
// if groups were added or modified, save the account
if account.SetJWTGroups(claims.UserId, groupsNames) {
if account.Settings.GroupsPropagationEnabled {
if user, err := account.FindUser(claims.UserId); err == nil {
addNewGroups := difference(user.AutoGroups, oldGroups)
removeOldGroups := difference(oldGroups, user.AutoGroups)
account.UserGroupsAddToPeers(claims.UserId, addNewGroups...)
account.UserGroupsRemoveFromPeers(claims.UserId, removeOldGroups...)
account.Network.IncSerial()
if err := am.Store.SaveAccount(ctx, account); err != nil {
log.WithContext(ctx).Errorf("failed to save account: %v", err)
} else {
log.WithContext(ctx).Tracef("user %s: JWT group membership changed, updating account peers", claims.UserId)
am.updateAccountPeers(ctx, account)
unlock()
alreadyUnlocked = true
for _, g := range addNewGroups {
if group := account.GetGroup(g); group != nil {
am.StoreEvent(ctx, user.Id, user.Id, account.Id, activity.GroupAddedToUser,
map[string]any{
"group": group.Name,
"group_id": group.ID,
"is_service_user": user.IsServiceUser,
"user_name": user.ServiceUserName})
}
}
for _, g := range removeOldGroups {
if group := account.GetGroup(g); group != nil {
am.StoreEvent(ctx, user.Id, user.Id, account.Id, activity.GroupRemovedFromUser,
map[string]any{
"group": group.Name,
"group_id": group.ID,
"is_service_user": user.IsServiceUser,
"user_name": user.ServiceUserName})
}
}
}
}
} else {
if err := am.Store.SaveAccount(ctx, account); err != nil {
log.WithContext(ctx).Errorf("failed to save account: %v", err)
}
}
}
} else {
log.WithContext(ctx).Debugf("JWT claim %q is not a string array", account.Settings.JWTGroupsClaimName)
}
} else {
log.WithContext(ctx).Debugf("JWT claim %q not found", account.Settings.JWTGroupsClaimName)
}
} }
return account, user, nil return accountID, user.Id, nil
} }
// getAccountWithAuthorizationClaims retrievs an account using JWT Claims. // syncJWTGroups processes the JWT groups for a user, updates the account based on the groups,
// and propagates changes to peers if group propagation is enabled.
func (am *DefaultAccountManager) syncJWTGroups(ctx context.Context, accountID string, user *User, claims jwtclaims.AuthorizationClaims) error {
settings, err := am.Store.GetAccountSettings(ctx, LockingStrengthShare, accountID)
if err != nil {
return err
}
if settings == nil || !settings.JWTGroupsEnabled {
return nil
}
if settings.JWTGroupsClaimName == "" {
log.WithContext(ctx).Errorf("JWT groups are enabled but no claim name is set")
return nil
}
// TODO: Remove GetAccount after refactoring account peer's update
unlock := am.Store.AcquireWriteLockByUID(ctx, accountID)
defer unlock()
account, err := am.Store.GetAccount(ctx, accountID)
if err != nil {
return err
}
jwtGroupsNames := extractJWTGroups(ctx, settings.JWTGroupsClaimName, claims)
oldGroups := make([]string, len(user.AutoGroups))
copy(oldGroups, user.AutoGroups)
// Update the account if group membership changes
if account.SetJWTGroups(claims.UserId, jwtGroupsNames) {
addNewGroups := difference(user.AutoGroups, oldGroups)
removeOldGroups := difference(oldGroups, user.AutoGroups)
if settings.GroupsPropagationEnabled {
account.UserGroupsAddToPeers(claims.UserId, addNewGroups...)
account.UserGroupsRemoveFromPeers(claims.UserId, removeOldGroups...)
account.Network.IncSerial()
}
if err := am.Store.SaveAccount(ctx, account); err != nil {
log.WithContext(ctx).Errorf("failed to save account: %v", err)
return nil
}
// Propagate changes to peers if group propagation is enabled
if settings.GroupsPropagationEnabled {
log.WithContext(ctx).Tracef("user %s: JWT group membership changed, updating account peers", claims.UserId)
am.updateAccountPeers(ctx, account)
}
for _, g := range addNewGroups {
if group := account.GetGroup(g); group != nil {
am.StoreEvent(ctx, user.Id, user.Id, account.Id, activity.GroupAddedToUser,
map[string]any{
"group": group.Name,
"group_id": group.ID,
"is_service_user": user.IsServiceUser,
"user_name": user.ServiceUserName})
}
}
for _, g := range removeOldGroups {
if group := account.GetGroup(g); group != nil {
am.StoreEvent(ctx, user.Id, user.Id, account.Id, activity.GroupRemovedFromUser,
map[string]any{
"group": group.Name,
"group_id": group.ID,
"is_service_user": user.IsServiceUser,
"user_name": user.ServiceUserName})
}
}
}
return nil
}
// getAccountIDWithAuthorizationClaims retrieves an account ID using JWT Claims.
// if domain is of the PrivateCategory category, it will evaluate // if domain is of the PrivateCategory category, it will evaluate
// if account is new, existing or if there is another account with the same domain // if account is new, existing or if there is another account with the same domain
// //
@ -1859,26 +1904,34 @@ func (am *DefaultAccountManager) GetAccountFromToken(ctx context.Context, claims
// Existing user + Existing account + Existing Indexed Domain -> Nothing changes // Existing user + Existing account + Existing Indexed Domain -> Nothing changes
// //
// Existing user + Existing account + Existing domain reclassified Domain as private -> Nothing changes (index domain) // Existing user + Existing account + Existing domain reclassified Domain as private -> Nothing changes (index domain)
func (am *DefaultAccountManager) getAccountWithAuthorizationClaims(ctx context.Context, claims jwtclaims.AuthorizationClaims) (*Account, error) { func (am *DefaultAccountManager) getAccountIDWithAuthorizationClaims(ctx context.Context, claims jwtclaims.AuthorizationClaims) (string, error) {
log.WithContext(ctx).Tracef("getting account with authorization claims. User ID: \"%s\", Account ID: \"%s\", Domain: \"%s\", Domain Category: \"%s\"", log.WithContext(ctx).Tracef("getting account with authorization claims. User ID: \"%s\", Account ID: \"%s\", Domain: \"%s\", Domain Category: \"%s\"",
claims.UserId, claims.AccountId, claims.Domain, claims.DomainCategory) claims.UserId, claims.AccountId, claims.Domain, claims.DomainCategory)
if claims.UserId == "" { if claims.UserId == "" {
return nil, fmt.Errorf("user ID is empty") return "", fmt.Errorf("user ID is empty")
} }
// if Account ID is part of the claims // if Account ID is part of the claims
// it means that we've already classified the domain and user has an account // it means that we've already classified the domain and user has an account
if claims.DomainCategory != PrivateCategory || !isDomainValid(claims.Domain) { if claims.DomainCategory != PrivateCategory || !isDomainValid(claims.Domain) {
return am.GetAccountByUserOrAccountID(ctx, claims.UserId, claims.AccountId, claims.Domain) return am.GetAccountIDByUserOrAccountID(ctx, claims.UserId, claims.AccountId, claims.Domain)
} else if claims.AccountId != "" { } else if claims.AccountId != "" {
accountFromID, err := am.Store.GetAccount(ctx, claims.AccountId) userAccountID, err := am.Store.GetAccountIDByUserID(claims.UserId)
if err != nil { if err != nil {
return nil, err return "", err
} }
if _, ok := accountFromID.Users[claims.UserId]; !ok {
return nil, fmt.Errorf("user %s is not part of the account id %s", claims.UserId, claims.AccountId) if userAccountID != claims.AccountId {
return "", fmt.Errorf("user %s is not part of the account id %s", claims.UserId, claims.AccountId)
} }
if accountFromID.DomainCategory == PrivateCategory || claims.DomainCategory != PrivateCategory || accountFromID.Domain != claims.Domain {
return accountFromID, nil domain, domainCategory, err := am.Store.GetAccountDomainAndCategory(ctx, LockingStrengthShare, claims.AccountId)
if err != nil {
return "", err
}
if domainCategory == PrivateCategory || claims.DomainCategory != PrivateCategory || domain != claims.Domain {
return userAccountID, nil
} }
} }
@ -1888,48 +1941,53 @@ func (am *DefaultAccountManager) getAccountWithAuthorizationClaims(ctx context.C
log.WithContext(ctx).Debugf("Acquired global lock in %s for user %s", time.Since(start), claims.UserId) log.WithContext(ctx).Debugf("Acquired global lock in %s for user %s", time.Since(start), claims.UserId)
// We checked if the domain has a primary account already // We checked if the domain has a primary account already
domainAccount, err := am.Store.GetAccountByPrivateDomain(ctx, claims.Domain) domainAccountID, err := am.Store.GetAccountIDByPrivateDomain(ctx, LockingStrengthShare, claims.Domain)
if err != nil { if err != nil {
// if NotFound we are good to continue, otherwise return error // if NotFound we are good to continue, otherwise return error
e, ok := status.FromError(err) e, ok := status.FromError(err)
if !ok || e.Type() != status.NotFound { if !ok || e.Type() != status.NotFound {
return nil, err return "", err
} }
} }
account, err := am.Store.GetAccountByUser(ctx, claims.UserId) userAccountID, err := am.Store.GetAccountIDByUserID(claims.UserId)
if err == nil { if err == nil {
unlockAccount := am.Store.AcquireWriteLockByUID(ctx, account.Id) unlockAccount := am.Store.AcquireWriteLockByUID(ctx, userAccountID)
defer unlockAccount() defer unlockAccount()
account, err = am.Store.GetAccountByUser(ctx, claims.UserId) account, err := am.Store.GetAccountByUser(ctx, claims.UserId)
if err != nil { if err != nil {
return nil, err return "", err
} }
// If there is no primary domain account yet, we set the account as primary for the domain. Otherwise, // If there is no primary domain account yet, we set the account as primary for the domain. Otherwise,
// we compare the account's ID with the domain account ID, and if they don't match, we set the account as // we compare the account's ID with the domain account ID, and if they don't match, we set the account as
// non-primary account for the domain. We don't merge accounts at this stage, because of cases when a domain // non-primary account for the domain. We don't merge accounts at this stage, because of cases when a domain
// was previously unclassified or classified as public so N users that logged int that time, has they own account // was previously unclassified or classified as public so N users that logged int that time, has they own account
// and peers that shouldn't be lost. // and peers that shouldn't be lost.
primaryDomain := domainAccount == nil || account.Id == domainAccount.Id primaryDomain := domainAccountID == "" || account.Id == domainAccountID
if err = am.handleExistingUserAccount(ctx, account, primaryDomain, claims); err != nil {
err = am.handleExistingUserAccount(ctx, account, primaryDomain, claims) return "", err
if err != nil {
return nil, err
} }
return account, nil
return account.Id, nil
} else if s, ok := status.FromError(err); ok && s.Type() == status.NotFound { } else if s, ok := status.FromError(err); ok && s.Type() == status.NotFound {
if domainAccount != nil { var domainAccount *Account
unlockAccount := am.Store.AcquireWriteLockByUID(ctx, domainAccount.Id) if domainAccountID != "" {
unlockAccount := am.Store.AcquireWriteLockByUID(ctx, domainAccountID)
defer unlockAccount() defer unlockAccount()
domainAccount, err = am.Store.GetAccountByPrivateDomain(ctx, claims.Domain) domainAccount, err = am.Store.GetAccountByPrivateDomain(ctx, claims.Domain)
if err != nil { if err != nil {
return nil, err return "", err
} }
} }
return am.handleNewUserAccount(ctx, domainAccount, claims)
account, err := am.handleNewUserAccount(ctx, domainAccount, claims)
if err != nil {
return "", err
}
return account.Id, nil
} else { } else {
// other error // other error
return nil, err return "", err
} }
} }
@ -2022,26 +2080,21 @@ func (am *DefaultAccountManager) GetDNSDomain() string {
// CheckUserAccessByJWTGroups checks if the user has access, particularly in cases where the admin enabled JWT // CheckUserAccessByJWTGroups checks if the user has access, particularly in cases where the admin enabled JWT
// group propagation and set the list of groups with access permissions. // group propagation and set the list of groups with access permissions.
func (am *DefaultAccountManager) CheckUserAccessByJWTGroups(ctx context.Context, claims jwtclaims.AuthorizationClaims) error { func (am *DefaultAccountManager) CheckUserAccessByJWTGroups(ctx context.Context, claims jwtclaims.AuthorizationClaims) error {
account, _, err := am.GetAccountFromToken(ctx, claims) accountID, _, err := am.GetAccountIDFromToken(ctx, claims)
if err != nil {
return err
}
settings, err := am.Store.GetAccountSettings(ctx, LockingStrengthShare, accountID)
if err != nil { if err != nil {
return err return err
} }
// Ensures JWT group synchronization to the management is enabled before, // Ensures JWT group synchronization to the management is enabled before,
// filtering access based on the allowed groups. // filtering access based on the allowed groups.
if account.Settings != nil && account.Settings.JWTGroupsEnabled { if settings != nil && settings.JWTGroupsEnabled {
if allowedGroups := account.Settings.JWTAllowGroups; len(allowedGroups) > 0 { if allowedGroups := settings.JWTAllowGroups; len(allowedGroups) > 0 {
userJWTGroups := make([]string, 0) userJWTGroups := extractJWTGroups(ctx, settings.JWTGroupsClaimName, claims)
if claim, ok := claims.Raw[account.Settings.JWTGroupsClaimName]; ok {
if claimGroups, ok := claim.([]interface{}); ok {
for _, g := range claimGroups {
if group, ok := g.(string); ok {
userJWTGroups = append(userJWTGroups, group)
}
}
}
}
if !userHasAllowedGroup(allowedGroups, userJWTGroups) { if !userHasAllowedGroup(allowedGroups, userJWTGroups) {
return fmt.Errorf("user does not belong to any of the allowed JWT groups") return fmt.Errorf("user does not belong to any of the allowed JWT groups")
@ -2111,6 +2164,19 @@ func (am *DefaultAccountManager) getFreeDNSLabel(ctx context.Context, store Stor
return newLabel, nil return newLabel, nil
} }
func (am *DefaultAccountManager) GetAccountSettings(ctx context.Context, accountID string, userID string) (*Settings, error) {
user, err := am.Store.GetUserByUserID(ctx, LockingStrengthShare, userID)
if err != nil {
return nil, err
}
if user.AccountID != accountID || (!user.HasAdminPower() && !user.IsServiceUser) {
return nil, status.Errorf(status.PermissionDenied, "the user has no permission to access account data")
}
return am.Store.GetAccountSettings(ctx, LockingStrengthShare, accountID)
}
// addAllGroup to account object if it doesn't exist // addAllGroup to account object if it doesn't exist
func addAllGroup(account *Account) error { func addAllGroup(account *Account) error {
if len(account.Groups) == 0 { if len(account.Groups) == 0 {
@ -2193,6 +2259,27 @@ func newAccountWithId(ctx context.Context, accountID, userID, domain string) *Ac
return acc return acc
} }
// extractJWTGroups extracts the group names from a JWT token's claims.
func extractJWTGroups(ctx context.Context, claimName string, claims jwtclaims.AuthorizationClaims) []string {
userJWTGroups := make([]string, 0)
if claim, ok := claims.Raw[claimName]; ok {
if claimGroups, ok := claim.([]interface{}); ok {
for _, g := range claimGroups {
if group, ok := g.(string); ok {
userJWTGroups = append(userJWTGroups, group)
} else {
log.WithContext(ctx).Debugf("JWT claim %q contains a non-string group (type: %T): %v", claimName, g, g)
}
}
}
} else {
log.WithContext(ctx).Debugf("JWT claim %q is not a string array", claimName)
}
return userJWTGroups
}
// userHasAllowedGroup checks if a user belongs to any of the allowed groups. // userHasAllowedGroup checks if a user belongs to any of the allowed groups.
func userHasAllowedGroup(allowedGroups []string, userGroups []string) bool { func userHasAllowedGroup(allowedGroups []string, userGroups []string) bool {
for _, userGroup := range userGroups { for _, userGroup := range userGroups {

View File

@ -462,7 +462,7 @@ func TestAccountManager_GetOrCreateAccountByUser(t *testing.T) {
assert.Equal(t, account.Id, ev.TargetID) assert.Equal(t, account.Id, ev.TargetID)
} }
func TestDefaultAccountManager_GetAccountFromToken(t *testing.T) { func TestDefaultAccountManager_GetAccountIDFromToken(t *testing.T) {
type initUserParams jwtclaims.AuthorizationClaims type initUserParams jwtclaims.AuthorizationClaims
type test struct { type test struct {
@ -633,9 +633,12 @@ func TestDefaultAccountManager_GetAccountFromToken(t *testing.T) {
manager, err := createManager(t) manager, err := createManager(t)
require.NoError(t, err, "unable to create account manager") require.NoError(t, err, "unable to create account manager")
initAccount, err := manager.GetAccountByUserOrAccountID(context.Background(), testCase.inputInitUserParams.UserId, testCase.inputInitUserParams.AccountId, testCase.inputInitUserParams.Domain) accountID, err := manager.GetAccountIDByUserOrAccountID(context.Background(), testCase.inputInitUserParams.UserId, testCase.inputInitUserParams.AccountId, testCase.inputInitUserParams.Domain)
require.NoError(t, err, "create init user failed") require.NoError(t, err, "create init user failed")
initAccount, err := manager.Store.GetAccount(context.Background(), accountID)
require.NoError(t, err, "get init account failed")
if testCase.inputUpdateAttrs { if testCase.inputUpdateAttrs {
err = manager.updateAccountDomainAttributes(context.Background(), initAccount, jwtclaims.AuthorizationClaims{UserId: testCase.inputInitUserParams.UserId, Domain: testCase.inputInitUserParams.Domain, DomainCategory: testCase.inputInitUserParams.DomainCategory}, true) err = manager.updateAccountDomainAttributes(context.Background(), initAccount, jwtclaims.AuthorizationClaims{UserId: testCase.inputInitUserParams.UserId, Domain: testCase.inputInitUserParams.Domain, DomainCategory: testCase.inputInitUserParams.DomainCategory}, true)
require.NoError(t, err, "update init user failed") require.NoError(t, err, "update init user failed")
@ -645,8 +648,12 @@ func TestDefaultAccountManager_GetAccountFromToken(t *testing.T) {
testCase.inputClaims.AccountId = initAccount.Id testCase.inputClaims.AccountId = initAccount.Id
} }
account, _, err := manager.GetAccountFromToken(context.Background(), testCase.inputClaims) accountID, _, err = manager.GetAccountIDFromToken(context.Background(), testCase.inputClaims)
require.NoError(t, err, "support function failed") require.NoError(t, err, "support function failed")
account, err := manager.Store.GetAccount(context.Background(), accountID)
require.NoError(t, err, "get account failed")
verifyNewAccountHasDefaultFields(t, account, testCase.expectedCreatedBy, testCase.inputClaims.Domain, testCase.expectedUsers) verifyNewAccountHasDefaultFields(t, account, testCase.expectedCreatedBy, testCase.inputClaims.Domain, testCase.expectedUsers)
verifyCanAddPeerToAccount(t, manager, account, testCase.expectedCreatedBy) verifyCanAddPeerToAccount(t, manager, account, testCase.expectedCreatedBy)
@ -669,12 +676,13 @@ func TestDefaultAccountManager_GetGroupsFromTheToken(t *testing.T) {
require.NoError(t, err, "unable to create account manager") require.NoError(t, err, "unable to create account manager")
accountID := initAccount.Id accountID := initAccount.Id
acc, err := manager.GetAccountByUserOrAccountID(context.Background(), userId, accountID, domain) accountID, err = manager.GetAccountIDByUserOrAccountID(context.Background(), userId, accountID, domain)
require.NoError(t, err, "create init user failed") require.NoError(t, err, "create init user failed")
// as initAccount was created without account id we have to take the id after account initialization // as initAccount was created without account id we have to take the id after account initialization
// that happens inside the GetAccountByUserOrAccountID where the id is getting generated // that happens inside the GetAccountIDByUserOrAccountID where the id is getting generated
// it is important to set the id as it help to avoid creating additional account with empty Id and re-pointing indices to it // it is important to set the id as it help to avoid creating additional account with empty Id and re-pointing indices to it
initAccount = acc initAccount, err = manager.Store.GetAccount(context.Background(), accountID)
require.NoError(t, err, "get init account failed")
claims := jwtclaims.AuthorizationClaims{ claims := jwtclaims.AuthorizationClaims{
AccountId: accountID, // is empty as it is based on accountID right after initialization of initAccount AccountId: accountID, // is empty as it is based on accountID right after initialization of initAccount
@ -685,8 +693,12 @@ func TestDefaultAccountManager_GetGroupsFromTheToken(t *testing.T) {
} }
t.Run("JWT groups disabled", func(t *testing.T) { t.Run("JWT groups disabled", func(t *testing.T) {
account, _, err := manager.GetAccountFromToken(context.Background(), claims) accountID, _, err := manager.GetAccountIDFromToken(context.Background(), claims)
require.NoError(t, err, "get account by token failed") require.NoError(t, err, "get account by token failed")
account, err := manager.Store.GetAccount(context.Background(), accountID)
require.NoError(t, err, "get account failed")
require.Len(t, account.Groups, 1, "only ALL group should exists") require.Len(t, account.Groups, 1, "only ALL group should exists")
}) })
@ -696,8 +708,12 @@ func TestDefaultAccountManager_GetGroupsFromTheToken(t *testing.T) {
require.NoError(t, err, "save account failed") require.NoError(t, err, "save account failed")
require.Len(t, manager.Store.GetAllAccounts(context.Background()), 1, "only one account should exist") require.Len(t, manager.Store.GetAllAccounts(context.Background()), 1, "only one account should exist")
account, _, err := manager.GetAccountFromToken(context.Background(), claims) accountID, _, err := manager.GetAccountIDFromToken(context.Background(), claims)
require.NoError(t, err, "get account by token failed") require.NoError(t, err, "get account by token failed")
account, err := manager.Store.GetAccount(context.Background(), accountID)
require.NoError(t, err, "get account failed")
require.Len(t, account.Groups, 1, "if group claim is not set no group added from JWT") require.Len(t, account.Groups, 1, "if group claim is not set no group added from JWT")
}) })
@ -708,8 +724,12 @@ func TestDefaultAccountManager_GetGroupsFromTheToken(t *testing.T) {
require.NoError(t, err, "save account failed") require.NoError(t, err, "save account failed")
require.Len(t, manager.Store.GetAllAccounts(context.Background()), 1, "only one account should exist") require.Len(t, manager.Store.GetAllAccounts(context.Background()), 1, "only one account should exist")
account, _, err := manager.GetAccountFromToken(context.Background(), claims) accountID, _, err := manager.GetAccountIDFromToken(context.Background(), claims)
require.NoError(t, err, "get account by token failed") require.NoError(t, err, "get account by token failed")
account, err := manager.Store.GetAccount(context.Background(), accountID)
require.NoError(t, err, "get account failed")
require.Len(t, account.Groups, 3, "groups should be added to the account") require.Len(t, account.Groups, 3, "groups should be added to the account")
groupsByNames := map[string]*group.Group{} groupsByNames := map[string]*group.Group{}
@ -874,21 +894,21 @@ func TestAccountManager_GetAccountByUserOrAccountId(t *testing.T) {
userId := "test_user" userId := "test_user"
account, err := manager.GetAccountByUserOrAccountID(context.Background(), userId, "", "") accountID, err := manager.GetAccountIDByUserOrAccountID(context.Background(), userId, "", "")
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
if account == nil { if accountID == "" {
t.Fatalf("expected to create an account for a user %s", userId) t.Fatalf("expected to create an account for a user %s", userId)
return return
} }
_, err = manager.GetAccountByUserOrAccountID(context.Background(), "", account.Id, "") _, err = manager.GetAccountIDByUserOrAccountID(context.Background(), "", accountID, "")
if err != nil { if err != nil {
t.Errorf("expected to get existing account after creation using userid, no account was found for a account %s", account.Id) t.Errorf("expected to get existing account after creation using userid, no account was found for a account %s", accountID)
} }
_, err = manager.GetAccountByUserOrAccountID(context.Background(), "", "", "") _, err = manager.GetAccountIDByUserOrAccountID(context.Background(), "", "", "")
if err == nil { if err == nil {
t.Errorf("expected an error when user and account IDs are empty") t.Errorf("expected an error when user and account IDs are empty")
} }
@ -1240,7 +1260,7 @@ func TestAccountManager_NetworkUpdates(t *testing.T) {
} }
}() }()
if err := manager.SavePolicy(context.Background(), account.Id, userID, &policy); err != nil { if err := manager.SavePolicy(context.Background(), account.Id, userID, &policy, false); err != nil {
t.Errorf("delete default rule: %v", err) t.Errorf("delete default rule: %v", err)
return return
} }
@ -1648,19 +1668,22 @@ func TestDefaultAccountManager_DefaultAccountSettings(t *testing.T) {
manager, err := createManager(t) manager, err := createManager(t)
require.NoError(t, err, "unable to create account manager") require.NoError(t, err, "unable to create account manager")
account, err := manager.GetAccountByUserOrAccountID(context.Background(), userID, "", "") accountID, err := manager.GetAccountIDByUserOrAccountID(context.Background(), userID, "", "")
require.NoError(t, err, "unable to create an account") require.NoError(t, err, "unable to create an account")
assert.NotNil(t, account.Settings) settings, err := manager.Store.GetAccountSettings(context.Background(), LockingStrengthShare, accountID)
assert.Equal(t, account.Settings.PeerLoginExpirationEnabled, true) require.NoError(t, err, "unable to get account settings")
assert.Equal(t, account.Settings.PeerLoginExpiration, 24*time.Hour)
assert.NotNil(t, settings)
assert.Equal(t, settings.PeerLoginExpirationEnabled, true)
assert.Equal(t, settings.PeerLoginExpiration, 24*time.Hour)
} }
func TestDefaultAccountManager_UpdatePeer_PeerLoginExpiration(t *testing.T) { func TestDefaultAccountManager_UpdatePeer_PeerLoginExpiration(t *testing.T) {
manager, err := createManager(t) manager, err := createManager(t)
require.NoError(t, err, "unable to create account manager") require.NoError(t, err, "unable to create account manager")
_, err = manager.GetAccountByUserOrAccountID(context.Background(), userID, "", "") _, err = manager.GetAccountIDByUserOrAccountID(context.Background(), userID, "", "")
require.NoError(t, err, "unable to create an account") require.NoError(t, err, "unable to create an account")
key, err := wgtypes.GenerateKey() key, err := wgtypes.GenerateKey()
@ -1672,11 +1695,16 @@ func TestDefaultAccountManager_UpdatePeer_PeerLoginExpiration(t *testing.T) {
}) })
require.NoError(t, err, "unable to add peer") require.NoError(t, err, "unable to add peer")
account, err := manager.GetAccountByUserOrAccountID(context.Background(), userID, "", "") accountID, err := manager.GetAccountIDByUserOrAccountID(context.Background(), userID, "", "")
require.NoError(t, err, "unable to get the account") require.NoError(t, err, "unable to get the account")
account, err := manager.Store.GetAccount(context.Background(), accountID)
require.NoError(t, err, "unable to get the account")
err = manager.MarkPeerConnected(context.Background(), key.PublicKey().String(), true, nil, account) err = manager.MarkPeerConnected(context.Background(), key.PublicKey().String(), true, nil, account)
require.NoError(t, err, "unable to mark peer connected") require.NoError(t, err, "unable to mark peer connected")
account, err = manager.UpdateAccountSettings(context.Background(), account.Id, userID, &Settings{
account, err = manager.UpdateAccountSettings(context.Background(), accountID, userID, &Settings{
PeerLoginExpiration: time.Hour, PeerLoginExpiration: time.Hour,
PeerLoginExpirationEnabled: true, PeerLoginExpirationEnabled: true,
}) })
@ -1713,7 +1741,7 @@ func TestDefaultAccountManager_MarkPeerConnected_PeerLoginExpiration(t *testing.
manager, err := createManager(t) manager, err := createManager(t)
require.NoError(t, err, "unable to create account manager") require.NoError(t, err, "unable to create account manager")
account, err := manager.GetAccountByUserOrAccountID(context.Background(), userID, "", "") accountID, err := manager.GetAccountIDByUserOrAccountID(context.Background(), userID, "", "")
require.NoError(t, err, "unable to create an account") require.NoError(t, err, "unable to create an account")
key, err := wgtypes.GenerateKey() key, err := wgtypes.GenerateKey()
@ -1724,7 +1752,7 @@ func TestDefaultAccountManager_MarkPeerConnected_PeerLoginExpiration(t *testing.
LoginExpirationEnabled: true, LoginExpirationEnabled: true,
}) })
require.NoError(t, err, "unable to add peer") require.NoError(t, err, "unable to add peer")
_, err = manager.UpdateAccountSettings(context.Background(), account.Id, userID, &Settings{ _, err = manager.UpdateAccountSettings(context.Background(), accountID, userID, &Settings{
PeerLoginExpiration: time.Hour, PeerLoginExpiration: time.Hour,
PeerLoginExpirationEnabled: true, PeerLoginExpirationEnabled: true,
}) })
@ -1741,8 +1769,12 @@ func TestDefaultAccountManager_MarkPeerConnected_PeerLoginExpiration(t *testing.
}, },
} }
account, err = manager.GetAccountByUserOrAccountID(context.Background(), userID, "", "") accountID, err = manager.GetAccountIDByUserOrAccountID(context.Background(), userID, "", "")
require.NoError(t, err, "unable to get the account") require.NoError(t, err, "unable to get the account")
account, err := manager.Store.GetAccount(context.Background(), accountID)
require.NoError(t, err, "unable to get the account")
// when we mark peer as connected, the peer login expiration routine should trigger // when we mark peer as connected, the peer login expiration routine should trigger
err = manager.MarkPeerConnected(context.Background(), key.PublicKey().String(), true, nil, account) err = manager.MarkPeerConnected(context.Background(), key.PublicKey().String(), true, nil, account)
require.NoError(t, err, "unable to mark peer connected") require.NoError(t, err, "unable to mark peer connected")
@ -1757,7 +1789,7 @@ func TestDefaultAccountManager_UpdateAccountSettings_PeerLoginExpiration(t *test
manager, err := createManager(t) manager, err := createManager(t)
require.NoError(t, err, "unable to create account manager") require.NoError(t, err, "unable to create account manager")
_, err = manager.GetAccountByUserOrAccountID(context.Background(), userID, "", "") _, err = manager.GetAccountIDByUserOrAccountID(context.Background(), userID, "", "")
require.NoError(t, err, "unable to create an account") require.NoError(t, err, "unable to create an account")
key, err := wgtypes.GenerateKey() key, err := wgtypes.GenerateKey()
@ -1769,8 +1801,12 @@ func TestDefaultAccountManager_UpdateAccountSettings_PeerLoginExpiration(t *test
}) })
require.NoError(t, err, "unable to add peer") require.NoError(t, err, "unable to add peer")
account, err := manager.GetAccountByUserOrAccountID(context.Background(), userID, "", "") accountID, err := manager.GetAccountIDByUserOrAccountID(context.Background(), userID, "", "")
require.NoError(t, err, "unable to get the account") require.NoError(t, err, "unable to get the account")
account, err := manager.Store.GetAccount(context.Background(), accountID)
require.NoError(t, err, "unable to get the account")
err = manager.MarkPeerConnected(context.Background(), key.PublicKey().String(), true, nil, account) err = manager.MarkPeerConnected(context.Background(), key.PublicKey().String(), true, nil, account)
require.NoError(t, err, "unable to mark peer connected") require.NoError(t, err, "unable to mark peer connected")
@ -1813,10 +1849,10 @@ func TestDefaultAccountManager_UpdateAccountSettings(t *testing.T) {
manager, err := createManager(t) manager, err := createManager(t)
require.NoError(t, err, "unable to create account manager") require.NoError(t, err, "unable to create account manager")
account, err := manager.GetAccountByUserOrAccountID(context.Background(), userID, "", "") accountID, err := manager.GetAccountIDByUserOrAccountID(context.Background(), userID, "", "")
require.NoError(t, err, "unable to create an account") require.NoError(t, err, "unable to create an account")
updated, err := manager.UpdateAccountSettings(context.Background(), account.Id, userID, &Settings{ updated, err := manager.UpdateAccountSettings(context.Background(), accountID, userID, &Settings{
PeerLoginExpiration: time.Hour, PeerLoginExpiration: time.Hour,
PeerLoginExpirationEnabled: false, PeerLoginExpirationEnabled: false,
}) })
@ -1824,19 +1860,22 @@ func TestDefaultAccountManager_UpdateAccountSettings(t *testing.T) {
assert.False(t, updated.Settings.PeerLoginExpirationEnabled) assert.False(t, updated.Settings.PeerLoginExpirationEnabled)
assert.Equal(t, updated.Settings.PeerLoginExpiration, time.Hour) assert.Equal(t, updated.Settings.PeerLoginExpiration, time.Hour)
account, err = manager.GetAccountByUserOrAccountID(context.Background(), "", account.Id, "") accountID, err = manager.GetAccountIDByUserOrAccountID(context.Background(), "", accountID, "")
require.NoError(t, err, "unable to get account by ID") require.NoError(t, err, "unable to get account by ID")
assert.False(t, account.Settings.PeerLoginExpirationEnabled) settings, err := manager.Store.GetAccountSettings(context.Background(), LockingStrengthShare, accountID)
assert.Equal(t, account.Settings.PeerLoginExpiration, time.Hour) require.NoError(t, err, "unable to get account settings")
_, err = manager.UpdateAccountSettings(context.Background(), account.Id, userID, &Settings{ assert.False(t, settings.PeerLoginExpirationEnabled)
assert.Equal(t, settings.PeerLoginExpiration, time.Hour)
_, err = manager.UpdateAccountSettings(context.Background(), accountID, userID, &Settings{
PeerLoginExpiration: time.Second, PeerLoginExpiration: time.Second,
PeerLoginExpirationEnabled: false, PeerLoginExpirationEnabled: false,
}) })
require.Error(t, err, "expecting to fail when providing PeerLoginExpiration less than one hour") require.Error(t, err, "expecting to fail when providing PeerLoginExpiration less than one hour")
_, err = manager.UpdateAccountSettings(context.Background(), account.Id, userID, &Settings{ _, err = manager.UpdateAccountSettings(context.Background(), accountID, userID, &Settings{
PeerLoginExpiration: time.Hour * 24 * 181, PeerLoginExpiration: time.Hour * 24 * 181,
PeerLoginExpirationEnabled: false, PeerLoginExpirationEnabled: false,
}) })

View File

@ -80,24 +80,16 @@ func (d DNSSettings) Copy() DNSSettings {
// GetDNSSettings validates a user role and returns the DNS settings for the provided account ID // GetDNSSettings validates a user role and returns the DNS settings for the provided account ID
func (am *DefaultAccountManager) GetDNSSettings(ctx context.Context, accountID string, userID string) (*DNSSettings, error) { func (am *DefaultAccountManager) GetDNSSettings(ctx context.Context, accountID string, userID string) (*DNSSettings, error) {
unlock := am.Store.AcquireWriteLockByUID(ctx, accountID) user, err := am.Store.GetUserByUserID(ctx, LockingStrengthShare, userID)
defer unlock()
account, err := am.Store.GetAccount(ctx, accountID)
if err != nil { if err != nil {
return nil, err return nil, err
} }
user, err := account.FindUser(userID) if !user.IsAdminOrServiceUser() || user.AccountID != accountID {
if err != nil {
return nil, err
}
if !(user.HasAdminPower() || user.IsServiceUser) {
return nil, status.Errorf(status.PermissionDenied, "only users with admin power are allowed to view DNS settings") return nil, status.Errorf(status.PermissionDenied, "only users with admin power are allowed to view DNS settings")
} }
dnsSettings := account.DNSSettings.Copy()
return &dnsSettings, nil return am.Store.GetAccountDNSSettings(ctx, LockingStrengthShare, accountID)
} }
// SaveDNSSettings validates a user role and updates the account's DNS settings // SaveDNSSettings validates a user role and updates the account's DNS settings

View File

@ -10,14 +10,15 @@ import (
"sync" "sync"
"time" "time"
"github.com/rs/xid" "github.com/netbirdio/netbird/dns"
log "github.com/sirupsen/logrus"
nbgroup "github.com/netbirdio/netbird/management/server/group" nbgroup "github.com/netbirdio/netbird/management/server/group"
nbpeer "github.com/netbirdio/netbird/management/server/peer" nbpeer "github.com/netbirdio/netbird/management/server/peer"
"github.com/netbirdio/netbird/management/server/posture" "github.com/netbirdio/netbird/management/server/posture"
"github.com/netbirdio/netbird/management/server/status" "github.com/netbirdio/netbird/management/server/status"
"github.com/netbirdio/netbird/management/server/telemetry" "github.com/netbirdio/netbird/management/server/telemetry"
"github.com/netbirdio/netbird/route"
"github.com/rs/xid"
log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/util" "github.com/netbirdio/netbird/util"
) )
@ -634,10 +635,19 @@ func (s *FileStore) GetUserByUserID(_ context.Context, _ LockingStrength, userID
return nil, err return nil, err
} }
return account.Users[userID].Copy(), nil user := account.Users[userID].Copy()
pat := make([]PersonalAccessToken, 0, len(user.PATs))
for _, token := range user.PATs {
if token != nil {
pat = append(pat, *token)
}
}
user.PATsG = pat
return user, nil
} }
func (s *FileStore) GetAccountGroups(ctx context.Context, accountID string) ([]*nbgroup.Group, error) { func (s *FileStore) GetAccountGroups(_ context.Context, accountID string) ([]*nbgroup.Group, error) {
account, err := s.getAccount(accountID) account, err := s.getAccount(accountID)
if err != nil { if err != nil {
return nil, err return nil, err
@ -931,7 +941,7 @@ func (s *FileStore) SaveUserLastLogin(_ context.Context, accountID, userID strin
return nil return nil
} }
func (s *FileStore) GetPostureCheckByChecksDefinition(accountID string, checks *posture.ChecksDefinition) (*posture.Checks, error) { func (s *FileStore) GetPostureCheckByChecksDefinition(_ string, _ *posture.ChecksDefinition) (*posture.Checks, error) {
return nil, status.Errorf(status.Internal, "GetPostureCheckByChecksDefinition is not implemented") return nil, status.Errorf(status.Internal, "GetPostureCheckByChecksDefinition is not implemented")
} }
@ -950,10 +960,85 @@ func (s *FileStore) GetStoreEngine() StoreEngine {
return FileStoreEngine return FileStoreEngine
} }
func (s *FileStore) SaveUsers(accountID string, users map[string]*User) error { func (s *FileStore) SaveUsers(_ string, _ map[string]*User) error {
return status.Errorf(status.Internal, "SaveUsers is not implemented") return status.Errorf(status.Internal, "SaveUsers is not implemented")
} }
func (s *FileStore) SaveGroups(accountID string, groups map[string]*nbgroup.Group) error { func (s *FileStore) SaveGroups(_ string, _ map[string]*nbgroup.Group) error {
return status.Errorf(status.Internal, "SaveGroups is not implemented") return status.Errorf(status.Internal, "SaveGroups is not implemented")
} }
func (s *FileStore) GetAccountIDByPrivateDomain(_ context.Context, _ LockingStrength, _ string) (string, error) {
return "", status.Errorf(status.Internal, "GetAccountIDByPrivateDomain is not implemented")
}
func (s *FileStore) GetAccountDomainAndCategory(_ context.Context, _ LockingStrength, accountID string) (string, string, error) {
s.mux.Lock()
defer s.mux.Unlock()
account, err := s.getAccount(accountID)
if err != nil {
return "", "", err
}
return account.Domain, account.DomainCategory, nil
}
// AccountExists checks whether an account exists by the given ID.
func (s *FileStore) AccountExists(_ context.Context, _ LockingStrength, id string) (bool, error) {
_, exists := s.Accounts[id]
return exists, nil
}
func (s *FileStore) GetAccountDNSSettings(_ context.Context, _ LockingStrength, _ string) (*DNSSettings, error) {
return nil, status.Errorf(status.Internal, "GetAccountDNSSettings is not implemented")
}
func (s *FileStore) GetGroupByID(_ context.Context, _ LockingStrength, _, _ string) (*nbgroup.Group, error) {
return nil, status.Errorf(status.Internal, "GetGroupByID is not implemented")
}
func (s *FileStore) GetGroupByName(_ context.Context, _ LockingStrength, _, _ string) (*nbgroup.Group, error) {
return nil, status.Errorf(status.Internal, "GetGroupByName is not implemented")
}
func (s *FileStore) GetAccountPolicies(_ context.Context, _ LockingStrength, _ string) ([]*Policy, error) {
return nil, status.Errorf(status.Internal, "GetPolicyByID is not implemented")
}
func (s *FileStore) GetPolicyByID(_ context.Context, _ LockingStrength, _ string, _ string) (*Policy, error) {
return nil, status.Errorf(status.Internal, "GetPolicyByID is not implemented")
}
func (s *FileStore) GetAccountPostureChecks(_ context.Context, _ LockingStrength, _ string) ([]*posture.Checks, error) {
return nil, status.Errorf(status.Internal, "GetAccountPostureChecks is not implemented")
}
func (s *FileStore) GetPostureChecksByID(_ context.Context, _ LockingStrength, _ string, _ string) (*posture.Checks, error) {
return nil, status.Errorf(status.Internal, "GetPostureChecksByID is not implemented")
}
func (s *FileStore) GetAccountRoutes(_ context.Context, _ LockingStrength, _ string) ([]*route.Route, error) {
return nil, status.Errorf(status.Internal, "GetAccountRoutes is not implemented")
}
func (s *FileStore) GetRouteByID(_ context.Context, _ LockingStrength, _ string, _ string) (*route.Route, error) {
return nil, status.Errorf(status.Internal, "GetRouteByID is not implemented")
}
func (s *FileStore) GetAccountSetupKeys(_ context.Context, _ LockingStrength, _ string) ([]*SetupKey, error) {
return nil, status.Errorf(status.Internal, "GetAccountSetupKeys is not implemented")
}
func (s *FileStore) GetSetupKeyByID(_ context.Context, _ LockingStrength, _ string, _ string) (*SetupKey, error) {
return nil, status.Errorf(status.Internal, "GetSetupKeyByID is not implemented")
}
func (s *FileStore) GetAccountNameServerGroups(_ context.Context, _ LockingStrength, _ string) ([]*dns.NameServerGroup, error) {
return nil, status.Errorf(status.Internal, "GetAccountNameServerGroups is not implemented")
}
func (s *FileStore) GetNameServerGroupByID(_ context.Context, _ LockingStrength, _ string, _ string) (*dns.NameServerGroup, error) {
return nil, status.Errorf(status.Internal, "GetNameServerGroupByID is not implemented")
}

View File

@ -25,91 +25,46 @@ func (e *GroupLinkError) Error() string {
return fmt.Sprintf("group has been linked to %s: %s", e.Resource, e.Name) return fmt.Sprintf("group has been linked to %s: %s", e.Resource, e.Name)
} }
// GetGroup object of the peers // CheckGroupPermissions validates if a user has the necessary permissions to view groups
func (am *DefaultAccountManager) CheckGroupPermissions(ctx context.Context, accountID, userID string) error {
settings, err := am.Store.GetAccountSettings(ctx, LockingStrengthShare, accountID)
if err != nil {
return err
}
user, err := am.Store.GetUserByUserID(ctx, LockingStrengthShare, userID)
if err != nil {
return err
}
if (!user.IsAdminOrServiceUser() && settings.RegularUsersViewBlocked) || user.AccountID != accountID {
return status.Errorf(status.PermissionDenied, "groups are blocked for users")
}
return nil
}
// GetGroup returns a specific group by groupID in an account
func (am *DefaultAccountManager) GetGroup(ctx context.Context, accountID, groupID, userID string) (*nbgroup.Group, error) { func (am *DefaultAccountManager) GetGroup(ctx context.Context, accountID, groupID, userID string) (*nbgroup.Group, error) {
unlock := am.Store.AcquireWriteLockByUID(ctx, accountID) if err := am.CheckGroupPermissions(ctx, accountID, userID); err != nil {
defer unlock()
account, err := am.Store.GetAccount(ctx, accountID)
if err != nil {
return nil, err return nil, err
} }
user, err := account.FindUser(userID) return am.Store.GetGroupByID(ctx, LockingStrengthShare, groupID, accountID)
if err != nil {
return nil, err
}
if !user.HasAdminPower() && !user.IsServiceUser && account.Settings.RegularUsersViewBlocked {
return nil, status.Errorf(status.PermissionDenied, "groups are blocked for users")
}
group, ok := account.Groups[groupID]
if ok {
return group, nil
}
return nil, status.Errorf(status.NotFound, "group with ID %s not found", groupID)
} }
// GetAllGroups returns all groups in an account // GetAllGroups returns all groups in an account
func (am *DefaultAccountManager) GetAllGroups(ctx context.Context, accountID string, userID string) ([]*nbgroup.Group, error) { func (am *DefaultAccountManager) GetAllGroups(ctx context.Context, accountID, userID string) ([]*nbgroup.Group, error) {
unlock := am.Store.AcquireWriteLockByUID(ctx, accountID) if err := am.CheckGroupPermissions(ctx, accountID, userID); err != nil {
defer unlock()
account, err := am.Store.GetAccount(ctx, accountID)
if err != nil {
return nil, err return nil, err
} }
user, err := account.FindUser(userID) return am.Store.GetAccountGroups(ctx, accountID)
if err != nil {
return nil, err
}
if !user.HasAdminPower() && !user.IsServiceUser && account.Settings.RegularUsersViewBlocked {
return nil, status.Errorf(status.PermissionDenied, "groups are blocked for users")
}
groups := make([]*nbgroup.Group, 0, len(account.Groups))
for _, item := range account.Groups {
groups = append(groups, item)
}
return groups, nil
} }
// GetGroupByName filters all groups in an account by name and returns the one with the most peers // GetGroupByName filters all groups in an account by name and returns the one with the most peers
func (am *DefaultAccountManager) GetGroupByName(ctx context.Context, groupName, accountID string) (*nbgroup.Group, error) { func (am *DefaultAccountManager) GetGroupByName(ctx context.Context, groupName, accountID string) (*nbgroup.Group, error) {
unlock := am.Store.AcquireWriteLockByUID(ctx, accountID) return am.Store.GetGroupByName(ctx, LockingStrengthShare, groupName, accountID)
defer unlock()
account, err := am.Store.GetAccount(ctx, accountID)
if err != nil {
return nil, err
}
matchingGroups := make([]*nbgroup.Group, 0)
for _, group := range account.Groups {
if group.Name == groupName {
matchingGroups = append(matchingGroups, group)
}
}
if len(matchingGroups) == 0 {
return nil, status.Errorf(status.NotFound, "group with name %s not found", groupName)
}
maxPeers := -1
var groupWithMostPeers *nbgroup.Group
for i, group := range matchingGroups {
if len(group.Peers) > maxPeers {
maxPeers = len(group.Peers)
groupWithMostPeers = matchingGroups[i]
}
}
return groupWithMostPeers, nil
} }
// SaveGroup object of the peers // SaveGroup object of the peers
@ -262,6 +217,15 @@ func (am *DefaultAccountManager) DeleteGroup(ctx context.Context, accountId, use
return nil return nil
} }
allGroup, err := account.GetGroupAll()
if err != nil {
return err
}
if allGroup.ID == groupID {
return status.Errorf(status.InvalidArgument, "deleting group ALL is not allowed")
}
if err = validateDeleteGroup(account, group, userId); err != nil { if err = validateDeleteGroup(account, group, userId); err != nil {
return err return err
} }

View File

@ -262,7 +262,7 @@ func (s *GRPCServer) validateToken(ctx context.Context, jwtToken string) (string
} }
claims := s.jwtClaimsExtractor.FromToken(token) claims := s.jwtClaimsExtractor.FromToken(token)
// we need to call this method because if user is new, we will automatically add it to existing or create a new account // we need to call this method because if user is new, we will automatically add it to existing or create a new account
_, _, err = s.accountManager.GetAccountFromToken(ctx, claims) _, _, err = s.accountManager.GetAccountIDFromToken(ctx, claims)
if err != nil { if err != nil {
return "", status.Errorf(codes.Internal, "unable to fetch account with claims, err: %v", err) return "", status.Errorf(codes.Internal, "unable to fetch account with claims, err: %v", err)
} }

View File

@ -35,25 +35,26 @@ func NewAccountsHandler(accountManager server.AccountManager, authCfg AuthCfg) *
// GetAllAccounts is HTTP GET handler that returns a list of accounts. Effectively returns just a single account. // GetAllAccounts is HTTP GET handler that returns a list of accounts. Effectively returns just a single account.
func (h *AccountsHandler) GetAllAccounts(w http.ResponseWriter, r *http.Request) { func (h *AccountsHandler) GetAllAccounts(w http.ResponseWriter, r *http.Request) {
claims := h.claimsExtractor.FromRequestContext(r) claims := h.claimsExtractor.FromRequestContext(r)
account, user, err := h.accountManager.GetAccountFromToken(r.Context(), claims) accountID, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims)
if err != nil { if err != nil {
util.WriteError(r.Context(), err, w) util.WriteError(r.Context(), err, w)
return return
} }
if !(user.HasAdminPower() || user.IsServiceUser) { settings, err := h.accountManager.GetAccountSettings(r.Context(), accountID, userID)
util.WriteError(r.Context(), status.Errorf(status.PermissionDenied, "the user has no permission to access account data"), w) if err != nil {
util.WriteError(r.Context(), err, w)
return return
} }
resp := toAccountResponse(account) resp := toAccountResponse(accountID, settings)
util.WriteJSONObject(r.Context(), w, []*api.Account{resp}) util.WriteJSONObject(r.Context(), w, []*api.Account{resp})
} }
// UpdateAccount is HTTP PUT handler that updates the provided account. Updates only account settings (server.Settings) // UpdateAccount is HTTP PUT handler that updates the provided account. Updates only account settings (server.Settings)
func (h *AccountsHandler) UpdateAccount(w http.ResponseWriter, r *http.Request) { func (h *AccountsHandler) UpdateAccount(w http.ResponseWriter, r *http.Request) {
claims := h.claimsExtractor.FromRequestContext(r) claims := h.claimsExtractor.FromRequestContext(r)
_, user, err := h.accountManager.GetAccountFromToken(r.Context(), claims) _, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims)
if err != nil { if err != nil {
util.WriteError(r.Context(), err, w) util.WriteError(r.Context(), err, w)
return return
@ -96,24 +97,19 @@ func (h *AccountsHandler) UpdateAccount(w http.ResponseWriter, r *http.Request)
settings.JWTAllowGroups = *req.Settings.JwtAllowGroups settings.JWTAllowGroups = *req.Settings.JwtAllowGroups
} }
updatedAccount, err := h.accountManager.UpdateAccountSettings(r.Context(), accountID, user.Id, settings) updatedAccount, err := h.accountManager.UpdateAccountSettings(r.Context(), accountID, userID, settings)
if err != nil { if err != nil {
util.WriteError(r.Context(), err, w) util.WriteError(r.Context(), err, w)
return return
} }
resp := toAccountResponse(updatedAccount) resp := toAccountResponse(updatedAccount.Id, updatedAccount.Settings)
util.WriteJSONObject(r.Context(), w, &resp) util.WriteJSONObject(r.Context(), w, &resp)
} }
// DeleteAccount is a HTTP DELETE handler to delete an account // DeleteAccount is a HTTP DELETE handler to delete an account
func (h *AccountsHandler) DeleteAccount(w http.ResponseWriter, r *http.Request) { func (h *AccountsHandler) DeleteAccount(w http.ResponseWriter, r *http.Request) {
if r.Method != http.MethodDelete {
util.WriteErrorResponse("wrong HTTP method", http.StatusMethodNotAllowed, w)
return
}
claims := h.claimsExtractor.FromRequestContext(r) claims := h.claimsExtractor.FromRequestContext(r)
vars := mux.Vars(r) vars := mux.Vars(r)
targetAccountID := vars["accountId"] targetAccountID := vars["accountId"]
@ -131,28 +127,28 @@ func (h *AccountsHandler) DeleteAccount(w http.ResponseWriter, r *http.Request)
util.WriteJSONObject(r.Context(), w, emptyObject{}) util.WriteJSONObject(r.Context(), w, emptyObject{})
} }
func toAccountResponse(account *server.Account) *api.Account { func toAccountResponse(accountID string, settings *server.Settings) *api.Account {
jwtAllowGroups := account.Settings.JWTAllowGroups jwtAllowGroups := settings.JWTAllowGroups
if jwtAllowGroups == nil { if jwtAllowGroups == nil {
jwtAllowGroups = []string{} jwtAllowGroups = []string{}
} }
settings := api.AccountSettings{ apiSettings := api.AccountSettings{
PeerLoginExpiration: int(account.Settings.PeerLoginExpiration.Seconds()), PeerLoginExpiration: int(settings.PeerLoginExpiration.Seconds()),
PeerLoginExpirationEnabled: account.Settings.PeerLoginExpirationEnabled, PeerLoginExpirationEnabled: settings.PeerLoginExpirationEnabled,
GroupsPropagationEnabled: &account.Settings.GroupsPropagationEnabled, GroupsPropagationEnabled: &settings.GroupsPropagationEnabled,
JwtGroupsEnabled: &account.Settings.JWTGroupsEnabled, JwtGroupsEnabled: &settings.JWTGroupsEnabled,
JwtGroupsClaimName: &account.Settings.JWTGroupsClaimName, JwtGroupsClaimName: &settings.JWTGroupsClaimName,
JwtAllowGroups: &jwtAllowGroups, JwtAllowGroups: &jwtAllowGroups,
RegularUsersViewBlocked: account.Settings.RegularUsersViewBlocked, RegularUsersViewBlocked: settings.RegularUsersViewBlocked,
} }
if account.Settings.Extra != nil { if settings.Extra != nil {
settings.Extra = &api.AccountExtraSettings{PeerApprovalEnabled: &account.Settings.Extra.PeerApprovalEnabled} apiSettings.Extra = &api.AccountExtraSettings{PeerApprovalEnabled: &settings.Extra.PeerApprovalEnabled}
} }
return &api.Account{ return &api.Account{
Id: account.Id, Id: accountID,
Settings: settings, Settings: apiSettings,
} }
} }

View File

@ -23,8 +23,11 @@ import (
func initAccountsTestData(account *server.Account, admin *server.User) *AccountsHandler { func initAccountsTestData(account *server.Account, admin *server.User) *AccountsHandler {
return &AccountsHandler{ return &AccountsHandler{
accountManager: &mock_server.MockAccountManager{ accountManager: &mock_server.MockAccountManager{
GetAccountFromTokenFunc: func(ctx context.Context, claims jwtclaims.AuthorizationClaims) (*server.Account, *server.User, error) { GetAccountIDFromTokenFunc: func(ctx context.Context, claims jwtclaims.AuthorizationClaims) (string, string, error) {
return account, admin, nil return account.Id, admin.Id, nil
},
GetAccountSettingsFunc: func(ctx context.Context, accountID string, userID string) (*server.Settings, error) {
return account.Settings, nil
}, },
UpdateAccountSettingsFunc: func(ctx context.Context, accountID, userID string, newSettings *server.Settings) (*server.Account, error) { UpdateAccountSettingsFunc: func(ctx context.Context, accountID, userID string, newSettings *server.Settings) (*server.Account, error) {
halfYearLimit := 180 * 24 * time.Hour halfYearLimit := 180 * 24 * time.Hour

View File

@ -32,14 +32,14 @@ func NewDNSSettingsHandler(accountManager server.AccountManager, authCfg AuthCfg
// GetDNSSettings returns the DNS settings for the account // GetDNSSettings returns the DNS settings for the account
func (h *DNSSettingsHandler) GetDNSSettings(w http.ResponseWriter, r *http.Request) { func (h *DNSSettingsHandler) GetDNSSettings(w http.ResponseWriter, r *http.Request) {
claims := h.claimsExtractor.FromRequestContext(r) claims := h.claimsExtractor.FromRequestContext(r)
account, user, err := h.accountManager.GetAccountFromToken(r.Context(), claims) accountID, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims)
if err != nil { if err != nil {
log.WithContext(r.Context()).Error(err) log.WithContext(r.Context()).Error(err)
http.Redirect(w, r, "/", http.StatusInternalServerError) http.Redirect(w, r, "/", http.StatusInternalServerError)
return return
} }
dnsSettings, err := h.accountManager.GetDNSSettings(r.Context(), account.Id, user.Id) dnsSettings, err := h.accountManager.GetDNSSettings(r.Context(), accountID, userID)
if err != nil { if err != nil {
util.WriteError(r.Context(), err, w) util.WriteError(r.Context(), err, w)
return return
@ -55,7 +55,7 @@ func (h *DNSSettingsHandler) GetDNSSettings(w http.ResponseWriter, r *http.Reque
// UpdateDNSSettings handles update to DNS settings of an account // UpdateDNSSettings handles update to DNS settings of an account
func (h *DNSSettingsHandler) UpdateDNSSettings(w http.ResponseWriter, r *http.Request) { func (h *DNSSettingsHandler) UpdateDNSSettings(w http.ResponseWriter, r *http.Request) {
claims := h.claimsExtractor.FromRequestContext(r) claims := h.claimsExtractor.FromRequestContext(r)
account, user, err := h.accountManager.GetAccountFromToken(r.Context(), claims) accountID, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims)
if err != nil { if err != nil {
util.WriteError(r.Context(), err, w) util.WriteError(r.Context(), err, w)
return return
@ -72,7 +72,7 @@ func (h *DNSSettingsHandler) UpdateDNSSettings(w http.ResponseWriter, r *http.Re
DisabledManagementGroups: req.DisabledManagementGroups, DisabledManagementGroups: req.DisabledManagementGroups,
} }
err = h.accountManager.SaveDNSSettings(r.Context(), account.Id, user.Id, updateDNSSettings) err = h.accountManager.SaveDNSSettings(r.Context(), accountID, userID, updateDNSSettings)
if err != nil { if err != nil {
util.WriteError(r.Context(), err, w) util.WriteError(r.Context(), err, w)
return return

View File

@ -52,8 +52,8 @@ func initDNSSettingsTestData() *DNSSettingsHandler {
} }
return status.Errorf(status.InvalidArgument, "the dns settings provided are nil") return status.Errorf(status.InvalidArgument, "the dns settings provided are nil")
}, },
GetAccountFromTokenFunc: func(ctx context.Context, _ jwtclaims.AuthorizationClaims) (*server.Account, *server.User, error) { GetAccountIDFromTokenFunc: func(ctx context.Context, _ jwtclaims.AuthorizationClaims) (string, string, error) {
return testingDNSSettingsAccount, testingDNSSettingsAccount.Users[testDNSSettingsUserID], nil return testingDNSSettingsAccount.Id, testingDNSSettingsAccount.Users[testDNSSettingsUserID].Id, nil
}, },
}, },
claimsExtractor: jwtclaims.NewClaimsExtractor( claimsExtractor: jwtclaims.NewClaimsExtractor(

View File

@ -34,14 +34,14 @@ func NewEventsHandler(accountManager server.AccountManager, authCfg AuthCfg) *Ev
// GetAllEvents list of the given account // GetAllEvents list of the given account
func (h *EventsHandler) GetAllEvents(w http.ResponseWriter, r *http.Request) { func (h *EventsHandler) GetAllEvents(w http.ResponseWriter, r *http.Request) {
claims := h.claimsExtractor.FromRequestContext(r) claims := h.claimsExtractor.FromRequestContext(r)
account, user, err := h.accountManager.GetAccountFromToken(r.Context(), claims) accountID, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims)
if err != nil { if err != nil {
log.WithContext(r.Context()).Error(err) log.WithContext(r.Context()).Error(err)
http.Redirect(w, r, "/", http.StatusInternalServerError) http.Redirect(w, r, "/", http.StatusInternalServerError)
return return
} }
accountEvents, err := h.accountManager.GetEvents(r.Context(), account.Id, user.Id) accountEvents, err := h.accountManager.GetEvents(r.Context(), accountID, userID)
if err != nil { if err != nil {
util.WriteError(r.Context(), err, w) util.WriteError(r.Context(), err, w)
return return
@ -51,7 +51,7 @@ func (h *EventsHandler) GetAllEvents(w http.ResponseWriter, r *http.Request) {
events[i] = toEventResponse(e) events[i] = toEventResponse(e)
} }
err = h.fillEventsWithUserInfo(r.Context(), events, account.Id, user.Id) err = h.fillEventsWithUserInfo(r.Context(), events, accountID, userID)
if err != nil { if err != nil {
util.WriteError(r.Context(), err, w) util.WriteError(r.Context(), err, w)
return return

View File

@ -20,7 +20,7 @@ import (
"github.com/netbirdio/netbird/management/server/mock_server" "github.com/netbirdio/netbird/management/server/mock_server"
) )
func initEventsTestData(account string, user *server.User, events ...*activity.Event) *EventsHandler { func initEventsTestData(account string, events ...*activity.Event) *EventsHandler {
return &EventsHandler{ return &EventsHandler{
accountManager: &mock_server.MockAccountManager{ accountManager: &mock_server.MockAccountManager{
GetEventsFunc: func(_ context.Context, accountID, userID string) ([]*activity.Event, error) { GetEventsFunc: func(_ context.Context, accountID, userID string) ([]*activity.Event, error) {
@ -29,14 +29,8 @@ func initEventsTestData(account string, user *server.User, events ...*activity.E
} }
return []*activity.Event{}, nil return []*activity.Event{}, nil
}, },
GetAccountFromTokenFunc: func(_ context.Context, claims jwtclaims.AuthorizationClaims) (*server.Account, *server.User, error) { GetAccountIDFromTokenFunc: func(_ context.Context, claims jwtclaims.AuthorizationClaims) (string, string, error) {
return &server.Account{ return claims.AccountId, claims.UserId, nil
Id: claims.AccountId,
Domain: "hotmail.com",
Users: map[string]*server.User{
user.Id: user,
},
}, user, nil
}, },
GetUsersFromAccountFunc: func(_ context.Context, accountID, userID string) ([]*server.UserInfo, error) { GetUsersFromAccountFunc: func(_ context.Context, accountID, userID string) ([]*server.UserInfo, error) {
return make([]*server.UserInfo, 0), nil return make([]*server.UserInfo, 0), nil
@ -199,7 +193,7 @@ func TestEvents_GetEvents(t *testing.T) {
accountID := "test_account" accountID := "test_account"
adminUser := server.NewAdminUser("test_user") adminUser := server.NewAdminUser("test_user")
events := generateEvents(accountID, adminUser.Id) events := generateEvents(accountID, adminUser.Id)
handler := initEventsTestData(accountID, adminUser, events...) handler := initEventsTestData(accountID, events...)
for _, tc := range tt { for _, tc := range tt {
t.Run(tc.name, func(t *testing.T) { t.Run(tc.name, func(t *testing.T) {

View File

@ -11,9 +11,9 @@ import (
"testing" "testing"
"github.com/gorilla/mux" "github.com/gorilla/mux"
"github.com/netbirdio/netbird/management/server"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/netbirdio/netbird/management/server"
"github.com/netbirdio/netbird/management/server/geolocation" "github.com/netbirdio/netbird/management/server/geolocation"
"github.com/netbirdio/netbird/management/server/http/api" "github.com/netbirdio/netbird/management/server/http/api"
"github.com/netbirdio/netbird/management/server/jwtclaims" "github.com/netbirdio/netbird/management/server/jwtclaims"
@ -43,14 +43,11 @@ func initGeolocationTestData(t *testing.T) *GeolocationsHandler {
return &GeolocationsHandler{ return &GeolocationsHandler{
accountManager: &mock_server.MockAccountManager{ accountManager: &mock_server.MockAccountManager{
GetAccountFromTokenFunc: func(_ context.Context, claims jwtclaims.AuthorizationClaims) (*server.Account, *server.User, error) { GetAccountIDFromTokenFunc: func(_ context.Context, claims jwtclaims.AuthorizationClaims) (string, string, error) {
user := server.NewAdminUser("test_user") return claims.AccountId, claims.UserId, nil
return &server.Account{ },
Id: claims.AccountId, GetUserByIDFunc: func(ctx context.Context, id string) (*server.User, error) {
Users: map[string]*server.User{ return server.NewAdminUser(id), nil
"test_user": user,
},
}, user, nil
}, },
}, },
geolocationManager: geo, geolocationManager: geo,

View File

@ -98,7 +98,12 @@ func (l *GeolocationsHandler) GetCitiesByCountry(w http.ResponseWriter, r *http.
func (l *GeolocationsHandler) authenticateUser(r *http.Request) error { func (l *GeolocationsHandler) authenticateUser(r *http.Request) error {
claims := l.claimsExtractor.FromRequestContext(r) claims := l.claimsExtractor.FromRequestContext(r)
_, user, err := l.accountManager.GetAccountFromToken(r.Context(), claims) _, userID, err := l.accountManager.GetAccountIDFromToken(r.Context(), claims)
if err != nil {
return err
}
user, err := l.accountManager.GetUserByID(r.Context(), userID)
if err != nil { if err != nil {
return err return err
} }

View File

@ -5,6 +5,7 @@ import (
"net/http" "net/http"
"github.com/gorilla/mux" "github.com/gorilla/mux"
nbpeer "github.com/netbirdio/netbird/management/server/peer"
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/management/server" "github.com/netbirdio/netbird/management/server"
@ -35,14 +36,20 @@ func NewGroupsHandler(accountManager server.AccountManager, authCfg AuthCfg) *Gr
// GetAllGroups list for the account // GetAllGroups list for the account
func (h *GroupsHandler) GetAllGroups(w http.ResponseWriter, r *http.Request) { func (h *GroupsHandler) GetAllGroups(w http.ResponseWriter, r *http.Request) {
claims := h.claimsExtractor.FromRequestContext(r) claims := h.claimsExtractor.FromRequestContext(r)
account, user, err := h.accountManager.GetAccountFromToken(r.Context(), claims) accountID, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims)
if err != nil { if err != nil {
log.WithContext(r.Context()).Error(err) log.WithContext(r.Context()).Error(err)
http.Redirect(w, r, "/", http.StatusInternalServerError) http.Redirect(w, r, "/", http.StatusInternalServerError)
return return
} }
groups, err := h.accountManager.GetAllGroups(r.Context(), account.Id, user.Id) groups, err := h.accountManager.GetAllGroups(r.Context(), accountID, userID)
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
accountPeers, err := h.accountManager.GetPeers(r.Context(), accountID, userID)
if err != nil { if err != nil {
util.WriteError(r.Context(), err, w) util.WriteError(r.Context(), err, w)
return return
@ -50,7 +57,7 @@ func (h *GroupsHandler) GetAllGroups(w http.ResponseWriter, r *http.Request) {
groupsResponse := make([]*api.Group, 0, len(groups)) groupsResponse := make([]*api.Group, 0, len(groups))
for _, group := range groups { for _, group := range groups {
groupsResponse = append(groupsResponse, toGroupResponse(account, group)) groupsResponse = append(groupsResponse, toGroupResponse(accountPeers, group))
} }
util.WriteJSONObject(r.Context(), w, groupsResponse) util.WriteJSONObject(r.Context(), w, groupsResponse)
@ -59,7 +66,7 @@ func (h *GroupsHandler) GetAllGroups(w http.ResponseWriter, r *http.Request) {
// UpdateGroup handles update to a group identified by a given ID // UpdateGroup handles update to a group identified by a given ID
func (h *GroupsHandler) UpdateGroup(w http.ResponseWriter, r *http.Request) { func (h *GroupsHandler) UpdateGroup(w http.ResponseWriter, r *http.Request) {
claims := h.claimsExtractor.FromRequestContext(r) claims := h.claimsExtractor.FromRequestContext(r)
account, user, err := h.accountManager.GetAccountFromToken(r.Context(), claims) accountID, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims)
if err != nil { if err != nil {
util.WriteError(r.Context(), err, w) util.WriteError(r.Context(), err, w)
return return
@ -76,17 +83,18 @@ func (h *GroupsHandler) UpdateGroup(w http.ResponseWriter, r *http.Request) {
return return
} }
eg, ok := account.Groups[groupID] existingGroup, err := h.accountManager.GetGroup(r.Context(), accountID, groupID, userID)
if !ok {
util.WriteError(r.Context(), status.Errorf(status.NotFound, "couldn't find group with ID %s", groupID), w)
return
}
allGroup, err := account.GetGroupAll()
if err != nil { if err != nil {
util.WriteError(r.Context(), err, w) util.WriteError(r.Context(), err, w)
return return
} }
allGroup, err := h.accountManager.GetGroupByName(r.Context(), "All", accountID)
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
if allGroup.ID == groupID { if allGroup.ID == groupID {
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "updating group ALL is not allowed"), w) util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "updating group ALL is not allowed"), w)
return return
@ -114,23 +122,29 @@ func (h *GroupsHandler) UpdateGroup(w http.ResponseWriter, r *http.Request) {
ID: groupID, ID: groupID,
Name: req.Name, Name: req.Name,
Peers: peers, Peers: peers,
Issued: eg.Issued, Issued: existingGroup.Issued,
IntegrationReference: eg.IntegrationReference, IntegrationReference: existingGroup.IntegrationReference,
} }
if err := h.accountManager.SaveGroup(r.Context(), account.Id, user.Id, &group); err != nil { if err := h.accountManager.SaveGroup(r.Context(), accountID, userID, &group); err != nil {
log.WithContext(r.Context()).Errorf("failed updating group %s under account %s %v", groupID, account.Id, err) log.WithContext(r.Context()).Errorf("failed updating group %s under account %s %v", groupID, accountID, err)
util.WriteError(r.Context(), err, w) util.WriteError(r.Context(), err, w)
return return
} }
util.WriteJSONObject(r.Context(), w, toGroupResponse(account, &group)) accountPeers, err := h.accountManager.GetPeers(r.Context(), accountID, userID)
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
util.WriteJSONObject(r.Context(), w, toGroupResponse(accountPeers, &group))
} }
// CreateGroup handles group creation request // CreateGroup handles group creation request
func (h *GroupsHandler) CreateGroup(w http.ResponseWriter, r *http.Request) { func (h *GroupsHandler) CreateGroup(w http.ResponseWriter, r *http.Request) {
claims := h.claimsExtractor.FromRequestContext(r) claims := h.claimsExtractor.FromRequestContext(r)
account, user, err := h.accountManager.GetAccountFromToken(r.Context(), claims) accountID, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims)
if err != nil { if err != nil {
util.WriteError(r.Context(), err, w) util.WriteError(r.Context(), err, w)
return return
@ -160,24 +174,29 @@ func (h *GroupsHandler) CreateGroup(w http.ResponseWriter, r *http.Request) {
Issued: nbgroup.GroupIssuedAPI, Issued: nbgroup.GroupIssuedAPI,
} }
err = h.accountManager.SaveGroup(r.Context(), account.Id, user.Id, &group) err = h.accountManager.SaveGroup(r.Context(), accountID, userID, &group)
if err != nil { if err != nil {
util.WriteError(r.Context(), err, w) util.WriteError(r.Context(), err, w)
return return
} }
util.WriteJSONObject(r.Context(), w, toGroupResponse(account, &group)) accountPeers, err := h.accountManager.GetPeers(r.Context(), accountID, userID)
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
util.WriteJSONObject(r.Context(), w, toGroupResponse(accountPeers, &group))
} }
// DeleteGroup handles group deletion request // DeleteGroup handles group deletion request
func (h *GroupsHandler) DeleteGroup(w http.ResponseWriter, r *http.Request) { func (h *GroupsHandler) DeleteGroup(w http.ResponseWriter, r *http.Request) {
claims := h.claimsExtractor.FromRequestContext(r) claims := h.claimsExtractor.FromRequestContext(r)
account, user, err := h.accountManager.GetAccountFromToken(r.Context(), claims) accountID, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims)
if err != nil { if err != nil {
util.WriteError(r.Context(), err, w) util.WriteError(r.Context(), err, w)
return return
} }
aID := account.Id
groupID := mux.Vars(r)["groupId"] groupID := mux.Vars(r)["groupId"]
if len(groupID) == 0 { if len(groupID) == 0 {
@ -185,18 +204,7 @@ func (h *GroupsHandler) DeleteGroup(w http.ResponseWriter, r *http.Request) {
return return
} }
allGroup, err := account.GetGroupAll() err = h.accountManager.DeleteGroup(r.Context(), accountID, userID, groupID)
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
if allGroup.ID == groupID {
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "deleting group ALL is not allowed"), w)
return
}
err = h.accountManager.DeleteGroup(r.Context(), aID, user.Id, groupID)
if err != nil { if err != nil {
_, ok := err.(*server.GroupLinkError) _, ok := err.(*server.GroupLinkError)
if ok { if ok {
@ -213,34 +221,39 @@ func (h *GroupsHandler) DeleteGroup(w http.ResponseWriter, r *http.Request) {
// GetGroup returns a group // GetGroup returns a group
func (h *GroupsHandler) GetGroup(w http.ResponseWriter, r *http.Request) { func (h *GroupsHandler) GetGroup(w http.ResponseWriter, r *http.Request) {
claims := h.claimsExtractor.FromRequestContext(r) claims := h.claimsExtractor.FromRequestContext(r)
account, user, err := h.accountManager.GetAccountFromToken(r.Context(), claims) accountID, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims)
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
groupID := mux.Vars(r)["groupId"]
if len(groupID) == 0 {
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "invalid group ID"), w)
return
}
group, err := h.accountManager.GetGroup(r.Context(), accountID, groupID, userID)
if err != nil { if err != nil {
util.WriteError(r.Context(), err, w) util.WriteError(r.Context(), err, w)
return return
} }
switch r.Method { accountPeers, err := h.accountManager.GetPeers(r.Context(), accountID, userID)
case http.MethodGet: if err != nil {
groupID := mux.Vars(r)["groupId"] util.WriteError(r.Context(), err, w)
if len(groupID) == 0 {
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "invalid group ID"), w)
return
}
group, err := h.accountManager.GetGroup(r.Context(), account.Id, groupID, user.Id)
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
util.WriteJSONObject(r.Context(), w, toGroupResponse(account, group))
default:
util.WriteError(r.Context(), status.Errorf(status.NotFound, "HTTP method not found"), w)
return return
} }
util.WriteJSONObject(r.Context(), w, toGroupResponse(accountPeers, group))
} }
func toGroupResponse(account *server.Account, group *nbgroup.Group) *api.Group { func toGroupResponse(peers []*nbpeer.Peer, group *nbgroup.Group) *api.Group {
peersMap := make(map[string]*nbpeer.Peer, len(peers))
for _, peer := range peers {
peersMap[peer.ID] = peer
}
cache := make(map[string]api.PeerMinimum) cache := make(map[string]api.PeerMinimum)
gr := api.Group{ gr := api.Group{
Id: group.ID, Id: group.ID,
@ -251,7 +264,7 @@ func toGroupResponse(account *server.Account, group *nbgroup.Group) *api.Group {
for _, pid := range group.Peers { for _, pid := range group.Peers {
_, ok := cache[pid] _, ok := cache[pid]
if !ok { if !ok {
peer, ok := account.Peers[pid] peer, ok := peersMap[pid]
if !ok { if !ok {
continue continue
} }

View File

@ -14,6 +14,7 @@ import (
"github.com/gorilla/mux" "github.com/gorilla/mux"
"github.com/magiconair/properties/assert" "github.com/magiconair/properties/assert"
"golang.org/x/exp/maps"
"github.com/netbirdio/netbird/management/server" "github.com/netbirdio/netbird/management/server"
nbgroup "github.com/netbirdio/netbird/management/server/group" nbgroup "github.com/netbirdio/netbird/management/server/group"
@ -30,7 +31,7 @@ var TestPeers = map[string]*nbpeer.Peer{
"B": {Key: "B", ID: "peer-B-ID", IP: net.ParseIP("200.200.200.200")}, "B": {Key: "B", ID: "peer-B-ID", IP: net.ParseIP("200.200.200.200")},
} }
func initGroupTestData(user *server.User, _ ...*nbgroup.Group) *GroupsHandler { func initGroupTestData(initGroups ...*nbgroup.Group) *GroupsHandler {
return &GroupsHandler{ return &GroupsHandler{
accountManager: &mock_server.MockAccountManager{ accountManager: &mock_server.MockAccountManager{
SaveGroupFunc: func(_ context.Context, accountID, userID string, group *nbgroup.Group) error { SaveGroupFunc: func(_ context.Context, accountID, userID string, group *nbgroup.Group) error {
@ -40,36 +41,35 @@ func initGroupTestData(user *server.User, _ ...*nbgroup.Group) *GroupsHandler {
return nil return nil
}, },
GetGroupFunc: func(_ context.Context, _, groupID, _ string) (*nbgroup.Group, error) { GetGroupFunc: func(_ context.Context, _, groupID, _ string) (*nbgroup.Group, error) {
if groupID != "idofthegroup" { groups := map[string]*nbgroup.Group{
"id-jwt-group": {ID: "id-jwt-group", Name: "From JWT", Issued: nbgroup.GroupIssuedJWT},
"id-existed": {ID: "id-existed", Peers: []string{"A", "B"}, Issued: nbgroup.GroupIssuedAPI},
"id-all": {ID: "id-all", Name: "All", Issued: nbgroup.GroupIssuedAPI},
}
for _, group := range initGroups {
groups[group.ID] = group
}
group, ok := groups[groupID]
if !ok {
return nil, status.Errorf(status.NotFound, "not found") return nil, status.Errorf(status.NotFound, "not found")
} }
if groupID == "id-jwt-group" {
return &nbgroup.Group{ return group, nil
ID: "id-jwt-group",
Name: "Default Group",
Issued: nbgroup.GroupIssuedJWT,
}, nil
}
return &nbgroup.Group{
ID: "idofthegroup",
Name: "Group",
Issued: nbgroup.GroupIssuedAPI,
}, nil
}, },
GetAccountFromTokenFunc: func(_ context.Context, claims jwtclaims.AuthorizationClaims) (*server.Account, *server.User, error) { GetAccountIDFromTokenFunc: func(_ context.Context, claims jwtclaims.AuthorizationClaims) (string, string, error) {
return &server.Account{ return claims.AccountId, claims.UserId, nil
Id: claims.AccountId, },
Domain: "hotmail.com", GetGroupByNameFunc: func(ctx context.Context, groupName, _ string) (*nbgroup.Group, error) {
Peers: TestPeers, if groupName == "All" {
Users: map[string]*server.User{ return &nbgroup.Group{ID: "id-all", Name: "All", Issued: nbgroup.GroupIssuedAPI}, nil
user.Id: user, }
},
Groups: map[string]*nbgroup.Group{ return nil, fmt.Errorf("unknown group name")
"id-jwt-group": {ID: "id-jwt-group", Name: "From JWT", Issued: nbgroup.GroupIssuedJWT}, },
"id-existed": {ID: "id-existed", Peers: []string{"A", "B"}, Issued: nbgroup.GroupIssuedAPI}, GetPeersFunc: func(ctx context.Context, accountID, userID string) ([]*nbpeer.Peer, error) {
"id-all": {ID: "id-all", Name: "All", Issued: nbgroup.GroupIssuedAPI}, return maps.Values(TestPeers), nil
},
}, user, nil
}, },
DeleteGroupFunc: func(_ context.Context, accountID, userId, groupID string) error { DeleteGroupFunc: func(_ context.Context, accountID, userId, groupID string) error {
if groupID == "linked-grp" { if groupID == "linked-grp" {
@ -125,8 +125,7 @@ func TestGetGroup(t *testing.T) {
Name: "Group", Name: "Group",
} }
adminUser := server.NewAdminUser("test_user") p := initGroupTestData(group)
p := initGroupTestData(adminUser, group)
for _, tc := range tt { for _, tc := range tt {
t.Run(tc.name, func(t *testing.T) { t.Run(tc.name, func(t *testing.T) {
@ -247,8 +246,7 @@ func TestWriteGroup(t *testing.T) {
}, },
} }
adminUser := server.NewAdminUser("test_user") p := initGroupTestData()
p := initGroupTestData(adminUser)
for _, tc := range tt { for _, tc := range tt {
t.Run(tc.name, func(t *testing.T) { t.Run(tc.name, func(t *testing.T) {
@ -325,8 +323,7 @@ func TestDeleteGroup(t *testing.T) {
}, },
} }
adminUser := server.NewAdminUser("test_user") p := initGroupTestData()
p := initGroupTestData(adminUser)
for _, tc := range tt { for _, tc := range tt {
t.Run(tc.name, func(t *testing.T) { t.Run(tc.name, func(t *testing.T) {

View File

@ -36,14 +36,14 @@ func NewNameserversHandler(accountManager server.AccountManager, authCfg AuthCfg
// GetAllNameservers returns the list of nameserver groups for the account // GetAllNameservers returns the list of nameserver groups for the account
func (h *NameserversHandler) GetAllNameservers(w http.ResponseWriter, r *http.Request) { func (h *NameserversHandler) GetAllNameservers(w http.ResponseWriter, r *http.Request) {
claims := h.claimsExtractor.FromRequestContext(r) claims := h.claimsExtractor.FromRequestContext(r)
account, user, err := h.accountManager.GetAccountFromToken(r.Context(), claims) accountID, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims)
if err != nil { if err != nil {
log.WithContext(r.Context()).Error(err) log.WithContext(r.Context()).Error(err)
http.Redirect(w, r, "/", http.StatusInternalServerError) http.Redirect(w, r, "/", http.StatusInternalServerError)
return return
} }
nsGroups, err := h.accountManager.ListNameServerGroups(r.Context(), account.Id, user.Id) nsGroups, err := h.accountManager.ListNameServerGroups(r.Context(), accountID, userID)
if err != nil { if err != nil {
util.WriteError(r.Context(), err, w) util.WriteError(r.Context(), err, w)
return return
@ -60,7 +60,7 @@ func (h *NameserversHandler) GetAllNameservers(w http.ResponseWriter, r *http.Re
// CreateNameserverGroup handles nameserver group creation request // CreateNameserverGroup handles nameserver group creation request
func (h *NameserversHandler) CreateNameserverGroup(w http.ResponseWriter, r *http.Request) { func (h *NameserversHandler) CreateNameserverGroup(w http.ResponseWriter, r *http.Request) {
claims := h.claimsExtractor.FromRequestContext(r) claims := h.claimsExtractor.FromRequestContext(r)
account, user, err := h.accountManager.GetAccountFromToken(r.Context(), claims) accountID, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims)
if err != nil { if err != nil {
util.WriteError(r.Context(), err, w) util.WriteError(r.Context(), err, w)
return return
@ -79,7 +79,7 @@ func (h *NameserversHandler) CreateNameserverGroup(w http.ResponseWriter, r *htt
return return
} }
nsGroup, err := h.accountManager.CreateNameServerGroup(r.Context(), account.Id, req.Name, req.Description, nsList, req.Groups, req.Primary, req.Domains, req.Enabled, user.Id, req.SearchDomainsEnabled) nsGroup, err := h.accountManager.CreateNameServerGroup(r.Context(), accountID, req.Name, req.Description, nsList, req.Groups, req.Primary, req.Domains, req.Enabled, userID, req.SearchDomainsEnabled)
if err != nil { if err != nil {
util.WriteError(r.Context(), err, w) util.WriteError(r.Context(), err, w)
return return
@ -93,7 +93,7 @@ func (h *NameserversHandler) CreateNameserverGroup(w http.ResponseWriter, r *htt
// UpdateNameserverGroup handles update to a nameserver group identified by a given ID // UpdateNameserverGroup handles update to a nameserver group identified by a given ID
func (h *NameserversHandler) UpdateNameserverGroup(w http.ResponseWriter, r *http.Request) { func (h *NameserversHandler) UpdateNameserverGroup(w http.ResponseWriter, r *http.Request) {
claims := h.claimsExtractor.FromRequestContext(r) claims := h.claimsExtractor.FromRequestContext(r)
account, user, err := h.accountManager.GetAccountFromToken(r.Context(), claims) accountID, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims)
if err != nil { if err != nil {
util.WriteError(r.Context(), err, w) util.WriteError(r.Context(), err, w)
return return
@ -130,7 +130,7 @@ func (h *NameserversHandler) UpdateNameserverGroup(w http.ResponseWriter, r *htt
SearchDomainsEnabled: req.SearchDomainsEnabled, SearchDomainsEnabled: req.SearchDomainsEnabled,
} }
err = h.accountManager.SaveNameServerGroup(r.Context(), account.Id, user.Id, updatedNSGroup) err = h.accountManager.SaveNameServerGroup(r.Context(), accountID, userID, updatedNSGroup)
if err != nil { if err != nil {
util.WriteError(r.Context(), err, w) util.WriteError(r.Context(), err, w)
return return
@ -144,7 +144,7 @@ func (h *NameserversHandler) UpdateNameserverGroup(w http.ResponseWriter, r *htt
// DeleteNameserverGroup handles nameserver group deletion request // DeleteNameserverGroup handles nameserver group deletion request
func (h *NameserversHandler) DeleteNameserverGroup(w http.ResponseWriter, r *http.Request) { func (h *NameserversHandler) DeleteNameserverGroup(w http.ResponseWriter, r *http.Request) {
claims := h.claimsExtractor.FromRequestContext(r) claims := h.claimsExtractor.FromRequestContext(r)
account, user, err := h.accountManager.GetAccountFromToken(r.Context(), claims) accountID, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims)
if err != nil { if err != nil {
util.WriteError(r.Context(), err, w) util.WriteError(r.Context(), err, w)
return return
@ -156,7 +156,7 @@ func (h *NameserversHandler) DeleteNameserverGroup(w http.ResponseWriter, r *htt
return return
} }
err = h.accountManager.DeleteNameServerGroup(r.Context(), account.Id, nsGroupID, user.Id) err = h.accountManager.DeleteNameServerGroup(r.Context(), accountID, nsGroupID, userID)
if err != nil { if err != nil {
util.WriteError(r.Context(), err, w) util.WriteError(r.Context(), err, w)
return return
@ -168,7 +168,7 @@ func (h *NameserversHandler) DeleteNameserverGroup(w http.ResponseWriter, r *htt
// GetNameserverGroup handles a nameserver group Get request identified by ID // GetNameserverGroup handles a nameserver group Get request identified by ID
func (h *NameserversHandler) GetNameserverGroup(w http.ResponseWriter, r *http.Request) { func (h *NameserversHandler) GetNameserverGroup(w http.ResponseWriter, r *http.Request) {
claims := h.claimsExtractor.FromRequestContext(r) claims := h.claimsExtractor.FromRequestContext(r)
account, user, err := h.accountManager.GetAccountFromToken(r.Context(), claims) accountID, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims)
if err != nil { if err != nil {
log.WithContext(r.Context()).Error(err) log.WithContext(r.Context()).Error(err)
http.Redirect(w, r, "/", http.StatusInternalServerError) http.Redirect(w, r, "/", http.StatusInternalServerError)
@ -181,7 +181,7 @@ func (h *NameserversHandler) GetNameserverGroup(w http.ResponseWriter, r *http.R
return return
} }
nsGroup, err := h.accountManager.GetNameServerGroup(r.Context(), account.Id, user.Id, nsGroupID) nsGroup, err := h.accountManager.GetNameServerGroup(r.Context(), accountID, userID, nsGroupID)
if err != nil { if err != nil {
util.WriteError(r.Context(), err, w) util.WriteError(r.Context(), err, w)
return return

View File

@ -18,7 +18,6 @@ import (
"github.com/gorilla/mux" "github.com/gorilla/mux"
"github.com/netbirdio/netbird/management/server"
"github.com/netbirdio/netbird/management/server/jwtclaims" "github.com/netbirdio/netbird/management/server/jwtclaims"
"github.com/netbirdio/netbird/management/server/mock_server" "github.com/netbirdio/netbird/management/server/mock_server"
) )
@ -29,14 +28,6 @@ const (
testNSGroupAccountID = "test_id" testNSGroupAccountID = "test_id"
) )
var testingNSAccount = &server.Account{
Id: testNSGroupAccountID,
Domain: "hotmail.com",
Users: map[string]*server.User{
"test_user": server.NewAdminUser("test_user"),
},
}
var baseExistingNSGroup = &nbdns.NameServerGroup{ var baseExistingNSGroup = &nbdns.NameServerGroup{
ID: existingNSGroupID, ID: existingNSGroupID,
Name: "super", Name: "super",
@ -90,8 +81,8 @@ func initNameserversTestData() *NameserversHandler {
} }
return status.Errorf(status.NotFound, "nameserver group with ID %s was not found", nsGroupToSave.ID) return status.Errorf(status.NotFound, "nameserver group with ID %s was not found", nsGroupToSave.ID)
}, },
GetAccountFromTokenFunc: func(_ context.Context, _ jwtclaims.AuthorizationClaims) (*server.Account, *server.User, error) { GetAccountIDFromTokenFunc: func(_ context.Context, claims jwtclaims.AuthorizationClaims) (string, string, error) {
return testingNSAccount, testingAccount.Users["test_user"], nil return claims.AccountId, claims.UserId, nil
}, },
}, },
claimsExtractor: jwtclaims.NewClaimsExtractor( claimsExtractor: jwtclaims.NewClaimsExtractor(

View File

@ -34,20 +34,20 @@ func NewPATsHandler(accountManager server.AccountManager, authCfg AuthCfg) *PATH
// GetAllTokens is HTTP GET handler that returns a list of all personal access tokens for the given user // GetAllTokens is HTTP GET handler that returns a list of all personal access tokens for the given user
func (h *PATHandler) GetAllTokens(w http.ResponseWriter, r *http.Request) { func (h *PATHandler) GetAllTokens(w http.ResponseWriter, r *http.Request) {
claims := h.claimsExtractor.FromRequestContext(r) claims := h.claimsExtractor.FromRequestContext(r)
account, user, err := h.accountManager.GetAccountFromToken(r.Context(), claims) accountID, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims)
if err != nil { if err != nil {
util.WriteError(r.Context(), err, w) util.WriteError(r.Context(), err, w)
return return
} }
vars := mux.Vars(r) vars := mux.Vars(r)
userID := vars["userId"] targetUserID := vars["userId"]
if len(userID) == 0 { if len(userID) == 0 {
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "invalid user ID"), w) util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "invalid user ID"), w)
return return
} }
pats, err := h.accountManager.GetAllPATs(r.Context(), account.Id, user.Id, userID) pats, err := h.accountManager.GetAllPATs(r.Context(), accountID, userID, targetUserID)
if err != nil { if err != nil {
util.WriteError(r.Context(), err, w) util.WriteError(r.Context(), err, w)
return return
@ -64,7 +64,7 @@ func (h *PATHandler) GetAllTokens(w http.ResponseWriter, r *http.Request) {
// GetToken is HTTP GET handler that returns a personal access token for the given user // GetToken is HTTP GET handler that returns a personal access token for the given user
func (h *PATHandler) GetToken(w http.ResponseWriter, r *http.Request) { func (h *PATHandler) GetToken(w http.ResponseWriter, r *http.Request) {
claims := h.claimsExtractor.FromRequestContext(r) claims := h.claimsExtractor.FromRequestContext(r)
account, user, err := h.accountManager.GetAccountFromToken(r.Context(), claims) accountID, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims)
if err != nil { if err != nil {
util.WriteError(r.Context(), err, w) util.WriteError(r.Context(), err, w)
return return
@ -83,7 +83,7 @@ func (h *PATHandler) GetToken(w http.ResponseWriter, r *http.Request) {
return return
} }
pat, err := h.accountManager.GetPAT(r.Context(), account.Id, user.Id, targetUserID, tokenID) pat, err := h.accountManager.GetPAT(r.Context(), accountID, userID, targetUserID, tokenID)
if err != nil { if err != nil {
util.WriteError(r.Context(), err, w) util.WriteError(r.Context(), err, w)
return return
@ -95,7 +95,7 @@ func (h *PATHandler) GetToken(w http.ResponseWriter, r *http.Request) {
// CreateToken is HTTP POST handler that creates a personal access token for the given user // CreateToken is HTTP POST handler that creates a personal access token for the given user
func (h *PATHandler) CreateToken(w http.ResponseWriter, r *http.Request) { func (h *PATHandler) CreateToken(w http.ResponseWriter, r *http.Request) {
claims := h.claimsExtractor.FromRequestContext(r) claims := h.claimsExtractor.FromRequestContext(r)
account, user, err := h.accountManager.GetAccountFromToken(r.Context(), claims) accountID, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims)
if err != nil { if err != nil {
util.WriteError(r.Context(), err, w) util.WriteError(r.Context(), err, w)
return return
@ -115,7 +115,7 @@ func (h *PATHandler) CreateToken(w http.ResponseWriter, r *http.Request) {
return return
} }
pat, err := h.accountManager.CreatePAT(r.Context(), account.Id, user.Id, targetUserID, req.Name, req.ExpiresIn) pat, err := h.accountManager.CreatePAT(r.Context(), accountID, userID, targetUserID, req.Name, req.ExpiresIn)
if err != nil { if err != nil {
util.WriteError(r.Context(), err, w) util.WriteError(r.Context(), err, w)
return return
@ -127,7 +127,7 @@ func (h *PATHandler) CreateToken(w http.ResponseWriter, r *http.Request) {
// DeleteToken is HTTP DELETE handler that deletes a personal access token for the given user // DeleteToken is HTTP DELETE handler that deletes a personal access token for the given user
func (h *PATHandler) DeleteToken(w http.ResponseWriter, r *http.Request) { func (h *PATHandler) DeleteToken(w http.ResponseWriter, r *http.Request) {
claims := h.claimsExtractor.FromRequestContext(r) claims := h.claimsExtractor.FromRequestContext(r)
account, user, err := h.accountManager.GetAccountFromToken(r.Context(), claims) accountID, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims)
if err != nil { if err != nil {
util.WriteError(r.Context(), err, w) util.WriteError(r.Context(), err, w)
return return
@ -146,7 +146,7 @@ func (h *PATHandler) DeleteToken(w http.ResponseWriter, r *http.Request) {
return return
} }
err = h.accountManager.DeletePAT(r.Context(), account.Id, user.Id, targetUserID, tokenID) err = h.accountManager.DeletePAT(r.Context(), accountID, userID, targetUserID, tokenID)
if err != nil { if err != nil {
util.WriteError(r.Context(), err, w) util.WriteError(r.Context(), err, w)
return return

View File

@ -77,8 +77,8 @@ func initPATTestData() *PATHandler {
}, nil }, nil
}, },
GetAccountFromTokenFunc: func(_ context.Context, _ jwtclaims.AuthorizationClaims) (*server.Account, *server.User, error) { GetAccountIDFromTokenFunc: func(_ context.Context, claims jwtclaims.AuthorizationClaims) (string, string, error) {
return testAccount, testAccount.Users[existingUserID], nil return claims.AccountId, claims.UserId, nil
}, },
DeletePATFunc: func(_ context.Context, accountID string, initiatorUserID string, targetUserID string, tokenID string) error { DeletePATFunc: func(_ context.Context, accountID string, initiatorUserID string, targetUserID string, tokenID string) error {
if accountID != existingAccountID { if accountID != existingAccountID {
@ -119,7 +119,7 @@ func initPATTestData() *PATHandler {
return jwtclaims.AuthorizationClaims{ return jwtclaims.AuthorizationClaims{
UserId: existingUserID, UserId: existingUserID,
Domain: testDomain, Domain: testDomain,
AccountId: testNSGroupAccountID, AccountId: existingAccountID,
} }
}), }),
), ),

View File

@ -74,7 +74,7 @@ func (h *PeersHandler) getPeer(ctx context.Context, account *server.Account, pee
util.WriteJSONObject(ctx, w, toSinglePeerResponse(peerToReturn, groupsInfo, dnsDomain, valid)) util.WriteJSONObject(ctx, w, toSinglePeerResponse(peerToReturn, groupsInfo, dnsDomain, valid))
} }
func (h *PeersHandler) updatePeer(ctx context.Context, account *server.Account, user *server.User, peerID string, w http.ResponseWriter, r *http.Request) { func (h *PeersHandler) updatePeer(ctx context.Context, account *server.Account, userID, peerID string, w http.ResponseWriter, r *http.Request) {
req := &api.PeerRequest{} req := &api.PeerRequest{}
err := json.NewDecoder(r.Body).Decode(&req) err := json.NewDecoder(r.Body).Decode(&req)
if err != nil { if err != nil {
@ -96,7 +96,7 @@ func (h *PeersHandler) updatePeer(ctx context.Context, account *server.Account,
} }
} }
peer, err := h.accountManager.UpdatePeer(ctx, account.Id, user.Id, update) peer, err := h.accountManager.UpdatePeer(ctx, account.Id, userID, update)
if err != nil { if err != nil {
util.WriteError(ctx, err, w) util.WriteError(ctx, err, w)
return return
@ -130,7 +130,7 @@ func (h *PeersHandler) deletePeer(ctx context.Context, accountID, userID string,
// HandlePeer handles all peer requests for GET, PUT and DELETE operations // HandlePeer handles all peer requests for GET, PUT and DELETE operations
func (h *PeersHandler) HandlePeer(w http.ResponseWriter, r *http.Request) { func (h *PeersHandler) HandlePeer(w http.ResponseWriter, r *http.Request) {
claims := h.claimsExtractor.FromRequestContext(r) claims := h.claimsExtractor.FromRequestContext(r)
account, user, err := h.accountManager.GetAccountFromToken(r.Context(), claims) accountID, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims)
if err != nil { if err != nil {
util.WriteError(r.Context(), err, w) util.WriteError(r.Context(), err, w)
return return
@ -144,13 +144,20 @@ func (h *PeersHandler) HandlePeer(w http.ResponseWriter, r *http.Request) {
switch r.Method { switch r.Method {
case http.MethodDelete: case http.MethodDelete:
h.deletePeer(r.Context(), account.Id, user.Id, peerID, w) h.deletePeer(r.Context(), accountID, userID, peerID, w)
return return
case http.MethodPut: case http.MethodGet, http.MethodPut:
h.updatePeer(r.Context(), account, user, peerID, w, r) account, err := h.accountManager.GetAccountByID(r.Context(), accountID, userID)
return if err != nil {
case http.MethodGet: util.WriteError(r.Context(), err, w)
h.getPeer(r.Context(), account, peerID, user.Id, w) return
}
if r.Method == http.MethodGet {
h.getPeer(r.Context(), account, peerID, userID, w)
} else {
h.updatePeer(r.Context(), account, userID, peerID, w, r)
}
return return
default: default:
util.WriteError(r.Context(), status.Errorf(status.NotFound, "unknown METHOD"), w) util.WriteError(r.Context(), status.Errorf(status.NotFound, "unknown METHOD"), w)
@ -159,19 +166,14 @@ func (h *PeersHandler) HandlePeer(w http.ResponseWriter, r *http.Request) {
// GetAllPeers returns a list of all peers associated with a provided account // GetAllPeers returns a list of all peers associated with a provided account
func (h *PeersHandler) GetAllPeers(w http.ResponseWriter, r *http.Request) { func (h *PeersHandler) GetAllPeers(w http.ResponseWriter, r *http.Request) {
if r.Method != http.MethodGet {
util.WriteError(r.Context(), status.Errorf(status.NotFound, "unknown METHOD"), w)
return
}
claims := h.claimsExtractor.FromRequestContext(r) claims := h.claimsExtractor.FromRequestContext(r)
account, user, err := h.accountManager.GetAccountFromToken(r.Context(), claims) accountID, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims)
if err != nil { if err != nil {
util.WriteError(r.Context(), err, w) util.WriteError(r.Context(), err, w)
return return
} }
peers, err := h.accountManager.GetPeers(r.Context(), account.Id, user.Id) account, err := h.accountManager.GetAccountByID(r.Context(), accountID, userID)
if err != nil { if err != nil {
util.WriteError(r.Context(), err, w) util.WriteError(r.Context(), err, w)
return return
@ -179,8 +181,8 @@ func (h *PeersHandler) GetAllPeers(w http.ResponseWriter, r *http.Request) {
dnsDomain := h.accountManager.GetDNSDomain() dnsDomain := h.accountManager.GetDNSDomain()
respBody := make([]*api.PeerBatch, 0, len(peers)) respBody := make([]*api.PeerBatch, 0, len(account.Peers))
for _, peer := range peers { for _, peer := range account.Peers {
peerToReturn, err := h.checkPeerStatus(peer) peerToReturn, err := h.checkPeerStatus(peer)
if err != nil { if err != nil {
util.WriteError(r.Context(), err, w) util.WriteError(r.Context(), err, w)
@ -214,7 +216,7 @@ func (h *PeersHandler) setApprovalRequiredFlag(respBody []*api.PeerBatch, approv
// GetAccessiblePeers returns a list of all peers that the specified peer can connect to within the network. // GetAccessiblePeers returns a list of all peers that the specified peer can connect to within the network.
func (h *PeersHandler) GetAccessiblePeers(w http.ResponseWriter, r *http.Request) { func (h *PeersHandler) GetAccessiblePeers(w http.ResponseWriter, r *http.Request) {
claims := h.claimsExtractor.FromRequestContext(r) claims := h.claimsExtractor.FromRequestContext(r)
account, user, err := h.accountManager.GetAccountFromToken(r.Context(), claims) accountID, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims)
if err != nil { if err != nil {
util.WriteError(r.Context(), err, w) util.WriteError(r.Context(), err, w)
return return
@ -227,6 +229,18 @@ func (h *PeersHandler) GetAccessiblePeers(w http.ResponseWriter, r *http.Request
return return
} }
account, err := h.accountManager.GetAccountByID(r.Context(), accountID, userID)
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
user, err := account.FindUser(userID)
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
// If the user is regular user and does not own the peer // If the user is regular user and does not own the peer
// with the given peerID return an empty list // with the given peerID return an empty list
if !user.HasAdminPower() && !user.IsServiceUser { if !user.HasAdminPower() && !user.IsServiceUser {

View File

@ -13,16 +13,15 @@ import (
"time" "time"
"github.com/gorilla/mux" "github.com/gorilla/mux"
"github.com/netbirdio/netbird/management/server"
nbgroup "github.com/netbirdio/netbird/management/server/group" nbgroup "github.com/netbirdio/netbird/management/server/group"
"github.com/netbirdio/netbird/management/server/http/api" "github.com/netbirdio/netbird/management/server/http/api"
"github.com/netbirdio/netbird/management/server/jwtclaims"
nbpeer "github.com/netbirdio/netbird/management/server/peer" nbpeer "github.com/netbirdio/netbird/management/server/peer"
"golang.org/x/exp/maps" "golang.org/x/exp/maps"
"github.com/netbirdio/netbird/management/server/jwtclaims"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/netbirdio/netbird/management/server"
"github.com/netbirdio/netbird/management/server/mock_server" "github.com/netbirdio/netbird/management/server/mock_server"
) )
@ -70,7 +69,10 @@ func initTestMetaData(peers ...*nbpeer.Peer) *PeersHandler {
GetDNSDomainFunc: func() string { GetDNSDomainFunc: func() string {
return "netbird.selfhosted" return "netbird.selfhosted"
}, },
GetAccountFromTokenFunc: func(_ context.Context, claims jwtclaims.AuthorizationClaims) (*server.Account, *server.User, error) { GetAccountIDFromTokenFunc: func(_ context.Context, claims jwtclaims.AuthorizationClaims) (string, string, error) {
return claims.AccountId, claims.UserId, nil
},
GetAccountByIDFunc: func(ctx context.Context, accountID string, userID string) (*server.Account, error) {
peersMap := make(map[string]*nbpeer.Peer) peersMap := make(map[string]*nbpeer.Peer)
for _, peer := range peers { for _, peer := range peers {
peersMap[peer.ID] = peer.Copy() peersMap[peer.ID] = peer.Copy()
@ -78,7 +80,7 @@ func initTestMetaData(peers ...*nbpeer.Peer) *PeersHandler {
policy := &server.Policy{ policy := &server.Policy{
ID: "policy", ID: "policy",
AccountID: claims.AccountId, AccountID: accountID,
Name: "policy", Name: "policy",
Enabled: true, Enabled: true,
Rules: []*server.PolicyRule{ Rules: []*server.PolicyRule{
@ -100,7 +102,7 @@ func initTestMetaData(peers ...*nbpeer.Peer) *PeersHandler {
srvUser.IsServiceUser = true srvUser.IsServiceUser = true
account := &server.Account{ account := &server.Account{
Id: claims.AccountId, Id: accountID,
Domain: "hotmail.com", Domain: "hotmail.com",
Peers: peersMap, Peers: peersMap,
Users: map[string]*server.User{ Users: map[string]*server.User{
@ -111,7 +113,7 @@ func initTestMetaData(peers ...*nbpeer.Peer) *PeersHandler {
Groups: map[string]*nbgroup.Group{ Groups: map[string]*nbgroup.Group{
"group1": { "group1": {
ID: "group1", ID: "group1",
AccountID: claims.AccountId, AccountID: accountID,
Name: "group1", Name: "group1",
Issued: "api", Issued: "api",
Peers: maps.Keys(peersMap), Peers: maps.Keys(peersMap),
@ -132,7 +134,7 @@ func initTestMetaData(peers ...*nbpeer.Peer) *PeersHandler {
}, },
} }
return account, account.Users[claims.UserId], nil return account, nil
}, },
HasConnectedChannelFunc: func(peerID string) bool { HasConnectedChannelFunc: func(peerID string) bool {
statuses := make(map[string]struct{}) statuses := make(map[string]struct{})
@ -279,9 +281,15 @@ func TestGetPeers(t *testing.T) {
// hardcode this check for now as we only have two peers in this suite // hardcode this check for now as we only have two peers in this suite
assert.Equal(t, len(respBody), 2) assert.Equal(t, len(respBody), 2)
assert.Equal(t, respBody[1].Connected, false)
got = respBody[0] for _, peer := range respBody {
if peer.Id == testPeerID {
got = peer
} else {
assert.Equal(t, peer.Connected, false)
}
}
} else { } else {
got = &api.Peer{} got = &api.Peer{}
err = json.Unmarshal(content, got) err = json.Unmarshal(content, got)

View File

@ -6,6 +6,7 @@ import (
"strconv" "strconv"
"github.com/gorilla/mux" "github.com/gorilla/mux"
nbgroup "github.com/netbirdio/netbird/management/server/group"
"github.com/rs/xid" "github.com/rs/xid"
"github.com/netbirdio/netbird/management/server" "github.com/netbirdio/netbird/management/server"
@ -35,21 +36,27 @@ func NewPoliciesHandler(accountManager server.AccountManager, authCfg AuthCfg) *
// GetAllPolicies list for the account // GetAllPolicies list for the account
func (h *Policies) GetAllPolicies(w http.ResponseWriter, r *http.Request) { func (h *Policies) GetAllPolicies(w http.ResponseWriter, r *http.Request) {
claims := h.claimsExtractor.FromRequestContext(r) claims := h.claimsExtractor.FromRequestContext(r)
account, user, err := h.accountManager.GetAccountFromToken(r.Context(), claims) accountID, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims)
if err != nil { if err != nil {
util.WriteError(r.Context(), err, w) util.WriteError(r.Context(), err, w)
return return
} }
accountPolicies, err := h.accountManager.ListPolicies(r.Context(), account.Id, user.Id) listPolicies, err := h.accountManager.ListPolicies(r.Context(), accountID, userID)
if err != nil { if err != nil {
util.WriteError(r.Context(), err, w) util.WriteError(r.Context(), err, w)
return return
} }
policies := []*api.Policy{} allGroups, err := h.accountManager.GetAllGroups(r.Context(), accountID, userID)
for _, policy := range accountPolicies { if err != nil {
resp := toPolicyResponse(account, policy) util.WriteError(r.Context(), err, w)
return
}
policies := make([]*api.Policy, 0, len(listPolicies))
for _, policy := range listPolicies {
resp := toPolicyResponse(allGroups, policy)
if len(resp.Rules) == 0 { if len(resp.Rules) == 0 {
util.WriteError(r.Context(), status.Errorf(status.Internal, "no rules in the policy"), w) util.WriteError(r.Context(), status.Errorf(status.Internal, "no rules in the policy"), w)
return return
@ -63,7 +70,7 @@ func (h *Policies) GetAllPolicies(w http.ResponseWriter, r *http.Request) {
// UpdatePolicy handles update to a policy identified by a given ID // UpdatePolicy handles update to a policy identified by a given ID
func (h *Policies) UpdatePolicy(w http.ResponseWriter, r *http.Request) { func (h *Policies) UpdatePolicy(w http.ResponseWriter, r *http.Request) {
claims := h.claimsExtractor.FromRequestContext(r) claims := h.claimsExtractor.FromRequestContext(r)
account, user, err := h.accountManager.GetAccountFromToken(r.Context(), claims) accountID, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims)
if err != nil { if err != nil {
util.WriteError(r.Context(), err, w) util.WriteError(r.Context(), err, w)
return return
@ -76,41 +83,29 @@ func (h *Policies) UpdatePolicy(w http.ResponseWriter, r *http.Request) {
return return
} }
policyIdx := -1 _, err = h.accountManager.GetPolicy(r.Context(), accountID, policyID, userID)
for i, policy := range account.Policies {
if policy.ID == policyID {
policyIdx = i
break
}
}
if policyIdx < 0 {
util.WriteError(r.Context(), status.Errorf(status.NotFound, "couldn't find policy id %s", policyID), w)
return
}
h.savePolicy(w, r, account, user, policyID)
}
// CreatePolicy handles policy creation request
func (h *Policies) CreatePolicy(w http.ResponseWriter, r *http.Request) {
claims := h.claimsExtractor.FromRequestContext(r)
account, user, err := h.accountManager.GetAccountFromToken(r.Context(), claims)
if err != nil { if err != nil {
util.WriteError(r.Context(), err, w) util.WriteError(r.Context(), err, w)
return return
} }
h.savePolicy(w, r, account, user, "") h.savePolicy(w, r, accountID, userID, policyID)
}
// CreatePolicy handles policy creation request
func (h *Policies) CreatePolicy(w http.ResponseWriter, r *http.Request) {
claims := h.claimsExtractor.FromRequestContext(r)
accountID, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims)
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
h.savePolicy(w, r, accountID, userID, "")
} }
// savePolicy handles policy creation and update // savePolicy handles policy creation and update
func (h *Policies) savePolicy( func (h *Policies) savePolicy(w http.ResponseWriter, r *http.Request, accountID string, userID string, policyID string) {
w http.ResponseWriter,
r *http.Request,
account *server.Account,
user *server.User,
policyID string,
) {
var req api.PutApiPoliciesPolicyIdJSONRequestBody var req api.PutApiPoliciesPolicyIdJSONRequestBody
if err := json.NewDecoder(r.Body).Decode(&req); err != nil { if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
util.WriteErrorResponse("couldn't parse JSON request", http.StatusBadRequest, w) util.WriteErrorResponse("couldn't parse JSON request", http.StatusBadRequest, w)
@ -127,6 +122,8 @@ func (h *Policies) savePolicy(
return return
} }
isUpdate := policyID != ""
if policyID == "" { if policyID == "" {
policyID = xid.New().String() policyID = xid.New().String()
} }
@ -141,8 +138,8 @@ func (h *Policies) savePolicy(
pr := server.PolicyRule{ pr := server.PolicyRule{
ID: policyID, // TODO: when policy can contain multiple rules, need refactor ID: policyID, // TODO: when policy can contain multiple rules, need refactor
Name: rule.Name, Name: rule.Name,
Destinations: groupMinimumsToStrings(account, rule.Destinations), Destinations: rule.Destinations,
Sources: groupMinimumsToStrings(account, rule.Sources), Sources: rule.Sources,
Bidirectional: rule.Bidirectional, Bidirectional: rule.Bidirectional,
} }
@ -207,15 +204,21 @@ func (h *Policies) savePolicy(
} }
if req.SourcePostureChecks != nil { if req.SourcePostureChecks != nil {
policy.SourcePostureChecks = sourcePostureChecksToStrings(account, *req.SourcePostureChecks) policy.SourcePostureChecks = *req.SourcePostureChecks
} }
if err := h.accountManager.SavePolicy(r.Context(), account.Id, user.Id, &policy); err != nil { if err := h.accountManager.SavePolicy(r.Context(), accountID, userID, &policy, isUpdate); err != nil {
util.WriteError(r.Context(), err, w) util.WriteError(r.Context(), err, w)
return return
} }
resp := toPolicyResponse(account, &policy) allGroups, err := h.accountManager.GetAllGroups(r.Context(), accountID, userID)
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
resp := toPolicyResponse(allGroups, &policy)
if len(resp.Rules) == 0 { if len(resp.Rules) == 0 {
util.WriteError(r.Context(), status.Errorf(status.Internal, "no rules in the policy"), w) util.WriteError(r.Context(), status.Errorf(status.Internal, "no rules in the policy"), w)
return return
@ -227,12 +230,11 @@ func (h *Policies) savePolicy(
// DeletePolicy handles policy deletion request // DeletePolicy handles policy deletion request
func (h *Policies) DeletePolicy(w http.ResponseWriter, r *http.Request) { func (h *Policies) DeletePolicy(w http.ResponseWriter, r *http.Request) {
claims := h.claimsExtractor.FromRequestContext(r) claims := h.claimsExtractor.FromRequestContext(r)
account, user, err := h.accountManager.GetAccountFromToken(r.Context(), claims) accountID, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims)
if err != nil { if err != nil {
util.WriteError(r.Context(), err, w) util.WriteError(r.Context(), err, w)
return return
} }
aID := account.Id
vars := mux.Vars(r) vars := mux.Vars(r)
policyID := vars["policyId"] policyID := vars["policyId"]
@ -241,7 +243,7 @@ func (h *Policies) DeletePolicy(w http.ResponseWriter, r *http.Request) {
return return
} }
if err = h.accountManager.DeletePolicy(r.Context(), aID, policyID, user.Id); err != nil { if err = h.accountManager.DeletePolicy(r.Context(), accountID, policyID, userID); err != nil {
util.WriteError(r.Context(), err, w) util.WriteError(r.Context(), err, w)
return return
} }
@ -252,40 +254,46 @@ func (h *Policies) DeletePolicy(w http.ResponseWriter, r *http.Request) {
// GetPolicy handles a group Get request identified by ID // GetPolicy handles a group Get request identified by ID
func (h *Policies) GetPolicy(w http.ResponseWriter, r *http.Request) { func (h *Policies) GetPolicy(w http.ResponseWriter, r *http.Request) {
claims := h.claimsExtractor.FromRequestContext(r) claims := h.claimsExtractor.FromRequestContext(r)
account, user, err := h.accountManager.GetAccountFromToken(r.Context(), claims) accountID, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims)
if err != nil { if err != nil {
util.WriteError(r.Context(), err, w) util.WriteError(r.Context(), err, w)
return return
} }
switch r.Method { vars := mux.Vars(r)
case http.MethodGet: policyID := vars["policyId"]
vars := mux.Vars(r) if len(policyID) == 0 {
policyID := vars["policyId"] util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "invalid policy ID"), w)
if len(policyID) == 0 { return
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "invalid policy ID"), w)
return
}
policy, err := h.accountManager.GetPolicy(r.Context(), account.Id, policyID, user.Id)
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
resp := toPolicyResponse(account, policy)
if len(resp.Rules) == 0 {
util.WriteError(r.Context(), status.Errorf(status.Internal, "no rules in the policy"), w)
return
}
util.WriteJSONObject(r.Context(), w, resp)
default:
util.WriteError(r.Context(), status.Errorf(status.NotFound, "method not found"), w)
} }
policy, err := h.accountManager.GetPolicy(r.Context(), accountID, policyID, userID)
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
allGroups, err := h.accountManager.GetAllGroups(r.Context(), accountID, userID)
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
resp := toPolicyResponse(allGroups, policy)
if len(resp.Rules) == 0 {
util.WriteError(r.Context(), status.Errorf(status.Internal, "no rules in the policy"), w)
return
}
util.WriteJSONObject(r.Context(), w, resp)
} }
func toPolicyResponse(account *server.Account, policy *server.Policy) *api.Policy { func toPolicyResponse(groups []*nbgroup.Group, policy *server.Policy) *api.Policy {
groupsMap := make(map[string]*nbgroup.Group)
for _, group := range groups {
groupsMap[group.ID] = group
}
cache := make(map[string]api.GroupMinimum) cache := make(map[string]api.GroupMinimum)
ap := &api.Policy{ ap := &api.Policy{
Id: &policy.ID, Id: &policy.ID,
@ -306,16 +314,18 @@ func toPolicyResponse(account *server.Account, policy *server.Policy) *api.Polic
Protocol: api.PolicyRuleProtocol(r.Protocol), Protocol: api.PolicyRuleProtocol(r.Protocol),
Action: api.PolicyRuleAction(r.Action), Action: api.PolicyRuleAction(r.Action),
} }
if len(r.Ports) != 0 { if len(r.Ports) != 0 {
portsCopy := r.Ports portsCopy := r.Ports
rule.Ports = &portsCopy rule.Ports = &portsCopy
} }
for _, gid := range r.Sources { for _, gid := range r.Sources {
_, ok := cache[gid] _, ok := cache[gid]
if ok { if ok {
continue continue
} }
if group, ok := account.Groups[gid]; ok { if group, ok := groupsMap[gid]; ok {
minimum := api.GroupMinimum{ minimum := api.GroupMinimum{
Id: group.ID, Id: group.ID,
Name: group.Name, Name: group.Name,
@ -325,13 +335,14 @@ func toPolicyResponse(account *server.Account, policy *server.Policy) *api.Polic
cache[gid] = minimum cache[gid] = minimum
} }
} }
for _, gid := range r.Destinations { for _, gid := range r.Destinations {
cachedMinimum, ok := cache[gid] cachedMinimum, ok := cache[gid]
if ok { if ok {
rule.Destinations = append(rule.Destinations, cachedMinimum) rule.Destinations = append(rule.Destinations, cachedMinimum)
continue continue
} }
if group, ok := account.Groups[gid]; ok { if group, ok := groupsMap[gid]; ok {
minimum := api.GroupMinimum{ minimum := api.GroupMinimum{
Id: group.ID, Id: group.ID,
Name: group.Name, Name: group.Name,
@ -345,28 +356,3 @@ func toPolicyResponse(account *server.Account, policy *server.Policy) *api.Polic
} }
return ap return ap
} }
func groupMinimumsToStrings(account *server.Account, gm []string) []string {
result := make([]string, 0, len(gm))
for _, g := range gm {
if _, ok := account.Groups[g]; !ok {
continue
}
result = append(result, g)
}
return result
}
func sourcePostureChecksToStrings(account *server.Account, postureChecksIds []string) []string {
result := make([]string, 0, len(postureChecksIds))
for _, id := range postureChecksIds {
for _, postureCheck := range account.PostureChecks {
if id == postureCheck.ID {
result = append(result, id)
continue
}
}
}
return result
}

View File

@ -38,17 +38,23 @@ func initPoliciesTestData(policies ...*server.Policy) *Policies {
} }
return policy, nil return policy, nil
}, },
SavePolicyFunc: func(_ context.Context, _, _ string, policy *server.Policy) error { SavePolicyFunc: func(_ context.Context, _, _ string, policy *server.Policy, _ bool) error {
if !strings.HasPrefix(policy.ID, "id-") { if !strings.HasPrefix(policy.ID, "id-") {
policy.ID = "id-was-set" policy.ID = "id-was-set"
policy.Rules[0].ID = "id-was-set" policy.Rules[0].ID = "id-was-set"
} }
return nil return nil
}, },
GetAccountFromTokenFunc: func(_ context.Context, claims jwtclaims.AuthorizationClaims) (*server.Account, *server.User, error) { GetAllGroupsFunc: func(ctx context.Context, accountID, userID string) ([]*nbgroup.Group, error) {
user := server.NewAdminUser("test_user") return []*nbgroup.Group{{ID: "F"}, {ID: "G"}}, nil
},
GetAccountIDFromTokenFunc: func(_ context.Context, claims jwtclaims.AuthorizationClaims) (string, string, error) {
return claims.AccountId, claims.UserId, nil
},
GetAccountByIDFunc: func(ctx context.Context, accountID string, userID string) (*server.Account, error) {
user := server.NewAdminUser(userID)
return &server.Account{ return &server.Account{
Id: claims.AccountId, Id: accountID,
Domain: "hotmail.com", Domain: "hotmail.com",
Policies: []*server.Policy{ Policies: []*server.Policy{
{ID: "id-existed"}, {ID: "id-existed"},
@ -60,7 +66,7 @@ func initPoliciesTestData(policies ...*server.Policy) *Policies {
Users: map[string]*server.User{ Users: map[string]*server.User{
"test_user": user, "test_user": user,
}, },
}, user, nil }, nil
}, },
}, },
claimsExtractor: jwtclaims.NewClaimsExtractor( claimsExtractor: jwtclaims.NewClaimsExtractor(

View File

@ -37,20 +37,20 @@ func NewPostureChecksHandler(accountManager server.AccountManager, geolocationMa
// GetAllPostureChecks list for the account // GetAllPostureChecks list for the account
func (p *PostureChecksHandler) GetAllPostureChecks(w http.ResponseWriter, r *http.Request) { func (p *PostureChecksHandler) GetAllPostureChecks(w http.ResponseWriter, r *http.Request) {
claims := p.claimsExtractor.FromRequestContext(r) claims := p.claimsExtractor.FromRequestContext(r)
account, user, err := p.accountManager.GetAccountFromToken(r.Context(), claims) accountID, userID, err := p.accountManager.GetAccountIDFromToken(r.Context(), claims)
if err != nil { if err != nil {
util.WriteError(r.Context(), err, w) util.WriteError(r.Context(), err, w)
return return
} }
accountPostureChecks, err := p.accountManager.ListPostureChecks(r.Context(), account.Id, user.Id) listPostureChecks, err := p.accountManager.ListPostureChecks(r.Context(), accountID, userID)
if err != nil { if err != nil {
util.WriteError(r.Context(), err, w) util.WriteError(r.Context(), err, w)
return return
} }
postureChecks := []*api.PostureCheck{} postureChecks := make([]*api.PostureCheck, 0, len(listPostureChecks))
for _, postureCheck := range accountPostureChecks { for _, postureCheck := range listPostureChecks {
postureChecks = append(postureChecks, postureCheck.ToAPIResponse()) postureChecks = append(postureChecks, postureCheck.ToAPIResponse())
} }
@ -60,7 +60,7 @@ func (p *PostureChecksHandler) GetAllPostureChecks(w http.ResponseWriter, r *htt
// UpdatePostureCheck handles update to a posture check identified by a given ID // UpdatePostureCheck handles update to a posture check identified by a given ID
func (p *PostureChecksHandler) UpdatePostureCheck(w http.ResponseWriter, r *http.Request) { func (p *PostureChecksHandler) UpdatePostureCheck(w http.ResponseWriter, r *http.Request) {
claims := p.claimsExtractor.FromRequestContext(r) claims := p.claimsExtractor.FromRequestContext(r)
account, user, err := p.accountManager.GetAccountFromToken(r.Context(), claims) accountID, userID, err := p.accountManager.GetAccountIDFromToken(r.Context(), claims)
if err != nil { if err != nil {
util.WriteError(r.Context(), err, w) util.WriteError(r.Context(), err, w)
return return
@ -73,37 +73,31 @@ func (p *PostureChecksHandler) UpdatePostureCheck(w http.ResponseWriter, r *http
return return
} }
postureChecksIdx := -1 _, err = p.accountManager.GetPostureChecks(r.Context(), accountID, postureChecksID, userID)
for i, postureCheck := range account.PostureChecks { if err != nil {
if postureCheck.ID == postureChecksID { util.WriteError(r.Context(), err, w)
postureChecksIdx = i
break
}
}
if postureChecksIdx < 0 {
util.WriteError(r.Context(), status.Errorf(status.NotFound, "couldn't find posture checks id %s", postureChecksID), w)
return return
} }
p.savePostureChecks(w, r, account, user, postureChecksID) p.savePostureChecks(w, r, accountID, userID, postureChecksID)
} }
// CreatePostureCheck handles posture check creation request // CreatePostureCheck handles posture check creation request
func (p *PostureChecksHandler) CreatePostureCheck(w http.ResponseWriter, r *http.Request) { func (p *PostureChecksHandler) CreatePostureCheck(w http.ResponseWriter, r *http.Request) {
claims := p.claimsExtractor.FromRequestContext(r) claims := p.claimsExtractor.FromRequestContext(r)
account, user, err := p.accountManager.GetAccountFromToken(r.Context(), claims) accountID, userID, err := p.accountManager.GetAccountIDFromToken(r.Context(), claims)
if err != nil { if err != nil {
util.WriteError(r.Context(), err, w) util.WriteError(r.Context(), err, w)
return return
} }
p.savePostureChecks(w, r, account, user, "") p.savePostureChecks(w, r, accountID, userID, "")
} }
// GetPostureCheck handles a posture check Get request identified by ID // GetPostureCheck handles a posture check Get request identified by ID
func (p *PostureChecksHandler) GetPostureCheck(w http.ResponseWriter, r *http.Request) { func (p *PostureChecksHandler) GetPostureCheck(w http.ResponseWriter, r *http.Request) {
claims := p.claimsExtractor.FromRequestContext(r) claims := p.claimsExtractor.FromRequestContext(r)
account, user, err := p.accountManager.GetAccountFromToken(r.Context(), claims) accountID, userID, err := p.accountManager.GetAccountIDFromToken(r.Context(), claims)
if err != nil { if err != nil {
util.WriteError(r.Context(), err, w) util.WriteError(r.Context(), err, w)
return return
@ -116,7 +110,7 @@ func (p *PostureChecksHandler) GetPostureCheck(w http.ResponseWriter, r *http.Re
return return
} }
postureChecks, err := p.accountManager.GetPostureChecks(r.Context(), account.Id, postureChecksID, user.Id) postureChecks, err := p.accountManager.GetPostureChecks(r.Context(), accountID, postureChecksID, userID)
if err != nil { if err != nil {
util.WriteError(r.Context(), err, w) util.WriteError(r.Context(), err, w)
return return
@ -128,7 +122,7 @@ func (p *PostureChecksHandler) GetPostureCheck(w http.ResponseWriter, r *http.Re
// DeletePostureCheck handles posture check deletion request // DeletePostureCheck handles posture check deletion request
func (p *PostureChecksHandler) DeletePostureCheck(w http.ResponseWriter, r *http.Request) { func (p *PostureChecksHandler) DeletePostureCheck(w http.ResponseWriter, r *http.Request) {
claims := p.claimsExtractor.FromRequestContext(r) claims := p.claimsExtractor.FromRequestContext(r)
account, user, err := p.accountManager.GetAccountFromToken(r.Context(), claims) accountID, userID, err := p.accountManager.GetAccountIDFromToken(r.Context(), claims)
if err != nil { if err != nil {
util.WriteError(r.Context(), err, w) util.WriteError(r.Context(), err, w)
return return
@ -141,7 +135,7 @@ func (p *PostureChecksHandler) DeletePostureCheck(w http.ResponseWriter, r *http
return return
} }
if err = p.accountManager.DeletePostureChecks(r.Context(), account.Id, postureChecksID, user.Id); err != nil { if err = p.accountManager.DeletePostureChecks(r.Context(), accountID, postureChecksID, userID); err != nil {
util.WriteError(r.Context(), err, w) util.WriteError(r.Context(), err, w)
return return
} }
@ -150,13 +144,7 @@ func (p *PostureChecksHandler) DeletePostureCheck(w http.ResponseWriter, r *http
} }
// savePostureChecks handles posture checks create and update // savePostureChecks handles posture checks create and update
func (p *PostureChecksHandler) savePostureChecks( func (p *PostureChecksHandler) savePostureChecks(w http.ResponseWriter, r *http.Request, accountID, userID, postureChecksID string) {
w http.ResponseWriter,
r *http.Request,
account *server.Account,
user *server.User,
postureChecksID string,
) {
var ( var (
err error err error
req api.PostureCheckUpdate req api.PostureCheckUpdate
@ -181,7 +169,7 @@ func (p *PostureChecksHandler) savePostureChecks(
return return
} }
if err := p.accountManager.SavePostureChecks(r.Context(), account.Id, user.Id, postureChecks); err != nil { if err := p.accountManager.SavePostureChecks(r.Context(), accountID, userID, postureChecks); err != nil {
util.WriteError(r.Context(), err, w) util.WriteError(r.Context(), err, w)
return return
} }

View File

@ -14,7 +14,6 @@ import (
"github.com/gorilla/mux" "github.com/gorilla/mux"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/netbirdio/netbird/management/server"
"github.com/netbirdio/netbird/management/server/geolocation" "github.com/netbirdio/netbird/management/server/geolocation"
"github.com/netbirdio/netbird/management/server/http/api" "github.com/netbirdio/netbird/management/server/http/api"
"github.com/netbirdio/netbird/management/server/jwtclaims" "github.com/netbirdio/netbird/management/server/jwtclaims"
@ -67,15 +66,8 @@ func initPostureChecksTestData(postureChecks ...*posture.Checks) *PostureChecksH
} }
return accountPostureChecks, nil return accountPostureChecks, nil
}, },
GetAccountFromTokenFunc: func(_ context.Context, claims jwtclaims.AuthorizationClaims) (*server.Account, *server.User, error) { GetAccountIDFromTokenFunc: func(_ context.Context, claims jwtclaims.AuthorizationClaims) (string, string, error) {
user := server.NewAdminUser("test_user") return claims.AccountId, claims.UserId, nil
return &server.Account{
Id: claims.AccountId,
Users: map[string]*server.User{
"test_user": user,
},
PostureChecks: postureChecks,
}, user, nil
}, },
}, },
geolocationManager: &geolocation.Geolocation{}, geolocationManager: &geolocation.Geolocation{},

View File

@ -43,13 +43,13 @@ func NewRoutesHandler(accountManager server.AccountManager, authCfg AuthCfg) *Ro
// GetAllRoutes returns the list of routes for the account // GetAllRoutes returns the list of routes for the account
func (h *RoutesHandler) GetAllRoutes(w http.ResponseWriter, r *http.Request) { func (h *RoutesHandler) GetAllRoutes(w http.ResponseWriter, r *http.Request) {
claims := h.claimsExtractor.FromRequestContext(r) claims := h.claimsExtractor.FromRequestContext(r)
account, user, err := h.accountManager.GetAccountFromToken(r.Context(), claims) accountID, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims)
if err != nil { if err != nil {
util.WriteError(r.Context(), err, w) util.WriteError(r.Context(), err, w)
return return
} }
routes, err := h.accountManager.ListRoutes(r.Context(), account.Id, user.Id) routes, err := h.accountManager.ListRoutes(r.Context(), accountID, userID)
if err != nil { if err != nil {
util.WriteError(r.Context(), err, w) util.WriteError(r.Context(), err, w)
return return
@ -70,7 +70,7 @@ func (h *RoutesHandler) GetAllRoutes(w http.ResponseWriter, r *http.Request) {
// CreateRoute handles route creation request // CreateRoute handles route creation request
func (h *RoutesHandler) CreateRoute(w http.ResponseWriter, r *http.Request) { func (h *RoutesHandler) CreateRoute(w http.ResponseWriter, r *http.Request) {
claims := h.claimsExtractor.FromRequestContext(r) claims := h.claimsExtractor.FromRequestContext(r)
account, user, err := h.accountManager.GetAccountFromToken(r.Context(), claims) accountID, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims)
if err != nil { if err != nil {
util.WriteError(r.Context(), err, w) util.WriteError(r.Context(), err, w)
return return
@ -117,15 +117,9 @@ func (h *RoutesHandler) CreateRoute(w http.ResponseWriter, r *http.Request) {
peerGroupIds = *req.PeerGroups peerGroupIds = *req.PeerGroups
} }
// Do not allow non-Linux peers newRoute, err := h.accountManager.CreateRoute(r.Context(), accountID, newPrefix, networkType, domains, peerId, peerGroupIds,
if peer := account.GetPeer(peerId); peer != nil { req.Description, route.NetID(req.NetworkId), req.Masquerade, req.Metric, req.Groups, req.Enabled, userID, req.KeepRoute,
if peer.Meta.GoOS != "linux" { )
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "non-linux peers are not supported as network routes"), w)
return
}
}
newRoute, err := h.accountManager.CreateRoute(r.Context(), account.Id, newPrefix, networkType, domains, peerId, peerGroupIds, req.Description, route.NetID(req.NetworkId), req.Masquerade, req.Metric, req.Groups, req.Enabled, user.Id, req.KeepRoute)
if err != nil { if err != nil {
util.WriteError(r.Context(), err, w) util.WriteError(r.Context(), err, w)
return return
@ -168,7 +162,7 @@ func (h *RoutesHandler) validateRoute(req api.PostApiRoutesJSONRequestBody) erro
// UpdateRoute handles update to a route identified by a given ID // UpdateRoute handles update to a route identified by a given ID
func (h *RoutesHandler) UpdateRoute(w http.ResponseWriter, r *http.Request) { func (h *RoutesHandler) UpdateRoute(w http.ResponseWriter, r *http.Request) {
claims := h.claimsExtractor.FromRequestContext(r) claims := h.claimsExtractor.FromRequestContext(r)
account, user, err := h.accountManager.GetAccountFromToken(r.Context(), claims) accountID, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims)
if err != nil { if err != nil {
util.WriteError(r.Context(), err, w) util.WriteError(r.Context(), err, w)
return return
@ -181,7 +175,7 @@ func (h *RoutesHandler) UpdateRoute(w http.ResponseWriter, r *http.Request) {
return return
} }
_, err = h.accountManager.GetRoute(r.Context(), account.Id, route.ID(routeID), user.Id) _, err = h.accountManager.GetRoute(r.Context(), accountID, route.ID(routeID), userID)
if err != nil { if err != nil {
util.WriteError(r.Context(), err, w) util.WriteError(r.Context(), err, w)
return return
@ -204,14 +198,6 @@ func (h *RoutesHandler) UpdateRoute(w http.ResponseWriter, r *http.Request) {
peerID = *req.Peer peerID = *req.Peer
} }
// do not allow non Linux peers
if peer := account.GetPeer(peerID); peer != nil {
if peer.Meta.GoOS != "linux" {
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "non-linux peers are non supported as network routes"), w)
return
}
}
newRoute := &route.Route{ newRoute := &route.Route{
ID: route.ID(routeID), ID: route.ID(routeID),
NetID: route.NetID(req.NetworkId), NetID: route.NetID(req.NetworkId),
@ -247,7 +233,7 @@ func (h *RoutesHandler) UpdateRoute(w http.ResponseWriter, r *http.Request) {
newRoute.PeerGroups = *req.PeerGroups newRoute.PeerGroups = *req.PeerGroups
} }
err = h.accountManager.SaveRoute(r.Context(), account.Id, user.Id, newRoute) err = h.accountManager.SaveRoute(r.Context(), accountID, userID, newRoute)
if err != nil { if err != nil {
util.WriteError(r.Context(), err, w) util.WriteError(r.Context(), err, w)
return return
@ -265,7 +251,7 @@ func (h *RoutesHandler) UpdateRoute(w http.ResponseWriter, r *http.Request) {
// DeleteRoute handles route deletion request // DeleteRoute handles route deletion request
func (h *RoutesHandler) DeleteRoute(w http.ResponseWriter, r *http.Request) { func (h *RoutesHandler) DeleteRoute(w http.ResponseWriter, r *http.Request) {
claims := h.claimsExtractor.FromRequestContext(r) claims := h.claimsExtractor.FromRequestContext(r)
account, user, err := h.accountManager.GetAccountFromToken(r.Context(), claims) accountID, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims)
if err != nil { if err != nil {
util.WriteError(r.Context(), err, w) util.WriteError(r.Context(), err, w)
return return
@ -277,7 +263,7 @@ func (h *RoutesHandler) DeleteRoute(w http.ResponseWriter, r *http.Request) {
return return
} }
err = h.accountManager.DeleteRoute(r.Context(), account.Id, route.ID(routeID), user.Id) err = h.accountManager.DeleteRoute(r.Context(), accountID, route.ID(routeID), userID)
if err != nil { if err != nil {
util.WriteError(r.Context(), err, w) util.WriteError(r.Context(), err, w)
return return
@ -289,7 +275,7 @@ func (h *RoutesHandler) DeleteRoute(w http.ResponseWriter, r *http.Request) {
// GetRoute handles a route Get request identified by ID // GetRoute handles a route Get request identified by ID
func (h *RoutesHandler) GetRoute(w http.ResponseWriter, r *http.Request) { func (h *RoutesHandler) GetRoute(w http.ResponseWriter, r *http.Request) {
claims := h.claimsExtractor.FromRequestContext(r) claims := h.claimsExtractor.FromRequestContext(r)
account, user, err := h.accountManager.GetAccountFromToken(r.Context(), claims) accountID, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims)
if err != nil { if err != nil {
util.WriteError(r.Context(), err, w) util.WriteError(r.Context(), err, w)
return return
@ -301,7 +287,7 @@ func (h *RoutesHandler) GetRoute(w http.ResponseWriter, r *http.Request) {
return return
} }
foundRoute, err := h.accountManager.GetRoute(r.Context(), account.Id, route.ID(routeID), user.Id) foundRoute, err := h.accountManager.GetRoute(r.Context(), accountID, route.ID(routeID), userID)
if err != nil { if err != nil {
util.WriteError(r.Context(), status.Errorf(status.NotFound, "route not found"), w) util.WriteError(r.Context(), status.Errorf(status.NotFound, "route not found"), w)
return return

View File

@ -112,6 +112,12 @@ func initRoutesTestData() *RoutesHandler {
if len(peerGroups) > 0 && peerGroups[0] == notFoundGroupID { if len(peerGroups) > 0 && peerGroups[0] == notFoundGroupID {
return nil, status.Errorf(status.InvalidArgument, "peer groups with ID %s not found", peerGroups[0]) return nil, status.Errorf(status.InvalidArgument, "peer groups with ID %s not found", peerGroups[0])
} }
if peerID != "" {
if peerID == nonLinuxExistingPeerID {
return nil, status.Errorf(status.InvalidArgument, "non-linux peers are not supported as network routes")
}
}
return &route.Route{ return &route.Route{
ID: existingRouteID, ID: existingRouteID,
NetID: netID, NetID: netID,
@ -131,6 +137,11 @@ func initRoutesTestData() *RoutesHandler {
if r.Peer == notFoundPeerID { if r.Peer == notFoundPeerID {
return status.Errorf(status.InvalidArgument, "peer with ID %s not found", r.Peer) return status.Errorf(status.InvalidArgument, "peer with ID %s not found", r.Peer)
} }
if r.Peer == nonLinuxExistingPeerID {
return status.Errorf(status.InvalidArgument, "non-linux peers are not supported as network routes")
}
return nil return nil
}, },
DeleteRouteFunc: func(_ context.Context, _ string, routeID route.ID, _ string) error { DeleteRouteFunc: func(_ context.Context, _ string, routeID route.ID, _ string) error {
@ -139,8 +150,9 @@ func initRoutesTestData() *RoutesHandler {
} }
return nil return nil
}, },
GetAccountFromTokenFunc: func(_ context.Context, _ jwtclaims.AuthorizationClaims) (*server.Account, *server.User, error) { GetAccountIDFromTokenFunc: func(_ context.Context, _ jwtclaims.AuthorizationClaims) (string, string, error) {
return testingAccount, testingAccount.Users["test_user"], nil //return testingAccount, testingAccount.Users["test_user"], nil
return testingAccount.Id, testingAccount.Users["test_user"].Id, nil
}, },
}, },
claimsExtractor: jwtclaims.NewClaimsExtractor( claimsExtractor: jwtclaims.NewClaimsExtractor(

View File

@ -35,7 +35,7 @@ func NewSetupKeysHandler(accountManager server.AccountManager, authCfg AuthCfg)
// CreateSetupKey is a POST requests that creates a new SetupKey // CreateSetupKey is a POST requests that creates a new SetupKey
func (h *SetupKeysHandler) CreateSetupKey(w http.ResponseWriter, r *http.Request) { func (h *SetupKeysHandler) CreateSetupKey(w http.ResponseWriter, r *http.Request) {
claims := h.claimsExtractor.FromRequestContext(r) claims := h.claimsExtractor.FromRequestContext(r)
account, user, err := h.accountManager.GetAccountFromToken(r.Context(), claims) accountID, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims)
if err != nil { if err != nil {
util.WriteError(r.Context(), err, w) util.WriteError(r.Context(), err, w)
return return
@ -76,8 +76,8 @@ func (h *SetupKeysHandler) CreateSetupKey(w http.ResponseWriter, r *http.Request
if req.Ephemeral != nil { if req.Ephemeral != nil {
ephemeral = *req.Ephemeral ephemeral = *req.Ephemeral
} }
setupKey, err := h.accountManager.CreateSetupKey(r.Context(), account.Id, req.Name, server.SetupKeyType(req.Type), expiresIn, setupKey, err := h.accountManager.CreateSetupKey(r.Context(), accountID, req.Name, server.SetupKeyType(req.Type), expiresIn,
req.AutoGroups, req.UsageLimit, user.Id, ephemeral) req.AutoGroups, req.UsageLimit, userID, ephemeral)
if err != nil { if err != nil {
util.WriteError(r.Context(), err, w) util.WriteError(r.Context(), err, w)
return return
@ -89,7 +89,7 @@ func (h *SetupKeysHandler) CreateSetupKey(w http.ResponseWriter, r *http.Request
// GetSetupKey is a GET request to get a SetupKey by ID // GetSetupKey is a GET request to get a SetupKey by ID
func (h *SetupKeysHandler) GetSetupKey(w http.ResponseWriter, r *http.Request) { func (h *SetupKeysHandler) GetSetupKey(w http.ResponseWriter, r *http.Request) {
claims := h.claimsExtractor.FromRequestContext(r) claims := h.claimsExtractor.FromRequestContext(r)
account, user, err := h.accountManager.GetAccountFromToken(r.Context(), claims) accountID, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims)
if err != nil { if err != nil {
util.WriteError(r.Context(), err, w) util.WriteError(r.Context(), err, w)
return return
@ -102,7 +102,7 @@ func (h *SetupKeysHandler) GetSetupKey(w http.ResponseWriter, r *http.Request) {
return return
} }
key, err := h.accountManager.GetSetupKey(r.Context(), account.Id, user.Id, keyID) key, err := h.accountManager.GetSetupKey(r.Context(), accountID, userID, keyID)
if err != nil { if err != nil {
util.WriteError(r.Context(), err, w) util.WriteError(r.Context(), err, w)
return return
@ -114,7 +114,7 @@ func (h *SetupKeysHandler) GetSetupKey(w http.ResponseWriter, r *http.Request) {
// UpdateSetupKey is a PUT request to update server.SetupKey // UpdateSetupKey is a PUT request to update server.SetupKey
func (h *SetupKeysHandler) UpdateSetupKey(w http.ResponseWriter, r *http.Request) { func (h *SetupKeysHandler) UpdateSetupKey(w http.ResponseWriter, r *http.Request) {
claims := h.claimsExtractor.FromRequestContext(r) claims := h.claimsExtractor.FromRequestContext(r)
account, user, err := h.accountManager.GetAccountFromToken(r.Context(), claims) accountID, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims)
if err != nil { if err != nil {
util.WriteError(r.Context(), err, w) util.WriteError(r.Context(), err, w)
return return
@ -150,7 +150,7 @@ func (h *SetupKeysHandler) UpdateSetupKey(w http.ResponseWriter, r *http.Request
newKey.Name = req.Name newKey.Name = req.Name
newKey.Id = keyID newKey.Id = keyID
newKey, err = h.accountManager.SaveSetupKey(r.Context(), account.Id, newKey, user.Id) newKey, err = h.accountManager.SaveSetupKey(r.Context(), accountID, newKey, userID)
if err != nil { if err != nil {
util.WriteError(r.Context(), err, w) util.WriteError(r.Context(), err, w)
return return
@ -161,13 +161,13 @@ func (h *SetupKeysHandler) UpdateSetupKey(w http.ResponseWriter, r *http.Request
// GetAllSetupKeys is a GET request that returns a list of SetupKey // GetAllSetupKeys is a GET request that returns a list of SetupKey
func (h *SetupKeysHandler) GetAllSetupKeys(w http.ResponseWriter, r *http.Request) { func (h *SetupKeysHandler) GetAllSetupKeys(w http.ResponseWriter, r *http.Request) {
claims := h.claimsExtractor.FromRequestContext(r) claims := h.claimsExtractor.FromRequestContext(r)
account, user, err := h.accountManager.GetAccountFromToken(r.Context(), claims) accountID, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims)
if err != nil { if err != nil {
util.WriteError(r.Context(), err, w) util.WriteError(r.Context(), err, w)
return return
} }
setupKeys, err := h.accountManager.ListSetupKeys(r.Context(), account.Id, user.Id) setupKeys, err := h.accountManager.ListSetupKeys(r.Context(), accountID, userID)
if err != nil { if err != nil {
util.WriteError(r.Context(), err, w) util.WriteError(r.Context(), err, w)
return return

View File

@ -15,7 +15,6 @@ import (
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/netbirdio/netbird/management/server" "github.com/netbirdio/netbird/management/server"
nbgroup "github.com/netbirdio/netbird/management/server/group"
"github.com/netbirdio/netbird/management/server/http/api" "github.com/netbirdio/netbird/management/server/http/api"
"github.com/netbirdio/netbird/management/server/jwtclaims" "github.com/netbirdio/netbird/management/server/jwtclaims"
"github.com/netbirdio/netbird/management/server/mock_server" "github.com/netbirdio/netbird/management/server/mock_server"
@ -34,21 +33,8 @@ func initSetupKeysTestMetaData(defaultKey *server.SetupKey, newKey *server.Setup
) *SetupKeysHandler { ) *SetupKeysHandler {
return &SetupKeysHandler{ return &SetupKeysHandler{
accountManager: &mock_server.MockAccountManager{ accountManager: &mock_server.MockAccountManager{
GetAccountFromTokenFunc: func(_ context.Context, claims jwtclaims.AuthorizationClaims) (*server.Account, *server.User, error) { GetAccountIDFromTokenFunc: func(_ context.Context, claims jwtclaims.AuthorizationClaims) (string, string, error) {
return &server.Account{ return claims.AccountId, claims.UserId, nil
Id: testAccountID,
Domain: "hotmail.com",
Users: map[string]*server.User{
user.Id: user,
},
SetupKeys: map[string]*server.SetupKey{
defaultKey.Key: defaultKey,
},
Groups: map[string]*nbgroup.Group{
"group-1": {ID: "group-1", Peers: []string{"A", "B"}},
"id-all": {ID: "id-all", Name: "All"},
},
}, user, nil
}, },
CreateSetupKeyFunc: func(_ context.Context, _ string, keyName string, typ server.SetupKeyType, _ time.Duration, _ []string, CreateSetupKeyFunc: func(_ context.Context, _ string, keyName string, typ server.SetupKeyType, _ time.Duration, _ []string,
_ int, _ string, ephemeral bool, _ int, _ string, ephemeral bool,

View File

@ -41,22 +41,22 @@ func (h *UsersHandler) UpdateUser(w http.ResponseWriter, r *http.Request) {
} }
claims := h.claimsExtractor.FromRequestContext(r) claims := h.claimsExtractor.FromRequestContext(r)
account, user, err := h.accountManager.GetAccountFromToken(r.Context(), claims) accountID, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims)
if err != nil { if err != nil {
util.WriteError(r.Context(), err, w) util.WriteError(r.Context(), err, w)
return return
} }
vars := mux.Vars(r) vars := mux.Vars(r)
userID := vars["userId"] targetUserID := vars["userId"]
if len(userID) == 0 { if len(targetUserID) == 0 {
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "invalid user ID"), w) util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "invalid user ID"), w)
return return
} }
existingUser, ok := account.Users[userID] existingUser, err := h.accountManager.GetUserByID(r.Context(), targetUserID)
if !ok { if err != nil {
util.WriteError(r.Context(), status.Errorf(status.NotFound, "couldn't find user with ID %s", userID), w) util.WriteError(r.Context(), err, w)
return return
} }
@ -78,8 +78,8 @@ func (h *UsersHandler) UpdateUser(w http.ResponseWriter, r *http.Request) {
return return
} }
newUser, err := h.accountManager.SaveUser(r.Context(), account.Id, user.Id, &server.User{ newUser, err := h.accountManager.SaveUser(r.Context(), accountID, userID, &server.User{
Id: userID, Id: targetUserID,
Role: userRole, Role: userRole,
AutoGroups: req.AutoGroups, AutoGroups: req.AutoGroups,
Blocked: req.IsBlocked, Blocked: req.IsBlocked,
@ -102,7 +102,7 @@ func (h *UsersHandler) DeleteUser(w http.ResponseWriter, r *http.Request) {
} }
claims := h.claimsExtractor.FromRequestContext(r) claims := h.claimsExtractor.FromRequestContext(r)
account, user, err := h.accountManager.GetAccountFromToken(r.Context(), claims) accountID, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims)
if err != nil { if err != nil {
util.WriteError(r.Context(), err, w) util.WriteError(r.Context(), err, w)
return return
@ -115,7 +115,7 @@ func (h *UsersHandler) DeleteUser(w http.ResponseWriter, r *http.Request) {
return return
} }
err = h.accountManager.DeleteUser(r.Context(), account.Id, user.Id, targetUserID) err = h.accountManager.DeleteUser(r.Context(), accountID, userID, targetUserID)
if err != nil { if err != nil {
util.WriteError(r.Context(), err, w) util.WriteError(r.Context(), err, w)
return return
@ -132,7 +132,7 @@ func (h *UsersHandler) CreateUser(w http.ResponseWriter, r *http.Request) {
} }
claims := h.claimsExtractor.FromRequestContext(r) claims := h.claimsExtractor.FromRequestContext(r)
account, user, err := h.accountManager.GetAccountFromToken(r.Context(), claims) accountID, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims)
if err != nil { if err != nil {
util.WriteError(r.Context(), err, w) util.WriteError(r.Context(), err, w)
return return
@ -160,7 +160,7 @@ func (h *UsersHandler) CreateUser(w http.ResponseWriter, r *http.Request) {
name = *req.Name name = *req.Name
} }
newUser, err := h.accountManager.CreateUser(r.Context(), account.Id, user.Id, &server.UserInfo{ newUser, err := h.accountManager.CreateUser(r.Context(), accountID, userID, &server.UserInfo{
Email: email, Email: email,
Name: name, Name: name,
Role: req.Role, Role: req.Role,
@ -184,13 +184,13 @@ func (h *UsersHandler) GetAllUsers(w http.ResponseWriter, r *http.Request) {
} }
claims := h.claimsExtractor.FromRequestContext(r) claims := h.claimsExtractor.FromRequestContext(r)
account, user, err := h.accountManager.GetAccountFromToken(r.Context(), claims) accountID, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims)
if err != nil { if err != nil {
util.WriteError(r.Context(), err, w) util.WriteError(r.Context(), err, w)
return return
} }
data, err := h.accountManager.GetUsersFromAccount(r.Context(), account.Id, user.Id) data, err := h.accountManager.GetUsersFromAccount(r.Context(), accountID, userID)
if err != nil { if err != nil {
util.WriteError(r.Context(), err, w) util.WriteError(r.Context(), err, w)
return return
@ -231,7 +231,7 @@ func (h *UsersHandler) InviteUser(w http.ResponseWriter, r *http.Request) {
} }
claims := h.claimsExtractor.FromRequestContext(r) claims := h.claimsExtractor.FromRequestContext(r)
account, user, err := h.accountManager.GetAccountFromToken(r.Context(), claims) accountID, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims)
if err != nil { if err != nil {
util.WriteError(r.Context(), err, w) util.WriteError(r.Context(), err, w)
return return
@ -244,7 +244,7 @@ func (h *UsersHandler) InviteUser(w http.ResponseWriter, r *http.Request) {
return return
} }
err = h.accountManager.InviteUser(r.Context(), account.Id, user.Id, targetUserID) err = h.accountManager.InviteUser(r.Context(), accountID, userID, targetUserID)
if err != nil { if err != nil {
util.WriteError(r.Context(), err, w) util.WriteError(r.Context(), err, w)
return return

View File

@ -64,8 +64,11 @@ var usersTestAccount = &server.Account{
func initUsersTestData() *UsersHandler { func initUsersTestData() *UsersHandler {
return &UsersHandler{ return &UsersHandler{
accountManager: &mock_server.MockAccountManager{ accountManager: &mock_server.MockAccountManager{
GetAccountFromTokenFunc: func(_ context.Context, claims jwtclaims.AuthorizationClaims) (*server.Account, *server.User, error) { GetAccountIDFromTokenFunc: func(_ context.Context, claims jwtclaims.AuthorizationClaims) (string, string, error) {
return usersTestAccount, usersTestAccount.Users[claims.UserId], nil return usersTestAccount.Id, claims.UserId, nil
},
GetUserByIDFunc: func(ctx context.Context, id string) (*server.User, error) {
return usersTestAccount.Users[id], nil
}, },
GetUsersFromAccountFunc: func(_ context.Context, accountID, userID string) ([]*server.UserInfo, error) { GetUsersFromAccountFunc: func(_ context.Context, accountID, userID string) ([]*server.UserInfo, error) {
users := make([]*server.UserInfo, 0) users := make([]*server.UserInfo, 0)

View File

@ -23,10 +23,11 @@ import (
type MockAccountManager struct { type MockAccountManager struct {
GetOrCreateAccountByUserFunc func(ctx context.Context, userId, domain string) (*server.Account, error) GetOrCreateAccountByUserFunc func(ctx context.Context, userId, domain string) (*server.Account, error)
GetAccountFunc func(ctx context.Context, accountID string) (*server.Account, error)
CreateSetupKeyFunc func(ctx context.Context, accountId string, keyName string, keyType server.SetupKeyType, CreateSetupKeyFunc func(ctx context.Context, accountId string, keyName string, keyType server.SetupKeyType,
expiresIn time.Duration, autoGroups []string, usageLimit int, userID string, ephemeral bool) (*server.SetupKey, error) expiresIn time.Duration, autoGroups []string, usageLimit int, userID string, ephemeral bool) (*server.SetupKey, error)
GetSetupKeyFunc func(ctx context.Context, accountID, userID, keyID string) (*server.SetupKey, error) GetSetupKeyFunc func(ctx context.Context, accountID, userID, keyID string) (*server.SetupKey, error)
GetAccountByUserOrAccountIdFunc func(ctx context.Context, userId, accountId, domain string) (*server.Account, error) GetAccountIDByUserOrAccountIdFunc func(ctx context.Context, userId, accountId, domain string) (string, error)
GetUserFunc func(ctx context.Context, claims jwtclaims.AuthorizationClaims) (*server.User, error) GetUserFunc func(ctx context.Context, claims jwtclaims.AuthorizationClaims) (*server.User, error)
ListUsersFunc func(ctx context.Context, accountID string) ([]*server.User, error) ListUsersFunc func(ctx context.Context, accountID string) ([]*server.User, error)
GetPeersFunc func(ctx context.Context, accountID, userID string) ([]*nbpeer.Peer, error) GetPeersFunc func(ctx context.Context, accountID, userID string) ([]*nbpeer.Peer, error)
@ -48,7 +49,7 @@ type MockAccountManager struct {
GroupDeletePeerFunc func(ctx context.Context, accountID, groupID, peerID string) error GroupDeletePeerFunc func(ctx context.Context, accountID, groupID, peerID string) error
DeleteRuleFunc func(ctx context.Context, accountID, ruleID, userID string) error DeleteRuleFunc func(ctx context.Context, accountID, ruleID, userID string) error
GetPolicyFunc func(ctx context.Context, accountID, policyID, userID string) (*server.Policy, error) GetPolicyFunc func(ctx context.Context, accountID, policyID, userID string) (*server.Policy, error)
SavePolicyFunc func(ctx context.Context, accountID, userID string, policy *server.Policy) error SavePolicyFunc func(ctx context.Context, accountID, userID string, policy *server.Policy, isUpdate bool) error
DeletePolicyFunc func(ctx context.Context, accountID, policyID, userID string) error DeletePolicyFunc func(ctx context.Context, accountID, policyID, userID string) error
ListPoliciesFunc func(ctx context.Context, accountID, userID string) ([]*server.Policy, error) ListPoliciesFunc func(ctx context.Context, accountID, userID string) ([]*server.Policy, error)
GetUsersFromAccountFunc func(ctx context.Context, accountID, userID string) ([]*server.UserInfo, error) GetUsersFromAccountFunc func(ctx context.Context, accountID, userID string) ([]*server.UserInfo, error)
@ -79,7 +80,7 @@ type MockAccountManager struct {
DeleteNameServerGroupFunc func(ctx context.Context, accountID, nsGroupID, userID string) error DeleteNameServerGroupFunc func(ctx context.Context, accountID, nsGroupID, userID string) error
ListNameServerGroupsFunc func(ctx context.Context, accountID string, userID string) ([]*nbdns.NameServerGroup, error) ListNameServerGroupsFunc func(ctx context.Context, accountID string, userID string) ([]*nbdns.NameServerGroup, error)
CreateUserFunc func(ctx context.Context, accountID, userID string, key *server.UserInfo) (*server.UserInfo, error) CreateUserFunc func(ctx context.Context, accountID, userID string, key *server.UserInfo) (*server.UserInfo, error)
GetAccountFromTokenFunc func(ctx context.Context, claims jwtclaims.AuthorizationClaims) (*server.Account, *server.User, error) GetAccountIDFromTokenFunc func(ctx context.Context, claims jwtclaims.AuthorizationClaims) (string, string, error)
CheckUserAccessByJWTGroupsFunc func(ctx context.Context, claims jwtclaims.AuthorizationClaims) error CheckUserAccessByJWTGroupsFunc func(ctx context.Context, claims jwtclaims.AuthorizationClaims) error
DeleteAccountFunc func(ctx context.Context, accountID, userID string) error DeleteAccountFunc func(ctx context.Context, accountID, userID string) error
GetDNSDomainFunc func() string GetDNSDomainFunc func() string
@ -105,6 +106,9 @@ type MockAccountManager struct {
SyncPeerMetaFunc func(ctx context.Context, peerPubKey string, meta nbpeer.PeerSystemMeta) error SyncPeerMetaFunc func(ctx context.Context, peerPubKey string, meta nbpeer.PeerSystemMeta) error
FindExistingPostureCheckFunc func(accountID string, checks *posture.ChecksDefinition) (*posture.Checks, error) FindExistingPostureCheckFunc func(accountID string, checks *posture.ChecksDefinition) (*posture.Checks, error)
GetAccountIDForPeerKeyFunc func(ctx context.Context, peerKey string) (string, error) GetAccountIDForPeerKeyFunc func(ctx context.Context, peerKey string) (string, error)
GetAccountByIDFunc func(ctx context.Context, accountID string, userID string) (*server.Account, error)
GetUserByIDFunc func(ctx context.Context, id string) (*server.User, error)
GetAccountSettingsFunc func(ctx context.Context, accountID string, userID string) (*server.Settings, error)
} }
func (am *MockAccountManager) SyncAndMarkPeer(ctx context.Context, accountID string, peerPubKey string, meta nbpeer.PeerSystemMeta, realIP net.IP) (*nbpeer.Peer, *server.NetworkMap, []*posture.Checks, error) { func (am *MockAccountManager) SyncAndMarkPeer(ctx context.Context, accountID string, peerPubKey string, meta nbpeer.PeerSystemMeta, realIP net.IP) (*nbpeer.Peer, *server.NetworkMap, []*posture.Checks, error) {
@ -190,16 +194,14 @@ func (am *MockAccountManager) CreateSetupKey(
return nil, status.Errorf(codes.Unimplemented, "method CreateSetupKey is not implemented") return nil, status.Errorf(codes.Unimplemented, "method CreateSetupKey is not implemented")
} }
// GetAccountByUserOrAccountID mock implementation of GetAccountByUserOrAccountID from server.AccountManager interface // GetAccountIDByUserOrAccountID mock implementation of GetAccountIDByUserOrAccountID from server.AccountManager interface
func (am *MockAccountManager) GetAccountByUserOrAccountID( func (am *MockAccountManager) GetAccountIDByUserOrAccountID(ctx context.Context, userId, accountId, domain string) (string, error) {
ctx context.Context, userId, accountId, domain string, if am.GetAccountIDByUserOrAccountIdFunc != nil {
) (*server.Account, error) { return am.GetAccountIDByUserOrAccountIdFunc(ctx, userId, accountId, domain)
if am.GetAccountByUserOrAccountIdFunc != nil {
return am.GetAccountByUserOrAccountIdFunc(ctx, userId, accountId, domain)
} }
return nil, status.Errorf( return "", status.Errorf(
codes.Unimplemented, codes.Unimplemented,
"method GetAccountByUserOrAccountID is not implemented", "method GetAccountIDByUserOrAccountID is not implemented",
) )
} }
@ -377,9 +379,9 @@ func (am *MockAccountManager) GetPolicy(ctx context.Context, accountID, policyID
} }
// SavePolicy mock implementation of SavePolicy from server.AccountManager interface // SavePolicy mock implementation of SavePolicy from server.AccountManager interface
func (am *MockAccountManager) SavePolicy(ctx context.Context, accountID, userID string, policy *server.Policy) error { func (am *MockAccountManager) SavePolicy(ctx context.Context, accountID, userID string, policy *server.Policy, isUpdate bool) error {
if am.SavePolicyFunc != nil { if am.SavePolicyFunc != nil {
return am.SavePolicyFunc(ctx, accountID, userID, policy) return am.SavePolicyFunc(ctx, accountID, userID, policy, isUpdate)
} }
return status.Errorf(codes.Unimplemented, "method SavePolicy is not implemented") return status.Errorf(codes.Unimplemented, "method SavePolicy is not implemented")
} }
@ -601,14 +603,12 @@ func (am *MockAccountManager) CreateUser(ctx context.Context, accountID, userID
return nil, status.Errorf(codes.Unimplemented, "method CreateUser is not implemented") return nil, status.Errorf(codes.Unimplemented, "method CreateUser is not implemented")
} }
// GetAccountFromToken mocks GetAccountFromToken of the AccountManager interface // GetAccountIDFromToken mocks GetAccountIDFromToken of the AccountManager interface
func (am *MockAccountManager) GetAccountFromToken(ctx context.Context, claims jwtclaims.AuthorizationClaims) (*server.Account, *server.User, func (am *MockAccountManager) GetAccountIDFromToken(ctx context.Context, claims jwtclaims.AuthorizationClaims) (string, string, error) {
error, if am.GetAccountIDFromTokenFunc != nil {
) { return am.GetAccountIDFromTokenFunc(ctx, claims)
if am.GetAccountFromTokenFunc != nil {
return am.GetAccountFromTokenFunc(ctx, claims)
} }
return nil, nil, status.Errorf(codes.Unimplemented, "method GetAccountFromToken is not implemented") return "", "", status.Errorf(codes.Unimplemented, "method GetAccountIDFromToken is not implemented")
} }
func (am *MockAccountManager) CheckUserAccessByJWTGroups(ctx context.Context, claims jwtclaims.AuthorizationClaims) error { func (am *MockAccountManager) CheckUserAccessByJWTGroups(ctx context.Context, claims jwtclaims.AuthorizationClaims) error {
@ -802,3 +802,33 @@ func (am *MockAccountManager) GetAccountIDForPeerKey(ctx context.Context, peerKe
} }
return "", status.Errorf(codes.Unimplemented, "method GetAccountIDForPeerKey is not implemented") return "", status.Errorf(codes.Unimplemented, "method GetAccountIDForPeerKey is not implemented")
} }
// GetAccountByID mocks GetAccountByID of the AccountManager interface
func (am *MockAccountManager) GetAccountByID(ctx context.Context, accountID string, userID string) (*server.Account, error) {
if am.GetAccountByIDFunc != nil {
return am.GetAccountByIDFunc(ctx, accountID, userID)
}
return nil, status.Errorf(codes.Unimplemented, "method GetAccountByID is not implemented")
}
// GetUserByID mocks GetUserByID of the AccountManager interface
func (am *MockAccountManager) GetUserByID(ctx context.Context, id string) (*server.User, error) {
if am.GetUserByIDFunc != nil {
return am.GetUserByIDFunc(ctx, id)
}
return nil, status.Errorf(codes.Unimplemented, "method GetUserByID is not implemented")
}
func (am *MockAccountManager) GetAccountSettings(ctx context.Context, accountID string, userID string) (*server.Settings, error) {
if am.GetAccountSettingsFunc != nil {
return am.GetAccountSettingsFunc(ctx, accountID, userID)
}
return nil, status.Errorf(codes.Unimplemented, "method GetAccountSettings is not implemented")
}
func (am *MockAccountManager) GetAccount(ctx context.Context, accountID string) (*server.Account, error) {
if am.GetAccountFunc != nil {
return am.GetAccountFunc(ctx, accountID)
}
return nil, status.Errorf(codes.Unimplemented, "method GetAccount is not implemented")
}

View File

@ -19,30 +19,16 @@ const domainPattern = `^(?i)[a-z0-9]+([\-\.]{1}[a-z0-9]+)*\.[a-z]{2,}$`
// GetNameServerGroup gets a nameserver group object from account and nameserver group IDs // GetNameServerGroup gets a nameserver group object from account and nameserver group IDs
func (am *DefaultAccountManager) GetNameServerGroup(ctx context.Context, accountID, userID, nsGroupID string) (*nbdns.NameServerGroup, error) { func (am *DefaultAccountManager) GetNameServerGroup(ctx context.Context, accountID, userID, nsGroupID string) (*nbdns.NameServerGroup, error) {
user, err := am.Store.GetUserByUserID(ctx, LockingStrengthShare, userID)
unlock := am.Store.AcquireWriteLockByUID(ctx, accountID)
defer unlock()
account, err := am.Store.GetAccount(ctx, accountID)
if err != nil { if err != nil {
return nil, err return nil, err
} }
user, err := account.FindUser(userID) if !user.IsAdminOrServiceUser() || user.AccountID != accountID {
if err != nil { return nil, status.Errorf(status.PermissionDenied, "only users with admin power can view name server groups")
return nil, err
} }
if !(user.HasAdminPower() || user.IsServiceUser) { return am.Store.GetNameServerGroupByID(ctx, LockingStrengthShare, nsGroupID, accountID)
return nil, status.Errorf(status.PermissionDenied, "only users with admin power can view nameserver groups")
}
nsGroup, found := account.NameServerGroups[nsGroupID]
if found {
return nsGroup.Copy(), nil
}
return nil, status.Errorf(status.NotFound, "nameserver group with ID %s not found", nsGroupID)
} }
// CreateNameServerGroup creates and saves a new nameserver group // CreateNameServerGroup creates and saves a new nameserver group
@ -159,30 +145,16 @@ func (am *DefaultAccountManager) DeleteNameServerGroup(ctx context.Context, acco
// ListNameServerGroups returns a list of nameserver groups from account // ListNameServerGroups returns a list of nameserver groups from account
func (am *DefaultAccountManager) ListNameServerGroups(ctx context.Context, accountID string, userID string) ([]*nbdns.NameServerGroup, error) { func (am *DefaultAccountManager) ListNameServerGroups(ctx context.Context, accountID string, userID string) ([]*nbdns.NameServerGroup, error) {
user, err := am.Store.GetUserByUserID(ctx, LockingStrengthShare, userID)
unlock := am.Store.AcquireWriteLockByUID(ctx, accountID)
defer unlock()
account, err := am.Store.GetAccount(ctx, accountID)
if err != nil { if err != nil {
return nil, err return nil, err
} }
user, err := account.FindUser(userID) if !user.IsAdminOrServiceUser() || user.AccountID != accountID {
if err != nil {
return nil, err
}
if !(user.HasAdminPower() || user.IsServiceUser) {
return nil, status.Errorf(status.PermissionDenied, "only users with admin power can view name server groups") return nil, status.Errorf(status.PermissionDenied, "only users with admin power can view name server groups")
} }
nsGroups := make([]*nbdns.NameServerGroup, 0, len(account.NameServerGroups)) return am.Store.GetAccountNameServerGroups(ctx, LockingStrengthShare, accountID)
for _, item := range account.NameServerGroups {
nsGroups = append(nsGroups, item.Copy())
}
return nsGroups, nil
} }
func validateNameServerGroup(existingGroup bool, nameserverGroup *nbdns.NameServerGroup, account *Account) error { func validateNameServerGroup(existingGroup bool, nameserverGroup *nbdns.NameServerGroup, account *Account) error {

View File

@ -251,7 +251,7 @@ func TestAccountManager_GetNetworkMapWithPolicy(t *testing.T) {
Action: PolicyTrafficActionAccept, Action: PolicyTrafficActionAccept,
}, },
} }
err = manager.SavePolicy(context.Background(), account.Id, userID, &policy) err = manager.SavePolicy(context.Background(), account.Id, userID, &policy, false)
if err != nil { if err != nil {
t.Errorf("expecting rule to be added, got failure %v", err) t.Errorf("expecting rule to be added, got failure %v", err)
return return
@ -299,7 +299,7 @@ func TestAccountManager_GetNetworkMapWithPolicy(t *testing.T) {
} }
policy.Enabled = false policy.Enabled = false
err = manager.SavePolicy(context.Background(), account.Id, userID, &policy) err = manager.SavePolicy(context.Background(), account.Id, userID, &policy, true)
if err != nil { if err != nil {
t.Errorf("expecting rule to be added, got failure %v", err) t.Errorf("expecting rule to be added, got failure %v", err)
return return

View File

@ -3,6 +3,7 @@ package server
import ( import (
"context" "context"
_ "embed" _ "embed"
"slices"
"strconv" "strconv"
"strings" "strings"
@ -314,34 +315,20 @@ func (a *Account) connResourcesGenerator(ctx context.Context) (func(*PolicyRule,
// GetPolicy from the store // GetPolicy from the store
func (am *DefaultAccountManager) GetPolicy(ctx context.Context, accountID, policyID, userID string) (*Policy, error) { func (am *DefaultAccountManager) GetPolicy(ctx context.Context, accountID, policyID, userID string) (*Policy, error) {
unlock := am.Store.AcquireWriteLockByUID(ctx, accountID) user, err := am.Store.GetUserByUserID(ctx, LockingStrengthShare, userID)
defer unlock()
account, err := am.Store.GetAccount(ctx, accountID)
if err != nil { if err != nil {
return nil, err return nil, err
} }
user, err := account.FindUser(userID) if !user.IsAdminOrServiceUser() || user.AccountID != accountID {
if err != nil {
return nil, err
}
if !(user.HasAdminPower() || user.IsServiceUser) {
return nil, status.Errorf(status.PermissionDenied, "only users with admin power are allowed to view policies") return nil, status.Errorf(status.PermissionDenied, "only users with admin power are allowed to view policies")
} }
for _, policy := range account.Policies { return am.Store.GetPolicyByID(ctx, LockingStrengthShare, policyID, accountID)
if policy.ID == policyID {
return policy, nil
}
}
return nil, status.Errorf(status.NotFound, "policy with ID %s not found", policyID)
} }
// SavePolicy in the store // SavePolicy in the store
func (am *DefaultAccountManager) SavePolicy(ctx context.Context, accountID, userID string, policy *Policy) error { func (am *DefaultAccountManager) SavePolicy(ctx context.Context, accountID, userID string, policy *Policy, isUpdate bool) error {
unlock := am.Store.AcquireWriteLockByUID(ctx, accountID) unlock := am.Store.AcquireWriteLockByUID(ctx, accountID)
defer unlock() defer unlock()
@ -350,7 +337,9 @@ func (am *DefaultAccountManager) SavePolicy(ctx context.Context, accountID, user
return err return err
} }
exists := am.savePolicy(account, policy) if err = am.savePolicy(account, policy, isUpdate); err != nil {
return err
}
account.Network.IncSerial() account.Network.IncSerial()
if err = am.Store.SaveAccount(ctx, account); err != nil { if err = am.Store.SaveAccount(ctx, account); err != nil {
@ -358,7 +347,7 @@ func (am *DefaultAccountManager) SavePolicy(ctx context.Context, accountID, user
} }
action := activity.PolicyAdded action := activity.PolicyAdded
if exists { if isUpdate {
action = activity.PolicyUpdated action = activity.PolicyUpdated
} }
am.StoreEvent(ctx, userID, policy.ID, accountID, action, policy.EventMeta()) am.StoreEvent(ctx, userID, policy.ID, accountID, action, policy.EventMeta())
@ -397,24 +386,16 @@ func (am *DefaultAccountManager) DeletePolicy(ctx context.Context, accountID, po
// ListPolicies from the store // ListPolicies from the store
func (am *DefaultAccountManager) ListPolicies(ctx context.Context, accountID, userID string) ([]*Policy, error) { func (am *DefaultAccountManager) ListPolicies(ctx context.Context, accountID, userID string) ([]*Policy, error) {
unlock := am.Store.AcquireWriteLockByUID(ctx, accountID) user, err := am.Store.GetUserByUserID(ctx, LockingStrengthShare, userID)
defer unlock()
account, err := am.Store.GetAccount(ctx, accountID)
if err != nil { if err != nil {
return nil, err return nil, err
} }
user, err := account.FindUser(userID) if !user.IsAdminOrServiceUser() || user.AccountID != accountID {
if err != nil { return nil, status.Errorf(status.PermissionDenied, "only users with admin power are allowed to view policies")
return nil, err
} }
if !(user.HasAdminPower() || user.IsServiceUser) { return am.Store.GetAccountPolicies(ctx, LockingStrengthShare, accountID)
return nil, status.Errorf(status.PermissionDenied, "only users with admin power can view policies")
}
return account.Policies, nil
} }
func (am *DefaultAccountManager) deletePolicy(account *Account, policyID string) (*Policy, error) { func (am *DefaultAccountManager) deletePolicy(account *Account, policyID string) (*Policy, error) {
@ -434,18 +415,34 @@ func (am *DefaultAccountManager) deletePolicy(account *Account, policyID string)
return policy, nil return policy, nil
} }
func (am *DefaultAccountManager) savePolicy(account *Account, policy *Policy) (exists bool) { // savePolicy saves or updates a policy in the given account.
for i, p := range account.Policies { // If isUpdate is true, the function updates the existing policy; otherwise, it adds a new policy.
if p.ID == policy.ID { func (am *DefaultAccountManager) savePolicy(account *Account, policyToSave *Policy, isUpdate bool) error {
account.Policies[i] = policy for index, rule := range policyToSave.Rules {
exists = true rule.Sources = filterValidGroupIDs(account, rule.Sources)
break rule.Destinations = filterValidGroupIDs(account, rule.Destinations)
policyToSave.Rules[index] = rule
}
if policyToSave.SourcePostureChecks != nil {
policyToSave.SourcePostureChecks = filterValidPostureChecks(account, policyToSave.SourcePostureChecks)
}
if isUpdate {
policyIdx := slices.IndexFunc(account.Policies, func(policy *Policy) bool { return policy.ID == policyToSave.ID })
if policyIdx < 0 {
return status.Errorf(status.NotFound, "couldn't find policy id %s", policyToSave.ID)
} }
// Update the existing policy
account.Policies[policyIdx] = policyToSave
return nil
} }
if !exists {
account.Policies = append(account.Policies, policy) // Add the new policy to the account
} account.Policies = append(account.Policies, policyToSave)
return
return nil
} }
func toProtocolFirewallRules(update []*FirewallRule) []*proto.FirewallRule { func toProtocolFirewallRules(update []*FirewallRule) []*proto.FirewallRule {
@ -560,3 +557,29 @@ func (a *Account) getPostureChecks(postureChecksID string) *posture.Checks {
} }
return nil return nil
} }
// filterValidPostureChecks filters and returns the posture check IDs from the given list
// that are valid within the provided account.
func filterValidPostureChecks(account *Account, postureChecksIds []string) []string {
result := make([]string, 0, len(postureChecksIds))
for _, id := range postureChecksIds {
for _, postureCheck := range account.PostureChecks {
if id == postureCheck.ID {
result = append(result, id)
continue
}
}
}
return result
}
// filterValidGroupIDs filters a list of group IDs and returns only the ones present in the account's group map.
func filterValidGroupIDs(account *Account, groupIDs []string) []string {
result := make([]string, 0, len(groupIDs))
for _, groupID := range groupIDs {
if _, exists := account.Groups[groupID]; exists {
result = append(result, groupID)
}
}
return result
}

View File

@ -15,30 +15,16 @@ const (
) )
func (am *DefaultAccountManager) GetPostureChecks(ctx context.Context, accountID, postureChecksID, userID string) (*posture.Checks, error) { func (am *DefaultAccountManager) GetPostureChecks(ctx context.Context, accountID, postureChecksID, userID string) (*posture.Checks, error) {
unlock := am.Store.AcquireWriteLockByUID(ctx, accountID) user, err := am.Store.GetUserByUserID(ctx, LockingStrengthShare, userID)
defer unlock()
account, err := am.Store.GetAccount(ctx, accountID)
if err != nil { if err != nil {
return nil, err return nil, err
} }
user, err := account.FindUser(userID) if !user.HasAdminPower() || user.AccountID != accountID {
if err != nil {
return nil, err
}
if !user.HasAdminPower() {
return nil, status.Errorf(status.PermissionDenied, errMsgPostureAdminOnly) return nil, status.Errorf(status.PermissionDenied, errMsgPostureAdminOnly)
} }
for _, postureChecks := range account.PostureChecks { return am.Store.GetPostureChecksByID(ctx, LockingStrengthShare, postureChecksID, accountID)
if postureChecks.ID == postureChecksID {
return postureChecks, nil
}
}
return nil, status.Errorf(status.NotFound, "posture checks with ID %s not found", postureChecksID)
} }
func (am *DefaultAccountManager) SavePostureChecks(ctx context.Context, accountID, userID string, postureChecks *posture.Checks) error { func (am *DefaultAccountManager) SavePostureChecks(ctx context.Context, accountID, userID string, postureChecks *posture.Checks) error {
@ -121,24 +107,16 @@ func (am *DefaultAccountManager) DeletePostureChecks(ctx context.Context, accoun
} }
func (am *DefaultAccountManager) ListPostureChecks(ctx context.Context, accountID, userID string) ([]*posture.Checks, error) { func (am *DefaultAccountManager) ListPostureChecks(ctx context.Context, accountID, userID string) ([]*posture.Checks, error) {
unlock := am.Store.AcquireWriteLockByUID(ctx, accountID) user, err := am.Store.GetUserByUserID(ctx, LockingStrengthShare, userID)
defer unlock()
account, err := am.Store.GetAccount(ctx, accountID)
if err != nil { if err != nil {
return nil, err return nil, err
} }
user, err := account.FindUser(userID) if !user.HasAdminPower() || user.AccountID != accountID {
if err != nil {
return nil, err
}
if !user.HasAdminPower() {
return nil, status.Errorf(status.PermissionDenied, errMsgPostureAdminOnly) return nil, status.Errorf(status.PermissionDenied, errMsgPostureAdminOnly)
} }
return account.PostureChecks, nil return am.Store.GetAccountPostureChecks(ctx, LockingStrengthShare, accountID)
} }
func (am *DefaultAccountManager) savePostureChecks(account *Account, postureChecks *posture.Checks) (exists, uniqName bool) { func (am *DefaultAccountManager) savePostureChecks(account *Account, postureChecks *posture.Checks) (exists, uniqName bool) {

View File

@ -17,29 +17,16 @@ import (
// GetRoute gets a route object from account and route IDs // GetRoute gets a route object from account and route IDs
func (am *DefaultAccountManager) GetRoute(ctx context.Context, accountID string, routeID route.ID, userID string) (*route.Route, error) { func (am *DefaultAccountManager) GetRoute(ctx context.Context, accountID string, routeID route.ID, userID string) (*route.Route, error) {
unlock := am.Store.AcquireWriteLockByUID(ctx, accountID) user, err := am.Store.GetUserByUserID(ctx, LockingStrengthShare, userID)
defer unlock()
account, err := am.Store.GetAccount(ctx, accountID)
if err != nil { if err != nil {
return nil, err return nil, err
} }
user, err := account.FindUser(userID) if !user.IsAdminOrServiceUser() || user.AccountID != accountID {
if err != nil {
return nil, err
}
if !(user.HasAdminPower() || user.IsServiceUser) {
return nil, status.Errorf(status.PermissionDenied, "only users with admin power can view Network Routes") return nil, status.Errorf(status.PermissionDenied, "only users with admin power can view Network Routes")
} }
wantedRoute, found := account.Routes[routeID] return am.Store.GetRouteByID(ctx, LockingStrengthShare, string(routeID), accountID)
if found {
return wantedRoute, nil
}
return nil, status.Errorf(status.NotFound, "route with ID %s not found", routeID)
} }
// checkRoutePrefixOrDomainsExistForPeers checks if a route with a given prefix exists for a single peer or multiple peer groups. // checkRoutePrefixOrDomainsExistForPeers checks if a route with a given prefix exists for a single peer or multiple peer groups.
@ -134,6 +121,13 @@ func (am *DefaultAccountManager) CreateRoute(ctx context.Context, accountID stri
return nil, err return nil, err
} }
// Do not allow non-Linux peers
if peer := account.GetPeer(peerID); peer != nil {
if peer.Meta.GoOS != "linux" {
return nil, status.Errorf(status.InvalidArgument, "non-linux peers are not supported as network routes")
}
}
if len(domains) > 0 && prefix.IsValid() { if len(domains) > 0 && prefix.IsValid() {
return nil, status.Errorf(status.InvalidArgument, "domains and network should not be provided at the same time") return nil, status.Errorf(status.InvalidArgument, "domains and network should not be provided at the same time")
} }
@ -234,6 +228,13 @@ func (am *DefaultAccountManager) SaveRoute(ctx context.Context, accountID, userI
return err return err
} }
// Do not allow non-Linux peers
if peer := account.GetPeer(routeToSave.Peer); peer != nil {
if peer.Meta.GoOS != "linux" {
return status.Errorf(status.InvalidArgument, "non-linux peers are not supported as network routes")
}
}
if len(routeToSave.Domains) > 0 && routeToSave.Network.IsValid() { if len(routeToSave.Domains) > 0 && routeToSave.Network.IsValid() {
return status.Errorf(status.InvalidArgument, "domains and network should not be provided at the same time") return status.Errorf(status.InvalidArgument, "domains and network should not be provided at the same time")
} }
@ -311,29 +312,16 @@ func (am *DefaultAccountManager) DeleteRoute(ctx context.Context, accountID stri
// ListRoutes returns a list of routes from account // ListRoutes returns a list of routes from account
func (am *DefaultAccountManager) ListRoutes(ctx context.Context, accountID, userID string) ([]*route.Route, error) { func (am *DefaultAccountManager) ListRoutes(ctx context.Context, accountID, userID string) ([]*route.Route, error) {
unlock := am.Store.AcquireWriteLockByUID(ctx, accountID) user, err := am.Store.GetUserByUserID(ctx, LockingStrengthShare, userID)
defer unlock()
account, err := am.Store.GetAccount(ctx, accountID)
if err != nil { if err != nil {
return nil, err return nil, err
} }
user, err := account.FindUser(userID) if !user.IsAdminOrServiceUser() || user.AccountID != accountID {
if err != nil {
return nil, err
}
if !(user.HasAdminPower() || user.IsServiceUser) {
return nil, status.Errorf(status.PermissionDenied, "only users with admin power can view Network Routes") return nil, status.Errorf(status.PermissionDenied, "only users with admin power can view Network Routes")
} }
routes := make([]*route.Route, 0, len(account.Routes)) return am.Store.GetAccountRoutes(ctx, LockingStrengthShare, accountID)
for _, item := range account.Routes {
routes = append(routes, item)
}
return routes, nil
} }
func toProtocolRoute(route *route.Route) *proto.Route { func toProtocolRoute(route *route.Route) *proto.Route {

View File

@ -1205,7 +1205,7 @@ func TestGetNetworkMap_RouteSync(t *testing.T) {
newPolicy.Rules[0].Sources = []string{newGroup.ID} newPolicy.Rules[0].Sources = []string{newGroup.ID}
newPolicy.Rules[0].Destinations = []string{newGroup.ID} newPolicy.Rules[0].Destinations = []string{newGroup.ID}
err = am.SavePolicy(context.Background(), account.Id, userID, newPolicy) err = am.SavePolicy(context.Background(), account.Id, userID, newPolicy, false)
require.NoError(t, err) require.NoError(t, err)
err = am.DeletePolicy(context.Background(), account.Id, defaultRule.ID, userID) err = am.DeletePolicy(context.Background(), account.Id, defaultRule.ID, userID)

View File

@ -330,26 +330,24 @@ func (am *DefaultAccountManager) SaveSetupKey(ctx context.Context, accountID str
// ListSetupKeys returns a list of all setup keys of the account // ListSetupKeys returns a list of all setup keys of the account
func (am *DefaultAccountManager) ListSetupKeys(ctx context.Context, accountID, userID string) ([]*SetupKey, error) { func (am *DefaultAccountManager) ListSetupKeys(ctx context.Context, accountID, userID string) ([]*SetupKey, error) {
unlock := am.Store.AcquireWriteLockByUID(ctx, accountID) user, err := am.Store.GetUserByUserID(ctx, LockingStrengthShare, userID)
defer unlock()
account, err := am.Store.GetAccount(ctx, accountID)
if err != nil { if err != nil {
return nil, err return nil, err
} }
user, err := account.FindUser(userID) if !user.IsAdminOrServiceUser() || user.AccountID != accountID {
return nil, status.Errorf(status.Unauthorized, "only users with admin power can view setup keys")
}
setupKeys, err := am.Store.GetAccountSetupKeys(ctx, LockingStrengthShare, accountID)
if err != nil { if err != nil {
return nil, err return nil, err
} }
if !user.HasAdminPower() && !user.IsServiceUser { keys := make([]*SetupKey, 0, len(setupKeys))
return nil, status.Errorf(status.Unauthorized, "only users with admin power can view policies") for _, key := range setupKeys {
}
keys := make([]*SetupKey, 0, len(account.SetupKeys))
for _, key := range account.SetupKeys {
var k *SetupKey var k *SetupKey
if !(user.HasAdminPower() || user.IsServiceUser) { if !user.IsAdminOrServiceUser() {
k = key.HiddenCopy(999) k = key.HiddenCopy(999)
} else { } else {
k = key.Copy() k = key.Copy()
@ -362,44 +360,30 @@ func (am *DefaultAccountManager) ListSetupKeys(ctx context.Context, accountID, u
// GetSetupKey looks up a SetupKey by KeyID, returns NotFound error if not found. // GetSetupKey looks up a SetupKey by KeyID, returns NotFound error if not found.
func (am *DefaultAccountManager) GetSetupKey(ctx context.Context, accountID, userID, keyID string) (*SetupKey, error) { func (am *DefaultAccountManager) GetSetupKey(ctx context.Context, accountID, userID, keyID string) (*SetupKey, error) {
unlock := am.Store.AcquireWriteLockByUID(ctx, accountID) user, err := am.Store.GetUserByUserID(ctx, LockingStrengthShare, userID)
defer unlock()
account, err := am.Store.GetAccount(ctx, accountID)
if err != nil { if err != nil {
return nil, err return nil, err
} }
user, err := account.FindUser(userID) if !user.IsAdminOrServiceUser() || user.AccountID != accountID {
return nil, status.Errorf(status.Unauthorized, "only users with admin power can view setup keys")
}
setupKey, err := am.Store.GetSetupKeyByID(ctx, LockingStrengthShare, keyID, accountID)
if err != nil { if err != nil {
return nil, err return nil, err
} }
if !user.HasAdminPower() && !user.IsServiceUser {
return nil, status.Errorf(status.Unauthorized, "only users with admin power can view policies")
}
var foundKey *SetupKey
for _, key := range account.SetupKeys {
if key.Id == keyID {
foundKey = key.Copy()
break
}
}
if foundKey == nil {
return nil, status.Errorf(status.NotFound, "setup key not found")
}
// the UpdatedAt field was introduced later, so there might be that some keys have a Zero value (e.g, null in the store file) // the UpdatedAt field was introduced later, so there might be that some keys have a Zero value (e.g, null in the store file)
if foundKey.UpdatedAt.IsZero() { if setupKey.UpdatedAt.IsZero() {
foundKey.UpdatedAt = foundKey.CreatedAt setupKey.UpdatedAt = setupKey.CreatedAt
} }
if !(user.HasAdminPower() || user.IsServiceUser) { if !user.IsAdminOrServiceUser() {
foundKey = foundKey.HiddenCopy(999) setupKey = setupKey.HiddenCopy(999)
} }
return foundKey, nil return setupKey, nil
} }
func validateSetupKeyAutoGroups(account *Account, autoGroups []string) error { func validateSetupKeyAutoGroups(account *Account, autoGroups []string) error {

View File

@ -36,6 +36,7 @@ const (
idQueryCondition = "id = ?" idQueryCondition = "id = ?"
keyQueryCondition = "key = ?" keyQueryCondition = "key = ?"
accountAndIDQueryCondition = "account_id = ? and id = ?" accountAndIDQueryCondition = "account_id = ? and id = ?"
accountIDCondition = "account_id = ?"
peerNotFoundFMT = "peer %s not found" peerNotFoundFMT = "peer %s not found"
) )
@ -399,20 +400,30 @@ func (s *SqlStore) DeleteTokenID2UserIDIndex(tokenID string) error {
} }
func (s *SqlStore) GetAccountByPrivateDomain(ctx context.Context, domain string) (*Account, error) { func (s *SqlStore) GetAccountByPrivateDomain(ctx context.Context, domain string) (*Account, error) {
var account Account accountID, err := s.GetAccountIDByPrivateDomain(ctx, LockingStrengthShare, domain)
if err != nil {
result := s.db.First(&account, "domain = ? and is_domain_primary_account = ? and domain_category = ?", return nil, err
strings.ToLower(domain), true, PrivateCategory)
if result.Error != nil {
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
return nil, status.Errorf(status.NotFound, "account not found: provided domain is not registered or is not private")
}
log.WithContext(ctx).Errorf("error when getting account from the store: %s", result.Error)
return nil, status.Errorf(status.Internal, "issue getting account from store")
} }
// TODO: rework to not call GetAccount // TODO: rework to not call GetAccount
return s.GetAccount(ctx, account.Id) return s.GetAccount(ctx, accountID)
}
func (s *SqlStore) GetAccountIDByPrivateDomain(ctx context.Context, lockStrength LockingStrength, domain string) (string, error) {
var accountID string
result := s.db.WithContext(ctx).Clauses(clause.Locking{Strength: string(lockStrength)}).Model(&Account{}).Select("id").
Where("domain = ? and is_domain_primary_account = ? and domain_category = ?",
strings.ToLower(domain), true, PrivateCategory,
).First(&accountID)
if result.Error != nil {
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
return "", status.Errorf(status.NotFound, "account not found: provided domain is not registered or is not private")
}
log.WithContext(ctx).Errorf("error when getting account from the store: %s", result.Error)
return "", status.Errorf(status.Internal, "issue getting account from store")
}
return accountID, nil
} }
func (s *SqlStore) GetAccountBySetupKey(ctx context.Context, setupKey string) (*Account, error) { func (s *SqlStore) GetAccountBySetupKey(ctx context.Context, setupKey string) (*Account, error) {
@ -478,7 +489,7 @@ func (s *SqlStore) GetUserByTokenID(ctx context.Context, tokenID string) (*User,
func (s *SqlStore) GetUserByUserID(ctx context.Context, lockStrength LockingStrength, userID string) (*User, error) { func (s *SqlStore) GetUserByUserID(ctx context.Context, lockStrength LockingStrength, userID string) (*User, error) {
var user User var user User
result := s.db.WithContext(ctx).Clauses(clause.Locking{Strength: string(lockStrength)}). result := s.db.WithContext(ctx).Clauses(clause.Locking{Strength: string(lockStrength)}).
First(&user, idQueryCondition, userID) Preload(clause.Associations).First(&user, idQueryCondition, userID)
if result.Error != nil { if result.Error != nil {
if errors.Is(result.Error, gorm.ErrRecordNotFound) { if errors.Is(result.Error, gorm.ErrRecordNotFound) {
return nil, status.NewUserNotFoundError(userID) return nil, status.NewUserNotFoundError(userID)
@ -491,7 +502,7 @@ func (s *SqlStore) GetUserByUserID(ctx context.Context, lockStrength LockingStre
func (s *SqlStore) GetAccountGroups(ctx context.Context, accountID string) ([]*nbgroup.Group, error) { func (s *SqlStore) GetAccountGroups(ctx context.Context, accountID string) ([]*nbgroup.Group, error) {
var groups []*nbgroup.Group var groups []*nbgroup.Group
result := s.db.Find(&groups, idQueryCondition, accountID) result := s.db.Find(&groups, accountIDCondition, accountID)
if result.Error != nil { if result.Error != nil {
if errors.Is(result.Error, gorm.ErrRecordNotFound) { if errors.Is(result.Error, gorm.ErrRecordNotFound) {
return nil, status.Errorf(status.NotFound, "accountID not found: index lookup failed") return nil, status.Errorf(status.NotFound, "accountID not found: index lookup failed")
@ -661,9 +672,8 @@ func (s *SqlStore) GetAccountIDByPeerPubKey(ctx context.Context, peerKey string)
} }
func (s *SqlStore) GetAccountIDByUserID(userID string) (string, error) { func (s *SqlStore) GetAccountIDByUserID(userID string) (string, error) {
var user User
var accountID string var accountID string
result := s.db.Model(&user).Select("account_id").Where(idQueryCondition, userID).First(&accountID) result := s.db.Model(&User{}).Select("account_id").Where(idQueryCondition, userID).First(&accountID)
if result.Error != nil { if result.Error != nil {
if errors.Is(result.Error, gorm.ErrRecordNotFound) { if errors.Is(result.Error, gorm.ErrRecordNotFound) {
return "", status.Errorf(status.NotFound, "account not found: index lookup failed") return "", status.Errorf(status.NotFound, "account not found: index lookup failed")
@ -1028,3 +1038,152 @@ func (s *SqlStore) withTx(tx *gorm.DB) Store {
func (s *SqlStore) GetDB() *gorm.DB { func (s *SqlStore) GetDB() *gorm.DB {
return s.db return s.db
} }
func (s *SqlStore) GetAccountDNSSettings(ctx context.Context, lockStrength LockingStrength, accountID string) (*DNSSettings, error) {
var accountDNSSettings AccountDNSSettings
result := s.db.WithContext(ctx).Clauses(clause.Locking{Strength: string(lockStrength)}).Model(&Account{}).
First(&accountDNSSettings, idQueryCondition, accountID)
if result.Error != nil {
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
return nil, status.Errorf(status.NotFound, "dns settings not found")
}
return nil, status.Errorf(status.Internal, "failed to get dns settings from store: %v", result.Error)
}
return &accountDNSSettings.DNSSettings, nil
}
// AccountExists checks whether an account exists by the given ID.
func (s *SqlStore) AccountExists(ctx context.Context, lockStrength LockingStrength, id string) (bool, error) {
var accountID string
result := s.db.WithContext(ctx).Clauses(clause.Locking{Strength: string(lockStrength)}).Model(&Account{}).
Select("id").First(&accountID, idQueryCondition, id)
if result.Error != nil {
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
return false, nil
}
return false, result.Error
}
return accountID != "", nil
}
// GetAccountDomainAndCategory retrieves the Domain and DomainCategory fields for an account based on the given accountID.
func (s *SqlStore) GetAccountDomainAndCategory(ctx context.Context, lockStrength LockingStrength, accountID string) (string, string, error) {
var account Account
result := s.db.WithContext(ctx).Clauses(clause.Locking{Strength: string(lockStrength)}).Model(&Account{}).Select("domain", "domain_category").
Where(idQueryCondition, accountID).First(&account)
if result.Error != nil {
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
return "", "", status.Errorf(status.NotFound, "account not found")
}
return "", "", status.Errorf(status.Internal, "failed to get domain category from store: %v", result.Error)
}
return account.Domain, account.DomainCategory, nil
}
// GetGroupByID retrieves a group by ID and account ID.
func (s *SqlStore) GetGroupByID(ctx context.Context, lockStrength LockingStrength, groupID, accountID string) (*nbgroup.Group, error) {
return getRecordByID[nbgroup.Group](s.db.WithContext(ctx).Preload(clause.Associations), lockStrength, groupID, accountID)
}
// GetGroupByName retrieves a group by name and account ID.
func (s *SqlStore) GetGroupByName(ctx context.Context, lockStrength LockingStrength, groupName, accountID string) (*nbgroup.Group, error) {
var group nbgroup.Group
result := s.db.WithContext(ctx).Clauses(clause.Locking{Strength: string(lockStrength)}).Preload(clause.Associations).
Order("json_array_length(peers) DESC").First(&group, "name = ? and account_id = ?", groupName, accountID)
if err := result.Error; err != nil {
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
return nil, status.Errorf(status.NotFound, "group not found")
}
return nil, status.Errorf(status.Internal, "failed to get group from store: %s", result.Error)
}
return &group, nil
}
// GetAccountPolicies retrieves policies for an account.
func (s *SqlStore) GetAccountPolicies(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*Policy, error) {
return getRecords[*Policy](s.db.WithContext(ctx).Preload(clause.Associations), lockStrength, accountID)
}
// GetPolicyByID retrieves a policy by its ID and account ID.
func (s *SqlStore) GetPolicyByID(ctx context.Context, lockStrength LockingStrength, policyID string, accountID string) (*Policy, error) {
return getRecordByID[Policy](s.db.WithContext(ctx).Preload(clause.Associations), lockStrength, policyID, accountID)
}
// GetAccountPostureChecks retrieves posture checks for an account.
func (s *SqlStore) GetAccountPostureChecks(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*posture.Checks, error) {
return getRecords[*posture.Checks](s.db.WithContext(ctx), lockStrength, accountID)
}
// GetPostureChecksByID retrieves posture checks by their ID and account ID.
func (s *SqlStore) GetPostureChecksByID(ctx context.Context, lockStrength LockingStrength, postureCheckID string, accountID string) (*posture.Checks, error) {
return getRecordByID[posture.Checks](s.db.WithContext(ctx), lockStrength, postureCheckID, accountID)
}
// GetAccountRoutes retrieves network routes for an account.
func (s *SqlStore) GetAccountRoutes(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*route.Route, error) {
return getRecords[*route.Route](s.db.WithContext(ctx), lockStrength, accountID)
}
// GetRouteByID retrieves a route by its ID and account ID.
func (s *SqlStore) GetRouteByID(ctx context.Context, lockStrength LockingStrength, routeID string, accountID string) (*route.Route, error) {
return getRecordByID[route.Route](s.db.WithContext(ctx), lockStrength, routeID, accountID)
}
// GetAccountSetupKeys retrieves setup keys for an account.
func (s *SqlStore) GetAccountSetupKeys(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*SetupKey, error) {
return getRecords[*SetupKey](s.db.WithContext(ctx), lockStrength, accountID)
}
// GetSetupKeyByID retrieves a setup key by its ID and account ID.
func (s *SqlStore) GetSetupKeyByID(ctx context.Context, lockStrength LockingStrength, setupKeyID string, accountID string) (*SetupKey, error) {
return getRecordByID[SetupKey](s.db.WithContext(ctx), lockStrength, setupKeyID, accountID)
}
// GetAccountNameServerGroups retrieves name server groups for an account.
func (s *SqlStore) GetAccountNameServerGroups(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*nbdns.NameServerGroup, error) {
return getRecords[*nbdns.NameServerGroup](s.db.WithContext(ctx), lockStrength, accountID)
}
// GetNameServerGroupByID retrieves a name server group by its ID and account ID.
func (s *SqlStore) GetNameServerGroupByID(ctx context.Context, lockStrength LockingStrength, nsGroupID string, accountID string) (*nbdns.NameServerGroup, error) {
return getRecordByID[nbdns.NameServerGroup](s.db.WithContext(ctx), lockStrength, nsGroupID, accountID)
}
// getRecords retrieves records from the database based on the account ID.
func getRecords[T any](db *gorm.DB, lockStrength LockingStrength, accountID string) ([]T, error) {
var record []T
result := db.Clauses(clause.Locking{Strength: string(lockStrength)}).Find(&record, accountIDCondition, accountID)
if err := result.Error; err != nil {
parts := strings.Split(fmt.Sprintf("%T", record), ".")
recordType := parts[len(parts)-1]
return nil, status.Errorf(status.Internal, "failed to get account %ss from store: %v", recordType, err)
}
return record, nil
}
// getRecordByID retrieves a record by its ID and account ID from the database.
func getRecordByID[T any](db *gorm.DB, lockStrength LockingStrength, recordID, accountID string) (*T, error) {
var record T
result := db.Clauses(clause.Locking{Strength: string(lockStrength)}).
First(&record, accountAndIDQueryCondition, accountID, recordID)
if err := result.Error; err != nil {
parts := strings.Split(fmt.Sprintf("%T", record), ".")
recordType := parts[len(parts)-1]
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
return nil, status.Errorf(status.NotFound, "%s not found", recordType)
}
return nil, status.Errorf(status.Internal, "failed to get %s from store: %v", recordType, err)
}
return &record, nil
}

View File

@ -12,6 +12,7 @@ import (
"strings" "strings"
"time" "time"
"github.com/netbirdio/netbird/dns"
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
"gorm.io/gorm" "gorm.io/gorm"
@ -39,53 +40,81 @@ const (
type Store interface { type Store interface {
GetAllAccounts(ctx context.Context) []*Account GetAllAccounts(ctx context.Context) []*Account
GetAccount(ctx context.Context, accountID string) (*Account, error) GetAccount(ctx context.Context, accountID string) (*Account, error)
DeleteAccount(ctx context.Context, account *Account) error AccountExists(ctx context.Context, lockStrength LockingStrength, id string) (bool, error)
GetAccountDomainAndCategory(ctx context.Context, lockStrength LockingStrength, accountID string) (string, string, error)
GetAccountByUser(ctx context.Context, userID string) (*Account, error) GetAccountByUser(ctx context.Context, userID string) (*Account, error)
GetAccountByPeerPubKey(ctx context.Context, peerKey string) (*Account, error) GetAccountByPeerPubKey(ctx context.Context, peerKey string) (*Account, error)
GetAccountIDByPeerPubKey(ctx context.Context, peerKey string) (string, error) GetAccountIDByPeerPubKey(ctx context.Context, peerKey string) (string, error)
GetAccountIDByUserID(peerKey string) (string, error) GetAccountIDByUserID(userID string) (string, error)
GetAccountIDBySetupKey(ctx context.Context, peerKey string) (string, error) GetAccountIDBySetupKey(ctx context.Context, peerKey string) (string, error)
GetAccountByPeerID(ctx context.Context, peerID string) (*Account, error) GetAccountByPeerID(ctx context.Context, peerID string) (*Account, error)
GetAccountBySetupKey(ctx context.Context, setupKey string) (*Account, error) // todo use key hash later GetAccountBySetupKey(ctx context.Context, setupKey string) (*Account, error) // todo use key hash later
GetAccountByPrivateDomain(ctx context.Context, domain string) (*Account, error) GetAccountByPrivateDomain(ctx context.Context, domain string) (*Account, error)
GetTokenIDByHashedToken(ctx context.Context, secret string) (string, error) GetAccountIDByPrivateDomain(ctx context.Context, lockStrength LockingStrength, domain string) (string, error)
GetAccountSettings(ctx context.Context, lockStrength LockingStrength, accountID string) (*Settings, error)
GetAccountDNSSettings(ctx context.Context, lockStrength LockingStrength, accountID string) (*DNSSettings, error)
SaveAccount(ctx context.Context, account *Account) error
DeleteAccount(ctx context.Context, account *Account) error
GetUserByTokenID(ctx context.Context, tokenID string) (*User, error) GetUserByTokenID(ctx context.Context, tokenID string) (*User, error)
GetUserByUserID(ctx context.Context, lockStrength LockingStrength, userID string) (*User, error) GetUserByUserID(ctx context.Context, lockStrength LockingStrength, userID string) (*User, error)
GetAccountGroups(ctx context.Context, accountID string) ([]*nbgroup.Group, error)
GetPostureCheckByChecksDefinition(accountID string, checks *posture.ChecksDefinition) (*posture.Checks, error)
SaveAccount(ctx context.Context, account *Account) error
SaveUsers(accountID string, users map[string]*User) error SaveUsers(accountID string, users map[string]*User) error
SaveGroups(accountID string, groups map[string]*nbgroup.Group) error SaveUserLastLogin(ctx context.Context, accountID, userID string, lastLogin time.Time) error
GetTokenIDByHashedToken(ctx context.Context, secret string) (string, error)
DeleteHashedPAT2TokenIDIndex(hashedToken string) error DeleteHashedPAT2TokenIDIndex(hashedToken string) error
DeleteTokenID2UserIDIndex(tokenID string) error DeleteTokenID2UserIDIndex(tokenID string) error
GetAccountGroups(ctx context.Context, accountID string) ([]*nbgroup.Group, error)
GetGroupByID(ctx context.Context, lockStrength LockingStrength, groupID, accountID string) (*nbgroup.Group, error)
GetGroupByName(ctx context.Context, lockStrength LockingStrength, groupName, accountID string) (*nbgroup.Group, error)
SaveGroups(accountID string, groups map[string]*nbgroup.Group) error
GetAccountPolicies(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*Policy, error)
GetPolicyByID(ctx context.Context, lockStrength LockingStrength, policyID string, accountID string) (*Policy, error)
GetPostureCheckByChecksDefinition(accountID string, checks *posture.ChecksDefinition) (*posture.Checks, error)
GetAccountPostureChecks(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*posture.Checks, error)
GetPostureChecksByID(ctx context.Context, lockStrength LockingStrength, postureCheckID string, accountID string) (*posture.Checks, error)
GetPeerLabelsInAccount(ctx context.Context, lockStrength LockingStrength, accountId string) ([]string, error)
AddPeerToAllGroup(ctx context.Context, accountID string, peerID string) error
AddPeerToGroup(ctx context.Context, accountId string, peerId string, groupID string) error
AddPeerToAccount(ctx context.Context, peer *nbpeer.Peer) error
GetPeerByPeerPubKey(ctx context.Context, lockStrength LockingStrength, peerKey string) (*nbpeer.Peer, error)
SavePeer(ctx context.Context, accountID string, peer *nbpeer.Peer) error
SavePeerStatus(accountID, peerID string, status nbpeer.PeerStatus) error
SavePeerLocation(accountID string, peer *nbpeer.Peer) error
GetSetupKeyBySecret(ctx context.Context, lockStrength LockingStrength, key string) (*SetupKey, error)
IncrementSetupKeyUsage(ctx context.Context, setupKeyID string) error
GetAccountSetupKeys(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*SetupKey, error)
GetSetupKeyByID(ctx context.Context, lockStrength LockingStrength, setupKeyID string, accountID string) (*SetupKey, error)
GetAccountRoutes(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*route.Route, error)
GetRouteByID(ctx context.Context, lockStrength LockingStrength, routeID string, accountID string) (*route.Route, error)
GetAccountNameServerGroups(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*dns.NameServerGroup, error)
GetNameServerGroupByID(ctx context.Context, lockStrength LockingStrength, nameServerGroupID string, accountID string) (*dns.NameServerGroup, error)
GetTakenIPs(ctx context.Context, lockStrength LockingStrength, accountId string) ([]net.IP, error)
IncrementNetworkSerial(ctx context.Context, accountId string) error
GetAccountNetwork(ctx context.Context, lockStrength LockingStrength, accountId string) (*Network, error)
GetInstallationID() string GetInstallationID() string
SaveInstallationID(ctx context.Context, ID string) error SaveInstallationID(ctx context.Context, ID string) error
// AcquireWriteLockByUID should attempt to acquire a lock for write purposes and return a function that releases the lock // AcquireWriteLockByUID should attempt to acquire a lock for write purposes and return a function that releases the lock
AcquireWriteLockByUID(ctx context.Context, uniqueID string) func() AcquireWriteLockByUID(ctx context.Context, uniqueID string) func()
// AcquireReadLockByUID should attempt to acquire lock for read purposes and return a function that releases the lock // AcquireReadLockByUID should attempt to acquire lock for read purposes and return a function that releases the lock
AcquireReadLockByUID(ctx context.Context, uniqueID string) func() AcquireReadLockByUID(ctx context.Context, uniqueID string) func()
// AcquireGlobalLock should attempt to acquire a global lock and return a function that releases the lock // AcquireGlobalLock should attempt to acquire a global lock and return a function that releases the lock
AcquireGlobalLock(ctx context.Context) func() AcquireGlobalLock(ctx context.Context) func()
SavePeer(ctx context.Context, accountID string, peer *nbpeer.Peer) error
SavePeerStatus(accountID, peerID string, status nbpeer.PeerStatus) error
SavePeerLocation(accountID string, peer *nbpeer.Peer) error
SaveUserLastLogin(ctx context.Context, accountID, userID string, lastLogin time.Time) error
// Close should close the store persisting all unsaved data. // Close should close the store persisting all unsaved data.
Close(ctx context.Context) error Close(ctx context.Context) error
// GetStoreEngine should return StoreEngine of the current store implementation. // GetStoreEngine should return StoreEngine of the current store implementation.
// This is also a method of metrics.DataSource interface. // This is also a method of metrics.DataSource interface.
GetStoreEngine() StoreEngine GetStoreEngine() StoreEngine
GetPeerByPeerPubKey(ctx context.Context, lockStrength LockingStrength, peerKey string) (*nbpeer.Peer, error)
GetAccountSettings(ctx context.Context, lockStrength LockingStrength, accountID string) (*Settings, error)
GetSetupKeyBySecret(ctx context.Context, lockStrength LockingStrength, key string) (*SetupKey, error)
GetTakenIPs(ctx context.Context, lockStrength LockingStrength, accountId string) ([]net.IP, error)
IncrementSetupKeyUsage(ctx context.Context, setupKeyID string) error
AddPeerToAllGroup(ctx context.Context, accountID string, peerID string) error
GetPeerLabelsInAccount(ctx context.Context, lockStrength LockingStrength, accountId string) ([]string, error)
AddPeerToGroup(ctx context.Context, accountId string, peerId string, groupID string) error
AddPeerToAccount(ctx context.Context, peer *nbpeer.Peer) error
IncrementNetworkSerial(ctx context.Context, accountId string) error
GetAccountNetwork(ctx context.Context, lockStrength LockingStrength, accountId string) (*Network, error)
ExecuteInTransaction(ctx context.Context, f func(store Store) error) error ExecuteInTransaction(ctx context.Context, f func(store Store) error) error
} }

View File

@ -94,6 +94,11 @@ func (u *User) HasAdminPower() bool {
return u.Role == UserRoleAdmin || u.Role == UserRoleOwner return u.Role == UserRoleAdmin || u.Role == UserRoleOwner
} }
// IsAdminOrServiceUser checks if the user has admin power or is a service user.
func (u *User) IsAdminOrServiceUser() bool {
return u.HasAdminPower() || u.IsServiceUser
}
// ToUserInfo converts a User object to a UserInfo object. // ToUserInfo converts a User object to a UserInfo object.
func (u *User) ToUserInfo(userData *idp.UserData, settings *Settings) (*UserInfo, error) { func (u *User) ToUserInfo(userData *idp.UserData, settings *Settings) (*UserInfo, error) {
autoGroups := u.AutoGroups autoGroups := u.AutoGroups
@ -357,39 +362,35 @@ func (am *DefaultAccountManager) inviteNewUser(ctx context.Context, accountID, u
return newUser.ToUserInfo(idpUser, account.Settings) return newUser.ToUserInfo(idpUser, account.Settings)
} }
func (am *DefaultAccountManager) GetUserByID(ctx context.Context, id string) (*User, error) {
return am.Store.GetUserByUserID(ctx, LockingStrengthShare, id)
}
// GetUser looks up a user by provided authorization claims. // GetUser looks up a user by provided authorization claims.
// It will also create an account if didn't exist for this user before. // It will also create an account if didn't exist for this user before.
func (am *DefaultAccountManager) GetUser(ctx context.Context, claims jwtclaims.AuthorizationClaims) (*User, error) { func (am *DefaultAccountManager) GetUser(ctx context.Context, claims jwtclaims.AuthorizationClaims) (*User, error) {
account, _, err := am.GetAccountFromToken(ctx, claims) accountID, userID, err := am.GetAccountIDFromToken(ctx, claims)
if err != nil { if err != nil {
return nil, fmt.Errorf("failed to get account with token claims %v", err) return nil, fmt.Errorf("failed to get account with token claims %v", err)
} }
unlock := am.Store.AcquireWriteLockByUID(ctx, account.Id) user, err := am.Store.GetUserByUserID(ctx, LockingStrengthShare, userID)
defer unlock()
account, err = am.Store.GetAccount(ctx, account.Id)
if err != nil { if err != nil {
return nil, fmt.Errorf("failed to get an account from store %v", err) return nil, err
} }
user, ok := account.Users[claims.UserId] // this code should be outside of the am.GetAccountIDFromToken(claims) because this method is called also by the gRPC
if !ok {
return nil, status.Errorf(status.NotFound, "user not found")
}
// this code should be outside of the am.GetAccountFromToken(claims) because this method is called also by the gRPC
// server when user authenticates a device. And we need to separate the Dashboard login event from the Device login event. // server when user authenticates a device. And we need to separate the Dashboard login event from the Device login event.
newLogin := user.LastDashboardLoginChanged(claims.LastLogin) newLogin := user.LastDashboardLoginChanged(claims.LastLogin)
err = am.Store.SaveUserLastLogin(ctx, account.Id, claims.UserId, claims.LastLogin) err = am.Store.SaveUserLastLogin(ctx, accountID, userID, claims.LastLogin)
if err != nil { if err != nil {
log.WithContext(ctx).Errorf("failed saving user last login: %v", err) log.WithContext(ctx).Errorf("failed saving user last login: %v", err)
} }
if newLogin { if newLogin {
meta := map[string]any{"timestamp": claims.LastLogin} meta := map[string]any{"timestamp": claims.LastLogin}
am.StoreEvent(ctx, claims.UserId, claims.UserId, account.Id, activity.DashboardLogin, meta) am.StoreEvent(ctx, claims.UserId, claims.UserId, accountID, activity.DashboardLogin, meta)
} }
return user, nil return user, nil
@ -642,63 +643,48 @@ func (am *DefaultAccountManager) DeletePAT(ctx context.Context, accountID string
// GetPAT returns a specific PAT from a user // GetPAT returns a specific PAT from a user
func (am *DefaultAccountManager) GetPAT(ctx context.Context, accountID string, initiatorUserID string, targetUserID string, tokenID string) (*PersonalAccessToken, error) { func (am *DefaultAccountManager) GetPAT(ctx context.Context, accountID string, initiatorUserID string, targetUserID string, tokenID string) (*PersonalAccessToken, error) {
unlock := am.Store.AcquireWriteLockByUID(ctx, accountID) initiatorUser, err := am.Store.GetUserByUserID(ctx, LockingStrengthShare, initiatorUserID)
defer unlock()
account, err := am.Store.GetAccount(ctx, accountID)
if err != nil { if err != nil {
return nil, status.Errorf(status.NotFound, "account not found: %s", err) return nil, err
} }
targetUser, ok := account.Users[targetUserID] targetUser, err := am.Store.GetUserByUserID(ctx, LockingStrengthShare, targetUserID)
if !ok { if err != nil {
return nil, status.Errorf(status.NotFound, "user not found") return nil, err
} }
executingUser, ok := account.Users[initiatorUserID] if (initiatorUserID != targetUserID && !initiatorUser.IsAdminOrServiceUser()) || initiatorUser.AccountID != accountID {
if !ok { return nil, status.Errorf(status.PermissionDenied, "no permission to get PAT for this user")
return nil, status.Errorf(status.NotFound, "user not found")
} }
if !(initiatorUserID == targetUserID || (executingUser.HasAdminPower() && targetUser.IsServiceUser)) { for _, pat := range targetUser.PATsG {
return nil, status.Errorf(status.PermissionDenied, "no permission to get PAT for this userser") if pat.ID == tokenID {
return pat.Copy(), nil
}
} }
pat := targetUser.PATs[tokenID] return nil, status.Errorf(status.NotFound, "PAT not found")
if pat == nil {
return nil, status.Errorf(status.NotFound, "PAT not found")
}
return pat, nil
} }
// GetAllPATs returns all PATs for a user // GetAllPATs returns all PATs for a user
func (am *DefaultAccountManager) GetAllPATs(ctx context.Context, accountID string, initiatorUserID string, targetUserID string) ([]*PersonalAccessToken, error) { func (am *DefaultAccountManager) GetAllPATs(ctx context.Context, accountID string, initiatorUserID string, targetUserID string) ([]*PersonalAccessToken, error) {
unlock := am.Store.AcquireWriteLockByUID(ctx, accountID) initiatorUser, err := am.Store.GetUserByUserID(ctx, LockingStrengthShare, initiatorUserID)
defer unlock()
account, err := am.Store.GetAccount(ctx, accountID)
if err != nil { if err != nil {
return nil, status.Errorf(status.NotFound, "account not found: %s", err) return nil, err
} }
targetUser, ok := account.Users[targetUserID] targetUser, err := am.Store.GetUserByUserID(ctx, LockingStrengthShare, targetUserID)
if !ok { if err != nil {
return nil, status.Errorf(status.NotFound, "user not found") return nil, err
} }
executingUser, ok := account.Users[initiatorUserID] if (initiatorUserID != targetUserID && !initiatorUser.IsAdminOrServiceUser()) || initiatorUser.AccountID != accountID {
if !ok {
return nil, status.Errorf(status.NotFound, "user not found")
}
if !(initiatorUserID == targetUserID || (executingUser.HasAdminPower() && targetUser.IsServiceUser)) {
return nil, status.Errorf(status.PermissionDenied, "no permission to get PAT for this user") return nil, status.Errorf(status.PermissionDenied, "no permission to get PAT for this user")
} }
var pats []*PersonalAccessToken pats := make([]*PersonalAccessToken, 0, len(targetUser.PATsG))
for _, pat := range targetUser.PATs { for _, pat := range targetUser.PATsG {
pats = append(pats, pat) pats = append(pats, pat.Copy())
} }
return pats, nil return pats, nil

View File

@ -199,7 +199,8 @@ func TestUser_GetPAT(t *testing.T) {
defer store.Close(context.Background()) defer store.Close(context.Background())
account := newAccountWithId(context.Background(), mockAccountID, mockUserID, "") account := newAccountWithId(context.Background(), mockAccountID, mockUserID, "")
account.Users[mockUserID] = &User{ account.Users[mockUserID] = &User{
Id: mockUserID, Id: mockUserID,
AccountID: mockAccountID,
PATs: map[string]*PersonalAccessToken{ PATs: map[string]*PersonalAccessToken{
mockTokenID1: { mockTokenID1: {
ID: mockTokenID1, ID: mockTokenID1,
@ -231,7 +232,8 @@ func TestUser_GetAllPATs(t *testing.T) {
defer store.Close(context.Background()) defer store.Close(context.Background())
account := newAccountWithId(context.Background(), mockAccountID, mockUserID, "") account := newAccountWithId(context.Background(), mockAccountID, mockUserID, "")
account.Users[mockUserID] = &User{ account.Users[mockUserID] = &User{
Id: mockUserID, Id: mockUserID,
AccountID: mockAccountID,
PATs: map[string]*PersonalAccessToken{ PATs: map[string]*PersonalAccessToken{
mockTokenID1: { mockTokenID1: {
ID: mockTokenID1, ID: mockTokenID1,
@ -796,7 +798,10 @@ func TestUser_DeleteUser_RegularUsers(t *testing.T) {
assert.NoError(t, err) assert.NoError(t, err)
} }
acc, err := am.GetAccountByUserOrAccountID(context.Background(), "", account.Id, "") accID, err := am.GetAccountIDByUserOrAccountID(context.Background(), "", account.Id, "")
assert.NoError(t, err)
acc, err := am.Store.GetAccount(context.Background(), accID)
assert.NoError(t, err) assert.NoError(t, err)
for _, id := range tc.expectedDeleted { for _, id := range tc.expectedDeleted {