[management] Add transaction to addPeer (#2469)

This PR removes the GetAccount and SaveAccount operations from the AddPeer and instead makes use of gorm.Transaction to add the new peer.
This commit is contained in:
pascal-fischer 2024-09-16 15:47:03 +02:00 committed by GitHub
parent 730dd1733e
commit 6c50b0c84b
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
13 changed files with 1095 additions and 216 deletions

View File

@ -49,7 +49,7 @@ jobs:
run: git --no-pager diff --exit-code run: git --no-pager diff --exit-code
- name: Test - name: Test
run: CGO_ENABLED=1 GOARCH=${{ matrix.arch }} NETBIRD_STORE_ENGINE=${{ matrix.store }} go test -exec 'sudo --preserve-env=CI,NETBIRD_STORE_ENGINE' -timeout 5m -p 1 ./... run: CGO_ENABLED=1 GOARCH=${{ matrix.arch }} NETBIRD_STORE_ENGINE=${{ matrix.store }} go test -exec 'sudo --preserve-env=CI,NETBIRD_STORE_ENGINE' -timeout 6m -p 1 ./...
test_client_on_docker: test_client_on_docker:
runs-on: ubuntu-20.04 runs-on: ubuntu-20.04

View File

@ -263,6 +263,11 @@ type AccountSettings struct {
Settings *Settings `gorm:"embedded;embeddedPrefix:settings_"` Settings *Settings `gorm:"embedded;embeddedPrefix:settings_"`
} }
// Subclass used in gorm to only load network and not whole account
type AccountNetwork struct {
Network *Network `gorm:"embedded;embeddedPrefix:network_"`
}
type UserPermissions struct { type UserPermissions struct {
DashboardView string `json:"dashboard_view"` DashboardView string `json:"dashboard_view"`
} }
@ -700,14 +705,6 @@ func (a *Account) GetPeerGroupsList(peerID string) []string {
return grps return grps
} }
func (a *Account) getUserGroups(userID string) ([]string, error) {
user, err := a.FindUser(userID)
if err != nil {
return nil, err
}
return user.AutoGroups, nil
}
func (a *Account) getPeerDNSManagementStatus(peerID string) bool { func (a *Account) getPeerDNSManagementStatus(peerID string) bool {
peerGroups := a.getPeerGroups(peerID) peerGroups := a.getPeerGroups(peerID)
enabled := true enabled := true
@ -734,14 +731,6 @@ func (a *Account) getPeerGroups(peerID string) lookupMap {
return groupList return groupList
} }
func (a *Account) getSetupKeyGroups(setupKey string) ([]string, error) {
key, err := a.FindSetupKey(setupKey)
if err != nil {
return nil, err
}
return key.AutoGroups, nil
}
func (a *Account) getTakenIPs() []net.IP { func (a *Account) getTakenIPs() []net.IP {
var takenIps []net.IP var takenIps []net.IP
for _, existingPeer := range a.Peers { for _, existingPeer := range a.Peers {
@ -2082,7 +2071,7 @@ func (am *DefaultAccountManager) GetAccountIDForPeerKey(ctx context.Context, pee
} }
func (am *DefaultAccountManager) handleUserPeer(ctx context.Context, peer *nbpeer.Peer, settings *Settings) (bool, error) { func (am *DefaultAccountManager) handleUserPeer(ctx context.Context, peer *nbpeer.Peer, settings *Settings) (bool, error) {
user, err := am.Store.GetUserByUserID(ctx, peer.UserID) user, err := am.Store.GetUserByUserID(ctx, LockingStrengthShare, peer.UserID)
if err != nil { if err != nil {
return false, err return false, err
} }
@ -2103,6 +2092,25 @@ func (am *DefaultAccountManager) handleUserPeer(ctx context.Context, peer *nbpee
return false, nil return false, nil
} }
func (am *DefaultAccountManager) getFreeDNSLabel(ctx context.Context, store Store, accountID string, peerHostName string) (string, error) {
existingLabels, err := store.GetPeerLabelsInAccount(ctx, LockingStrengthShare, accountID)
if err != nil {
return "", fmt.Errorf("failed to get peer dns labels: %w", err)
}
labelMap := ConvertSliceToMap(existingLabels)
newLabel, err := getPeerHostLabel(peerHostName, labelMap)
if err != nil {
return "", fmt.Errorf("failed to get new host label: %w", err)
}
if newLabel == "" {
return "", fmt.Errorf("failed to get new host label: %w", err)
}
return newLabel, nil
}
// addAllGroup to account object if it doesn't exist // addAllGroup to account object if it doesn't exist
func addAllGroup(account *Account) error { func addAllGroup(account *Account) error {
if len(account.Groups) == 0 { if len(account.Groups) == 0 {

View File

@ -7,6 +7,7 @@ import (
"time" "time"
nbpeer "github.com/netbirdio/netbird/management/server/peer" nbpeer "github.com/netbirdio/netbird/management/server/peer"
"github.com/netbirdio/netbird/management/server/status"
) )
type MockStore struct { type MockStore struct {
@ -24,7 +25,7 @@ func (s *MockStore) GetAccountByPeerID(_ context.Context, peerId string) (*Accou
return s.account, nil return s.account, nil
} }
return nil, fmt.Errorf("account not found") return nil, status.NewPeerNotFoundError(peerId)
} }
type MocAccountManager struct { type MocAccountManager struct {

View File

@ -2,6 +2,8 @@ package server
import ( import (
"context" "context"
"errors"
"net"
"os" "os"
"path/filepath" "path/filepath"
"strings" "strings"
@ -46,6 +48,158 @@ type FileStore struct {
metrics telemetry.AppMetrics `json:"-"` metrics telemetry.AppMetrics `json:"-"`
} }
func (s *FileStore) ExecuteInTransaction(ctx context.Context, f func(store Store) error) error {
return f(s)
}
func (s *FileStore) IncrementSetupKeyUsage(ctx context.Context, setupKeyID string) error {
s.mux.Lock()
defer s.mux.Unlock()
accountID, ok := s.SetupKeyID2AccountID[strings.ToUpper(setupKeyID)]
if !ok {
return status.NewSetupKeyNotFoundError()
}
account, err := s.getAccount(accountID)
if err != nil {
return err
}
account.SetupKeys[setupKeyID].UsedTimes++
return s.SaveAccount(ctx, account)
}
func (s *FileStore) AddPeerToAllGroup(ctx context.Context, accountID string, peerID string) error {
s.mux.Lock()
defer s.mux.Unlock()
account, err := s.getAccount(accountID)
if err != nil {
return err
}
allGroup, err := account.GetGroupAll()
if err != nil || allGroup == nil {
return errors.New("all group not found")
}
allGroup.Peers = append(allGroup.Peers, peerID)
return nil
}
func (s *FileStore) AddPeerToGroup(ctx context.Context, accountId string, peerId string, groupID string) error {
s.mux.Lock()
defer s.mux.Unlock()
account, err := s.getAccount(accountId)
if err != nil {
return err
}
account.Groups[groupID].Peers = append(account.Groups[groupID].Peers, peerId)
return nil
}
func (s *FileStore) AddPeerToAccount(ctx context.Context, peer *nbpeer.Peer) error {
s.mux.Lock()
defer s.mux.Unlock()
account, ok := s.Accounts[peer.AccountID]
if !ok {
return status.NewAccountNotFoundError(peer.AccountID)
}
account.Peers[peer.ID] = peer
return s.SaveAccount(ctx, account)
}
func (s *FileStore) IncrementNetworkSerial(ctx context.Context, accountId string) error {
s.mux.Lock()
defer s.mux.Unlock()
account, ok := s.Accounts[accountId]
if !ok {
return status.NewAccountNotFoundError(accountId)
}
account.Network.Serial++
return s.SaveAccount(ctx, account)
}
func (s *FileStore) GetSetupKeyBySecret(ctx context.Context, lockStrength LockingStrength, key string) (*SetupKey, error) {
s.mux.Lock()
defer s.mux.Unlock()
accountID, ok := s.SetupKeyID2AccountID[strings.ToUpper(key)]
if !ok {
return nil, status.NewSetupKeyNotFoundError()
}
account, err := s.getAccount(accountID)
if err != nil {
return nil, err
}
setupKey, ok := account.SetupKeys[key]
if !ok {
return nil, status.Errorf(status.NotFound, "setup key not found")
}
return setupKey, nil
}
func (s *FileStore) GetTakenIPs(ctx context.Context, lockStrength LockingStrength, accountID string) ([]net.IP, error) {
s.mux.Lock()
defer s.mux.Unlock()
account, err := s.getAccount(accountID)
if err != nil {
return nil, err
}
var takenIps []net.IP
for _, existingPeer := range account.Peers {
takenIps = append(takenIps, existingPeer.IP)
}
return takenIps, nil
}
func (s *FileStore) GetPeerLabelsInAccount(ctx context.Context, lockStrength LockingStrength, accountID string) ([]string, error) {
s.mux.Lock()
defer s.mux.Unlock()
account, err := s.getAccount(accountID)
if err != nil {
return nil, err
}
existingLabels := []string{}
for _, peer := range account.Peers {
if peer.DNSLabel != "" {
existingLabels = append(existingLabels, peer.DNSLabel)
}
}
return existingLabels, nil
}
func (s *FileStore) GetAccountNetwork(ctx context.Context, lockStrength LockingStrength, accountID string) (*Network, error) {
s.mux.Lock()
defer s.mux.Unlock()
account, err := s.getAccount(accountID)
if err != nil {
return nil, err
}
return account.Network, nil
}
type StoredAccount struct{} type StoredAccount struct{}
// NewFileStore restores a store from the file located in the datadir // NewFileStore restores a store from the file located in the datadir
@ -422,7 +576,7 @@ func (s *FileStore) GetAccountBySetupKey(_ context.Context, setupKey string) (*A
accountID, ok := s.SetupKeyID2AccountID[strings.ToUpper(setupKey)] accountID, ok := s.SetupKeyID2AccountID[strings.ToUpper(setupKey)]
if !ok { if !ok {
return nil, status.Errorf(status.NotFound, "account not found: provided setup key doesn't exists") return nil, status.NewSetupKeyNotFoundError()
} }
account, err := s.getAccount(accountID) account, err := s.getAccount(accountID)
@ -469,7 +623,7 @@ func (s *FileStore) GetUserByTokenID(_ context.Context, tokenID string) (*User,
return account.Users[userID].Copy(), nil return account.Users[userID].Copy(), nil
} }
func (s *FileStore) GetUserByUserID(_ context.Context, userID string) (*User, error) { func (s *FileStore) GetUserByUserID(_ context.Context, _ LockingStrength, userID string) (*User, error) {
accountID, ok := s.UserID2AccountID[userID] accountID, ok := s.UserID2AccountID[userID]
if !ok { if !ok {
return nil, status.Errorf(status.NotFound, "accountID not found: provided userID doesn't exists") return nil, status.Errorf(status.NotFound, "accountID not found: provided userID doesn't exists")
@ -513,7 +667,7 @@ func (s *FileStore) GetAllAccounts(_ context.Context) (all []*Account) {
func (s *FileStore) getAccount(accountID string) (*Account, error) { func (s *FileStore) getAccount(accountID string) (*Account, error) {
account, ok := s.Accounts[accountID] account, ok := s.Accounts[accountID]
if !ok { if !ok {
return nil, status.Errorf(status.NotFound, "account not found") return nil, status.NewAccountNotFoundError(accountID)
} }
return account, nil return account, nil
@ -639,13 +793,13 @@ func (s *FileStore) GetAccountIDBySetupKey(_ context.Context, setupKey string) (
accountID, ok := s.SetupKeyID2AccountID[strings.ToUpper(setupKey)] accountID, ok := s.SetupKeyID2AccountID[strings.ToUpper(setupKey)]
if !ok { if !ok {
return "", status.Errorf(status.NotFound, "account not found: provided setup key doesn't exists") return "", status.NewSetupKeyNotFoundError()
} }
return accountID, nil return accountID, nil
} }
func (s *FileStore) GetPeerByPeerPubKey(_ context.Context, peerKey string) (*nbpeer.Peer, error) { func (s *FileStore) GetPeerByPeerPubKey(_ context.Context, _ LockingStrength, peerKey string) (*nbpeer.Peer, error) {
s.mux.Lock() s.mux.Lock()
defer s.mux.Unlock() defer s.mux.Unlock()
@ -668,7 +822,7 @@ func (s *FileStore) GetPeerByPeerPubKey(_ context.Context, peerKey string) (*nbp
return nil, status.NewPeerNotFoundError(peerKey) return nil, status.NewPeerNotFoundError(peerKey)
} }
func (s *FileStore) GetAccountSettings(_ context.Context, accountID string) (*Settings, error) { func (s *FileStore) GetAccountSettings(_ context.Context, _ LockingStrength, accountID string) (*Settings, error) {
s.mux.Lock() s.mux.Lock()
defer s.mux.Unlock() defer s.mux.Unlock()
@ -758,7 +912,7 @@ func (s *FileStore) SavePeerLocation(accountID string, peerWithLocation *nbpeer.
} }
// SaveUserLastLogin stores the last login time for a user in memory. It doesn't attempt to persist data to speed up things. // SaveUserLastLogin stores the last login time for a user in memory. It doesn't attempt to persist data to speed up things.
func (s *FileStore) SaveUserLastLogin(accountID, userID string, lastLogin time.Time) error { func (s *FileStore) SaveUserLastLogin(_ context.Context, accountID, userID string, lastLogin time.Time) error {
s.mux.Lock() s.mux.Lock()
defer s.mux.Unlock() defer s.mux.Unlock()

View File

@ -627,7 +627,7 @@ func testSyncStatusRace(t *testing.T) {
} }
time.Sleep(10 * time.Millisecond) time.Sleep(10 * time.Millisecond)
peer, err := am.Store.GetPeerByPeerPubKey(context.Background(), peerWithInvalidStatus.PublicKey().String()) peer, err := am.Store.GetPeerByPeerPubKey(context.Background(), LockingStrengthShare, peerWithInvalidStatus.PublicKey().String())
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
return return
@ -638,8 +638,8 @@ func testSyncStatusRace(t *testing.T) {
} }
func Test_LoginPerformance(t *testing.T) { func Test_LoginPerformance(t *testing.T) {
if os.Getenv("CI") == "true" { if os.Getenv("CI") == "true" || runtime.GOOS == "windows" {
t.Skip("Skipping on CI") t.Skip("Skipping test on CI or Windows")
} }
t.Setenv("NETBIRD_STORE_ENGINE", "sqlite") t.Setenv("NETBIRD_STORE_ENGINE", "sqlite")
@ -655,7 +655,7 @@ func Test_LoginPerformance(t *testing.T) {
// {"M", 250, 1}, // {"M", 250, 1},
// {"L", 500, 1}, // {"L", 500, 1},
// {"XL", 750, 1}, // {"XL", 750, 1},
{"XXL", 2000, 1}, {"XXL", 5000, 1},
} }
log.SetOutput(io.Discard) log.SetOutput(io.Discard)
@ -700,15 +700,18 @@ func Test_LoginPerformance(t *testing.T) {
} }
defer mgmtServer.GracefulStop() defer mgmtServer.GracefulStop()
t.Logf("management setup complete, start registering peers")
var counter int32 var counter int32
var counterStart int32 var counterStart int32
var wg sync.WaitGroup var wgAccount sync.WaitGroup
var mu sync.Mutex var mu sync.Mutex
messageCalls := []func() error{} messageCalls := []func() error{}
for j := 0; j < bc.accounts; j++ { for j := 0; j < bc.accounts; j++ {
wg.Add(1) wgAccount.Add(1)
var wgPeer sync.WaitGroup
go func(j int, counter *int32, counterStart *int32) { go func(j int, counter *int32, counterStart *int32) {
defer wg.Done() defer wgAccount.Done()
account, err := createAccount(am, fmt.Sprintf("account-%d", j), fmt.Sprintf("user-%d", j), fmt.Sprintf("domain-%d", j)) account, err := createAccount(am, fmt.Sprintf("account-%d", j), fmt.Sprintf("user-%d", j), fmt.Sprintf("domain-%d", j))
if err != nil { if err != nil {
@ -722,7 +725,9 @@ func Test_LoginPerformance(t *testing.T) {
return return
} }
startTime := time.Now()
for i := 0; i < bc.peers; i++ { for i := 0; i < bc.peers; i++ {
wgPeer.Add(1)
key, err := wgtypes.GeneratePrivateKey() key, err := wgtypes.GeneratePrivateKey()
if err != nil { if err != nil {
t.Logf("failed to generate key: %v", err) t.Logf("failed to generate key: %v", err)
@ -763,21 +768,29 @@ func Test_LoginPerformance(t *testing.T) {
mu.Lock() mu.Lock()
messageCalls = append(messageCalls, login) messageCalls = append(messageCalls, login)
mu.Unlock() mu.Unlock()
_, _, _, err = am.LoginPeer(context.Background(), peerLogin)
if err != nil {
t.Logf("failed to login peer: %v", err)
return
}
atomic.AddInt32(counterStart, 1) go func(peerLogin PeerLogin, counterStart *int32) {
if *counterStart%100 == 0 { defer wgPeer.Done()
t.Logf("registered %d peers", *counterStart) _, _, _, err = am.LoginPeer(context.Background(), peerLogin)
} if err != nil {
t.Logf("failed to login peer: %v", err)
return
}
atomic.AddInt32(counterStart, 1)
if *counterStart%100 == 0 {
t.Logf("registered %d peers", *counterStart)
}
}(peerLogin, counterStart)
} }
wgPeer.Wait()
t.Logf("Time for registration: %s", time.Since(startTime))
}(j, &counter, &counterStart) }(j, &counter, &counterStart)
} }
wg.Wait() wgAccount.Wait()
t.Logf("prepared %d login calls", len(messageCalls)) t.Logf("prepared %d login calls", len(messageCalls))
testLoginPerformance(t, messageCalls) testLoginPerformance(t, messageCalls)

View File

@ -11,6 +11,7 @@ import (
"github.com/rs/xid" "github.com/rs/xid"
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/management/server/idp"
"github.com/netbirdio/netbird/management/server/posture" "github.com/netbirdio/netbird/management/server/posture"
"github.com/netbirdio/netbird/management/proto" "github.com/netbirdio/netbird/management/proto"
@ -371,164 +372,175 @@ func (am *DefaultAccountManager) AddPeer(ctx context.Context, setupKey, userID s
} }
}() }()
var account *Account
// ensure that we consider modification happened meanwhile (because we were outside the account lock when we fetched the account)
account, err = am.Store.GetAccount(ctx, accountID)
if err != nil {
return nil, nil, nil, err
}
if strings.ToLower(peer.Meta.Hostname) == "iphone" || strings.ToLower(peer.Meta.Hostname) == "ipad" && userID != "" {
if am.idpManager != nil {
userdata, err := am.lookupUserInCache(ctx, userID, account)
if err == nil && userdata != nil {
peer.Meta.Hostname = fmt.Sprintf("%s-%s", peer.Meta.Hostname, strings.Split(userdata.Email, "@")[0])
}
}
}
// This is a handling for the case when the same machine (with the same WireGuard pub key) tries to register twice. // This is a handling for the case when the same machine (with the same WireGuard pub key) tries to register twice.
// Such case is possible when AddPeer function takes long time to finish after AcquireWriteLockByUID (e.g., database is slow) // Such case is possible when AddPeer function takes long time to finish after AcquireWriteLockByUID (e.g., database is slow)
// and the peer disconnects with a timeout and tries to register again. // and the peer disconnects with a timeout and tries to register again.
// We just check if this machine has been registered before and reject the second registration. // We just check if this machine has been registered before and reject the second registration.
// The connecting peer should be able to recover with a retry. // The connecting peer should be able to recover with a retry.
_, err = account.FindPeerByPubKey(peer.Key) _, err = am.Store.GetPeerByPeerPubKey(ctx, LockingStrengthShare, peer.Key)
if err == nil { if err == nil {
return nil, nil, nil, status.Errorf(status.PreconditionFailed, "peer has been already registered") return nil, nil, nil, status.Errorf(status.PreconditionFailed, "peer has been already registered")
} }
opEvent := &activity.Event{ opEvent := &activity.Event{
Timestamp: time.Now().UTC(), Timestamp: time.Now().UTC(),
AccountID: account.Id, AccountID: accountID,
} }
var ephemeral bool var newPeer *nbpeer.Peer
setupKeyName := ""
if !addedByUser {
// validate the setup key if adding with a key
sk, err := account.FindSetupKey(upperKey)
if err != nil {
return nil, nil, nil, err
}
if !sk.IsValid() { err = am.Store.ExecuteInTransaction(ctx, func(transaction Store) error {
return nil, nil, nil, status.Errorf(status.PreconditionFailed, "couldn't add peer: setup key is invalid") var groupsToAdd []string
} var setupKeyID string
var setupKeyName string
account.SetupKeys[sk.Key] = sk.IncrementUsage() var ephemeral bool
opEvent.InitiatorID = sk.Id if addedByUser {
opEvent.Activity = activity.PeerAddedWithSetupKey user, err := transaction.GetUserByUserID(ctx, LockingStrengthUpdate, userID)
ephemeral = sk.Ephemeral if err != nil {
setupKeyName = sk.Name return fmt.Errorf("failed to get user groups: %w", err)
} else { }
opEvent.InitiatorID = userID groupsToAdd = user.AutoGroups
opEvent.Activity = activity.PeerAddedByUser opEvent.InitiatorID = userID
} opEvent.Activity = activity.PeerAddedByUser
takenIps := account.getTakenIPs()
existingLabels := account.getPeerDNSLabels()
newLabel, err := getPeerHostLabel(peer.Meta.Hostname, existingLabels)
if err != nil {
return nil, nil, nil, err
}
peer.DNSLabel = newLabel
network := account.Network
nextIp, err := AllocatePeerIP(network.Net, takenIps)
if err != nil {
return nil, nil, nil, err
}
registrationTime := time.Now().UTC()
newPeer := &nbpeer.Peer{
ID: xid.New().String(),
Key: peer.Key,
SetupKey: upperKey,
IP: nextIp,
Meta: peer.Meta,
Name: peer.Meta.Hostname,
DNSLabel: newLabel,
UserID: userID,
Status: &nbpeer.PeerStatus{Connected: false, LastSeen: registrationTime},
SSHEnabled: false,
SSHKey: peer.SSHKey,
LastLogin: registrationTime,
CreatedAt: registrationTime,
LoginExpirationEnabled: addedByUser,
Ephemeral: ephemeral,
Location: peer.Location,
}
if am.geo != nil && newPeer.Location.ConnectionIP != nil {
location, err := am.geo.Lookup(newPeer.Location.ConnectionIP)
if err != nil {
log.WithContext(ctx).Warnf("failed to get location for new peer realip: [%s]: %v", newPeer.Location.ConnectionIP.String(), err)
} else { } else {
newPeer.Location.CountryCode = location.Country.ISOCode // Validate the setup key
newPeer.Location.CityName = location.City.Names.En sk, err := transaction.GetSetupKeyBySecret(ctx, LockingStrengthUpdate, upperKey)
newPeer.Location.GeoNameID = location.City.GeonameID if err != nil {
} return fmt.Errorf("failed to get setup key: %w", err)
} }
// add peer to 'All' group if !sk.IsValid() {
group, err := account.GetGroupAll() return status.Errorf(status.PreconditionFailed, "couldn't add peer: setup key is invalid")
if err != nil { }
return nil, nil, nil, err
}
group.Peers = append(group.Peers, newPeer.ID)
var groupsToAdd []string opEvent.InitiatorID = sk.Id
if addedByUser { opEvent.Activity = activity.PeerAddedWithSetupKey
groupsToAdd, err = account.getUserGroups(userID) groupsToAdd = sk.AutoGroups
if err != nil { ephemeral = sk.Ephemeral
return nil, nil, nil, err setupKeyID = sk.Id
setupKeyName = sk.Name
} }
} else {
groupsToAdd, err = account.getSetupKeyGroups(upperKey)
if err != nil {
return nil, nil, nil, err
}
}
if len(groupsToAdd) > 0 { if strings.ToLower(peer.Meta.Hostname) == "iphone" || strings.ToLower(peer.Meta.Hostname) == "ipad" && userID != "" {
for _, s := range groupsToAdd { if am.idpManager != nil {
if g, ok := account.Groups[s]; ok && g.Name != "All" { userdata, err := am.idpManager.GetUserDataByID(ctx, userID, idp.AppMetadata{WTAccountID: accountID})
g.Peers = append(g.Peers, newPeer.ID) if err == nil && userdata != nil {
peer.Meta.Hostname = fmt.Sprintf("%s-%s", peer.Meta.Hostname, strings.Split(userdata.Email, "@")[0])
}
} }
} }
}
newPeer = am.integratedPeerValidator.PreparePeer(ctx, account.Id, newPeer, account.GetPeerGroupsList(newPeer.ID), account.Settings.Extra) freeLabel, err := am.getFreeDNSLabel(ctx, transaction, accountID, peer.Meta.Hostname)
if addedByUser {
user, err := account.FindUser(userID)
if err != nil { if err != nil {
return nil, nil, nil, status.Errorf(status.Internal, "couldn't find user") return fmt.Errorf("failed to get free DNS label: %w", err)
} }
user.updateLastLogin(newPeer.LastLogin)
}
account.Peers[newPeer.ID] = newPeer freeIP, err := am.getFreeIP(ctx, transaction, accountID)
account.Network.IncSerial() if err != nil {
err = am.Store.SaveAccount(ctx, account) return fmt.Errorf("failed to get free IP: %w", err)
}
registrationTime := time.Now().UTC()
newPeer = &nbpeer.Peer{
ID: xid.New().String(),
AccountID: accountID,
Key: peer.Key,
SetupKey: upperKey,
IP: freeIP,
Meta: peer.Meta,
Name: peer.Meta.Hostname,
DNSLabel: freeLabel,
UserID: userID,
Status: &nbpeer.PeerStatus{Connected: false, LastSeen: registrationTime},
SSHEnabled: false,
SSHKey: peer.SSHKey,
LastLogin: registrationTime,
CreatedAt: registrationTime,
LoginExpirationEnabled: addedByUser,
Ephemeral: ephemeral,
Location: peer.Location,
}
opEvent.TargetID = newPeer.ID
opEvent.Meta = newPeer.EventMeta(am.GetDNSDomain())
if !addedByUser {
opEvent.Meta["setup_key_name"] = setupKeyName
}
if am.geo != nil && newPeer.Location.ConnectionIP != nil {
location, err := am.geo.Lookup(newPeer.Location.ConnectionIP)
if err != nil {
log.WithContext(ctx).Warnf("failed to get location for new peer realip: [%s]: %v", newPeer.Location.ConnectionIP.String(), err)
} else {
newPeer.Location.CountryCode = location.Country.ISOCode
newPeer.Location.CityName = location.City.Names.En
newPeer.Location.GeoNameID = location.City.GeonameID
}
}
settings, err := transaction.GetAccountSettings(ctx, LockingStrengthShare, accountID)
if err != nil {
return fmt.Errorf("failed to get account settings: %w", err)
}
newPeer = am.integratedPeerValidator.PreparePeer(ctx, accountID, newPeer, groupsToAdd, settings.Extra)
err = transaction.AddPeerToAllGroup(ctx, accountID, newPeer.ID)
if err != nil {
return fmt.Errorf("failed adding peer to All group: %w", err)
}
if len(groupsToAdd) > 0 {
for _, g := range groupsToAdd {
err = transaction.AddPeerToGroup(ctx, accountID, newPeer.ID, g)
if err != nil {
return err
}
}
}
err = transaction.AddPeerToAccount(ctx, newPeer)
if err != nil {
return fmt.Errorf("failed to add peer to account: %w", err)
}
err = transaction.IncrementNetworkSerial(ctx, accountID)
if err != nil {
return fmt.Errorf("failed to increment network serial: %w", err)
}
if addedByUser {
err := transaction.SaveUserLastLogin(ctx, accountID, userID, newPeer.LastLogin)
if err != nil {
return fmt.Errorf("failed to update user last login: %w", err)
}
} else {
err = transaction.IncrementSetupKeyUsage(ctx, setupKeyID)
if err != nil {
return fmt.Errorf("failed to increment setup key usage: %w", err)
}
}
log.WithContext(ctx).Debugf("Peer %s added to account %s", newPeer.ID, accountID)
return nil
})
if err != nil { if err != nil {
return nil, nil, nil, err return nil, nil, nil, fmt.Errorf("failed to add peer to database: %w", err)
} }
// Account is saved, we can release the lock if newPeer == nil {
unlock() return nil, nil, nil, fmt.Errorf("new peer is nil")
unlock = nil
opEvent.TargetID = newPeer.ID
opEvent.Meta = newPeer.EventMeta(am.GetDNSDomain())
if !addedByUser {
opEvent.Meta["setup_key_name"] = setupKeyName
} }
am.StoreEvent(ctx, opEvent.InitiatorID, opEvent.TargetID, opEvent.AccountID, opEvent.Activity, opEvent.Meta) am.StoreEvent(ctx, opEvent.InitiatorID, opEvent.TargetID, opEvent.AccountID, opEvent.Activity, opEvent.Meta)
unlock()
unlock = nil
account, err := am.requestBuffer.GetAccountWithBackpressure(ctx, accountID)
if err != nil {
return nil, nil, nil, fmt.Errorf("error getting account: %w", err)
}
am.updateAccountPeers(ctx, account) am.updateAccountPeers(ctx, account)
approvedPeersMap, err := am.GetValidatedPeers(account) approvedPeersMap, err := am.GetValidatedPeers(account)
@ -536,12 +548,31 @@ func (am *DefaultAccountManager) AddPeer(ctx context.Context, setupKey, userID s
return nil, nil, nil, err return nil, nil, nil, err
} }
postureChecks := am.getPeerPostureChecks(account, peer) postureChecks := am.getPeerPostureChecks(account, newPeer)
customZone := account.GetPeersCustomZone(ctx, am.dnsDomain) customZone := account.GetPeersCustomZone(ctx, am.dnsDomain)
networkMap := account.GetPeerNetworkMap(ctx, newPeer.ID, customZone, approvedPeersMap, am.metrics.AccountManagerMetrics()) networkMap := account.GetPeerNetworkMap(ctx, newPeer.ID, customZone, approvedPeersMap, am.metrics.AccountManagerMetrics())
return newPeer, networkMap, postureChecks, nil return newPeer, networkMap, postureChecks, nil
} }
func (am *DefaultAccountManager) getFreeIP(ctx context.Context, store Store, accountID string) (net.IP, error) {
takenIps, err := store.GetTakenIPs(ctx, LockingStrengthShare, accountID)
if err != nil {
return nil, fmt.Errorf("failed to get taken IPs: %w", err)
}
network, err := store.GetAccountNetwork(ctx, LockingStrengthUpdate, accountID)
if err != nil {
return nil, fmt.Errorf("failed getting network: %w", err)
}
nextIp, err := AllocatePeerIP(network.Net, takenIps)
if err != nil {
return nil, fmt.Errorf("failed to allocate new peer ip: %w", err)
}
return nextIp, nil
}
// SyncPeer checks whether peer is eligible for receiving NetworkMap (authenticated) and returns its NetworkMap if eligible // SyncPeer checks whether peer is eligible for receiving NetworkMap (authenticated) and returns its NetworkMap if eligible
func (am *DefaultAccountManager) SyncPeer(ctx context.Context, sync PeerSync, account *Account) (*nbpeer.Peer, *NetworkMap, []*posture.Checks, error) { func (am *DefaultAccountManager) SyncPeer(ctx context.Context, sync PeerSync, account *Account) (*nbpeer.Peer, *NetworkMap, []*posture.Checks, error) {
peer, err := account.FindPeerByPubKey(sync.WireGuardPubKey) peer, err := account.FindPeerByPubKey(sync.WireGuardPubKey)
@ -647,12 +678,12 @@ func (am *DefaultAccountManager) LoginPeer(ctx context.Context, login PeerLogin)
} }
}() }()
peer, err := am.Store.GetPeerByPeerPubKey(ctx, login.WireGuardPubKey) peer, err := am.Store.GetPeerByPeerPubKey(ctx, LockingStrengthUpdate, login.WireGuardPubKey)
if err != nil { if err != nil {
return nil, nil, nil, err return nil, nil, nil, err
} }
settings, err := am.Store.GetAccountSettings(ctx, accountID) settings, err := am.Store.GetAccountSettings(ctx, LockingStrengthShare, accountID)
if err != nil { if err != nil {
return nil, nil, nil, err return nil, nil, nil, err
} }
@ -730,7 +761,7 @@ func (am *DefaultAccountManager) LoginPeer(ctx context.Context, login PeerLogin)
// with no JWT token and usually no setup-key. As the client can send up to two login request to check if it is expired // with no JWT token and usually no setup-key. As the client can send up to two login request to check if it is expired
// and before starting the engine, we do the checks without an account lock to avoid piling up requests. // and before starting the engine, we do the checks without an account lock to avoid piling up requests.
func (am *DefaultAccountManager) checkIFPeerNeedsLoginWithoutLock(ctx context.Context, accountID string, login PeerLogin) error { func (am *DefaultAccountManager) checkIFPeerNeedsLoginWithoutLock(ctx context.Context, accountID string, login PeerLogin) error {
peer, err := am.Store.GetPeerByPeerPubKey(ctx, login.WireGuardPubKey) peer, err := am.Store.GetPeerByPeerPubKey(ctx, LockingStrengthShare, login.WireGuardPubKey)
if err != nil { if err != nil {
return err return err
} }
@ -741,7 +772,7 @@ func (am *DefaultAccountManager) checkIFPeerNeedsLoginWithoutLock(ctx context.Co
return nil return nil
} }
settings, err := am.Store.GetAccountSettings(ctx, accountID) settings, err := am.Store.GetAccountSettings(ctx, LockingStrengthShare, accountID)
if err != nil { if err != nil {
return err return err
} }
@ -786,7 +817,7 @@ func (am *DefaultAccountManager) handleExpiredPeer(ctx context.Context, user *Us
return err return err
} }
err = am.Store.SaveUserLastLogin(user.AccountID, user.Id, peer.LastLogin) err = am.Store.SaveUserLastLogin(ctx, user.AccountID, user.Id, peer.LastLogin)
if err != nil { if err != nil {
return err return err
} }
@ -969,3 +1000,11 @@ func (am *DefaultAccountManager) updateAccountPeers(ctx context.Context, account
wg.Wait() wg.Wait()
} }
func ConvertSliceToMap(existingLabels []string) map[string]struct{} {
labelMap := make(map[string]struct{}, len(existingLabels))
for _, label := range existingLabels {
labelMap[label] = struct{}{}
}
return labelMap
}

View File

@ -7,20 +7,24 @@ import (
"net" "net"
"net/netip" "net/netip"
"os" "os"
"runtime"
"testing" "testing"
"time" "time"
"github.com/rs/xid" "github.com/rs/xid"
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"golang.zx2c4.com/wireguard/wgctrl/wgtypes" "golang.zx2c4.com/wireguard/wgctrl/wgtypes"
nbdns "github.com/netbirdio/netbird/dns" nbdns "github.com/netbirdio/netbird/dns"
"github.com/netbirdio/netbird/management/domain" "github.com/netbirdio/netbird/management/domain"
"github.com/netbirdio/netbird/management/proto" "github.com/netbirdio/netbird/management/proto"
"github.com/netbirdio/netbird/management/server/activity"
nbgroup "github.com/netbirdio/netbird/management/server/group" nbgroup "github.com/netbirdio/netbird/management/server/group"
nbpeer "github.com/netbirdio/netbird/management/server/peer" nbpeer "github.com/netbirdio/netbird/management/server/peer"
"github.com/netbirdio/netbird/management/server/posture" "github.com/netbirdio/netbird/management/server/posture"
"github.com/netbirdio/netbird/management/server/telemetry"
nbroute "github.com/netbirdio/netbird/route" nbroute "github.com/netbirdio/netbird/route"
) )
@ -995,3 +999,184 @@ func TestToSyncResponse(t *testing.T) {
assert.Equal(t, 1, len(response.Checks)) assert.Equal(t, 1, len(response.Checks))
assert.Equal(t, "/usr/bin/netbird", response.Checks[0].Files[0]) assert.Equal(t, "/usr/bin/netbird", response.Checks[0].Files[0])
} }
func Test_RegisterPeerByUser(t *testing.T) {
if runtime.GOOS == "windows" {
t.Skip("The SQLite store is not properly supported by Windows yet")
}
store := newSqliteStoreFromFile(t, "testdata/extended-store.json")
eventStore := &activity.InMemoryEventStore{}
metrics, err := telemetry.NewDefaultAppMetrics(context.Background())
assert.NoError(t, err)
am, err := BuildManager(context.Background(), store, NewPeersUpdateManager(nil), nil, "", "netbird.cloud", eventStore, nil, false, MocIntegratedValidator{}, metrics)
assert.NoError(t, err)
existingAccountID := "bf1c8084-ba50-4ce7-9439-34653001fc3b"
existingUserID := "edafee4e-63fb-11ec-90d6-0242ac120003"
_, err = store.GetAccount(context.Background(), existingAccountID)
require.NoError(t, err)
newPeer := &nbpeer.Peer{
ID: xid.New().String(),
AccountID: existingAccountID,
Key: "newPeerKey",
SetupKey: "",
IP: net.IP{123, 123, 123, 123},
Meta: nbpeer.PeerSystemMeta{
Hostname: "newPeer",
GoOS: "linux",
},
Name: "newPeerName",
DNSLabel: "newPeer.test",
UserID: existingUserID,
Status: &nbpeer.PeerStatus{Connected: false, LastSeen: time.Now()},
SSHEnabled: false,
LastLogin: time.Now(),
}
addedPeer, _, _, err := am.AddPeer(context.Background(), "", existingUserID, newPeer)
require.NoError(t, err)
peer, err := store.GetPeerByPeerPubKey(context.Background(), LockingStrengthShare, addedPeer.Key)
require.NoError(t, err)
assert.Equal(t, peer.AccountID, existingAccountID)
assert.Equal(t, peer.UserID, existingUserID)
account, err := store.GetAccount(context.Background(), existingAccountID)
require.NoError(t, err)
assert.Contains(t, account.Peers, addedPeer.ID)
assert.Equal(t, peer.Meta.Hostname, newPeer.Meta.Hostname)
assert.Contains(t, account.Groups["cfefqs706sqkneg59g3g"].Peers, addedPeer.ID)
assert.Contains(t, account.Groups["cfefqs706sqkneg59g4g"].Peers, addedPeer.ID)
assert.Equal(t, uint64(1), account.Network.Serial)
lastLogin, err := time.Parse("2006-01-02T15:04:05Z", "0001-01-01T00:00:00Z")
assert.NoError(t, err)
assert.NotEqual(t, lastLogin, account.Users[existingUserID].LastLogin)
}
func Test_RegisterPeerBySetupKey(t *testing.T) {
if runtime.GOOS == "windows" {
t.Skip("The SQLite store is not properly supported by Windows yet")
}
store := newSqliteStoreFromFile(t, "testdata/extended-store.json")
eventStore := &activity.InMemoryEventStore{}
metrics, err := telemetry.NewDefaultAppMetrics(context.Background())
assert.NoError(t, err)
am, err := BuildManager(context.Background(), store, NewPeersUpdateManager(nil), nil, "", "netbird.cloud", eventStore, nil, false, MocIntegratedValidator{}, metrics)
assert.NoError(t, err)
existingAccountID := "bf1c8084-ba50-4ce7-9439-34653001fc3b"
existingSetupKeyID := "A2C8E62B-38F5-4553-B31E-DD66C696CEBB"
_, err = store.GetAccount(context.Background(), existingAccountID)
require.NoError(t, err)
newPeer := &nbpeer.Peer{
ID: xid.New().String(),
AccountID: existingAccountID,
Key: "newPeerKey",
SetupKey: "existingSetupKey",
UserID: "",
IP: net.IP{123, 123, 123, 123},
Meta: nbpeer.PeerSystemMeta{
Hostname: "newPeer",
GoOS: "linux",
},
Name: "newPeerName",
DNSLabel: "newPeer.test",
Status: &nbpeer.PeerStatus{Connected: false, LastSeen: time.Now()},
SSHEnabled: false,
}
addedPeer, _, _, err := am.AddPeer(context.Background(), existingSetupKeyID, "", newPeer)
require.NoError(t, err)
peer, err := store.GetPeerByPeerPubKey(context.Background(), LockingStrengthShare, newPeer.Key)
require.NoError(t, err)
assert.Equal(t, peer.AccountID, existingAccountID)
assert.Equal(t, peer.SetupKey, existingSetupKeyID)
account, err := store.GetAccount(context.Background(), existingAccountID)
require.NoError(t, err)
assert.Contains(t, account.Peers, addedPeer.ID)
assert.Contains(t, account.Groups["cfefqs706sqkneg59g2g"].Peers, addedPeer.ID)
assert.Contains(t, account.Groups["cfefqs706sqkneg59g4g"].Peers, addedPeer.ID)
assert.Equal(t, uint64(1), account.Network.Serial)
lastUsed, err := time.Parse("2006-01-02T15:04:05Z", "0001-01-01T00:00:00Z")
assert.NoError(t, err)
assert.NotEqual(t, lastUsed, account.SetupKeys[existingSetupKeyID].LastUsed)
assert.Equal(t, 1, account.SetupKeys[existingSetupKeyID].UsedTimes)
}
func Test_RegisterPeerRollbackOnFailure(t *testing.T) {
if runtime.GOOS == "windows" {
t.Skip("The SQLite store is not properly supported by Windows yet")
}
store := newSqliteStoreFromFile(t, "testdata/extended-store.json")
eventStore := &activity.InMemoryEventStore{}
metrics, err := telemetry.NewDefaultAppMetrics(context.Background())
assert.NoError(t, err)
am, err := BuildManager(context.Background(), store, NewPeersUpdateManager(nil), nil, "", "netbird.cloud", eventStore, nil, false, MocIntegratedValidator{}, metrics)
assert.NoError(t, err)
existingAccountID := "bf1c8084-ba50-4ce7-9439-34653001fc3b"
faultyKey := "A2C8E62B-38F5-4553-B31E-DD66C696CEBC"
_, err = store.GetAccount(context.Background(), existingAccountID)
require.NoError(t, err)
newPeer := &nbpeer.Peer{
ID: xid.New().String(),
AccountID: existingAccountID,
Key: "newPeerKey",
SetupKey: "existingSetupKey",
UserID: "",
IP: net.IP{123, 123, 123, 123},
Meta: nbpeer.PeerSystemMeta{
Hostname: "newPeer",
GoOS: "linux",
},
Name: "newPeerName",
DNSLabel: "newPeer.test",
Status: &nbpeer.PeerStatus{Connected: false, LastSeen: time.Now()},
SSHEnabled: false,
}
_, _, _, err = am.AddPeer(context.Background(), faultyKey, "", newPeer)
require.Error(t, err)
_, err = store.GetPeerByPeerPubKey(context.Background(), LockingStrengthShare, newPeer.Key)
require.Error(t, err)
account, err := store.GetAccount(context.Background(), existingAccountID)
require.NoError(t, err)
assert.NotContains(t, account.Peers, newPeer.ID)
assert.NotContains(t, account.Groups["cfefqs706sqkneg59g3g"].Peers, newPeer.ID)
assert.NotContains(t, account.Groups["cfefqs706sqkneg59g4g"].Peers, newPeer.ID)
assert.Equal(t, uint64(0), account.Network.Serial)
lastUsed, err := time.Parse("2006-01-02T15:04:05Z", "0001-01-01T00:00:00Z")
assert.NoError(t, err)
assert.Equal(t, lastUsed, account.SetupKeys[faultyKey].LastUsed)
assert.Equal(t, 0, account.SetupKeys[faultyKey].UsedTimes)
}

View File

@ -5,6 +5,7 @@ import (
"encoding/json" "encoding/json"
"errors" "errors"
"fmt" "fmt"
"net"
"os" "os"
"path/filepath" "path/filepath"
"runtime" "runtime"
@ -33,6 +34,7 @@ import (
const ( const (
storeSqliteFileName = "store.db" storeSqliteFileName = "store.db"
idQueryCondition = "id = ?" idQueryCondition = "id = ?"
keyQueryCondition = "key = ?"
accountAndIDQueryCondition = "account_id = ? and id = ?" accountAndIDQueryCondition = "account_id = ? and id = ?"
peerNotFoundFMT = "peer %s not found" peerNotFoundFMT = "peer %s not found"
) )
@ -415,13 +417,12 @@ func (s *SqlStore) GetAccountByPrivateDomain(ctx context.Context, domain string)
func (s *SqlStore) GetAccountBySetupKey(ctx context.Context, setupKey string) (*Account, error) { func (s *SqlStore) GetAccountBySetupKey(ctx context.Context, setupKey string) (*Account, error) {
var key SetupKey var key SetupKey
result := s.db.Select("account_id").First(&key, "key = ?", strings.ToUpper(setupKey)) result := s.db.WithContext(ctx).Select("account_id").First(&key, keyQueryCondition, strings.ToUpper(setupKey))
if result.Error != nil { if result.Error != nil {
if errors.Is(result.Error, gorm.ErrRecordNotFound) { if errors.Is(result.Error, gorm.ErrRecordNotFound) {
return nil, status.Errorf(status.NotFound, "account not found: index lookup failed") return nil, status.Errorf(status.NotFound, "account not found: index lookup failed")
} }
log.WithContext(ctx).Errorf("error when getting setup key from the store: %s", result.Error) return nil, status.NewSetupKeyNotFoundError()
return nil, status.Errorf(status.Internal, "issue getting setup key from store")
} }
if key.AccountID == "" { if key.AccountID == "" {
@ -474,15 +475,15 @@ func (s *SqlStore) GetUserByTokenID(ctx context.Context, tokenID string) (*User,
return &user, nil return &user, nil
} }
func (s *SqlStore) GetUserByUserID(ctx context.Context, userID string) (*User, error) { func (s *SqlStore) GetUserByUserID(ctx context.Context, lockStrength LockingStrength, userID string) (*User, error) {
var user User var user User
result := s.db.First(&user, idQueryCondition, userID) result := s.db.WithContext(ctx).Clauses(clause.Locking{Strength: string(lockStrength)}).
First(&user, idQueryCondition, userID)
if result.Error != nil { if result.Error != nil {
if errors.Is(result.Error, gorm.ErrRecordNotFound) { if errors.Is(result.Error, gorm.ErrRecordNotFound) {
return nil, status.Errorf(status.NotFound, "user not found: index lookup failed") return nil, status.NewUserNotFoundError(userID)
} }
log.WithContext(ctx).Errorf("error when getting user from the store: %s", result.Error) return nil, status.NewGetUserFromStoreError()
return nil, status.Errorf(status.Internal, "issue getting user from store")
} }
return &user, nil return &user, nil
@ -535,7 +536,7 @@ func (s *SqlStore) GetAccount(ctx context.Context, accountID string) (*Account,
if result.Error != nil { if result.Error != nil {
log.WithContext(ctx).Errorf("error when getting account %s from the store: %s", accountID, result.Error) log.WithContext(ctx).Errorf("error when getting account %s from the store: %s", accountID, result.Error)
if errors.Is(result.Error, gorm.ErrRecordNotFound) { if errors.Is(result.Error, gorm.ErrRecordNotFound) {
return nil, status.Errorf(status.NotFound, "account not found") return nil, status.NewAccountNotFoundError(accountID)
} }
return nil, status.Errorf(status.Internal, "issue getting account from store") return nil, status.Errorf(status.Internal, "issue getting account from store")
} }
@ -595,7 +596,7 @@ func (s *SqlStore) GetAccount(ctx context.Context, accountID string) (*Account,
func (s *SqlStore) GetAccountByUser(ctx context.Context, userID string) (*Account, error) { func (s *SqlStore) GetAccountByUser(ctx context.Context, userID string) (*Account, error) {
var user User var user User
result := s.db.Select("account_id").First(&user, idQueryCondition, userID) result := s.db.WithContext(ctx).Select("account_id").First(&user, idQueryCondition, userID)
if result.Error != nil { if result.Error != nil {
if errors.Is(result.Error, gorm.ErrRecordNotFound) { if errors.Is(result.Error, gorm.ErrRecordNotFound) {
return nil, status.Errorf(status.NotFound, "account not found: index lookup failed") return nil, status.Errorf(status.NotFound, "account not found: index lookup failed")
@ -612,12 +613,11 @@ func (s *SqlStore) GetAccountByUser(ctx context.Context, userID string) (*Accoun
func (s *SqlStore) GetAccountByPeerID(ctx context.Context, peerID string) (*Account, error) { func (s *SqlStore) GetAccountByPeerID(ctx context.Context, peerID string) (*Account, error) {
var peer nbpeer.Peer var peer nbpeer.Peer
result := s.db.Select("account_id").First(&peer, idQueryCondition, peerID) result := s.db.WithContext(ctx).Select("account_id").First(&peer, idQueryCondition, peerID)
if result.Error != nil { if result.Error != nil {
if errors.Is(result.Error, gorm.ErrRecordNotFound) { if errors.Is(result.Error, gorm.ErrRecordNotFound) {
return nil, status.Errorf(status.NotFound, "account not found: index lookup failed") return nil, status.Errorf(status.NotFound, "account not found: index lookup failed")
} }
log.WithContext(ctx).Errorf("error when getting peer from the store: %s", result.Error)
return nil, status.Errorf(status.Internal, "issue getting account from store") return nil, status.Errorf(status.Internal, "issue getting account from store")
} }
@ -631,12 +631,11 @@ func (s *SqlStore) GetAccountByPeerID(ctx context.Context, peerID string) (*Acco
func (s *SqlStore) GetAccountByPeerPubKey(ctx context.Context, peerKey string) (*Account, error) { func (s *SqlStore) GetAccountByPeerPubKey(ctx context.Context, peerKey string) (*Account, error) {
var peer nbpeer.Peer var peer nbpeer.Peer
result := s.db.Select("account_id").First(&peer, "key = ?", peerKey) result := s.db.WithContext(ctx).Select("account_id").First(&peer, keyQueryCondition, peerKey)
if result.Error != nil { if result.Error != nil {
if errors.Is(result.Error, gorm.ErrRecordNotFound) { if errors.Is(result.Error, gorm.ErrRecordNotFound) {
return nil, status.Errorf(status.NotFound, "account not found: index lookup failed") return nil, status.Errorf(status.NotFound, "account not found: index lookup failed")
} }
log.WithContext(ctx).Errorf("error when getting peer from the store: %s", result.Error)
return nil, status.Errorf(status.Internal, "issue getting account from store") return nil, status.Errorf(status.Internal, "issue getting account from store")
} }
@ -650,12 +649,11 @@ func (s *SqlStore) GetAccountByPeerPubKey(ctx context.Context, peerKey string) (
func (s *SqlStore) GetAccountIDByPeerPubKey(ctx context.Context, peerKey string) (string, error) { func (s *SqlStore) GetAccountIDByPeerPubKey(ctx context.Context, peerKey string) (string, error) {
var peer nbpeer.Peer var peer nbpeer.Peer
var accountID string var accountID string
result := s.db.Model(&peer).Select("account_id").Where("key = ?", peerKey).First(&accountID) result := s.db.WithContext(ctx).Model(&peer).Select("account_id").Where(keyQueryCondition, peerKey).First(&accountID)
if result.Error != nil { if result.Error != nil {
if errors.Is(result.Error, gorm.ErrRecordNotFound) { if errors.Is(result.Error, gorm.ErrRecordNotFound) {
return "", status.Errorf(status.NotFound, "account not found: index lookup failed") return "", status.Errorf(status.NotFound, "account not found: index lookup failed")
} }
log.WithContext(ctx).Errorf("error when getting peer from the store: %s", result.Error)
return "", status.Errorf(status.Internal, "issue getting account from store") return "", status.Errorf(status.Internal, "issue getting account from store")
} }
@ -677,61 +675,117 @@ func (s *SqlStore) GetAccountIDByUserID(userID string) (string, error) {
} }
func (s *SqlStore) GetAccountIDBySetupKey(ctx context.Context, setupKey string) (string, error) { func (s *SqlStore) GetAccountIDBySetupKey(ctx context.Context, setupKey string) (string, error) {
var key SetupKey
var accountID string var accountID string
result := s.db.Model(&key).Select("account_id").Where("key = ?", strings.ToUpper(setupKey)).First(&accountID) result := s.db.WithContext(ctx).Model(&SetupKey{}).Select("account_id").Where(keyQueryCondition, strings.ToUpper(setupKey)).First(&accountID)
if result.Error != nil { if result.Error != nil {
if errors.Is(result.Error, gorm.ErrRecordNotFound) { if errors.Is(result.Error, gorm.ErrRecordNotFound) {
return "", status.Errorf(status.NotFound, "account not found: index lookup failed") return "", status.Errorf(status.NotFound, "account not found: index lookup failed")
} }
log.WithContext(ctx).Errorf("error when getting setup key from the store: %s", result.Error) return "", status.NewSetupKeyNotFoundError()
return "", status.Errorf(status.Internal, "issue getting setup key from store") }
if accountID == "" {
return "", status.Errorf(status.NotFound, "account not found: index lookup failed")
} }
return accountID, nil return accountID, nil
} }
func (s *SqlStore) GetPeerByPeerPubKey(ctx context.Context, peerKey string) (*nbpeer.Peer, error) { func (s *SqlStore) GetTakenIPs(ctx context.Context, lockStrength LockingStrength, accountID string) ([]net.IP, error) {
var ipJSONStrings []string
// Fetch the IP addresses as JSON strings
result := s.db.WithContext(ctx).Clauses(clause.Locking{Strength: string(lockStrength)}).Model(&nbpeer.Peer{}).
Where("account_id = ?", accountID).
Pluck("ip", &ipJSONStrings)
if result.Error != nil {
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
return nil, status.Errorf(status.NotFound, "no peers found for the account")
}
return nil, status.Errorf(status.Internal, "issue getting IPs from store")
}
// Convert the JSON strings to net.IP objects
ips := make([]net.IP, len(ipJSONStrings))
for i, ipJSON := range ipJSONStrings {
var ip net.IP
if err := json.Unmarshal([]byte(ipJSON), &ip); err != nil {
return nil, status.Errorf(status.Internal, "issue parsing IP JSON from store")
}
ips[i] = ip
}
return ips, nil
}
func (s *SqlStore) GetPeerLabelsInAccount(ctx context.Context, lockStrength LockingStrength, accountID string) ([]string, error) {
var labels []string
result := s.db.WithContext(ctx).Clauses(clause.Locking{Strength: string(lockStrength)}).Model(&nbpeer.Peer{}).
Where("account_id = ?", accountID).
Pluck("dns_label", &labels)
if result.Error != nil {
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
return nil, status.Errorf(status.NotFound, "no peers found for the account")
}
log.WithContext(ctx).Errorf("error when getting dns labels from the store: %s", result.Error)
return nil, status.Errorf(status.Internal, "issue getting dns labels from store")
}
return labels, nil
}
func (s *SqlStore) GetAccountNetwork(ctx context.Context, lockStrength LockingStrength, accountID string) (*Network, error) {
var accountNetwork AccountNetwork
if err := s.db.WithContext(ctx).Model(&Account{}).Where(idQueryCondition, accountID).First(&accountNetwork).Error; err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return nil, status.NewAccountNotFoundError(accountID)
}
return nil, status.Errorf(status.Internal, "issue getting network from store")
}
return accountNetwork.Network, nil
}
func (s *SqlStore) GetPeerByPeerPubKey(ctx context.Context, lockStrength LockingStrength, peerKey string) (*nbpeer.Peer, error) {
var peer nbpeer.Peer var peer nbpeer.Peer
result := s.db.First(&peer, "key = ?", peerKey) result := s.db.WithContext(ctx).Clauses(clause.Locking{Strength: string(lockStrength)}).First(&peer, keyQueryCondition, peerKey)
if result.Error != nil { if result.Error != nil {
if errors.Is(result.Error, gorm.ErrRecordNotFound) { if errors.Is(result.Error, gorm.ErrRecordNotFound) {
return nil, status.Errorf(status.NotFound, "peer not found") return nil, status.Errorf(status.NotFound, "peer not found")
} }
log.WithContext(ctx).Errorf("error when getting peer from the store: %s", result.Error)
return nil, status.Errorf(status.Internal, "issue getting peer from store") return nil, status.Errorf(status.Internal, "issue getting peer from store")
} }
return &peer, nil return &peer, nil
} }
func (s *SqlStore) GetAccountSettings(ctx context.Context, accountID string) (*Settings, error) { func (s *SqlStore) GetAccountSettings(ctx context.Context, lockStrength LockingStrength, accountID string) (*Settings, error) {
var accountSettings AccountSettings var accountSettings AccountSettings
if err := s.db.Model(&Account{}).Where(idQueryCondition, accountID).First(&accountSettings).Error; err != nil { if err := s.db.WithContext(ctx).Clauses(clause.Locking{Strength: string(lockStrength)}).Model(&Account{}).Where(idQueryCondition, accountID).First(&accountSettings).Error; err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) { if errors.Is(err, gorm.ErrRecordNotFound) {
return nil, status.Errorf(status.NotFound, "settings not found") return nil, status.Errorf(status.NotFound, "settings not found")
} }
log.WithContext(ctx).Errorf("error when getting settings from the store: %s", err)
return nil, status.Errorf(status.Internal, "issue getting settings from store") return nil, status.Errorf(status.Internal, "issue getting settings from store")
} }
return accountSettings.Settings, nil return accountSettings.Settings, nil
} }
// SaveUserLastLogin stores the last login time for a user in DB. // SaveUserLastLogin stores the last login time for a user in DB.
func (s *SqlStore) SaveUserLastLogin(accountID, userID string, lastLogin time.Time) error { func (s *SqlStore) SaveUserLastLogin(ctx context.Context, accountID, userID string, lastLogin time.Time) error {
var user User var user User
result := s.db.First(&user, accountAndIDQueryCondition, accountID, userID) result := s.db.WithContext(ctx).First(&user, accountAndIDQueryCondition, accountID, userID)
if result.Error != nil { if result.Error != nil {
if errors.Is(result.Error, gorm.ErrRecordNotFound) { if errors.Is(result.Error, gorm.ErrRecordNotFound) {
return status.Errorf(status.NotFound, "user %s not found", userID) return status.NewUserNotFoundError(userID)
} }
return status.Errorf(status.Internal, "issue getting user from store") return status.NewGetUserFromStoreError()
} }
user.LastLogin = lastLogin user.LastLogin = lastLogin
return s.db.Save(user).Error return s.db.Save(&user).Error
} }
func (s *SqlStore) GetPostureCheckByChecksDefinition(accountID string, checks *posture.ChecksDefinition) (*posture.Checks, error) { func (s *SqlStore) GetPostureCheckByChecksDefinition(accountID string, checks *posture.ChecksDefinition) (*posture.Checks, error) {
@ -850,3 +904,123 @@ func NewPostgresqlStoreFromFileStore(ctx context.Context, fileStore *FileStore,
return store, nil return store, nil
} }
func (s *SqlStore) GetSetupKeyBySecret(ctx context.Context, lockStrength LockingStrength, key string) (*SetupKey, error) {
var setupKey SetupKey
result := s.db.WithContext(ctx).Clauses(clause.Locking{Strength: string(lockStrength)}).
First(&setupKey, keyQueryCondition, strings.ToUpper(key))
if result.Error != nil {
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
return nil, status.Errorf(status.NotFound, "setup key not found")
}
return nil, status.NewSetupKeyNotFoundError()
}
return &setupKey, nil
}
func (s *SqlStore) IncrementSetupKeyUsage(ctx context.Context, setupKeyID string) error {
result := s.db.WithContext(ctx).Model(&SetupKey{}).
Where(idQueryCondition, setupKeyID).
Updates(map[string]interface{}{
"used_times": gorm.Expr("used_times + 1"),
"last_used": time.Now(),
})
if result.Error != nil {
return status.Errorf(status.Internal, "issue incrementing setup key usage count: %s", result.Error)
}
if result.RowsAffected == 0 {
return status.Errorf(status.NotFound, "setup key not found")
}
return nil
}
func (s *SqlStore) AddPeerToAllGroup(ctx context.Context, accountID string, peerID string) error {
var group nbgroup.Group
result := s.db.WithContext(ctx).Where("account_id = ? AND name = ?", accountID, "All").First(&group)
if result.Error != nil {
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
return status.Errorf(status.NotFound, "group 'All' not found for account")
}
return status.Errorf(status.Internal, "issue finding group 'All'")
}
for _, existingPeerID := range group.Peers {
if existingPeerID == peerID {
return nil
}
}
group.Peers = append(group.Peers, peerID)
if err := s.db.Save(&group).Error; err != nil {
return status.Errorf(status.Internal, "issue updating group 'All'")
}
return nil
}
func (s *SqlStore) AddPeerToGroup(ctx context.Context, accountId string, peerId string, groupID string) error {
var group nbgroup.Group
result := s.db.WithContext(ctx).Where(accountAndIDQueryCondition, accountId, groupID).First(&group)
if result.Error != nil {
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
return status.Errorf(status.NotFound, "group not found for account")
}
return status.Errorf(status.Internal, "issue finding group")
}
for _, existingPeerID := range group.Peers {
if existingPeerID == peerId {
return nil
}
}
group.Peers = append(group.Peers, peerId)
if err := s.db.Save(&group).Error; err != nil {
return status.Errorf(status.Internal, "issue updating group")
}
return nil
}
func (s *SqlStore) AddPeerToAccount(ctx context.Context, peer *nbpeer.Peer) error {
if err := s.db.WithContext(ctx).Create(peer).Error; err != nil {
return status.Errorf(status.Internal, "issue adding peer to account")
}
return nil
}
func (s *SqlStore) IncrementNetworkSerial(ctx context.Context, accountId string) error {
result := s.db.WithContext(ctx).Model(&Account{}).Where(idQueryCondition, accountId).Update("network_serial", gorm.Expr("network_serial + 1"))
if result.Error != nil {
return status.Errorf(status.Internal, "issue incrementing network serial count")
}
return nil
}
func (s *SqlStore) ExecuteInTransaction(ctx context.Context, operation func(store Store) error) error {
tx := s.db.WithContext(ctx).Begin()
if tx.Error != nil {
return tx.Error
}
repo := s.withTx(tx)
err := operation(repo)
if err != nil {
tx.Rollback()
return err
}
return tx.Commit().Error
}
func (s *SqlStore) withTx(tx *gorm.DB) Store {
return &SqlStore{
db: tx,
}
}

View File

@ -1003,3 +1003,163 @@ func TestPostgresql_GetUserByTokenID(t *testing.T) {
require.NoError(t, err) require.NoError(t, err)
require.Equal(t, id, user.PATs[id].ID) require.Equal(t, id, user.PATs[id].ID)
} }
func TestSqlite_GetTakenIPs(t *testing.T) {
if runtime.GOOS == "windows" {
t.Skip("The SQLite store is not properly supported by Windows yet")
}
store := newSqliteStoreFromFile(t, "testdata/extended-store.json")
defer store.Close(context.Background())
existingAccountID := "bf1c8084-ba50-4ce7-9439-34653001fc3b"
_, err := store.GetAccount(context.Background(), existingAccountID)
require.NoError(t, err)
takenIPs, err := store.GetTakenIPs(context.Background(), LockingStrengthShare, existingAccountID)
require.NoError(t, err)
assert.Equal(t, []net.IP{}, takenIPs)
peer1 := &nbpeer.Peer{
ID: "peer1",
AccountID: existingAccountID,
IP: net.IP{1, 1, 1, 1},
}
err = store.AddPeerToAccount(context.Background(), peer1)
require.NoError(t, err)
takenIPs, err = store.GetTakenIPs(context.Background(), LockingStrengthShare, existingAccountID)
require.NoError(t, err)
ip1 := net.IP{1, 1, 1, 1}.To16()
assert.Equal(t, []net.IP{ip1}, takenIPs)
peer2 := &nbpeer.Peer{
ID: "peer2",
AccountID: existingAccountID,
IP: net.IP{2, 2, 2, 2},
}
err = store.AddPeerToAccount(context.Background(), peer2)
require.NoError(t, err)
takenIPs, err = store.GetTakenIPs(context.Background(), LockingStrengthShare, existingAccountID)
require.NoError(t, err)
ip2 := net.IP{2, 2, 2, 2}.To16()
assert.Equal(t, []net.IP{ip1, ip2}, takenIPs)
}
func TestSqlite_GetPeerLabelsInAccount(t *testing.T) {
if runtime.GOOS == "windows" {
t.Skip("The SQLite store is not properly supported by Windows yet")
}
store := newSqliteStoreFromFile(t, "testdata/extended-store.json")
defer store.Close(context.Background())
existingAccountID := "bf1c8084-ba50-4ce7-9439-34653001fc3b"
_, err := store.GetAccount(context.Background(), existingAccountID)
require.NoError(t, err)
labels, err := store.GetPeerLabelsInAccount(context.Background(), LockingStrengthShare, existingAccountID)
require.NoError(t, err)
assert.Equal(t, []string{}, labels)
peer1 := &nbpeer.Peer{
ID: "peer1",
AccountID: existingAccountID,
DNSLabel: "peer1.domain.test",
}
err = store.AddPeerToAccount(context.Background(), peer1)
require.NoError(t, err)
labels, err = store.GetPeerLabelsInAccount(context.Background(), LockingStrengthShare, existingAccountID)
require.NoError(t, err)
assert.Equal(t, []string{"peer1.domain.test"}, labels)
peer2 := &nbpeer.Peer{
ID: "peer2",
AccountID: existingAccountID,
DNSLabel: "peer2.domain.test",
}
err = store.AddPeerToAccount(context.Background(), peer2)
require.NoError(t, err)
labels, err = store.GetPeerLabelsInAccount(context.Background(), LockingStrengthShare, existingAccountID)
require.NoError(t, err)
assert.Equal(t, []string{"peer1.domain.test", "peer2.domain.test"}, labels)
}
func TestSqlite_GetAccountNetwork(t *testing.T) {
if runtime.GOOS == "windows" {
t.Skip("The SQLite store is not properly supported by Windows yet")
}
store := newSqliteStoreFromFile(t, "testdata/extended-store.json")
defer store.Close(context.Background())
existingAccountID := "bf1c8084-ba50-4ce7-9439-34653001fc3b"
_, err := store.GetAccount(context.Background(), existingAccountID)
require.NoError(t, err)
network, err := store.GetAccountNetwork(context.Background(), LockingStrengthShare, existingAccountID)
require.NoError(t, err)
ip := net.IP{100, 64, 0, 0}.To16()
assert.Equal(t, ip, network.Net.IP)
assert.Equal(t, net.IPMask{255, 255, 0, 0}, network.Net.Mask)
assert.Equal(t, "", network.Dns)
assert.Equal(t, "af1c8024-ha40-4ce2-9418-34653101fc3c", network.Identifier)
assert.Equal(t, uint64(0), network.Serial)
}
func TestSqlite_GetSetupKeyBySecret(t *testing.T) {
if runtime.GOOS == "windows" {
t.Skip("The SQLite store is not properly supported by Windows yet")
}
store := newSqliteStoreFromFile(t, "testdata/extended-store.json")
defer store.Close(context.Background())
existingAccountID := "bf1c8084-ba50-4ce7-9439-34653001fc3b"
_, err := store.GetAccount(context.Background(), existingAccountID)
require.NoError(t, err)
setupKey, err := store.GetSetupKeyBySecret(context.Background(), LockingStrengthShare, "A2C8E62B-38F5-4553-B31E-DD66C696CEBB")
require.NoError(t, err)
assert.Equal(t, "A2C8E62B-38F5-4553-B31E-DD66C696CEBB", setupKey.Key)
assert.Equal(t, "bf1c8084-ba50-4ce7-9439-34653001fc3b", setupKey.AccountID)
assert.Equal(t, "Default key", setupKey.Name)
}
func TestSqlite_incrementSetupKeyUsage(t *testing.T) {
if runtime.GOOS == "windows" {
t.Skip("The SQLite store is not properly supported by Windows yet")
}
store := newSqliteStoreFromFile(t, "testdata/extended-store.json")
defer store.Close(context.Background())
existingAccountID := "bf1c8084-ba50-4ce7-9439-34653001fc3b"
_, err := store.GetAccount(context.Background(), existingAccountID)
require.NoError(t, err)
setupKey, err := store.GetSetupKeyBySecret(context.Background(), LockingStrengthShare, "A2C8E62B-38F5-4553-B31E-DD66C696CEBB")
require.NoError(t, err)
assert.Equal(t, 0, setupKey.UsedTimes)
err = store.IncrementSetupKeyUsage(context.Background(), setupKey.Id)
require.NoError(t, err)
setupKey, err = store.GetSetupKeyBySecret(context.Background(), LockingStrengthShare, "A2C8E62B-38F5-4553-B31E-DD66C696CEBB")
require.NoError(t, err)
assert.Equal(t, 1, setupKey.UsedTimes)
err = store.IncrementSetupKeyUsage(context.Background(), setupKey.Id)
require.NoError(t, err)
setupKey, err = store.GetSetupKeyBySecret(context.Background(), LockingStrengthShare, "A2C8E62B-38F5-4553-B31E-DD66C696CEBB")
require.NoError(t, err)
assert.Equal(t, 2, setupKey.UsedTimes)
}

View File

@ -100,3 +100,13 @@ func NewPeerNotRegisteredError() error {
func NewPeerLoginExpiredError() error { func NewPeerLoginExpiredError() error {
return Errorf(PermissionDenied, "peer login has expired, please log in once more") return Errorf(PermissionDenied, "peer login has expired, please log in once more")
} }
// NewSetupKeyNotFoundError creates a new Error with NotFound type for a missing setup key
func NewSetupKeyNotFoundError() error {
return Errorf(NotFound, "setup key not found")
}
// NewGetUserFromStoreError creates a new Error with Internal type for an issue getting user from store
func NewGetUserFromStoreError() error {
return Errorf(Internal, "issue getting user from store")
}

View File

@ -27,6 +27,15 @@ import (
"github.com/netbirdio/netbird/route" "github.com/netbirdio/netbird/route"
) )
type LockingStrength string
const (
LockingStrengthUpdate LockingStrength = "UPDATE" // Strongest lock, preventing any changes by other transactions until your transaction completes.
LockingStrengthShare LockingStrength = "SHARE" // Allows reading but prevents changes by other transactions.
LockingStrengthNoKeyUpdate LockingStrength = "NO KEY UPDATE" // Similar to UPDATE but allows changes to related rows.
LockingStrengthKeyShare LockingStrength = "KEY SHARE" // Protects against changes to primary/unique keys but allows other updates.
)
type Store interface { type Store interface {
GetAllAccounts(ctx context.Context) []*Account GetAllAccounts(ctx context.Context) []*Account
GetAccount(ctx context.Context, accountID string) (*Account, error) GetAccount(ctx context.Context, accountID string) (*Account, error)
@ -41,7 +50,7 @@ type Store interface {
GetAccountByPrivateDomain(ctx context.Context, domain string) (*Account, error) GetAccountByPrivateDomain(ctx context.Context, domain string) (*Account, error)
GetTokenIDByHashedToken(ctx context.Context, secret string) (string, error) GetTokenIDByHashedToken(ctx context.Context, secret string) (string, error)
GetUserByTokenID(ctx context.Context, tokenID string) (*User, error) GetUserByTokenID(ctx context.Context, tokenID string) (*User, error)
GetUserByUserID(ctx context.Context, userID string) (*User, error) GetUserByUserID(ctx context.Context, lockStrength LockingStrength, userID string) (*User, error)
GetAccountGroups(ctx context.Context, accountID string) ([]*nbgroup.Group, error) GetAccountGroups(ctx context.Context, accountID string) ([]*nbgroup.Group, error)
GetPostureCheckByChecksDefinition(accountID string, checks *posture.ChecksDefinition) (*posture.Checks, error) GetPostureCheckByChecksDefinition(accountID string, checks *posture.ChecksDefinition) (*posture.Checks, error)
SaveAccount(ctx context.Context, account *Account) error SaveAccount(ctx context.Context, account *Account) error
@ -60,14 +69,24 @@ type Store interface {
SavePeer(ctx context.Context, accountID string, peer *nbpeer.Peer) error SavePeer(ctx context.Context, accountID string, peer *nbpeer.Peer) error
SavePeerStatus(accountID, peerID string, status nbpeer.PeerStatus) error SavePeerStatus(accountID, peerID string, status nbpeer.PeerStatus) error
SavePeerLocation(accountID string, peer *nbpeer.Peer) error SavePeerLocation(accountID string, peer *nbpeer.Peer) error
SaveUserLastLogin(accountID, userID string, lastLogin time.Time) error SaveUserLastLogin(ctx context.Context, accountID, userID string, lastLogin time.Time) error
// Close should close the store persisting all unsaved data. // Close should close the store persisting all unsaved data.
Close(ctx context.Context) error Close(ctx context.Context) error
// GetStoreEngine should return StoreEngine of the current store implementation. // GetStoreEngine should return StoreEngine of the current store implementation.
// This is also a method of metrics.DataSource interface. // This is also a method of metrics.DataSource interface.
GetStoreEngine() StoreEngine GetStoreEngine() StoreEngine
GetPeerByPeerPubKey(ctx context.Context, peerKey string) (*nbpeer.Peer, error) GetPeerByPeerPubKey(ctx context.Context, lockStrength LockingStrength, peerKey string) (*nbpeer.Peer, error)
GetAccountSettings(ctx context.Context, accountID string) (*Settings, error) GetAccountSettings(ctx context.Context, lockStrength LockingStrength, accountID string) (*Settings, error)
GetSetupKeyBySecret(ctx context.Context, lockStrength LockingStrength, key string) (*SetupKey, error)
GetTakenIPs(ctx context.Context, lockStrength LockingStrength, accountId string) ([]net.IP, error)
IncrementSetupKeyUsage(ctx context.Context, setupKeyID string) error
AddPeerToAllGroup(ctx context.Context, accountID string, peerID string) error
GetPeerLabelsInAccount(ctx context.Context, lockStrength LockingStrength, accountId string) ([]string, error)
AddPeerToGroup(ctx context.Context, accountId string, peerId string, groupID string) error
AddPeerToAccount(ctx context.Context, peer *nbpeer.Peer) error
IncrementNetworkSerial(ctx context.Context, accountId string) error
GetAccountNetwork(ctx context.Context, lockStrength LockingStrength, accountId string) (*Network, error)
ExecuteInTransaction(ctx context.Context, f func(store Store) error) error
} }
type StoreEngine string type StoreEngine string

View File

@ -0,0 +1,120 @@
{
"Accounts": {
"bf1c8084-ba50-4ce7-9439-34653001fc3b": {
"Id": "bf1c8084-ba50-4ce7-9439-34653001fc3b",
"CreatedBy": "",
"Domain": "test.com",
"DomainCategory": "private",
"IsDomainPrimaryAccount": true,
"SetupKeys": {
"A2C8E62B-38F5-4553-B31E-DD66C696CEBB": {
"Id": "A2C8E62B-38F5-4553-B31E-DD66C696CEBB",
"AccountID": "",
"Key": "A2C8E62B-38F5-4553-B31E-DD66C696CEBB",
"Name": "Default key",
"Type": "reusable",
"CreatedAt": "2021-08-19T20:46:20.005936822+02:00",
"ExpiresAt": "2321-09-18T20:46:20.005936822+02:00",
"UpdatedAt": "0001-01-01T00:00:00Z",
"Revoked": false,
"UsedTimes": 0,
"LastUsed": "0001-01-01T00:00:00Z",
"AutoGroups": ["cfefqs706sqkneg59g2g"],
"UsageLimit": 0,
"Ephemeral": false
},
"A2C8E62B-38F5-4553-B31E-DD66C696CEBC": {
"Id": "A2C8E62B-38F5-4553-B31E-DD66C696CEBC",
"AccountID": "",
"Key": "A2C8E62B-38F5-4553-B31E-DD66C696CEBC",
"Name": "Faulty key with non existing group",
"Type": "reusable",
"CreatedAt": "2021-08-19T20:46:20.005936822+02:00",
"ExpiresAt": "2321-09-18T20:46:20.005936822+02:00",
"UpdatedAt": "0001-01-01T00:00:00Z",
"Revoked": false,
"UsedTimes": 0,
"LastUsed": "0001-01-01T00:00:00Z",
"AutoGroups": ["abcd"],
"UsageLimit": 0,
"Ephemeral": false
}
},
"Network": {
"id": "af1c8024-ha40-4ce2-9418-34653101fc3c",
"Net": {
"IP": "100.64.0.0",
"Mask": "//8AAA=="
},
"Dns": "",
"Serial": 0
},
"Peers": {},
"Users": {
"edafee4e-63fb-11ec-90d6-0242ac120003": {
"Id": "edafee4e-63fb-11ec-90d6-0242ac120003",
"AccountID": "",
"Role": "admin",
"IsServiceUser": false,
"ServiceUserName": "",
"AutoGroups": ["cfefqs706sqkneg59g3g"],
"PATs": {},
"Blocked": false,
"LastLogin": "0001-01-01T00:00:00Z"
},
"f4f6d672-63fb-11ec-90d6-0242ac120003": {
"Id": "f4f6d672-63fb-11ec-90d6-0242ac120003",
"AccountID": "",
"Role": "user",
"IsServiceUser": false,
"ServiceUserName": "",
"AutoGroups": null,
"PATs": {
"9dj38s35-63fb-11ec-90d6-0242ac120003": {
"ID": "9dj38s35-63fb-11ec-90d6-0242ac120003",
"UserID": "",
"Name": "",
"HashedToken": "SoMeHaShEdToKeN",
"ExpirationDate": "2023-02-27T00:00:00Z",
"CreatedBy": "user",
"CreatedAt": "2023-01-01T00:00:00Z",
"LastUsed": "2023-02-01T00:00:00Z"
}
},
"Blocked": false,
"LastLogin": "0001-01-01T00:00:00Z"
}
},
"Groups": {
"cfefqs706sqkneg59g4g": {
"ID": "cfefqs706sqkneg59g4g",
"Name": "All",
"Peers": []
},
"cfefqs706sqkneg59g3g": {
"ID": "cfefqs706sqkneg59g3g",
"Name": "AwesomeGroup1",
"Peers": []
},
"cfefqs706sqkneg59g2g": {
"ID": "cfefqs706sqkneg59g2g",
"Name": "AwesomeGroup2",
"Peers": []
}
},
"Rules": null,
"Policies": [],
"Routes": null,
"NameServerGroups": null,
"DNSSettings": null,
"Settings": {
"PeerLoginExpirationEnabled": false,
"PeerLoginExpiration": 86400000000000,
"GroupsPropagationEnabled": false,
"JWTGroupsEnabled": false,
"JWTGroupsClaimName": ""
}
}
},
"InstallationID": ""
}

View File

@ -89,10 +89,6 @@ func (u *User) LastDashboardLoginChanged(LastLogin time.Time) bool {
return LastLogin.After(u.LastLogin) && !u.LastLogin.IsZero() return LastLogin.After(u.LastLogin) && !u.LastLogin.IsZero()
} }
func (u *User) updateLastLogin(login time.Time) {
u.LastLogin = login
}
// HasAdminPower returns true if the user has admin or owner roles, false otherwise // HasAdminPower returns true if the user has admin or owner roles, false otherwise
func (u *User) HasAdminPower() bool { func (u *User) HasAdminPower() bool {
return u.Role == UserRoleAdmin || u.Role == UserRoleOwner return u.Role == UserRoleAdmin || u.Role == UserRoleOwner
@ -386,7 +382,7 @@ func (am *DefaultAccountManager) GetUser(ctx context.Context, claims jwtclaims.A
// server when user authenticates a device. And we need to separate the Dashboard login event from the Device login event. // server when user authenticates a device. And we need to separate the Dashboard login event from the Device login event.
newLogin := user.LastDashboardLoginChanged(claims.LastLogin) newLogin := user.LastDashboardLoginChanged(claims.LastLogin)
err = am.Store.SaveUserLastLogin(account.Id, claims.UserId, claims.LastLogin) err = am.Store.SaveUserLastLogin(ctx, account.Id, claims.UserId, claims.LastLogin)
if err != nil { if err != nil {
log.WithContext(ctx).Errorf("failed saving user last login: %v", err) log.WithContext(ctx).Errorf("failed saving user last login: %v", err)
} }