mirror of
https://github.com/netbirdio/netbird.git
synced 2025-03-04 18:01:13 +01:00
Merge branch 'refactor-get-account-by-token' into refactor/get-account-usage
# Conflicts: # management/server/file_store.go
This commit is contained in:
commit
f61c914fd7
@ -1266,7 +1266,7 @@ func (am *DefaultAccountManager) DeleteAccount(ctx context.Context, accountID, u
|
||||
// Returns the account ID or an error if none is found or created.
|
||||
func (am *DefaultAccountManager) GetAccountIDByUserOrAccountID(ctx context.Context, userID, accountID, domain string) (string, error) {
|
||||
if accountID != "" {
|
||||
exists, err := am.Store.AccountExists(ctx, accountID)
|
||||
exists, err := am.Store.AccountExists(ctx, LockingStrengthShare, accountID)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
@ -647,7 +647,7 @@ func (s *FileStore) GetUserByUserID(_ context.Context, _ LockingStrength, userID
|
||||
return user, nil
|
||||
}
|
||||
|
||||
func (s *FileStore) GetAccountGroups(ctx context.Context, accountID string) ([]*nbgroup.Group, error) {
|
||||
func (s *FileStore) GetAccountGroups(_ context.Context, accountID string) ([]*nbgroup.Group, error) {
|
||||
account, err := s.getAccount(accountID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
@ -985,7 +985,7 @@ func (s *FileStore) GetAccountDomainAndCategory(_ context.Context, _ LockingStre
|
||||
}
|
||||
|
||||
// AccountExists checks whether an account exists by the given ID.
|
||||
func (s *FileStore) AccountExists(_ context.Context, id string) (bool, error) {
|
||||
func (s *FileStore) AccountExists(_ context.Context, _ LockingStrength, id string) (bool, error) {
|
||||
_, exists := s.Accounts[id]
|
||||
return exists, nil
|
||||
}
|
||||
@ -1002,7 +1002,7 @@ func (s *FileStore) GetGroupByName(_ context.Context, _ LockingStrength, _, _ st
|
||||
return nil, status.Errorf(status.Internal, "GetGroupByName is not implemented")
|
||||
}
|
||||
|
||||
func (s *FileStore) GetAccountPolicies(_ context.Context, _ string) ([]*Policy, error) {
|
||||
func (s *FileStore) GetAccountPolicies(_ context.Context, _ LockingStrength, _ string) ([]*Policy, error) {
|
||||
return nil, status.Errorf(status.Internal, "GetPolicyByID is not implemented")
|
||||
}
|
||||
|
||||
@ -1019,7 +1019,8 @@ func (s *FileStore) DeletePolicy(_ context.Context, _ LockingStrength, _ string)
|
||||
return status.Errorf(status.Internal, "DeletePolicy is not implemented")
|
||||
}
|
||||
|
||||
func (s *FileStore) GetAccountPostureChecks(_ context.Context, _ string) ([]*posture.Checks, error) {
|
||||
|
||||
func (s *FileStore) GetAccountPostureChecks(_ context.Context, _ LockingStrength, _ string) ([]*posture.Checks, error) {
|
||||
return nil, status.Errorf(status.Internal, "GetAccountPostureChecks is not implemented")
|
||||
}
|
||||
|
||||
@ -1035,7 +1036,7 @@ func (s *FileStore) DeletePostureChecks(_ context.Context, _ LockingStrength, _
|
||||
return status.Errorf(status.Internal, "DeletePostureChecks is not implemented")
|
||||
}
|
||||
|
||||
func (s *FileStore) GetAccountRoutes(_ context.Context, _ string) ([]*route.Route, error) {
|
||||
func (s *FileStore) GetAccountRoutes(_ context.Context, _ LockingStrength, _ string) ([]*route.Route, error) {
|
||||
return nil, status.Errorf(status.Internal, "GetAccountRoutes is not implemented")
|
||||
}
|
||||
|
||||
@ -1043,7 +1044,7 @@ func (s *FileStore) GetRouteByID(_ context.Context, _ LockingStrength, _ string,
|
||||
return nil, status.Errorf(status.Internal, "GetRouteByID is not implemented")
|
||||
}
|
||||
|
||||
func (s *FileStore) GetAccountSetupKeys(_ context.Context, _ string) ([]*SetupKey, error) {
|
||||
func (s *FileStore) GetAccountSetupKeys(_ context.Context, _ LockingStrength, _ string) ([]*SetupKey, error) {
|
||||
return nil, status.Errorf(status.Internal, "GetAccountSetupKeys is not implemented")
|
||||
}
|
||||
|
||||
@ -1051,7 +1052,7 @@ func (s *FileStore) GetSetupKeyByID(_ context.Context, _ LockingStrength, _ stri
|
||||
return nil, status.Errorf(status.Internal, "GetSetupKeyByID is not implemented")
|
||||
}
|
||||
|
||||
func (s *FileStore) GetAccountNameServerGroups(_ context.Context, _ string) ([]*dns.NameServerGroup, error) {
|
||||
func (s *FileStore) GetAccountNameServerGroups(_ context.Context, _ LockingStrength, _ string) ([]*dns.NameServerGroup, error) {
|
||||
return nil, status.Errorf(status.Internal, "GetAccountNameServerGroups is not implemented")
|
||||
}
|
||||
|
||||
|
@ -154,7 +154,7 @@ func (am *DefaultAccountManager) ListNameServerGroups(ctx context.Context, accou
|
||||
return nil, status.Errorf(status.PermissionDenied, "only users with admin power can view name server groups")
|
||||
}
|
||||
|
||||
return am.Store.GetAccountNameServerGroups(ctx, accountID)
|
||||
return am.Store.GetAccountNameServerGroups(ctx, LockingStrengthShare, accountID)
|
||||
}
|
||||
|
||||
func validateNameServerGroup(existingGroup bool, nameserverGroup *nbdns.NameServerGroup, account *Account) error {
|
||||
|
@ -443,7 +443,7 @@ func (am *DefaultAccountManager) ListPolicies(ctx context.Context, accountID, us
|
||||
return nil, status.Errorf(status.PermissionDenied, "only admin users are allowed to view policies")
|
||||
}
|
||||
|
||||
return am.Store.GetAccountPolicies(ctx, accountID)
|
||||
return am.Store.GetAccountPolicies(ctx, LockingStrengthShare, accountID)
|
||||
}
|
||||
|
||||
func (am *DefaultAccountManager) deletePolicy(account *Account, policyID string) (*Policy, error) {
|
||||
|
@ -137,7 +137,7 @@ func (am *DefaultAccountManager) ListPostureChecks(ctx context.Context, accountI
|
||||
return nil, status.Errorf(status.PermissionDenied, errMsgPostureAdminOnly)
|
||||
}
|
||||
|
||||
return am.Store.GetAccountPostureChecks(ctx, accountID)
|
||||
return am.Store.GetAccountPostureChecks(ctx, LockingStrengthShare, accountID)
|
||||
}
|
||||
|
||||
// isPostureCheckLinkedToPolicy checks whether the posture check is linked to any account policy.
|
||||
|
@ -321,7 +321,7 @@ func (am *DefaultAccountManager) ListRoutes(ctx context.Context, accountID, user
|
||||
return nil, status.Errorf(status.PermissionDenied, "only users with admin power can view Network Routes")
|
||||
}
|
||||
|
||||
return am.Store.GetAccountRoutes(ctx, accountID)
|
||||
return am.Store.GetAccountRoutes(ctx, LockingStrengthShare, accountID)
|
||||
}
|
||||
|
||||
func toProtocolRoute(route *route.Route) *proto.Route {
|
||||
|
@ -339,7 +339,7 @@ func (am *DefaultAccountManager) ListSetupKeys(ctx context.Context, accountID, u
|
||||
return nil, status.Errorf(status.Unauthorized, "only users with admin power can view setup keys")
|
||||
}
|
||||
|
||||
setupKeys, err := am.Store.GetAccountSetupKeys(ctx, accountID)
|
||||
setupKeys, err := am.Store.GetAccountSetupKeys(ctx, LockingStrengthShare, accountID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
@ -1056,10 +1056,11 @@ func (s *SqlStore) GetAccountDNSSettings(ctx context.Context, lockStrength Locki
|
||||
}
|
||||
|
||||
// AccountExists checks whether an account exists by the given ID.
|
||||
func (s *SqlStore) AccountExists(ctx context.Context, id string) (bool, error) {
|
||||
func (s *SqlStore) AccountExists(ctx context.Context, lockStrength LockingStrength, id string) (bool, error) {
|
||||
var count int64
|
||||
|
||||
result := s.db.WithContext(ctx).Model(&Account{}).Where(idQueryCondition, id).Count(&count)
|
||||
result := s.db.WithContext(ctx).Clauses(clause.Locking{Strength: string(lockStrength)}).Model(&Account{}).
|
||||
Where(idQueryCondition, id).Count(&count)
|
||||
if result.Error != nil {
|
||||
return false, result.Error
|
||||
}
|
||||
@ -1103,8 +1104,8 @@ func (s *SqlStore) GetGroupByName(ctx context.Context, lockStrength LockingStren
|
||||
}
|
||||
|
||||
// GetAccountPolicies retrieves policies for an account.
|
||||
func (s *SqlStore) GetAccountPolicies(ctx context.Context, accountID string) ([]*Policy, error) {
|
||||
return getRecords[*Policy](s.db.WithContext(ctx).Preload(clause.Associations), accountID)
|
||||
func (s *SqlStore) GetAccountPolicies(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*Policy, error) {
|
||||
return getRecords[*Policy](s.db.WithContext(ctx).Preload(clause.Associations), lockStrength, accountID)
|
||||
}
|
||||
|
||||
// GetPolicyByID retrieves a policy by its ID and account ID.
|
||||
@ -1125,8 +1126,8 @@ func (s *SqlStore) DeletePolicy(ctx context.Context, lockStrength LockingStrengt
|
||||
}
|
||||
|
||||
// GetAccountPostureChecks retrieves posture checks for an account.
|
||||
func (s *SqlStore) GetAccountPostureChecks(ctx context.Context, accountID string) ([]*posture.Checks, error) {
|
||||
return getRecords[*posture.Checks](s.db.WithContext(ctx), accountID)
|
||||
func (s *SqlStore) GetAccountPostureChecks(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*posture.Checks, error) {
|
||||
return getRecords[*posture.Checks](s.db.WithContext(ctx), lockStrength, accountID)
|
||||
}
|
||||
|
||||
// GetPostureChecksByID retrieves posture checks by their ID and account ID.
|
||||
@ -1156,8 +1157,8 @@ func (s *SqlStore) DeletePostureChecks(ctx context.Context, lockStrength Locking
|
||||
}
|
||||
|
||||
// GetAccountRoutes retrieves network routes for an account.
|
||||
func (s *SqlStore) GetAccountRoutes(ctx context.Context, accountID string) ([]*route.Route, error) {
|
||||
return getRecords[*route.Route](s.db.WithContext(ctx), accountID)
|
||||
func (s *SqlStore) GetAccountRoutes(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*route.Route, error) {
|
||||
return getRecords[*route.Route](s.db.WithContext(ctx), lockStrength, accountID)
|
||||
}
|
||||
|
||||
// GetRouteByID retrieves a route by its ID and account ID.
|
||||
@ -1166,8 +1167,8 @@ func (s *SqlStore) GetRouteByID(ctx context.Context, lockStrength LockingStrengt
|
||||
}
|
||||
|
||||
// GetAccountSetupKeys retrieves setup keys for an account.
|
||||
func (s *SqlStore) GetAccountSetupKeys(ctx context.Context, accountID string) ([]*SetupKey, error) {
|
||||
return getRecords[*SetupKey](s.db.WithContext(ctx), accountID)
|
||||
func (s *SqlStore) GetAccountSetupKeys(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*SetupKey, error) {
|
||||
return getRecords[*SetupKey](s.db.WithContext(ctx), lockStrength, accountID)
|
||||
}
|
||||
|
||||
// GetSetupKeyByID retrieves a setup key by its ID and account ID.
|
||||
@ -1176,8 +1177,8 @@ func (s *SqlStore) GetSetupKeyByID(ctx context.Context, lockStrength LockingStre
|
||||
}
|
||||
|
||||
// GetAccountNameServerGroups retrieves name server groups for an account.
|
||||
func (s *SqlStore) GetAccountNameServerGroups(ctx context.Context, accountID string) ([]*nbdns.NameServerGroup, error) {
|
||||
return getRecords[*nbdns.NameServerGroup](s.db.WithContext(ctx), accountID)
|
||||
func (s *SqlStore) GetAccountNameServerGroups(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*nbdns.NameServerGroup, error) {
|
||||
return getRecords[*nbdns.NameServerGroup](s.db.WithContext(ctx), lockStrength, accountID)
|
||||
}
|
||||
|
||||
// GetNameServerGroupByID retrieves a name server group by its ID and account ID.
|
||||
@ -1186,10 +1187,10 @@ func (s *SqlStore) GetNameServerGroupByID(ctx context.Context, lockStrength Lock
|
||||
}
|
||||
|
||||
// getRecords retrieves records from the database based on the account ID.
|
||||
func getRecords[T any](db *gorm.DB, accountID string) ([]T, error) {
|
||||
func getRecords[T any](db *gorm.DB, lockStrength LockingStrength, accountID string) ([]T, error) {
|
||||
var record []T
|
||||
|
||||
result := db.Find(&record, accountIDCondition, accountID)
|
||||
result := db.Clauses(clause.Locking{Strength: string(lockStrength)}).Find(&record, accountIDCondition, accountID)
|
||||
if err := result.Error; err != nil {
|
||||
parts := strings.Split(fmt.Sprintf("%T", record), ".")
|
||||
recordType := parts[len(parts)-1]
|
||||
|
@ -40,7 +40,7 @@ const (
|
||||
type Store interface {
|
||||
GetAllAccounts(ctx context.Context) []*Account
|
||||
GetAccount(ctx context.Context, accountID string) (*Account, error)
|
||||
AccountExists(ctx context.Context, id string) (bool, error)
|
||||
AccountExists(ctx context.Context, lockStrength LockingStrength, id string) (bool, error)
|
||||
GetAccountDomainAndCategory(ctx context.Context, lockStrength LockingStrength, accountID string) (string, string, error)
|
||||
GetAccountByUser(ctx context.Context, userID string) (*Account, error)
|
||||
GetAccountByPeerPubKey(ctx context.Context, peerKey string) (*Account, error)
|
||||
@ -69,13 +69,13 @@ type Store interface {
|
||||
GetGroupByName(ctx context.Context, lockStrength LockingStrength, groupName, accountID string) (*nbgroup.Group, error)
|
||||
SaveGroups(accountID string, groups map[string]*nbgroup.Group) error
|
||||
|
||||
GetAccountPolicies(ctx context.Context, accountID string) ([]*Policy, error)
|
||||
GetAccountPolicies(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*Policy, error)
|
||||
GetPolicyByID(ctx context.Context, lockStrength LockingStrength, policyID string, accountID string) (*Policy, error)
|
||||
SavePolicy(ctx context.Context, lockStrength LockingStrength, policy *Policy) error
|
||||
DeletePolicy(ctx context.Context, lockStrength LockingStrength, postureCheckID string) error
|
||||
|
||||
GetPostureCheckByChecksDefinition(accountID string, checks *posture.ChecksDefinition) (*posture.Checks, error)
|
||||
GetAccountPostureChecks(ctx context.Context, accountID string) ([]*posture.Checks, error)
|
||||
GetAccountPostureChecks(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*posture.Checks, error)
|
||||
GetPostureChecksByID(ctx context.Context, lockStrength LockingStrength, postureCheckID string, accountID string) (*posture.Checks, error)
|
||||
SavePostureChecks(ctx context.Context, lockStrength LockingStrength, postureCheck *posture.Checks) error
|
||||
DeletePostureChecks(ctx context.Context, lockStrength LockingStrength, postureChecksID string) error
|
||||
@ -91,13 +91,13 @@ type Store interface {
|
||||
|
||||
GetSetupKeyBySecret(ctx context.Context, lockStrength LockingStrength, key string) (*SetupKey, error)
|
||||
IncrementSetupKeyUsage(ctx context.Context, setupKeyID string) error
|
||||
GetAccountSetupKeys(ctx context.Context, accountID string) ([]*SetupKey, error)
|
||||
GetAccountSetupKeys(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*SetupKey, error)
|
||||
GetSetupKeyByID(ctx context.Context, lockStrength LockingStrength, setupKeyID string, accountID string) (*SetupKey, error)
|
||||
|
||||
GetAccountRoutes(ctx context.Context, accountID string) ([]*route.Route, error)
|
||||
GetAccountRoutes(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*route.Route, error)
|
||||
GetRouteByID(ctx context.Context, lockStrength LockingStrength, routeID string, accountID string) (*route.Route, error)
|
||||
|
||||
GetAccountNameServerGroups(ctx context.Context, accountID string) ([]*dns.NameServerGroup, error)
|
||||
GetAccountNameServerGroups(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*dns.NameServerGroup, error)
|
||||
GetNameServerGroupByID(ctx context.Context, lockStrength LockingStrength, nameServerGroupID string, accountID string) (*dns.NameServerGroup, error)
|
||||
|
||||
GetTakenIPs(ctx context.Context, lockStrength LockingStrength, accountId string) ([]net.IP, error)
|
||||
|
Loading…
Reference in New Issue
Block a user