mirror of
https://github.com/netbirdio/netbird.git
synced 2024-11-22 08:03:30 +01:00
[management] Remove redundant get account calls in GetAccountFromToken (#2615)
* refactor access control middleware and user access by JWT groups Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com> * refactor jwt groups extractor Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com> * refactor handlers to get account when necessary Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com> * refactor getAccountFromToken Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com> * refactor getAccountWithAuthorizationClaims Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com> * fix merge Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com> * revert handles change Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com> * remove GetUserByID from account manager Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com> * fix tests Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com> * refactor getAccountWithAuthorizationClaims to return account id Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com> * refactor handlers to use GetAccountIDFromToken Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com> * fix tests Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com> * remove locks Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com> * refactor Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com> * add GetGroupByName from store Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com> * add GetGroupByID from store and refactor Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com> * Refactor retrieval of policy and posture checks Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com> * Refactor user permissions and retrieves PAT Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com> * Refactor route, setupkey, nameserver and dns to get record(s) from store Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com> * Refactor store Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com> * fix lint Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com> * fix tests Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com> * fix add missing policy source posture checks Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com> * add store lock Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com> * fix tests Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com> * add get account Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com> --------- Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>
This commit is contained in:
parent
4ebf6e1c4c
commit
acb73bd64a
@ -20,11 +20,6 @@ import (
|
||||
cacheStore "github.com/eko/gocache/v3/store"
|
||||
"github.com/hashicorp/go-multierror"
|
||||
"github.com/miekg/dns"
|
||||
gocache "github.com/patrickmn/go-cache"
|
||||
"github.com/rs/xid"
|
||||
log "github.com/sirupsen/logrus"
|
||||
"golang.org/x/exp/maps"
|
||||
|
||||
"github.com/netbirdio/netbird/base62"
|
||||
nbdns "github.com/netbirdio/netbird/dns"
|
||||
"github.com/netbirdio/netbird/management/domain"
|
||||
@ -41,6 +36,10 @@ import (
|
||||
"github.com/netbirdio/netbird/management/server/status"
|
||||
"github.com/netbirdio/netbird/management/server/telemetry"
|
||||
"github.com/netbirdio/netbird/route"
|
||||
gocache "github.com/patrickmn/go-cache"
|
||||
"github.com/rs/xid"
|
||||
log "github.com/sirupsen/logrus"
|
||||
"golang.org/x/exp/maps"
|
||||
)
|
||||
|
||||
const (
|
||||
@ -63,6 +62,7 @@ func cacheEntryExpiration() time.Duration {
|
||||
|
||||
type AccountManager interface {
|
||||
GetOrCreateAccountByUser(ctx context.Context, userId, domain string) (*Account, error)
|
||||
GetAccount(ctx context.Context, accountID string) (*Account, error)
|
||||
CreateSetupKey(ctx context.Context, accountID string, keyName string, keyType SetupKeyType, expiresIn time.Duration,
|
||||
autoGroups []string, usageLimit int, userID string, ephemeral bool) (*SetupKey, error)
|
||||
SaveSetupKey(ctx context.Context, accountID string, key *SetupKey, userID string) (*SetupKey, error)
|
||||
@ -75,12 +75,14 @@ type AccountManager interface {
|
||||
SaveOrAddUser(ctx context.Context, accountID, initiatorUserID string, update *User, addIfNotExists bool) (*UserInfo, error)
|
||||
SaveOrAddUsers(ctx context.Context, accountID, initiatorUserID string, updates []*User, addIfNotExists bool) ([]*UserInfo, error)
|
||||
GetSetupKey(ctx context.Context, accountID, userID, keyID string) (*SetupKey, error)
|
||||
GetAccountByUserOrAccountID(ctx context.Context, userID, accountID, domain string) (*Account, error)
|
||||
GetAccountFromToken(ctx context.Context, claims jwtclaims.AuthorizationClaims) (*Account, *User, error)
|
||||
GetAccountByID(ctx context.Context, accountID string, userID string) (*Account, error)
|
||||
GetAccountIDByUserOrAccountID(ctx context.Context, userID, accountID, domain string) (string, error)
|
||||
GetAccountIDFromToken(ctx context.Context, claims jwtclaims.AuthorizationClaims) (string, string, error)
|
||||
CheckUserAccessByJWTGroups(ctx context.Context, claims jwtclaims.AuthorizationClaims) error
|
||||
GetAccountFromPAT(ctx context.Context, pat string) (*Account, *User, *PersonalAccessToken, error)
|
||||
DeleteAccount(ctx context.Context, accountID, userID string) error
|
||||
MarkPATUsed(ctx context.Context, tokenID string) error
|
||||
GetUserByID(ctx context.Context, id string) (*User, error)
|
||||
GetUser(ctx context.Context, claims jwtclaims.AuthorizationClaims) (*User, error)
|
||||
ListUsers(ctx context.Context, accountID string) ([]*User, error)
|
||||
GetPeers(ctx context.Context, accountID, userID string) ([]*nbpeer.Peer, error)
|
||||
@ -107,7 +109,7 @@ type AccountManager interface {
|
||||
GroupAddPeer(ctx context.Context, accountId, groupID, peerID string) error
|
||||
GroupDeletePeer(ctx context.Context, accountId, groupID, peerID string) error
|
||||
GetPolicy(ctx context.Context, accountID, policyID, userID string) (*Policy, error)
|
||||
SavePolicy(ctx context.Context, accountID, userID string, policy *Policy) error
|
||||
SavePolicy(ctx context.Context, accountID, userID string, policy *Policy, isUpdate bool) error
|
||||
DeletePolicy(ctx context.Context, accountID, policyID, userID string) error
|
||||
ListPolicies(ctx context.Context, accountID, userID string) ([]*Policy, error)
|
||||
GetRoute(ctx context.Context, accountID string, routeID route.ID, userID string) (*route.Route, error)
|
||||
@ -145,6 +147,7 @@ type AccountManager interface {
|
||||
SyncPeerMeta(ctx context.Context, peerPubKey string, meta nbpeer.PeerSystemMeta) error
|
||||
FindExistingPostureCheck(accountID string, checks *posture.ChecksDefinition) (*posture.Checks, error)
|
||||
GetAccountIDForPeerKey(ctx context.Context, peerKey string) (string, error)
|
||||
GetAccountSettings(ctx context.Context, accountID string, userID string) (*Settings, error)
|
||||
}
|
||||
|
||||
type DefaultAccountManager struct {
|
||||
@ -268,6 +271,11 @@ type AccountNetwork struct {
|
||||
Network *Network `gorm:"embedded;embeddedPrefix:network_"`
|
||||
}
|
||||
|
||||
// AccountDNSSettings used in gorm to only load dns settings and not whole account
|
||||
type AccountDNSSettings struct {
|
||||
DNSSettings DNSSettings `gorm:"embedded;embeddedPrefix:dns_settings_"`
|
||||
}
|
||||
|
||||
type UserPermissions struct {
|
||||
DashboardView string `json:"dashboard_view"`
|
||||
}
|
||||
@ -1252,25 +1260,37 @@ func (am *DefaultAccountManager) DeleteAccount(ctx context.Context, accountID, u
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetAccountByUserOrAccountID looks for an account by user or accountID, if no account is provided and
|
||||
// userID doesn't have an account associated with it, one account is created
|
||||
// domain is used to create a new account if no account is found
|
||||
func (am *DefaultAccountManager) GetAccountByUserOrAccountID(ctx context.Context, userID, accountID, domain string) (*Account, error) {
|
||||
// GetAccountIDByUserOrAccountID retrieves the account ID based on either the userID or accountID provided.
|
||||
// If an accountID is provided, it checks if the account exists and returns it.
|
||||
// If no accountID is provided, but a userID is given, it tries to retrieve the account by userID.
|
||||
// If the user doesn't have an account, it creates one using the provided domain.
|
||||
// Returns the account ID or an error if none is found or created.
|
||||
func (am *DefaultAccountManager) GetAccountIDByUserOrAccountID(ctx context.Context, userID, accountID, domain string) (string, error) {
|
||||
if accountID != "" {
|
||||
return am.Store.GetAccount(ctx, accountID)
|
||||
} else if userID != "" {
|
||||
account, err := am.GetOrCreateAccountByUser(ctx, userID, domain)
|
||||
exists, err := am.Store.AccountExists(ctx, LockingStrengthShare, accountID)
|
||||
if err != nil {
|
||||
return nil, status.Errorf(status.NotFound, "account not found using user id: %s", userID)
|
||||
return "", err
|
||||
}
|
||||
err = am.addAccountIDToIDPAppMeta(ctx, userID, account)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
if !exists {
|
||||
return "", status.Errorf(status.NotFound, "account %s does not exist", accountID)
|
||||
}
|
||||
return account, nil
|
||||
return accountID, nil
|
||||
}
|
||||
|
||||
return nil, status.Errorf(status.NotFound, "no valid user or account Id provided")
|
||||
if userID != "" {
|
||||
account, err := am.GetOrCreateAccountByUser(ctx, userID, domain)
|
||||
if err != nil {
|
||||
return "", status.Errorf(status.NotFound, "account not found or created for user id: %s", userID)
|
||||
}
|
||||
|
||||
if err = am.addAccountIDToIDPAppMeta(ctx, userID, account); err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
return account.Id, nil
|
||||
}
|
||||
|
||||
return "", status.Errorf(status.NotFound, "no valid userID or accountID provided")
|
||||
}
|
||||
|
||||
func isNil(i idp.Manager) bool {
|
||||
@ -1613,13 +1633,18 @@ func (am *DefaultAccountManager) handleNewUserAccount(ctx context.Context, domai
|
||||
}
|
||||
|
||||
// redeemInvite checks whether user has been invited and redeems the invite
|
||||
func (am *DefaultAccountManager) redeemInvite(ctx context.Context, account *Account, userID string) error {
|
||||
func (am *DefaultAccountManager) redeemInvite(ctx context.Context, accountID string, userID string) error {
|
||||
// only possible with the enabled IdP manager
|
||||
if am.idpManager == nil {
|
||||
log.WithContext(ctx).Warnf("invites only work with enabled IdP manager")
|
||||
return nil
|
||||
}
|
||||
|
||||
account, err := am.Store.GetAccount(ctx, accountID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
user, err := am.lookupUserInCache(ctx, userID, account)
|
||||
if err != nil {
|
||||
return err
|
||||
@ -1678,6 +1703,11 @@ func (am *DefaultAccountManager) MarkPATUsed(ctx context.Context, tokenID string
|
||||
return am.Store.SaveAccount(ctx, account)
|
||||
}
|
||||
|
||||
// GetAccount returns an account associated with this account ID.
|
||||
func (am *DefaultAccountManager) GetAccount(ctx context.Context, accountID string) (*Account, error) {
|
||||
return am.Store.GetAccount(ctx, accountID)
|
||||
}
|
||||
|
||||
// GetAccountFromPAT returns Account and User associated with a personal access token
|
||||
func (am *DefaultAccountManager) GetAccountFromPAT(ctx context.Context, token string) (*Account, *User, *PersonalAccessToken, error) {
|
||||
if len(token) != PATLength {
|
||||
@ -1726,10 +1756,24 @@ func (am *DefaultAccountManager) GetAccountFromPAT(ctx context.Context, token st
|
||||
return account, user, pat, nil
|
||||
}
|
||||
|
||||
// GetAccountFromToken returns an account associated with this token
|
||||
func (am *DefaultAccountManager) GetAccountFromToken(ctx context.Context, claims jwtclaims.AuthorizationClaims) (*Account, *User, error) {
|
||||
// GetAccountByID returns an account associated with this account ID.
|
||||
func (am *DefaultAccountManager) GetAccountByID(ctx context.Context, accountID string, userID string) (*Account, error) {
|
||||
user, err := am.Store.GetUserByUserID(ctx, LockingStrengthShare, userID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if user.AccountID != accountID || (!user.HasAdminPower() && !user.IsServiceUser) {
|
||||
return nil, status.Errorf(status.PermissionDenied, "the user has no permission to access account data")
|
||||
}
|
||||
|
||||
return am.Store.GetAccount(ctx, accountID)
|
||||
}
|
||||
|
||||
// GetAccountIDFromToken returns an account ID associated with this token.
|
||||
func (am *DefaultAccountManager) GetAccountIDFromToken(ctx context.Context, claims jwtclaims.AuthorizationClaims) (string, string, error) {
|
||||
if claims.UserId == "" {
|
||||
return nil, nil, fmt.Errorf("user ID is empty")
|
||||
return "", "", fmt.Errorf("user ID is empty")
|
||||
}
|
||||
if am.singleAccountMode && am.singleAccountModeDomain != "" {
|
||||
// This section is mostly related to self-hosted installations.
|
||||
@ -1739,70 +1783,84 @@ func (am *DefaultAccountManager) GetAccountFromToken(ctx context.Context, claims
|
||||
log.WithContext(ctx).Debugf("overriding JWT Domain and DomainCategory claims since single account mode is enabled")
|
||||
}
|
||||
|
||||
newAcc, err := am.getAccountWithAuthorizationClaims(ctx, claims)
|
||||
accountID, err := am.getAccountIDWithAuthorizationClaims(ctx, claims)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
unlock := am.Store.AcquireWriteLockByUID(ctx, newAcc.Id)
|
||||
alreadyUnlocked := false
|
||||
defer func() {
|
||||
if !alreadyUnlocked {
|
||||
unlock()
|
||||
}
|
||||
}()
|
||||
|
||||
account, err := am.Store.GetAccount(ctx, newAcc.Id)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
return "", "", err
|
||||
}
|
||||
|
||||
user := account.Users[claims.UserId]
|
||||
if user == nil {
|
||||
user, err := am.Store.GetUserByUserID(ctx, LockingStrengthShare, claims.UserId)
|
||||
if err != nil {
|
||||
// this is not really possible because we got an account by user ID
|
||||
return nil, nil, status.Errorf(status.NotFound, "user %s not found", claims.UserId)
|
||||
return "", "", status.Errorf(status.NotFound, "user %s not found", claims.UserId)
|
||||
}
|
||||
|
||||
if !user.IsServiceUser && claims.Invited {
|
||||
err = am.redeemInvite(ctx, account, claims.UserId)
|
||||
err = am.redeemInvite(ctx, accountID, user.Id)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
return "", "", err
|
||||
}
|
||||
}
|
||||
|
||||
if account.Settings.JWTGroupsEnabled {
|
||||
if account.Settings.JWTGroupsClaimName == "" {
|
||||
if err = am.syncJWTGroups(ctx, accountID, user, claims); err != nil {
|
||||
return "", "", err
|
||||
}
|
||||
|
||||
return accountID, user.Id, nil
|
||||
}
|
||||
|
||||
// syncJWTGroups processes the JWT groups for a user, updates the account based on the groups,
|
||||
// and propagates changes to peers if group propagation is enabled.
|
||||
func (am *DefaultAccountManager) syncJWTGroups(ctx context.Context, accountID string, user *User, claims jwtclaims.AuthorizationClaims) error {
|
||||
settings, err := am.Store.GetAccountSettings(ctx, LockingStrengthShare, accountID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if settings == nil || !settings.JWTGroupsEnabled {
|
||||
return nil
|
||||
}
|
||||
|
||||
if settings.JWTGroupsClaimName == "" {
|
||||
log.WithContext(ctx).Errorf("JWT groups are enabled but no claim name is set")
|
||||
return account, user, nil
|
||||
}
|
||||
if claim, ok := claims.Raw[account.Settings.JWTGroupsClaimName]; ok {
|
||||
if slice, ok := claim.([]interface{}); ok {
|
||||
var groupsNames []string
|
||||
for _, item := range slice {
|
||||
if g, ok := item.(string); ok {
|
||||
groupsNames = append(groupsNames, g)
|
||||
} else {
|
||||
log.WithContext(ctx).Errorf("JWT claim %q is not a string: %v", account.Settings.JWTGroupsClaimName, item)
|
||||
return nil
|
||||
}
|
||||
|
||||
// TODO: Remove GetAccount after refactoring account peer's update
|
||||
unlock := am.Store.AcquireWriteLockByUID(ctx, accountID)
|
||||
defer unlock()
|
||||
|
||||
account, err := am.Store.GetAccount(ctx, accountID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
jwtGroupsNames := extractJWTGroups(ctx, settings.JWTGroupsClaimName, claims)
|
||||
|
||||
oldGroups := make([]string, len(user.AutoGroups))
|
||||
copy(oldGroups, user.AutoGroups)
|
||||
// if groups were added or modified, save the account
|
||||
if account.SetJWTGroups(claims.UserId, groupsNames) {
|
||||
if account.Settings.GroupsPropagationEnabled {
|
||||
if user, err := account.FindUser(claims.UserId); err == nil {
|
||||
|
||||
// Update the account if group membership changes
|
||||
if account.SetJWTGroups(claims.UserId, jwtGroupsNames) {
|
||||
addNewGroups := difference(user.AutoGroups, oldGroups)
|
||||
removeOldGroups := difference(oldGroups, user.AutoGroups)
|
||||
|
||||
if settings.GroupsPropagationEnabled {
|
||||
account.UserGroupsAddToPeers(claims.UserId, addNewGroups...)
|
||||
account.UserGroupsRemoveFromPeers(claims.UserId, removeOldGroups...)
|
||||
account.Network.IncSerial()
|
||||
}
|
||||
|
||||
if err := am.Store.SaveAccount(ctx, account); err != nil {
|
||||
log.WithContext(ctx).Errorf("failed to save account: %v", err)
|
||||
} else {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Propagate changes to peers if group propagation is enabled
|
||||
if settings.GroupsPropagationEnabled {
|
||||
log.WithContext(ctx).Tracef("user %s: JWT group membership changed, updating account peers", claims.UserId)
|
||||
am.updateAccountPeers(ctx, account)
|
||||
unlock()
|
||||
alreadyUnlocked = true
|
||||
}
|
||||
|
||||
for _, g := range addNewGroups {
|
||||
if group := account.GetGroup(g); group != nil {
|
||||
am.StoreEvent(ctx, user.Id, user.Id, account.Id, activity.GroupAddedToUser,
|
||||
@ -1813,6 +1871,7 @@ func (am *DefaultAccountManager) GetAccountFromToken(ctx context.Context, claims
|
||||
"user_name": user.ServiceUserName})
|
||||
}
|
||||
}
|
||||
|
||||
for _, g := range removeOldGroups {
|
||||
if group := account.GetGroup(g); group != nil {
|
||||
am.StoreEvent(ctx, user.Id, user.Id, account.Id, activity.GroupRemovedFromUser,
|
||||
@ -1824,25 +1883,11 @@ func (am *DefaultAccountManager) GetAccountFromToken(ctx context.Context, claims
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
} else {
|
||||
if err := am.Store.SaveAccount(ctx, account); err != nil {
|
||||
log.WithContext(ctx).Errorf("failed to save account: %v", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
} else {
|
||||
log.WithContext(ctx).Debugf("JWT claim %q is not a string array", account.Settings.JWTGroupsClaimName)
|
||||
}
|
||||
} else {
|
||||
log.WithContext(ctx).Debugf("JWT claim %q not found", account.Settings.JWTGroupsClaimName)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
return account, user, nil
|
||||
}
|
||||
|
||||
// getAccountWithAuthorizationClaims retrievs an account using JWT Claims.
|
||||
// getAccountIDWithAuthorizationClaims retrieves an account ID using JWT Claims.
|
||||
// if domain is of the PrivateCategory category, it will evaluate
|
||||
// if account is new, existing or if there is another account with the same domain
|
||||
//
|
||||
@ -1859,26 +1904,34 @@ func (am *DefaultAccountManager) GetAccountFromToken(ctx context.Context, claims
|
||||
// Existing user + Existing account + Existing Indexed Domain -> Nothing changes
|
||||
//
|
||||
// Existing user + Existing account + Existing domain reclassified Domain as private -> Nothing changes (index domain)
|
||||
func (am *DefaultAccountManager) getAccountWithAuthorizationClaims(ctx context.Context, claims jwtclaims.AuthorizationClaims) (*Account, error) {
|
||||
func (am *DefaultAccountManager) getAccountIDWithAuthorizationClaims(ctx context.Context, claims jwtclaims.AuthorizationClaims) (string, error) {
|
||||
log.WithContext(ctx).Tracef("getting account with authorization claims. User ID: \"%s\", Account ID: \"%s\", Domain: \"%s\", Domain Category: \"%s\"",
|
||||
claims.UserId, claims.AccountId, claims.Domain, claims.DomainCategory)
|
||||
if claims.UserId == "" {
|
||||
return nil, fmt.Errorf("user ID is empty")
|
||||
return "", fmt.Errorf("user ID is empty")
|
||||
}
|
||||
|
||||
// if Account ID is part of the claims
|
||||
// it means that we've already classified the domain and user has an account
|
||||
if claims.DomainCategory != PrivateCategory || !isDomainValid(claims.Domain) {
|
||||
return am.GetAccountByUserOrAccountID(ctx, claims.UserId, claims.AccountId, claims.Domain)
|
||||
return am.GetAccountIDByUserOrAccountID(ctx, claims.UserId, claims.AccountId, claims.Domain)
|
||||
} else if claims.AccountId != "" {
|
||||
accountFromID, err := am.Store.GetAccount(ctx, claims.AccountId)
|
||||
userAccountID, err := am.Store.GetAccountIDByUserID(claims.UserId)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
return "", err
|
||||
}
|
||||
if _, ok := accountFromID.Users[claims.UserId]; !ok {
|
||||
return nil, fmt.Errorf("user %s is not part of the account id %s", claims.UserId, claims.AccountId)
|
||||
|
||||
if userAccountID != claims.AccountId {
|
||||
return "", fmt.Errorf("user %s is not part of the account id %s", claims.UserId, claims.AccountId)
|
||||
}
|
||||
if accountFromID.DomainCategory == PrivateCategory || claims.DomainCategory != PrivateCategory || accountFromID.Domain != claims.Domain {
|
||||
return accountFromID, nil
|
||||
|
||||
domain, domainCategory, err := am.Store.GetAccountDomainAndCategory(ctx, LockingStrengthShare, claims.AccountId)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
if domainCategory == PrivateCategory || claims.DomainCategory != PrivateCategory || domain != claims.Domain {
|
||||
return userAccountID, nil
|
||||
}
|
||||
}
|
||||
|
||||
@ -1888,48 +1941,53 @@ func (am *DefaultAccountManager) getAccountWithAuthorizationClaims(ctx context.C
|
||||
log.WithContext(ctx).Debugf("Acquired global lock in %s for user %s", time.Since(start), claims.UserId)
|
||||
|
||||
// We checked if the domain has a primary account already
|
||||
domainAccount, err := am.Store.GetAccountByPrivateDomain(ctx, claims.Domain)
|
||||
domainAccountID, err := am.Store.GetAccountIDByPrivateDomain(ctx, LockingStrengthShare, claims.Domain)
|
||||
if err != nil {
|
||||
// if NotFound we are good to continue, otherwise return error
|
||||
e, ok := status.FromError(err)
|
||||
if !ok || e.Type() != status.NotFound {
|
||||
return nil, err
|
||||
return "", err
|
||||
}
|
||||
}
|
||||
|
||||
account, err := am.Store.GetAccountByUser(ctx, claims.UserId)
|
||||
userAccountID, err := am.Store.GetAccountIDByUserID(claims.UserId)
|
||||
if err == nil {
|
||||
unlockAccount := am.Store.AcquireWriteLockByUID(ctx, account.Id)
|
||||
unlockAccount := am.Store.AcquireWriteLockByUID(ctx, userAccountID)
|
||||
defer unlockAccount()
|
||||
account, err = am.Store.GetAccountByUser(ctx, claims.UserId)
|
||||
account, err := am.Store.GetAccountByUser(ctx, claims.UserId)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
return "", err
|
||||
}
|
||||
// If there is no primary domain account yet, we set the account as primary for the domain. Otherwise,
|
||||
// we compare the account's ID with the domain account ID, and if they don't match, we set the account as
|
||||
// non-primary account for the domain. We don't merge accounts at this stage, because of cases when a domain
|
||||
// was previously unclassified or classified as public so N users that logged int that time, has they own account
|
||||
// and peers that shouldn't be lost.
|
||||
primaryDomain := domainAccount == nil || account.Id == domainAccount.Id
|
||||
|
||||
err = am.handleExistingUserAccount(ctx, account, primaryDomain, claims)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
primaryDomain := domainAccountID == "" || account.Id == domainAccountID
|
||||
if err = am.handleExistingUserAccount(ctx, account, primaryDomain, claims); err != nil {
|
||||
return "", err
|
||||
}
|
||||
return account, nil
|
||||
|
||||
return account.Id, nil
|
||||
} else if s, ok := status.FromError(err); ok && s.Type() == status.NotFound {
|
||||
if domainAccount != nil {
|
||||
unlockAccount := am.Store.AcquireWriteLockByUID(ctx, domainAccount.Id)
|
||||
var domainAccount *Account
|
||||
if domainAccountID != "" {
|
||||
unlockAccount := am.Store.AcquireWriteLockByUID(ctx, domainAccountID)
|
||||
defer unlockAccount()
|
||||
domainAccount, err = am.Store.GetAccountByPrivateDomain(ctx, claims.Domain)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
return "", err
|
||||
}
|
||||
}
|
||||
return am.handleNewUserAccount(ctx, domainAccount, claims)
|
||||
|
||||
account, err := am.handleNewUserAccount(ctx, domainAccount, claims)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
return account.Id, nil
|
||||
} else {
|
||||
// other error
|
||||
return nil, err
|
||||
return "", err
|
||||
}
|
||||
}
|
||||
|
||||
@ -2022,26 +2080,21 @@ func (am *DefaultAccountManager) GetDNSDomain() string {
|
||||
// CheckUserAccessByJWTGroups checks if the user has access, particularly in cases where the admin enabled JWT
|
||||
// group propagation and set the list of groups with access permissions.
|
||||
func (am *DefaultAccountManager) CheckUserAccessByJWTGroups(ctx context.Context, claims jwtclaims.AuthorizationClaims) error {
|
||||
account, _, err := am.GetAccountFromToken(ctx, claims)
|
||||
accountID, _, err := am.GetAccountIDFromToken(ctx, claims)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
settings, err := am.Store.GetAccountSettings(ctx, LockingStrengthShare, accountID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Ensures JWT group synchronization to the management is enabled before,
|
||||
// filtering access based on the allowed groups.
|
||||
if account.Settings != nil && account.Settings.JWTGroupsEnabled {
|
||||
if allowedGroups := account.Settings.JWTAllowGroups; len(allowedGroups) > 0 {
|
||||
userJWTGroups := make([]string, 0)
|
||||
|
||||
if claim, ok := claims.Raw[account.Settings.JWTGroupsClaimName]; ok {
|
||||
if claimGroups, ok := claim.([]interface{}); ok {
|
||||
for _, g := range claimGroups {
|
||||
if group, ok := g.(string); ok {
|
||||
userJWTGroups = append(userJWTGroups, group)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
if settings != nil && settings.JWTGroupsEnabled {
|
||||
if allowedGroups := settings.JWTAllowGroups; len(allowedGroups) > 0 {
|
||||
userJWTGroups := extractJWTGroups(ctx, settings.JWTGroupsClaimName, claims)
|
||||
|
||||
if !userHasAllowedGroup(allowedGroups, userJWTGroups) {
|
||||
return fmt.Errorf("user does not belong to any of the allowed JWT groups")
|
||||
@ -2111,6 +2164,19 @@ func (am *DefaultAccountManager) getFreeDNSLabel(ctx context.Context, store Stor
|
||||
return newLabel, nil
|
||||
}
|
||||
|
||||
func (am *DefaultAccountManager) GetAccountSettings(ctx context.Context, accountID string, userID string) (*Settings, error) {
|
||||
user, err := am.Store.GetUserByUserID(ctx, LockingStrengthShare, userID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if user.AccountID != accountID || (!user.HasAdminPower() && !user.IsServiceUser) {
|
||||
return nil, status.Errorf(status.PermissionDenied, "the user has no permission to access account data")
|
||||
}
|
||||
|
||||
return am.Store.GetAccountSettings(ctx, LockingStrengthShare, accountID)
|
||||
}
|
||||
|
||||
// addAllGroup to account object if it doesn't exist
|
||||
func addAllGroup(account *Account) error {
|
||||
if len(account.Groups) == 0 {
|
||||
@ -2193,6 +2259,27 @@ func newAccountWithId(ctx context.Context, accountID, userID, domain string) *Ac
|
||||
return acc
|
||||
}
|
||||
|
||||
// extractJWTGroups extracts the group names from a JWT token's claims.
|
||||
func extractJWTGroups(ctx context.Context, claimName string, claims jwtclaims.AuthorizationClaims) []string {
|
||||
userJWTGroups := make([]string, 0)
|
||||
|
||||
if claim, ok := claims.Raw[claimName]; ok {
|
||||
if claimGroups, ok := claim.([]interface{}); ok {
|
||||
for _, g := range claimGroups {
|
||||
if group, ok := g.(string); ok {
|
||||
userJWTGroups = append(userJWTGroups, group)
|
||||
} else {
|
||||
log.WithContext(ctx).Debugf("JWT claim %q contains a non-string group (type: %T): %v", claimName, g, g)
|
||||
}
|
||||
}
|
||||
}
|
||||
} else {
|
||||
log.WithContext(ctx).Debugf("JWT claim %q is not a string array", claimName)
|
||||
}
|
||||
|
||||
return userJWTGroups
|
||||
}
|
||||
|
||||
// userHasAllowedGroup checks if a user belongs to any of the allowed groups.
|
||||
func userHasAllowedGroup(allowedGroups []string, userGroups []string) bool {
|
||||
for _, userGroup := range userGroups {
|
||||
|
@ -462,7 +462,7 @@ func TestAccountManager_GetOrCreateAccountByUser(t *testing.T) {
|
||||
assert.Equal(t, account.Id, ev.TargetID)
|
||||
}
|
||||
|
||||
func TestDefaultAccountManager_GetAccountFromToken(t *testing.T) {
|
||||
func TestDefaultAccountManager_GetAccountIDFromToken(t *testing.T) {
|
||||
type initUserParams jwtclaims.AuthorizationClaims
|
||||
|
||||
type test struct {
|
||||
@ -633,9 +633,12 @@ func TestDefaultAccountManager_GetAccountFromToken(t *testing.T) {
|
||||
manager, err := createManager(t)
|
||||
require.NoError(t, err, "unable to create account manager")
|
||||
|
||||
initAccount, err := manager.GetAccountByUserOrAccountID(context.Background(), testCase.inputInitUserParams.UserId, testCase.inputInitUserParams.AccountId, testCase.inputInitUserParams.Domain)
|
||||
accountID, err := manager.GetAccountIDByUserOrAccountID(context.Background(), testCase.inputInitUserParams.UserId, testCase.inputInitUserParams.AccountId, testCase.inputInitUserParams.Domain)
|
||||
require.NoError(t, err, "create init user failed")
|
||||
|
||||
initAccount, err := manager.Store.GetAccount(context.Background(), accountID)
|
||||
require.NoError(t, err, "get init account failed")
|
||||
|
||||
if testCase.inputUpdateAttrs {
|
||||
err = manager.updateAccountDomainAttributes(context.Background(), initAccount, jwtclaims.AuthorizationClaims{UserId: testCase.inputInitUserParams.UserId, Domain: testCase.inputInitUserParams.Domain, DomainCategory: testCase.inputInitUserParams.DomainCategory}, true)
|
||||
require.NoError(t, err, "update init user failed")
|
||||
@ -645,8 +648,12 @@ func TestDefaultAccountManager_GetAccountFromToken(t *testing.T) {
|
||||
testCase.inputClaims.AccountId = initAccount.Id
|
||||
}
|
||||
|
||||
account, _, err := manager.GetAccountFromToken(context.Background(), testCase.inputClaims)
|
||||
accountID, _, err = manager.GetAccountIDFromToken(context.Background(), testCase.inputClaims)
|
||||
require.NoError(t, err, "support function failed")
|
||||
|
||||
account, err := manager.Store.GetAccount(context.Background(), accountID)
|
||||
require.NoError(t, err, "get account failed")
|
||||
|
||||
verifyNewAccountHasDefaultFields(t, account, testCase.expectedCreatedBy, testCase.inputClaims.Domain, testCase.expectedUsers)
|
||||
verifyCanAddPeerToAccount(t, manager, account, testCase.expectedCreatedBy)
|
||||
|
||||
@ -669,12 +676,13 @@ func TestDefaultAccountManager_GetGroupsFromTheToken(t *testing.T) {
|
||||
require.NoError(t, err, "unable to create account manager")
|
||||
|
||||
accountID := initAccount.Id
|
||||
acc, err := manager.GetAccountByUserOrAccountID(context.Background(), userId, accountID, domain)
|
||||
accountID, err = manager.GetAccountIDByUserOrAccountID(context.Background(), userId, accountID, domain)
|
||||
require.NoError(t, err, "create init user failed")
|
||||
// as initAccount was created without account id we have to take the id after account initialization
|
||||
// that happens inside the GetAccountByUserOrAccountID where the id is getting generated
|
||||
// that happens inside the GetAccountIDByUserOrAccountID where the id is getting generated
|
||||
// it is important to set the id as it help to avoid creating additional account with empty Id and re-pointing indices to it
|
||||
initAccount = acc
|
||||
initAccount, err = manager.Store.GetAccount(context.Background(), accountID)
|
||||
require.NoError(t, err, "get init account failed")
|
||||
|
||||
claims := jwtclaims.AuthorizationClaims{
|
||||
AccountId: accountID, // is empty as it is based on accountID right after initialization of initAccount
|
||||
@ -685,8 +693,12 @@ func TestDefaultAccountManager_GetGroupsFromTheToken(t *testing.T) {
|
||||
}
|
||||
|
||||
t.Run("JWT groups disabled", func(t *testing.T) {
|
||||
account, _, err := manager.GetAccountFromToken(context.Background(), claims)
|
||||
accountID, _, err := manager.GetAccountIDFromToken(context.Background(), claims)
|
||||
require.NoError(t, err, "get account by token failed")
|
||||
|
||||
account, err := manager.Store.GetAccount(context.Background(), accountID)
|
||||
require.NoError(t, err, "get account failed")
|
||||
|
||||
require.Len(t, account.Groups, 1, "only ALL group should exists")
|
||||
})
|
||||
|
||||
@ -696,8 +708,12 @@ func TestDefaultAccountManager_GetGroupsFromTheToken(t *testing.T) {
|
||||
require.NoError(t, err, "save account failed")
|
||||
require.Len(t, manager.Store.GetAllAccounts(context.Background()), 1, "only one account should exist")
|
||||
|
||||
account, _, err := manager.GetAccountFromToken(context.Background(), claims)
|
||||
accountID, _, err := manager.GetAccountIDFromToken(context.Background(), claims)
|
||||
require.NoError(t, err, "get account by token failed")
|
||||
|
||||
account, err := manager.Store.GetAccount(context.Background(), accountID)
|
||||
require.NoError(t, err, "get account failed")
|
||||
|
||||
require.Len(t, account.Groups, 1, "if group claim is not set no group added from JWT")
|
||||
})
|
||||
|
||||
@ -708,8 +724,12 @@ func TestDefaultAccountManager_GetGroupsFromTheToken(t *testing.T) {
|
||||
require.NoError(t, err, "save account failed")
|
||||
require.Len(t, manager.Store.GetAllAccounts(context.Background()), 1, "only one account should exist")
|
||||
|
||||
account, _, err := manager.GetAccountFromToken(context.Background(), claims)
|
||||
accountID, _, err := manager.GetAccountIDFromToken(context.Background(), claims)
|
||||
require.NoError(t, err, "get account by token failed")
|
||||
|
||||
account, err := manager.Store.GetAccount(context.Background(), accountID)
|
||||
require.NoError(t, err, "get account failed")
|
||||
|
||||
require.Len(t, account.Groups, 3, "groups should be added to the account")
|
||||
|
||||
groupsByNames := map[string]*group.Group{}
|
||||
@ -874,21 +894,21 @@ func TestAccountManager_GetAccountByUserOrAccountId(t *testing.T) {
|
||||
|
||||
userId := "test_user"
|
||||
|
||||
account, err := manager.GetAccountByUserOrAccountID(context.Background(), userId, "", "")
|
||||
accountID, err := manager.GetAccountIDByUserOrAccountID(context.Background(), userId, "", "")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if account == nil {
|
||||
if accountID == "" {
|
||||
t.Fatalf("expected to create an account for a user %s", userId)
|
||||
return
|
||||
}
|
||||
|
||||
_, err = manager.GetAccountByUserOrAccountID(context.Background(), "", account.Id, "")
|
||||
_, err = manager.GetAccountIDByUserOrAccountID(context.Background(), "", accountID, "")
|
||||
if err != nil {
|
||||
t.Errorf("expected to get existing account after creation using userid, no account was found for a account %s", account.Id)
|
||||
t.Errorf("expected to get existing account after creation using userid, no account was found for a account %s", accountID)
|
||||
}
|
||||
|
||||
_, err = manager.GetAccountByUserOrAccountID(context.Background(), "", "", "")
|
||||
_, err = manager.GetAccountIDByUserOrAccountID(context.Background(), "", "", "")
|
||||
if err == nil {
|
||||
t.Errorf("expected an error when user and account IDs are empty")
|
||||
}
|
||||
@ -1240,7 +1260,7 @@ func TestAccountManager_NetworkUpdates(t *testing.T) {
|
||||
}
|
||||
}()
|
||||
|
||||
if err := manager.SavePolicy(context.Background(), account.Id, userID, &policy); err != nil {
|
||||
if err := manager.SavePolicy(context.Background(), account.Id, userID, &policy, false); err != nil {
|
||||
t.Errorf("delete default rule: %v", err)
|
||||
return
|
||||
}
|
||||
@ -1648,19 +1668,22 @@ func TestDefaultAccountManager_DefaultAccountSettings(t *testing.T) {
|
||||
manager, err := createManager(t)
|
||||
require.NoError(t, err, "unable to create account manager")
|
||||
|
||||
account, err := manager.GetAccountByUserOrAccountID(context.Background(), userID, "", "")
|
||||
accountID, err := manager.GetAccountIDByUserOrAccountID(context.Background(), userID, "", "")
|
||||
require.NoError(t, err, "unable to create an account")
|
||||
|
||||
assert.NotNil(t, account.Settings)
|
||||
assert.Equal(t, account.Settings.PeerLoginExpirationEnabled, true)
|
||||
assert.Equal(t, account.Settings.PeerLoginExpiration, 24*time.Hour)
|
||||
settings, err := manager.Store.GetAccountSettings(context.Background(), LockingStrengthShare, accountID)
|
||||
require.NoError(t, err, "unable to get account settings")
|
||||
|
||||
assert.NotNil(t, settings)
|
||||
assert.Equal(t, settings.PeerLoginExpirationEnabled, true)
|
||||
assert.Equal(t, settings.PeerLoginExpiration, 24*time.Hour)
|
||||
}
|
||||
|
||||
func TestDefaultAccountManager_UpdatePeer_PeerLoginExpiration(t *testing.T) {
|
||||
manager, err := createManager(t)
|
||||
require.NoError(t, err, "unable to create account manager")
|
||||
|
||||
_, err = manager.GetAccountByUserOrAccountID(context.Background(), userID, "", "")
|
||||
_, err = manager.GetAccountIDByUserOrAccountID(context.Background(), userID, "", "")
|
||||
require.NoError(t, err, "unable to create an account")
|
||||
|
||||
key, err := wgtypes.GenerateKey()
|
||||
@ -1672,11 +1695,16 @@ func TestDefaultAccountManager_UpdatePeer_PeerLoginExpiration(t *testing.T) {
|
||||
})
|
||||
require.NoError(t, err, "unable to add peer")
|
||||
|
||||
account, err := manager.GetAccountByUserOrAccountID(context.Background(), userID, "", "")
|
||||
accountID, err := manager.GetAccountIDByUserOrAccountID(context.Background(), userID, "", "")
|
||||
require.NoError(t, err, "unable to get the account")
|
||||
|
||||
account, err := manager.Store.GetAccount(context.Background(), accountID)
|
||||
require.NoError(t, err, "unable to get the account")
|
||||
|
||||
err = manager.MarkPeerConnected(context.Background(), key.PublicKey().String(), true, nil, account)
|
||||
require.NoError(t, err, "unable to mark peer connected")
|
||||
account, err = manager.UpdateAccountSettings(context.Background(), account.Id, userID, &Settings{
|
||||
|
||||
account, err = manager.UpdateAccountSettings(context.Background(), accountID, userID, &Settings{
|
||||
PeerLoginExpiration: time.Hour,
|
||||
PeerLoginExpirationEnabled: true,
|
||||
})
|
||||
@ -1713,7 +1741,7 @@ func TestDefaultAccountManager_MarkPeerConnected_PeerLoginExpiration(t *testing.
|
||||
manager, err := createManager(t)
|
||||
require.NoError(t, err, "unable to create account manager")
|
||||
|
||||
account, err := manager.GetAccountByUserOrAccountID(context.Background(), userID, "", "")
|
||||
accountID, err := manager.GetAccountIDByUserOrAccountID(context.Background(), userID, "", "")
|
||||
require.NoError(t, err, "unable to create an account")
|
||||
|
||||
key, err := wgtypes.GenerateKey()
|
||||
@ -1724,7 +1752,7 @@ func TestDefaultAccountManager_MarkPeerConnected_PeerLoginExpiration(t *testing.
|
||||
LoginExpirationEnabled: true,
|
||||
})
|
||||
require.NoError(t, err, "unable to add peer")
|
||||
_, err = manager.UpdateAccountSettings(context.Background(), account.Id, userID, &Settings{
|
||||
_, err = manager.UpdateAccountSettings(context.Background(), accountID, userID, &Settings{
|
||||
PeerLoginExpiration: time.Hour,
|
||||
PeerLoginExpirationEnabled: true,
|
||||
})
|
||||
@ -1741,8 +1769,12 @@ func TestDefaultAccountManager_MarkPeerConnected_PeerLoginExpiration(t *testing.
|
||||
},
|
||||
}
|
||||
|
||||
account, err = manager.GetAccountByUserOrAccountID(context.Background(), userID, "", "")
|
||||
accountID, err = manager.GetAccountIDByUserOrAccountID(context.Background(), userID, "", "")
|
||||
require.NoError(t, err, "unable to get the account")
|
||||
|
||||
account, err := manager.Store.GetAccount(context.Background(), accountID)
|
||||
require.NoError(t, err, "unable to get the account")
|
||||
|
||||
// when we mark peer as connected, the peer login expiration routine should trigger
|
||||
err = manager.MarkPeerConnected(context.Background(), key.PublicKey().String(), true, nil, account)
|
||||
require.NoError(t, err, "unable to mark peer connected")
|
||||
@ -1757,7 +1789,7 @@ func TestDefaultAccountManager_UpdateAccountSettings_PeerLoginExpiration(t *test
|
||||
manager, err := createManager(t)
|
||||
require.NoError(t, err, "unable to create account manager")
|
||||
|
||||
_, err = manager.GetAccountByUserOrAccountID(context.Background(), userID, "", "")
|
||||
_, err = manager.GetAccountIDByUserOrAccountID(context.Background(), userID, "", "")
|
||||
require.NoError(t, err, "unable to create an account")
|
||||
|
||||
key, err := wgtypes.GenerateKey()
|
||||
@ -1769,8 +1801,12 @@ func TestDefaultAccountManager_UpdateAccountSettings_PeerLoginExpiration(t *test
|
||||
})
|
||||
require.NoError(t, err, "unable to add peer")
|
||||
|
||||
account, err := manager.GetAccountByUserOrAccountID(context.Background(), userID, "", "")
|
||||
accountID, err := manager.GetAccountIDByUserOrAccountID(context.Background(), userID, "", "")
|
||||
require.NoError(t, err, "unable to get the account")
|
||||
|
||||
account, err := manager.Store.GetAccount(context.Background(), accountID)
|
||||
require.NoError(t, err, "unable to get the account")
|
||||
|
||||
err = manager.MarkPeerConnected(context.Background(), key.PublicKey().String(), true, nil, account)
|
||||
require.NoError(t, err, "unable to mark peer connected")
|
||||
|
||||
@ -1813,10 +1849,10 @@ func TestDefaultAccountManager_UpdateAccountSettings(t *testing.T) {
|
||||
manager, err := createManager(t)
|
||||
require.NoError(t, err, "unable to create account manager")
|
||||
|
||||
account, err := manager.GetAccountByUserOrAccountID(context.Background(), userID, "", "")
|
||||
accountID, err := manager.GetAccountIDByUserOrAccountID(context.Background(), userID, "", "")
|
||||
require.NoError(t, err, "unable to create an account")
|
||||
|
||||
updated, err := manager.UpdateAccountSettings(context.Background(), account.Id, userID, &Settings{
|
||||
updated, err := manager.UpdateAccountSettings(context.Background(), accountID, userID, &Settings{
|
||||
PeerLoginExpiration: time.Hour,
|
||||
PeerLoginExpirationEnabled: false,
|
||||
})
|
||||
@ -1824,19 +1860,22 @@ func TestDefaultAccountManager_UpdateAccountSettings(t *testing.T) {
|
||||
assert.False(t, updated.Settings.PeerLoginExpirationEnabled)
|
||||
assert.Equal(t, updated.Settings.PeerLoginExpiration, time.Hour)
|
||||
|
||||
account, err = manager.GetAccountByUserOrAccountID(context.Background(), "", account.Id, "")
|
||||
accountID, err = manager.GetAccountIDByUserOrAccountID(context.Background(), "", accountID, "")
|
||||
require.NoError(t, err, "unable to get account by ID")
|
||||
|
||||
assert.False(t, account.Settings.PeerLoginExpirationEnabled)
|
||||
assert.Equal(t, account.Settings.PeerLoginExpiration, time.Hour)
|
||||
settings, err := manager.Store.GetAccountSettings(context.Background(), LockingStrengthShare, accountID)
|
||||
require.NoError(t, err, "unable to get account settings")
|
||||
|
||||
_, err = manager.UpdateAccountSettings(context.Background(), account.Id, userID, &Settings{
|
||||
assert.False(t, settings.PeerLoginExpirationEnabled)
|
||||
assert.Equal(t, settings.PeerLoginExpiration, time.Hour)
|
||||
|
||||
_, err = manager.UpdateAccountSettings(context.Background(), accountID, userID, &Settings{
|
||||
PeerLoginExpiration: time.Second,
|
||||
PeerLoginExpirationEnabled: false,
|
||||
})
|
||||
require.Error(t, err, "expecting to fail when providing PeerLoginExpiration less than one hour")
|
||||
|
||||
_, err = manager.UpdateAccountSettings(context.Background(), account.Id, userID, &Settings{
|
||||
_, err = manager.UpdateAccountSettings(context.Background(), accountID, userID, &Settings{
|
||||
PeerLoginExpiration: time.Hour * 24 * 181,
|
||||
PeerLoginExpirationEnabled: false,
|
||||
})
|
||||
|
@ -80,24 +80,16 @@ func (d DNSSettings) Copy() DNSSettings {
|
||||
|
||||
// GetDNSSettings validates a user role and returns the DNS settings for the provided account ID
|
||||
func (am *DefaultAccountManager) GetDNSSettings(ctx context.Context, accountID string, userID string) (*DNSSettings, error) {
|
||||
unlock := am.Store.AcquireWriteLockByUID(ctx, accountID)
|
||||
defer unlock()
|
||||
|
||||
account, err := am.Store.GetAccount(ctx, accountID)
|
||||
user, err := am.Store.GetUserByUserID(ctx, LockingStrengthShare, userID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
user, err := account.FindUser(userID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if !(user.HasAdminPower() || user.IsServiceUser) {
|
||||
if !user.IsAdminOrServiceUser() || user.AccountID != accountID {
|
||||
return nil, status.Errorf(status.PermissionDenied, "only users with admin power are allowed to view DNS settings")
|
||||
}
|
||||
dnsSettings := account.DNSSettings.Copy()
|
||||
return &dnsSettings, nil
|
||||
|
||||
return am.Store.GetAccountDNSSettings(ctx, LockingStrengthShare, accountID)
|
||||
}
|
||||
|
||||
// SaveDNSSettings validates a user role and updates the account's DNS settings
|
||||
|
@ -10,14 +10,15 @@ import (
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/rs/xid"
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
"github.com/netbirdio/netbird/dns"
|
||||
nbgroup "github.com/netbirdio/netbird/management/server/group"
|
||||
nbpeer "github.com/netbirdio/netbird/management/server/peer"
|
||||
"github.com/netbirdio/netbird/management/server/posture"
|
||||
"github.com/netbirdio/netbird/management/server/status"
|
||||
"github.com/netbirdio/netbird/management/server/telemetry"
|
||||
"github.com/netbirdio/netbird/route"
|
||||
"github.com/rs/xid"
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
"github.com/netbirdio/netbird/util"
|
||||
)
|
||||
@ -634,10 +635,19 @@ func (s *FileStore) GetUserByUserID(_ context.Context, _ LockingStrength, userID
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return account.Users[userID].Copy(), nil
|
||||
user := account.Users[userID].Copy()
|
||||
pat := make([]PersonalAccessToken, 0, len(user.PATs))
|
||||
for _, token := range user.PATs {
|
||||
if token != nil {
|
||||
pat = append(pat, *token)
|
||||
}
|
||||
}
|
||||
user.PATsG = pat
|
||||
|
||||
return user, nil
|
||||
}
|
||||
|
||||
func (s *FileStore) GetAccountGroups(ctx context.Context, accountID string) ([]*nbgroup.Group, error) {
|
||||
func (s *FileStore) GetAccountGroups(_ context.Context, accountID string) ([]*nbgroup.Group, error) {
|
||||
account, err := s.getAccount(accountID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
@ -931,7 +941,7 @@ func (s *FileStore) SaveUserLastLogin(_ context.Context, accountID, userID strin
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *FileStore) GetPostureCheckByChecksDefinition(accountID string, checks *posture.ChecksDefinition) (*posture.Checks, error) {
|
||||
func (s *FileStore) GetPostureCheckByChecksDefinition(_ string, _ *posture.ChecksDefinition) (*posture.Checks, error) {
|
||||
return nil, status.Errorf(status.Internal, "GetPostureCheckByChecksDefinition is not implemented")
|
||||
}
|
||||
|
||||
@ -950,10 +960,85 @@ func (s *FileStore) GetStoreEngine() StoreEngine {
|
||||
return FileStoreEngine
|
||||
}
|
||||
|
||||
func (s *FileStore) SaveUsers(accountID string, users map[string]*User) error {
|
||||
func (s *FileStore) SaveUsers(_ string, _ map[string]*User) error {
|
||||
return status.Errorf(status.Internal, "SaveUsers is not implemented")
|
||||
}
|
||||
|
||||
func (s *FileStore) SaveGroups(accountID string, groups map[string]*nbgroup.Group) error {
|
||||
func (s *FileStore) SaveGroups(_ string, _ map[string]*nbgroup.Group) error {
|
||||
return status.Errorf(status.Internal, "SaveGroups is not implemented")
|
||||
}
|
||||
|
||||
func (s *FileStore) GetAccountIDByPrivateDomain(_ context.Context, _ LockingStrength, _ string) (string, error) {
|
||||
return "", status.Errorf(status.Internal, "GetAccountIDByPrivateDomain is not implemented")
|
||||
}
|
||||
|
||||
func (s *FileStore) GetAccountDomainAndCategory(_ context.Context, _ LockingStrength, accountID string) (string, string, error) {
|
||||
s.mux.Lock()
|
||||
defer s.mux.Unlock()
|
||||
|
||||
account, err := s.getAccount(accountID)
|
||||
if err != nil {
|
||||
return "", "", err
|
||||
}
|
||||
|
||||
return account.Domain, account.DomainCategory, nil
|
||||
}
|
||||
|
||||
// AccountExists checks whether an account exists by the given ID.
|
||||
func (s *FileStore) AccountExists(_ context.Context, _ LockingStrength, id string) (bool, error) {
|
||||
_, exists := s.Accounts[id]
|
||||
return exists, nil
|
||||
}
|
||||
|
||||
func (s *FileStore) GetAccountDNSSettings(_ context.Context, _ LockingStrength, _ string) (*DNSSettings, error) {
|
||||
return nil, status.Errorf(status.Internal, "GetAccountDNSSettings is not implemented")
|
||||
}
|
||||
|
||||
func (s *FileStore) GetGroupByID(_ context.Context, _ LockingStrength, _, _ string) (*nbgroup.Group, error) {
|
||||
return nil, status.Errorf(status.Internal, "GetGroupByID is not implemented")
|
||||
}
|
||||
|
||||
func (s *FileStore) GetGroupByName(_ context.Context, _ LockingStrength, _, _ string) (*nbgroup.Group, error) {
|
||||
return nil, status.Errorf(status.Internal, "GetGroupByName is not implemented")
|
||||
}
|
||||
|
||||
func (s *FileStore) GetAccountPolicies(_ context.Context, _ LockingStrength, _ string) ([]*Policy, error) {
|
||||
return nil, status.Errorf(status.Internal, "GetPolicyByID is not implemented")
|
||||
}
|
||||
|
||||
func (s *FileStore) GetPolicyByID(_ context.Context, _ LockingStrength, _ string, _ string) (*Policy, error) {
|
||||
return nil, status.Errorf(status.Internal, "GetPolicyByID is not implemented")
|
||||
|
||||
}
|
||||
|
||||
func (s *FileStore) GetAccountPostureChecks(_ context.Context, _ LockingStrength, _ string) ([]*posture.Checks, error) {
|
||||
return nil, status.Errorf(status.Internal, "GetAccountPostureChecks is not implemented")
|
||||
}
|
||||
|
||||
func (s *FileStore) GetPostureChecksByID(_ context.Context, _ LockingStrength, _ string, _ string) (*posture.Checks, error) {
|
||||
return nil, status.Errorf(status.Internal, "GetPostureChecksByID is not implemented")
|
||||
}
|
||||
|
||||
func (s *FileStore) GetAccountRoutes(_ context.Context, _ LockingStrength, _ string) ([]*route.Route, error) {
|
||||
return nil, status.Errorf(status.Internal, "GetAccountRoutes is not implemented")
|
||||
}
|
||||
|
||||
func (s *FileStore) GetRouteByID(_ context.Context, _ LockingStrength, _ string, _ string) (*route.Route, error) {
|
||||
return nil, status.Errorf(status.Internal, "GetRouteByID is not implemented")
|
||||
}
|
||||
|
||||
func (s *FileStore) GetAccountSetupKeys(_ context.Context, _ LockingStrength, _ string) ([]*SetupKey, error) {
|
||||
return nil, status.Errorf(status.Internal, "GetAccountSetupKeys is not implemented")
|
||||
}
|
||||
|
||||
func (s *FileStore) GetSetupKeyByID(_ context.Context, _ LockingStrength, _ string, _ string) (*SetupKey, error) {
|
||||
return nil, status.Errorf(status.Internal, "GetSetupKeyByID is not implemented")
|
||||
}
|
||||
|
||||
func (s *FileStore) GetAccountNameServerGroups(_ context.Context, _ LockingStrength, _ string) ([]*dns.NameServerGroup, error) {
|
||||
return nil, status.Errorf(status.Internal, "GetAccountNameServerGroups is not implemented")
|
||||
}
|
||||
|
||||
func (s *FileStore) GetNameServerGroupByID(_ context.Context, _ LockingStrength, _ string, _ string) (*dns.NameServerGroup, error) {
|
||||
return nil, status.Errorf(status.Internal, "GetNameServerGroupByID is not implemented")
|
||||
}
|
||||
|
@ -25,91 +25,46 @@ func (e *GroupLinkError) Error() string {
|
||||
return fmt.Sprintf("group has been linked to %s: %s", e.Resource, e.Name)
|
||||
}
|
||||
|
||||
// GetGroup object of the peers
|
||||
// CheckGroupPermissions validates if a user has the necessary permissions to view groups
|
||||
func (am *DefaultAccountManager) CheckGroupPermissions(ctx context.Context, accountID, userID string) error {
|
||||
settings, err := am.Store.GetAccountSettings(ctx, LockingStrengthShare, accountID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
user, err := am.Store.GetUserByUserID(ctx, LockingStrengthShare, userID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if (!user.IsAdminOrServiceUser() && settings.RegularUsersViewBlocked) || user.AccountID != accountID {
|
||||
return status.Errorf(status.PermissionDenied, "groups are blocked for users")
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetGroup returns a specific group by groupID in an account
|
||||
func (am *DefaultAccountManager) GetGroup(ctx context.Context, accountID, groupID, userID string) (*nbgroup.Group, error) {
|
||||
unlock := am.Store.AcquireWriteLockByUID(ctx, accountID)
|
||||
defer unlock()
|
||||
|
||||
account, err := am.Store.GetAccount(ctx, accountID)
|
||||
if err != nil {
|
||||
if err := am.CheckGroupPermissions(ctx, accountID, userID); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
user, err := account.FindUser(userID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if !user.HasAdminPower() && !user.IsServiceUser && account.Settings.RegularUsersViewBlocked {
|
||||
return nil, status.Errorf(status.PermissionDenied, "groups are blocked for users")
|
||||
}
|
||||
|
||||
group, ok := account.Groups[groupID]
|
||||
if ok {
|
||||
return group, nil
|
||||
}
|
||||
|
||||
return nil, status.Errorf(status.NotFound, "group with ID %s not found", groupID)
|
||||
return am.Store.GetGroupByID(ctx, LockingStrengthShare, groupID, accountID)
|
||||
}
|
||||
|
||||
// GetAllGroups returns all groups in an account
|
||||
func (am *DefaultAccountManager) GetAllGroups(ctx context.Context, accountID string, userID string) ([]*nbgroup.Group, error) {
|
||||
unlock := am.Store.AcquireWriteLockByUID(ctx, accountID)
|
||||
defer unlock()
|
||||
|
||||
account, err := am.Store.GetAccount(ctx, accountID)
|
||||
if err != nil {
|
||||
func (am *DefaultAccountManager) GetAllGroups(ctx context.Context, accountID, userID string) ([]*nbgroup.Group, error) {
|
||||
if err := am.CheckGroupPermissions(ctx, accountID, userID); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
user, err := account.FindUser(userID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if !user.HasAdminPower() && !user.IsServiceUser && account.Settings.RegularUsersViewBlocked {
|
||||
return nil, status.Errorf(status.PermissionDenied, "groups are blocked for users")
|
||||
}
|
||||
|
||||
groups := make([]*nbgroup.Group, 0, len(account.Groups))
|
||||
for _, item := range account.Groups {
|
||||
groups = append(groups, item)
|
||||
}
|
||||
|
||||
return groups, nil
|
||||
return am.Store.GetAccountGroups(ctx, accountID)
|
||||
}
|
||||
|
||||
// GetGroupByName filters all groups in an account by name and returns the one with the most peers
|
||||
func (am *DefaultAccountManager) GetGroupByName(ctx context.Context, groupName, accountID string) (*nbgroup.Group, error) {
|
||||
unlock := am.Store.AcquireWriteLockByUID(ctx, accountID)
|
||||
defer unlock()
|
||||
|
||||
account, err := am.Store.GetAccount(ctx, accountID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
matchingGroups := make([]*nbgroup.Group, 0)
|
||||
for _, group := range account.Groups {
|
||||
if group.Name == groupName {
|
||||
matchingGroups = append(matchingGroups, group)
|
||||
}
|
||||
}
|
||||
|
||||
if len(matchingGroups) == 0 {
|
||||
return nil, status.Errorf(status.NotFound, "group with name %s not found", groupName)
|
||||
}
|
||||
|
||||
maxPeers := -1
|
||||
var groupWithMostPeers *nbgroup.Group
|
||||
for i, group := range matchingGroups {
|
||||
if len(group.Peers) > maxPeers {
|
||||
maxPeers = len(group.Peers)
|
||||
groupWithMostPeers = matchingGroups[i]
|
||||
}
|
||||
}
|
||||
|
||||
return groupWithMostPeers, nil
|
||||
return am.Store.GetGroupByName(ctx, LockingStrengthShare, groupName, accountID)
|
||||
}
|
||||
|
||||
// SaveGroup object of the peers
|
||||
@ -262,6 +217,15 @@ func (am *DefaultAccountManager) DeleteGroup(ctx context.Context, accountId, use
|
||||
return nil
|
||||
}
|
||||
|
||||
allGroup, err := account.GetGroupAll()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if allGroup.ID == groupID {
|
||||
return status.Errorf(status.InvalidArgument, "deleting group ALL is not allowed")
|
||||
}
|
||||
|
||||
if err = validateDeleteGroup(account, group, userId); err != nil {
|
||||
return err
|
||||
}
|
||||
|
@ -262,7 +262,7 @@ func (s *GRPCServer) validateToken(ctx context.Context, jwtToken string) (string
|
||||
}
|
||||
claims := s.jwtClaimsExtractor.FromToken(token)
|
||||
// we need to call this method because if user is new, we will automatically add it to existing or create a new account
|
||||
_, _, err = s.accountManager.GetAccountFromToken(ctx, claims)
|
||||
_, _, err = s.accountManager.GetAccountIDFromToken(ctx, claims)
|
||||
if err != nil {
|
||||
return "", status.Errorf(codes.Internal, "unable to fetch account with claims, err: %v", err)
|
||||
}
|
||||
|
@ -35,25 +35,26 @@ func NewAccountsHandler(accountManager server.AccountManager, authCfg AuthCfg) *
|
||||
// GetAllAccounts is HTTP GET handler that returns a list of accounts. Effectively returns just a single account.
|
||||
func (h *AccountsHandler) GetAllAccounts(w http.ResponseWriter, r *http.Request) {
|
||||
claims := h.claimsExtractor.FromRequestContext(r)
|
||||
account, user, err := h.accountManager.GetAccountFromToken(r.Context(), claims)
|
||||
accountID, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims)
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
|
||||
if !(user.HasAdminPower() || user.IsServiceUser) {
|
||||
util.WriteError(r.Context(), status.Errorf(status.PermissionDenied, "the user has no permission to access account data"), w)
|
||||
settings, err := h.accountManager.GetAccountSettings(r.Context(), accountID, userID)
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
|
||||
resp := toAccountResponse(account)
|
||||
resp := toAccountResponse(accountID, settings)
|
||||
util.WriteJSONObject(r.Context(), w, []*api.Account{resp})
|
||||
}
|
||||
|
||||
// UpdateAccount is HTTP PUT handler that updates the provided account. Updates only account settings (server.Settings)
|
||||
func (h *AccountsHandler) UpdateAccount(w http.ResponseWriter, r *http.Request) {
|
||||
claims := h.claimsExtractor.FromRequestContext(r)
|
||||
_, user, err := h.accountManager.GetAccountFromToken(r.Context(), claims)
|
||||
_, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims)
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
@ -96,24 +97,19 @@ func (h *AccountsHandler) UpdateAccount(w http.ResponseWriter, r *http.Request)
|
||||
settings.JWTAllowGroups = *req.Settings.JwtAllowGroups
|
||||
}
|
||||
|
||||
updatedAccount, err := h.accountManager.UpdateAccountSettings(r.Context(), accountID, user.Id, settings)
|
||||
updatedAccount, err := h.accountManager.UpdateAccountSettings(r.Context(), accountID, userID, settings)
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
|
||||
resp := toAccountResponse(updatedAccount)
|
||||
resp := toAccountResponse(updatedAccount.Id, updatedAccount.Settings)
|
||||
|
||||
util.WriteJSONObject(r.Context(), w, &resp)
|
||||
}
|
||||
|
||||
// DeleteAccount is a HTTP DELETE handler to delete an account
|
||||
func (h *AccountsHandler) DeleteAccount(w http.ResponseWriter, r *http.Request) {
|
||||
if r.Method != http.MethodDelete {
|
||||
util.WriteErrorResponse("wrong HTTP method", http.StatusMethodNotAllowed, w)
|
||||
return
|
||||
}
|
||||
|
||||
claims := h.claimsExtractor.FromRequestContext(r)
|
||||
vars := mux.Vars(r)
|
||||
targetAccountID := vars["accountId"]
|
||||
@ -131,28 +127,28 @@ func (h *AccountsHandler) DeleteAccount(w http.ResponseWriter, r *http.Request)
|
||||
util.WriteJSONObject(r.Context(), w, emptyObject{})
|
||||
}
|
||||
|
||||
func toAccountResponse(account *server.Account) *api.Account {
|
||||
jwtAllowGroups := account.Settings.JWTAllowGroups
|
||||
func toAccountResponse(accountID string, settings *server.Settings) *api.Account {
|
||||
jwtAllowGroups := settings.JWTAllowGroups
|
||||
if jwtAllowGroups == nil {
|
||||
jwtAllowGroups = []string{}
|
||||
}
|
||||
|
||||
settings := api.AccountSettings{
|
||||
PeerLoginExpiration: int(account.Settings.PeerLoginExpiration.Seconds()),
|
||||
PeerLoginExpirationEnabled: account.Settings.PeerLoginExpirationEnabled,
|
||||
GroupsPropagationEnabled: &account.Settings.GroupsPropagationEnabled,
|
||||
JwtGroupsEnabled: &account.Settings.JWTGroupsEnabled,
|
||||
JwtGroupsClaimName: &account.Settings.JWTGroupsClaimName,
|
||||
apiSettings := api.AccountSettings{
|
||||
PeerLoginExpiration: int(settings.PeerLoginExpiration.Seconds()),
|
||||
PeerLoginExpirationEnabled: settings.PeerLoginExpirationEnabled,
|
||||
GroupsPropagationEnabled: &settings.GroupsPropagationEnabled,
|
||||
JwtGroupsEnabled: &settings.JWTGroupsEnabled,
|
||||
JwtGroupsClaimName: &settings.JWTGroupsClaimName,
|
||||
JwtAllowGroups: &jwtAllowGroups,
|
||||
RegularUsersViewBlocked: account.Settings.RegularUsersViewBlocked,
|
||||
RegularUsersViewBlocked: settings.RegularUsersViewBlocked,
|
||||
}
|
||||
|
||||
if account.Settings.Extra != nil {
|
||||
settings.Extra = &api.AccountExtraSettings{PeerApprovalEnabled: &account.Settings.Extra.PeerApprovalEnabled}
|
||||
if settings.Extra != nil {
|
||||
apiSettings.Extra = &api.AccountExtraSettings{PeerApprovalEnabled: &settings.Extra.PeerApprovalEnabled}
|
||||
}
|
||||
|
||||
return &api.Account{
|
||||
Id: account.Id,
|
||||
Settings: settings,
|
||||
Id: accountID,
|
||||
Settings: apiSettings,
|
||||
}
|
||||
}
|
||||
|
@ -23,8 +23,11 @@ import (
|
||||
func initAccountsTestData(account *server.Account, admin *server.User) *AccountsHandler {
|
||||
return &AccountsHandler{
|
||||
accountManager: &mock_server.MockAccountManager{
|
||||
GetAccountFromTokenFunc: func(ctx context.Context, claims jwtclaims.AuthorizationClaims) (*server.Account, *server.User, error) {
|
||||
return account, admin, nil
|
||||
GetAccountIDFromTokenFunc: func(ctx context.Context, claims jwtclaims.AuthorizationClaims) (string, string, error) {
|
||||
return account.Id, admin.Id, nil
|
||||
},
|
||||
GetAccountSettingsFunc: func(ctx context.Context, accountID string, userID string) (*server.Settings, error) {
|
||||
return account.Settings, nil
|
||||
},
|
||||
UpdateAccountSettingsFunc: func(ctx context.Context, accountID, userID string, newSettings *server.Settings) (*server.Account, error) {
|
||||
halfYearLimit := 180 * 24 * time.Hour
|
||||
|
@ -32,14 +32,14 @@ func NewDNSSettingsHandler(accountManager server.AccountManager, authCfg AuthCfg
|
||||
// GetDNSSettings returns the DNS settings for the account
|
||||
func (h *DNSSettingsHandler) GetDNSSettings(w http.ResponseWriter, r *http.Request) {
|
||||
claims := h.claimsExtractor.FromRequestContext(r)
|
||||
account, user, err := h.accountManager.GetAccountFromToken(r.Context(), claims)
|
||||
accountID, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims)
|
||||
if err != nil {
|
||||
log.WithContext(r.Context()).Error(err)
|
||||
http.Redirect(w, r, "/", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
dnsSettings, err := h.accountManager.GetDNSSettings(r.Context(), account.Id, user.Id)
|
||||
dnsSettings, err := h.accountManager.GetDNSSettings(r.Context(), accountID, userID)
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
@ -55,7 +55,7 @@ func (h *DNSSettingsHandler) GetDNSSettings(w http.ResponseWriter, r *http.Reque
|
||||
// UpdateDNSSettings handles update to DNS settings of an account
|
||||
func (h *DNSSettingsHandler) UpdateDNSSettings(w http.ResponseWriter, r *http.Request) {
|
||||
claims := h.claimsExtractor.FromRequestContext(r)
|
||||
account, user, err := h.accountManager.GetAccountFromToken(r.Context(), claims)
|
||||
accountID, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims)
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
@ -72,7 +72,7 @@ func (h *DNSSettingsHandler) UpdateDNSSettings(w http.ResponseWriter, r *http.Re
|
||||
DisabledManagementGroups: req.DisabledManagementGroups,
|
||||
}
|
||||
|
||||
err = h.accountManager.SaveDNSSettings(r.Context(), account.Id, user.Id, updateDNSSettings)
|
||||
err = h.accountManager.SaveDNSSettings(r.Context(), accountID, userID, updateDNSSettings)
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
|
@ -52,8 +52,8 @@ func initDNSSettingsTestData() *DNSSettingsHandler {
|
||||
}
|
||||
return status.Errorf(status.InvalidArgument, "the dns settings provided are nil")
|
||||
},
|
||||
GetAccountFromTokenFunc: func(ctx context.Context, _ jwtclaims.AuthorizationClaims) (*server.Account, *server.User, error) {
|
||||
return testingDNSSettingsAccount, testingDNSSettingsAccount.Users[testDNSSettingsUserID], nil
|
||||
GetAccountIDFromTokenFunc: func(ctx context.Context, _ jwtclaims.AuthorizationClaims) (string, string, error) {
|
||||
return testingDNSSettingsAccount.Id, testingDNSSettingsAccount.Users[testDNSSettingsUserID].Id, nil
|
||||
},
|
||||
},
|
||||
claimsExtractor: jwtclaims.NewClaimsExtractor(
|
||||
|
@ -34,14 +34,14 @@ func NewEventsHandler(accountManager server.AccountManager, authCfg AuthCfg) *Ev
|
||||
// GetAllEvents list of the given account
|
||||
func (h *EventsHandler) GetAllEvents(w http.ResponseWriter, r *http.Request) {
|
||||
claims := h.claimsExtractor.FromRequestContext(r)
|
||||
account, user, err := h.accountManager.GetAccountFromToken(r.Context(), claims)
|
||||
accountID, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims)
|
||||
if err != nil {
|
||||
log.WithContext(r.Context()).Error(err)
|
||||
http.Redirect(w, r, "/", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
accountEvents, err := h.accountManager.GetEvents(r.Context(), account.Id, user.Id)
|
||||
accountEvents, err := h.accountManager.GetEvents(r.Context(), accountID, userID)
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
@ -51,7 +51,7 @@ func (h *EventsHandler) GetAllEvents(w http.ResponseWriter, r *http.Request) {
|
||||
events[i] = toEventResponse(e)
|
||||
}
|
||||
|
||||
err = h.fillEventsWithUserInfo(r.Context(), events, account.Id, user.Id)
|
||||
err = h.fillEventsWithUserInfo(r.Context(), events, accountID, userID)
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
|
@ -20,7 +20,7 @@ import (
|
||||
"github.com/netbirdio/netbird/management/server/mock_server"
|
||||
)
|
||||
|
||||
func initEventsTestData(account string, user *server.User, events ...*activity.Event) *EventsHandler {
|
||||
func initEventsTestData(account string, events ...*activity.Event) *EventsHandler {
|
||||
return &EventsHandler{
|
||||
accountManager: &mock_server.MockAccountManager{
|
||||
GetEventsFunc: func(_ context.Context, accountID, userID string) ([]*activity.Event, error) {
|
||||
@ -29,14 +29,8 @@ func initEventsTestData(account string, user *server.User, events ...*activity.E
|
||||
}
|
||||
return []*activity.Event{}, nil
|
||||
},
|
||||
GetAccountFromTokenFunc: func(_ context.Context, claims jwtclaims.AuthorizationClaims) (*server.Account, *server.User, error) {
|
||||
return &server.Account{
|
||||
Id: claims.AccountId,
|
||||
Domain: "hotmail.com",
|
||||
Users: map[string]*server.User{
|
||||
user.Id: user,
|
||||
},
|
||||
}, user, nil
|
||||
GetAccountIDFromTokenFunc: func(_ context.Context, claims jwtclaims.AuthorizationClaims) (string, string, error) {
|
||||
return claims.AccountId, claims.UserId, nil
|
||||
},
|
||||
GetUsersFromAccountFunc: func(_ context.Context, accountID, userID string) ([]*server.UserInfo, error) {
|
||||
return make([]*server.UserInfo, 0), nil
|
||||
@ -199,7 +193,7 @@ func TestEvents_GetEvents(t *testing.T) {
|
||||
accountID := "test_account"
|
||||
adminUser := server.NewAdminUser("test_user")
|
||||
events := generateEvents(accountID, adminUser.Id)
|
||||
handler := initEventsTestData(accountID, adminUser, events...)
|
||||
handler := initEventsTestData(accountID, events...)
|
||||
|
||||
for _, tc := range tt {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
|
@ -11,9 +11,9 @@ import (
|
||||
"testing"
|
||||
|
||||
"github.com/gorilla/mux"
|
||||
"github.com/netbirdio/netbird/management/server"
|
||||
"github.com/stretchr/testify/assert"
|
||||
|
||||
"github.com/netbirdio/netbird/management/server"
|
||||
"github.com/netbirdio/netbird/management/server/geolocation"
|
||||
"github.com/netbirdio/netbird/management/server/http/api"
|
||||
"github.com/netbirdio/netbird/management/server/jwtclaims"
|
||||
@ -43,14 +43,11 @@ func initGeolocationTestData(t *testing.T) *GeolocationsHandler {
|
||||
|
||||
return &GeolocationsHandler{
|
||||
accountManager: &mock_server.MockAccountManager{
|
||||
GetAccountFromTokenFunc: func(_ context.Context, claims jwtclaims.AuthorizationClaims) (*server.Account, *server.User, error) {
|
||||
user := server.NewAdminUser("test_user")
|
||||
return &server.Account{
|
||||
Id: claims.AccountId,
|
||||
Users: map[string]*server.User{
|
||||
"test_user": user,
|
||||
GetAccountIDFromTokenFunc: func(_ context.Context, claims jwtclaims.AuthorizationClaims) (string, string, error) {
|
||||
return claims.AccountId, claims.UserId, nil
|
||||
},
|
||||
}, user, nil
|
||||
GetUserByIDFunc: func(ctx context.Context, id string) (*server.User, error) {
|
||||
return server.NewAdminUser(id), nil
|
||||
},
|
||||
},
|
||||
geolocationManager: geo,
|
||||
|
@ -98,7 +98,12 @@ func (l *GeolocationsHandler) GetCitiesByCountry(w http.ResponseWriter, r *http.
|
||||
|
||||
func (l *GeolocationsHandler) authenticateUser(r *http.Request) error {
|
||||
claims := l.claimsExtractor.FromRequestContext(r)
|
||||
_, user, err := l.accountManager.GetAccountFromToken(r.Context(), claims)
|
||||
_, userID, err := l.accountManager.GetAccountIDFromToken(r.Context(), claims)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
user, err := l.accountManager.GetUserByID(r.Context(), userID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
@ -5,6 +5,7 @@ import (
|
||||
"net/http"
|
||||
|
||||
"github.com/gorilla/mux"
|
||||
nbpeer "github.com/netbirdio/netbird/management/server/peer"
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
"github.com/netbirdio/netbird/management/server"
|
||||
@ -35,14 +36,20 @@ func NewGroupsHandler(accountManager server.AccountManager, authCfg AuthCfg) *Gr
|
||||
// GetAllGroups list for the account
|
||||
func (h *GroupsHandler) GetAllGroups(w http.ResponseWriter, r *http.Request) {
|
||||
claims := h.claimsExtractor.FromRequestContext(r)
|
||||
account, user, err := h.accountManager.GetAccountFromToken(r.Context(), claims)
|
||||
accountID, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims)
|
||||
if err != nil {
|
||||
log.WithContext(r.Context()).Error(err)
|
||||
http.Redirect(w, r, "/", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
groups, err := h.accountManager.GetAllGroups(r.Context(), account.Id, user.Id)
|
||||
groups, err := h.accountManager.GetAllGroups(r.Context(), accountID, userID)
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
|
||||
accountPeers, err := h.accountManager.GetPeers(r.Context(), accountID, userID)
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
@ -50,7 +57,7 @@ func (h *GroupsHandler) GetAllGroups(w http.ResponseWriter, r *http.Request) {
|
||||
|
||||
groupsResponse := make([]*api.Group, 0, len(groups))
|
||||
for _, group := range groups {
|
||||
groupsResponse = append(groupsResponse, toGroupResponse(account, group))
|
||||
groupsResponse = append(groupsResponse, toGroupResponse(accountPeers, group))
|
||||
}
|
||||
|
||||
util.WriteJSONObject(r.Context(), w, groupsResponse)
|
||||
@ -59,7 +66,7 @@ func (h *GroupsHandler) GetAllGroups(w http.ResponseWriter, r *http.Request) {
|
||||
// UpdateGroup handles update to a group identified by a given ID
|
||||
func (h *GroupsHandler) UpdateGroup(w http.ResponseWriter, r *http.Request) {
|
||||
claims := h.claimsExtractor.FromRequestContext(r)
|
||||
account, user, err := h.accountManager.GetAccountFromToken(r.Context(), claims)
|
||||
accountID, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims)
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
@ -76,17 +83,18 @@ func (h *GroupsHandler) UpdateGroup(w http.ResponseWriter, r *http.Request) {
|
||||
return
|
||||
}
|
||||
|
||||
eg, ok := account.Groups[groupID]
|
||||
if !ok {
|
||||
util.WriteError(r.Context(), status.Errorf(status.NotFound, "couldn't find group with ID %s", groupID), w)
|
||||
return
|
||||
}
|
||||
|
||||
allGroup, err := account.GetGroupAll()
|
||||
existingGroup, err := h.accountManager.GetGroup(r.Context(), accountID, groupID, userID)
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
|
||||
allGroup, err := h.accountManager.GetGroupByName(r.Context(), "All", accountID)
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
|
||||
if allGroup.ID == groupID {
|
||||
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "updating group ALL is not allowed"), w)
|
||||
return
|
||||
@ -114,23 +122,29 @@ func (h *GroupsHandler) UpdateGroup(w http.ResponseWriter, r *http.Request) {
|
||||
ID: groupID,
|
||||
Name: req.Name,
|
||||
Peers: peers,
|
||||
Issued: eg.Issued,
|
||||
IntegrationReference: eg.IntegrationReference,
|
||||
Issued: existingGroup.Issued,
|
||||
IntegrationReference: existingGroup.IntegrationReference,
|
||||
}
|
||||
|
||||
if err := h.accountManager.SaveGroup(r.Context(), account.Id, user.Id, &group); err != nil {
|
||||
log.WithContext(r.Context()).Errorf("failed updating group %s under account %s %v", groupID, account.Id, err)
|
||||
if err := h.accountManager.SaveGroup(r.Context(), accountID, userID, &group); err != nil {
|
||||
log.WithContext(r.Context()).Errorf("failed updating group %s under account %s %v", groupID, accountID, err)
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
|
||||
util.WriteJSONObject(r.Context(), w, toGroupResponse(account, &group))
|
||||
accountPeers, err := h.accountManager.GetPeers(r.Context(), accountID, userID)
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
|
||||
util.WriteJSONObject(r.Context(), w, toGroupResponse(accountPeers, &group))
|
||||
}
|
||||
|
||||
// CreateGroup handles group creation request
|
||||
func (h *GroupsHandler) CreateGroup(w http.ResponseWriter, r *http.Request) {
|
||||
claims := h.claimsExtractor.FromRequestContext(r)
|
||||
account, user, err := h.accountManager.GetAccountFromToken(r.Context(), claims)
|
||||
accountID, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims)
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
@ -160,24 +174,29 @@ func (h *GroupsHandler) CreateGroup(w http.ResponseWriter, r *http.Request) {
|
||||
Issued: nbgroup.GroupIssuedAPI,
|
||||
}
|
||||
|
||||
err = h.accountManager.SaveGroup(r.Context(), account.Id, user.Id, &group)
|
||||
err = h.accountManager.SaveGroup(r.Context(), accountID, userID, &group)
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
|
||||
util.WriteJSONObject(r.Context(), w, toGroupResponse(account, &group))
|
||||
accountPeers, err := h.accountManager.GetPeers(r.Context(), accountID, userID)
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
|
||||
util.WriteJSONObject(r.Context(), w, toGroupResponse(accountPeers, &group))
|
||||
}
|
||||
|
||||
// DeleteGroup handles group deletion request
|
||||
func (h *GroupsHandler) DeleteGroup(w http.ResponseWriter, r *http.Request) {
|
||||
claims := h.claimsExtractor.FromRequestContext(r)
|
||||
account, user, err := h.accountManager.GetAccountFromToken(r.Context(), claims)
|
||||
accountID, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims)
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
aID := account.Id
|
||||
|
||||
groupID := mux.Vars(r)["groupId"]
|
||||
if len(groupID) == 0 {
|
||||
@ -185,18 +204,7 @@ func (h *GroupsHandler) DeleteGroup(w http.ResponseWriter, r *http.Request) {
|
||||
return
|
||||
}
|
||||
|
||||
allGroup, err := account.GetGroupAll()
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
|
||||
if allGroup.ID == groupID {
|
||||
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "deleting group ALL is not allowed"), w)
|
||||
return
|
||||
}
|
||||
|
||||
err = h.accountManager.DeleteGroup(r.Context(), aID, user.Id, groupID)
|
||||
err = h.accountManager.DeleteGroup(r.Context(), accountID, userID, groupID)
|
||||
if err != nil {
|
||||
_, ok := err.(*server.GroupLinkError)
|
||||
if ok {
|
||||
@ -213,34 +221,39 @@ func (h *GroupsHandler) DeleteGroup(w http.ResponseWriter, r *http.Request) {
|
||||
// GetGroup returns a group
|
||||
func (h *GroupsHandler) GetGroup(w http.ResponseWriter, r *http.Request) {
|
||||
claims := h.claimsExtractor.FromRequestContext(r)
|
||||
account, user, err := h.accountManager.GetAccountFromToken(r.Context(), claims)
|
||||
accountID, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims)
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
|
||||
switch r.Method {
|
||||
case http.MethodGet:
|
||||
groupID := mux.Vars(r)["groupId"]
|
||||
if len(groupID) == 0 {
|
||||
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "invalid group ID"), w)
|
||||
return
|
||||
}
|
||||
|
||||
group, err := h.accountManager.GetGroup(r.Context(), account.Id, groupID, user.Id)
|
||||
group, err := h.accountManager.GetGroup(r.Context(), accountID, groupID, userID)
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
|
||||
util.WriteJSONObject(r.Context(), w, toGroupResponse(account, group))
|
||||
default:
|
||||
util.WriteError(r.Context(), status.Errorf(status.NotFound, "HTTP method not found"), w)
|
||||
accountPeers, err := h.accountManager.GetPeers(r.Context(), accountID, userID)
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
|
||||
util.WriteJSONObject(r.Context(), w, toGroupResponse(accountPeers, group))
|
||||
|
||||
}
|
||||
|
||||
func toGroupResponse(peers []*nbpeer.Peer, group *nbgroup.Group) *api.Group {
|
||||
peersMap := make(map[string]*nbpeer.Peer, len(peers))
|
||||
for _, peer := range peers {
|
||||
peersMap[peer.ID] = peer
|
||||
}
|
||||
|
||||
func toGroupResponse(account *server.Account, group *nbgroup.Group) *api.Group {
|
||||
cache := make(map[string]api.PeerMinimum)
|
||||
gr := api.Group{
|
||||
Id: group.ID,
|
||||
@ -251,7 +264,7 @@ func toGroupResponse(account *server.Account, group *nbgroup.Group) *api.Group {
|
||||
for _, pid := range group.Peers {
|
||||
_, ok := cache[pid]
|
||||
if !ok {
|
||||
peer, ok := account.Peers[pid]
|
||||
peer, ok := peersMap[pid]
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
|
@ -14,6 +14,7 @@ import (
|
||||
|
||||
"github.com/gorilla/mux"
|
||||
"github.com/magiconair/properties/assert"
|
||||
"golang.org/x/exp/maps"
|
||||
|
||||
"github.com/netbirdio/netbird/management/server"
|
||||
nbgroup "github.com/netbirdio/netbird/management/server/group"
|
||||
@ -30,7 +31,7 @@ var TestPeers = map[string]*nbpeer.Peer{
|
||||
"B": {Key: "B", ID: "peer-B-ID", IP: net.ParseIP("200.200.200.200")},
|
||||
}
|
||||
|
||||
func initGroupTestData(user *server.User, _ ...*nbgroup.Group) *GroupsHandler {
|
||||
func initGroupTestData(initGroups ...*nbgroup.Group) *GroupsHandler {
|
||||
return &GroupsHandler{
|
||||
accountManager: &mock_server.MockAccountManager{
|
||||
SaveGroupFunc: func(_ context.Context, accountID, userID string, group *nbgroup.Group) error {
|
||||
@ -40,36 +41,35 @@ func initGroupTestData(user *server.User, _ ...*nbgroup.Group) *GroupsHandler {
|
||||
return nil
|
||||
},
|
||||
GetGroupFunc: func(_ context.Context, _, groupID, _ string) (*nbgroup.Group, error) {
|
||||
if groupID != "idofthegroup" {
|
||||
return nil, status.Errorf(status.NotFound, "not found")
|
||||
}
|
||||
if groupID == "id-jwt-group" {
|
||||
return &nbgroup.Group{
|
||||
ID: "id-jwt-group",
|
||||
Name: "Default Group",
|
||||
Issued: nbgroup.GroupIssuedJWT,
|
||||
}, nil
|
||||
}
|
||||
return &nbgroup.Group{
|
||||
ID: "idofthegroup",
|
||||
Name: "Group",
|
||||
Issued: nbgroup.GroupIssuedAPI,
|
||||
}, nil
|
||||
},
|
||||
GetAccountFromTokenFunc: func(_ context.Context, claims jwtclaims.AuthorizationClaims) (*server.Account, *server.User, error) {
|
||||
return &server.Account{
|
||||
Id: claims.AccountId,
|
||||
Domain: "hotmail.com",
|
||||
Peers: TestPeers,
|
||||
Users: map[string]*server.User{
|
||||
user.Id: user,
|
||||
},
|
||||
Groups: map[string]*nbgroup.Group{
|
||||
groups := map[string]*nbgroup.Group{
|
||||
"id-jwt-group": {ID: "id-jwt-group", Name: "From JWT", Issued: nbgroup.GroupIssuedJWT},
|
||||
"id-existed": {ID: "id-existed", Peers: []string{"A", "B"}, Issued: nbgroup.GroupIssuedAPI},
|
||||
"id-all": {ID: "id-all", Name: "All", Issued: nbgroup.GroupIssuedAPI},
|
||||
}
|
||||
|
||||
for _, group := range initGroups {
|
||||
groups[group.ID] = group
|
||||
}
|
||||
|
||||
group, ok := groups[groupID]
|
||||
if !ok {
|
||||
return nil, status.Errorf(status.NotFound, "not found")
|
||||
}
|
||||
|
||||
return group, nil
|
||||
},
|
||||
}, user, nil
|
||||
GetAccountIDFromTokenFunc: func(_ context.Context, claims jwtclaims.AuthorizationClaims) (string, string, error) {
|
||||
return claims.AccountId, claims.UserId, nil
|
||||
},
|
||||
GetGroupByNameFunc: func(ctx context.Context, groupName, _ string) (*nbgroup.Group, error) {
|
||||
if groupName == "All" {
|
||||
return &nbgroup.Group{ID: "id-all", Name: "All", Issued: nbgroup.GroupIssuedAPI}, nil
|
||||
}
|
||||
|
||||
return nil, fmt.Errorf("unknown group name")
|
||||
},
|
||||
GetPeersFunc: func(ctx context.Context, accountID, userID string) ([]*nbpeer.Peer, error) {
|
||||
return maps.Values(TestPeers), nil
|
||||
},
|
||||
DeleteGroupFunc: func(_ context.Context, accountID, userId, groupID string) error {
|
||||
if groupID == "linked-grp" {
|
||||
@ -125,8 +125,7 @@ func TestGetGroup(t *testing.T) {
|
||||
Name: "Group",
|
||||
}
|
||||
|
||||
adminUser := server.NewAdminUser("test_user")
|
||||
p := initGroupTestData(adminUser, group)
|
||||
p := initGroupTestData(group)
|
||||
|
||||
for _, tc := range tt {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
@ -247,8 +246,7 @@ func TestWriteGroup(t *testing.T) {
|
||||
},
|
||||
}
|
||||
|
||||
adminUser := server.NewAdminUser("test_user")
|
||||
p := initGroupTestData(adminUser)
|
||||
p := initGroupTestData()
|
||||
|
||||
for _, tc := range tt {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
@ -325,8 +323,7 @@ func TestDeleteGroup(t *testing.T) {
|
||||
},
|
||||
}
|
||||
|
||||
adminUser := server.NewAdminUser("test_user")
|
||||
p := initGroupTestData(adminUser)
|
||||
p := initGroupTestData()
|
||||
|
||||
for _, tc := range tt {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
|
@ -36,14 +36,14 @@ func NewNameserversHandler(accountManager server.AccountManager, authCfg AuthCfg
|
||||
// GetAllNameservers returns the list of nameserver groups for the account
|
||||
func (h *NameserversHandler) GetAllNameservers(w http.ResponseWriter, r *http.Request) {
|
||||
claims := h.claimsExtractor.FromRequestContext(r)
|
||||
account, user, err := h.accountManager.GetAccountFromToken(r.Context(), claims)
|
||||
accountID, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims)
|
||||
if err != nil {
|
||||
log.WithContext(r.Context()).Error(err)
|
||||
http.Redirect(w, r, "/", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
nsGroups, err := h.accountManager.ListNameServerGroups(r.Context(), account.Id, user.Id)
|
||||
nsGroups, err := h.accountManager.ListNameServerGroups(r.Context(), accountID, userID)
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
@ -60,7 +60,7 @@ func (h *NameserversHandler) GetAllNameservers(w http.ResponseWriter, r *http.Re
|
||||
// CreateNameserverGroup handles nameserver group creation request
|
||||
func (h *NameserversHandler) CreateNameserverGroup(w http.ResponseWriter, r *http.Request) {
|
||||
claims := h.claimsExtractor.FromRequestContext(r)
|
||||
account, user, err := h.accountManager.GetAccountFromToken(r.Context(), claims)
|
||||
accountID, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims)
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
@ -79,7 +79,7 @@ func (h *NameserversHandler) CreateNameserverGroup(w http.ResponseWriter, r *htt
|
||||
return
|
||||
}
|
||||
|
||||
nsGroup, err := h.accountManager.CreateNameServerGroup(r.Context(), account.Id, req.Name, req.Description, nsList, req.Groups, req.Primary, req.Domains, req.Enabled, user.Id, req.SearchDomainsEnabled)
|
||||
nsGroup, err := h.accountManager.CreateNameServerGroup(r.Context(), accountID, req.Name, req.Description, nsList, req.Groups, req.Primary, req.Domains, req.Enabled, userID, req.SearchDomainsEnabled)
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
@ -93,7 +93,7 @@ func (h *NameserversHandler) CreateNameserverGroup(w http.ResponseWriter, r *htt
|
||||
// UpdateNameserverGroup handles update to a nameserver group identified by a given ID
|
||||
func (h *NameserversHandler) UpdateNameserverGroup(w http.ResponseWriter, r *http.Request) {
|
||||
claims := h.claimsExtractor.FromRequestContext(r)
|
||||
account, user, err := h.accountManager.GetAccountFromToken(r.Context(), claims)
|
||||
accountID, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims)
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
@ -130,7 +130,7 @@ func (h *NameserversHandler) UpdateNameserverGroup(w http.ResponseWriter, r *htt
|
||||
SearchDomainsEnabled: req.SearchDomainsEnabled,
|
||||
}
|
||||
|
||||
err = h.accountManager.SaveNameServerGroup(r.Context(), account.Id, user.Id, updatedNSGroup)
|
||||
err = h.accountManager.SaveNameServerGroup(r.Context(), accountID, userID, updatedNSGroup)
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
@ -144,7 +144,7 @@ func (h *NameserversHandler) UpdateNameserverGroup(w http.ResponseWriter, r *htt
|
||||
// DeleteNameserverGroup handles nameserver group deletion request
|
||||
func (h *NameserversHandler) DeleteNameserverGroup(w http.ResponseWriter, r *http.Request) {
|
||||
claims := h.claimsExtractor.FromRequestContext(r)
|
||||
account, user, err := h.accountManager.GetAccountFromToken(r.Context(), claims)
|
||||
accountID, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims)
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
@ -156,7 +156,7 @@ func (h *NameserversHandler) DeleteNameserverGroup(w http.ResponseWriter, r *htt
|
||||
return
|
||||
}
|
||||
|
||||
err = h.accountManager.DeleteNameServerGroup(r.Context(), account.Id, nsGroupID, user.Id)
|
||||
err = h.accountManager.DeleteNameServerGroup(r.Context(), accountID, nsGroupID, userID)
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
@ -168,7 +168,7 @@ func (h *NameserversHandler) DeleteNameserverGroup(w http.ResponseWriter, r *htt
|
||||
// GetNameserverGroup handles a nameserver group Get request identified by ID
|
||||
func (h *NameserversHandler) GetNameserverGroup(w http.ResponseWriter, r *http.Request) {
|
||||
claims := h.claimsExtractor.FromRequestContext(r)
|
||||
account, user, err := h.accountManager.GetAccountFromToken(r.Context(), claims)
|
||||
accountID, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims)
|
||||
if err != nil {
|
||||
log.WithContext(r.Context()).Error(err)
|
||||
http.Redirect(w, r, "/", http.StatusInternalServerError)
|
||||
@ -181,7 +181,7 @@ func (h *NameserversHandler) GetNameserverGroup(w http.ResponseWriter, r *http.R
|
||||
return
|
||||
}
|
||||
|
||||
nsGroup, err := h.accountManager.GetNameServerGroup(r.Context(), account.Id, user.Id, nsGroupID)
|
||||
nsGroup, err := h.accountManager.GetNameServerGroup(r.Context(), accountID, userID, nsGroupID)
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
|
@ -18,7 +18,6 @@ import (
|
||||
|
||||
"github.com/gorilla/mux"
|
||||
|
||||
"github.com/netbirdio/netbird/management/server"
|
||||
"github.com/netbirdio/netbird/management/server/jwtclaims"
|
||||
"github.com/netbirdio/netbird/management/server/mock_server"
|
||||
)
|
||||
@ -29,14 +28,6 @@ const (
|
||||
testNSGroupAccountID = "test_id"
|
||||
)
|
||||
|
||||
var testingNSAccount = &server.Account{
|
||||
Id: testNSGroupAccountID,
|
||||
Domain: "hotmail.com",
|
||||
Users: map[string]*server.User{
|
||||
"test_user": server.NewAdminUser("test_user"),
|
||||
},
|
||||
}
|
||||
|
||||
var baseExistingNSGroup = &nbdns.NameServerGroup{
|
||||
ID: existingNSGroupID,
|
||||
Name: "super",
|
||||
@ -90,8 +81,8 @@ func initNameserversTestData() *NameserversHandler {
|
||||
}
|
||||
return status.Errorf(status.NotFound, "nameserver group with ID %s was not found", nsGroupToSave.ID)
|
||||
},
|
||||
GetAccountFromTokenFunc: func(_ context.Context, _ jwtclaims.AuthorizationClaims) (*server.Account, *server.User, error) {
|
||||
return testingNSAccount, testingAccount.Users["test_user"], nil
|
||||
GetAccountIDFromTokenFunc: func(_ context.Context, claims jwtclaims.AuthorizationClaims) (string, string, error) {
|
||||
return claims.AccountId, claims.UserId, nil
|
||||
},
|
||||
},
|
||||
claimsExtractor: jwtclaims.NewClaimsExtractor(
|
||||
|
@ -34,20 +34,20 @@ func NewPATsHandler(accountManager server.AccountManager, authCfg AuthCfg) *PATH
|
||||
// GetAllTokens is HTTP GET handler that returns a list of all personal access tokens for the given user
|
||||
func (h *PATHandler) GetAllTokens(w http.ResponseWriter, r *http.Request) {
|
||||
claims := h.claimsExtractor.FromRequestContext(r)
|
||||
account, user, err := h.accountManager.GetAccountFromToken(r.Context(), claims)
|
||||
accountID, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims)
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
|
||||
vars := mux.Vars(r)
|
||||
userID := vars["userId"]
|
||||
targetUserID := vars["userId"]
|
||||
if len(userID) == 0 {
|
||||
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "invalid user ID"), w)
|
||||
return
|
||||
}
|
||||
|
||||
pats, err := h.accountManager.GetAllPATs(r.Context(), account.Id, user.Id, userID)
|
||||
pats, err := h.accountManager.GetAllPATs(r.Context(), accountID, userID, targetUserID)
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
@ -64,7 +64,7 @@ func (h *PATHandler) GetAllTokens(w http.ResponseWriter, r *http.Request) {
|
||||
// GetToken is HTTP GET handler that returns a personal access token for the given user
|
||||
func (h *PATHandler) GetToken(w http.ResponseWriter, r *http.Request) {
|
||||
claims := h.claimsExtractor.FromRequestContext(r)
|
||||
account, user, err := h.accountManager.GetAccountFromToken(r.Context(), claims)
|
||||
accountID, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims)
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
@ -83,7 +83,7 @@ func (h *PATHandler) GetToken(w http.ResponseWriter, r *http.Request) {
|
||||
return
|
||||
}
|
||||
|
||||
pat, err := h.accountManager.GetPAT(r.Context(), account.Id, user.Id, targetUserID, tokenID)
|
||||
pat, err := h.accountManager.GetPAT(r.Context(), accountID, userID, targetUserID, tokenID)
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
@ -95,7 +95,7 @@ func (h *PATHandler) GetToken(w http.ResponseWriter, r *http.Request) {
|
||||
// CreateToken is HTTP POST handler that creates a personal access token for the given user
|
||||
func (h *PATHandler) CreateToken(w http.ResponseWriter, r *http.Request) {
|
||||
claims := h.claimsExtractor.FromRequestContext(r)
|
||||
account, user, err := h.accountManager.GetAccountFromToken(r.Context(), claims)
|
||||
accountID, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims)
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
@ -115,7 +115,7 @@ func (h *PATHandler) CreateToken(w http.ResponseWriter, r *http.Request) {
|
||||
return
|
||||
}
|
||||
|
||||
pat, err := h.accountManager.CreatePAT(r.Context(), account.Id, user.Id, targetUserID, req.Name, req.ExpiresIn)
|
||||
pat, err := h.accountManager.CreatePAT(r.Context(), accountID, userID, targetUserID, req.Name, req.ExpiresIn)
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
@ -127,7 +127,7 @@ func (h *PATHandler) CreateToken(w http.ResponseWriter, r *http.Request) {
|
||||
// DeleteToken is HTTP DELETE handler that deletes a personal access token for the given user
|
||||
func (h *PATHandler) DeleteToken(w http.ResponseWriter, r *http.Request) {
|
||||
claims := h.claimsExtractor.FromRequestContext(r)
|
||||
account, user, err := h.accountManager.GetAccountFromToken(r.Context(), claims)
|
||||
accountID, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims)
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
@ -146,7 +146,7 @@ func (h *PATHandler) DeleteToken(w http.ResponseWriter, r *http.Request) {
|
||||
return
|
||||
}
|
||||
|
||||
err = h.accountManager.DeletePAT(r.Context(), account.Id, user.Id, targetUserID, tokenID)
|
||||
err = h.accountManager.DeletePAT(r.Context(), accountID, userID, targetUserID, tokenID)
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
|
@ -77,8 +77,8 @@ func initPATTestData() *PATHandler {
|
||||
}, nil
|
||||
},
|
||||
|
||||
GetAccountFromTokenFunc: func(_ context.Context, _ jwtclaims.AuthorizationClaims) (*server.Account, *server.User, error) {
|
||||
return testAccount, testAccount.Users[existingUserID], nil
|
||||
GetAccountIDFromTokenFunc: func(_ context.Context, claims jwtclaims.AuthorizationClaims) (string, string, error) {
|
||||
return claims.AccountId, claims.UserId, nil
|
||||
},
|
||||
DeletePATFunc: func(_ context.Context, accountID string, initiatorUserID string, targetUserID string, tokenID string) error {
|
||||
if accountID != existingAccountID {
|
||||
@ -119,7 +119,7 @@ func initPATTestData() *PATHandler {
|
||||
return jwtclaims.AuthorizationClaims{
|
||||
UserId: existingUserID,
|
||||
Domain: testDomain,
|
||||
AccountId: testNSGroupAccountID,
|
||||
AccountId: existingAccountID,
|
||||
}
|
||||
}),
|
||||
),
|
||||
|
@ -74,7 +74,7 @@ func (h *PeersHandler) getPeer(ctx context.Context, account *server.Account, pee
|
||||
util.WriteJSONObject(ctx, w, toSinglePeerResponse(peerToReturn, groupsInfo, dnsDomain, valid))
|
||||
}
|
||||
|
||||
func (h *PeersHandler) updatePeer(ctx context.Context, account *server.Account, user *server.User, peerID string, w http.ResponseWriter, r *http.Request) {
|
||||
func (h *PeersHandler) updatePeer(ctx context.Context, account *server.Account, userID, peerID string, w http.ResponseWriter, r *http.Request) {
|
||||
req := &api.PeerRequest{}
|
||||
err := json.NewDecoder(r.Body).Decode(&req)
|
||||
if err != nil {
|
||||
@ -96,7 +96,7 @@ func (h *PeersHandler) updatePeer(ctx context.Context, account *server.Account,
|
||||
}
|
||||
}
|
||||
|
||||
peer, err := h.accountManager.UpdatePeer(ctx, account.Id, user.Id, update)
|
||||
peer, err := h.accountManager.UpdatePeer(ctx, account.Id, userID, update)
|
||||
if err != nil {
|
||||
util.WriteError(ctx, err, w)
|
||||
return
|
||||
@ -130,7 +130,7 @@ func (h *PeersHandler) deletePeer(ctx context.Context, accountID, userID string,
|
||||
// HandlePeer handles all peer requests for GET, PUT and DELETE operations
|
||||
func (h *PeersHandler) HandlePeer(w http.ResponseWriter, r *http.Request) {
|
||||
claims := h.claimsExtractor.FromRequestContext(r)
|
||||
account, user, err := h.accountManager.GetAccountFromToken(r.Context(), claims)
|
||||
accountID, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims)
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
@ -144,13 +144,20 @@ func (h *PeersHandler) HandlePeer(w http.ResponseWriter, r *http.Request) {
|
||||
|
||||
switch r.Method {
|
||||
case http.MethodDelete:
|
||||
h.deletePeer(r.Context(), account.Id, user.Id, peerID, w)
|
||||
h.deletePeer(r.Context(), accountID, userID, peerID, w)
|
||||
return
|
||||
case http.MethodPut:
|
||||
h.updatePeer(r.Context(), account, user, peerID, w, r)
|
||||
case http.MethodGet, http.MethodPut:
|
||||
account, err := h.accountManager.GetAccountByID(r.Context(), accountID, userID)
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
case http.MethodGet:
|
||||
h.getPeer(r.Context(), account, peerID, user.Id, w)
|
||||
}
|
||||
|
||||
if r.Method == http.MethodGet {
|
||||
h.getPeer(r.Context(), account, peerID, userID, w)
|
||||
} else {
|
||||
h.updatePeer(r.Context(), account, userID, peerID, w, r)
|
||||
}
|
||||
return
|
||||
default:
|
||||
util.WriteError(r.Context(), status.Errorf(status.NotFound, "unknown METHOD"), w)
|
||||
@ -159,19 +166,14 @@ func (h *PeersHandler) HandlePeer(w http.ResponseWriter, r *http.Request) {
|
||||
|
||||
// GetAllPeers returns a list of all peers associated with a provided account
|
||||
func (h *PeersHandler) GetAllPeers(w http.ResponseWriter, r *http.Request) {
|
||||
if r.Method != http.MethodGet {
|
||||
util.WriteError(r.Context(), status.Errorf(status.NotFound, "unknown METHOD"), w)
|
||||
return
|
||||
}
|
||||
|
||||
claims := h.claimsExtractor.FromRequestContext(r)
|
||||
account, user, err := h.accountManager.GetAccountFromToken(r.Context(), claims)
|
||||
accountID, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims)
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
|
||||
peers, err := h.accountManager.GetPeers(r.Context(), account.Id, user.Id)
|
||||
account, err := h.accountManager.GetAccountByID(r.Context(), accountID, userID)
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
@ -179,8 +181,8 @@ func (h *PeersHandler) GetAllPeers(w http.ResponseWriter, r *http.Request) {
|
||||
|
||||
dnsDomain := h.accountManager.GetDNSDomain()
|
||||
|
||||
respBody := make([]*api.PeerBatch, 0, len(peers))
|
||||
for _, peer := range peers {
|
||||
respBody := make([]*api.PeerBatch, 0, len(account.Peers))
|
||||
for _, peer := range account.Peers {
|
||||
peerToReturn, err := h.checkPeerStatus(peer)
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
@ -214,7 +216,7 @@ func (h *PeersHandler) setApprovalRequiredFlag(respBody []*api.PeerBatch, approv
|
||||
// GetAccessiblePeers returns a list of all peers that the specified peer can connect to within the network.
|
||||
func (h *PeersHandler) GetAccessiblePeers(w http.ResponseWriter, r *http.Request) {
|
||||
claims := h.claimsExtractor.FromRequestContext(r)
|
||||
account, user, err := h.accountManager.GetAccountFromToken(r.Context(), claims)
|
||||
accountID, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims)
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
@ -227,6 +229,18 @@ func (h *PeersHandler) GetAccessiblePeers(w http.ResponseWriter, r *http.Request
|
||||
return
|
||||
}
|
||||
|
||||
account, err := h.accountManager.GetAccountByID(r.Context(), accountID, userID)
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
|
||||
user, err := account.FindUser(userID)
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
|
||||
// If the user is regular user and does not own the peer
|
||||
// with the given peerID return an empty list
|
||||
if !user.HasAdminPower() && !user.IsServiceUser {
|
||||
|
@ -13,16 +13,15 @@ import (
|
||||
"time"
|
||||
|
||||
"github.com/gorilla/mux"
|
||||
"github.com/netbirdio/netbird/management/server"
|
||||
nbgroup "github.com/netbirdio/netbird/management/server/group"
|
||||
"github.com/netbirdio/netbird/management/server/http/api"
|
||||
"github.com/netbirdio/netbird/management/server/jwtclaims"
|
||||
nbpeer "github.com/netbirdio/netbird/management/server/peer"
|
||||
"golang.org/x/exp/maps"
|
||||
|
||||
"github.com/netbirdio/netbird/management/server/jwtclaims"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
|
||||
"github.com/netbirdio/netbird/management/server"
|
||||
"github.com/netbirdio/netbird/management/server/mock_server"
|
||||
)
|
||||
|
||||
@ -70,7 +69,10 @@ func initTestMetaData(peers ...*nbpeer.Peer) *PeersHandler {
|
||||
GetDNSDomainFunc: func() string {
|
||||
return "netbird.selfhosted"
|
||||
},
|
||||
GetAccountFromTokenFunc: func(_ context.Context, claims jwtclaims.AuthorizationClaims) (*server.Account, *server.User, error) {
|
||||
GetAccountIDFromTokenFunc: func(_ context.Context, claims jwtclaims.AuthorizationClaims) (string, string, error) {
|
||||
return claims.AccountId, claims.UserId, nil
|
||||
},
|
||||
GetAccountByIDFunc: func(ctx context.Context, accountID string, userID string) (*server.Account, error) {
|
||||
peersMap := make(map[string]*nbpeer.Peer)
|
||||
for _, peer := range peers {
|
||||
peersMap[peer.ID] = peer.Copy()
|
||||
@ -78,7 +80,7 @@ func initTestMetaData(peers ...*nbpeer.Peer) *PeersHandler {
|
||||
|
||||
policy := &server.Policy{
|
||||
ID: "policy",
|
||||
AccountID: claims.AccountId,
|
||||
AccountID: accountID,
|
||||
Name: "policy",
|
||||
Enabled: true,
|
||||
Rules: []*server.PolicyRule{
|
||||
@ -100,7 +102,7 @@ func initTestMetaData(peers ...*nbpeer.Peer) *PeersHandler {
|
||||
srvUser.IsServiceUser = true
|
||||
|
||||
account := &server.Account{
|
||||
Id: claims.AccountId,
|
||||
Id: accountID,
|
||||
Domain: "hotmail.com",
|
||||
Peers: peersMap,
|
||||
Users: map[string]*server.User{
|
||||
@ -111,7 +113,7 @@ func initTestMetaData(peers ...*nbpeer.Peer) *PeersHandler {
|
||||
Groups: map[string]*nbgroup.Group{
|
||||
"group1": {
|
||||
ID: "group1",
|
||||
AccountID: claims.AccountId,
|
||||
AccountID: accountID,
|
||||
Name: "group1",
|
||||
Issued: "api",
|
||||
Peers: maps.Keys(peersMap),
|
||||
@ -132,7 +134,7 @@ func initTestMetaData(peers ...*nbpeer.Peer) *PeersHandler {
|
||||
},
|
||||
}
|
||||
|
||||
return account, account.Users[claims.UserId], nil
|
||||
return account, nil
|
||||
},
|
||||
HasConnectedChannelFunc: func(peerID string) bool {
|
||||
statuses := make(map[string]struct{})
|
||||
@ -279,9 +281,15 @@ func TestGetPeers(t *testing.T) {
|
||||
|
||||
// hardcode this check for now as we only have two peers in this suite
|
||||
assert.Equal(t, len(respBody), 2)
|
||||
assert.Equal(t, respBody[1].Connected, false)
|
||||
|
||||
got = respBody[0]
|
||||
for _, peer := range respBody {
|
||||
if peer.Id == testPeerID {
|
||||
got = peer
|
||||
} else {
|
||||
assert.Equal(t, peer.Connected, false)
|
||||
}
|
||||
}
|
||||
|
||||
} else {
|
||||
got = &api.Peer{}
|
||||
err = json.Unmarshal(content, got)
|
||||
|
@ -6,6 +6,7 @@ import (
|
||||
"strconv"
|
||||
|
||||
"github.com/gorilla/mux"
|
||||
nbgroup "github.com/netbirdio/netbird/management/server/group"
|
||||
"github.com/rs/xid"
|
||||
|
||||
"github.com/netbirdio/netbird/management/server"
|
||||
@ -35,21 +36,27 @@ func NewPoliciesHandler(accountManager server.AccountManager, authCfg AuthCfg) *
|
||||
// GetAllPolicies list for the account
|
||||
func (h *Policies) GetAllPolicies(w http.ResponseWriter, r *http.Request) {
|
||||
claims := h.claimsExtractor.FromRequestContext(r)
|
||||
account, user, err := h.accountManager.GetAccountFromToken(r.Context(), claims)
|
||||
accountID, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims)
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
|
||||
accountPolicies, err := h.accountManager.ListPolicies(r.Context(), account.Id, user.Id)
|
||||
listPolicies, err := h.accountManager.ListPolicies(r.Context(), accountID, userID)
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
|
||||
policies := []*api.Policy{}
|
||||
for _, policy := range accountPolicies {
|
||||
resp := toPolicyResponse(account, policy)
|
||||
allGroups, err := h.accountManager.GetAllGroups(r.Context(), accountID, userID)
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
|
||||
policies := make([]*api.Policy, 0, len(listPolicies))
|
||||
for _, policy := range listPolicies {
|
||||
resp := toPolicyResponse(allGroups, policy)
|
||||
if len(resp.Rules) == 0 {
|
||||
util.WriteError(r.Context(), status.Errorf(status.Internal, "no rules in the policy"), w)
|
||||
return
|
||||
@ -63,7 +70,7 @@ func (h *Policies) GetAllPolicies(w http.ResponseWriter, r *http.Request) {
|
||||
// UpdatePolicy handles update to a policy identified by a given ID
|
||||
func (h *Policies) UpdatePolicy(w http.ResponseWriter, r *http.Request) {
|
||||
claims := h.claimsExtractor.FromRequestContext(r)
|
||||
account, user, err := h.accountManager.GetAccountFromToken(r.Context(), claims)
|
||||
accountID, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims)
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
@ -76,41 +83,29 @@ func (h *Policies) UpdatePolicy(w http.ResponseWriter, r *http.Request) {
|
||||
return
|
||||
}
|
||||
|
||||
policyIdx := -1
|
||||
for i, policy := range account.Policies {
|
||||
if policy.ID == policyID {
|
||||
policyIdx = i
|
||||
break
|
||||
}
|
||||
}
|
||||
if policyIdx < 0 {
|
||||
util.WriteError(r.Context(), status.Errorf(status.NotFound, "couldn't find policy id %s", policyID), w)
|
||||
return
|
||||
}
|
||||
|
||||
h.savePolicy(w, r, account, user, policyID)
|
||||
}
|
||||
|
||||
// CreatePolicy handles policy creation request
|
||||
func (h *Policies) CreatePolicy(w http.ResponseWriter, r *http.Request) {
|
||||
claims := h.claimsExtractor.FromRequestContext(r)
|
||||
account, user, err := h.accountManager.GetAccountFromToken(r.Context(), claims)
|
||||
_, err = h.accountManager.GetPolicy(r.Context(), accountID, policyID, userID)
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
|
||||
h.savePolicy(w, r, account, user, "")
|
||||
h.savePolicy(w, r, accountID, userID, policyID)
|
||||
}
|
||||
|
||||
// CreatePolicy handles policy creation request
|
||||
func (h *Policies) CreatePolicy(w http.ResponseWriter, r *http.Request) {
|
||||
claims := h.claimsExtractor.FromRequestContext(r)
|
||||
accountID, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims)
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
|
||||
h.savePolicy(w, r, accountID, userID, "")
|
||||
}
|
||||
|
||||
// savePolicy handles policy creation and update
|
||||
func (h *Policies) savePolicy(
|
||||
w http.ResponseWriter,
|
||||
r *http.Request,
|
||||
account *server.Account,
|
||||
user *server.User,
|
||||
policyID string,
|
||||
) {
|
||||
func (h *Policies) savePolicy(w http.ResponseWriter, r *http.Request, accountID string, userID string, policyID string) {
|
||||
var req api.PutApiPoliciesPolicyIdJSONRequestBody
|
||||
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
|
||||
util.WriteErrorResponse("couldn't parse JSON request", http.StatusBadRequest, w)
|
||||
@ -127,6 +122,8 @@ func (h *Policies) savePolicy(
|
||||
return
|
||||
}
|
||||
|
||||
isUpdate := policyID != ""
|
||||
|
||||
if policyID == "" {
|
||||
policyID = xid.New().String()
|
||||
}
|
||||
@ -141,8 +138,8 @@ func (h *Policies) savePolicy(
|
||||
pr := server.PolicyRule{
|
||||
ID: policyID, // TODO: when policy can contain multiple rules, need refactor
|
||||
Name: rule.Name,
|
||||
Destinations: groupMinimumsToStrings(account, rule.Destinations),
|
||||
Sources: groupMinimumsToStrings(account, rule.Sources),
|
||||
Destinations: rule.Destinations,
|
||||
Sources: rule.Sources,
|
||||
Bidirectional: rule.Bidirectional,
|
||||
}
|
||||
|
||||
@ -207,15 +204,21 @@ func (h *Policies) savePolicy(
|
||||
}
|
||||
|
||||
if req.SourcePostureChecks != nil {
|
||||
policy.SourcePostureChecks = sourcePostureChecksToStrings(account, *req.SourcePostureChecks)
|
||||
policy.SourcePostureChecks = *req.SourcePostureChecks
|
||||
}
|
||||
|
||||
if err := h.accountManager.SavePolicy(r.Context(), account.Id, user.Id, &policy); err != nil {
|
||||
if err := h.accountManager.SavePolicy(r.Context(), accountID, userID, &policy, isUpdate); err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
|
||||
resp := toPolicyResponse(account, &policy)
|
||||
allGroups, err := h.accountManager.GetAllGroups(r.Context(), accountID, userID)
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
|
||||
resp := toPolicyResponse(allGroups, &policy)
|
||||
if len(resp.Rules) == 0 {
|
||||
util.WriteError(r.Context(), status.Errorf(status.Internal, "no rules in the policy"), w)
|
||||
return
|
||||
@ -227,12 +230,11 @@ func (h *Policies) savePolicy(
|
||||
// DeletePolicy handles policy deletion request
|
||||
func (h *Policies) DeletePolicy(w http.ResponseWriter, r *http.Request) {
|
||||
claims := h.claimsExtractor.FromRequestContext(r)
|
||||
account, user, err := h.accountManager.GetAccountFromToken(r.Context(), claims)
|
||||
accountID, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims)
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
aID := account.Id
|
||||
|
||||
vars := mux.Vars(r)
|
||||
policyID := vars["policyId"]
|
||||
@ -241,7 +243,7 @@ func (h *Policies) DeletePolicy(w http.ResponseWriter, r *http.Request) {
|
||||
return
|
||||
}
|
||||
|
||||
if err = h.accountManager.DeletePolicy(r.Context(), aID, policyID, user.Id); err != nil {
|
||||
if err = h.accountManager.DeletePolicy(r.Context(), accountID, policyID, userID); err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
@ -252,14 +254,12 @@ func (h *Policies) DeletePolicy(w http.ResponseWriter, r *http.Request) {
|
||||
// GetPolicy handles a group Get request identified by ID
|
||||
func (h *Policies) GetPolicy(w http.ResponseWriter, r *http.Request) {
|
||||
claims := h.claimsExtractor.FromRequestContext(r)
|
||||
account, user, err := h.accountManager.GetAccountFromToken(r.Context(), claims)
|
||||
accountID, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims)
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
|
||||
switch r.Method {
|
||||
case http.MethodGet:
|
||||
vars := mux.Vars(r)
|
||||
policyID := vars["policyId"]
|
||||
if len(policyID) == 0 {
|
||||
@ -267,25 +267,33 @@ func (h *Policies) GetPolicy(w http.ResponseWriter, r *http.Request) {
|
||||
return
|
||||
}
|
||||
|
||||
policy, err := h.accountManager.GetPolicy(r.Context(), account.Id, policyID, user.Id)
|
||||
policy, err := h.accountManager.GetPolicy(r.Context(), accountID, policyID, userID)
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
|
||||
resp := toPolicyResponse(account, policy)
|
||||
allGroups, err := h.accountManager.GetAllGroups(r.Context(), accountID, userID)
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
|
||||
resp := toPolicyResponse(allGroups, policy)
|
||||
if len(resp.Rules) == 0 {
|
||||
util.WriteError(r.Context(), status.Errorf(status.Internal, "no rules in the policy"), w)
|
||||
return
|
||||
}
|
||||
|
||||
util.WriteJSONObject(r.Context(), w, resp)
|
||||
default:
|
||||
util.WriteError(r.Context(), status.Errorf(status.NotFound, "method not found"), w)
|
||||
}
|
||||
}
|
||||
|
||||
func toPolicyResponse(account *server.Account, policy *server.Policy) *api.Policy {
|
||||
func toPolicyResponse(groups []*nbgroup.Group, policy *server.Policy) *api.Policy {
|
||||
groupsMap := make(map[string]*nbgroup.Group)
|
||||
for _, group := range groups {
|
||||
groupsMap[group.ID] = group
|
||||
}
|
||||
|
||||
cache := make(map[string]api.GroupMinimum)
|
||||
ap := &api.Policy{
|
||||
Id: &policy.ID,
|
||||
@ -306,16 +314,18 @@ func toPolicyResponse(account *server.Account, policy *server.Policy) *api.Polic
|
||||
Protocol: api.PolicyRuleProtocol(r.Protocol),
|
||||
Action: api.PolicyRuleAction(r.Action),
|
||||
}
|
||||
|
||||
if len(r.Ports) != 0 {
|
||||
portsCopy := r.Ports
|
||||
rule.Ports = &portsCopy
|
||||
}
|
||||
|
||||
for _, gid := range r.Sources {
|
||||
_, ok := cache[gid]
|
||||
if ok {
|
||||
continue
|
||||
}
|
||||
if group, ok := account.Groups[gid]; ok {
|
||||
if group, ok := groupsMap[gid]; ok {
|
||||
minimum := api.GroupMinimum{
|
||||
Id: group.ID,
|
||||
Name: group.Name,
|
||||
@ -325,13 +335,14 @@ func toPolicyResponse(account *server.Account, policy *server.Policy) *api.Polic
|
||||
cache[gid] = minimum
|
||||
}
|
||||
}
|
||||
|
||||
for _, gid := range r.Destinations {
|
||||
cachedMinimum, ok := cache[gid]
|
||||
if ok {
|
||||
rule.Destinations = append(rule.Destinations, cachedMinimum)
|
||||
continue
|
||||
}
|
||||
if group, ok := account.Groups[gid]; ok {
|
||||
if group, ok := groupsMap[gid]; ok {
|
||||
minimum := api.GroupMinimum{
|
||||
Id: group.ID,
|
||||
Name: group.Name,
|
||||
@ -345,28 +356,3 @@ func toPolicyResponse(account *server.Account, policy *server.Policy) *api.Polic
|
||||
}
|
||||
return ap
|
||||
}
|
||||
|
||||
func groupMinimumsToStrings(account *server.Account, gm []string) []string {
|
||||
result := make([]string, 0, len(gm))
|
||||
for _, g := range gm {
|
||||
if _, ok := account.Groups[g]; !ok {
|
||||
continue
|
||||
}
|
||||
result = append(result, g)
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
||||
func sourcePostureChecksToStrings(account *server.Account, postureChecksIds []string) []string {
|
||||
result := make([]string, 0, len(postureChecksIds))
|
||||
for _, id := range postureChecksIds {
|
||||
for _, postureCheck := range account.PostureChecks {
|
||||
if id == postureCheck.ID {
|
||||
result = append(result, id)
|
||||
continue
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
@ -38,17 +38,23 @@ func initPoliciesTestData(policies ...*server.Policy) *Policies {
|
||||
}
|
||||
return policy, nil
|
||||
},
|
||||
SavePolicyFunc: func(_ context.Context, _, _ string, policy *server.Policy) error {
|
||||
SavePolicyFunc: func(_ context.Context, _, _ string, policy *server.Policy, _ bool) error {
|
||||
if !strings.HasPrefix(policy.ID, "id-") {
|
||||
policy.ID = "id-was-set"
|
||||
policy.Rules[0].ID = "id-was-set"
|
||||
}
|
||||
return nil
|
||||
},
|
||||
GetAccountFromTokenFunc: func(_ context.Context, claims jwtclaims.AuthorizationClaims) (*server.Account, *server.User, error) {
|
||||
user := server.NewAdminUser("test_user")
|
||||
GetAllGroupsFunc: func(ctx context.Context, accountID, userID string) ([]*nbgroup.Group, error) {
|
||||
return []*nbgroup.Group{{ID: "F"}, {ID: "G"}}, nil
|
||||
},
|
||||
GetAccountIDFromTokenFunc: func(_ context.Context, claims jwtclaims.AuthorizationClaims) (string, string, error) {
|
||||
return claims.AccountId, claims.UserId, nil
|
||||
},
|
||||
GetAccountByIDFunc: func(ctx context.Context, accountID string, userID string) (*server.Account, error) {
|
||||
user := server.NewAdminUser(userID)
|
||||
return &server.Account{
|
||||
Id: claims.AccountId,
|
||||
Id: accountID,
|
||||
Domain: "hotmail.com",
|
||||
Policies: []*server.Policy{
|
||||
{ID: "id-existed"},
|
||||
@ -60,7 +66,7 @@ func initPoliciesTestData(policies ...*server.Policy) *Policies {
|
||||
Users: map[string]*server.User{
|
||||
"test_user": user,
|
||||
},
|
||||
}, user, nil
|
||||
}, nil
|
||||
},
|
||||
},
|
||||
claimsExtractor: jwtclaims.NewClaimsExtractor(
|
||||
|
@ -37,20 +37,20 @@ func NewPostureChecksHandler(accountManager server.AccountManager, geolocationMa
|
||||
// GetAllPostureChecks list for the account
|
||||
func (p *PostureChecksHandler) GetAllPostureChecks(w http.ResponseWriter, r *http.Request) {
|
||||
claims := p.claimsExtractor.FromRequestContext(r)
|
||||
account, user, err := p.accountManager.GetAccountFromToken(r.Context(), claims)
|
||||
accountID, userID, err := p.accountManager.GetAccountIDFromToken(r.Context(), claims)
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
|
||||
accountPostureChecks, err := p.accountManager.ListPostureChecks(r.Context(), account.Id, user.Id)
|
||||
listPostureChecks, err := p.accountManager.ListPostureChecks(r.Context(), accountID, userID)
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
|
||||
postureChecks := []*api.PostureCheck{}
|
||||
for _, postureCheck := range accountPostureChecks {
|
||||
postureChecks := make([]*api.PostureCheck, 0, len(listPostureChecks))
|
||||
for _, postureCheck := range listPostureChecks {
|
||||
postureChecks = append(postureChecks, postureCheck.ToAPIResponse())
|
||||
}
|
||||
|
||||
@ -60,7 +60,7 @@ func (p *PostureChecksHandler) GetAllPostureChecks(w http.ResponseWriter, r *htt
|
||||
// UpdatePostureCheck handles update to a posture check identified by a given ID
|
||||
func (p *PostureChecksHandler) UpdatePostureCheck(w http.ResponseWriter, r *http.Request) {
|
||||
claims := p.claimsExtractor.FromRequestContext(r)
|
||||
account, user, err := p.accountManager.GetAccountFromToken(r.Context(), claims)
|
||||
accountID, userID, err := p.accountManager.GetAccountIDFromToken(r.Context(), claims)
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
@ -73,37 +73,31 @@ func (p *PostureChecksHandler) UpdatePostureCheck(w http.ResponseWriter, r *http
|
||||
return
|
||||
}
|
||||
|
||||
postureChecksIdx := -1
|
||||
for i, postureCheck := range account.PostureChecks {
|
||||
if postureCheck.ID == postureChecksID {
|
||||
postureChecksIdx = i
|
||||
break
|
||||
}
|
||||
}
|
||||
if postureChecksIdx < 0 {
|
||||
util.WriteError(r.Context(), status.Errorf(status.NotFound, "couldn't find posture checks id %s", postureChecksID), w)
|
||||
_, err = p.accountManager.GetPostureChecks(r.Context(), accountID, postureChecksID, userID)
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
|
||||
p.savePostureChecks(w, r, account, user, postureChecksID)
|
||||
p.savePostureChecks(w, r, accountID, userID, postureChecksID)
|
||||
}
|
||||
|
||||
// CreatePostureCheck handles posture check creation request
|
||||
func (p *PostureChecksHandler) CreatePostureCheck(w http.ResponseWriter, r *http.Request) {
|
||||
claims := p.claimsExtractor.FromRequestContext(r)
|
||||
account, user, err := p.accountManager.GetAccountFromToken(r.Context(), claims)
|
||||
accountID, userID, err := p.accountManager.GetAccountIDFromToken(r.Context(), claims)
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
|
||||
p.savePostureChecks(w, r, account, user, "")
|
||||
p.savePostureChecks(w, r, accountID, userID, "")
|
||||
}
|
||||
|
||||
// GetPostureCheck handles a posture check Get request identified by ID
|
||||
func (p *PostureChecksHandler) GetPostureCheck(w http.ResponseWriter, r *http.Request) {
|
||||
claims := p.claimsExtractor.FromRequestContext(r)
|
||||
account, user, err := p.accountManager.GetAccountFromToken(r.Context(), claims)
|
||||
accountID, userID, err := p.accountManager.GetAccountIDFromToken(r.Context(), claims)
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
@ -116,7 +110,7 @@ func (p *PostureChecksHandler) GetPostureCheck(w http.ResponseWriter, r *http.Re
|
||||
return
|
||||
}
|
||||
|
||||
postureChecks, err := p.accountManager.GetPostureChecks(r.Context(), account.Id, postureChecksID, user.Id)
|
||||
postureChecks, err := p.accountManager.GetPostureChecks(r.Context(), accountID, postureChecksID, userID)
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
@ -128,7 +122,7 @@ func (p *PostureChecksHandler) GetPostureCheck(w http.ResponseWriter, r *http.Re
|
||||
// DeletePostureCheck handles posture check deletion request
|
||||
func (p *PostureChecksHandler) DeletePostureCheck(w http.ResponseWriter, r *http.Request) {
|
||||
claims := p.claimsExtractor.FromRequestContext(r)
|
||||
account, user, err := p.accountManager.GetAccountFromToken(r.Context(), claims)
|
||||
accountID, userID, err := p.accountManager.GetAccountIDFromToken(r.Context(), claims)
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
@ -141,7 +135,7 @@ func (p *PostureChecksHandler) DeletePostureCheck(w http.ResponseWriter, r *http
|
||||
return
|
||||
}
|
||||
|
||||
if err = p.accountManager.DeletePostureChecks(r.Context(), account.Id, postureChecksID, user.Id); err != nil {
|
||||
if err = p.accountManager.DeletePostureChecks(r.Context(), accountID, postureChecksID, userID); err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
@ -150,13 +144,7 @@ func (p *PostureChecksHandler) DeletePostureCheck(w http.ResponseWriter, r *http
|
||||
}
|
||||
|
||||
// savePostureChecks handles posture checks create and update
|
||||
func (p *PostureChecksHandler) savePostureChecks(
|
||||
w http.ResponseWriter,
|
||||
r *http.Request,
|
||||
account *server.Account,
|
||||
user *server.User,
|
||||
postureChecksID string,
|
||||
) {
|
||||
func (p *PostureChecksHandler) savePostureChecks(w http.ResponseWriter, r *http.Request, accountID, userID, postureChecksID string) {
|
||||
var (
|
||||
err error
|
||||
req api.PostureCheckUpdate
|
||||
@ -181,7 +169,7 @@ func (p *PostureChecksHandler) savePostureChecks(
|
||||
return
|
||||
}
|
||||
|
||||
if err := p.accountManager.SavePostureChecks(r.Context(), account.Id, user.Id, postureChecks); err != nil {
|
||||
if err := p.accountManager.SavePostureChecks(r.Context(), accountID, userID, postureChecks); err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
|
@ -14,7 +14,6 @@ import (
|
||||
"github.com/gorilla/mux"
|
||||
"github.com/stretchr/testify/assert"
|
||||
|
||||
"github.com/netbirdio/netbird/management/server"
|
||||
"github.com/netbirdio/netbird/management/server/geolocation"
|
||||
"github.com/netbirdio/netbird/management/server/http/api"
|
||||
"github.com/netbirdio/netbird/management/server/jwtclaims"
|
||||
@ -67,15 +66,8 @@ func initPostureChecksTestData(postureChecks ...*posture.Checks) *PostureChecksH
|
||||
}
|
||||
return accountPostureChecks, nil
|
||||
},
|
||||
GetAccountFromTokenFunc: func(_ context.Context, claims jwtclaims.AuthorizationClaims) (*server.Account, *server.User, error) {
|
||||
user := server.NewAdminUser("test_user")
|
||||
return &server.Account{
|
||||
Id: claims.AccountId,
|
||||
Users: map[string]*server.User{
|
||||
"test_user": user,
|
||||
},
|
||||
PostureChecks: postureChecks,
|
||||
}, user, nil
|
||||
GetAccountIDFromTokenFunc: func(_ context.Context, claims jwtclaims.AuthorizationClaims) (string, string, error) {
|
||||
return claims.AccountId, claims.UserId, nil
|
||||
},
|
||||
},
|
||||
geolocationManager: &geolocation.Geolocation{},
|
||||
|
@ -43,13 +43,13 @@ func NewRoutesHandler(accountManager server.AccountManager, authCfg AuthCfg) *Ro
|
||||
// GetAllRoutes returns the list of routes for the account
|
||||
func (h *RoutesHandler) GetAllRoutes(w http.ResponseWriter, r *http.Request) {
|
||||
claims := h.claimsExtractor.FromRequestContext(r)
|
||||
account, user, err := h.accountManager.GetAccountFromToken(r.Context(), claims)
|
||||
accountID, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims)
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
|
||||
routes, err := h.accountManager.ListRoutes(r.Context(), account.Id, user.Id)
|
||||
routes, err := h.accountManager.ListRoutes(r.Context(), accountID, userID)
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
@ -70,7 +70,7 @@ func (h *RoutesHandler) GetAllRoutes(w http.ResponseWriter, r *http.Request) {
|
||||
// CreateRoute handles route creation request
|
||||
func (h *RoutesHandler) CreateRoute(w http.ResponseWriter, r *http.Request) {
|
||||
claims := h.claimsExtractor.FromRequestContext(r)
|
||||
account, user, err := h.accountManager.GetAccountFromToken(r.Context(), claims)
|
||||
accountID, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims)
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
@ -117,15 +117,9 @@ func (h *RoutesHandler) CreateRoute(w http.ResponseWriter, r *http.Request) {
|
||||
peerGroupIds = *req.PeerGroups
|
||||
}
|
||||
|
||||
// Do not allow non-Linux peers
|
||||
if peer := account.GetPeer(peerId); peer != nil {
|
||||
if peer.Meta.GoOS != "linux" {
|
||||
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "non-linux peers are not supported as network routes"), w)
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
newRoute, err := h.accountManager.CreateRoute(r.Context(), account.Id, newPrefix, networkType, domains, peerId, peerGroupIds, req.Description, route.NetID(req.NetworkId), req.Masquerade, req.Metric, req.Groups, req.Enabled, user.Id, req.KeepRoute)
|
||||
newRoute, err := h.accountManager.CreateRoute(r.Context(), accountID, newPrefix, networkType, domains, peerId, peerGroupIds,
|
||||
req.Description, route.NetID(req.NetworkId), req.Masquerade, req.Metric, req.Groups, req.Enabled, userID, req.KeepRoute,
|
||||
)
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
@ -168,7 +162,7 @@ func (h *RoutesHandler) validateRoute(req api.PostApiRoutesJSONRequestBody) erro
|
||||
// UpdateRoute handles update to a route identified by a given ID
|
||||
func (h *RoutesHandler) UpdateRoute(w http.ResponseWriter, r *http.Request) {
|
||||
claims := h.claimsExtractor.FromRequestContext(r)
|
||||
account, user, err := h.accountManager.GetAccountFromToken(r.Context(), claims)
|
||||
accountID, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims)
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
@ -181,7 +175,7 @@ func (h *RoutesHandler) UpdateRoute(w http.ResponseWriter, r *http.Request) {
|
||||
return
|
||||
}
|
||||
|
||||
_, err = h.accountManager.GetRoute(r.Context(), account.Id, route.ID(routeID), user.Id)
|
||||
_, err = h.accountManager.GetRoute(r.Context(), accountID, route.ID(routeID), userID)
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
@ -204,14 +198,6 @@ func (h *RoutesHandler) UpdateRoute(w http.ResponseWriter, r *http.Request) {
|
||||
peerID = *req.Peer
|
||||
}
|
||||
|
||||
// do not allow non Linux peers
|
||||
if peer := account.GetPeer(peerID); peer != nil {
|
||||
if peer.Meta.GoOS != "linux" {
|
||||
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "non-linux peers are non supported as network routes"), w)
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
newRoute := &route.Route{
|
||||
ID: route.ID(routeID),
|
||||
NetID: route.NetID(req.NetworkId),
|
||||
@ -247,7 +233,7 @@ func (h *RoutesHandler) UpdateRoute(w http.ResponseWriter, r *http.Request) {
|
||||
newRoute.PeerGroups = *req.PeerGroups
|
||||
}
|
||||
|
||||
err = h.accountManager.SaveRoute(r.Context(), account.Id, user.Id, newRoute)
|
||||
err = h.accountManager.SaveRoute(r.Context(), accountID, userID, newRoute)
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
@ -265,7 +251,7 @@ func (h *RoutesHandler) UpdateRoute(w http.ResponseWriter, r *http.Request) {
|
||||
// DeleteRoute handles route deletion request
|
||||
func (h *RoutesHandler) DeleteRoute(w http.ResponseWriter, r *http.Request) {
|
||||
claims := h.claimsExtractor.FromRequestContext(r)
|
||||
account, user, err := h.accountManager.GetAccountFromToken(r.Context(), claims)
|
||||
accountID, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims)
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
@ -277,7 +263,7 @@ func (h *RoutesHandler) DeleteRoute(w http.ResponseWriter, r *http.Request) {
|
||||
return
|
||||
}
|
||||
|
||||
err = h.accountManager.DeleteRoute(r.Context(), account.Id, route.ID(routeID), user.Id)
|
||||
err = h.accountManager.DeleteRoute(r.Context(), accountID, route.ID(routeID), userID)
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
@ -289,7 +275,7 @@ func (h *RoutesHandler) DeleteRoute(w http.ResponseWriter, r *http.Request) {
|
||||
// GetRoute handles a route Get request identified by ID
|
||||
func (h *RoutesHandler) GetRoute(w http.ResponseWriter, r *http.Request) {
|
||||
claims := h.claimsExtractor.FromRequestContext(r)
|
||||
account, user, err := h.accountManager.GetAccountFromToken(r.Context(), claims)
|
||||
accountID, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims)
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
@ -301,7 +287,7 @@ func (h *RoutesHandler) GetRoute(w http.ResponseWriter, r *http.Request) {
|
||||
return
|
||||
}
|
||||
|
||||
foundRoute, err := h.accountManager.GetRoute(r.Context(), account.Id, route.ID(routeID), user.Id)
|
||||
foundRoute, err := h.accountManager.GetRoute(r.Context(), accountID, route.ID(routeID), userID)
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), status.Errorf(status.NotFound, "route not found"), w)
|
||||
return
|
||||
|
@ -112,6 +112,12 @@ func initRoutesTestData() *RoutesHandler {
|
||||
if len(peerGroups) > 0 && peerGroups[0] == notFoundGroupID {
|
||||
return nil, status.Errorf(status.InvalidArgument, "peer groups with ID %s not found", peerGroups[0])
|
||||
}
|
||||
if peerID != "" {
|
||||
if peerID == nonLinuxExistingPeerID {
|
||||
return nil, status.Errorf(status.InvalidArgument, "non-linux peers are not supported as network routes")
|
||||
}
|
||||
}
|
||||
|
||||
return &route.Route{
|
||||
ID: existingRouteID,
|
||||
NetID: netID,
|
||||
@ -131,6 +137,11 @@ func initRoutesTestData() *RoutesHandler {
|
||||
if r.Peer == notFoundPeerID {
|
||||
return status.Errorf(status.InvalidArgument, "peer with ID %s not found", r.Peer)
|
||||
}
|
||||
|
||||
if r.Peer == nonLinuxExistingPeerID {
|
||||
return status.Errorf(status.InvalidArgument, "non-linux peers are not supported as network routes")
|
||||
}
|
||||
|
||||
return nil
|
||||
},
|
||||
DeleteRouteFunc: func(_ context.Context, _ string, routeID route.ID, _ string) error {
|
||||
@ -139,8 +150,9 @@ func initRoutesTestData() *RoutesHandler {
|
||||
}
|
||||
return nil
|
||||
},
|
||||
GetAccountFromTokenFunc: func(_ context.Context, _ jwtclaims.AuthorizationClaims) (*server.Account, *server.User, error) {
|
||||
return testingAccount, testingAccount.Users["test_user"], nil
|
||||
GetAccountIDFromTokenFunc: func(_ context.Context, _ jwtclaims.AuthorizationClaims) (string, string, error) {
|
||||
//return testingAccount, testingAccount.Users["test_user"], nil
|
||||
return testingAccount.Id, testingAccount.Users["test_user"].Id, nil
|
||||
},
|
||||
},
|
||||
claimsExtractor: jwtclaims.NewClaimsExtractor(
|
||||
|
@ -35,7 +35,7 @@ func NewSetupKeysHandler(accountManager server.AccountManager, authCfg AuthCfg)
|
||||
// CreateSetupKey is a POST requests that creates a new SetupKey
|
||||
func (h *SetupKeysHandler) CreateSetupKey(w http.ResponseWriter, r *http.Request) {
|
||||
claims := h.claimsExtractor.FromRequestContext(r)
|
||||
account, user, err := h.accountManager.GetAccountFromToken(r.Context(), claims)
|
||||
accountID, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims)
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
@ -76,8 +76,8 @@ func (h *SetupKeysHandler) CreateSetupKey(w http.ResponseWriter, r *http.Request
|
||||
if req.Ephemeral != nil {
|
||||
ephemeral = *req.Ephemeral
|
||||
}
|
||||
setupKey, err := h.accountManager.CreateSetupKey(r.Context(), account.Id, req.Name, server.SetupKeyType(req.Type), expiresIn,
|
||||
req.AutoGroups, req.UsageLimit, user.Id, ephemeral)
|
||||
setupKey, err := h.accountManager.CreateSetupKey(r.Context(), accountID, req.Name, server.SetupKeyType(req.Type), expiresIn,
|
||||
req.AutoGroups, req.UsageLimit, userID, ephemeral)
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
@ -89,7 +89,7 @@ func (h *SetupKeysHandler) CreateSetupKey(w http.ResponseWriter, r *http.Request
|
||||
// GetSetupKey is a GET request to get a SetupKey by ID
|
||||
func (h *SetupKeysHandler) GetSetupKey(w http.ResponseWriter, r *http.Request) {
|
||||
claims := h.claimsExtractor.FromRequestContext(r)
|
||||
account, user, err := h.accountManager.GetAccountFromToken(r.Context(), claims)
|
||||
accountID, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims)
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
@ -102,7 +102,7 @@ func (h *SetupKeysHandler) GetSetupKey(w http.ResponseWriter, r *http.Request) {
|
||||
return
|
||||
}
|
||||
|
||||
key, err := h.accountManager.GetSetupKey(r.Context(), account.Id, user.Id, keyID)
|
||||
key, err := h.accountManager.GetSetupKey(r.Context(), accountID, userID, keyID)
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
@ -114,7 +114,7 @@ func (h *SetupKeysHandler) GetSetupKey(w http.ResponseWriter, r *http.Request) {
|
||||
// UpdateSetupKey is a PUT request to update server.SetupKey
|
||||
func (h *SetupKeysHandler) UpdateSetupKey(w http.ResponseWriter, r *http.Request) {
|
||||
claims := h.claimsExtractor.FromRequestContext(r)
|
||||
account, user, err := h.accountManager.GetAccountFromToken(r.Context(), claims)
|
||||
accountID, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims)
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
@ -150,7 +150,7 @@ func (h *SetupKeysHandler) UpdateSetupKey(w http.ResponseWriter, r *http.Request
|
||||
newKey.Name = req.Name
|
||||
newKey.Id = keyID
|
||||
|
||||
newKey, err = h.accountManager.SaveSetupKey(r.Context(), account.Id, newKey, user.Id)
|
||||
newKey, err = h.accountManager.SaveSetupKey(r.Context(), accountID, newKey, userID)
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
@ -161,13 +161,13 @@ func (h *SetupKeysHandler) UpdateSetupKey(w http.ResponseWriter, r *http.Request
|
||||
// GetAllSetupKeys is a GET request that returns a list of SetupKey
|
||||
func (h *SetupKeysHandler) GetAllSetupKeys(w http.ResponseWriter, r *http.Request) {
|
||||
claims := h.claimsExtractor.FromRequestContext(r)
|
||||
account, user, err := h.accountManager.GetAccountFromToken(r.Context(), claims)
|
||||
accountID, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims)
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
|
||||
setupKeys, err := h.accountManager.ListSetupKeys(r.Context(), account.Id, user.Id)
|
||||
setupKeys, err := h.accountManager.ListSetupKeys(r.Context(), accountID, userID)
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
|
@ -15,7 +15,6 @@ import (
|
||||
"github.com/stretchr/testify/assert"
|
||||
|
||||
"github.com/netbirdio/netbird/management/server"
|
||||
nbgroup "github.com/netbirdio/netbird/management/server/group"
|
||||
"github.com/netbirdio/netbird/management/server/http/api"
|
||||
"github.com/netbirdio/netbird/management/server/jwtclaims"
|
||||
"github.com/netbirdio/netbird/management/server/mock_server"
|
||||
@ -34,21 +33,8 @@ func initSetupKeysTestMetaData(defaultKey *server.SetupKey, newKey *server.Setup
|
||||
) *SetupKeysHandler {
|
||||
return &SetupKeysHandler{
|
||||
accountManager: &mock_server.MockAccountManager{
|
||||
GetAccountFromTokenFunc: func(_ context.Context, claims jwtclaims.AuthorizationClaims) (*server.Account, *server.User, error) {
|
||||
return &server.Account{
|
||||
Id: testAccountID,
|
||||
Domain: "hotmail.com",
|
||||
Users: map[string]*server.User{
|
||||
user.Id: user,
|
||||
},
|
||||
SetupKeys: map[string]*server.SetupKey{
|
||||
defaultKey.Key: defaultKey,
|
||||
},
|
||||
Groups: map[string]*nbgroup.Group{
|
||||
"group-1": {ID: "group-1", Peers: []string{"A", "B"}},
|
||||
"id-all": {ID: "id-all", Name: "All"},
|
||||
},
|
||||
}, user, nil
|
||||
GetAccountIDFromTokenFunc: func(_ context.Context, claims jwtclaims.AuthorizationClaims) (string, string, error) {
|
||||
return claims.AccountId, claims.UserId, nil
|
||||
},
|
||||
CreateSetupKeyFunc: func(_ context.Context, _ string, keyName string, typ server.SetupKeyType, _ time.Duration, _ []string,
|
||||
_ int, _ string, ephemeral bool,
|
||||
|
@ -41,22 +41,22 @@ func (h *UsersHandler) UpdateUser(w http.ResponseWriter, r *http.Request) {
|
||||
}
|
||||
|
||||
claims := h.claimsExtractor.FromRequestContext(r)
|
||||
account, user, err := h.accountManager.GetAccountFromToken(r.Context(), claims)
|
||||
accountID, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims)
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
|
||||
vars := mux.Vars(r)
|
||||
userID := vars["userId"]
|
||||
if len(userID) == 0 {
|
||||
targetUserID := vars["userId"]
|
||||
if len(targetUserID) == 0 {
|
||||
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "invalid user ID"), w)
|
||||
return
|
||||
}
|
||||
|
||||
existingUser, ok := account.Users[userID]
|
||||
if !ok {
|
||||
util.WriteError(r.Context(), status.Errorf(status.NotFound, "couldn't find user with ID %s", userID), w)
|
||||
existingUser, err := h.accountManager.GetUserByID(r.Context(), targetUserID)
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
|
||||
@ -78,8 +78,8 @@ func (h *UsersHandler) UpdateUser(w http.ResponseWriter, r *http.Request) {
|
||||
return
|
||||
}
|
||||
|
||||
newUser, err := h.accountManager.SaveUser(r.Context(), account.Id, user.Id, &server.User{
|
||||
Id: userID,
|
||||
newUser, err := h.accountManager.SaveUser(r.Context(), accountID, userID, &server.User{
|
||||
Id: targetUserID,
|
||||
Role: userRole,
|
||||
AutoGroups: req.AutoGroups,
|
||||
Blocked: req.IsBlocked,
|
||||
@ -102,7 +102,7 @@ func (h *UsersHandler) DeleteUser(w http.ResponseWriter, r *http.Request) {
|
||||
}
|
||||
|
||||
claims := h.claimsExtractor.FromRequestContext(r)
|
||||
account, user, err := h.accountManager.GetAccountFromToken(r.Context(), claims)
|
||||
accountID, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims)
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
@ -115,7 +115,7 @@ func (h *UsersHandler) DeleteUser(w http.ResponseWriter, r *http.Request) {
|
||||
return
|
||||
}
|
||||
|
||||
err = h.accountManager.DeleteUser(r.Context(), account.Id, user.Id, targetUserID)
|
||||
err = h.accountManager.DeleteUser(r.Context(), accountID, userID, targetUserID)
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
@ -132,7 +132,7 @@ func (h *UsersHandler) CreateUser(w http.ResponseWriter, r *http.Request) {
|
||||
}
|
||||
|
||||
claims := h.claimsExtractor.FromRequestContext(r)
|
||||
account, user, err := h.accountManager.GetAccountFromToken(r.Context(), claims)
|
||||
accountID, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims)
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
@ -160,7 +160,7 @@ func (h *UsersHandler) CreateUser(w http.ResponseWriter, r *http.Request) {
|
||||
name = *req.Name
|
||||
}
|
||||
|
||||
newUser, err := h.accountManager.CreateUser(r.Context(), account.Id, user.Id, &server.UserInfo{
|
||||
newUser, err := h.accountManager.CreateUser(r.Context(), accountID, userID, &server.UserInfo{
|
||||
Email: email,
|
||||
Name: name,
|
||||
Role: req.Role,
|
||||
@ -184,13 +184,13 @@ func (h *UsersHandler) GetAllUsers(w http.ResponseWriter, r *http.Request) {
|
||||
}
|
||||
|
||||
claims := h.claimsExtractor.FromRequestContext(r)
|
||||
account, user, err := h.accountManager.GetAccountFromToken(r.Context(), claims)
|
||||
accountID, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims)
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
|
||||
data, err := h.accountManager.GetUsersFromAccount(r.Context(), account.Id, user.Id)
|
||||
data, err := h.accountManager.GetUsersFromAccount(r.Context(), accountID, userID)
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
@ -231,7 +231,7 @@ func (h *UsersHandler) InviteUser(w http.ResponseWriter, r *http.Request) {
|
||||
}
|
||||
|
||||
claims := h.claimsExtractor.FromRequestContext(r)
|
||||
account, user, err := h.accountManager.GetAccountFromToken(r.Context(), claims)
|
||||
accountID, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims)
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
@ -244,7 +244,7 @@ func (h *UsersHandler) InviteUser(w http.ResponseWriter, r *http.Request) {
|
||||
return
|
||||
}
|
||||
|
||||
err = h.accountManager.InviteUser(r.Context(), account.Id, user.Id, targetUserID)
|
||||
err = h.accountManager.InviteUser(r.Context(), accountID, userID, targetUserID)
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
|
@ -64,8 +64,11 @@ var usersTestAccount = &server.Account{
|
||||
func initUsersTestData() *UsersHandler {
|
||||
return &UsersHandler{
|
||||
accountManager: &mock_server.MockAccountManager{
|
||||
GetAccountFromTokenFunc: func(_ context.Context, claims jwtclaims.AuthorizationClaims) (*server.Account, *server.User, error) {
|
||||
return usersTestAccount, usersTestAccount.Users[claims.UserId], nil
|
||||
GetAccountIDFromTokenFunc: func(_ context.Context, claims jwtclaims.AuthorizationClaims) (string, string, error) {
|
||||
return usersTestAccount.Id, claims.UserId, nil
|
||||
},
|
||||
GetUserByIDFunc: func(ctx context.Context, id string) (*server.User, error) {
|
||||
return usersTestAccount.Users[id], nil
|
||||
},
|
||||
GetUsersFromAccountFunc: func(_ context.Context, accountID, userID string) ([]*server.UserInfo, error) {
|
||||
users := make([]*server.UserInfo, 0)
|
||||
|
@ -23,10 +23,11 @@ import (
|
||||
|
||||
type MockAccountManager struct {
|
||||
GetOrCreateAccountByUserFunc func(ctx context.Context, userId, domain string) (*server.Account, error)
|
||||
GetAccountFunc func(ctx context.Context, accountID string) (*server.Account, error)
|
||||
CreateSetupKeyFunc func(ctx context.Context, accountId string, keyName string, keyType server.SetupKeyType,
|
||||
expiresIn time.Duration, autoGroups []string, usageLimit int, userID string, ephemeral bool) (*server.SetupKey, error)
|
||||
GetSetupKeyFunc func(ctx context.Context, accountID, userID, keyID string) (*server.SetupKey, error)
|
||||
GetAccountByUserOrAccountIdFunc func(ctx context.Context, userId, accountId, domain string) (*server.Account, error)
|
||||
GetAccountIDByUserOrAccountIdFunc func(ctx context.Context, userId, accountId, domain string) (string, error)
|
||||
GetUserFunc func(ctx context.Context, claims jwtclaims.AuthorizationClaims) (*server.User, error)
|
||||
ListUsersFunc func(ctx context.Context, accountID string) ([]*server.User, error)
|
||||
GetPeersFunc func(ctx context.Context, accountID, userID string) ([]*nbpeer.Peer, error)
|
||||
@ -48,7 +49,7 @@ type MockAccountManager struct {
|
||||
GroupDeletePeerFunc func(ctx context.Context, accountID, groupID, peerID string) error
|
||||
DeleteRuleFunc func(ctx context.Context, accountID, ruleID, userID string) error
|
||||
GetPolicyFunc func(ctx context.Context, accountID, policyID, userID string) (*server.Policy, error)
|
||||
SavePolicyFunc func(ctx context.Context, accountID, userID string, policy *server.Policy) error
|
||||
SavePolicyFunc func(ctx context.Context, accountID, userID string, policy *server.Policy, isUpdate bool) error
|
||||
DeletePolicyFunc func(ctx context.Context, accountID, policyID, userID string) error
|
||||
ListPoliciesFunc func(ctx context.Context, accountID, userID string) ([]*server.Policy, error)
|
||||
GetUsersFromAccountFunc func(ctx context.Context, accountID, userID string) ([]*server.UserInfo, error)
|
||||
@ -79,7 +80,7 @@ type MockAccountManager struct {
|
||||
DeleteNameServerGroupFunc func(ctx context.Context, accountID, nsGroupID, userID string) error
|
||||
ListNameServerGroupsFunc func(ctx context.Context, accountID string, userID string) ([]*nbdns.NameServerGroup, error)
|
||||
CreateUserFunc func(ctx context.Context, accountID, userID string, key *server.UserInfo) (*server.UserInfo, error)
|
||||
GetAccountFromTokenFunc func(ctx context.Context, claims jwtclaims.AuthorizationClaims) (*server.Account, *server.User, error)
|
||||
GetAccountIDFromTokenFunc func(ctx context.Context, claims jwtclaims.AuthorizationClaims) (string, string, error)
|
||||
CheckUserAccessByJWTGroupsFunc func(ctx context.Context, claims jwtclaims.AuthorizationClaims) error
|
||||
DeleteAccountFunc func(ctx context.Context, accountID, userID string) error
|
||||
GetDNSDomainFunc func() string
|
||||
@ -105,6 +106,9 @@ type MockAccountManager struct {
|
||||
SyncPeerMetaFunc func(ctx context.Context, peerPubKey string, meta nbpeer.PeerSystemMeta) error
|
||||
FindExistingPostureCheckFunc func(accountID string, checks *posture.ChecksDefinition) (*posture.Checks, error)
|
||||
GetAccountIDForPeerKeyFunc func(ctx context.Context, peerKey string) (string, error)
|
||||
GetAccountByIDFunc func(ctx context.Context, accountID string, userID string) (*server.Account, error)
|
||||
GetUserByIDFunc func(ctx context.Context, id string) (*server.User, error)
|
||||
GetAccountSettingsFunc func(ctx context.Context, accountID string, userID string) (*server.Settings, error)
|
||||
}
|
||||
|
||||
func (am *MockAccountManager) SyncAndMarkPeer(ctx context.Context, accountID string, peerPubKey string, meta nbpeer.PeerSystemMeta, realIP net.IP) (*nbpeer.Peer, *server.NetworkMap, []*posture.Checks, error) {
|
||||
@ -190,16 +194,14 @@ func (am *MockAccountManager) CreateSetupKey(
|
||||
return nil, status.Errorf(codes.Unimplemented, "method CreateSetupKey is not implemented")
|
||||
}
|
||||
|
||||
// GetAccountByUserOrAccountID mock implementation of GetAccountByUserOrAccountID from server.AccountManager interface
|
||||
func (am *MockAccountManager) GetAccountByUserOrAccountID(
|
||||
ctx context.Context, userId, accountId, domain string,
|
||||
) (*server.Account, error) {
|
||||
if am.GetAccountByUserOrAccountIdFunc != nil {
|
||||
return am.GetAccountByUserOrAccountIdFunc(ctx, userId, accountId, domain)
|
||||
// GetAccountIDByUserOrAccountID mock implementation of GetAccountIDByUserOrAccountID from server.AccountManager interface
|
||||
func (am *MockAccountManager) GetAccountIDByUserOrAccountID(ctx context.Context, userId, accountId, domain string) (string, error) {
|
||||
if am.GetAccountIDByUserOrAccountIdFunc != nil {
|
||||
return am.GetAccountIDByUserOrAccountIdFunc(ctx, userId, accountId, domain)
|
||||
}
|
||||
return nil, status.Errorf(
|
||||
return "", status.Errorf(
|
||||
codes.Unimplemented,
|
||||
"method GetAccountByUserOrAccountID is not implemented",
|
||||
"method GetAccountIDByUserOrAccountID is not implemented",
|
||||
)
|
||||
}
|
||||
|
||||
@ -377,9 +379,9 @@ func (am *MockAccountManager) GetPolicy(ctx context.Context, accountID, policyID
|
||||
}
|
||||
|
||||
// SavePolicy mock implementation of SavePolicy from server.AccountManager interface
|
||||
func (am *MockAccountManager) SavePolicy(ctx context.Context, accountID, userID string, policy *server.Policy) error {
|
||||
func (am *MockAccountManager) SavePolicy(ctx context.Context, accountID, userID string, policy *server.Policy, isUpdate bool) error {
|
||||
if am.SavePolicyFunc != nil {
|
||||
return am.SavePolicyFunc(ctx, accountID, userID, policy)
|
||||
return am.SavePolicyFunc(ctx, accountID, userID, policy, isUpdate)
|
||||
}
|
||||
return status.Errorf(codes.Unimplemented, "method SavePolicy is not implemented")
|
||||
}
|
||||
@ -601,14 +603,12 @@ func (am *MockAccountManager) CreateUser(ctx context.Context, accountID, userID
|
||||
return nil, status.Errorf(codes.Unimplemented, "method CreateUser is not implemented")
|
||||
}
|
||||
|
||||
// GetAccountFromToken mocks GetAccountFromToken of the AccountManager interface
|
||||
func (am *MockAccountManager) GetAccountFromToken(ctx context.Context, claims jwtclaims.AuthorizationClaims) (*server.Account, *server.User,
|
||||
error,
|
||||
) {
|
||||
if am.GetAccountFromTokenFunc != nil {
|
||||
return am.GetAccountFromTokenFunc(ctx, claims)
|
||||
// GetAccountIDFromToken mocks GetAccountIDFromToken of the AccountManager interface
|
||||
func (am *MockAccountManager) GetAccountIDFromToken(ctx context.Context, claims jwtclaims.AuthorizationClaims) (string, string, error) {
|
||||
if am.GetAccountIDFromTokenFunc != nil {
|
||||
return am.GetAccountIDFromTokenFunc(ctx, claims)
|
||||
}
|
||||
return nil, nil, status.Errorf(codes.Unimplemented, "method GetAccountFromToken is not implemented")
|
||||
return "", "", status.Errorf(codes.Unimplemented, "method GetAccountIDFromToken is not implemented")
|
||||
}
|
||||
|
||||
func (am *MockAccountManager) CheckUserAccessByJWTGroups(ctx context.Context, claims jwtclaims.AuthorizationClaims) error {
|
||||
@ -802,3 +802,33 @@ func (am *MockAccountManager) GetAccountIDForPeerKey(ctx context.Context, peerKe
|
||||
}
|
||||
return "", status.Errorf(codes.Unimplemented, "method GetAccountIDForPeerKey is not implemented")
|
||||
}
|
||||
|
||||
// GetAccountByID mocks GetAccountByID of the AccountManager interface
|
||||
func (am *MockAccountManager) GetAccountByID(ctx context.Context, accountID string, userID string) (*server.Account, error) {
|
||||
if am.GetAccountByIDFunc != nil {
|
||||
return am.GetAccountByIDFunc(ctx, accountID, userID)
|
||||
}
|
||||
return nil, status.Errorf(codes.Unimplemented, "method GetAccountByID is not implemented")
|
||||
}
|
||||
|
||||
// GetUserByID mocks GetUserByID of the AccountManager interface
|
||||
func (am *MockAccountManager) GetUserByID(ctx context.Context, id string) (*server.User, error) {
|
||||
if am.GetUserByIDFunc != nil {
|
||||
return am.GetUserByIDFunc(ctx, id)
|
||||
}
|
||||
return nil, status.Errorf(codes.Unimplemented, "method GetUserByID is not implemented")
|
||||
}
|
||||
|
||||
func (am *MockAccountManager) GetAccountSettings(ctx context.Context, accountID string, userID string) (*server.Settings, error) {
|
||||
if am.GetAccountSettingsFunc != nil {
|
||||
return am.GetAccountSettingsFunc(ctx, accountID, userID)
|
||||
}
|
||||
return nil, status.Errorf(codes.Unimplemented, "method GetAccountSettings is not implemented")
|
||||
}
|
||||
|
||||
func (am *MockAccountManager) GetAccount(ctx context.Context, accountID string) (*server.Account, error) {
|
||||
if am.GetAccountFunc != nil {
|
||||
return am.GetAccountFunc(ctx, accountID)
|
||||
}
|
||||
return nil, status.Errorf(codes.Unimplemented, "method GetAccount is not implemented")
|
||||
}
|
||||
|
@ -19,30 +19,16 @@ const domainPattern = `^(?i)[a-z0-9]+([\-\.]{1}[a-z0-9]+)*\.[a-z]{2,}$`
|
||||
|
||||
// GetNameServerGroup gets a nameserver group object from account and nameserver group IDs
|
||||
func (am *DefaultAccountManager) GetNameServerGroup(ctx context.Context, accountID, userID, nsGroupID string) (*nbdns.NameServerGroup, error) {
|
||||
|
||||
unlock := am.Store.AcquireWriteLockByUID(ctx, accountID)
|
||||
defer unlock()
|
||||
|
||||
account, err := am.Store.GetAccount(ctx, accountID)
|
||||
user, err := am.Store.GetUserByUserID(ctx, LockingStrengthShare, userID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
user, err := account.FindUser(userID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if !(user.HasAdminPower() || user.IsServiceUser) {
|
||||
if !user.IsAdminOrServiceUser() || user.AccountID != accountID {
|
||||
return nil, status.Errorf(status.PermissionDenied, "only users with admin power can view name server groups")
|
||||
}
|
||||
|
||||
nsGroup, found := account.NameServerGroups[nsGroupID]
|
||||
if found {
|
||||
return nsGroup.Copy(), nil
|
||||
}
|
||||
|
||||
return nil, status.Errorf(status.NotFound, "nameserver group with ID %s not found", nsGroupID)
|
||||
return am.Store.GetNameServerGroupByID(ctx, LockingStrengthShare, nsGroupID, accountID)
|
||||
}
|
||||
|
||||
// CreateNameServerGroup creates and saves a new nameserver group
|
||||
@ -159,30 +145,16 @@ func (am *DefaultAccountManager) DeleteNameServerGroup(ctx context.Context, acco
|
||||
|
||||
// ListNameServerGroups returns a list of nameserver groups from account
|
||||
func (am *DefaultAccountManager) ListNameServerGroups(ctx context.Context, accountID string, userID string) ([]*nbdns.NameServerGroup, error) {
|
||||
|
||||
unlock := am.Store.AcquireWriteLockByUID(ctx, accountID)
|
||||
defer unlock()
|
||||
|
||||
account, err := am.Store.GetAccount(ctx, accountID)
|
||||
user, err := am.Store.GetUserByUserID(ctx, LockingStrengthShare, userID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
user, err := account.FindUser(userID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if !(user.HasAdminPower() || user.IsServiceUser) {
|
||||
if !user.IsAdminOrServiceUser() || user.AccountID != accountID {
|
||||
return nil, status.Errorf(status.PermissionDenied, "only users with admin power can view name server groups")
|
||||
}
|
||||
|
||||
nsGroups := make([]*nbdns.NameServerGroup, 0, len(account.NameServerGroups))
|
||||
for _, item := range account.NameServerGroups {
|
||||
nsGroups = append(nsGroups, item.Copy())
|
||||
}
|
||||
|
||||
return nsGroups, nil
|
||||
return am.Store.GetAccountNameServerGroups(ctx, LockingStrengthShare, accountID)
|
||||
}
|
||||
|
||||
func validateNameServerGroup(existingGroup bool, nameserverGroup *nbdns.NameServerGroup, account *Account) error {
|
||||
|
@ -251,7 +251,7 @@ func TestAccountManager_GetNetworkMapWithPolicy(t *testing.T) {
|
||||
Action: PolicyTrafficActionAccept,
|
||||
},
|
||||
}
|
||||
err = manager.SavePolicy(context.Background(), account.Id, userID, &policy)
|
||||
err = manager.SavePolicy(context.Background(), account.Id, userID, &policy, false)
|
||||
if err != nil {
|
||||
t.Errorf("expecting rule to be added, got failure %v", err)
|
||||
return
|
||||
@ -299,7 +299,7 @@ func TestAccountManager_GetNetworkMapWithPolicy(t *testing.T) {
|
||||
}
|
||||
|
||||
policy.Enabled = false
|
||||
err = manager.SavePolicy(context.Background(), account.Id, userID, &policy)
|
||||
err = manager.SavePolicy(context.Background(), account.Id, userID, &policy, true)
|
||||
if err != nil {
|
||||
t.Errorf("expecting rule to be added, got failure %v", err)
|
||||
return
|
||||
|
@ -3,6 +3,7 @@ package server
|
||||
import (
|
||||
"context"
|
||||
_ "embed"
|
||||
"slices"
|
||||
"strconv"
|
||||
"strings"
|
||||
|
||||
@ -314,34 +315,20 @@ func (a *Account) connResourcesGenerator(ctx context.Context) (func(*PolicyRule,
|
||||
|
||||
// GetPolicy from the store
|
||||
func (am *DefaultAccountManager) GetPolicy(ctx context.Context, accountID, policyID, userID string) (*Policy, error) {
|
||||
unlock := am.Store.AcquireWriteLockByUID(ctx, accountID)
|
||||
defer unlock()
|
||||
|
||||
account, err := am.Store.GetAccount(ctx, accountID)
|
||||
user, err := am.Store.GetUserByUserID(ctx, LockingStrengthShare, userID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
user, err := account.FindUser(userID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if !(user.HasAdminPower() || user.IsServiceUser) {
|
||||
if !user.IsAdminOrServiceUser() || user.AccountID != accountID {
|
||||
return nil, status.Errorf(status.PermissionDenied, "only users with admin power are allowed to view policies")
|
||||
}
|
||||
|
||||
for _, policy := range account.Policies {
|
||||
if policy.ID == policyID {
|
||||
return policy, nil
|
||||
}
|
||||
}
|
||||
|
||||
return nil, status.Errorf(status.NotFound, "policy with ID %s not found", policyID)
|
||||
return am.Store.GetPolicyByID(ctx, LockingStrengthShare, policyID, accountID)
|
||||
}
|
||||
|
||||
// SavePolicy in the store
|
||||
func (am *DefaultAccountManager) SavePolicy(ctx context.Context, accountID, userID string, policy *Policy) error {
|
||||
func (am *DefaultAccountManager) SavePolicy(ctx context.Context, accountID, userID string, policy *Policy, isUpdate bool) error {
|
||||
unlock := am.Store.AcquireWriteLockByUID(ctx, accountID)
|
||||
defer unlock()
|
||||
|
||||
@ -350,7 +337,9 @@ func (am *DefaultAccountManager) SavePolicy(ctx context.Context, accountID, user
|
||||
return err
|
||||
}
|
||||
|
||||
exists := am.savePolicy(account, policy)
|
||||
if err = am.savePolicy(account, policy, isUpdate); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
account.Network.IncSerial()
|
||||
if err = am.Store.SaveAccount(ctx, account); err != nil {
|
||||
@ -358,7 +347,7 @@ func (am *DefaultAccountManager) SavePolicy(ctx context.Context, accountID, user
|
||||
}
|
||||
|
||||
action := activity.PolicyAdded
|
||||
if exists {
|
||||
if isUpdate {
|
||||
action = activity.PolicyUpdated
|
||||
}
|
||||
am.StoreEvent(ctx, userID, policy.ID, accountID, action, policy.EventMeta())
|
||||
@ -397,24 +386,16 @@ func (am *DefaultAccountManager) DeletePolicy(ctx context.Context, accountID, po
|
||||
|
||||
// ListPolicies from the store
|
||||
func (am *DefaultAccountManager) ListPolicies(ctx context.Context, accountID, userID string) ([]*Policy, error) {
|
||||
unlock := am.Store.AcquireWriteLockByUID(ctx, accountID)
|
||||
defer unlock()
|
||||
|
||||
account, err := am.Store.GetAccount(ctx, accountID)
|
||||
user, err := am.Store.GetUserByUserID(ctx, LockingStrengthShare, userID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
user, err := account.FindUser(userID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
if !user.IsAdminOrServiceUser() || user.AccountID != accountID {
|
||||
return nil, status.Errorf(status.PermissionDenied, "only users with admin power are allowed to view policies")
|
||||
}
|
||||
|
||||
if !(user.HasAdminPower() || user.IsServiceUser) {
|
||||
return nil, status.Errorf(status.PermissionDenied, "only users with admin power can view policies")
|
||||
}
|
||||
|
||||
return account.Policies, nil
|
||||
return am.Store.GetAccountPolicies(ctx, LockingStrengthShare, accountID)
|
||||
}
|
||||
|
||||
func (am *DefaultAccountManager) deletePolicy(account *Account, policyID string) (*Policy, error) {
|
||||
@ -434,18 +415,34 @@ func (am *DefaultAccountManager) deletePolicy(account *Account, policyID string)
|
||||
return policy, nil
|
||||
}
|
||||
|
||||
func (am *DefaultAccountManager) savePolicy(account *Account, policy *Policy) (exists bool) {
|
||||
for i, p := range account.Policies {
|
||||
if p.ID == policy.ID {
|
||||
account.Policies[i] = policy
|
||||
exists = true
|
||||
break
|
||||
// savePolicy saves or updates a policy in the given account.
|
||||
// If isUpdate is true, the function updates the existing policy; otherwise, it adds a new policy.
|
||||
func (am *DefaultAccountManager) savePolicy(account *Account, policyToSave *Policy, isUpdate bool) error {
|
||||
for index, rule := range policyToSave.Rules {
|
||||
rule.Sources = filterValidGroupIDs(account, rule.Sources)
|
||||
rule.Destinations = filterValidGroupIDs(account, rule.Destinations)
|
||||
policyToSave.Rules[index] = rule
|
||||
}
|
||||
|
||||
if policyToSave.SourcePostureChecks != nil {
|
||||
policyToSave.SourcePostureChecks = filterValidPostureChecks(account, policyToSave.SourcePostureChecks)
|
||||
}
|
||||
if !exists {
|
||||
account.Policies = append(account.Policies, policy)
|
||||
|
||||
if isUpdate {
|
||||
policyIdx := slices.IndexFunc(account.Policies, func(policy *Policy) bool { return policy.ID == policyToSave.ID })
|
||||
if policyIdx < 0 {
|
||||
return status.Errorf(status.NotFound, "couldn't find policy id %s", policyToSave.ID)
|
||||
}
|
||||
return
|
||||
|
||||
// Update the existing policy
|
||||
account.Policies[policyIdx] = policyToSave
|
||||
return nil
|
||||
}
|
||||
|
||||
// Add the new policy to the account
|
||||
account.Policies = append(account.Policies, policyToSave)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func toProtocolFirewallRules(update []*FirewallRule) []*proto.FirewallRule {
|
||||
@ -560,3 +557,29 @@ func (a *Account) getPostureChecks(postureChecksID string) *posture.Checks {
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// filterValidPostureChecks filters and returns the posture check IDs from the given list
|
||||
// that are valid within the provided account.
|
||||
func filterValidPostureChecks(account *Account, postureChecksIds []string) []string {
|
||||
result := make([]string, 0, len(postureChecksIds))
|
||||
for _, id := range postureChecksIds {
|
||||
for _, postureCheck := range account.PostureChecks {
|
||||
if id == postureCheck.ID {
|
||||
result = append(result, id)
|
||||
continue
|
||||
}
|
||||
}
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
||||
// filterValidGroupIDs filters a list of group IDs and returns only the ones present in the account's group map.
|
||||
func filterValidGroupIDs(account *Account, groupIDs []string) []string {
|
||||
result := make([]string, 0, len(groupIDs))
|
||||
for _, groupID := range groupIDs {
|
||||
if _, exists := account.Groups[groupID]; exists {
|
||||
result = append(result, groupID)
|
||||
}
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
@ -15,30 +15,16 @@ const (
|
||||
)
|
||||
|
||||
func (am *DefaultAccountManager) GetPostureChecks(ctx context.Context, accountID, postureChecksID, userID string) (*posture.Checks, error) {
|
||||
unlock := am.Store.AcquireWriteLockByUID(ctx, accountID)
|
||||
defer unlock()
|
||||
|
||||
account, err := am.Store.GetAccount(ctx, accountID)
|
||||
user, err := am.Store.GetUserByUserID(ctx, LockingStrengthShare, userID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
user, err := account.FindUser(userID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if !user.HasAdminPower() {
|
||||
if !user.HasAdminPower() || user.AccountID != accountID {
|
||||
return nil, status.Errorf(status.PermissionDenied, errMsgPostureAdminOnly)
|
||||
}
|
||||
|
||||
for _, postureChecks := range account.PostureChecks {
|
||||
if postureChecks.ID == postureChecksID {
|
||||
return postureChecks, nil
|
||||
}
|
||||
}
|
||||
|
||||
return nil, status.Errorf(status.NotFound, "posture checks with ID %s not found", postureChecksID)
|
||||
return am.Store.GetPostureChecksByID(ctx, LockingStrengthShare, postureChecksID, accountID)
|
||||
}
|
||||
|
||||
func (am *DefaultAccountManager) SavePostureChecks(ctx context.Context, accountID, userID string, postureChecks *posture.Checks) error {
|
||||
@ -121,24 +107,16 @@ func (am *DefaultAccountManager) DeletePostureChecks(ctx context.Context, accoun
|
||||
}
|
||||
|
||||
func (am *DefaultAccountManager) ListPostureChecks(ctx context.Context, accountID, userID string) ([]*posture.Checks, error) {
|
||||
unlock := am.Store.AcquireWriteLockByUID(ctx, accountID)
|
||||
defer unlock()
|
||||
|
||||
account, err := am.Store.GetAccount(ctx, accountID)
|
||||
user, err := am.Store.GetUserByUserID(ctx, LockingStrengthShare, userID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
user, err := account.FindUser(userID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if !user.HasAdminPower() {
|
||||
if !user.HasAdminPower() || user.AccountID != accountID {
|
||||
return nil, status.Errorf(status.PermissionDenied, errMsgPostureAdminOnly)
|
||||
}
|
||||
|
||||
return account.PostureChecks, nil
|
||||
return am.Store.GetAccountPostureChecks(ctx, LockingStrengthShare, accountID)
|
||||
}
|
||||
|
||||
func (am *DefaultAccountManager) savePostureChecks(account *Account, postureChecks *posture.Checks) (exists, uniqName bool) {
|
||||
|
@ -17,29 +17,16 @@ import (
|
||||
|
||||
// GetRoute gets a route object from account and route IDs
|
||||
func (am *DefaultAccountManager) GetRoute(ctx context.Context, accountID string, routeID route.ID, userID string) (*route.Route, error) {
|
||||
unlock := am.Store.AcquireWriteLockByUID(ctx, accountID)
|
||||
defer unlock()
|
||||
|
||||
account, err := am.Store.GetAccount(ctx, accountID)
|
||||
user, err := am.Store.GetUserByUserID(ctx, LockingStrengthShare, userID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
user, err := account.FindUser(userID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if !(user.HasAdminPower() || user.IsServiceUser) {
|
||||
if !user.IsAdminOrServiceUser() || user.AccountID != accountID {
|
||||
return nil, status.Errorf(status.PermissionDenied, "only users with admin power can view Network Routes")
|
||||
}
|
||||
|
||||
wantedRoute, found := account.Routes[routeID]
|
||||
if found {
|
||||
return wantedRoute, nil
|
||||
}
|
||||
|
||||
return nil, status.Errorf(status.NotFound, "route with ID %s not found", routeID)
|
||||
return am.Store.GetRouteByID(ctx, LockingStrengthShare, string(routeID), accountID)
|
||||
}
|
||||
|
||||
// checkRoutePrefixOrDomainsExistForPeers checks if a route with a given prefix exists for a single peer or multiple peer groups.
|
||||
@ -134,6 +121,13 @@ func (am *DefaultAccountManager) CreateRoute(ctx context.Context, accountID stri
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Do not allow non-Linux peers
|
||||
if peer := account.GetPeer(peerID); peer != nil {
|
||||
if peer.Meta.GoOS != "linux" {
|
||||
return nil, status.Errorf(status.InvalidArgument, "non-linux peers are not supported as network routes")
|
||||
}
|
||||
}
|
||||
|
||||
if len(domains) > 0 && prefix.IsValid() {
|
||||
return nil, status.Errorf(status.InvalidArgument, "domains and network should not be provided at the same time")
|
||||
}
|
||||
@ -234,6 +228,13 @@ func (am *DefaultAccountManager) SaveRoute(ctx context.Context, accountID, userI
|
||||
return err
|
||||
}
|
||||
|
||||
// Do not allow non-Linux peers
|
||||
if peer := account.GetPeer(routeToSave.Peer); peer != nil {
|
||||
if peer.Meta.GoOS != "linux" {
|
||||
return status.Errorf(status.InvalidArgument, "non-linux peers are not supported as network routes")
|
||||
}
|
||||
}
|
||||
|
||||
if len(routeToSave.Domains) > 0 && routeToSave.Network.IsValid() {
|
||||
return status.Errorf(status.InvalidArgument, "domains and network should not be provided at the same time")
|
||||
}
|
||||
@ -311,29 +312,16 @@ func (am *DefaultAccountManager) DeleteRoute(ctx context.Context, accountID stri
|
||||
|
||||
// ListRoutes returns a list of routes from account
|
||||
func (am *DefaultAccountManager) ListRoutes(ctx context.Context, accountID, userID string) ([]*route.Route, error) {
|
||||
unlock := am.Store.AcquireWriteLockByUID(ctx, accountID)
|
||||
defer unlock()
|
||||
|
||||
account, err := am.Store.GetAccount(ctx, accountID)
|
||||
user, err := am.Store.GetUserByUserID(ctx, LockingStrengthShare, userID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
user, err := account.FindUser(userID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if !(user.HasAdminPower() || user.IsServiceUser) {
|
||||
if !user.IsAdminOrServiceUser() || user.AccountID != accountID {
|
||||
return nil, status.Errorf(status.PermissionDenied, "only users with admin power can view Network Routes")
|
||||
}
|
||||
|
||||
routes := make([]*route.Route, 0, len(account.Routes))
|
||||
for _, item := range account.Routes {
|
||||
routes = append(routes, item)
|
||||
}
|
||||
|
||||
return routes, nil
|
||||
return am.Store.GetAccountRoutes(ctx, LockingStrengthShare, accountID)
|
||||
}
|
||||
|
||||
func toProtocolRoute(route *route.Route) *proto.Route {
|
||||
|
@ -1205,7 +1205,7 @@ func TestGetNetworkMap_RouteSync(t *testing.T) {
|
||||
newPolicy.Rules[0].Sources = []string{newGroup.ID}
|
||||
newPolicy.Rules[0].Destinations = []string{newGroup.ID}
|
||||
|
||||
err = am.SavePolicy(context.Background(), account.Id, userID, newPolicy)
|
||||
err = am.SavePolicy(context.Background(), account.Id, userID, newPolicy, false)
|
||||
require.NoError(t, err)
|
||||
|
||||
err = am.DeletePolicy(context.Background(), account.Id, defaultRule.ID, userID)
|
||||
|
@ -330,26 +330,24 @@ func (am *DefaultAccountManager) SaveSetupKey(ctx context.Context, accountID str
|
||||
|
||||
// ListSetupKeys returns a list of all setup keys of the account
|
||||
func (am *DefaultAccountManager) ListSetupKeys(ctx context.Context, accountID, userID string) ([]*SetupKey, error) {
|
||||
unlock := am.Store.AcquireWriteLockByUID(ctx, accountID)
|
||||
defer unlock()
|
||||
account, err := am.Store.GetAccount(ctx, accountID)
|
||||
user, err := am.Store.GetUserByUserID(ctx, LockingStrengthShare, userID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
user, err := account.FindUser(userID)
|
||||
if !user.IsAdminOrServiceUser() || user.AccountID != accountID {
|
||||
return nil, status.Errorf(status.Unauthorized, "only users with admin power can view setup keys")
|
||||
}
|
||||
|
||||
setupKeys, err := am.Store.GetAccountSetupKeys(ctx, LockingStrengthShare, accountID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if !user.HasAdminPower() && !user.IsServiceUser {
|
||||
return nil, status.Errorf(status.Unauthorized, "only users with admin power can view policies")
|
||||
}
|
||||
|
||||
keys := make([]*SetupKey, 0, len(account.SetupKeys))
|
||||
for _, key := range account.SetupKeys {
|
||||
keys := make([]*SetupKey, 0, len(setupKeys))
|
||||
for _, key := range setupKeys {
|
||||
var k *SetupKey
|
||||
if !(user.HasAdminPower() || user.IsServiceUser) {
|
||||
if !user.IsAdminOrServiceUser() {
|
||||
k = key.HiddenCopy(999)
|
||||
} else {
|
||||
k = key.Copy()
|
||||
@ -362,44 +360,30 @@ func (am *DefaultAccountManager) ListSetupKeys(ctx context.Context, accountID, u
|
||||
|
||||
// GetSetupKey looks up a SetupKey by KeyID, returns NotFound error if not found.
|
||||
func (am *DefaultAccountManager) GetSetupKey(ctx context.Context, accountID, userID, keyID string) (*SetupKey, error) {
|
||||
unlock := am.Store.AcquireWriteLockByUID(ctx, accountID)
|
||||
defer unlock()
|
||||
|
||||
account, err := am.Store.GetAccount(ctx, accountID)
|
||||
user, err := am.Store.GetUserByUserID(ctx, LockingStrengthShare, userID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
user, err := account.FindUser(userID)
|
||||
if !user.IsAdminOrServiceUser() || user.AccountID != accountID {
|
||||
return nil, status.Errorf(status.Unauthorized, "only users with admin power can view setup keys")
|
||||
}
|
||||
|
||||
setupKey, err := am.Store.GetSetupKeyByID(ctx, LockingStrengthShare, keyID, accountID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if !user.HasAdminPower() && !user.IsServiceUser {
|
||||
return nil, status.Errorf(status.Unauthorized, "only users with admin power can view policies")
|
||||
}
|
||||
|
||||
var foundKey *SetupKey
|
||||
for _, key := range account.SetupKeys {
|
||||
if key.Id == keyID {
|
||||
foundKey = key.Copy()
|
||||
break
|
||||
}
|
||||
}
|
||||
if foundKey == nil {
|
||||
return nil, status.Errorf(status.NotFound, "setup key not found")
|
||||
}
|
||||
|
||||
// the UpdatedAt field was introduced later, so there might be that some keys have a Zero value (e.g, null in the store file)
|
||||
if foundKey.UpdatedAt.IsZero() {
|
||||
foundKey.UpdatedAt = foundKey.CreatedAt
|
||||
if setupKey.UpdatedAt.IsZero() {
|
||||
setupKey.UpdatedAt = setupKey.CreatedAt
|
||||
}
|
||||
|
||||
if !(user.HasAdminPower() || user.IsServiceUser) {
|
||||
foundKey = foundKey.HiddenCopy(999)
|
||||
if !user.IsAdminOrServiceUser() {
|
||||
setupKey = setupKey.HiddenCopy(999)
|
||||
}
|
||||
|
||||
return foundKey, nil
|
||||
return setupKey, nil
|
||||
}
|
||||
|
||||
func validateSetupKeyAutoGroups(account *Account, autoGroups []string) error {
|
||||
|
@ -36,6 +36,7 @@ const (
|
||||
idQueryCondition = "id = ?"
|
||||
keyQueryCondition = "key = ?"
|
||||
accountAndIDQueryCondition = "account_id = ? and id = ?"
|
||||
accountIDCondition = "account_id = ?"
|
||||
peerNotFoundFMT = "peer %s not found"
|
||||
)
|
||||
|
||||
@ -399,20 +400,30 @@ func (s *SqlStore) DeleteTokenID2UserIDIndex(tokenID string) error {
|
||||
}
|
||||
|
||||
func (s *SqlStore) GetAccountByPrivateDomain(ctx context.Context, domain string) (*Account, error) {
|
||||
var account Account
|
||||
|
||||
result := s.db.First(&account, "domain = ? and is_domain_primary_account = ? and domain_category = ?",
|
||||
strings.ToLower(domain), true, PrivateCategory)
|
||||
if result.Error != nil {
|
||||
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
|
||||
return nil, status.Errorf(status.NotFound, "account not found: provided domain is not registered or is not private")
|
||||
}
|
||||
log.WithContext(ctx).Errorf("error when getting account from the store: %s", result.Error)
|
||||
return nil, status.Errorf(status.Internal, "issue getting account from store")
|
||||
accountID, err := s.GetAccountIDByPrivateDomain(ctx, LockingStrengthShare, domain)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// TODO: rework to not call GetAccount
|
||||
return s.GetAccount(ctx, account.Id)
|
||||
return s.GetAccount(ctx, accountID)
|
||||
}
|
||||
|
||||
func (s *SqlStore) GetAccountIDByPrivateDomain(ctx context.Context, lockStrength LockingStrength, domain string) (string, error) {
|
||||
var accountID string
|
||||
result := s.db.WithContext(ctx).Clauses(clause.Locking{Strength: string(lockStrength)}).Model(&Account{}).Select("id").
|
||||
Where("domain = ? and is_domain_primary_account = ? and domain_category = ?",
|
||||
strings.ToLower(domain), true, PrivateCategory,
|
||||
).First(&accountID)
|
||||
if result.Error != nil {
|
||||
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
|
||||
return "", status.Errorf(status.NotFound, "account not found: provided domain is not registered or is not private")
|
||||
}
|
||||
log.WithContext(ctx).Errorf("error when getting account from the store: %s", result.Error)
|
||||
return "", status.Errorf(status.Internal, "issue getting account from store")
|
||||
}
|
||||
|
||||
return accountID, nil
|
||||
}
|
||||
|
||||
func (s *SqlStore) GetAccountBySetupKey(ctx context.Context, setupKey string) (*Account, error) {
|
||||
@ -478,7 +489,7 @@ func (s *SqlStore) GetUserByTokenID(ctx context.Context, tokenID string) (*User,
|
||||
func (s *SqlStore) GetUserByUserID(ctx context.Context, lockStrength LockingStrength, userID string) (*User, error) {
|
||||
var user User
|
||||
result := s.db.WithContext(ctx).Clauses(clause.Locking{Strength: string(lockStrength)}).
|
||||
First(&user, idQueryCondition, userID)
|
||||
Preload(clause.Associations).First(&user, idQueryCondition, userID)
|
||||
if result.Error != nil {
|
||||
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
|
||||
return nil, status.NewUserNotFoundError(userID)
|
||||
@ -491,7 +502,7 @@ func (s *SqlStore) GetUserByUserID(ctx context.Context, lockStrength LockingStre
|
||||
|
||||
func (s *SqlStore) GetAccountGroups(ctx context.Context, accountID string) ([]*nbgroup.Group, error) {
|
||||
var groups []*nbgroup.Group
|
||||
result := s.db.Find(&groups, idQueryCondition, accountID)
|
||||
result := s.db.Find(&groups, accountIDCondition, accountID)
|
||||
if result.Error != nil {
|
||||
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
|
||||
return nil, status.Errorf(status.NotFound, "accountID not found: index lookup failed")
|
||||
@ -661,9 +672,8 @@ func (s *SqlStore) GetAccountIDByPeerPubKey(ctx context.Context, peerKey string)
|
||||
}
|
||||
|
||||
func (s *SqlStore) GetAccountIDByUserID(userID string) (string, error) {
|
||||
var user User
|
||||
var accountID string
|
||||
result := s.db.Model(&user).Select("account_id").Where(idQueryCondition, userID).First(&accountID)
|
||||
result := s.db.Model(&User{}).Select("account_id").Where(idQueryCondition, userID).First(&accountID)
|
||||
if result.Error != nil {
|
||||
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
|
||||
return "", status.Errorf(status.NotFound, "account not found: index lookup failed")
|
||||
@ -1028,3 +1038,152 @@ func (s *SqlStore) withTx(tx *gorm.DB) Store {
|
||||
func (s *SqlStore) GetDB() *gorm.DB {
|
||||
return s.db
|
||||
}
|
||||
|
||||
func (s *SqlStore) GetAccountDNSSettings(ctx context.Context, lockStrength LockingStrength, accountID string) (*DNSSettings, error) {
|
||||
var accountDNSSettings AccountDNSSettings
|
||||
|
||||
result := s.db.WithContext(ctx).Clauses(clause.Locking{Strength: string(lockStrength)}).Model(&Account{}).
|
||||
First(&accountDNSSettings, idQueryCondition, accountID)
|
||||
if result.Error != nil {
|
||||
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
|
||||
return nil, status.Errorf(status.NotFound, "dns settings not found")
|
||||
}
|
||||
return nil, status.Errorf(status.Internal, "failed to get dns settings from store: %v", result.Error)
|
||||
}
|
||||
return &accountDNSSettings.DNSSettings, nil
|
||||
}
|
||||
|
||||
// AccountExists checks whether an account exists by the given ID.
|
||||
func (s *SqlStore) AccountExists(ctx context.Context, lockStrength LockingStrength, id string) (bool, error) {
|
||||
var accountID string
|
||||
|
||||
result := s.db.WithContext(ctx).Clauses(clause.Locking{Strength: string(lockStrength)}).Model(&Account{}).
|
||||
Select("id").First(&accountID, idQueryCondition, id)
|
||||
if result.Error != nil {
|
||||
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
|
||||
return false, nil
|
||||
}
|
||||
return false, result.Error
|
||||
}
|
||||
|
||||
return accountID != "", nil
|
||||
}
|
||||
|
||||
// GetAccountDomainAndCategory retrieves the Domain and DomainCategory fields for an account based on the given accountID.
|
||||
func (s *SqlStore) GetAccountDomainAndCategory(ctx context.Context, lockStrength LockingStrength, accountID string) (string, string, error) {
|
||||
var account Account
|
||||
|
||||
result := s.db.WithContext(ctx).Clauses(clause.Locking{Strength: string(lockStrength)}).Model(&Account{}).Select("domain", "domain_category").
|
||||
Where(idQueryCondition, accountID).First(&account)
|
||||
if result.Error != nil {
|
||||
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
|
||||
return "", "", status.Errorf(status.NotFound, "account not found")
|
||||
}
|
||||
return "", "", status.Errorf(status.Internal, "failed to get domain category from store: %v", result.Error)
|
||||
}
|
||||
|
||||
return account.Domain, account.DomainCategory, nil
|
||||
}
|
||||
|
||||
// GetGroupByID retrieves a group by ID and account ID.
|
||||
func (s *SqlStore) GetGroupByID(ctx context.Context, lockStrength LockingStrength, groupID, accountID string) (*nbgroup.Group, error) {
|
||||
return getRecordByID[nbgroup.Group](s.db.WithContext(ctx).Preload(clause.Associations), lockStrength, groupID, accountID)
|
||||
}
|
||||
|
||||
// GetGroupByName retrieves a group by name and account ID.
|
||||
func (s *SqlStore) GetGroupByName(ctx context.Context, lockStrength LockingStrength, groupName, accountID string) (*nbgroup.Group, error) {
|
||||
var group nbgroup.Group
|
||||
|
||||
result := s.db.WithContext(ctx).Clauses(clause.Locking{Strength: string(lockStrength)}).Preload(clause.Associations).
|
||||
Order("json_array_length(peers) DESC").First(&group, "name = ? and account_id = ?", groupName, accountID)
|
||||
if err := result.Error; err != nil {
|
||||
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
|
||||
return nil, status.Errorf(status.NotFound, "group not found")
|
||||
}
|
||||
return nil, status.Errorf(status.Internal, "failed to get group from store: %s", result.Error)
|
||||
}
|
||||
return &group, nil
|
||||
}
|
||||
|
||||
// GetAccountPolicies retrieves policies for an account.
|
||||
func (s *SqlStore) GetAccountPolicies(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*Policy, error) {
|
||||
return getRecords[*Policy](s.db.WithContext(ctx).Preload(clause.Associations), lockStrength, accountID)
|
||||
}
|
||||
|
||||
// GetPolicyByID retrieves a policy by its ID and account ID.
|
||||
func (s *SqlStore) GetPolicyByID(ctx context.Context, lockStrength LockingStrength, policyID string, accountID string) (*Policy, error) {
|
||||
return getRecordByID[Policy](s.db.WithContext(ctx).Preload(clause.Associations), lockStrength, policyID, accountID)
|
||||
}
|
||||
|
||||
// GetAccountPostureChecks retrieves posture checks for an account.
|
||||
func (s *SqlStore) GetAccountPostureChecks(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*posture.Checks, error) {
|
||||
return getRecords[*posture.Checks](s.db.WithContext(ctx), lockStrength, accountID)
|
||||
}
|
||||
|
||||
// GetPostureChecksByID retrieves posture checks by their ID and account ID.
|
||||
func (s *SqlStore) GetPostureChecksByID(ctx context.Context, lockStrength LockingStrength, postureCheckID string, accountID string) (*posture.Checks, error) {
|
||||
return getRecordByID[posture.Checks](s.db.WithContext(ctx), lockStrength, postureCheckID, accountID)
|
||||
}
|
||||
|
||||
// GetAccountRoutes retrieves network routes for an account.
|
||||
func (s *SqlStore) GetAccountRoutes(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*route.Route, error) {
|
||||
return getRecords[*route.Route](s.db.WithContext(ctx), lockStrength, accountID)
|
||||
}
|
||||
|
||||
// GetRouteByID retrieves a route by its ID and account ID.
|
||||
func (s *SqlStore) GetRouteByID(ctx context.Context, lockStrength LockingStrength, routeID string, accountID string) (*route.Route, error) {
|
||||
return getRecordByID[route.Route](s.db.WithContext(ctx), lockStrength, routeID, accountID)
|
||||
}
|
||||
|
||||
// GetAccountSetupKeys retrieves setup keys for an account.
|
||||
func (s *SqlStore) GetAccountSetupKeys(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*SetupKey, error) {
|
||||
return getRecords[*SetupKey](s.db.WithContext(ctx), lockStrength, accountID)
|
||||
}
|
||||
|
||||
// GetSetupKeyByID retrieves a setup key by its ID and account ID.
|
||||
func (s *SqlStore) GetSetupKeyByID(ctx context.Context, lockStrength LockingStrength, setupKeyID string, accountID string) (*SetupKey, error) {
|
||||
return getRecordByID[SetupKey](s.db.WithContext(ctx), lockStrength, setupKeyID, accountID)
|
||||
}
|
||||
|
||||
// GetAccountNameServerGroups retrieves name server groups for an account.
|
||||
func (s *SqlStore) GetAccountNameServerGroups(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*nbdns.NameServerGroup, error) {
|
||||
return getRecords[*nbdns.NameServerGroup](s.db.WithContext(ctx), lockStrength, accountID)
|
||||
}
|
||||
|
||||
// GetNameServerGroupByID retrieves a name server group by its ID and account ID.
|
||||
func (s *SqlStore) GetNameServerGroupByID(ctx context.Context, lockStrength LockingStrength, nsGroupID string, accountID string) (*nbdns.NameServerGroup, error) {
|
||||
return getRecordByID[nbdns.NameServerGroup](s.db.WithContext(ctx), lockStrength, nsGroupID, accountID)
|
||||
}
|
||||
|
||||
// getRecords retrieves records from the database based on the account ID.
|
||||
func getRecords[T any](db *gorm.DB, lockStrength LockingStrength, accountID string) ([]T, error) {
|
||||
var record []T
|
||||
|
||||
result := db.Clauses(clause.Locking{Strength: string(lockStrength)}).Find(&record, accountIDCondition, accountID)
|
||||
if err := result.Error; err != nil {
|
||||
parts := strings.Split(fmt.Sprintf("%T", record), ".")
|
||||
recordType := parts[len(parts)-1]
|
||||
|
||||
return nil, status.Errorf(status.Internal, "failed to get account %ss from store: %v", recordType, err)
|
||||
}
|
||||
|
||||
return record, nil
|
||||
}
|
||||
|
||||
// getRecordByID retrieves a record by its ID and account ID from the database.
|
||||
func getRecordByID[T any](db *gorm.DB, lockStrength LockingStrength, recordID, accountID string) (*T, error) {
|
||||
var record T
|
||||
|
||||
result := db.Clauses(clause.Locking{Strength: string(lockStrength)}).
|
||||
First(&record, accountAndIDQueryCondition, accountID, recordID)
|
||||
if err := result.Error; err != nil {
|
||||
parts := strings.Split(fmt.Sprintf("%T", record), ".")
|
||||
recordType := parts[len(parts)-1]
|
||||
|
||||
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
|
||||
return nil, status.Errorf(status.NotFound, "%s not found", recordType)
|
||||
}
|
||||
return nil, status.Errorf(status.Internal, "failed to get %s from store: %v", recordType, err)
|
||||
}
|
||||
return &record, nil
|
||||
}
|
||||
|
@ -12,6 +12,7 @@ import (
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/netbirdio/netbird/dns"
|
||||
log "github.com/sirupsen/logrus"
|
||||
"gorm.io/gorm"
|
||||
|
||||
@ -39,53 +40,81 @@ const (
|
||||
type Store interface {
|
||||
GetAllAccounts(ctx context.Context) []*Account
|
||||
GetAccount(ctx context.Context, accountID string) (*Account, error)
|
||||
DeleteAccount(ctx context.Context, account *Account) error
|
||||
AccountExists(ctx context.Context, lockStrength LockingStrength, id string) (bool, error)
|
||||
GetAccountDomainAndCategory(ctx context.Context, lockStrength LockingStrength, accountID string) (string, string, error)
|
||||
GetAccountByUser(ctx context.Context, userID string) (*Account, error)
|
||||
GetAccountByPeerPubKey(ctx context.Context, peerKey string) (*Account, error)
|
||||
GetAccountIDByPeerPubKey(ctx context.Context, peerKey string) (string, error)
|
||||
GetAccountIDByUserID(peerKey string) (string, error)
|
||||
GetAccountIDByUserID(userID string) (string, error)
|
||||
GetAccountIDBySetupKey(ctx context.Context, peerKey string) (string, error)
|
||||
GetAccountByPeerID(ctx context.Context, peerID string) (*Account, error)
|
||||
GetAccountBySetupKey(ctx context.Context, setupKey string) (*Account, error) // todo use key hash later
|
||||
GetAccountByPrivateDomain(ctx context.Context, domain string) (*Account, error)
|
||||
GetTokenIDByHashedToken(ctx context.Context, secret string) (string, error)
|
||||
GetAccountIDByPrivateDomain(ctx context.Context, lockStrength LockingStrength, domain string) (string, error)
|
||||
GetAccountSettings(ctx context.Context, lockStrength LockingStrength, accountID string) (*Settings, error)
|
||||
GetAccountDNSSettings(ctx context.Context, lockStrength LockingStrength, accountID string) (*DNSSettings, error)
|
||||
SaveAccount(ctx context.Context, account *Account) error
|
||||
DeleteAccount(ctx context.Context, account *Account) error
|
||||
|
||||
GetUserByTokenID(ctx context.Context, tokenID string) (*User, error)
|
||||
GetUserByUserID(ctx context.Context, lockStrength LockingStrength, userID string) (*User, error)
|
||||
GetAccountGroups(ctx context.Context, accountID string) ([]*nbgroup.Group, error)
|
||||
GetPostureCheckByChecksDefinition(accountID string, checks *posture.ChecksDefinition) (*posture.Checks, error)
|
||||
SaveAccount(ctx context.Context, account *Account) error
|
||||
SaveUsers(accountID string, users map[string]*User) error
|
||||
SaveGroups(accountID string, groups map[string]*nbgroup.Group) error
|
||||
SaveUserLastLogin(ctx context.Context, accountID, userID string, lastLogin time.Time) error
|
||||
GetTokenIDByHashedToken(ctx context.Context, secret string) (string, error)
|
||||
DeleteHashedPAT2TokenIDIndex(hashedToken string) error
|
||||
DeleteTokenID2UserIDIndex(tokenID string) error
|
||||
|
||||
GetAccountGroups(ctx context.Context, accountID string) ([]*nbgroup.Group, error)
|
||||
GetGroupByID(ctx context.Context, lockStrength LockingStrength, groupID, accountID string) (*nbgroup.Group, error)
|
||||
GetGroupByName(ctx context.Context, lockStrength LockingStrength, groupName, accountID string) (*nbgroup.Group, error)
|
||||
SaveGroups(accountID string, groups map[string]*nbgroup.Group) error
|
||||
|
||||
GetAccountPolicies(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*Policy, error)
|
||||
GetPolicyByID(ctx context.Context, lockStrength LockingStrength, policyID string, accountID string) (*Policy, error)
|
||||
|
||||
GetPostureCheckByChecksDefinition(accountID string, checks *posture.ChecksDefinition) (*posture.Checks, error)
|
||||
GetAccountPostureChecks(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*posture.Checks, error)
|
||||
GetPostureChecksByID(ctx context.Context, lockStrength LockingStrength, postureCheckID string, accountID string) (*posture.Checks, error)
|
||||
|
||||
GetPeerLabelsInAccount(ctx context.Context, lockStrength LockingStrength, accountId string) ([]string, error)
|
||||
AddPeerToAllGroup(ctx context.Context, accountID string, peerID string) error
|
||||
AddPeerToGroup(ctx context.Context, accountId string, peerId string, groupID string) error
|
||||
AddPeerToAccount(ctx context.Context, peer *nbpeer.Peer) error
|
||||
GetPeerByPeerPubKey(ctx context.Context, lockStrength LockingStrength, peerKey string) (*nbpeer.Peer, error)
|
||||
SavePeer(ctx context.Context, accountID string, peer *nbpeer.Peer) error
|
||||
SavePeerStatus(accountID, peerID string, status nbpeer.PeerStatus) error
|
||||
SavePeerLocation(accountID string, peer *nbpeer.Peer) error
|
||||
|
||||
GetSetupKeyBySecret(ctx context.Context, lockStrength LockingStrength, key string) (*SetupKey, error)
|
||||
IncrementSetupKeyUsage(ctx context.Context, setupKeyID string) error
|
||||
GetAccountSetupKeys(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*SetupKey, error)
|
||||
GetSetupKeyByID(ctx context.Context, lockStrength LockingStrength, setupKeyID string, accountID string) (*SetupKey, error)
|
||||
|
||||
GetAccountRoutes(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*route.Route, error)
|
||||
GetRouteByID(ctx context.Context, lockStrength LockingStrength, routeID string, accountID string) (*route.Route, error)
|
||||
|
||||
GetAccountNameServerGroups(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*dns.NameServerGroup, error)
|
||||
GetNameServerGroupByID(ctx context.Context, lockStrength LockingStrength, nameServerGroupID string, accountID string) (*dns.NameServerGroup, error)
|
||||
|
||||
GetTakenIPs(ctx context.Context, lockStrength LockingStrength, accountId string) ([]net.IP, error)
|
||||
IncrementNetworkSerial(ctx context.Context, accountId string) error
|
||||
GetAccountNetwork(ctx context.Context, lockStrength LockingStrength, accountId string) (*Network, error)
|
||||
|
||||
GetInstallationID() string
|
||||
SaveInstallationID(ctx context.Context, ID string) error
|
||||
|
||||
// AcquireWriteLockByUID should attempt to acquire a lock for write purposes and return a function that releases the lock
|
||||
AcquireWriteLockByUID(ctx context.Context, uniqueID string) func()
|
||||
// AcquireReadLockByUID should attempt to acquire lock for read purposes and return a function that releases the lock
|
||||
AcquireReadLockByUID(ctx context.Context, uniqueID string) func()
|
||||
// AcquireGlobalLock should attempt to acquire a global lock and return a function that releases the lock
|
||||
AcquireGlobalLock(ctx context.Context) func()
|
||||
SavePeer(ctx context.Context, accountID string, peer *nbpeer.Peer) error
|
||||
SavePeerStatus(accountID, peerID string, status nbpeer.PeerStatus) error
|
||||
SavePeerLocation(accountID string, peer *nbpeer.Peer) error
|
||||
SaveUserLastLogin(ctx context.Context, accountID, userID string, lastLogin time.Time) error
|
||||
|
||||
// Close should close the store persisting all unsaved data.
|
||||
Close(ctx context.Context) error
|
||||
// GetStoreEngine should return StoreEngine of the current store implementation.
|
||||
// This is also a method of metrics.DataSource interface.
|
||||
GetStoreEngine() StoreEngine
|
||||
GetPeerByPeerPubKey(ctx context.Context, lockStrength LockingStrength, peerKey string) (*nbpeer.Peer, error)
|
||||
GetAccountSettings(ctx context.Context, lockStrength LockingStrength, accountID string) (*Settings, error)
|
||||
GetSetupKeyBySecret(ctx context.Context, lockStrength LockingStrength, key string) (*SetupKey, error)
|
||||
GetTakenIPs(ctx context.Context, lockStrength LockingStrength, accountId string) ([]net.IP, error)
|
||||
IncrementSetupKeyUsage(ctx context.Context, setupKeyID string) error
|
||||
AddPeerToAllGroup(ctx context.Context, accountID string, peerID string) error
|
||||
GetPeerLabelsInAccount(ctx context.Context, lockStrength LockingStrength, accountId string) ([]string, error)
|
||||
AddPeerToGroup(ctx context.Context, accountId string, peerId string, groupID string) error
|
||||
AddPeerToAccount(ctx context.Context, peer *nbpeer.Peer) error
|
||||
IncrementNetworkSerial(ctx context.Context, accountId string) error
|
||||
GetAccountNetwork(ctx context.Context, lockStrength LockingStrength, accountId string) (*Network, error)
|
||||
ExecuteInTransaction(ctx context.Context, f func(store Store) error) error
|
||||
}
|
||||
|
||||
|
@ -94,6 +94,11 @@ func (u *User) HasAdminPower() bool {
|
||||
return u.Role == UserRoleAdmin || u.Role == UserRoleOwner
|
||||
}
|
||||
|
||||
// IsAdminOrServiceUser checks if the user has admin power or is a service user.
|
||||
func (u *User) IsAdminOrServiceUser() bool {
|
||||
return u.HasAdminPower() || u.IsServiceUser
|
||||
}
|
||||
|
||||
// ToUserInfo converts a User object to a UserInfo object.
|
||||
func (u *User) ToUserInfo(userData *idp.UserData, settings *Settings) (*UserInfo, error) {
|
||||
autoGroups := u.AutoGroups
|
||||
@ -357,39 +362,35 @@ func (am *DefaultAccountManager) inviteNewUser(ctx context.Context, accountID, u
|
||||
return newUser.ToUserInfo(idpUser, account.Settings)
|
||||
}
|
||||
|
||||
func (am *DefaultAccountManager) GetUserByID(ctx context.Context, id string) (*User, error) {
|
||||
return am.Store.GetUserByUserID(ctx, LockingStrengthShare, id)
|
||||
}
|
||||
|
||||
// GetUser looks up a user by provided authorization claims.
|
||||
// It will also create an account if didn't exist for this user before.
|
||||
func (am *DefaultAccountManager) GetUser(ctx context.Context, claims jwtclaims.AuthorizationClaims) (*User, error) {
|
||||
account, _, err := am.GetAccountFromToken(ctx, claims)
|
||||
accountID, userID, err := am.GetAccountIDFromToken(ctx, claims)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to get account with token claims %v", err)
|
||||
}
|
||||
|
||||
unlock := am.Store.AcquireWriteLockByUID(ctx, account.Id)
|
||||
defer unlock()
|
||||
|
||||
account, err = am.Store.GetAccount(ctx, account.Id)
|
||||
user, err := am.Store.GetUserByUserID(ctx, LockingStrengthShare, userID)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to get an account from store %v", err)
|
||||
return nil, err
|
||||
}
|
||||
|
||||
user, ok := account.Users[claims.UserId]
|
||||
if !ok {
|
||||
return nil, status.Errorf(status.NotFound, "user not found")
|
||||
}
|
||||
|
||||
// this code should be outside of the am.GetAccountFromToken(claims) because this method is called also by the gRPC
|
||||
// this code should be outside of the am.GetAccountIDFromToken(claims) because this method is called also by the gRPC
|
||||
// server when user authenticates a device. And we need to separate the Dashboard login event from the Device login event.
|
||||
newLogin := user.LastDashboardLoginChanged(claims.LastLogin)
|
||||
|
||||
err = am.Store.SaveUserLastLogin(ctx, account.Id, claims.UserId, claims.LastLogin)
|
||||
err = am.Store.SaveUserLastLogin(ctx, accountID, userID, claims.LastLogin)
|
||||
if err != nil {
|
||||
log.WithContext(ctx).Errorf("failed saving user last login: %v", err)
|
||||
}
|
||||
|
||||
if newLogin {
|
||||
meta := map[string]any{"timestamp": claims.LastLogin}
|
||||
am.StoreEvent(ctx, claims.UserId, claims.UserId, account.Id, activity.DashboardLogin, meta)
|
||||
am.StoreEvent(ctx, claims.UserId, claims.UserId, accountID, activity.DashboardLogin, meta)
|
||||
}
|
||||
|
||||
return user, nil
|
||||
@ -642,63 +643,48 @@ func (am *DefaultAccountManager) DeletePAT(ctx context.Context, accountID string
|
||||
|
||||
// GetPAT returns a specific PAT from a user
|
||||
func (am *DefaultAccountManager) GetPAT(ctx context.Context, accountID string, initiatorUserID string, targetUserID string, tokenID string) (*PersonalAccessToken, error) {
|
||||
unlock := am.Store.AcquireWriteLockByUID(ctx, accountID)
|
||||
defer unlock()
|
||||
|
||||
account, err := am.Store.GetAccount(ctx, accountID)
|
||||
initiatorUser, err := am.Store.GetUserByUserID(ctx, LockingStrengthShare, initiatorUserID)
|
||||
if err != nil {
|
||||
return nil, status.Errorf(status.NotFound, "account not found: %s", err)
|
||||
return nil, err
|
||||
}
|
||||
|
||||
targetUser, ok := account.Users[targetUserID]
|
||||
if !ok {
|
||||
return nil, status.Errorf(status.NotFound, "user not found")
|
||||
targetUser, err := am.Store.GetUserByUserID(ctx, LockingStrengthShare, targetUserID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
executingUser, ok := account.Users[initiatorUserID]
|
||||
if !ok {
|
||||
return nil, status.Errorf(status.NotFound, "user not found")
|
||||
if (initiatorUserID != targetUserID && !initiatorUser.IsAdminOrServiceUser()) || initiatorUser.AccountID != accountID {
|
||||
return nil, status.Errorf(status.PermissionDenied, "no permission to get PAT for this user")
|
||||
}
|
||||
|
||||
if !(initiatorUserID == targetUserID || (executingUser.HasAdminPower() && targetUser.IsServiceUser)) {
|
||||
return nil, status.Errorf(status.PermissionDenied, "no permission to get PAT for this userser")
|
||||
for _, pat := range targetUser.PATsG {
|
||||
if pat.ID == tokenID {
|
||||
return pat.Copy(), nil
|
||||
}
|
||||
}
|
||||
|
||||
pat := targetUser.PATs[tokenID]
|
||||
if pat == nil {
|
||||
return nil, status.Errorf(status.NotFound, "PAT not found")
|
||||
}
|
||||
|
||||
return pat, nil
|
||||
}
|
||||
|
||||
// GetAllPATs returns all PATs for a user
|
||||
func (am *DefaultAccountManager) GetAllPATs(ctx context.Context, accountID string, initiatorUserID string, targetUserID string) ([]*PersonalAccessToken, error) {
|
||||
unlock := am.Store.AcquireWriteLockByUID(ctx, accountID)
|
||||
defer unlock()
|
||||
|
||||
account, err := am.Store.GetAccount(ctx, accountID)
|
||||
initiatorUser, err := am.Store.GetUserByUserID(ctx, LockingStrengthShare, initiatorUserID)
|
||||
if err != nil {
|
||||
return nil, status.Errorf(status.NotFound, "account not found: %s", err)
|
||||
return nil, err
|
||||
}
|
||||
|
||||
targetUser, ok := account.Users[targetUserID]
|
||||
if !ok {
|
||||
return nil, status.Errorf(status.NotFound, "user not found")
|
||||
targetUser, err := am.Store.GetUserByUserID(ctx, LockingStrengthShare, targetUserID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
executingUser, ok := account.Users[initiatorUserID]
|
||||
if !ok {
|
||||
return nil, status.Errorf(status.NotFound, "user not found")
|
||||
}
|
||||
|
||||
if !(initiatorUserID == targetUserID || (executingUser.HasAdminPower() && targetUser.IsServiceUser)) {
|
||||
if (initiatorUserID != targetUserID && !initiatorUser.IsAdminOrServiceUser()) || initiatorUser.AccountID != accountID {
|
||||
return nil, status.Errorf(status.PermissionDenied, "no permission to get PAT for this user")
|
||||
}
|
||||
|
||||
var pats []*PersonalAccessToken
|
||||
for _, pat := range targetUser.PATs {
|
||||
pats = append(pats, pat)
|
||||
pats := make([]*PersonalAccessToken, 0, len(targetUser.PATsG))
|
||||
for _, pat := range targetUser.PATsG {
|
||||
pats = append(pats, pat.Copy())
|
||||
}
|
||||
|
||||
return pats, nil
|
||||
|
@ -200,6 +200,7 @@ func TestUser_GetPAT(t *testing.T) {
|
||||
account := newAccountWithId(context.Background(), mockAccountID, mockUserID, "")
|
||||
account.Users[mockUserID] = &User{
|
||||
Id: mockUserID,
|
||||
AccountID: mockAccountID,
|
||||
PATs: map[string]*PersonalAccessToken{
|
||||
mockTokenID1: {
|
||||
ID: mockTokenID1,
|
||||
@ -232,6 +233,7 @@ func TestUser_GetAllPATs(t *testing.T) {
|
||||
account := newAccountWithId(context.Background(), mockAccountID, mockUserID, "")
|
||||
account.Users[mockUserID] = &User{
|
||||
Id: mockUserID,
|
||||
AccountID: mockAccountID,
|
||||
PATs: map[string]*PersonalAccessToken{
|
||||
mockTokenID1: {
|
||||
ID: mockTokenID1,
|
||||
@ -796,7 +798,10 @@ func TestUser_DeleteUser_RegularUsers(t *testing.T) {
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
|
||||
acc, err := am.GetAccountByUserOrAccountID(context.Background(), "", account.Id, "")
|
||||
accID, err := am.GetAccountIDByUserOrAccountID(context.Background(), "", account.Id, "")
|
||||
assert.NoError(t, err)
|
||||
|
||||
acc, err := am.Store.GetAccount(context.Background(), accID)
|
||||
assert.NoError(t, err)
|
||||
|
||||
for _, id := range tc.expectedDeleted {
|
||||
|
Loading…
Reference in New Issue
Block a user