Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>
This commit is contained in:
bcmmbaga 2024-09-24 13:30:13 +03:00
parent d9f612d623
commit 28840383e1
No known key found for this signature in database
GPG Key ID: 511EED5C928AD547
4 changed files with 125 additions and 55 deletions

View File

@ -20,11 +20,6 @@ import (
cacheStore "github.com/eko/gocache/v3/store" cacheStore "github.com/eko/gocache/v3/store"
"github.com/hashicorp/go-multierror" "github.com/hashicorp/go-multierror"
"github.com/miekg/dns" "github.com/miekg/dns"
gocache "github.com/patrickmn/go-cache"
"github.com/rs/xid"
log "github.com/sirupsen/logrus"
"golang.org/x/exp/maps"
"github.com/netbirdio/netbird/base62" "github.com/netbirdio/netbird/base62"
nbdns "github.com/netbirdio/netbird/dns" nbdns "github.com/netbirdio/netbird/dns"
"github.com/netbirdio/netbird/management/domain" "github.com/netbirdio/netbird/management/domain"
@ -41,6 +36,10 @@ import (
"github.com/netbirdio/netbird/management/server/status" "github.com/netbirdio/netbird/management/server/status"
"github.com/netbirdio/netbird/management/server/telemetry" "github.com/netbirdio/netbird/management/server/telemetry"
"github.com/netbirdio/netbird/route" "github.com/netbirdio/netbird/route"
gocache "github.com/patrickmn/go-cache"
"github.com/rs/xid"
log "github.com/sirupsen/logrus"
"golang.org/x/exp/maps"
) )
const ( const (
@ -1255,30 +1254,37 @@ func (am *DefaultAccountManager) DeleteAccount(ctx context.Context, accountID, u
return nil return nil
} }
// GetAccountIDByUserOrAccountID looks for an account by user or accountID, if no account is provided and // GetAccountIDByUserOrAccountID retrieves the account ID based on either the userID or accountID provided.
// userID doesn't have an account associated with it, one account is created // If an accountID is provided, it checks if the account exists and returns it.
// domain is used to create a new account if no account is found // If no accountID is provided, but a userID is given, it tries to retrieve the account by userID.
// If the user doesn't have an account, it creates one using the provided domain.
// Returns the account ID or an error if none is found or created.
func (am *DefaultAccountManager) GetAccountIDByUserOrAccountID(ctx context.Context, userID, accountID, domain string) (string, error) { func (am *DefaultAccountManager) GetAccountIDByUserOrAccountID(ctx context.Context, userID, accountID, domain string) (string, error) {
if accountID != "" { if accountID != "" {
_, _, err := am.Store.GetAccountDomainAndCategory(ctx, accountID) exists, err := am.Store.AccountExists(ctx, accountID)
if err != nil { if err != nil {
return "", err return "", err
} }
if !exists {
return "", status.Errorf(status.NotFound, "account %s does not exist", accountID)
}
return accountID, nil return accountID, nil
} else if userID != "" { }
if userID != "" {
account, err := am.GetOrCreateAccountByUser(ctx, userID, domain) account, err := am.GetOrCreateAccountByUser(ctx, userID, domain)
if err != nil { if err != nil {
return "", status.Errorf(status.NotFound, "account not found using user id: %s", userID) return "", status.Errorf(status.NotFound, "account not found or created for user id: %s", userID)
} }
err = am.addAccountIDToIDPAppMeta(ctx, userID, account) if err = am.addAccountIDToIDPAppMeta(ctx, userID, account); err != nil {
if err != nil {
return "", err return "", err
} }
return account.Id, nil return account.Id, nil
} }
return "", status.Errorf(status.NotFound, "no valid user or account Id provided") return "", status.Errorf(status.NotFound, "no valid userID or accountID provided")
} }
func isNil(i idp.Manager) bool { func isNil(i idp.Manager) bool {
@ -1808,6 +1814,7 @@ func (am *DefaultAccountManager) syncJWTGroups(ctx context.Context, accountID st
return nil return nil
} }
// TODO: Remove GetAccount after refactoring account peer's update
unlock := am.Store.AcquireWriteLockByUID(ctx, accountID) unlock := am.Store.AcquireWriteLockByUID(ctx, accountID)
defer unlock() defer unlock()
@ -1907,7 +1914,7 @@ func (am *DefaultAccountManager) getAccountIDWithAuthorizationClaims(ctx context
return "", fmt.Errorf("user %s is not part of the account id %s", claims.UserId, claims.AccountId) return "", fmt.Errorf("user %s is not part of the account id %s", claims.UserId, claims.AccountId)
} }
domain, domainCategory, err := am.Store.GetAccountDomainAndCategory(ctx, claims.AccountId) domain, domainCategory, err := am.Store.GetAccountDomainAndCategory(ctx, LockingStrengthShare, claims.AccountId)
if err != nil { if err != nil {
return "", err return "", err
} }
@ -1923,7 +1930,7 @@ func (am *DefaultAccountManager) getAccountIDWithAuthorizationClaims(ctx context
log.WithContext(ctx).Debugf("Acquired global lock in %s for user %s", time.Since(start), claims.UserId) log.WithContext(ctx).Debugf("Acquired global lock in %s for user %s", time.Since(start), claims.UserId)
// We checked if the domain has a primary account already // We checked if the domain has a primary account already
domainAccountID, err := am.Store.GetAccountIDByPrivateDomain(ctx, claims.Domain) domainAccountID, err := am.Store.GetAccountIDByPrivateDomain(ctx, LockingStrengthShare, claims.Domain)
if err != nil { if err != nil {
// if NotFound we are good to continue, otherwise return error // if NotFound we are good to continue, otherwise return error
e, ok := status.FromError(err) e, ok := status.FromError(err)

View File

@ -10,14 +10,13 @@ import (
"sync" "sync"
"time" "time"
"github.com/rs/xid"
log "github.com/sirupsen/logrus"
nbgroup "github.com/netbirdio/netbird/management/server/group" nbgroup "github.com/netbirdio/netbird/management/server/group"
nbpeer "github.com/netbirdio/netbird/management/server/peer" nbpeer "github.com/netbirdio/netbird/management/server/peer"
"github.com/netbirdio/netbird/management/server/posture" "github.com/netbirdio/netbird/management/server/posture"
"github.com/netbirdio/netbird/management/server/status" "github.com/netbirdio/netbird/management/server/status"
"github.com/netbirdio/netbird/management/server/telemetry" "github.com/netbirdio/netbird/management/server/telemetry"
"github.com/rs/xid"
log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/util" "github.com/netbirdio/netbird/util"
) )
@ -958,11 +957,11 @@ func (s *FileStore) SaveGroups(_ string, _ map[string]*nbgroup.Group) error {
return status.Errorf(status.Internal, "SaveGroups is not implemented") return status.Errorf(status.Internal, "SaveGroups is not implemented")
} }
func (s *FileStore) GetAccountIDByPrivateDomain(_ context.Context, _ string) (string, error) { func (s *FileStore) GetAccountIDByPrivateDomain(_ context.Context, _ LockingStrength, _ string) (string, error) {
return "", status.Errorf(status.Internal, "GetAccountIDByPrivateDomain is not implemented") return "", status.Errorf(status.Internal, "GetAccountIDByPrivateDomain is not implemented")
} }
func (s *FileStore) GetAccountDomainAndCategory(_ context.Context, accountID string) (string, string, error) { func (s *FileStore) GetAccountDomainAndCategory(_ context.Context, _ LockingStrength, accountID string) (string, string, error) {
s.mux.Lock() s.mux.Lock()
defer s.mux.Unlock() defer s.mux.Unlock()
@ -973,3 +972,13 @@ func (s *FileStore) GetAccountDomainAndCategory(_ context.Context, accountID str
return account.Domain, account.DomainCategory, nil return account.Domain, account.DomainCategory, nil
} }
// AccountExists checks whether an account exists by the given ID.
func (s *FileStore) AccountExists(_ context.Context, id string) (bool, error) {
_, exists := s.Accounts[id]
return exists, nil
}
func (s *FileStore) UpdateAccount(_ context.Context, _ LockingStrength, _ *Account) error {
return nil
}

View File

@ -400,7 +400,7 @@ func (s *SqlStore) DeleteTokenID2UserIDIndex(tokenID string) error {
} }
func (s *SqlStore) GetAccountByPrivateDomain(ctx context.Context, domain string) (*Account, error) { func (s *SqlStore) GetAccountByPrivateDomain(ctx context.Context, domain string) (*Account, error) {
accountID, err := s.GetAccountIDByPrivateDomain(ctx, domain) accountID, err := s.GetAccountIDByPrivateDomain(ctx, LockingStrengthShare, domain)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -409,11 +409,12 @@ func (s *SqlStore) GetAccountByPrivateDomain(ctx context.Context, domain string)
return s.GetAccount(ctx, accountID) return s.GetAccount(ctx, accountID)
} }
func (s *SqlStore) GetAccountIDByPrivateDomain(ctx context.Context, domain string) (string, error) { func (s *SqlStore) GetAccountIDByPrivateDomain(ctx context.Context, lockStrength LockingStrength, domain string) (string, error) {
var account Account var accountID string
result := s.db.WithContext(ctx).Clauses(clause.Locking{Strength: string(lockStrength)}).Model(&Account{}).Select("id").
result := s.db.First(&account, "domain = ? and is_domain_primary_account = ? and domain_category = ?", Where("domain = ? and is_domain_primary_account = ? and domain_category = ?",
strings.ToLower(domain), true, PrivateCategory) strings.ToLower(domain), true, PrivateCategory,
).First(&accountID)
if result.Error != nil { if result.Error != nil {
if errors.Is(result.Error, gorm.ErrRecordNotFound) { if errors.Is(result.Error, gorm.ErrRecordNotFound) {
return "", status.Errorf(status.NotFound, "account not found: provided domain is not registered or is not private") return "", status.Errorf(status.NotFound, "account not found: provided domain is not registered or is not private")
@ -422,7 +423,7 @@ func (s *SqlStore) GetAccountIDByPrivateDomain(ctx context.Context, domain strin
return "", status.Errorf(status.Internal, "issue getting account from store") return "", status.Errorf(status.Internal, "issue getting account from store")
} }
return account.Id, nil return accountID, nil
} }
func (s *SqlStore) GetAccountBySetupKey(ctx context.Context, setupKey string) (*Account, error) { func (s *SqlStore) GetAccountBySetupKey(ctx context.Context, setupKey string) (*Account, error) {
@ -671,9 +672,8 @@ func (s *SqlStore) GetAccountIDByPeerPubKey(ctx context.Context, peerKey string)
} }
func (s *SqlStore) GetAccountIDByUserID(userID string) (string, error) { func (s *SqlStore) GetAccountIDByUserID(userID string) (string, error) {
var user User
var accountID string var accountID string
result := s.db.Model(&user).Select("account_id").Where(idQueryCondition, userID).First(&accountID) result := s.db.Model(&User{}).Select("account_id").Where(idQueryCondition, userID).First(&accountID)
if result.Error != nil { if result.Error != nil {
if errors.Is(result.Error, gorm.ErrRecordNotFound) { if errors.Is(result.Error, gorm.ErrRecordNotFound) {
return "", status.Errorf(status.NotFound, "account not found: index lookup failed") return "", status.Errorf(status.NotFound, "account not found: index lookup failed")
@ -1035,10 +1035,53 @@ func (s *SqlStore) withTx(tx *gorm.DB) Store {
} }
} }
// UpdateAccount updates an existing account's domain, DNS settings, and settings fields.
func (s *SqlStore) UpdateAccount(ctx context.Context, lockStrength LockingStrength, account *Account) error {
updates := make(map[string]interface{})
if account.Domain != "" {
updates["domain"] = account.Domain
}
if account.DNSSettings.DisabledManagementGroups != nil {
updates["dns_settings"] = account.DNSSettings
}
if account.Settings != nil {
updates["settings"] = account.Settings
}
if len(updates) == 0 {
return nil
}
result := s.db.WithContext(ctx).Clauses(clause.Locking{Strength: string(lockStrength)}).Model(&Account{}).
Where("id = ?", account.Id).Updates(updates)
if result.Error != nil {
return status.Errorf(status.Internal, "failed to update account: %v", result.Error)
}
if result.RowsAffected == 0 {
return status.Errorf(status.NotFound, "account not found")
}
return nil
}
// AccountExists checks whether an account exists by the given ID.
func (s *SqlStore) AccountExists(ctx context.Context, id string) (bool, error) {
var count int64
result := s.db.WithContext(ctx).Model(&Account{}).Where(idQueryCondition, id).Count(&count)
if result.Error != nil {
return false, result.Error
}
return count > 0, nil
}
// GetAccountDomainAndCategory retrieves the Domain and DomainCategory fields for an account based on the given accountID. // GetAccountDomainAndCategory retrieves the Domain and DomainCategory fields for an account based on the given accountID.
func (s *SqlStore) GetAccountDomainAndCategory(ctx context.Context, accountID string) (string, string, error) { func (s *SqlStore) GetAccountDomainAndCategory(ctx context.Context, lockStrength LockingStrength, accountID string) (string, string, error) {
var account Account var account Account
result := s.db.WithContext(ctx).Model(&Account{}).Select("domain", "domain_category"). result := s.db.WithContext(ctx).Clauses(clause.Locking{Strength: string(lockStrength)}).Model(&Account{}).Select("domain", "domain_category").
Where(idQueryCondition, accountID).First(&account) Where(idQueryCondition, accountID).First(&account)
if result.Error != nil { if result.Error != nil {
if errors.Is(result.Error, gorm.ErrRecordNotFound) { if errors.Is(result.Error, gorm.ErrRecordNotFound) {

View File

@ -39,8 +39,8 @@ const (
type Store interface { type Store interface {
GetAllAccounts(ctx context.Context) []*Account GetAllAccounts(ctx context.Context) []*Account
GetAccount(ctx context.Context, accountID string) (*Account, error) GetAccount(ctx context.Context, accountID string) (*Account, error)
GetAccountDomainAndCategory(ctx context.Context, accountID string) (string, string, error) AccountExists(ctx context.Context, id string) (bool, error)
DeleteAccount(ctx context.Context, account *Account) error GetAccountDomainAndCategory(ctx context.Context, lockStrength LockingStrength, accountID string) (string, string, error)
GetAccountByUser(ctx context.Context, userID string) (*Account, error) GetAccountByUser(ctx context.Context, userID string) (*Account, error)
GetAccountByPeerPubKey(ctx context.Context, peerKey string) (*Account, error) GetAccountByPeerPubKey(ctx context.Context, peerKey string) (*Account, error)
GetAccountIDByPeerPubKey(ctx context.Context, peerKey string) (string, error) GetAccountIDByPeerPubKey(ctx context.Context, peerKey string) (string, error)
@ -49,45 +49,56 @@ type Store interface {
GetAccountByPeerID(ctx context.Context, peerID string) (*Account, error) GetAccountByPeerID(ctx context.Context, peerID string) (*Account, error)
GetAccountBySetupKey(ctx context.Context, setupKey string) (*Account, error) // todo use key hash later GetAccountBySetupKey(ctx context.Context, setupKey string) (*Account, error) // todo use key hash later
GetAccountByPrivateDomain(ctx context.Context, domain string) (*Account, error) GetAccountByPrivateDomain(ctx context.Context, domain string) (*Account, error)
GetAccountIDByPrivateDomain(ctx context.Context, domain string) (string, error) GetAccountIDByPrivateDomain(ctx context.Context, lockStrength LockingStrength, domain string) (string, error)
GetTokenIDByHashedToken(ctx context.Context, secret string) (string, error) GetAccountSettings(ctx context.Context, lockStrength LockingStrength, accountID string) (*Settings, error)
UpdateAccount(ctx context.Context, lockStrength LockingStrength, account *Account) error
SaveAccount(ctx context.Context, account *Account) error
DeleteAccount(ctx context.Context, account *Account) error
GetUserByTokenID(ctx context.Context, tokenID string) (*User, error) GetUserByTokenID(ctx context.Context, tokenID string) (*User, error)
GetUserByUserID(ctx context.Context, lockStrength LockingStrength, userID string) (*User, error) GetUserByUserID(ctx context.Context, lockStrength LockingStrength, userID string) (*User, error)
GetAccountGroups(ctx context.Context, accountID string) ([]*nbgroup.Group, error)
GetPostureCheckByChecksDefinition(accountID string, checks *posture.ChecksDefinition) (*posture.Checks, error)
SaveAccount(ctx context.Context, account *Account) error
SaveUsers(accountID string, users map[string]*User) error SaveUsers(accountID string, users map[string]*User) error
SaveGroups(accountID string, groups map[string]*nbgroup.Group) error SaveUserLastLogin(ctx context.Context, accountID, userID string, lastLogin time.Time) error
GetTokenIDByHashedToken(ctx context.Context, secret string) (string, error)
DeleteHashedPAT2TokenIDIndex(hashedToken string) error DeleteHashedPAT2TokenIDIndex(hashedToken string) error
DeleteTokenID2UserIDIndex(tokenID string) error DeleteTokenID2UserIDIndex(tokenID string) error
GetAccountGroups(ctx context.Context, accountID string) ([]*nbgroup.Group, error)
SaveGroups(accountID string, groups map[string]*nbgroup.Group) error
GetPostureCheckByChecksDefinition(accountID string, checks *posture.ChecksDefinition) (*posture.Checks, error)
GetPeerLabelsInAccount(ctx context.Context, lockStrength LockingStrength, accountId string) ([]string, error)
AddPeerToAllGroup(ctx context.Context, accountID string, peerID string) error
AddPeerToGroup(ctx context.Context, accountId string, peerId string, groupID string) error
AddPeerToAccount(ctx context.Context, peer *nbpeer.Peer) error
GetPeerByPeerPubKey(ctx context.Context, lockStrength LockingStrength, peerKey string) (*nbpeer.Peer, error)
SavePeer(ctx context.Context, accountID string, peer *nbpeer.Peer) error
SavePeerStatus(accountID, peerID string, status nbpeer.PeerStatus) error
SavePeerLocation(accountID string, peer *nbpeer.Peer) error
GetSetupKeyBySecret(ctx context.Context, lockStrength LockingStrength, key string) (*SetupKey, error)
IncrementSetupKeyUsage(ctx context.Context, setupKeyID string) error
GetTakenIPs(ctx context.Context, lockStrength LockingStrength, accountId string) ([]net.IP, error)
IncrementNetworkSerial(ctx context.Context, accountId string) error
GetAccountNetwork(ctx context.Context, lockStrength LockingStrength, accountId string) (*Network, error)
GetInstallationID() string GetInstallationID() string
SaveInstallationID(ctx context.Context, ID string) error SaveInstallationID(ctx context.Context, ID string) error
// AcquireWriteLockByUID should attempt to acquire a lock for write purposes and return a function that releases the lock // AcquireWriteLockByUID should attempt to acquire a lock for write purposes and return a function that releases the lock
AcquireWriteLockByUID(ctx context.Context, uniqueID string) func() AcquireWriteLockByUID(ctx context.Context, uniqueID string) func()
// AcquireReadLockByUID should attempt to acquire lock for read purposes and return a function that releases the lock // AcquireReadLockByUID should attempt to acquire lock for read purposes and return a function that releases the lock
AcquireReadLockByUID(ctx context.Context, uniqueID string) func() AcquireReadLockByUID(ctx context.Context, uniqueID string) func()
// AcquireGlobalLock should attempt to acquire a global lock and return a function that releases the lock // AcquireGlobalLock should attempt to acquire a global lock and return a function that releases the lock
AcquireGlobalLock(ctx context.Context) func() AcquireGlobalLock(ctx context.Context) func()
SavePeer(ctx context.Context, accountID string, peer *nbpeer.Peer) error
SavePeerStatus(accountID, peerID string, status nbpeer.PeerStatus) error
SavePeerLocation(accountID string, peer *nbpeer.Peer) error
SaveUserLastLogin(ctx context.Context, accountID, userID string, lastLogin time.Time) error
// Close should close the store persisting all unsaved data. // Close should close the store persisting all unsaved data.
Close(ctx context.Context) error Close(ctx context.Context) error
// GetStoreEngine should return StoreEngine of the current store implementation. // GetStoreEngine should return StoreEngine of the current store implementation.
// This is also a method of metrics.DataSource interface. // This is also a method of metrics.DataSource interface.
GetStoreEngine() StoreEngine GetStoreEngine() StoreEngine
GetPeerByPeerPubKey(ctx context.Context, lockStrength LockingStrength, peerKey string) (*nbpeer.Peer, error)
GetAccountSettings(ctx context.Context, lockStrength LockingStrength, accountID string) (*Settings, error)
GetSetupKeyBySecret(ctx context.Context, lockStrength LockingStrength, key string) (*SetupKey, error)
GetTakenIPs(ctx context.Context, lockStrength LockingStrength, accountId string) ([]net.IP, error)
IncrementSetupKeyUsage(ctx context.Context, setupKeyID string) error
AddPeerToAllGroup(ctx context.Context, accountID string, peerID string) error
GetPeerLabelsInAccount(ctx context.Context, lockStrength LockingStrength, accountId string) ([]string, error)
AddPeerToGroup(ctx context.Context, accountId string, peerId string, groupID string) error
AddPeerToAccount(ctx context.Context, peer *nbpeer.Peer) error
IncrementNetworkSerial(ctx context.Context, accountId string) error
GetAccountNetwork(ctx context.Context, lockStrength LockingStrength, accountId string) (*Network, error)
ExecuteInTransaction(ctx context.Context, f func(store Store) error) error ExecuteInTransaction(ctx context.Context, f func(store Store) error) error
} }