Merge branch 'refactor-get-account-by-token' into refactor/get-account-usage

# Conflicts:
#	management/server/file_store.go
This commit is contained in:
bcmmbaga
2024-09-26 18:51:47 +03:00
9 changed files with 35 additions and 33 deletions

View File

@@ -1266,7 +1266,7 @@ func (am *DefaultAccountManager) DeleteAccount(ctx context.Context, accountID, u
// Returns the account ID or an error if none is found or created. // Returns the account ID or an error if none is found or created.
func (am *DefaultAccountManager) GetAccountIDByUserOrAccountID(ctx context.Context, userID, accountID, domain string) (string, error) { func (am *DefaultAccountManager) GetAccountIDByUserOrAccountID(ctx context.Context, userID, accountID, domain string) (string, error) {
if accountID != "" { if accountID != "" {
exists, err := am.Store.AccountExists(ctx, accountID) exists, err := am.Store.AccountExists(ctx, LockingStrengthShare, accountID)
if err != nil { if err != nil {
return "", err return "", err
} }

View File

@@ -647,7 +647,7 @@ func (s *FileStore) GetUserByUserID(_ context.Context, _ LockingStrength, userID
return user, nil return user, nil
} }
func (s *FileStore) GetAccountGroups(ctx context.Context, accountID string) ([]*nbgroup.Group, error) { func (s *FileStore) GetAccountGroups(_ context.Context, accountID string) ([]*nbgroup.Group, error) {
account, err := s.getAccount(accountID) account, err := s.getAccount(accountID)
if err != nil { if err != nil {
return nil, err return nil, err
@@ -985,7 +985,7 @@ func (s *FileStore) GetAccountDomainAndCategory(_ context.Context, _ LockingStre
} }
// AccountExists checks whether an account exists by the given ID. // AccountExists checks whether an account exists by the given ID.
func (s *FileStore) AccountExists(_ context.Context, id string) (bool, error) { func (s *FileStore) AccountExists(_ context.Context, _ LockingStrength, id string) (bool, error) {
_, exists := s.Accounts[id] _, exists := s.Accounts[id]
return exists, nil return exists, nil
} }
@@ -1002,7 +1002,7 @@ func (s *FileStore) GetGroupByName(_ context.Context, _ LockingStrength, _, _ st
return nil, status.Errorf(status.Internal, "GetGroupByName is not implemented") return nil, status.Errorf(status.Internal, "GetGroupByName is not implemented")
} }
func (s *FileStore) GetAccountPolicies(_ context.Context, _ string) ([]*Policy, error) { func (s *FileStore) GetAccountPolicies(_ context.Context, _ LockingStrength, _ string) ([]*Policy, error) {
return nil, status.Errorf(status.Internal, "GetPolicyByID is not implemented") return nil, status.Errorf(status.Internal, "GetPolicyByID is not implemented")
} }
@@ -1019,7 +1019,8 @@ func (s *FileStore) DeletePolicy(_ context.Context, _ LockingStrength, _ string)
return status.Errorf(status.Internal, "DeletePolicy is not implemented") return status.Errorf(status.Internal, "DeletePolicy is not implemented")
} }
func (s *FileStore) GetAccountPostureChecks(_ context.Context, _ string) ([]*posture.Checks, error) {
func (s *FileStore) GetAccountPostureChecks(_ context.Context, _ LockingStrength, _ string) ([]*posture.Checks, error) {
return nil, status.Errorf(status.Internal, "GetAccountPostureChecks is not implemented") return nil, status.Errorf(status.Internal, "GetAccountPostureChecks is not implemented")
} }
@@ -1035,7 +1036,7 @@ func (s *FileStore) DeletePostureChecks(_ context.Context, _ LockingStrength, _
return status.Errorf(status.Internal, "DeletePostureChecks is not implemented") return status.Errorf(status.Internal, "DeletePostureChecks is not implemented")
} }
func (s *FileStore) GetAccountRoutes(_ context.Context, _ string) ([]*route.Route, error) { func (s *FileStore) GetAccountRoutes(_ context.Context, _ LockingStrength, _ string) ([]*route.Route, error) {
return nil, status.Errorf(status.Internal, "GetAccountRoutes is not implemented") return nil, status.Errorf(status.Internal, "GetAccountRoutes is not implemented")
} }
@@ -1043,7 +1044,7 @@ func (s *FileStore) GetRouteByID(_ context.Context, _ LockingStrength, _ string,
return nil, status.Errorf(status.Internal, "GetRouteByID is not implemented") return nil, status.Errorf(status.Internal, "GetRouteByID is not implemented")
} }
func (s *FileStore) GetAccountSetupKeys(_ context.Context, _ string) ([]*SetupKey, error) { func (s *FileStore) GetAccountSetupKeys(_ context.Context, _ LockingStrength, _ string) ([]*SetupKey, error) {
return nil, status.Errorf(status.Internal, "GetAccountSetupKeys is not implemented") return nil, status.Errorf(status.Internal, "GetAccountSetupKeys is not implemented")
} }
@@ -1051,7 +1052,7 @@ func (s *FileStore) GetSetupKeyByID(_ context.Context, _ LockingStrength, _ stri
return nil, status.Errorf(status.Internal, "GetSetupKeyByID is not implemented") return nil, status.Errorf(status.Internal, "GetSetupKeyByID is not implemented")
} }
func (s *FileStore) GetAccountNameServerGroups(_ context.Context, _ string) ([]*dns.NameServerGroup, error) { func (s *FileStore) GetAccountNameServerGroups(_ context.Context, _ LockingStrength, _ string) ([]*dns.NameServerGroup, error) {
return nil, status.Errorf(status.Internal, "GetAccountNameServerGroups is not implemented") return nil, status.Errorf(status.Internal, "GetAccountNameServerGroups is not implemented")
} }

View File

@@ -154,7 +154,7 @@ func (am *DefaultAccountManager) ListNameServerGroups(ctx context.Context, accou
return nil, status.Errorf(status.PermissionDenied, "only users with admin power can view name server groups") return nil, status.Errorf(status.PermissionDenied, "only users with admin power can view name server groups")
} }
return am.Store.GetAccountNameServerGroups(ctx, accountID) return am.Store.GetAccountNameServerGroups(ctx, LockingStrengthShare, accountID)
} }
func validateNameServerGroup(existingGroup bool, nameserverGroup *nbdns.NameServerGroup, account *Account) error { func validateNameServerGroup(existingGroup bool, nameserverGroup *nbdns.NameServerGroup, account *Account) error {

View File

@@ -443,7 +443,7 @@ func (am *DefaultAccountManager) ListPolicies(ctx context.Context, accountID, us
return nil, status.Errorf(status.PermissionDenied, "only admin users are allowed to view policies") return nil, status.Errorf(status.PermissionDenied, "only admin users are allowed to view policies")
} }
return am.Store.GetAccountPolicies(ctx, accountID) return am.Store.GetAccountPolicies(ctx, LockingStrengthShare, accountID)
} }
func (am *DefaultAccountManager) deletePolicy(account *Account, policyID string) (*Policy, error) { func (am *DefaultAccountManager) deletePolicy(account *Account, policyID string) (*Policy, error) {

View File

@@ -137,7 +137,7 @@ func (am *DefaultAccountManager) ListPostureChecks(ctx context.Context, accountI
return nil, status.Errorf(status.PermissionDenied, errMsgPostureAdminOnly) return nil, status.Errorf(status.PermissionDenied, errMsgPostureAdminOnly)
} }
return am.Store.GetAccountPostureChecks(ctx, accountID) return am.Store.GetAccountPostureChecks(ctx, LockingStrengthShare, accountID)
} }
// isPostureCheckLinkedToPolicy checks whether the posture check is linked to any account policy. // isPostureCheckLinkedToPolicy checks whether the posture check is linked to any account policy.

