mirror of
https://github.com/netbirdio/netbird.git
synced 2024-11-23 08:33:45 +01:00
95299be52d
* Removejsonfile' from test matrix in workflows * Remove sqlite to json migration command * Refactor store engine implementation to remove JSON file store support The codebase has been refactored to remove support for JSON file store storage engine, with SQLite serving as the default store engine. New functions have been added to handle unsupported store engines and to migrate data from file store to SQLite. * Remove 'downCmd' from migration commands * Refactoring * Add sqlite cleanup * Remove comment
715 lines
21 KiB
Go
715 lines
21 KiB
Go
package server
|
|
|
|
import (
|
|
"encoding/json"
|
|
"errors"
|
|
"fmt"
|
|
"os"
|
|
"path/filepath"
|
|
"runtime"
|
|
"strings"
|
|
"sync"
|
|
"time"
|
|
|
|
log "github.com/sirupsen/logrus"
|
|
"gorm.io/driver/postgres"
|
|
"gorm.io/driver/sqlite"
|
|
"gorm.io/gorm"
|
|
"gorm.io/gorm/clause"
|
|
"gorm.io/gorm/logger"
|
|
|
|
nbdns "github.com/netbirdio/netbird/dns"
|
|
"github.com/netbirdio/netbird/management/server/account"
|
|
nbgroup "github.com/netbirdio/netbird/management/server/group"
|
|
nbpeer "github.com/netbirdio/netbird/management/server/peer"
|
|
"github.com/netbirdio/netbird/management/server/posture"
|
|
"github.com/netbirdio/netbird/management/server/status"
|
|
"github.com/netbirdio/netbird/management/server/telemetry"
|
|
"github.com/netbirdio/netbird/route"
|
|
)
|
|
|
|
const (
|
|
storeSqliteFileName = "store.db"
|
|
)
|
|
|
|
// SqlStore represents an account storage backed by a Sql DB persisted to disk
|
|
type SqlStore struct {
|
|
db *gorm.DB
|
|
accountLocks sync.Map
|
|
globalAccountLock sync.Mutex
|
|
metrics telemetry.AppMetrics
|
|
installationPK int
|
|
storeEngine StoreEngine
|
|
}
|
|
|
|
type installation struct {
|
|
ID uint `gorm:"primaryKey"`
|
|
InstallationIDValue string
|
|
}
|
|
|
|
type migrationFunc func(*gorm.DB) error
|
|
|
|
// NewSqlStore creates a new SqlStore instance.
|
|
func NewSqlStore(db *gorm.DB, storeEngine StoreEngine, metrics telemetry.AppMetrics) (*SqlStore, error) {
|
|
sql, err := db.DB()
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
conns := runtime.NumCPU()
|
|
sql.SetMaxOpenConns(conns) // TODO: make it configurable
|
|
|
|
if err := migrate(db); err != nil {
|
|
return nil, fmt.Errorf("migrate: %w", err)
|
|
}
|
|
err = db.AutoMigrate(
|
|
&SetupKey{}, &nbpeer.Peer{}, &User{}, &PersonalAccessToken{}, &nbgroup.Group{},
|
|
&Account{}, &Policy{}, &PolicyRule{}, &route.Route{}, &nbdns.NameServerGroup{},
|
|
&installation{}, &account.ExtraSettings{}, &posture.Checks{}, &nbpeer.NetworkAddress{},
|
|
)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("auto migrate: %w", err)
|
|
}
|
|
|
|
return &SqlStore{db: db, storeEngine: storeEngine, metrics: metrics, installationPK: 1}, nil
|
|
}
|
|
|
|
// AcquireGlobalLock acquires global lock across all the accounts and returns a function that releases the lock
|
|
func (s *SqlStore) AcquireGlobalLock() (unlock func()) {
|
|
log.Tracef("acquiring global lock")
|
|
start := time.Now()
|
|
s.globalAccountLock.Lock()
|
|
|
|
unlock = func() {
|
|
s.globalAccountLock.Unlock()
|
|
log.Tracef("released global lock in %v", time.Since(start))
|
|
}
|
|
|
|
took := time.Since(start)
|
|
log.Tracef("took %v to acquire global lock", took)
|
|
if s.metrics != nil {
|
|
s.metrics.StoreMetrics().CountGlobalLockAcquisitionDuration(took)
|
|
}
|
|
|
|
return unlock
|
|
}
|
|
|
|
func (s *SqlStore) AcquireAccountWriteLock(accountID string) (unlock func()) {
|
|
log.Tracef("acquiring write lock for account %s", accountID)
|
|
|
|
start := time.Now()
|
|
value, _ := s.accountLocks.LoadOrStore(accountID, &sync.RWMutex{})
|
|
mtx := value.(*sync.RWMutex)
|
|
mtx.Lock()
|
|
|
|
unlock = func() {
|
|
mtx.Unlock()
|
|
log.Tracef("released write lock for account %s in %v", accountID, time.Since(start))
|
|
}
|
|
|
|
return unlock
|
|
}
|
|
|
|
func (s *SqlStore) AcquireAccountReadLock(accountID string) (unlock func()) {
|
|
log.Tracef("acquiring read lock for account %s", accountID)
|
|
|
|
start := time.Now()
|
|
value, _ := s.accountLocks.LoadOrStore(accountID, &sync.RWMutex{})
|
|
mtx := value.(*sync.RWMutex)
|
|
mtx.RLock()
|
|
|
|
unlock = func() {
|
|
mtx.RUnlock()
|
|
log.Tracef("released read lock for account %s in %v", accountID, time.Since(start))
|
|
}
|
|
|
|
return unlock
|
|
}
|
|
|
|
func (s *SqlStore) SaveAccount(account *Account) error {
|
|
start := time.Now()
|
|
|
|
for _, key := range account.SetupKeys {
|
|
account.SetupKeysG = append(account.SetupKeysG, *key)
|
|
}
|
|
|
|
for id, peer := range account.Peers {
|
|
peer.ID = id
|
|
account.PeersG = append(account.PeersG, *peer)
|
|
}
|
|
|
|
for id, user := range account.Users {
|
|
user.Id = id
|
|
for id, pat := range user.PATs {
|
|
pat.ID = id
|
|
user.PATsG = append(user.PATsG, *pat)
|
|
}
|
|
account.UsersG = append(account.UsersG, *user)
|
|
}
|
|
|
|
for id, group := range account.Groups {
|
|
group.ID = id
|
|
account.GroupsG = append(account.GroupsG, *group)
|
|
}
|
|
|
|
for id, route := range account.Routes {
|
|
route.ID = id
|
|
account.RoutesG = append(account.RoutesG, *route)
|
|
}
|
|
|
|
for id, ns := range account.NameServerGroups {
|
|
ns.ID = id
|
|
account.NameServerGroupsG = append(account.NameServerGroupsG, *ns)
|
|
}
|
|
|
|
err := s.db.Transaction(func(tx *gorm.DB) error {
|
|
result := tx.Select(clause.Associations).Delete(account.Policies, "account_id = ?", account.Id)
|
|
if result.Error != nil {
|
|
return result.Error
|
|
}
|
|
|
|
result = tx.Select(clause.Associations).Delete(account.UsersG, "account_id = ?", account.Id)
|
|
if result.Error != nil {
|
|
return result.Error
|
|
}
|
|
|
|
result = tx.Select(clause.Associations).Delete(account)
|
|
if result.Error != nil {
|
|
return result.Error
|
|
}
|
|
|
|
result = tx.
|
|
Session(&gorm.Session{FullSaveAssociations: true}).
|
|
Clauses(clause.OnConflict{UpdateAll: true}).
|
|
Create(account)
|
|
if result.Error != nil {
|
|
return result.Error
|
|
}
|
|
return nil
|
|
})
|
|
|
|
took := time.Since(start)
|
|
if s.metrics != nil {
|
|
s.metrics.StoreMetrics().CountPersistenceDuration(took)
|
|
}
|
|
log.Debugf("took %d ms to persist an account to the store", took.Milliseconds())
|
|
|
|
return err
|
|
}
|
|
|
|
func (s *SqlStore) DeleteAccount(account *Account) error {
|
|
start := time.Now()
|
|
|
|
err := s.db.Transaction(func(tx *gorm.DB) error {
|
|
result := tx.Select(clause.Associations).Delete(account.Policies, "account_id = ?", account.Id)
|
|
if result.Error != nil {
|
|
return result.Error
|
|
}
|
|
|
|
result = tx.Select(clause.Associations).Delete(account.UsersG, "account_id = ?", account.Id)
|
|
if result.Error != nil {
|
|
return result.Error
|
|
}
|
|
|
|
result = tx.Select(clause.Associations).Delete(account)
|
|
if result.Error != nil {
|
|
return result.Error
|
|
}
|
|
|
|
return nil
|
|
})
|
|
|
|
took := time.Since(start)
|
|
if s.metrics != nil {
|
|
s.metrics.StoreMetrics().CountPersistenceDuration(took)
|
|
}
|
|
log.Debugf("took %d ms to delete an account to the store", took.Milliseconds())
|
|
|
|
return err
|
|
}
|
|
|
|
func (s *SqlStore) SaveInstallationID(ID string) error {
|
|
installation := installation{InstallationIDValue: ID}
|
|
installation.ID = uint(s.installationPK)
|
|
|
|
return s.db.Clauses(clause.OnConflict{UpdateAll: true}).Create(&installation).Error
|
|
}
|
|
|
|
func (s *SqlStore) GetInstallationID() string {
|
|
var installation installation
|
|
|
|
if result := s.db.First(&installation, "id = ?", s.installationPK); result.Error != nil {
|
|
return ""
|
|
}
|
|
|
|
return installation.InstallationIDValue
|
|
}
|
|
|
|
func (s *SqlStore) SavePeerStatus(accountID, peerID string, peerStatus nbpeer.PeerStatus) error {
|
|
var peerCopy nbpeer.Peer
|
|
peerCopy.Status = &peerStatus
|
|
result := s.db.Model(&nbpeer.Peer{}).
|
|
Where("account_id = ? AND id = ?", accountID, peerID).
|
|
Updates(peerCopy)
|
|
|
|
if result.Error != nil {
|
|
return result.Error
|
|
}
|
|
|
|
if result.RowsAffected == 0 {
|
|
return status.Errorf(status.NotFound, "peer %s not found", peerID)
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
func (s *SqlStore) SavePeerLocation(accountID string, peerWithLocation *nbpeer.Peer) error {
|
|
// To maintain data integrity, we create a copy of the peer's location to prevent unintended updates to other fields.
|
|
var peerCopy nbpeer.Peer
|
|
// Since the location field has been migrated to JSON serialization,
|
|
// updating the struct ensures the correct data format is inserted into the database.
|
|
peerCopy.Location = peerWithLocation.Location
|
|
|
|
result := s.db.Model(&nbpeer.Peer{}).
|
|
Where("account_id = ? and id = ?", accountID, peerWithLocation.ID).
|
|
Updates(peerCopy)
|
|
|
|
if result.Error != nil {
|
|
return result.Error
|
|
}
|
|
|
|
if result.RowsAffected == 0 {
|
|
return status.Errorf(status.NotFound, "peer %s not found", peerWithLocation.ID)
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
// DeleteHashedPAT2TokenIDIndex is noop in SqlStore
|
|
func (s *SqlStore) DeleteHashedPAT2TokenIDIndex(hashedToken string) error {
|
|
return nil
|
|
}
|
|
|
|
// DeleteTokenID2UserIDIndex is noop in SqlStore
|
|
func (s *SqlStore) DeleteTokenID2UserIDIndex(tokenID string) error {
|
|
return nil
|
|
}
|
|
|
|
func (s *SqlStore) GetAccountByPrivateDomain(domain string) (*Account, error) {
|
|
var account Account
|
|
|
|
result := s.db.First(&account, "domain = ? and is_domain_primary_account = ? and domain_category = ?",
|
|
strings.ToLower(domain), true, PrivateCategory)
|
|
if result.Error != nil {
|
|
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
|
|
return nil, status.Errorf(status.NotFound, "account not found: provided domain is not registered or is not private")
|
|
}
|
|
log.Errorf("error when getting account from the store: %s", result.Error)
|
|
return nil, status.Errorf(status.Internal, "issue getting account from store")
|
|
}
|
|
|
|
// TODO: rework to not call GetAccount
|
|
return s.GetAccount(account.Id)
|
|
}
|
|
|
|
func (s *SqlStore) GetAccountBySetupKey(setupKey string) (*Account, error) {
|
|
var key SetupKey
|
|
result := s.db.Select("account_id").First(&key, "key = ?", strings.ToUpper(setupKey))
|
|
if result.Error != nil {
|
|
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
|
|
return nil, status.Errorf(status.NotFound, "account not found: index lookup failed")
|
|
}
|
|
log.Errorf("error when getting setup key from the store: %s", result.Error)
|
|
return nil, status.Errorf(status.Internal, "issue getting setup key from store")
|
|
}
|
|
|
|
if key.AccountID == "" {
|
|
return nil, status.Errorf(status.NotFound, "account not found: index lookup failed")
|
|
}
|
|
|
|
return s.GetAccount(key.AccountID)
|
|
}
|
|
|
|
func (s *SqlStore) GetTokenIDByHashedToken(hashedToken string) (string, error) {
|
|
var token PersonalAccessToken
|
|
result := s.db.First(&token, "hashed_token = ?", hashedToken)
|
|
if result.Error != nil {
|
|
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
|
|
return "", status.Errorf(status.NotFound, "account not found: index lookup failed")
|
|
}
|
|
log.Errorf("error when getting token from the store: %s", result.Error)
|
|
return "", status.Errorf(status.Internal, "issue getting account from store")
|
|
}
|
|
|
|
return token.ID, nil
|
|
}
|
|
|
|
func (s *SqlStore) GetUserByTokenID(tokenID string) (*User, error) {
|
|
var token PersonalAccessToken
|
|
result := s.db.First(&token, "id = ?", tokenID)
|
|
if result.Error != nil {
|
|
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
|
|
return nil, status.Errorf(status.NotFound, "account not found: index lookup failed")
|
|
}
|
|
log.Errorf("error when getting token from the store: %s", result.Error)
|
|
return nil, status.Errorf(status.Internal, "issue getting account from store")
|
|
}
|
|
|
|
if token.UserID == "" {
|
|
return nil, status.Errorf(status.NotFound, "account not found: index lookup failed")
|
|
}
|
|
|
|
var user User
|
|
result = s.db.Preload("PATsG").First(&user, "id = ?", token.UserID)
|
|
if result.Error != nil {
|
|
return nil, status.Errorf(status.NotFound, "account not found: index lookup failed")
|
|
}
|
|
|
|
user.PATs = make(map[string]*PersonalAccessToken, len(user.PATsG))
|
|
for _, pat := range user.PATsG {
|
|
user.PATs[pat.ID] = pat.Copy()
|
|
}
|
|
|
|
return &user, nil
|
|
}
|
|
|
|
func (s *SqlStore) GetAllAccounts() (all []*Account) {
|
|
var accounts []Account
|
|
result := s.db.Find(&accounts)
|
|
if result.Error != nil {
|
|
return all
|
|
}
|
|
|
|
for _, account := range accounts {
|
|
if acc, err := s.GetAccount(account.Id); err == nil {
|
|
all = append(all, acc)
|
|
}
|
|
}
|
|
|
|
return all
|
|
}
|
|
|
|
func (s *SqlStore) GetAccount(accountID string) (*Account, error) {
|
|
|
|
var account Account
|
|
result := s.db.Model(&account).
|
|
Preload("UsersG.PATsG"). // have to be specifies as this is nester reference
|
|
Preload(clause.Associations).
|
|
First(&account, "id = ?", accountID)
|
|
if result.Error != nil {
|
|
log.Errorf("error when getting account %s from the store: %s", accountID, result.Error)
|
|
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
|
|
return nil, status.Errorf(status.NotFound, "account not found")
|
|
}
|
|
return nil, status.Errorf(status.Internal, "issue getting account from store")
|
|
}
|
|
|
|
// we have to manually preload policy rules as it seems that gorm preloading doesn't do it for us
|
|
for i, policy := range account.Policies {
|
|
var rules []*PolicyRule
|
|
err := s.db.Model(&PolicyRule{}).Find(&rules, "policy_id = ?", policy.ID).Error
|
|
if err != nil {
|
|
return nil, status.Errorf(status.NotFound, "rule not found")
|
|
}
|
|
account.Policies[i].Rules = rules
|
|
}
|
|
|
|
account.SetupKeys = make(map[string]*SetupKey, len(account.SetupKeysG))
|
|
for _, key := range account.SetupKeysG {
|
|
account.SetupKeys[key.Key] = key.Copy()
|
|
}
|
|
account.SetupKeysG = nil
|
|
|
|
account.Peers = make(map[string]*nbpeer.Peer, len(account.PeersG))
|
|
for _, peer := range account.PeersG {
|
|
account.Peers[peer.ID] = peer.Copy()
|
|
}
|
|
account.PeersG = nil
|
|
|
|
account.Users = make(map[string]*User, len(account.UsersG))
|
|
for _, user := range account.UsersG {
|
|
user.PATs = make(map[string]*PersonalAccessToken, len(user.PATs))
|
|
for _, pat := range user.PATsG {
|
|
user.PATs[pat.ID] = pat.Copy()
|
|
}
|
|
account.Users[user.Id] = user.Copy()
|
|
}
|
|
account.UsersG = nil
|
|
|
|
account.Groups = make(map[string]*nbgroup.Group, len(account.GroupsG))
|
|
for _, group := range account.GroupsG {
|
|
account.Groups[group.ID] = group.Copy()
|
|
}
|
|
account.GroupsG = nil
|
|
|
|
account.Routes = make(map[route.ID]*route.Route, len(account.RoutesG))
|
|
for _, route := range account.RoutesG {
|
|
account.Routes[route.ID] = route.Copy()
|
|
}
|
|
account.RoutesG = nil
|
|
|
|
account.NameServerGroups = make(map[string]*nbdns.NameServerGroup, len(account.NameServerGroupsG))
|
|
for _, ns := range account.NameServerGroupsG {
|
|
account.NameServerGroups[ns.ID] = ns.Copy()
|
|
}
|
|
account.NameServerGroupsG = nil
|
|
|
|
return &account, nil
|
|
}
|
|
|
|
func (s *SqlStore) GetAccountByUser(userID string) (*Account, error) {
|
|
var user User
|
|
result := s.db.Select("account_id").First(&user, "id = ?", userID)
|
|
if result.Error != nil {
|
|
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
|
|
return nil, status.Errorf(status.NotFound, "account not found: index lookup failed")
|
|
}
|
|
return nil, status.Errorf(status.Internal, "issue getting account from store")
|
|
}
|
|
|
|
if user.AccountID == "" {
|
|
return nil, status.Errorf(status.NotFound, "account not found: index lookup failed")
|
|
}
|
|
|
|
return s.GetAccount(user.AccountID)
|
|
}
|
|
|
|
func (s *SqlStore) GetAccountByPeerID(peerID string) (*Account, error) {
|
|
var peer nbpeer.Peer
|
|
result := s.db.Select("account_id").First(&peer, "id = ?", peerID)
|
|
if result.Error != nil {
|
|
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
|
|
return nil, status.Errorf(status.NotFound, "account not found: index lookup failed")
|
|
}
|
|
log.Errorf("error when getting peer from the store: %s", result.Error)
|
|
return nil, status.Errorf(status.Internal, "issue getting account from store")
|
|
}
|
|
|
|
if peer.AccountID == "" {
|
|
return nil, status.Errorf(status.NotFound, "account not found: index lookup failed")
|
|
}
|
|
|
|
return s.GetAccount(peer.AccountID)
|
|
}
|
|
|
|
func (s *SqlStore) GetAccountByPeerPubKey(peerKey string) (*Account, error) {
|
|
var peer nbpeer.Peer
|
|
|
|
result := s.db.Select("account_id").First(&peer, "key = ?", peerKey)
|
|
if result.Error != nil {
|
|
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
|
|
return nil, status.Errorf(status.NotFound, "account not found: index lookup failed")
|
|
}
|
|
log.Errorf("error when getting peer from the store: %s", result.Error)
|
|
return nil, status.Errorf(status.Internal, "issue getting account from store")
|
|
}
|
|
|
|
if peer.AccountID == "" {
|
|
return nil, status.Errorf(status.NotFound, "account not found: index lookup failed")
|
|
}
|
|
|
|
return s.GetAccount(peer.AccountID)
|
|
}
|
|
|
|
func (s *SqlStore) GetAccountIDByPeerPubKey(peerKey string) (string, error) {
|
|
var peer nbpeer.Peer
|
|
var accountID string
|
|
result := s.db.Model(&peer).Select("account_id").Where("key = ?", peerKey).First(&accountID)
|
|
if result.Error != nil {
|
|
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
|
|
return "", status.Errorf(status.NotFound, "account not found: index lookup failed")
|
|
}
|
|
log.Errorf("error when getting peer from the store: %s", result.Error)
|
|
return "", status.Errorf(status.Internal, "issue getting account from store")
|
|
}
|
|
|
|
return accountID, nil
|
|
}
|
|
|
|
func (s *SqlStore) GetAccountIDByUserID(userID string) (string, error) {
|
|
var user User
|
|
var accountID string
|
|
result := s.db.Model(&user).Select("account_id").Where("id = ?", userID).First(&accountID)
|
|
if result.Error != nil {
|
|
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
|
|
return "", status.Errorf(status.NotFound, "account not found: index lookup failed")
|
|
}
|
|
return "", status.Errorf(status.Internal, "issue getting account from store")
|
|
}
|
|
|
|
return accountID, nil
|
|
}
|
|
|
|
func (s *SqlStore) GetAccountIDBySetupKey(setupKey string) (string, error) {
|
|
var key SetupKey
|
|
var accountID string
|
|
result := s.db.Model(&key).Select("account_id").Where("key = ?", strings.ToUpper(setupKey)).First(&accountID)
|
|
if result.Error != nil {
|
|
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
|
|
return "", status.Errorf(status.NotFound, "account not found: index lookup failed")
|
|
}
|
|
log.Errorf("error when getting setup key from the store: %s", result.Error)
|
|
return "", status.Errorf(status.Internal, "issue getting setup key from store")
|
|
}
|
|
|
|
return accountID, nil
|
|
}
|
|
|
|
func (s *SqlStore) GetPeerByPeerPubKey(peerKey string) (*nbpeer.Peer, error) {
|
|
var peer nbpeer.Peer
|
|
result := s.db.First(&peer, "key = ?", peerKey)
|
|
if result.Error != nil {
|
|
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
|
|
return nil, status.Errorf(status.NotFound, "peer not found")
|
|
}
|
|
log.Errorf("error when getting peer from the store: %s", result.Error)
|
|
return nil, status.Errorf(status.Internal, "issue getting peer from store")
|
|
}
|
|
|
|
return &peer, nil
|
|
}
|
|
|
|
func (s *SqlStore) GetAccountSettings(accountID string) (*Settings, error) {
|
|
var accountSettings AccountSettings
|
|
if err := s.db.Model(&Account{}).Where("id = ?", accountID).First(&accountSettings).Error; err != nil {
|
|
if errors.Is(err, gorm.ErrRecordNotFound) {
|
|
return nil, status.Errorf(status.NotFound, "settings not found")
|
|
}
|
|
log.Errorf("error when getting settings from the store: %s", err)
|
|
return nil, status.Errorf(status.Internal, "issue getting settings from store")
|
|
}
|
|
return accountSettings.Settings, nil
|
|
}
|
|
|
|
// SaveUserLastLogin stores the last login time for a user in DB.
|
|
func (s *SqlStore) SaveUserLastLogin(accountID, userID string, lastLogin time.Time) error {
|
|
var user User
|
|
|
|
result := s.db.First(&user, "account_id = ? and id = ?", accountID, userID)
|
|
if result.Error != nil {
|
|
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
|
|
return status.Errorf(status.NotFound, "user %s not found", userID)
|
|
}
|
|
return status.Errorf(status.Internal, "issue getting user from store")
|
|
}
|
|
|
|
user.LastLogin = lastLogin
|
|
|
|
return s.db.Save(user).Error
|
|
}
|
|
|
|
func (s *SqlStore) GetPostureCheckByChecksDefinition(accountID string, checks *posture.ChecksDefinition) (*posture.Checks, error) {
|
|
definitionJSON, err := json.Marshal(checks)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
var postureCheck posture.Checks
|
|
err = s.db.Where("account_id = ? AND checks = ?", accountID, string(definitionJSON)).First(&postureCheck).Error
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
return &postureCheck, nil
|
|
}
|
|
|
|
// Close closes the underlying DB connection
|
|
func (s *SqlStore) Close() error {
|
|
sql, err := s.db.DB()
|
|
if err != nil {
|
|
return fmt.Errorf("get db: %w", err)
|
|
}
|
|
return sql.Close()
|
|
}
|
|
|
|
// GetStoreEngine returns underlying store engine
|
|
func (s *SqlStore) GetStoreEngine() StoreEngine {
|
|
return s.storeEngine
|
|
}
|
|
|
|
// NewSqliteStore creates a new SQLite store.
|
|
func NewSqliteStore(dataDir string, metrics telemetry.AppMetrics) (*SqlStore, error) {
|
|
storeStr := fmt.Sprintf("%s?cache=shared", storeSqliteFileName)
|
|
if runtime.GOOS == "windows" {
|
|
// Vo avoid `The process cannot access the file because it is being used by another process` on Windows
|
|
storeStr = storeSqliteFileName
|
|
}
|
|
|
|
file := filepath.Join(dataDir, storeStr)
|
|
db, err := gorm.Open(sqlite.Open(file), &gorm.Config{
|
|
Logger: logger.Default.LogMode(logger.Silent),
|
|
CreateBatchSize: 400,
|
|
PrepareStmt: true,
|
|
})
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
return NewSqlStore(db, SqliteStoreEngine, metrics)
|
|
}
|
|
|
|
// NewPostgresqlStore creates a new Postgres store.
|
|
func NewPostgresqlStore(dsn string, metrics telemetry.AppMetrics) (*SqlStore, error) {
|
|
db, err := gorm.Open(postgres.Open(dsn), &gorm.Config{
|
|
Logger: logger.Default.LogMode(logger.Silent),
|
|
PrepareStmt: true,
|
|
})
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
return NewSqlStore(db, PostgresStoreEngine, metrics)
|
|
}
|
|
|
|
// newPostgresStore initializes a new Postgres store.
|
|
func newPostgresStore(metrics telemetry.AppMetrics) (Store, error) {
|
|
dsn, ok := os.LookupEnv(postgresDsnEnv)
|
|
if !ok {
|
|
return nil, fmt.Errorf("%s is not set", postgresDsnEnv)
|
|
}
|
|
return NewPostgresqlStore(dsn, metrics)
|
|
}
|
|
|
|
// NewSqliteStoreFromFileStore restores a store from FileStore and stores SQLite DB in the file located in datadir.
|
|
func NewSqliteStoreFromFileStore(fileStore *FileStore, dataDir string, metrics telemetry.AppMetrics) (*SqlStore, error) {
|
|
store, err := NewSqliteStore(dataDir, metrics)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
err = store.SaveInstallationID(fileStore.InstallationID)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
for _, account := range fileStore.GetAllAccounts() {
|
|
err := store.SaveAccount(account)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
}
|
|
|
|
return store, nil
|
|
}
|
|
|
|
// NewPostgresqlStoreFromFileStore restores a store from FileStore and stores Postgres DB.
|
|
func NewPostgresqlStoreFromFileStore(fileStore *FileStore, dsn string, metrics telemetry.AppMetrics) (*SqlStore, error) {
|
|
store, err := NewPostgresqlStore(dsn, metrics)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
err = store.SaveInstallationID(fileStore.InstallationID)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
for _, account := range fileStore.GetAllAccounts() {
|
|
err := store.SaveAccount(account)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
}
|
|
|
|
return store, nil
|
|
}
|