Merge branch 'refactor-get-account-by-token' into refactor/get-account-usage

# Conflicts:
#	management/server/file_store.go
This commit is contained in:
bcmmbaga 2024-09-26 18:51:47 +03:00
commit f61c914fd7
No known key found for this signature in database
GPG Key ID: 511EED5C928AD547
9 changed files with 35 additions and 33 deletions

View File

@ -1266,7 +1266,7 @@ func (am *DefaultAccountManager) DeleteAccount(ctx context.Context, accountID, u
// Returns the account ID or an error if none is found or created.
func (am *DefaultAccountManager) GetAccountIDByUserOrAccountID(ctx context.Context, userID, accountID, domain string) (string, error) {
if accountID != "" {
exists, err := am.Store.AccountExists(ctx, accountID)
exists, err := am.Store.AccountExists(ctx, LockingStrengthShare, accountID)
if err != nil {
return "", err
}

View File

@ -647,7 +647,7 @@ func (s *FileStore) GetUserByUserID(_ context.Context, _ LockingStrength, userID
return user, nil
}
func (s *FileStore) GetAccountGroups(ctx context.Context, accountID string) ([]*nbgroup.Group, error) {
func (s *FileStore) GetAccountGroups(_ context.Context, accountID string) ([]*nbgroup.Group, error) {
account, err := s.getAccount(accountID)
if err != nil {
return nil, err
@ -985,7 +985,7 @@ func (s *FileStore) GetAccountDomainAndCategory(_ context.Context, _ LockingStre
}
// AccountExists checks whether an account exists by the given ID.
func (s *FileStore) AccountExists(_ context.Context, id string) (bool, error) {
func (s *FileStore) AccountExists(_ context.Context, _ LockingStrength, id string) (bool, error) {
_, exists := s.Accounts[id]
return exists, nil
}
@ -1002,7 +1002,7 @@ func (s *FileStore) GetGroupByName(_ context.Context, _ LockingStrength, _, _ st
return nil, status.Errorf(status.Internal, "GetGroupByName is not implemented")
}
func (s *FileStore) GetAccountPolicies(_ context.Context, _ string) ([]*Policy, error) {
func (s *FileStore) GetAccountPolicies(_ context.Context, _ LockingStrength, _ string) ([]*Policy, error) {
return nil, status.Errorf(status.Internal, "GetPolicyByID is not implemented")
}
@ -1019,7 +1019,8 @@ func (s *FileStore) DeletePolicy(_ context.Context, _ LockingStrength, _ string)
return status.Errorf(status.Internal, "DeletePolicy is not implemented")
}
func (s *FileStore) GetAccountPostureChecks(_ context.Context, _ string) ([]*posture.Checks, error) {
func (s *FileStore) GetAccountPostureChecks(_ context.Context, _ LockingStrength, _ string) ([]*posture.Checks, error) {
return nil, status.Errorf(status.Internal, "GetAccountPostureChecks is not implemented")
}
@ -1035,7 +1036,7 @@ func (s *FileStore) DeletePostureChecks(_ context.Context, _ LockingStrength, _
return status.Errorf(status.Internal, "DeletePostureChecks is not implemented")
}
func (s *FileStore) GetAccountRoutes(_ context.Context, _ string) ([]*route.Route, error) {
func (s *FileStore) GetAccountRoutes(_ context.Context, _ LockingStrength, _ string) ([]*route.Route, error) {
return nil, status.Errorf(status.Internal, "GetAccountRoutes is not implemented")
}
@ -1043,7 +1044,7 @@ func (s *FileStore) GetRouteByID(_ context.Context, _ LockingStrength, _ string,
return nil, status.Errorf(status.Internal, "GetRouteByID is not implemented")
}
func (s *FileStore) GetAccountSetupKeys(_ context.Context, _ string) ([]*SetupKey, error) {
func (s *FileStore) GetAccountSetupKeys(_ context.Context, _ LockingStrength, _ string) ([]*SetupKey, error) {
return nil, status.Errorf(status.Internal, "GetAccountSetupKeys is not implemented")
}
@ -1051,7 +1052,7 @@ func (s *FileStore) GetSetupKeyByID(_ context.Context, _ LockingStrength, _ stri
return nil, status.Errorf(status.Internal, "GetSetupKeyByID is not implemented")
}
func (s *FileStore) GetAccountNameServerGroups(_ context.Context, _ string) ([]*dns.NameServerGroup, error) {
func (s *FileStore) GetAccountNameServerGroups(_ context.Context, _ LockingStrength, _ string) ([]*dns.NameServerGroup, error) {
return nil, status.Errorf(status.Internal, "GetAccountNameServerGroups is not implemented")
}

View File

@ -154,7 +154,7 @@ func (am *DefaultAccountManager) ListNameServerGroups(ctx context.Context, accou
return nil, status.Errorf(status.PermissionDenied, "only users with admin power can view name server groups")
}
return am.Store.GetAccountNameServerGroups(ctx, accountID)
return am.Store.GetAccountNameServerGroups(ctx, LockingStrengthShare, accountID)
}
func validateNameServerGroup(existingGroup bool, nameserverGroup *nbdns.NameServerGroup, account *Account) error {

View File

@ -443,7 +443,7 @@ func (am *DefaultAccountManager) ListPolicies(ctx context.Context, accountID, us
return nil, status.Errorf(status.PermissionDenied, "only admin users are allowed to view policies")
}
return am.Store.GetAccountPolicies(ctx, accountID)
return am.Store.GetAccountPolicies(ctx, LockingStrengthShare, accountID)
}
func (am *DefaultAccountManager) deletePolicy(account *Account, policyID string) (*Policy, error) {

View File

@ -137,7 +137,7 @@ func (am *DefaultAccountManager) ListPostureChecks(ctx context.Context, accountI
return nil, status.Errorf(status.PermissionDenied, errMsgPostureAdminOnly)
}
return am.Store.GetAccountPostureChecks(ctx, accountID)
return am.Store.GetAccountPostureChecks(ctx, LockingStrengthShare, accountID)
}
// isPostureCheckLinkedToPolicy checks whether the posture check is linked to any account policy.

View File

@ -321,7 +321,7 @@ func (am *DefaultAccountManager) ListRoutes(ctx context.Context, accountID, user
return nil, status.Errorf(status.PermissionDenied, "only users with admin power can view Network Routes")
}
return am.Store.GetAccountRoutes(ctx, accountID)
return am.Store.GetAccountRoutes(ctx, LockingStrengthShare, accountID)
}
func toProtocolRoute(route *route.Route) *proto.Route {

View File

@ -339,7 +339,7 @@ func (am *DefaultAccountManager) ListSetupKeys(ctx context.Context, accountID, u
return nil, status.Errorf(status.Unauthorized, "only users with admin power can view setup keys")
}
setupKeys, err := am.Store.GetAccountSetupKeys(ctx, accountID)
setupKeys, err := am.Store.GetAccountSetupKeys(ctx, LockingStrengthShare, accountID)
if err != nil {
return nil, err
}

View File

@ -1056,10 +1056,11 @@ func (s *SqlStore) GetAccountDNSSettings(ctx context.Context, lockStrength Locki
}
// AccountExists checks whether an account exists by the given ID.
func (s *SqlStore) AccountExists(ctx context.Context, id string) (bool, error) {
func (s *SqlStore) AccountExists(ctx context.Context, lockStrength LockingStrength, id string) (bool, error) {
var count int64
result := s.db.WithContext(ctx).Model(&Account{}).Where(idQueryCondition, id).Count(&count)
result := s.db.WithContext(ctx).Clauses(clause.Locking{Strength: string(lockStrength)}).Model(&Account{}).
Where(idQueryCondition, id).Count(&count)
if result.Error != nil {
return false, result.Error
}
@ -1103,8 +1104,8 @@ func (s *SqlStore) GetGroupByName(ctx context.Context, lockStrength LockingStren
}
// GetAccountPolicies retrieves policies for an account.
func (s *SqlStore) GetAccountPolicies(ctx context.Context, accountID string) ([]*Policy, error) {
return getRecords[*Policy](s.db.WithContext(ctx).Preload(clause.Associations), accountID)
func (s *SqlStore) GetAccountPolicies(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*Policy, error) {
return getRecords[*Policy](s.db.WithContext(ctx).Preload(clause.Associations), lockStrength, accountID)
}
// GetPolicyByID retrieves a policy by its ID and account ID.
@ -1125,8 +1126,8 @@ func (s *SqlStore) DeletePolicy(ctx context.Context, lockStrength LockingStrengt
}
// GetAccountPostureChecks retrieves posture checks for an account.
func (s *SqlStore) GetAccountPostureChecks(ctx context.Context, accountID string) ([]*posture.Checks, error) {
return getRecords[*posture.Checks](s.db.WithContext(ctx), accountID)
func (s *SqlStore) GetAccountPostureChecks(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*posture.Checks, error) {
return getRecords[*posture.Checks](s.db.WithContext(ctx), lockStrength, accountID)
}
// GetPostureChecksByID retrieves posture checks by their ID and account ID.
@ -1156,8 +1157,8 @@ func (s *SqlStore) DeletePostureChecks(ctx context.Context, lockStrength Locking
}
// GetAccountRoutes retrieves network routes for an account.
func (s *SqlStore) GetAccountRoutes(ctx context.Context, accountID string) ([]*route.Route, error) {
return getRecords[*route.Route](s.db.WithContext(ctx), accountID)
func (s *SqlStore) GetAccountRoutes(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*route.Route, error) {
return getRecords[*route.Route](s.db.WithContext(ctx), lockStrength, accountID)
}
// GetRouteByID retrieves a route by its ID and account ID.
@ -1166,8 +1167,8 @@ func (s *SqlStore) GetRouteByID(ctx context.Context, lockStrength LockingStrengt
}
// GetAccountSetupKeys retrieves setup keys for an account.
func (s *SqlStore) GetAccountSetupKeys(ctx context.Context, accountID string) ([]*SetupKey, error) {
return getRecords[*SetupKey](s.db.WithContext(ctx), accountID)
func (s *SqlStore) GetAccountSetupKeys(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*SetupKey, error) {
return getRecords[*SetupKey](s.db.WithContext(ctx), lockStrength, accountID)
}
// GetSetupKeyByID retrieves a setup key by its ID and account ID.
@ -1176,8 +1177,8 @@ func (s *SqlStore) GetSetupKeyByID(ctx context.Context, lockStrength LockingStre
}
// GetAccountNameServerGroups retrieves name server groups for an account.
func (s *SqlStore) GetAccountNameServerGroups(ctx context.Context, accountID string) ([]*nbdns.NameServerGroup, error) {
return getRecords[*nbdns.NameServerGroup](s.db.WithContext(ctx), accountID)
func (s *SqlStore) GetAccountNameServerGroups(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*nbdns.NameServerGroup, error) {
return getRecords[*nbdns.NameServerGroup](s.db.WithContext(ctx), lockStrength, accountID)
}
// GetNameServerGroupByID retrieves a name server group by its ID and account ID.
@ -1186,10 +1187,10 @@ func (s *SqlStore) GetNameServerGroupByID(ctx context.Context, lockStrength Lock
}
// getRecords retrieves records from the database based on the account ID.
func getRecords[T any](db *gorm.DB, accountID string) ([]T, error) {
func getRecords[T any](db *gorm.DB, lockStrength LockingStrength, accountID string) ([]T, error) {
var record []T
result := db.Find(&record, accountIDCondition, accountID)
result := db.Clauses(clause.Locking{Strength: string(lockStrength)}).Find(&record, accountIDCondition, accountID)
if err := result.Error; err != nil {
parts := strings.Split(fmt.Sprintf("%T", record), ".")
recordType := parts[len(parts)-1]

View File

@ -40,7 +40,7 @@ const (
type Store interface {
GetAllAccounts(ctx context.Context) []*Account
GetAccount(ctx context.Context, accountID string) (*Account, error)
AccountExists(ctx context.Context, id string) (bool, error)
AccountExists(ctx context.Context, lockStrength LockingStrength, id string) (bool, error)
GetAccountDomainAndCategory(ctx context.Context, lockStrength LockingStrength, accountID string) (string, string, error)
GetAccountByUser(ctx context.Context, userID string) (*Account, error)
GetAccountByPeerPubKey(ctx context.Context, peerKey string) (*Account, error)
@ -69,13 +69,13 @@ type Store interface {
GetGroupByName(ctx context.Context, lockStrength LockingStrength, groupName, accountID string) (*nbgroup.Group, error)
SaveGroups(accountID string, groups map[string]*nbgroup.Group) error
GetAccountPolicies(ctx context.Context, accountID string) ([]*Policy, error)
GetAccountPolicies(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*Policy, error)
GetPolicyByID(ctx context.Context, lockStrength LockingStrength, policyID string, accountID string) (*Policy, error)
SavePolicy(ctx context.Context, lockStrength LockingStrength, policy *Policy) error
DeletePolicy(ctx context.Context, lockStrength LockingStrength, postureCheckID string) error
GetPostureCheckByChecksDefinition(accountID string, checks *posture.ChecksDefinition) (*posture.Checks, error)
GetAccountPostureChecks(ctx context.Context, accountID string) ([]*posture.Checks, error)
GetAccountPostureChecks(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*posture.Checks, error)
GetPostureChecksByID(ctx context.Context, lockStrength LockingStrength, postureCheckID string, accountID string) (*posture.Checks, error)
SavePostureChecks(ctx context.Context, lockStrength LockingStrength, postureCheck *posture.Checks) error
DeletePostureChecks(ctx context.Context, lockStrength LockingStrength, postureChecksID string) error
@ -91,13 +91,13 @@ type Store interface {
GetSetupKeyBySecret(ctx context.Context, lockStrength LockingStrength, key string) (*SetupKey, error)
IncrementSetupKeyUsage(ctx context.Context, setupKeyID string) error
GetAccountSetupKeys(ctx context.Context, accountID string) ([]*SetupKey, error)
GetAccountSetupKeys(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*SetupKey, error)
GetSetupKeyByID(ctx context.Context, lockStrength LockingStrength, setupKeyID string, accountID string) (*SetupKey, error)
GetAccountRoutes(ctx context.Context, accountID string) ([]*route.Route, error)
GetAccountRoutes(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*route.Route, error)
GetRouteByID(ctx context.Context, lockStrength LockingStrength, routeID string, accountID string) (*route.Route, error)
GetAccountNameServerGroups(ctx context.Context, accountID string) ([]*dns.NameServerGroup, error)
GetAccountNameServerGroups(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*dns.NameServerGroup, error)
GetNameServerGroupByID(ctx context.Context, lockStrength LockingStrength, nameServerGroupID string, accountID string) (*dns.NameServerGroup, error)
GetTakenIPs(ctx context.Context, lockStrength LockingStrength, accountId string) ([]net.IP, error)