mirror of
https://github.com/netbirdio/netbird.git
synced 2024-12-13 10:21:10 +01:00
Refactor store
Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>
This commit is contained in:
parent
16174f0478
commit
41b212f610
@ -10,11 +10,13 @@ import (
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/netbirdio/netbird/dns"
|
||||
nbgroup "github.com/netbirdio/netbird/management/server/group"
|
||||
nbpeer "github.com/netbirdio/netbird/management/server/peer"
|
||||
"github.com/netbirdio/netbird/management/server/posture"
|
||||
"github.com/netbirdio/netbird/management/server/status"
|
||||
"github.com/netbirdio/netbird/management/server/telemetry"
|
||||
"github.com/netbirdio/netbird/route"
|
||||
"github.com/rs/xid"
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
@ -979,11 +981,15 @@ func (s *FileStore) AccountExists(_ context.Context, id string) (bool, error) {
|
||||
return exists, nil
|
||||
}
|
||||
|
||||
func (s *FileStore) GetAccountDNSSettings(_ context.Context, _ LockingStrength, _ string) (*DNSSettings, error) {
|
||||
return nil, status.Errorf(status.Internal, "GetAccountDNSSettings is not implemented")
|
||||
}
|
||||
|
||||
func (s *FileStore) UpdateAccount(_ context.Context, _ LockingStrength, _ *Account) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *FileStore) GetGroupByID(_ context.Context, _, _ string) (*nbgroup.Group, error) {
|
||||
func (s *FileStore) GetGroupByID(_ context.Context, _ LockingStrength, _, _ string) (*nbgroup.Group, error) {
|
||||
return nil, status.Errorf(status.Internal, "GetGroupByID is not implemented")
|
||||
}
|
||||
|
||||
@ -1007,3 +1013,27 @@ func (s *FileStore) GetAccountPostureChecks(_ context.Context, _ string) ([]*pos
|
||||
func (s *FileStore) GetPostureChecksByID(_ context.Context, _ LockingStrength, _ string, _ string) (*posture.Checks, error) {
|
||||
return nil, status.Errorf(status.Internal, "GetPostureChecksByID is not implemented")
|
||||
}
|
||||
|
||||
func (s *FileStore) GetAccountRoutes(_ context.Context, _ string) ([]*route.Route, error) {
|
||||
return nil, status.Errorf(status.Internal, "GetAccountRoutes is not implemented")
|
||||
}
|
||||
|
||||
func (s *FileStore) GetRouteByID(_ context.Context, _ LockingStrength, _ string, _ string) (*route.Route, error) {
|
||||
return nil, status.Errorf(status.Internal, "GetRouteByID is not implemented")
|
||||
}
|
||||
|
||||
func (s *FileStore) GetAccountSetupKeys(_ context.Context, _ string) ([]*SetupKey, error) {
|
||||
return nil, status.Errorf(status.Internal, "GetAccountSetupKeys is not implemented")
|
||||
}
|
||||
|
||||
func (s *FileStore) GetSetupKeyByID(_ context.Context, _ LockingStrength, _ string, _ string) (*SetupKey, error) {
|
||||
return nil, status.Errorf(status.Internal, "GetSetupKeyByID is not implemented")
|
||||
}
|
||||
|
||||
func (s *FileStore) GetAccountNameServerGroups(_ context.Context, _ string) ([]*dns.NameServerGroup, error) {
|
||||
return nil, status.Errorf(status.Internal, "GetAccountNameServerGroups is not implemented")
|
||||
}
|
||||
|
||||
func (s *FileStore) GetNameServerGroupByID(_ context.Context, _ LockingStrength, _ string, _ string) (*dns.NameServerGroup, error) {
|
||||
return nil, status.Errorf(status.Internal, "GetNameServerGroupByID is not implemented")
|
||||
}
|
||||
|
@ -1035,6 +1035,20 @@ func (s *SqlStore) withTx(tx *gorm.DB) Store {
|
||||
}
|
||||
}
|
||||
|
||||
func (s *SqlStore) GetAccountDNSSettings(ctx context.Context, lockStrength LockingStrength, accountID string) (*DNSSettings, error) {
|
||||
var accountDNSSettings AccountDNSSettings
|
||||
|
||||
result := s.db.WithContext(ctx).Clauses(clause.Locking{Strength: string(lockStrength)}).Model(&Account{}).
|
||||
First(&accountDNSSettings, idQueryCondition, accountID)
|
||||
if result.Error != nil {
|
||||
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
|
||||
return nil, status.Errorf(status.NotFound, "dns settings not found")
|
||||
}
|
||||
return nil, status.Errorf(status.Internal, "failed to get dns settings from store: %v", result.Error)
|
||||
}
|
||||
return &accountDNSSettings.DNSSettings, nil
|
||||
}
|
||||
|
||||
// UpdateAccount updates an existing account's domain, DNS settings, and settings fields.
|
||||
func (s *SqlStore) UpdateAccount(ctx context.Context, lockStrength LockingStrength, account *Account) error {
|
||||
updates := make(map[string]interface{})
|
||||
@ -1071,6 +1085,7 @@ func (s *SqlStore) UpdateAccount(ctx context.Context, lockStrength LockingStreng
|
||||
// AccountExists checks whether an account exists by the given ID.
|
||||
func (s *SqlStore) AccountExists(ctx context.Context, id string) (bool, error) {
|
||||
var count int64
|
||||
|
||||
result := s.db.WithContext(ctx).Model(&Account{}).Where(idQueryCondition, id).Count(&count)
|
||||
if result.Error != nil {
|
||||
return false, result.Error
|
||||
@ -1081,6 +1096,7 @@ func (s *SqlStore) AccountExists(ctx context.Context, id string) (bool, error) {
|
||||
// GetAccountDomainAndCategory retrieves the Domain and DomainCategory fields for an account based on the given accountID.
|
||||
func (s *SqlStore) GetAccountDomainAndCategory(ctx context.Context, lockStrength LockingStrength, accountID string) (string, string, error) {
|
||||
var account Account
|
||||
|
||||
result := s.db.WithContext(ctx).Clauses(clause.Locking{Strength: string(lockStrength)}).Model(&Account{}).Select("domain", "domain_category").
|
||||
Where(idQueryCondition, accountID).First(&account)
|
||||
if result.Error != nil {
|
||||
@ -1093,24 +1109,17 @@ func (s *SqlStore) GetAccountDomainAndCategory(ctx context.Context, lockStrength
|
||||
return account.Domain, account.DomainCategory, nil
|
||||
}
|
||||
|
||||
func (s *SqlStore) GetGroupByID(ctx context.Context, groupID, accountID string) (*nbgroup.Group, error) {
|
||||
var group nbgroup.Group
|
||||
result := s.db.WithContext(ctx).Model(&nbgroup.Group{}).Preload(clause.Associations).
|
||||
Where(accountAndIDQueryCondition, accountID, groupID).First(&group)
|
||||
if result.Error != nil {
|
||||
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
|
||||
return nil, status.Errorf(status.NotFound, "group not found")
|
||||
}
|
||||
return nil, status.Errorf(status.Internal, "failed to get group from store: %s", result.Error)
|
||||
}
|
||||
return &group, nil
|
||||
// GetGroupByID
|
||||
func (s *SqlStore) GetGroupByID(ctx context.Context, lockStrength LockingStrength, groupID, accountID string) (*nbgroup.Group, error) {
|
||||
return getRecordByID[nbgroup.Group](s.db.WithContext(ctx).Preload(clause.Associations), lockStrength, groupID, accountID)
|
||||
}
|
||||
|
||||
// GetGroupByName retrieves a group by name and account ID.
|
||||
func (s *SqlStore) GetGroupByName(ctx context.Context, lockStrength LockingStrength, groupName, accountID string) (*nbgroup.Group, error) {
|
||||
var group nbgroup.Group
|
||||
result := s.db.WithContext(ctx).Clauses(clause.Locking{Strength: string(lockStrength)}).Model(&nbgroup.Group{}).
|
||||
Preload(clause.Associations).Where("name = ? and account_id = ?", groupName, accountID).Order("json_array_length(peers) DESC").First(&group)
|
||||
|
||||
result := s.db.WithContext(ctx).Clauses(clause.Locking{Strength: string(lockStrength)}).Preload(clause.Associations).
|
||||
Order("json_array_length(peers) DESC").First(&group, "name = ? and account_id = ?", groupName, accountID)
|
||||
if err := result.Error; err != nil {
|
||||
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
|
||||
return nil, status.Errorf(status.NotFound, "group not found")
|
||||
@ -1120,47 +1129,85 @@ func (s *SqlStore) GetGroupByName(ctx context.Context, lockStrength LockingStren
|
||||
return &group, nil
|
||||
}
|
||||
|
||||
// GetAccountPolicies retrieves policies for an account.
|
||||
func (s *SqlStore) GetAccountPolicies(ctx context.Context, accountID string) ([]*Policy, error) {
|
||||
var policies []*Policy
|
||||
result := s.db.WithContext(ctx).Model(&Policy{}).Where(accountIDCondition, accountID).
|
||||
Preload(clause.Associations).Find(&policies)
|
||||
if result.Error != nil {
|
||||
return nil, status.Errorf(status.Internal, "failed to get account posture checks: %v", result.Error)
|
||||
}
|
||||
return policies, nil
|
||||
return getRecords[*Policy](s.db.WithContext(ctx).Preload(clause.Associations), accountID)
|
||||
}
|
||||
|
||||
// GetPolicyByID retrieves a policy by its ID and account ID.
|
||||
func (s *SqlStore) GetPolicyByID(ctx context.Context, lockStrength LockingStrength, policyID string, accountID string) (*Policy, error) {
|
||||
var policy *Policy
|
||||
result := s.db.WithContext(ctx).Clauses(clause.Locking{Strength: string(lockStrength)}).Model(&Policy{}).
|
||||
Preload(clause.Associations).Where(accountAndIDQueryCondition, accountID, policyID).First(&policy)
|
||||
if err := result.Error; err != nil {
|
||||
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
|
||||
return nil, status.Errorf(status.NotFound, "posture checks not found")
|
||||
}
|
||||
return nil, status.Errorf(status.Internal, "failed to get posture checks from store: %s", result.Error)
|
||||
}
|
||||
return policy, nil
|
||||
return getRecordByID[Policy](s.db.WithContext(ctx).Preload(clause.Associations), lockStrength, policyID, accountID)
|
||||
}
|
||||
|
||||
// GetAccountPostureChecks retrieves posture checks for an account.
|
||||
func (s *SqlStore) GetAccountPostureChecks(ctx context.Context, accountID string) ([]*posture.Checks, error) {
|
||||
var postureChecks []*posture.Checks
|
||||
result := s.db.WithContext(ctx).Model(&posture.Checks{}).Where(accountIDCondition, accountID).Find(&postureChecks)
|
||||
if result.Error != nil {
|
||||
return nil, status.Errorf(status.Internal, "failed to get account posture checks: %v", result.Error)
|
||||
}
|
||||
return postureChecks, nil
|
||||
return getRecords[*posture.Checks](s.db.WithContext(ctx), accountID)
|
||||
}
|
||||
|
||||
// GetPostureChecksByID retrieves posture checks by their ID and account ID.
|
||||
func (s *SqlStore) GetPostureChecksByID(ctx context.Context, lockStrength LockingStrength, postureCheckID string, accountID string) (*posture.Checks, error) {
|
||||
var postureCheck *posture.Checks
|
||||
result := s.db.WithContext(ctx).Clauses(clause.Locking{Strength: string(lockStrength)}).Model(&posture.Checks{}).
|
||||
Where(accountAndIDQueryCondition, accountID, postureCheckID).First(&postureCheck)
|
||||
if err := result.Error; err != nil {
|
||||
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
|
||||
return nil, status.Errorf(status.NotFound, "posture checks not found")
|
||||
}
|
||||
return nil, status.Errorf(status.Internal, "failed to get posture checks from store: %s", result.Error)
|
||||
}
|
||||
return postureCheck, nil
|
||||
return getRecordByID[posture.Checks](s.db.WithContext(ctx), lockStrength, postureCheckID, accountID)
|
||||
}
|
||||
|
||||
// GetAccountRoutes retrieves network routes for an account.
|
||||
func (s *SqlStore) GetAccountRoutes(ctx context.Context, accountID string) ([]*route.Route, error) {
|
||||
return getRecords[*route.Route](s.db.WithContext(ctx), accountID)
|
||||
}
|
||||
|
||||
// GetRouteByID retrieves a route by its ID and account ID.
|
||||
func (s *SqlStore) GetRouteByID(ctx context.Context, lockStrength LockingStrength, routeID string, accountID string) (*route.Route, error) {
|
||||
return getRecordByID[route.Route](s.db.WithContext(ctx), lockStrength, routeID, accountID)
|
||||
}
|
||||
|
||||
// GetAccountSetupKeys retrieves setup keys for an account.
|
||||
func (s *SqlStore) GetAccountSetupKeys(ctx context.Context, accountID string) ([]*SetupKey, error) {
|
||||
return getRecords[*SetupKey](s.db.WithContext(ctx), accountID)
|
||||
}
|
||||
|
||||
// GetSetupKeyByID retrieves a setup key by its ID and account ID.
|
||||
func (s *SqlStore) GetSetupKeyByID(ctx context.Context, lockStrength LockingStrength, setupKeyID string, accountID string) (*SetupKey, error) {
|
||||
return getRecordByID[SetupKey](s.db.WithContext(ctx), lockStrength, setupKeyID, accountID)
|
||||
}
|
||||
|
||||
// GetAccountNameServerGroups retrieves name server groups for an account.
|
||||
func (s *SqlStore) GetAccountNameServerGroups(ctx context.Context, accountID string) ([]*nbdns.NameServerGroup, error) {
|
||||
return getRecords[*nbdns.NameServerGroup](s.db.WithContext(ctx), accountID)
|
||||
}
|
||||
|
||||
// GetNameServerGroupByID retrieves a name server group by its ID and account ID.
|
||||
func (s *SqlStore) GetNameServerGroupByID(ctx context.Context, lockStrength LockingStrength, nsGroupID string, accountID string) (*nbdns.NameServerGroup, error) {
|
||||
return getRecordByID[nbdns.NameServerGroup](s.db.WithContext(ctx), lockStrength, nsGroupID, accountID)
|
||||
}
|
||||
|
||||
// getRecords retrieves records from the database based on the account ID.
|
||||
func getRecords[T any](db *gorm.DB, accountID string) ([]T, error) {
|
||||
var record []T
|
||||
|
||||
result := db.Find(&record, accountIDCondition, accountID)
|
||||
if err := result.Error; err != nil {
|
||||
parts := strings.Split(fmt.Sprintf("%T", record), ".")
|
||||
recordType := parts[len(parts)-1]
|
||||
|
||||
return nil, status.Errorf(status.Internal, "failed to get account %ss from store: %v", recordType, err)
|
||||
}
|
||||
|
||||
return record, nil
|
||||
}
|
||||
|
||||
// getRecordByID retrieves a record by its ID and account ID from the database.
|
||||
func getRecordByID[T any](db *gorm.DB, lockStrength LockingStrength, recordID, accountID string) (*T, error) {
|
||||
var record T
|
||||
|
||||
result := db.Clauses(clause.Locking{Strength: string(lockStrength)}).
|
||||
First(&record, accountAndIDQueryCondition, accountID, recordID)
|
||||
if err := result.Error; err != nil {
|
||||
parts := strings.Split(fmt.Sprintf("%T", record), ".")
|
||||
recordType := parts[len(parts)-1]
|
||||
|
||||
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
|
||||
return nil, status.Errorf(status.NotFound, "%s not found", recordType)
|
||||
}
|
||||
return nil, status.Errorf(status.Internal, "failed to get %s from store: %v", recordType, err)
|
||||
}
|
||||
return &record, nil
|
||||
}
|
||||
|
@ -12,6 +12,7 @@ import (
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/netbirdio/netbird/dns"
|
||||
log "github.com/sirupsen/logrus"
|
||||
"gorm.io/gorm"
|
||||
|
||||
@ -51,6 +52,7 @@ type Store interface {
|
||||
GetAccountByPrivateDomain(ctx context.Context, domain string) (*Account, error)
|
||||
GetAccountIDByPrivateDomain(ctx context.Context, lockStrength LockingStrength, domain string) (string, error)
|
||||
GetAccountSettings(ctx context.Context, lockStrength LockingStrength, accountID string) (*Settings, error)
|
||||
GetAccountDNSSettings(ctx context.Context, lockStrength LockingStrength, accountID string) (*DNSSettings, error)
|
||||
UpdateAccount(ctx context.Context, lockStrength LockingStrength, account *Account) error
|
||||
SaveAccount(ctx context.Context, account *Account) error
|
||||
DeleteAccount(ctx context.Context, account *Account) error
|
||||
@ -64,7 +66,7 @@ type Store interface {
|
||||
DeleteTokenID2UserIDIndex(tokenID string) error
|
||||
|
||||
GetAccountGroups(ctx context.Context, accountID string) ([]*nbgroup.Group, error)
|
||||
GetGroupByID(ctx context.Context, groupID, accountID string) (*nbgroup.Group, error)
|
||||
GetGroupByID(ctx context.Context, lockStrength LockingStrength, groupID, accountID string) (*nbgroup.Group, error)
|
||||
GetGroupByName(ctx context.Context, lockStrength LockingStrength, groupName, accountID string) (*nbgroup.Group, error)
|
||||
SaveGroups(accountID string, groups map[string]*nbgroup.Group) error
|
||||
|
||||
@ -86,6 +88,14 @@ type Store interface {
|
||||
|
||||
GetSetupKeyBySecret(ctx context.Context, lockStrength LockingStrength, key string) (*SetupKey, error)
|
||||
IncrementSetupKeyUsage(ctx context.Context, setupKeyID string) error
|
||||
GetAccountSetupKeys(ctx context.Context, accountID string) ([]*SetupKey, error)
|
||||
GetSetupKeyByID(ctx context.Context, lockStrength LockingStrength, setupKeyID string, accountID string) (*SetupKey, error)
|
||||
|
||||
GetAccountRoutes(ctx context.Context, accountID string) ([]*route.Route, error)
|
||||
GetRouteByID(ctx context.Context, lockStrength LockingStrength, routeID string, accountID string) (*route.Route, error)
|
||||
|
||||
GetAccountNameServerGroups(ctx context.Context, accountID string) ([]*dns.NameServerGroup, error)
|
||||
GetNameServerGroupByID(ctx context.Context, lockStrength LockingStrength, nameServerGroupID string, accountID string) (*dns.NameServerGroup, error)
|
||||
|
||||
GetTakenIPs(ctx context.Context, lockStrength LockingStrength, accountId string) ([]net.IP, error)
|
||||
IncrementNetworkSerial(ctx context.Context, accountId string) error
|
||||
|
Loading…
Reference in New Issue
Block a user