View File

@@ -321,7 +321,7 @@ func (am *DefaultAccountManager) ListRoutes(ctx context.Context, accountID, user
return nil, status.Errorf(status.PermissionDenied, "only users with admin power can view Network Routes") return nil, status.Errorf(status.PermissionDenied, "only users with admin power can view Network Routes")
} }
return am.Store.GetAccountRoutes(ctx, accountID) return am.Store.GetAccountRoutes(ctx, LockingStrengthShare, accountID)
} }
func toProtocolRoute(route *route.Route) *proto.Route { func toProtocolRoute(route *route.Route) *proto.Route {

View File

@@ -339,7 +339,7 @@ func (am *DefaultAccountManager) ListSetupKeys(ctx context.Context, accountID, u
return nil, status.Errorf(status.Unauthorized, "only users with admin power can view setup keys") return nil, status.Errorf(status.Unauthorized, "only users with admin power can view setup keys")
} }
setupKeys, err := am.Store.GetAccountSetupKeys(ctx, accountID) setupKeys, err := am.Store.GetAccountSetupKeys(ctx, LockingStrengthShare, accountID)
if err != nil { if err != nil {
return nil, err return nil, err
} }

View File

@@ -1056,10 +1056,11 @@ func (s *SqlStore) GetAccountDNSSettings(ctx context.Context, lockStrength Locki
} }
// AccountExists checks whether an account exists by the given ID. // AccountExists checks whether an account exists by the given ID.
func (s *SqlStore) AccountExists(ctx context.Context, id string) (bool, error) { func (s *SqlStore) AccountExists(ctx context.Context, lockStrength LockingStrength, id string) (bool, error) {
var count int64 var count int64
result := s.db.WithContext(ctx).Model(&Account{}).Where(idQueryCondition, id).Count(&count) result := s.db.WithContext(ctx).Clauses(clause.Locking{Strength: string(lockStrength)}).Model(&Account{}).
Where(idQueryCondition, id).Count(&count)
if result.Error != nil { if result.Error != nil {
return false, result.Error return false, result.Error
} }
@@ -1103,8 +1104,8 @@ func (s *SqlStore) GetGroupByName(ctx context.Context, lockStrength LockingStren
} }
// GetAccountPolicies retrieves policies for an account. // GetAccountPolicies retrieves policies for an account.
func (s *SqlStore) GetAccountPolicies(ctx context.Context, accountID string) ([]*Policy, error) { func (s *SqlStore) GetAccountPolicies(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*Policy, error) {
return getRecords[*Policy](s.db.WithContext(ctx).Preload(clause.Associations), accountID) return getRecords[*Policy](s.db.WithContext(ctx).Preload(clause.Associations), lockStrength, accountID)
} }
// GetPolicyByID retrieves a policy by its ID and account ID. // GetPolicyByID retrieves a policy by its ID and account ID.
@@ -1125,8 +1126,8 @@ func (s *SqlStore) DeletePolicy(ctx context.Context, lockStrength LockingStrengt
} }
// GetAccountPostureChecks retrieves posture checks for an account. // GetAccountPostureChecks retrieves posture checks for an account.
func (s *SqlStore) GetAccountPostureChecks(ctx context.Context, accountID string) ([]*posture.Checks, error) { func (s *SqlStore) GetAccountPostureChecks(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*posture.Checks, error) {
return getRecords[*posture.Checks](s.db.WithContext(ctx), accountID) return getRecords[*posture.Checks](s.db.WithContext(ctx), lockStrength, accountID)
} }
// GetPostureChecksByID retrieves posture checks by their ID and account ID. // GetPostureChecksByID retrieves posture checks by their ID and account ID.
@@ -1156,8 +1157,8 @@ func (s *SqlStore) DeletePostureChecks(ctx context.Context, lockStrength Locking
} }
// GetAccountRoutes retrieves network routes for an account. // GetAccountRoutes retrieves network routes for an account.
func (s *SqlStore) GetAccountRoutes(ctx context.Context, accountID string) ([]*route.Route, error) { func (s *SqlStore) GetAccountRoutes(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*route.Route, error) {
return getRecords[*route.Route](s.db.WithContext(ctx), accountID) return getRecords[*route.Route](s.db.WithContext(ctx), lockStrength, accountID)
} }
// GetRouteByID retrieves a route by its ID and account ID. // GetRouteByID retrieves a route by its ID and account ID.
@@ -1166,8 +1167,8 @@ func (s *SqlStore) GetRouteByID(ctx context.Context, lockStrength LockingStrengt
} }
// GetAccountSetupKeys retrieves setup keys for an account. // GetAccountSetupKeys retrieves setup keys for an account.
func (s *SqlStore) GetAccountSetupKeys(ctx context.Context, accountID string) ([]*SetupKey, error) { func (s *SqlStore) GetAccountSetupKeys(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*SetupKey, error) {
return getRecords[*SetupKey](s.db.WithContext(ctx), accountID) return getRecords[*SetupKey](s.db.WithContext(ctx), lockStrength, accountID)
} }
// GetSetupKeyByID retrieves a setup key by its ID and account ID. // GetSetupKeyByID retrieves a setup key by its ID and account ID.
@@ -1176,8 +1177,8 @@ func (s *SqlStore) GetSetupKeyByID(ctx context.Context, lockStrength LockingStre
} }
// GetAccountNameServerGroups retrieves name server groups for an account. // GetAccountNameServerGroups retrieves name server groups for an account.
func (s *SqlStore) GetAccountNameServerGroups(ctx context.Context, accountID string) ([]*nbdns.NameServerGroup, error) { func (s *SqlStore) GetAccountNameServerGroups(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*nbdns.NameServerGroup, error) {
return getRecords[*nbdns.NameServerGroup](s.db.WithContext(ctx), accountID) return getRecords[*nbdns.NameServerGroup](s.db.WithContext(ctx), lockStrength, accountID)
} }
// GetNameServerGroupByID retrieves a name server group by its ID and account ID. // GetNameServerGroupByID retrieves a name server group by its ID and account ID.
@@ -1186,10 +1187,10 @@ func (s *SqlStore) GetNameServerGroupByID(ctx context.Context, lockStrength Lock
} }
// getRecords retrieves records from the database based on the account ID. // getRecords retrieves records from the database based on the account ID.
func getRecords[T any](db *gorm.DB, accountID string) ([]T, error) { func getRecords[T any](db *gorm.DB, lockStrength LockingStrength, accountID string) ([]T, error) {
var record []T var record []T
result := db.Find(&record, accountIDCondition, accountID) result := db.Clauses(clause.Locking{Strength: string(lockStrength)}).Find(&record, accountIDCondition, accountID)
if err := result.Error; err != nil { if err := result.Error; err != nil {
parts := strings.Split(fmt.Sprintf("%T", record), ".") parts := strings.Split(fmt.Sprintf("%T", record), ".")
recordType := parts[len(parts)-1] recordType := parts[len(parts)-1]

View File

@@ -40,7 +40,7 @@ const (
type Store interface { type Store interface {
GetAllAccounts(ctx context.Context) []*Account GetAllAccounts(ctx context.Context) []*Account
GetAccount(ctx context.Context, accountID string) (*Account, error) GetAccount(ctx context.Context, accountID string) (*Account, error)
AccountExists(ctx context.Context, id string) (bool, error) AccountExists(ctx context.Context, lockStrength LockingStrength, id string) (bool, error)
GetAccountDomainAndCategory(ctx context.Context, lockStrength LockingStrength, accountID string) (string, string, error) GetAccountDomainAndCategory(ctx context.Context, lockStrength LockingStrength, accountID string) (string, string, error)
GetAccountByUser(ctx context.Context, userID string) (*Account, error) GetAccountByUser(ctx context.Context, userID string) (*Account, error)
GetAccountByPeerPubKey(ctx context.Context, peerKey string) (*Account, error) GetAccountByPeerPubKey(ctx context.Context, peerKey string) (*Account, error)
@@ -69,13 +69,13 @@ type Store interface {
GetGroupByName(ctx context.Context, lockStrength LockingStrength, groupName, accountID string) (*nbgroup.Group, error) GetGroupByName(ctx context.Context, lockStrength LockingStrength, groupName, accountID string) (*nbgroup.Group, error)
SaveGroups(accountID string, groups map[string]*nbgroup.Group) error SaveGroups(accountID string, groups map[string]*nbgroup.Group) error
GetAccountPolicies(ctx context.Context, accountID string) ([]*Policy, error) GetAccountPolicies(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*Policy, error)
GetPolicyByID(ctx context.Context, lockStrength LockingStrength, policyID string, accountID string) (*Policy, error) GetPolicyByID(ctx context.Context, lockStrength LockingStrength, policyID string, accountID string) (*Policy, error)
SavePolicy(ctx context.Context, lockStrength LockingStrength, policy *Policy) error SavePolicy(ctx context.Context, lockStrength LockingStrength, policy *Policy) error
DeletePolicy(ctx context.Context, lockStrength LockingStrength, postureCheckID string) error DeletePolicy(ctx context.Context, lockStrength LockingStrength, postureCheckID string) error
GetPostureCheckByChecksDefinition(accountID string, checks *posture.ChecksDefinition) (*posture.Checks, error) GetPostureCheckByChecksDefinition(accountID string, checks *posture.ChecksDefinition) (*posture.Checks, error)
GetAccountPostureChecks(ctx context.Context, accountID string) ([]*posture.Checks, error) GetAccountPostureChecks(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*posture.Checks, error)
GetPostureChecksByID(ctx context.Context, lockStrength LockingStrength, postureCheckID string, accountID string) (*posture.Checks, error) GetPostureChecksByID(ctx context.Context, lockStrength LockingStrength, postureCheckID string, accountID string) (*posture.Checks, error)
SavePostureChecks(ctx context.Context, lockStrength LockingStrength, postureCheck *posture.Checks) error SavePostureChecks(ctx context.Context, lockStrength LockingStrength, postureCheck *posture.Checks) error
DeletePostureChecks(ctx context.Context, lockStrength LockingStrength, postureChecksID string) error DeletePostureChecks(ctx context.Context, lockStrength LockingStrength, postureChecksID string) error
@@ -91,13 +91,13 @@ type Store interface {
GetSetupKeyBySecret(ctx context.Context, lockStrength LockingStrength, key string) (*SetupKey, error) GetSetupKeyBySecret(ctx context.Context, lockStrength LockingStrength, key string) (*SetupKey, error)
IncrementSetupKeyUsage(ctx context.Context, setupKeyID string) error IncrementSetupKeyUsage(ctx context.Context, setupKeyID string) error
GetAccountSetupKeys(ctx context.Context, accountID string) ([]*SetupKey, error) GetAccountSetupKeys(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*SetupKey, error)
GetSetupKeyByID(ctx context.Context, lockStrength LockingStrength, setupKeyID string, accountID string) (*SetupKey, error) GetSetupKeyByID(ctx context.Context, lockStrength LockingStrength, setupKeyID string, accountID string) (*SetupKey, error)
GetAccountRoutes(ctx context.Context, accountID string) ([]*route.Route, error) GetAccountRoutes(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*route.Route, error)
GetRouteByID(ctx context.Context, lockStrength LockingStrength, routeID string, accountID string) (*route.Route, error) GetRouteByID(ctx context.Context, lockStrength LockingStrength, routeID string, accountID string) (*route.Route, error)
GetAccountNameServerGroups(ctx context.Context, accountID string) ([]*dns.NameServerGroup, error) GetAccountNameServerGroups(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*dns.NameServerGroup, error)
GetNameServerGroupByID(ctx context.Context, lockStrength LockingStrength, nameServerGroupID string, accountID string) (*dns.NameServerGroup, error) GetNameServerGroupByID(ctx context.Context, lockStrength LockingStrength, nameServerGroupID string, accountID string) (*dns.NameServerGroup, error)
GetTakenIPs(ctx context.Context, lockStrength LockingStrength, accountId string) ([]net.IP, error) GetTakenIPs(ctx context.Context, lockStrength LockingStrength, accountId string) ([]net.IP, error)