mirror of
https://github.com/netbirdio/netbird.git
synced 2025-08-20 11:33:07 +02:00
Merge branch 'refactor-get-account-by-token' into refactor/get-account-usage
# Conflicts: # management/server/file_store.go
This commit is contained in:
@@ -1266,7 +1266,7 @@ func (am *DefaultAccountManager) DeleteAccount(ctx context.Context, accountID, u
|
|||||||
// Returns the account ID or an error if none is found or created.
|
// Returns the account ID or an error if none is found or created.
|
||||||
func (am *DefaultAccountManager) GetAccountIDByUserOrAccountID(ctx context.Context, userID, accountID, domain string) (string, error) {
|
func (am *DefaultAccountManager) GetAccountIDByUserOrAccountID(ctx context.Context, userID, accountID, domain string) (string, error) {
|
||||||
if accountID != "" {
|
if accountID != "" {
|
||||||
exists, err := am.Store.AccountExists(ctx, accountID)
|
exists, err := am.Store.AccountExists(ctx, LockingStrengthShare, accountID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return "", err
|
return "", err
|
||||||
}
|
}
|
||||||
|
@@ -647,7 +647,7 @@ func (s *FileStore) GetUserByUserID(_ context.Context, _ LockingStrength, userID
|
|||||||
return user, nil
|
return user, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *FileStore) GetAccountGroups(ctx context.Context, accountID string) ([]*nbgroup.Group, error) {
|
func (s *FileStore) GetAccountGroups(_ context.Context, accountID string) ([]*nbgroup.Group, error) {
|
||||||
account, err := s.getAccount(accountID)
|
account, err := s.getAccount(accountID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
@@ -985,7 +985,7 @@ func (s *FileStore) GetAccountDomainAndCategory(_ context.Context, _ LockingStre
|
|||||||
}
|
}
|
||||||
|
|
||||||
// AccountExists checks whether an account exists by the given ID.
|
// AccountExists checks whether an account exists by the given ID.
|
||||||
func (s *FileStore) AccountExists(_ context.Context, id string) (bool, error) {
|
func (s *FileStore) AccountExists(_ context.Context, _ LockingStrength, id string) (bool, error) {
|
||||||
_, exists := s.Accounts[id]
|
_, exists := s.Accounts[id]
|
||||||
return exists, nil
|
return exists, nil
|
||||||
}
|
}
|
||||||
@@ -1002,7 +1002,7 @@ func (s *FileStore) GetGroupByName(_ context.Context, _ LockingStrength, _, _ st
|
|||||||
return nil, status.Errorf(status.Internal, "GetGroupByName is not implemented")
|
return nil, status.Errorf(status.Internal, "GetGroupByName is not implemented")
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *FileStore) GetAccountPolicies(_ context.Context, _ string) ([]*Policy, error) {
|
func (s *FileStore) GetAccountPolicies(_ context.Context, _ LockingStrength, _ string) ([]*Policy, error) {
|
||||||
return nil, status.Errorf(status.Internal, "GetPolicyByID is not implemented")
|
return nil, status.Errorf(status.Internal, "GetPolicyByID is not implemented")
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -1019,7 +1019,8 @@ func (s *FileStore) DeletePolicy(_ context.Context, _ LockingStrength, _ string)
|
|||||||
return status.Errorf(status.Internal, "DeletePolicy is not implemented")
|
return status.Errorf(status.Internal, "DeletePolicy is not implemented")
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *FileStore) GetAccountPostureChecks(_ context.Context, _ string) ([]*posture.Checks, error) {
|
|
||||||
|
func (s *FileStore) GetAccountPostureChecks(_ context.Context, _ LockingStrength, _ string) ([]*posture.Checks, error) {
|
||||||
return nil, status.Errorf(status.Internal, "GetAccountPostureChecks is not implemented")
|
return nil, status.Errorf(status.Internal, "GetAccountPostureChecks is not implemented")
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -1035,7 +1036,7 @@ func (s *FileStore) DeletePostureChecks(_ context.Context, _ LockingStrength, _
|
|||||||
return status.Errorf(status.Internal, "DeletePostureChecks is not implemented")
|
return status.Errorf(status.Internal, "DeletePostureChecks is not implemented")
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *FileStore) GetAccountRoutes(_ context.Context, _ string) ([]*route.Route, error) {
|
func (s *FileStore) GetAccountRoutes(_ context.Context, _ LockingStrength, _ string) ([]*route.Route, error) {
|
||||||
return nil, status.Errorf(status.Internal, "GetAccountRoutes is not implemented")
|
return nil, status.Errorf(status.Internal, "GetAccountRoutes is not implemented")
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -1043,7 +1044,7 @@ func (s *FileStore) GetRouteByID(_ context.Context, _ LockingStrength, _ string,
|
|||||||
return nil, status.Errorf(status.Internal, "GetRouteByID is not implemented")
|
return nil, status.Errorf(status.Internal, "GetRouteByID is not implemented")
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *FileStore) GetAccountSetupKeys(_ context.Context, _ string) ([]*SetupKey, error) {
|
func (s *FileStore) GetAccountSetupKeys(_ context.Context, _ LockingStrength, _ string) ([]*SetupKey, error) {
|
||||||
return nil, status.Errorf(status.Internal, "GetAccountSetupKeys is not implemented")
|
return nil, status.Errorf(status.Internal, "GetAccountSetupKeys is not implemented")
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -1051,7 +1052,7 @@ func (s *FileStore) GetSetupKeyByID(_ context.Context, _ LockingStrength, _ stri
|
|||||||
return nil, status.Errorf(status.Internal, "GetSetupKeyByID is not implemented")
|
return nil, status.Errorf(status.Internal, "GetSetupKeyByID is not implemented")
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *FileStore) GetAccountNameServerGroups(_ context.Context, _ string) ([]*dns.NameServerGroup, error) {
|
func (s *FileStore) GetAccountNameServerGroups(_ context.Context, _ LockingStrength, _ string) ([]*dns.NameServerGroup, error) {
|
||||||
return nil, status.Errorf(status.Internal, "GetAccountNameServerGroups is not implemented")
|
return nil, status.Errorf(status.Internal, "GetAccountNameServerGroups is not implemented")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@@ -154,7 +154,7 @@ func (am *DefaultAccountManager) ListNameServerGroups(ctx context.Context, accou
|
|||||||
return nil, status.Errorf(status.PermissionDenied, "only users with admin power can view name server groups")
|
return nil, status.Errorf(status.PermissionDenied, "only users with admin power can view name server groups")
|
||||||
}
|
}
|
||||||
|
|
||||||
return am.Store.GetAccountNameServerGroups(ctx, accountID)
|
return am.Store.GetAccountNameServerGroups(ctx, LockingStrengthShare, accountID)
|
||||||
}
|
}
|
||||||
|
|
||||||
func validateNameServerGroup(existingGroup bool, nameserverGroup *nbdns.NameServerGroup, account *Account) error {
|
func validateNameServerGroup(existingGroup bool, nameserverGroup *nbdns.NameServerGroup, account *Account) error {
|
||||||
|
@@ -443,7 +443,7 @@ func (am *DefaultAccountManager) ListPolicies(ctx context.Context, accountID, us
|
|||||||
return nil, status.Errorf(status.PermissionDenied, "only admin users are allowed to view policies")
|
return nil, status.Errorf(status.PermissionDenied, "only admin users are allowed to view policies")
|
||||||
}
|
}
|
||||||
|
|
||||||
return am.Store.GetAccountPolicies(ctx, accountID)
|
return am.Store.GetAccountPolicies(ctx, LockingStrengthShare, accountID)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (am *DefaultAccountManager) deletePolicy(account *Account, policyID string) (*Policy, error) {
|
func (am *DefaultAccountManager) deletePolicy(account *Account, policyID string) (*Policy, error) {
|
||||||
|
@@ -137,7 +137,7 @@ func (am *DefaultAccountManager) ListPostureChecks(ctx context.Context, accountI
|
|||||||
return nil, status.Errorf(status.PermissionDenied, errMsgPostureAdminOnly)
|
return nil, status.Errorf(status.PermissionDenied, errMsgPostureAdminOnly)
|
||||||
}
|
}
|
||||||
|
|
||||||
return am.Store.GetAccountPostureChecks(ctx, accountID)
|
return am.Store.GetAccountPostureChecks(ctx, LockingStrengthShare, accountID)
|
||||||
}
|
}
|
||||||
|
|
||||||
// isPostureCheckLinkedToPolicy checks whether the posture check is linked to any account policy.
|
// isPostureCheckLinkedToPolicy checks whether the posture check is linked to any account policy.
|
||||||
|
@@ -321,7 +321,7 @@ func (am *DefaultAccountManager) ListRoutes(ctx context.Context, accountID, user
|
|||||||
return nil, status.Errorf(status.PermissionDenied, "only users with admin power can view Network Routes")
|
return nil, status.Errorf(status.PermissionDenied, "only users with admin power can view Network Routes")
|
||||||
}
|
}
|
||||||
|
|
||||||
return am.Store.GetAccountRoutes(ctx, accountID)
|
return am.Store.GetAccountRoutes(ctx, LockingStrengthShare, accountID)
|
||||||
}
|
}
|
||||||
|
|
||||||
func toProtocolRoute(route *route.Route) *proto.Route {
|
func toProtocolRoute(route *route.Route) *proto.Route {
|
||||||
|
@@ -339,7 +339,7 @@ func (am *DefaultAccountManager) ListSetupKeys(ctx context.Context, accountID, u
|
|||||||
return nil, status.Errorf(status.Unauthorized, "only users with admin power can view setup keys")
|
return nil, status.Errorf(status.Unauthorized, "only users with admin power can view setup keys")
|
||||||
}
|
}
|
||||||
|
|
||||||
setupKeys, err := am.Store.GetAccountSetupKeys(ctx, accountID)
|
setupKeys, err := am.Store.GetAccountSetupKeys(ctx, LockingStrengthShare, accountID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
@@ -1056,10 +1056,11 @@ func (s *SqlStore) GetAccountDNSSettings(ctx context.Context, lockStrength Locki
|
|||||||
}
|
}
|
||||||
|
|
||||||
// AccountExists checks whether an account exists by the given ID.
|
// AccountExists checks whether an account exists by the given ID.
|
||||||
func (s *SqlStore) AccountExists(ctx context.Context, id string) (bool, error) {
|
func (s *SqlStore) AccountExists(ctx context.Context, lockStrength LockingStrength, id string) (bool, error) {
|
||||||
var count int64
|
var count int64
|
||||||
|
|
||||||
result := s.db.WithContext(ctx).Model(&Account{}).Where(idQueryCondition, id).Count(&count)
|
result := s.db.WithContext(ctx).Clauses(clause.Locking{Strength: string(lockStrength)}).Model(&Account{}).
|
||||||
|
Where(idQueryCondition, id).Count(&count)
|
||||||
if result.Error != nil {
|
if result.Error != nil {
|
||||||
return false, result.Error
|
return false, result.Error
|
||||||
}
|
}
|
||||||
@@ -1103,8 +1104,8 @@ func (s *SqlStore) GetGroupByName(ctx context.Context, lockStrength LockingStren
|
|||||||
}
|
}
|
||||||
|
|
||||||
// GetAccountPolicies retrieves policies for an account.
|
// GetAccountPolicies retrieves policies for an account.
|
||||||
func (s *SqlStore) GetAccountPolicies(ctx context.Context, accountID string) ([]*Policy, error) {
|
func (s *SqlStore) GetAccountPolicies(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*Policy, error) {
|
||||||
return getRecords[*Policy](s.db.WithContext(ctx).Preload(clause.Associations), accountID)
|
return getRecords[*Policy](s.db.WithContext(ctx).Preload(clause.Associations), lockStrength, accountID)
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetPolicyByID retrieves a policy by its ID and account ID.
|
// GetPolicyByID retrieves a policy by its ID and account ID.
|
||||||
@@ -1125,8 +1126,8 @@ func (s *SqlStore) DeletePolicy(ctx context.Context, lockStrength LockingStrengt
|
|||||||
}
|
}
|
||||||
|
|
||||||
// GetAccountPostureChecks retrieves posture checks for an account.
|
// GetAccountPostureChecks retrieves posture checks for an account.
|
||||||
func (s *SqlStore) GetAccountPostureChecks(ctx context.Context, accountID string) ([]*posture.Checks, error) {
|
func (s *SqlStore) GetAccountPostureChecks(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*posture.Checks, error) {
|
||||||
return getRecords[*posture.Checks](s.db.WithContext(ctx), accountID)
|
return getRecords[*posture.Checks](s.db.WithContext(ctx), lockStrength, accountID)
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetPostureChecksByID retrieves posture checks by their ID and account ID.
|
// GetPostureChecksByID retrieves posture checks by their ID and account ID.
|
||||||
@@ -1156,8 +1157,8 @@ func (s *SqlStore) DeletePostureChecks(ctx context.Context, lockStrength Locking
|
|||||||
}
|
}
|
||||||
|
|
||||||
// GetAccountRoutes retrieves network routes for an account.
|
// GetAccountRoutes retrieves network routes for an account.
|
||||||
func (s *SqlStore) GetAccountRoutes(ctx context.Context, accountID string) ([]*route.Route, error) {
|
func (s *SqlStore) GetAccountRoutes(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*route.Route, error) {
|
||||||
return getRecords[*route.Route](s.db.WithContext(ctx), accountID)
|
return getRecords[*route.Route](s.db.WithContext(ctx), lockStrength, accountID)
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetRouteByID retrieves a route by its ID and account ID.
|
// GetRouteByID retrieves a route by its ID and account ID.
|
||||||
@@ -1166,8 +1167,8 @@ func (s *SqlStore) GetRouteByID(ctx context.Context, lockStrength LockingStrengt
|
|||||||
}
|
}
|
||||||
|
|
||||||
// GetAccountSetupKeys retrieves setup keys for an account.
|
// GetAccountSetupKeys retrieves setup keys for an account.
|
||||||
func (s *SqlStore) GetAccountSetupKeys(ctx context.Context, accountID string) ([]*SetupKey, error) {
|
func (s *SqlStore) GetAccountSetupKeys(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*SetupKey, error) {
|
||||||
return getRecords[*SetupKey](s.db.WithContext(ctx), accountID)
|
return getRecords[*SetupKey](s.db.WithContext(ctx), lockStrength, accountID)
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetSetupKeyByID retrieves a setup key by its ID and account ID.
|
// GetSetupKeyByID retrieves a setup key by its ID and account ID.
|
||||||
@@ -1176,8 +1177,8 @@ func (s *SqlStore) GetSetupKeyByID(ctx context.Context, lockStrength LockingStre
|
|||||||
}
|
}
|
||||||
|
|
||||||
// GetAccountNameServerGroups retrieves name server groups for an account.
|
// GetAccountNameServerGroups retrieves name server groups for an account.
|
||||||
func (s *SqlStore) GetAccountNameServerGroups(ctx context.Context, accountID string) ([]*nbdns.NameServerGroup, error) {
|
func (s *SqlStore) GetAccountNameServerGroups(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*nbdns.NameServerGroup, error) {
|
||||||
return getRecords[*nbdns.NameServerGroup](s.db.WithContext(ctx), accountID)
|
return getRecords[*nbdns.NameServerGroup](s.db.WithContext(ctx), lockStrength, accountID)
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetNameServerGroupByID retrieves a name server group by its ID and account ID.
|
// GetNameServerGroupByID retrieves a name server group by its ID and account ID.
|
||||||
@@ -1186,10 +1187,10 @@ func (s *SqlStore) GetNameServerGroupByID(ctx context.Context, lockStrength Lock
|
|||||||
}
|
}
|
||||||
|
|
||||||
// getRecords retrieves records from the database based on the account ID.
|
// getRecords retrieves records from the database based on the account ID.
|
||||||
func getRecords[T any](db *gorm.DB, accountID string) ([]T, error) {
|
func getRecords[T any](db *gorm.DB, lockStrength LockingStrength, accountID string) ([]T, error) {
|
||||||
var record []T
|
var record []T
|
||||||
|
|
||||||
result := db.Find(&record, accountIDCondition, accountID)
|
result := db.Clauses(clause.Locking{Strength: string(lockStrength)}).Find(&record, accountIDCondition, accountID)
|
||||||
if err := result.Error; err != nil {
|
if err := result.Error; err != nil {
|
||||||
parts := strings.Split(fmt.Sprintf("%T", record), ".")
|
parts := strings.Split(fmt.Sprintf("%T", record), ".")
|
||||||
recordType := parts[len(parts)-1]
|
recordType := parts[len(parts)-1]
|
||||||
|
@@ -40,7 +40,7 @@ const (
|
|||||||
type Store interface {
|
type Store interface {
|
||||||
GetAllAccounts(ctx context.Context) []*Account
|
GetAllAccounts(ctx context.Context) []*Account
|
||||||
GetAccount(ctx context.Context, accountID string) (*Account, error)
|
GetAccount(ctx context.Context, accountID string) (*Account, error)
|
||||||
AccountExists(ctx context.Context, id string) (bool, error)
|
AccountExists(ctx context.Context, lockStrength LockingStrength, id string) (bool, error)
|
||||||
GetAccountDomainAndCategory(ctx context.Context, lockStrength LockingStrength, accountID string) (string, string, error)
|
GetAccountDomainAndCategory(ctx context.Context, lockStrength LockingStrength, accountID string) (string, string, error)
|
||||||
GetAccountByUser(ctx context.Context, userID string) (*Account, error)
|
GetAccountByUser(ctx context.Context, userID string) (*Account, error)
|
||||||
GetAccountByPeerPubKey(ctx context.Context, peerKey string) (*Account, error)
|
GetAccountByPeerPubKey(ctx context.Context, peerKey string) (*Account, error)
|
||||||
@@ -69,13 +69,13 @@ type Store interface {
|
|||||||
GetGroupByName(ctx context.Context, lockStrength LockingStrength, groupName, accountID string) (*nbgroup.Group, error)
|
GetGroupByName(ctx context.Context, lockStrength LockingStrength, groupName, accountID string) (*nbgroup.Group, error)
|
||||||
SaveGroups(accountID string, groups map[string]*nbgroup.Group) error
|
SaveGroups(accountID string, groups map[string]*nbgroup.Group) error
|
||||||
|
|
||||||
GetAccountPolicies(ctx context.Context, accountID string) ([]*Policy, error)
|
GetAccountPolicies(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*Policy, error)
|
||||||
GetPolicyByID(ctx context.Context, lockStrength LockingStrength, policyID string, accountID string) (*Policy, error)
|
GetPolicyByID(ctx context.Context, lockStrength LockingStrength, policyID string, accountID string) (*Policy, error)
|
||||||
SavePolicy(ctx context.Context, lockStrength LockingStrength, policy *Policy) error
|
SavePolicy(ctx context.Context, lockStrength LockingStrength, policy *Policy) error
|
||||||
DeletePolicy(ctx context.Context, lockStrength LockingStrength, postureCheckID string) error
|
DeletePolicy(ctx context.Context, lockStrength LockingStrength, postureCheckID string) error
|
||||||
|
|
||||||
GetPostureCheckByChecksDefinition(accountID string, checks *posture.ChecksDefinition) (*posture.Checks, error)
|
GetPostureCheckByChecksDefinition(accountID string, checks *posture.ChecksDefinition) (*posture.Checks, error)
|
||||||
GetAccountPostureChecks(ctx context.Context, accountID string) ([]*posture.Checks, error)
|
GetAccountPostureChecks(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*posture.Checks, error)
|
||||||
GetPostureChecksByID(ctx context.Context, lockStrength LockingStrength, postureCheckID string, accountID string) (*posture.Checks, error)
|
GetPostureChecksByID(ctx context.Context, lockStrength LockingStrength, postureCheckID string, accountID string) (*posture.Checks, error)
|
||||||
SavePostureChecks(ctx context.Context, lockStrength LockingStrength, postureCheck *posture.Checks) error
|
SavePostureChecks(ctx context.Context, lockStrength LockingStrength, postureCheck *posture.Checks) error
|
||||||
DeletePostureChecks(ctx context.Context, lockStrength LockingStrength, postureChecksID string) error
|
DeletePostureChecks(ctx context.Context, lockStrength LockingStrength, postureChecksID string) error
|
||||||
@@ -91,13 +91,13 @@ type Store interface {
|
|||||||
|
|
||||||
GetSetupKeyBySecret(ctx context.Context, lockStrength LockingStrength, key string) (*SetupKey, error)
|
GetSetupKeyBySecret(ctx context.Context, lockStrength LockingStrength, key string) (*SetupKey, error)
|
||||||
IncrementSetupKeyUsage(ctx context.Context, setupKeyID string) error
|
IncrementSetupKeyUsage(ctx context.Context, setupKeyID string) error
|
||||||
GetAccountSetupKeys(ctx context.Context, accountID string) ([]*SetupKey, error)
|
GetAccountSetupKeys(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*SetupKey, error)
|
||||||
GetSetupKeyByID(ctx context.Context, lockStrength LockingStrength, setupKeyID string, accountID string) (*SetupKey, error)
|
GetSetupKeyByID(ctx context.Context, lockStrength LockingStrength, setupKeyID string, accountID string) (*SetupKey, error)
|
||||||
|
|
||||||
GetAccountRoutes(ctx context.Context, accountID string) ([]*route.Route, error)
|
GetAccountRoutes(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*route.Route, error)
|
||||||
GetRouteByID(ctx context.Context, lockStrength LockingStrength, routeID string, accountID string) (*route.Route, error)
|
GetRouteByID(ctx context.Context, lockStrength LockingStrength, routeID string, accountID string) (*route.Route, error)
|
||||||
|
|
||||||
GetAccountNameServerGroups(ctx context.Context, accountID string) ([]*dns.NameServerGroup, error)
|
GetAccountNameServerGroups(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*dns.NameServerGroup, error)
|
||||||
GetNameServerGroupByID(ctx context.Context, lockStrength LockingStrength, nameServerGroupID string, accountID string) (*dns.NameServerGroup, error)
|
GetNameServerGroupByID(ctx context.Context, lockStrength LockingStrength, nameServerGroupID string, accountID string) (*dns.NameServerGroup, error)
|
||||||
|
|
||||||
GetTakenIPs(ctx context.Context, lockStrength LockingStrength, accountId string) ([]net.IP, error)
|
GetTakenIPs(ctx context.Context, lockStrength LockingStrength, accountId string) ([]net.IP, error)
|
||||||
|
Reference in New Issue
Block a user