Introduce locking on the account level (#548)

This commit is contained in:
Misha Bragin 2022-11-07 17:52:23 +01:00 committed by GitHub
parent 1f845f466c
commit ed7ac81027
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
14 changed files with 200 additions and 166 deletions

View File

@ -38,7 +38,6 @@ func cacheEntryExpiration() time.Duration {
type AccountManager interface {
GetOrCreateAccountByUser(userId, domain string) (*Account, error)
GetAccountByUser(userId string) (*Account, error)
CreateSetupKey(
accountId string,
keyName string,
@ -51,8 +50,7 @@ type AccountManager interface {
ListSetupKeys(accountID, userID string) ([]*SetupKey, error)
SaveUser(accountID string, key *User) (*UserInfo, error)
GetSetupKey(accountID, userID, keyID string) (*SetupKey, error)
GetAccountById(accountId string) (*Account, error)
GetAccountByUserOrAccountId(userId, accountId, domain string) (*Account, error)
GetAccountByUserOrAccountID(userID, accountID, domain string) (*Account, error)
GetAccountFromToken(claims jwtclaims.AuthorizationClaims) (*Account, error)
IsUserAdmin(claims jwtclaims.AuthorizationClaims) (bool, error)
AccountExists(accountId string) (*bool, error)
@ -97,8 +95,6 @@ type AccountManager interface {
type DefaultAccountManager struct {
Store Store
// mux to synchronise account operations (e.g. generating Peer IP address inside the Network)
mux sync.Mutex
// cacheMux and cacheLoading helps to make sure that only a single cache reload runs at a time per accountID
cacheMux sync.Mutex
// cacheLoading keeps the accountIDs that are currently reloading. The accountID has to be removed once cache has been reloaded
@ -359,7 +355,6 @@ func BuildManager(store Store, peersUpdateManager *PeersUpdateManager, idpManage
singleAccountModeDomain string, dnsDomain string) (*DefaultAccountManager, error) {
am := &DefaultAccountManager{
Store: store,
mux: sync.Mutex{},
peersUpdateManager: peersUpdateManager,
idpManager: idpManager,
ctx: context.Background(),
@ -460,32 +455,17 @@ func (am *DefaultAccountManager) warmupIDPCache() error {
return nil
}
// GetAccountById returns an existing account using its ID or error (NotFound) if doesn't exist
func (am *DefaultAccountManager) GetAccountById(accountId string) (*Account, error) {
am.mux.Lock()
defer am.mux.Unlock()
account, err := am.Store.GetAccount(accountId)
if err != nil {
return nil, status.Errorf(codes.NotFound, "account not found")
}
return account, nil
}
// GetAccountByUserOrAccountId look for an account by user or account Id, if no account is provided and
// user id doesn't have an account associated with it, one account is created
func (am *DefaultAccountManager) GetAccountByUserOrAccountId(
userId, accountId, domain string,
) (*Account, error) {
if accountId != "" {
return am.GetAccountById(accountId)
} else if userId != "" {
account, err := am.GetOrCreateAccountByUser(userId, domain)
// GetAccountByUserOrAccountID looks for an account by user or accountID, if no account is provided and
// userID doesn't have an account associated with it, one account is created
func (am *DefaultAccountManager) GetAccountByUserOrAccountID(userID, accountID, domain string) (*Account, error) {
if accountID != "" {
return am.Store.GetAccount(accountID)
} else if userID != "" {
account, err := am.GetOrCreateAccountByUser(userID, domain)
if err != nil {
return nil, status.Errorf(codes.NotFound, "account not found using user id: %s", userId)
return nil, status.Errorf(codes.NotFound, "account not found using user id: %s", userID)
}
err = am.addAccountIDToIDPAppMeta(userId, account)
err = am.addAccountIDToIDPAppMeta(userID, account)
if err != nil {
return nil, err
}
@ -825,15 +805,13 @@ func (am *DefaultAccountManager) GetAccountFromToken(claims jwtclaims.Authorizat
// Existing user + Existing account + Existing Indexed Domain -> Nothing changes
//
// Existing user + Existing account + Existing domain reclassified Domain as private -> Nothing changes (index domain)
func (am *DefaultAccountManager) getAccountWithAuthorizationClaims(
claims jwtclaims.AuthorizationClaims,
) (*Account, error) {
func (am *DefaultAccountManager) getAccountWithAuthorizationClaims(claims jwtclaims.AuthorizationClaims) (*Account, error) {
// if Account ID is part of the claims
// it means that we've already classified the domain and user has an account
if claims.DomainCategory != PrivateCategory || !isDomainValid(claims.Domain) {
return am.GetAccountByUserOrAccountId(claims.UserId, claims.AccountId, claims.Domain)
return am.GetAccountByUserOrAccountID(claims.UserId, claims.AccountId, claims.Domain)
} else if claims.AccountId != "" {
accountFromID, err := am.GetAccountById(claims.AccountId)
accountFromID, err := am.Store.GetAccount(claims.AccountId)
if err != nil {
return nil, err
}
@ -845,8 +823,8 @@ func (am *DefaultAccountManager) getAccountWithAuthorizationClaims(
}
}
am.mux.Lock()
defer am.mux.Unlock()
unlock := am.Store.AcquireGlobalLock()
defer unlock()
// We checked if the domain has a primary account already
domainAccount, err := am.Store.GetAccountByPrivateDomain(claims.Domain)
@ -876,12 +854,13 @@ func isDomainValid(domain string) bool {
}
// AccountExists checks whether account exists (returns true) or not (returns false)
func (am *DefaultAccountManager) AccountExists(accountId string) (*bool, error) {
am.mux.Lock()
defer am.mux.Unlock()
func (am *DefaultAccountManager) AccountExists(accountID string) (*bool, error) {
unlock := am.Store.AcquireAccountLock(accountID)
defer unlock()
var res bool
_, err := am.Store.GetAccount(accountId)
_, err := am.Store.GetAccount(accountID)
if err != nil {
if s, ok := status.FromError(err); ok && s.Code() == codes.NotFound {
res = false

View File

@ -121,7 +121,7 @@ func TestAccountManager_GetOrCreateAccountByUser(t *testing.T) {
t.Fatalf("expected to create an account for a user %s", userId)
}
account, err = manager.GetAccountByUser(userId)
account, err = manager.Store.GetAccountByUser(userId)
if err != nil {
t.Errorf("expected to get existing account after creation, no account was found for a user %s", userId)
}
@ -302,7 +302,7 @@ func TestDefaultAccountManager_GetAccountFromToken(t *testing.T) {
manager, err := createManager(t)
require.NoError(t, err, "unable to create account manager")
initAccount, err := manager.GetAccountByUserOrAccountId(testCase.inputInitUserParams.UserId, testCase.inputInitUserParams.AccountId, testCase.inputInitUserParams.Domain)
initAccount, err := manager.GetAccountByUserOrAccountID(testCase.inputInitUserParams.UserId, testCase.inputInitUserParams.AccountId, testCase.inputInitUserParams.Domain)
require.NoError(t, err, "create init user failed")
if testCase.inputUpdateAttrs {
@ -345,7 +345,7 @@ func TestAccountManager_PrivateAccount(t *testing.T) {
t.Fatalf("expected to create an account for a user %s", userId)
}
account, err = manager.GetAccountByUser(userId)
account, err = manager.Store.GetAccountByUser(userId)
if err != nil {
t.Errorf("expected to get existing account after creation, no account was found for a user %s", userId)
}
@ -401,7 +401,7 @@ func TestAccountManager_GetAccountByUserOrAccountId(t *testing.T) {
userId := "test_user"
account, err := manager.GetAccountByUserOrAccountId(userId, "", "")
account, err := manager.GetAccountByUserOrAccountID(userId, "", "")
if err != nil {
t.Fatal(err)
}
@ -411,12 +411,12 @@ func TestAccountManager_GetAccountByUserOrAccountId(t *testing.T) {
accountId := account.Id
_, err = manager.GetAccountByUserOrAccountId("", accountId, "")
_, err = manager.GetAccountByUserOrAccountID("", accountId, "")
if err != nil {
t.Errorf("expected to get existing account after creation using userid, no account was found for a account %s", accountId)
}
_, err = manager.GetAccountByUserOrAccountId("", "", "")
_, err = manager.GetAccountByUserOrAccountID("", "", "")
if err == nil {
t.Errorf("expected an error when user and account IDs are empty")
}
@ -470,7 +470,7 @@ func TestAccountManager_GetAccount(t *testing.T) {
}
// AddAccount has been already tested so we can assume it is correct and compare results
getAccount, err := manager.GetAccountById(expectedId)
getAccount, err := manager.Store.GetAccount(account.Id)
if err != nil {
t.Fatal(err)
return
@ -540,7 +540,7 @@ func TestAccountManager_AddPeer(t *testing.T) {
return
}
account, err = manager.GetAccountById(account.Id)
account, err = manager.Store.GetAccount(account.Id)
if err != nil {
t.Fatal(err)
return
@ -602,7 +602,7 @@ func TestAccountManager_AddPeerWithUserID(t *testing.T) {
return
}
account, err = manager.GetAccountById(account.Id)
account, err = manager.Store.GetAccount(account.Id)
if err != nil {
t.Fatal(err)
return
@ -680,7 +680,7 @@ func TestAccountManager_NetworkUpdates(t *testing.T) {
peer2 := getPeer()
peer3 := getPeer()
account, err = manager.GetAccountById(account.Id)
account, err = manager.Store.GetAccount(account.Id)
if err != nil {
t.Fatal(err)
return
@ -848,7 +848,7 @@ func TestAccountManager_DeletePeer(t *testing.T) {
return
}
account, err = manager.GetAccountById(account.Id)
account, err = manager.Store.GetAccount(account.Id)
if err != nil {
t.Fatal(err)
return

View File

@ -1,10 +1,12 @@
package server
import (
log "github.com/sirupsen/logrus"
"os"
"path/filepath"
"strings"
"sync"
"time"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/status"
@ -27,6 +29,10 @@ type FileStore struct {
// mutex to synchronise Store read/write operations
mux sync.Mutex `json:"-"`
storeFile string `json:"-"`
// sync.Mutex indexed by accountID
accountLocks sync.Map `json:"-"`
globalAccountLock sync.Mutex `json:"-"`
}
type StoredAccount struct{}
@ -44,6 +50,7 @@ func restore(file string) (*FileStore, error) {
s := &FileStore{
Accounts: make(map[string]*Account),
mux: sync.Mutex{},
globalAccountLock: sync.Mutex{},
SetupKeyID2AccountID: make(map[string]string),
PeerKeyID2AccountID: make(map[string]string),
UserID2AccountID: make(map[string]string),
@ -111,7 +118,36 @@ func (s *FileStore) persist(file string) error {
return util.WriteJson(file, s)
}
// SaveAccount updates an existing account or adds a new one
// AcquireGlobalLock acquires global lock across all the accounts and returns a function that releases the lock
func (s *FileStore) AcquireGlobalLock() (unlock func()) {
log.Debugf("acquiring global lock")
start := time.Now()
s.globalAccountLock.Lock()
unlock = func() {
s.globalAccountLock.Unlock()
log.Debugf("released global lock in %v", time.Since(start))
}
return unlock
}
// AcquireAccountLock acquires account lock and returns a function that releases the lock
func (s *FileStore) AcquireAccountLock(accountID string) (unlock func()) {
log.Debugf("acquiring lock for account %s", accountID)
start := time.Now()
value, _ := s.accountLocks.LoadOrStore(accountID, &sync.Mutex{})
mtx := value.(*sync.Mutex)
mtx.Lock()
unlock = func() {
mtx.Unlock()
log.Debugf("released lock for account %s in %v", accountID, time.Since(start))
}
return unlock
}
func (s *FileStore) SaveAccount(account *Account) error {
s.mux.Lock()
defer s.mux.Unlock()

View File

@ -47,8 +47,9 @@ func (g *Group) Copy() *Group {
// GetGroup object of the peers
func (am *DefaultAccountManager) GetGroup(accountID, groupID string) (*Group, error) {
am.mux.Lock()
defer am.mux.Unlock()
unlock := am.Store.AcquireAccountLock(accountID)
defer unlock()
account, err := am.Store.GetAccount(accountID)
if err != nil {
@ -65,8 +66,9 @@ func (am *DefaultAccountManager) GetGroup(accountID, groupID string) (*Group, er
// SaveGroup object of the peers
func (am *DefaultAccountManager) SaveGroup(accountID string, group *Group) error {
am.mux.Lock()
defer am.mux.Unlock()
unlock := am.Store.AcquireAccountLock(accountID)
defer unlock()
account, err := am.Store.GetAccount(accountID)
if err != nil {
@ -86,8 +88,9 @@ func (am *DefaultAccountManager) SaveGroup(accountID string, group *Group) error
// UpdateGroup updates a group using a list of operations
func (am *DefaultAccountManager) UpdateGroup(accountID string,
groupID string, operations []GroupUpdateOperation) (*Group, error) {
am.mux.Lock()
defer am.mux.Unlock()
unlock := am.Store.AcquireAccountLock(accountID)
defer unlock()
account, err := am.Store.GetAccount(accountID)
if err != nil {
@ -135,8 +138,9 @@ func (am *DefaultAccountManager) UpdateGroup(accountID string,
// DeleteGroup object of the peers
func (am *DefaultAccountManager) DeleteGroup(accountID, groupID string) error {
am.mux.Lock()
defer am.mux.Unlock()
unlock := am.Store.AcquireAccountLock(accountID)
defer unlock()
account, err := am.Store.GetAccount(accountID)
if err != nil {
@ -155,8 +159,9 @@ func (am *DefaultAccountManager) DeleteGroup(accountID, groupID string) error {
// ListGroups objects of the peers
func (am *DefaultAccountManager) ListGroups(accountID string) ([]*Group, error) {
am.mux.Lock()
defer am.mux.Unlock()
unlock := am.Store.AcquireAccountLock(accountID)
defer unlock()
account, err := am.Store.GetAccount(accountID)
if err != nil {
@ -173,8 +178,9 @@ func (am *DefaultAccountManager) ListGroups(accountID string) ([]*Group, error)
// GroupAddPeer appends peer to the group
func (am *DefaultAccountManager) GroupAddPeer(accountID, groupID, peerKey string) error {
am.mux.Lock()
defer am.mux.Unlock()
unlock := am.Store.AcquireAccountLock(accountID)
defer unlock()
account, err := am.Store.GetAccount(accountID)
if err != nil {
@ -207,8 +213,9 @@ func (am *DefaultAccountManager) GroupAddPeer(accountID, groupID, peerKey string
// GroupDeletePeer removes peer from the group
func (am *DefaultAccountManager) GroupDeletePeer(accountID, groupID, peerKey string) error {
am.mux.Lock()
defer am.mux.Unlock()
unlock := am.Store.AcquireAccountLock(accountID)
defer unlock()
account, err := am.Store.GetAccount(accountID)
if err != nil {
@ -235,8 +242,9 @@ func (am *DefaultAccountManager) GroupDeletePeer(accountID, groupID, peerKey str
// GroupListPeers returns list of the peers from the group
func (am *DefaultAccountManager) GroupListPeers(accountID, groupID string) ([]*Peer, error) {
am.mux.Lock()
defer am.mux.Unlock()
unlock := am.Store.AcquireAccountLock(accountID)
defer unlock()
account, err := am.Store.GetAccount(accountID)
if err != nil {

View File

@ -15,7 +15,6 @@ type MockAccountManager struct {
GetAccountByUserFunc func(userId string) (*server.Account, error)
CreateSetupKeyFunc func(accountId string, keyName string, keyType server.SetupKeyType, expiresIn time.Duration, autoGroups []string) (*server.SetupKey, error)
GetSetupKeyFunc func(accountID, userID, keyID string) (*server.SetupKey, error)
GetAccountByIdFunc func(accountId string) (*server.Account, error)
GetAccountByUserOrAccountIdFunc func(userId, accountId, domain string) (*server.Account, error)
IsUserAdminFunc func(claims jwtclaims.AuthorizationClaims) (bool, error)
AccountExistsFunc func(accountId string) (*bool, error)
@ -114,16 +113,8 @@ func (am *MockAccountManager) CreateSetupKey(
return nil, status.Errorf(codes.Unimplemented, "method CreateSetupKey is not implemented")
}
// GetAccountById mock implementation of GetAccountById from server.AccountManager interface
func (am *MockAccountManager) GetAccountById(accountId string) (*server.Account, error) {
if am.GetAccountByIdFunc != nil {
return am.GetAccountByIdFunc(accountId)
}
return nil, status.Errorf(codes.Unimplemented, "method GetAccountById is not implemented")
}
// GetAccountByUserOrAccountId mock implementation of GetAccountByUserOrAccountId from server.AccountManager interface
func (am *MockAccountManager) GetAccountByUserOrAccountId(
func (am *MockAccountManager) GetAccountByUserOrAccountID(
userId, accountId, domain string,
) (*server.Account, error) {
if am.GetAccountByUserOrAccountIdFunc != nil {
@ -131,7 +122,7 @@ func (am *MockAccountManager) GetAccountByUserOrAccountId(
}
return nil, status.Errorf(
codes.Unimplemented,
"method GetAccountByUserOrAccountId is not implemented",
"method GetAccountByUserOrAccountID is not implemented",
)
}

View File

@ -60,8 +60,9 @@ type NameServerGroupUpdateOperation struct {
// GetNameServerGroup gets a nameserver group object from account and nameserver group IDs
func (am *DefaultAccountManager) GetNameServerGroup(accountID, nsGroupID string) (*nbdns.NameServerGroup, error) {
am.mux.Lock()
defer am.mux.Unlock()
unlock := am.Store.AcquireAccountLock(accountID)
defer unlock()
account, err := am.Store.GetAccount(accountID)
if err != nil {
@ -78,8 +79,9 @@ func (am *DefaultAccountManager) GetNameServerGroup(accountID, nsGroupID string)
// CreateNameServerGroup creates and saves a new nameserver group
func (am *DefaultAccountManager) CreateNameServerGroup(accountID string, name, description string, nameServerList []nbdns.NameServer, groups []string, primary bool, domains []string, enabled bool) (*nbdns.NameServerGroup, error) {
am.mux.Lock()
defer am.mux.Unlock()
unlock := am.Store.AcquireAccountLock(accountID)
defer unlock()
account, err := am.Store.GetAccount(accountID)
if err != nil {
@ -125,8 +127,9 @@ func (am *DefaultAccountManager) CreateNameServerGroup(accountID string, name, d
// SaveNameServerGroup saves nameserver group
func (am *DefaultAccountManager) SaveNameServerGroup(accountID string, nsGroupToSave *nbdns.NameServerGroup) error {
am.mux.Lock()
defer am.mux.Unlock()
unlock := am.Store.AcquireAccountLock(accountID)
defer unlock()
if nsGroupToSave == nil {
return status.Errorf(codes.InvalidArgument, "nameserver group provided is nil")
@ -161,8 +164,9 @@ func (am *DefaultAccountManager) SaveNameServerGroup(accountID string, nsGroupTo
// UpdateNameServerGroup updates existing nameserver group with set of operations
func (am *DefaultAccountManager) UpdateNameServerGroup(accountID, nsGroupID string, operations []NameServerGroupUpdateOperation) (*nbdns.NameServerGroup, error) {
am.mux.Lock()
defer am.mux.Unlock()
unlock := am.Store.AcquireAccountLock(accountID)
defer unlock()
account, err := am.Store.GetAccount(accountID)
if err != nil {
@ -263,8 +267,9 @@ func (am *DefaultAccountManager) UpdateNameServerGroup(accountID, nsGroupID stri
// DeleteNameServerGroup deletes nameserver group with nsGroupID
func (am *DefaultAccountManager) DeleteNameServerGroup(accountID, nsGroupID string) error {
am.mux.Lock()
defer am.mux.Unlock()
unlock := am.Store.AcquireAccountLock(accountID)
defer unlock()
account, err := am.Store.GetAccount(accountID)
if err != nil {
@ -290,8 +295,9 @@ func (am *DefaultAccountManager) DeleteNameServerGroup(accountID, nsGroupID stri
// ListNameServerGroups returns a list of nameserver groups from account
func (am *DefaultAccountManager) ListNameServerGroups(accountID string) ([]*nbdns.NameServerGroup, error) {
am.mux.Lock()
defer am.mux.Unlock()
unlock := am.Store.AcquireAccountLock(accountID)
defer unlock()
account, err := am.Store.GetAccount(accountID)
if err != nil {

View File

@ -635,7 +635,7 @@ func TestSaveNameServerGroup(t *testing.T) {
return
}
account, err = am.GetAccountById(account.Id)
account, err = am.Store.GetAccount(account.Id)
if err != nil {
t.Fatal(err)
}

View File

@ -76,8 +76,6 @@ func (p *Peer) Copy() *Peer {
// GetPeer looks up peer by its public WireGuard key
func (am *DefaultAccountManager) GetPeer(peerPubKey string) (*Peer, error) {
am.mux.Lock()
defer am.mux.Unlock()
account, err := am.Store.GetAccountByPeerPubKey(peerPubKey)
if err != nil {
@ -90,8 +88,7 @@ func (am *DefaultAccountManager) GetPeer(peerPubKey string) (*Peer, error) {
// GetPeers returns a list of peers under the given account filtering out peers that do not belong to a user if
// the current user is not an admin.
func (am *DefaultAccountManager) GetPeers(accountID, userID string) ([]*Peer, error) {
am.mux.Lock()
defer am.mux.Unlock()
account, err := am.Store.GetAccount(accountID)
if err != nil {
return nil, err
@ -116,14 +113,21 @@ func (am *DefaultAccountManager) GetPeers(accountID, userID string) ([]*Peer, er
// MarkPeerConnected marks peer as connected (true) or disconnected (false)
func (am *DefaultAccountManager) MarkPeerConnected(peerPubKey string, connected bool) error {
am.mux.Lock()
defer am.mux.Unlock()
account, err := am.Store.GetAccountByPeerPubKey(peerPubKey)
if err != nil {
return err
}
unlock := am.Store.AcquireAccountLock(account.Id)
defer unlock()
// ensure that we consider modification happened meanwhile (because we were outside the account lock when we fetched the account)
account, err = am.Store.GetAccount(account.Id)
if err != nil {
return err
}
peer, err := account.FindPeerByPubKey(peerPubKey)
if err != nil {
return err
@ -143,8 +147,9 @@ func (am *DefaultAccountManager) MarkPeerConnected(peerPubKey string, connected
// UpdatePeer updates peer. Only Peer.Name and Peer.SSHEnabled can be updated.
func (am *DefaultAccountManager) UpdatePeer(accountID string, update *Peer) (*Peer, error) {
am.mux.Lock()
defer am.mux.Unlock()
unlock := am.Store.AcquireAccountLock(accountID)
defer unlock()
account, err := am.Store.GetAccount(accountID)
if err != nil {
@ -188,8 +193,9 @@ func (am *DefaultAccountManager) UpdatePeer(accountID string, update *Peer) (*Pe
// DeletePeer removes peer from the account by its IP
func (am *DefaultAccountManager) DeletePeer(accountID string, peerPubKey string) (*Peer, error) {
am.mux.Lock()
defer am.mux.Unlock()
unlock := am.Store.AcquireAccountLock(accountID)
defer unlock()
account, err := am.Store.GetAccount(accountID)
if err != nil {
@ -237,8 +243,9 @@ func (am *DefaultAccountManager) DeletePeer(accountID string, peerPubKey string)
// GetPeerByIP returns peer by its IP
func (am *DefaultAccountManager) GetPeerByIP(accountID string, peerIP string) (*Peer, error) {
am.mux.Lock()
defer am.mux.Unlock()
unlock := am.Store.AcquireAccountLock(accountID)
defer unlock()
account, err := am.Store.GetAccount(accountID)
if err != nil {
@ -256,8 +263,6 @@ func (am *DefaultAccountManager) GetPeerByIP(accountID string, peerIP string) (*
// GetNetworkMap returns Network map for a given peer (omits original peer from the Peers result)
func (am *DefaultAccountManager) GetNetworkMap(peerPubKey string) (*NetworkMap, error) {
am.mux.Lock()
defer am.mux.Unlock()
account, err := am.Store.GetAccountByPeerPubKey(peerPubKey)
if err != nil {
@ -292,8 +297,6 @@ func (am *DefaultAccountManager) GetNetworkMap(peerPubKey string) (*NetworkMap,
// GetPeerNetwork returns the Network for a given peer
func (am *DefaultAccountManager) GetPeerNetwork(peerPubKey string) (*Network, error) {
am.mux.Lock()
defer am.mux.Unlock()
account, err := am.Store.GetAccountByPeerPubKey(peerPubKey)
if err != nil {
@ -311,8 +314,6 @@ func (am *DefaultAccountManager) GetPeerNetwork(peerPubKey string) (*Network, er
// Each new Peer will be assigned a new next net.IP from the Account.Network and Account.Network.LastIP will be updated (IP's are not reused).
// The peer property is just a placeholder for the Peer properties to pass further
func (am *DefaultAccountManager) AddPeer(setupKey string, userID string, peer *Peer) (*Peer, error) {
am.mux.Lock()
defer am.mux.Unlock()
upperKey := strings.ToUpper(setupKey)
@ -367,6 +368,15 @@ func (am *DefaultAccountManager) AddPeer(setupKey string, userID string, peer *P
return nil, status.Errorf(codes.InvalidArgument, "no setup key or user id provided")
}
unlock := am.Store.AcquireAccountLock(account.Id)
defer unlock()
// ensure that we consider modification happened meanwhile (because we were outside the account lock when we fetched the account)
account, err = am.Store.GetAccount(account.Id)
if err != nil {
return nil, err
}
var takenIps []net.IP
existingLabels := make(lookupMap)
for _, existingPeer := range account.Peers {
@ -433,8 +443,6 @@ func (am *DefaultAccountManager) AddPeer(setupKey string, userID string, peer *P
// UpdatePeerSSHKey updates peer's public SSH key
func (am *DefaultAccountManager) UpdatePeerSSHKey(peerPubKey string, sshKey string) error {
am.mux.Lock()
defer am.mux.Unlock()
if sshKey == "" {
log.Debugf("empty SSH key provided for peer %s, skipping update", peerPubKey)
@ -446,6 +454,15 @@ func (am *DefaultAccountManager) UpdatePeerSSHKey(peerPubKey string, sshKey stri
return err
}
unlock := am.Store.AcquireAccountLock(account.Id)
defer unlock()
// ensure that we consider modification happened meanwhile (because we were outside the account lock when we fetched the account)
account, err = am.Store.GetAccount(account.Id)
if err != nil {
return err
}
peer, err := account.FindPeerByPubKey(peerPubKey)
if err != nil {
return err
@ -470,14 +487,15 @@ func (am *DefaultAccountManager) UpdatePeerSSHKey(peerPubKey string, sshKey stri
// UpdatePeerMeta updates peer's system metadata
func (am *DefaultAccountManager) UpdatePeerMeta(peerPubKey string, meta PeerSystemMeta) error {
am.mux.Lock()
defer am.mux.Unlock()
account, err := am.Store.GetAccountByPeerPubKey(peerPubKey)
if err != nil {
return err
}
unlock := am.Store.AcquireAccountLock(account.Id)
defer unlock()
peer, err := account.FindPeerByPubKey(peerPubKey)
if err != nil {
return err

View File

@ -61,8 +61,8 @@ type RouteUpdateOperation struct {
// GetRoute gets a route object from account and route IDs
func (am *DefaultAccountManager) GetRoute(accountID, routeID, userID string) (*route.Route, error) {
am.mux.Lock()
defer am.mux.Unlock()
unlock := am.Store.AcquireAccountLock(accountID)
defer unlock()
account, err := am.Store.GetAccount(accountID)
if err != nil {
@ -116,8 +116,8 @@ func (am *DefaultAccountManager) checkPrefixPeerExists(accountID, peer string, p
// CreateRoute creates and saves a new route
func (am *DefaultAccountManager) CreateRoute(accountID string, network, peer, description, netID string, masquerade bool, metric int, enabled bool) (*route.Route, error) {
am.mux.Lock()
defer am.mux.Unlock()
unlock := am.Store.AcquireAccountLock(accountID)
defer unlock()
account, err := am.Store.GetAccount(accountID)
if err != nil {
@ -180,8 +180,8 @@ func (am *DefaultAccountManager) CreateRoute(accountID string, network, peer, de
// SaveRoute saves route
func (am *DefaultAccountManager) SaveRoute(accountID string, routeToSave *route.Route) error {
am.mux.Lock()
defer am.mux.Unlock()
unlock := am.Store.AcquireAccountLock(accountID)
defer unlock()
if routeToSave == nil {
return status.Errorf(codes.InvalidArgument, "route provided is nil")
@ -223,8 +223,8 @@ func (am *DefaultAccountManager) SaveRoute(accountID string, routeToSave *route.
// UpdateRoute updates existing route with set of operations
func (am *DefaultAccountManager) UpdateRoute(accountID, routeID string, operations []RouteUpdateOperation) (*route.Route, error) {
am.mux.Lock()
defer am.mux.Unlock()
unlock := am.Store.AcquireAccountLock(accountID)
defer unlock()
account, err := am.Store.GetAccount(accountID)
if err != nil {
@ -320,8 +320,8 @@ func (am *DefaultAccountManager) UpdateRoute(accountID, routeID string, operatio
// DeleteRoute deletes route with routeID
func (am *DefaultAccountManager) DeleteRoute(accountID, routeID string) error {
am.mux.Lock()
defer am.mux.Unlock()
unlock := am.Store.AcquireAccountLock(accountID)
defer unlock()
account, err := am.Store.GetAccount(accountID)
if err != nil {
@ -340,8 +340,8 @@ func (am *DefaultAccountManager) DeleteRoute(accountID, routeID string) error {
// ListRoutes returns a list of routes from account
func (am *DefaultAccountManager) ListRoutes(accountID, userID string) ([]*route.Route, error) {
am.mux.Lock()
defer am.mux.Unlock()
unlock := am.Store.AcquireAccountLock(accountID)
defer unlock()
account, err := am.Store.GetAccount(accountID)
if err != nil {

View File

@ -380,7 +380,7 @@ func TestSaveRoute(t *testing.T) {
return
}
account, err = am.GetAccountById(account.Id)
account, err = am.Store.GetAccount(account.Id)
if err != nil {
t.Fatal(err)
}
@ -845,5 +845,5 @@ func initTestRouteAccount(t *testing.T, am *DefaultAccountManager) (*Account, er
return nil, err
}
return am.GetAccountById(accountID)
return am.Store.GetAccount(account.Id)
}

View File

@ -90,8 +90,8 @@ func (r *Rule) Copy() *Rule {
// GetRule of ACL from the store
func (am *DefaultAccountManager) GetRule(accountID, ruleID, userID string) (*Rule, error) {
am.mux.Lock()
defer am.mux.Unlock()
unlock := am.Store.AcquireAccountLock(accountID)
defer unlock()
account, err := am.Store.GetAccount(accountID)
if err != nil {
@ -117,8 +117,8 @@ func (am *DefaultAccountManager) GetRule(accountID, ruleID, userID string) (*Rul
// SaveRule of ACL in the store
func (am *DefaultAccountManager) SaveRule(accountID string, rule *Rule) error {
am.mux.Lock()
defer am.mux.Unlock()
unlock := am.Store.AcquireAccountLock(accountID)
defer unlock()
account, err := am.Store.GetAccount(accountID)
if err != nil {
@ -138,8 +138,8 @@ func (am *DefaultAccountManager) SaveRule(accountID string, rule *Rule) error {
// UpdateRule updates a rule using a list of operations
func (am *DefaultAccountManager) UpdateRule(accountID string, ruleID string,
operations []RuleUpdateOperation) (*Rule, error) {
am.mux.Lock()
defer am.mux.Unlock()
unlock := am.Store.AcquireAccountLock(accountID)
defer unlock()
account, err := am.Store.GetAccount(accountID)
if err != nil {
@ -212,8 +212,8 @@ func (am *DefaultAccountManager) UpdateRule(accountID string, ruleID string,
// DeleteRule of ACL from the store
func (am *DefaultAccountManager) DeleteRule(accountID, ruleID string) error {
am.mux.Lock()
defer am.mux.Unlock()
unlock := am.Store.AcquireAccountLock(accountID)
defer unlock()
account, err := am.Store.GetAccount(accountID)
if err != nil {
@ -232,8 +232,8 @@ func (am *DefaultAccountManager) DeleteRule(accountID, ruleID string) error {
// ListRules of ACL from the store
func (am *DefaultAccountManager) ListRules(accountID, userID string) ([]*Rule, error) {
am.mux.Lock()
defer am.mux.Unlock()
unlock := am.Store.AcquireAccountLock(accountID)
defer unlock()
account, err := am.Store.GetAccount(accountID)
if err != nil {

View File

@ -173,8 +173,8 @@ func Hash(s string) uint32 {
// and adds it to the specified account. A list of autoGroups IDs can be empty.
func (am *DefaultAccountManager) CreateSetupKey(accountID string, keyName string, keyType SetupKeyType,
expiresIn time.Duration, autoGroups []string) (*SetupKey, error) {
am.mux.Lock()
defer am.mux.Unlock()
unlock := am.Store.AcquireAccountLock(accountID)
defer unlock()
keyDuration := DefaultSetupKeyDuration
if expiresIn != 0 {
@ -208,8 +208,8 @@ func (am *DefaultAccountManager) CreateSetupKey(accountID string, keyName string
// (e.g. the key itself, creation date, ID, etc).
// These properties are overwritten: Name, AutoGroups, Revoked. The rest is copied from the existing key.
func (am *DefaultAccountManager) SaveSetupKey(accountID string, keyToSave *SetupKey) (*SetupKey, error) {
am.mux.Lock()
defer am.mux.Unlock()
unlock := am.Store.AcquireAccountLock(accountID)
defer unlock()
if keyToSave == nil {
return nil, status.Errorf(codes.InvalidArgument, "provided setup key to update is nil")
@ -249,8 +249,8 @@ func (am *DefaultAccountManager) SaveSetupKey(accountID string, keyToSave *Setup
// ListSetupKeys returns a list of all setup keys of the account
func (am *DefaultAccountManager) ListSetupKeys(accountID, userID string) ([]*SetupKey, error) {
am.mux.Lock()
defer am.mux.Unlock()
unlock := am.Store.AcquireAccountLock(accountID)
defer unlock()
account, err := am.Store.GetAccount(accountID)
if err != nil {
return nil, status.Errorf(codes.NotFound, "account not found")
@ -277,8 +277,8 @@ func (am *DefaultAccountManager) ListSetupKeys(accountID, userID string) ([]*Set
// GetSetupKey looks up a SetupKey by KeyID, returns NotFound error if not found.
func (am *DefaultAccountManager) GetSetupKey(accountID, userID, keyID string) (*SetupKey, error) {
am.mux.Lock()
defer am.mux.Unlock()
unlock := am.Store.AcquireAccountLock(accountID)
defer unlock()
account, err := am.Store.GetAccount(accountID)
if err != nil {

View File

@ -10,4 +10,8 @@ type Store interface {
SaveAccount(account *Account) error
GetInstallationID() string
SaveInstallationID(id string) error
// AcquireAccountLock should attempt to acquire account lock and return a function that releases the lock
AcquireAccountLock(accountID string) func()
// AcquireGlobalLock should attempt to acquire a global lock and return a function that releases the lock
AcquireGlobalLock() func()
}

View File

@ -119,8 +119,8 @@ func NewAdminUser(id string) *User {
// CreateUser creates a new user under the given account. Effectively this is a user invite.
func (am *DefaultAccountManager) CreateUser(accountID string, invite *UserInfo) (*UserInfo, error) {
am.mux.Lock()
defer am.mux.Unlock()
unlock := am.Store.AcquireAccountLock(accountID)
defer unlock()
if am.idpManager == nil {
return nil, Errorf(PreconditionFailed, "IdP manager must be enabled to send user invites")
@ -184,8 +184,8 @@ func (am *DefaultAccountManager) CreateUser(accountID string, invite *UserInfo)
// SaveUser saves updates a given user. If the user doesn't exit it will throw status.NotFound error.
// Only User.AutoGroups field is allowed to be updated for now.
func (am *DefaultAccountManager) SaveUser(accountID string, update *User) (*UserInfo, error) {
am.mux.Lock()
defer am.mux.Unlock()
unlock := am.Store.AcquireAccountLock(accountID)
defer unlock()
if update == nil {
return nil, status.Errorf(codes.InvalidArgument, "provided user update is nil")
@ -234,16 +234,16 @@ func (am *DefaultAccountManager) SaveUser(accountID string, update *User) (*User
}
// GetOrCreateAccountByUser returns an existing account for a given user id or creates a new one if doesn't exist
func (am *DefaultAccountManager) GetOrCreateAccountByUser(userId, domain string) (*Account, error) {
am.mux.Lock()
defer am.mux.Unlock()
func (am *DefaultAccountManager) GetOrCreateAccountByUser(userID, domain string) (*Account, error) {
unlock := am.Store.AcquireGlobalLock()
defer unlock()
lowerDomain := strings.ToLower(domain)
account, err := am.Store.GetAccountByUser(userId)
account, err := am.Store.GetAccountByUser(userID)
if err != nil {
if s, ok := status.FromError(err); ok && s.Code() == codes.NotFound {
account, err = am.newAccount(userId, lowerDomain)
account, err = am.newAccount(userID, lowerDomain)
if err != nil {
return nil, err
}
@ -257,7 +257,7 @@ func (am *DefaultAccountManager) GetOrCreateAccountByUser(userId, domain string)
}
}
userObj := account.Users[userId]
userObj := account.Users[userID]
if account.Domain != lowerDomain && userObj.Role == UserRoleAdmin {
account.Domain = lowerDomain
@ -270,14 +270,6 @@ func (am *DefaultAccountManager) GetOrCreateAccountByUser(userId, domain string)
return account, nil
}
// GetAccountByUser returns an existing account for a given user id, NotFound if account couldn't be found
func (am *DefaultAccountManager) GetAccountByUser(userId string) (*Account, error) {
am.mux.Lock()
defer am.mux.Unlock()
return am.Store.GetAccountByUser(userId)
}
// IsUserAdmin flag for current user authenticated by JWT token
func (am *DefaultAccountManager) IsUserAdmin(claims jwtclaims.AuthorizationClaims) (bool, error) {
account, err := am.GetAccountFromToken(claims)
@ -296,7 +288,7 @@ func (am *DefaultAccountManager) IsUserAdmin(claims jwtclaims.AuthorizationClaim
// GetUsersFromAccount performs a batched request for users from IDP by account ID apply filter on what data to return
// based on provided user role.
func (am *DefaultAccountManager) GetUsersFromAccount(accountID, userID string) ([]*UserInfo, error) {
account, err := am.GetAccountById(accountID)
account, err := am.Store.GetAccount(accountID)
if err != nil {
return nil, err
}