mirror of
https://github.com/netbirdio/netbird.git
synced 2024-12-14 02:41:34 +01:00
refactor
Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>
This commit is contained in:
parent
d9f612d623
commit
28840383e1
@ -20,11 +20,6 @@ import (
|
||||
cacheStore "github.com/eko/gocache/v3/store"
|
||||
"github.com/hashicorp/go-multierror"
|
||||
"github.com/miekg/dns"
|
||||
gocache "github.com/patrickmn/go-cache"
|
||||
"github.com/rs/xid"
|
||||
log "github.com/sirupsen/logrus"
|
||||
"golang.org/x/exp/maps"
|
||||
|
||||
"github.com/netbirdio/netbird/base62"
|
||||
nbdns "github.com/netbirdio/netbird/dns"
|
||||
"github.com/netbirdio/netbird/management/domain"
|
||||
@ -41,6 +36,10 @@ import (
|
||||
"github.com/netbirdio/netbird/management/server/status"
|
||||
"github.com/netbirdio/netbird/management/server/telemetry"
|
||||
"github.com/netbirdio/netbird/route"
|
||||
gocache "github.com/patrickmn/go-cache"
|
||||
"github.com/rs/xid"
|
||||
log "github.com/sirupsen/logrus"
|
||||
"golang.org/x/exp/maps"
|
||||
)
|
||||
|
||||
const (
|
||||
@ -1255,30 +1254,37 @@ func (am *DefaultAccountManager) DeleteAccount(ctx context.Context, accountID, u
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetAccountIDByUserOrAccountID looks for an account by user or accountID, if no account is provided and
|
||||
// userID doesn't have an account associated with it, one account is created
|
||||
// domain is used to create a new account if no account is found
|
||||
// GetAccountIDByUserOrAccountID retrieves the account ID based on either the userID or accountID provided.
|
||||
// If an accountID is provided, it checks if the account exists and returns it.
|
||||
// If no accountID is provided, but a userID is given, it tries to retrieve the account by userID.
|
||||
// If the user doesn't have an account, it creates one using the provided domain.
|
||||
// Returns the account ID or an error if none is found or created.
|
||||
func (am *DefaultAccountManager) GetAccountIDByUserOrAccountID(ctx context.Context, userID, accountID, domain string) (string, error) {
|
||||
if accountID != "" {
|
||||
_, _, err := am.Store.GetAccountDomainAndCategory(ctx, accountID)
|
||||
exists, err := am.Store.AccountExists(ctx, accountID)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
if !exists {
|
||||
return "", status.Errorf(status.NotFound, "account %s does not exist", accountID)
|
||||
}
|
||||
return accountID, nil
|
||||
} else if userID != "" {
|
||||
account, err := am.GetOrCreateAccountByUser(ctx, userID, domain)
|
||||
if err != nil {
|
||||
return "", status.Errorf(status.NotFound, "account not found using user id: %s", userID)
|
||||
}
|
||||
|
||||
err = am.addAccountIDToIDPAppMeta(ctx, userID, account)
|
||||
if userID != "" {
|
||||
account, err := am.GetOrCreateAccountByUser(ctx, userID, domain)
|
||||
if err != nil {
|
||||
return "", status.Errorf(status.NotFound, "account not found or created for user id: %s", userID)
|
||||
}
|
||||
|
||||
if err = am.addAccountIDToIDPAppMeta(ctx, userID, account); err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
return account.Id, nil
|
||||
}
|
||||
|
||||
return "", status.Errorf(status.NotFound, "no valid user or account Id provided")
|
||||
return "", status.Errorf(status.NotFound, "no valid userID or accountID provided")
|
||||
}
|
||||
|
||||
func isNil(i idp.Manager) bool {
|
||||
@ -1808,6 +1814,7 @@ func (am *DefaultAccountManager) syncJWTGroups(ctx context.Context, accountID st
|
||||
return nil
|
||||
}
|
||||
|
||||
// TODO: Remove GetAccount after refactoring account peer's update
|
||||
unlock := am.Store.AcquireWriteLockByUID(ctx, accountID)
|
||||
defer unlock()
|
||||
|
||||
@ -1907,7 +1914,7 @@ func (am *DefaultAccountManager) getAccountIDWithAuthorizationClaims(ctx context
|
||||
return "", fmt.Errorf("user %s is not part of the account id %s", claims.UserId, claims.AccountId)
|
||||
}
|
||||
|
||||
domain, domainCategory, err := am.Store.GetAccountDomainAndCategory(ctx, claims.AccountId)
|
||||
domain, domainCategory, err := am.Store.GetAccountDomainAndCategory(ctx, LockingStrengthShare, claims.AccountId)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
@ -1923,7 +1930,7 @@ func (am *DefaultAccountManager) getAccountIDWithAuthorizationClaims(ctx context
|
||||
log.WithContext(ctx).Debugf("Acquired global lock in %s for user %s", time.Since(start), claims.UserId)
|
||||
|
||||
// We checked if the domain has a primary account already
|
||||
domainAccountID, err := am.Store.GetAccountIDByPrivateDomain(ctx, claims.Domain)
|
||||
domainAccountID, err := am.Store.GetAccountIDByPrivateDomain(ctx, LockingStrengthShare, claims.Domain)
|
||||
if err != nil {
|
||||
// if NotFound we are good to continue, otherwise return error
|
||||
e, ok := status.FromError(err)
|
||||
|
@ -10,14 +10,13 @@ import (
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/rs/xid"
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
nbgroup "github.com/netbirdio/netbird/management/server/group"
|
||||
nbpeer "github.com/netbirdio/netbird/management/server/peer"
|
||||
"github.com/netbirdio/netbird/management/server/posture"
|
||||
"github.com/netbirdio/netbird/management/server/status"
|
||||
"github.com/netbirdio/netbird/management/server/telemetry"
|
||||
"github.com/rs/xid"
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
"github.com/netbirdio/netbird/util"
|
||||
)
|
||||
@ -958,11 +957,11 @@ func (s *FileStore) SaveGroups(_ string, _ map[string]*nbgroup.Group) error {
|
||||
return status.Errorf(status.Internal, "SaveGroups is not implemented")
|
||||
}
|
||||
|
||||
func (s *FileStore) GetAccountIDByPrivateDomain(_ context.Context, _ string) (string, error) {
|
||||
func (s *FileStore) GetAccountIDByPrivateDomain(_ context.Context, _ LockingStrength, _ string) (string, error) {
|
||||
return "", status.Errorf(status.Internal, "GetAccountIDByPrivateDomain is not implemented")
|
||||
}
|
||||
|
||||
func (s *FileStore) GetAccountDomainAndCategory(_ context.Context, accountID string) (string, string, error) {
|
||||
func (s *FileStore) GetAccountDomainAndCategory(_ context.Context, _ LockingStrength, accountID string) (string, string, error) {
|
||||
s.mux.Lock()
|
||||
defer s.mux.Unlock()
|
||||
|
||||
@ -973,3 +972,13 @@ func (s *FileStore) GetAccountDomainAndCategory(_ context.Context, accountID str
|
||||
|
||||
return account.Domain, account.DomainCategory, nil
|
||||
}
|
||||
|
||||
// AccountExists checks whether an account exists by the given ID.
|
||||
func (s *FileStore) AccountExists(_ context.Context, id string) (bool, error) {
|
||||
_, exists := s.Accounts[id]
|
||||
return exists, nil
|
||||
}
|
||||
|
||||
func (s *FileStore) UpdateAccount(_ context.Context, _ LockingStrength, _ *Account) error {
|
||||
return nil
|
||||
}
|
||||
|
@ -400,7 +400,7 @@ func (s *SqlStore) DeleteTokenID2UserIDIndex(tokenID string) error {
|
||||
}
|
||||
|
||||
func (s *SqlStore) GetAccountByPrivateDomain(ctx context.Context, domain string) (*Account, error) {
|
||||
accountID, err := s.GetAccountIDByPrivateDomain(ctx, domain)
|
||||
accountID, err := s.GetAccountIDByPrivateDomain(ctx, LockingStrengthShare, domain)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@ -409,11 +409,12 @@ func (s *SqlStore) GetAccountByPrivateDomain(ctx context.Context, domain string)
|
||||
return s.GetAccount(ctx, accountID)
|
||||
}
|
||||
|
||||
func (s *SqlStore) GetAccountIDByPrivateDomain(ctx context.Context, domain string) (string, error) {
|
||||
var account Account
|
||||
|
||||
result := s.db.First(&account, "domain = ? and is_domain_primary_account = ? and domain_category = ?",
|
||||
strings.ToLower(domain), true, PrivateCategory)
|
||||
func (s *SqlStore) GetAccountIDByPrivateDomain(ctx context.Context, lockStrength LockingStrength, domain string) (string, error) {
|
||||
var accountID string
|
||||
result := s.db.WithContext(ctx).Clauses(clause.Locking{Strength: string(lockStrength)}).Model(&Account{}).Select("id").
|
||||
Where("domain = ? and is_domain_primary_account = ? and domain_category = ?",
|
||||
strings.ToLower(domain), true, PrivateCategory,
|
||||
).First(&accountID)
|
||||
if result.Error != nil {
|
||||
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
|
||||
return "", status.Errorf(status.NotFound, "account not found: provided domain is not registered or is not private")
|
||||
@ -422,7 +423,7 @@ func (s *SqlStore) GetAccountIDByPrivateDomain(ctx context.Context, domain strin
|
||||
return "", status.Errorf(status.Internal, "issue getting account from store")
|
||||
}
|
||||
|
||||
return account.Id, nil
|
||||
return accountID, nil
|
||||
}
|
||||
|
||||
func (s *SqlStore) GetAccountBySetupKey(ctx context.Context, setupKey string) (*Account, error) {
|
||||
@ -671,9 +672,8 @@ func (s *SqlStore) GetAccountIDByPeerPubKey(ctx context.Context, peerKey string)
|
||||
}
|
||||
|
||||
func (s *SqlStore) GetAccountIDByUserID(userID string) (string, error) {
|
||||
var user User
|
||||
var accountID string
|
||||
result := s.db.Model(&user).Select("account_id").Where(idQueryCondition, userID).First(&accountID)
|
||||
result := s.db.Model(&User{}).Select("account_id").Where(idQueryCondition, userID).First(&accountID)
|
||||
if result.Error != nil {
|
||||
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
|
||||
return "", status.Errorf(status.NotFound, "account not found: index lookup failed")
|
||||
@ -1035,10 +1035,53 @@ func (s *SqlStore) withTx(tx *gorm.DB) Store {
|
||||
}
|
||||
}
|
||||
|
||||
// UpdateAccount updates an existing account's domain, DNS settings, and settings fields.
|
||||
func (s *SqlStore) UpdateAccount(ctx context.Context, lockStrength LockingStrength, account *Account) error {
|
||||
updates := make(map[string]interface{})
|
||||
|
||||
if account.Domain != "" {
|
||||
updates["domain"] = account.Domain
|
||||
}
|
||||
|
||||
if account.DNSSettings.DisabledManagementGroups != nil {
|
||||
updates["dns_settings"] = account.DNSSettings
|
||||
}
|
||||
|
||||
if account.Settings != nil {
|
||||
updates["settings"] = account.Settings
|
||||
}
|
||||
|
||||
if len(updates) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
result := s.db.WithContext(ctx).Clauses(clause.Locking{Strength: string(lockStrength)}).Model(&Account{}).
|
||||
Where("id = ?", account.Id).Updates(updates)
|
||||
if result.Error != nil {
|
||||
return status.Errorf(status.Internal, "failed to update account: %v", result.Error)
|
||||
}
|
||||
|
||||
if result.RowsAffected == 0 {
|
||||
return status.Errorf(status.NotFound, "account not found")
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// AccountExists checks whether an account exists by the given ID.
|
||||
func (s *SqlStore) AccountExists(ctx context.Context, id string) (bool, error) {
|
||||
var count int64
|
||||
result := s.db.WithContext(ctx).Model(&Account{}).Where(idQueryCondition, id).Count(&count)
|
||||
if result.Error != nil {
|
||||
return false, result.Error
|
||||
}
|
||||
return count > 0, nil
|
||||
}
|
||||
|
||||
// GetAccountDomainAndCategory retrieves the Domain and DomainCategory fields for an account based on the given accountID.
|
||||
func (s *SqlStore) GetAccountDomainAndCategory(ctx context.Context, accountID string) (string, string, error) {
|
||||
func (s *SqlStore) GetAccountDomainAndCategory(ctx context.Context, lockStrength LockingStrength, accountID string) (string, string, error) {
|
||||
var account Account
|
||||
result := s.db.WithContext(ctx).Model(&Account{}).Select("domain", "domain_category").
|
||||
result := s.db.WithContext(ctx).Clauses(clause.Locking{Strength: string(lockStrength)}).Model(&Account{}).Select("domain", "domain_category").
|
||||
Where(idQueryCondition, accountID).First(&account)
|
||||
if result.Error != nil {
|
||||
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
|
||||
|
@ -39,8 +39,8 @@ const (
|
||||
type Store interface {
|
||||
GetAllAccounts(ctx context.Context) []*Account
|
||||
GetAccount(ctx context.Context, accountID string) (*Account, error)
|
||||
GetAccountDomainAndCategory(ctx context.Context, accountID string) (string, string, error)
|
||||
DeleteAccount(ctx context.Context, account *Account) error
|
||||
AccountExists(ctx context.Context, id string) (bool, error)
|
||||
GetAccountDomainAndCategory(ctx context.Context, lockStrength LockingStrength, accountID string) (string, string, error)
|
||||
GetAccountByUser(ctx context.Context, userID string) (*Account, error)
|
||||
GetAccountByPeerPubKey(ctx context.Context, peerKey string) (*Account, error)
|
||||
GetAccountIDByPeerPubKey(ctx context.Context, peerKey string) (string, error)
|
||||
@ -49,45 +49,56 @@ type Store interface {
|
||||
GetAccountByPeerID(ctx context.Context, peerID string) (*Account, error)
|
||||
GetAccountBySetupKey(ctx context.Context, setupKey string) (*Account, error) // todo use key hash later
|
||||
GetAccountByPrivateDomain(ctx context.Context, domain string) (*Account, error)
|
||||
GetAccountIDByPrivateDomain(ctx context.Context, domain string) (string, error)
|
||||
GetTokenIDByHashedToken(ctx context.Context, secret string) (string, error)
|
||||
GetAccountIDByPrivateDomain(ctx context.Context, lockStrength LockingStrength, domain string) (string, error)
|
||||
GetAccountSettings(ctx context.Context, lockStrength LockingStrength, accountID string) (*Settings, error)
|
||||
UpdateAccount(ctx context.Context, lockStrength LockingStrength, account *Account) error
|
||||
SaveAccount(ctx context.Context, account *Account) error
|
||||
DeleteAccount(ctx context.Context, account *Account) error
|
||||
|
||||
GetUserByTokenID(ctx context.Context, tokenID string) (*User, error)
|
||||
GetUserByUserID(ctx context.Context, lockStrength LockingStrength, userID string) (*User, error)
|
||||
GetAccountGroups(ctx context.Context, accountID string) ([]*nbgroup.Group, error)
|
||||
GetPostureCheckByChecksDefinition(accountID string, checks *posture.ChecksDefinition) (*posture.Checks, error)
|
||||
SaveAccount(ctx context.Context, account *Account) error
|
||||
SaveUsers(accountID string, users map[string]*User) error
|
||||
SaveGroups(accountID string, groups map[string]*nbgroup.Group) error
|
||||
SaveUserLastLogin(ctx context.Context, accountID, userID string, lastLogin time.Time) error
|
||||
GetTokenIDByHashedToken(ctx context.Context, secret string) (string, error)
|
||||
DeleteHashedPAT2TokenIDIndex(hashedToken string) error
|
||||
DeleteTokenID2UserIDIndex(tokenID string) error
|
||||
|
||||
GetAccountGroups(ctx context.Context, accountID string) ([]*nbgroup.Group, error)
|
||||
SaveGroups(accountID string, groups map[string]*nbgroup.Group) error
|
||||
|
||||
GetPostureCheckByChecksDefinition(accountID string, checks *posture.ChecksDefinition) (*posture.Checks, error)
|
||||
|
||||
GetPeerLabelsInAccount(ctx context.Context, lockStrength LockingStrength, accountId string) ([]string, error)
|
||||
AddPeerToAllGroup(ctx context.Context, accountID string, peerID string) error
|
||||
AddPeerToGroup(ctx context.Context, accountId string, peerId string, groupID string) error
|
||||
AddPeerToAccount(ctx context.Context, peer *nbpeer.Peer) error
|
||||
GetPeerByPeerPubKey(ctx context.Context, lockStrength LockingStrength, peerKey string) (*nbpeer.Peer, error)
|
||||
SavePeer(ctx context.Context, accountID string, peer *nbpeer.Peer) error
|
||||
SavePeerStatus(accountID, peerID string, status nbpeer.PeerStatus) error
|
||||
SavePeerLocation(accountID string, peer *nbpeer.Peer) error
|
||||
|
||||
GetSetupKeyBySecret(ctx context.Context, lockStrength LockingStrength, key string) (*SetupKey, error)
|
||||
IncrementSetupKeyUsage(ctx context.Context, setupKeyID string) error
|
||||
|
||||
GetTakenIPs(ctx context.Context, lockStrength LockingStrength, accountId string) ([]net.IP, error)
|
||||
IncrementNetworkSerial(ctx context.Context, accountId string) error
|
||||
GetAccountNetwork(ctx context.Context, lockStrength LockingStrength, accountId string) (*Network, error)
|
||||
|
||||
GetInstallationID() string
|
||||
SaveInstallationID(ctx context.Context, ID string) error
|
||||
|
||||
// AcquireWriteLockByUID should attempt to acquire a lock for write purposes and return a function that releases the lock
|
||||
AcquireWriteLockByUID(ctx context.Context, uniqueID string) func()
|
||||
// AcquireReadLockByUID should attempt to acquire lock for read purposes and return a function that releases the lock
|
||||
AcquireReadLockByUID(ctx context.Context, uniqueID string) func()
|
||||
// AcquireGlobalLock should attempt to acquire a global lock and return a function that releases the lock
|
||||
AcquireGlobalLock(ctx context.Context) func()
|
||||
SavePeer(ctx context.Context, accountID string, peer *nbpeer.Peer) error
|
||||
SavePeerStatus(accountID, peerID string, status nbpeer.PeerStatus) error
|
||||
SavePeerLocation(accountID string, peer *nbpeer.Peer) error
|
||||
SaveUserLastLogin(ctx context.Context, accountID, userID string, lastLogin time.Time) error
|
||||
|
||||
// Close should close the store persisting all unsaved data.
|
||||
Close(ctx context.Context) error
|
||||
// GetStoreEngine should return StoreEngine of the current store implementation.
|
||||
// This is also a method of metrics.DataSource interface.
|
||||
GetStoreEngine() StoreEngine
|
||||
GetPeerByPeerPubKey(ctx context.Context, lockStrength LockingStrength, peerKey string) (*nbpeer.Peer, error)
|
||||
GetAccountSettings(ctx context.Context, lockStrength LockingStrength, accountID string) (*Settings, error)
|
||||
GetSetupKeyBySecret(ctx context.Context, lockStrength LockingStrength, key string) (*SetupKey, error)
|
||||
GetTakenIPs(ctx context.Context, lockStrength LockingStrength, accountId string) ([]net.IP, error)
|
||||
IncrementSetupKeyUsage(ctx context.Context, setupKeyID string) error
|
||||
AddPeerToAllGroup(ctx context.Context, accountID string, peerID string) error
|
||||
GetPeerLabelsInAccount(ctx context.Context, lockStrength LockingStrength, accountId string) ([]string, error)
|
||||
AddPeerToGroup(ctx context.Context, accountId string, peerId string, groupID string) error
|
||||
AddPeerToAccount(ctx context.Context, peer *nbpeer.Peer) error
|
||||
IncrementNetworkSerial(ctx context.Context, accountId string) error
|
||||
GetAccountNetwork(ctx context.Context, lockStrength LockingStrength, accountId string) (*Network, error)
|
||||
ExecuteInTransaction(ctx context.Context, f func(store Store) error) error
|
||||
}
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user