netbird/management/server/sql_store.go

1216 lines
41 KiB
Go
Raw Normal View History

package server
import (
"context"
2024-05-30 15:22:42 +02:00
"encoding/json"
"errors"
"fmt"
"net"
"os"
"path/filepath"
"runtime"
"runtime/debug"
"strings"
"sync"
"time"
log "github.com/sirupsen/logrus"
"gorm.io/driver/postgres"
"gorm.io/driver/sqlite"
"gorm.io/gorm"
"gorm.io/gorm/clause"
"gorm.io/gorm/logger"
2023-11-28 13:45:26 +01:00
nbdns "github.com/netbirdio/netbird/dns"
2023-11-30 11:51:35 +01:00
"github.com/netbirdio/netbird/management/server/account"
nbgroup "github.com/netbirdio/netbird/management/server/group"
2023-11-28 13:45:26 +01:00
nbpeer "github.com/netbirdio/netbird/management/server/peer"
"github.com/netbirdio/netbird/management/server/posture"
2023-11-28 13:45:26 +01:00
"github.com/netbirdio/netbird/management/server/status"
"github.com/netbirdio/netbird/management/server/telemetry"
"github.com/netbirdio/netbird/route"
)
const (
storeSqliteFileName = "store.db"
idQueryCondition = "id = ?"
keyQueryCondition = "key = ?"
accountAndIDQueryCondition = "account_id = ? and id = ?"
accountIDCondition = "account_id = ?"
peerNotFoundFMT = "peer %s not found"
)
// SqlStore represents an account storage backed by a Sql DB persisted to disk
type SqlStore struct {
db *gorm.DB
resourceLocks sync.Map
globalAccountLock sync.Mutex
metrics telemetry.AppMetrics
installationPK int
storeEngine StoreEngine
}
type installation struct {
ID uint `gorm:"primaryKey"`
InstallationIDValue string
}
type migrationFunc func(*gorm.DB) error
// NewSqlStore creates a new SqlStore instance.
func NewSqlStore(ctx context.Context, db *gorm.DB, storeEngine StoreEngine, metrics telemetry.AppMetrics) (*SqlStore, error) {
sql, err := db.DB()
if err != nil {
return nil, err
}
conns := runtime.NumCPU()
sql.SetMaxOpenConns(conns) // TODO: make it configurable
if err := migrate(ctx, db); err != nil {
return nil, fmt.Errorf("migrate: %w", err)
}
err = db.AutoMigrate(
&SetupKey{}, &nbpeer.Peer{}, &User{}, &PersonalAccessToken{}, &nbgroup.Group{},
&Account{}, &Policy{}, &PolicyRule{}, &route.Route{}, &nbdns.NameServerGroup{},
Extend system meta (#1598) * wip: add posture checks structs * add netbird version check * Refactor posture checks and add version checks * Add posture check activities (#1445) * Integrate Endpoints for Posture Checks (#1432) * wip: add posture checks structs * add netbird version check * Refactor posture checks and add version checks * Implement posture and version checks in API models * Refactor API models and enhance posture check functionality * wip: add posture checks endpoints * go mod tidy * Reference the posture checks by id's in policy * Add posture checks management to server * Add posture checks management mocks * implement posture checks handlers * Add posture checks to account copy and fix tests * Refactor posture checks validation * wip: Add posture checks handler tests * Add JSON encoding support to posture checks * Encode posture checks to correct api response object * Refactored posture checks implementation to align with the new API schema * Refactor structure of `Checks` from slice to map * Cleanup * Add posture check activities (#1445) * Revert map to use list of checks * Add posture check activity events * Refactor posture check initialization in account test * Improve the handling of version range in posture check * Fix tests and linter * Remove max_version from NBVersionCheck * Added unit tests for NBVersionCheck * go mod tidy * Extend policy endpoint with posture checks (#1450) * Implement posture and version checks in API models * go mod tidy * Allow attaching posture checks to policy * Update error message for linked posture check on deleting * Refactor PostureCheck and Checks structures * go mod tidy * Add validation for non-existing posture checks * fix unit tests * use Wt version * Remove the enabled field, as posture check will now automatically be activated by default when attaching to a policy * wip: add posture checks structs * add netbird version check * Refactor posture checks and add version checks * Add posture check activities (#1445) * Integrate Endpoints for Posture Checks (#1432) * wip: add posture checks structs * add netbird version check * Refactor posture checks and add version checks * Implement posture and version checks in API models * Refactor API models and enhance posture check functionality * wip: add posture checks endpoints * go mod tidy * Reference the posture checks by id's in policy * Add posture checks management to server * Add posture checks management mocks * implement posture checks handlers * Add posture checks to account copy and fix tests * Refactor posture checks validation * wip: Add posture checks handler tests * Add JSON encoding support to posture checks * Encode posture checks to correct api response object * Refactored posture checks implementation to align with the new API schema * Refactor structure of `Checks` from slice to map * Cleanup * Add posture check activities (#1445) * Revert map to use list of checks * Add posture check activity events * Refactor posture check initialization in account test * Improve the handling of version range in posture check * Fix tests and linter * Remove max_version from NBVersionCheck * Added unit tests for NBVersionCheck * go mod tidy * Extend policy endpoint with posture checks (#1450) * Implement posture and version checks in API models * go mod tidy * Allow attaching posture checks to policy * Update error message for linked posture check on deleting * Refactor PostureCheck and Checks structures * go mod tidy * Add validation for non-existing posture checks * fix unit tests * use Wt version * Remove the enabled field, as posture check will now automatically be activated by default when attaching to a policy * Extend network map generation with posture checks (#1466) * Apply posture checks to network map generation * run policy posture checks on peers to connect * Refactor and streamline policy posture check process for peers to connect. * Add posture checks testing in a network map * Remove redundant nil check in policy.go * Refactor peer validation check in policy.go * Update 'Check' function signature and use logger for version check * Refactor posture checks run on sources and updated the validation func * Update peer validation * fix tests * improved test coverage for policy posture check * Refactoring * Extend NetBird agent to collect kernel version (#1495) * Add KernelVersion field to LoginRequest * Add KernelVersion to system info retrieval * Fix tests * Remove Core field from system info * Replace Core field with new OSVersion field in system info * Added WMI dependency to info_windows.go * Add OS Version posture checks (#1479) * Initial support of Geolocation service (#1491) * Add Geo Location posture check (#1500) * wip: implement geolocation check * add geo location posture checks to posture api * Merge branch 'feature/posture-checks' into geo-posture-check * Remove CityGeoNameID and update required fields in API * Add geoLocation checks to posture checks handler tests * Implement geo location-based checks for peers * Update test values and embed location struct in peer system * add support for country wide checks * initialize country code regex once * Fix peer meta core compability with older clients (#1515) * Refactor extraction of OSVersion in grpcserver * Ignore lint check * Fix peer meta core compability with older management (#1532) * Revert core field deprecation * fix tests * Extend peer meta with location information (#1517) This PR uses the geolocation service to resolve IP to location. The lookup happens once on the first connection - when a client calls the Sync func. The location is stored as part of the peer: * Add Locations endpoints (#1516) * add locations endpoints * Add sqlite3 check and database generation in geolite script * Add SQLite storage for geolocation data * Refactor file existence check into a separate function * Integrate geolocation services into management application * Refactoring * Refactor city retrieval to include Geonames ID * Add signature verification for GeoLite2 database download * Change to in-memory database for geolocation store * Merge manager to geolocation * Update GetAllCountries to return Country name and iso code * fix tests * Add reload to SqliteStore * Add geoname indexes * move db file check to connectDB * Add concurrency safety to SQL queries and database reloading The commit adds mutex locks to the GetAllCountries and GetCitiesByCountry functions to ensure thread-safety during database queries. Additionally, it introduces a mechanism to safely close the old database connection before a new connection is established upon reloading, which improves the reliability of database operations. Lastly, it moves the checking of database file existence to the connectDB function. * Add sha256 sum check to geolocation store before reload * Use read lock * Check SHA256 twice when reload geonames db --------- Co-authored-by: Yury Gargay <yury.gargay@gmail.com> * Add tests and validation for empty peer location in GeoLocationCheck (#1546) * Disallow Geo check creation/update without configured Geo DB (#1548) * Fix shared access to in memory copy of geonames.db (#1550) * Trim suffix in when evaluate Min Kernel Version in OS check * Add Valid Peer Windows Kernel version test * Add Geolocation handler tests (#1556) * Implement user admin checks in posture checks * Add geolocation handler tests * Mark initGeolocationTestData as helper func * Add error handling to geolocation database closure * Add cleanup function to close geolocation resources * Simplify checks definition serialisation (#1555) * Regenerate network map on posture check update (#1563) * change network state and generate map on posture check update * Refactoring * Make city name optional (#1575) * Do not return empty city name * Validate action param of geo location checks (#1577) We only support allow and deny * Switch realip middleware to upstream (#1578) * Be more silent in download-geolite2.sh script * Fix geonames db reload (#1580) * Ensure posture check name uniqueness when create (#1594) * Enhance the management of posture checks (#1595) * add a correct min version and kernel for os posture check example * handle error when geo or location db is nil * expose all peer location details in api response * Check for nil geolocation manager only * Validate posture check before save * bump open api version * add peer location fields to toPeerListItemResponse * Feautre/extend sys meta (#1536) * Collect network addresses * Add Linux sys product info * Fix peer meta comparison * Collect sys info on mac * Add windows sys info * Fix test * Fix test * Fix grpc client * Ignore test * Fix test * Collect IPv6 addresses * Change the IP to IP + net * fix tests * Use netip on server side * Serialize netip to json * Extend Peer metadata with cloud detection (#1552) * add cloud detection + test binary * test windows exe * Collect IPv6 addresses * Change the IP to IP + net * switch to forked cloud detect lib * new test builds * new GCE build * discontinue using library but local copy instead * fix imports * remove openstack check * add hierarchy to cloud check * merge IBM and SoftLayer * close resp bodies and use os lib for file reading * close more resp bodies * fix error check logic * parallelize IBM checks * fix response value * go mod tidy * include context + change kubernetes detection * add context in info functions * extract platform into separate field * fix imports * add missing wmi import --------- Co-authored-by: Zoltan Papp <zoltan.pmail@gmail.com> --------- Co-authored-by: pascal-fischer <32096965+pascal-fischer@users.noreply.github.com> * generate proto * remove test binaries --------- Co-authored-by: bcmmbaga <bethuelmbaga12@gmail.com> Co-authored-by: Yury Gargay <yury.gargay@gmail.com> Co-authored-by: Zoltan Papp <zoltan.pmail@gmail.com>
2024-02-20 11:53:11 +01:00
&installation{}, &account.ExtraSettings{}, &posture.Checks{}, &nbpeer.NetworkAddress{},
)
if err != nil {
return nil, fmt.Errorf("auto migrate: %w", err)
}
return &SqlStore{db: db, storeEngine: storeEngine, metrics: metrics, installationPK: 1}, nil
}
// AcquireGlobalLock acquires global lock across all the accounts and returns a function that releases the lock
func (s *SqlStore) AcquireGlobalLock(ctx context.Context) (unlock func()) {
log.WithContext(ctx).Tracef("acquiring global lock")
start := time.Now()
s.globalAccountLock.Lock()
unlock = func() {
s.globalAccountLock.Unlock()
log.WithContext(ctx).Tracef("released global lock in %v", time.Since(start))
}
took := time.Since(start)
log.WithContext(ctx).Tracef("took %v to acquire global lock", took)
if s.metrics != nil {
s.metrics.StoreMetrics().CountGlobalLockAcquisitionDuration(took)
}
return unlock
}
// AcquireWriteLockByUID acquires an ID lock for writing to a resource and returns a function that releases the lock
func (s *SqlStore) AcquireWriteLockByUID(ctx context.Context, uniqueID string) (unlock func()) {
log.WithContext(ctx).Tracef("acquiring write lock for ID %s", uniqueID)
start := time.Now()
value, _ := s.resourceLocks.LoadOrStore(uniqueID, &sync.RWMutex{})
2024-05-07 14:30:03 +02:00
mtx := value.(*sync.RWMutex)
mtx.Lock()
unlock = func() {
mtx.Unlock()
log.WithContext(ctx).Tracef("released write lock for ID %s in %v", uniqueID, time.Since(start))
2024-05-07 14:30:03 +02:00
}
return unlock
}
// AcquireReadLockByUID acquires an ID lock for writing to a resource and returns a function that releases the lock
func (s *SqlStore) AcquireReadLockByUID(ctx context.Context, uniqueID string) (unlock func()) {
log.WithContext(ctx).Tracef("acquiring read lock for ID %s", uniqueID)
2024-05-07 14:30:03 +02:00
start := time.Now()
value, _ := s.resourceLocks.LoadOrStore(uniqueID, &sync.RWMutex{})
2024-05-07 14:30:03 +02:00
mtx := value.(*sync.RWMutex)
mtx.RLock()
unlock = func() {
mtx.RUnlock()
log.WithContext(ctx).Tracef("released read lock for ID %s in %v", uniqueID, time.Since(start))
}
return unlock
}
func (s *SqlStore) SaveAccount(ctx context.Context, account *Account) error {
start := time.Now()
defer func() {
elapsed := time.Since(start)
if elapsed > 1*time.Second {
log.WithContext(ctx).Tracef("SaveAccount for account %s exceeded 1s, took: %v", account.Id, elapsed)
}
}()
// todo: remove this check after the issue is resolved
s.checkAccountDomainBeforeSave(ctx, account.Id, account.Domain)
generateAccountSQLTypes(account)
err := s.db.Transaction(func(tx *gorm.DB) error {
result := tx.Select(clause.Associations).Delete(account.Policies, "account_id = ?", account.Id)
if result.Error != nil {
return result.Error
}
result = tx.Select(clause.Associations).Delete(account.UsersG, "account_id = ?", account.Id)
if result.Error != nil {
return result.Error
}
result = tx.Select(clause.Associations).Delete(account)
if result.Error != nil {
return result.Error
}
result = tx.
Session(&gorm.Session{FullSaveAssociations: true}).
Clauses(clause.OnConflict{UpdateAll: true}).
Create(account)
if result.Error != nil {
return result.Error
}
return nil
})
took := time.Since(start)
if s.metrics != nil {
s.metrics.StoreMetrics().CountPersistenceDuration(took)
}
log.WithContext(ctx).Debugf("took %d ms to persist an account to the store", took.Milliseconds())
return err
}
// generateAccountSQLTypes generates the GORM compatible types for the account
func generateAccountSQLTypes(account *Account) {
for _, key := range account.SetupKeys {
account.SetupKeysG = append(account.SetupKeysG, *key)
}
for id, peer := range account.Peers {
peer.ID = id
account.PeersG = append(account.PeersG, *peer)
}
for id, user := range account.Users {
user.Id = id
for id, pat := range user.PATs {
pat.ID = id
user.PATsG = append(user.PATsG, *pat)
}
account.UsersG = append(account.UsersG, *user)
}
for id, group := range account.Groups {
group.ID = id
account.GroupsG = append(account.GroupsG, *group)
}
for id, route := range account.Routes {
route.ID = id
account.RoutesG = append(account.RoutesG, *route)
}
for id, ns := range account.NameServerGroups {
ns.ID = id
account.NameServerGroupsG = append(account.NameServerGroupsG, *ns)
}
}
// checkAccountDomainBeforeSave temporary method to troubleshoot an issue with domains getting blank
func (s *SqlStore) checkAccountDomainBeforeSave(ctx context.Context, accountID, newDomain string) {
var acc Account
var domain string
result := s.db.Model(&acc).Select("domain").Where(idQueryCondition, accountID).First(&domain)
if result.Error != nil {
if !errors.Is(result.Error, gorm.ErrRecordNotFound) {
log.WithContext(ctx).Errorf("error when getting account %s from the store to check domain: %s", accountID, result.Error)
}
return
}
if domain != "" && newDomain == "" {
log.WithContext(ctx).Warnf("saving an account with empty domain when there was a domain set. Previous domain %s, Account ID: %s, Trace: %s", domain, accountID, debug.Stack())
}
}
func (s *SqlStore) DeleteAccount(ctx context.Context, account *Account) error {
start := time.Now()
err := s.db.Transaction(func(tx *gorm.DB) error {
result := tx.Select(clause.Associations).Delete(account.Policies, "account_id = ?", account.Id)
if result.Error != nil {
return result.Error
}
result = tx.Select(clause.Associations).Delete(account.UsersG, "account_id = ?", account.Id)
if result.Error != nil {
return result.Error
}
result = tx.Select(clause.Associations).Delete(account)
if result.Error != nil {
return result.Error
}
return nil
})
took := time.Since(start)
if s.metrics != nil {
s.metrics.StoreMetrics().CountPersistenceDuration(took)
}
log.WithContext(ctx).Debugf("took %d ms to delete an account to the store", took.Milliseconds())
return err
}
func (s *SqlStore) SaveInstallationID(_ context.Context, ID string) error {
installation := installation{InstallationIDValue: ID}
installation.ID = uint(s.installationPK)
return s.db.Clauses(clause.OnConflict{UpdateAll: true}).Create(&installation).Error
}
func (s *SqlStore) GetInstallationID() string {
var installation installation
if result := s.db.First(&installation, idQueryCondition, s.installationPK); result.Error != nil {
return ""
}
return installation.InstallationIDValue
}
func (s *SqlStore) SavePeer(ctx context.Context, accountID string, peer *nbpeer.Peer) error {
// To maintain data integrity, we create a copy of the peer's to prevent unintended updates to other fields.
peerCopy := peer.Copy()
peerCopy.AccountID = accountID
err := s.db.WithContext(ctx).Transaction(func(tx *gorm.DB) error {
// check if peer exists before saving
var peerID string
result := tx.Model(&nbpeer.Peer{}).Select("id").Find(&peerID, accountAndIDQueryCondition, accountID, peer.ID)
if result.Error != nil {
return result.Error
}
if peerID == "" {
return status.Errorf(status.NotFound, peerNotFoundFMT, peer.ID)
}
result = tx.Model(&nbpeer.Peer{}).Where(accountAndIDQueryCondition, accountID, peer.ID).Save(peerCopy)
if result.Error != nil {
return result.Error
}
return nil
})
if err != nil {
return err
}
return nil
}
func (s *SqlStore) SavePeerStatus(accountID, peerID string, peerStatus nbpeer.PeerStatus) error {
2024-05-07 14:30:03 +02:00
var peerCopy nbpeer.Peer
peerCopy.Status = &peerStatus
fieldsToUpdate := []string{
"peer_status_last_seen", "peer_status_connected",
"peer_status_login_expired", "peer_status_required_approval",
}
2024-05-07 14:30:03 +02:00
result := s.db.Model(&nbpeer.Peer{}).
Select(fieldsToUpdate).
Where(accountAndIDQueryCondition, accountID, peerID).
Updates(&peerCopy)
if result.Error != nil {
2024-05-07 14:30:03 +02:00
return result.Error
}
2024-05-07 14:30:03 +02:00
if result.RowsAffected == 0 {
return status.Errorf(status.NotFound, peerNotFoundFMT, peerID)
2024-05-07 14:30:03 +02:00
}
2024-05-07 14:30:03 +02:00
return nil
}
func (s *SqlStore) SavePeerLocation(accountID string, peerWithLocation *nbpeer.Peer) error {
2024-05-07 14:30:03 +02:00
// To maintain data integrity, we create a copy of the peer's location to prevent unintended updates to other fields.
var peerCopy nbpeer.Peer
// Since the location field has been migrated to JSON serialization,
// updating the struct ensures the correct data format is inserted into the database.
peerCopy.Location = peerWithLocation.Location
result := s.db.Model(&nbpeer.Peer{}).
Where(accountAndIDQueryCondition, accountID, peerWithLocation.ID).
2024-05-07 14:30:03 +02:00
Updates(peerCopy)
if result.Error != nil {
2024-05-07 14:30:03 +02:00
return result.Error
}
2024-05-07 14:30:03 +02:00
if result.RowsAffected == 0 {
return status.Errorf(status.NotFound, peerNotFoundFMT, peerWithLocation.ID)
2024-05-07 14:30:03 +02:00
}
2024-05-07 14:30:03 +02:00
return nil
}
// SaveUsers saves the given list of users to the database.
// It updates existing users if a conflict occurs.
func (s *SqlStore) SaveUsers(accountID string, users map[string]*User) error {
usersToSave := make([]User, 0, len(users))
for _, user := range users {
user.AccountID = accountID
for id, pat := range user.PATs {
pat.ID = id
user.PATsG = append(user.PATsG, *pat)
}
usersToSave = append(usersToSave, *user)
}
return s.db.Session(&gorm.Session{FullSaveAssociations: true}).
Clauses(clause.OnConflict{UpdateAll: true}).
Create(&usersToSave).Error
}
// SaveGroups saves the given list of groups to the database.
// It updates existing groups if a conflict occurs.
func (s *SqlStore) SaveGroups(accountID string, groups map[string]*nbgroup.Group) error {
groupsToSave := make([]nbgroup.Group, 0, len(groups))
for _, group := range groups {
group.AccountID = accountID
groupsToSave = append(groupsToSave, *group)
}
return s.db.Clauses(clause.OnConflict{UpdateAll: true}).Create(&groupsToSave).Error
}
// DeleteHashedPAT2TokenIDIndex is noop in SqlStore
func (s *SqlStore) DeleteHashedPAT2TokenIDIndex(hashedToken string) error {
return nil
}
// DeleteTokenID2UserIDIndex is noop in SqlStore
func (s *SqlStore) DeleteTokenID2UserIDIndex(tokenID string) error {
return nil
}
func (s *SqlStore) GetAccountByPrivateDomain(ctx context.Context, domain string) (*Account, error) {
accountID, err := s.GetAccountIDByPrivateDomain(ctx, LockingStrengthShare, domain)
if err != nil {
return nil, err
}
// TODO: rework to not call GetAccount
return s.GetAccount(ctx, accountID)
}
func (s *SqlStore) GetAccountIDByPrivateDomain(ctx context.Context, lockStrength LockingStrength, domain string) (string, error) {
var accountID string
result := s.db.WithContext(ctx).Clauses(clause.Locking{Strength: string(lockStrength)}).Model(&Account{}).Select("id").
Where("domain = ? and is_domain_primary_account = ? and domain_category = ?",
strings.ToLower(domain), true, PrivateCategory,
).First(&accountID)
if result.Error != nil {
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
return "", status.Errorf(status.NotFound, "account not found: provided domain is not registered or is not private")
}
log.WithContext(ctx).Errorf("error when getting account from the store: %s", result.Error)
return "", status.Errorf(status.Internal, "issue getting account from store")
}
return accountID, nil
}
func (s *SqlStore) GetAccountBySetupKey(ctx context.Context, setupKey string) (*Account, error) {
var key SetupKey
result := s.db.WithContext(ctx).Select("account_id").First(&key, keyQueryCondition, strings.ToUpper(setupKey))
if result.Error != nil {
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
return nil, status.Errorf(status.NotFound, "account not found: index lookup failed")
}
return nil, status.NewSetupKeyNotFoundError()
}
if key.AccountID == "" {
return nil, status.Errorf(status.NotFound, "account not found: index lookup failed")
}
return s.GetAccount(ctx, key.AccountID)
}
func (s *SqlStore) GetTokenIDByHashedToken(ctx context.Context, hashedToken string) (string, error) {
var token PersonalAccessToken
result := s.db.First(&token, "hashed_token = ?", hashedToken)
if result.Error != nil {
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
return "", status.Errorf(status.NotFound, "account not found: index lookup failed")
}
log.WithContext(ctx).Errorf("error when getting token from the store: %s", result.Error)
return "", status.Errorf(status.Internal, "issue getting account from store")
}
return token.ID, nil
}
func (s *SqlStore) GetUserByTokenID(ctx context.Context, tokenID string) (*User, error) {
var token PersonalAccessToken
result := s.db.First(&token, idQueryCondition, tokenID)
if result.Error != nil {
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
return nil, status.Errorf(status.NotFound, "account not found: index lookup failed")
}
log.WithContext(ctx).Errorf("error when getting token from the store: %s", result.Error)
return nil, status.Errorf(status.Internal, "issue getting account from store")
}
if token.UserID == "" {
return nil, status.Errorf(status.NotFound, "account not found: index lookup failed")
}
var user User
result = s.db.Preload("PATsG").First(&user, idQueryCondition, token.UserID)
if result.Error != nil {
return nil, status.Errorf(status.NotFound, "account not found: index lookup failed")
}
user.PATs = make(map[string]*PersonalAccessToken, len(user.PATsG))
for _, pat := range user.PATsG {
user.PATs[pat.ID] = pat.Copy()
}
return &user, nil
}
func (s *SqlStore) GetUserByUserID(ctx context.Context, lockStrength LockingStrength, userID string) (*User, error) {
2024-08-19 12:50:11 +02:00
var user User
result := s.db.WithContext(ctx).Clauses(clause.Locking{Strength: string(lockStrength)}).
Preload(clause.Associations).First(&user, idQueryCondition, userID)
2024-08-19 12:50:11 +02:00
if result.Error != nil {
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
return nil, status.NewUserNotFoundError(userID)
2024-08-19 12:50:11 +02:00
}
return nil, status.NewGetUserFromStoreError()
2024-08-19 12:50:11 +02:00
}
return &user, nil
}
func (s *SqlStore) GetAccountGroups(ctx context.Context, accountID string) ([]*nbgroup.Group, error) {
var groups []*nbgroup.Group
result := s.db.Find(&groups, accountIDCondition, accountID)
2024-08-19 12:50:11 +02:00
if result.Error != nil {
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
return nil, status.Errorf(status.NotFound, "accountID not found: index lookup failed")
}
log.WithContext(ctx).Errorf("error when getting groups from the store: %s", result.Error)
return nil, status.Errorf(status.Internal, "issue getting groups from store")
}
return groups, nil
}
func (s *SqlStore) GetAllAccounts(ctx context.Context) (all []*Account) {
var accounts []Account
result := s.db.Find(&accounts)
if result.Error != nil {
return all
}
for _, account := range accounts {
if acc, err := s.GetAccount(ctx, account.Id); err == nil {
all = append(all, acc)
}
}
return all
}
func (s *SqlStore) GetAccount(ctx context.Context, accountID string) (*Account, error) {
start := time.Now()
defer func() {
elapsed := time.Since(start)
if elapsed > 1*time.Second {
log.WithContext(ctx).Tracef("GetAccount for account %s exceeded 1s, took: %v", accountID, elapsed)
}
}()
2024-05-07 14:30:03 +02:00
var account Account
result := s.db.Model(&account).
Preload("UsersG.PATsG"). // have to be specifies as this is nester reference
Preload(clause.Associations).
First(&account, idQueryCondition, accountID)
if result.Error != nil {
log.WithContext(ctx).Errorf("error when getting account %s from the store: %s", accountID, result.Error)
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
return nil, status.NewAccountNotFoundError(accountID)
}
return nil, status.Errorf(status.Internal, "issue getting account from store")
}
// we have to manually preload policy rules as it seems that gorm preloading doesn't do it for us
for i, policy := range account.Policies {
var rules []*PolicyRule
err := s.db.Model(&PolicyRule{}).Find(&rules, "policy_id = ?", policy.ID).Error
if err != nil {
return nil, status.Errorf(status.NotFound, "rule not found")
}
account.Policies[i].Rules = rules
}
account.SetupKeys = make(map[string]*SetupKey, len(account.SetupKeysG))
for _, key := range account.SetupKeysG {
account.SetupKeys[key.Key] = key.Copy()
}
account.SetupKeysG = nil
2023-11-28 13:45:26 +01:00
account.Peers = make(map[string]*nbpeer.Peer, len(account.PeersG))
for _, peer := range account.PeersG {
account.Peers[peer.ID] = peer.Copy()
}
account.PeersG = nil
account.Users = make(map[string]*User, len(account.UsersG))
for _, user := range account.UsersG {
user.PATs = make(map[string]*PersonalAccessToken, len(user.PATs))
for _, pat := range user.PATsG {
user.PATs[pat.ID] = pat.Copy()
}
account.Users[user.Id] = user.Copy()
}
account.UsersG = nil
account.Groups = make(map[string]*nbgroup.Group, len(account.GroupsG))
for _, group := range account.GroupsG {
account.Groups[group.ID] = group.Copy()
}
account.GroupsG = nil
2024-05-06 14:47:49 +02:00
account.Routes = make(map[route.ID]*route.Route, len(account.RoutesG))
for _, route := range account.RoutesG {
account.Routes[route.ID] = route.Copy()
}
account.RoutesG = nil
account.NameServerGroups = make(map[string]*nbdns.NameServerGroup, len(account.NameServerGroupsG))
for _, ns := range account.NameServerGroupsG {
account.NameServerGroups[ns.ID] = ns.Copy()
}
account.NameServerGroupsG = nil
return &account, nil
}
func (s *SqlStore) GetAccountByUser(ctx context.Context, userID string) (*Account, error) {
var user User
result := s.db.WithContext(ctx).Select("account_id").First(&user, idQueryCondition, userID)
if result.Error != nil {
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
return nil, status.Errorf(status.NotFound, "account not found: index lookup failed")
}
return nil, status.Errorf(status.Internal, "issue getting account from store")
}
if user.AccountID == "" {
return nil, status.Errorf(status.NotFound, "account not found: index lookup failed")
}
return s.GetAccount(ctx, user.AccountID)
}
func (s *SqlStore) GetAccountByPeerID(ctx context.Context, peerID string) (*Account, error) {
2023-11-28 13:45:26 +01:00
var peer nbpeer.Peer
result := s.db.WithContext(ctx).Select("account_id").First(&peer, idQueryCondition, peerID)
if result.Error != nil {
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
return nil, status.Errorf(status.NotFound, "account not found: index lookup failed")
}
return nil, status.Errorf(status.Internal, "issue getting account from store")
}
if peer.AccountID == "" {
return nil, status.Errorf(status.NotFound, "account not found: index lookup failed")
}
return s.GetAccount(ctx, peer.AccountID)
}
func (s *SqlStore) GetAccountByPeerPubKey(ctx context.Context, peerKey string) (*Account, error) {
2023-11-28 13:45:26 +01:00
var peer nbpeer.Peer
result := s.db.WithContext(ctx).Select("account_id").First(&peer, keyQueryCondition, peerKey)
if result.Error != nil {
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
return nil, status.Errorf(status.NotFound, "account not found: index lookup failed")
}
return nil, status.Errorf(status.Internal, "issue getting account from store")
}
if peer.AccountID == "" {
return nil, status.Errorf(status.NotFound, "account not found: index lookup failed")
}
return s.GetAccount(ctx, peer.AccountID)
}
func (s *SqlStore) GetAccountIDByPeerPubKey(ctx context.Context, peerKey string) (string, error) {
2024-05-07 14:30:03 +02:00
var peer nbpeer.Peer
var accountID string
result := s.db.WithContext(ctx).Model(&peer).Select("account_id").Where(keyQueryCondition, peerKey).First(&accountID)
2024-05-07 14:30:03 +02:00
if result.Error != nil {
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
return "", status.Errorf(status.NotFound, "account not found: index lookup failed")
}
return "", status.Errorf(status.Internal, "issue getting account from store")
}
return accountID, nil
}
2024-05-31 16:41:12 +02:00
func (s *SqlStore) GetAccountIDByUserID(userID string) (string, error) {
var accountID string
result := s.db.Model(&User{}).Select("account_id").Where(idQueryCondition, userID).First(&accountID)
2024-05-31 16:41:12 +02:00
if result.Error != nil {
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
return "", status.Errorf(status.NotFound, "account not found: index lookup failed")
}
return "", status.Errorf(status.Internal, "issue getting account from store")
}
return accountID, nil
}
func (s *SqlStore) GetAccountIDBySetupKey(ctx context.Context, setupKey string) (string, error) {
2024-05-31 16:41:12 +02:00
var accountID string
result := s.db.WithContext(ctx).Model(&SetupKey{}).Select("account_id").Where(keyQueryCondition, strings.ToUpper(setupKey)).First(&accountID)
2024-05-31 16:41:12 +02:00
if result.Error != nil {
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
return "", status.Errorf(status.NotFound, "account not found: index lookup failed")
}
return "", status.NewSetupKeyNotFoundError()
}
if accountID == "" {
return "", status.Errorf(status.NotFound, "account not found: index lookup failed")
2024-05-31 16:41:12 +02:00
}
return accountID, nil
}
func (s *SqlStore) GetTakenIPs(ctx context.Context, lockStrength LockingStrength, accountID string) ([]net.IP, error) {
var ipJSONStrings []string
// Fetch the IP addresses as JSON strings
result := s.db.WithContext(ctx).Clauses(clause.Locking{Strength: string(lockStrength)}).Model(&nbpeer.Peer{}).
Where("account_id = ?", accountID).
Pluck("ip", &ipJSONStrings)
if result.Error != nil {
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
return nil, status.Errorf(status.NotFound, "no peers found for the account")
}
return nil, status.Errorf(status.Internal, "issue getting IPs from store")
}
// Convert the JSON strings to net.IP objects
ips := make([]net.IP, len(ipJSONStrings))
for i, ipJSON := range ipJSONStrings {
var ip net.IP
if err := json.Unmarshal([]byte(ipJSON), &ip); err != nil {
return nil, status.Errorf(status.Internal, "issue parsing IP JSON from store")
}
ips[i] = ip
}
return ips, nil
}
func (s *SqlStore) GetPeerLabelsInAccount(ctx context.Context, lockStrength LockingStrength, accountID string) ([]string, error) {
var labels []string
result := s.db.WithContext(ctx).Clauses(clause.Locking{Strength: string(lockStrength)}).Model(&nbpeer.Peer{}).
Where("account_id = ?", accountID).
Pluck("dns_label", &labels)
if result.Error != nil {
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
return nil, status.Errorf(status.NotFound, "no peers found for the account")
}
log.WithContext(ctx).Errorf("error when getting dns labels from the store: %s", result.Error)
return nil, status.Errorf(status.Internal, "issue getting dns labels from store")
}
return labels, nil
}
func (s *SqlStore) GetAccountNetwork(ctx context.Context, lockStrength LockingStrength, accountID string) (*Network, error) {
var accountNetwork AccountNetwork
if err := s.db.WithContext(ctx).Model(&Account{}).Where(idQueryCondition, accountID).First(&accountNetwork).Error; err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return nil, status.NewAccountNotFoundError(accountID)
}
return nil, status.Errorf(status.Internal, "issue getting network from store")
}
return accountNetwork.Network, nil
}
func (s *SqlStore) GetPeerByPeerPubKey(ctx context.Context, lockStrength LockingStrength, peerKey string) (*nbpeer.Peer, error) {
2024-05-31 16:41:12 +02:00
var peer nbpeer.Peer
result := s.db.WithContext(ctx).Clauses(clause.Locking{Strength: string(lockStrength)}).First(&peer, keyQueryCondition, peerKey)
2024-05-31 16:41:12 +02:00
if result.Error != nil {
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
return nil, status.Errorf(status.NotFound, "peer not found")
}
return nil, status.Errorf(status.Internal, "issue getting peer from store")
}
return &peer, nil
}
func (s *SqlStore) GetAccountSettings(ctx context.Context, lockStrength LockingStrength, accountID string) (*Settings, error) {
2024-05-31 16:41:12 +02:00
var accountSettings AccountSettings
if err := s.db.WithContext(ctx).Clauses(clause.Locking{Strength: string(lockStrength)}).Model(&Account{}).Where(idQueryCondition, accountID).First(&accountSettings).Error; err != nil {
2024-05-31 16:41:12 +02:00
if errors.Is(err, gorm.ErrRecordNotFound) {
return nil, status.Errorf(status.NotFound, "settings not found")
}
return nil, status.Errorf(status.Internal, "issue getting settings from store")
}
return accountSettings.Settings, nil
}
// SaveUserLastLogin stores the last login time for a user in DB.
func (s *SqlStore) SaveUserLastLogin(ctx context.Context, accountID, userID string, lastLogin time.Time) error {
var user User
result := s.db.WithContext(ctx).First(&user, accountAndIDQueryCondition, accountID, userID)
if result.Error != nil {
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
return status.NewUserNotFoundError(userID)
}
return status.NewGetUserFromStoreError()
}
user.LastLogin = lastLogin
return s.db.Save(&user).Error
}
2024-05-30 15:22:42 +02:00
func (s *SqlStore) GetPostureCheckByChecksDefinition(accountID string, checks *posture.ChecksDefinition) (*posture.Checks, error) {
definitionJSON, err := json.Marshal(checks)
if err != nil {
return nil, err
}
var postureCheck posture.Checks
err = s.db.Where("account_id = ? AND checks = ?", accountID, string(definitionJSON)).First(&postureCheck).Error
if err != nil {
return nil, err
}
return &postureCheck, nil
}
// Close closes the underlying DB connection
func (s *SqlStore) Close(_ context.Context) error {
sql, err := s.db.DB()
if err != nil {
return fmt.Errorf("get db: %w", err)
}
return sql.Close()
}
// GetStoreEngine returns underlying store engine
func (s *SqlStore) GetStoreEngine() StoreEngine {
return s.storeEngine
}
// NewSqliteStore creates a new SQLite store.
func NewSqliteStore(ctx context.Context, dataDir string, metrics telemetry.AppMetrics) (*SqlStore, error) {
storeStr := fmt.Sprintf("%s?cache=shared", storeSqliteFileName)
if runtime.GOOS == "windows" {
// Vo avoid `The process cannot access the file because it is being used by another process` on Windows
storeStr = storeSqliteFileName
}
file := filepath.Join(dataDir, storeStr)
db, err := gorm.Open(sqlite.Open(file), getGormConfig())
if err != nil {
return nil, err
}
return NewSqlStore(ctx, db, SqliteStoreEngine, metrics)
}
// NewPostgresqlStore creates a new Postgres store.
func NewPostgresqlStore(ctx context.Context, dsn string, metrics telemetry.AppMetrics) (*SqlStore, error) {
db, err := gorm.Open(postgres.Open(dsn), getGormConfig())
if err != nil {
return nil, err
}
return NewSqlStore(ctx, db, PostgresStoreEngine, metrics)
}
func getGormConfig() *gorm.Config {
return &gorm.Config{
Logger: logger.Default.LogMode(logger.Silent),
CreateBatchSize: 400,
PrepareStmt: true,
TranslateError: true,
}
}
// newPostgresStore initializes a new Postgres store.
func newPostgresStore(ctx context.Context, metrics telemetry.AppMetrics) (Store, error) {
dsn, ok := os.LookupEnv(postgresDsnEnv)
if !ok {
return nil, fmt.Errorf("%s is not set", postgresDsnEnv)
}
return NewPostgresqlStore(ctx, dsn, metrics)
}
// NewSqliteStoreFromFileStore restores a store from FileStore and stores SQLite DB in the file located in datadir.
func NewSqliteStoreFromFileStore(ctx context.Context, fileStore *FileStore, dataDir string, metrics telemetry.AppMetrics) (*SqlStore, error) {
store, err := NewSqliteStore(ctx, dataDir, metrics)
if err != nil {
return nil, err
}
err = store.SaveInstallationID(ctx, fileStore.InstallationID)
if err != nil {
return nil, err
}
for _, account := range fileStore.GetAllAccounts(ctx) {
err := store.SaveAccount(ctx, account)
if err != nil {
return nil, err
}
}
return store, nil
}
// NewPostgresqlStoreFromFileStore restores a store from FileStore and stores Postgres DB.
func NewPostgresqlStoreFromFileStore(ctx context.Context, fileStore *FileStore, dsn string, metrics telemetry.AppMetrics) (*SqlStore, error) {
store, err := NewPostgresqlStore(ctx, dsn, metrics)
if err != nil {
return nil, err
}
err = store.SaveInstallationID(ctx, fileStore.InstallationID)
if err != nil {
return nil, err
}
for _, account := range fileStore.GetAllAccounts(ctx) {
err := store.SaveAccount(ctx, account)
if err != nil {
return nil, err
}
}
return store, nil
}
func (s *SqlStore) GetSetupKeyBySecret(ctx context.Context, lockStrength LockingStrength, key string) (*SetupKey, error) {
var setupKey SetupKey
result := s.db.WithContext(ctx).Clauses(clause.Locking{Strength: string(lockStrength)}).
First(&setupKey, keyQueryCondition, strings.ToUpper(key))
if result.Error != nil {
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
return nil, status.Errorf(status.NotFound, "setup key not found")
}
return nil, status.NewSetupKeyNotFoundError()
}
return &setupKey, nil
}
func (s *SqlStore) IncrementSetupKeyUsage(ctx context.Context, setupKeyID string) error {
result := s.db.WithContext(ctx).Model(&SetupKey{}).
Where(idQueryCondition, setupKeyID).
Updates(map[string]interface{}{
"used_times": gorm.Expr("used_times + 1"),
"last_used": time.Now(),
})
if result.Error != nil {
return status.Errorf(status.Internal, "issue incrementing setup key usage count: %s", result.Error)
}
if result.RowsAffected == 0 {
return status.Errorf(status.NotFound, "setup key not found")
}
return nil
}
func (s *SqlStore) AddPeerToAllGroup(ctx context.Context, accountID string, peerID string) error {
var group nbgroup.Group
result := s.db.WithContext(ctx).Where("account_id = ? AND name = ?", accountID, "All").First(&group)
if result.Error != nil {
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
return status.Errorf(status.NotFound, "group 'All' not found for account")
}
return status.Errorf(status.Internal, "issue finding group 'All'")
}
for _, existingPeerID := range group.Peers {
if existingPeerID == peerID {
return nil
}
}
group.Peers = append(group.Peers, peerID)
if err := s.db.Save(&group).Error; err != nil {
return status.Errorf(status.Internal, "issue updating group 'All'")
}
return nil
}
func (s *SqlStore) AddPeerToGroup(ctx context.Context, accountId string, peerId string, groupID string) error {
var group nbgroup.Group
result := s.db.WithContext(ctx).Where(accountAndIDQueryCondition, accountId, groupID).First(&group)
if result.Error != nil {
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
return status.Errorf(status.NotFound, "group not found for account")
}
return status.Errorf(status.Internal, "issue finding group")
}
for _, existingPeerID := range group.Peers {
if existingPeerID == peerId {
return nil
}
}
group.Peers = append(group.Peers, peerId)
if err := s.db.Save(&group).Error; err != nil {
return status.Errorf(status.Internal, "issue updating group")
}
return nil
}
func (s *SqlStore) AddPeerToAccount(ctx context.Context, peer *nbpeer.Peer) error {
if err := s.db.WithContext(ctx).Create(peer).Error; err != nil {
return status.Errorf(status.Internal, "issue adding peer to account")
}
return nil
}
func (s *SqlStore) IncrementNetworkSerial(ctx context.Context, lockStrength LockingStrength, accountId string) error {
result := s.db.WithContext(ctx).Clauses(clause.Locking{Strength: string(lockStrength)}).Model(&Account{}).
Where(idQueryCondition, accountId).Update("network_serial", gorm.Expr("network_serial + 1"))
if result.Error != nil {
return status.Errorf(status.Internal, "issue incrementing network serial count")
}
return nil
}
func (s *SqlStore) ExecuteInTransaction(ctx context.Context, operation func(store Store) error) error {
tx := s.db.WithContext(ctx).Begin()
if tx.Error != nil {
return tx.Error
}
repo := s.withTx(tx)
err := operation(repo)
if err != nil {
tx.Rollback()
return err
}
return tx.Commit().Error
}
func (s *SqlStore) withTx(tx *gorm.DB) Store {
return &SqlStore{
db: tx,
}
}
func (s *SqlStore) GetAccountDNSSettings(ctx context.Context, lockStrength LockingStrength, accountID string) (*DNSSettings, error) {
var accountDNSSettings AccountDNSSettings
result := s.db.WithContext(ctx).Clauses(clause.Locking{Strength: string(lockStrength)}).Model(&Account{}).
First(&accountDNSSettings, idQueryCondition, accountID)
if result.Error != nil {
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
return nil, status.Errorf(status.NotFound, "dns settings not found")
}
return nil, status.Errorf(status.Internal, "failed to get dns settings from store: %v", result.Error)
}
return &accountDNSSettings.DNSSettings, nil
}
// AccountExists checks whether an account exists by the given ID.
func (s *SqlStore) AccountExists(ctx context.Context, id string) (bool, error) {
var count int64
result := s.db.WithContext(ctx).Model(&Account{}).Where(idQueryCondition, id).Count(&count)
if result.Error != nil {
return false, result.Error
}
return count > 0, nil
}
// GetAccountDomainAndCategory retrieves the Domain and DomainCategory fields for an account based on the given accountID.
func (s *SqlStore) GetAccountDomainAndCategory(ctx context.Context, lockStrength LockingStrength, accountID string) (string, string, error) {
var account Account
result := s.db.WithContext(ctx).Clauses(clause.Locking{Strength: string(lockStrength)}).Model(&Account{}).Select("domain", "domain_category").
Where(idQueryCondition, accountID).First(&account)
if result.Error != nil {
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
return "", "", status.Errorf(status.NotFound, "account not found")
}
return "", "", status.Errorf(status.Internal, "failed to get domain category from store: %v", result.Error)
}
return account.Domain, account.DomainCategory, nil
}
// GetGroupByID retrieves a group by ID and account ID.
func (s *SqlStore) GetGroupByID(ctx context.Context, lockStrength LockingStrength, groupID, accountID string) (*nbgroup.Group, error) {
return getRecordByID[nbgroup.Group](s.db.WithContext(ctx).Preload(clause.Associations), lockStrength, groupID, accountID)
}
// GetGroupByName retrieves a group by name and account ID.
func (s *SqlStore) GetGroupByName(ctx context.Context, lockStrength LockingStrength, groupName, accountID string) (*nbgroup.Group, error) {
var group nbgroup.Group
result := s.db.WithContext(ctx).Clauses(clause.Locking{Strength: string(lockStrength)}).Preload(clause.Associations).
Order("json_array_length(peers) DESC").First(&group, "name = ? and account_id = ?", groupName, accountID)
if err := result.Error; err != nil {
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
return nil, status.Errorf(status.NotFound, "group not found")
}
return nil, status.Errorf(status.Internal, "failed to get group from store: %s", result.Error)
}
return &group, nil
}
// GetAccountPolicies retrieves policies for an account.
func (s *SqlStore) GetAccountPolicies(ctx context.Context, accountID string) ([]*Policy, error) {
return getRecords[*Policy](s.db.WithContext(ctx).Preload(clause.Associations), accountID)
}
// GetPolicyByID retrieves a policy by its ID and account ID.
func (s *SqlStore) GetPolicyByID(ctx context.Context, lockStrength LockingStrength, policyID string, accountID string) (*Policy, error) {
return getRecordByID[Policy](s.db.WithContext(ctx).Preload(clause.Associations), lockStrength, policyID, accountID)
}
// SavePolicy saves a policy to the database.
func (s *SqlStore) SavePolicy(ctx context.Context, lockStrength LockingStrength, policy *Policy) error {
return s.db.WithContext(ctx).Session(&gorm.Session{FullSaveAssociations: true}).
Clauses(clause.Locking{Strength: string(lockStrength)}).Save(&policy).Error
}
// DeletePolicy deletes a policy from the database.
func (s *SqlStore) DeletePolicy(ctx context.Context, lockStrength LockingStrength, policyID string) error {
return s.db.WithContext(ctx).Clauses(clause.Locking{Strength: string(lockStrength)}).
Delete(&Policy{}, idQueryCondition, policyID).Error
}
// GetAccountPostureChecks retrieves posture checks for an account.
func (s *SqlStore) GetAccountPostureChecks(ctx context.Context, accountID string) ([]*posture.Checks, error) {
return getRecords[*posture.Checks](s.db.WithContext(ctx), accountID)
}
// GetPostureChecksByID retrieves posture checks by their ID and account ID.
func (s *SqlStore) GetPostureChecksByID(ctx context.Context, lockStrength LockingStrength, postureCheckID string, accountID string) (*posture.Checks, error) {
return getRecordByID[posture.Checks](s.db.WithContext(ctx), lockStrength, postureCheckID, accountID)
}
// SavePostureChecks saves a posture checks to the database.
func (s *SqlStore) SavePostureChecks(ctx context.Context, lockStrength LockingStrength, postureCheck *posture.Checks) error {
result := s.db.WithContext(ctx).Session(&gorm.Session{FullSaveAssociations: true}).
Clauses(clause.Locking{Strength: string(lockStrength)}).Save(&postureCheck)
if result.Error != nil {
if errors.Is(result.Error, gorm.ErrDuplicatedKey) {
return status.Errorf(status.InvalidArgument, "name should be unique")
}
return result.Error
}
return nil
}
// DeletePostureChecks deletes a posture checks from the database.
func (s *SqlStore) DeletePostureChecks(ctx context.Context, lockStrength LockingStrength, postureChecksID string) error {
return s.db.WithContext(ctx).Clauses(clause.Locking{Strength: string(lockStrength)}).
Delete(&posture.Checks{}, idQueryCondition, postureChecksID).Error
}
// GetAccountRoutes retrieves network routes for an account.
func (s *SqlStore) GetAccountRoutes(ctx context.Context, accountID string) ([]*route.Route, error) {
return getRecords[*route.Route](s.db.WithContext(ctx), accountID)
}
// GetRouteByID retrieves a route by its ID and account ID.
func (s *SqlStore) GetRouteByID(ctx context.Context, lockStrength LockingStrength, routeID string, accountID string) (*route.Route, error) {
return getRecordByID[route.Route](s.db.WithContext(ctx), lockStrength, routeID, accountID)
}
// GetAccountSetupKeys retrieves setup keys for an account.
func (s *SqlStore) GetAccountSetupKeys(ctx context.Context, accountID string) ([]*SetupKey, error) {
return getRecords[*SetupKey](s.db.WithContext(ctx), accountID)
}
// GetSetupKeyByID retrieves a setup key by its ID and account ID.
func (s *SqlStore) GetSetupKeyByID(ctx context.Context, lockStrength LockingStrength, setupKeyID string, accountID string) (*SetupKey, error) {
return getRecordByID[SetupKey](s.db.WithContext(ctx), lockStrength, setupKeyID, accountID)
}
// GetAccountNameServerGroups retrieves name server groups for an account.
func (s *SqlStore) GetAccountNameServerGroups(ctx context.Context, accountID string) ([]*nbdns.NameServerGroup, error) {
return getRecords[*nbdns.NameServerGroup](s.db.WithContext(ctx), accountID)
}
// GetNameServerGroupByID retrieves a name server group by its ID and account ID.
func (s *SqlStore) GetNameServerGroupByID(ctx context.Context, lockStrength LockingStrength, nsGroupID string, accountID string) (*nbdns.NameServerGroup, error) {
return getRecordByID[nbdns.NameServerGroup](s.db.WithContext(ctx), lockStrength, nsGroupID, accountID)
}
// getRecords retrieves records from the database based on the account ID.
func getRecords[T any](db *gorm.DB, accountID string) ([]T, error) {
var record []T
result := db.Find(&record, accountIDCondition, accountID)
if err := result.Error; err != nil {
parts := strings.Split(fmt.Sprintf("%T", record), ".")
recordType := parts[len(parts)-1]
return nil, status.Errorf(status.Internal, "failed to get account %ss from store: %v", recordType, err)
}
return record, nil
}
// getRecordByID retrieves a record by its ID and account ID from the database.
func getRecordByID[T any](db *gorm.DB, lockStrength LockingStrength, recordID, accountID string) (*T, error) {
var record T
result := db.Clauses(clause.Locking{Strength: string(lockStrength)}).
First(&record, accountAndIDQueryCondition, accountID, recordID)
if err := result.Error; err != nil {
parts := strings.Split(fmt.Sprintf("%T", record), ".")
recordType := parts[len(parts)-1]
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
return nil, status.Errorf(status.NotFound, "%s not found", recordType)
}
return nil, status.Errorf(status.Internal, "failed to get %s from store: %v", recordType, err)
}
return &record, nil
}