2023-10-12 15:42:36 +02:00
package server
import (
2024-07-03 11:33:02 +02:00
"context"
2024-05-30 15:22:42 +02:00
"encoding/json"
2024-03-10 19:09:45 +01:00
"errors"
2024-02-22 12:27:08 +01:00
"fmt"
2024-09-16 15:47:03 +02:00
"net"
2024-06-13 12:39:19 +02:00
"os"
2023-10-12 15:42:36 +02:00
"path/filepath"
"runtime"
2024-07-02 12:40:26 +02:00
"runtime/debug"
2023-10-12 15:42:36 +02:00
"strings"
"sync"
"time"
log "github.com/sirupsen/logrus"
2024-05-16 18:28:37 +02:00
"gorm.io/driver/postgres"
2023-10-12 15:42:36 +02:00
"gorm.io/driver/sqlite"
"gorm.io/gorm"
"gorm.io/gorm/clause"
"gorm.io/gorm/logger"
2023-11-28 13:45:26 +01:00
nbdns "github.com/netbirdio/netbird/dns"
2023-11-30 11:51:35 +01:00
"github.com/netbirdio/netbird/management/server/account"
2024-03-27 18:48:48 +01:00
nbgroup "github.com/netbirdio/netbird/management/server/group"
2023-11-28 13:45:26 +01:00
nbpeer "github.com/netbirdio/netbird/management/server/peer"
2024-02-20 09:59:56 +01:00
"github.com/netbirdio/netbird/management/server/posture"
2023-11-28 13:45:26 +01:00
"github.com/netbirdio/netbird/management/server/status"
"github.com/netbirdio/netbird/management/server/telemetry"
"github.com/netbirdio/netbird/route"
2023-10-12 15:42:36 +02:00
)
2024-06-13 12:39:19 +02:00
const (
2024-07-26 07:49:05 +02:00
storeSqliteFileName = "store.db"
idQueryCondition = "id = ?"
2024-09-16 15:47:03 +02:00
keyQueryCondition = "key = ?"
2024-07-26 07:49:05 +02:00
accountAndIDQueryCondition = "account_id = ? and id = ?"
2024-09-22 14:14:31 +02:00
accountIDCondition = "account_id = ?"
2024-07-26 07:49:05 +02:00
peerNotFoundFMT = "peer %s not found"
2024-06-13 12:39:19 +02:00
)
2024-05-16 18:28:37 +02:00
// SqlStore represents an account storage backed by a Sql DB persisted to disk
type SqlStore struct {
2023-10-12 15:42:36 +02:00
db * gorm . DB
2024-07-31 14:53:32 +02:00
resourceLocks sync . Map
2023-10-12 15:42:36 +02:00
globalAccountLock sync . Mutex
metrics telemetry . AppMetrics
installationPK int
2024-05-16 18:28:37 +02:00
storeEngine StoreEngine
2023-10-12 15:42:36 +02:00
}
type installation struct {
ID uint ` gorm:"primaryKey" `
InstallationIDValue string
}
2024-04-18 18:14:21 +02:00
type migrationFunc func ( * gorm . DB ) error
2024-05-16 18:28:37 +02:00
// NewSqlStore creates a new SqlStore instance.
2024-07-03 11:33:02 +02:00
func NewSqlStore ( ctx context . Context , db * gorm . DB , storeEngine StoreEngine , metrics telemetry . AppMetrics ) ( * SqlStore , error ) {
2023-10-12 15:42:36 +02:00
sql , err := db . DB ( )
if err != nil {
return nil , err
}
conns := runtime . NumCPU ( )
sql . SetMaxOpenConns ( conns ) // TODO: make it configurable
2024-07-03 11:33:02 +02:00
if err := migrate ( ctx , db ) ; err != nil {
2024-04-18 18:14:21 +02:00
return nil , fmt . Errorf ( "migrate: %w" , err )
}
2023-10-12 15:42:36 +02:00
err = db . AutoMigrate (
2024-03-27 18:48:48 +01:00
& SetupKey { } , & nbpeer . Peer { } , & User { } , & PersonalAccessToken { } , & nbgroup . Group { } ,
2023-10-12 15:42:36 +02:00
& Account { } , & Policy { } , & PolicyRule { } , & route . Route { } , & nbdns . NameServerGroup { } ,
2024-02-20 11:53:11 +01:00
& installation { } , & account . ExtraSettings { } , & posture . Checks { } , & nbpeer . NetworkAddress { } ,
2023-10-12 15:42:36 +02:00
)
if err != nil {
2024-04-18 18:14:21 +02:00
return nil , fmt . Errorf ( "auto migrate: %w" , err )
2023-10-12 15:42:36 +02:00
}
2024-05-16 18:28:37 +02:00
return & SqlStore { db : db , storeEngine : storeEngine , metrics : metrics , installationPK : 1 } , nil
2023-10-12 15:42:36 +02:00
}
// AcquireGlobalLock acquires global lock across all the accounts and returns a function that releases the lock
2024-07-03 11:33:02 +02:00
func ( s * SqlStore ) AcquireGlobalLock ( ctx context . Context ) ( unlock func ( ) ) {
log . WithContext ( ctx ) . Tracef ( "acquiring global lock" )
2023-10-12 15:42:36 +02:00
start := time . Now ( )
s . globalAccountLock . Lock ( )
unlock = func ( ) {
s . globalAccountLock . Unlock ( )
2024-07-03 11:33:02 +02:00
log . WithContext ( ctx ) . Tracef ( "released global lock in %v" , time . Since ( start ) )
2023-10-12 15:42:36 +02:00
}
took := time . Since ( start )
2024-07-03 11:33:02 +02:00
log . WithContext ( ctx ) . Tracef ( "took %v to acquire global lock" , took )
2023-10-12 15:42:36 +02:00
if s . metrics != nil {
s . metrics . StoreMetrics ( ) . CountGlobalLockAcquisitionDuration ( took )
}
return unlock
}
2024-07-31 14:53:32 +02:00
// AcquireWriteLockByUID acquires an ID lock for writing to a resource and returns a function that releases the lock
func ( s * SqlStore ) AcquireWriteLockByUID ( ctx context . Context , uniqueID string ) ( unlock func ( ) ) {
log . WithContext ( ctx ) . Tracef ( "acquiring write lock for ID %s" , uniqueID )
2023-10-12 15:42:36 +02:00
start := time . Now ( )
2024-07-31 14:53:32 +02:00
value , _ := s . resourceLocks . LoadOrStore ( uniqueID , & sync . RWMutex { } )
2024-05-07 14:30:03 +02:00
mtx := value . ( * sync . RWMutex )
2023-10-12 15:42:36 +02:00
mtx . Lock ( )
unlock = func ( ) {
mtx . Unlock ( )
2024-07-31 14:53:32 +02:00
log . WithContext ( ctx ) . Tracef ( "released write lock for ID %s in %v" , uniqueID , time . Since ( start ) )
2024-05-07 14:30:03 +02:00
}
return unlock
}
2024-07-31 14:53:32 +02:00
// AcquireReadLockByUID acquires an ID lock for writing to a resource and returns a function that releases the lock
func ( s * SqlStore ) AcquireReadLockByUID ( ctx context . Context , uniqueID string ) ( unlock func ( ) ) {
log . WithContext ( ctx ) . Tracef ( "acquiring read lock for ID %s" , uniqueID )
2024-05-07 14:30:03 +02:00
start := time . Now ( )
2024-07-31 14:53:32 +02:00
value , _ := s . resourceLocks . LoadOrStore ( uniqueID , & sync . RWMutex { } )
2024-05-07 14:30:03 +02:00
mtx := value . ( * sync . RWMutex )
mtx . RLock ( )
unlock = func ( ) {
mtx . RUnlock ( )
2024-07-31 14:53:32 +02:00
log . WithContext ( ctx ) . Tracef ( "released read lock for ID %s in %v" , uniqueID , time . Since ( start ) )
2023-10-12 15:42:36 +02:00
}
return unlock
}
2024-07-03 11:33:02 +02:00
func ( s * SqlStore ) SaveAccount ( ctx context . Context , account * Account ) error {
2023-10-12 15:42:36 +02:00
start := time . Now ( )
2024-08-23 18:42:55 +02:00
defer func ( ) {
elapsed := time . Since ( start )
if elapsed > 1 * time . Second {
log . WithContext ( ctx ) . Tracef ( "SaveAccount for account %s exceeded 1s, took: %v" , account . Id , elapsed )
}
} ( )
2023-10-12 15:42:36 +02:00
2024-07-02 12:40:26 +02:00
// todo: remove this check after the issue is resolved
2024-07-03 11:33:02 +02:00
s . checkAccountDomainBeforeSave ( ctx , account . Id , account . Domain )
2023-10-12 15:42:36 +02:00
2024-07-02 12:40:26 +02:00
generateAccountSQLTypes ( account )
2023-10-12 15:42:36 +02:00
err := s . db . Transaction ( func ( tx * gorm . DB ) error {
result := tx . Select ( clause . Associations ) . Delete ( account . Policies , "account_id = ?" , account . Id )
if result . Error != nil {
return result . Error
}
result = tx . Select ( clause . Associations ) . Delete ( account . UsersG , "account_id = ?" , account . Id )
if result . Error != nil {
return result . Error
}
result = tx . Select ( clause . Associations ) . Delete ( account )
if result . Error != nil {
return result . Error
}
result = tx .
Session ( & gorm . Session { FullSaveAssociations : true } ) .
2024-04-20 22:04:20 +02:00
Clauses ( clause . OnConflict { UpdateAll : true } ) .
Create ( account )
2023-10-12 15:42:36 +02:00
if result . Error != nil {
return result . Error
}
return nil
} )
took := time . Since ( start )
if s . metrics != nil {
s . metrics . StoreMetrics ( ) . CountPersistenceDuration ( took )
}
2024-07-03 11:33:02 +02:00
log . WithContext ( ctx ) . Debugf ( "took %d ms to persist an account to the store" , took . Milliseconds ( ) )
2023-10-12 15:42:36 +02:00
return err
}
2024-07-02 12:40:26 +02:00
// generateAccountSQLTypes generates the GORM compatible types for the account
func generateAccountSQLTypes ( account * Account ) {
for _ , key := range account . SetupKeys {
account . SetupKeysG = append ( account . SetupKeysG , * key )
}
for id , peer := range account . Peers {
peer . ID = id
account . PeersG = append ( account . PeersG , * peer )
}
for id , user := range account . Users {
user . Id = id
for id , pat := range user . PATs {
pat . ID = id
user . PATsG = append ( user . PATsG , * pat )
}
account . UsersG = append ( account . UsersG , * user )
}
for id , group := range account . Groups {
group . ID = id
account . GroupsG = append ( account . GroupsG , * group )
}
for id , route := range account . Routes {
route . ID = id
account . RoutesG = append ( account . RoutesG , * route )
}
for id , ns := range account . NameServerGroups {
ns . ID = id
account . NameServerGroupsG = append ( account . NameServerGroupsG , * ns )
}
}
// checkAccountDomainBeforeSave temporary method to troubleshoot an issue with domains getting blank
2024-07-03 11:33:02 +02:00
func ( s * SqlStore ) checkAccountDomainBeforeSave ( ctx context . Context , accountID , newDomain string ) {
2024-07-02 12:40:26 +02:00
var acc Account
var domain string
result := s . db . Model ( & acc ) . Select ( "domain" ) . Where ( idQueryCondition , accountID ) . First ( & domain )
if result . Error != nil {
if ! errors . Is ( result . Error , gorm . ErrRecordNotFound ) {
2024-07-03 11:33:02 +02:00
log . WithContext ( ctx ) . Errorf ( "error when getting account %s from the store to check domain: %s" , accountID , result . Error )
2024-07-02 12:40:26 +02:00
}
return
}
if domain != "" && newDomain == "" {
2024-07-03 11:33:02 +02:00
log . WithContext ( ctx ) . Warnf ( "saving an account with empty domain when there was a domain set. Previous domain %s, Account ID: %s, Trace: %s" , domain , accountID , debug . Stack ( ) )
2024-07-02 12:40:26 +02:00
}
}
2024-07-03 11:33:02 +02:00
func ( s * SqlStore ) DeleteAccount ( ctx context . Context , account * Account ) error {
2023-11-28 14:23:38 +01:00
start := time . Now ( )
err := s . db . Transaction ( func ( tx * gorm . DB ) error {
result := tx . Select ( clause . Associations ) . Delete ( account . Policies , "account_id = ?" , account . Id )
if result . Error != nil {
return result . Error
}
result = tx . Select ( clause . Associations ) . Delete ( account . UsersG , "account_id = ?" , account . Id )
if result . Error != nil {
return result . Error
}
result = tx . Select ( clause . Associations ) . Delete ( account )
if result . Error != nil {
return result . Error
}
return nil
} )
took := time . Since ( start )
if s . metrics != nil {
s . metrics . StoreMetrics ( ) . CountPersistenceDuration ( took )
}
2024-07-03 11:33:02 +02:00
log . WithContext ( ctx ) . Debugf ( "took %d ms to delete an account to the store" , took . Milliseconds ( ) )
2023-11-28 14:23:38 +01:00
return err
}
2024-07-03 11:33:02 +02:00
func ( s * SqlStore ) SaveInstallationID ( _ context . Context , ID string ) error {
2023-10-12 15:42:36 +02:00
installation := installation { InstallationIDValue : ID }
installation . ID = uint ( s . installationPK )
return s . db . Clauses ( clause . OnConflict { UpdateAll : true } ) . Create ( & installation ) . Error
}
2024-05-16 18:28:37 +02:00
func ( s * SqlStore ) GetInstallationID ( ) string {
2023-10-12 15:42:36 +02:00
var installation installation
2024-07-02 12:40:26 +02:00
if result := s . db . First ( & installation , idQueryCondition , s . installationPK ) ; result . Error != nil {
2023-10-12 15:42:36 +02:00
return ""
}
return installation . InstallationIDValue
}
2024-07-26 07:49:05 +02:00
func ( s * SqlStore ) SavePeer ( ctx context . Context , accountID string , peer * nbpeer . Peer ) error {
// To maintain data integrity, we create a copy of the peer's to prevent unintended updates to other fields.
peerCopy := peer . Copy ( )
peerCopy . AccountID = accountID
err := s . db . WithContext ( ctx ) . Transaction ( func ( tx * gorm . DB ) error {
// check if peer exists before saving
var peerID string
result := tx . Model ( & nbpeer . Peer { } ) . Select ( "id" ) . Find ( & peerID , accountAndIDQueryCondition , accountID , peer . ID )
if result . Error != nil {
return result . Error
}
if peerID == "" {
return status . Errorf ( status . NotFound , peerNotFoundFMT , peer . ID )
}
result = tx . Model ( & nbpeer . Peer { } ) . Where ( accountAndIDQueryCondition , accountID , peer . ID ) . Save ( peerCopy )
if result . Error != nil {
return result . Error
}
return nil
} )
if err != nil {
return err
}
return nil
}
2024-05-16 18:28:37 +02:00
func ( s * SqlStore ) SavePeerStatus ( accountID , peerID string , peerStatus nbpeer . PeerStatus ) error {
2024-05-07 14:30:03 +02:00
var peerCopy nbpeer . Peer
peerCopy . Status = & peerStatus
2024-07-16 17:38:12 +02:00
fieldsToUpdate := [ ] string {
"peer_status_last_seen" , "peer_status_connected" ,
"peer_status_login_expired" , "peer_status_required_approval" ,
}
2024-05-07 14:30:03 +02:00
result := s . db . Model ( & nbpeer . Peer { } ) .
2024-07-16 17:38:12 +02:00
Select ( fieldsToUpdate ) .
2024-07-26 07:49:05 +02:00
Where ( accountAndIDQueryCondition , accountID , peerID ) .
2024-07-16 17:38:12 +02:00
Updates ( & peerCopy )
2023-10-12 15:42:36 +02:00
if result . Error != nil {
2024-05-07 14:30:03 +02:00
return result . Error
2023-10-12 15:42:36 +02:00
}
2024-05-07 14:30:03 +02:00
if result . RowsAffected == 0 {
2024-07-26 07:49:05 +02:00
return status . Errorf ( status . NotFound , peerNotFoundFMT , peerID )
2024-05-07 14:30:03 +02:00
}
2023-10-12 15:42:36 +02:00
2024-05-07 14:30:03 +02:00
return nil
2023-10-12 15:42:36 +02:00
}
2024-05-16 18:28:37 +02:00
func ( s * SqlStore ) SavePeerLocation ( accountID string , peerWithLocation * nbpeer . Peer ) error {
2024-05-07 14:30:03 +02:00
// To maintain data integrity, we create a copy of the peer's location to prevent unintended updates to other fields.
var peerCopy nbpeer . Peer
// Since the location field has been migrated to JSON serialization,
// updating the struct ensures the correct data format is inserted into the database.
peerCopy . Location = peerWithLocation . Location
result := s . db . Model ( & nbpeer . Peer { } ) .
2024-07-26 07:49:05 +02:00
Where ( accountAndIDQueryCondition , accountID , peerWithLocation . ID ) .
2024-05-07 14:30:03 +02:00
Updates ( peerCopy )
2024-02-20 09:59:56 +01:00
if result . Error != nil {
2024-05-07 14:30:03 +02:00
return result . Error
2024-02-20 09:59:56 +01:00
}
2024-05-07 14:30:03 +02:00
if result . RowsAffected == 0 {
2024-07-26 07:49:05 +02:00
return status . Errorf ( status . NotFound , peerNotFoundFMT , peerWithLocation . ID )
2024-05-07 14:30:03 +02:00
}
2024-02-20 09:59:56 +01:00
2024-05-07 14:30:03 +02:00
return nil
2024-02-20 09:59:56 +01:00
}
2024-07-15 16:04:06 +02:00
// SaveUsers saves the given list of users to the database.
// It updates existing users if a conflict occurs.
func ( s * SqlStore ) SaveUsers ( accountID string , users map [ string ] * User ) error {
usersToSave := make ( [ ] User , 0 , len ( users ) )
for _ , user := range users {
user . AccountID = accountID
for id , pat := range user . PATs {
pat . ID = id
user . PATsG = append ( user . PATsG , * pat )
}
usersToSave = append ( usersToSave , * user )
}
return s . db . Session ( & gorm . Session { FullSaveAssociations : true } ) .
Clauses ( clause . OnConflict { UpdateAll : true } ) .
Create ( & usersToSave ) . Error
}
// SaveGroups saves the given list of groups to the database.
// It updates existing groups if a conflict occurs.
func ( s * SqlStore ) SaveGroups ( accountID string , groups map [ string ] * nbgroup . Group ) error {
groupsToSave := make ( [ ] nbgroup . Group , 0 , len ( groups ) )
for _ , group := range groups {
group . AccountID = accountID
groupsToSave = append ( groupsToSave , * group )
}
return s . db . Clauses ( clause . OnConflict { UpdateAll : true } ) . Create ( & groupsToSave ) . Error
}
2024-05-16 18:28:37 +02:00
// DeleteHashedPAT2TokenIDIndex is noop in SqlStore
func ( s * SqlStore ) DeleteHashedPAT2TokenIDIndex ( hashedToken string ) error {
2023-10-12 15:42:36 +02:00
return nil
}
2024-05-16 18:28:37 +02:00
// DeleteTokenID2UserIDIndex is noop in SqlStore
func ( s * SqlStore ) DeleteTokenID2UserIDIndex ( tokenID string ) error {
2023-10-12 15:42:36 +02:00
return nil
}
2024-07-03 11:33:02 +02:00
func ( s * SqlStore ) GetAccountByPrivateDomain ( ctx context . Context , domain string ) ( * Account , error ) {
2024-09-24 12:30:13 +02:00
accountID , err := s . GetAccountIDByPrivateDomain ( ctx , LockingStrengthShare , domain )
2024-09-18 14:55:52 +02:00
if err != nil {
return nil , err
}
// TODO: rework to not call GetAccount
return s . GetAccount ( ctx , accountID )
}
2024-09-24 12:30:13 +02:00
func ( s * SqlStore ) GetAccountIDByPrivateDomain ( ctx context . Context , lockStrength LockingStrength , domain string ) ( string , error ) {
var accountID string
result := s . db . WithContext ( ctx ) . Clauses ( clause . Locking { Strength : string ( lockStrength ) } ) . Model ( & Account { } ) . Select ( "id" ) .
Where ( "domain = ? and is_domain_primary_account = ? and domain_category = ?" ,
strings . ToLower ( domain ) , true , PrivateCategory ,
) . First ( & accountID )
2023-10-12 15:42:36 +02:00
if result . Error != nil {
2024-03-10 19:09:45 +01:00
if errors . Is ( result . Error , gorm . ErrRecordNotFound ) {
2024-09-18 14:55:52 +02:00
return "" , status . Errorf ( status . NotFound , "account not found: provided domain is not registered or is not private" )
2024-03-10 19:09:45 +01:00
}
2024-07-03 11:33:02 +02:00
log . WithContext ( ctx ) . Errorf ( "error when getting account from the store: %s" , result . Error )
2024-09-18 14:55:52 +02:00
return "" , status . Errorf ( status . Internal , "issue getting account from store" )
2023-10-12 15:42:36 +02:00
}
2024-09-24 12:30:13 +02:00
return accountID , nil
2023-10-12 15:42:36 +02:00
}
2024-07-03 11:33:02 +02:00
func ( s * SqlStore ) GetAccountBySetupKey ( ctx context . Context , setupKey string ) ( * Account , error ) {
2023-10-12 15:42:36 +02:00
var key SetupKey
2024-09-16 15:47:03 +02:00
result := s . db . WithContext ( ctx ) . Select ( "account_id" ) . First ( & key , keyQueryCondition , strings . ToUpper ( setupKey ) )
2023-10-12 15:42:36 +02:00
if result . Error != nil {
2024-03-10 19:09:45 +01:00
if errors . Is ( result . Error , gorm . ErrRecordNotFound ) {
return nil , status . Errorf ( status . NotFound , "account not found: index lookup failed" )
}
2024-09-16 15:47:03 +02:00
return nil , status . NewSetupKeyNotFoundError ( )
2023-10-12 15:42:36 +02:00
}
if key . AccountID == "" {
return nil , status . Errorf ( status . NotFound , "account not found: index lookup failed" )
}
2024-07-03 11:33:02 +02:00
return s . GetAccount ( ctx , key . AccountID )
2023-10-12 15:42:36 +02:00
}
2024-07-03 11:33:02 +02:00
func ( s * SqlStore ) GetTokenIDByHashedToken ( ctx context . Context , hashedToken string ) ( string , error ) {
2023-10-12 15:42:36 +02:00
var token PersonalAccessToken
result := s . db . First ( & token , "hashed_token = ?" , hashedToken )
if result . Error != nil {
2024-03-10 19:09:45 +01:00
if errors . Is ( result . Error , gorm . ErrRecordNotFound ) {
return "" , status . Errorf ( status . NotFound , "account not found: index lookup failed" )
}
2024-07-03 11:33:02 +02:00
log . WithContext ( ctx ) . Errorf ( "error when getting token from the store: %s" , result . Error )
2024-03-10 19:09:45 +01:00
return "" , status . Errorf ( status . Internal , "issue getting account from store" )
2023-10-12 15:42:36 +02:00
}
return token . ID , nil
}
2024-07-03 11:33:02 +02:00
func ( s * SqlStore ) GetUserByTokenID ( ctx context . Context , tokenID string ) ( * User , error ) {
2023-10-12 15:42:36 +02:00
var token PersonalAccessToken
2024-07-02 12:40:26 +02:00
result := s . db . First ( & token , idQueryCondition , tokenID )
2023-10-12 15:42:36 +02:00
if result . Error != nil {
2024-03-10 19:09:45 +01:00
if errors . Is ( result . Error , gorm . ErrRecordNotFound ) {
return nil , status . Errorf ( status . NotFound , "account not found: index lookup failed" )
}
2024-07-03 11:33:02 +02:00
log . WithContext ( ctx ) . Errorf ( "error when getting token from the store: %s" , result . Error )
2024-03-10 19:09:45 +01:00
return nil , status . Errorf ( status . Internal , "issue getting account from store" )
2023-10-12 15:42:36 +02:00
}
if token . UserID == "" {
return nil , status . Errorf ( status . NotFound , "account not found: index lookup failed" )
}
var user User
2024-07-02 12:40:26 +02:00
result = s . db . Preload ( "PATsG" ) . First ( & user , idQueryCondition , token . UserID )
2023-10-12 15:42:36 +02:00
if result . Error != nil {
return nil , status . Errorf ( status . NotFound , "account not found: index lookup failed" )
}
user . PATs = make ( map [ string ] * PersonalAccessToken , len ( user . PATsG ) )
for _ , pat := range user . PATsG {
2023-11-15 14:15:12 +01:00
user . PATs [ pat . ID ] = pat . Copy ( )
2023-10-12 15:42:36 +02:00
}
return & user , nil
}
2024-09-16 15:47:03 +02:00
func ( s * SqlStore ) GetUserByUserID ( ctx context . Context , lockStrength LockingStrength , userID string ) ( * User , error ) {
2024-08-19 12:50:11 +02:00
var user User
2024-09-16 15:47:03 +02:00
result := s . db . WithContext ( ctx ) . Clauses ( clause . Locking { Strength : string ( lockStrength ) } ) .
2024-09-24 20:57:33 +02:00
Preload ( clause . Associations ) . First ( & user , idQueryCondition , userID )
2024-08-19 12:50:11 +02:00
if result . Error != nil {
if errors . Is ( result . Error , gorm . ErrRecordNotFound ) {
2024-09-16 15:47:03 +02:00
return nil , status . NewUserNotFoundError ( userID )
2024-08-19 12:50:11 +02:00
}
2024-09-16 15:47:03 +02:00
return nil , status . NewGetUserFromStoreError ( )
2024-08-19 12:50:11 +02:00
}
return & user , nil
}
func ( s * SqlStore ) GetAccountGroups ( ctx context . Context , accountID string ) ( [ ] * nbgroup . Group , error ) {
var groups [ ] * nbgroup . Group
2024-09-22 14:14:31 +02:00
result := s . db . Find ( & groups , accountIDCondition , accountID )
2024-08-19 12:50:11 +02:00
if result . Error != nil {
if errors . Is ( result . Error , gorm . ErrRecordNotFound ) {
return nil , status . Errorf ( status . NotFound , "accountID not found: index lookup failed" )
}
log . WithContext ( ctx ) . Errorf ( "error when getting groups from the store: %s" , result . Error )
return nil , status . Errorf ( status . Internal , "issue getting groups from store" )
}
return groups , nil
}
2024-07-03 11:33:02 +02:00
func ( s * SqlStore ) GetAllAccounts ( ctx context . Context ) ( all [ ] * Account ) {
2023-10-12 15:42:36 +02:00
var accounts [ ] Account
result := s . db . Find ( & accounts )
if result . Error != nil {
return all
}
for _ , account := range accounts {
2024-07-03 11:33:02 +02:00
if acc , err := s . GetAccount ( ctx , account . Id ) ; err == nil {
2023-10-12 15:42:36 +02:00
all = append ( all , acc )
}
}
return all
}
2024-07-03 11:33:02 +02:00
func ( s * SqlStore ) GetAccount ( ctx context . Context , accountID string ) ( * Account , error ) {
2024-08-23 18:42:55 +02:00
start := time . Now ( )
defer func ( ) {
elapsed := time . Since ( start )
if elapsed > 1 * time . Second {
log . WithContext ( ctx ) . Tracef ( "GetAccount for account %s exceeded 1s, took: %v" , accountID , elapsed )
}
} ( )
2024-05-07 14:30:03 +02:00
2023-10-12 15:42:36 +02:00
var account Account
result := s . db . Model ( & account ) .
Preload ( "UsersG.PATsG" ) . // have to be specifies as this is nester reference
Preload ( clause . Associations ) .
2024-07-02 12:40:26 +02:00
First ( & account , idQueryCondition , accountID )
2023-10-12 15:42:36 +02:00
if result . Error != nil {
2024-07-03 11:33:02 +02:00
log . WithContext ( ctx ) . Errorf ( "error when getting account %s from the store: %s" , accountID , result . Error )
2024-03-10 19:09:45 +01:00
if errors . Is ( result . Error , gorm . ErrRecordNotFound ) {
2024-09-16 15:47:03 +02:00
return nil , status . NewAccountNotFoundError ( accountID )
2024-03-10 19:09:45 +01:00
}
return nil , status . Errorf ( status . Internal , "issue getting account from store" )
2023-10-12 15:42:36 +02:00
}
// we have to manually preload policy rules as it seems that gorm preloading doesn't do it for us
for i , policy := range account . Policies {
var rules [ ] * PolicyRule
err := s . db . Model ( & PolicyRule { } ) . Find ( & rules , "policy_id = ?" , policy . ID ) . Error
if err != nil {
2023-11-28 14:23:38 +01:00
return nil , status . Errorf ( status . NotFound , "rule not found" )
2023-10-12 15:42:36 +02:00
}
account . Policies [ i ] . Rules = rules
}
account . SetupKeys = make ( map [ string ] * SetupKey , len ( account . SetupKeysG ) )
for _ , key := range account . SetupKeysG {
account . SetupKeys [ key . Key ] = key . Copy ( )
}
account . SetupKeysG = nil
2023-11-28 13:45:26 +01:00
account . Peers = make ( map [ string ] * nbpeer . Peer , len ( account . PeersG ) )
2023-10-12 15:42:36 +02:00
for _ , peer := range account . PeersG {
account . Peers [ peer . ID ] = peer . Copy ( )
}
account . PeersG = nil
account . Users = make ( map [ string ] * User , len ( account . UsersG ) )
for _ , user := range account . UsersG {
user . PATs = make ( map [ string ] * PersonalAccessToken , len ( user . PATs ) )
for _ , pat := range user . PATsG {
user . PATs [ pat . ID ] = pat . Copy ( )
}
account . Users [ user . Id ] = user . Copy ( )
}
account . UsersG = nil
2024-03-27 18:48:48 +01:00
account . Groups = make ( map [ string ] * nbgroup . Group , len ( account . GroupsG ) )
2023-10-12 15:42:36 +02:00
for _ , group := range account . GroupsG {
account . Groups [ group . ID ] = group . Copy ( )
}
account . GroupsG = nil
2024-05-06 14:47:49 +02:00
account . Routes = make ( map [ route . ID ] * route . Route , len ( account . RoutesG ) )
2023-10-12 15:42:36 +02:00
for _ , route := range account . RoutesG {
account . Routes [ route . ID ] = route . Copy ( )
}
account . RoutesG = nil
account . NameServerGroups = make ( map [ string ] * nbdns . NameServerGroup , len ( account . NameServerGroupsG ) )
for _ , ns := range account . NameServerGroupsG {
account . NameServerGroups [ ns . ID ] = ns . Copy ( )
}
account . NameServerGroupsG = nil
return & account , nil
}
2024-07-03 11:33:02 +02:00
func ( s * SqlStore ) GetAccountByUser ( ctx context . Context , userID string ) ( * Account , error ) {
2023-10-12 15:42:36 +02:00
var user User
2024-09-16 15:47:03 +02:00
result := s . db . WithContext ( ctx ) . Select ( "account_id" ) . First ( & user , idQueryCondition , userID )
2023-10-12 15:42:36 +02:00
if result . Error != nil {
2024-03-10 19:09:45 +01:00
if errors . Is ( result . Error , gorm . ErrRecordNotFound ) {
return nil , status . Errorf ( status . NotFound , "account not found: index lookup failed" )
}
return nil , status . Errorf ( status . Internal , "issue getting account from store" )
2023-10-12 15:42:36 +02:00
}
if user . AccountID == "" {
return nil , status . Errorf ( status . NotFound , "account not found: index lookup failed" )
}
2024-07-03 11:33:02 +02:00
return s . GetAccount ( ctx , user . AccountID )
2023-10-12 15:42:36 +02:00
}
2024-07-03 11:33:02 +02:00
func ( s * SqlStore ) GetAccountByPeerID ( ctx context . Context , peerID string ) ( * Account , error ) {
2023-11-28 13:45:26 +01:00
var peer nbpeer . Peer
2024-09-16 15:47:03 +02:00
result := s . db . WithContext ( ctx ) . Select ( "account_id" ) . First ( & peer , idQueryCondition , peerID )
2023-10-12 15:42:36 +02:00
if result . Error != nil {
2024-03-10 19:09:45 +01:00
if errors . Is ( result . Error , gorm . ErrRecordNotFound ) {
return nil , status . Errorf ( status . NotFound , "account not found: index lookup failed" )
}
return nil , status . Errorf ( status . Internal , "issue getting account from store" )
2023-10-12 15:42:36 +02:00
}
if peer . AccountID == "" {
return nil , status . Errorf ( status . NotFound , "account not found: index lookup failed" )
}
2024-07-03 11:33:02 +02:00
return s . GetAccount ( ctx , peer . AccountID )
2023-10-12 15:42:36 +02:00
}
2024-07-03 11:33:02 +02:00
func ( s * SqlStore ) GetAccountByPeerPubKey ( ctx context . Context , peerKey string ) ( * Account , error ) {
2023-11-28 13:45:26 +01:00
var peer nbpeer . Peer
2023-10-12 15:42:36 +02:00
2024-09-16 15:47:03 +02:00
result := s . db . WithContext ( ctx ) . Select ( "account_id" ) . First ( & peer , keyQueryCondition , peerKey )
2023-10-12 15:42:36 +02:00
if result . Error != nil {
2024-03-10 19:09:45 +01:00
if errors . Is ( result . Error , gorm . ErrRecordNotFound ) {
return nil , status . Errorf ( status . NotFound , "account not found: index lookup failed" )
}
return nil , status . Errorf ( status . Internal , "issue getting account from store" )
2023-10-12 15:42:36 +02:00
}
if peer . AccountID == "" {
return nil , status . Errorf ( status . NotFound , "account not found: index lookup failed" )
}
2024-07-03 11:33:02 +02:00
return s . GetAccount ( ctx , peer . AccountID )
2023-10-12 15:42:36 +02:00
}
2024-07-03 11:33:02 +02:00
func ( s * SqlStore ) GetAccountIDByPeerPubKey ( ctx context . Context , peerKey string ) ( string , error ) {
2024-05-07 14:30:03 +02:00
var peer nbpeer . Peer
var accountID string
2024-09-16 15:47:03 +02:00
result := s . db . WithContext ( ctx ) . Model ( & peer ) . Select ( "account_id" ) . Where ( keyQueryCondition , peerKey ) . First ( & accountID )
2024-05-07 14:30:03 +02:00
if result . Error != nil {
if errors . Is ( result . Error , gorm . ErrRecordNotFound ) {
return "" , status . Errorf ( status . NotFound , "account not found: index lookup failed" )
}
return "" , status . Errorf ( status . Internal , "issue getting account from store" )
}
return accountID , nil
}
2024-05-31 16:41:12 +02:00
func ( s * SqlStore ) GetAccountIDByUserID ( userID string ) ( string , error ) {
var accountID string
2024-09-24 12:30:13 +02:00
result := s . db . Model ( & User { } ) . Select ( "account_id" ) . Where ( idQueryCondition , userID ) . First ( & accountID )
2024-05-31 16:41:12 +02:00
if result . Error != nil {
if errors . Is ( result . Error , gorm . ErrRecordNotFound ) {
return "" , status . Errorf ( status . NotFound , "account not found: index lookup failed" )
}
return "" , status . Errorf ( status . Internal , "issue getting account from store" )
}
return accountID , nil
}
2024-07-03 11:33:02 +02:00
func ( s * SqlStore ) GetAccountIDBySetupKey ( ctx context . Context , setupKey string ) ( string , error ) {
2024-05-31 16:41:12 +02:00
var accountID string
2024-09-16 15:47:03 +02:00
result := s . db . WithContext ( ctx ) . Model ( & SetupKey { } ) . Select ( "account_id" ) . Where ( keyQueryCondition , strings . ToUpper ( setupKey ) ) . First ( & accountID )
2024-05-31 16:41:12 +02:00
if result . Error != nil {
if errors . Is ( result . Error , gorm . ErrRecordNotFound ) {
return "" , status . Errorf ( status . NotFound , "account not found: index lookup failed" )
}
2024-09-16 15:47:03 +02:00
return "" , status . NewSetupKeyNotFoundError ( )
}
if accountID == "" {
return "" , status . Errorf ( status . NotFound , "account not found: index lookup failed" )
2024-05-31 16:41:12 +02:00
}
return accountID , nil
}
2024-09-16 15:47:03 +02:00
func ( s * SqlStore ) GetTakenIPs ( ctx context . Context , lockStrength LockingStrength , accountID string ) ( [ ] net . IP , error ) {
var ipJSONStrings [ ] string
// Fetch the IP addresses as JSON strings
result := s . db . WithContext ( ctx ) . Clauses ( clause . Locking { Strength : string ( lockStrength ) } ) . Model ( & nbpeer . Peer { } ) .
Where ( "account_id = ?" , accountID ) .
Pluck ( "ip" , & ipJSONStrings )
if result . Error != nil {
if errors . Is ( result . Error , gorm . ErrRecordNotFound ) {
return nil , status . Errorf ( status . NotFound , "no peers found for the account" )
}
return nil , status . Errorf ( status . Internal , "issue getting IPs from store" )
}
// Convert the JSON strings to net.IP objects
ips := make ( [ ] net . IP , len ( ipJSONStrings ) )
for i , ipJSON := range ipJSONStrings {
var ip net . IP
if err := json . Unmarshal ( [ ] byte ( ipJSON ) , & ip ) ; err != nil {
return nil , status . Errorf ( status . Internal , "issue parsing IP JSON from store" )
}
ips [ i ] = ip
}
return ips , nil
}
func ( s * SqlStore ) GetPeerLabelsInAccount ( ctx context . Context , lockStrength LockingStrength , accountID string ) ( [ ] string , error ) {
var labels [ ] string
result := s . db . WithContext ( ctx ) . Clauses ( clause . Locking { Strength : string ( lockStrength ) } ) . Model ( & nbpeer . Peer { } ) .
Where ( "account_id = ?" , accountID ) .
Pluck ( "dns_label" , & labels )
if result . Error != nil {
if errors . Is ( result . Error , gorm . ErrRecordNotFound ) {
return nil , status . Errorf ( status . NotFound , "no peers found for the account" )
}
log . WithContext ( ctx ) . Errorf ( "error when getting dns labels from the store: %s" , result . Error )
return nil , status . Errorf ( status . Internal , "issue getting dns labels from store" )
}
return labels , nil
}
func ( s * SqlStore ) GetAccountNetwork ( ctx context . Context , lockStrength LockingStrength , accountID string ) ( * Network , error ) {
var accountNetwork AccountNetwork
if err := s . db . WithContext ( ctx ) . Model ( & Account { } ) . Where ( idQueryCondition , accountID ) . First ( & accountNetwork ) . Error ; err != nil {
if errors . Is ( err , gorm . ErrRecordNotFound ) {
return nil , status . NewAccountNotFoundError ( accountID )
}
return nil , status . Errorf ( status . Internal , "issue getting network from store" )
}
return accountNetwork . Network , nil
}
func ( s * SqlStore ) GetPeerByPeerPubKey ( ctx context . Context , lockStrength LockingStrength , peerKey string ) ( * nbpeer . Peer , error ) {
2024-05-31 16:41:12 +02:00
var peer nbpeer . Peer
2024-09-16 15:47:03 +02:00
result := s . db . WithContext ( ctx ) . Clauses ( clause . Locking { Strength : string ( lockStrength ) } ) . First ( & peer , keyQueryCondition , peerKey )
2024-05-31 16:41:12 +02:00
if result . Error != nil {
if errors . Is ( result . Error , gorm . ErrRecordNotFound ) {
return nil , status . Errorf ( status . NotFound , "peer not found" )
}
return nil , status . Errorf ( status . Internal , "issue getting peer from store" )
}
return & peer , nil
}
2024-09-16 15:47:03 +02:00
func ( s * SqlStore ) GetAccountSettings ( ctx context . Context , lockStrength LockingStrength , accountID string ) ( * Settings , error ) {
2024-05-31 16:41:12 +02:00
var accountSettings AccountSettings
2024-09-16 15:47:03 +02:00
if err := s . db . WithContext ( ctx ) . Clauses ( clause . Locking { Strength : string ( lockStrength ) } ) . Model ( & Account { } ) . Where ( idQueryCondition , accountID ) . First ( & accountSettings ) . Error ; err != nil {
2024-05-31 16:41:12 +02:00
if errors . Is ( err , gorm . ErrRecordNotFound ) {
return nil , status . Errorf ( status . NotFound , "settings not found" )
}
return nil , status . Errorf ( status . Internal , "issue getting settings from store" )
}
return accountSettings . Settings , nil
}
2023-10-12 15:42:36 +02:00
// SaveUserLastLogin stores the last login time for a user in DB.
2024-09-16 15:47:03 +02:00
func ( s * SqlStore ) SaveUserLastLogin ( ctx context . Context , accountID , userID string , lastLogin time . Time ) error {
2023-10-23 16:08:21 +02:00
var user User
2023-10-12 15:42:36 +02:00
2024-09-16 15:47:03 +02:00
result := s . db . WithContext ( ctx ) . First ( & user , accountAndIDQueryCondition , accountID , userID )
2023-10-12 15:42:36 +02:00
if result . Error != nil {
2024-03-10 19:09:45 +01:00
if errors . Is ( result . Error , gorm . ErrRecordNotFound ) {
2024-09-16 15:47:03 +02:00
return status . NewUserNotFoundError ( userID )
2024-03-10 19:09:45 +01:00
}
2024-09-16 15:47:03 +02:00
return status . NewGetUserFromStoreError ( )
2023-10-12 15:42:36 +02:00
}
2023-10-23 16:08:21 +02:00
user . LastLogin = lastLogin
2023-10-12 15:42:36 +02:00
2024-09-16 15:47:03 +02:00
return s . db . Save ( & user ) . Error
2023-10-12 15:42:36 +02:00
}
2024-05-30 15:22:42 +02:00
func ( s * SqlStore ) GetPostureCheckByChecksDefinition ( accountID string , checks * posture . ChecksDefinition ) ( * posture . Checks , error ) {
definitionJSON , err := json . Marshal ( checks )
if err != nil {
return nil , err
}
var postureCheck posture . Checks
err = s . db . Where ( "account_id = ? AND checks = ?" , accountID , string ( definitionJSON ) ) . First ( & postureCheck ) . Error
if err != nil {
return nil , err
}
return & postureCheck , nil
}
2024-02-22 12:27:08 +01:00
// Close closes the underlying DB connection
2024-07-03 11:33:02 +02:00
func ( s * SqlStore ) Close ( _ context . Context ) error {
2024-02-20 15:06:32 +01:00
sql , err := s . db . DB ( )
if err != nil {
2024-02-22 12:27:08 +01:00
return fmt . Errorf ( "get db: %w" , err )
2024-02-20 15:06:32 +01:00
}
return sql . Close ( )
2023-10-12 15:42:36 +02:00
}
2024-05-16 18:28:37 +02:00
// GetStoreEngine returns underlying store engine
func ( s * SqlStore ) GetStoreEngine ( ) StoreEngine {
return s . storeEngine
}
// NewSqliteStore creates a new SQLite store.
2024-07-03 11:33:02 +02:00
func NewSqliteStore ( ctx context . Context , dataDir string , metrics telemetry . AppMetrics ) ( * SqlStore , error ) {
2024-06-13 12:39:19 +02:00
storeStr := fmt . Sprintf ( "%s?cache=shared" , storeSqliteFileName )
2024-05-16 18:28:37 +02:00
if runtime . GOOS == "windows" {
// Vo avoid `The process cannot access the file because it is being used by another process` on Windows
2024-06-13 12:39:19 +02:00
storeStr = storeSqliteFileName
2024-05-16 18:28:37 +02:00
}
file := filepath . Join ( dataDir , storeStr )
2024-07-12 09:28:53 +02:00
db , err := gorm . Open ( sqlite . Open ( file ) , getGormConfig ( ) )
2024-05-16 18:28:37 +02:00
if err != nil {
return nil , err
}
2024-07-03 11:33:02 +02:00
return NewSqlStore ( ctx , db , SqliteStoreEngine , metrics )
2024-05-16 18:28:37 +02:00
}
// NewPostgresqlStore creates a new Postgres store.
2024-07-03 11:33:02 +02:00
func NewPostgresqlStore ( ctx context . Context , dsn string , metrics telemetry . AppMetrics ) ( * SqlStore , error ) {
2024-07-12 09:28:53 +02:00
db , err := gorm . Open ( postgres . Open ( dsn ) , getGormConfig ( ) )
2024-05-16 18:28:37 +02:00
if err != nil {
return nil , err
}
2024-07-03 11:33:02 +02:00
return NewSqlStore ( ctx , db , PostgresStoreEngine , metrics )
2023-10-12 15:42:36 +02:00
}
2024-04-18 18:14:21 +02:00
2024-07-12 09:28:53 +02:00
func getGormConfig ( ) * gorm . Config {
return & gorm . Config {
Logger : logger . Default . LogMode ( logger . Silent ) ,
CreateBatchSize : 400 ,
PrepareStmt : true ,
}
}
2024-06-13 12:39:19 +02:00
// newPostgresStore initializes a new Postgres store.
2024-07-03 11:33:02 +02:00
func newPostgresStore ( ctx context . Context , metrics telemetry . AppMetrics ) ( Store , error ) {
2024-06-13 12:39:19 +02:00
dsn , ok := os . LookupEnv ( postgresDsnEnv )
if ! ok {
return nil , fmt . Errorf ( "%s is not set" , postgresDsnEnv )
}
2024-07-03 11:33:02 +02:00
return NewPostgresqlStore ( ctx , dsn , metrics )
2024-06-13 12:39:19 +02:00
}
2024-05-16 18:28:37 +02:00
// NewSqliteStoreFromFileStore restores a store from FileStore and stores SQLite DB in the file located in datadir.
2024-07-03 11:33:02 +02:00
func NewSqliteStoreFromFileStore ( ctx context . Context , fileStore * FileStore , dataDir string , metrics telemetry . AppMetrics ) ( * SqlStore , error ) {
store , err := NewSqliteStore ( ctx , dataDir , metrics )
2024-05-16 18:28:37 +02:00
if err != nil {
return nil , err
}
2024-04-18 18:14:21 +02:00
2024-07-03 11:33:02 +02:00
err = store . SaveInstallationID ( ctx , fileStore . InstallationID )
2024-05-16 18:28:37 +02:00
if err != nil {
return nil , err
}
2024-07-03 11:33:02 +02:00
for _ , account := range fileStore . GetAllAccounts ( ctx ) {
err := store . SaveAccount ( ctx , account )
2024-05-16 18:28:37 +02:00
if err != nil {
return nil , err
2024-04-18 18:14:21 +02:00
}
}
2024-05-16 18:28:37 +02:00
return store , nil
2024-04-18 18:14:21 +02:00
}
2024-05-16 18:28:37 +02:00
// NewPostgresqlStoreFromFileStore restores a store from FileStore and stores Postgres DB.
2024-07-03 11:33:02 +02:00
func NewPostgresqlStoreFromFileStore ( ctx context . Context , fileStore * FileStore , dsn string , metrics telemetry . AppMetrics ) ( * SqlStore , error ) {
store , err := NewPostgresqlStore ( ctx , dsn , metrics )
2024-05-16 18:28:37 +02:00
if err != nil {
return nil , err
2024-04-18 18:14:21 +02:00
}
2024-05-16 18:28:37 +02:00
2024-07-03 11:33:02 +02:00
err = store . SaveInstallationID ( ctx , fileStore . InstallationID )
2024-05-16 18:28:37 +02:00
if err != nil {
return nil , err
}
2024-07-03 11:33:02 +02:00
for _ , account := range fileStore . GetAllAccounts ( ctx ) {
err := store . SaveAccount ( ctx , account )
2024-05-16 18:28:37 +02:00
if err != nil {
return nil , err
}
}
return store , nil
2024-04-18 18:14:21 +02:00
}
2024-09-16 15:47:03 +02:00
func ( s * SqlStore ) GetSetupKeyBySecret ( ctx context . Context , lockStrength LockingStrength , key string ) ( * SetupKey , error ) {
var setupKey SetupKey
result := s . db . WithContext ( ctx ) . Clauses ( clause . Locking { Strength : string ( lockStrength ) } ) .
First ( & setupKey , keyQueryCondition , strings . ToUpper ( key ) )
if result . Error != nil {
if errors . Is ( result . Error , gorm . ErrRecordNotFound ) {
return nil , status . Errorf ( status . NotFound , "setup key not found" )
}
return nil , status . NewSetupKeyNotFoundError ( )
}
return & setupKey , nil
}
func ( s * SqlStore ) IncrementSetupKeyUsage ( ctx context . Context , setupKeyID string ) error {
result := s . db . WithContext ( ctx ) . Model ( & SetupKey { } ) .
Where ( idQueryCondition , setupKeyID ) .
Updates ( map [ string ] interface { } {
"used_times" : gorm . Expr ( "used_times + 1" ) ,
"last_used" : time . Now ( ) ,
} )
if result . Error != nil {
return status . Errorf ( status . Internal , "issue incrementing setup key usage count: %s" , result . Error )
}
if result . RowsAffected == 0 {
return status . Errorf ( status . NotFound , "setup key not found" )
}
return nil
}
func ( s * SqlStore ) AddPeerToAllGroup ( ctx context . Context , accountID string , peerID string ) error {
var group nbgroup . Group
result := s . db . WithContext ( ctx ) . Where ( "account_id = ? AND name = ?" , accountID , "All" ) . First ( & group )
if result . Error != nil {
if errors . Is ( result . Error , gorm . ErrRecordNotFound ) {
return status . Errorf ( status . NotFound , "group 'All' not found for account" )
}
return status . Errorf ( status . Internal , "issue finding group 'All'" )
}
for _ , existingPeerID := range group . Peers {
if existingPeerID == peerID {
return nil
}
}
group . Peers = append ( group . Peers , peerID )
if err := s . db . Save ( & group ) . Error ; err != nil {
return status . Errorf ( status . Internal , "issue updating group 'All'" )
}
return nil
}
func ( s * SqlStore ) AddPeerToGroup ( ctx context . Context , accountId string , peerId string , groupID string ) error {
var group nbgroup . Group
result := s . db . WithContext ( ctx ) . Where ( accountAndIDQueryCondition , accountId , groupID ) . First ( & group )
if result . Error != nil {
if errors . Is ( result . Error , gorm . ErrRecordNotFound ) {
return status . Errorf ( status . NotFound , "group not found for account" )
}
return status . Errorf ( status . Internal , "issue finding group" )
}
for _ , existingPeerID := range group . Peers {
if existingPeerID == peerId {
return nil
}
}
group . Peers = append ( group . Peers , peerId )
if err := s . db . Save ( & group ) . Error ; err != nil {
return status . Errorf ( status . Internal , "issue updating group" )
}
return nil
}
func ( s * SqlStore ) AddPeerToAccount ( ctx context . Context , peer * nbpeer . Peer ) error {
if err := s . db . WithContext ( ctx ) . Create ( peer ) . Error ; err != nil {
return status . Errorf ( status . Internal , "issue adding peer to account" )
}
return nil
}
func ( s * SqlStore ) IncrementNetworkSerial ( ctx context . Context , accountId string ) error {
result := s . db . WithContext ( ctx ) . Model ( & Account { } ) . Where ( idQueryCondition , accountId ) . Update ( "network_serial" , gorm . Expr ( "network_serial + 1" ) )
if result . Error != nil {
return status . Errorf ( status . Internal , "issue incrementing network serial count" )
}
return nil
}
func ( s * SqlStore ) ExecuteInTransaction ( ctx context . Context , operation func ( store Store ) error ) error {
tx := s . db . WithContext ( ctx ) . Begin ( )
if tx . Error != nil {
return tx . Error
}
repo := s . withTx ( tx )
err := operation ( repo )
if err != nil {
tx . Rollback ( )
return err
}
return tx . Commit ( ) . Error
}
func ( s * SqlStore ) withTx ( tx * gorm . DB ) Store {
return & SqlStore {
db : tx ,
}
}
2024-09-20 13:07:44 +02:00
2024-09-25 11:53:20 +02:00
func ( s * SqlStore ) GetAccountDNSSettings ( ctx context . Context , lockStrength LockingStrength , accountID string ) ( * DNSSettings , error ) {
var accountDNSSettings AccountDNSSettings
result := s . db . WithContext ( ctx ) . Clauses ( clause . Locking { Strength : string ( lockStrength ) } ) . Model ( & Account { } ) .
First ( & accountDNSSettings , idQueryCondition , accountID )
if result . Error != nil {
if errors . Is ( result . Error , gorm . ErrRecordNotFound ) {
return nil , status . Errorf ( status . NotFound , "dns settings not found" )
}
return nil , status . Errorf ( status . Internal , "failed to get dns settings from store: %v" , result . Error )
}
return & accountDNSSettings . DNSSettings , nil
}
2024-09-24 12:30:13 +02:00
// UpdateAccount updates an existing account's domain, DNS settings, and settings fields.
func ( s * SqlStore ) UpdateAccount ( ctx context . Context , lockStrength LockingStrength , account * Account ) error {
updates := make ( map [ string ] interface { } )
if account . Domain != "" {
updates [ "domain" ] = account . Domain
}
if account . DNSSettings . DisabledManagementGroups != nil {
updates [ "dns_settings" ] = account . DNSSettings
}
if account . Settings != nil {
updates [ "settings" ] = account . Settings
}
if len ( updates ) == 0 {
return nil
}
result := s . db . WithContext ( ctx ) . Clauses ( clause . Locking { Strength : string ( lockStrength ) } ) . Model ( & Account { } ) .
Where ( "id = ?" , account . Id ) . Updates ( updates )
if result . Error != nil {
return status . Errorf ( status . Internal , "failed to update account: %v" , result . Error )
}
if result . RowsAffected == 0 {
return status . Errorf ( status . NotFound , "account not found" )
}
return nil
}
// AccountExists checks whether an account exists by the given ID.
func ( s * SqlStore ) AccountExists ( ctx context . Context , id string ) ( bool , error ) {
var count int64
2024-09-25 11:53:20 +02:00
2024-09-24 12:30:13 +02:00
result := s . db . WithContext ( ctx ) . Model ( & Account { } ) . Where ( idQueryCondition , id ) . Count ( & count )
if result . Error != nil {
return false , result . Error
}
return count > 0 , nil
}
2024-09-20 13:07:44 +02:00
// GetAccountDomainAndCategory retrieves the Domain and DomainCategory fields for an account based on the given accountID.
2024-09-24 12:30:13 +02:00
func ( s * SqlStore ) GetAccountDomainAndCategory ( ctx context . Context , lockStrength LockingStrength , accountID string ) ( string , string , error ) {
2024-09-20 13:07:44 +02:00
var account Account
2024-09-25 11:53:20 +02:00
2024-09-24 12:30:13 +02:00
result := s . db . WithContext ( ctx ) . Clauses ( clause . Locking { Strength : string ( lockStrength ) } ) . Model ( & Account { } ) . Select ( "domain" , "domain_category" ) .
2024-09-20 13:07:44 +02:00
Where ( idQueryCondition , accountID ) . First ( & account )
if result . Error != nil {
if errors . Is ( result . Error , gorm . ErrRecordNotFound ) {
return "" , "" , status . Errorf ( status . NotFound , "account not found" )
}
2024-09-24 18:55:33 +02:00
return "" , "" , status . Errorf ( status . Internal , "failed to get domain category from store: %v" , result . Error )
2024-09-20 13:07:44 +02:00
}
return account . Domain , account . DomainCategory , nil
}
2024-09-24 15:36:57 +02:00
2024-09-25 11:53:20 +02:00
// GetGroupByID
func ( s * SqlStore ) GetGroupByID ( ctx context . Context , lockStrength LockingStrength , groupID , accountID string ) ( * nbgroup . Group , error ) {
return getRecordByID [ nbgroup . Group ] ( s . db . WithContext ( ctx ) . Preload ( clause . Associations ) , lockStrength , groupID , accountID )
2024-09-24 18:55:33 +02:00
}
2024-09-24 15:36:57 +02:00
// GetGroupByName retrieves a group by name and account ID.
func ( s * SqlStore ) GetGroupByName ( ctx context . Context , lockStrength LockingStrength , groupName , accountID string ) ( * nbgroup . Group , error ) {
var group nbgroup . Group
2024-09-25 11:53:20 +02:00
result := s . db . WithContext ( ctx ) . Clauses ( clause . Locking { Strength : string ( lockStrength ) } ) . Preload ( clause . Associations ) .
Order ( "json_array_length(peers) DESC" ) . First ( & group , "name = ? and account_id = ?" , groupName , accountID )
2024-09-24 15:36:57 +02:00
if err := result . Error ; err != nil {
if errors . Is ( result . Error , gorm . ErrRecordNotFound ) {
return nil , status . Errorf ( status . NotFound , "group not found" )
}
2024-09-24 18:55:33 +02:00
return nil , status . Errorf ( status . Internal , "failed to get group from store: %s" , result . Error )
2024-09-24 15:36:57 +02:00
}
return & group , nil
}
2024-09-24 20:57:33 +02:00
2024-09-25 11:53:20 +02:00
// GetAccountPolicies retrieves policies for an account.
2024-09-24 20:57:33 +02:00
func ( s * SqlStore ) GetAccountPolicies ( ctx context . Context , accountID string ) ( [ ] * Policy , error ) {
2024-09-25 11:53:20 +02:00
return getRecords [ * Policy ] ( s . db . WithContext ( ctx ) . Preload ( clause . Associations ) , accountID )
2024-09-24 20:57:33 +02:00
}
2024-09-25 11:53:20 +02:00
// GetPolicyByID retrieves a policy by its ID and account ID.
2024-09-24 20:57:33 +02:00
func ( s * SqlStore ) GetPolicyByID ( ctx context . Context , lockStrength LockingStrength , policyID string , accountID string ) ( * Policy , error ) {
2024-09-25 11:53:20 +02:00
return getRecordByID [ Policy ] ( s . db . WithContext ( ctx ) . Preload ( clause . Associations ) , lockStrength , policyID , accountID )
2024-09-24 20:57:33 +02:00
}
2024-09-25 11:53:20 +02:00
// GetAccountPostureChecks retrieves posture checks for an account.
2024-09-24 20:57:33 +02:00
func ( s * SqlStore ) GetAccountPostureChecks ( ctx context . Context , accountID string ) ( [ ] * posture . Checks , error ) {
2024-09-25 11:53:20 +02:00
return getRecords [ * posture . Checks ] ( s . db . WithContext ( ctx ) , accountID )
2024-09-24 20:57:33 +02:00
}
2024-09-25 11:53:20 +02:00
// GetPostureChecksByID retrieves posture checks by their ID and account ID.
2024-09-24 20:57:33 +02:00
func ( s * SqlStore ) GetPostureChecksByID ( ctx context . Context , lockStrength LockingStrength , postureCheckID string , accountID string ) ( * posture . Checks , error ) {
2024-09-25 11:53:20 +02:00
return getRecordByID [ posture . Checks ] ( s . db . WithContext ( ctx ) , lockStrength , postureCheckID , accountID )
}
// GetAccountRoutes retrieves network routes for an account.
func ( s * SqlStore ) GetAccountRoutes ( ctx context . Context , accountID string ) ( [ ] * route . Route , error ) {
return getRecords [ * route . Route ] ( s . db . WithContext ( ctx ) , accountID )
}
// GetRouteByID retrieves a route by its ID and account ID.
func ( s * SqlStore ) GetRouteByID ( ctx context . Context , lockStrength LockingStrength , routeID string , accountID string ) ( * route . Route , error ) {
return getRecordByID [ route . Route ] ( s . db . WithContext ( ctx ) , lockStrength , routeID , accountID )
}
// GetAccountSetupKeys retrieves setup keys for an account.
func ( s * SqlStore ) GetAccountSetupKeys ( ctx context . Context , accountID string ) ( [ ] * SetupKey , error ) {
return getRecords [ * SetupKey ] ( s . db . WithContext ( ctx ) , accountID )
}
// GetSetupKeyByID retrieves a setup key by its ID and account ID.
func ( s * SqlStore ) GetSetupKeyByID ( ctx context . Context , lockStrength LockingStrength , setupKeyID string , accountID string ) ( * SetupKey , error ) {
return getRecordByID [ SetupKey ] ( s . db . WithContext ( ctx ) , lockStrength , setupKeyID , accountID )
}
// GetAccountNameServerGroups retrieves name server groups for an account.
func ( s * SqlStore ) GetAccountNameServerGroups ( ctx context . Context , accountID string ) ( [ ] * nbdns . NameServerGroup , error ) {
return getRecords [ * nbdns . NameServerGroup ] ( s . db . WithContext ( ctx ) , accountID )
}
// GetNameServerGroupByID retrieves a name server group by its ID and account ID.
func ( s * SqlStore ) GetNameServerGroupByID ( ctx context . Context , lockStrength LockingStrength , nsGroupID string , accountID string ) ( * nbdns . NameServerGroup , error ) {
return getRecordByID [ nbdns . NameServerGroup ] ( s . db . WithContext ( ctx ) , lockStrength , nsGroupID , accountID )
}
// getRecords retrieves records from the database based on the account ID.
func getRecords [ T any ] ( db * gorm . DB , accountID string ) ( [ ] T , error ) {
var record [ ] T
result := db . Find ( & record , accountIDCondition , accountID )
if err := result . Error ; err != nil {
parts := strings . Split ( fmt . Sprintf ( "%T" , record ) , "." )
recordType := parts [ len ( parts ) - 1 ]
return nil , status . Errorf ( status . Internal , "failed to get account %ss from store: %v" , recordType , err )
}
return record , nil
}
// getRecordByID retrieves a record by its ID and account ID from the database.
func getRecordByID [ T any ] ( db * gorm . DB , lockStrength LockingStrength , recordID , accountID string ) ( * T , error ) {
var record T
result := db . Clauses ( clause . Locking { Strength : string ( lockStrength ) } ) .
First ( & record , accountAndIDQueryCondition , accountID , recordID )
2024-09-24 20:57:33 +02:00
if err := result . Error ; err != nil {
2024-09-25 11:53:20 +02:00
parts := strings . Split ( fmt . Sprintf ( "%T" , record ) , "." )
recordType := parts [ len ( parts ) - 1 ]
2024-09-24 20:57:33 +02:00
if errors . Is ( result . Error , gorm . ErrRecordNotFound ) {
2024-09-25 11:53:20 +02:00
return nil , status . Errorf ( status . NotFound , "%s not found" , recordType )
2024-09-24 20:57:33 +02:00
}
2024-09-25 11:53:20 +02:00
return nil , status . Errorf ( status . Internal , "failed to get %s from store: %v" , recordType , err )
2024-09-24 20:57:33 +02:00
}
2024-09-25 11:53:20 +02:00
return & record , nil
2024-09-24 20:57:33 +02:00
}