diff --git a/management/server/file_store.go b/management/server/file_store.go index 7b766a2e3..f6057b281 100644 --- a/management/server/file_store.go +++ b/management/server/file_store.go @@ -10,11 +10,13 @@ import ( "sync" "time" + "github.com/netbirdio/netbird/dns" nbgroup "github.com/netbirdio/netbird/management/server/group" nbpeer "github.com/netbirdio/netbird/management/server/peer" "github.com/netbirdio/netbird/management/server/posture" "github.com/netbirdio/netbird/management/server/status" "github.com/netbirdio/netbird/management/server/telemetry" + "github.com/netbirdio/netbird/route" "github.com/rs/xid" log "github.com/sirupsen/logrus" @@ -979,11 +981,15 @@ func (s *FileStore) AccountExists(_ context.Context, id string) (bool, error) { return exists, nil } +func (s *FileStore) GetAccountDNSSettings(_ context.Context, _ LockingStrength, _ string) (*DNSSettings, error) { + return nil, status.Errorf(status.Internal, "GetAccountDNSSettings is not implemented") +} + func (s *FileStore) UpdateAccount(_ context.Context, _ LockingStrength, _ *Account) error { return nil } -func (s *FileStore) GetGroupByID(_ context.Context, _, _ string) (*nbgroup.Group, error) { +func (s *FileStore) GetGroupByID(_ context.Context, _ LockingStrength, _, _ string) (*nbgroup.Group, error) { return nil, status.Errorf(status.Internal, "GetGroupByID is not implemented") } @@ -1007,3 +1013,27 @@ func (s *FileStore) GetAccountPostureChecks(_ context.Context, _ string) ([]*pos func (s *FileStore) GetPostureChecksByID(_ context.Context, _ LockingStrength, _ string, _ string) (*posture.Checks, error) { return nil, status.Errorf(status.Internal, "GetPostureChecksByID is not implemented") } + +func (s *FileStore) GetAccountRoutes(_ context.Context, _ string) ([]*route.Route, error) { + return nil, status.Errorf(status.Internal, "GetAccountRoutes is not implemented") +} + +func (s *FileStore) GetRouteByID(_ context.Context, _ LockingStrength, _ string, _ string) (*route.Route, error) { + return nil, status.Errorf(status.Internal, "GetRouteByID is not implemented") +} + +func (s *FileStore) GetAccountSetupKeys(_ context.Context, _ string) ([]*SetupKey, error) { + return nil, status.Errorf(status.Internal, "GetAccountSetupKeys is not implemented") +} + +func (s *FileStore) GetSetupKeyByID(_ context.Context, _ LockingStrength, _ string, _ string) (*SetupKey, error) { + return nil, status.Errorf(status.Internal, "GetSetupKeyByID is not implemented") +} + +func (s *FileStore) GetAccountNameServerGroups(_ context.Context, _ string) ([]*dns.NameServerGroup, error) { + return nil, status.Errorf(status.Internal, "GetAccountNameServerGroups is not implemented") +} + +func (s *FileStore) GetNameServerGroupByID(_ context.Context, _ LockingStrength, _ string, _ string) (*dns.NameServerGroup, error) { + return nil, status.Errorf(status.Internal, "GetNameServerGroupByID is not implemented") +} diff --git a/management/server/sql_store.go b/management/server/sql_store.go index 5094c589b..cc3d771d9 100644 --- a/management/server/sql_store.go +++ b/management/server/sql_store.go @@ -1035,6 +1035,20 @@ func (s *SqlStore) withTx(tx *gorm.DB) Store { } } +func (s *SqlStore) GetAccountDNSSettings(ctx context.Context, lockStrength LockingStrength, accountID string) (*DNSSettings, error) { + var accountDNSSettings AccountDNSSettings + + result := s.db.WithContext(ctx).Clauses(clause.Locking{Strength: string(lockStrength)}).Model(&Account{}). + First(&accountDNSSettings, idQueryCondition, accountID) + if result.Error != nil { + if errors.Is(result.Error, gorm.ErrRecordNotFound) { + return nil, status.Errorf(status.NotFound, "dns settings not found") + } + return nil, status.Errorf(status.Internal, "failed to get dns settings from store: %v", result.Error) + } + return &accountDNSSettings.DNSSettings, nil +} + // UpdateAccount updates an existing account's domain, DNS settings, and settings fields. func (s *SqlStore) UpdateAccount(ctx context.Context, lockStrength LockingStrength, account *Account) error { updates := make(map[string]interface{}) @@ -1071,6 +1085,7 @@ func (s *SqlStore) UpdateAccount(ctx context.Context, lockStrength LockingStreng // AccountExists checks whether an account exists by the given ID. func (s *SqlStore) AccountExists(ctx context.Context, id string) (bool, error) { var count int64 + result := s.db.WithContext(ctx).Model(&Account{}).Where(idQueryCondition, id).Count(&count) if result.Error != nil { return false, result.Error @@ -1081,6 +1096,7 @@ func (s *SqlStore) AccountExists(ctx context.Context, id string) (bool, error) { // GetAccountDomainAndCategory retrieves the Domain and DomainCategory fields for an account based on the given accountID. func (s *SqlStore) GetAccountDomainAndCategory(ctx context.Context, lockStrength LockingStrength, accountID string) (string, string, error) { var account Account + result := s.db.WithContext(ctx).Clauses(clause.Locking{Strength: string(lockStrength)}).Model(&Account{}).Select("domain", "domain_category"). Where(idQueryCondition, accountID).First(&account) if result.Error != nil { @@ -1093,24 +1109,17 @@ func (s *SqlStore) GetAccountDomainAndCategory(ctx context.Context, lockStrength return account.Domain, account.DomainCategory, nil } -func (s *SqlStore) GetGroupByID(ctx context.Context, groupID, accountID string) (*nbgroup.Group, error) { - var group nbgroup.Group - result := s.db.WithContext(ctx).Model(&nbgroup.Group{}).Preload(clause.Associations). - Where(accountAndIDQueryCondition, accountID, groupID).First(&group) - if result.Error != nil { - if errors.Is(result.Error, gorm.ErrRecordNotFound) { - return nil, status.Errorf(status.NotFound, "group not found") - } - return nil, status.Errorf(status.Internal, "failed to get group from store: %s", result.Error) - } - return &group, nil +// GetGroupByID +func (s *SqlStore) GetGroupByID(ctx context.Context, lockStrength LockingStrength, groupID, accountID string) (*nbgroup.Group, error) { + return getRecordByID[nbgroup.Group](s.db.WithContext(ctx).Preload(clause.Associations), lockStrength, groupID, accountID) } // GetGroupByName retrieves a group by name and account ID. func (s *SqlStore) GetGroupByName(ctx context.Context, lockStrength LockingStrength, groupName, accountID string) (*nbgroup.Group, error) { var group nbgroup.Group - result := s.db.WithContext(ctx).Clauses(clause.Locking{Strength: string(lockStrength)}).Model(&nbgroup.Group{}). - Preload(clause.Associations).Where("name = ? and account_id = ?", groupName, accountID).Order("json_array_length(peers) DESC").First(&group) + + result := s.db.WithContext(ctx).Clauses(clause.Locking{Strength: string(lockStrength)}).Preload(clause.Associations). + Order("json_array_length(peers) DESC").First(&group, "name = ? and account_id = ?", groupName, accountID) if err := result.Error; err != nil { if errors.Is(result.Error, gorm.ErrRecordNotFound) { return nil, status.Errorf(status.NotFound, "group not found") @@ -1120,47 +1129,85 @@ func (s *SqlStore) GetGroupByName(ctx context.Context, lockStrength LockingStren return &group, nil } +// GetAccountPolicies retrieves policies for an account. func (s *SqlStore) GetAccountPolicies(ctx context.Context, accountID string) ([]*Policy, error) { - var policies []*Policy - result := s.db.WithContext(ctx).Model(&Policy{}).Where(accountIDCondition, accountID). - Preload(clause.Associations).Find(&policies) - if result.Error != nil { - return nil, status.Errorf(status.Internal, "failed to get account posture checks: %v", result.Error) - } - return policies, nil + return getRecords[*Policy](s.db.WithContext(ctx).Preload(clause.Associations), accountID) } +// GetPolicyByID retrieves a policy by its ID and account ID. func (s *SqlStore) GetPolicyByID(ctx context.Context, lockStrength LockingStrength, policyID string, accountID string) (*Policy, error) { - var policy *Policy - result := s.db.WithContext(ctx).Clauses(clause.Locking{Strength: string(lockStrength)}).Model(&Policy{}). - Preload(clause.Associations).Where(accountAndIDQueryCondition, accountID, policyID).First(&policy) - if err := result.Error; err != nil { - if errors.Is(result.Error, gorm.ErrRecordNotFound) { - return nil, status.Errorf(status.NotFound, "posture checks not found") - } - return nil, status.Errorf(status.Internal, "failed to get posture checks from store: %s", result.Error) - } - return policy, nil + return getRecordByID[Policy](s.db.WithContext(ctx).Preload(clause.Associations), lockStrength, policyID, accountID) } +// GetAccountPostureChecks retrieves posture checks for an account. func (s *SqlStore) GetAccountPostureChecks(ctx context.Context, accountID string) ([]*posture.Checks, error) { - var postureChecks []*posture.Checks - result := s.db.WithContext(ctx).Model(&posture.Checks{}).Where(accountIDCondition, accountID).Find(&postureChecks) - if result.Error != nil { - return nil, status.Errorf(status.Internal, "failed to get account posture checks: %v", result.Error) - } - return postureChecks, nil + return getRecords[*posture.Checks](s.db.WithContext(ctx), accountID) } +// GetPostureChecksByID retrieves posture checks by their ID and account ID. func (s *SqlStore) GetPostureChecksByID(ctx context.Context, lockStrength LockingStrength, postureCheckID string, accountID string) (*posture.Checks, error) { - var postureCheck *posture.Checks - result := s.db.WithContext(ctx).Clauses(clause.Locking{Strength: string(lockStrength)}).Model(&posture.Checks{}). - Where(accountAndIDQueryCondition, accountID, postureCheckID).First(&postureCheck) - if err := result.Error; err != nil { - if errors.Is(result.Error, gorm.ErrRecordNotFound) { - return nil, status.Errorf(status.NotFound, "posture checks not found") - } - return nil, status.Errorf(status.Internal, "failed to get posture checks from store: %s", result.Error) - } - return postureCheck, nil + return getRecordByID[posture.Checks](s.db.WithContext(ctx), lockStrength, postureCheckID, accountID) +} + +// GetAccountRoutes retrieves network routes for an account. +func (s *SqlStore) GetAccountRoutes(ctx context.Context, accountID string) ([]*route.Route, error) { + return getRecords[*route.Route](s.db.WithContext(ctx), accountID) +} + +// GetRouteByID retrieves a route by its ID and account ID. +func (s *SqlStore) GetRouteByID(ctx context.Context, lockStrength LockingStrength, routeID string, accountID string) (*route.Route, error) { + return getRecordByID[route.Route](s.db.WithContext(ctx), lockStrength, routeID, accountID) +} + +// GetAccountSetupKeys retrieves setup keys for an account. +func (s *SqlStore) GetAccountSetupKeys(ctx context.Context, accountID string) ([]*SetupKey, error) { + return getRecords[*SetupKey](s.db.WithContext(ctx), accountID) +} + +// GetSetupKeyByID retrieves a setup key by its ID and account ID. +func (s *SqlStore) GetSetupKeyByID(ctx context.Context, lockStrength LockingStrength, setupKeyID string, accountID string) (*SetupKey, error) { + return getRecordByID[SetupKey](s.db.WithContext(ctx), lockStrength, setupKeyID, accountID) +} + +// GetAccountNameServerGroups retrieves name server groups for an account. +func (s *SqlStore) GetAccountNameServerGroups(ctx context.Context, accountID string) ([]*nbdns.NameServerGroup, error) { + return getRecords[*nbdns.NameServerGroup](s.db.WithContext(ctx), accountID) +} + +// GetNameServerGroupByID retrieves a name server group by its ID and account ID. +func (s *SqlStore) GetNameServerGroupByID(ctx context.Context, lockStrength LockingStrength, nsGroupID string, accountID string) (*nbdns.NameServerGroup, error) { + return getRecordByID[nbdns.NameServerGroup](s.db.WithContext(ctx), lockStrength, nsGroupID, accountID) +} + +// getRecords retrieves records from the database based on the account ID. +func getRecords[T any](db *gorm.DB, accountID string) ([]T, error) { + var record []T + + result := db.Find(&record, accountIDCondition, accountID) + if err := result.Error; err != nil { + parts := strings.Split(fmt.Sprintf("%T", record), ".") + recordType := parts[len(parts)-1] + + return nil, status.Errorf(status.Internal, "failed to get account %ss from store: %v", recordType, err) + } + + return record, nil +} + +// getRecordByID retrieves a record by its ID and account ID from the database. +func getRecordByID[T any](db *gorm.DB, lockStrength LockingStrength, recordID, accountID string) (*T, error) { + var record T + + result := db.Clauses(clause.Locking{Strength: string(lockStrength)}). + First(&record, accountAndIDQueryCondition, accountID, recordID) + if err := result.Error; err != nil { + parts := strings.Split(fmt.Sprintf("%T", record), ".") + recordType := parts[len(parts)-1] + + if errors.Is(result.Error, gorm.ErrRecordNotFound) { + return nil, status.Errorf(status.NotFound, "%s not found", recordType) + } + return nil, status.Errorf(status.Internal, "failed to get %s from store: %v", recordType, err) + } + return &record, nil } diff --git a/management/server/store.go b/management/server/store.go index 601e173e2..c629a4c75 100644 --- a/management/server/store.go +++ b/management/server/store.go @@ -12,6 +12,7 @@ import ( "strings" "time" + "github.com/netbirdio/netbird/dns" log "github.com/sirupsen/logrus" "gorm.io/gorm" @@ -51,6 +52,7 @@ type Store interface { GetAccountByPrivateDomain(ctx context.Context, domain string) (*Account, error) GetAccountIDByPrivateDomain(ctx context.Context, lockStrength LockingStrength, domain string) (string, error) GetAccountSettings(ctx context.Context, lockStrength LockingStrength, accountID string) (*Settings, error) + GetAccountDNSSettings(ctx context.Context, lockStrength LockingStrength, accountID string) (*DNSSettings, error) UpdateAccount(ctx context.Context, lockStrength LockingStrength, account *Account) error SaveAccount(ctx context.Context, account *Account) error DeleteAccount(ctx context.Context, account *Account) error @@ -64,7 +66,7 @@ type Store interface { DeleteTokenID2UserIDIndex(tokenID string) error GetAccountGroups(ctx context.Context, accountID string) ([]*nbgroup.Group, error) - GetGroupByID(ctx context.Context, groupID, accountID string) (*nbgroup.Group, error) + GetGroupByID(ctx context.Context, lockStrength LockingStrength, groupID, accountID string) (*nbgroup.Group, error) GetGroupByName(ctx context.Context, lockStrength LockingStrength, groupName, accountID string) (*nbgroup.Group, error) SaveGroups(accountID string, groups map[string]*nbgroup.Group) error @@ -86,6 +88,14 @@ type Store interface { GetSetupKeyBySecret(ctx context.Context, lockStrength LockingStrength, key string) (*SetupKey, error) IncrementSetupKeyUsage(ctx context.Context, setupKeyID string) error + GetAccountSetupKeys(ctx context.Context, accountID string) ([]*SetupKey, error) + GetSetupKeyByID(ctx context.Context, lockStrength LockingStrength, setupKeyID string, accountID string) (*SetupKey, error) + + GetAccountRoutes(ctx context.Context, accountID string) ([]*route.Route, error) + GetRouteByID(ctx context.Context, lockStrength LockingStrength, routeID string, accountID string) (*route.Route, error) + + GetAccountNameServerGroups(ctx context.Context, accountID string) ([]*dns.NameServerGroup, error) + GetNameServerGroupByID(ctx context.Context, lockStrength LockingStrength, nameServerGroupID string, accountID string) (*dns.NameServerGroup, error) GetTakenIPs(ctx context.Context, lockStrength LockingStrength, accountId string) ([]net.IP, error) IncrementNetworkSerial(ctx context.Context, accountId string) error