diff --git a/management/server/account.go b/management/server/account.go index 5cb7cfd84..6895c9378 100644 --- a/management/server/account.go +++ b/management/server/account.go @@ -20,11 +20,6 @@ import ( cacheStore "github.com/eko/gocache/v3/store" "github.com/hashicorp/go-multierror" "github.com/miekg/dns" - gocache "github.com/patrickmn/go-cache" - "github.com/rs/xid" - log "github.com/sirupsen/logrus" - "golang.org/x/exp/maps" - "github.com/netbirdio/netbird/base62" nbdns "github.com/netbirdio/netbird/dns" "github.com/netbirdio/netbird/management/domain" @@ -41,6 +36,10 @@ import ( "github.com/netbirdio/netbird/management/server/status" "github.com/netbirdio/netbird/management/server/telemetry" "github.com/netbirdio/netbird/route" + gocache "github.com/patrickmn/go-cache" + "github.com/rs/xid" + log "github.com/sirupsen/logrus" + "golang.org/x/exp/maps" ) const ( @@ -1255,30 +1254,37 @@ func (am *DefaultAccountManager) DeleteAccount(ctx context.Context, accountID, u return nil } -// GetAccountIDByUserOrAccountID looks for an account by user or accountID, if no account is provided and -// userID doesn't have an account associated with it, one account is created -// domain is used to create a new account if no account is found +// GetAccountIDByUserOrAccountID retrieves the account ID based on either the userID or accountID provided. +// If an accountID is provided, it checks if the account exists and returns it. +// If no accountID is provided, but a userID is given, it tries to retrieve the account by userID. +// If the user doesn't have an account, it creates one using the provided domain. +// Returns the account ID or an error if none is found or created. func (am *DefaultAccountManager) GetAccountIDByUserOrAccountID(ctx context.Context, userID, accountID, domain string) (string, error) { if accountID != "" { - _, _, err := am.Store.GetAccountDomainAndCategory(ctx, accountID) + exists, err := am.Store.AccountExists(ctx, accountID) if err != nil { return "", err } + if !exists { + return "", status.Errorf(status.NotFound, "account %s does not exist", accountID) + } return accountID, nil - } else if userID != "" { + } + + if userID != "" { account, err := am.GetOrCreateAccountByUser(ctx, userID, domain) if err != nil { - return "", status.Errorf(status.NotFound, "account not found using user id: %s", userID) + return "", status.Errorf(status.NotFound, "account not found or created for user id: %s", userID) } - err = am.addAccountIDToIDPAppMeta(ctx, userID, account) - if err != nil { + if err = am.addAccountIDToIDPAppMeta(ctx, userID, account); err != nil { return "", err } + return account.Id, nil } - return "", status.Errorf(status.NotFound, "no valid user or account Id provided") + return "", status.Errorf(status.NotFound, "no valid userID or accountID provided") } func isNil(i idp.Manager) bool { @@ -1808,6 +1814,7 @@ func (am *DefaultAccountManager) syncJWTGroups(ctx context.Context, accountID st return nil } + // TODO: Remove GetAccount after refactoring account peer's update unlock := am.Store.AcquireWriteLockByUID(ctx, accountID) defer unlock() @@ -1907,7 +1914,7 @@ func (am *DefaultAccountManager) getAccountIDWithAuthorizationClaims(ctx context return "", fmt.Errorf("user %s is not part of the account id %s", claims.UserId, claims.AccountId) } - domain, domainCategory, err := am.Store.GetAccountDomainAndCategory(ctx, claims.AccountId) + domain, domainCategory, err := am.Store.GetAccountDomainAndCategory(ctx, LockingStrengthShare, claims.AccountId) if err != nil { return "", err } @@ -1923,7 +1930,7 @@ func (am *DefaultAccountManager) getAccountIDWithAuthorizationClaims(ctx context log.WithContext(ctx).Debugf("Acquired global lock in %s for user %s", time.Since(start), claims.UserId) // We checked if the domain has a primary account already - domainAccountID, err := am.Store.GetAccountIDByPrivateDomain(ctx, claims.Domain) + domainAccountID, err := am.Store.GetAccountIDByPrivateDomain(ctx, LockingStrengthShare, claims.Domain) if err != nil { // if NotFound we are good to continue, otherwise return error e, ok := status.FromError(err) diff --git a/management/server/file_store.go b/management/server/file_store.go index 84b5547a9..be4c6ec16 100644 --- a/management/server/file_store.go +++ b/management/server/file_store.go @@ -10,14 +10,13 @@ import ( "sync" "time" - "github.com/rs/xid" - log "github.com/sirupsen/logrus" - nbgroup "github.com/netbirdio/netbird/management/server/group" nbpeer "github.com/netbirdio/netbird/management/server/peer" "github.com/netbirdio/netbird/management/server/posture" "github.com/netbirdio/netbird/management/server/status" "github.com/netbirdio/netbird/management/server/telemetry" + "github.com/rs/xid" + log "github.com/sirupsen/logrus" "github.com/netbirdio/netbird/util" ) @@ -958,11 +957,11 @@ func (s *FileStore) SaveGroups(_ string, _ map[string]*nbgroup.Group) error { return status.Errorf(status.Internal, "SaveGroups is not implemented") } -func (s *FileStore) GetAccountIDByPrivateDomain(_ context.Context, _ string) (string, error) { +func (s *FileStore) GetAccountIDByPrivateDomain(_ context.Context, _ LockingStrength, _ string) (string, error) { return "", status.Errorf(status.Internal, "GetAccountIDByPrivateDomain is not implemented") } -func (s *FileStore) GetAccountDomainAndCategory(_ context.Context, accountID string) (string, string, error) { +func (s *FileStore) GetAccountDomainAndCategory(_ context.Context, _ LockingStrength, accountID string) (string, string, error) { s.mux.Lock() defer s.mux.Unlock() @@ -973,3 +972,13 @@ func (s *FileStore) GetAccountDomainAndCategory(_ context.Context, accountID str return account.Domain, account.DomainCategory, nil } + +// AccountExists checks whether an account exists by the given ID. +func (s *FileStore) AccountExists(_ context.Context, id string) (bool, error) { + _, exists := s.Accounts[id] + return exists, nil +} + +func (s *FileStore) UpdateAccount(_ context.Context, _ LockingStrength, _ *Account) error { + return nil +} diff --git a/management/server/sql_store.go b/management/server/sql_store.go index b4bcbfbd0..58b258404 100644 --- a/management/server/sql_store.go +++ b/management/server/sql_store.go @@ -400,7 +400,7 @@ func (s *SqlStore) DeleteTokenID2UserIDIndex(tokenID string) error { } func (s *SqlStore) GetAccountByPrivateDomain(ctx context.Context, domain string) (*Account, error) { - accountID, err := s.GetAccountIDByPrivateDomain(ctx, domain) + accountID, err := s.GetAccountIDByPrivateDomain(ctx, LockingStrengthShare, domain) if err != nil { return nil, err } @@ -409,11 +409,12 @@ func (s *SqlStore) GetAccountByPrivateDomain(ctx context.Context, domain string) return s.GetAccount(ctx, accountID) } -func (s *SqlStore) GetAccountIDByPrivateDomain(ctx context.Context, domain string) (string, error) { - var account Account - - result := s.db.First(&account, "domain = ? and is_domain_primary_account = ? and domain_category = ?", - strings.ToLower(domain), true, PrivateCategory) +func (s *SqlStore) GetAccountIDByPrivateDomain(ctx context.Context, lockStrength LockingStrength, domain string) (string, error) { + var accountID string + result := s.db.WithContext(ctx).Clauses(clause.Locking{Strength: string(lockStrength)}).Model(&Account{}).Select("id"). + Where("domain = ? and is_domain_primary_account = ? and domain_category = ?", + strings.ToLower(domain), true, PrivateCategory, + ).First(&accountID) if result.Error != nil { if errors.Is(result.Error, gorm.ErrRecordNotFound) { return "", status.Errorf(status.NotFound, "account not found: provided domain is not registered or is not private") @@ -422,7 +423,7 @@ func (s *SqlStore) GetAccountIDByPrivateDomain(ctx context.Context, domain strin return "", status.Errorf(status.Internal, "issue getting account from store") } - return account.Id, nil + return accountID, nil } func (s *SqlStore) GetAccountBySetupKey(ctx context.Context, setupKey string) (*Account, error) { @@ -671,9 +672,8 @@ func (s *SqlStore) GetAccountIDByPeerPubKey(ctx context.Context, peerKey string) } func (s *SqlStore) GetAccountIDByUserID(userID string) (string, error) { - var user User var accountID string - result := s.db.Model(&user).Select("account_id").Where(idQueryCondition, userID).First(&accountID) + result := s.db.Model(&User{}).Select("account_id").Where(idQueryCondition, userID).First(&accountID) if result.Error != nil { if errors.Is(result.Error, gorm.ErrRecordNotFound) { return "", status.Errorf(status.NotFound, "account not found: index lookup failed") @@ -1035,10 +1035,53 @@ func (s *SqlStore) withTx(tx *gorm.DB) Store { } } +// UpdateAccount updates an existing account's domain, DNS settings, and settings fields. +func (s *SqlStore) UpdateAccount(ctx context.Context, lockStrength LockingStrength, account *Account) error { + updates := make(map[string]interface{}) + + if account.Domain != "" { + updates["domain"] = account.Domain + } + + if account.DNSSettings.DisabledManagementGroups != nil { + updates["dns_settings"] = account.DNSSettings + } + + if account.Settings != nil { + updates["settings"] = account.Settings + } + + if len(updates) == 0 { + return nil + } + + result := s.db.WithContext(ctx).Clauses(clause.Locking{Strength: string(lockStrength)}).Model(&Account{}). + Where("id = ?", account.Id).Updates(updates) + if result.Error != nil { + return status.Errorf(status.Internal, "failed to update account: %v", result.Error) + } + + if result.RowsAffected == 0 { + return status.Errorf(status.NotFound, "account not found") + } + + return nil +} + +// AccountExists checks whether an account exists by the given ID. +func (s *SqlStore) AccountExists(ctx context.Context, id string) (bool, error) { + var count int64 + result := s.db.WithContext(ctx).Model(&Account{}).Where(idQueryCondition, id).Count(&count) + if result.Error != nil { + return false, result.Error + } + return count > 0, nil +} + // GetAccountDomainAndCategory retrieves the Domain and DomainCategory fields for an account based on the given accountID. -func (s *SqlStore) GetAccountDomainAndCategory(ctx context.Context, accountID string) (string, string, error) { +func (s *SqlStore) GetAccountDomainAndCategory(ctx context.Context, lockStrength LockingStrength, accountID string) (string, string, error) { var account Account - result := s.db.WithContext(ctx).Model(&Account{}).Select("domain", "domain_category"). + result := s.db.WithContext(ctx).Clauses(clause.Locking{Strength: string(lockStrength)}).Model(&Account{}).Select("domain", "domain_category"). Where(idQueryCondition, accountID).First(&account) if result.Error != nil { if errors.Is(result.Error, gorm.ErrRecordNotFound) { diff --git a/management/server/store.go b/management/server/store.go index 54a559605..8f00f62d6 100644 --- a/management/server/store.go +++ b/management/server/store.go @@ -39,8 +39,8 @@ const ( type Store interface { GetAllAccounts(ctx context.Context) []*Account GetAccount(ctx context.Context, accountID string) (*Account, error) - GetAccountDomainAndCategory(ctx context.Context, accountID string) (string, string, error) - DeleteAccount(ctx context.Context, account *Account) error + AccountExists(ctx context.Context, id string) (bool, error) + GetAccountDomainAndCategory(ctx context.Context, lockStrength LockingStrength, accountID string) (string, string, error) GetAccountByUser(ctx context.Context, userID string) (*Account, error) GetAccountByPeerPubKey(ctx context.Context, peerKey string) (*Account, error) GetAccountIDByPeerPubKey(ctx context.Context, peerKey string) (string, error) @@ -49,45 +49,56 @@ type Store interface { GetAccountByPeerID(ctx context.Context, peerID string) (*Account, error) GetAccountBySetupKey(ctx context.Context, setupKey string) (*Account, error) // todo use key hash later GetAccountByPrivateDomain(ctx context.Context, domain string) (*Account, error) - GetAccountIDByPrivateDomain(ctx context.Context, domain string) (string, error) - GetTokenIDByHashedToken(ctx context.Context, secret string) (string, error) + GetAccountIDByPrivateDomain(ctx context.Context, lockStrength LockingStrength, domain string) (string, error) + GetAccountSettings(ctx context.Context, lockStrength LockingStrength, accountID string) (*Settings, error) + UpdateAccount(ctx context.Context, lockStrength LockingStrength, account *Account) error + SaveAccount(ctx context.Context, account *Account) error + DeleteAccount(ctx context.Context, account *Account) error + GetUserByTokenID(ctx context.Context, tokenID string) (*User, error) GetUserByUserID(ctx context.Context, lockStrength LockingStrength, userID string) (*User, error) - GetAccountGroups(ctx context.Context, accountID string) ([]*nbgroup.Group, error) - GetPostureCheckByChecksDefinition(accountID string, checks *posture.ChecksDefinition) (*posture.Checks, error) - SaveAccount(ctx context.Context, account *Account) error SaveUsers(accountID string, users map[string]*User) error - SaveGroups(accountID string, groups map[string]*nbgroup.Group) error + SaveUserLastLogin(ctx context.Context, accountID, userID string, lastLogin time.Time) error + GetTokenIDByHashedToken(ctx context.Context, secret string) (string, error) DeleteHashedPAT2TokenIDIndex(hashedToken string) error DeleteTokenID2UserIDIndex(tokenID string) error + + GetAccountGroups(ctx context.Context, accountID string) ([]*nbgroup.Group, error) + SaveGroups(accountID string, groups map[string]*nbgroup.Group) error + + GetPostureCheckByChecksDefinition(accountID string, checks *posture.ChecksDefinition) (*posture.Checks, error) + + GetPeerLabelsInAccount(ctx context.Context, lockStrength LockingStrength, accountId string) ([]string, error) + AddPeerToAllGroup(ctx context.Context, accountID string, peerID string) error + AddPeerToGroup(ctx context.Context, accountId string, peerId string, groupID string) error + AddPeerToAccount(ctx context.Context, peer *nbpeer.Peer) error + GetPeerByPeerPubKey(ctx context.Context, lockStrength LockingStrength, peerKey string) (*nbpeer.Peer, error) + SavePeer(ctx context.Context, accountID string, peer *nbpeer.Peer) error + SavePeerStatus(accountID, peerID string, status nbpeer.PeerStatus) error + SavePeerLocation(accountID string, peer *nbpeer.Peer) error + + GetSetupKeyBySecret(ctx context.Context, lockStrength LockingStrength, key string) (*SetupKey, error) + IncrementSetupKeyUsage(ctx context.Context, setupKeyID string) error + + GetTakenIPs(ctx context.Context, lockStrength LockingStrength, accountId string) ([]net.IP, error) + IncrementNetworkSerial(ctx context.Context, accountId string) error + GetAccountNetwork(ctx context.Context, lockStrength LockingStrength, accountId string) (*Network, error) + GetInstallationID() string SaveInstallationID(ctx context.Context, ID string) error + // AcquireWriteLockByUID should attempt to acquire a lock for write purposes and return a function that releases the lock AcquireWriteLockByUID(ctx context.Context, uniqueID string) func() // AcquireReadLockByUID should attempt to acquire lock for read purposes and return a function that releases the lock AcquireReadLockByUID(ctx context.Context, uniqueID string) func() // AcquireGlobalLock should attempt to acquire a global lock and return a function that releases the lock AcquireGlobalLock(ctx context.Context) func() - SavePeer(ctx context.Context, accountID string, peer *nbpeer.Peer) error - SavePeerStatus(accountID, peerID string, status nbpeer.PeerStatus) error - SavePeerLocation(accountID string, peer *nbpeer.Peer) error - SaveUserLastLogin(ctx context.Context, accountID, userID string, lastLogin time.Time) error + // Close should close the store persisting all unsaved data. Close(ctx context.Context) error // GetStoreEngine should return StoreEngine of the current store implementation. // This is also a method of metrics.DataSource interface. GetStoreEngine() StoreEngine - GetPeerByPeerPubKey(ctx context.Context, lockStrength LockingStrength, peerKey string) (*nbpeer.Peer, error) - GetAccountSettings(ctx context.Context, lockStrength LockingStrength, accountID string) (*Settings, error) - GetSetupKeyBySecret(ctx context.Context, lockStrength LockingStrength, key string) (*SetupKey, error) - GetTakenIPs(ctx context.Context, lockStrength LockingStrength, accountId string) ([]net.IP, error) - IncrementSetupKeyUsage(ctx context.Context, setupKeyID string) error - AddPeerToAllGroup(ctx context.Context, accountID string, peerID string) error - GetPeerLabelsInAccount(ctx context.Context, lockStrength LockingStrength, accountId string) ([]string, error) - AddPeerToGroup(ctx context.Context, accountId string, peerId string, groupID string) error - AddPeerToAccount(ctx context.Context, peer *nbpeer.Peer) error - IncrementNetworkSerial(ctx context.Context, accountId string) error - GetAccountNetwork(ctx context.Context, lockStrength LockingStrength, accountId string) (*Network, error) ExecuteInTransaction(ctx context.Context, f func(store Store) error) error }