mirror of
https://github.com/netbirdio/netbird.git
synced 2024-11-24 17:13:30 +01:00
[management] Add transaction to addPeer (#2469)
This PR removes the GetAccount and SaveAccount operations from the AddPeer and instead makes use of gorm.Transaction to add the new peer.
This commit is contained in:
parent
730dd1733e
commit
6c50b0c84b
2
.github/workflows/golang-test-linux.yml
vendored
2
.github/workflows/golang-test-linux.yml
vendored
@ -49,7 +49,7 @@ jobs:
|
||||
run: git --no-pager diff --exit-code
|
||||
|
||||
- name: Test
|
||||
run: CGO_ENABLED=1 GOARCH=${{ matrix.arch }} NETBIRD_STORE_ENGINE=${{ matrix.store }} go test -exec 'sudo --preserve-env=CI,NETBIRD_STORE_ENGINE' -timeout 5m -p 1 ./...
|
||||
run: CGO_ENABLED=1 GOARCH=${{ matrix.arch }} NETBIRD_STORE_ENGINE=${{ matrix.store }} go test -exec 'sudo --preserve-env=CI,NETBIRD_STORE_ENGINE' -timeout 6m -p 1 ./...
|
||||
|
||||
test_client_on_docker:
|
||||
runs-on: ubuntu-20.04
|
||||
|
@ -263,6 +263,11 @@ type AccountSettings struct {
|
||||
Settings *Settings `gorm:"embedded;embeddedPrefix:settings_"`
|
||||
}
|
||||
|
||||
// Subclass used in gorm to only load network and not whole account
|
||||
type AccountNetwork struct {
|
||||
Network *Network `gorm:"embedded;embeddedPrefix:network_"`
|
||||
}
|
||||
|
||||
type UserPermissions struct {
|
||||
DashboardView string `json:"dashboard_view"`
|
||||
}
|
||||
@ -700,14 +705,6 @@ func (a *Account) GetPeerGroupsList(peerID string) []string {
|
||||
return grps
|
||||
}
|
||||
|
||||
func (a *Account) getUserGroups(userID string) ([]string, error) {
|
||||
user, err := a.FindUser(userID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return user.AutoGroups, nil
|
||||
}
|
||||
|
||||
func (a *Account) getPeerDNSManagementStatus(peerID string) bool {
|
||||
peerGroups := a.getPeerGroups(peerID)
|
||||
enabled := true
|
||||
@ -734,14 +731,6 @@ func (a *Account) getPeerGroups(peerID string) lookupMap {
|
||||
return groupList
|
||||
}
|
||||
|
||||
func (a *Account) getSetupKeyGroups(setupKey string) ([]string, error) {
|
||||
key, err := a.FindSetupKey(setupKey)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return key.AutoGroups, nil
|
||||
}
|
||||
|
||||
func (a *Account) getTakenIPs() []net.IP {
|
||||
var takenIps []net.IP
|
||||
for _, existingPeer := range a.Peers {
|
||||
@ -2082,7 +2071,7 @@ func (am *DefaultAccountManager) GetAccountIDForPeerKey(ctx context.Context, pee
|
||||
}
|
||||
|
||||
func (am *DefaultAccountManager) handleUserPeer(ctx context.Context, peer *nbpeer.Peer, settings *Settings) (bool, error) {
|
||||
user, err := am.Store.GetUserByUserID(ctx, peer.UserID)
|
||||
user, err := am.Store.GetUserByUserID(ctx, LockingStrengthShare, peer.UserID)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
@ -2103,6 +2092,25 @@ func (am *DefaultAccountManager) handleUserPeer(ctx context.Context, peer *nbpee
|
||||
return false, nil
|
||||
}
|
||||
|
||||
func (am *DefaultAccountManager) getFreeDNSLabel(ctx context.Context, store Store, accountID string, peerHostName string) (string, error) {
|
||||
existingLabels, err := store.GetPeerLabelsInAccount(ctx, LockingStrengthShare, accountID)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to get peer dns labels: %w", err)
|
||||
}
|
||||
|
||||
labelMap := ConvertSliceToMap(existingLabels)
|
||||
newLabel, err := getPeerHostLabel(peerHostName, labelMap)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to get new host label: %w", err)
|
||||
}
|
||||
|
||||
if newLabel == "" {
|
||||
return "", fmt.Errorf("failed to get new host label: %w", err)
|
||||
}
|
||||
|
||||
return newLabel, nil
|
||||
}
|
||||
|
||||
// addAllGroup to account object if it doesn't exist
|
||||
func addAllGroup(account *Account) error {
|
||||
if len(account.Groups) == 0 {
|
||||
|
@ -7,6 +7,7 @@ import (
|
||||
"time"
|
||||
|
||||
nbpeer "github.com/netbirdio/netbird/management/server/peer"
|
||||
"github.com/netbirdio/netbird/management/server/status"
|
||||
)
|
||||
|
||||
type MockStore struct {
|
||||
@ -24,7 +25,7 @@ func (s *MockStore) GetAccountByPeerID(_ context.Context, peerId string) (*Accou
|
||||
return s.account, nil
|
||||
}
|
||||
|
||||
return nil, fmt.Errorf("account not found")
|
||||
return nil, status.NewPeerNotFoundError(peerId)
|
||||
}
|
||||
|
||||
type MocAccountManager struct {
|
||||
|
@ -2,6 +2,8 @@ package server
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"net"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
@ -46,6 +48,158 @@ type FileStore struct {
|
||||
metrics telemetry.AppMetrics `json:"-"`
|
||||
}
|
||||
|
||||
func (s *FileStore) ExecuteInTransaction(ctx context.Context, f func(store Store) error) error {
|
||||
return f(s)
|
||||
}
|
||||
|
||||
func (s *FileStore) IncrementSetupKeyUsage(ctx context.Context, setupKeyID string) error {
|
||||
s.mux.Lock()
|
||||
defer s.mux.Unlock()
|
||||
|
||||
accountID, ok := s.SetupKeyID2AccountID[strings.ToUpper(setupKeyID)]
|
||||
if !ok {
|
||||
return status.NewSetupKeyNotFoundError()
|
||||
}
|
||||
|
||||
account, err := s.getAccount(accountID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
account.SetupKeys[setupKeyID].UsedTimes++
|
||||
|
||||
return s.SaveAccount(ctx, account)
|
||||
}
|
||||
|
||||
func (s *FileStore) AddPeerToAllGroup(ctx context.Context, accountID string, peerID string) error {
|
||||
s.mux.Lock()
|
||||
defer s.mux.Unlock()
|
||||
|
||||
account, err := s.getAccount(accountID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
allGroup, err := account.GetGroupAll()
|
||||
if err != nil || allGroup == nil {
|
||||
return errors.New("all group not found")
|
||||
}
|
||||
|
||||
allGroup.Peers = append(allGroup.Peers, peerID)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *FileStore) AddPeerToGroup(ctx context.Context, accountId string, peerId string, groupID string) error {
|
||||
s.mux.Lock()
|
||||
defer s.mux.Unlock()
|
||||
|
||||
account, err := s.getAccount(accountId)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
account.Groups[groupID].Peers = append(account.Groups[groupID].Peers, peerId)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *FileStore) AddPeerToAccount(ctx context.Context, peer *nbpeer.Peer) error {
|
||||
s.mux.Lock()
|
||||
defer s.mux.Unlock()
|
||||
|
||||
account, ok := s.Accounts[peer.AccountID]
|
||||
if !ok {
|
||||
return status.NewAccountNotFoundError(peer.AccountID)
|
||||
}
|
||||
|
||||
account.Peers[peer.ID] = peer
|
||||
return s.SaveAccount(ctx, account)
|
||||
}
|
||||
|
||||
func (s *FileStore) IncrementNetworkSerial(ctx context.Context, accountId string) error {
|
||||
s.mux.Lock()
|
||||
defer s.mux.Unlock()
|
||||
|
||||
account, ok := s.Accounts[accountId]
|
||||
if !ok {
|
||||
return status.NewAccountNotFoundError(accountId)
|
||||
}
|
||||
|
||||
account.Network.Serial++
|
||||
|
||||
return s.SaveAccount(ctx, account)
|
||||
}
|
||||
|
||||
func (s *FileStore) GetSetupKeyBySecret(ctx context.Context, lockStrength LockingStrength, key string) (*SetupKey, error) {
|
||||
s.mux.Lock()
|
||||
defer s.mux.Unlock()
|
||||
|
||||
accountID, ok := s.SetupKeyID2AccountID[strings.ToUpper(key)]
|
||||
if !ok {
|
||||
return nil, status.NewSetupKeyNotFoundError()
|
||||
}
|
||||
|
||||
account, err := s.getAccount(accountID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
setupKey, ok := account.SetupKeys[key]
|
||||
if !ok {
|
||||
return nil, status.Errorf(status.NotFound, "setup key not found")
|
||||
}
|
||||
|
||||
return setupKey, nil
|
||||
}
|
||||
|
||||
func (s *FileStore) GetTakenIPs(ctx context.Context, lockStrength LockingStrength, accountID string) ([]net.IP, error) {
|
||||
s.mux.Lock()
|
||||
defer s.mux.Unlock()
|
||||
|
||||
account, err := s.getAccount(accountID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
var takenIps []net.IP
|
||||
for _, existingPeer := range account.Peers {
|
||||
takenIps = append(takenIps, existingPeer.IP)
|
||||
}
|
||||
|
||||
return takenIps, nil
|
||||
}
|
||||
|
||||
func (s *FileStore) GetPeerLabelsInAccount(ctx context.Context, lockStrength LockingStrength, accountID string) ([]string, error) {
|
||||
s.mux.Lock()
|
||||
defer s.mux.Unlock()
|
||||
|
||||
account, err := s.getAccount(accountID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
existingLabels := []string{}
|
||||
for _, peer := range account.Peers {
|
||||
if peer.DNSLabel != "" {
|
||||
existingLabels = append(existingLabels, peer.DNSLabel)
|
||||
}
|
||||
}
|
||||
return existingLabels, nil
|
||||
}
|
||||
|
||||
func (s *FileStore) GetAccountNetwork(ctx context.Context, lockStrength LockingStrength, accountID string) (*Network, error) {
|
||||
s.mux.Lock()
|
||||
defer s.mux.Unlock()
|
||||
|
||||
account, err := s.getAccount(accountID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return account.Network, nil
|
||||
}
|
||||
|
||||
type StoredAccount struct{}
|
||||
|
||||
// NewFileStore restores a store from the file located in the datadir
|
||||
@ -422,7 +576,7 @@ func (s *FileStore) GetAccountBySetupKey(_ context.Context, setupKey string) (*A
|
||||
|
||||
accountID, ok := s.SetupKeyID2AccountID[strings.ToUpper(setupKey)]
|
||||
if !ok {
|
||||
return nil, status.Errorf(status.NotFound, "account not found: provided setup key doesn't exists")
|
||||
return nil, status.NewSetupKeyNotFoundError()
|
||||
}
|
||||
|
||||
account, err := s.getAccount(accountID)
|
||||
@ -469,7 +623,7 @@ func (s *FileStore) GetUserByTokenID(_ context.Context, tokenID string) (*User,
|
||||
return account.Users[userID].Copy(), nil
|
||||
}
|
||||
|
||||
func (s *FileStore) GetUserByUserID(_ context.Context, userID string) (*User, error) {
|
||||
func (s *FileStore) GetUserByUserID(_ context.Context, _ LockingStrength, userID string) (*User, error) {
|
||||
accountID, ok := s.UserID2AccountID[userID]
|
||||
if !ok {
|
||||
return nil, status.Errorf(status.NotFound, "accountID not found: provided userID doesn't exists")
|
||||
@ -513,7 +667,7 @@ func (s *FileStore) GetAllAccounts(_ context.Context) (all []*Account) {
|
||||
func (s *FileStore) getAccount(accountID string) (*Account, error) {
|
||||
account, ok := s.Accounts[accountID]
|
||||
if !ok {
|
||||
return nil, status.Errorf(status.NotFound, "account not found")
|
||||
return nil, status.NewAccountNotFoundError(accountID)
|
||||
}
|
||||
|
||||
return account, nil
|
||||
@ -639,13 +793,13 @@ func (s *FileStore) GetAccountIDBySetupKey(_ context.Context, setupKey string) (
|
||||
|
||||
accountID, ok := s.SetupKeyID2AccountID[strings.ToUpper(setupKey)]
|
||||
if !ok {
|
||||
return "", status.Errorf(status.NotFound, "account not found: provided setup key doesn't exists")
|
||||
return "", status.NewSetupKeyNotFoundError()
|
||||
}
|
||||
|
||||
return accountID, nil
|
||||
}
|
||||
|
||||
func (s *FileStore) GetPeerByPeerPubKey(_ context.Context, peerKey string) (*nbpeer.Peer, error) {
|
||||
func (s *FileStore) GetPeerByPeerPubKey(_ context.Context, _ LockingStrength, peerKey string) (*nbpeer.Peer, error) {
|
||||
s.mux.Lock()
|
||||
defer s.mux.Unlock()
|
||||
|
||||
@ -668,7 +822,7 @@ func (s *FileStore) GetPeerByPeerPubKey(_ context.Context, peerKey string) (*nbp
|
||||
return nil, status.NewPeerNotFoundError(peerKey)
|
||||
}
|
||||
|
||||
func (s *FileStore) GetAccountSettings(_ context.Context, accountID string) (*Settings, error) {
|
||||
func (s *FileStore) GetAccountSettings(_ context.Context, _ LockingStrength, accountID string) (*Settings, error) {
|
||||
s.mux.Lock()
|
||||
defer s.mux.Unlock()
|
||||
|
||||
@ -758,7 +912,7 @@ func (s *FileStore) SavePeerLocation(accountID string, peerWithLocation *nbpeer.
|
||||
}
|
||||
|
||||
// SaveUserLastLogin stores the last login time for a user in memory. It doesn't attempt to persist data to speed up things.
|
||||
func (s *FileStore) SaveUserLastLogin(accountID, userID string, lastLogin time.Time) error {
|
||||
func (s *FileStore) SaveUserLastLogin(_ context.Context, accountID, userID string, lastLogin time.Time) error {
|
||||
s.mux.Lock()
|
||||
defer s.mux.Unlock()
|
||||
|
||||
|
@ -627,7 +627,7 @@ func testSyncStatusRace(t *testing.T) {
|
||||
}
|
||||
|
||||
time.Sleep(10 * time.Millisecond)
|
||||
peer, err := am.Store.GetPeerByPeerPubKey(context.Background(), peerWithInvalidStatus.PublicKey().String())
|
||||
peer, err := am.Store.GetPeerByPeerPubKey(context.Background(), LockingStrengthShare, peerWithInvalidStatus.PublicKey().String())
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
return
|
||||
@ -638,8 +638,8 @@ func testSyncStatusRace(t *testing.T) {
|
||||
}
|
||||
|
||||
func Test_LoginPerformance(t *testing.T) {
|
||||
if os.Getenv("CI") == "true" {
|
||||
t.Skip("Skipping on CI")
|
||||
if os.Getenv("CI") == "true" || runtime.GOOS == "windows" {
|
||||
t.Skip("Skipping test on CI or Windows")
|
||||
}
|
||||
|
||||
t.Setenv("NETBIRD_STORE_ENGINE", "sqlite")
|
||||
@ -655,7 +655,7 @@ func Test_LoginPerformance(t *testing.T) {
|
||||
// {"M", 250, 1},
|
||||
// {"L", 500, 1},
|
||||
// {"XL", 750, 1},
|
||||
{"XXL", 2000, 1},
|
||||
{"XXL", 5000, 1},
|
||||
}
|
||||
|
||||
log.SetOutput(io.Discard)
|
||||
@ -700,15 +700,18 @@ func Test_LoginPerformance(t *testing.T) {
|
||||
}
|
||||
defer mgmtServer.GracefulStop()
|
||||
|
||||
t.Logf("management setup complete, start registering peers")
|
||||
|
||||
var counter int32
|
||||
var counterStart int32
|
||||
var wg sync.WaitGroup
|
||||
var wgAccount sync.WaitGroup
|
||||
var mu sync.Mutex
|
||||
messageCalls := []func() error{}
|
||||
for j := 0; j < bc.accounts; j++ {
|
||||
wg.Add(1)
|
||||
wgAccount.Add(1)
|
||||
var wgPeer sync.WaitGroup
|
||||
go func(j int, counter *int32, counterStart *int32) {
|
||||
defer wg.Done()
|
||||
defer wgAccount.Done()
|
||||
|
||||
account, err := createAccount(am, fmt.Sprintf("account-%d", j), fmt.Sprintf("user-%d", j), fmt.Sprintf("domain-%d", j))
|
||||
if err != nil {
|
||||
@ -722,7 +725,9 @@ func Test_LoginPerformance(t *testing.T) {
|
||||
return
|
||||
}
|
||||
|
||||
startTime := time.Now()
|
||||
for i := 0; i < bc.peers; i++ {
|
||||
wgPeer.Add(1)
|
||||
key, err := wgtypes.GeneratePrivateKey()
|
||||
if err != nil {
|
||||
t.Logf("failed to generate key: %v", err)
|
||||
@ -763,21 +768,29 @@ func Test_LoginPerformance(t *testing.T) {
|
||||
mu.Lock()
|
||||
messageCalls = append(messageCalls, login)
|
||||
mu.Unlock()
|
||||
_, _, _, err = am.LoginPeer(context.Background(), peerLogin)
|
||||
if err != nil {
|
||||
t.Logf("failed to login peer: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
atomic.AddInt32(counterStart, 1)
|
||||
if *counterStart%100 == 0 {
|
||||
t.Logf("registered %d peers", *counterStart)
|
||||
}
|
||||
go func(peerLogin PeerLogin, counterStart *int32) {
|
||||
defer wgPeer.Done()
|
||||
_, _, _, err = am.LoginPeer(context.Background(), peerLogin)
|
||||
if err != nil {
|
||||
t.Logf("failed to login peer: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
atomic.AddInt32(counterStart, 1)
|
||||
if *counterStart%100 == 0 {
|
||||
t.Logf("registered %d peers", *counterStart)
|
||||
}
|
||||
}(peerLogin, counterStart)
|
||||
|
||||
}
|
||||
wgPeer.Wait()
|
||||
|
||||
t.Logf("Time for registration: %s", time.Since(startTime))
|
||||
}(j, &counter, &counterStart)
|
||||
}
|
||||
|
||||
wg.Wait()
|
||||
wgAccount.Wait()
|
||||
|
||||
t.Logf("prepared %d login calls", len(messageCalls))
|
||||
testLoginPerformance(t, messageCalls)
|
||||
|
@ -11,6 +11,7 @@ import (
|
||||
"github.com/rs/xid"
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
"github.com/netbirdio/netbird/management/server/idp"
|
||||
"github.com/netbirdio/netbird/management/server/posture"
|
||||
|
||||
"github.com/netbirdio/netbird/management/proto"
|
||||
@ -371,164 +372,175 @@ func (am *DefaultAccountManager) AddPeer(ctx context.Context, setupKey, userID s
|
||||
}
|
||||
}()
|
||||
|
||||
var account *Account
|
||||
// ensure that we consider modification happened meanwhile (because we were outside the account lock when we fetched the account)
|
||||
account, err = am.Store.GetAccount(ctx, accountID)
|
||||
if err != nil {
|
||||
return nil, nil, nil, err
|
||||
}
|
||||
|
||||
if strings.ToLower(peer.Meta.Hostname) == "iphone" || strings.ToLower(peer.Meta.Hostname) == "ipad" && userID != "" {
|
||||
if am.idpManager != nil {
|
||||
userdata, err := am.lookupUserInCache(ctx, userID, account)
|
||||
if err == nil && userdata != nil {
|
||||
peer.Meta.Hostname = fmt.Sprintf("%s-%s", peer.Meta.Hostname, strings.Split(userdata.Email, "@")[0])
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// This is a handling for the case when the same machine (with the same WireGuard pub key) tries to register twice.
|
||||
// Such case is possible when AddPeer function takes long time to finish after AcquireWriteLockByUID (e.g., database is slow)
|
||||
// and the peer disconnects with a timeout and tries to register again.
|
||||
// We just check if this machine has been registered before and reject the second registration.
|
||||
// The connecting peer should be able to recover with a retry.
|
||||
_, err = account.FindPeerByPubKey(peer.Key)
|
||||
_, err = am.Store.GetPeerByPeerPubKey(ctx, LockingStrengthShare, peer.Key)
|
||||
if err == nil {
|
||||
return nil, nil, nil, status.Errorf(status.PreconditionFailed, "peer has been already registered")
|
||||
}
|
||||
|
||||
opEvent := &activity.Event{
|
||||
Timestamp: time.Now().UTC(),
|
||||
AccountID: account.Id,
|
||||
AccountID: accountID,
|
||||
}
|
||||
|
||||
var ephemeral bool
|
||||
setupKeyName := ""
|
||||
if !addedByUser {
|
||||
// validate the setup key if adding with a key
|
||||
sk, err := account.FindSetupKey(upperKey)
|
||||
if err != nil {
|
||||
return nil, nil, nil, err
|
||||
}
|
||||
var newPeer *nbpeer.Peer
|
||||
|
||||
if !sk.IsValid() {
|
||||
return nil, nil, nil, status.Errorf(status.PreconditionFailed, "couldn't add peer: setup key is invalid")
|
||||
}
|
||||
|
||||
account.SetupKeys[sk.Key] = sk.IncrementUsage()
|
||||
opEvent.InitiatorID = sk.Id
|
||||
opEvent.Activity = activity.PeerAddedWithSetupKey
|
||||
ephemeral = sk.Ephemeral
|
||||
setupKeyName = sk.Name
|
||||
} else {
|
||||
opEvent.InitiatorID = userID
|
||||
opEvent.Activity = activity.PeerAddedByUser
|
||||
}
|
||||
|
||||
takenIps := account.getTakenIPs()
|
||||
existingLabels := account.getPeerDNSLabels()
|
||||
|
||||
newLabel, err := getPeerHostLabel(peer.Meta.Hostname, existingLabels)
|
||||
if err != nil {
|
||||
return nil, nil, nil, err
|
||||
}
|
||||
|
||||
peer.DNSLabel = newLabel
|
||||
network := account.Network
|
||||
nextIp, err := AllocatePeerIP(network.Net, takenIps)
|
||||
if err != nil {
|
||||
return nil, nil, nil, err
|
||||
}
|
||||
|
||||
registrationTime := time.Now().UTC()
|
||||
|
||||
newPeer := &nbpeer.Peer{
|
||||
ID: xid.New().String(),
|
||||
Key: peer.Key,
|
||||
SetupKey: upperKey,
|
||||
IP: nextIp,
|
||||
Meta: peer.Meta,
|
||||
Name: peer.Meta.Hostname,
|
||||
DNSLabel: newLabel,
|
||||
UserID: userID,
|
||||
Status: &nbpeer.PeerStatus{Connected: false, LastSeen: registrationTime},
|
||||
SSHEnabled: false,
|
||||
SSHKey: peer.SSHKey,
|
||||
LastLogin: registrationTime,
|
||||
CreatedAt: registrationTime,
|
||||
LoginExpirationEnabled: addedByUser,
|
||||
Ephemeral: ephemeral,
|
||||
Location: peer.Location,
|
||||
}
|
||||
|
||||
if am.geo != nil && newPeer.Location.ConnectionIP != nil {
|
||||
location, err := am.geo.Lookup(newPeer.Location.ConnectionIP)
|
||||
if err != nil {
|
||||
log.WithContext(ctx).Warnf("failed to get location for new peer realip: [%s]: %v", newPeer.Location.ConnectionIP.String(), err)
|
||||
err = am.Store.ExecuteInTransaction(ctx, func(transaction Store) error {
|
||||
var groupsToAdd []string
|
||||
var setupKeyID string
|
||||
var setupKeyName string
|
||||
var ephemeral bool
|
||||
if addedByUser {
|
||||
user, err := transaction.GetUserByUserID(ctx, LockingStrengthUpdate, userID)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to get user groups: %w", err)
|
||||
}
|
||||
groupsToAdd = user.AutoGroups
|
||||
opEvent.InitiatorID = userID
|
||||
opEvent.Activity = activity.PeerAddedByUser
|
||||
} else {
|
||||
newPeer.Location.CountryCode = location.Country.ISOCode
|
||||
newPeer.Location.CityName = location.City.Names.En
|
||||
newPeer.Location.GeoNameID = location.City.GeonameID
|
||||
}
|
||||
}
|
||||
// Validate the setup key
|
||||
sk, err := transaction.GetSetupKeyBySecret(ctx, LockingStrengthUpdate, upperKey)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to get setup key: %w", err)
|
||||
}
|
||||
|
||||
// add peer to 'All' group
|
||||
group, err := account.GetGroupAll()
|
||||
if err != nil {
|
||||
return nil, nil, nil, err
|
||||
}
|
||||
group.Peers = append(group.Peers, newPeer.ID)
|
||||
if !sk.IsValid() {
|
||||
return status.Errorf(status.PreconditionFailed, "couldn't add peer: setup key is invalid")
|
||||
}
|
||||
|
||||
var groupsToAdd []string
|
||||
if addedByUser {
|
||||
groupsToAdd, err = account.getUserGroups(userID)
|
||||
if err != nil {
|
||||
return nil, nil, nil, err
|
||||
opEvent.InitiatorID = sk.Id
|
||||
opEvent.Activity = activity.PeerAddedWithSetupKey
|
||||
groupsToAdd = sk.AutoGroups
|
||||
ephemeral = sk.Ephemeral
|
||||
setupKeyID = sk.Id
|
||||
setupKeyName = sk.Name
|
||||
}
|
||||
} else {
|
||||
groupsToAdd, err = account.getSetupKeyGroups(upperKey)
|
||||
if err != nil {
|
||||
return nil, nil, nil, err
|
||||
}
|
||||
}
|
||||
|
||||
if len(groupsToAdd) > 0 {
|
||||
for _, s := range groupsToAdd {
|
||||
if g, ok := account.Groups[s]; ok && g.Name != "All" {
|
||||
g.Peers = append(g.Peers, newPeer.ID)
|
||||
if strings.ToLower(peer.Meta.Hostname) == "iphone" || strings.ToLower(peer.Meta.Hostname) == "ipad" && userID != "" {
|
||||
if am.idpManager != nil {
|
||||
userdata, err := am.idpManager.GetUserDataByID(ctx, userID, idp.AppMetadata{WTAccountID: accountID})
|
||||
if err == nil && userdata != nil {
|
||||
peer.Meta.Hostname = fmt.Sprintf("%s-%s", peer.Meta.Hostname, strings.Split(userdata.Email, "@")[0])
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
newPeer = am.integratedPeerValidator.PreparePeer(ctx, account.Id, newPeer, account.GetPeerGroupsList(newPeer.ID), account.Settings.Extra)
|
||||
|
||||
if addedByUser {
|
||||
user, err := account.FindUser(userID)
|
||||
freeLabel, err := am.getFreeDNSLabel(ctx, transaction, accountID, peer.Meta.Hostname)
|
||||
if err != nil {
|
||||
return nil, nil, nil, status.Errorf(status.Internal, "couldn't find user")
|
||||
return fmt.Errorf("failed to get free DNS label: %w", err)
|
||||
}
|
||||
user.updateLastLogin(newPeer.LastLogin)
|
||||
}
|
||||
|
||||
account.Peers[newPeer.ID] = newPeer
|
||||
account.Network.IncSerial()
|
||||
err = am.Store.SaveAccount(ctx, account)
|
||||
freeIP, err := am.getFreeIP(ctx, transaction, accountID)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to get free IP: %w", err)
|
||||
}
|
||||
|
||||
registrationTime := time.Now().UTC()
|
||||
newPeer = &nbpeer.Peer{
|
||||
ID: xid.New().String(),
|
||||
AccountID: accountID,
|
||||
Key: peer.Key,
|
||||
SetupKey: upperKey,
|
||||
IP: freeIP,
|
||||
Meta: peer.Meta,
|
||||
Name: peer.Meta.Hostname,
|
||||
DNSLabel: freeLabel,
|
||||
UserID: userID,
|
||||
Status: &nbpeer.PeerStatus{Connected: false, LastSeen: registrationTime},
|
||||
SSHEnabled: false,
|
||||
SSHKey: peer.SSHKey,
|
||||
LastLogin: registrationTime,
|
||||
CreatedAt: registrationTime,
|
||||
LoginExpirationEnabled: addedByUser,
|
||||
Ephemeral: ephemeral,
|
||||
Location: peer.Location,
|
||||
}
|
||||
opEvent.TargetID = newPeer.ID
|
||||
opEvent.Meta = newPeer.EventMeta(am.GetDNSDomain())
|
||||
if !addedByUser {
|
||||
opEvent.Meta["setup_key_name"] = setupKeyName
|
||||
}
|
||||
|
||||
if am.geo != nil && newPeer.Location.ConnectionIP != nil {
|
||||
location, err := am.geo.Lookup(newPeer.Location.ConnectionIP)
|
||||
if err != nil {
|
||||
log.WithContext(ctx).Warnf("failed to get location for new peer realip: [%s]: %v", newPeer.Location.ConnectionIP.String(), err)
|
||||
} else {
|
||||
newPeer.Location.CountryCode = location.Country.ISOCode
|
||||
newPeer.Location.CityName = location.City.Names.En
|
||||
newPeer.Location.GeoNameID = location.City.GeonameID
|
||||
}
|
||||
}
|
||||
|
||||
settings, err := transaction.GetAccountSettings(ctx, LockingStrengthShare, accountID)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to get account settings: %w", err)
|
||||
}
|
||||
newPeer = am.integratedPeerValidator.PreparePeer(ctx, accountID, newPeer, groupsToAdd, settings.Extra)
|
||||
|
||||
err = transaction.AddPeerToAllGroup(ctx, accountID, newPeer.ID)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed adding peer to All group: %w", err)
|
||||
}
|
||||
|
||||
if len(groupsToAdd) > 0 {
|
||||
for _, g := range groupsToAdd {
|
||||
err = transaction.AddPeerToGroup(ctx, accountID, newPeer.ID, g)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
err = transaction.AddPeerToAccount(ctx, newPeer)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to add peer to account: %w", err)
|
||||
}
|
||||
|
||||
err = transaction.IncrementNetworkSerial(ctx, accountID)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to increment network serial: %w", err)
|
||||
}
|
||||
|
||||
if addedByUser {
|
||||
err := transaction.SaveUserLastLogin(ctx, accountID, userID, newPeer.LastLogin)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to update user last login: %w", err)
|
||||
}
|
||||
} else {
|
||||
err = transaction.IncrementSetupKeyUsage(ctx, setupKeyID)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to increment setup key usage: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
log.WithContext(ctx).Debugf("Peer %s added to account %s", newPeer.ID, accountID)
|
||||
return nil
|
||||
})
|
||||
|
||||
if err != nil {
|
||||
return nil, nil, nil, err
|
||||
return nil, nil, nil, fmt.Errorf("failed to add peer to database: %w", err)
|
||||
}
|
||||
|
||||
// Account is saved, we can release the lock
|
||||
unlock()
|
||||
unlock = nil
|
||||
|
||||
opEvent.TargetID = newPeer.ID
|
||||
opEvent.Meta = newPeer.EventMeta(am.GetDNSDomain())
|
||||
if !addedByUser {
|
||||
opEvent.Meta["setup_key_name"] = setupKeyName
|
||||
if newPeer == nil {
|
||||
return nil, nil, nil, fmt.Errorf("new peer is nil")
|
||||
}
|
||||
|
||||
am.StoreEvent(ctx, opEvent.InitiatorID, opEvent.TargetID, opEvent.AccountID, opEvent.Activity, opEvent.Meta)
|
||||
|
||||
unlock()
|
||||
unlock = nil
|
||||
|
||||
account, err := am.requestBuffer.GetAccountWithBackpressure(ctx, accountID)
|
||||
if err != nil {
|
||||
return nil, nil, nil, fmt.Errorf("error getting account: %w", err)
|
||||
}
|
||||
|
||||
am.updateAccountPeers(ctx, account)
|
||||
|
||||
approvedPeersMap, err := am.GetValidatedPeers(account)
|
||||
@ -536,12 +548,31 @@ func (am *DefaultAccountManager) AddPeer(ctx context.Context, setupKey, userID s
|
||||
return nil, nil, nil, err
|
||||
}
|
||||
|
||||
postureChecks := am.getPeerPostureChecks(account, peer)
|
||||
postureChecks := am.getPeerPostureChecks(account, newPeer)
|
||||
customZone := account.GetPeersCustomZone(ctx, am.dnsDomain)
|
||||
networkMap := account.GetPeerNetworkMap(ctx, newPeer.ID, customZone, approvedPeersMap, am.metrics.AccountManagerMetrics())
|
||||
return newPeer, networkMap, postureChecks, nil
|
||||
}
|
||||
|
||||
func (am *DefaultAccountManager) getFreeIP(ctx context.Context, store Store, accountID string) (net.IP, error) {
|
||||
takenIps, err := store.GetTakenIPs(ctx, LockingStrengthShare, accountID)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to get taken IPs: %w", err)
|
||||
}
|
||||
|
||||
network, err := store.GetAccountNetwork(ctx, LockingStrengthUpdate, accountID)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed getting network: %w", err)
|
||||
}
|
||||
|
||||
nextIp, err := AllocatePeerIP(network.Net, takenIps)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to allocate new peer ip: %w", err)
|
||||
}
|
||||
|
||||
return nextIp, nil
|
||||
}
|
||||
|
||||
// SyncPeer checks whether peer is eligible for receiving NetworkMap (authenticated) and returns its NetworkMap if eligible
|
||||
func (am *DefaultAccountManager) SyncPeer(ctx context.Context, sync PeerSync, account *Account) (*nbpeer.Peer, *NetworkMap, []*posture.Checks, error) {
|
||||
peer, err := account.FindPeerByPubKey(sync.WireGuardPubKey)
|
||||
@ -647,12 +678,12 @@ func (am *DefaultAccountManager) LoginPeer(ctx context.Context, login PeerLogin)
|
||||
}
|
||||
}()
|
||||
|
||||
peer, err := am.Store.GetPeerByPeerPubKey(ctx, login.WireGuardPubKey)
|
||||
peer, err := am.Store.GetPeerByPeerPubKey(ctx, LockingStrengthUpdate, login.WireGuardPubKey)
|
||||
if err != nil {
|
||||
return nil, nil, nil, err
|
||||
}
|
||||
|
||||
settings, err := am.Store.GetAccountSettings(ctx, accountID)
|
||||
settings, err := am.Store.GetAccountSettings(ctx, LockingStrengthShare, accountID)
|
||||
if err != nil {
|
||||
return nil, nil, nil, err
|
||||
}
|
||||
@ -730,7 +761,7 @@ func (am *DefaultAccountManager) LoginPeer(ctx context.Context, login PeerLogin)
|
||||
// with no JWT token and usually no setup-key. As the client can send up to two login request to check if it is expired
|
||||
// and before starting the engine, we do the checks without an account lock to avoid piling up requests.
|
||||
func (am *DefaultAccountManager) checkIFPeerNeedsLoginWithoutLock(ctx context.Context, accountID string, login PeerLogin) error {
|
||||
peer, err := am.Store.GetPeerByPeerPubKey(ctx, login.WireGuardPubKey)
|
||||
peer, err := am.Store.GetPeerByPeerPubKey(ctx, LockingStrengthShare, login.WireGuardPubKey)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@ -741,7 +772,7 @@ func (am *DefaultAccountManager) checkIFPeerNeedsLoginWithoutLock(ctx context.Co
|
||||
return nil
|
||||
}
|
||||
|
||||
settings, err := am.Store.GetAccountSettings(ctx, accountID)
|
||||
settings, err := am.Store.GetAccountSettings(ctx, LockingStrengthShare, accountID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@ -786,7 +817,7 @@ func (am *DefaultAccountManager) handleExpiredPeer(ctx context.Context, user *Us
|
||||
return err
|
||||
}
|
||||
|
||||
err = am.Store.SaveUserLastLogin(user.AccountID, user.Id, peer.LastLogin)
|
||||
err = am.Store.SaveUserLastLogin(ctx, user.AccountID, user.Id, peer.LastLogin)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@ -969,3 +1000,11 @@ func (am *DefaultAccountManager) updateAccountPeers(ctx context.Context, account
|
||||
|
||||
wg.Wait()
|
||||
}
|
||||
|
||||
func ConvertSliceToMap(existingLabels []string) map[string]struct{} {
|
||||
labelMap := make(map[string]struct{}, len(existingLabels))
|
||||
for _, label := range existingLabels {
|
||||
labelMap[label] = struct{}{}
|
||||
}
|
||||
return labelMap
|
||||
}
|
||||
|
@ -7,20 +7,24 @@ import (
|
||||
"net"
|
||||
"net/netip"
|
||||
"os"
|
||||
"runtime"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/rs/xid"
|
||||
log "github.com/sirupsen/logrus"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
|
||||
|
||||
nbdns "github.com/netbirdio/netbird/dns"
|
||||
"github.com/netbirdio/netbird/management/domain"
|
||||
"github.com/netbirdio/netbird/management/proto"
|
||||
"github.com/netbirdio/netbird/management/server/activity"
|
||||
nbgroup "github.com/netbirdio/netbird/management/server/group"
|
||||
nbpeer "github.com/netbirdio/netbird/management/server/peer"
|
||||
"github.com/netbirdio/netbird/management/server/posture"
|
||||
"github.com/netbirdio/netbird/management/server/telemetry"
|
||||
nbroute "github.com/netbirdio/netbird/route"
|
||||
)
|
||||
|
||||
@ -995,3 +999,184 @@ func TestToSyncResponse(t *testing.T) {
|
||||
assert.Equal(t, 1, len(response.Checks))
|
||||
assert.Equal(t, "/usr/bin/netbird", response.Checks[0].Files[0])
|
||||
}
|
||||
|
||||
func Test_RegisterPeerByUser(t *testing.T) {
|
||||
if runtime.GOOS == "windows" {
|
||||
t.Skip("The SQLite store is not properly supported by Windows yet")
|
||||
}
|
||||
|
||||
store := newSqliteStoreFromFile(t, "testdata/extended-store.json")
|
||||
|
||||
eventStore := &activity.InMemoryEventStore{}
|
||||
|
||||
metrics, err := telemetry.NewDefaultAppMetrics(context.Background())
|
||||
assert.NoError(t, err)
|
||||
|
||||
am, err := BuildManager(context.Background(), store, NewPeersUpdateManager(nil), nil, "", "netbird.cloud", eventStore, nil, false, MocIntegratedValidator{}, metrics)
|
||||
assert.NoError(t, err)
|
||||
|
||||
existingAccountID := "bf1c8084-ba50-4ce7-9439-34653001fc3b"
|
||||
existingUserID := "edafee4e-63fb-11ec-90d6-0242ac120003"
|
||||
|
||||
_, err = store.GetAccount(context.Background(), existingAccountID)
|
||||
require.NoError(t, err)
|
||||
|
||||
newPeer := &nbpeer.Peer{
|
||||
ID: xid.New().String(),
|
||||
AccountID: existingAccountID,
|
||||
Key: "newPeerKey",
|
||||
SetupKey: "",
|
||||
IP: net.IP{123, 123, 123, 123},
|
||||
Meta: nbpeer.PeerSystemMeta{
|
||||
Hostname: "newPeer",
|
||||
GoOS: "linux",
|
||||
},
|
||||
Name: "newPeerName",
|
||||
DNSLabel: "newPeer.test",
|
||||
UserID: existingUserID,
|
||||
Status: &nbpeer.PeerStatus{Connected: false, LastSeen: time.Now()},
|
||||
SSHEnabled: false,
|
||||
LastLogin: time.Now(),
|
||||
}
|
||||
|
||||
addedPeer, _, _, err := am.AddPeer(context.Background(), "", existingUserID, newPeer)
|
||||
require.NoError(t, err)
|
||||
|
||||
peer, err := store.GetPeerByPeerPubKey(context.Background(), LockingStrengthShare, addedPeer.Key)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, peer.AccountID, existingAccountID)
|
||||
assert.Equal(t, peer.UserID, existingUserID)
|
||||
|
||||
account, err := store.GetAccount(context.Background(), existingAccountID)
|
||||
require.NoError(t, err)
|
||||
assert.Contains(t, account.Peers, addedPeer.ID)
|
||||
assert.Equal(t, peer.Meta.Hostname, newPeer.Meta.Hostname)
|
||||
assert.Contains(t, account.Groups["cfefqs706sqkneg59g3g"].Peers, addedPeer.ID)
|
||||
assert.Contains(t, account.Groups["cfefqs706sqkneg59g4g"].Peers, addedPeer.ID)
|
||||
|
||||
assert.Equal(t, uint64(1), account.Network.Serial)
|
||||
|
||||
lastLogin, err := time.Parse("2006-01-02T15:04:05Z", "0001-01-01T00:00:00Z")
|
||||
assert.NoError(t, err)
|
||||
assert.NotEqual(t, lastLogin, account.Users[existingUserID].LastLogin)
|
||||
}
|
||||
|
||||
func Test_RegisterPeerBySetupKey(t *testing.T) {
|
||||
if runtime.GOOS == "windows" {
|
||||
t.Skip("The SQLite store is not properly supported by Windows yet")
|
||||
}
|
||||
|
||||
store := newSqliteStoreFromFile(t, "testdata/extended-store.json")
|
||||
|
||||
eventStore := &activity.InMemoryEventStore{}
|
||||
|
||||
metrics, err := telemetry.NewDefaultAppMetrics(context.Background())
|
||||
assert.NoError(t, err)
|
||||
|
||||
am, err := BuildManager(context.Background(), store, NewPeersUpdateManager(nil), nil, "", "netbird.cloud", eventStore, nil, false, MocIntegratedValidator{}, metrics)
|
||||
assert.NoError(t, err)
|
||||
|
||||
existingAccountID := "bf1c8084-ba50-4ce7-9439-34653001fc3b"
|
||||
existingSetupKeyID := "A2C8E62B-38F5-4553-B31E-DD66C696CEBB"
|
||||
|
||||
_, err = store.GetAccount(context.Background(), existingAccountID)
|
||||
require.NoError(t, err)
|
||||
|
||||
newPeer := &nbpeer.Peer{
|
||||
ID: xid.New().String(),
|
||||
AccountID: existingAccountID,
|
||||
Key: "newPeerKey",
|
||||
SetupKey: "existingSetupKey",
|
||||
UserID: "",
|
||||
IP: net.IP{123, 123, 123, 123},
|
||||
Meta: nbpeer.PeerSystemMeta{
|
||||
Hostname: "newPeer",
|
||||
GoOS: "linux",
|
||||
},
|
||||
Name: "newPeerName",
|
||||
DNSLabel: "newPeer.test",
|
||||
Status: &nbpeer.PeerStatus{Connected: false, LastSeen: time.Now()},
|
||||
SSHEnabled: false,
|
||||
}
|
||||
|
||||
addedPeer, _, _, err := am.AddPeer(context.Background(), existingSetupKeyID, "", newPeer)
|
||||
|
||||
require.NoError(t, err)
|
||||
|
||||
peer, err := store.GetPeerByPeerPubKey(context.Background(), LockingStrengthShare, newPeer.Key)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, peer.AccountID, existingAccountID)
|
||||
assert.Equal(t, peer.SetupKey, existingSetupKeyID)
|
||||
|
||||
account, err := store.GetAccount(context.Background(), existingAccountID)
|
||||
require.NoError(t, err)
|
||||
assert.Contains(t, account.Peers, addedPeer.ID)
|
||||
assert.Contains(t, account.Groups["cfefqs706sqkneg59g2g"].Peers, addedPeer.ID)
|
||||
assert.Contains(t, account.Groups["cfefqs706sqkneg59g4g"].Peers, addedPeer.ID)
|
||||
|
||||
assert.Equal(t, uint64(1), account.Network.Serial)
|
||||
|
||||
lastUsed, err := time.Parse("2006-01-02T15:04:05Z", "0001-01-01T00:00:00Z")
|
||||
assert.NoError(t, err)
|
||||
assert.NotEqual(t, lastUsed, account.SetupKeys[existingSetupKeyID].LastUsed)
|
||||
assert.Equal(t, 1, account.SetupKeys[existingSetupKeyID].UsedTimes)
|
||||
|
||||
}
|
||||
|
||||
func Test_RegisterPeerRollbackOnFailure(t *testing.T) {
|
||||
if runtime.GOOS == "windows" {
|
||||
t.Skip("The SQLite store is not properly supported by Windows yet")
|
||||
}
|
||||
|
||||
store := newSqliteStoreFromFile(t, "testdata/extended-store.json")
|
||||
|
||||
eventStore := &activity.InMemoryEventStore{}
|
||||
|
||||
metrics, err := telemetry.NewDefaultAppMetrics(context.Background())
|
||||
assert.NoError(t, err)
|
||||
|
||||
am, err := BuildManager(context.Background(), store, NewPeersUpdateManager(nil), nil, "", "netbird.cloud", eventStore, nil, false, MocIntegratedValidator{}, metrics)
|
||||
assert.NoError(t, err)
|
||||
|
||||
existingAccountID := "bf1c8084-ba50-4ce7-9439-34653001fc3b"
|
||||
faultyKey := "A2C8E62B-38F5-4553-B31E-DD66C696CEBC"
|
||||
|
||||
_, err = store.GetAccount(context.Background(), existingAccountID)
|
||||
require.NoError(t, err)
|
||||
|
||||
newPeer := &nbpeer.Peer{
|
||||
ID: xid.New().String(),
|
||||
AccountID: existingAccountID,
|
||||
Key: "newPeerKey",
|
||||
SetupKey: "existingSetupKey",
|
||||
UserID: "",
|
||||
IP: net.IP{123, 123, 123, 123},
|
||||
Meta: nbpeer.PeerSystemMeta{
|
||||
Hostname: "newPeer",
|
||||
GoOS: "linux",
|
||||
},
|
||||
Name: "newPeerName",
|
||||
DNSLabel: "newPeer.test",
|
||||
Status: &nbpeer.PeerStatus{Connected: false, LastSeen: time.Now()},
|
||||
SSHEnabled: false,
|
||||
}
|
||||
|
||||
_, _, _, err = am.AddPeer(context.Background(), faultyKey, "", newPeer)
|
||||
require.Error(t, err)
|
||||
|
||||
_, err = store.GetPeerByPeerPubKey(context.Background(), LockingStrengthShare, newPeer.Key)
|
||||
require.Error(t, err)
|
||||
|
||||
account, err := store.GetAccount(context.Background(), existingAccountID)
|
||||
require.NoError(t, err)
|
||||
assert.NotContains(t, account.Peers, newPeer.ID)
|
||||
assert.NotContains(t, account.Groups["cfefqs706sqkneg59g3g"].Peers, newPeer.ID)
|
||||
assert.NotContains(t, account.Groups["cfefqs706sqkneg59g4g"].Peers, newPeer.ID)
|
||||
|
||||
assert.Equal(t, uint64(0), account.Network.Serial)
|
||||
|
||||
lastUsed, err := time.Parse("2006-01-02T15:04:05Z", "0001-01-01T00:00:00Z")
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, lastUsed, account.SetupKeys[faultyKey].LastUsed)
|
||||
assert.Equal(t, 0, account.SetupKeys[faultyKey].UsedTimes)
|
||||
}
|
||||
|
@ -5,6 +5,7 @@ import (
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"runtime"
|
||||
@ -33,6 +34,7 @@ import (
|
||||
const (
|
||||
storeSqliteFileName = "store.db"
|
||||
idQueryCondition = "id = ?"
|
||||
keyQueryCondition = "key = ?"
|
||||
accountAndIDQueryCondition = "account_id = ? and id = ?"
|
||||
peerNotFoundFMT = "peer %s not found"
|
||||
)
|
||||
@ -415,13 +417,12 @@ func (s *SqlStore) GetAccountByPrivateDomain(ctx context.Context, domain string)
|
||||
|
||||
func (s *SqlStore) GetAccountBySetupKey(ctx context.Context, setupKey string) (*Account, error) {
|
||||
var key SetupKey
|
||||
result := s.db.Select("account_id").First(&key, "key = ?", strings.ToUpper(setupKey))
|
||||
result := s.db.WithContext(ctx).Select("account_id").First(&key, keyQueryCondition, strings.ToUpper(setupKey))
|
||||
if result.Error != nil {
|
||||
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
|
||||
return nil, status.Errorf(status.NotFound, "account not found: index lookup failed")
|
||||
}
|
||||
log.WithContext(ctx).Errorf("error when getting setup key from the store: %s", result.Error)
|
||||
return nil, status.Errorf(status.Internal, "issue getting setup key from store")
|
||||
return nil, status.NewSetupKeyNotFoundError()
|
||||
}
|
||||
|
||||
if key.AccountID == "" {
|
||||
@ -474,15 +475,15 @@ func (s *SqlStore) GetUserByTokenID(ctx context.Context, tokenID string) (*User,
|
||||
return &user, nil
|
||||
}
|
||||
|
||||
func (s *SqlStore) GetUserByUserID(ctx context.Context, userID string) (*User, error) {
|
||||
func (s *SqlStore) GetUserByUserID(ctx context.Context, lockStrength LockingStrength, userID string) (*User, error) {
|
||||
var user User
|
||||
result := s.db.First(&user, idQueryCondition, userID)
|
||||
result := s.db.WithContext(ctx).Clauses(clause.Locking{Strength: string(lockStrength)}).
|
||||
First(&user, idQueryCondition, userID)
|
||||
if result.Error != nil {
|
||||
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
|
||||
return nil, status.Errorf(status.NotFound, "user not found: index lookup failed")
|
||||
return nil, status.NewUserNotFoundError(userID)
|
||||
}
|
||||
log.WithContext(ctx).Errorf("error when getting user from the store: %s", result.Error)
|
||||
return nil, status.Errorf(status.Internal, "issue getting user from store")
|
||||
return nil, status.NewGetUserFromStoreError()
|
||||
}
|
||||
|
||||
return &user, nil
|
||||
@ -535,7 +536,7 @@ func (s *SqlStore) GetAccount(ctx context.Context, accountID string) (*Account,
|
||||
if result.Error != nil {
|
||||
log.WithContext(ctx).Errorf("error when getting account %s from the store: %s", accountID, result.Error)
|
||||
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
|
||||
return nil, status.Errorf(status.NotFound, "account not found")
|
||||
return nil, status.NewAccountNotFoundError(accountID)
|
||||
}
|
||||
return nil, status.Errorf(status.Internal, "issue getting account from store")
|
||||
}
|
||||
@ -595,7 +596,7 @@ func (s *SqlStore) GetAccount(ctx context.Context, accountID string) (*Account,
|
||||
|
||||
func (s *SqlStore) GetAccountByUser(ctx context.Context, userID string) (*Account, error) {
|
||||
var user User
|
||||
result := s.db.Select("account_id").First(&user, idQueryCondition, userID)
|
||||
result := s.db.WithContext(ctx).Select("account_id").First(&user, idQueryCondition, userID)
|
||||
if result.Error != nil {
|
||||
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
|
||||
return nil, status.Errorf(status.NotFound, "account not found: index lookup failed")
|
||||
@ -612,12 +613,11 @@ func (s *SqlStore) GetAccountByUser(ctx context.Context, userID string) (*Accoun
|
||||
|
||||
func (s *SqlStore) GetAccountByPeerID(ctx context.Context, peerID string) (*Account, error) {
|
||||
var peer nbpeer.Peer
|
||||
result := s.db.Select("account_id").First(&peer, idQueryCondition, peerID)
|
||||
result := s.db.WithContext(ctx).Select("account_id").First(&peer, idQueryCondition, peerID)
|
||||
if result.Error != nil {
|
||||
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
|
||||
return nil, status.Errorf(status.NotFound, "account not found: index lookup failed")
|
||||
}
|
||||
log.WithContext(ctx).Errorf("error when getting peer from the store: %s", result.Error)
|
||||
return nil, status.Errorf(status.Internal, "issue getting account from store")
|
||||
}
|
||||
|
||||
@ -631,12 +631,11 @@ func (s *SqlStore) GetAccountByPeerID(ctx context.Context, peerID string) (*Acco
|
||||
func (s *SqlStore) GetAccountByPeerPubKey(ctx context.Context, peerKey string) (*Account, error) {
|
||||
var peer nbpeer.Peer
|
||||
|
||||
result := s.db.Select("account_id").First(&peer, "key = ?", peerKey)
|
||||
result := s.db.WithContext(ctx).Select("account_id").First(&peer, keyQueryCondition, peerKey)
|
||||
if result.Error != nil {
|
||||
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
|
||||
return nil, status.Errorf(status.NotFound, "account not found: index lookup failed")
|
||||
}
|
||||
log.WithContext(ctx).Errorf("error when getting peer from the store: %s", result.Error)
|
||||
return nil, status.Errorf(status.Internal, "issue getting account from store")
|
||||
}
|
||||
|
||||
@ -650,12 +649,11 @@ func (s *SqlStore) GetAccountByPeerPubKey(ctx context.Context, peerKey string) (
|
||||
func (s *SqlStore) GetAccountIDByPeerPubKey(ctx context.Context, peerKey string) (string, error) {
|
||||
var peer nbpeer.Peer
|
||||
var accountID string
|
||||
result := s.db.Model(&peer).Select("account_id").Where("key = ?", peerKey).First(&accountID)
|
||||
result := s.db.WithContext(ctx).Model(&peer).Select("account_id").Where(keyQueryCondition, peerKey).First(&accountID)
|
||||
if result.Error != nil {
|
||||
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
|
||||
return "", status.Errorf(status.NotFound, "account not found: index lookup failed")
|
||||
}
|
||||
log.WithContext(ctx).Errorf("error when getting peer from the store: %s", result.Error)
|
||||
return "", status.Errorf(status.Internal, "issue getting account from store")
|
||||
}
|
||||
|
||||
@ -677,61 +675,117 @@ func (s *SqlStore) GetAccountIDByUserID(userID string) (string, error) {
|
||||
}
|
||||
|
||||
func (s *SqlStore) GetAccountIDBySetupKey(ctx context.Context, setupKey string) (string, error) {
|
||||
var key SetupKey
|
||||
var accountID string
|
||||
result := s.db.Model(&key).Select("account_id").Where("key = ?", strings.ToUpper(setupKey)).First(&accountID)
|
||||
result := s.db.WithContext(ctx).Model(&SetupKey{}).Select("account_id").Where(keyQueryCondition, strings.ToUpper(setupKey)).First(&accountID)
|
||||
if result.Error != nil {
|
||||
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
|
||||
return "", status.Errorf(status.NotFound, "account not found: index lookup failed")
|
||||
}
|
||||
log.WithContext(ctx).Errorf("error when getting setup key from the store: %s", result.Error)
|
||||
return "", status.Errorf(status.Internal, "issue getting setup key from store")
|
||||
return "", status.NewSetupKeyNotFoundError()
|
||||
}
|
||||
|
||||
if accountID == "" {
|
||||
return "", status.Errorf(status.NotFound, "account not found: index lookup failed")
|
||||
}
|
||||
|
||||
return accountID, nil
|
||||
}
|
||||
|
||||
func (s *SqlStore) GetPeerByPeerPubKey(ctx context.Context, peerKey string) (*nbpeer.Peer, error) {
|
||||
func (s *SqlStore) GetTakenIPs(ctx context.Context, lockStrength LockingStrength, accountID string) ([]net.IP, error) {
|
||||
var ipJSONStrings []string
|
||||
|
||||
// Fetch the IP addresses as JSON strings
|
||||
result := s.db.WithContext(ctx).Clauses(clause.Locking{Strength: string(lockStrength)}).Model(&nbpeer.Peer{}).
|
||||
Where("account_id = ?", accountID).
|
||||
Pluck("ip", &ipJSONStrings)
|
||||
if result.Error != nil {
|
||||
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
|
||||
return nil, status.Errorf(status.NotFound, "no peers found for the account")
|
||||
}
|
||||
return nil, status.Errorf(status.Internal, "issue getting IPs from store")
|
||||
}
|
||||
|
||||
// Convert the JSON strings to net.IP objects
|
||||
ips := make([]net.IP, len(ipJSONStrings))
|
||||
for i, ipJSON := range ipJSONStrings {
|
||||
var ip net.IP
|
||||
if err := json.Unmarshal([]byte(ipJSON), &ip); err != nil {
|
||||
return nil, status.Errorf(status.Internal, "issue parsing IP JSON from store")
|
||||
}
|
||||
ips[i] = ip
|
||||
}
|
||||
|
||||
return ips, nil
|
||||
}
|
||||
|
||||
func (s *SqlStore) GetPeerLabelsInAccount(ctx context.Context, lockStrength LockingStrength, accountID string) ([]string, error) {
|
||||
var labels []string
|
||||
|
||||
result := s.db.WithContext(ctx).Clauses(clause.Locking{Strength: string(lockStrength)}).Model(&nbpeer.Peer{}).
|
||||
Where("account_id = ?", accountID).
|
||||
Pluck("dns_label", &labels)
|
||||
|
||||
if result.Error != nil {
|
||||
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
|
||||
return nil, status.Errorf(status.NotFound, "no peers found for the account")
|
||||
}
|
||||
log.WithContext(ctx).Errorf("error when getting dns labels from the store: %s", result.Error)
|
||||
return nil, status.Errorf(status.Internal, "issue getting dns labels from store")
|
||||
}
|
||||
|
||||
return labels, nil
|
||||
}
|
||||
|
||||
func (s *SqlStore) GetAccountNetwork(ctx context.Context, lockStrength LockingStrength, accountID string) (*Network, error) {
|
||||
var accountNetwork AccountNetwork
|
||||
|
||||
if err := s.db.WithContext(ctx).Model(&Account{}).Where(idQueryCondition, accountID).First(&accountNetwork).Error; err != nil {
|
||||
if errors.Is(err, gorm.ErrRecordNotFound) {
|
||||
return nil, status.NewAccountNotFoundError(accountID)
|
||||
}
|
||||
return nil, status.Errorf(status.Internal, "issue getting network from store")
|
||||
}
|
||||
return accountNetwork.Network, nil
|
||||
}
|
||||
|
||||
func (s *SqlStore) GetPeerByPeerPubKey(ctx context.Context, lockStrength LockingStrength, peerKey string) (*nbpeer.Peer, error) {
|
||||
var peer nbpeer.Peer
|
||||
result := s.db.First(&peer, "key = ?", peerKey)
|
||||
result := s.db.WithContext(ctx).Clauses(clause.Locking{Strength: string(lockStrength)}).First(&peer, keyQueryCondition, peerKey)
|
||||
if result.Error != nil {
|
||||
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
|
||||
return nil, status.Errorf(status.NotFound, "peer not found")
|
||||
}
|
||||
log.WithContext(ctx).Errorf("error when getting peer from the store: %s", result.Error)
|
||||
return nil, status.Errorf(status.Internal, "issue getting peer from store")
|
||||
}
|
||||
|
||||
return &peer, nil
|
||||
}
|
||||
|
||||
func (s *SqlStore) GetAccountSettings(ctx context.Context, accountID string) (*Settings, error) {
|
||||
func (s *SqlStore) GetAccountSettings(ctx context.Context, lockStrength LockingStrength, accountID string) (*Settings, error) {
|
||||
var accountSettings AccountSettings
|
||||
if err := s.db.Model(&Account{}).Where(idQueryCondition, accountID).First(&accountSettings).Error; err != nil {
|
||||
if err := s.db.WithContext(ctx).Clauses(clause.Locking{Strength: string(lockStrength)}).Model(&Account{}).Where(idQueryCondition, accountID).First(&accountSettings).Error; err != nil {
|
||||
if errors.Is(err, gorm.ErrRecordNotFound) {
|
||||
return nil, status.Errorf(status.NotFound, "settings not found")
|
||||
}
|
||||
log.WithContext(ctx).Errorf("error when getting settings from the store: %s", err)
|
||||
return nil, status.Errorf(status.Internal, "issue getting settings from store")
|
||||
}
|
||||
return accountSettings.Settings, nil
|
||||
}
|
||||
|
||||
// SaveUserLastLogin stores the last login time for a user in DB.
|
||||
func (s *SqlStore) SaveUserLastLogin(accountID, userID string, lastLogin time.Time) error {
|
||||
func (s *SqlStore) SaveUserLastLogin(ctx context.Context, accountID, userID string, lastLogin time.Time) error {
|
||||
var user User
|
||||
|
||||
result := s.db.First(&user, accountAndIDQueryCondition, accountID, userID)
|
||||
result := s.db.WithContext(ctx).First(&user, accountAndIDQueryCondition, accountID, userID)
|
||||
if result.Error != nil {
|
||||
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
|
||||
return status.Errorf(status.NotFound, "user %s not found", userID)
|
||||
return status.NewUserNotFoundError(userID)
|
||||
}
|
||||
return status.Errorf(status.Internal, "issue getting user from store")
|
||||
return status.NewGetUserFromStoreError()
|
||||
}
|
||||
|
||||
user.LastLogin = lastLogin
|
||||
|
||||
return s.db.Save(user).Error
|
||||
return s.db.Save(&user).Error
|
||||
}
|
||||
|
||||
func (s *SqlStore) GetPostureCheckByChecksDefinition(accountID string, checks *posture.ChecksDefinition) (*posture.Checks, error) {
|
||||
@ -850,3 +904,123 @@ func NewPostgresqlStoreFromFileStore(ctx context.Context, fileStore *FileStore,
|
||||
|
||||
return store, nil
|
||||
}
|
||||
|
||||
func (s *SqlStore) GetSetupKeyBySecret(ctx context.Context, lockStrength LockingStrength, key string) (*SetupKey, error) {
|
||||
var setupKey SetupKey
|
||||
result := s.db.WithContext(ctx).Clauses(clause.Locking{Strength: string(lockStrength)}).
|
||||
First(&setupKey, keyQueryCondition, strings.ToUpper(key))
|
||||
if result.Error != nil {
|
||||
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
|
||||
return nil, status.Errorf(status.NotFound, "setup key not found")
|
||||
}
|
||||
return nil, status.NewSetupKeyNotFoundError()
|
||||
}
|
||||
return &setupKey, nil
|
||||
}
|
||||
|
||||
func (s *SqlStore) IncrementSetupKeyUsage(ctx context.Context, setupKeyID string) error {
|
||||
result := s.db.WithContext(ctx).Model(&SetupKey{}).
|
||||
Where(idQueryCondition, setupKeyID).
|
||||
Updates(map[string]interface{}{
|
||||
"used_times": gorm.Expr("used_times + 1"),
|
||||
"last_used": time.Now(),
|
||||
})
|
||||
|
||||
if result.Error != nil {
|
||||
return status.Errorf(status.Internal, "issue incrementing setup key usage count: %s", result.Error)
|
||||
}
|
||||
|
||||
if result.RowsAffected == 0 {
|
||||
return status.Errorf(status.NotFound, "setup key not found")
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *SqlStore) AddPeerToAllGroup(ctx context.Context, accountID string, peerID string) error {
|
||||
var group nbgroup.Group
|
||||
|
||||
result := s.db.WithContext(ctx).Where("account_id = ? AND name = ?", accountID, "All").First(&group)
|
||||
if result.Error != nil {
|
||||
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
|
||||
return status.Errorf(status.NotFound, "group 'All' not found for account")
|
||||
}
|
||||
return status.Errorf(status.Internal, "issue finding group 'All'")
|
||||
}
|
||||
|
||||
for _, existingPeerID := range group.Peers {
|
||||
if existingPeerID == peerID {
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
group.Peers = append(group.Peers, peerID)
|
||||
|
||||
if err := s.db.Save(&group).Error; err != nil {
|
||||
return status.Errorf(status.Internal, "issue updating group 'All'")
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *SqlStore) AddPeerToGroup(ctx context.Context, accountId string, peerId string, groupID string) error {
|
||||
var group nbgroup.Group
|
||||
|
||||
result := s.db.WithContext(ctx).Where(accountAndIDQueryCondition, accountId, groupID).First(&group)
|
||||
if result.Error != nil {
|
||||
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
|
||||
return status.Errorf(status.NotFound, "group not found for account")
|
||||
}
|
||||
return status.Errorf(status.Internal, "issue finding group")
|
||||
}
|
||||
|
||||
for _, existingPeerID := range group.Peers {
|
||||
if existingPeerID == peerId {
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
group.Peers = append(group.Peers, peerId)
|
||||
|
||||
if err := s.db.Save(&group).Error; err != nil {
|
||||
return status.Errorf(status.Internal, "issue updating group")
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *SqlStore) AddPeerToAccount(ctx context.Context, peer *nbpeer.Peer) error {
|
||||
if err := s.db.WithContext(ctx).Create(peer).Error; err != nil {
|
||||
return status.Errorf(status.Internal, "issue adding peer to account")
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *SqlStore) IncrementNetworkSerial(ctx context.Context, accountId string) error {
|
||||
result := s.db.WithContext(ctx).Model(&Account{}).Where(idQueryCondition, accountId).Update("network_serial", gorm.Expr("network_serial + 1"))
|
||||
if result.Error != nil {
|
||||
return status.Errorf(status.Internal, "issue incrementing network serial count")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *SqlStore) ExecuteInTransaction(ctx context.Context, operation func(store Store) error) error {
|
||||
tx := s.db.WithContext(ctx).Begin()
|
||||
if tx.Error != nil {
|
||||
return tx.Error
|
||||
}
|
||||
repo := s.withTx(tx)
|
||||
err := operation(repo)
|
||||
if err != nil {
|
||||
tx.Rollback()
|
||||
return err
|
||||
}
|
||||
return tx.Commit().Error
|
||||
}
|
||||
|
||||
func (s *SqlStore) withTx(tx *gorm.DB) Store {
|
||||
return &SqlStore{
|
||||
db: tx,
|
||||
}
|
||||
}
|
||||
|
@ -1003,3 +1003,163 @@ func TestPostgresql_GetUserByTokenID(t *testing.T) {
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, id, user.PATs[id].ID)
|
||||
}
|
||||
|
||||
func TestSqlite_GetTakenIPs(t *testing.T) {
|
||||
if runtime.GOOS == "windows" {
|
||||
t.Skip("The SQLite store is not properly supported by Windows yet")
|
||||
}
|
||||
|
||||
store := newSqliteStoreFromFile(t, "testdata/extended-store.json")
|
||||
defer store.Close(context.Background())
|
||||
|
||||
existingAccountID := "bf1c8084-ba50-4ce7-9439-34653001fc3b"
|
||||
|
||||
_, err := store.GetAccount(context.Background(), existingAccountID)
|
||||
require.NoError(t, err)
|
||||
|
||||
takenIPs, err := store.GetTakenIPs(context.Background(), LockingStrengthShare, existingAccountID)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, []net.IP{}, takenIPs)
|
||||
|
||||
peer1 := &nbpeer.Peer{
|
||||
ID: "peer1",
|
||||
AccountID: existingAccountID,
|
||||
IP: net.IP{1, 1, 1, 1},
|
||||
}
|
||||
err = store.AddPeerToAccount(context.Background(), peer1)
|
||||
require.NoError(t, err)
|
||||
|
||||
takenIPs, err = store.GetTakenIPs(context.Background(), LockingStrengthShare, existingAccountID)
|
||||
require.NoError(t, err)
|
||||
ip1 := net.IP{1, 1, 1, 1}.To16()
|
||||
assert.Equal(t, []net.IP{ip1}, takenIPs)
|
||||
|
||||
peer2 := &nbpeer.Peer{
|
||||
ID: "peer2",
|
||||
AccountID: existingAccountID,
|
||||
IP: net.IP{2, 2, 2, 2},
|
||||
}
|
||||
err = store.AddPeerToAccount(context.Background(), peer2)
|
||||
require.NoError(t, err)
|
||||
|
||||
takenIPs, err = store.GetTakenIPs(context.Background(), LockingStrengthShare, existingAccountID)
|
||||
require.NoError(t, err)
|
||||
ip2 := net.IP{2, 2, 2, 2}.To16()
|
||||
assert.Equal(t, []net.IP{ip1, ip2}, takenIPs)
|
||||
|
||||
}
|
||||
|
||||
func TestSqlite_GetPeerLabelsInAccount(t *testing.T) {
|
||||
if runtime.GOOS == "windows" {
|
||||
t.Skip("The SQLite store is not properly supported by Windows yet")
|
||||
}
|
||||
|
||||
store := newSqliteStoreFromFile(t, "testdata/extended-store.json")
|
||||
defer store.Close(context.Background())
|
||||
|
||||
existingAccountID := "bf1c8084-ba50-4ce7-9439-34653001fc3b"
|
||||
|
||||
_, err := store.GetAccount(context.Background(), existingAccountID)
|
||||
require.NoError(t, err)
|
||||
|
||||
labels, err := store.GetPeerLabelsInAccount(context.Background(), LockingStrengthShare, existingAccountID)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, []string{}, labels)
|
||||
|
||||
peer1 := &nbpeer.Peer{
|
||||
ID: "peer1",
|
||||
AccountID: existingAccountID,
|
||||
DNSLabel: "peer1.domain.test",
|
||||
}
|
||||
err = store.AddPeerToAccount(context.Background(), peer1)
|
||||
require.NoError(t, err)
|
||||
|
||||
labels, err = store.GetPeerLabelsInAccount(context.Background(), LockingStrengthShare, existingAccountID)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, []string{"peer1.domain.test"}, labels)
|
||||
|
||||
peer2 := &nbpeer.Peer{
|
||||
ID: "peer2",
|
||||
AccountID: existingAccountID,
|
||||
DNSLabel: "peer2.domain.test",
|
||||
}
|
||||
err = store.AddPeerToAccount(context.Background(), peer2)
|
||||
require.NoError(t, err)
|
||||
|
||||
labels, err = store.GetPeerLabelsInAccount(context.Background(), LockingStrengthShare, existingAccountID)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, []string{"peer1.domain.test", "peer2.domain.test"}, labels)
|
||||
}
|
||||
|
||||
func TestSqlite_GetAccountNetwork(t *testing.T) {
|
||||
if runtime.GOOS == "windows" {
|
||||
t.Skip("The SQLite store is not properly supported by Windows yet")
|
||||
}
|
||||
|
||||
store := newSqliteStoreFromFile(t, "testdata/extended-store.json")
|
||||
defer store.Close(context.Background())
|
||||
|
||||
existingAccountID := "bf1c8084-ba50-4ce7-9439-34653001fc3b"
|
||||
|
||||
_, err := store.GetAccount(context.Background(), existingAccountID)
|
||||
require.NoError(t, err)
|
||||
|
||||
network, err := store.GetAccountNetwork(context.Background(), LockingStrengthShare, existingAccountID)
|
||||
require.NoError(t, err)
|
||||
ip := net.IP{100, 64, 0, 0}.To16()
|
||||
assert.Equal(t, ip, network.Net.IP)
|
||||
assert.Equal(t, net.IPMask{255, 255, 0, 0}, network.Net.Mask)
|
||||
assert.Equal(t, "", network.Dns)
|
||||
assert.Equal(t, "af1c8024-ha40-4ce2-9418-34653101fc3c", network.Identifier)
|
||||
assert.Equal(t, uint64(0), network.Serial)
|
||||
}
|
||||
|
||||
func TestSqlite_GetSetupKeyBySecret(t *testing.T) {
|
||||
if runtime.GOOS == "windows" {
|
||||
t.Skip("The SQLite store is not properly supported by Windows yet")
|
||||
}
|
||||
store := newSqliteStoreFromFile(t, "testdata/extended-store.json")
|
||||
defer store.Close(context.Background())
|
||||
|
||||
existingAccountID := "bf1c8084-ba50-4ce7-9439-34653001fc3b"
|
||||
|
||||
_, err := store.GetAccount(context.Background(), existingAccountID)
|
||||
require.NoError(t, err)
|
||||
|
||||
setupKey, err := store.GetSetupKeyBySecret(context.Background(), LockingStrengthShare, "A2C8E62B-38F5-4553-B31E-DD66C696CEBB")
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "A2C8E62B-38F5-4553-B31E-DD66C696CEBB", setupKey.Key)
|
||||
assert.Equal(t, "bf1c8084-ba50-4ce7-9439-34653001fc3b", setupKey.AccountID)
|
||||
assert.Equal(t, "Default key", setupKey.Name)
|
||||
}
|
||||
|
||||
func TestSqlite_incrementSetupKeyUsage(t *testing.T) {
|
||||
if runtime.GOOS == "windows" {
|
||||
t.Skip("The SQLite store is not properly supported by Windows yet")
|
||||
}
|
||||
store := newSqliteStoreFromFile(t, "testdata/extended-store.json")
|
||||
defer store.Close(context.Background())
|
||||
|
||||
existingAccountID := "bf1c8084-ba50-4ce7-9439-34653001fc3b"
|
||||
|
||||
_, err := store.GetAccount(context.Background(), existingAccountID)
|
||||
require.NoError(t, err)
|
||||
|
||||
setupKey, err := store.GetSetupKeyBySecret(context.Background(), LockingStrengthShare, "A2C8E62B-38F5-4553-B31E-DD66C696CEBB")
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, 0, setupKey.UsedTimes)
|
||||
|
||||
err = store.IncrementSetupKeyUsage(context.Background(), setupKey.Id)
|
||||
require.NoError(t, err)
|
||||
|
||||
setupKey, err = store.GetSetupKeyBySecret(context.Background(), LockingStrengthShare, "A2C8E62B-38F5-4553-B31E-DD66C696CEBB")
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, 1, setupKey.UsedTimes)
|
||||
|
||||
err = store.IncrementSetupKeyUsage(context.Background(), setupKey.Id)
|
||||
require.NoError(t, err)
|
||||
|
||||
setupKey, err = store.GetSetupKeyBySecret(context.Background(), LockingStrengthShare, "A2C8E62B-38F5-4553-B31E-DD66C696CEBB")
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, 2, setupKey.UsedTimes)
|
||||
}
|
||||
|
@ -100,3 +100,13 @@ func NewPeerNotRegisteredError() error {
|
||||
func NewPeerLoginExpiredError() error {
|
||||
return Errorf(PermissionDenied, "peer login has expired, please log in once more")
|
||||
}
|
||||
|
||||
// NewSetupKeyNotFoundError creates a new Error with NotFound type for a missing setup key
|
||||
func NewSetupKeyNotFoundError() error {
|
||||
return Errorf(NotFound, "setup key not found")
|
||||
}
|
||||
|
||||
// NewGetUserFromStoreError creates a new Error with Internal type for an issue getting user from store
|
||||
func NewGetUserFromStoreError() error {
|
||||
return Errorf(Internal, "issue getting user from store")
|
||||
}
|
||||
|
@ -27,6 +27,15 @@ import (
|
||||
"github.com/netbirdio/netbird/route"
|
||||
)
|
||||
|
||||
type LockingStrength string
|
||||
|
||||
const (
|
||||
LockingStrengthUpdate LockingStrength = "UPDATE" // Strongest lock, preventing any changes by other transactions until your transaction completes.
|
||||
LockingStrengthShare LockingStrength = "SHARE" // Allows reading but prevents changes by other transactions.
|
||||
LockingStrengthNoKeyUpdate LockingStrength = "NO KEY UPDATE" // Similar to UPDATE but allows changes to related rows.
|
||||
LockingStrengthKeyShare LockingStrength = "KEY SHARE" // Protects against changes to primary/unique keys but allows other updates.
|
||||
)
|
||||
|
||||
type Store interface {
|
||||
GetAllAccounts(ctx context.Context) []*Account
|
||||
GetAccount(ctx context.Context, accountID string) (*Account, error)
|
||||
@ -41,7 +50,7 @@ type Store interface {
|
||||
GetAccountByPrivateDomain(ctx context.Context, domain string) (*Account, error)
|
||||
GetTokenIDByHashedToken(ctx context.Context, secret string) (string, error)
|
||||
GetUserByTokenID(ctx context.Context, tokenID string) (*User, error)
|
||||
GetUserByUserID(ctx context.Context, userID string) (*User, error)
|
||||
GetUserByUserID(ctx context.Context, lockStrength LockingStrength, userID string) (*User, error)
|
||||
GetAccountGroups(ctx context.Context, accountID string) ([]*nbgroup.Group, error)
|
||||
GetPostureCheckByChecksDefinition(accountID string, checks *posture.ChecksDefinition) (*posture.Checks, error)
|
||||
SaveAccount(ctx context.Context, account *Account) error
|
||||
@ -60,14 +69,24 @@ type Store interface {
|
||||
SavePeer(ctx context.Context, accountID string, peer *nbpeer.Peer) error
|
||||
SavePeerStatus(accountID, peerID string, status nbpeer.PeerStatus) error
|
||||
SavePeerLocation(accountID string, peer *nbpeer.Peer) error
|
||||
SaveUserLastLogin(accountID, userID string, lastLogin time.Time) error
|
||||
SaveUserLastLogin(ctx context.Context, accountID, userID string, lastLogin time.Time) error
|
||||
// Close should close the store persisting all unsaved data.
|
||||
Close(ctx context.Context) error
|
||||
// GetStoreEngine should return StoreEngine of the current store implementation.
|
||||
// This is also a method of metrics.DataSource interface.
|
||||
GetStoreEngine() StoreEngine
|
||||
GetPeerByPeerPubKey(ctx context.Context, peerKey string) (*nbpeer.Peer, error)
|
||||
GetAccountSettings(ctx context.Context, accountID string) (*Settings, error)
|
||||
GetPeerByPeerPubKey(ctx context.Context, lockStrength LockingStrength, peerKey string) (*nbpeer.Peer, error)
|
||||
GetAccountSettings(ctx context.Context, lockStrength LockingStrength, accountID string) (*Settings, error)
|
||||
GetSetupKeyBySecret(ctx context.Context, lockStrength LockingStrength, key string) (*SetupKey, error)
|
||||
GetTakenIPs(ctx context.Context, lockStrength LockingStrength, accountId string) ([]net.IP, error)
|
||||
IncrementSetupKeyUsage(ctx context.Context, setupKeyID string) error
|
||||
AddPeerToAllGroup(ctx context.Context, accountID string, peerID string) error
|
||||
GetPeerLabelsInAccount(ctx context.Context, lockStrength LockingStrength, accountId string) ([]string, error)
|
||||
AddPeerToGroup(ctx context.Context, accountId string, peerId string, groupID string) error
|
||||
AddPeerToAccount(ctx context.Context, peer *nbpeer.Peer) error
|
||||
IncrementNetworkSerial(ctx context.Context, accountId string) error
|
||||
GetAccountNetwork(ctx context.Context, lockStrength LockingStrength, accountId string) (*Network, error)
|
||||
ExecuteInTransaction(ctx context.Context, f func(store Store) error) error
|
||||
}
|
||||
|
||||
type StoreEngine string
|
||||
|
120
management/server/testdata/extended-store.json
vendored
Normal file
120
management/server/testdata/extended-store.json
vendored
Normal file
@ -0,0 +1,120 @@
|
||||
{
|
||||
"Accounts": {
|
||||
"bf1c8084-ba50-4ce7-9439-34653001fc3b": {
|
||||
"Id": "bf1c8084-ba50-4ce7-9439-34653001fc3b",
|
||||
"CreatedBy": "",
|
||||
"Domain": "test.com",
|
||||
"DomainCategory": "private",
|
||||
"IsDomainPrimaryAccount": true,
|
||||
"SetupKeys": {
|
||||
"A2C8E62B-38F5-4553-B31E-DD66C696CEBB": {
|
||||
"Id": "A2C8E62B-38F5-4553-B31E-DD66C696CEBB",
|
||||
"AccountID": "",
|
||||
"Key": "A2C8E62B-38F5-4553-B31E-DD66C696CEBB",
|
||||
"Name": "Default key",
|
||||
"Type": "reusable",
|
||||
"CreatedAt": "2021-08-19T20:46:20.005936822+02:00",
|
||||
"ExpiresAt": "2321-09-18T20:46:20.005936822+02:00",
|
||||
"UpdatedAt": "0001-01-01T00:00:00Z",
|
||||
"Revoked": false,
|
||||
"UsedTimes": 0,
|
||||
"LastUsed": "0001-01-01T00:00:00Z",
|
||||
"AutoGroups": ["cfefqs706sqkneg59g2g"],
|
||||
"UsageLimit": 0,
|
||||
"Ephemeral": false
|
||||
},
|
||||
"A2C8E62B-38F5-4553-B31E-DD66C696CEBC": {
|
||||
"Id": "A2C8E62B-38F5-4553-B31E-DD66C696CEBC",
|
||||
"AccountID": "",
|
||||
"Key": "A2C8E62B-38F5-4553-B31E-DD66C696CEBC",
|
||||
"Name": "Faulty key with non existing group",
|
||||
"Type": "reusable",
|
||||
"CreatedAt": "2021-08-19T20:46:20.005936822+02:00",
|
||||
"ExpiresAt": "2321-09-18T20:46:20.005936822+02:00",
|
||||
"UpdatedAt": "0001-01-01T00:00:00Z",
|
||||
"Revoked": false,
|
||||
"UsedTimes": 0,
|
||||
"LastUsed": "0001-01-01T00:00:00Z",
|
||||
"AutoGroups": ["abcd"],
|
||||
"UsageLimit": 0,
|
||||
"Ephemeral": false
|
||||
}
|
||||
},
|
||||
"Network": {
|
||||
"id": "af1c8024-ha40-4ce2-9418-34653101fc3c",
|
||||
"Net": {
|
||||
"IP": "100.64.0.0",
|
||||
"Mask": "//8AAA=="
|
||||
},
|
||||
"Dns": "",
|
||||
"Serial": 0
|
||||
},
|
||||
"Peers": {},
|
||||
"Users": {
|
||||
"edafee4e-63fb-11ec-90d6-0242ac120003": {
|
||||
"Id": "edafee4e-63fb-11ec-90d6-0242ac120003",
|
||||
"AccountID": "",
|
||||
"Role": "admin",
|
||||
"IsServiceUser": false,
|
||||
"ServiceUserName": "",
|
||||
"AutoGroups": ["cfefqs706sqkneg59g3g"],
|
||||
"PATs": {},
|
||||
"Blocked": false,
|
||||
"LastLogin": "0001-01-01T00:00:00Z"
|
||||
},
|
||||
"f4f6d672-63fb-11ec-90d6-0242ac120003": {
|
||||
"Id": "f4f6d672-63fb-11ec-90d6-0242ac120003",
|
||||
"AccountID": "",
|
||||
"Role": "user",
|
||||
"IsServiceUser": false,
|
||||
"ServiceUserName": "",
|
||||
"AutoGroups": null,
|
||||
"PATs": {
|
||||
"9dj38s35-63fb-11ec-90d6-0242ac120003": {
|
||||
"ID": "9dj38s35-63fb-11ec-90d6-0242ac120003",
|
||||
"UserID": "",
|
||||
"Name": "",
|
||||
"HashedToken": "SoMeHaShEdToKeN",
|
||||
"ExpirationDate": "2023-02-27T00:00:00Z",
|
||||
"CreatedBy": "user",
|
||||
"CreatedAt": "2023-01-01T00:00:00Z",
|
||||
"LastUsed": "2023-02-01T00:00:00Z"
|
||||
}
|
||||
},
|
||||
"Blocked": false,
|
||||
"LastLogin": "0001-01-01T00:00:00Z"
|
||||
}
|
||||
},
|
||||
"Groups": {
|
||||
"cfefqs706sqkneg59g4g": {
|
||||
"ID": "cfefqs706sqkneg59g4g",
|
||||
"Name": "All",
|
||||
"Peers": []
|
||||
},
|
||||
"cfefqs706sqkneg59g3g": {
|
||||
"ID": "cfefqs706sqkneg59g3g",
|
||||
"Name": "AwesomeGroup1",
|
||||
"Peers": []
|
||||
},
|
||||
"cfefqs706sqkneg59g2g": {
|
||||
"ID": "cfefqs706sqkneg59g2g",
|
||||
"Name": "AwesomeGroup2",
|
||||
"Peers": []
|
||||
}
|
||||
},
|
||||
"Rules": null,
|
||||
"Policies": [],
|
||||
"Routes": null,
|
||||
"NameServerGroups": null,
|
||||
"DNSSettings": null,
|
||||
"Settings": {
|
||||
"PeerLoginExpirationEnabled": false,
|
||||
"PeerLoginExpiration": 86400000000000,
|
||||
"GroupsPropagationEnabled": false,
|
||||
"JWTGroupsEnabled": false,
|
||||
"JWTGroupsClaimName": ""
|
||||
}
|
||||
}
|
||||
},
|
||||
"InstallationID": ""
|
||||
}
|
@ -89,10 +89,6 @@ func (u *User) LastDashboardLoginChanged(LastLogin time.Time) bool {
|
||||
return LastLogin.After(u.LastLogin) && !u.LastLogin.IsZero()
|
||||
}
|
||||
|
||||
func (u *User) updateLastLogin(login time.Time) {
|
||||
u.LastLogin = login
|
||||
}
|
||||
|
||||
// HasAdminPower returns true if the user has admin or owner roles, false otherwise
|
||||
func (u *User) HasAdminPower() bool {
|
||||
return u.Role == UserRoleAdmin || u.Role == UserRoleOwner
|
||||
@ -386,7 +382,7 @@ func (am *DefaultAccountManager) GetUser(ctx context.Context, claims jwtclaims.A
|
||||
// server when user authenticates a device. And we need to separate the Dashboard login event from the Device login event.
|
||||
newLogin := user.LastDashboardLoginChanged(claims.LastLogin)
|
||||
|
||||
err = am.Store.SaveUserLastLogin(account.Id, claims.UserId, claims.LastLogin)
|
||||
err = am.Store.SaveUserLastLogin(ctx, account.Id, claims.UserId, claims.LastLogin)
|
||||
if err != nil {
|
||||
log.WithContext(ctx).Errorf("failed saving user last login: %v", err)
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user