mirror of
https://github.com/netbirdio/netbird.git
synced 2024-11-25 01:23:22 +01:00
[management] Add transaction to addPeer (#2469)
This PR removes the GetAccount and SaveAccount operations from the AddPeer and instead makes use of gorm.Transaction to add the new peer.
This commit is contained in:
parent
730dd1733e
commit
6c50b0c84b
2
.github/workflows/golang-test-linux.yml
vendored
2
.github/workflows/golang-test-linux.yml
vendored
@ -49,7 +49,7 @@ jobs:
|
|||||||
run: git --no-pager diff --exit-code
|
run: git --no-pager diff --exit-code
|
||||||
|
|
||||||
- name: Test
|
- name: Test
|
||||||
run: CGO_ENABLED=1 GOARCH=${{ matrix.arch }} NETBIRD_STORE_ENGINE=${{ matrix.store }} go test -exec 'sudo --preserve-env=CI,NETBIRD_STORE_ENGINE' -timeout 5m -p 1 ./...
|
run: CGO_ENABLED=1 GOARCH=${{ matrix.arch }} NETBIRD_STORE_ENGINE=${{ matrix.store }} go test -exec 'sudo --preserve-env=CI,NETBIRD_STORE_ENGINE' -timeout 6m -p 1 ./...
|
||||||
|
|
||||||
test_client_on_docker:
|
test_client_on_docker:
|
||||||
runs-on: ubuntu-20.04
|
runs-on: ubuntu-20.04
|
||||||
|
@ -263,6 +263,11 @@ type AccountSettings struct {
|
|||||||
Settings *Settings `gorm:"embedded;embeddedPrefix:settings_"`
|
Settings *Settings `gorm:"embedded;embeddedPrefix:settings_"`
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Subclass used in gorm to only load network and not whole account
|
||||||
|
type AccountNetwork struct {
|
||||||
|
Network *Network `gorm:"embedded;embeddedPrefix:network_"`
|
||||||
|
}
|
||||||
|
|
||||||
type UserPermissions struct {
|
type UserPermissions struct {
|
||||||
DashboardView string `json:"dashboard_view"`
|
DashboardView string `json:"dashboard_view"`
|
||||||
}
|
}
|
||||||
@ -700,14 +705,6 @@ func (a *Account) GetPeerGroupsList(peerID string) []string {
|
|||||||
return grps
|
return grps
|
||||||
}
|
}
|
||||||
|
|
||||||
func (a *Account) getUserGroups(userID string) ([]string, error) {
|
|
||||||
user, err := a.FindUser(userID)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
return user.AutoGroups, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (a *Account) getPeerDNSManagementStatus(peerID string) bool {
|
func (a *Account) getPeerDNSManagementStatus(peerID string) bool {
|
||||||
peerGroups := a.getPeerGroups(peerID)
|
peerGroups := a.getPeerGroups(peerID)
|
||||||
enabled := true
|
enabled := true
|
||||||
@ -734,14 +731,6 @@ func (a *Account) getPeerGroups(peerID string) lookupMap {
|
|||||||
return groupList
|
return groupList
|
||||||
}
|
}
|
||||||
|
|
||||||
func (a *Account) getSetupKeyGroups(setupKey string) ([]string, error) {
|
|
||||||
key, err := a.FindSetupKey(setupKey)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
return key.AutoGroups, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (a *Account) getTakenIPs() []net.IP {
|
func (a *Account) getTakenIPs() []net.IP {
|
||||||
var takenIps []net.IP
|
var takenIps []net.IP
|
||||||
for _, existingPeer := range a.Peers {
|
for _, existingPeer := range a.Peers {
|
||||||
@ -2082,7 +2071,7 @@ func (am *DefaultAccountManager) GetAccountIDForPeerKey(ctx context.Context, pee
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (am *DefaultAccountManager) handleUserPeer(ctx context.Context, peer *nbpeer.Peer, settings *Settings) (bool, error) {
|
func (am *DefaultAccountManager) handleUserPeer(ctx context.Context, peer *nbpeer.Peer, settings *Settings) (bool, error) {
|
||||||
user, err := am.Store.GetUserByUserID(ctx, peer.UserID)
|
user, err := am.Store.GetUserByUserID(ctx, LockingStrengthShare, peer.UserID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return false, err
|
return false, err
|
||||||
}
|
}
|
||||||
@ -2103,6 +2092,25 @@ func (am *DefaultAccountManager) handleUserPeer(ctx context.Context, peer *nbpee
|
|||||||
return false, nil
|
return false, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (am *DefaultAccountManager) getFreeDNSLabel(ctx context.Context, store Store, accountID string, peerHostName string) (string, error) {
|
||||||
|
existingLabels, err := store.GetPeerLabelsInAccount(ctx, LockingStrengthShare, accountID)
|
||||||
|
if err != nil {
|
||||||
|
return "", fmt.Errorf("failed to get peer dns labels: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
labelMap := ConvertSliceToMap(existingLabels)
|
||||||
|
newLabel, err := getPeerHostLabel(peerHostName, labelMap)
|
||||||
|
if err != nil {
|
||||||
|
return "", fmt.Errorf("failed to get new host label: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if newLabel == "" {
|
||||||
|
return "", fmt.Errorf("failed to get new host label: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return newLabel, nil
|
||||||
|
}
|
||||||
|
|
||||||
// addAllGroup to account object if it doesn't exist
|
// addAllGroup to account object if it doesn't exist
|
||||||
func addAllGroup(account *Account) error {
|
func addAllGroup(account *Account) error {
|
||||||
if len(account.Groups) == 0 {
|
if len(account.Groups) == 0 {
|
||||||
|
@ -7,6 +7,7 @@ import (
|
|||||||
"time"
|
"time"
|
||||||
|
|
||||||
nbpeer "github.com/netbirdio/netbird/management/server/peer"
|
nbpeer "github.com/netbirdio/netbird/management/server/peer"
|
||||||
|
"github.com/netbirdio/netbird/management/server/status"
|
||||||
)
|
)
|
||||||
|
|
||||||
type MockStore struct {
|
type MockStore struct {
|
||||||
@ -24,7 +25,7 @@ func (s *MockStore) GetAccountByPeerID(_ context.Context, peerId string) (*Accou
|
|||||||
return s.account, nil
|
return s.account, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
return nil, fmt.Errorf("account not found")
|
return nil, status.NewPeerNotFoundError(peerId)
|
||||||
}
|
}
|
||||||
|
|
||||||
type MocAccountManager struct {
|
type MocAccountManager struct {
|
||||||
|
@ -2,6 +2,8 @@ package server
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
|
"errors"
|
||||||
|
"net"
|
||||||
"os"
|
"os"
|
||||||
"path/filepath"
|
"path/filepath"
|
||||||
"strings"
|
"strings"
|
||||||
@ -46,6 +48,158 @@ type FileStore struct {
|
|||||||
metrics telemetry.AppMetrics `json:"-"`
|
metrics telemetry.AppMetrics `json:"-"`
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (s *FileStore) ExecuteInTransaction(ctx context.Context, f func(store Store) error) error {
|
||||||
|
return f(s)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *FileStore) IncrementSetupKeyUsage(ctx context.Context, setupKeyID string) error {
|
||||||
|
s.mux.Lock()
|
||||||
|
defer s.mux.Unlock()
|
||||||
|
|
||||||
|
accountID, ok := s.SetupKeyID2AccountID[strings.ToUpper(setupKeyID)]
|
||||||
|
if !ok {
|
||||||
|
return status.NewSetupKeyNotFoundError()
|
||||||
|
}
|
||||||
|
|
||||||
|
account, err := s.getAccount(accountID)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
account.SetupKeys[setupKeyID].UsedTimes++
|
||||||
|
|
||||||
|
return s.SaveAccount(ctx, account)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *FileStore) AddPeerToAllGroup(ctx context.Context, accountID string, peerID string) error {
|
||||||
|
s.mux.Lock()
|
||||||
|
defer s.mux.Unlock()
|
||||||
|
|
||||||
|
account, err := s.getAccount(accountID)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
allGroup, err := account.GetGroupAll()
|
||||||
|
if err != nil || allGroup == nil {
|
||||||
|
return errors.New("all group not found")
|
||||||
|
}
|
||||||
|
|
||||||
|
allGroup.Peers = append(allGroup.Peers, peerID)
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *FileStore) AddPeerToGroup(ctx context.Context, accountId string, peerId string, groupID string) error {
|
||||||
|
s.mux.Lock()
|
||||||
|
defer s.mux.Unlock()
|
||||||
|
|
||||||
|
account, err := s.getAccount(accountId)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
account.Groups[groupID].Peers = append(account.Groups[groupID].Peers, peerId)
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *FileStore) AddPeerToAccount(ctx context.Context, peer *nbpeer.Peer) error {
|
||||||
|
s.mux.Lock()
|
||||||
|
defer s.mux.Unlock()
|
||||||
|
|
||||||
|
account, ok := s.Accounts[peer.AccountID]
|
||||||
|
if !ok {
|
||||||
|
return status.NewAccountNotFoundError(peer.AccountID)
|
||||||
|
}
|
||||||
|
|
||||||
|
account.Peers[peer.ID] = peer
|
||||||
|
return s.SaveAccount(ctx, account)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *FileStore) IncrementNetworkSerial(ctx context.Context, accountId string) error {
|
||||||
|
s.mux.Lock()
|
||||||
|
defer s.mux.Unlock()
|
||||||
|
|
||||||
|
account, ok := s.Accounts[accountId]
|
||||||
|
if !ok {
|
||||||
|
return status.NewAccountNotFoundError(accountId)
|
||||||
|
}
|
||||||
|
|
||||||
|
account.Network.Serial++
|
||||||
|
|
||||||
|
return s.SaveAccount(ctx, account)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *FileStore) GetSetupKeyBySecret(ctx context.Context, lockStrength LockingStrength, key string) (*SetupKey, error) {
|
||||||
|
s.mux.Lock()
|
||||||
|
defer s.mux.Unlock()
|
||||||
|
|
||||||
|
accountID, ok := s.SetupKeyID2AccountID[strings.ToUpper(key)]
|
||||||
|
if !ok {
|
||||||
|
return nil, status.NewSetupKeyNotFoundError()
|
||||||
|
}
|
||||||
|
|
||||||
|
account, err := s.getAccount(accountID)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
setupKey, ok := account.SetupKeys[key]
|
||||||
|
if !ok {
|
||||||
|
return nil, status.Errorf(status.NotFound, "setup key not found")
|
||||||
|
}
|
||||||
|
|
||||||
|
return setupKey, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *FileStore) GetTakenIPs(ctx context.Context, lockStrength LockingStrength, accountID string) ([]net.IP, error) {
|
||||||
|
s.mux.Lock()
|
||||||
|
defer s.mux.Unlock()
|
||||||
|
|
||||||
|
account, err := s.getAccount(accountID)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
var takenIps []net.IP
|
||||||
|
for _, existingPeer := range account.Peers {
|
||||||
|
takenIps = append(takenIps, existingPeer.IP)
|
||||||
|
}
|
||||||
|
|
||||||
|
return takenIps, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *FileStore) GetPeerLabelsInAccount(ctx context.Context, lockStrength LockingStrength, accountID string) ([]string, error) {
|
||||||
|
s.mux.Lock()
|
||||||
|
defer s.mux.Unlock()
|
||||||
|
|
||||||
|
account, err := s.getAccount(accountID)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
existingLabels := []string{}
|
||||||
|
for _, peer := range account.Peers {
|
||||||
|
if peer.DNSLabel != "" {
|
||||||
|
existingLabels = append(existingLabels, peer.DNSLabel)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return existingLabels, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *FileStore) GetAccountNetwork(ctx context.Context, lockStrength LockingStrength, accountID string) (*Network, error) {
|
||||||
|
s.mux.Lock()
|
||||||
|
defer s.mux.Unlock()
|
||||||
|
|
||||||
|
account, err := s.getAccount(accountID)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
return account.Network, nil
|
||||||
|
}
|
||||||
|
|
||||||
type StoredAccount struct{}
|
type StoredAccount struct{}
|
||||||
|
|
||||||
// NewFileStore restores a store from the file located in the datadir
|
// NewFileStore restores a store from the file located in the datadir
|
||||||
@ -422,7 +576,7 @@ func (s *FileStore) GetAccountBySetupKey(_ context.Context, setupKey string) (*A
|
|||||||
|
|
||||||
accountID, ok := s.SetupKeyID2AccountID[strings.ToUpper(setupKey)]
|
accountID, ok := s.SetupKeyID2AccountID[strings.ToUpper(setupKey)]
|
||||||
if !ok {
|
if !ok {
|
||||||
return nil, status.Errorf(status.NotFound, "account not found: provided setup key doesn't exists")
|
return nil, status.NewSetupKeyNotFoundError()
|
||||||
}
|
}
|
||||||
|
|
||||||
account, err := s.getAccount(accountID)
|
account, err := s.getAccount(accountID)
|
||||||
@ -469,7 +623,7 @@ func (s *FileStore) GetUserByTokenID(_ context.Context, tokenID string) (*User,
|
|||||||
return account.Users[userID].Copy(), nil
|
return account.Users[userID].Copy(), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *FileStore) GetUserByUserID(_ context.Context, userID string) (*User, error) {
|
func (s *FileStore) GetUserByUserID(_ context.Context, _ LockingStrength, userID string) (*User, error) {
|
||||||
accountID, ok := s.UserID2AccountID[userID]
|
accountID, ok := s.UserID2AccountID[userID]
|
||||||
if !ok {
|
if !ok {
|
||||||
return nil, status.Errorf(status.NotFound, "accountID not found: provided userID doesn't exists")
|
return nil, status.Errorf(status.NotFound, "accountID not found: provided userID doesn't exists")
|
||||||
@ -513,7 +667,7 @@ func (s *FileStore) GetAllAccounts(_ context.Context) (all []*Account) {
|
|||||||
func (s *FileStore) getAccount(accountID string) (*Account, error) {
|
func (s *FileStore) getAccount(accountID string) (*Account, error) {
|
||||||
account, ok := s.Accounts[accountID]
|
account, ok := s.Accounts[accountID]
|
||||||
if !ok {
|
if !ok {
|
||||||
return nil, status.Errorf(status.NotFound, "account not found")
|
return nil, status.NewAccountNotFoundError(accountID)
|
||||||
}
|
}
|
||||||
|
|
||||||
return account, nil
|
return account, nil
|
||||||
@ -639,13 +793,13 @@ func (s *FileStore) GetAccountIDBySetupKey(_ context.Context, setupKey string) (
|
|||||||
|
|
||||||
accountID, ok := s.SetupKeyID2AccountID[strings.ToUpper(setupKey)]
|
accountID, ok := s.SetupKeyID2AccountID[strings.ToUpper(setupKey)]
|
||||||
if !ok {
|
if !ok {
|
||||||
return "", status.Errorf(status.NotFound, "account not found: provided setup key doesn't exists")
|
return "", status.NewSetupKeyNotFoundError()
|
||||||
}
|
}
|
||||||
|
|
||||||
return accountID, nil
|
return accountID, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *FileStore) GetPeerByPeerPubKey(_ context.Context, peerKey string) (*nbpeer.Peer, error) {
|
func (s *FileStore) GetPeerByPeerPubKey(_ context.Context, _ LockingStrength, peerKey string) (*nbpeer.Peer, error) {
|
||||||
s.mux.Lock()
|
s.mux.Lock()
|
||||||
defer s.mux.Unlock()
|
defer s.mux.Unlock()
|
||||||
|
|
||||||
@ -668,7 +822,7 @@ func (s *FileStore) GetPeerByPeerPubKey(_ context.Context, peerKey string) (*nbp
|
|||||||
return nil, status.NewPeerNotFoundError(peerKey)
|
return nil, status.NewPeerNotFoundError(peerKey)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *FileStore) GetAccountSettings(_ context.Context, accountID string) (*Settings, error) {
|
func (s *FileStore) GetAccountSettings(_ context.Context, _ LockingStrength, accountID string) (*Settings, error) {
|
||||||
s.mux.Lock()
|
s.mux.Lock()
|
||||||
defer s.mux.Unlock()
|
defer s.mux.Unlock()
|
||||||
|
|
||||||
@ -758,7 +912,7 @@ func (s *FileStore) SavePeerLocation(accountID string, peerWithLocation *nbpeer.
|
|||||||
}
|
}
|
||||||
|
|
||||||
// SaveUserLastLogin stores the last login time for a user in memory. It doesn't attempt to persist data to speed up things.
|
// SaveUserLastLogin stores the last login time for a user in memory. It doesn't attempt to persist data to speed up things.
|
||||||
func (s *FileStore) SaveUserLastLogin(accountID, userID string, lastLogin time.Time) error {
|
func (s *FileStore) SaveUserLastLogin(_ context.Context, accountID, userID string, lastLogin time.Time) error {
|
||||||
s.mux.Lock()
|
s.mux.Lock()
|
||||||
defer s.mux.Unlock()
|
defer s.mux.Unlock()
|
||||||
|
|
||||||
|
@ -627,7 +627,7 @@ func testSyncStatusRace(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
time.Sleep(10 * time.Millisecond)
|
time.Sleep(10 * time.Millisecond)
|
||||||
peer, err := am.Store.GetPeerByPeerPubKey(context.Background(), peerWithInvalidStatus.PublicKey().String())
|
peer, err := am.Store.GetPeerByPeerPubKey(context.Background(), LockingStrengthShare, peerWithInvalidStatus.PublicKey().String())
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
return
|
return
|
||||||
@ -638,8 +638,8 @@ func testSyncStatusRace(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func Test_LoginPerformance(t *testing.T) {
|
func Test_LoginPerformance(t *testing.T) {
|
||||||
if os.Getenv("CI") == "true" {
|
if os.Getenv("CI") == "true" || runtime.GOOS == "windows" {
|
||||||
t.Skip("Skipping on CI")
|
t.Skip("Skipping test on CI or Windows")
|
||||||
}
|
}
|
||||||
|
|
||||||
t.Setenv("NETBIRD_STORE_ENGINE", "sqlite")
|
t.Setenv("NETBIRD_STORE_ENGINE", "sqlite")
|
||||||
@ -655,7 +655,7 @@ func Test_LoginPerformance(t *testing.T) {
|
|||||||
// {"M", 250, 1},
|
// {"M", 250, 1},
|
||||||
// {"L", 500, 1},
|
// {"L", 500, 1},
|
||||||
// {"XL", 750, 1},
|
// {"XL", 750, 1},
|
||||||
{"XXL", 2000, 1},
|
{"XXL", 5000, 1},
|
||||||
}
|
}
|
||||||
|
|
||||||
log.SetOutput(io.Discard)
|
log.SetOutput(io.Discard)
|
||||||
@ -700,15 +700,18 @@ func Test_LoginPerformance(t *testing.T) {
|
|||||||
}
|
}
|
||||||
defer mgmtServer.GracefulStop()
|
defer mgmtServer.GracefulStop()
|
||||||
|
|
||||||
|
t.Logf("management setup complete, start registering peers")
|
||||||
|
|
||||||
var counter int32
|
var counter int32
|
||||||
var counterStart int32
|
var counterStart int32
|
||||||
var wg sync.WaitGroup
|
var wgAccount sync.WaitGroup
|
||||||
var mu sync.Mutex
|
var mu sync.Mutex
|
||||||
messageCalls := []func() error{}
|
messageCalls := []func() error{}
|
||||||
for j := 0; j < bc.accounts; j++ {
|
for j := 0; j < bc.accounts; j++ {
|
||||||
wg.Add(1)
|
wgAccount.Add(1)
|
||||||
|
var wgPeer sync.WaitGroup
|
||||||
go func(j int, counter *int32, counterStart *int32) {
|
go func(j int, counter *int32, counterStart *int32) {
|
||||||
defer wg.Done()
|
defer wgAccount.Done()
|
||||||
|
|
||||||
account, err := createAccount(am, fmt.Sprintf("account-%d", j), fmt.Sprintf("user-%d", j), fmt.Sprintf("domain-%d", j))
|
account, err := createAccount(am, fmt.Sprintf("account-%d", j), fmt.Sprintf("user-%d", j), fmt.Sprintf("domain-%d", j))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@ -722,7 +725,9 @@ func Test_LoginPerformance(t *testing.T) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
startTime := time.Now()
|
||||||
for i := 0; i < bc.peers; i++ {
|
for i := 0; i < bc.peers; i++ {
|
||||||
|
wgPeer.Add(1)
|
||||||
key, err := wgtypes.GeneratePrivateKey()
|
key, err := wgtypes.GeneratePrivateKey()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Logf("failed to generate key: %v", err)
|
t.Logf("failed to generate key: %v", err)
|
||||||
@ -763,21 +768,29 @@ func Test_LoginPerformance(t *testing.T) {
|
|||||||
mu.Lock()
|
mu.Lock()
|
||||||
messageCalls = append(messageCalls, login)
|
messageCalls = append(messageCalls, login)
|
||||||
mu.Unlock()
|
mu.Unlock()
|
||||||
_, _, _, err = am.LoginPeer(context.Background(), peerLogin)
|
|
||||||
if err != nil {
|
|
||||||
t.Logf("failed to login peer: %v", err)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
atomic.AddInt32(counterStart, 1)
|
go func(peerLogin PeerLogin, counterStart *int32) {
|
||||||
if *counterStart%100 == 0 {
|
defer wgPeer.Done()
|
||||||
t.Logf("registered %d peers", *counterStart)
|
_, _, _, err = am.LoginPeer(context.Background(), peerLogin)
|
||||||
}
|
if err != nil {
|
||||||
|
t.Logf("failed to login peer: %v", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
atomic.AddInt32(counterStart, 1)
|
||||||
|
if *counterStart%100 == 0 {
|
||||||
|
t.Logf("registered %d peers", *counterStart)
|
||||||
|
}
|
||||||
|
}(peerLogin, counterStart)
|
||||||
|
|
||||||
}
|
}
|
||||||
|
wgPeer.Wait()
|
||||||
|
|
||||||
|
t.Logf("Time for registration: %s", time.Since(startTime))
|
||||||
}(j, &counter, &counterStart)
|
}(j, &counter, &counterStart)
|
||||||
}
|
}
|
||||||
|
|
||||||
wg.Wait()
|
wgAccount.Wait()
|
||||||
|
|
||||||
t.Logf("prepared %d login calls", len(messageCalls))
|
t.Logf("prepared %d login calls", len(messageCalls))
|
||||||
testLoginPerformance(t, messageCalls)
|
testLoginPerformance(t, messageCalls)
|
||||||
|
@ -11,6 +11,7 @@ import (
|
|||||||
"github.com/rs/xid"
|
"github.com/rs/xid"
|
||||||
log "github.com/sirupsen/logrus"
|
log "github.com/sirupsen/logrus"
|
||||||
|
|
||||||
|
"github.com/netbirdio/netbird/management/server/idp"
|
||||||
"github.com/netbirdio/netbird/management/server/posture"
|
"github.com/netbirdio/netbird/management/server/posture"
|
||||||
|
|
||||||
"github.com/netbirdio/netbird/management/proto"
|
"github.com/netbirdio/netbird/management/proto"
|
||||||
@ -371,164 +372,175 @@ func (am *DefaultAccountManager) AddPeer(ctx context.Context, setupKey, userID s
|
|||||||
}
|
}
|
||||||
}()
|
}()
|
||||||
|
|
||||||
var account *Account
|
|
||||||
// ensure that we consider modification happened meanwhile (because we were outside the account lock when we fetched the account)
|
|
||||||
account, err = am.Store.GetAccount(ctx, accountID)
|
|
||||||
if err != nil {
|
|
||||||
return nil, nil, nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
if strings.ToLower(peer.Meta.Hostname) == "iphone" || strings.ToLower(peer.Meta.Hostname) == "ipad" && userID != "" {
|
|
||||||
if am.idpManager != nil {
|
|
||||||
userdata, err := am.lookupUserInCache(ctx, userID, account)
|
|
||||||
if err == nil && userdata != nil {
|
|
||||||
peer.Meta.Hostname = fmt.Sprintf("%s-%s", peer.Meta.Hostname, strings.Split(userdata.Email, "@")[0])
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// This is a handling for the case when the same machine (with the same WireGuard pub key) tries to register twice.
|
// This is a handling for the case when the same machine (with the same WireGuard pub key) tries to register twice.
|
||||||
// Such case is possible when AddPeer function takes long time to finish after AcquireWriteLockByUID (e.g., database is slow)
|
// Such case is possible when AddPeer function takes long time to finish after AcquireWriteLockByUID (e.g., database is slow)
|
||||||
// and the peer disconnects with a timeout and tries to register again.
|
// and the peer disconnects with a timeout and tries to register again.
|
||||||
// We just check if this machine has been registered before and reject the second registration.
|
// We just check if this machine has been registered before and reject the second registration.
|
||||||
// The connecting peer should be able to recover with a retry.
|
// The connecting peer should be able to recover with a retry.
|
||||||
_, err = account.FindPeerByPubKey(peer.Key)
|
_, err = am.Store.GetPeerByPeerPubKey(ctx, LockingStrengthShare, peer.Key)
|
||||||
if err == nil {
|
if err == nil {
|
||||||
return nil, nil, nil, status.Errorf(status.PreconditionFailed, "peer has been already registered")
|
return nil, nil, nil, status.Errorf(status.PreconditionFailed, "peer has been already registered")
|
||||||
}
|
}
|
||||||
|
|
||||||
opEvent := &activity.Event{
|
opEvent := &activity.Event{
|
||||||
Timestamp: time.Now().UTC(),
|
Timestamp: time.Now().UTC(),
|
||||||
AccountID: account.Id,
|
AccountID: accountID,
|
||||||
}
|
}
|
||||||
|
|
||||||
var ephemeral bool
|
var newPeer *nbpeer.Peer
|
||||||
setupKeyName := ""
|
|
||||||
if !addedByUser {
|
|
||||||
// validate the setup key if adding with a key
|
|
||||||
sk, err := account.FindSetupKey(upperKey)
|
|
||||||
if err != nil {
|
|
||||||
return nil, nil, nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
if !sk.IsValid() {
|
err = am.Store.ExecuteInTransaction(ctx, func(transaction Store) error {
|
||||||
return nil, nil, nil, status.Errorf(status.PreconditionFailed, "couldn't add peer: setup key is invalid")
|
var groupsToAdd []string
|
||||||
}
|
var setupKeyID string
|
||||||
|
var setupKeyName string
|
||||||
account.SetupKeys[sk.Key] = sk.IncrementUsage()
|
var ephemeral bool
|
||||||
opEvent.InitiatorID = sk.Id
|
if addedByUser {
|
||||||
opEvent.Activity = activity.PeerAddedWithSetupKey
|
user, err := transaction.GetUserByUserID(ctx, LockingStrengthUpdate, userID)
|
||||||
ephemeral = sk.Ephemeral
|
if err != nil {
|
||||||
setupKeyName = sk.Name
|
return fmt.Errorf("failed to get user groups: %w", err)
|
||||||
} else {
|
}
|
||||||
opEvent.InitiatorID = userID
|
groupsToAdd = user.AutoGroups
|
||||||
opEvent.Activity = activity.PeerAddedByUser
|
opEvent.InitiatorID = userID
|
||||||
}
|
opEvent.Activity = activity.PeerAddedByUser
|
||||||
|
|
||||||
takenIps := account.getTakenIPs()
|
|
||||||
existingLabels := account.getPeerDNSLabels()
|
|
||||||
|
|
||||||
newLabel, err := getPeerHostLabel(peer.Meta.Hostname, existingLabels)
|
|
||||||
if err != nil {
|
|
||||||
return nil, nil, nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
peer.DNSLabel = newLabel
|
|
||||||
network := account.Network
|
|
||||||
nextIp, err := AllocatePeerIP(network.Net, takenIps)
|
|
||||||
if err != nil {
|
|
||||||
return nil, nil, nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
registrationTime := time.Now().UTC()
|
|
||||||
|
|
||||||
newPeer := &nbpeer.Peer{
|
|
||||||
ID: xid.New().String(),
|
|
||||||
Key: peer.Key,
|
|
||||||
SetupKey: upperKey,
|
|
||||||
IP: nextIp,
|
|
||||||
Meta: peer.Meta,
|
|
||||||
Name: peer.Meta.Hostname,
|
|
||||||
DNSLabel: newLabel,
|
|
||||||
UserID: userID,
|
|
||||||
Status: &nbpeer.PeerStatus{Connected: false, LastSeen: registrationTime},
|
|
||||||
SSHEnabled: false,
|
|
||||||
SSHKey: peer.SSHKey,
|
|
||||||
LastLogin: registrationTime,
|
|
||||||
CreatedAt: registrationTime,
|
|
||||||
LoginExpirationEnabled: addedByUser,
|
|
||||||
Ephemeral: ephemeral,
|
|
||||||
Location: peer.Location,
|
|
||||||
}
|
|
||||||
|
|
||||||
if am.geo != nil && newPeer.Location.ConnectionIP != nil {
|
|
||||||
location, err := am.geo.Lookup(newPeer.Location.ConnectionIP)
|
|
||||||
if err != nil {
|
|
||||||
log.WithContext(ctx).Warnf("failed to get location for new peer realip: [%s]: %v", newPeer.Location.ConnectionIP.String(), err)
|
|
||||||
} else {
|
} else {
|
||||||
newPeer.Location.CountryCode = location.Country.ISOCode
|
// Validate the setup key
|
||||||
newPeer.Location.CityName = location.City.Names.En
|
sk, err := transaction.GetSetupKeyBySecret(ctx, LockingStrengthUpdate, upperKey)
|
||||||
newPeer.Location.GeoNameID = location.City.GeonameID
|
if err != nil {
|
||||||
}
|
return fmt.Errorf("failed to get setup key: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
// add peer to 'All' group
|
if !sk.IsValid() {
|
||||||
group, err := account.GetGroupAll()
|
return status.Errorf(status.PreconditionFailed, "couldn't add peer: setup key is invalid")
|
||||||
if err != nil {
|
}
|
||||||
return nil, nil, nil, err
|
|
||||||
}
|
|
||||||
group.Peers = append(group.Peers, newPeer.ID)
|
|
||||||
|
|
||||||
var groupsToAdd []string
|
opEvent.InitiatorID = sk.Id
|
||||||
if addedByUser {
|
opEvent.Activity = activity.PeerAddedWithSetupKey
|
||||||
groupsToAdd, err = account.getUserGroups(userID)
|
groupsToAdd = sk.AutoGroups
|
||||||
if err != nil {
|
ephemeral = sk.Ephemeral
|
||||||
return nil, nil, nil, err
|
setupKeyID = sk.Id
|
||||||
|
setupKeyName = sk.Name
|
||||||
}
|
}
|
||||||
} else {
|
|
||||||
groupsToAdd, err = account.getSetupKeyGroups(upperKey)
|
|
||||||
if err != nil {
|
|
||||||
return nil, nil, nil, err
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if len(groupsToAdd) > 0 {
|
if strings.ToLower(peer.Meta.Hostname) == "iphone" || strings.ToLower(peer.Meta.Hostname) == "ipad" && userID != "" {
|
||||||
for _, s := range groupsToAdd {
|
if am.idpManager != nil {
|
||||||
if g, ok := account.Groups[s]; ok && g.Name != "All" {
|
userdata, err := am.idpManager.GetUserDataByID(ctx, userID, idp.AppMetadata{WTAccountID: accountID})
|
||||||
g.Peers = append(g.Peers, newPeer.ID)
|
if err == nil && userdata != nil {
|
||||||
|
peer.Meta.Hostname = fmt.Sprintf("%s-%s", peer.Meta.Hostname, strings.Split(userdata.Email, "@")[0])
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
|
||||||
|
|
||||||
newPeer = am.integratedPeerValidator.PreparePeer(ctx, account.Id, newPeer, account.GetPeerGroupsList(newPeer.ID), account.Settings.Extra)
|
freeLabel, err := am.getFreeDNSLabel(ctx, transaction, accountID, peer.Meta.Hostname)
|
||||||
|
|
||||||
if addedByUser {
|
|
||||||
user, err := account.FindUser(userID)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, nil, nil, status.Errorf(status.Internal, "couldn't find user")
|
return fmt.Errorf("failed to get free DNS label: %w", err)
|
||||||
}
|
}
|
||||||
user.updateLastLogin(newPeer.LastLogin)
|
|
||||||
}
|
|
||||||
|
|
||||||
account.Peers[newPeer.ID] = newPeer
|
freeIP, err := am.getFreeIP(ctx, transaction, accountID)
|
||||||
account.Network.IncSerial()
|
if err != nil {
|
||||||
err = am.Store.SaveAccount(ctx, account)
|
return fmt.Errorf("failed to get free IP: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
registrationTime := time.Now().UTC()
|
||||||
|
newPeer = &nbpeer.Peer{
|
||||||
|
ID: xid.New().String(),
|
||||||
|
AccountID: accountID,
|
||||||
|
Key: peer.Key,
|
||||||
|
SetupKey: upperKey,
|
||||||
|
IP: freeIP,
|
||||||
|
Meta: peer.Meta,
|
||||||
|
Name: peer.Meta.Hostname,
|
||||||
|
DNSLabel: freeLabel,
|
||||||
|
UserID: userID,
|
||||||
|
Status: &nbpeer.PeerStatus{Connected: false, LastSeen: registrationTime},
|
||||||
|
SSHEnabled: false,
|
||||||
|
SSHKey: peer.SSHKey,
|
||||||
|
LastLogin: registrationTime,
|
||||||
|
CreatedAt: registrationTime,
|
||||||
|
LoginExpirationEnabled: addedByUser,
|
||||||
|
Ephemeral: ephemeral,
|
||||||
|
Location: peer.Location,
|
||||||
|
}
|
||||||
|
opEvent.TargetID = newPeer.ID
|
||||||
|
opEvent.Meta = newPeer.EventMeta(am.GetDNSDomain())
|
||||||
|
if !addedByUser {
|
||||||
|
opEvent.Meta["setup_key_name"] = setupKeyName
|
||||||
|
}
|
||||||
|
|
||||||
|
if am.geo != nil && newPeer.Location.ConnectionIP != nil {
|
||||||
|
location, err := am.geo.Lookup(newPeer.Location.ConnectionIP)
|
||||||
|
if err != nil {
|
||||||
|
log.WithContext(ctx).Warnf("failed to get location for new peer realip: [%s]: %v", newPeer.Location.ConnectionIP.String(), err)
|
||||||
|
} else {
|
||||||
|
newPeer.Location.CountryCode = location.Country.ISOCode
|
||||||
|
newPeer.Location.CityName = location.City.Names.En
|
||||||
|
newPeer.Location.GeoNameID = location.City.GeonameID
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
settings, err := transaction.GetAccountSettings(ctx, LockingStrengthShare, accountID)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to get account settings: %w", err)
|
||||||
|
}
|
||||||
|
newPeer = am.integratedPeerValidator.PreparePeer(ctx, accountID, newPeer, groupsToAdd, settings.Extra)
|
||||||
|
|
||||||
|
err = transaction.AddPeerToAllGroup(ctx, accountID, newPeer.ID)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed adding peer to All group: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(groupsToAdd) > 0 {
|
||||||
|
for _, g := range groupsToAdd {
|
||||||
|
err = transaction.AddPeerToGroup(ctx, accountID, newPeer.ID, g)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
err = transaction.AddPeerToAccount(ctx, newPeer)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to add peer to account: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
err = transaction.IncrementNetworkSerial(ctx, accountID)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to increment network serial: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if addedByUser {
|
||||||
|
err := transaction.SaveUserLastLogin(ctx, accountID, userID, newPeer.LastLogin)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to update user last login: %w", err)
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
err = transaction.IncrementSetupKeyUsage(ctx, setupKeyID)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to increment setup key usage: %w", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
log.WithContext(ctx).Debugf("Peer %s added to account %s", newPeer.ID, accountID)
|
||||||
|
return nil
|
||||||
|
})
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, nil, nil, err
|
return nil, nil, nil, fmt.Errorf("failed to add peer to database: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Account is saved, we can release the lock
|
if newPeer == nil {
|
||||||
unlock()
|
return nil, nil, nil, fmt.Errorf("new peer is nil")
|
||||||
unlock = nil
|
|
||||||
|
|
||||||
opEvent.TargetID = newPeer.ID
|
|
||||||
opEvent.Meta = newPeer.EventMeta(am.GetDNSDomain())
|
|
||||||
if !addedByUser {
|
|
||||||
opEvent.Meta["setup_key_name"] = setupKeyName
|
|
||||||
}
|
}
|
||||||
|
|
||||||
am.StoreEvent(ctx, opEvent.InitiatorID, opEvent.TargetID, opEvent.AccountID, opEvent.Activity, opEvent.Meta)
|
am.StoreEvent(ctx, opEvent.InitiatorID, opEvent.TargetID, opEvent.AccountID, opEvent.Activity, opEvent.Meta)
|
||||||
|
|
||||||
|
unlock()
|
||||||
|
unlock = nil
|
||||||
|
|
||||||
|
account, err := am.requestBuffer.GetAccountWithBackpressure(ctx, accountID)
|
||||||
|
if err != nil {
|
||||||
|
return nil, nil, nil, fmt.Errorf("error getting account: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
am.updateAccountPeers(ctx, account)
|
am.updateAccountPeers(ctx, account)
|
||||||
|
|
||||||
approvedPeersMap, err := am.GetValidatedPeers(account)
|
approvedPeersMap, err := am.GetValidatedPeers(account)
|
||||||
@ -536,12 +548,31 @@ func (am *DefaultAccountManager) AddPeer(ctx context.Context, setupKey, userID s
|
|||||||
return nil, nil, nil, err
|
return nil, nil, nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
postureChecks := am.getPeerPostureChecks(account, peer)
|
postureChecks := am.getPeerPostureChecks(account, newPeer)
|
||||||
customZone := account.GetPeersCustomZone(ctx, am.dnsDomain)
|
customZone := account.GetPeersCustomZone(ctx, am.dnsDomain)
|
||||||
networkMap := account.GetPeerNetworkMap(ctx, newPeer.ID, customZone, approvedPeersMap, am.metrics.AccountManagerMetrics())
|
networkMap := account.GetPeerNetworkMap(ctx, newPeer.ID, customZone, approvedPeersMap, am.metrics.AccountManagerMetrics())
|
||||||
return newPeer, networkMap, postureChecks, nil
|
return newPeer, networkMap, postureChecks, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (am *DefaultAccountManager) getFreeIP(ctx context.Context, store Store, accountID string) (net.IP, error) {
|
||||||
|
takenIps, err := store.GetTakenIPs(ctx, LockingStrengthShare, accountID)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to get taken IPs: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
network, err := store.GetAccountNetwork(ctx, LockingStrengthUpdate, accountID)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("failed getting network: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
nextIp, err := AllocatePeerIP(network.Net, takenIps)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to allocate new peer ip: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return nextIp, nil
|
||||||
|
}
|
||||||
|
|
||||||
// SyncPeer checks whether peer is eligible for receiving NetworkMap (authenticated) and returns its NetworkMap if eligible
|
// SyncPeer checks whether peer is eligible for receiving NetworkMap (authenticated) and returns its NetworkMap if eligible
|
||||||
func (am *DefaultAccountManager) SyncPeer(ctx context.Context, sync PeerSync, account *Account) (*nbpeer.Peer, *NetworkMap, []*posture.Checks, error) {
|
func (am *DefaultAccountManager) SyncPeer(ctx context.Context, sync PeerSync, account *Account) (*nbpeer.Peer, *NetworkMap, []*posture.Checks, error) {
|
||||||
peer, err := account.FindPeerByPubKey(sync.WireGuardPubKey)
|
peer, err := account.FindPeerByPubKey(sync.WireGuardPubKey)
|
||||||
@ -647,12 +678,12 @@ func (am *DefaultAccountManager) LoginPeer(ctx context.Context, login PeerLogin)
|
|||||||
}
|
}
|
||||||
}()
|
}()
|
||||||
|
|
||||||
peer, err := am.Store.GetPeerByPeerPubKey(ctx, login.WireGuardPubKey)
|
peer, err := am.Store.GetPeerByPeerPubKey(ctx, LockingStrengthUpdate, login.WireGuardPubKey)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, nil, nil, err
|
return nil, nil, nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
settings, err := am.Store.GetAccountSettings(ctx, accountID)
|
settings, err := am.Store.GetAccountSettings(ctx, LockingStrengthShare, accountID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, nil, nil, err
|
return nil, nil, nil, err
|
||||||
}
|
}
|
||||||
@ -730,7 +761,7 @@ func (am *DefaultAccountManager) LoginPeer(ctx context.Context, login PeerLogin)
|
|||||||
// with no JWT token and usually no setup-key. As the client can send up to two login request to check if it is expired
|
// with no JWT token and usually no setup-key. As the client can send up to two login request to check if it is expired
|
||||||
// and before starting the engine, we do the checks without an account lock to avoid piling up requests.
|
// and before starting the engine, we do the checks without an account lock to avoid piling up requests.
|
||||||
func (am *DefaultAccountManager) checkIFPeerNeedsLoginWithoutLock(ctx context.Context, accountID string, login PeerLogin) error {
|
func (am *DefaultAccountManager) checkIFPeerNeedsLoginWithoutLock(ctx context.Context, accountID string, login PeerLogin) error {
|
||||||
peer, err := am.Store.GetPeerByPeerPubKey(ctx, login.WireGuardPubKey)
|
peer, err := am.Store.GetPeerByPeerPubKey(ctx, LockingStrengthShare, login.WireGuardPubKey)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
@ -741,7 +772,7 @@ func (am *DefaultAccountManager) checkIFPeerNeedsLoginWithoutLock(ctx context.Co
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
settings, err := am.Store.GetAccountSettings(ctx, accountID)
|
settings, err := am.Store.GetAccountSettings(ctx, LockingStrengthShare, accountID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
@ -786,7 +817,7 @@ func (am *DefaultAccountManager) handleExpiredPeer(ctx context.Context, user *Us
|
|||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
err = am.Store.SaveUserLastLogin(user.AccountID, user.Id, peer.LastLogin)
|
err = am.Store.SaveUserLastLogin(ctx, user.AccountID, user.Id, peer.LastLogin)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
@ -969,3 +1000,11 @@ func (am *DefaultAccountManager) updateAccountPeers(ctx context.Context, account
|
|||||||
|
|
||||||
wg.Wait()
|
wg.Wait()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func ConvertSliceToMap(existingLabels []string) map[string]struct{} {
|
||||||
|
labelMap := make(map[string]struct{}, len(existingLabels))
|
||||||
|
for _, label := range existingLabels {
|
||||||
|
labelMap[label] = struct{}{}
|
||||||
|
}
|
||||||
|
return labelMap
|
||||||
|
}
|
||||||
|
@ -7,20 +7,24 @@ import (
|
|||||||
"net"
|
"net"
|
||||||
"net/netip"
|
"net/netip"
|
||||||
"os"
|
"os"
|
||||||
|
"runtime"
|
||||||
"testing"
|
"testing"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/rs/xid"
|
"github.com/rs/xid"
|
||||||
log "github.com/sirupsen/logrus"
|
log "github.com/sirupsen/logrus"
|
||||||
"github.com/stretchr/testify/assert"
|
"github.com/stretchr/testify/assert"
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
|
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
|
||||||
|
|
||||||
nbdns "github.com/netbirdio/netbird/dns"
|
nbdns "github.com/netbirdio/netbird/dns"
|
||||||
"github.com/netbirdio/netbird/management/domain"
|
"github.com/netbirdio/netbird/management/domain"
|
||||||
"github.com/netbirdio/netbird/management/proto"
|
"github.com/netbirdio/netbird/management/proto"
|
||||||
|
"github.com/netbirdio/netbird/management/server/activity"
|
||||||
nbgroup "github.com/netbirdio/netbird/management/server/group"
|
nbgroup "github.com/netbirdio/netbird/management/server/group"
|
||||||
nbpeer "github.com/netbirdio/netbird/management/server/peer"
|
nbpeer "github.com/netbirdio/netbird/management/server/peer"
|
||||||
"github.com/netbirdio/netbird/management/server/posture"
|
"github.com/netbirdio/netbird/management/server/posture"
|
||||||
|
"github.com/netbirdio/netbird/management/server/telemetry"
|
||||||
nbroute "github.com/netbirdio/netbird/route"
|
nbroute "github.com/netbirdio/netbird/route"
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -995,3 +999,184 @@ func TestToSyncResponse(t *testing.T) {
|
|||||||
assert.Equal(t, 1, len(response.Checks))
|
assert.Equal(t, 1, len(response.Checks))
|
||||||
assert.Equal(t, "/usr/bin/netbird", response.Checks[0].Files[0])
|
assert.Equal(t, "/usr/bin/netbird", response.Checks[0].Files[0])
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func Test_RegisterPeerByUser(t *testing.T) {
|
||||||
|
if runtime.GOOS == "windows" {
|
||||||
|
t.Skip("The SQLite store is not properly supported by Windows yet")
|
||||||
|
}
|
||||||
|
|
||||||
|
store := newSqliteStoreFromFile(t, "testdata/extended-store.json")
|
||||||
|
|
||||||
|
eventStore := &activity.InMemoryEventStore{}
|
||||||
|
|
||||||
|
metrics, err := telemetry.NewDefaultAppMetrics(context.Background())
|
||||||
|
assert.NoError(t, err)
|
||||||
|
|
||||||
|
am, err := BuildManager(context.Background(), store, NewPeersUpdateManager(nil), nil, "", "netbird.cloud", eventStore, nil, false, MocIntegratedValidator{}, metrics)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
|
||||||
|
existingAccountID := "bf1c8084-ba50-4ce7-9439-34653001fc3b"
|
||||||
|
existingUserID := "edafee4e-63fb-11ec-90d6-0242ac120003"
|
||||||
|
|
||||||
|
_, err = store.GetAccount(context.Background(), existingAccountID)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
newPeer := &nbpeer.Peer{
|
||||||
|
ID: xid.New().String(),
|
||||||
|
AccountID: existingAccountID,
|
||||||
|
Key: "newPeerKey",
|
||||||
|
SetupKey: "",
|
||||||
|
IP: net.IP{123, 123, 123, 123},
|
||||||
|
Meta: nbpeer.PeerSystemMeta{
|
||||||
|
Hostname: "newPeer",
|
||||||
|
GoOS: "linux",
|
||||||
|
},
|
||||||
|
Name: "newPeerName",
|
||||||
|
DNSLabel: "newPeer.test",
|
||||||
|
UserID: existingUserID,
|
||||||
|
Status: &nbpeer.PeerStatus{Connected: false, LastSeen: time.Now()},
|
||||||
|
SSHEnabled: false,
|
||||||
|
LastLogin: time.Now(),
|
||||||
|
}
|
||||||
|
|
||||||
|
addedPeer, _, _, err := am.AddPeer(context.Background(), "", existingUserID, newPeer)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
peer, err := store.GetPeerByPeerPubKey(context.Background(), LockingStrengthShare, addedPeer.Key)
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.Equal(t, peer.AccountID, existingAccountID)
|
||||||
|
assert.Equal(t, peer.UserID, existingUserID)
|
||||||
|
|
||||||
|
account, err := store.GetAccount(context.Background(), existingAccountID)
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.Contains(t, account.Peers, addedPeer.ID)
|
||||||
|
assert.Equal(t, peer.Meta.Hostname, newPeer.Meta.Hostname)
|
||||||
|
assert.Contains(t, account.Groups["cfefqs706sqkneg59g3g"].Peers, addedPeer.ID)
|
||||||
|
assert.Contains(t, account.Groups["cfefqs706sqkneg59g4g"].Peers, addedPeer.ID)
|
||||||
|
|
||||||
|
assert.Equal(t, uint64(1), account.Network.Serial)
|
||||||
|
|
||||||
|
lastLogin, err := time.Parse("2006-01-02T15:04:05Z", "0001-01-01T00:00:00Z")
|
||||||
|
assert.NoError(t, err)
|
||||||
|
assert.NotEqual(t, lastLogin, account.Users[existingUserID].LastLogin)
|
||||||
|
}
|
||||||
|
|
||||||
|
func Test_RegisterPeerBySetupKey(t *testing.T) {
|
||||||
|
if runtime.GOOS == "windows" {
|
||||||
|
t.Skip("The SQLite store is not properly supported by Windows yet")
|
||||||
|
}
|
||||||
|
|
||||||
|
store := newSqliteStoreFromFile(t, "testdata/extended-store.json")
|
||||||
|
|
||||||
|
eventStore := &activity.InMemoryEventStore{}
|
||||||
|
|
||||||
|
metrics, err := telemetry.NewDefaultAppMetrics(context.Background())
|
||||||
|
assert.NoError(t, err)
|
||||||
|
|
||||||
|
am, err := BuildManager(context.Background(), store, NewPeersUpdateManager(nil), nil, "", "netbird.cloud", eventStore, nil, false, MocIntegratedValidator{}, metrics)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
|
||||||
|
existingAccountID := "bf1c8084-ba50-4ce7-9439-34653001fc3b"
|
||||||
|
existingSetupKeyID := "A2C8E62B-38F5-4553-B31E-DD66C696CEBB"
|
||||||
|
|
||||||
|
_, err = store.GetAccount(context.Background(), existingAccountID)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
newPeer := &nbpeer.Peer{
|
||||||
|
ID: xid.New().String(),
|
||||||
|
AccountID: existingAccountID,
|
||||||
|
Key: "newPeerKey",
|
||||||
|
SetupKey: "existingSetupKey",
|
||||||
|
UserID: "",
|
||||||
|
IP: net.IP{123, 123, 123, 123},
|
||||||
|
Meta: nbpeer.PeerSystemMeta{
|
||||||
|
Hostname: "newPeer",
|
||||||
|
GoOS: "linux",
|
||||||
|
},
|
||||||
|
Name: "newPeerName",
|
||||||
|
DNSLabel: "newPeer.test",
|
||||||
|
Status: &nbpeer.PeerStatus{Connected: false, LastSeen: time.Now()},
|
||||||
|
SSHEnabled: false,
|
||||||
|
}
|
||||||
|
|
||||||
|
addedPeer, _, _, err := am.AddPeer(context.Background(), existingSetupKeyID, "", newPeer)
|
||||||
|
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
peer, err := store.GetPeerByPeerPubKey(context.Background(), LockingStrengthShare, newPeer.Key)
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.Equal(t, peer.AccountID, existingAccountID)
|
||||||
|
assert.Equal(t, peer.SetupKey, existingSetupKeyID)
|
||||||
|
|
||||||
|
account, err := store.GetAccount(context.Background(), existingAccountID)
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.Contains(t, account.Peers, addedPeer.ID)
|
||||||
|
assert.Contains(t, account.Groups["cfefqs706sqkneg59g2g"].Peers, addedPeer.ID)
|
||||||
|
assert.Contains(t, account.Groups["cfefqs706sqkneg59g4g"].Peers, addedPeer.ID)
|
||||||
|
|
||||||
|
assert.Equal(t, uint64(1), account.Network.Serial)
|
||||||
|
|
||||||
|
lastUsed, err := time.Parse("2006-01-02T15:04:05Z", "0001-01-01T00:00:00Z")
|
||||||
|
assert.NoError(t, err)
|
||||||
|
assert.NotEqual(t, lastUsed, account.SetupKeys[existingSetupKeyID].LastUsed)
|
||||||
|
assert.Equal(t, 1, account.SetupKeys[existingSetupKeyID].UsedTimes)
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
func Test_RegisterPeerRollbackOnFailure(t *testing.T) {
|
||||||
|
if runtime.GOOS == "windows" {
|
||||||
|
t.Skip("The SQLite store is not properly supported by Windows yet")
|
||||||
|
}
|
||||||
|
|
||||||
|
store := newSqliteStoreFromFile(t, "testdata/extended-store.json")
|
||||||
|
|
||||||
|
eventStore := &activity.InMemoryEventStore{}
|
||||||
|
|
||||||
|
metrics, err := telemetry.NewDefaultAppMetrics(context.Background())
|
||||||
|
assert.NoError(t, err)
|
||||||
|
|
||||||
|
am, err := BuildManager(context.Background(), store, NewPeersUpdateManager(nil), nil, "", "netbird.cloud", eventStore, nil, false, MocIntegratedValidator{}, metrics)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
|
||||||
|
existingAccountID := "bf1c8084-ba50-4ce7-9439-34653001fc3b"
|
||||||
|
faultyKey := "A2C8E62B-38F5-4553-B31E-DD66C696CEBC"
|
||||||
|
|
||||||
|
_, err = store.GetAccount(context.Background(), existingAccountID)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
newPeer := &nbpeer.Peer{
|
||||||
|
ID: xid.New().String(),
|
||||||
|
AccountID: existingAccountID,
|
||||||
|
Key: "newPeerKey",
|
||||||
|
SetupKey: "existingSetupKey",
|
||||||
|
UserID: "",
|
||||||
|
IP: net.IP{123, 123, 123, 123},
|
||||||
|
Meta: nbpeer.PeerSystemMeta{
|
||||||
|
Hostname: "newPeer",
|
||||||
|
GoOS: "linux",
|
||||||
|
},
|
||||||
|
Name: "newPeerName",
|
||||||
|
DNSLabel: "newPeer.test",
|
||||||
|
Status: &nbpeer.PeerStatus{Connected: false, LastSeen: time.Now()},
|
||||||
|
SSHEnabled: false,
|
||||||
|
}
|
||||||
|
|
||||||
|
_, _, _, err = am.AddPeer(context.Background(), faultyKey, "", newPeer)
|
||||||
|
require.Error(t, err)
|
||||||
|
|
||||||
|
_, err = store.GetPeerByPeerPubKey(context.Background(), LockingStrengthShare, newPeer.Key)
|
||||||
|
require.Error(t, err)
|
||||||
|
|
||||||
|
account, err := store.GetAccount(context.Background(), existingAccountID)
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.NotContains(t, account.Peers, newPeer.ID)
|
||||||
|
assert.NotContains(t, account.Groups["cfefqs706sqkneg59g3g"].Peers, newPeer.ID)
|
||||||
|
assert.NotContains(t, account.Groups["cfefqs706sqkneg59g4g"].Peers, newPeer.ID)
|
||||||
|
|
||||||
|
assert.Equal(t, uint64(0), account.Network.Serial)
|
||||||
|
|
||||||
|
lastUsed, err := time.Parse("2006-01-02T15:04:05Z", "0001-01-01T00:00:00Z")
|
||||||
|
assert.NoError(t, err)
|
||||||
|
assert.Equal(t, lastUsed, account.SetupKeys[faultyKey].LastUsed)
|
||||||
|
assert.Equal(t, 0, account.SetupKeys[faultyKey].UsedTimes)
|
||||||
|
}
|
||||||
|
@ -5,6 +5,7 @@ import (
|
|||||||
"encoding/json"
|
"encoding/json"
|
||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
|
"net"
|
||||||
"os"
|
"os"
|
||||||
"path/filepath"
|
"path/filepath"
|
||||||
"runtime"
|
"runtime"
|
||||||
@ -33,6 +34,7 @@ import (
|
|||||||
const (
|
const (
|
||||||
storeSqliteFileName = "store.db"
|
storeSqliteFileName = "store.db"
|
||||||
idQueryCondition = "id = ?"
|
idQueryCondition = "id = ?"
|
||||||
|
keyQueryCondition = "key = ?"
|
||||||
accountAndIDQueryCondition = "account_id = ? and id = ?"
|
accountAndIDQueryCondition = "account_id = ? and id = ?"
|
||||||
peerNotFoundFMT = "peer %s not found"
|
peerNotFoundFMT = "peer %s not found"
|
||||||
)
|
)
|
||||||
@ -415,13 +417,12 @@ func (s *SqlStore) GetAccountByPrivateDomain(ctx context.Context, domain string)
|
|||||||
|
|
||||||
func (s *SqlStore) GetAccountBySetupKey(ctx context.Context, setupKey string) (*Account, error) {
|
func (s *SqlStore) GetAccountBySetupKey(ctx context.Context, setupKey string) (*Account, error) {
|
||||||
var key SetupKey
|
var key SetupKey
|
||||||
result := s.db.Select("account_id").First(&key, "key = ?", strings.ToUpper(setupKey))
|
result := s.db.WithContext(ctx).Select("account_id").First(&key, keyQueryCondition, strings.ToUpper(setupKey))
|
||||||
if result.Error != nil {
|
if result.Error != nil {
|
||||||
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
|
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
|
||||||
return nil, status.Errorf(status.NotFound, "account not found: index lookup failed")
|
return nil, status.Errorf(status.NotFound, "account not found: index lookup failed")
|
||||||
}
|
}
|
||||||
log.WithContext(ctx).Errorf("error when getting setup key from the store: %s", result.Error)
|
return nil, status.NewSetupKeyNotFoundError()
|
||||||
return nil, status.Errorf(status.Internal, "issue getting setup key from store")
|
|
||||||
}
|
}
|
||||||
|
|
||||||
if key.AccountID == "" {
|
if key.AccountID == "" {
|
||||||
@ -474,15 +475,15 @@ func (s *SqlStore) GetUserByTokenID(ctx context.Context, tokenID string) (*User,
|
|||||||
return &user, nil
|
return &user, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *SqlStore) GetUserByUserID(ctx context.Context, userID string) (*User, error) {
|
func (s *SqlStore) GetUserByUserID(ctx context.Context, lockStrength LockingStrength, userID string) (*User, error) {
|
||||||
var user User
|
var user User
|
||||||
result := s.db.First(&user, idQueryCondition, userID)
|
result := s.db.WithContext(ctx).Clauses(clause.Locking{Strength: string(lockStrength)}).
|
||||||
|
First(&user, idQueryCondition, userID)
|
||||||
if result.Error != nil {
|
if result.Error != nil {
|
||||||
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
|
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
|
||||||
return nil, status.Errorf(status.NotFound, "user not found: index lookup failed")
|
return nil, status.NewUserNotFoundError(userID)
|
||||||
}
|
}
|
||||||
log.WithContext(ctx).Errorf("error when getting user from the store: %s", result.Error)
|
return nil, status.NewGetUserFromStoreError()
|
||||||
return nil, status.Errorf(status.Internal, "issue getting user from store")
|
|
||||||
}
|
}
|
||||||
|
|
||||||
return &user, nil
|
return &user, nil
|
||||||
@ -535,7 +536,7 @@ func (s *SqlStore) GetAccount(ctx context.Context, accountID string) (*Account,
|
|||||||
if result.Error != nil {
|
if result.Error != nil {
|
||||||
log.WithContext(ctx).Errorf("error when getting account %s from the store: %s", accountID, result.Error)
|
log.WithContext(ctx).Errorf("error when getting account %s from the store: %s", accountID, result.Error)
|
||||||
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
|
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
|
||||||
return nil, status.Errorf(status.NotFound, "account not found")
|
return nil, status.NewAccountNotFoundError(accountID)
|
||||||
}
|
}
|
||||||
return nil, status.Errorf(status.Internal, "issue getting account from store")
|
return nil, status.Errorf(status.Internal, "issue getting account from store")
|
||||||
}
|
}
|
||||||
@ -595,7 +596,7 @@ func (s *SqlStore) GetAccount(ctx context.Context, accountID string) (*Account,
|
|||||||
|
|
||||||
func (s *SqlStore) GetAccountByUser(ctx context.Context, userID string) (*Account, error) {
|
func (s *SqlStore) GetAccountByUser(ctx context.Context, userID string) (*Account, error) {
|
||||||
var user User
|
var user User
|
||||||
result := s.db.Select("account_id").First(&user, idQueryCondition, userID)
|
result := s.db.WithContext(ctx).Select("account_id").First(&user, idQueryCondition, userID)
|
||||||
if result.Error != nil {
|
if result.Error != nil {
|
||||||
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
|
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
|
||||||
return nil, status.Errorf(status.NotFound, "account not found: index lookup failed")
|
return nil, status.Errorf(status.NotFound, "account not found: index lookup failed")
|
||||||
@ -612,12 +613,11 @@ func (s *SqlStore) GetAccountByUser(ctx context.Context, userID string) (*Accoun
|
|||||||
|
|
||||||
func (s *SqlStore) GetAccountByPeerID(ctx context.Context, peerID string) (*Account, error) {
|
func (s *SqlStore) GetAccountByPeerID(ctx context.Context, peerID string) (*Account, error) {
|
||||||
var peer nbpeer.Peer
|
var peer nbpeer.Peer
|
||||||
result := s.db.Select("account_id").First(&peer, idQueryCondition, peerID)
|
result := s.db.WithContext(ctx).Select("account_id").First(&peer, idQueryCondition, peerID)
|
||||||
if result.Error != nil {
|
if result.Error != nil {
|
||||||
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
|
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
|
||||||
return nil, status.Errorf(status.NotFound, "account not found: index lookup failed")
|
return nil, status.Errorf(status.NotFound, "account not found: index lookup failed")
|
||||||
}
|
}
|
||||||
log.WithContext(ctx).Errorf("error when getting peer from the store: %s", result.Error)
|
|
||||||
return nil, status.Errorf(status.Internal, "issue getting account from store")
|
return nil, status.Errorf(status.Internal, "issue getting account from store")
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -631,12 +631,11 @@ func (s *SqlStore) GetAccountByPeerID(ctx context.Context, peerID string) (*Acco
|
|||||||
func (s *SqlStore) GetAccountByPeerPubKey(ctx context.Context, peerKey string) (*Account, error) {
|
func (s *SqlStore) GetAccountByPeerPubKey(ctx context.Context, peerKey string) (*Account, error) {
|
||||||
var peer nbpeer.Peer
|
var peer nbpeer.Peer
|
||||||
|
|
||||||
result := s.db.Select("account_id").First(&peer, "key = ?", peerKey)
|
result := s.db.WithContext(ctx).Select("account_id").First(&peer, keyQueryCondition, peerKey)
|
||||||
if result.Error != nil {
|
if result.Error != nil {
|
||||||
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
|
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
|
||||||
return nil, status.Errorf(status.NotFound, "account not found: index lookup failed")
|
return nil, status.Errorf(status.NotFound, "account not found: index lookup failed")
|
||||||
}
|
}
|
||||||
log.WithContext(ctx).Errorf("error when getting peer from the store: %s", result.Error)
|
|
||||||
return nil, status.Errorf(status.Internal, "issue getting account from store")
|
return nil, status.Errorf(status.Internal, "issue getting account from store")
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -650,12 +649,11 @@ func (s *SqlStore) GetAccountByPeerPubKey(ctx context.Context, peerKey string) (
|
|||||||
func (s *SqlStore) GetAccountIDByPeerPubKey(ctx context.Context, peerKey string) (string, error) {
|
func (s *SqlStore) GetAccountIDByPeerPubKey(ctx context.Context, peerKey string) (string, error) {
|
||||||
var peer nbpeer.Peer
|
var peer nbpeer.Peer
|
||||||
var accountID string
|
var accountID string
|
||||||
result := s.db.Model(&peer).Select("account_id").Where("key = ?", peerKey).First(&accountID)
|
result := s.db.WithContext(ctx).Model(&peer).Select("account_id").Where(keyQueryCondition, peerKey).First(&accountID)
|
||||||
if result.Error != nil {
|
if result.Error != nil {
|
||||||
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
|
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
|
||||||
return "", status.Errorf(status.NotFound, "account not found: index lookup failed")
|
return "", status.Errorf(status.NotFound, "account not found: index lookup failed")
|
||||||
}
|
}
|
||||||
log.WithContext(ctx).Errorf("error when getting peer from the store: %s", result.Error)
|
|
||||||
return "", status.Errorf(status.Internal, "issue getting account from store")
|
return "", status.Errorf(status.Internal, "issue getting account from store")
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -677,61 +675,117 @@ func (s *SqlStore) GetAccountIDByUserID(userID string) (string, error) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (s *SqlStore) GetAccountIDBySetupKey(ctx context.Context, setupKey string) (string, error) {
|
func (s *SqlStore) GetAccountIDBySetupKey(ctx context.Context, setupKey string) (string, error) {
|
||||||
var key SetupKey
|
|
||||||
var accountID string
|
var accountID string
|
||||||
result := s.db.Model(&key).Select("account_id").Where("key = ?", strings.ToUpper(setupKey)).First(&accountID)
|
result := s.db.WithContext(ctx).Model(&SetupKey{}).Select("account_id").Where(keyQueryCondition, strings.ToUpper(setupKey)).First(&accountID)
|
||||||
if result.Error != nil {
|
if result.Error != nil {
|
||||||
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
|
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
|
||||||
return "", status.Errorf(status.NotFound, "account not found: index lookup failed")
|
return "", status.Errorf(status.NotFound, "account not found: index lookup failed")
|
||||||
}
|
}
|
||||||
log.WithContext(ctx).Errorf("error when getting setup key from the store: %s", result.Error)
|
return "", status.NewSetupKeyNotFoundError()
|
||||||
return "", status.Errorf(status.Internal, "issue getting setup key from store")
|
}
|
||||||
|
|
||||||
|
if accountID == "" {
|
||||||
|
return "", status.Errorf(status.NotFound, "account not found: index lookup failed")
|
||||||
}
|
}
|
||||||
|
|
||||||
return accountID, nil
|
return accountID, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *SqlStore) GetPeerByPeerPubKey(ctx context.Context, peerKey string) (*nbpeer.Peer, error) {
|
func (s *SqlStore) GetTakenIPs(ctx context.Context, lockStrength LockingStrength, accountID string) ([]net.IP, error) {
|
||||||
|
var ipJSONStrings []string
|
||||||
|
|
||||||
|
// Fetch the IP addresses as JSON strings
|
||||||
|
result := s.db.WithContext(ctx).Clauses(clause.Locking{Strength: string(lockStrength)}).Model(&nbpeer.Peer{}).
|
||||||
|
Where("account_id = ?", accountID).
|
||||||
|
Pluck("ip", &ipJSONStrings)
|
||||||
|
if result.Error != nil {
|
||||||
|
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
|
||||||
|
return nil, status.Errorf(status.NotFound, "no peers found for the account")
|
||||||
|
}
|
||||||
|
return nil, status.Errorf(status.Internal, "issue getting IPs from store")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Convert the JSON strings to net.IP objects
|
||||||
|
ips := make([]net.IP, len(ipJSONStrings))
|
||||||
|
for i, ipJSON := range ipJSONStrings {
|
||||||
|
var ip net.IP
|
||||||
|
if err := json.Unmarshal([]byte(ipJSON), &ip); err != nil {
|
||||||
|
return nil, status.Errorf(status.Internal, "issue parsing IP JSON from store")
|
||||||
|
}
|
||||||
|
ips[i] = ip
|
||||||
|
}
|
||||||
|
|
||||||
|
return ips, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *SqlStore) GetPeerLabelsInAccount(ctx context.Context, lockStrength LockingStrength, accountID string) ([]string, error) {
|
||||||
|
var labels []string
|
||||||
|
|
||||||
|
result := s.db.WithContext(ctx).Clauses(clause.Locking{Strength: string(lockStrength)}).Model(&nbpeer.Peer{}).
|
||||||
|
Where("account_id = ?", accountID).
|
||||||
|
Pluck("dns_label", &labels)
|
||||||
|
|
||||||
|
if result.Error != nil {
|
||||||
|
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
|
||||||
|
return nil, status.Errorf(status.NotFound, "no peers found for the account")
|
||||||
|
}
|
||||||
|
log.WithContext(ctx).Errorf("error when getting dns labels from the store: %s", result.Error)
|
||||||
|
return nil, status.Errorf(status.Internal, "issue getting dns labels from store")
|
||||||
|
}
|
||||||
|
|
||||||
|
return labels, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *SqlStore) GetAccountNetwork(ctx context.Context, lockStrength LockingStrength, accountID string) (*Network, error) {
|
||||||
|
var accountNetwork AccountNetwork
|
||||||
|
|
||||||
|
if err := s.db.WithContext(ctx).Model(&Account{}).Where(idQueryCondition, accountID).First(&accountNetwork).Error; err != nil {
|
||||||
|
if errors.Is(err, gorm.ErrRecordNotFound) {
|
||||||
|
return nil, status.NewAccountNotFoundError(accountID)
|
||||||
|
}
|
||||||
|
return nil, status.Errorf(status.Internal, "issue getting network from store")
|
||||||
|
}
|
||||||
|
return accountNetwork.Network, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *SqlStore) GetPeerByPeerPubKey(ctx context.Context, lockStrength LockingStrength, peerKey string) (*nbpeer.Peer, error) {
|
||||||
var peer nbpeer.Peer
|
var peer nbpeer.Peer
|
||||||
result := s.db.First(&peer, "key = ?", peerKey)
|
result := s.db.WithContext(ctx).Clauses(clause.Locking{Strength: string(lockStrength)}).First(&peer, keyQueryCondition, peerKey)
|
||||||
if result.Error != nil {
|
if result.Error != nil {
|
||||||
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
|
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
|
||||||
return nil, status.Errorf(status.NotFound, "peer not found")
|
return nil, status.Errorf(status.NotFound, "peer not found")
|
||||||
}
|
}
|
||||||
log.WithContext(ctx).Errorf("error when getting peer from the store: %s", result.Error)
|
|
||||||
return nil, status.Errorf(status.Internal, "issue getting peer from store")
|
return nil, status.Errorf(status.Internal, "issue getting peer from store")
|
||||||
}
|
}
|
||||||
|
|
||||||
return &peer, nil
|
return &peer, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *SqlStore) GetAccountSettings(ctx context.Context, accountID string) (*Settings, error) {
|
func (s *SqlStore) GetAccountSettings(ctx context.Context, lockStrength LockingStrength, accountID string) (*Settings, error) {
|
||||||
var accountSettings AccountSettings
|
var accountSettings AccountSettings
|
||||||
if err := s.db.Model(&Account{}).Where(idQueryCondition, accountID).First(&accountSettings).Error; err != nil {
|
if err := s.db.WithContext(ctx).Clauses(clause.Locking{Strength: string(lockStrength)}).Model(&Account{}).Where(idQueryCondition, accountID).First(&accountSettings).Error; err != nil {
|
||||||
if errors.Is(err, gorm.ErrRecordNotFound) {
|
if errors.Is(err, gorm.ErrRecordNotFound) {
|
||||||
return nil, status.Errorf(status.NotFound, "settings not found")
|
return nil, status.Errorf(status.NotFound, "settings not found")
|
||||||
}
|
}
|
||||||
log.WithContext(ctx).Errorf("error when getting settings from the store: %s", err)
|
|
||||||
return nil, status.Errorf(status.Internal, "issue getting settings from store")
|
return nil, status.Errorf(status.Internal, "issue getting settings from store")
|
||||||
}
|
}
|
||||||
return accountSettings.Settings, nil
|
return accountSettings.Settings, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// SaveUserLastLogin stores the last login time for a user in DB.
|
// SaveUserLastLogin stores the last login time for a user in DB.
|
||||||
func (s *SqlStore) SaveUserLastLogin(accountID, userID string, lastLogin time.Time) error {
|
func (s *SqlStore) SaveUserLastLogin(ctx context.Context, accountID, userID string, lastLogin time.Time) error {
|
||||||
var user User
|
var user User
|
||||||
|
|
||||||
result := s.db.First(&user, accountAndIDQueryCondition, accountID, userID)
|
result := s.db.WithContext(ctx).First(&user, accountAndIDQueryCondition, accountID, userID)
|
||||||
if result.Error != nil {
|
if result.Error != nil {
|
||||||
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
|
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
|
||||||
return status.Errorf(status.NotFound, "user %s not found", userID)
|
return status.NewUserNotFoundError(userID)
|
||||||
}
|
}
|
||||||
return status.Errorf(status.Internal, "issue getting user from store")
|
return status.NewGetUserFromStoreError()
|
||||||
}
|
}
|
||||||
|
|
||||||
user.LastLogin = lastLogin
|
user.LastLogin = lastLogin
|
||||||
|
|
||||||
return s.db.Save(user).Error
|
return s.db.Save(&user).Error
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *SqlStore) GetPostureCheckByChecksDefinition(accountID string, checks *posture.ChecksDefinition) (*posture.Checks, error) {
|
func (s *SqlStore) GetPostureCheckByChecksDefinition(accountID string, checks *posture.ChecksDefinition) (*posture.Checks, error) {
|
||||||
@ -850,3 +904,123 @@ func NewPostgresqlStoreFromFileStore(ctx context.Context, fileStore *FileStore,
|
|||||||
|
|
||||||
return store, nil
|
return store, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (s *SqlStore) GetSetupKeyBySecret(ctx context.Context, lockStrength LockingStrength, key string) (*SetupKey, error) {
|
||||||
|
var setupKey SetupKey
|
||||||
|
result := s.db.WithContext(ctx).Clauses(clause.Locking{Strength: string(lockStrength)}).
|
||||||
|
First(&setupKey, keyQueryCondition, strings.ToUpper(key))
|
||||||
|
if result.Error != nil {
|
||||||
|
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
|
||||||
|
return nil, status.Errorf(status.NotFound, "setup key not found")
|
||||||
|
}
|
||||||
|
return nil, status.NewSetupKeyNotFoundError()
|
||||||
|
}
|
||||||
|
return &setupKey, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *SqlStore) IncrementSetupKeyUsage(ctx context.Context, setupKeyID string) error {
|
||||||
|
result := s.db.WithContext(ctx).Model(&SetupKey{}).
|
||||||
|
Where(idQueryCondition, setupKeyID).
|
||||||
|
Updates(map[string]interface{}{
|
||||||
|
"used_times": gorm.Expr("used_times + 1"),
|
||||||
|
"last_used": time.Now(),
|
||||||
|
})
|
||||||
|
|
||||||
|
if result.Error != nil {
|
||||||
|
return status.Errorf(status.Internal, "issue incrementing setup key usage count: %s", result.Error)
|
||||||
|
}
|
||||||
|
|
||||||
|
if result.RowsAffected == 0 {
|
||||||
|
return status.Errorf(status.NotFound, "setup key not found")
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *SqlStore) AddPeerToAllGroup(ctx context.Context, accountID string, peerID string) error {
|
||||||
|
var group nbgroup.Group
|
||||||
|
|
||||||
|
result := s.db.WithContext(ctx).Where("account_id = ? AND name = ?", accountID, "All").First(&group)
|
||||||
|
if result.Error != nil {
|
||||||
|
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
|
||||||
|
return status.Errorf(status.NotFound, "group 'All' not found for account")
|
||||||
|
}
|
||||||
|
return status.Errorf(status.Internal, "issue finding group 'All'")
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, existingPeerID := range group.Peers {
|
||||||
|
if existingPeerID == peerID {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
group.Peers = append(group.Peers, peerID)
|
||||||
|
|
||||||
|
if err := s.db.Save(&group).Error; err != nil {
|
||||||
|
return status.Errorf(status.Internal, "issue updating group 'All'")
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *SqlStore) AddPeerToGroup(ctx context.Context, accountId string, peerId string, groupID string) error {
|
||||||
|
var group nbgroup.Group
|
||||||
|
|
||||||
|
result := s.db.WithContext(ctx).Where(accountAndIDQueryCondition, accountId, groupID).First(&group)
|
||||||
|
if result.Error != nil {
|
||||||
|
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
|
||||||
|
return status.Errorf(status.NotFound, "group not found for account")
|
||||||
|
}
|
||||||
|
return status.Errorf(status.Internal, "issue finding group")
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, existingPeerID := range group.Peers {
|
||||||
|
if existingPeerID == peerId {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
group.Peers = append(group.Peers, peerId)
|
||||||
|
|
||||||
|
if err := s.db.Save(&group).Error; err != nil {
|
||||||
|
return status.Errorf(status.Internal, "issue updating group")
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *SqlStore) AddPeerToAccount(ctx context.Context, peer *nbpeer.Peer) error {
|
||||||
|
if err := s.db.WithContext(ctx).Create(peer).Error; err != nil {
|
||||||
|
return status.Errorf(status.Internal, "issue adding peer to account")
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *SqlStore) IncrementNetworkSerial(ctx context.Context, accountId string) error {
|
||||||
|
result := s.db.WithContext(ctx).Model(&Account{}).Where(idQueryCondition, accountId).Update("network_serial", gorm.Expr("network_serial + 1"))
|
||||||
|
if result.Error != nil {
|
||||||
|
return status.Errorf(status.Internal, "issue incrementing network serial count")
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *SqlStore) ExecuteInTransaction(ctx context.Context, operation func(store Store) error) error {
|
||||||
|
tx := s.db.WithContext(ctx).Begin()
|
||||||
|
if tx.Error != nil {
|
||||||
|
return tx.Error
|
||||||
|
}
|
||||||
|
repo := s.withTx(tx)
|
||||||
|
err := operation(repo)
|
||||||
|
if err != nil {
|
||||||
|
tx.Rollback()
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
return tx.Commit().Error
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *SqlStore) withTx(tx *gorm.DB) Store {
|
||||||
|
return &SqlStore{
|
||||||
|
db: tx,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
@ -1003,3 +1003,163 @@ func TestPostgresql_GetUserByTokenID(t *testing.T) {
|
|||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
require.Equal(t, id, user.PATs[id].ID)
|
require.Equal(t, id, user.PATs[id].ID)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestSqlite_GetTakenIPs(t *testing.T) {
|
||||||
|
if runtime.GOOS == "windows" {
|
||||||
|
t.Skip("The SQLite store is not properly supported by Windows yet")
|
||||||
|
}
|
||||||
|
|
||||||
|
store := newSqliteStoreFromFile(t, "testdata/extended-store.json")
|
||||||
|
defer store.Close(context.Background())
|
||||||
|
|
||||||
|
existingAccountID := "bf1c8084-ba50-4ce7-9439-34653001fc3b"
|
||||||
|
|
||||||
|
_, err := store.GetAccount(context.Background(), existingAccountID)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
takenIPs, err := store.GetTakenIPs(context.Background(), LockingStrengthShare, existingAccountID)
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.Equal(t, []net.IP{}, takenIPs)
|
||||||
|
|
||||||
|
peer1 := &nbpeer.Peer{
|
||||||
|
ID: "peer1",
|
||||||
|
AccountID: existingAccountID,
|
||||||
|
IP: net.IP{1, 1, 1, 1},
|
||||||
|
}
|
||||||
|
err = store.AddPeerToAccount(context.Background(), peer1)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
takenIPs, err = store.GetTakenIPs(context.Background(), LockingStrengthShare, existingAccountID)
|
||||||
|
require.NoError(t, err)
|
||||||
|
ip1 := net.IP{1, 1, 1, 1}.To16()
|
||||||
|
assert.Equal(t, []net.IP{ip1}, takenIPs)
|
||||||
|
|
||||||
|
peer2 := &nbpeer.Peer{
|
||||||
|
ID: "peer2",
|
||||||
|
AccountID: existingAccountID,
|
||||||
|
IP: net.IP{2, 2, 2, 2},
|
||||||
|
}
|
||||||
|
err = store.AddPeerToAccount(context.Background(), peer2)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
takenIPs, err = store.GetTakenIPs(context.Background(), LockingStrengthShare, existingAccountID)
|
||||||
|
require.NoError(t, err)
|
||||||
|
ip2 := net.IP{2, 2, 2, 2}.To16()
|
||||||
|
assert.Equal(t, []net.IP{ip1, ip2}, takenIPs)
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSqlite_GetPeerLabelsInAccount(t *testing.T) {
|
||||||
|
if runtime.GOOS == "windows" {
|
||||||
|
t.Skip("The SQLite store is not properly supported by Windows yet")
|
||||||
|
}
|
||||||
|
|
||||||
|
store := newSqliteStoreFromFile(t, "testdata/extended-store.json")
|
||||||
|
defer store.Close(context.Background())
|
||||||
|
|
||||||
|
existingAccountID := "bf1c8084-ba50-4ce7-9439-34653001fc3b"
|
||||||
|
|
||||||
|
_, err := store.GetAccount(context.Background(), existingAccountID)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
labels, err := store.GetPeerLabelsInAccount(context.Background(), LockingStrengthShare, existingAccountID)
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.Equal(t, []string{}, labels)
|
||||||
|
|
||||||
|
peer1 := &nbpeer.Peer{
|
||||||
|
ID: "peer1",
|
||||||
|
AccountID: existingAccountID,
|
||||||
|
DNSLabel: "peer1.domain.test",
|
||||||
|
}
|
||||||
|
err = store.AddPeerToAccount(context.Background(), peer1)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
labels, err = store.GetPeerLabelsInAccount(context.Background(), LockingStrengthShare, existingAccountID)
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.Equal(t, []string{"peer1.domain.test"}, labels)
|
||||||
|
|
||||||
|
peer2 := &nbpeer.Peer{
|
||||||
|
ID: "peer2",
|
||||||
|
AccountID: existingAccountID,
|
||||||
|
DNSLabel: "peer2.domain.test",
|
||||||
|
}
|
||||||
|
err = store.AddPeerToAccount(context.Background(), peer2)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
labels, err = store.GetPeerLabelsInAccount(context.Background(), LockingStrengthShare, existingAccountID)
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.Equal(t, []string{"peer1.domain.test", "peer2.domain.test"}, labels)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSqlite_GetAccountNetwork(t *testing.T) {
|
||||||
|
if runtime.GOOS == "windows" {
|
||||||
|
t.Skip("The SQLite store is not properly supported by Windows yet")
|
||||||
|
}
|
||||||
|
|
||||||
|
store := newSqliteStoreFromFile(t, "testdata/extended-store.json")
|
||||||
|
defer store.Close(context.Background())
|
||||||
|
|
||||||
|
existingAccountID := "bf1c8084-ba50-4ce7-9439-34653001fc3b"
|
||||||
|
|
||||||
|
_, err := store.GetAccount(context.Background(), existingAccountID)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
network, err := store.GetAccountNetwork(context.Background(), LockingStrengthShare, existingAccountID)
|
||||||
|
require.NoError(t, err)
|
||||||
|
ip := net.IP{100, 64, 0, 0}.To16()
|
||||||
|
assert.Equal(t, ip, network.Net.IP)
|
||||||
|
assert.Equal(t, net.IPMask{255, 255, 0, 0}, network.Net.Mask)
|
||||||
|
assert.Equal(t, "", network.Dns)
|
||||||
|
assert.Equal(t, "af1c8024-ha40-4ce2-9418-34653101fc3c", network.Identifier)
|
||||||
|
assert.Equal(t, uint64(0), network.Serial)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSqlite_GetSetupKeyBySecret(t *testing.T) {
|
||||||
|
if runtime.GOOS == "windows" {
|
||||||
|
t.Skip("The SQLite store is not properly supported by Windows yet")
|
||||||
|
}
|
||||||
|
store := newSqliteStoreFromFile(t, "testdata/extended-store.json")
|
||||||
|
defer store.Close(context.Background())
|
||||||
|
|
||||||
|
existingAccountID := "bf1c8084-ba50-4ce7-9439-34653001fc3b"
|
||||||
|
|
||||||
|
_, err := store.GetAccount(context.Background(), existingAccountID)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
setupKey, err := store.GetSetupKeyBySecret(context.Background(), LockingStrengthShare, "A2C8E62B-38F5-4553-B31E-DD66C696CEBB")
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.Equal(t, "A2C8E62B-38F5-4553-B31E-DD66C696CEBB", setupKey.Key)
|
||||||
|
assert.Equal(t, "bf1c8084-ba50-4ce7-9439-34653001fc3b", setupKey.AccountID)
|
||||||
|
assert.Equal(t, "Default key", setupKey.Name)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSqlite_incrementSetupKeyUsage(t *testing.T) {
|
||||||
|
if runtime.GOOS == "windows" {
|
||||||
|
t.Skip("The SQLite store is not properly supported by Windows yet")
|
||||||
|
}
|
||||||
|
store := newSqliteStoreFromFile(t, "testdata/extended-store.json")
|
||||||
|
defer store.Close(context.Background())
|
||||||
|
|
||||||
|
existingAccountID := "bf1c8084-ba50-4ce7-9439-34653001fc3b"
|
||||||
|
|
||||||
|
_, err := store.GetAccount(context.Background(), existingAccountID)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
setupKey, err := store.GetSetupKeyBySecret(context.Background(), LockingStrengthShare, "A2C8E62B-38F5-4553-B31E-DD66C696CEBB")
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.Equal(t, 0, setupKey.UsedTimes)
|
||||||
|
|
||||||
|
err = store.IncrementSetupKeyUsage(context.Background(), setupKey.Id)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
setupKey, err = store.GetSetupKeyBySecret(context.Background(), LockingStrengthShare, "A2C8E62B-38F5-4553-B31E-DD66C696CEBB")
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.Equal(t, 1, setupKey.UsedTimes)
|
||||||
|
|
||||||
|
err = store.IncrementSetupKeyUsage(context.Background(), setupKey.Id)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
setupKey, err = store.GetSetupKeyBySecret(context.Background(), LockingStrengthShare, "A2C8E62B-38F5-4553-B31E-DD66C696CEBB")
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.Equal(t, 2, setupKey.UsedTimes)
|
||||||
|
}
|
||||||
|
@ -100,3 +100,13 @@ func NewPeerNotRegisteredError() error {
|
|||||||
func NewPeerLoginExpiredError() error {
|
func NewPeerLoginExpiredError() error {
|
||||||
return Errorf(PermissionDenied, "peer login has expired, please log in once more")
|
return Errorf(PermissionDenied, "peer login has expired, please log in once more")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// NewSetupKeyNotFoundError creates a new Error with NotFound type for a missing setup key
|
||||||
|
func NewSetupKeyNotFoundError() error {
|
||||||
|
return Errorf(NotFound, "setup key not found")
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewGetUserFromStoreError creates a new Error with Internal type for an issue getting user from store
|
||||||
|
func NewGetUserFromStoreError() error {
|
||||||
|
return Errorf(Internal, "issue getting user from store")
|
||||||
|
}
|
||||||
|
@ -27,6 +27,15 @@ import (
|
|||||||
"github.com/netbirdio/netbird/route"
|
"github.com/netbirdio/netbird/route"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
type LockingStrength string
|
||||||
|
|
||||||
|
const (
|
||||||
|
LockingStrengthUpdate LockingStrength = "UPDATE" // Strongest lock, preventing any changes by other transactions until your transaction completes.
|
||||||
|
LockingStrengthShare LockingStrength = "SHARE" // Allows reading but prevents changes by other transactions.
|
||||||
|
LockingStrengthNoKeyUpdate LockingStrength = "NO KEY UPDATE" // Similar to UPDATE but allows changes to related rows.
|
||||||
|
LockingStrengthKeyShare LockingStrength = "KEY SHARE" // Protects against changes to primary/unique keys but allows other updates.
|
||||||
|
)
|
||||||
|
|
||||||
type Store interface {
|
type Store interface {
|
||||||
GetAllAccounts(ctx context.Context) []*Account
|
GetAllAccounts(ctx context.Context) []*Account
|
||||||
GetAccount(ctx context.Context, accountID string) (*Account, error)
|
GetAccount(ctx context.Context, accountID string) (*Account, error)
|
||||||
@ -41,7 +50,7 @@ type Store interface {
|
|||||||
GetAccountByPrivateDomain(ctx context.Context, domain string) (*Account, error)
|
GetAccountByPrivateDomain(ctx context.Context, domain string) (*Account, error)
|
||||||
GetTokenIDByHashedToken(ctx context.Context, secret string) (string, error)
|
GetTokenIDByHashedToken(ctx context.Context, secret string) (string, error)
|
||||||
GetUserByTokenID(ctx context.Context, tokenID string) (*User, error)
|
GetUserByTokenID(ctx context.Context, tokenID string) (*User, error)
|
||||||
GetUserByUserID(ctx context.Context, userID string) (*User, error)
|
GetUserByUserID(ctx context.Context, lockStrength LockingStrength, userID string) (*User, error)
|
||||||
GetAccountGroups(ctx context.Context, accountID string) ([]*nbgroup.Group, error)
|
GetAccountGroups(ctx context.Context, accountID string) ([]*nbgroup.Group, error)
|
||||||
GetPostureCheckByChecksDefinition(accountID string, checks *posture.ChecksDefinition) (*posture.Checks, error)
|
GetPostureCheckByChecksDefinition(accountID string, checks *posture.ChecksDefinition) (*posture.Checks, error)
|
||||||
SaveAccount(ctx context.Context, account *Account) error
|
SaveAccount(ctx context.Context, account *Account) error
|
||||||
@ -60,14 +69,24 @@ type Store interface {
|
|||||||
SavePeer(ctx context.Context, accountID string, peer *nbpeer.Peer) error
|
SavePeer(ctx context.Context, accountID string, peer *nbpeer.Peer) error
|
||||||
SavePeerStatus(accountID, peerID string, status nbpeer.PeerStatus) error
|
SavePeerStatus(accountID, peerID string, status nbpeer.PeerStatus) error
|
||||||
SavePeerLocation(accountID string, peer *nbpeer.Peer) error
|
SavePeerLocation(accountID string, peer *nbpeer.Peer) error
|
||||||
SaveUserLastLogin(accountID, userID string, lastLogin time.Time) error
|
SaveUserLastLogin(ctx context.Context, accountID, userID string, lastLogin time.Time) error
|
||||||
// Close should close the store persisting all unsaved data.
|
// Close should close the store persisting all unsaved data.
|
||||||
Close(ctx context.Context) error
|
Close(ctx context.Context) error
|
||||||
// GetStoreEngine should return StoreEngine of the current store implementation.
|
// GetStoreEngine should return StoreEngine of the current store implementation.
|
||||||
// This is also a method of metrics.DataSource interface.
|
// This is also a method of metrics.DataSource interface.
|
||||||
GetStoreEngine() StoreEngine
|
GetStoreEngine() StoreEngine
|
||||||
GetPeerByPeerPubKey(ctx context.Context, peerKey string) (*nbpeer.Peer, error)
|
GetPeerByPeerPubKey(ctx context.Context, lockStrength LockingStrength, peerKey string) (*nbpeer.Peer, error)
|
||||||
GetAccountSettings(ctx context.Context, accountID string) (*Settings, error)
|
GetAccountSettings(ctx context.Context, lockStrength LockingStrength, accountID string) (*Settings, error)
|
||||||
|
GetSetupKeyBySecret(ctx context.Context, lockStrength LockingStrength, key string) (*SetupKey, error)
|
||||||
|
GetTakenIPs(ctx context.Context, lockStrength LockingStrength, accountId string) ([]net.IP, error)
|
||||||
|
IncrementSetupKeyUsage(ctx context.Context, setupKeyID string) error
|
||||||
|
AddPeerToAllGroup(ctx context.Context, accountID string, peerID string) error
|
||||||
|
GetPeerLabelsInAccount(ctx context.Context, lockStrength LockingStrength, accountId string) ([]string, error)
|
||||||
|
AddPeerToGroup(ctx context.Context, accountId string, peerId string, groupID string) error
|
||||||
|
AddPeerToAccount(ctx context.Context, peer *nbpeer.Peer) error
|
||||||
|
IncrementNetworkSerial(ctx context.Context, accountId string) error
|
||||||
|
GetAccountNetwork(ctx context.Context, lockStrength LockingStrength, accountId string) (*Network, error)
|
||||||
|
ExecuteInTransaction(ctx context.Context, f func(store Store) error) error
|
||||||
}
|
}
|
||||||
|
|
||||||
type StoreEngine string
|
type StoreEngine string
|
||||||
|
120
management/server/testdata/extended-store.json
vendored
Normal file
120
management/server/testdata/extended-store.json
vendored
Normal file
@ -0,0 +1,120 @@
|
|||||||
|
{
|
||||||
|
"Accounts": {
|
||||||
|
"bf1c8084-ba50-4ce7-9439-34653001fc3b": {
|
||||||
|
"Id": "bf1c8084-ba50-4ce7-9439-34653001fc3b",
|
||||||
|
"CreatedBy": "",
|
||||||
|
"Domain": "test.com",
|
||||||
|
"DomainCategory": "private",
|
||||||
|
"IsDomainPrimaryAccount": true,
|
||||||
|
"SetupKeys": {
|
||||||
|
"A2C8E62B-38F5-4553-B31E-DD66C696CEBB": {
|
||||||
|
"Id": "A2C8E62B-38F5-4553-B31E-DD66C696CEBB",
|
||||||
|
"AccountID": "",
|
||||||
|
"Key": "A2C8E62B-38F5-4553-B31E-DD66C696CEBB",
|
||||||
|
"Name": "Default key",
|
||||||
|
"Type": "reusable",
|
||||||
|
"CreatedAt": "2021-08-19T20:46:20.005936822+02:00",
|
||||||
|
"ExpiresAt": "2321-09-18T20:46:20.005936822+02:00",
|
||||||
|
"UpdatedAt": "0001-01-01T00:00:00Z",
|
||||||
|
"Revoked": false,
|
||||||
|
"UsedTimes": 0,
|
||||||
|
"LastUsed": "0001-01-01T00:00:00Z",
|
||||||
|
"AutoGroups": ["cfefqs706sqkneg59g2g"],
|
||||||
|
"UsageLimit": 0,
|
||||||
|
"Ephemeral": false
|
||||||
|
},
|
||||||
|
"A2C8E62B-38F5-4553-B31E-DD66C696CEBC": {
|
||||||
|
"Id": "A2C8E62B-38F5-4553-B31E-DD66C696CEBC",
|
||||||
|
"AccountID": "",
|
||||||
|
"Key": "A2C8E62B-38F5-4553-B31E-DD66C696CEBC",
|
||||||
|
"Name": "Faulty key with non existing group",
|
||||||
|
"Type": "reusable",
|
||||||
|
"CreatedAt": "2021-08-19T20:46:20.005936822+02:00",
|
||||||
|
"ExpiresAt": "2321-09-18T20:46:20.005936822+02:00",
|
||||||
|
"UpdatedAt": "0001-01-01T00:00:00Z",
|
||||||
|
"Revoked": false,
|
||||||
|
"UsedTimes": 0,
|
||||||
|
"LastUsed": "0001-01-01T00:00:00Z",
|
||||||
|
"AutoGroups": ["abcd"],
|
||||||
|
"UsageLimit": 0,
|
||||||
|
"Ephemeral": false
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"Network": {
|
||||||
|
"id": "af1c8024-ha40-4ce2-9418-34653101fc3c",
|
||||||
|
"Net": {
|
||||||
|
"IP": "100.64.0.0",
|
||||||
|
"Mask": "//8AAA=="
|
||||||
|
},
|
||||||
|
"Dns": "",
|
||||||
|
"Serial": 0
|
||||||
|
},
|
||||||
|
"Peers": {},
|
||||||
|
"Users": {
|
||||||
|
"edafee4e-63fb-11ec-90d6-0242ac120003": {
|
||||||
|
"Id": "edafee4e-63fb-11ec-90d6-0242ac120003",
|
||||||
|
"AccountID": "",
|
||||||
|
"Role": "admin",
|
||||||
|
"IsServiceUser": false,
|
||||||
|
"ServiceUserName": "",
|
||||||
|
"AutoGroups": ["cfefqs706sqkneg59g3g"],
|
||||||
|
"PATs": {},
|
||||||
|
"Blocked": false,
|
||||||
|
"LastLogin": "0001-01-01T00:00:00Z"
|
||||||
|
},
|
||||||
|
"f4f6d672-63fb-11ec-90d6-0242ac120003": {
|
||||||
|
"Id": "f4f6d672-63fb-11ec-90d6-0242ac120003",
|
||||||
|
"AccountID": "",
|
||||||
|
"Role": "user",
|
||||||
|
"IsServiceUser": false,
|
||||||
|
"ServiceUserName": "",
|
||||||
|
"AutoGroups": null,
|
||||||
|
"PATs": {
|
||||||
|
"9dj38s35-63fb-11ec-90d6-0242ac120003": {
|
||||||
|
"ID": "9dj38s35-63fb-11ec-90d6-0242ac120003",
|
||||||
|
"UserID": "",
|
||||||
|
"Name": "",
|
||||||
|
"HashedToken": "SoMeHaShEdToKeN",
|
||||||
|
"ExpirationDate": "2023-02-27T00:00:00Z",
|
||||||
|
"CreatedBy": "user",
|
||||||
|
"CreatedAt": "2023-01-01T00:00:00Z",
|
||||||
|
"LastUsed": "2023-02-01T00:00:00Z"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"Blocked": false,
|
||||||
|
"LastLogin": "0001-01-01T00:00:00Z"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"Groups": {
|
||||||
|
"cfefqs706sqkneg59g4g": {
|
||||||
|
"ID": "cfefqs706sqkneg59g4g",
|
||||||
|
"Name": "All",
|
||||||
|
"Peers": []
|
||||||
|
},
|
||||||
|
"cfefqs706sqkneg59g3g": {
|
||||||
|
"ID": "cfefqs706sqkneg59g3g",
|
||||||
|
"Name": "AwesomeGroup1",
|
||||||
|
"Peers": []
|
||||||
|
},
|
||||||
|
"cfefqs706sqkneg59g2g": {
|
||||||
|
"ID": "cfefqs706sqkneg59g2g",
|
||||||
|
"Name": "AwesomeGroup2",
|
||||||
|
"Peers": []
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"Rules": null,
|
||||||
|
"Policies": [],
|
||||||
|
"Routes": null,
|
||||||
|
"NameServerGroups": null,
|
||||||
|
"DNSSettings": null,
|
||||||
|
"Settings": {
|
||||||
|
"PeerLoginExpirationEnabled": false,
|
||||||
|
"PeerLoginExpiration": 86400000000000,
|
||||||
|
"GroupsPropagationEnabled": false,
|
||||||
|
"JWTGroupsEnabled": false,
|
||||||
|
"JWTGroupsClaimName": ""
|
||||||
|
}
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"InstallationID": ""
|
||||||
|
}
|
@ -89,10 +89,6 @@ func (u *User) LastDashboardLoginChanged(LastLogin time.Time) bool {
|
|||||||
return LastLogin.After(u.LastLogin) && !u.LastLogin.IsZero()
|
return LastLogin.After(u.LastLogin) && !u.LastLogin.IsZero()
|
||||||
}
|
}
|
||||||
|
|
||||||
func (u *User) updateLastLogin(login time.Time) {
|
|
||||||
u.LastLogin = login
|
|
||||||
}
|
|
||||||
|
|
||||||
// HasAdminPower returns true if the user has admin or owner roles, false otherwise
|
// HasAdminPower returns true if the user has admin or owner roles, false otherwise
|
||||||
func (u *User) HasAdminPower() bool {
|
func (u *User) HasAdminPower() bool {
|
||||||
return u.Role == UserRoleAdmin || u.Role == UserRoleOwner
|
return u.Role == UserRoleAdmin || u.Role == UserRoleOwner
|
||||||
@ -386,7 +382,7 @@ func (am *DefaultAccountManager) GetUser(ctx context.Context, claims jwtclaims.A
|
|||||||
// server when user authenticates a device. And we need to separate the Dashboard login event from the Device login event.
|
// server when user authenticates a device. And we need to separate the Dashboard login event from the Device login event.
|
||||||
newLogin := user.LastDashboardLoginChanged(claims.LastLogin)
|
newLogin := user.LastDashboardLoginChanged(claims.LastLogin)
|
||||||
|
|
||||||
err = am.Store.SaveUserLastLogin(account.Id, claims.UserId, claims.LastLogin)
|
err = am.Store.SaveUserLastLogin(ctx, account.Id, claims.UserId, claims.LastLogin)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.WithContext(ctx).Errorf("failed saving user last login: %v", err)
|
log.WithContext(ctx).Errorf("failed saving user last login: %v", err)
|
||||||
}
|
}
|
||||||
|
Loading…
Reference in New Issue
Block a user