[management] Remove redundant get account calls in GetAccountFromToken (#2615)

* refactor access control middleware and user access by JWT groups

Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>

* refactor jwt groups extractor

Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>

* refactor handlers to get account when necessary

Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>

* refactor getAccountFromToken

Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>

* refactor getAccountWithAuthorizationClaims

Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>

* fix merge

Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>

* revert handles change

Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>

* remove GetUserByID from account manager

Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>

* fix tests

Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>

* refactor getAccountWithAuthorizationClaims to return account id

Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>

* refactor handlers to use GetAccountIDFromToken

Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>

* fix tests

Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>

* remove locks

Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>

* refactor

Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>

* add GetGroupByName from store

Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>

* add GetGroupByID from store and refactor

Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>

* Refactor retrieval of policy and posture checks

Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>

* Refactor user permissions and retrieves PAT

Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>

* Refactor route, setupkey, nameserver and dns to get record(s) from store

Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>

* Refactor store

Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>

* fix lint

Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>

* fix tests

Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>

* fix add missing policy source posture checks

Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>

* add store lock

Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>

* fix tests

Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>

* add get account

Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>

---------

Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>
This commit is contained in:
Bethuel Mmbaga
2024-09-27 17:10:50 +03:00
committed by GitHub
parent 4ebf6e1c4c
commit acb73bd64a
44 changed files with 1279 additions and 981 deletions

View File

@ -36,6 +36,7 @@ const (
idQueryCondition = "id = ?"
keyQueryCondition = "key = ?"
accountAndIDQueryCondition = "account_id = ? and id = ?"
accountIDCondition = "account_id = ?"
peerNotFoundFMT = "peer %s not found"
)
@ -399,20 +400,30 @@ func (s *SqlStore) DeleteTokenID2UserIDIndex(tokenID string) error {
}
func (s *SqlStore) GetAccountByPrivateDomain(ctx context.Context, domain string) (*Account, error) {
var account Account
result := s.db.First(&account, "domain = ? and is_domain_primary_account = ? and domain_category = ?",
strings.ToLower(domain), true, PrivateCategory)
if result.Error != nil {
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
return nil, status.Errorf(status.NotFound, "account not found: provided domain is not registered or is not private")
}
log.WithContext(ctx).Errorf("error when getting account from the store: %s", result.Error)
return nil, status.Errorf(status.Internal, "issue getting account from store")
accountID, err := s.GetAccountIDByPrivateDomain(ctx, LockingStrengthShare, domain)
if err != nil {
return nil, err
}
// TODO: rework to not call GetAccount
return s.GetAccount(ctx, account.Id)
return s.GetAccount(ctx, accountID)
}
func (s *SqlStore) GetAccountIDByPrivateDomain(ctx context.Context, lockStrength LockingStrength, domain string) (string, error) {
var accountID string
result := s.db.WithContext(ctx).Clauses(clause.Locking{Strength: string(lockStrength)}).Model(&Account{}).Select("id").
Where("domain = ? and is_domain_primary_account = ? and domain_category = ?",
strings.ToLower(domain), true, PrivateCategory,
).First(&accountID)
if result.Error != nil {
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
return "", status.Errorf(status.NotFound, "account not found: provided domain is not registered or is not private")
}
log.WithContext(ctx).Errorf("error when getting account from the store: %s", result.Error)
return "", status.Errorf(status.Internal, "issue getting account from store")
}
return accountID, nil
}
func (s *SqlStore) GetAccountBySetupKey(ctx context.Context, setupKey string) (*Account, error) {
@ -478,7 +489,7 @@ func (s *SqlStore) GetUserByTokenID(ctx context.Context, tokenID string) (*User,
func (s *SqlStore) GetUserByUserID(ctx context.Context, lockStrength LockingStrength, userID string) (*User, error) {
var user User
result := s.db.WithContext(ctx).Clauses(clause.Locking{Strength: string(lockStrength)}).
First(&user, idQueryCondition, userID)
Preload(clause.Associations).First(&user, idQueryCondition, userID)
if result.Error != nil {
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
return nil, status.NewUserNotFoundError(userID)
@ -491,7 +502,7 @@ func (s *SqlStore) GetUserByUserID(ctx context.Context, lockStrength LockingStre
func (s *SqlStore) GetAccountGroups(ctx context.Context, accountID string) ([]*nbgroup.Group, error) {
var groups []*nbgroup.Group
result := s.db.Find(&groups, idQueryCondition, accountID)
result := s.db.Find(&groups, accountIDCondition, accountID)
if result.Error != nil {
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
return nil, status.Errorf(status.NotFound, "accountID not found: index lookup failed")
@ -661,9 +672,8 @@ func (s *SqlStore) GetAccountIDByPeerPubKey(ctx context.Context, peerKey string)
}
func (s *SqlStore) GetAccountIDByUserID(userID string) (string, error) {
var user User
var accountID string
result := s.db.Model(&user).Select("account_id").Where(idQueryCondition, userID).First(&accountID)
result := s.db.Model(&User{}).Select("account_id").Where(idQueryCondition, userID).First(&accountID)
if result.Error != nil {
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
return "", status.Errorf(status.NotFound, "account not found: index lookup failed")
@ -1028,3 +1038,152 @@ func (s *SqlStore) withTx(tx *gorm.DB) Store {
func (s *SqlStore) GetDB() *gorm.DB {
return s.db
}
func (s *SqlStore) GetAccountDNSSettings(ctx context.Context, lockStrength LockingStrength, accountID string) (*DNSSettings, error) {
var accountDNSSettings AccountDNSSettings
result := s.db.WithContext(ctx).Clauses(clause.Locking{Strength: string(lockStrength)}).Model(&Account{}).
First(&accountDNSSettings, idQueryCondition, accountID)
if result.Error != nil {
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
return nil, status.Errorf(status.NotFound, "dns settings not found")
}
return nil, status.Errorf(status.Internal, "failed to get dns settings from store: %v", result.Error)
}
return &accountDNSSettings.DNSSettings, nil
}
// AccountExists checks whether an account exists by the given ID.
func (s *SqlStore) AccountExists(ctx context.Context, lockStrength LockingStrength, id string) (bool, error) {
var accountID string
result := s.db.WithContext(ctx).Clauses(clause.Locking{Strength: string(lockStrength)}).Model(&Account{}).
Select("id").First(&accountID, idQueryCondition, id)
if result.Error != nil {
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
return false, nil
}
return false, result.Error
}
return accountID != "", nil
}
// GetAccountDomainAndCategory retrieves the Domain and DomainCategory fields for an account based on the given accountID.
func (s *SqlStore) GetAccountDomainAndCategory(ctx context.Context, lockStrength LockingStrength, accountID string) (string, string, error) {
var account Account
result := s.db.WithContext(ctx).Clauses(clause.Locking{Strength: string(lockStrength)}).Model(&Account{}).Select("domain", "domain_category").
Where(idQueryCondition, accountID).First(&account)
if result.Error != nil {
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
return "", "", status.Errorf(status.NotFound, "account not found")
}
return "", "", status.Errorf(status.Internal, "failed to get domain category from store: %v", result.Error)
}
return account.Domain, account.DomainCategory, nil
}
// GetGroupByID retrieves a group by ID and account ID.
func (s *SqlStore) GetGroupByID(ctx context.Context, lockStrength LockingStrength, groupID, accountID string) (*nbgroup.Group, error) {
return getRecordByID[nbgroup.Group](s.db.WithContext(ctx).Preload(clause.Associations), lockStrength, groupID, accountID)
}
// GetGroupByName retrieves a group by name and account ID.
func (s *SqlStore) GetGroupByName(ctx context.Context, lockStrength LockingStrength, groupName, accountID string) (*nbgroup.Group, error) {
var group nbgroup.Group
result := s.db.WithContext(ctx).Clauses(clause.Locking{Strength: string(lockStrength)}).Preload(clause.Associations).
Order("json_array_length(peers) DESC").First(&group, "name = ? and account_id = ?", groupName, accountID)
if err := result.Error; err != nil {
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
return nil, status.Errorf(status.NotFound, "group not found")
}
return nil, status.Errorf(status.Internal, "failed to get group from store: %s", result.Error)
}
return &group, nil
}
// GetAccountPolicies retrieves policies for an account.
func (s *SqlStore) GetAccountPolicies(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*Policy, error) {
return getRecords[*Policy](s.db.WithContext(ctx).Preload(clause.Associations), lockStrength, accountID)
}
// GetPolicyByID retrieves a policy by its ID and account ID.
func (s *SqlStore) GetPolicyByID(ctx context.Context, lockStrength LockingStrength, policyID string, accountID string) (*Policy, error) {
return getRecordByID[Policy](s.db.WithContext(ctx).Preload(clause.Associations), lockStrength, policyID, accountID)
}
// GetAccountPostureChecks retrieves posture checks for an account.
func (s *SqlStore) GetAccountPostureChecks(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*posture.Checks, error) {
return getRecords[*posture.Checks](s.db.WithContext(ctx), lockStrength, accountID)
}
// GetPostureChecksByID retrieves posture checks by their ID and account ID.
func (s *SqlStore) GetPostureChecksByID(ctx context.Context, lockStrength LockingStrength, postureCheckID string, accountID string) (*posture.Checks, error) {
return getRecordByID[posture.Checks](s.db.WithContext(ctx), lockStrength, postureCheckID, accountID)
}
// GetAccountRoutes retrieves network routes for an account.
func (s *SqlStore) GetAccountRoutes(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*route.Route, error) {
return getRecords[*route.Route](s.db.WithContext(ctx), lockStrength, accountID)
}
// GetRouteByID retrieves a route by its ID and account ID.
func (s *SqlStore) GetRouteByID(ctx context.Context, lockStrength LockingStrength, routeID string, accountID string) (*route.Route, error) {
return getRecordByID[route.Route](s.db.WithContext(ctx), lockStrength, routeID, accountID)
}
// GetAccountSetupKeys retrieves setup keys for an account.
func (s *SqlStore) GetAccountSetupKeys(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*SetupKey, error) {
return getRecords[*SetupKey](s.db.WithContext(ctx), lockStrength, accountID)
}
// GetSetupKeyByID retrieves a setup key by its ID and account ID.
func (s *SqlStore) GetSetupKeyByID(ctx context.Context, lockStrength LockingStrength, setupKeyID string, accountID string) (*SetupKey, error) {
return getRecordByID[SetupKey](s.db.WithContext(ctx), lockStrength, setupKeyID, accountID)
}
// GetAccountNameServerGroups retrieves name server groups for an account.
func (s *SqlStore) GetAccountNameServerGroups(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*nbdns.NameServerGroup, error) {
return getRecords[*nbdns.NameServerGroup](s.db.WithContext(ctx), lockStrength, accountID)
}
// GetNameServerGroupByID retrieves a name server group by its ID and account ID.
func (s *SqlStore) GetNameServerGroupByID(ctx context.Context, lockStrength LockingStrength, nsGroupID string, accountID string) (*nbdns.NameServerGroup, error) {
return getRecordByID[nbdns.NameServerGroup](s.db.WithContext(ctx), lockStrength, nsGroupID, accountID)
}
// getRecords retrieves records from the database based on the account ID.
func getRecords[T any](db *gorm.DB, lockStrength LockingStrength, accountID string) ([]T, error) {
var record []T
result := db.Clauses(clause.Locking{Strength: string(lockStrength)}).Find(&record, accountIDCondition, accountID)
if err := result.Error; err != nil {
parts := strings.Split(fmt.Sprintf("%T", record), ".")
recordType := parts[len(parts)-1]
return nil, status.Errorf(status.Internal, "failed to get account %ss from store: %v", recordType, err)
}
return record, nil
}
// getRecordByID retrieves a record by its ID and account ID from the database.
func getRecordByID[T any](db *gorm.DB, lockStrength LockingStrength, recordID, accountID string) (*T, error) {
var record T
result := db.Clauses(clause.Locking{Strength: string(lockStrength)}).
First(&record, accountAndIDQueryCondition, accountID, recordID)
if err := result.Error; err != nil {
parts := strings.Split(fmt.Sprintf("%T", record), ".")
recordType := parts[len(parts)-1]
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
return nil, status.Errorf(status.NotFound, "%s not found", recordType)
}
return nil, status.Errorf(status.Internal, "failed to get %s from store: %v", recordType, err)
}
return &record, nil
}