mirror of
https://github.com/netbirdio/netbird.git
synced 2025-04-01 11:46:39 +02:00
[management] Remove redundant get account calls in GetAccountFromToken (#2615)
* refactor access control middleware and user access by JWT groups Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com> * refactor jwt groups extractor Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com> * refactor handlers to get account when necessary Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com> * refactor getAccountFromToken Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com> * refactor getAccountWithAuthorizationClaims Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com> * fix merge Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com> * revert handles change Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com> * remove GetUserByID from account manager Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com> * fix tests Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com> * refactor getAccountWithAuthorizationClaims to return account id Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com> * refactor handlers to use GetAccountIDFromToken Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com> * fix tests Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com> * remove locks Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com> * refactor Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com> * add GetGroupByName from store Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com> * add GetGroupByID from store and refactor Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com> * Refactor retrieval of policy and posture checks Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com> * Refactor user permissions and retrieves PAT Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com> * Refactor route, setupkey, nameserver and dns to get record(s) from store Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com> * Refactor store Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com> * fix lint Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com> * fix tests Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com> * fix add missing policy source posture checks Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com> * add store lock Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com> * fix tests Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com> * add get account Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com> --------- Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>
This commit is contained in:
parent
4ebf6e1c4c
commit
acb73bd64a
@ -20,11 +20,6 @@ import (
|
|||||||
cacheStore "github.com/eko/gocache/v3/store"
|
cacheStore "github.com/eko/gocache/v3/store"
|
||||||
"github.com/hashicorp/go-multierror"
|
"github.com/hashicorp/go-multierror"
|
||||||
"github.com/miekg/dns"
|
"github.com/miekg/dns"
|
||||||
gocache "github.com/patrickmn/go-cache"
|
|
||||||
"github.com/rs/xid"
|
|
||||||
log "github.com/sirupsen/logrus"
|
|
||||||
"golang.org/x/exp/maps"
|
|
||||||
|
|
||||||
"github.com/netbirdio/netbird/base62"
|
"github.com/netbirdio/netbird/base62"
|
||||||
nbdns "github.com/netbirdio/netbird/dns"
|
nbdns "github.com/netbirdio/netbird/dns"
|
||||||
"github.com/netbirdio/netbird/management/domain"
|
"github.com/netbirdio/netbird/management/domain"
|
||||||
@ -41,6 +36,10 @@ import (
|
|||||||
"github.com/netbirdio/netbird/management/server/status"
|
"github.com/netbirdio/netbird/management/server/status"
|
||||||
"github.com/netbirdio/netbird/management/server/telemetry"
|
"github.com/netbirdio/netbird/management/server/telemetry"
|
||||||
"github.com/netbirdio/netbird/route"
|
"github.com/netbirdio/netbird/route"
|
||||||
|
gocache "github.com/patrickmn/go-cache"
|
||||||
|
"github.com/rs/xid"
|
||||||
|
log "github.com/sirupsen/logrus"
|
||||||
|
"golang.org/x/exp/maps"
|
||||||
)
|
)
|
||||||
|
|
||||||
const (
|
const (
|
||||||
@ -63,6 +62,7 @@ func cacheEntryExpiration() time.Duration {
|
|||||||
|
|
||||||
type AccountManager interface {
|
type AccountManager interface {
|
||||||
GetOrCreateAccountByUser(ctx context.Context, userId, domain string) (*Account, error)
|
GetOrCreateAccountByUser(ctx context.Context, userId, domain string) (*Account, error)
|
||||||
|
GetAccount(ctx context.Context, accountID string) (*Account, error)
|
||||||
CreateSetupKey(ctx context.Context, accountID string, keyName string, keyType SetupKeyType, expiresIn time.Duration,
|
CreateSetupKey(ctx context.Context, accountID string, keyName string, keyType SetupKeyType, expiresIn time.Duration,
|
||||||
autoGroups []string, usageLimit int, userID string, ephemeral bool) (*SetupKey, error)
|
autoGroups []string, usageLimit int, userID string, ephemeral bool) (*SetupKey, error)
|
||||||
SaveSetupKey(ctx context.Context, accountID string, key *SetupKey, userID string) (*SetupKey, error)
|
SaveSetupKey(ctx context.Context, accountID string, key *SetupKey, userID string) (*SetupKey, error)
|
||||||
@ -75,12 +75,14 @@ type AccountManager interface {
|
|||||||
SaveOrAddUser(ctx context.Context, accountID, initiatorUserID string, update *User, addIfNotExists bool) (*UserInfo, error)
|
SaveOrAddUser(ctx context.Context, accountID, initiatorUserID string, update *User, addIfNotExists bool) (*UserInfo, error)
|
||||||
SaveOrAddUsers(ctx context.Context, accountID, initiatorUserID string, updates []*User, addIfNotExists bool) ([]*UserInfo, error)
|
SaveOrAddUsers(ctx context.Context, accountID, initiatorUserID string, updates []*User, addIfNotExists bool) ([]*UserInfo, error)
|
||||||
GetSetupKey(ctx context.Context, accountID, userID, keyID string) (*SetupKey, error)
|
GetSetupKey(ctx context.Context, accountID, userID, keyID string) (*SetupKey, error)
|
||||||
GetAccountByUserOrAccountID(ctx context.Context, userID, accountID, domain string) (*Account, error)
|
GetAccountByID(ctx context.Context, accountID string, userID string) (*Account, error)
|
||||||
GetAccountFromToken(ctx context.Context, claims jwtclaims.AuthorizationClaims) (*Account, *User, error)
|
GetAccountIDByUserOrAccountID(ctx context.Context, userID, accountID, domain string) (string, error)
|
||||||
|
GetAccountIDFromToken(ctx context.Context, claims jwtclaims.AuthorizationClaims) (string, string, error)
|
||||||
CheckUserAccessByJWTGroups(ctx context.Context, claims jwtclaims.AuthorizationClaims) error
|
CheckUserAccessByJWTGroups(ctx context.Context, claims jwtclaims.AuthorizationClaims) error
|
||||||
GetAccountFromPAT(ctx context.Context, pat string) (*Account, *User, *PersonalAccessToken, error)
|
GetAccountFromPAT(ctx context.Context, pat string) (*Account, *User, *PersonalAccessToken, error)
|
||||||
DeleteAccount(ctx context.Context, accountID, userID string) error
|
DeleteAccount(ctx context.Context, accountID, userID string) error
|
||||||
MarkPATUsed(ctx context.Context, tokenID string) error
|
MarkPATUsed(ctx context.Context, tokenID string) error
|
||||||
|
GetUserByID(ctx context.Context, id string) (*User, error)
|
||||||
GetUser(ctx context.Context, claims jwtclaims.AuthorizationClaims) (*User, error)
|
GetUser(ctx context.Context, claims jwtclaims.AuthorizationClaims) (*User, error)
|
||||||
ListUsers(ctx context.Context, accountID string) ([]*User, error)
|
ListUsers(ctx context.Context, accountID string) ([]*User, error)
|
||||||
GetPeers(ctx context.Context, accountID, userID string) ([]*nbpeer.Peer, error)
|
GetPeers(ctx context.Context, accountID, userID string) ([]*nbpeer.Peer, error)
|
||||||
@ -107,7 +109,7 @@ type AccountManager interface {
|
|||||||
GroupAddPeer(ctx context.Context, accountId, groupID, peerID string) error
|
GroupAddPeer(ctx context.Context, accountId, groupID, peerID string) error
|
||||||
GroupDeletePeer(ctx context.Context, accountId, groupID, peerID string) error
|
GroupDeletePeer(ctx context.Context, accountId, groupID, peerID string) error
|
||||||
GetPolicy(ctx context.Context, accountID, policyID, userID string) (*Policy, error)
|
GetPolicy(ctx context.Context, accountID, policyID, userID string) (*Policy, error)
|
||||||
SavePolicy(ctx context.Context, accountID, userID string, policy *Policy) error
|
SavePolicy(ctx context.Context, accountID, userID string, policy *Policy, isUpdate bool) error
|
||||||
DeletePolicy(ctx context.Context, accountID, policyID, userID string) error
|
DeletePolicy(ctx context.Context, accountID, policyID, userID string) error
|
||||||
ListPolicies(ctx context.Context, accountID, userID string) ([]*Policy, error)
|
ListPolicies(ctx context.Context, accountID, userID string) ([]*Policy, error)
|
||||||
GetRoute(ctx context.Context, accountID string, routeID route.ID, userID string) (*route.Route, error)
|
GetRoute(ctx context.Context, accountID string, routeID route.ID, userID string) (*route.Route, error)
|
||||||
@ -145,6 +147,7 @@ type AccountManager interface {
|
|||||||
SyncPeerMeta(ctx context.Context, peerPubKey string, meta nbpeer.PeerSystemMeta) error
|
SyncPeerMeta(ctx context.Context, peerPubKey string, meta nbpeer.PeerSystemMeta) error
|
||||||
FindExistingPostureCheck(accountID string, checks *posture.ChecksDefinition) (*posture.Checks, error)
|
FindExistingPostureCheck(accountID string, checks *posture.ChecksDefinition) (*posture.Checks, error)
|
||||||
GetAccountIDForPeerKey(ctx context.Context, peerKey string) (string, error)
|
GetAccountIDForPeerKey(ctx context.Context, peerKey string) (string, error)
|
||||||
|
GetAccountSettings(ctx context.Context, accountID string, userID string) (*Settings, error)
|
||||||
}
|
}
|
||||||
|
|
||||||
type DefaultAccountManager struct {
|
type DefaultAccountManager struct {
|
||||||
@ -268,6 +271,11 @@ type AccountNetwork struct {
|
|||||||
Network *Network `gorm:"embedded;embeddedPrefix:network_"`
|
Network *Network `gorm:"embedded;embeddedPrefix:network_"`
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// AccountDNSSettings used in gorm to only load dns settings and not whole account
|
||||||
|
type AccountDNSSettings struct {
|
||||||
|
DNSSettings DNSSettings `gorm:"embedded;embeddedPrefix:dns_settings_"`
|
||||||
|
}
|
||||||
|
|
||||||
type UserPermissions struct {
|
type UserPermissions struct {
|
||||||
DashboardView string `json:"dashboard_view"`
|
DashboardView string `json:"dashboard_view"`
|
||||||
}
|
}
|
||||||
@ -1252,25 +1260,37 @@ func (am *DefaultAccountManager) DeleteAccount(ctx context.Context, accountID, u
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetAccountByUserOrAccountID looks for an account by user or accountID, if no account is provided and
|
// GetAccountIDByUserOrAccountID retrieves the account ID based on either the userID or accountID provided.
|
||||||
// userID doesn't have an account associated with it, one account is created
|
// If an accountID is provided, it checks if the account exists and returns it.
|
||||||
// domain is used to create a new account if no account is found
|
// If no accountID is provided, but a userID is given, it tries to retrieve the account by userID.
|
||||||
func (am *DefaultAccountManager) GetAccountByUserOrAccountID(ctx context.Context, userID, accountID, domain string) (*Account, error) {
|
// If the user doesn't have an account, it creates one using the provided domain.
|
||||||
|
// Returns the account ID or an error if none is found or created.
|
||||||
|
func (am *DefaultAccountManager) GetAccountIDByUserOrAccountID(ctx context.Context, userID, accountID, domain string) (string, error) {
|
||||||
if accountID != "" {
|
if accountID != "" {
|
||||||
return am.Store.GetAccount(ctx, accountID)
|
exists, err := am.Store.AccountExists(ctx, LockingStrengthShare, accountID)
|
||||||
} else if userID != "" {
|
|
||||||
account, err := am.GetOrCreateAccountByUser(ctx, userID, domain)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, status.Errorf(status.NotFound, "account not found using user id: %s", userID)
|
return "", err
|
||||||
}
|
}
|
||||||
err = am.addAccountIDToIDPAppMeta(ctx, userID, account)
|
if !exists {
|
||||||
if err != nil {
|
return "", status.Errorf(status.NotFound, "account %s does not exist", accountID)
|
||||||
return nil, err
|
|
||||||
}
|
}
|
||||||
return account, nil
|
return accountID, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
return nil, status.Errorf(status.NotFound, "no valid user or account Id provided")
|
if userID != "" {
|
||||||
|
account, err := am.GetOrCreateAccountByUser(ctx, userID, domain)
|
||||||
|
if err != nil {
|
||||||
|
return "", status.Errorf(status.NotFound, "account not found or created for user id: %s", userID)
|
||||||
|
}
|
||||||
|
|
||||||
|
if err = am.addAccountIDToIDPAppMeta(ctx, userID, account); err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
|
||||||
|
return account.Id, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
return "", status.Errorf(status.NotFound, "no valid userID or accountID provided")
|
||||||
}
|
}
|
||||||
|
|
||||||
func isNil(i idp.Manager) bool {
|
func isNil(i idp.Manager) bool {
|
||||||
@ -1613,13 +1633,18 @@ func (am *DefaultAccountManager) handleNewUserAccount(ctx context.Context, domai
|
|||||||
}
|
}
|
||||||
|
|
||||||
// redeemInvite checks whether user has been invited and redeems the invite
|
// redeemInvite checks whether user has been invited and redeems the invite
|
||||||
func (am *DefaultAccountManager) redeemInvite(ctx context.Context, account *Account, userID string) error {
|
func (am *DefaultAccountManager) redeemInvite(ctx context.Context, accountID string, userID string) error {
|
||||||
// only possible with the enabled IdP manager
|
// only possible with the enabled IdP manager
|
||||||
if am.idpManager == nil {
|
if am.idpManager == nil {
|
||||||
log.WithContext(ctx).Warnf("invites only work with enabled IdP manager")
|
log.WithContext(ctx).Warnf("invites only work with enabled IdP manager")
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
account, err := am.Store.GetAccount(ctx, accountID)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
user, err := am.lookupUserInCache(ctx, userID, account)
|
user, err := am.lookupUserInCache(ctx, userID, account)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
@ -1678,6 +1703,11 @@ func (am *DefaultAccountManager) MarkPATUsed(ctx context.Context, tokenID string
|
|||||||
return am.Store.SaveAccount(ctx, account)
|
return am.Store.SaveAccount(ctx, account)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// GetAccount returns an account associated with this account ID.
|
||||||
|
func (am *DefaultAccountManager) GetAccount(ctx context.Context, accountID string) (*Account, error) {
|
||||||
|
return am.Store.GetAccount(ctx, accountID)
|
||||||
|
}
|
||||||
|
|
||||||
// GetAccountFromPAT returns Account and User associated with a personal access token
|
// GetAccountFromPAT returns Account and User associated with a personal access token
|
||||||
func (am *DefaultAccountManager) GetAccountFromPAT(ctx context.Context, token string) (*Account, *User, *PersonalAccessToken, error) {
|
func (am *DefaultAccountManager) GetAccountFromPAT(ctx context.Context, token string) (*Account, *User, *PersonalAccessToken, error) {
|
||||||
if len(token) != PATLength {
|
if len(token) != PATLength {
|
||||||
@ -1726,10 +1756,24 @@ func (am *DefaultAccountManager) GetAccountFromPAT(ctx context.Context, token st
|
|||||||
return account, user, pat, nil
|
return account, user, pat, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetAccountFromToken returns an account associated with this token
|
// GetAccountByID returns an account associated with this account ID.
|
||||||
func (am *DefaultAccountManager) GetAccountFromToken(ctx context.Context, claims jwtclaims.AuthorizationClaims) (*Account, *User, error) {
|
func (am *DefaultAccountManager) GetAccountByID(ctx context.Context, accountID string, userID string) (*Account, error) {
|
||||||
|
user, err := am.Store.GetUserByUserID(ctx, LockingStrengthShare, userID)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
if user.AccountID != accountID || (!user.HasAdminPower() && !user.IsServiceUser) {
|
||||||
|
return nil, status.Errorf(status.PermissionDenied, "the user has no permission to access account data")
|
||||||
|
}
|
||||||
|
|
||||||
|
return am.Store.GetAccount(ctx, accountID)
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetAccountIDFromToken returns an account ID associated with this token.
|
||||||
|
func (am *DefaultAccountManager) GetAccountIDFromToken(ctx context.Context, claims jwtclaims.AuthorizationClaims) (string, string, error) {
|
||||||
if claims.UserId == "" {
|
if claims.UserId == "" {
|
||||||
return nil, nil, fmt.Errorf("user ID is empty")
|
return "", "", fmt.Errorf("user ID is empty")
|
||||||
}
|
}
|
||||||
if am.singleAccountMode && am.singleAccountModeDomain != "" {
|
if am.singleAccountMode && am.singleAccountModeDomain != "" {
|
||||||
// This section is mostly related to self-hosted installations.
|
// This section is mostly related to self-hosted installations.
|
||||||
@ -1739,110 +1783,111 @@ func (am *DefaultAccountManager) GetAccountFromToken(ctx context.Context, claims
|
|||||||
log.WithContext(ctx).Debugf("overriding JWT Domain and DomainCategory claims since single account mode is enabled")
|
log.WithContext(ctx).Debugf("overriding JWT Domain and DomainCategory claims since single account mode is enabled")
|
||||||
}
|
}
|
||||||
|
|
||||||
newAcc, err := am.getAccountWithAuthorizationClaims(ctx, claims)
|
accountID, err := am.getAccountIDWithAuthorizationClaims(ctx, claims)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, nil, err
|
return "", "", err
|
||||||
}
|
|
||||||
unlock := am.Store.AcquireWriteLockByUID(ctx, newAcc.Id)
|
|
||||||
alreadyUnlocked := false
|
|
||||||
defer func() {
|
|
||||||
if !alreadyUnlocked {
|
|
||||||
unlock()
|
|
||||||
}
|
|
||||||
}()
|
|
||||||
|
|
||||||
account, err := am.Store.GetAccount(ctx, newAcc.Id)
|
|
||||||
if err != nil {
|
|
||||||
return nil, nil, err
|
|
||||||
}
|
}
|
||||||
|
|
||||||
user := account.Users[claims.UserId]
|
user, err := am.Store.GetUserByUserID(ctx, LockingStrengthShare, claims.UserId)
|
||||||
if user == nil {
|
if err != nil {
|
||||||
// this is not really possible because we got an account by user ID
|
// this is not really possible because we got an account by user ID
|
||||||
return nil, nil, status.Errorf(status.NotFound, "user %s not found", claims.UserId)
|
return "", "", status.Errorf(status.NotFound, "user %s not found", claims.UserId)
|
||||||
}
|
}
|
||||||
|
|
||||||
if !user.IsServiceUser && claims.Invited {
|
if !user.IsServiceUser && claims.Invited {
|
||||||
err = am.redeemInvite(ctx, account, claims.UserId)
|
err = am.redeemInvite(ctx, accountID, user.Id)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, nil, err
|
return "", "", err
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if account.Settings.JWTGroupsEnabled {
|
if err = am.syncJWTGroups(ctx, accountID, user, claims); err != nil {
|
||||||
if account.Settings.JWTGroupsClaimName == "" {
|
return "", "", err
|
||||||
log.WithContext(ctx).Errorf("JWT groups are enabled but no claim name is set")
|
|
||||||
return account, user, nil
|
|
||||||
}
|
|
||||||
if claim, ok := claims.Raw[account.Settings.JWTGroupsClaimName]; ok {
|
|
||||||
if slice, ok := claim.([]interface{}); ok {
|
|
||||||
var groupsNames []string
|
|
||||||
for _, item := range slice {
|
|
||||||
if g, ok := item.(string); ok {
|
|
||||||
groupsNames = append(groupsNames, g)
|
|
||||||
} else {
|
|
||||||
log.WithContext(ctx).Errorf("JWT claim %q is not a string: %v", account.Settings.JWTGroupsClaimName, item)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
oldGroups := make([]string, len(user.AutoGroups))
|
|
||||||
copy(oldGroups, user.AutoGroups)
|
|
||||||
// if groups were added or modified, save the account
|
|
||||||
if account.SetJWTGroups(claims.UserId, groupsNames) {
|
|
||||||
if account.Settings.GroupsPropagationEnabled {
|
|
||||||
if user, err := account.FindUser(claims.UserId); err == nil {
|
|
||||||
addNewGroups := difference(user.AutoGroups, oldGroups)
|
|
||||||
removeOldGroups := difference(oldGroups, user.AutoGroups)
|
|
||||||
account.UserGroupsAddToPeers(claims.UserId, addNewGroups...)
|
|
||||||
account.UserGroupsRemoveFromPeers(claims.UserId, removeOldGroups...)
|
|
||||||
account.Network.IncSerial()
|
|
||||||
if err := am.Store.SaveAccount(ctx, account); err != nil {
|
|
||||||
log.WithContext(ctx).Errorf("failed to save account: %v", err)
|
|
||||||
} else {
|
|
||||||
log.WithContext(ctx).Tracef("user %s: JWT group membership changed, updating account peers", claims.UserId)
|
|
||||||
am.updateAccountPeers(ctx, account)
|
|
||||||
unlock()
|
|
||||||
alreadyUnlocked = true
|
|
||||||
for _, g := range addNewGroups {
|
|
||||||
if group := account.GetGroup(g); group != nil {
|
|
||||||
am.StoreEvent(ctx, user.Id, user.Id, account.Id, activity.GroupAddedToUser,
|
|
||||||
map[string]any{
|
|
||||||
"group": group.Name,
|
|
||||||
"group_id": group.ID,
|
|
||||||
"is_service_user": user.IsServiceUser,
|
|
||||||
"user_name": user.ServiceUserName})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
for _, g := range removeOldGroups {
|
|
||||||
if group := account.GetGroup(g); group != nil {
|
|
||||||
am.StoreEvent(ctx, user.Id, user.Id, account.Id, activity.GroupRemovedFromUser,
|
|
||||||
map[string]any{
|
|
||||||
"group": group.Name,
|
|
||||||
"group_id": group.ID,
|
|
||||||
"is_service_user": user.IsServiceUser,
|
|
||||||
"user_name": user.ServiceUserName})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
if err := am.Store.SaveAccount(ctx, account); err != nil {
|
|
||||||
log.WithContext(ctx).Errorf("failed to save account: %v", err)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
log.WithContext(ctx).Debugf("JWT claim %q is not a string array", account.Settings.JWTGroupsClaimName)
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
log.WithContext(ctx).Debugf("JWT claim %q not found", account.Settings.JWTGroupsClaimName)
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
return account, user, nil
|
return accountID, user.Id, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// getAccountWithAuthorizationClaims retrievs an account using JWT Claims.
|
// syncJWTGroups processes the JWT groups for a user, updates the account based on the groups,
|
||||||
|
// and propagates changes to peers if group propagation is enabled.
|
||||||
|
func (am *DefaultAccountManager) syncJWTGroups(ctx context.Context, accountID string, user *User, claims jwtclaims.AuthorizationClaims) error {
|
||||||
|
settings, err := am.Store.GetAccountSettings(ctx, LockingStrengthShare, accountID)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
if settings == nil || !settings.JWTGroupsEnabled {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
if settings.JWTGroupsClaimName == "" {
|
||||||
|
log.WithContext(ctx).Errorf("JWT groups are enabled but no claim name is set")
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// TODO: Remove GetAccount after refactoring account peer's update
|
||||||
|
unlock := am.Store.AcquireWriteLockByUID(ctx, accountID)
|
||||||
|
defer unlock()
|
||||||
|
|
||||||
|
account, err := am.Store.GetAccount(ctx, accountID)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
jwtGroupsNames := extractJWTGroups(ctx, settings.JWTGroupsClaimName, claims)
|
||||||
|
|
||||||
|
oldGroups := make([]string, len(user.AutoGroups))
|
||||||
|
copy(oldGroups, user.AutoGroups)
|
||||||
|
|
||||||
|
// Update the account if group membership changes
|
||||||
|
if account.SetJWTGroups(claims.UserId, jwtGroupsNames) {
|
||||||
|
addNewGroups := difference(user.AutoGroups, oldGroups)
|
||||||
|
removeOldGroups := difference(oldGroups, user.AutoGroups)
|
||||||
|
|
||||||
|
if settings.GroupsPropagationEnabled {
|
||||||
|
account.UserGroupsAddToPeers(claims.UserId, addNewGroups...)
|
||||||
|
account.UserGroupsRemoveFromPeers(claims.UserId, removeOldGroups...)
|
||||||
|
account.Network.IncSerial()
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := am.Store.SaveAccount(ctx, account); err != nil {
|
||||||
|
log.WithContext(ctx).Errorf("failed to save account: %v", err)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Propagate changes to peers if group propagation is enabled
|
||||||
|
if settings.GroupsPropagationEnabled {
|
||||||
|
log.WithContext(ctx).Tracef("user %s: JWT group membership changed, updating account peers", claims.UserId)
|
||||||
|
am.updateAccountPeers(ctx, account)
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, g := range addNewGroups {
|
||||||
|
if group := account.GetGroup(g); group != nil {
|
||||||
|
am.StoreEvent(ctx, user.Id, user.Id, account.Id, activity.GroupAddedToUser,
|
||||||
|
map[string]any{
|
||||||
|
"group": group.Name,
|
||||||
|
"group_id": group.ID,
|
||||||
|
"is_service_user": user.IsServiceUser,
|
||||||
|
"user_name": user.ServiceUserName})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, g := range removeOldGroups {
|
||||||
|
if group := account.GetGroup(g); group != nil {
|
||||||
|
am.StoreEvent(ctx, user.Id, user.Id, account.Id, activity.GroupRemovedFromUser,
|
||||||
|
map[string]any{
|
||||||
|
"group": group.Name,
|
||||||
|
"group_id": group.ID,
|
||||||
|
"is_service_user": user.IsServiceUser,
|
||||||
|
"user_name": user.ServiceUserName})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// getAccountIDWithAuthorizationClaims retrieves an account ID using JWT Claims.
|
||||||
// if domain is of the PrivateCategory category, it will evaluate
|
// if domain is of the PrivateCategory category, it will evaluate
|
||||||
// if account is new, existing or if there is another account with the same domain
|
// if account is new, existing or if there is another account with the same domain
|
||||||
//
|
//
|
||||||
@ -1859,26 +1904,34 @@ func (am *DefaultAccountManager) GetAccountFromToken(ctx context.Context, claims
|
|||||||
// Existing user + Existing account + Existing Indexed Domain -> Nothing changes
|
// Existing user + Existing account + Existing Indexed Domain -> Nothing changes
|
||||||
//
|
//
|
||||||
// Existing user + Existing account + Existing domain reclassified Domain as private -> Nothing changes (index domain)
|
// Existing user + Existing account + Existing domain reclassified Domain as private -> Nothing changes (index domain)
|
||||||
func (am *DefaultAccountManager) getAccountWithAuthorizationClaims(ctx context.Context, claims jwtclaims.AuthorizationClaims) (*Account, error) {
|
func (am *DefaultAccountManager) getAccountIDWithAuthorizationClaims(ctx context.Context, claims jwtclaims.AuthorizationClaims) (string, error) {
|
||||||
log.WithContext(ctx).Tracef("getting account with authorization claims. User ID: \"%s\", Account ID: \"%s\", Domain: \"%s\", Domain Category: \"%s\"",
|
log.WithContext(ctx).Tracef("getting account with authorization claims. User ID: \"%s\", Account ID: \"%s\", Domain: \"%s\", Domain Category: \"%s\"",
|
||||||
claims.UserId, claims.AccountId, claims.Domain, claims.DomainCategory)
|
claims.UserId, claims.AccountId, claims.Domain, claims.DomainCategory)
|
||||||
if claims.UserId == "" {
|
if claims.UserId == "" {
|
||||||
return nil, fmt.Errorf("user ID is empty")
|
return "", fmt.Errorf("user ID is empty")
|
||||||
}
|
}
|
||||||
|
|
||||||
// if Account ID is part of the claims
|
// if Account ID is part of the claims
|
||||||
// it means that we've already classified the domain and user has an account
|
// it means that we've already classified the domain and user has an account
|
||||||
if claims.DomainCategory != PrivateCategory || !isDomainValid(claims.Domain) {
|
if claims.DomainCategory != PrivateCategory || !isDomainValid(claims.Domain) {
|
||||||
return am.GetAccountByUserOrAccountID(ctx, claims.UserId, claims.AccountId, claims.Domain)
|
return am.GetAccountIDByUserOrAccountID(ctx, claims.UserId, claims.AccountId, claims.Domain)
|
||||||
} else if claims.AccountId != "" {
|
} else if claims.AccountId != "" {
|
||||||
accountFromID, err := am.Store.GetAccount(ctx, claims.AccountId)
|
userAccountID, err := am.Store.GetAccountIDByUserID(claims.UserId)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return "", err
|
||||||
}
|
}
|
||||||
if _, ok := accountFromID.Users[claims.UserId]; !ok {
|
|
||||||
return nil, fmt.Errorf("user %s is not part of the account id %s", claims.UserId, claims.AccountId)
|
if userAccountID != claims.AccountId {
|
||||||
|
return "", fmt.Errorf("user %s is not part of the account id %s", claims.UserId, claims.AccountId)
|
||||||
}
|
}
|
||||||
if accountFromID.DomainCategory == PrivateCategory || claims.DomainCategory != PrivateCategory || accountFromID.Domain != claims.Domain {
|
|
||||||
return accountFromID, nil
|
domain, domainCategory, err := am.Store.GetAccountDomainAndCategory(ctx, LockingStrengthShare, claims.AccountId)
|
||||||
|
if err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
|
||||||
|
if domainCategory == PrivateCategory || claims.DomainCategory != PrivateCategory || domain != claims.Domain {
|
||||||
|
return userAccountID, nil
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -1888,48 +1941,53 @@ func (am *DefaultAccountManager) getAccountWithAuthorizationClaims(ctx context.C
|
|||||||
log.WithContext(ctx).Debugf("Acquired global lock in %s for user %s", time.Since(start), claims.UserId)
|
log.WithContext(ctx).Debugf("Acquired global lock in %s for user %s", time.Since(start), claims.UserId)
|
||||||
|
|
||||||
// We checked if the domain has a primary account already
|
// We checked if the domain has a primary account already
|
||||||
domainAccount, err := am.Store.GetAccountByPrivateDomain(ctx, claims.Domain)
|
domainAccountID, err := am.Store.GetAccountIDByPrivateDomain(ctx, LockingStrengthShare, claims.Domain)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
// if NotFound we are good to continue, otherwise return error
|
// if NotFound we are good to continue, otherwise return error
|
||||||
e, ok := status.FromError(err)
|
e, ok := status.FromError(err)
|
||||||
if !ok || e.Type() != status.NotFound {
|
if !ok || e.Type() != status.NotFound {
|
||||||
return nil, err
|
return "", err
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
account, err := am.Store.GetAccountByUser(ctx, claims.UserId)
|
userAccountID, err := am.Store.GetAccountIDByUserID(claims.UserId)
|
||||||
if err == nil {
|
if err == nil {
|
||||||
unlockAccount := am.Store.AcquireWriteLockByUID(ctx, account.Id)
|
unlockAccount := am.Store.AcquireWriteLockByUID(ctx, userAccountID)
|
||||||
defer unlockAccount()
|
defer unlockAccount()
|
||||||
account, err = am.Store.GetAccountByUser(ctx, claims.UserId)
|
account, err := am.Store.GetAccountByUser(ctx, claims.UserId)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return "", err
|
||||||
}
|
}
|
||||||
// If there is no primary domain account yet, we set the account as primary for the domain. Otherwise,
|
// If there is no primary domain account yet, we set the account as primary for the domain. Otherwise,
|
||||||
// we compare the account's ID with the domain account ID, and if they don't match, we set the account as
|
// we compare the account's ID with the domain account ID, and if they don't match, we set the account as
|
||||||
// non-primary account for the domain. We don't merge accounts at this stage, because of cases when a domain
|
// non-primary account for the domain. We don't merge accounts at this stage, because of cases when a domain
|
||||||
// was previously unclassified or classified as public so N users that logged int that time, has they own account
|
// was previously unclassified or classified as public so N users that logged int that time, has they own account
|
||||||
// and peers that shouldn't be lost.
|
// and peers that shouldn't be lost.
|
||||||
primaryDomain := domainAccount == nil || account.Id == domainAccount.Id
|
primaryDomain := domainAccountID == "" || account.Id == domainAccountID
|
||||||
|
if err = am.handleExistingUserAccount(ctx, account, primaryDomain, claims); err != nil {
|
||||||
err = am.handleExistingUserAccount(ctx, account, primaryDomain, claims)
|
return "", err
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
}
|
||||||
return account, nil
|
|
||||||
|
return account.Id, nil
|
||||||
} else if s, ok := status.FromError(err); ok && s.Type() == status.NotFound {
|
} else if s, ok := status.FromError(err); ok && s.Type() == status.NotFound {
|
||||||
if domainAccount != nil {
|
var domainAccount *Account
|
||||||
unlockAccount := am.Store.AcquireWriteLockByUID(ctx, domainAccount.Id)
|
if domainAccountID != "" {
|
||||||
|
unlockAccount := am.Store.AcquireWriteLockByUID(ctx, domainAccountID)
|
||||||
defer unlockAccount()
|
defer unlockAccount()
|
||||||
domainAccount, err = am.Store.GetAccountByPrivateDomain(ctx, claims.Domain)
|
domainAccount, err = am.Store.GetAccountByPrivateDomain(ctx, claims.Domain)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return "", err
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return am.handleNewUserAccount(ctx, domainAccount, claims)
|
|
||||||
|
account, err := am.handleNewUserAccount(ctx, domainAccount, claims)
|
||||||
|
if err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
return account.Id, nil
|
||||||
} else {
|
} else {
|
||||||
// other error
|
// other error
|
||||||
return nil, err
|
return "", err
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -2022,26 +2080,21 @@ func (am *DefaultAccountManager) GetDNSDomain() string {
|
|||||||
// CheckUserAccessByJWTGroups checks if the user has access, particularly in cases where the admin enabled JWT
|
// CheckUserAccessByJWTGroups checks if the user has access, particularly in cases where the admin enabled JWT
|
||||||
// group propagation and set the list of groups with access permissions.
|
// group propagation and set the list of groups with access permissions.
|
||||||
func (am *DefaultAccountManager) CheckUserAccessByJWTGroups(ctx context.Context, claims jwtclaims.AuthorizationClaims) error {
|
func (am *DefaultAccountManager) CheckUserAccessByJWTGroups(ctx context.Context, claims jwtclaims.AuthorizationClaims) error {
|
||||||
account, _, err := am.GetAccountFromToken(ctx, claims)
|
accountID, _, err := am.GetAccountIDFromToken(ctx, claims)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
settings, err := am.Store.GetAccountSettings(ctx, LockingStrengthShare, accountID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
// Ensures JWT group synchronization to the management is enabled before,
|
// Ensures JWT group synchronization to the management is enabled before,
|
||||||
// filtering access based on the allowed groups.
|
// filtering access based on the allowed groups.
|
||||||
if account.Settings != nil && account.Settings.JWTGroupsEnabled {
|
if settings != nil && settings.JWTGroupsEnabled {
|
||||||
if allowedGroups := account.Settings.JWTAllowGroups; len(allowedGroups) > 0 {
|
if allowedGroups := settings.JWTAllowGroups; len(allowedGroups) > 0 {
|
||||||
userJWTGroups := make([]string, 0)
|
userJWTGroups := extractJWTGroups(ctx, settings.JWTGroupsClaimName, claims)
|
||||||
|
|
||||||
if claim, ok := claims.Raw[account.Settings.JWTGroupsClaimName]; ok {
|
|
||||||
if claimGroups, ok := claim.([]interface{}); ok {
|
|
||||||
for _, g := range claimGroups {
|
|
||||||
if group, ok := g.(string); ok {
|
|
||||||
userJWTGroups = append(userJWTGroups, group)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if !userHasAllowedGroup(allowedGroups, userJWTGroups) {
|
if !userHasAllowedGroup(allowedGroups, userJWTGroups) {
|
||||||
return fmt.Errorf("user does not belong to any of the allowed JWT groups")
|
return fmt.Errorf("user does not belong to any of the allowed JWT groups")
|
||||||
@ -2111,6 +2164,19 @@ func (am *DefaultAccountManager) getFreeDNSLabel(ctx context.Context, store Stor
|
|||||||
return newLabel, nil
|
return newLabel, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (am *DefaultAccountManager) GetAccountSettings(ctx context.Context, accountID string, userID string) (*Settings, error) {
|
||||||
|
user, err := am.Store.GetUserByUserID(ctx, LockingStrengthShare, userID)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
if user.AccountID != accountID || (!user.HasAdminPower() && !user.IsServiceUser) {
|
||||||
|
return nil, status.Errorf(status.PermissionDenied, "the user has no permission to access account data")
|
||||||
|
}
|
||||||
|
|
||||||
|
return am.Store.GetAccountSettings(ctx, LockingStrengthShare, accountID)
|
||||||
|
}
|
||||||
|
|
||||||
// addAllGroup to account object if it doesn't exist
|
// addAllGroup to account object if it doesn't exist
|
||||||
func addAllGroup(account *Account) error {
|
func addAllGroup(account *Account) error {
|
||||||
if len(account.Groups) == 0 {
|
if len(account.Groups) == 0 {
|
||||||
@ -2193,6 +2259,27 @@ func newAccountWithId(ctx context.Context, accountID, userID, domain string) *Ac
|
|||||||
return acc
|
return acc
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// extractJWTGroups extracts the group names from a JWT token's claims.
|
||||||
|
func extractJWTGroups(ctx context.Context, claimName string, claims jwtclaims.AuthorizationClaims) []string {
|
||||||
|
userJWTGroups := make([]string, 0)
|
||||||
|
|
||||||
|
if claim, ok := claims.Raw[claimName]; ok {
|
||||||
|
if claimGroups, ok := claim.([]interface{}); ok {
|
||||||
|
for _, g := range claimGroups {
|
||||||
|
if group, ok := g.(string); ok {
|
||||||
|
userJWTGroups = append(userJWTGroups, group)
|
||||||
|
} else {
|
||||||
|
log.WithContext(ctx).Debugf("JWT claim %q contains a non-string group (type: %T): %v", claimName, g, g)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
log.WithContext(ctx).Debugf("JWT claim %q is not a string array", claimName)
|
||||||
|
}
|
||||||
|
|
||||||
|
return userJWTGroups
|
||||||
|
}
|
||||||
|
|
||||||
// userHasAllowedGroup checks if a user belongs to any of the allowed groups.
|
// userHasAllowedGroup checks if a user belongs to any of the allowed groups.
|
||||||
func userHasAllowedGroup(allowedGroups []string, userGroups []string) bool {
|
func userHasAllowedGroup(allowedGroups []string, userGroups []string) bool {
|
||||||
for _, userGroup := range userGroups {
|
for _, userGroup := range userGroups {
|
||||||
|
@ -462,7 +462,7 @@ func TestAccountManager_GetOrCreateAccountByUser(t *testing.T) {
|
|||||||
assert.Equal(t, account.Id, ev.TargetID)
|
assert.Equal(t, account.Id, ev.TargetID)
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestDefaultAccountManager_GetAccountFromToken(t *testing.T) {
|
func TestDefaultAccountManager_GetAccountIDFromToken(t *testing.T) {
|
||||||
type initUserParams jwtclaims.AuthorizationClaims
|
type initUserParams jwtclaims.AuthorizationClaims
|
||||||
|
|
||||||
type test struct {
|
type test struct {
|
||||||
@ -633,9 +633,12 @@ func TestDefaultAccountManager_GetAccountFromToken(t *testing.T) {
|
|||||||
manager, err := createManager(t)
|
manager, err := createManager(t)
|
||||||
require.NoError(t, err, "unable to create account manager")
|
require.NoError(t, err, "unable to create account manager")
|
||||||
|
|
||||||
initAccount, err := manager.GetAccountByUserOrAccountID(context.Background(), testCase.inputInitUserParams.UserId, testCase.inputInitUserParams.AccountId, testCase.inputInitUserParams.Domain)
|
accountID, err := manager.GetAccountIDByUserOrAccountID(context.Background(), testCase.inputInitUserParams.UserId, testCase.inputInitUserParams.AccountId, testCase.inputInitUserParams.Domain)
|
||||||
require.NoError(t, err, "create init user failed")
|
require.NoError(t, err, "create init user failed")
|
||||||
|
|
||||||
|
initAccount, err := manager.Store.GetAccount(context.Background(), accountID)
|
||||||
|
require.NoError(t, err, "get init account failed")
|
||||||
|
|
||||||
if testCase.inputUpdateAttrs {
|
if testCase.inputUpdateAttrs {
|
||||||
err = manager.updateAccountDomainAttributes(context.Background(), initAccount, jwtclaims.AuthorizationClaims{UserId: testCase.inputInitUserParams.UserId, Domain: testCase.inputInitUserParams.Domain, DomainCategory: testCase.inputInitUserParams.DomainCategory}, true)
|
err = manager.updateAccountDomainAttributes(context.Background(), initAccount, jwtclaims.AuthorizationClaims{UserId: testCase.inputInitUserParams.UserId, Domain: testCase.inputInitUserParams.Domain, DomainCategory: testCase.inputInitUserParams.DomainCategory}, true)
|
||||||
require.NoError(t, err, "update init user failed")
|
require.NoError(t, err, "update init user failed")
|
||||||
@ -645,8 +648,12 @@ func TestDefaultAccountManager_GetAccountFromToken(t *testing.T) {
|
|||||||
testCase.inputClaims.AccountId = initAccount.Id
|
testCase.inputClaims.AccountId = initAccount.Id
|
||||||
}
|
}
|
||||||
|
|
||||||
account, _, err := manager.GetAccountFromToken(context.Background(), testCase.inputClaims)
|
accountID, _, err = manager.GetAccountIDFromToken(context.Background(), testCase.inputClaims)
|
||||||
require.NoError(t, err, "support function failed")
|
require.NoError(t, err, "support function failed")
|
||||||
|
|
||||||
|
account, err := manager.Store.GetAccount(context.Background(), accountID)
|
||||||
|
require.NoError(t, err, "get account failed")
|
||||||
|
|
||||||
verifyNewAccountHasDefaultFields(t, account, testCase.expectedCreatedBy, testCase.inputClaims.Domain, testCase.expectedUsers)
|
verifyNewAccountHasDefaultFields(t, account, testCase.expectedCreatedBy, testCase.inputClaims.Domain, testCase.expectedUsers)
|
||||||
verifyCanAddPeerToAccount(t, manager, account, testCase.expectedCreatedBy)
|
verifyCanAddPeerToAccount(t, manager, account, testCase.expectedCreatedBy)
|
||||||
|
|
||||||
@ -669,12 +676,13 @@ func TestDefaultAccountManager_GetGroupsFromTheToken(t *testing.T) {
|
|||||||
require.NoError(t, err, "unable to create account manager")
|
require.NoError(t, err, "unable to create account manager")
|
||||||
|
|
||||||
accountID := initAccount.Id
|
accountID := initAccount.Id
|
||||||
acc, err := manager.GetAccountByUserOrAccountID(context.Background(), userId, accountID, domain)
|
accountID, err = manager.GetAccountIDByUserOrAccountID(context.Background(), userId, accountID, domain)
|
||||||
require.NoError(t, err, "create init user failed")
|
require.NoError(t, err, "create init user failed")
|
||||||
// as initAccount was created without account id we have to take the id after account initialization
|
// as initAccount was created without account id we have to take the id after account initialization
|
||||||
// that happens inside the GetAccountByUserOrAccountID where the id is getting generated
|
// that happens inside the GetAccountIDByUserOrAccountID where the id is getting generated
|
||||||
// it is important to set the id as it help to avoid creating additional account with empty Id and re-pointing indices to it
|
// it is important to set the id as it help to avoid creating additional account with empty Id and re-pointing indices to it
|
||||||
initAccount = acc
|
initAccount, err = manager.Store.GetAccount(context.Background(), accountID)
|
||||||
|
require.NoError(t, err, "get init account failed")
|
||||||
|
|
||||||
claims := jwtclaims.AuthorizationClaims{
|
claims := jwtclaims.AuthorizationClaims{
|
||||||
AccountId: accountID, // is empty as it is based on accountID right after initialization of initAccount
|
AccountId: accountID, // is empty as it is based on accountID right after initialization of initAccount
|
||||||
@ -685,8 +693,12 @@ func TestDefaultAccountManager_GetGroupsFromTheToken(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
t.Run("JWT groups disabled", func(t *testing.T) {
|
t.Run("JWT groups disabled", func(t *testing.T) {
|
||||||
account, _, err := manager.GetAccountFromToken(context.Background(), claims)
|
accountID, _, err := manager.GetAccountIDFromToken(context.Background(), claims)
|
||||||
require.NoError(t, err, "get account by token failed")
|
require.NoError(t, err, "get account by token failed")
|
||||||
|
|
||||||
|
account, err := manager.Store.GetAccount(context.Background(), accountID)
|
||||||
|
require.NoError(t, err, "get account failed")
|
||||||
|
|
||||||
require.Len(t, account.Groups, 1, "only ALL group should exists")
|
require.Len(t, account.Groups, 1, "only ALL group should exists")
|
||||||
})
|
})
|
||||||
|
|
||||||
@ -696,8 +708,12 @@ func TestDefaultAccountManager_GetGroupsFromTheToken(t *testing.T) {
|
|||||||
require.NoError(t, err, "save account failed")
|
require.NoError(t, err, "save account failed")
|
||||||
require.Len(t, manager.Store.GetAllAccounts(context.Background()), 1, "only one account should exist")
|
require.Len(t, manager.Store.GetAllAccounts(context.Background()), 1, "only one account should exist")
|
||||||
|
|
||||||
account, _, err := manager.GetAccountFromToken(context.Background(), claims)
|
accountID, _, err := manager.GetAccountIDFromToken(context.Background(), claims)
|
||||||
require.NoError(t, err, "get account by token failed")
|
require.NoError(t, err, "get account by token failed")
|
||||||
|
|
||||||
|
account, err := manager.Store.GetAccount(context.Background(), accountID)
|
||||||
|
require.NoError(t, err, "get account failed")
|
||||||
|
|
||||||
require.Len(t, account.Groups, 1, "if group claim is not set no group added from JWT")
|
require.Len(t, account.Groups, 1, "if group claim is not set no group added from JWT")
|
||||||
})
|
})
|
||||||
|
|
||||||
@ -708,8 +724,12 @@ func TestDefaultAccountManager_GetGroupsFromTheToken(t *testing.T) {
|
|||||||
require.NoError(t, err, "save account failed")
|
require.NoError(t, err, "save account failed")
|
||||||
require.Len(t, manager.Store.GetAllAccounts(context.Background()), 1, "only one account should exist")
|
require.Len(t, manager.Store.GetAllAccounts(context.Background()), 1, "only one account should exist")
|
||||||
|
|
||||||
account, _, err := manager.GetAccountFromToken(context.Background(), claims)
|
accountID, _, err := manager.GetAccountIDFromToken(context.Background(), claims)
|
||||||
require.NoError(t, err, "get account by token failed")
|
require.NoError(t, err, "get account by token failed")
|
||||||
|
|
||||||
|
account, err := manager.Store.GetAccount(context.Background(), accountID)
|
||||||
|
require.NoError(t, err, "get account failed")
|
||||||
|
|
||||||
require.Len(t, account.Groups, 3, "groups should be added to the account")
|
require.Len(t, account.Groups, 3, "groups should be added to the account")
|
||||||
|
|
||||||
groupsByNames := map[string]*group.Group{}
|
groupsByNames := map[string]*group.Group{}
|
||||||
@ -874,21 +894,21 @@ func TestAccountManager_GetAccountByUserOrAccountId(t *testing.T) {
|
|||||||
|
|
||||||
userId := "test_user"
|
userId := "test_user"
|
||||||
|
|
||||||
account, err := manager.GetAccountByUserOrAccountID(context.Background(), userId, "", "")
|
accountID, err := manager.GetAccountIDByUserOrAccountID(context.Background(), userId, "", "")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
if account == nil {
|
if accountID == "" {
|
||||||
t.Fatalf("expected to create an account for a user %s", userId)
|
t.Fatalf("expected to create an account for a user %s", userId)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
_, err = manager.GetAccountByUserOrAccountID(context.Background(), "", account.Id, "")
|
_, err = manager.GetAccountIDByUserOrAccountID(context.Background(), "", accountID, "")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Errorf("expected to get existing account after creation using userid, no account was found for a account %s", account.Id)
|
t.Errorf("expected to get existing account after creation using userid, no account was found for a account %s", accountID)
|
||||||
}
|
}
|
||||||
|
|
||||||
_, err = manager.GetAccountByUserOrAccountID(context.Background(), "", "", "")
|
_, err = manager.GetAccountIDByUserOrAccountID(context.Background(), "", "", "")
|
||||||
if err == nil {
|
if err == nil {
|
||||||
t.Errorf("expected an error when user and account IDs are empty")
|
t.Errorf("expected an error when user and account IDs are empty")
|
||||||
}
|
}
|
||||||
@ -1240,7 +1260,7 @@ func TestAccountManager_NetworkUpdates(t *testing.T) {
|
|||||||
}
|
}
|
||||||
}()
|
}()
|
||||||
|
|
||||||
if err := manager.SavePolicy(context.Background(), account.Id, userID, &policy); err != nil {
|
if err := manager.SavePolicy(context.Background(), account.Id, userID, &policy, false); err != nil {
|
||||||
t.Errorf("delete default rule: %v", err)
|
t.Errorf("delete default rule: %v", err)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
@ -1648,19 +1668,22 @@ func TestDefaultAccountManager_DefaultAccountSettings(t *testing.T) {
|
|||||||
manager, err := createManager(t)
|
manager, err := createManager(t)
|
||||||
require.NoError(t, err, "unable to create account manager")
|
require.NoError(t, err, "unable to create account manager")
|
||||||
|
|
||||||
account, err := manager.GetAccountByUserOrAccountID(context.Background(), userID, "", "")
|
accountID, err := manager.GetAccountIDByUserOrAccountID(context.Background(), userID, "", "")
|
||||||
require.NoError(t, err, "unable to create an account")
|
require.NoError(t, err, "unable to create an account")
|
||||||
|
|
||||||
assert.NotNil(t, account.Settings)
|
settings, err := manager.Store.GetAccountSettings(context.Background(), LockingStrengthShare, accountID)
|
||||||
assert.Equal(t, account.Settings.PeerLoginExpirationEnabled, true)
|
require.NoError(t, err, "unable to get account settings")
|
||||||
assert.Equal(t, account.Settings.PeerLoginExpiration, 24*time.Hour)
|
|
||||||
|
assert.NotNil(t, settings)
|
||||||
|
assert.Equal(t, settings.PeerLoginExpirationEnabled, true)
|
||||||
|
assert.Equal(t, settings.PeerLoginExpiration, 24*time.Hour)
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestDefaultAccountManager_UpdatePeer_PeerLoginExpiration(t *testing.T) {
|
func TestDefaultAccountManager_UpdatePeer_PeerLoginExpiration(t *testing.T) {
|
||||||
manager, err := createManager(t)
|
manager, err := createManager(t)
|
||||||
require.NoError(t, err, "unable to create account manager")
|
require.NoError(t, err, "unable to create account manager")
|
||||||
|
|
||||||
_, err = manager.GetAccountByUserOrAccountID(context.Background(), userID, "", "")
|
_, err = manager.GetAccountIDByUserOrAccountID(context.Background(), userID, "", "")
|
||||||
require.NoError(t, err, "unable to create an account")
|
require.NoError(t, err, "unable to create an account")
|
||||||
|
|
||||||
key, err := wgtypes.GenerateKey()
|
key, err := wgtypes.GenerateKey()
|
||||||
@ -1672,11 +1695,16 @@ func TestDefaultAccountManager_UpdatePeer_PeerLoginExpiration(t *testing.T) {
|
|||||||
})
|
})
|
||||||
require.NoError(t, err, "unable to add peer")
|
require.NoError(t, err, "unable to add peer")
|
||||||
|
|
||||||
account, err := manager.GetAccountByUserOrAccountID(context.Background(), userID, "", "")
|
accountID, err := manager.GetAccountIDByUserOrAccountID(context.Background(), userID, "", "")
|
||||||
require.NoError(t, err, "unable to get the account")
|
require.NoError(t, err, "unable to get the account")
|
||||||
|
|
||||||
|
account, err := manager.Store.GetAccount(context.Background(), accountID)
|
||||||
|
require.NoError(t, err, "unable to get the account")
|
||||||
|
|
||||||
err = manager.MarkPeerConnected(context.Background(), key.PublicKey().String(), true, nil, account)
|
err = manager.MarkPeerConnected(context.Background(), key.PublicKey().String(), true, nil, account)
|
||||||
require.NoError(t, err, "unable to mark peer connected")
|
require.NoError(t, err, "unable to mark peer connected")
|
||||||
account, err = manager.UpdateAccountSettings(context.Background(), account.Id, userID, &Settings{
|
|
||||||
|
account, err = manager.UpdateAccountSettings(context.Background(), accountID, userID, &Settings{
|
||||||
PeerLoginExpiration: time.Hour,
|
PeerLoginExpiration: time.Hour,
|
||||||
PeerLoginExpirationEnabled: true,
|
PeerLoginExpirationEnabled: true,
|
||||||
})
|
})
|
||||||
@ -1713,7 +1741,7 @@ func TestDefaultAccountManager_MarkPeerConnected_PeerLoginExpiration(t *testing.
|
|||||||
manager, err := createManager(t)
|
manager, err := createManager(t)
|
||||||
require.NoError(t, err, "unable to create account manager")
|
require.NoError(t, err, "unable to create account manager")
|
||||||
|
|
||||||
account, err := manager.GetAccountByUserOrAccountID(context.Background(), userID, "", "")
|
accountID, err := manager.GetAccountIDByUserOrAccountID(context.Background(), userID, "", "")
|
||||||
require.NoError(t, err, "unable to create an account")
|
require.NoError(t, err, "unable to create an account")
|
||||||
|
|
||||||
key, err := wgtypes.GenerateKey()
|
key, err := wgtypes.GenerateKey()
|
||||||
@ -1724,7 +1752,7 @@ func TestDefaultAccountManager_MarkPeerConnected_PeerLoginExpiration(t *testing.
|
|||||||
LoginExpirationEnabled: true,
|
LoginExpirationEnabled: true,
|
||||||
})
|
})
|
||||||
require.NoError(t, err, "unable to add peer")
|
require.NoError(t, err, "unable to add peer")
|
||||||
_, err = manager.UpdateAccountSettings(context.Background(), account.Id, userID, &Settings{
|
_, err = manager.UpdateAccountSettings(context.Background(), accountID, userID, &Settings{
|
||||||
PeerLoginExpiration: time.Hour,
|
PeerLoginExpiration: time.Hour,
|
||||||
PeerLoginExpirationEnabled: true,
|
PeerLoginExpirationEnabled: true,
|
||||||
})
|
})
|
||||||
@ -1741,8 +1769,12 @@ func TestDefaultAccountManager_MarkPeerConnected_PeerLoginExpiration(t *testing.
|
|||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
account, err = manager.GetAccountByUserOrAccountID(context.Background(), userID, "", "")
|
accountID, err = manager.GetAccountIDByUserOrAccountID(context.Background(), userID, "", "")
|
||||||
require.NoError(t, err, "unable to get the account")
|
require.NoError(t, err, "unable to get the account")
|
||||||
|
|
||||||
|
account, err := manager.Store.GetAccount(context.Background(), accountID)
|
||||||
|
require.NoError(t, err, "unable to get the account")
|
||||||
|
|
||||||
// when we mark peer as connected, the peer login expiration routine should trigger
|
// when we mark peer as connected, the peer login expiration routine should trigger
|
||||||
err = manager.MarkPeerConnected(context.Background(), key.PublicKey().String(), true, nil, account)
|
err = manager.MarkPeerConnected(context.Background(), key.PublicKey().String(), true, nil, account)
|
||||||
require.NoError(t, err, "unable to mark peer connected")
|
require.NoError(t, err, "unable to mark peer connected")
|
||||||
@ -1757,7 +1789,7 @@ func TestDefaultAccountManager_UpdateAccountSettings_PeerLoginExpiration(t *test
|
|||||||
manager, err := createManager(t)
|
manager, err := createManager(t)
|
||||||
require.NoError(t, err, "unable to create account manager")
|
require.NoError(t, err, "unable to create account manager")
|
||||||
|
|
||||||
_, err = manager.GetAccountByUserOrAccountID(context.Background(), userID, "", "")
|
_, err = manager.GetAccountIDByUserOrAccountID(context.Background(), userID, "", "")
|
||||||
require.NoError(t, err, "unable to create an account")
|
require.NoError(t, err, "unable to create an account")
|
||||||
|
|
||||||
key, err := wgtypes.GenerateKey()
|
key, err := wgtypes.GenerateKey()
|
||||||
@ -1769,8 +1801,12 @@ func TestDefaultAccountManager_UpdateAccountSettings_PeerLoginExpiration(t *test
|
|||||||
})
|
})
|
||||||
require.NoError(t, err, "unable to add peer")
|
require.NoError(t, err, "unable to add peer")
|
||||||
|
|
||||||
account, err := manager.GetAccountByUserOrAccountID(context.Background(), userID, "", "")
|
accountID, err := manager.GetAccountIDByUserOrAccountID(context.Background(), userID, "", "")
|
||||||
require.NoError(t, err, "unable to get the account")
|
require.NoError(t, err, "unable to get the account")
|
||||||
|
|
||||||
|
account, err := manager.Store.GetAccount(context.Background(), accountID)
|
||||||
|
require.NoError(t, err, "unable to get the account")
|
||||||
|
|
||||||
err = manager.MarkPeerConnected(context.Background(), key.PublicKey().String(), true, nil, account)
|
err = manager.MarkPeerConnected(context.Background(), key.PublicKey().String(), true, nil, account)
|
||||||
require.NoError(t, err, "unable to mark peer connected")
|
require.NoError(t, err, "unable to mark peer connected")
|
||||||
|
|
||||||
@ -1813,10 +1849,10 @@ func TestDefaultAccountManager_UpdateAccountSettings(t *testing.T) {
|
|||||||
manager, err := createManager(t)
|
manager, err := createManager(t)
|
||||||
require.NoError(t, err, "unable to create account manager")
|
require.NoError(t, err, "unable to create account manager")
|
||||||
|
|
||||||
account, err := manager.GetAccountByUserOrAccountID(context.Background(), userID, "", "")
|
accountID, err := manager.GetAccountIDByUserOrAccountID(context.Background(), userID, "", "")
|
||||||
require.NoError(t, err, "unable to create an account")
|
require.NoError(t, err, "unable to create an account")
|
||||||
|
|
||||||
updated, err := manager.UpdateAccountSettings(context.Background(), account.Id, userID, &Settings{
|
updated, err := manager.UpdateAccountSettings(context.Background(), accountID, userID, &Settings{
|
||||||
PeerLoginExpiration: time.Hour,
|
PeerLoginExpiration: time.Hour,
|
||||||
PeerLoginExpirationEnabled: false,
|
PeerLoginExpirationEnabled: false,
|
||||||
})
|
})
|
||||||
@ -1824,19 +1860,22 @@ func TestDefaultAccountManager_UpdateAccountSettings(t *testing.T) {
|
|||||||
assert.False(t, updated.Settings.PeerLoginExpirationEnabled)
|
assert.False(t, updated.Settings.PeerLoginExpirationEnabled)
|
||||||
assert.Equal(t, updated.Settings.PeerLoginExpiration, time.Hour)
|
assert.Equal(t, updated.Settings.PeerLoginExpiration, time.Hour)
|
||||||
|
|
||||||
account, err = manager.GetAccountByUserOrAccountID(context.Background(), "", account.Id, "")
|
accountID, err = manager.GetAccountIDByUserOrAccountID(context.Background(), "", accountID, "")
|
||||||
require.NoError(t, err, "unable to get account by ID")
|
require.NoError(t, err, "unable to get account by ID")
|
||||||
|
|
||||||
assert.False(t, account.Settings.PeerLoginExpirationEnabled)
|
settings, err := manager.Store.GetAccountSettings(context.Background(), LockingStrengthShare, accountID)
|
||||||
assert.Equal(t, account.Settings.PeerLoginExpiration, time.Hour)
|
require.NoError(t, err, "unable to get account settings")
|
||||||
|
|
||||||
_, err = manager.UpdateAccountSettings(context.Background(), account.Id, userID, &Settings{
|
assert.False(t, settings.PeerLoginExpirationEnabled)
|
||||||
|
assert.Equal(t, settings.PeerLoginExpiration, time.Hour)
|
||||||
|
|
||||||
|
_, err = manager.UpdateAccountSettings(context.Background(), accountID, userID, &Settings{
|
||||||
PeerLoginExpiration: time.Second,
|
PeerLoginExpiration: time.Second,
|
||||||
PeerLoginExpirationEnabled: false,
|
PeerLoginExpirationEnabled: false,
|
||||||
})
|
})
|
||||||
require.Error(t, err, "expecting to fail when providing PeerLoginExpiration less than one hour")
|
require.Error(t, err, "expecting to fail when providing PeerLoginExpiration less than one hour")
|
||||||
|
|
||||||
_, err = manager.UpdateAccountSettings(context.Background(), account.Id, userID, &Settings{
|
_, err = manager.UpdateAccountSettings(context.Background(), accountID, userID, &Settings{
|
||||||
PeerLoginExpiration: time.Hour * 24 * 181,
|
PeerLoginExpiration: time.Hour * 24 * 181,
|
||||||
PeerLoginExpirationEnabled: false,
|
PeerLoginExpirationEnabled: false,
|
||||||
})
|
})
|
||||||
|
@ -80,24 +80,16 @@ func (d DNSSettings) Copy() DNSSettings {
|
|||||||
|
|
||||||
// GetDNSSettings validates a user role and returns the DNS settings for the provided account ID
|
// GetDNSSettings validates a user role and returns the DNS settings for the provided account ID
|
||||||
func (am *DefaultAccountManager) GetDNSSettings(ctx context.Context, accountID string, userID string) (*DNSSettings, error) {
|
func (am *DefaultAccountManager) GetDNSSettings(ctx context.Context, accountID string, userID string) (*DNSSettings, error) {
|
||||||
unlock := am.Store.AcquireWriteLockByUID(ctx, accountID)
|
user, err := am.Store.GetUserByUserID(ctx, LockingStrengthShare, userID)
|
||||||
defer unlock()
|
|
||||||
|
|
||||||
account, err := am.Store.GetAccount(ctx, accountID)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
user, err := account.FindUser(userID)
|
if !user.IsAdminOrServiceUser() || user.AccountID != accountID {
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
if !(user.HasAdminPower() || user.IsServiceUser) {
|
|
||||||
return nil, status.Errorf(status.PermissionDenied, "only users with admin power are allowed to view DNS settings")
|
return nil, status.Errorf(status.PermissionDenied, "only users with admin power are allowed to view DNS settings")
|
||||||
}
|
}
|
||||||
dnsSettings := account.DNSSettings.Copy()
|
|
||||||
return &dnsSettings, nil
|
return am.Store.GetAccountDNSSettings(ctx, LockingStrengthShare, accountID)
|
||||||
}
|
}
|
||||||
|
|
||||||
// SaveDNSSettings validates a user role and updates the account's DNS settings
|
// SaveDNSSettings validates a user role and updates the account's DNS settings
|
||||||
|
@ -10,14 +10,15 @@ import (
|
|||||||
"sync"
|
"sync"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/rs/xid"
|
"github.com/netbirdio/netbird/dns"
|
||||||
log "github.com/sirupsen/logrus"
|
|
||||||
|
|
||||||
nbgroup "github.com/netbirdio/netbird/management/server/group"
|
nbgroup "github.com/netbirdio/netbird/management/server/group"
|
||||||
nbpeer "github.com/netbirdio/netbird/management/server/peer"
|
nbpeer "github.com/netbirdio/netbird/management/server/peer"
|
||||||
"github.com/netbirdio/netbird/management/server/posture"
|
"github.com/netbirdio/netbird/management/server/posture"
|
||||||
"github.com/netbirdio/netbird/management/server/status"
|
"github.com/netbirdio/netbird/management/server/status"
|
||||||
"github.com/netbirdio/netbird/management/server/telemetry"
|
"github.com/netbirdio/netbird/management/server/telemetry"
|
||||||
|
"github.com/netbirdio/netbird/route"
|
||||||
|
"github.com/rs/xid"
|
||||||
|
log "github.com/sirupsen/logrus"
|
||||||
|
|
||||||
"github.com/netbirdio/netbird/util"
|
"github.com/netbirdio/netbird/util"
|
||||||
)
|
)
|
||||||
@ -634,10 +635,19 @@ func (s *FileStore) GetUserByUserID(_ context.Context, _ LockingStrength, userID
|
|||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
return account.Users[userID].Copy(), nil
|
user := account.Users[userID].Copy()
|
||||||
|
pat := make([]PersonalAccessToken, 0, len(user.PATs))
|
||||||
|
for _, token := range user.PATs {
|
||||||
|
if token != nil {
|
||||||
|
pat = append(pat, *token)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
user.PATsG = pat
|
||||||
|
|
||||||
|
return user, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *FileStore) GetAccountGroups(ctx context.Context, accountID string) ([]*nbgroup.Group, error) {
|
func (s *FileStore) GetAccountGroups(_ context.Context, accountID string) ([]*nbgroup.Group, error) {
|
||||||
account, err := s.getAccount(accountID)
|
account, err := s.getAccount(accountID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
@ -931,7 +941,7 @@ func (s *FileStore) SaveUserLastLogin(_ context.Context, accountID, userID strin
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *FileStore) GetPostureCheckByChecksDefinition(accountID string, checks *posture.ChecksDefinition) (*posture.Checks, error) {
|
func (s *FileStore) GetPostureCheckByChecksDefinition(_ string, _ *posture.ChecksDefinition) (*posture.Checks, error) {
|
||||||
return nil, status.Errorf(status.Internal, "GetPostureCheckByChecksDefinition is not implemented")
|
return nil, status.Errorf(status.Internal, "GetPostureCheckByChecksDefinition is not implemented")
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -950,10 +960,85 @@ func (s *FileStore) GetStoreEngine() StoreEngine {
|
|||||||
return FileStoreEngine
|
return FileStoreEngine
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *FileStore) SaveUsers(accountID string, users map[string]*User) error {
|
func (s *FileStore) SaveUsers(_ string, _ map[string]*User) error {
|
||||||
return status.Errorf(status.Internal, "SaveUsers is not implemented")
|
return status.Errorf(status.Internal, "SaveUsers is not implemented")
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *FileStore) SaveGroups(accountID string, groups map[string]*nbgroup.Group) error {
|
func (s *FileStore) SaveGroups(_ string, _ map[string]*nbgroup.Group) error {
|
||||||
return status.Errorf(status.Internal, "SaveGroups is not implemented")
|
return status.Errorf(status.Internal, "SaveGroups is not implemented")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (s *FileStore) GetAccountIDByPrivateDomain(_ context.Context, _ LockingStrength, _ string) (string, error) {
|
||||||
|
return "", status.Errorf(status.Internal, "GetAccountIDByPrivateDomain is not implemented")
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *FileStore) GetAccountDomainAndCategory(_ context.Context, _ LockingStrength, accountID string) (string, string, error) {
|
||||||
|
s.mux.Lock()
|
||||||
|
defer s.mux.Unlock()
|
||||||
|
|
||||||
|
account, err := s.getAccount(accountID)
|
||||||
|
if err != nil {
|
||||||
|
return "", "", err
|
||||||
|
}
|
||||||
|
|
||||||
|
return account.Domain, account.DomainCategory, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// AccountExists checks whether an account exists by the given ID.
|
||||||
|
func (s *FileStore) AccountExists(_ context.Context, _ LockingStrength, id string) (bool, error) {
|
||||||
|
_, exists := s.Accounts[id]
|
||||||
|
return exists, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *FileStore) GetAccountDNSSettings(_ context.Context, _ LockingStrength, _ string) (*DNSSettings, error) {
|
||||||
|
return nil, status.Errorf(status.Internal, "GetAccountDNSSettings is not implemented")
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *FileStore) GetGroupByID(_ context.Context, _ LockingStrength, _, _ string) (*nbgroup.Group, error) {
|
||||||
|
return nil, status.Errorf(status.Internal, "GetGroupByID is not implemented")
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *FileStore) GetGroupByName(_ context.Context, _ LockingStrength, _, _ string) (*nbgroup.Group, error) {
|
||||||
|
return nil, status.Errorf(status.Internal, "GetGroupByName is not implemented")
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *FileStore) GetAccountPolicies(_ context.Context, _ LockingStrength, _ string) ([]*Policy, error) {
|
||||||
|
return nil, status.Errorf(status.Internal, "GetPolicyByID is not implemented")
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *FileStore) GetPolicyByID(_ context.Context, _ LockingStrength, _ string, _ string) (*Policy, error) {
|
||||||
|
return nil, status.Errorf(status.Internal, "GetPolicyByID is not implemented")
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *FileStore) GetAccountPostureChecks(_ context.Context, _ LockingStrength, _ string) ([]*posture.Checks, error) {
|
||||||
|
return nil, status.Errorf(status.Internal, "GetAccountPostureChecks is not implemented")
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *FileStore) GetPostureChecksByID(_ context.Context, _ LockingStrength, _ string, _ string) (*posture.Checks, error) {
|
||||||
|
return nil, status.Errorf(status.Internal, "GetPostureChecksByID is not implemented")
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *FileStore) GetAccountRoutes(_ context.Context, _ LockingStrength, _ string) ([]*route.Route, error) {
|
||||||
|
return nil, status.Errorf(status.Internal, "GetAccountRoutes is not implemented")
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *FileStore) GetRouteByID(_ context.Context, _ LockingStrength, _ string, _ string) (*route.Route, error) {
|
||||||
|
return nil, status.Errorf(status.Internal, "GetRouteByID is not implemented")
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *FileStore) GetAccountSetupKeys(_ context.Context, _ LockingStrength, _ string) ([]*SetupKey, error) {
|
||||||
|
return nil, status.Errorf(status.Internal, "GetAccountSetupKeys is not implemented")
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *FileStore) GetSetupKeyByID(_ context.Context, _ LockingStrength, _ string, _ string) (*SetupKey, error) {
|
||||||
|
return nil, status.Errorf(status.Internal, "GetSetupKeyByID is not implemented")
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *FileStore) GetAccountNameServerGroups(_ context.Context, _ LockingStrength, _ string) ([]*dns.NameServerGroup, error) {
|
||||||
|
return nil, status.Errorf(status.Internal, "GetAccountNameServerGroups is not implemented")
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *FileStore) GetNameServerGroupByID(_ context.Context, _ LockingStrength, _ string, _ string) (*dns.NameServerGroup, error) {
|
||||||
|
return nil, status.Errorf(status.Internal, "GetNameServerGroupByID is not implemented")
|
||||||
|
}
|
||||||
|
@ -25,91 +25,46 @@ func (e *GroupLinkError) Error() string {
|
|||||||
return fmt.Sprintf("group has been linked to %s: %s", e.Resource, e.Name)
|
return fmt.Sprintf("group has been linked to %s: %s", e.Resource, e.Name)
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetGroup object of the peers
|
// CheckGroupPermissions validates if a user has the necessary permissions to view groups
|
||||||
|
func (am *DefaultAccountManager) CheckGroupPermissions(ctx context.Context, accountID, userID string) error {
|
||||||
|
settings, err := am.Store.GetAccountSettings(ctx, LockingStrengthShare, accountID)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
user, err := am.Store.GetUserByUserID(ctx, LockingStrengthShare, userID)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
if (!user.IsAdminOrServiceUser() && settings.RegularUsersViewBlocked) || user.AccountID != accountID {
|
||||||
|
return status.Errorf(status.PermissionDenied, "groups are blocked for users")
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetGroup returns a specific group by groupID in an account
|
||||||
func (am *DefaultAccountManager) GetGroup(ctx context.Context, accountID, groupID, userID string) (*nbgroup.Group, error) {
|
func (am *DefaultAccountManager) GetGroup(ctx context.Context, accountID, groupID, userID string) (*nbgroup.Group, error) {
|
||||||
unlock := am.Store.AcquireWriteLockByUID(ctx, accountID)
|
if err := am.CheckGroupPermissions(ctx, accountID, userID); err != nil {
|
||||||
defer unlock()
|
|
||||||
|
|
||||||
account, err := am.Store.GetAccount(ctx, accountID)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
user, err := account.FindUser(userID)
|
return am.Store.GetGroupByID(ctx, LockingStrengthShare, groupID, accountID)
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
if !user.HasAdminPower() && !user.IsServiceUser && account.Settings.RegularUsersViewBlocked {
|
|
||||||
return nil, status.Errorf(status.PermissionDenied, "groups are blocked for users")
|
|
||||||
}
|
|
||||||
|
|
||||||
group, ok := account.Groups[groupID]
|
|
||||||
if ok {
|
|
||||||
return group, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil, status.Errorf(status.NotFound, "group with ID %s not found", groupID)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetAllGroups returns all groups in an account
|
// GetAllGroups returns all groups in an account
|
||||||
func (am *DefaultAccountManager) GetAllGroups(ctx context.Context, accountID string, userID string) ([]*nbgroup.Group, error) {
|
func (am *DefaultAccountManager) GetAllGroups(ctx context.Context, accountID, userID string) ([]*nbgroup.Group, error) {
|
||||||
unlock := am.Store.AcquireWriteLockByUID(ctx, accountID)
|
if err := am.CheckGroupPermissions(ctx, accountID, userID); err != nil {
|
||||||
defer unlock()
|
|
||||||
|
|
||||||
account, err := am.Store.GetAccount(ctx, accountID)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
user, err := account.FindUser(userID)
|
return am.Store.GetAccountGroups(ctx, accountID)
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
if !user.HasAdminPower() && !user.IsServiceUser && account.Settings.RegularUsersViewBlocked {
|
|
||||||
return nil, status.Errorf(status.PermissionDenied, "groups are blocked for users")
|
|
||||||
}
|
|
||||||
|
|
||||||
groups := make([]*nbgroup.Group, 0, len(account.Groups))
|
|
||||||
for _, item := range account.Groups {
|
|
||||||
groups = append(groups, item)
|
|
||||||
}
|
|
||||||
|
|
||||||
return groups, nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetGroupByName filters all groups in an account by name and returns the one with the most peers
|
// GetGroupByName filters all groups in an account by name and returns the one with the most peers
|
||||||
func (am *DefaultAccountManager) GetGroupByName(ctx context.Context, groupName, accountID string) (*nbgroup.Group, error) {
|
func (am *DefaultAccountManager) GetGroupByName(ctx context.Context, groupName, accountID string) (*nbgroup.Group, error) {
|
||||||
unlock := am.Store.AcquireWriteLockByUID(ctx, accountID)
|
return am.Store.GetGroupByName(ctx, LockingStrengthShare, groupName, accountID)
|
||||||
defer unlock()
|
|
||||||
|
|
||||||
account, err := am.Store.GetAccount(ctx, accountID)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
matchingGroups := make([]*nbgroup.Group, 0)
|
|
||||||
for _, group := range account.Groups {
|
|
||||||
if group.Name == groupName {
|
|
||||||
matchingGroups = append(matchingGroups, group)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if len(matchingGroups) == 0 {
|
|
||||||
return nil, status.Errorf(status.NotFound, "group with name %s not found", groupName)
|
|
||||||
}
|
|
||||||
|
|
||||||
maxPeers := -1
|
|
||||||
var groupWithMostPeers *nbgroup.Group
|
|
||||||
for i, group := range matchingGroups {
|
|
||||||
if len(group.Peers) > maxPeers {
|
|
||||||
maxPeers = len(group.Peers)
|
|
||||||
groupWithMostPeers = matchingGroups[i]
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return groupWithMostPeers, nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// SaveGroup object of the peers
|
// SaveGroup object of the peers
|
||||||
@ -262,6 +217,15 @@ func (am *DefaultAccountManager) DeleteGroup(ctx context.Context, accountId, use
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
allGroup, err := account.GetGroupAll()
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
if allGroup.ID == groupID {
|
||||||
|
return status.Errorf(status.InvalidArgument, "deleting group ALL is not allowed")
|
||||||
|
}
|
||||||
|
|
||||||
if err = validateDeleteGroup(account, group, userId); err != nil {
|
if err = validateDeleteGroup(account, group, userId); err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
@ -262,7 +262,7 @@ func (s *GRPCServer) validateToken(ctx context.Context, jwtToken string) (string
|
|||||||
}
|
}
|
||||||
claims := s.jwtClaimsExtractor.FromToken(token)
|
claims := s.jwtClaimsExtractor.FromToken(token)
|
||||||
// we need to call this method because if user is new, we will automatically add it to existing or create a new account
|
// we need to call this method because if user is new, we will automatically add it to existing or create a new account
|
||||||
_, _, err = s.accountManager.GetAccountFromToken(ctx, claims)
|
_, _, err = s.accountManager.GetAccountIDFromToken(ctx, claims)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return "", status.Errorf(codes.Internal, "unable to fetch account with claims, err: %v", err)
|
return "", status.Errorf(codes.Internal, "unable to fetch account with claims, err: %v", err)
|
||||||
}
|
}
|
||||||
|
@ -35,25 +35,26 @@ func NewAccountsHandler(accountManager server.AccountManager, authCfg AuthCfg) *
|
|||||||
// GetAllAccounts is HTTP GET handler that returns a list of accounts. Effectively returns just a single account.
|
// GetAllAccounts is HTTP GET handler that returns a list of accounts. Effectively returns just a single account.
|
||||||
func (h *AccountsHandler) GetAllAccounts(w http.ResponseWriter, r *http.Request) {
|
func (h *AccountsHandler) GetAllAccounts(w http.ResponseWriter, r *http.Request) {
|
||||||
claims := h.claimsExtractor.FromRequestContext(r)
|
claims := h.claimsExtractor.FromRequestContext(r)
|
||||||
account, user, err := h.accountManager.GetAccountFromToken(r.Context(), claims)
|
accountID, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
util.WriteError(r.Context(), err, w)
|
util.WriteError(r.Context(), err, w)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
if !(user.HasAdminPower() || user.IsServiceUser) {
|
settings, err := h.accountManager.GetAccountSettings(r.Context(), accountID, userID)
|
||||||
util.WriteError(r.Context(), status.Errorf(status.PermissionDenied, "the user has no permission to access account data"), w)
|
if err != nil {
|
||||||
|
util.WriteError(r.Context(), err, w)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
resp := toAccountResponse(account)
|
resp := toAccountResponse(accountID, settings)
|
||||||
util.WriteJSONObject(r.Context(), w, []*api.Account{resp})
|
util.WriteJSONObject(r.Context(), w, []*api.Account{resp})
|
||||||
}
|
}
|
||||||
|
|
||||||
// UpdateAccount is HTTP PUT handler that updates the provided account. Updates only account settings (server.Settings)
|
// UpdateAccount is HTTP PUT handler that updates the provided account. Updates only account settings (server.Settings)
|
||||||
func (h *AccountsHandler) UpdateAccount(w http.ResponseWriter, r *http.Request) {
|
func (h *AccountsHandler) UpdateAccount(w http.ResponseWriter, r *http.Request) {
|
||||||
claims := h.claimsExtractor.FromRequestContext(r)
|
claims := h.claimsExtractor.FromRequestContext(r)
|
||||||
_, user, err := h.accountManager.GetAccountFromToken(r.Context(), claims)
|
_, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
util.WriteError(r.Context(), err, w)
|
util.WriteError(r.Context(), err, w)
|
||||||
return
|
return
|
||||||
@ -96,24 +97,19 @@ func (h *AccountsHandler) UpdateAccount(w http.ResponseWriter, r *http.Request)
|
|||||||
settings.JWTAllowGroups = *req.Settings.JwtAllowGroups
|
settings.JWTAllowGroups = *req.Settings.JwtAllowGroups
|
||||||
}
|
}
|
||||||
|
|
||||||
updatedAccount, err := h.accountManager.UpdateAccountSettings(r.Context(), accountID, user.Id, settings)
|
updatedAccount, err := h.accountManager.UpdateAccountSettings(r.Context(), accountID, userID, settings)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
util.WriteError(r.Context(), err, w)
|
util.WriteError(r.Context(), err, w)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
resp := toAccountResponse(updatedAccount)
|
resp := toAccountResponse(updatedAccount.Id, updatedAccount.Settings)
|
||||||
|
|
||||||
util.WriteJSONObject(r.Context(), w, &resp)
|
util.WriteJSONObject(r.Context(), w, &resp)
|
||||||
}
|
}
|
||||||
|
|
||||||
// DeleteAccount is a HTTP DELETE handler to delete an account
|
// DeleteAccount is a HTTP DELETE handler to delete an account
|
||||||
func (h *AccountsHandler) DeleteAccount(w http.ResponseWriter, r *http.Request) {
|
func (h *AccountsHandler) DeleteAccount(w http.ResponseWriter, r *http.Request) {
|
||||||
if r.Method != http.MethodDelete {
|
|
||||||
util.WriteErrorResponse("wrong HTTP method", http.StatusMethodNotAllowed, w)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
claims := h.claimsExtractor.FromRequestContext(r)
|
claims := h.claimsExtractor.FromRequestContext(r)
|
||||||
vars := mux.Vars(r)
|
vars := mux.Vars(r)
|
||||||
targetAccountID := vars["accountId"]
|
targetAccountID := vars["accountId"]
|
||||||
@ -131,28 +127,28 @@ func (h *AccountsHandler) DeleteAccount(w http.ResponseWriter, r *http.Request)
|
|||||||
util.WriteJSONObject(r.Context(), w, emptyObject{})
|
util.WriteJSONObject(r.Context(), w, emptyObject{})
|
||||||
}
|
}
|
||||||
|
|
||||||
func toAccountResponse(account *server.Account) *api.Account {
|
func toAccountResponse(accountID string, settings *server.Settings) *api.Account {
|
||||||
jwtAllowGroups := account.Settings.JWTAllowGroups
|
jwtAllowGroups := settings.JWTAllowGroups
|
||||||
if jwtAllowGroups == nil {
|
if jwtAllowGroups == nil {
|
||||||
jwtAllowGroups = []string{}
|
jwtAllowGroups = []string{}
|
||||||
}
|
}
|
||||||
|
|
||||||
settings := api.AccountSettings{
|
apiSettings := api.AccountSettings{
|
||||||
PeerLoginExpiration: int(account.Settings.PeerLoginExpiration.Seconds()),
|
PeerLoginExpiration: int(settings.PeerLoginExpiration.Seconds()),
|
||||||
PeerLoginExpirationEnabled: account.Settings.PeerLoginExpirationEnabled,
|
PeerLoginExpirationEnabled: settings.PeerLoginExpirationEnabled,
|
||||||
GroupsPropagationEnabled: &account.Settings.GroupsPropagationEnabled,
|
GroupsPropagationEnabled: &settings.GroupsPropagationEnabled,
|
||||||
JwtGroupsEnabled: &account.Settings.JWTGroupsEnabled,
|
JwtGroupsEnabled: &settings.JWTGroupsEnabled,
|
||||||
JwtGroupsClaimName: &account.Settings.JWTGroupsClaimName,
|
JwtGroupsClaimName: &settings.JWTGroupsClaimName,
|
||||||
JwtAllowGroups: &jwtAllowGroups,
|
JwtAllowGroups: &jwtAllowGroups,
|
||||||
RegularUsersViewBlocked: account.Settings.RegularUsersViewBlocked,
|
RegularUsersViewBlocked: settings.RegularUsersViewBlocked,
|
||||||
}
|
}
|
||||||
|
|
||||||
if account.Settings.Extra != nil {
|
if settings.Extra != nil {
|
||||||
settings.Extra = &api.AccountExtraSettings{PeerApprovalEnabled: &account.Settings.Extra.PeerApprovalEnabled}
|
apiSettings.Extra = &api.AccountExtraSettings{PeerApprovalEnabled: &settings.Extra.PeerApprovalEnabled}
|
||||||
}
|
}
|
||||||
|
|
||||||
return &api.Account{
|
return &api.Account{
|
||||||
Id: account.Id,
|
Id: accountID,
|
||||||
Settings: settings,
|
Settings: apiSettings,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -23,8 +23,11 @@ import (
|
|||||||
func initAccountsTestData(account *server.Account, admin *server.User) *AccountsHandler {
|
func initAccountsTestData(account *server.Account, admin *server.User) *AccountsHandler {
|
||||||
return &AccountsHandler{
|
return &AccountsHandler{
|
||||||
accountManager: &mock_server.MockAccountManager{
|
accountManager: &mock_server.MockAccountManager{
|
||||||
GetAccountFromTokenFunc: func(ctx context.Context, claims jwtclaims.AuthorizationClaims) (*server.Account, *server.User, error) {
|
GetAccountIDFromTokenFunc: func(ctx context.Context, claims jwtclaims.AuthorizationClaims) (string, string, error) {
|
||||||
return account, admin, nil
|
return account.Id, admin.Id, nil
|
||||||
|
},
|
||||||
|
GetAccountSettingsFunc: func(ctx context.Context, accountID string, userID string) (*server.Settings, error) {
|
||||||
|
return account.Settings, nil
|
||||||
},
|
},
|
||||||
UpdateAccountSettingsFunc: func(ctx context.Context, accountID, userID string, newSettings *server.Settings) (*server.Account, error) {
|
UpdateAccountSettingsFunc: func(ctx context.Context, accountID, userID string, newSettings *server.Settings) (*server.Account, error) {
|
||||||
halfYearLimit := 180 * 24 * time.Hour
|
halfYearLimit := 180 * 24 * time.Hour
|
||||||
|
@ -32,14 +32,14 @@ func NewDNSSettingsHandler(accountManager server.AccountManager, authCfg AuthCfg
|
|||||||
// GetDNSSettings returns the DNS settings for the account
|
// GetDNSSettings returns the DNS settings for the account
|
||||||
func (h *DNSSettingsHandler) GetDNSSettings(w http.ResponseWriter, r *http.Request) {
|
func (h *DNSSettingsHandler) GetDNSSettings(w http.ResponseWriter, r *http.Request) {
|
||||||
claims := h.claimsExtractor.FromRequestContext(r)
|
claims := h.claimsExtractor.FromRequestContext(r)
|
||||||
account, user, err := h.accountManager.GetAccountFromToken(r.Context(), claims)
|
accountID, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.WithContext(r.Context()).Error(err)
|
log.WithContext(r.Context()).Error(err)
|
||||||
http.Redirect(w, r, "/", http.StatusInternalServerError)
|
http.Redirect(w, r, "/", http.StatusInternalServerError)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
dnsSettings, err := h.accountManager.GetDNSSettings(r.Context(), account.Id, user.Id)
|
dnsSettings, err := h.accountManager.GetDNSSettings(r.Context(), accountID, userID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
util.WriteError(r.Context(), err, w)
|
util.WriteError(r.Context(), err, w)
|
||||||
return
|
return
|
||||||
@ -55,7 +55,7 @@ func (h *DNSSettingsHandler) GetDNSSettings(w http.ResponseWriter, r *http.Reque
|
|||||||
// UpdateDNSSettings handles update to DNS settings of an account
|
// UpdateDNSSettings handles update to DNS settings of an account
|
||||||
func (h *DNSSettingsHandler) UpdateDNSSettings(w http.ResponseWriter, r *http.Request) {
|
func (h *DNSSettingsHandler) UpdateDNSSettings(w http.ResponseWriter, r *http.Request) {
|
||||||
claims := h.claimsExtractor.FromRequestContext(r)
|
claims := h.claimsExtractor.FromRequestContext(r)
|
||||||
account, user, err := h.accountManager.GetAccountFromToken(r.Context(), claims)
|
accountID, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
util.WriteError(r.Context(), err, w)
|
util.WriteError(r.Context(), err, w)
|
||||||
return
|
return
|
||||||
@ -72,7 +72,7 @@ func (h *DNSSettingsHandler) UpdateDNSSettings(w http.ResponseWriter, r *http.Re
|
|||||||
DisabledManagementGroups: req.DisabledManagementGroups,
|
DisabledManagementGroups: req.DisabledManagementGroups,
|
||||||
}
|
}
|
||||||
|
|
||||||
err = h.accountManager.SaveDNSSettings(r.Context(), account.Id, user.Id, updateDNSSettings)
|
err = h.accountManager.SaveDNSSettings(r.Context(), accountID, userID, updateDNSSettings)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
util.WriteError(r.Context(), err, w)
|
util.WriteError(r.Context(), err, w)
|
||||||
return
|
return
|
||||||
|
@ -52,8 +52,8 @@ func initDNSSettingsTestData() *DNSSettingsHandler {
|
|||||||
}
|
}
|
||||||
return status.Errorf(status.InvalidArgument, "the dns settings provided are nil")
|
return status.Errorf(status.InvalidArgument, "the dns settings provided are nil")
|
||||||
},
|
},
|
||||||
GetAccountFromTokenFunc: func(ctx context.Context, _ jwtclaims.AuthorizationClaims) (*server.Account, *server.User, error) {
|
GetAccountIDFromTokenFunc: func(ctx context.Context, _ jwtclaims.AuthorizationClaims) (string, string, error) {
|
||||||
return testingDNSSettingsAccount, testingDNSSettingsAccount.Users[testDNSSettingsUserID], nil
|
return testingDNSSettingsAccount.Id, testingDNSSettingsAccount.Users[testDNSSettingsUserID].Id, nil
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
claimsExtractor: jwtclaims.NewClaimsExtractor(
|
claimsExtractor: jwtclaims.NewClaimsExtractor(
|
||||||
|
@ -34,14 +34,14 @@ func NewEventsHandler(accountManager server.AccountManager, authCfg AuthCfg) *Ev
|
|||||||
// GetAllEvents list of the given account
|
// GetAllEvents list of the given account
|
||||||
func (h *EventsHandler) GetAllEvents(w http.ResponseWriter, r *http.Request) {
|
func (h *EventsHandler) GetAllEvents(w http.ResponseWriter, r *http.Request) {
|
||||||
claims := h.claimsExtractor.FromRequestContext(r)
|
claims := h.claimsExtractor.FromRequestContext(r)
|
||||||
account, user, err := h.accountManager.GetAccountFromToken(r.Context(), claims)
|
accountID, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.WithContext(r.Context()).Error(err)
|
log.WithContext(r.Context()).Error(err)
|
||||||
http.Redirect(w, r, "/", http.StatusInternalServerError)
|
http.Redirect(w, r, "/", http.StatusInternalServerError)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
accountEvents, err := h.accountManager.GetEvents(r.Context(), account.Id, user.Id)
|
accountEvents, err := h.accountManager.GetEvents(r.Context(), accountID, userID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
util.WriteError(r.Context(), err, w)
|
util.WriteError(r.Context(), err, w)
|
||||||
return
|
return
|
||||||
@ -51,7 +51,7 @@ func (h *EventsHandler) GetAllEvents(w http.ResponseWriter, r *http.Request) {
|
|||||||
events[i] = toEventResponse(e)
|
events[i] = toEventResponse(e)
|
||||||
}
|
}
|
||||||
|
|
||||||
err = h.fillEventsWithUserInfo(r.Context(), events, account.Id, user.Id)
|
err = h.fillEventsWithUserInfo(r.Context(), events, accountID, userID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
util.WriteError(r.Context(), err, w)
|
util.WriteError(r.Context(), err, w)
|
||||||
return
|
return
|
||||||
|
@ -20,7 +20,7 @@ import (
|
|||||||
"github.com/netbirdio/netbird/management/server/mock_server"
|
"github.com/netbirdio/netbird/management/server/mock_server"
|
||||||
)
|
)
|
||||||
|
|
||||||
func initEventsTestData(account string, user *server.User, events ...*activity.Event) *EventsHandler {
|
func initEventsTestData(account string, events ...*activity.Event) *EventsHandler {
|
||||||
return &EventsHandler{
|
return &EventsHandler{
|
||||||
accountManager: &mock_server.MockAccountManager{
|
accountManager: &mock_server.MockAccountManager{
|
||||||
GetEventsFunc: func(_ context.Context, accountID, userID string) ([]*activity.Event, error) {
|
GetEventsFunc: func(_ context.Context, accountID, userID string) ([]*activity.Event, error) {
|
||||||
@ -29,14 +29,8 @@ func initEventsTestData(account string, user *server.User, events ...*activity.E
|
|||||||
}
|
}
|
||||||
return []*activity.Event{}, nil
|
return []*activity.Event{}, nil
|
||||||
},
|
},
|
||||||
GetAccountFromTokenFunc: func(_ context.Context, claims jwtclaims.AuthorizationClaims) (*server.Account, *server.User, error) {
|
GetAccountIDFromTokenFunc: func(_ context.Context, claims jwtclaims.AuthorizationClaims) (string, string, error) {
|
||||||
return &server.Account{
|
return claims.AccountId, claims.UserId, nil
|
||||||
Id: claims.AccountId,
|
|
||||||
Domain: "hotmail.com",
|
|
||||||
Users: map[string]*server.User{
|
|
||||||
user.Id: user,
|
|
||||||
},
|
|
||||||
}, user, nil
|
|
||||||
},
|
},
|
||||||
GetUsersFromAccountFunc: func(_ context.Context, accountID, userID string) ([]*server.UserInfo, error) {
|
GetUsersFromAccountFunc: func(_ context.Context, accountID, userID string) ([]*server.UserInfo, error) {
|
||||||
return make([]*server.UserInfo, 0), nil
|
return make([]*server.UserInfo, 0), nil
|
||||||
@ -199,7 +193,7 @@ func TestEvents_GetEvents(t *testing.T) {
|
|||||||
accountID := "test_account"
|
accountID := "test_account"
|
||||||
adminUser := server.NewAdminUser("test_user")
|
adminUser := server.NewAdminUser("test_user")
|
||||||
events := generateEvents(accountID, adminUser.Id)
|
events := generateEvents(accountID, adminUser.Id)
|
||||||
handler := initEventsTestData(accountID, adminUser, events...)
|
handler := initEventsTestData(accountID, events...)
|
||||||
|
|
||||||
for _, tc := range tt {
|
for _, tc := range tt {
|
||||||
t.Run(tc.name, func(t *testing.T) {
|
t.Run(tc.name, func(t *testing.T) {
|
||||||
|
@ -11,9 +11,9 @@ import (
|
|||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
"github.com/gorilla/mux"
|
"github.com/gorilla/mux"
|
||||||
|
"github.com/netbirdio/netbird/management/server"
|
||||||
"github.com/stretchr/testify/assert"
|
"github.com/stretchr/testify/assert"
|
||||||
|
|
||||||
"github.com/netbirdio/netbird/management/server"
|
|
||||||
"github.com/netbirdio/netbird/management/server/geolocation"
|
"github.com/netbirdio/netbird/management/server/geolocation"
|
||||||
"github.com/netbirdio/netbird/management/server/http/api"
|
"github.com/netbirdio/netbird/management/server/http/api"
|
||||||
"github.com/netbirdio/netbird/management/server/jwtclaims"
|
"github.com/netbirdio/netbird/management/server/jwtclaims"
|
||||||
@ -43,14 +43,11 @@ func initGeolocationTestData(t *testing.T) *GeolocationsHandler {
|
|||||||
|
|
||||||
return &GeolocationsHandler{
|
return &GeolocationsHandler{
|
||||||
accountManager: &mock_server.MockAccountManager{
|
accountManager: &mock_server.MockAccountManager{
|
||||||
GetAccountFromTokenFunc: func(_ context.Context, claims jwtclaims.AuthorizationClaims) (*server.Account, *server.User, error) {
|
GetAccountIDFromTokenFunc: func(_ context.Context, claims jwtclaims.AuthorizationClaims) (string, string, error) {
|
||||||
user := server.NewAdminUser("test_user")
|
return claims.AccountId, claims.UserId, nil
|
||||||
return &server.Account{
|
},
|
||||||
Id: claims.AccountId,
|
GetUserByIDFunc: func(ctx context.Context, id string) (*server.User, error) {
|
||||||
Users: map[string]*server.User{
|
return server.NewAdminUser(id), nil
|
||||||
"test_user": user,
|
|
||||||
},
|
|
||||||
}, user, nil
|
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
geolocationManager: geo,
|
geolocationManager: geo,
|
||||||
|
@ -98,7 +98,12 @@ func (l *GeolocationsHandler) GetCitiesByCountry(w http.ResponseWriter, r *http.
|
|||||||
|
|
||||||
func (l *GeolocationsHandler) authenticateUser(r *http.Request) error {
|
func (l *GeolocationsHandler) authenticateUser(r *http.Request) error {
|
||||||
claims := l.claimsExtractor.FromRequestContext(r)
|
claims := l.claimsExtractor.FromRequestContext(r)
|
||||||
_, user, err := l.accountManager.GetAccountFromToken(r.Context(), claims)
|
_, userID, err := l.accountManager.GetAccountIDFromToken(r.Context(), claims)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
user, err := l.accountManager.GetUserByID(r.Context(), userID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
@ -5,6 +5,7 @@ import (
|
|||||||
"net/http"
|
"net/http"
|
||||||
|
|
||||||
"github.com/gorilla/mux"
|
"github.com/gorilla/mux"
|
||||||
|
nbpeer "github.com/netbirdio/netbird/management/server/peer"
|
||||||
log "github.com/sirupsen/logrus"
|
log "github.com/sirupsen/logrus"
|
||||||
|
|
||||||
"github.com/netbirdio/netbird/management/server"
|
"github.com/netbirdio/netbird/management/server"
|
||||||
@ -35,14 +36,20 @@ func NewGroupsHandler(accountManager server.AccountManager, authCfg AuthCfg) *Gr
|
|||||||
// GetAllGroups list for the account
|
// GetAllGroups list for the account
|
||||||
func (h *GroupsHandler) GetAllGroups(w http.ResponseWriter, r *http.Request) {
|
func (h *GroupsHandler) GetAllGroups(w http.ResponseWriter, r *http.Request) {
|
||||||
claims := h.claimsExtractor.FromRequestContext(r)
|
claims := h.claimsExtractor.FromRequestContext(r)
|
||||||
account, user, err := h.accountManager.GetAccountFromToken(r.Context(), claims)
|
accountID, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.WithContext(r.Context()).Error(err)
|
log.WithContext(r.Context()).Error(err)
|
||||||
http.Redirect(w, r, "/", http.StatusInternalServerError)
|
http.Redirect(w, r, "/", http.StatusInternalServerError)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
groups, err := h.accountManager.GetAllGroups(r.Context(), account.Id, user.Id)
|
groups, err := h.accountManager.GetAllGroups(r.Context(), accountID, userID)
|
||||||
|
if err != nil {
|
||||||
|
util.WriteError(r.Context(), err, w)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
accountPeers, err := h.accountManager.GetPeers(r.Context(), accountID, userID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
util.WriteError(r.Context(), err, w)
|
util.WriteError(r.Context(), err, w)
|
||||||
return
|
return
|
||||||
@ -50,7 +57,7 @@ func (h *GroupsHandler) GetAllGroups(w http.ResponseWriter, r *http.Request) {
|
|||||||
|
|
||||||
groupsResponse := make([]*api.Group, 0, len(groups))
|
groupsResponse := make([]*api.Group, 0, len(groups))
|
||||||
for _, group := range groups {
|
for _, group := range groups {
|
||||||
groupsResponse = append(groupsResponse, toGroupResponse(account, group))
|
groupsResponse = append(groupsResponse, toGroupResponse(accountPeers, group))
|
||||||
}
|
}
|
||||||
|
|
||||||
util.WriteJSONObject(r.Context(), w, groupsResponse)
|
util.WriteJSONObject(r.Context(), w, groupsResponse)
|
||||||
@ -59,7 +66,7 @@ func (h *GroupsHandler) GetAllGroups(w http.ResponseWriter, r *http.Request) {
|
|||||||
// UpdateGroup handles update to a group identified by a given ID
|
// UpdateGroup handles update to a group identified by a given ID
|
||||||
func (h *GroupsHandler) UpdateGroup(w http.ResponseWriter, r *http.Request) {
|
func (h *GroupsHandler) UpdateGroup(w http.ResponseWriter, r *http.Request) {
|
||||||
claims := h.claimsExtractor.FromRequestContext(r)
|
claims := h.claimsExtractor.FromRequestContext(r)
|
||||||
account, user, err := h.accountManager.GetAccountFromToken(r.Context(), claims)
|
accountID, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
util.WriteError(r.Context(), err, w)
|
util.WriteError(r.Context(), err, w)
|
||||||
return
|
return
|
||||||
@ -76,17 +83,18 @@ func (h *GroupsHandler) UpdateGroup(w http.ResponseWriter, r *http.Request) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
eg, ok := account.Groups[groupID]
|
existingGroup, err := h.accountManager.GetGroup(r.Context(), accountID, groupID, userID)
|
||||||
if !ok {
|
|
||||||
util.WriteError(r.Context(), status.Errorf(status.NotFound, "couldn't find group with ID %s", groupID), w)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
allGroup, err := account.GetGroupAll()
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
util.WriteError(r.Context(), err, w)
|
util.WriteError(r.Context(), err, w)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
allGroup, err := h.accountManager.GetGroupByName(r.Context(), "All", accountID)
|
||||||
|
if err != nil {
|
||||||
|
util.WriteError(r.Context(), err, w)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
if allGroup.ID == groupID {
|
if allGroup.ID == groupID {
|
||||||
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "updating group ALL is not allowed"), w)
|
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "updating group ALL is not allowed"), w)
|
||||||
return
|
return
|
||||||
@ -114,23 +122,29 @@ func (h *GroupsHandler) UpdateGroup(w http.ResponseWriter, r *http.Request) {
|
|||||||
ID: groupID,
|
ID: groupID,
|
||||||
Name: req.Name,
|
Name: req.Name,
|
||||||
Peers: peers,
|
Peers: peers,
|
||||||
Issued: eg.Issued,
|
Issued: existingGroup.Issued,
|
||||||
IntegrationReference: eg.IntegrationReference,
|
IntegrationReference: existingGroup.IntegrationReference,
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := h.accountManager.SaveGroup(r.Context(), account.Id, user.Id, &group); err != nil {
|
if err := h.accountManager.SaveGroup(r.Context(), accountID, userID, &group); err != nil {
|
||||||
log.WithContext(r.Context()).Errorf("failed updating group %s under account %s %v", groupID, account.Id, err)
|
log.WithContext(r.Context()).Errorf("failed updating group %s under account %s %v", groupID, accountID, err)
|
||||||
util.WriteError(r.Context(), err, w)
|
util.WriteError(r.Context(), err, w)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
util.WriteJSONObject(r.Context(), w, toGroupResponse(account, &group))
|
accountPeers, err := h.accountManager.GetPeers(r.Context(), accountID, userID)
|
||||||
|
if err != nil {
|
||||||
|
util.WriteError(r.Context(), err, w)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
util.WriteJSONObject(r.Context(), w, toGroupResponse(accountPeers, &group))
|
||||||
}
|
}
|
||||||
|
|
||||||
// CreateGroup handles group creation request
|
// CreateGroup handles group creation request
|
||||||
func (h *GroupsHandler) CreateGroup(w http.ResponseWriter, r *http.Request) {
|
func (h *GroupsHandler) CreateGroup(w http.ResponseWriter, r *http.Request) {
|
||||||
claims := h.claimsExtractor.FromRequestContext(r)
|
claims := h.claimsExtractor.FromRequestContext(r)
|
||||||
account, user, err := h.accountManager.GetAccountFromToken(r.Context(), claims)
|
accountID, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
util.WriteError(r.Context(), err, w)
|
util.WriteError(r.Context(), err, w)
|
||||||
return
|
return
|
||||||
@ -160,24 +174,29 @@ func (h *GroupsHandler) CreateGroup(w http.ResponseWriter, r *http.Request) {
|
|||||||
Issued: nbgroup.GroupIssuedAPI,
|
Issued: nbgroup.GroupIssuedAPI,
|
||||||
}
|
}
|
||||||
|
|
||||||
err = h.accountManager.SaveGroup(r.Context(), account.Id, user.Id, &group)
|
err = h.accountManager.SaveGroup(r.Context(), accountID, userID, &group)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
util.WriteError(r.Context(), err, w)
|
util.WriteError(r.Context(), err, w)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
util.WriteJSONObject(r.Context(), w, toGroupResponse(account, &group))
|
accountPeers, err := h.accountManager.GetPeers(r.Context(), accountID, userID)
|
||||||
|
if err != nil {
|
||||||
|
util.WriteError(r.Context(), err, w)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
util.WriteJSONObject(r.Context(), w, toGroupResponse(accountPeers, &group))
|
||||||
}
|
}
|
||||||
|
|
||||||
// DeleteGroup handles group deletion request
|
// DeleteGroup handles group deletion request
|
||||||
func (h *GroupsHandler) DeleteGroup(w http.ResponseWriter, r *http.Request) {
|
func (h *GroupsHandler) DeleteGroup(w http.ResponseWriter, r *http.Request) {
|
||||||
claims := h.claimsExtractor.FromRequestContext(r)
|
claims := h.claimsExtractor.FromRequestContext(r)
|
||||||
account, user, err := h.accountManager.GetAccountFromToken(r.Context(), claims)
|
accountID, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
util.WriteError(r.Context(), err, w)
|
util.WriteError(r.Context(), err, w)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
aID := account.Id
|
|
||||||
|
|
||||||
groupID := mux.Vars(r)["groupId"]
|
groupID := mux.Vars(r)["groupId"]
|
||||||
if len(groupID) == 0 {
|
if len(groupID) == 0 {
|
||||||
@ -185,18 +204,7 @@ func (h *GroupsHandler) DeleteGroup(w http.ResponseWriter, r *http.Request) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
allGroup, err := account.GetGroupAll()
|
err = h.accountManager.DeleteGroup(r.Context(), accountID, userID, groupID)
|
||||||
if err != nil {
|
|
||||||
util.WriteError(r.Context(), err, w)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
if allGroup.ID == groupID {
|
|
||||||
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "deleting group ALL is not allowed"), w)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
err = h.accountManager.DeleteGroup(r.Context(), aID, user.Id, groupID)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
_, ok := err.(*server.GroupLinkError)
|
_, ok := err.(*server.GroupLinkError)
|
||||||
if ok {
|
if ok {
|
||||||
@ -213,34 +221,39 @@ func (h *GroupsHandler) DeleteGroup(w http.ResponseWriter, r *http.Request) {
|
|||||||
// GetGroup returns a group
|
// GetGroup returns a group
|
||||||
func (h *GroupsHandler) GetGroup(w http.ResponseWriter, r *http.Request) {
|
func (h *GroupsHandler) GetGroup(w http.ResponseWriter, r *http.Request) {
|
||||||
claims := h.claimsExtractor.FromRequestContext(r)
|
claims := h.claimsExtractor.FromRequestContext(r)
|
||||||
account, user, err := h.accountManager.GetAccountFromToken(r.Context(), claims)
|
accountID, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims)
|
||||||
|
if err != nil {
|
||||||
|
util.WriteError(r.Context(), err, w)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
groupID := mux.Vars(r)["groupId"]
|
||||||
|
if len(groupID) == 0 {
|
||||||
|
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "invalid group ID"), w)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
group, err := h.accountManager.GetGroup(r.Context(), accountID, groupID, userID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
util.WriteError(r.Context(), err, w)
|
util.WriteError(r.Context(), err, w)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
switch r.Method {
|
accountPeers, err := h.accountManager.GetPeers(r.Context(), accountID, userID)
|
||||||
case http.MethodGet:
|
if err != nil {
|
||||||
groupID := mux.Vars(r)["groupId"]
|
util.WriteError(r.Context(), err, w)
|
||||||
if len(groupID) == 0 {
|
|
||||||
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "invalid group ID"), w)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
group, err := h.accountManager.GetGroup(r.Context(), account.Id, groupID, user.Id)
|
|
||||||
if err != nil {
|
|
||||||
util.WriteError(r.Context(), err, w)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
util.WriteJSONObject(r.Context(), w, toGroupResponse(account, group))
|
|
||||||
default:
|
|
||||||
util.WriteError(r.Context(), status.Errorf(status.NotFound, "HTTP method not found"), w)
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
util.WriteJSONObject(r.Context(), w, toGroupResponse(accountPeers, group))
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func toGroupResponse(account *server.Account, group *nbgroup.Group) *api.Group {
|
func toGroupResponse(peers []*nbpeer.Peer, group *nbgroup.Group) *api.Group {
|
||||||
|
peersMap := make(map[string]*nbpeer.Peer, len(peers))
|
||||||
|
for _, peer := range peers {
|
||||||
|
peersMap[peer.ID] = peer
|
||||||
|
}
|
||||||
|
|
||||||
cache := make(map[string]api.PeerMinimum)
|
cache := make(map[string]api.PeerMinimum)
|
||||||
gr := api.Group{
|
gr := api.Group{
|
||||||
Id: group.ID,
|
Id: group.ID,
|
||||||
@ -251,7 +264,7 @@ func toGroupResponse(account *server.Account, group *nbgroup.Group) *api.Group {
|
|||||||
for _, pid := range group.Peers {
|
for _, pid := range group.Peers {
|
||||||
_, ok := cache[pid]
|
_, ok := cache[pid]
|
||||||
if !ok {
|
if !ok {
|
||||||
peer, ok := account.Peers[pid]
|
peer, ok := peersMap[pid]
|
||||||
if !ok {
|
if !ok {
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
@ -14,6 +14,7 @@ import (
|
|||||||
|
|
||||||
"github.com/gorilla/mux"
|
"github.com/gorilla/mux"
|
||||||
"github.com/magiconair/properties/assert"
|
"github.com/magiconair/properties/assert"
|
||||||
|
"golang.org/x/exp/maps"
|
||||||
|
|
||||||
"github.com/netbirdio/netbird/management/server"
|
"github.com/netbirdio/netbird/management/server"
|
||||||
nbgroup "github.com/netbirdio/netbird/management/server/group"
|
nbgroup "github.com/netbirdio/netbird/management/server/group"
|
||||||
@ -30,7 +31,7 @@ var TestPeers = map[string]*nbpeer.Peer{
|
|||||||
"B": {Key: "B", ID: "peer-B-ID", IP: net.ParseIP("200.200.200.200")},
|
"B": {Key: "B", ID: "peer-B-ID", IP: net.ParseIP("200.200.200.200")},
|
||||||
}
|
}
|
||||||
|
|
||||||
func initGroupTestData(user *server.User, _ ...*nbgroup.Group) *GroupsHandler {
|
func initGroupTestData(initGroups ...*nbgroup.Group) *GroupsHandler {
|
||||||
return &GroupsHandler{
|
return &GroupsHandler{
|
||||||
accountManager: &mock_server.MockAccountManager{
|
accountManager: &mock_server.MockAccountManager{
|
||||||
SaveGroupFunc: func(_ context.Context, accountID, userID string, group *nbgroup.Group) error {
|
SaveGroupFunc: func(_ context.Context, accountID, userID string, group *nbgroup.Group) error {
|
||||||
@ -40,36 +41,35 @@ func initGroupTestData(user *server.User, _ ...*nbgroup.Group) *GroupsHandler {
|
|||||||
return nil
|
return nil
|
||||||
},
|
},
|
||||||
GetGroupFunc: func(_ context.Context, _, groupID, _ string) (*nbgroup.Group, error) {
|
GetGroupFunc: func(_ context.Context, _, groupID, _ string) (*nbgroup.Group, error) {
|
||||||
if groupID != "idofthegroup" {
|
groups := map[string]*nbgroup.Group{
|
||||||
|
"id-jwt-group": {ID: "id-jwt-group", Name: "From JWT", Issued: nbgroup.GroupIssuedJWT},
|
||||||
|
"id-existed": {ID: "id-existed", Peers: []string{"A", "B"}, Issued: nbgroup.GroupIssuedAPI},
|
||||||
|
"id-all": {ID: "id-all", Name: "All", Issued: nbgroup.GroupIssuedAPI},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, group := range initGroups {
|
||||||
|
groups[group.ID] = group
|
||||||
|
}
|
||||||
|
|
||||||
|
group, ok := groups[groupID]
|
||||||
|
if !ok {
|
||||||
return nil, status.Errorf(status.NotFound, "not found")
|
return nil, status.Errorf(status.NotFound, "not found")
|
||||||
}
|
}
|
||||||
if groupID == "id-jwt-group" {
|
|
||||||
return &nbgroup.Group{
|
return group, nil
|
||||||
ID: "id-jwt-group",
|
|
||||||
Name: "Default Group",
|
|
||||||
Issued: nbgroup.GroupIssuedJWT,
|
|
||||||
}, nil
|
|
||||||
}
|
|
||||||
return &nbgroup.Group{
|
|
||||||
ID: "idofthegroup",
|
|
||||||
Name: "Group",
|
|
||||||
Issued: nbgroup.GroupIssuedAPI,
|
|
||||||
}, nil
|
|
||||||
},
|
},
|
||||||
GetAccountFromTokenFunc: func(_ context.Context, claims jwtclaims.AuthorizationClaims) (*server.Account, *server.User, error) {
|
GetAccountIDFromTokenFunc: func(_ context.Context, claims jwtclaims.AuthorizationClaims) (string, string, error) {
|
||||||
return &server.Account{
|
return claims.AccountId, claims.UserId, nil
|
||||||
Id: claims.AccountId,
|
},
|
||||||
Domain: "hotmail.com",
|
GetGroupByNameFunc: func(ctx context.Context, groupName, _ string) (*nbgroup.Group, error) {
|
||||||
Peers: TestPeers,
|
if groupName == "All" {
|
||||||
Users: map[string]*server.User{
|
return &nbgroup.Group{ID: "id-all", Name: "All", Issued: nbgroup.GroupIssuedAPI}, nil
|
||||||
user.Id: user,
|
}
|
||||||
},
|
|
||||||
Groups: map[string]*nbgroup.Group{
|
return nil, fmt.Errorf("unknown group name")
|
||||||
"id-jwt-group": {ID: "id-jwt-group", Name: "From JWT", Issued: nbgroup.GroupIssuedJWT},
|
},
|
||||||
"id-existed": {ID: "id-existed", Peers: []string{"A", "B"}, Issued: nbgroup.GroupIssuedAPI},
|
GetPeersFunc: func(ctx context.Context, accountID, userID string) ([]*nbpeer.Peer, error) {
|
||||||
"id-all": {ID: "id-all", Name: "All", Issued: nbgroup.GroupIssuedAPI},
|
return maps.Values(TestPeers), nil
|
||||||
},
|
|
||||||
}, user, nil
|
|
||||||
},
|
},
|
||||||
DeleteGroupFunc: func(_ context.Context, accountID, userId, groupID string) error {
|
DeleteGroupFunc: func(_ context.Context, accountID, userId, groupID string) error {
|
||||||
if groupID == "linked-grp" {
|
if groupID == "linked-grp" {
|
||||||
@ -125,8 +125,7 @@ func TestGetGroup(t *testing.T) {
|
|||||||
Name: "Group",
|
Name: "Group",
|
||||||
}
|
}
|
||||||
|
|
||||||
adminUser := server.NewAdminUser("test_user")
|
p := initGroupTestData(group)
|
||||||
p := initGroupTestData(adminUser, group)
|
|
||||||
|
|
||||||
for _, tc := range tt {
|
for _, tc := range tt {
|
||||||
t.Run(tc.name, func(t *testing.T) {
|
t.Run(tc.name, func(t *testing.T) {
|
||||||
@ -247,8 +246,7 @@ func TestWriteGroup(t *testing.T) {
|
|||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
adminUser := server.NewAdminUser("test_user")
|
p := initGroupTestData()
|
||||||
p := initGroupTestData(adminUser)
|
|
||||||
|
|
||||||
for _, tc := range tt {
|
for _, tc := range tt {
|
||||||
t.Run(tc.name, func(t *testing.T) {
|
t.Run(tc.name, func(t *testing.T) {
|
||||||
@ -325,8 +323,7 @@ func TestDeleteGroup(t *testing.T) {
|
|||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
adminUser := server.NewAdminUser("test_user")
|
p := initGroupTestData()
|
||||||
p := initGroupTestData(adminUser)
|
|
||||||
|
|
||||||
for _, tc := range tt {
|
for _, tc := range tt {
|
||||||
t.Run(tc.name, func(t *testing.T) {
|
t.Run(tc.name, func(t *testing.T) {
|
||||||
|
@ -36,14 +36,14 @@ func NewNameserversHandler(accountManager server.AccountManager, authCfg AuthCfg
|
|||||||
// GetAllNameservers returns the list of nameserver groups for the account
|
// GetAllNameservers returns the list of nameserver groups for the account
|
||||||
func (h *NameserversHandler) GetAllNameservers(w http.ResponseWriter, r *http.Request) {
|
func (h *NameserversHandler) GetAllNameservers(w http.ResponseWriter, r *http.Request) {
|
||||||
claims := h.claimsExtractor.FromRequestContext(r)
|
claims := h.claimsExtractor.FromRequestContext(r)
|
||||||
account, user, err := h.accountManager.GetAccountFromToken(r.Context(), claims)
|
accountID, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.WithContext(r.Context()).Error(err)
|
log.WithContext(r.Context()).Error(err)
|
||||||
http.Redirect(w, r, "/", http.StatusInternalServerError)
|
http.Redirect(w, r, "/", http.StatusInternalServerError)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
nsGroups, err := h.accountManager.ListNameServerGroups(r.Context(), account.Id, user.Id)
|
nsGroups, err := h.accountManager.ListNameServerGroups(r.Context(), accountID, userID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
util.WriteError(r.Context(), err, w)
|
util.WriteError(r.Context(), err, w)
|
||||||
return
|
return
|
||||||
@ -60,7 +60,7 @@ func (h *NameserversHandler) GetAllNameservers(w http.ResponseWriter, r *http.Re
|
|||||||
// CreateNameserverGroup handles nameserver group creation request
|
// CreateNameserverGroup handles nameserver group creation request
|
||||||
func (h *NameserversHandler) CreateNameserverGroup(w http.ResponseWriter, r *http.Request) {
|
func (h *NameserversHandler) CreateNameserverGroup(w http.ResponseWriter, r *http.Request) {
|
||||||
claims := h.claimsExtractor.FromRequestContext(r)
|
claims := h.claimsExtractor.FromRequestContext(r)
|
||||||
account, user, err := h.accountManager.GetAccountFromToken(r.Context(), claims)
|
accountID, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
util.WriteError(r.Context(), err, w)
|
util.WriteError(r.Context(), err, w)
|
||||||
return
|
return
|
||||||
@ -79,7 +79,7 @@ func (h *NameserversHandler) CreateNameserverGroup(w http.ResponseWriter, r *htt
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
nsGroup, err := h.accountManager.CreateNameServerGroup(r.Context(), account.Id, req.Name, req.Description, nsList, req.Groups, req.Primary, req.Domains, req.Enabled, user.Id, req.SearchDomainsEnabled)
|
nsGroup, err := h.accountManager.CreateNameServerGroup(r.Context(), accountID, req.Name, req.Description, nsList, req.Groups, req.Primary, req.Domains, req.Enabled, userID, req.SearchDomainsEnabled)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
util.WriteError(r.Context(), err, w)
|
util.WriteError(r.Context(), err, w)
|
||||||
return
|
return
|
||||||
@ -93,7 +93,7 @@ func (h *NameserversHandler) CreateNameserverGroup(w http.ResponseWriter, r *htt
|
|||||||
// UpdateNameserverGroup handles update to a nameserver group identified by a given ID
|
// UpdateNameserverGroup handles update to a nameserver group identified by a given ID
|
||||||
func (h *NameserversHandler) UpdateNameserverGroup(w http.ResponseWriter, r *http.Request) {
|
func (h *NameserversHandler) UpdateNameserverGroup(w http.ResponseWriter, r *http.Request) {
|
||||||
claims := h.claimsExtractor.FromRequestContext(r)
|
claims := h.claimsExtractor.FromRequestContext(r)
|
||||||
account, user, err := h.accountManager.GetAccountFromToken(r.Context(), claims)
|
accountID, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
util.WriteError(r.Context(), err, w)
|
util.WriteError(r.Context(), err, w)
|
||||||
return
|
return
|
||||||
@ -130,7 +130,7 @@ func (h *NameserversHandler) UpdateNameserverGroup(w http.ResponseWriter, r *htt
|
|||||||
SearchDomainsEnabled: req.SearchDomainsEnabled,
|
SearchDomainsEnabled: req.SearchDomainsEnabled,
|
||||||
}
|
}
|
||||||
|
|
||||||
err = h.accountManager.SaveNameServerGroup(r.Context(), account.Id, user.Id, updatedNSGroup)
|
err = h.accountManager.SaveNameServerGroup(r.Context(), accountID, userID, updatedNSGroup)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
util.WriteError(r.Context(), err, w)
|
util.WriteError(r.Context(), err, w)
|
||||||
return
|
return
|
||||||
@ -144,7 +144,7 @@ func (h *NameserversHandler) UpdateNameserverGroup(w http.ResponseWriter, r *htt
|
|||||||
// DeleteNameserverGroup handles nameserver group deletion request
|
// DeleteNameserverGroup handles nameserver group deletion request
|
||||||
func (h *NameserversHandler) DeleteNameserverGroup(w http.ResponseWriter, r *http.Request) {
|
func (h *NameserversHandler) DeleteNameserverGroup(w http.ResponseWriter, r *http.Request) {
|
||||||
claims := h.claimsExtractor.FromRequestContext(r)
|
claims := h.claimsExtractor.FromRequestContext(r)
|
||||||
account, user, err := h.accountManager.GetAccountFromToken(r.Context(), claims)
|
accountID, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
util.WriteError(r.Context(), err, w)
|
util.WriteError(r.Context(), err, w)
|
||||||
return
|
return
|
||||||
@ -156,7 +156,7 @@ func (h *NameserversHandler) DeleteNameserverGroup(w http.ResponseWriter, r *htt
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
err = h.accountManager.DeleteNameServerGroup(r.Context(), account.Id, nsGroupID, user.Id)
|
err = h.accountManager.DeleteNameServerGroup(r.Context(), accountID, nsGroupID, userID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
util.WriteError(r.Context(), err, w)
|
util.WriteError(r.Context(), err, w)
|
||||||
return
|
return
|
||||||
@ -168,7 +168,7 @@ func (h *NameserversHandler) DeleteNameserverGroup(w http.ResponseWriter, r *htt
|
|||||||
// GetNameserverGroup handles a nameserver group Get request identified by ID
|
// GetNameserverGroup handles a nameserver group Get request identified by ID
|
||||||
func (h *NameserversHandler) GetNameserverGroup(w http.ResponseWriter, r *http.Request) {
|
func (h *NameserversHandler) GetNameserverGroup(w http.ResponseWriter, r *http.Request) {
|
||||||
claims := h.claimsExtractor.FromRequestContext(r)
|
claims := h.claimsExtractor.FromRequestContext(r)
|
||||||
account, user, err := h.accountManager.GetAccountFromToken(r.Context(), claims)
|
accountID, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.WithContext(r.Context()).Error(err)
|
log.WithContext(r.Context()).Error(err)
|
||||||
http.Redirect(w, r, "/", http.StatusInternalServerError)
|
http.Redirect(w, r, "/", http.StatusInternalServerError)
|
||||||
@ -181,7 +181,7 @@ func (h *NameserversHandler) GetNameserverGroup(w http.ResponseWriter, r *http.R
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
nsGroup, err := h.accountManager.GetNameServerGroup(r.Context(), account.Id, user.Id, nsGroupID)
|
nsGroup, err := h.accountManager.GetNameServerGroup(r.Context(), accountID, userID, nsGroupID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
util.WriteError(r.Context(), err, w)
|
util.WriteError(r.Context(), err, w)
|
||||||
return
|
return
|
||||||
|
@ -18,7 +18,6 @@ import (
|
|||||||
|
|
||||||
"github.com/gorilla/mux"
|
"github.com/gorilla/mux"
|
||||||
|
|
||||||
"github.com/netbirdio/netbird/management/server"
|
|
||||||
"github.com/netbirdio/netbird/management/server/jwtclaims"
|
"github.com/netbirdio/netbird/management/server/jwtclaims"
|
||||||
"github.com/netbirdio/netbird/management/server/mock_server"
|
"github.com/netbirdio/netbird/management/server/mock_server"
|
||||||
)
|
)
|
||||||
@ -29,14 +28,6 @@ const (
|
|||||||
testNSGroupAccountID = "test_id"
|
testNSGroupAccountID = "test_id"
|
||||||
)
|
)
|
||||||
|
|
||||||
var testingNSAccount = &server.Account{
|
|
||||||
Id: testNSGroupAccountID,
|
|
||||||
Domain: "hotmail.com",
|
|
||||||
Users: map[string]*server.User{
|
|
||||||
"test_user": server.NewAdminUser("test_user"),
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
var baseExistingNSGroup = &nbdns.NameServerGroup{
|
var baseExistingNSGroup = &nbdns.NameServerGroup{
|
||||||
ID: existingNSGroupID,
|
ID: existingNSGroupID,
|
||||||
Name: "super",
|
Name: "super",
|
||||||
@ -90,8 +81,8 @@ func initNameserversTestData() *NameserversHandler {
|
|||||||
}
|
}
|
||||||
return status.Errorf(status.NotFound, "nameserver group with ID %s was not found", nsGroupToSave.ID)
|
return status.Errorf(status.NotFound, "nameserver group with ID %s was not found", nsGroupToSave.ID)
|
||||||
},
|
},
|
||||||
GetAccountFromTokenFunc: func(_ context.Context, _ jwtclaims.AuthorizationClaims) (*server.Account, *server.User, error) {
|
GetAccountIDFromTokenFunc: func(_ context.Context, claims jwtclaims.AuthorizationClaims) (string, string, error) {
|
||||||
return testingNSAccount, testingAccount.Users["test_user"], nil
|
return claims.AccountId, claims.UserId, nil
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
claimsExtractor: jwtclaims.NewClaimsExtractor(
|
claimsExtractor: jwtclaims.NewClaimsExtractor(
|
||||||
|
@ -34,20 +34,20 @@ func NewPATsHandler(accountManager server.AccountManager, authCfg AuthCfg) *PATH
|
|||||||
// GetAllTokens is HTTP GET handler that returns a list of all personal access tokens for the given user
|
// GetAllTokens is HTTP GET handler that returns a list of all personal access tokens for the given user
|
||||||
func (h *PATHandler) GetAllTokens(w http.ResponseWriter, r *http.Request) {
|
func (h *PATHandler) GetAllTokens(w http.ResponseWriter, r *http.Request) {
|
||||||
claims := h.claimsExtractor.FromRequestContext(r)
|
claims := h.claimsExtractor.FromRequestContext(r)
|
||||||
account, user, err := h.accountManager.GetAccountFromToken(r.Context(), claims)
|
accountID, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
util.WriteError(r.Context(), err, w)
|
util.WriteError(r.Context(), err, w)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
vars := mux.Vars(r)
|
vars := mux.Vars(r)
|
||||||
userID := vars["userId"]
|
targetUserID := vars["userId"]
|
||||||
if len(userID) == 0 {
|
if len(userID) == 0 {
|
||||||
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "invalid user ID"), w)
|
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "invalid user ID"), w)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
pats, err := h.accountManager.GetAllPATs(r.Context(), account.Id, user.Id, userID)
|
pats, err := h.accountManager.GetAllPATs(r.Context(), accountID, userID, targetUserID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
util.WriteError(r.Context(), err, w)
|
util.WriteError(r.Context(), err, w)
|
||||||
return
|
return
|
||||||
@ -64,7 +64,7 @@ func (h *PATHandler) GetAllTokens(w http.ResponseWriter, r *http.Request) {
|
|||||||
// GetToken is HTTP GET handler that returns a personal access token for the given user
|
// GetToken is HTTP GET handler that returns a personal access token for the given user
|
||||||
func (h *PATHandler) GetToken(w http.ResponseWriter, r *http.Request) {
|
func (h *PATHandler) GetToken(w http.ResponseWriter, r *http.Request) {
|
||||||
claims := h.claimsExtractor.FromRequestContext(r)
|
claims := h.claimsExtractor.FromRequestContext(r)
|
||||||
account, user, err := h.accountManager.GetAccountFromToken(r.Context(), claims)
|
accountID, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
util.WriteError(r.Context(), err, w)
|
util.WriteError(r.Context(), err, w)
|
||||||
return
|
return
|
||||||
@ -83,7 +83,7 @@ func (h *PATHandler) GetToken(w http.ResponseWriter, r *http.Request) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
pat, err := h.accountManager.GetPAT(r.Context(), account.Id, user.Id, targetUserID, tokenID)
|
pat, err := h.accountManager.GetPAT(r.Context(), accountID, userID, targetUserID, tokenID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
util.WriteError(r.Context(), err, w)
|
util.WriteError(r.Context(), err, w)
|
||||||
return
|
return
|
||||||
@ -95,7 +95,7 @@ func (h *PATHandler) GetToken(w http.ResponseWriter, r *http.Request) {
|
|||||||
// CreateToken is HTTP POST handler that creates a personal access token for the given user
|
// CreateToken is HTTP POST handler that creates a personal access token for the given user
|
||||||
func (h *PATHandler) CreateToken(w http.ResponseWriter, r *http.Request) {
|
func (h *PATHandler) CreateToken(w http.ResponseWriter, r *http.Request) {
|
||||||
claims := h.claimsExtractor.FromRequestContext(r)
|
claims := h.claimsExtractor.FromRequestContext(r)
|
||||||
account, user, err := h.accountManager.GetAccountFromToken(r.Context(), claims)
|
accountID, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
util.WriteError(r.Context(), err, w)
|
util.WriteError(r.Context(), err, w)
|
||||||
return
|
return
|
||||||
@ -115,7 +115,7 @@ func (h *PATHandler) CreateToken(w http.ResponseWriter, r *http.Request) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
pat, err := h.accountManager.CreatePAT(r.Context(), account.Id, user.Id, targetUserID, req.Name, req.ExpiresIn)
|
pat, err := h.accountManager.CreatePAT(r.Context(), accountID, userID, targetUserID, req.Name, req.ExpiresIn)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
util.WriteError(r.Context(), err, w)
|
util.WriteError(r.Context(), err, w)
|
||||||
return
|
return
|
||||||
@ -127,7 +127,7 @@ func (h *PATHandler) CreateToken(w http.ResponseWriter, r *http.Request) {
|
|||||||
// DeleteToken is HTTP DELETE handler that deletes a personal access token for the given user
|
// DeleteToken is HTTP DELETE handler that deletes a personal access token for the given user
|
||||||
func (h *PATHandler) DeleteToken(w http.ResponseWriter, r *http.Request) {
|
func (h *PATHandler) DeleteToken(w http.ResponseWriter, r *http.Request) {
|
||||||
claims := h.claimsExtractor.FromRequestContext(r)
|
claims := h.claimsExtractor.FromRequestContext(r)
|
||||||
account, user, err := h.accountManager.GetAccountFromToken(r.Context(), claims)
|
accountID, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
util.WriteError(r.Context(), err, w)
|
util.WriteError(r.Context(), err, w)
|
||||||
return
|
return
|
||||||
@ -146,7 +146,7 @@ func (h *PATHandler) DeleteToken(w http.ResponseWriter, r *http.Request) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
err = h.accountManager.DeletePAT(r.Context(), account.Id, user.Id, targetUserID, tokenID)
|
err = h.accountManager.DeletePAT(r.Context(), accountID, userID, targetUserID, tokenID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
util.WriteError(r.Context(), err, w)
|
util.WriteError(r.Context(), err, w)
|
||||||
return
|
return
|
||||||
|
@ -77,8 +77,8 @@ func initPATTestData() *PATHandler {
|
|||||||
}, nil
|
}, nil
|
||||||
},
|
},
|
||||||
|
|
||||||
GetAccountFromTokenFunc: func(_ context.Context, _ jwtclaims.AuthorizationClaims) (*server.Account, *server.User, error) {
|
GetAccountIDFromTokenFunc: func(_ context.Context, claims jwtclaims.AuthorizationClaims) (string, string, error) {
|
||||||
return testAccount, testAccount.Users[existingUserID], nil
|
return claims.AccountId, claims.UserId, nil
|
||||||
},
|
},
|
||||||
DeletePATFunc: func(_ context.Context, accountID string, initiatorUserID string, targetUserID string, tokenID string) error {
|
DeletePATFunc: func(_ context.Context, accountID string, initiatorUserID string, targetUserID string, tokenID string) error {
|
||||||
if accountID != existingAccountID {
|
if accountID != existingAccountID {
|
||||||
@ -119,7 +119,7 @@ func initPATTestData() *PATHandler {
|
|||||||
return jwtclaims.AuthorizationClaims{
|
return jwtclaims.AuthorizationClaims{
|
||||||
UserId: existingUserID,
|
UserId: existingUserID,
|
||||||
Domain: testDomain,
|
Domain: testDomain,
|
||||||
AccountId: testNSGroupAccountID,
|
AccountId: existingAccountID,
|
||||||
}
|
}
|
||||||
}),
|
}),
|
||||||
),
|
),
|
||||||
|
@ -74,7 +74,7 @@ func (h *PeersHandler) getPeer(ctx context.Context, account *server.Account, pee
|
|||||||
util.WriteJSONObject(ctx, w, toSinglePeerResponse(peerToReturn, groupsInfo, dnsDomain, valid))
|
util.WriteJSONObject(ctx, w, toSinglePeerResponse(peerToReturn, groupsInfo, dnsDomain, valid))
|
||||||
}
|
}
|
||||||
|
|
||||||
func (h *PeersHandler) updatePeer(ctx context.Context, account *server.Account, user *server.User, peerID string, w http.ResponseWriter, r *http.Request) {
|
func (h *PeersHandler) updatePeer(ctx context.Context, account *server.Account, userID, peerID string, w http.ResponseWriter, r *http.Request) {
|
||||||
req := &api.PeerRequest{}
|
req := &api.PeerRequest{}
|
||||||
err := json.NewDecoder(r.Body).Decode(&req)
|
err := json.NewDecoder(r.Body).Decode(&req)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@ -96,7 +96,7 @@ func (h *PeersHandler) updatePeer(ctx context.Context, account *server.Account,
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
peer, err := h.accountManager.UpdatePeer(ctx, account.Id, user.Id, update)
|
peer, err := h.accountManager.UpdatePeer(ctx, account.Id, userID, update)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
util.WriteError(ctx, err, w)
|
util.WriteError(ctx, err, w)
|
||||||
return
|
return
|
||||||
@ -130,7 +130,7 @@ func (h *PeersHandler) deletePeer(ctx context.Context, accountID, userID string,
|
|||||||
// HandlePeer handles all peer requests for GET, PUT and DELETE operations
|
// HandlePeer handles all peer requests for GET, PUT and DELETE operations
|
||||||
func (h *PeersHandler) HandlePeer(w http.ResponseWriter, r *http.Request) {
|
func (h *PeersHandler) HandlePeer(w http.ResponseWriter, r *http.Request) {
|
||||||
claims := h.claimsExtractor.FromRequestContext(r)
|
claims := h.claimsExtractor.FromRequestContext(r)
|
||||||
account, user, err := h.accountManager.GetAccountFromToken(r.Context(), claims)
|
accountID, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
util.WriteError(r.Context(), err, w)
|
util.WriteError(r.Context(), err, w)
|
||||||
return
|
return
|
||||||
@ -144,13 +144,20 @@ func (h *PeersHandler) HandlePeer(w http.ResponseWriter, r *http.Request) {
|
|||||||
|
|
||||||
switch r.Method {
|
switch r.Method {
|
||||||
case http.MethodDelete:
|
case http.MethodDelete:
|
||||||
h.deletePeer(r.Context(), account.Id, user.Id, peerID, w)
|
h.deletePeer(r.Context(), accountID, userID, peerID, w)
|
||||||
return
|
return
|
||||||
case http.MethodPut:
|
case http.MethodGet, http.MethodPut:
|
||||||
h.updatePeer(r.Context(), account, user, peerID, w, r)
|
account, err := h.accountManager.GetAccountByID(r.Context(), accountID, userID)
|
||||||
return
|
if err != nil {
|
||||||
case http.MethodGet:
|
util.WriteError(r.Context(), err, w)
|
||||||
h.getPeer(r.Context(), account, peerID, user.Id, w)
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if r.Method == http.MethodGet {
|
||||||
|
h.getPeer(r.Context(), account, peerID, userID, w)
|
||||||
|
} else {
|
||||||
|
h.updatePeer(r.Context(), account, userID, peerID, w, r)
|
||||||
|
}
|
||||||
return
|
return
|
||||||
default:
|
default:
|
||||||
util.WriteError(r.Context(), status.Errorf(status.NotFound, "unknown METHOD"), w)
|
util.WriteError(r.Context(), status.Errorf(status.NotFound, "unknown METHOD"), w)
|
||||||
@ -159,19 +166,14 @@ func (h *PeersHandler) HandlePeer(w http.ResponseWriter, r *http.Request) {
|
|||||||
|
|
||||||
// GetAllPeers returns a list of all peers associated with a provided account
|
// GetAllPeers returns a list of all peers associated with a provided account
|
||||||
func (h *PeersHandler) GetAllPeers(w http.ResponseWriter, r *http.Request) {
|
func (h *PeersHandler) GetAllPeers(w http.ResponseWriter, r *http.Request) {
|
||||||
if r.Method != http.MethodGet {
|
|
||||||
util.WriteError(r.Context(), status.Errorf(status.NotFound, "unknown METHOD"), w)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
claims := h.claimsExtractor.FromRequestContext(r)
|
claims := h.claimsExtractor.FromRequestContext(r)
|
||||||
account, user, err := h.accountManager.GetAccountFromToken(r.Context(), claims)
|
accountID, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
util.WriteError(r.Context(), err, w)
|
util.WriteError(r.Context(), err, w)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
peers, err := h.accountManager.GetPeers(r.Context(), account.Id, user.Id)
|
account, err := h.accountManager.GetAccountByID(r.Context(), accountID, userID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
util.WriteError(r.Context(), err, w)
|
util.WriteError(r.Context(), err, w)
|
||||||
return
|
return
|
||||||
@ -179,8 +181,8 @@ func (h *PeersHandler) GetAllPeers(w http.ResponseWriter, r *http.Request) {
|
|||||||
|
|
||||||
dnsDomain := h.accountManager.GetDNSDomain()
|
dnsDomain := h.accountManager.GetDNSDomain()
|
||||||
|
|
||||||
respBody := make([]*api.PeerBatch, 0, len(peers))
|
respBody := make([]*api.PeerBatch, 0, len(account.Peers))
|
||||||
for _, peer := range peers {
|
for _, peer := range account.Peers {
|
||||||
peerToReturn, err := h.checkPeerStatus(peer)
|
peerToReturn, err := h.checkPeerStatus(peer)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
util.WriteError(r.Context(), err, w)
|
util.WriteError(r.Context(), err, w)
|
||||||
@ -214,7 +216,7 @@ func (h *PeersHandler) setApprovalRequiredFlag(respBody []*api.PeerBatch, approv
|
|||||||
// GetAccessiblePeers returns a list of all peers that the specified peer can connect to within the network.
|
// GetAccessiblePeers returns a list of all peers that the specified peer can connect to within the network.
|
||||||
func (h *PeersHandler) GetAccessiblePeers(w http.ResponseWriter, r *http.Request) {
|
func (h *PeersHandler) GetAccessiblePeers(w http.ResponseWriter, r *http.Request) {
|
||||||
claims := h.claimsExtractor.FromRequestContext(r)
|
claims := h.claimsExtractor.FromRequestContext(r)
|
||||||
account, user, err := h.accountManager.GetAccountFromToken(r.Context(), claims)
|
accountID, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
util.WriteError(r.Context(), err, w)
|
util.WriteError(r.Context(), err, w)
|
||||||
return
|
return
|
||||||
@ -227,6 +229,18 @@ func (h *PeersHandler) GetAccessiblePeers(w http.ResponseWriter, r *http.Request
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
account, err := h.accountManager.GetAccountByID(r.Context(), accountID, userID)
|
||||||
|
if err != nil {
|
||||||
|
util.WriteError(r.Context(), err, w)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
user, err := account.FindUser(userID)
|
||||||
|
if err != nil {
|
||||||
|
util.WriteError(r.Context(), err, w)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
// If the user is regular user and does not own the peer
|
// If the user is regular user and does not own the peer
|
||||||
// with the given peerID return an empty list
|
// with the given peerID return an empty list
|
||||||
if !user.HasAdminPower() && !user.IsServiceUser {
|
if !user.HasAdminPower() && !user.IsServiceUser {
|
||||||
|
@ -13,16 +13,15 @@ import (
|
|||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/gorilla/mux"
|
"github.com/gorilla/mux"
|
||||||
|
"github.com/netbirdio/netbird/management/server"
|
||||||
nbgroup "github.com/netbirdio/netbird/management/server/group"
|
nbgroup "github.com/netbirdio/netbird/management/server/group"
|
||||||
"github.com/netbirdio/netbird/management/server/http/api"
|
"github.com/netbirdio/netbird/management/server/http/api"
|
||||||
|
"github.com/netbirdio/netbird/management/server/jwtclaims"
|
||||||
nbpeer "github.com/netbirdio/netbird/management/server/peer"
|
nbpeer "github.com/netbirdio/netbird/management/server/peer"
|
||||||
"golang.org/x/exp/maps"
|
"golang.org/x/exp/maps"
|
||||||
|
|
||||||
"github.com/netbirdio/netbird/management/server/jwtclaims"
|
|
||||||
|
|
||||||
"github.com/stretchr/testify/assert"
|
"github.com/stretchr/testify/assert"
|
||||||
|
|
||||||
"github.com/netbirdio/netbird/management/server"
|
|
||||||
"github.com/netbirdio/netbird/management/server/mock_server"
|
"github.com/netbirdio/netbird/management/server/mock_server"
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -70,7 +69,10 @@ func initTestMetaData(peers ...*nbpeer.Peer) *PeersHandler {
|
|||||||
GetDNSDomainFunc: func() string {
|
GetDNSDomainFunc: func() string {
|
||||||
return "netbird.selfhosted"
|
return "netbird.selfhosted"
|
||||||
},
|
},
|
||||||
GetAccountFromTokenFunc: func(_ context.Context, claims jwtclaims.AuthorizationClaims) (*server.Account, *server.User, error) {
|
GetAccountIDFromTokenFunc: func(_ context.Context, claims jwtclaims.AuthorizationClaims) (string, string, error) {
|
||||||
|
return claims.AccountId, claims.UserId, nil
|
||||||
|
},
|
||||||
|
GetAccountByIDFunc: func(ctx context.Context, accountID string, userID string) (*server.Account, error) {
|
||||||
peersMap := make(map[string]*nbpeer.Peer)
|
peersMap := make(map[string]*nbpeer.Peer)
|
||||||
for _, peer := range peers {
|
for _, peer := range peers {
|
||||||
peersMap[peer.ID] = peer.Copy()
|
peersMap[peer.ID] = peer.Copy()
|
||||||
@ -78,7 +80,7 @@ func initTestMetaData(peers ...*nbpeer.Peer) *PeersHandler {
|
|||||||
|
|
||||||
policy := &server.Policy{
|
policy := &server.Policy{
|
||||||
ID: "policy",
|
ID: "policy",
|
||||||
AccountID: claims.AccountId,
|
AccountID: accountID,
|
||||||
Name: "policy",
|
Name: "policy",
|
||||||
Enabled: true,
|
Enabled: true,
|
||||||
Rules: []*server.PolicyRule{
|
Rules: []*server.PolicyRule{
|
||||||
@ -100,7 +102,7 @@ func initTestMetaData(peers ...*nbpeer.Peer) *PeersHandler {
|
|||||||
srvUser.IsServiceUser = true
|
srvUser.IsServiceUser = true
|
||||||
|
|
||||||
account := &server.Account{
|
account := &server.Account{
|
||||||
Id: claims.AccountId,
|
Id: accountID,
|
||||||
Domain: "hotmail.com",
|
Domain: "hotmail.com",
|
||||||
Peers: peersMap,
|
Peers: peersMap,
|
||||||
Users: map[string]*server.User{
|
Users: map[string]*server.User{
|
||||||
@ -111,7 +113,7 @@ func initTestMetaData(peers ...*nbpeer.Peer) *PeersHandler {
|
|||||||
Groups: map[string]*nbgroup.Group{
|
Groups: map[string]*nbgroup.Group{
|
||||||
"group1": {
|
"group1": {
|
||||||
ID: "group1",
|
ID: "group1",
|
||||||
AccountID: claims.AccountId,
|
AccountID: accountID,
|
||||||
Name: "group1",
|
Name: "group1",
|
||||||
Issued: "api",
|
Issued: "api",
|
||||||
Peers: maps.Keys(peersMap),
|
Peers: maps.Keys(peersMap),
|
||||||
@ -132,7 +134,7 @@ func initTestMetaData(peers ...*nbpeer.Peer) *PeersHandler {
|
|||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
return account, account.Users[claims.UserId], nil
|
return account, nil
|
||||||
},
|
},
|
||||||
HasConnectedChannelFunc: func(peerID string) bool {
|
HasConnectedChannelFunc: func(peerID string) bool {
|
||||||
statuses := make(map[string]struct{})
|
statuses := make(map[string]struct{})
|
||||||
@ -279,9 +281,15 @@ func TestGetPeers(t *testing.T) {
|
|||||||
|
|
||||||
// hardcode this check for now as we only have two peers in this suite
|
// hardcode this check for now as we only have two peers in this suite
|
||||||
assert.Equal(t, len(respBody), 2)
|
assert.Equal(t, len(respBody), 2)
|
||||||
assert.Equal(t, respBody[1].Connected, false)
|
|
||||||
|
|
||||||
got = respBody[0]
|
for _, peer := range respBody {
|
||||||
|
if peer.Id == testPeerID {
|
||||||
|
got = peer
|
||||||
|
} else {
|
||||||
|
assert.Equal(t, peer.Connected, false)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
} else {
|
} else {
|
||||||
got = &api.Peer{}
|
got = &api.Peer{}
|
||||||
err = json.Unmarshal(content, got)
|
err = json.Unmarshal(content, got)
|
||||||
|
@ -6,6 +6,7 @@ import (
|
|||||||
"strconv"
|
"strconv"
|
||||||
|
|
||||||
"github.com/gorilla/mux"
|
"github.com/gorilla/mux"
|
||||||
|
nbgroup "github.com/netbirdio/netbird/management/server/group"
|
||||||
"github.com/rs/xid"
|
"github.com/rs/xid"
|
||||||
|
|
||||||
"github.com/netbirdio/netbird/management/server"
|
"github.com/netbirdio/netbird/management/server"
|
||||||
@ -35,21 +36,27 @@ func NewPoliciesHandler(accountManager server.AccountManager, authCfg AuthCfg) *
|
|||||||
// GetAllPolicies list for the account
|
// GetAllPolicies list for the account
|
||||||
func (h *Policies) GetAllPolicies(w http.ResponseWriter, r *http.Request) {
|
func (h *Policies) GetAllPolicies(w http.ResponseWriter, r *http.Request) {
|
||||||
claims := h.claimsExtractor.FromRequestContext(r)
|
claims := h.claimsExtractor.FromRequestContext(r)
|
||||||
account, user, err := h.accountManager.GetAccountFromToken(r.Context(), claims)
|
accountID, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
util.WriteError(r.Context(), err, w)
|
util.WriteError(r.Context(), err, w)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
accountPolicies, err := h.accountManager.ListPolicies(r.Context(), account.Id, user.Id)
|
listPolicies, err := h.accountManager.ListPolicies(r.Context(), accountID, userID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
util.WriteError(r.Context(), err, w)
|
util.WriteError(r.Context(), err, w)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
policies := []*api.Policy{}
|
allGroups, err := h.accountManager.GetAllGroups(r.Context(), accountID, userID)
|
||||||
for _, policy := range accountPolicies {
|
if err != nil {
|
||||||
resp := toPolicyResponse(account, policy)
|
util.WriteError(r.Context(), err, w)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
policies := make([]*api.Policy, 0, len(listPolicies))
|
||||||
|
for _, policy := range listPolicies {
|
||||||
|
resp := toPolicyResponse(allGroups, policy)
|
||||||
if len(resp.Rules) == 0 {
|
if len(resp.Rules) == 0 {
|
||||||
util.WriteError(r.Context(), status.Errorf(status.Internal, "no rules in the policy"), w)
|
util.WriteError(r.Context(), status.Errorf(status.Internal, "no rules in the policy"), w)
|
||||||
return
|
return
|
||||||
@ -63,7 +70,7 @@ func (h *Policies) GetAllPolicies(w http.ResponseWriter, r *http.Request) {
|
|||||||
// UpdatePolicy handles update to a policy identified by a given ID
|
// UpdatePolicy handles update to a policy identified by a given ID
|
||||||
func (h *Policies) UpdatePolicy(w http.ResponseWriter, r *http.Request) {
|
func (h *Policies) UpdatePolicy(w http.ResponseWriter, r *http.Request) {
|
||||||
claims := h.claimsExtractor.FromRequestContext(r)
|
claims := h.claimsExtractor.FromRequestContext(r)
|
||||||
account, user, err := h.accountManager.GetAccountFromToken(r.Context(), claims)
|
accountID, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
util.WriteError(r.Context(), err, w)
|
util.WriteError(r.Context(), err, w)
|
||||||
return
|
return
|
||||||
@ -76,41 +83,29 @@ func (h *Policies) UpdatePolicy(w http.ResponseWriter, r *http.Request) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
policyIdx := -1
|
_, err = h.accountManager.GetPolicy(r.Context(), accountID, policyID, userID)
|
||||||
for i, policy := range account.Policies {
|
|
||||||
if policy.ID == policyID {
|
|
||||||
policyIdx = i
|
|
||||||
break
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if policyIdx < 0 {
|
|
||||||
util.WriteError(r.Context(), status.Errorf(status.NotFound, "couldn't find policy id %s", policyID), w)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
h.savePolicy(w, r, account, user, policyID)
|
|
||||||
}
|
|
||||||
|
|
||||||
// CreatePolicy handles policy creation request
|
|
||||||
func (h *Policies) CreatePolicy(w http.ResponseWriter, r *http.Request) {
|
|
||||||
claims := h.claimsExtractor.FromRequestContext(r)
|
|
||||||
account, user, err := h.accountManager.GetAccountFromToken(r.Context(), claims)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
util.WriteError(r.Context(), err, w)
|
util.WriteError(r.Context(), err, w)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
h.savePolicy(w, r, account, user, "")
|
h.savePolicy(w, r, accountID, userID, policyID)
|
||||||
|
}
|
||||||
|
|
||||||
|
// CreatePolicy handles policy creation request
|
||||||
|
func (h *Policies) CreatePolicy(w http.ResponseWriter, r *http.Request) {
|
||||||
|
claims := h.claimsExtractor.FromRequestContext(r)
|
||||||
|
accountID, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims)
|
||||||
|
if err != nil {
|
||||||
|
util.WriteError(r.Context(), err, w)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
h.savePolicy(w, r, accountID, userID, "")
|
||||||
}
|
}
|
||||||
|
|
||||||
// savePolicy handles policy creation and update
|
// savePolicy handles policy creation and update
|
||||||
func (h *Policies) savePolicy(
|
func (h *Policies) savePolicy(w http.ResponseWriter, r *http.Request, accountID string, userID string, policyID string) {
|
||||||
w http.ResponseWriter,
|
|
||||||
r *http.Request,
|
|
||||||
account *server.Account,
|
|
||||||
user *server.User,
|
|
||||||
policyID string,
|
|
||||||
) {
|
|
||||||
var req api.PutApiPoliciesPolicyIdJSONRequestBody
|
var req api.PutApiPoliciesPolicyIdJSONRequestBody
|
||||||
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
|
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
|
||||||
util.WriteErrorResponse("couldn't parse JSON request", http.StatusBadRequest, w)
|
util.WriteErrorResponse("couldn't parse JSON request", http.StatusBadRequest, w)
|
||||||
@ -127,6 +122,8 @@ func (h *Policies) savePolicy(
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
isUpdate := policyID != ""
|
||||||
|
|
||||||
if policyID == "" {
|
if policyID == "" {
|
||||||
policyID = xid.New().String()
|
policyID = xid.New().String()
|
||||||
}
|
}
|
||||||
@ -141,8 +138,8 @@ func (h *Policies) savePolicy(
|
|||||||
pr := server.PolicyRule{
|
pr := server.PolicyRule{
|
||||||
ID: policyID, // TODO: when policy can contain multiple rules, need refactor
|
ID: policyID, // TODO: when policy can contain multiple rules, need refactor
|
||||||
Name: rule.Name,
|
Name: rule.Name,
|
||||||
Destinations: groupMinimumsToStrings(account, rule.Destinations),
|
Destinations: rule.Destinations,
|
||||||
Sources: groupMinimumsToStrings(account, rule.Sources),
|
Sources: rule.Sources,
|
||||||
Bidirectional: rule.Bidirectional,
|
Bidirectional: rule.Bidirectional,
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -207,15 +204,21 @@ func (h *Policies) savePolicy(
|
|||||||
}
|
}
|
||||||
|
|
||||||
if req.SourcePostureChecks != nil {
|
if req.SourcePostureChecks != nil {
|
||||||
policy.SourcePostureChecks = sourcePostureChecksToStrings(account, *req.SourcePostureChecks)
|
policy.SourcePostureChecks = *req.SourcePostureChecks
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := h.accountManager.SavePolicy(r.Context(), account.Id, user.Id, &policy); err != nil {
|
if err := h.accountManager.SavePolicy(r.Context(), accountID, userID, &policy, isUpdate); err != nil {
|
||||||
util.WriteError(r.Context(), err, w)
|
util.WriteError(r.Context(), err, w)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
resp := toPolicyResponse(account, &policy)
|
allGroups, err := h.accountManager.GetAllGroups(r.Context(), accountID, userID)
|
||||||
|
if err != nil {
|
||||||
|
util.WriteError(r.Context(), err, w)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
resp := toPolicyResponse(allGroups, &policy)
|
||||||
if len(resp.Rules) == 0 {
|
if len(resp.Rules) == 0 {
|
||||||
util.WriteError(r.Context(), status.Errorf(status.Internal, "no rules in the policy"), w)
|
util.WriteError(r.Context(), status.Errorf(status.Internal, "no rules in the policy"), w)
|
||||||
return
|
return
|
||||||
@ -227,12 +230,11 @@ func (h *Policies) savePolicy(
|
|||||||
// DeletePolicy handles policy deletion request
|
// DeletePolicy handles policy deletion request
|
||||||
func (h *Policies) DeletePolicy(w http.ResponseWriter, r *http.Request) {
|
func (h *Policies) DeletePolicy(w http.ResponseWriter, r *http.Request) {
|
||||||
claims := h.claimsExtractor.FromRequestContext(r)
|
claims := h.claimsExtractor.FromRequestContext(r)
|
||||||
account, user, err := h.accountManager.GetAccountFromToken(r.Context(), claims)
|
accountID, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
util.WriteError(r.Context(), err, w)
|
util.WriteError(r.Context(), err, w)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
aID := account.Id
|
|
||||||
|
|
||||||
vars := mux.Vars(r)
|
vars := mux.Vars(r)
|
||||||
policyID := vars["policyId"]
|
policyID := vars["policyId"]
|
||||||
@ -241,7 +243,7 @@ func (h *Policies) DeletePolicy(w http.ResponseWriter, r *http.Request) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
if err = h.accountManager.DeletePolicy(r.Context(), aID, policyID, user.Id); err != nil {
|
if err = h.accountManager.DeletePolicy(r.Context(), accountID, policyID, userID); err != nil {
|
||||||
util.WriteError(r.Context(), err, w)
|
util.WriteError(r.Context(), err, w)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
@ -252,40 +254,46 @@ func (h *Policies) DeletePolicy(w http.ResponseWriter, r *http.Request) {
|
|||||||
// GetPolicy handles a group Get request identified by ID
|
// GetPolicy handles a group Get request identified by ID
|
||||||
func (h *Policies) GetPolicy(w http.ResponseWriter, r *http.Request) {
|
func (h *Policies) GetPolicy(w http.ResponseWriter, r *http.Request) {
|
||||||
claims := h.claimsExtractor.FromRequestContext(r)
|
claims := h.claimsExtractor.FromRequestContext(r)
|
||||||
account, user, err := h.accountManager.GetAccountFromToken(r.Context(), claims)
|
accountID, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
util.WriteError(r.Context(), err, w)
|
util.WriteError(r.Context(), err, w)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
switch r.Method {
|
vars := mux.Vars(r)
|
||||||
case http.MethodGet:
|
policyID := vars["policyId"]
|
||||||
vars := mux.Vars(r)
|
if len(policyID) == 0 {
|
||||||
policyID := vars["policyId"]
|
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "invalid policy ID"), w)
|
||||||
if len(policyID) == 0 {
|
return
|
||||||
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "invalid policy ID"), w)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
policy, err := h.accountManager.GetPolicy(r.Context(), account.Id, policyID, user.Id)
|
|
||||||
if err != nil {
|
|
||||||
util.WriteError(r.Context(), err, w)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
resp := toPolicyResponse(account, policy)
|
|
||||||
if len(resp.Rules) == 0 {
|
|
||||||
util.WriteError(r.Context(), status.Errorf(status.Internal, "no rules in the policy"), w)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
util.WriteJSONObject(r.Context(), w, resp)
|
|
||||||
default:
|
|
||||||
util.WriteError(r.Context(), status.Errorf(status.NotFound, "method not found"), w)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
policy, err := h.accountManager.GetPolicy(r.Context(), accountID, policyID, userID)
|
||||||
|
if err != nil {
|
||||||
|
util.WriteError(r.Context(), err, w)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
allGroups, err := h.accountManager.GetAllGroups(r.Context(), accountID, userID)
|
||||||
|
if err != nil {
|
||||||
|
util.WriteError(r.Context(), err, w)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
resp := toPolicyResponse(allGroups, policy)
|
||||||
|
if len(resp.Rules) == 0 {
|
||||||
|
util.WriteError(r.Context(), status.Errorf(status.Internal, "no rules in the policy"), w)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
util.WriteJSONObject(r.Context(), w, resp)
|
||||||
}
|
}
|
||||||
|
|
||||||
func toPolicyResponse(account *server.Account, policy *server.Policy) *api.Policy {
|
func toPolicyResponse(groups []*nbgroup.Group, policy *server.Policy) *api.Policy {
|
||||||
|
groupsMap := make(map[string]*nbgroup.Group)
|
||||||
|
for _, group := range groups {
|
||||||
|
groupsMap[group.ID] = group
|
||||||
|
}
|
||||||
|
|
||||||
cache := make(map[string]api.GroupMinimum)
|
cache := make(map[string]api.GroupMinimum)
|
||||||
ap := &api.Policy{
|
ap := &api.Policy{
|
||||||
Id: &policy.ID,
|
Id: &policy.ID,
|
||||||
@ -306,16 +314,18 @@ func toPolicyResponse(account *server.Account, policy *server.Policy) *api.Polic
|
|||||||
Protocol: api.PolicyRuleProtocol(r.Protocol),
|
Protocol: api.PolicyRuleProtocol(r.Protocol),
|
||||||
Action: api.PolicyRuleAction(r.Action),
|
Action: api.PolicyRuleAction(r.Action),
|
||||||
}
|
}
|
||||||
|
|
||||||
if len(r.Ports) != 0 {
|
if len(r.Ports) != 0 {
|
||||||
portsCopy := r.Ports
|
portsCopy := r.Ports
|
||||||
rule.Ports = &portsCopy
|
rule.Ports = &portsCopy
|
||||||
}
|
}
|
||||||
|
|
||||||
for _, gid := range r.Sources {
|
for _, gid := range r.Sources {
|
||||||
_, ok := cache[gid]
|
_, ok := cache[gid]
|
||||||
if ok {
|
if ok {
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
if group, ok := account.Groups[gid]; ok {
|
if group, ok := groupsMap[gid]; ok {
|
||||||
minimum := api.GroupMinimum{
|
minimum := api.GroupMinimum{
|
||||||
Id: group.ID,
|
Id: group.ID,
|
||||||
Name: group.Name,
|
Name: group.Name,
|
||||||
@ -325,13 +335,14 @@ func toPolicyResponse(account *server.Account, policy *server.Policy) *api.Polic
|
|||||||
cache[gid] = minimum
|
cache[gid] = minimum
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
for _, gid := range r.Destinations {
|
for _, gid := range r.Destinations {
|
||||||
cachedMinimum, ok := cache[gid]
|
cachedMinimum, ok := cache[gid]
|
||||||
if ok {
|
if ok {
|
||||||
rule.Destinations = append(rule.Destinations, cachedMinimum)
|
rule.Destinations = append(rule.Destinations, cachedMinimum)
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
if group, ok := account.Groups[gid]; ok {
|
if group, ok := groupsMap[gid]; ok {
|
||||||
minimum := api.GroupMinimum{
|
minimum := api.GroupMinimum{
|
||||||
Id: group.ID,
|
Id: group.ID,
|
||||||
Name: group.Name,
|
Name: group.Name,
|
||||||
@ -345,28 +356,3 @@ func toPolicyResponse(account *server.Account, policy *server.Policy) *api.Polic
|
|||||||
}
|
}
|
||||||
return ap
|
return ap
|
||||||
}
|
}
|
||||||
|
|
||||||
func groupMinimumsToStrings(account *server.Account, gm []string) []string {
|
|
||||||
result := make([]string, 0, len(gm))
|
|
||||||
for _, g := range gm {
|
|
||||||
if _, ok := account.Groups[g]; !ok {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
result = append(result, g)
|
|
||||||
}
|
|
||||||
return result
|
|
||||||
}
|
|
||||||
|
|
||||||
func sourcePostureChecksToStrings(account *server.Account, postureChecksIds []string) []string {
|
|
||||||
result := make([]string, 0, len(postureChecksIds))
|
|
||||||
for _, id := range postureChecksIds {
|
|
||||||
for _, postureCheck := range account.PostureChecks {
|
|
||||||
if id == postureCheck.ID {
|
|
||||||
result = append(result, id)
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
}
|
|
||||||
return result
|
|
||||||
}
|
|
||||||
|
@ -38,17 +38,23 @@ func initPoliciesTestData(policies ...*server.Policy) *Policies {
|
|||||||
}
|
}
|
||||||
return policy, nil
|
return policy, nil
|
||||||
},
|
},
|
||||||
SavePolicyFunc: func(_ context.Context, _, _ string, policy *server.Policy) error {
|
SavePolicyFunc: func(_ context.Context, _, _ string, policy *server.Policy, _ bool) error {
|
||||||
if !strings.HasPrefix(policy.ID, "id-") {
|
if !strings.HasPrefix(policy.ID, "id-") {
|
||||||
policy.ID = "id-was-set"
|
policy.ID = "id-was-set"
|
||||||
policy.Rules[0].ID = "id-was-set"
|
policy.Rules[0].ID = "id-was-set"
|
||||||
}
|
}
|
||||||
return nil
|
return nil
|
||||||
},
|
},
|
||||||
GetAccountFromTokenFunc: func(_ context.Context, claims jwtclaims.AuthorizationClaims) (*server.Account, *server.User, error) {
|
GetAllGroupsFunc: func(ctx context.Context, accountID, userID string) ([]*nbgroup.Group, error) {
|
||||||
user := server.NewAdminUser("test_user")
|
return []*nbgroup.Group{{ID: "F"}, {ID: "G"}}, nil
|
||||||
|
},
|
||||||
|
GetAccountIDFromTokenFunc: func(_ context.Context, claims jwtclaims.AuthorizationClaims) (string, string, error) {
|
||||||
|
return claims.AccountId, claims.UserId, nil
|
||||||
|
},
|
||||||
|
GetAccountByIDFunc: func(ctx context.Context, accountID string, userID string) (*server.Account, error) {
|
||||||
|
user := server.NewAdminUser(userID)
|
||||||
return &server.Account{
|
return &server.Account{
|
||||||
Id: claims.AccountId,
|
Id: accountID,
|
||||||
Domain: "hotmail.com",
|
Domain: "hotmail.com",
|
||||||
Policies: []*server.Policy{
|
Policies: []*server.Policy{
|
||||||
{ID: "id-existed"},
|
{ID: "id-existed"},
|
||||||
@ -60,7 +66,7 @@ func initPoliciesTestData(policies ...*server.Policy) *Policies {
|
|||||||
Users: map[string]*server.User{
|
Users: map[string]*server.User{
|
||||||
"test_user": user,
|
"test_user": user,
|
||||||
},
|
},
|
||||||
}, user, nil
|
}, nil
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
claimsExtractor: jwtclaims.NewClaimsExtractor(
|
claimsExtractor: jwtclaims.NewClaimsExtractor(
|
||||||
|
@ -37,20 +37,20 @@ func NewPostureChecksHandler(accountManager server.AccountManager, geolocationMa
|
|||||||
// GetAllPostureChecks list for the account
|
// GetAllPostureChecks list for the account
|
||||||
func (p *PostureChecksHandler) GetAllPostureChecks(w http.ResponseWriter, r *http.Request) {
|
func (p *PostureChecksHandler) GetAllPostureChecks(w http.ResponseWriter, r *http.Request) {
|
||||||
claims := p.claimsExtractor.FromRequestContext(r)
|
claims := p.claimsExtractor.FromRequestContext(r)
|
||||||
account, user, err := p.accountManager.GetAccountFromToken(r.Context(), claims)
|
accountID, userID, err := p.accountManager.GetAccountIDFromToken(r.Context(), claims)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
util.WriteError(r.Context(), err, w)
|
util.WriteError(r.Context(), err, w)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
accountPostureChecks, err := p.accountManager.ListPostureChecks(r.Context(), account.Id, user.Id)
|
listPostureChecks, err := p.accountManager.ListPostureChecks(r.Context(), accountID, userID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
util.WriteError(r.Context(), err, w)
|
util.WriteError(r.Context(), err, w)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
postureChecks := []*api.PostureCheck{}
|
postureChecks := make([]*api.PostureCheck, 0, len(listPostureChecks))
|
||||||
for _, postureCheck := range accountPostureChecks {
|
for _, postureCheck := range listPostureChecks {
|
||||||
postureChecks = append(postureChecks, postureCheck.ToAPIResponse())
|
postureChecks = append(postureChecks, postureCheck.ToAPIResponse())
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -60,7 +60,7 @@ func (p *PostureChecksHandler) GetAllPostureChecks(w http.ResponseWriter, r *htt
|
|||||||
// UpdatePostureCheck handles update to a posture check identified by a given ID
|
// UpdatePostureCheck handles update to a posture check identified by a given ID
|
||||||
func (p *PostureChecksHandler) UpdatePostureCheck(w http.ResponseWriter, r *http.Request) {
|
func (p *PostureChecksHandler) UpdatePostureCheck(w http.ResponseWriter, r *http.Request) {
|
||||||
claims := p.claimsExtractor.FromRequestContext(r)
|
claims := p.claimsExtractor.FromRequestContext(r)
|
||||||
account, user, err := p.accountManager.GetAccountFromToken(r.Context(), claims)
|
accountID, userID, err := p.accountManager.GetAccountIDFromToken(r.Context(), claims)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
util.WriteError(r.Context(), err, w)
|
util.WriteError(r.Context(), err, w)
|
||||||
return
|
return
|
||||||
@ -73,37 +73,31 @@ func (p *PostureChecksHandler) UpdatePostureCheck(w http.ResponseWriter, r *http
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
postureChecksIdx := -1
|
_, err = p.accountManager.GetPostureChecks(r.Context(), accountID, postureChecksID, userID)
|
||||||
for i, postureCheck := range account.PostureChecks {
|
if err != nil {
|
||||||
if postureCheck.ID == postureChecksID {
|
util.WriteError(r.Context(), err, w)
|
||||||
postureChecksIdx = i
|
|
||||||
break
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if postureChecksIdx < 0 {
|
|
||||||
util.WriteError(r.Context(), status.Errorf(status.NotFound, "couldn't find posture checks id %s", postureChecksID), w)
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
p.savePostureChecks(w, r, account, user, postureChecksID)
|
p.savePostureChecks(w, r, accountID, userID, postureChecksID)
|
||||||
}
|
}
|
||||||
|
|
||||||
// CreatePostureCheck handles posture check creation request
|
// CreatePostureCheck handles posture check creation request
|
||||||
func (p *PostureChecksHandler) CreatePostureCheck(w http.ResponseWriter, r *http.Request) {
|
func (p *PostureChecksHandler) CreatePostureCheck(w http.ResponseWriter, r *http.Request) {
|
||||||
claims := p.claimsExtractor.FromRequestContext(r)
|
claims := p.claimsExtractor.FromRequestContext(r)
|
||||||
account, user, err := p.accountManager.GetAccountFromToken(r.Context(), claims)
|
accountID, userID, err := p.accountManager.GetAccountIDFromToken(r.Context(), claims)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
util.WriteError(r.Context(), err, w)
|
util.WriteError(r.Context(), err, w)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
p.savePostureChecks(w, r, account, user, "")
|
p.savePostureChecks(w, r, accountID, userID, "")
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetPostureCheck handles a posture check Get request identified by ID
|
// GetPostureCheck handles a posture check Get request identified by ID
|
||||||
func (p *PostureChecksHandler) GetPostureCheck(w http.ResponseWriter, r *http.Request) {
|
func (p *PostureChecksHandler) GetPostureCheck(w http.ResponseWriter, r *http.Request) {
|
||||||
claims := p.claimsExtractor.FromRequestContext(r)
|
claims := p.claimsExtractor.FromRequestContext(r)
|
||||||
account, user, err := p.accountManager.GetAccountFromToken(r.Context(), claims)
|
accountID, userID, err := p.accountManager.GetAccountIDFromToken(r.Context(), claims)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
util.WriteError(r.Context(), err, w)
|
util.WriteError(r.Context(), err, w)
|
||||||
return
|
return
|
||||||
@ -116,7 +110,7 @@ func (p *PostureChecksHandler) GetPostureCheck(w http.ResponseWriter, r *http.Re
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
postureChecks, err := p.accountManager.GetPostureChecks(r.Context(), account.Id, postureChecksID, user.Id)
|
postureChecks, err := p.accountManager.GetPostureChecks(r.Context(), accountID, postureChecksID, userID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
util.WriteError(r.Context(), err, w)
|
util.WriteError(r.Context(), err, w)
|
||||||
return
|
return
|
||||||
@ -128,7 +122,7 @@ func (p *PostureChecksHandler) GetPostureCheck(w http.ResponseWriter, r *http.Re
|
|||||||
// DeletePostureCheck handles posture check deletion request
|
// DeletePostureCheck handles posture check deletion request
|
||||||
func (p *PostureChecksHandler) DeletePostureCheck(w http.ResponseWriter, r *http.Request) {
|
func (p *PostureChecksHandler) DeletePostureCheck(w http.ResponseWriter, r *http.Request) {
|
||||||
claims := p.claimsExtractor.FromRequestContext(r)
|
claims := p.claimsExtractor.FromRequestContext(r)
|
||||||
account, user, err := p.accountManager.GetAccountFromToken(r.Context(), claims)
|
accountID, userID, err := p.accountManager.GetAccountIDFromToken(r.Context(), claims)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
util.WriteError(r.Context(), err, w)
|
util.WriteError(r.Context(), err, w)
|
||||||
return
|
return
|
||||||
@ -141,7 +135,7 @@ func (p *PostureChecksHandler) DeletePostureCheck(w http.ResponseWriter, r *http
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
if err = p.accountManager.DeletePostureChecks(r.Context(), account.Id, postureChecksID, user.Id); err != nil {
|
if err = p.accountManager.DeletePostureChecks(r.Context(), accountID, postureChecksID, userID); err != nil {
|
||||||
util.WriteError(r.Context(), err, w)
|
util.WriteError(r.Context(), err, w)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
@ -150,13 +144,7 @@ func (p *PostureChecksHandler) DeletePostureCheck(w http.ResponseWriter, r *http
|
|||||||
}
|
}
|
||||||
|
|
||||||
// savePostureChecks handles posture checks create and update
|
// savePostureChecks handles posture checks create and update
|
||||||
func (p *PostureChecksHandler) savePostureChecks(
|
func (p *PostureChecksHandler) savePostureChecks(w http.ResponseWriter, r *http.Request, accountID, userID, postureChecksID string) {
|
||||||
w http.ResponseWriter,
|
|
||||||
r *http.Request,
|
|
||||||
account *server.Account,
|
|
||||||
user *server.User,
|
|
||||||
postureChecksID string,
|
|
||||||
) {
|
|
||||||
var (
|
var (
|
||||||
err error
|
err error
|
||||||
req api.PostureCheckUpdate
|
req api.PostureCheckUpdate
|
||||||
@ -181,7 +169,7 @@ func (p *PostureChecksHandler) savePostureChecks(
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := p.accountManager.SavePostureChecks(r.Context(), account.Id, user.Id, postureChecks); err != nil {
|
if err := p.accountManager.SavePostureChecks(r.Context(), accountID, userID, postureChecks); err != nil {
|
||||||
util.WriteError(r.Context(), err, w)
|
util.WriteError(r.Context(), err, w)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
@ -14,7 +14,6 @@ import (
|
|||||||
"github.com/gorilla/mux"
|
"github.com/gorilla/mux"
|
||||||
"github.com/stretchr/testify/assert"
|
"github.com/stretchr/testify/assert"
|
||||||
|
|
||||||
"github.com/netbirdio/netbird/management/server"
|
|
||||||
"github.com/netbirdio/netbird/management/server/geolocation"
|
"github.com/netbirdio/netbird/management/server/geolocation"
|
||||||
"github.com/netbirdio/netbird/management/server/http/api"
|
"github.com/netbirdio/netbird/management/server/http/api"
|
||||||
"github.com/netbirdio/netbird/management/server/jwtclaims"
|
"github.com/netbirdio/netbird/management/server/jwtclaims"
|
||||||
@ -67,15 +66,8 @@ func initPostureChecksTestData(postureChecks ...*posture.Checks) *PostureChecksH
|
|||||||
}
|
}
|
||||||
return accountPostureChecks, nil
|
return accountPostureChecks, nil
|
||||||
},
|
},
|
||||||
GetAccountFromTokenFunc: func(_ context.Context, claims jwtclaims.AuthorizationClaims) (*server.Account, *server.User, error) {
|
GetAccountIDFromTokenFunc: func(_ context.Context, claims jwtclaims.AuthorizationClaims) (string, string, error) {
|
||||||
user := server.NewAdminUser("test_user")
|
return claims.AccountId, claims.UserId, nil
|
||||||
return &server.Account{
|
|
||||||
Id: claims.AccountId,
|
|
||||||
Users: map[string]*server.User{
|
|
||||||
"test_user": user,
|
|
||||||
},
|
|
||||||
PostureChecks: postureChecks,
|
|
||||||
}, user, nil
|
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
geolocationManager: &geolocation.Geolocation{},
|
geolocationManager: &geolocation.Geolocation{},
|
||||||
|
@ -43,13 +43,13 @@ func NewRoutesHandler(accountManager server.AccountManager, authCfg AuthCfg) *Ro
|
|||||||
// GetAllRoutes returns the list of routes for the account
|
// GetAllRoutes returns the list of routes for the account
|
||||||
func (h *RoutesHandler) GetAllRoutes(w http.ResponseWriter, r *http.Request) {
|
func (h *RoutesHandler) GetAllRoutes(w http.ResponseWriter, r *http.Request) {
|
||||||
claims := h.claimsExtractor.FromRequestContext(r)
|
claims := h.claimsExtractor.FromRequestContext(r)
|
||||||
account, user, err := h.accountManager.GetAccountFromToken(r.Context(), claims)
|
accountID, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
util.WriteError(r.Context(), err, w)
|
util.WriteError(r.Context(), err, w)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
routes, err := h.accountManager.ListRoutes(r.Context(), account.Id, user.Id)
|
routes, err := h.accountManager.ListRoutes(r.Context(), accountID, userID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
util.WriteError(r.Context(), err, w)
|
util.WriteError(r.Context(), err, w)
|
||||||
return
|
return
|
||||||
@ -70,7 +70,7 @@ func (h *RoutesHandler) GetAllRoutes(w http.ResponseWriter, r *http.Request) {
|
|||||||
// CreateRoute handles route creation request
|
// CreateRoute handles route creation request
|
||||||
func (h *RoutesHandler) CreateRoute(w http.ResponseWriter, r *http.Request) {
|
func (h *RoutesHandler) CreateRoute(w http.ResponseWriter, r *http.Request) {
|
||||||
claims := h.claimsExtractor.FromRequestContext(r)
|
claims := h.claimsExtractor.FromRequestContext(r)
|
||||||
account, user, err := h.accountManager.GetAccountFromToken(r.Context(), claims)
|
accountID, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
util.WriteError(r.Context(), err, w)
|
util.WriteError(r.Context(), err, w)
|
||||||
return
|
return
|
||||||
@ -117,15 +117,9 @@ func (h *RoutesHandler) CreateRoute(w http.ResponseWriter, r *http.Request) {
|
|||||||
peerGroupIds = *req.PeerGroups
|
peerGroupIds = *req.PeerGroups
|
||||||
}
|
}
|
||||||
|
|
||||||
// Do not allow non-Linux peers
|
newRoute, err := h.accountManager.CreateRoute(r.Context(), accountID, newPrefix, networkType, domains, peerId, peerGroupIds,
|
||||||
if peer := account.GetPeer(peerId); peer != nil {
|
req.Description, route.NetID(req.NetworkId), req.Masquerade, req.Metric, req.Groups, req.Enabled, userID, req.KeepRoute,
|
||||||
if peer.Meta.GoOS != "linux" {
|
)
|
||||||
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "non-linux peers are not supported as network routes"), w)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
newRoute, err := h.accountManager.CreateRoute(r.Context(), account.Id, newPrefix, networkType, domains, peerId, peerGroupIds, req.Description, route.NetID(req.NetworkId), req.Masquerade, req.Metric, req.Groups, req.Enabled, user.Id, req.KeepRoute)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
util.WriteError(r.Context(), err, w)
|
util.WriteError(r.Context(), err, w)
|
||||||
return
|
return
|
||||||
@ -168,7 +162,7 @@ func (h *RoutesHandler) validateRoute(req api.PostApiRoutesJSONRequestBody) erro
|
|||||||
// UpdateRoute handles update to a route identified by a given ID
|
// UpdateRoute handles update to a route identified by a given ID
|
||||||
func (h *RoutesHandler) UpdateRoute(w http.ResponseWriter, r *http.Request) {
|
func (h *RoutesHandler) UpdateRoute(w http.ResponseWriter, r *http.Request) {
|
||||||
claims := h.claimsExtractor.FromRequestContext(r)
|
claims := h.claimsExtractor.FromRequestContext(r)
|
||||||
account, user, err := h.accountManager.GetAccountFromToken(r.Context(), claims)
|
accountID, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
util.WriteError(r.Context(), err, w)
|
util.WriteError(r.Context(), err, w)
|
||||||
return
|
return
|
||||||
@ -181,7 +175,7 @@ func (h *RoutesHandler) UpdateRoute(w http.ResponseWriter, r *http.Request) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
_, err = h.accountManager.GetRoute(r.Context(), account.Id, route.ID(routeID), user.Id)
|
_, err = h.accountManager.GetRoute(r.Context(), accountID, route.ID(routeID), userID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
util.WriteError(r.Context(), err, w)
|
util.WriteError(r.Context(), err, w)
|
||||||
return
|
return
|
||||||
@ -204,14 +198,6 @@ func (h *RoutesHandler) UpdateRoute(w http.ResponseWriter, r *http.Request) {
|
|||||||
peerID = *req.Peer
|
peerID = *req.Peer
|
||||||
}
|
}
|
||||||
|
|
||||||
// do not allow non Linux peers
|
|
||||||
if peer := account.GetPeer(peerID); peer != nil {
|
|
||||||
if peer.Meta.GoOS != "linux" {
|
|
||||||
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "non-linux peers are non supported as network routes"), w)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
newRoute := &route.Route{
|
newRoute := &route.Route{
|
||||||
ID: route.ID(routeID),
|
ID: route.ID(routeID),
|
||||||
NetID: route.NetID(req.NetworkId),
|
NetID: route.NetID(req.NetworkId),
|
||||||
@ -247,7 +233,7 @@ func (h *RoutesHandler) UpdateRoute(w http.ResponseWriter, r *http.Request) {
|
|||||||
newRoute.PeerGroups = *req.PeerGroups
|
newRoute.PeerGroups = *req.PeerGroups
|
||||||
}
|
}
|
||||||
|
|
||||||
err = h.accountManager.SaveRoute(r.Context(), account.Id, user.Id, newRoute)
|
err = h.accountManager.SaveRoute(r.Context(), accountID, userID, newRoute)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
util.WriteError(r.Context(), err, w)
|
util.WriteError(r.Context(), err, w)
|
||||||
return
|
return
|
||||||
@ -265,7 +251,7 @@ func (h *RoutesHandler) UpdateRoute(w http.ResponseWriter, r *http.Request) {
|
|||||||
// DeleteRoute handles route deletion request
|
// DeleteRoute handles route deletion request
|
||||||
func (h *RoutesHandler) DeleteRoute(w http.ResponseWriter, r *http.Request) {
|
func (h *RoutesHandler) DeleteRoute(w http.ResponseWriter, r *http.Request) {
|
||||||
claims := h.claimsExtractor.FromRequestContext(r)
|
claims := h.claimsExtractor.FromRequestContext(r)
|
||||||
account, user, err := h.accountManager.GetAccountFromToken(r.Context(), claims)
|
accountID, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
util.WriteError(r.Context(), err, w)
|
util.WriteError(r.Context(), err, w)
|
||||||
return
|
return
|
||||||
@ -277,7 +263,7 @@ func (h *RoutesHandler) DeleteRoute(w http.ResponseWriter, r *http.Request) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
err = h.accountManager.DeleteRoute(r.Context(), account.Id, route.ID(routeID), user.Id)
|
err = h.accountManager.DeleteRoute(r.Context(), accountID, route.ID(routeID), userID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
util.WriteError(r.Context(), err, w)
|
util.WriteError(r.Context(), err, w)
|
||||||
return
|
return
|
||||||
@ -289,7 +275,7 @@ func (h *RoutesHandler) DeleteRoute(w http.ResponseWriter, r *http.Request) {
|
|||||||
// GetRoute handles a route Get request identified by ID
|
// GetRoute handles a route Get request identified by ID
|
||||||
func (h *RoutesHandler) GetRoute(w http.ResponseWriter, r *http.Request) {
|
func (h *RoutesHandler) GetRoute(w http.ResponseWriter, r *http.Request) {
|
||||||
claims := h.claimsExtractor.FromRequestContext(r)
|
claims := h.claimsExtractor.FromRequestContext(r)
|
||||||
account, user, err := h.accountManager.GetAccountFromToken(r.Context(), claims)
|
accountID, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
util.WriteError(r.Context(), err, w)
|
util.WriteError(r.Context(), err, w)
|
||||||
return
|
return
|
||||||
@ -301,7 +287,7 @@ func (h *RoutesHandler) GetRoute(w http.ResponseWriter, r *http.Request) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
foundRoute, err := h.accountManager.GetRoute(r.Context(), account.Id, route.ID(routeID), user.Id)
|
foundRoute, err := h.accountManager.GetRoute(r.Context(), accountID, route.ID(routeID), userID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
util.WriteError(r.Context(), status.Errorf(status.NotFound, "route not found"), w)
|
util.WriteError(r.Context(), status.Errorf(status.NotFound, "route not found"), w)
|
||||||
return
|
return
|
||||||
|
@ -112,6 +112,12 @@ func initRoutesTestData() *RoutesHandler {
|
|||||||
if len(peerGroups) > 0 && peerGroups[0] == notFoundGroupID {
|
if len(peerGroups) > 0 && peerGroups[0] == notFoundGroupID {
|
||||||
return nil, status.Errorf(status.InvalidArgument, "peer groups with ID %s not found", peerGroups[0])
|
return nil, status.Errorf(status.InvalidArgument, "peer groups with ID %s not found", peerGroups[0])
|
||||||
}
|
}
|
||||||
|
if peerID != "" {
|
||||||
|
if peerID == nonLinuxExistingPeerID {
|
||||||
|
return nil, status.Errorf(status.InvalidArgument, "non-linux peers are not supported as network routes")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
return &route.Route{
|
return &route.Route{
|
||||||
ID: existingRouteID,
|
ID: existingRouteID,
|
||||||
NetID: netID,
|
NetID: netID,
|
||||||
@ -131,6 +137,11 @@ func initRoutesTestData() *RoutesHandler {
|
|||||||
if r.Peer == notFoundPeerID {
|
if r.Peer == notFoundPeerID {
|
||||||
return status.Errorf(status.InvalidArgument, "peer with ID %s not found", r.Peer)
|
return status.Errorf(status.InvalidArgument, "peer with ID %s not found", r.Peer)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if r.Peer == nonLinuxExistingPeerID {
|
||||||
|
return status.Errorf(status.InvalidArgument, "non-linux peers are not supported as network routes")
|
||||||
|
}
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
},
|
},
|
||||||
DeleteRouteFunc: func(_ context.Context, _ string, routeID route.ID, _ string) error {
|
DeleteRouteFunc: func(_ context.Context, _ string, routeID route.ID, _ string) error {
|
||||||
@ -139,8 +150,9 @@ func initRoutesTestData() *RoutesHandler {
|
|||||||
}
|
}
|
||||||
return nil
|
return nil
|
||||||
},
|
},
|
||||||
GetAccountFromTokenFunc: func(_ context.Context, _ jwtclaims.AuthorizationClaims) (*server.Account, *server.User, error) {
|
GetAccountIDFromTokenFunc: func(_ context.Context, _ jwtclaims.AuthorizationClaims) (string, string, error) {
|
||||||
return testingAccount, testingAccount.Users["test_user"], nil
|
//return testingAccount, testingAccount.Users["test_user"], nil
|
||||||
|
return testingAccount.Id, testingAccount.Users["test_user"].Id, nil
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
claimsExtractor: jwtclaims.NewClaimsExtractor(
|
claimsExtractor: jwtclaims.NewClaimsExtractor(
|
||||||
|
@ -35,7 +35,7 @@ func NewSetupKeysHandler(accountManager server.AccountManager, authCfg AuthCfg)
|
|||||||
// CreateSetupKey is a POST requests that creates a new SetupKey
|
// CreateSetupKey is a POST requests that creates a new SetupKey
|
||||||
func (h *SetupKeysHandler) CreateSetupKey(w http.ResponseWriter, r *http.Request) {
|
func (h *SetupKeysHandler) CreateSetupKey(w http.ResponseWriter, r *http.Request) {
|
||||||
claims := h.claimsExtractor.FromRequestContext(r)
|
claims := h.claimsExtractor.FromRequestContext(r)
|
||||||
account, user, err := h.accountManager.GetAccountFromToken(r.Context(), claims)
|
accountID, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
util.WriteError(r.Context(), err, w)
|
util.WriteError(r.Context(), err, w)
|
||||||
return
|
return
|
||||||
@ -76,8 +76,8 @@ func (h *SetupKeysHandler) CreateSetupKey(w http.ResponseWriter, r *http.Request
|
|||||||
if req.Ephemeral != nil {
|
if req.Ephemeral != nil {
|
||||||
ephemeral = *req.Ephemeral
|
ephemeral = *req.Ephemeral
|
||||||
}
|
}
|
||||||
setupKey, err := h.accountManager.CreateSetupKey(r.Context(), account.Id, req.Name, server.SetupKeyType(req.Type), expiresIn,
|
setupKey, err := h.accountManager.CreateSetupKey(r.Context(), accountID, req.Name, server.SetupKeyType(req.Type), expiresIn,
|
||||||
req.AutoGroups, req.UsageLimit, user.Id, ephemeral)
|
req.AutoGroups, req.UsageLimit, userID, ephemeral)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
util.WriteError(r.Context(), err, w)
|
util.WriteError(r.Context(), err, w)
|
||||||
return
|
return
|
||||||
@ -89,7 +89,7 @@ func (h *SetupKeysHandler) CreateSetupKey(w http.ResponseWriter, r *http.Request
|
|||||||
// GetSetupKey is a GET request to get a SetupKey by ID
|
// GetSetupKey is a GET request to get a SetupKey by ID
|
||||||
func (h *SetupKeysHandler) GetSetupKey(w http.ResponseWriter, r *http.Request) {
|
func (h *SetupKeysHandler) GetSetupKey(w http.ResponseWriter, r *http.Request) {
|
||||||
claims := h.claimsExtractor.FromRequestContext(r)
|
claims := h.claimsExtractor.FromRequestContext(r)
|
||||||
account, user, err := h.accountManager.GetAccountFromToken(r.Context(), claims)
|
accountID, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
util.WriteError(r.Context(), err, w)
|
util.WriteError(r.Context(), err, w)
|
||||||
return
|
return
|
||||||
@ -102,7 +102,7 @@ func (h *SetupKeysHandler) GetSetupKey(w http.ResponseWriter, r *http.Request) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
key, err := h.accountManager.GetSetupKey(r.Context(), account.Id, user.Id, keyID)
|
key, err := h.accountManager.GetSetupKey(r.Context(), accountID, userID, keyID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
util.WriteError(r.Context(), err, w)
|
util.WriteError(r.Context(), err, w)
|
||||||
return
|
return
|
||||||
@ -114,7 +114,7 @@ func (h *SetupKeysHandler) GetSetupKey(w http.ResponseWriter, r *http.Request) {
|
|||||||
// UpdateSetupKey is a PUT request to update server.SetupKey
|
// UpdateSetupKey is a PUT request to update server.SetupKey
|
||||||
func (h *SetupKeysHandler) UpdateSetupKey(w http.ResponseWriter, r *http.Request) {
|
func (h *SetupKeysHandler) UpdateSetupKey(w http.ResponseWriter, r *http.Request) {
|
||||||
claims := h.claimsExtractor.FromRequestContext(r)
|
claims := h.claimsExtractor.FromRequestContext(r)
|
||||||
account, user, err := h.accountManager.GetAccountFromToken(r.Context(), claims)
|
accountID, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
util.WriteError(r.Context(), err, w)
|
util.WriteError(r.Context(), err, w)
|
||||||
return
|
return
|
||||||
@ -150,7 +150,7 @@ func (h *SetupKeysHandler) UpdateSetupKey(w http.ResponseWriter, r *http.Request
|
|||||||
newKey.Name = req.Name
|
newKey.Name = req.Name
|
||||||
newKey.Id = keyID
|
newKey.Id = keyID
|
||||||
|
|
||||||
newKey, err = h.accountManager.SaveSetupKey(r.Context(), account.Id, newKey, user.Id)
|
newKey, err = h.accountManager.SaveSetupKey(r.Context(), accountID, newKey, userID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
util.WriteError(r.Context(), err, w)
|
util.WriteError(r.Context(), err, w)
|
||||||
return
|
return
|
||||||
@ -161,13 +161,13 @@ func (h *SetupKeysHandler) UpdateSetupKey(w http.ResponseWriter, r *http.Request
|
|||||||
// GetAllSetupKeys is a GET request that returns a list of SetupKey
|
// GetAllSetupKeys is a GET request that returns a list of SetupKey
|
||||||
func (h *SetupKeysHandler) GetAllSetupKeys(w http.ResponseWriter, r *http.Request) {
|
func (h *SetupKeysHandler) GetAllSetupKeys(w http.ResponseWriter, r *http.Request) {
|
||||||
claims := h.claimsExtractor.FromRequestContext(r)
|
claims := h.claimsExtractor.FromRequestContext(r)
|
||||||
account, user, err := h.accountManager.GetAccountFromToken(r.Context(), claims)
|
accountID, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
util.WriteError(r.Context(), err, w)
|
util.WriteError(r.Context(), err, w)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
setupKeys, err := h.accountManager.ListSetupKeys(r.Context(), account.Id, user.Id)
|
setupKeys, err := h.accountManager.ListSetupKeys(r.Context(), accountID, userID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
util.WriteError(r.Context(), err, w)
|
util.WriteError(r.Context(), err, w)
|
||||||
return
|
return
|
||||||
|
@ -15,7 +15,6 @@ import (
|
|||||||
"github.com/stretchr/testify/assert"
|
"github.com/stretchr/testify/assert"
|
||||||
|
|
||||||
"github.com/netbirdio/netbird/management/server"
|
"github.com/netbirdio/netbird/management/server"
|
||||||
nbgroup "github.com/netbirdio/netbird/management/server/group"
|
|
||||||
"github.com/netbirdio/netbird/management/server/http/api"
|
"github.com/netbirdio/netbird/management/server/http/api"
|
||||||
"github.com/netbirdio/netbird/management/server/jwtclaims"
|
"github.com/netbirdio/netbird/management/server/jwtclaims"
|
||||||
"github.com/netbirdio/netbird/management/server/mock_server"
|
"github.com/netbirdio/netbird/management/server/mock_server"
|
||||||
@ -34,21 +33,8 @@ func initSetupKeysTestMetaData(defaultKey *server.SetupKey, newKey *server.Setup
|
|||||||
) *SetupKeysHandler {
|
) *SetupKeysHandler {
|
||||||
return &SetupKeysHandler{
|
return &SetupKeysHandler{
|
||||||
accountManager: &mock_server.MockAccountManager{
|
accountManager: &mock_server.MockAccountManager{
|
||||||
GetAccountFromTokenFunc: func(_ context.Context, claims jwtclaims.AuthorizationClaims) (*server.Account, *server.User, error) {
|
GetAccountIDFromTokenFunc: func(_ context.Context, claims jwtclaims.AuthorizationClaims) (string, string, error) {
|
||||||
return &server.Account{
|
return claims.AccountId, claims.UserId, nil
|
||||||
Id: testAccountID,
|
|
||||||
Domain: "hotmail.com",
|
|
||||||
Users: map[string]*server.User{
|
|
||||||
user.Id: user,
|
|
||||||
},
|
|
||||||
SetupKeys: map[string]*server.SetupKey{
|
|
||||||
defaultKey.Key: defaultKey,
|
|
||||||
},
|
|
||||||
Groups: map[string]*nbgroup.Group{
|
|
||||||
"group-1": {ID: "group-1", Peers: []string{"A", "B"}},
|
|
||||||
"id-all": {ID: "id-all", Name: "All"},
|
|
||||||
},
|
|
||||||
}, user, nil
|
|
||||||
},
|
},
|
||||||
CreateSetupKeyFunc: func(_ context.Context, _ string, keyName string, typ server.SetupKeyType, _ time.Duration, _ []string,
|
CreateSetupKeyFunc: func(_ context.Context, _ string, keyName string, typ server.SetupKeyType, _ time.Duration, _ []string,
|
||||||
_ int, _ string, ephemeral bool,
|
_ int, _ string, ephemeral bool,
|
||||||
|
@ -41,22 +41,22 @@ func (h *UsersHandler) UpdateUser(w http.ResponseWriter, r *http.Request) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
claims := h.claimsExtractor.FromRequestContext(r)
|
claims := h.claimsExtractor.FromRequestContext(r)
|
||||||
account, user, err := h.accountManager.GetAccountFromToken(r.Context(), claims)
|
accountID, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
util.WriteError(r.Context(), err, w)
|
util.WriteError(r.Context(), err, w)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
vars := mux.Vars(r)
|
vars := mux.Vars(r)
|
||||||
userID := vars["userId"]
|
targetUserID := vars["userId"]
|
||||||
if len(userID) == 0 {
|
if len(targetUserID) == 0 {
|
||||||
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "invalid user ID"), w)
|
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "invalid user ID"), w)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
existingUser, ok := account.Users[userID]
|
existingUser, err := h.accountManager.GetUserByID(r.Context(), targetUserID)
|
||||||
if !ok {
|
if err != nil {
|
||||||
util.WriteError(r.Context(), status.Errorf(status.NotFound, "couldn't find user with ID %s", userID), w)
|
util.WriteError(r.Context(), err, w)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -78,8 +78,8 @@ func (h *UsersHandler) UpdateUser(w http.ResponseWriter, r *http.Request) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
newUser, err := h.accountManager.SaveUser(r.Context(), account.Id, user.Id, &server.User{
|
newUser, err := h.accountManager.SaveUser(r.Context(), accountID, userID, &server.User{
|
||||||
Id: userID,
|
Id: targetUserID,
|
||||||
Role: userRole,
|
Role: userRole,
|
||||||
AutoGroups: req.AutoGroups,
|
AutoGroups: req.AutoGroups,
|
||||||
Blocked: req.IsBlocked,
|
Blocked: req.IsBlocked,
|
||||||
@ -102,7 +102,7 @@ func (h *UsersHandler) DeleteUser(w http.ResponseWriter, r *http.Request) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
claims := h.claimsExtractor.FromRequestContext(r)
|
claims := h.claimsExtractor.FromRequestContext(r)
|
||||||
account, user, err := h.accountManager.GetAccountFromToken(r.Context(), claims)
|
accountID, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
util.WriteError(r.Context(), err, w)
|
util.WriteError(r.Context(), err, w)
|
||||||
return
|
return
|
||||||
@ -115,7 +115,7 @@ func (h *UsersHandler) DeleteUser(w http.ResponseWriter, r *http.Request) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
err = h.accountManager.DeleteUser(r.Context(), account.Id, user.Id, targetUserID)
|
err = h.accountManager.DeleteUser(r.Context(), accountID, userID, targetUserID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
util.WriteError(r.Context(), err, w)
|
util.WriteError(r.Context(), err, w)
|
||||||
return
|
return
|
||||||
@ -132,7 +132,7 @@ func (h *UsersHandler) CreateUser(w http.ResponseWriter, r *http.Request) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
claims := h.claimsExtractor.FromRequestContext(r)
|
claims := h.claimsExtractor.FromRequestContext(r)
|
||||||
account, user, err := h.accountManager.GetAccountFromToken(r.Context(), claims)
|
accountID, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
util.WriteError(r.Context(), err, w)
|
util.WriteError(r.Context(), err, w)
|
||||||
return
|
return
|
||||||
@ -160,7 +160,7 @@ func (h *UsersHandler) CreateUser(w http.ResponseWriter, r *http.Request) {
|
|||||||
name = *req.Name
|
name = *req.Name
|
||||||
}
|
}
|
||||||
|
|
||||||
newUser, err := h.accountManager.CreateUser(r.Context(), account.Id, user.Id, &server.UserInfo{
|
newUser, err := h.accountManager.CreateUser(r.Context(), accountID, userID, &server.UserInfo{
|
||||||
Email: email,
|
Email: email,
|
||||||
Name: name,
|
Name: name,
|
||||||
Role: req.Role,
|
Role: req.Role,
|
||||||
@ -184,13 +184,13 @@ func (h *UsersHandler) GetAllUsers(w http.ResponseWriter, r *http.Request) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
claims := h.claimsExtractor.FromRequestContext(r)
|
claims := h.claimsExtractor.FromRequestContext(r)
|
||||||
account, user, err := h.accountManager.GetAccountFromToken(r.Context(), claims)
|
accountID, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
util.WriteError(r.Context(), err, w)
|
util.WriteError(r.Context(), err, w)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
data, err := h.accountManager.GetUsersFromAccount(r.Context(), account.Id, user.Id)
|
data, err := h.accountManager.GetUsersFromAccount(r.Context(), accountID, userID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
util.WriteError(r.Context(), err, w)
|
util.WriteError(r.Context(), err, w)
|
||||||
return
|
return
|
||||||
@ -231,7 +231,7 @@ func (h *UsersHandler) InviteUser(w http.ResponseWriter, r *http.Request) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
claims := h.claimsExtractor.FromRequestContext(r)
|
claims := h.claimsExtractor.FromRequestContext(r)
|
||||||
account, user, err := h.accountManager.GetAccountFromToken(r.Context(), claims)
|
accountID, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
util.WriteError(r.Context(), err, w)
|
util.WriteError(r.Context(), err, w)
|
||||||
return
|
return
|
||||||
@ -244,7 +244,7 @@ func (h *UsersHandler) InviteUser(w http.ResponseWriter, r *http.Request) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
err = h.accountManager.InviteUser(r.Context(), account.Id, user.Id, targetUserID)
|
err = h.accountManager.InviteUser(r.Context(), accountID, userID, targetUserID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
util.WriteError(r.Context(), err, w)
|
util.WriteError(r.Context(), err, w)
|
||||||
return
|
return
|
||||||
|
@ -64,8 +64,11 @@ var usersTestAccount = &server.Account{
|
|||||||
func initUsersTestData() *UsersHandler {
|
func initUsersTestData() *UsersHandler {
|
||||||
return &UsersHandler{
|
return &UsersHandler{
|
||||||
accountManager: &mock_server.MockAccountManager{
|
accountManager: &mock_server.MockAccountManager{
|
||||||
GetAccountFromTokenFunc: func(_ context.Context, claims jwtclaims.AuthorizationClaims) (*server.Account, *server.User, error) {
|
GetAccountIDFromTokenFunc: func(_ context.Context, claims jwtclaims.AuthorizationClaims) (string, string, error) {
|
||||||
return usersTestAccount, usersTestAccount.Users[claims.UserId], nil
|
return usersTestAccount.Id, claims.UserId, nil
|
||||||
|
},
|
||||||
|
GetUserByIDFunc: func(ctx context.Context, id string) (*server.User, error) {
|
||||||
|
return usersTestAccount.Users[id], nil
|
||||||
},
|
},
|
||||||
GetUsersFromAccountFunc: func(_ context.Context, accountID, userID string) ([]*server.UserInfo, error) {
|
GetUsersFromAccountFunc: func(_ context.Context, accountID, userID string) ([]*server.UserInfo, error) {
|
||||||
users := make([]*server.UserInfo, 0)
|
users := make([]*server.UserInfo, 0)
|
||||||
|
@ -23,10 +23,11 @@ import (
|
|||||||
|
|
||||||
type MockAccountManager struct {
|
type MockAccountManager struct {
|
||||||
GetOrCreateAccountByUserFunc func(ctx context.Context, userId, domain string) (*server.Account, error)
|
GetOrCreateAccountByUserFunc func(ctx context.Context, userId, domain string) (*server.Account, error)
|
||||||
|
GetAccountFunc func(ctx context.Context, accountID string) (*server.Account, error)
|
||||||
CreateSetupKeyFunc func(ctx context.Context, accountId string, keyName string, keyType server.SetupKeyType,
|
CreateSetupKeyFunc func(ctx context.Context, accountId string, keyName string, keyType server.SetupKeyType,
|
||||||
expiresIn time.Duration, autoGroups []string, usageLimit int, userID string, ephemeral bool) (*server.SetupKey, error)
|
expiresIn time.Duration, autoGroups []string, usageLimit int, userID string, ephemeral bool) (*server.SetupKey, error)
|
||||||
GetSetupKeyFunc func(ctx context.Context, accountID, userID, keyID string) (*server.SetupKey, error)
|
GetSetupKeyFunc func(ctx context.Context, accountID, userID, keyID string) (*server.SetupKey, error)
|
||||||
GetAccountByUserOrAccountIdFunc func(ctx context.Context, userId, accountId, domain string) (*server.Account, error)
|
GetAccountIDByUserOrAccountIdFunc func(ctx context.Context, userId, accountId, domain string) (string, error)
|
||||||
GetUserFunc func(ctx context.Context, claims jwtclaims.AuthorizationClaims) (*server.User, error)
|
GetUserFunc func(ctx context.Context, claims jwtclaims.AuthorizationClaims) (*server.User, error)
|
||||||
ListUsersFunc func(ctx context.Context, accountID string) ([]*server.User, error)
|
ListUsersFunc func(ctx context.Context, accountID string) ([]*server.User, error)
|
||||||
GetPeersFunc func(ctx context.Context, accountID, userID string) ([]*nbpeer.Peer, error)
|
GetPeersFunc func(ctx context.Context, accountID, userID string) ([]*nbpeer.Peer, error)
|
||||||
@ -48,7 +49,7 @@ type MockAccountManager struct {
|
|||||||
GroupDeletePeerFunc func(ctx context.Context, accountID, groupID, peerID string) error
|
GroupDeletePeerFunc func(ctx context.Context, accountID, groupID, peerID string) error
|
||||||
DeleteRuleFunc func(ctx context.Context, accountID, ruleID, userID string) error
|
DeleteRuleFunc func(ctx context.Context, accountID, ruleID, userID string) error
|
||||||
GetPolicyFunc func(ctx context.Context, accountID, policyID, userID string) (*server.Policy, error)
|
GetPolicyFunc func(ctx context.Context, accountID, policyID, userID string) (*server.Policy, error)
|
||||||
SavePolicyFunc func(ctx context.Context, accountID, userID string, policy *server.Policy) error
|
SavePolicyFunc func(ctx context.Context, accountID, userID string, policy *server.Policy, isUpdate bool) error
|
||||||
DeletePolicyFunc func(ctx context.Context, accountID, policyID, userID string) error
|
DeletePolicyFunc func(ctx context.Context, accountID, policyID, userID string) error
|
||||||
ListPoliciesFunc func(ctx context.Context, accountID, userID string) ([]*server.Policy, error)
|
ListPoliciesFunc func(ctx context.Context, accountID, userID string) ([]*server.Policy, error)
|
||||||
GetUsersFromAccountFunc func(ctx context.Context, accountID, userID string) ([]*server.UserInfo, error)
|
GetUsersFromAccountFunc func(ctx context.Context, accountID, userID string) ([]*server.UserInfo, error)
|
||||||
@ -79,7 +80,7 @@ type MockAccountManager struct {
|
|||||||
DeleteNameServerGroupFunc func(ctx context.Context, accountID, nsGroupID, userID string) error
|
DeleteNameServerGroupFunc func(ctx context.Context, accountID, nsGroupID, userID string) error
|
||||||
ListNameServerGroupsFunc func(ctx context.Context, accountID string, userID string) ([]*nbdns.NameServerGroup, error)
|
ListNameServerGroupsFunc func(ctx context.Context, accountID string, userID string) ([]*nbdns.NameServerGroup, error)
|
||||||
CreateUserFunc func(ctx context.Context, accountID, userID string, key *server.UserInfo) (*server.UserInfo, error)
|
CreateUserFunc func(ctx context.Context, accountID, userID string, key *server.UserInfo) (*server.UserInfo, error)
|
||||||
GetAccountFromTokenFunc func(ctx context.Context, claims jwtclaims.AuthorizationClaims) (*server.Account, *server.User, error)
|
GetAccountIDFromTokenFunc func(ctx context.Context, claims jwtclaims.AuthorizationClaims) (string, string, error)
|
||||||
CheckUserAccessByJWTGroupsFunc func(ctx context.Context, claims jwtclaims.AuthorizationClaims) error
|
CheckUserAccessByJWTGroupsFunc func(ctx context.Context, claims jwtclaims.AuthorizationClaims) error
|
||||||
DeleteAccountFunc func(ctx context.Context, accountID, userID string) error
|
DeleteAccountFunc func(ctx context.Context, accountID, userID string) error
|
||||||
GetDNSDomainFunc func() string
|
GetDNSDomainFunc func() string
|
||||||
@ -105,6 +106,9 @@ type MockAccountManager struct {
|
|||||||
SyncPeerMetaFunc func(ctx context.Context, peerPubKey string, meta nbpeer.PeerSystemMeta) error
|
SyncPeerMetaFunc func(ctx context.Context, peerPubKey string, meta nbpeer.PeerSystemMeta) error
|
||||||
FindExistingPostureCheckFunc func(accountID string, checks *posture.ChecksDefinition) (*posture.Checks, error)
|
FindExistingPostureCheckFunc func(accountID string, checks *posture.ChecksDefinition) (*posture.Checks, error)
|
||||||
GetAccountIDForPeerKeyFunc func(ctx context.Context, peerKey string) (string, error)
|
GetAccountIDForPeerKeyFunc func(ctx context.Context, peerKey string) (string, error)
|
||||||
|
GetAccountByIDFunc func(ctx context.Context, accountID string, userID string) (*server.Account, error)
|
||||||
|
GetUserByIDFunc func(ctx context.Context, id string) (*server.User, error)
|
||||||
|
GetAccountSettingsFunc func(ctx context.Context, accountID string, userID string) (*server.Settings, error)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (am *MockAccountManager) SyncAndMarkPeer(ctx context.Context, accountID string, peerPubKey string, meta nbpeer.PeerSystemMeta, realIP net.IP) (*nbpeer.Peer, *server.NetworkMap, []*posture.Checks, error) {
|
func (am *MockAccountManager) SyncAndMarkPeer(ctx context.Context, accountID string, peerPubKey string, meta nbpeer.PeerSystemMeta, realIP net.IP) (*nbpeer.Peer, *server.NetworkMap, []*posture.Checks, error) {
|
||||||
@ -190,16 +194,14 @@ func (am *MockAccountManager) CreateSetupKey(
|
|||||||
return nil, status.Errorf(codes.Unimplemented, "method CreateSetupKey is not implemented")
|
return nil, status.Errorf(codes.Unimplemented, "method CreateSetupKey is not implemented")
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetAccountByUserOrAccountID mock implementation of GetAccountByUserOrAccountID from server.AccountManager interface
|
// GetAccountIDByUserOrAccountID mock implementation of GetAccountIDByUserOrAccountID from server.AccountManager interface
|
||||||
func (am *MockAccountManager) GetAccountByUserOrAccountID(
|
func (am *MockAccountManager) GetAccountIDByUserOrAccountID(ctx context.Context, userId, accountId, domain string) (string, error) {
|
||||||
ctx context.Context, userId, accountId, domain string,
|
if am.GetAccountIDByUserOrAccountIdFunc != nil {
|
||||||
) (*server.Account, error) {
|
return am.GetAccountIDByUserOrAccountIdFunc(ctx, userId, accountId, domain)
|
||||||
if am.GetAccountByUserOrAccountIdFunc != nil {
|
|
||||||
return am.GetAccountByUserOrAccountIdFunc(ctx, userId, accountId, domain)
|
|
||||||
}
|
}
|
||||||
return nil, status.Errorf(
|
return "", status.Errorf(
|
||||||
codes.Unimplemented,
|
codes.Unimplemented,
|
||||||
"method GetAccountByUserOrAccountID is not implemented",
|
"method GetAccountIDByUserOrAccountID is not implemented",
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -377,9 +379,9 @@ func (am *MockAccountManager) GetPolicy(ctx context.Context, accountID, policyID
|
|||||||
}
|
}
|
||||||
|
|
||||||
// SavePolicy mock implementation of SavePolicy from server.AccountManager interface
|
// SavePolicy mock implementation of SavePolicy from server.AccountManager interface
|
||||||
func (am *MockAccountManager) SavePolicy(ctx context.Context, accountID, userID string, policy *server.Policy) error {
|
func (am *MockAccountManager) SavePolicy(ctx context.Context, accountID, userID string, policy *server.Policy, isUpdate bool) error {
|
||||||
if am.SavePolicyFunc != nil {
|
if am.SavePolicyFunc != nil {
|
||||||
return am.SavePolicyFunc(ctx, accountID, userID, policy)
|
return am.SavePolicyFunc(ctx, accountID, userID, policy, isUpdate)
|
||||||
}
|
}
|
||||||
return status.Errorf(codes.Unimplemented, "method SavePolicy is not implemented")
|
return status.Errorf(codes.Unimplemented, "method SavePolicy is not implemented")
|
||||||
}
|
}
|
||||||
@ -601,14 +603,12 @@ func (am *MockAccountManager) CreateUser(ctx context.Context, accountID, userID
|
|||||||
return nil, status.Errorf(codes.Unimplemented, "method CreateUser is not implemented")
|
return nil, status.Errorf(codes.Unimplemented, "method CreateUser is not implemented")
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetAccountFromToken mocks GetAccountFromToken of the AccountManager interface
|
// GetAccountIDFromToken mocks GetAccountIDFromToken of the AccountManager interface
|
||||||
func (am *MockAccountManager) GetAccountFromToken(ctx context.Context, claims jwtclaims.AuthorizationClaims) (*server.Account, *server.User,
|
func (am *MockAccountManager) GetAccountIDFromToken(ctx context.Context, claims jwtclaims.AuthorizationClaims) (string, string, error) {
|
||||||
error,
|
if am.GetAccountIDFromTokenFunc != nil {
|
||||||
) {
|
return am.GetAccountIDFromTokenFunc(ctx, claims)
|
||||||
if am.GetAccountFromTokenFunc != nil {
|
|
||||||
return am.GetAccountFromTokenFunc(ctx, claims)
|
|
||||||
}
|
}
|
||||||
return nil, nil, status.Errorf(codes.Unimplemented, "method GetAccountFromToken is not implemented")
|
return "", "", status.Errorf(codes.Unimplemented, "method GetAccountIDFromToken is not implemented")
|
||||||
}
|
}
|
||||||
|
|
||||||
func (am *MockAccountManager) CheckUserAccessByJWTGroups(ctx context.Context, claims jwtclaims.AuthorizationClaims) error {
|
func (am *MockAccountManager) CheckUserAccessByJWTGroups(ctx context.Context, claims jwtclaims.AuthorizationClaims) error {
|
||||||
@ -802,3 +802,33 @@ func (am *MockAccountManager) GetAccountIDForPeerKey(ctx context.Context, peerKe
|
|||||||
}
|
}
|
||||||
return "", status.Errorf(codes.Unimplemented, "method GetAccountIDForPeerKey is not implemented")
|
return "", status.Errorf(codes.Unimplemented, "method GetAccountIDForPeerKey is not implemented")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// GetAccountByID mocks GetAccountByID of the AccountManager interface
|
||||||
|
func (am *MockAccountManager) GetAccountByID(ctx context.Context, accountID string, userID string) (*server.Account, error) {
|
||||||
|
if am.GetAccountByIDFunc != nil {
|
||||||
|
return am.GetAccountByIDFunc(ctx, accountID, userID)
|
||||||
|
}
|
||||||
|
return nil, status.Errorf(codes.Unimplemented, "method GetAccountByID is not implemented")
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetUserByID mocks GetUserByID of the AccountManager interface
|
||||||
|
func (am *MockAccountManager) GetUserByID(ctx context.Context, id string) (*server.User, error) {
|
||||||
|
if am.GetUserByIDFunc != nil {
|
||||||
|
return am.GetUserByIDFunc(ctx, id)
|
||||||
|
}
|
||||||
|
return nil, status.Errorf(codes.Unimplemented, "method GetUserByID is not implemented")
|
||||||
|
}
|
||||||
|
|
||||||
|
func (am *MockAccountManager) GetAccountSettings(ctx context.Context, accountID string, userID string) (*server.Settings, error) {
|
||||||
|
if am.GetAccountSettingsFunc != nil {
|
||||||
|
return am.GetAccountSettingsFunc(ctx, accountID, userID)
|
||||||
|
}
|
||||||
|
return nil, status.Errorf(codes.Unimplemented, "method GetAccountSettings is not implemented")
|
||||||
|
}
|
||||||
|
|
||||||
|
func (am *MockAccountManager) GetAccount(ctx context.Context, accountID string) (*server.Account, error) {
|
||||||
|
if am.GetAccountFunc != nil {
|
||||||
|
return am.GetAccountFunc(ctx, accountID)
|
||||||
|
}
|
||||||
|
return nil, status.Errorf(codes.Unimplemented, "method GetAccount is not implemented")
|
||||||
|
}
|
||||||
|
@ -19,30 +19,16 @@ const domainPattern = `^(?i)[a-z0-9]+([\-\.]{1}[a-z0-9]+)*\.[a-z]{2,}$`
|
|||||||
|
|
||||||
// GetNameServerGroup gets a nameserver group object from account and nameserver group IDs
|
// GetNameServerGroup gets a nameserver group object from account and nameserver group IDs
|
||||||
func (am *DefaultAccountManager) GetNameServerGroup(ctx context.Context, accountID, userID, nsGroupID string) (*nbdns.NameServerGroup, error) {
|
func (am *DefaultAccountManager) GetNameServerGroup(ctx context.Context, accountID, userID, nsGroupID string) (*nbdns.NameServerGroup, error) {
|
||||||
|
user, err := am.Store.GetUserByUserID(ctx, LockingStrengthShare, userID)
|
||||||
unlock := am.Store.AcquireWriteLockByUID(ctx, accountID)
|
|
||||||
defer unlock()
|
|
||||||
|
|
||||||
account, err := am.Store.GetAccount(ctx, accountID)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
user, err := account.FindUser(userID)
|
if !user.IsAdminOrServiceUser() || user.AccountID != accountID {
|
||||||
if err != nil {
|
return nil, status.Errorf(status.PermissionDenied, "only users with admin power can view name server groups")
|
||||||
return nil, err
|
|
||||||
}
|
}
|
||||||
|
|
||||||
if !(user.HasAdminPower() || user.IsServiceUser) {
|
return am.Store.GetNameServerGroupByID(ctx, LockingStrengthShare, nsGroupID, accountID)
|
||||||
return nil, status.Errorf(status.PermissionDenied, "only users with admin power can view nameserver groups")
|
|
||||||
}
|
|
||||||
|
|
||||||
nsGroup, found := account.NameServerGroups[nsGroupID]
|
|
||||||
if found {
|
|
||||||
return nsGroup.Copy(), nil
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil, status.Errorf(status.NotFound, "nameserver group with ID %s not found", nsGroupID)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// CreateNameServerGroup creates and saves a new nameserver group
|
// CreateNameServerGroup creates and saves a new nameserver group
|
||||||
@ -159,30 +145,16 @@ func (am *DefaultAccountManager) DeleteNameServerGroup(ctx context.Context, acco
|
|||||||
|
|
||||||
// ListNameServerGroups returns a list of nameserver groups from account
|
// ListNameServerGroups returns a list of nameserver groups from account
|
||||||
func (am *DefaultAccountManager) ListNameServerGroups(ctx context.Context, accountID string, userID string) ([]*nbdns.NameServerGroup, error) {
|
func (am *DefaultAccountManager) ListNameServerGroups(ctx context.Context, accountID string, userID string) ([]*nbdns.NameServerGroup, error) {
|
||||||
|
user, err := am.Store.GetUserByUserID(ctx, LockingStrengthShare, userID)
|
||||||
unlock := am.Store.AcquireWriteLockByUID(ctx, accountID)
|
|
||||||
defer unlock()
|
|
||||||
|
|
||||||
account, err := am.Store.GetAccount(ctx, accountID)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
user, err := account.FindUser(userID)
|
if !user.IsAdminOrServiceUser() || user.AccountID != accountID {
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
if !(user.HasAdminPower() || user.IsServiceUser) {
|
|
||||||
return nil, status.Errorf(status.PermissionDenied, "only users with admin power can view name server groups")
|
return nil, status.Errorf(status.PermissionDenied, "only users with admin power can view name server groups")
|
||||||
}
|
}
|
||||||
|
|
||||||
nsGroups := make([]*nbdns.NameServerGroup, 0, len(account.NameServerGroups))
|
return am.Store.GetAccountNameServerGroups(ctx, LockingStrengthShare, accountID)
|
||||||
for _, item := range account.NameServerGroups {
|
|
||||||
nsGroups = append(nsGroups, item.Copy())
|
|
||||||
}
|
|
||||||
|
|
||||||
return nsGroups, nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func validateNameServerGroup(existingGroup bool, nameserverGroup *nbdns.NameServerGroup, account *Account) error {
|
func validateNameServerGroup(existingGroup bool, nameserverGroup *nbdns.NameServerGroup, account *Account) error {
|
||||||
|
@ -251,7 +251,7 @@ func TestAccountManager_GetNetworkMapWithPolicy(t *testing.T) {
|
|||||||
Action: PolicyTrafficActionAccept,
|
Action: PolicyTrafficActionAccept,
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
err = manager.SavePolicy(context.Background(), account.Id, userID, &policy)
|
err = manager.SavePolicy(context.Background(), account.Id, userID, &policy, false)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Errorf("expecting rule to be added, got failure %v", err)
|
t.Errorf("expecting rule to be added, got failure %v", err)
|
||||||
return
|
return
|
||||||
@ -299,7 +299,7 @@ func TestAccountManager_GetNetworkMapWithPolicy(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
policy.Enabled = false
|
policy.Enabled = false
|
||||||
err = manager.SavePolicy(context.Background(), account.Id, userID, &policy)
|
err = manager.SavePolicy(context.Background(), account.Id, userID, &policy, true)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Errorf("expecting rule to be added, got failure %v", err)
|
t.Errorf("expecting rule to be added, got failure %v", err)
|
||||||
return
|
return
|
||||||
|
@ -3,6 +3,7 @@ package server
|
|||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
_ "embed"
|
_ "embed"
|
||||||
|
"slices"
|
||||||
"strconv"
|
"strconv"
|
||||||
"strings"
|
"strings"
|
||||||
|
|
||||||
@ -314,34 +315,20 @@ func (a *Account) connResourcesGenerator(ctx context.Context) (func(*PolicyRule,
|
|||||||
|
|
||||||
// GetPolicy from the store
|
// GetPolicy from the store
|
||||||
func (am *DefaultAccountManager) GetPolicy(ctx context.Context, accountID, policyID, userID string) (*Policy, error) {
|
func (am *DefaultAccountManager) GetPolicy(ctx context.Context, accountID, policyID, userID string) (*Policy, error) {
|
||||||
unlock := am.Store.AcquireWriteLockByUID(ctx, accountID)
|
user, err := am.Store.GetUserByUserID(ctx, LockingStrengthShare, userID)
|
||||||
defer unlock()
|
|
||||||
|
|
||||||
account, err := am.Store.GetAccount(ctx, accountID)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
user, err := account.FindUser(userID)
|
if !user.IsAdminOrServiceUser() || user.AccountID != accountID {
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
if !(user.HasAdminPower() || user.IsServiceUser) {
|
|
||||||
return nil, status.Errorf(status.PermissionDenied, "only users with admin power are allowed to view policies")
|
return nil, status.Errorf(status.PermissionDenied, "only users with admin power are allowed to view policies")
|
||||||
}
|
}
|
||||||
|
|
||||||
for _, policy := range account.Policies {
|
return am.Store.GetPolicyByID(ctx, LockingStrengthShare, policyID, accountID)
|
||||||
if policy.ID == policyID {
|
|
||||||
return policy, nil
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil, status.Errorf(status.NotFound, "policy with ID %s not found", policyID)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// SavePolicy in the store
|
// SavePolicy in the store
|
||||||
func (am *DefaultAccountManager) SavePolicy(ctx context.Context, accountID, userID string, policy *Policy) error {
|
func (am *DefaultAccountManager) SavePolicy(ctx context.Context, accountID, userID string, policy *Policy, isUpdate bool) error {
|
||||||
unlock := am.Store.AcquireWriteLockByUID(ctx, accountID)
|
unlock := am.Store.AcquireWriteLockByUID(ctx, accountID)
|
||||||
defer unlock()
|
defer unlock()
|
||||||
|
|
||||||
@ -350,7 +337,9 @@ func (am *DefaultAccountManager) SavePolicy(ctx context.Context, accountID, user
|
|||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
exists := am.savePolicy(account, policy)
|
if err = am.savePolicy(account, policy, isUpdate); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
account.Network.IncSerial()
|
account.Network.IncSerial()
|
||||||
if err = am.Store.SaveAccount(ctx, account); err != nil {
|
if err = am.Store.SaveAccount(ctx, account); err != nil {
|
||||||
@ -358,7 +347,7 @@ func (am *DefaultAccountManager) SavePolicy(ctx context.Context, accountID, user
|
|||||||
}
|
}
|
||||||
|
|
||||||
action := activity.PolicyAdded
|
action := activity.PolicyAdded
|
||||||
if exists {
|
if isUpdate {
|
||||||
action = activity.PolicyUpdated
|
action = activity.PolicyUpdated
|
||||||
}
|
}
|
||||||
am.StoreEvent(ctx, userID, policy.ID, accountID, action, policy.EventMeta())
|
am.StoreEvent(ctx, userID, policy.ID, accountID, action, policy.EventMeta())
|
||||||
@ -397,24 +386,16 @@ func (am *DefaultAccountManager) DeletePolicy(ctx context.Context, accountID, po
|
|||||||
|
|
||||||
// ListPolicies from the store
|
// ListPolicies from the store
|
||||||
func (am *DefaultAccountManager) ListPolicies(ctx context.Context, accountID, userID string) ([]*Policy, error) {
|
func (am *DefaultAccountManager) ListPolicies(ctx context.Context, accountID, userID string) ([]*Policy, error) {
|
||||||
unlock := am.Store.AcquireWriteLockByUID(ctx, accountID)
|
user, err := am.Store.GetUserByUserID(ctx, LockingStrengthShare, userID)
|
||||||
defer unlock()
|
|
||||||
|
|
||||||
account, err := am.Store.GetAccount(ctx, accountID)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
user, err := account.FindUser(userID)
|
if !user.IsAdminOrServiceUser() || user.AccountID != accountID {
|
||||||
if err != nil {
|
return nil, status.Errorf(status.PermissionDenied, "only users with admin power are allowed to view policies")
|
||||||
return nil, err
|
|
||||||
}
|
}
|
||||||
|
|
||||||
if !(user.HasAdminPower() || user.IsServiceUser) {
|
return am.Store.GetAccountPolicies(ctx, LockingStrengthShare, accountID)
|
||||||
return nil, status.Errorf(status.PermissionDenied, "only users with admin power can view policies")
|
|
||||||
}
|
|
||||||
|
|
||||||
return account.Policies, nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (am *DefaultAccountManager) deletePolicy(account *Account, policyID string) (*Policy, error) {
|
func (am *DefaultAccountManager) deletePolicy(account *Account, policyID string) (*Policy, error) {
|
||||||
@ -434,18 +415,34 @@ func (am *DefaultAccountManager) deletePolicy(account *Account, policyID string)
|
|||||||
return policy, nil
|
return policy, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (am *DefaultAccountManager) savePolicy(account *Account, policy *Policy) (exists bool) {
|
// savePolicy saves or updates a policy in the given account.
|
||||||
for i, p := range account.Policies {
|
// If isUpdate is true, the function updates the existing policy; otherwise, it adds a new policy.
|
||||||
if p.ID == policy.ID {
|
func (am *DefaultAccountManager) savePolicy(account *Account, policyToSave *Policy, isUpdate bool) error {
|
||||||
account.Policies[i] = policy
|
for index, rule := range policyToSave.Rules {
|
||||||
exists = true
|
rule.Sources = filterValidGroupIDs(account, rule.Sources)
|
||||||
break
|
rule.Destinations = filterValidGroupIDs(account, rule.Destinations)
|
||||||
|
policyToSave.Rules[index] = rule
|
||||||
|
}
|
||||||
|
|
||||||
|
if policyToSave.SourcePostureChecks != nil {
|
||||||
|
policyToSave.SourcePostureChecks = filterValidPostureChecks(account, policyToSave.SourcePostureChecks)
|
||||||
|
}
|
||||||
|
|
||||||
|
if isUpdate {
|
||||||
|
policyIdx := slices.IndexFunc(account.Policies, func(policy *Policy) bool { return policy.ID == policyToSave.ID })
|
||||||
|
if policyIdx < 0 {
|
||||||
|
return status.Errorf(status.NotFound, "couldn't find policy id %s", policyToSave.ID)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Update the existing policy
|
||||||
|
account.Policies[policyIdx] = policyToSave
|
||||||
|
return nil
|
||||||
}
|
}
|
||||||
if !exists {
|
|
||||||
account.Policies = append(account.Policies, policy)
|
// Add the new policy to the account
|
||||||
}
|
account.Policies = append(account.Policies, policyToSave)
|
||||||
return
|
|
||||||
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func toProtocolFirewallRules(update []*FirewallRule) []*proto.FirewallRule {
|
func toProtocolFirewallRules(update []*FirewallRule) []*proto.FirewallRule {
|
||||||
@ -560,3 +557,29 @@ func (a *Account) getPostureChecks(postureChecksID string) *posture.Checks {
|
|||||||
}
|
}
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// filterValidPostureChecks filters and returns the posture check IDs from the given list
|
||||||
|
// that are valid within the provided account.
|
||||||
|
func filterValidPostureChecks(account *Account, postureChecksIds []string) []string {
|
||||||
|
result := make([]string, 0, len(postureChecksIds))
|
||||||
|
for _, id := range postureChecksIds {
|
||||||
|
for _, postureCheck := range account.PostureChecks {
|
||||||
|
if id == postureCheck.ID {
|
||||||
|
result = append(result, id)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return result
|
||||||
|
}
|
||||||
|
|
||||||
|
// filterValidGroupIDs filters a list of group IDs and returns only the ones present in the account's group map.
|
||||||
|
func filterValidGroupIDs(account *Account, groupIDs []string) []string {
|
||||||
|
result := make([]string, 0, len(groupIDs))
|
||||||
|
for _, groupID := range groupIDs {
|
||||||
|
if _, exists := account.Groups[groupID]; exists {
|
||||||
|
result = append(result, groupID)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return result
|
||||||
|
}
|
||||||
|
@ -15,30 +15,16 @@ const (
|
|||||||
)
|
)
|
||||||
|
|
||||||
func (am *DefaultAccountManager) GetPostureChecks(ctx context.Context, accountID, postureChecksID, userID string) (*posture.Checks, error) {
|
func (am *DefaultAccountManager) GetPostureChecks(ctx context.Context, accountID, postureChecksID, userID string) (*posture.Checks, error) {
|
||||||
unlock := am.Store.AcquireWriteLockByUID(ctx, accountID)
|
user, err := am.Store.GetUserByUserID(ctx, LockingStrengthShare, userID)
|
||||||
defer unlock()
|
|
||||||
|
|
||||||
account, err := am.Store.GetAccount(ctx, accountID)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
user, err := account.FindUser(userID)
|
if !user.HasAdminPower() || user.AccountID != accountID {
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
if !user.HasAdminPower() {
|
|
||||||
return nil, status.Errorf(status.PermissionDenied, errMsgPostureAdminOnly)
|
return nil, status.Errorf(status.PermissionDenied, errMsgPostureAdminOnly)
|
||||||
}
|
}
|
||||||
|
|
||||||
for _, postureChecks := range account.PostureChecks {
|
return am.Store.GetPostureChecksByID(ctx, LockingStrengthShare, postureChecksID, accountID)
|
||||||
if postureChecks.ID == postureChecksID {
|
|
||||||
return postureChecks, nil
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil, status.Errorf(status.NotFound, "posture checks with ID %s not found", postureChecksID)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (am *DefaultAccountManager) SavePostureChecks(ctx context.Context, accountID, userID string, postureChecks *posture.Checks) error {
|
func (am *DefaultAccountManager) SavePostureChecks(ctx context.Context, accountID, userID string, postureChecks *posture.Checks) error {
|
||||||
@ -121,24 +107,16 @@ func (am *DefaultAccountManager) DeletePostureChecks(ctx context.Context, accoun
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (am *DefaultAccountManager) ListPostureChecks(ctx context.Context, accountID, userID string) ([]*posture.Checks, error) {
|
func (am *DefaultAccountManager) ListPostureChecks(ctx context.Context, accountID, userID string) ([]*posture.Checks, error) {
|
||||||
unlock := am.Store.AcquireWriteLockByUID(ctx, accountID)
|
user, err := am.Store.GetUserByUserID(ctx, LockingStrengthShare, userID)
|
||||||
defer unlock()
|
|
||||||
|
|
||||||
account, err := am.Store.GetAccount(ctx, accountID)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
user, err := account.FindUser(userID)
|
if !user.HasAdminPower() || user.AccountID != accountID {
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
if !user.HasAdminPower() {
|
|
||||||
return nil, status.Errorf(status.PermissionDenied, errMsgPostureAdminOnly)
|
return nil, status.Errorf(status.PermissionDenied, errMsgPostureAdminOnly)
|
||||||
}
|
}
|
||||||
|
|
||||||
return account.PostureChecks, nil
|
return am.Store.GetAccountPostureChecks(ctx, LockingStrengthShare, accountID)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (am *DefaultAccountManager) savePostureChecks(account *Account, postureChecks *posture.Checks) (exists, uniqName bool) {
|
func (am *DefaultAccountManager) savePostureChecks(account *Account, postureChecks *posture.Checks) (exists, uniqName bool) {
|
||||||
|
@ -17,29 +17,16 @@ import (
|
|||||||
|
|
||||||
// GetRoute gets a route object from account and route IDs
|
// GetRoute gets a route object from account and route IDs
|
||||||
func (am *DefaultAccountManager) GetRoute(ctx context.Context, accountID string, routeID route.ID, userID string) (*route.Route, error) {
|
func (am *DefaultAccountManager) GetRoute(ctx context.Context, accountID string, routeID route.ID, userID string) (*route.Route, error) {
|
||||||
unlock := am.Store.AcquireWriteLockByUID(ctx, accountID)
|
user, err := am.Store.GetUserByUserID(ctx, LockingStrengthShare, userID)
|
||||||
defer unlock()
|
|
||||||
|
|
||||||
account, err := am.Store.GetAccount(ctx, accountID)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
user, err := account.FindUser(userID)
|
if !user.IsAdminOrServiceUser() || user.AccountID != accountID {
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
if !(user.HasAdminPower() || user.IsServiceUser) {
|
|
||||||
return nil, status.Errorf(status.PermissionDenied, "only users with admin power can view Network Routes")
|
return nil, status.Errorf(status.PermissionDenied, "only users with admin power can view Network Routes")
|
||||||
}
|
}
|
||||||
|
|
||||||
wantedRoute, found := account.Routes[routeID]
|
return am.Store.GetRouteByID(ctx, LockingStrengthShare, string(routeID), accountID)
|
||||||
if found {
|
|
||||||
return wantedRoute, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil, status.Errorf(status.NotFound, "route with ID %s not found", routeID)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// checkRoutePrefixOrDomainsExistForPeers checks if a route with a given prefix exists for a single peer or multiple peer groups.
|
// checkRoutePrefixOrDomainsExistForPeers checks if a route with a given prefix exists for a single peer or multiple peer groups.
|
||||||
@ -134,6 +121,13 @@ func (am *DefaultAccountManager) CreateRoute(ctx context.Context, accountID stri
|
|||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Do not allow non-Linux peers
|
||||||
|
if peer := account.GetPeer(peerID); peer != nil {
|
||||||
|
if peer.Meta.GoOS != "linux" {
|
||||||
|
return nil, status.Errorf(status.InvalidArgument, "non-linux peers are not supported as network routes")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
if len(domains) > 0 && prefix.IsValid() {
|
if len(domains) > 0 && prefix.IsValid() {
|
||||||
return nil, status.Errorf(status.InvalidArgument, "domains and network should not be provided at the same time")
|
return nil, status.Errorf(status.InvalidArgument, "domains and network should not be provided at the same time")
|
||||||
}
|
}
|
||||||
@ -234,6 +228,13 @@ func (am *DefaultAccountManager) SaveRoute(ctx context.Context, accountID, userI
|
|||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Do not allow non-Linux peers
|
||||||
|
if peer := account.GetPeer(routeToSave.Peer); peer != nil {
|
||||||
|
if peer.Meta.GoOS != "linux" {
|
||||||
|
return status.Errorf(status.InvalidArgument, "non-linux peers are not supported as network routes")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
if len(routeToSave.Domains) > 0 && routeToSave.Network.IsValid() {
|
if len(routeToSave.Domains) > 0 && routeToSave.Network.IsValid() {
|
||||||
return status.Errorf(status.InvalidArgument, "domains and network should not be provided at the same time")
|
return status.Errorf(status.InvalidArgument, "domains and network should not be provided at the same time")
|
||||||
}
|
}
|
||||||
@ -311,29 +312,16 @@ func (am *DefaultAccountManager) DeleteRoute(ctx context.Context, accountID stri
|
|||||||
|
|
||||||
// ListRoutes returns a list of routes from account
|
// ListRoutes returns a list of routes from account
|
||||||
func (am *DefaultAccountManager) ListRoutes(ctx context.Context, accountID, userID string) ([]*route.Route, error) {
|
func (am *DefaultAccountManager) ListRoutes(ctx context.Context, accountID, userID string) ([]*route.Route, error) {
|
||||||
unlock := am.Store.AcquireWriteLockByUID(ctx, accountID)
|
user, err := am.Store.GetUserByUserID(ctx, LockingStrengthShare, userID)
|
||||||
defer unlock()
|
|
||||||
|
|
||||||
account, err := am.Store.GetAccount(ctx, accountID)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
user, err := account.FindUser(userID)
|
if !user.IsAdminOrServiceUser() || user.AccountID != accountID {
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
if !(user.HasAdminPower() || user.IsServiceUser) {
|
|
||||||
return nil, status.Errorf(status.PermissionDenied, "only users with admin power can view Network Routes")
|
return nil, status.Errorf(status.PermissionDenied, "only users with admin power can view Network Routes")
|
||||||
}
|
}
|
||||||
|
|
||||||
routes := make([]*route.Route, 0, len(account.Routes))
|
return am.Store.GetAccountRoutes(ctx, LockingStrengthShare, accountID)
|
||||||
for _, item := range account.Routes {
|
|
||||||
routes = append(routes, item)
|
|
||||||
}
|
|
||||||
|
|
||||||
return routes, nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func toProtocolRoute(route *route.Route) *proto.Route {
|
func toProtocolRoute(route *route.Route) *proto.Route {
|
||||||
|
@ -1205,7 +1205,7 @@ func TestGetNetworkMap_RouteSync(t *testing.T) {
|
|||||||
newPolicy.Rules[0].Sources = []string{newGroup.ID}
|
newPolicy.Rules[0].Sources = []string{newGroup.ID}
|
||||||
newPolicy.Rules[0].Destinations = []string{newGroup.ID}
|
newPolicy.Rules[0].Destinations = []string{newGroup.ID}
|
||||||
|
|
||||||
err = am.SavePolicy(context.Background(), account.Id, userID, newPolicy)
|
err = am.SavePolicy(context.Background(), account.Id, userID, newPolicy, false)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
err = am.DeletePolicy(context.Background(), account.Id, defaultRule.ID, userID)
|
err = am.DeletePolicy(context.Background(), account.Id, defaultRule.ID, userID)
|
||||||
|
@ -330,26 +330,24 @@ func (am *DefaultAccountManager) SaveSetupKey(ctx context.Context, accountID str
|
|||||||
|
|
||||||
// ListSetupKeys returns a list of all setup keys of the account
|
// ListSetupKeys returns a list of all setup keys of the account
|
||||||
func (am *DefaultAccountManager) ListSetupKeys(ctx context.Context, accountID, userID string) ([]*SetupKey, error) {
|
func (am *DefaultAccountManager) ListSetupKeys(ctx context.Context, accountID, userID string) ([]*SetupKey, error) {
|
||||||
unlock := am.Store.AcquireWriteLockByUID(ctx, accountID)
|
user, err := am.Store.GetUserByUserID(ctx, LockingStrengthShare, userID)
|
||||||
defer unlock()
|
|
||||||
account, err := am.Store.GetAccount(ctx, accountID)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
user, err := account.FindUser(userID)
|
if !user.IsAdminOrServiceUser() || user.AccountID != accountID {
|
||||||
|
return nil, status.Errorf(status.Unauthorized, "only users with admin power can view setup keys")
|
||||||
|
}
|
||||||
|
|
||||||
|
setupKeys, err := am.Store.GetAccountSetupKeys(ctx, LockingStrengthShare, accountID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
if !user.HasAdminPower() && !user.IsServiceUser {
|
keys := make([]*SetupKey, 0, len(setupKeys))
|
||||||
return nil, status.Errorf(status.Unauthorized, "only users with admin power can view policies")
|
for _, key := range setupKeys {
|
||||||
}
|
|
||||||
|
|
||||||
keys := make([]*SetupKey, 0, len(account.SetupKeys))
|
|
||||||
for _, key := range account.SetupKeys {
|
|
||||||
var k *SetupKey
|
var k *SetupKey
|
||||||
if !(user.HasAdminPower() || user.IsServiceUser) {
|
if !user.IsAdminOrServiceUser() {
|
||||||
k = key.HiddenCopy(999)
|
k = key.HiddenCopy(999)
|
||||||
} else {
|
} else {
|
||||||
k = key.Copy()
|
k = key.Copy()
|
||||||
@ -362,44 +360,30 @@ func (am *DefaultAccountManager) ListSetupKeys(ctx context.Context, accountID, u
|
|||||||
|
|
||||||
// GetSetupKey looks up a SetupKey by KeyID, returns NotFound error if not found.
|
// GetSetupKey looks up a SetupKey by KeyID, returns NotFound error if not found.
|
||||||
func (am *DefaultAccountManager) GetSetupKey(ctx context.Context, accountID, userID, keyID string) (*SetupKey, error) {
|
func (am *DefaultAccountManager) GetSetupKey(ctx context.Context, accountID, userID, keyID string) (*SetupKey, error) {
|
||||||
unlock := am.Store.AcquireWriteLockByUID(ctx, accountID)
|
user, err := am.Store.GetUserByUserID(ctx, LockingStrengthShare, userID)
|
||||||
defer unlock()
|
|
||||||
|
|
||||||
account, err := am.Store.GetAccount(ctx, accountID)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
user, err := account.FindUser(userID)
|
if !user.IsAdminOrServiceUser() || user.AccountID != accountID {
|
||||||
|
return nil, status.Errorf(status.Unauthorized, "only users with admin power can view setup keys")
|
||||||
|
}
|
||||||
|
|
||||||
|
setupKey, err := am.Store.GetSetupKeyByID(ctx, LockingStrengthShare, keyID, accountID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
if !user.HasAdminPower() && !user.IsServiceUser {
|
|
||||||
return nil, status.Errorf(status.Unauthorized, "only users with admin power can view policies")
|
|
||||||
}
|
|
||||||
|
|
||||||
var foundKey *SetupKey
|
|
||||||
for _, key := range account.SetupKeys {
|
|
||||||
if key.Id == keyID {
|
|
||||||
foundKey = key.Copy()
|
|
||||||
break
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if foundKey == nil {
|
|
||||||
return nil, status.Errorf(status.NotFound, "setup key not found")
|
|
||||||
}
|
|
||||||
|
|
||||||
// the UpdatedAt field was introduced later, so there might be that some keys have a Zero value (e.g, null in the store file)
|
// the UpdatedAt field was introduced later, so there might be that some keys have a Zero value (e.g, null in the store file)
|
||||||
if foundKey.UpdatedAt.IsZero() {
|
if setupKey.UpdatedAt.IsZero() {
|
||||||
foundKey.UpdatedAt = foundKey.CreatedAt
|
setupKey.UpdatedAt = setupKey.CreatedAt
|
||||||
}
|
}
|
||||||
|
|
||||||
if !(user.HasAdminPower() || user.IsServiceUser) {
|
if !user.IsAdminOrServiceUser() {
|
||||||
foundKey = foundKey.HiddenCopy(999)
|
setupKey = setupKey.HiddenCopy(999)
|
||||||
}
|
}
|
||||||
|
|
||||||
return foundKey, nil
|
return setupKey, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func validateSetupKeyAutoGroups(account *Account, autoGroups []string) error {
|
func validateSetupKeyAutoGroups(account *Account, autoGroups []string) error {
|
||||||
|
@ -36,6 +36,7 @@ const (
|
|||||||
idQueryCondition = "id = ?"
|
idQueryCondition = "id = ?"
|
||||||
keyQueryCondition = "key = ?"
|
keyQueryCondition = "key = ?"
|
||||||
accountAndIDQueryCondition = "account_id = ? and id = ?"
|
accountAndIDQueryCondition = "account_id = ? and id = ?"
|
||||||
|
accountIDCondition = "account_id = ?"
|
||||||
peerNotFoundFMT = "peer %s not found"
|
peerNotFoundFMT = "peer %s not found"
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -399,20 +400,30 @@ func (s *SqlStore) DeleteTokenID2UserIDIndex(tokenID string) error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (s *SqlStore) GetAccountByPrivateDomain(ctx context.Context, domain string) (*Account, error) {
|
func (s *SqlStore) GetAccountByPrivateDomain(ctx context.Context, domain string) (*Account, error) {
|
||||||
var account Account
|
accountID, err := s.GetAccountIDByPrivateDomain(ctx, LockingStrengthShare, domain)
|
||||||
|
if err != nil {
|
||||||
result := s.db.First(&account, "domain = ? and is_domain_primary_account = ? and domain_category = ?",
|
return nil, err
|
||||||
strings.ToLower(domain), true, PrivateCategory)
|
|
||||||
if result.Error != nil {
|
|
||||||
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
|
|
||||||
return nil, status.Errorf(status.NotFound, "account not found: provided domain is not registered or is not private")
|
|
||||||
}
|
|
||||||
log.WithContext(ctx).Errorf("error when getting account from the store: %s", result.Error)
|
|
||||||
return nil, status.Errorf(status.Internal, "issue getting account from store")
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// TODO: rework to not call GetAccount
|
// TODO: rework to not call GetAccount
|
||||||
return s.GetAccount(ctx, account.Id)
|
return s.GetAccount(ctx, accountID)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *SqlStore) GetAccountIDByPrivateDomain(ctx context.Context, lockStrength LockingStrength, domain string) (string, error) {
|
||||||
|
var accountID string
|
||||||
|
result := s.db.WithContext(ctx).Clauses(clause.Locking{Strength: string(lockStrength)}).Model(&Account{}).Select("id").
|
||||||
|
Where("domain = ? and is_domain_primary_account = ? and domain_category = ?",
|
||||||
|
strings.ToLower(domain), true, PrivateCategory,
|
||||||
|
).First(&accountID)
|
||||||
|
if result.Error != nil {
|
||||||
|
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
|
||||||
|
return "", status.Errorf(status.NotFound, "account not found: provided domain is not registered or is not private")
|
||||||
|
}
|
||||||
|
log.WithContext(ctx).Errorf("error when getting account from the store: %s", result.Error)
|
||||||
|
return "", status.Errorf(status.Internal, "issue getting account from store")
|
||||||
|
}
|
||||||
|
|
||||||
|
return accountID, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *SqlStore) GetAccountBySetupKey(ctx context.Context, setupKey string) (*Account, error) {
|
func (s *SqlStore) GetAccountBySetupKey(ctx context.Context, setupKey string) (*Account, error) {
|
||||||
@ -478,7 +489,7 @@ func (s *SqlStore) GetUserByTokenID(ctx context.Context, tokenID string) (*User,
|
|||||||
func (s *SqlStore) GetUserByUserID(ctx context.Context, lockStrength LockingStrength, userID string) (*User, error) {
|
func (s *SqlStore) GetUserByUserID(ctx context.Context, lockStrength LockingStrength, userID string) (*User, error) {
|
||||||
var user User
|
var user User
|
||||||
result := s.db.WithContext(ctx).Clauses(clause.Locking{Strength: string(lockStrength)}).
|
result := s.db.WithContext(ctx).Clauses(clause.Locking{Strength: string(lockStrength)}).
|
||||||
First(&user, idQueryCondition, userID)
|
Preload(clause.Associations).First(&user, idQueryCondition, userID)
|
||||||
if result.Error != nil {
|
if result.Error != nil {
|
||||||
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
|
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
|
||||||
return nil, status.NewUserNotFoundError(userID)
|
return nil, status.NewUserNotFoundError(userID)
|
||||||
@ -491,7 +502,7 @@ func (s *SqlStore) GetUserByUserID(ctx context.Context, lockStrength LockingStre
|
|||||||
|
|
||||||
func (s *SqlStore) GetAccountGroups(ctx context.Context, accountID string) ([]*nbgroup.Group, error) {
|
func (s *SqlStore) GetAccountGroups(ctx context.Context, accountID string) ([]*nbgroup.Group, error) {
|
||||||
var groups []*nbgroup.Group
|
var groups []*nbgroup.Group
|
||||||
result := s.db.Find(&groups, idQueryCondition, accountID)
|
result := s.db.Find(&groups, accountIDCondition, accountID)
|
||||||
if result.Error != nil {
|
if result.Error != nil {
|
||||||
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
|
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
|
||||||
return nil, status.Errorf(status.NotFound, "accountID not found: index lookup failed")
|
return nil, status.Errorf(status.NotFound, "accountID not found: index lookup failed")
|
||||||
@ -661,9 +672,8 @@ func (s *SqlStore) GetAccountIDByPeerPubKey(ctx context.Context, peerKey string)
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (s *SqlStore) GetAccountIDByUserID(userID string) (string, error) {
|
func (s *SqlStore) GetAccountIDByUserID(userID string) (string, error) {
|
||||||
var user User
|
|
||||||
var accountID string
|
var accountID string
|
||||||
result := s.db.Model(&user).Select("account_id").Where(idQueryCondition, userID).First(&accountID)
|
result := s.db.Model(&User{}).Select("account_id").Where(idQueryCondition, userID).First(&accountID)
|
||||||
if result.Error != nil {
|
if result.Error != nil {
|
||||||
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
|
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
|
||||||
return "", status.Errorf(status.NotFound, "account not found: index lookup failed")
|
return "", status.Errorf(status.NotFound, "account not found: index lookup failed")
|
||||||
@ -1028,3 +1038,152 @@ func (s *SqlStore) withTx(tx *gorm.DB) Store {
|
|||||||
func (s *SqlStore) GetDB() *gorm.DB {
|
func (s *SqlStore) GetDB() *gorm.DB {
|
||||||
return s.db
|
return s.db
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (s *SqlStore) GetAccountDNSSettings(ctx context.Context, lockStrength LockingStrength, accountID string) (*DNSSettings, error) {
|
||||||
|
var accountDNSSettings AccountDNSSettings
|
||||||
|
|
||||||
|
result := s.db.WithContext(ctx).Clauses(clause.Locking{Strength: string(lockStrength)}).Model(&Account{}).
|
||||||
|
First(&accountDNSSettings, idQueryCondition, accountID)
|
||||||
|
if result.Error != nil {
|
||||||
|
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
|
||||||
|
return nil, status.Errorf(status.NotFound, "dns settings not found")
|
||||||
|
}
|
||||||
|
return nil, status.Errorf(status.Internal, "failed to get dns settings from store: %v", result.Error)
|
||||||
|
}
|
||||||
|
return &accountDNSSettings.DNSSettings, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// AccountExists checks whether an account exists by the given ID.
|
||||||
|
func (s *SqlStore) AccountExists(ctx context.Context, lockStrength LockingStrength, id string) (bool, error) {
|
||||||
|
var accountID string
|
||||||
|
|
||||||
|
result := s.db.WithContext(ctx).Clauses(clause.Locking{Strength: string(lockStrength)}).Model(&Account{}).
|
||||||
|
Select("id").First(&accountID, idQueryCondition, id)
|
||||||
|
if result.Error != nil {
|
||||||
|
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
|
||||||
|
return false, nil
|
||||||
|
}
|
||||||
|
return false, result.Error
|
||||||
|
}
|
||||||
|
|
||||||
|
return accountID != "", nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetAccountDomainAndCategory retrieves the Domain and DomainCategory fields for an account based on the given accountID.
|
||||||
|
func (s *SqlStore) GetAccountDomainAndCategory(ctx context.Context, lockStrength LockingStrength, accountID string) (string, string, error) {
|
||||||
|
var account Account
|
||||||
|
|
||||||
|
result := s.db.WithContext(ctx).Clauses(clause.Locking{Strength: string(lockStrength)}).Model(&Account{}).Select("domain", "domain_category").
|
||||||
|
Where(idQueryCondition, accountID).First(&account)
|
||||||
|
if result.Error != nil {
|
||||||
|
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
|
||||||
|
return "", "", status.Errorf(status.NotFound, "account not found")
|
||||||
|
}
|
||||||
|
return "", "", status.Errorf(status.Internal, "failed to get domain category from store: %v", result.Error)
|
||||||
|
}
|
||||||
|
|
||||||
|
return account.Domain, account.DomainCategory, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetGroupByID retrieves a group by ID and account ID.
|
||||||
|
func (s *SqlStore) GetGroupByID(ctx context.Context, lockStrength LockingStrength, groupID, accountID string) (*nbgroup.Group, error) {
|
||||||
|
return getRecordByID[nbgroup.Group](s.db.WithContext(ctx).Preload(clause.Associations), lockStrength, groupID, accountID)
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetGroupByName retrieves a group by name and account ID.
|
||||||
|
func (s *SqlStore) GetGroupByName(ctx context.Context, lockStrength LockingStrength, groupName, accountID string) (*nbgroup.Group, error) {
|
||||||
|
var group nbgroup.Group
|
||||||
|
|
||||||
|
result := s.db.WithContext(ctx).Clauses(clause.Locking{Strength: string(lockStrength)}).Preload(clause.Associations).
|
||||||
|
Order("json_array_length(peers) DESC").First(&group, "name = ? and account_id = ?", groupName, accountID)
|
||||||
|
if err := result.Error; err != nil {
|
||||||
|
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
|
||||||
|
return nil, status.Errorf(status.NotFound, "group not found")
|
||||||
|
}
|
||||||
|
return nil, status.Errorf(status.Internal, "failed to get group from store: %s", result.Error)
|
||||||
|
}
|
||||||
|
return &group, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetAccountPolicies retrieves policies for an account.
|
||||||
|
func (s *SqlStore) GetAccountPolicies(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*Policy, error) {
|
||||||
|
return getRecords[*Policy](s.db.WithContext(ctx).Preload(clause.Associations), lockStrength, accountID)
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetPolicyByID retrieves a policy by its ID and account ID.
|
||||||
|
func (s *SqlStore) GetPolicyByID(ctx context.Context, lockStrength LockingStrength, policyID string, accountID string) (*Policy, error) {
|
||||||
|
return getRecordByID[Policy](s.db.WithContext(ctx).Preload(clause.Associations), lockStrength, policyID, accountID)
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetAccountPostureChecks retrieves posture checks for an account.
|
||||||
|
func (s *SqlStore) GetAccountPostureChecks(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*posture.Checks, error) {
|
||||||
|
return getRecords[*posture.Checks](s.db.WithContext(ctx), lockStrength, accountID)
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetPostureChecksByID retrieves posture checks by their ID and account ID.
|
||||||
|
func (s *SqlStore) GetPostureChecksByID(ctx context.Context, lockStrength LockingStrength, postureCheckID string, accountID string) (*posture.Checks, error) {
|
||||||
|
return getRecordByID[posture.Checks](s.db.WithContext(ctx), lockStrength, postureCheckID, accountID)
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetAccountRoutes retrieves network routes for an account.
|
||||||
|
func (s *SqlStore) GetAccountRoutes(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*route.Route, error) {
|
||||||
|
return getRecords[*route.Route](s.db.WithContext(ctx), lockStrength, accountID)
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetRouteByID retrieves a route by its ID and account ID.
|
||||||
|
func (s *SqlStore) GetRouteByID(ctx context.Context, lockStrength LockingStrength, routeID string, accountID string) (*route.Route, error) {
|
||||||
|
return getRecordByID[route.Route](s.db.WithContext(ctx), lockStrength, routeID, accountID)
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetAccountSetupKeys retrieves setup keys for an account.
|
||||||
|
func (s *SqlStore) GetAccountSetupKeys(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*SetupKey, error) {
|
||||||
|
return getRecords[*SetupKey](s.db.WithContext(ctx), lockStrength, accountID)
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetSetupKeyByID retrieves a setup key by its ID and account ID.
|
||||||
|
func (s *SqlStore) GetSetupKeyByID(ctx context.Context, lockStrength LockingStrength, setupKeyID string, accountID string) (*SetupKey, error) {
|
||||||
|
return getRecordByID[SetupKey](s.db.WithContext(ctx), lockStrength, setupKeyID, accountID)
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetAccountNameServerGroups retrieves name server groups for an account.
|
||||||
|
func (s *SqlStore) GetAccountNameServerGroups(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*nbdns.NameServerGroup, error) {
|
||||||
|
return getRecords[*nbdns.NameServerGroup](s.db.WithContext(ctx), lockStrength, accountID)
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetNameServerGroupByID retrieves a name server group by its ID and account ID.
|
||||||
|
func (s *SqlStore) GetNameServerGroupByID(ctx context.Context, lockStrength LockingStrength, nsGroupID string, accountID string) (*nbdns.NameServerGroup, error) {
|
||||||
|
return getRecordByID[nbdns.NameServerGroup](s.db.WithContext(ctx), lockStrength, nsGroupID, accountID)
|
||||||
|
}
|
||||||
|
|
||||||
|
// getRecords retrieves records from the database based on the account ID.
|
||||||
|
func getRecords[T any](db *gorm.DB, lockStrength LockingStrength, accountID string) ([]T, error) {
|
||||||
|
var record []T
|
||||||
|
|
||||||
|
result := db.Clauses(clause.Locking{Strength: string(lockStrength)}).Find(&record, accountIDCondition, accountID)
|
||||||
|
if err := result.Error; err != nil {
|
||||||
|
parts := strings.Split(fmt.Sprintf("%T", record), ".")
|
||||||
|
recordType := parts[len(parts)-1]
|
||||||
|
|
||||||
|
return nil, status.Errorf(status.Internal, "failed to get account %ss from store: %v", recordType, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return record, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// getRecordByID retrieves a record by its ID and account ID from the database.
|
||||||
|
func getRecordByID[T any](db *gorm.DB, lockStrength LockingStrength, recordID, accountID string) (*T, error) {
|
||||||
|
var record T
|
||||||
|
|
||||||
|
result := db.Clauses(clause.Locking{Strength: string(lockStrength)}).
|
||||||
|
First(&record, accountAndIDQueryCondition, accountID, recordID)
|
||||||
|
if err := result.Error; err != nil {
|
||||||
|
parts := strings.Split(fmt.Sprintf("%T", record), ".")
|
||||||
|
recordType := parts[len(parts)-1]
|
||||||
|
|
||||||
|
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
|
||||||
|
return nil, status.Errorf(status.NotFound, "%s not found", recordType)
|
||||||
|
}
|
||||||
|
return nil, status.Errorf(status.Internal, "failed to get %s from store: %v", recordType, err)
|
||||||
|
}
|
||||||
|
return &record, nil
|
||||||
|
}
|
||||||
|
@ -12,6 +12,7 @@ import (
|
|||||||
"strings"
|
"strings"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
"github.com/netbirdio/netbird/dns"
|
||||||
log "github.com/sirupsen/logrus"
|
log "github.com/sirupsen/logrus"
|
||||||
"gorm.io/gorm"
|
"gorm.io/gorm"
|
||||||
|
|
||||||
@ -39,53 +40,81 @@ const (
|
|||||||
type Store interface {
|
type Store interface {
|
||||||
GetAllAccounts(ctx context.Context) []*Account
|
GetAllAccounts(ctx context.Context) []*Account
|
||||||
GetAccount(ctx context.Context, accountID string) (*Account, error)
|
GetAccount(ctx context.Context, accountID string) (*Account, error)
|
||||||
DeleteAccount(ctx context.Context, account *Account) error
|
AccountExists(ctx context.Context, lockStrength LockingStrength, id string) (bool, error)
|
||||||
|
GetAccountDomainAndCategory(ctx context.Context, lockStrength LockingStrength, accountID string) (string, string, error)
|
||||||
GetAccountByUser(ctx context.Context, userID string) (*Account, error)
|
GetAccountByUser(ctx context.Context, userID string) (*Account, error)
|
||||||
GetAccountByPeerPubKey(ctx context.Context, peerKey string) (*Account, error)
|
GetAccountByPeerPubKey(ctx context.Context, peerKey string) (*Account, error)
|
||||||
GetAccountIDByPeerPubKey(ctx context.Context, peerKey string) (string, error)
|
GetAccountIDByPeerPubKey(ctx context.Context, peerKey string) (string, error)
|
||||||
GetAccountIDByUserID(peerKey string) (string, error)
|
GetAccountIDByUserID(userID string) (string, error)
|
||||||
GetAccountIDBySetupKey(ctx context.Context, peerKey string) (string, error)
|
GetAccountIDBySetupKey(ctx context.Context, peerKey string) (string, error)
|
||||||
GetAccountByPeerID(ctx context.Context, peerID string) (*Account, error)
|
GetAccountByPeerID(ctx context.Context, peerID string) (*Account, error)
|
||||||
GetAccountBySetupKey(ctx context.Context, setupKey string) (*Account, error) // todo use key hash later
|
GetAccountBySetupKey(ctx context.Context, setupKey string) (*Account, error) // todo use key hash later
|
||||||
GetAccountByPrivateDomain(ctx context.Context, domain string) (*Account, error)
|
GetAccountByPrivateDomain(ctx context.Context, domain string) (*Account, error)
|
||||||
GetTokenIDByHashedToken(ctx context.Context, secret string) (string, error)
|
GetAccountIDByPrivateDomain(ctx context.Context, lockStrength LockingStrength, domain string) (string, error)
|
||||||
|
GetAccountSettings(ctx context.Context, lockStrength LockingStrength, accountID string) (*Settings, error)
|
||||||
|
GetAccountDNSSettings(ctx context.Context, lockStrength LockingStrength, accountID string) (*DNSSettings, error)
|
||||||
|
SaveAccount(ctx context.Context, account *Account) error
|
||||||
|
DeleteAccount(ctx context.Context, account *Account) error
|
||||||
|
|
||||||
GetUserByTokenID(ctx context.Context, tokenID string) (*User, error)
|
GetUserByTokenID(ctx context.Context, tokenID string) (*User, error)
|
||||||
GetUserByUserID(ctx context.Context, lockStrength LockingStrength, userID string) (*User, error)
|
GetUserByUserID(ctx context.Context, lockStrength LockingStrength, userID string) (*User, error)
|
||||||
GetAccountGroups(ctx context.Context, accountID string) ([]*nbgroup.Group, error)
|
|
||||||
GetPostureCheckByChecksDefinition(accountID string, checks *posture.ChecksDefinition) (*posture.Checks, error)
|
|
||||||
SaveAccount(ctx context.Context, account *Account) error
|
|
||||||
SaveUsers(accountID string, users map[string]*User) error
|
SaveUsers(accountID string, users map[string]*User) error
|
||||||
SaveGroups(accountID string, groups map[string]*nbgroup.Group) error
|
SaveUserLastLogin(ctx context.Context, accountID, userID string, lastLogin time.Time) error
|
||||||
|
GetTokenIDByHashedToken(ctx context.Context, secret string) (string, error)
|
||||||
DeleteHashedPAT2TokenIDIndex(hashedToken string) error
|
DeleteHashedPAT2TokenIDIndex(hashedToken string) error
|
||||||
DeleteTokenID2UserIDIndex(tokenID string) error
|
DeleteTokenID2UserIDIndex(tokenID string) error
|
||||||
|
|
||||||
|
GetAccountGroups(ctx context.Context, accountID string) ([]*nbgroup.Group, error)
|
||||||
|
GetGroupByID(ctx context.Context, lockStrength LockingStrength, groupID, accountID string) (*nbgroup.Group, error)
|
||||||
|
GetGroupByName(ctx context.Context, lockStrength LockingStrength, groupName, accountID string) (*nbgroup.Group, error)
|
||||||
|
SaveGroups(accountID string, groups map[string]*nbgroup.Group) error
|
||||||
|
|
||||||
|
GetAccountPolicies(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*Policy, error)
|
||||||
|
GetPolicyByID(ctx context.Context, lockStrength LockingStrength, policyID string, accountID string) (*Policy, error)
|
||||||
|
|
||||||
|
GetPostureCheckByChecksDefinition(accountID string, checks *posture.ChecksDefinition) (*posture.Checks, error)
|
||||||
|
GetAccountPostureChecks(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*posture.Checks, error)
|
||||||
|
GetPostureChecksByID(ctx context.Context, lockStrength LockingStrength, postureCheckID string, accountID string) (*posture.Checks, error)
|
||||||
|
|
||||||
|
GetPeerLabelsInAccount(ctx context.Context, lockStrength LockingStrength, accountId string) ([]string, error)
|
||||||
|
AddPeerToAllGroup(ctx context.Context, accountID string, peerID string) error
|
||||||
|
AddPeerToGroup(ctx context.Context, accountId string, peerId string, groupID string) error
|
||||||
|
AddPeerToAccount(ctx context.Context, peer *nbpeer.Peer) error
|
||||||
|
GetPeerByPeerPubKey(ctx context.Context, lockStrength LockingStrength, peerKey string) (*nbpeer.Peer, error)
|
||||||
|
SavePeer(ctx context.Context, accountID string, peer *nbpeer.Peer) error
|
||||||
|
SavePeerStatus(accountID, peerID string, status nbpeer.PeerStatus) error
|
||||||
|
SavePeerLocation(accountID string, peer *nbpeer.Peer) error
|
||||||
|
|
||||||
|
GetSetupKeyBySecret(ctx context.Context, lockStrength LockingStrength, key string) (*SetupKey, error)
|
||||||
|
IncrementSetupKeyUsage(ctx context.Context, setupKeyID string) error
|
||||||
|
GetAccountSetupKeys(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*SetupKey, error)
|
||||||
|
GetSetupKeyByID(ctx context.Context, lockStrength LockingStrength, setupKeyID string, accountID string) (*SetupKey, error)
|
||||||
|
|
||||||
|
GetAccountRoutes(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*route.Route, error)
|
||||||
|
GetRouteByID(ctx context.Context, lockStrength LockingStrength, routeID string, accountID string) (*route.Route, error)
|
||||||
|
|
||||||
|
GetAccountNameServerGroups(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*dns.NameServerGroup, error)
|
||||||
|
GetNameServerGroupByID(ctx context.Context, lockStrength LockingStrength, nameServerGroupID string, accountID string) (*dns.NameServerGroup, error)
|
||||||
|
|
||||||
|
GetTakenIPs(ctx context.Context, lockStrength LockingStrength, accountId string) ([]net.IP, error)
|
||||||
|
IncrementNetworkSerial(ctx context.Context, accountId string) error
|
||||||
|
GetAccountNetwork(ctx context.Context, lockStrength LockingStrength, accountId string) (*Network, error)
|
||||||
|
|
||||||
GetInstallationID() string
|
GetInstallationID() string
|
||||||
SaveInstallationID(ctx context.Context, ID string) error
|
SaveInstallationID(ctx context.Context, ID string) error
|
||||||
|
|
||||||
// AcquireWriteLockByUID should attempt to acquire a lock for write purposes and return a function that releases the lock
|
// AcquireWriteLockByUID should attempt to acquire a lock for write purposes and return a function that releases the lock
|
||||||
AcquireWriteLockByUID(ctx context.Context, uniqueID string) func()
|
AcquireWriteLockByUID(ctx context.Context, uniqueID string) func()
|
||||||
// AcquireReadLockByUID should attempt to acquire lock for read purposes and return a function that releases the lock
|
// AcquireReadLockByUID should attempt to acquire lock for read purposes and return a function that releases the lock
|
||||||
AcquireReadLockByUID(ctx context.Context, uniqueID string) func()
|
AcquireReadLockByUID(ctx context.Context, uniqueID string) func()
|
||||||
// AcquireGlobalLock should attempt to acquire a global lock and return a function that releases the lock
|
// AcquireGlobalLock should attempt to acquire a global lock and return a function that releases the lock
|
||||||
AcquireGlobalLock(ctx context.Context) func()
|
AcquireGlobalLock(ctx context.Context) func()
|
||||||
SavePeer(ctx context.Context, accountID string, peer *nbpeer.Peer) error
|
|
||||||
SavePeerStatus(accountID, peerID string, status nbpeer.PeerStatus) error
|
|
||||||
SavePeerLocation(accountID string, peer *nbpeer.Peer) error
|
|
||||||
SaveUserLastLogin(ctx context.Context, accountID, userID string, lastLogin time.Time) error
|
|
||||||
// Close should close the store persisting all unsaved data.
|
// Close should close the store persisting all unsaved data.
|
||||||
Close(ctx context.Context) error
|
Close(ctx context.Context) error
|
||||||
// GetStoreEngine should return StoreEngine of the current store implementation.
|
// GetStoreEngine should return StoreEngine of the current store implementation.
|
||||||
// This is also a method of metrics.DataSource interface.
|
// This is also a method of metrics.DataSource interface.
|
||||||
GetStoreEngine() StoreEngine
|
GetStoreEngine() StoreEngine
|
||||||
GetPeerByPeerPubKey(ctx context.Context, lockStrength LockingStrength, peerKey string) (*nbpeer.Peer, error)
|
|
||||||
GetAccountSettings(ctx context.Context, lockStrength LockingStrength, accountID string) (*Settings, error)
|
|
||||||
GetSetupKeyBySecret(ctx context.Context, lockStrength LockingStrength, key string) (*SetupKey, error)
|
|
||||||
GetTakenIPs(ctx context.Context, lockStrength LockingStrength, accountId string) ([]net.IP, error)
|
|
||||||
IncrementSetupKeyUsage(ctx context.Context, setupKeyID string) error
|
|
||||||
AddPeerToAllGroup(ctx context.Context, accountID string, peerID string) error
|
|
||||||
GetPeerLabelsInAccount(ctx context.Context, lockStrength LockingStrength, accountId string) ([]string, error)
|
|
||||||
AddPeerToGroup(ctx context.Context, accountId string, peerId string, groupID string) error
|
|
||||||
AddPeerToAccount(ctx context.Context, peer *nbpeer.Peer) error
|
|
||||||
IncrementNetworkSerial(ctx context.Context, accountId string) error
|
|
||||||
GetAccountNetwork(ctx context.Context, lockStrength LockingStrength, accountId string) (*Network, error)
|
|
||||||
ExecuteInTransaction(ctx context.Context, f func(store Store) error) error
|
ExecuteInTransaction(ctx context.Context, f func(store Store) error) error
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -94,6 +94,11 @@ func (u *User) HasAdminPower() bool {
|
|||||||
return u.Role == UserRoleAdmin || u.Role == UserRoleOwner
|
return u.Role == UserRoleAdmin || u.Role == UserRoleOwner
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// IsAdminOrServiceUser checks if the user has admin power or is a service user.
|
||||||
|
func (u *User) IsAdminOrServiceUser() bool {
|
||||||
|
return u.HasAdminPower() || u.IsServiceUser
|
||||||
|
}
|
||||||
|
|
||||||
// ToUserInfo converts a User object to a UserInfo object.
|
// ToUserInfo converts a User object to a UserInfo object.
|
||||||
func (u *User) ToUserInfo(userData *idp.UserData, settings *Settings) (*UserInfo, error) {
|
func (u *User) ToUserInfo(userData *idp.UserData, settings *Settings) (*UserInfo, error) {
|
||||||
autoGroups := u.AutoGroups
|
autoGroups := u.AutoGroups
|
||||||
@ -357,39 +362,35 @@ func (am *DefaultAccountManager) inviteNewUser(ctx context.Context, accountID, u
|
|||||||
return newUser.ToUserInfo(idpUser, account.Settings)
|
return newUser.ToUserInfo(idpUser, account.Settings)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (am *DefaultAccountManager) GetUserByID(ctx context.Context, id string) (*User, error) {
|
||||||
|
return am.Store.GetUserByUserID(ctx, LockingStrengthShare, id)
|
||||||
|
}
|
||||||
|
|
||||||
// GetUser looks up a user by provided authorization claims.
|
// GetUser looks up a user by provided authorization claims.
|
||||||
// It will also create an account if didn't exist for this user before.
|
// It will also create an account if didn't exist for this user before.
|
||||||
func (am *DefaultAccountManager) GetUser(ctx context.Context, claims jwtclaims.AuthorizationClaims) (*User, error) {
|
func (am *DefaultAccountManager) GetUser(ctx context.Context, claims jwtclaims.AuthorizationClaims) (*User, error) {
|
||||||
account, _, err := am.GetAccountFromToken(ctx, claims)
|
accountID, userID, err := am.GetAccountIDFromToken(ctx, claims)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("failed to get account with token claims %v", err)
|
return nil, fmt.Errorf("failed to get account with token claims %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
unlock := am.Store.AcquireWriteLockByUID(ctx, account.Id)
|
user, err := am.Store.GetUserByUserID(ctx, LockingStrengthShare, userID)
|
||||||
defer unlock()
|
|
||||||
|
|
||||||
account, err = am.Store.GetAccount(ctx, account.Id)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("failed to get an account from store %v", err)
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
user, ok := account.Users[claims.UserId]
|
// this code should be outside of the am.GetAccountIDFromToken(claims) because this method is called also by the gRPC
|
||||||
if !ok {
|
|
||||||
return nil, status.Errorf(status.NotFound, "user not found")
|
|
||||||
}
|
|
||||||
|
|
||||||
// this code should be outside of the am.GetAccountFromToken(claims) because this method is called also by the gRPC
|
|
||||||
// server when user authenticates a device. And we need to separate the Dashboard login event from the Device login event.
|
// server when user authenticates a device. And we need to separate the Dashboard login event from the Device login event.
|
||||||
newLogin := user.LastDashboardLoginChanged(claims.LastLogin)
|
newLogin := user.LastDashboardLoginChanged(claims.LastLogin)
|
||||||
|
|
||||||
err = am.Store.SaveUserLastLogin(ctx, account.Id, claims.UserId, claims.LastLogin)
|
err = am.Store.SaveUserLastLogin(ctx, accountID, userID, claims.LastLogin)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.WithContext(ctx).Errorf("failed saving user last login: %v", err)
|
log.WithContext(ctx).Errorf("failed saving user last login: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
if newLogin {
|
if newLogin {
|
||||||
meta := map[string]any{"timestamp": claims.LastLogin}
|
meta := map[string]any{"timestamp": claims.LastLogin}
|
||||||
am.StoreEvent(ctx, claims.UserId, claims.UserId, account.Id, activity.DashboardLogin, meta)
|
am.StoreEvent(ctx, claims.UserId, claims.UserId, accountID, activity.DashboardLogin, meta)
|
||||||
}
|
}
|
||||||
|
|
||||||
return user, nil
|
return user, nil
|
||||||
@ -642,63 +643,48 @@ func (am *DefaultAccountManager) DeletePAT(ctx context.Context, accountID string
|
|||||||
|
|
||||||
// GetPAT returns a specific PAT from a user
|
// GetPAT returns a specific PAT from a user
|
||||||
func (am *DefaultAccountManager) GetPAT(ctx context.Context, accountID string, initiatorUserID string, targetUserID string, tokenID string) (*PersonalAccessToken, error) {
|
func (am *DefaultAccountManager) GetPAT(ctx context.Context, accountID string, initiatorUserID string, targetUserID string, tokenID string) (*PersonalAccessToken, error) {
|
||||||
unlock := am.Store.AcquireWriteLockByUID(ctx, accountID)
|
initiatorUser, err := am.Store.GetUserByUserID(ctx, LockingStrengthShare, initiatorUserID)
|
||||||
defer unlock()
|
|
||||||
|
|
||||||
account, err := am.Store.GetAccount(ctx, accountID)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, status.Errorf(status.NotFound, "account not found: %s", err)
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
targetUser, ok := account.Users[targetUserID]
|
targetUser, err := am.Store.GetUserByUserID(ctx, LockingStrengthShare, targetUserID)
|
||||||
if !ok {
|
if err != nil {
|
||||||
return nil, status.Errorf(status.NotFound, "user not found")
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
executingUser, ok := account.Users[initiatorUserID]
|
if (initiatorUserID != targetUserID && !initiatorUser.IsAdminOrServiceUser()) || initiatorUser.AccountID != accountID {
|
||||||
if !ok {
|
return nil, status.Errorf(status.PermissionDenied, "no permission to get PAT for this user")
|
||||||
return nil, status.Errorf(status.NotFound, "user not found")
|
|
||||||
}
|
}
|
||||||
|
|
||||||
if !(initiatorUserID == targetUserID || (executingUser.HasAdminPower() && targetUser.IsServiceUser)) {
|
for _, pat := range targetUser.PATsG {
|
||||||
return nil, status.Errorf(status.PermissionDenied, "no permission to get PAT for this userser")
|
if pat.ID == tokenID {
|
||||||
|
return pat.Copy(), nil
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
pat := targetUser.PATs[tokenID]
|
return nil, status.Errorf(status.NotFound, "PAT not found")
|
||||||
if pat == nil {
|
|
||||||
return nil, status.Errorf(status.NotFound, "PAT not found")
|
|
||||||
}
|
|
||||||
|
|
||||||
return pat, nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetAllPATs returns all PATs for a user
|
// GetAllPATs returns all PATs for a user
|
||||||
func (am *DefaultAccountManager) GetAllPATs(ctx context.Context, accountID string, initiatorUserID string, targetUserID string) ([]*PersonalAccessToken, error) {
|
func (am *DefaultAccountManager) GetAllPATs(ctx context.Context, accountID string, initiatorUserID string, targetUserID string) ([]*PersonalAccessToken, error) {
|
||||||
unlock := am.Store.AcquireWriteLockByUID(ctx, accountID)
|
initiatorUser, err := am.Store.GetUserByUserID(ctx, LockingStrengthShare, initiatorUserID)
|
||||||
defer unlock()
|
|
||||||
|
|
||||||
account, err := am.Store.GetAccount(ctx, accountID)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, status.Errorf(status.NotFound, "account not found: %s", err)
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
targetUser, ok := account.Users[targetUserID]
|
targetUser, err := am.Store.GetUserByUserID(ctx, LockingStrengthShare, targetUserID)
|
||||||
if !ok {
|
if err != nil {
|
||||||
return nil, status.Errorf(status.NotFound, "user not found")
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
executingUser, ok := account.Users[initiatorUserID]
|
if (initiatorUserID != targetUserID && !initiatorUser.IsAdminOrServiceUser()) || initiatorUser.AccountID != accountID {
|
||||||
if !ok {
|
|
||||||
return nil, status.Errorf(status.NotFound, "user not found")
|
|
||||||
}
|
|
||||||
|
|
||||||
if !(initiatorUserID == targetUserID || (executingUser.HasAdminPower() && targetUser.IsServiceUser)) {
|
|
||||||
return nil, status.Errorf(status.PermissionDenied, "no permission to get PAT for this user")
|
return nil, status.Errorf(status.PermissionDenied, "no permission to get PAT for this user")
|
||||||
}
|
}
|
||||||
|
|
||||||
var pats []*PersonalAccessToken
|
pats := make([]*PersonalAccessToken, 0, len(targetUser.PATsG))
|
||||||
for _, pat := range targetUser.PATs {
|
for _, pat := range targetUser.PATsG {
|
||||||
pats = append(pats, pat)
|
pats = append(pats, pat.Copy())
|
||||||
}
|
}
|
||||||
|
|
||||||
return pats, nil
|
return pats, nil
|
||||||
|
@ -199,7 +199,8 @@ func TestUser_GetPAT(t *testing.T) {
|
|||||||
defer store.Close(context.Background())
|
defer store.Close(context.Background())
|
||||||
account := newAccountWithId(context.Background(), mockAccountID, mockUserID, "")
|
account := newAccountWithId(context.Background(), mockAccountID, mockUserID, "")
|
||||||
account.Users[mockUserID] = &User{
|
account.Users[mockUserID] = &User{
|
||||||
Id: mockUserID,
|
Id: mockUserID,
|
||||||
|
AccountID: mockAccountID,
|
||||||
PATs: map[string]*PersonalAccessToken{
|
PATs: map[string]*PersonalAccessToken{
|
||||||
mockTokenID1: {
|
mockTokenID1: {
|
||||||
ID: mockTokenID1,
|
ID: mockTokenID1,
|
||||||
@ -231,7 +232,8 @@ func TestUser_GetAllPATs(t *testing.T) {
|
|||||||
defer store.Close(context.Background())
|
defer store.Close(context.Background())
|
||||||
account := newAccountWithId(context.Background(), mockAccountID, mockUserID, "")
|
account := newAccountWithId(context.Background(), mockAccountID, mockUserID, "")
|
||||||
account.Users[mockUserID] = &User{
|
account.Users[mockUserID] = &User{
|
||||||
Id: mockUserID,
|
Id: mockUserID,
|
||||||
|
AccountID: mockAccountID,
|
||||||
PATs: map[string]*PersonalAccessToken{
|
PATs: map[string]*PersonalAccessToken{
|
||||||
mockTokenID1: {
|
mockTokenID1: {
|
||||||
ID: mockTokenID1,
|
ID: mockTokenID1,
|
||||||
@ -796,7 +798,10 @@ func TestUser_DeleteUser_RegularUsers(t *testing.T) {
|
|||||||
assert.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
}
|
}
|
||||||
|
|
||||||
acc, err := am.GetAccountByUserOrAccountID(context.Background(), "", account.Id, "")
|
accID, err := am.GetAccountIDByUserOrAccountID(context.Background(), "", account.Id, "")
|
||||||
|
assert.NoError(t, err)
|
||||||
|
|
||||||
|
acc, err := am.Store.GetAccount(context.Background(), accID)
|
||||||
assert.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
|
|
||||||
for _, id := range tc.expectedDeleted {
|
for _, id := range tc.expectedDeleted {
|
||||||
|
Loading…
Reference in New Issue
Block a user