Introduce locking on the account level (#548)

This commit is contained in:
Misha Bragin 2022-11-07 17:52:23 +01:00 committed by GitHub
parent 1f845f466c
commit ed7ac81027
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
14 changed files with 200 additions and 166 deletions

View File

@ -38,7 +38,6 @@ func cacheEntryExpiration() time.Duration {
type AccountManager interface { type AccountManager interface {
GetOrCreateAccountByUser(userId, domain string) (*Account, error) GetOrCreateAccountByUser(userId, domain string) (*Account, error)
GetAccountByUser(userId string) (*Account, error)
CreateSetupKey( CreateSetupKey(
accountId string, accountId string,
keyName string, keyName string,
@ -51,8 +50,7 @@ type AccountManager interface {
ListSetupKeys(accountID, userID string) ([]*SetupKey, error) ListSetupKeys(accountID, userID string) ([]*SetupKey, error)
SaveUser(accountID string, key *User) (*UserInfo, error) SaveUser(accountID string, key *User) (*UserInfo, error)
GetSetupKey(accountID, userID, keyID string) (*SetupKey, error) GetSetupKey(accountID, userID, keyID string) (*SetupKey, error)
GetAccountById(accountId string) (*Account, error) GetAccountByUserOrAccountID(userID, accountID, domain string) (*Account, error)
GetAccountByUserOrAccountId(userId, accountId, domain string) (*Account, error)
GetAccountFromToken(claims jwtclaims.AuthorizationClaims) (*Account, error) GetAccountFromToken(claims jwtclaims.AuthorizationClaims) (*Account, error)
IsUserAdmin(claims jwtclaims.AuthorizationClaims) (bool, error) IsUserAdmin(claims jwtclaims.AuthorizationClaims) (bool, error)
AccountExists(accountId string) (*bool, error) AccountExists(accountId string) (*bool, error)
@ -97,8 +95,6 @@ type AccountManager interface {
type DefaultAccountManager struct { type DefaultAccountManager struct {
Store Store Store Store
// mux to synchronise account operations (e.g. generating Peer IP address inside the Network)
mux sync.Mutex
// cacheMux and cacheLoading helps to make sure that only a single cache reload runs at a time per accountID // cacheMux and cacheLoading helps to make sure that only a single cache reload runs at a time per accountID
cacheMux sync.Mutex cacheMux sync.Mutex
// cacheLoading keeps the accountIDs that are currently reloading. The accountID has to be removed once cache has been reloaded // cacheLoading keeps the accountIDs that are currently reloading. The accountID has to be removed once cache has been reloaded
@ -359,7 +355,6 @@ func BuildManager(store Store, peersUpdateManager *PeersUpdateManager, idpManage
singleAccountModeDomain string, dnsDomain string) (*DefaultAccountManager, error) { singleAccountModeDomain string, dnsDomain string) (*DefaultAccountManager, error) {
am := &DefaultAccountManager{ am := &DefaultAccountManager{
Store: store, Store: store,
mux: sync.Mutex{},
peersUpdateManager: peersUpdateManager, peersUpdateManager: peersUpdateManager,
idpManager: idpManager, idpManager: idpManager,
ctx: context.Background(), ctx: context.Background(),
@ -460,32 +455,17 @@ func (am *DefaultAccountManager) warmupIDPCache() error {
return nil return nil
} }
// GetAccountById returns an existing account using its ID or error (NotFound) if doesn't exist // GetAccountByUserOrAccountID looks for an account by user or accountID, if no account is provided and
func (am *DefaultAccountManager) GetAccountById(accountId string) (*Account, error) { // userID doesn't have an account associated with it, one account is created
am.mux.Lock() func (am *DefaultAccountManager) GetAccountByUserOrAccountID(userID, accountID, domain string) (*Account, error) {
defer am.mux.Unlock() if accountID != "" {
return am.Store.GetAccount(accountID)
account, err := am.Store.GetAccount(accountId) } else if userID != "" {
if err != nil { account, err := am.GetOrCreateAccountByUser(userID, domain)
return nil, status.Errorf(codes.NotFound, "account not found")
}
return account, nil
}
// GetAccountByUserOrAccountId look for an account by user or account Id, if no account is provided and
// user id doesn't have an account associated with it, one account is created
func (am *DefaultAccountManager) GetAccountByUserOrAccountId(
userId, accountId, domain string,
) (*Account, error) {
if accountId != "" {
return am.GetAccountById(accountId)
} else if userId != "" {
account, err := am.GetOrCreateAccountByUser(userId, domain)
if err != nil { if err != nil {
return nil, status.Errorf(codes.NotFound, "account not found using user id: %s", userId) return nil, status.Errorf(codes.NotFound, "account not found using user id: %s", userID)
} }
err = am.addAccountIDToIDPAppMeta(userId, account) err = am.addAccountIDToIDPAppMeta(userID, account)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -825,15 +805,13 @@ func (am *DefaultAccountManager) GetAccountFromToken(claims jwtclaims.Authorizat
// Existing user + Existing account + Existing Indexed Domain -> Nothing changes // Existing user + Existing account + Existing Indexed Domain -> Nothing changes
// //
// Existing user + Existing account + Existing domain reclassified Domain as private -> Nothing changes (index domain) // Existing user + Existing account + Existing domain reclassified Domain as private -> Nothing changes (index domain)
func (am *DefaultAccountManager) getAccountWithAuthorizationClaims( func (am *DefaultAccountManager) getAccountWithAuthorizationClaims(claims jwtclaims.AuthorizationClaims) (*Account, error) {
claims jwtclaims.AuthorizationClaims,
) (*Account, error) {
// if Account ID is part of the claims // if Account ID is part of the claims
// it means that we've already classified the domain and user has an account // it means that we've already classified the domain and user has an account
if claims.DomainCategory != PrivateCategory || !isDomainValid(claims.Domain) { if claims.DomainCategory != PrivateCategory || !isDomainValid(claims.Domain) {
return am.GetAccountByUserOrAccountId(claims.UserId, claims.AccountId, claims.Domain) return am.GetAccountByUserOrAccountID(claims.UserId, claims.AccountId, claims.Domain)
} else if claims.AccountId != "" { } else if claims.AccountId != "" {
accountFromID, err := am.GetAccountById(claims.AccountId) accountFromID, err := am.Store.GetAccount(claims.AccountId)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -845,8 +823,8 @@ func (am *DefaultAccountManager) getAccountWithAuthorizationClaims(
} }
} }
am.mux.Lock() unlock := am.Store.AcquireGlobalLock()
defer am.mux.Unlock() defer unlock()
// We checked if the domain has a primary account already // We checked if the domain has a primary account already
domainAccount, err := am.Store.GetAccountByPrivateDomain(claims.Domain) domainAccount, err := am.Store.GetAccountByPrivateDomain(claims.Domain)
@ -876,12 +854,13 @@ func isDomainValid(domain string) bool {
} }
// AccountExists checks whether account exists (returns true) or not (returns false) // AccountExists checks whether account exists (returns true) or not (returns false)
func (am *DefaultAccountManager) AccountExists(accountId string) (*bool, error) { func (am *DefaultAccountManager) AccountExists(accountID string) (*bool, error) {
am.mux.Lock()
defer am.mux.Unlock() unlock := am.Store.AcquireAccountLock(accountID)
defer unlock()
var res bool var res bool
_, err := am.Store.GetAccount(accountId) _, err := am.Store.GetAccount(accountID)
if err != nil { if err != nil {
if s, ok := status.FromError(err); ok && s.Code() == codes.NotFound { if s, ok := status.FromError(err); ok && s.Code() == codes.NotFound {
res = false res = false

View File

@ -121,7 +121,7 @@ func TestAccountManager_GetOrCreateAccountByUser(t *testing.T) {
t.Fatalf("expected to create an account for a user %s", userId) t.Fatalf("expected to create an account for a user %s", userId)
} }
account, err = manager.GetAccountByUser(userId) account, err = manager.Store.GetAccountByUser(userId)
if err != nil { if err != nil {
t.Errorf("expected to get existing account after creation, no account was found for a user %s", userId) t.Errorf("expected to get existing account after creation, no account was found for a user %s", userId)
} }
@ -302,7 +302,7 @@ func TestDefaultAccountManager_GetAccountFromToken(t *testing.T) {
manager, err := createManager(t) manager, err := createManager(t)
require.NoError(t, err, "unable to create account manager") require.NoError(t, err, "unable to create account manager")
initAccount, err := manager.GetAccountByUserOrAccountId(testCase.inputInitUserParams.UserId, testCase.inputInitUserParams.AccountId, testCase.inputInitUserParams.Domain) initAccount, err := manager.GetAccountByUserOrAccountID(testCase.inputInitUserParams.UserId, testCase.inputInitUserParams.AccountId, testCase.inputInitUserParams.Domain)
require.NoError(t, err, "create init user failed") require.NoError(t, err, "create init user failed")
if testCase.inputUpdateAttrs { if testCase.inputUpdateAttrs {
@ -345,7 +345,7 @@ func TestAccountManager_PrivateAccount(t *testing.T) {
t.Fatalf("expected to create an account for a user %s", userId) t.Fatalf("expected to create an account for a user %s", userId)
} }
account, err = manager.GetAccountByUser(userId) account, err = manager.Store.GetAccountByUser(userId)
if err != nil { if err != nil {
t.Errorf("expected to get existing account after creation, no account was found for a user %s", userId) t.Errorf("expected to get existing account after creation, no account was found for a user %s", userId)
} }
@ -401,7 +401,7 @@ func TestAccountManager_GetAccountByUserOrAccountId(t *testing.T) {
userId := "test_user" userId := "test_user"
account, err := manager.GetAccountByUserOrAccountId(userId, "", "") account, err := manager.GetAccountByUserOrAccountID(userId, "", "")
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
@ -411,12 +411,12 @@ func TestAccountManager_GetAccountByUserOrAccountId(t *testing.T) {
accountId := account.Id accountId := account.Id
_, err = manager.GetAccountByUserOrAccountId("", accountId, "") _, err = manager.GetAccountByUserOrAccountID("", accountId, "")
if err != nil { if err != nil {
t.Errorf("expected to get existing account after creation using userid, no account was found for a account %s", accountId) t.Errorf("expected to get existing account after creation using userid, no account was found for a account %s", accountId)
} }
_, err = manager.GetAccountByUserOrAccountId("", "", "") _, err = manager.GetAccountByUserOrAccountID("", "", "")
if err == nil { if err == nil {
t.Errorf("expected an error when user and account IDs are empty") t.Errorf("expected an error when user and account IDs are empty")
} }
@ -470,7 +470,7 @@ func TestAccountManager_GetAccount(t *testing.T) {
} }
// AddAccount has been already tested so we can assume it is correct and compare results // AddAccount has been already tested so we can assume it is correct and compare results
getAccount, err := manager.GetAccountById(expectedId) getAccount, err := manager.Store.GetAccount(account.Id)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
return return
@ -540,7 +540,7 @@ func TestAccountManager_AddPeer(t *testing.T) {
return return
} }
account, err = manager.GetAccountById(account.Id) account, err = manager.Store.GetAccount(account.Id)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
return return
@ -602,7 +602,7 @@ func TestAccountManager_AddPeerWithUserID(t *testing.T) {
return return
} }
account, err = manager.GetAccountById(account.Id) account, err = manager.Store.GetAccount(account.Id)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
return return
@ -680,7 +680,7 @@ func TestAccountManager_NetworkUpdates(t *testing.T) {
peer2 := getPeer() peer2 := getPeer()
peer3 := getPeer() peer3 := getPeer()
account, err = manager.GetAccountById(account.Id) account, err = manager.Store.GetAccount(account.Id)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
return return
@ -848,7 +848,7 @@ func TestAccountManager_DeletePeer(t *testing.T) {
return return
} }
account, err = manager.GetAccountById(account.Id) account, err = manager.Store.GetAccount(account.Id)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
return return

View File

@ -1,10 +1,12 @@
package server package server
import ( import (
log "github.com/sirupsen/logrus"
"os" "os"
"path/filepath" "path/filepath"
"strings" "strings"
"sync" "sync"
"time"
"google.golang.org/grpc/codes" "google.golang.org/grpc/codes"
"google.golang.org/grpc/status" "google.golang.org/grpc/status"
@ -27,6 +29,10 @@ type FileStore struct {
// mutex to synchronise Store read/write operations // mutex to synchronise Store read/write operations
mux sync.Mutex `json:"-"` mux sync.Mutex `json:"-"`
storeFile string `json:"-"` storeFile string `json:"-"`
// sync.Mutex indexed by accountID
accountLocks sync.Map `json:"-"`
globalAccountLock sync.Mutex `json:"-"`
} }
type StoredAccount struct{} type StoredAccount struct{}
@ -44,6 +50,7 @@ func restore(file string) (*FileStore, error) {
s := &FileStore{ s := &FileStore{
Accounts: make(map[string]*Account), Accounts: make(map[string]*Account),
mux: sync.Mutex{}, mux: sync.Mutex{},
globalAccountLock: sync.Mutex{},
SetupKeyID2AccountID: make(map[string]string), SetupKeyID2AccountID: make(map[string]string),
PeerKeyID2AccountID: make(map[string]string), PeerKeyID2AccountID: make(map[string]string),
UserID2AccountID: make(map[string]string), UserID2AccountID: make(map[string]string),
@ -111,7 +118,36 @@ func (s *FileStore) persist(file string) error {
return util.WriteJson(file, s) return util.WriteJson(file, s)
} }
// SaveAccount updates an existing account or adds a new one // AcquireGlobalLock acquires global lock across all the accounts and returns a function that releases the lock
func (s *FileStore) AcquireGlobalLock() (unlock func()) {
log.Debugf("acquiring global lock")
start := time.Now()
s.globalAccountLock.Lock()
unlock = func() {
s.globalAccountLock.Unlock()
log.Debugf("released global lock in %v", time.Since(start))
}
return unlock
}
// AcquireAccountLock acquires account lock and returns a function that releases the lock
func (s *FileStore) AcquireAccountLock(accountID string) (unlock func()) {
log.Debugf("acquiring lock for account %s", accountID)
start := time.Now()
value, _ := s.accountLocks.LoadOrStore(accountID, &sync.Mutex{})
mtx := value.(*sync.Mutex)
mtx.Lock()
unlock = func() {
mtx.Unlock()
log.Debugf("released lock for account %s in %v", accountID, time.Since(start))
}
return unlock
}
func (s *FileStore) SaveAccount(account *Account) error { func (s *FileStore) SaveAccount(account *Account) error {
s.mux.Lock() s.mux.Lock()
defer s.mux.Unlock() defer s.mux.Unlock()

View File

@ -47,8 +47,9 @@ func (g *Group) Copy() *Group {
// GetGroup object of the peers // GetGroup object of the peers
func (am *DefaultAccountManager) GetGroup(accountID, groupID string) (*Group, error) { func (am *DefaultAccountManager) GetGroup(accountID, groupID string) (*Group, error) {
am.mux.Lock()
defer am.mux.Unlock() unlock := am.Store.AcquireAccountLock(accountID)
defer unlock()
account, err := am.Store.GetAccount(accountID) account, err := am.Store.GetAccount(accountID)
if err != nil { if err != nil {
@ -65,8 +66,9 @@ func (am *DefaultAccountManager) GetGroup(accountID, groupID string) (*Group, er
// SaveGroup object of the peers // SaveGroup object of the peers
func (am *DefaultAccountManager) SaveGroup(accountID string, group *Group) error { func (am *DefaultAccountManager) SaveGroup(accountID string, group *Group) error {
am.mux.Lock()
defer am.mux.Unlock() unlock := am.Store.AcquireAccountLock(accountID)
defer unlock()
account, err := am.Store.GetAccount(accountID) account, err := am.Store.GetAccount(accountID)
if err != nil { if err != nil {
@ -86,8 +88,9 @@ func (am *DefaultAccountManager) SaveGroup(accountID string, group *Group) error
// UpdateGroup updates a group using a list of operations // UpdateGroup updates a group using a list of operations
func (am *DefaultAccountManager) UpdateGroup(accountID string, func (am *DefaultAccountManager) UpdateGroup(accountID string,
groupID string, operations []GroupUpdateOperation) (*Group, error) { groupID string, operations []GroupUpdateOperation) (*Group, error) {
am.mux.Lock()
defer am.mux.Unlock() unlock := am.Store.AcquireAccountLock(accountID)
defer unlock()
account, err := am.Store.GetAccount(accountID) account, err := am.Store.GetAccount(accountID)
if err != nil { if err != nil {
@ -135,8 +138,9 @@ func (am *DefaultAccountManager) UpdateGroup(accountID string,
// DeleteGroup object of the peers // DeleteGroup object of the peers
func (am *DefaultAccountManager) DeleteGroup(accountID, groupID string) error { func (am *DefaultAccountManager) DeleteGroup(accountID, groupID string) error {
am.mux.Lock()
defer am.mux.Unlock() unlock := am.Store.AcquireAccountLock(accountID)
defer unlock()
account, err := am.Store.GetAccount(accountID) account, err := am.Store.GetAccount(accountID)
if err != nil { if err != nil {
@ -155,8 +159,9 @@ func (am *DefaultAccountManager) DeleteGroup(accountID, groupID string) error {
// ListGroups objects of the peers // ListGroups objects of the peers
func (am *DefaultAccountManager) ListGroups(accountID string) ([]*Group, error) { func (am *DefaultAccountManager) ListGroups(accountID string) ([]*Group, error) {
am.mux.Lock()
defer am.mux.Unlock() unlock := am.Store.AcquireAccountLock(accountID)
defer unlock()
account, err := am.Store.GetAccount(accountID) account, err := am.Store.GetAccount(accountID)
if err != nil { if err != nil {
@ -173,8 +178,9 @@ func (am *DefaultAccountManager) ListGroups(accountID string) ([]*Group, error)
// GroupAddPeer appends peer to the group // GroupAddPeer appends peer to the group
func (am *DefaultAccountManager) GroupAddPeer(accountID, groupID, peerKey string) error { func (am *DefaultAccountManager) GroupAddPeer(accountID, groupID, peerKey string) error {
am.mux.Lock()
defer am.mux.Unlock() unlock := am.Store.AcquireAccountLock(accountID)
defer unlock()
account, err := am.Store.GetAccount(accountID) account, err := am.Store.GetAccount(accountID)
if err != nil { if err != nil {
@ -207,8 +213,9 @@ func (am *DefaultAccountManager) GroupAddPeer(accountID, groupID, peerKey string
// GroupDeletePeer removes peer from the group // GroupDeletePeer removes peer from the group
func (am *DefaultAccountManager) GroupDeletePeer(accountID, groupID, peerKey string) error { func (am *DefaultAccountManager) GroupDeletePeer(accountID, groupID, peerKey string) error {
am.mux.Lock()
defer am.mux.Unlock() unlock := am.Store.AcquireAccountLock(accountID)
defer unlock()
account, err := am.Store.GetAccount(accountID) account, err := am.Store.GetAccount(accountID)
if err != nil { if err != nil {
@ -235,8 +242,9 @@ func (am *DefaultAccountManager) GroupDeletePeer(accountID, groupID, peerKey str
// GroupListPeers returns list of the peers from the group // GroupListPeers returns list of the peers from the group
func (am *DefaultAccountManager) GroupListPeers(accountID, groupID string) ([]*Peer, error) { func (am *DefaultAccountManager) GroupListPeers(accountID, groupID string) ([]*Peer, error) {
am.mux.Lock()
defer am.mux.Unlock() unlock := am.Store.AcquireAccountLock(accountID)
defer unlock()
account, err := am.Store.GetAccount(accountID) account, err := am.Store.GetAccount(accountID)
if err != nil { if err != nil {

View File

@ -15,7 +15,6 @@ type MockAccountManager struct {
GetAccountByUserFunc func(userId string) (*server.Account, error) GetAccountByUserFunc func(userId string) (*server.Account, error)
CreateSetupKeyFunc func(accountId string, keyName string, keyType server.SetupKeyType, expiresIn time.Duration, autoGroups []string) (*server.SetupKey, error) CreateSetupKeyFunc func(accountId string, keyName string, keyType server.SetupKeyType, expiresIn time.Duration, autoGroups []string) (*server.SetupKey, error)
GetSetupKeyFunc func(accountID, userID, keyID string) (*server.SetupKey, error) GetSetupKeyFunc func(accountID, userID, keyID string) (*server.SetupKey, error)
GetAccountByIdFunc func(accountId string) (*server.Account, error)
GetAccountByUserOrAccountIdFunc func(userId, accountId, domain string) (*server.Account, error) GetAccountByUserOrAccountIdFunc func(userId, accountId, domain string) (*server.Account, error)
IsUserAdminFunc func(claims jwtclaims.AuthorizationClaims) (bool, error) IsUserAdminFunc func(claims jwtclaims.AuthorizationClaims) (bool, error)
AccountExistsFunc func(accountId string) (*bool, error) AccountExistsFunc func(accountId string) (*bool, error)
@ -114,16 +113,8 @@ func (am *MockAccountManager) CreateSetupKey(
return nil, status.Errorf(codes.Unimplemented, "method CreateSetupKey is not implemented") return nil, status.Errorf(codes.Unimplemented, "method CreateSetupKey is not implemented")
} }
// GetAccountById mock implementation of GetAccountById from server.AccountManager interface
func (am *MockAccountManager) GetAccountById(accountId string) (*server.Account, error) {
if am.GetAccountByIdFunc != nil {
return am.GetAccountByIdFunc(accountId)
}
return nil, status.Errorf(codes.Unimplemented, "method GetAccountById is not implemented")
}
// GetAccountByUserOrAccountId mock implementation of GetAccountByUserOrAccountId from server.AccountManager interface // GetAccountByUserOrAccountId mock implementation of GetAccountByUserOrAccountId from server.AccountManager interface
func (am *MockAccountManager) GetAccountByUserOrAccountId( func (am *MockAccountManager) GetAccountByUserOrAccountID(
userId, accountId, domain string, userId, accountId, domain string,
) (*server.Account, error) { ) (*server.Account, error) {
if am.GetAccountByUserOrAccountIdFunc != nil { if am.GetAccountByUserOrAccountIdFunc != nil {
@ -131,7 +122,7 @@ func (am *MockAccountManager) GetAccountByUserOrAccountId(
} }
return nil, status.Errorf( return nil, status.Errorf(
codes.Unimplemented, codes.Unimplemented,
"method GetAccountByUserOrAccountId is not implemented", "method GetAccountByUserOrAccountID is not implemented",
) )
} }

View File

@ -60,8 +60,9 @@ type NameServerGroupUpdateOperation struct {
// GetNameServerGroup gets a nameserver group object from account and nameserver group IDs // GetNameServerGroup gets a nameserver group object from account and nameserver group IDs
func (am *DefaultAccountManager) GetNameServerGroup(accountID, nsGroupID string) (*nbdns.NameServerGroup, error) { func (am *DefaultAccountManager) GetNameServerGroup(accountID, nsGroupID string) (*nbdns.NameServerGroup, error) {
am.mux.Lock()
defer am.mux.Unlock() unlock := am.Store.AcquireAccountLock(accountID)
defer unlock()
account, err := am.Store.GetAccount(accountID) account, err := am.Store.GetAccount(accountID)
if err != nil { if err != nil {
@ -78,8 +79,9 @@ func (am *DefaultAccountManager) GetNameServerGroup(accountID, nsGroupID string)
// CreateNameServerGroup creates and saves a new nameserver group // CreateNameServerGroup creates and saves a new nameserver group
func (am *DefaultAccountManager) CreateNameServerGroup(accountID string, name, description string, nameServerList []nbdns.NameServer, groups []string, primary bool, domains []string, enabled bool) (*nbdns.NameServerGroup, error) { func (am *DefaultAccountManager) CreateNameServerGroup(accountID string, name, description string, nameServerList []nbdns.NameServer, groups []string, primary bool, domains []string, enabled bool) (*nbdns.NameServerGroup, error) {
am.mux.Lock()
defer am.mux.Unlock() unlock := am.Store.AcquireAccountLock(accountID)
defer unlock()
account, err := am.Store.GetAccount(accountID) account, err := am.Store.GetAccount(accountID)
if err != nil { if err != nil {
@ -125,8 +127,9 @@ func (am *DefaultAccountManager) CreateNameServerGroup(accountID string, name, d
// SaveNameServerGroup saves nameserver group // SaveNameServerGroup saves nameserver group
func (am *DefaultAccountManager) SaveNameServerGroup(accountID string, nsGroupToSave *nbdns.NameServerGroup) error { func (am *DefaultAccountManager) SaveNameServerGroup(accountID string, nsGroupToSave *nbdns.NameServerGroup) error {
am.mux.Lock()
defer am.mux.Unlock() unlock := am.Store.AcquireAccountLock(accountID)
defer unlock()
if nsGroupToSave == nil { if nsGroupToSave == nil {
return status.Errorf(codes.InvalidArgument, "nameserver group provided is nil") return status.Errorf(codes.InvalidArgument, "nameserver group provided is nil")
@ -161,8 +164,9 @@ func (am *DefaultAccountManager) SaveNameServerGroup(accountID string, nsGroupTo
// UpdateNameServerGroup updates existing nameserver group with set of operations // UpdateNameServerGroup updates existing nameserver group with set of operations
func (am *DefaultAccountManager) UpdateNameServerGroup(accountID, nsGroupID string, operations []NameServerGroupUpdateOperation) (*nbdns.NameServerGroup, error) { func (am *DefaultAccountManager) UpdateNameServerGroup(accountID, nsGroupID string, operations []NameServerGroupUpdateOperation) (*nbdns.NameServerGroup, error) {
am.mux.Lock()
defer am.mux.Unlock() unlock := am.Store.AcquireAccountLock(accountID)
defer unlock()
account, err := am.Store.GetAccount(accountID) account, err := am.Store.GetAccount(accountID)
if err != nil { if err != nil {
@ -263,8 +267,9 @@ func (am *DefaultAccountManager) UpdateNameServerGroup(accountID, nsGroupID stri
// DeleteNameServerGroup deletes nameserver group with nsGroupID // DeleteNameServerGroup deletes nameserver group with nsGroupID
func (am *DefaultAccountManager) DeleteNameServerGroup(accountID, nsGroupID string) error { func (am *DefaultAccountManager) DeleteNameServerGroup(accountID, nsGroupID string) error {
am.mux.Lock()
defer am.mux.Unlock() unlock := am.Store.AcquireAccountLock(accountID)
defer unlock()
account, err := am.Store.GetAccount(accountID) account, err := am.Store.GetAccount(accountID)
if err != nil { if err != nil {
@ -290,8 +295,9 @@ func (am *DefaultAccountManager) DeleteNameServerGroup(accountID, nsGroupID stri
// ListNameServerGroups returns a list of nameserver groups from account // ListNameServerGroups returns a list of nameserver groups from account
func (am *DefaultAccountManager) ListNameServerGroups(accountID string) ([]*nbdns.NameServerGroup, error) { func (am *DefaultAccountManager) ListNameServerGroups(accountID string) ([]*nbdns.NameServerGroup, error) {
am.mux.Lock()
defer am.mux.Unlock() unlock := am.Store.AcquireAccountLock(accountID)
defer unlock()
account, err := am.Store.GetAccount(accountID) account, err := am.Store.GetAccount(accountID)
if err != nil { if err != nil {

View File

@ -635,7 +635,7 @@ func TestSaveNameServerGroup(t *testing.T) {
return return
} }
account, err = am.GetAccountById(account.Id) account, err = am.Store.GetAccount(account.Id)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }

View File

@ -76,8 +76,6 @@ func (p *Peer) Copy() *Peer {
// GetPeer looks up peer by its public WireGuard key // GetPeer looks up peer by its public WireGuard key
func (am *DefaultAccountManager) GetPeer(peerPubKey string) (*Peer, error) { func (am *DefaultAccountManager) GetPeer(peerPubKey string) (*Peer, error) {
am.mux.Lock()
defer am.mux.Unlock()
account, err := am.Store.GetAccountByPeerPubKey(peerPubKey) account, err := am.Store.GetAccountByPeerPubKey(peerPubKey)
if err != nil { if err != nil {
@ -90,8 +88,7 @@ func (am *DefaultAccountManager) GetPeer(peerPubKey string) (*Peer, error) {
// GetPeers returns a list of peers under the given account filtering out peers that do not belong to a user if // GetPeers returns a list of peers under the given account filtering out peers that do not belong to a user if
// the current user is not an admin. // the current user is not an admin.
func (am *DefaultAccountManager) GetPeers(accountID, userID string) ([]*Peer, error) { func (am *DefaultAccountManager) GetPeers(accountID, userID string) ([]*Peer, error) {
am.mux.Lock()
defer am.mux.Unlock()
account, err := am.Store.GetAccount(accountID) account, err := am.Store.GetAccount(accountID)
if err != nil { if err != nil {
return nil, err return nil, err
@ -116,14 +113,21 @@ func (am *DefaultAccountManager) GetPeers(accountID, userID string) ([]*Peer, er
// MarkPeerConnected marks peer as connected (true) or disconnected (false) // MarkPeerConnected marks peer as connected (true) or disconnected (false)
func (am *DefaultAccountManager) MarkPeerConnected(peerPubKey string, connected bool) error { func (am *DefaultAccountManager) MarkPeerConnected(peerPubKey string, connected bool) error {
am.mux.Lock()
defer am.mux.Unlock()
account, err := am.Store.GetAccountByPeerPubKey(peerPubKey) account, err := am.Store.GetAccountByPeerPubKey(peerPubKey)
if err != nil { if err != nil {
return err return err
} }
unlock := am.Store.AcquireAccountLock(account.Id)
defer unlock()
// ensure that we consider modification happened meanwhile (because we were outside the account lock when we fetched the account)
account, err = am.Store.GetAccount(account.Id)
if err != nil {
return err
}
peer, err := account.FindPeerByPubKey(peerPubKey) peer, err := account.FindPeerByPubKey(peerPubKey)
if err != nil { if err != nil {
return err return err
@ -143,8 +147,9 @@ func (am *DefaultAccountManager) MarkPeerConnected(peerPubKey string, connected
// UpdatePeer updates peer. Only Peer.Name and Peer.SSHEnabled can be updated. // UpdatePeer updates peer. Only Peer.Name and Peer.SSHEnabled can be updated.
func (am *DefaultAccountManager) UpdatePeer(accountID string, update *Peer) (*Peer, error) { func (am *DefaultAccountManager) UpdatePeer(accountID string, update *Peer) (*Peer, error) {
am.mux.Lock()
defer am.mux.Unlock() unlock := am.Store.AcquireAccountLock(accountID)
defer unlock()
account, err := am.Store.GetAccount(accountID) account, err := am.Store.GetAccount(accountID)
if err != nil { if err != nil {
@ -188,8 +193,9 @@ func (am *DefaultAccountManager) UpdatePeer(accountID string, update *Peer) (*Pe
// DeletePeer removes peer from the account by its IP // DeletePeer removes peer from the account by its IP
func (am *DefaultAccountManager) DeletePeer(accountID string, peerPubKey string) (*Peer, error) { func (am *DefaultAccountManager) DeletePeer(accountID string, peerPubKey string) (*Peer, error) {
am.mux.Lock()
defer am.mux.Unlock() unlock := am.Store.AcquireAccountLock(accountID)
defer unlock()
account, err := am.Store.GetAccount(accountID) account, err := am.Store.GetAccount(accountID)
if err != nil { if err != nil {
@ -237,8 +243,9 @@ func (am *DefaultAccountManager) DeletePeer(accountID string, peerPubKey string)
// GetPeerByIP returns peer by its IP // GetPeerByIP returns peer by its IP
func (am *DefaultAccountManager) GetPeerByIP(accountID string, peerIP string) (*Peer, error) { func (am *DefaultAccountManager) GetPeerByIP(accountID string, peerIP string) (*Peer, error) {
am.mux.Lock()
defer am.mux.Unlock() unlock := am.Store.AcquireAccountLock(accountID)
defer unlock()
account, err := am.Store.GetAccount(accountID) account, err := am.Store.GetAccount(accountID)
if err != nil { if err != nil {
@ -256,8 +263,6 @@ func (am *DefaultAccountManager) GetPeerByIP(accountID string, peerIP string) (*
// GetNetworkMap returns Network map for a given peer (omits original peer from the Peers result) // GetNetworkMap returns Network map for a given peer (omits original peer from the Peers result)
func (am *DefaultAccountManager) GetNetworkMap(peerPubKey string) (*NetworkMap, error) { func (am *DefaultAccountManager) GetNetworkMap(peerPubKey string) (*NetworkMap, error) {
am.mux.Lock()
defer am.mux.Unlock()
account, err := am.Store.GetAccountByPeerPubKey(peerPubKey) account, err := am.Store.GetAccountByPeerPubKey(peerPubKey)
if err != nil { if err != nil {
@ -292,8 +297,6 @@ func (am *DefaultAccountManager) GetNetworkMap(peerPubKey string) (*NetworkMap,
// GetPeerNetwork returns the Network for a given peer // GetPeerNetwork returns the Network for a given peer
func (am *DefaultAccountManager) GetPeerNetwork(peerPubKey string) (*Network, error) { func (am *DefaultAccountManager) GetPeerNetwork(peerPubKey string) (*Network, error) {
am.mux.Lock()
defer am.mux.Unlock()
account, err := am.Store.GetAccountByPeerPubKey(peerPubKey) account, err := am.Store.GetAccountByPeerPubKey(peerPubKey)
if err != nil { if err != nil {
@ -311,8 +314,6 @@ func (am *DefaultAccountManager) GetPeerNetwork(peerPubKey string) (*Network, er
// Each new Peer will be assigned a new next net.IP from the Account.Network and Account.Network.LastIP will be updated (IP's are not reused). // Each new Peer will be assigned a new next net.IP from the Account.Network and Account.Network.LastIP will be updated (IP's are not reused).
// The peer property is just a placeholder for the Peer properties to pass further // The peer property is just a placeholder for the Peer properties to pass further
func (am *DefaultAccountManager) AddPeer(setupKey string, userID string, peer *Peer) (*Peer, error) { func (am *DefaultAccountManager) AddPeer(setupKey string, userID string, peer *Peer) (*Peer, error) {
am.mux.Lock()
defer am.mux.Unlock()
upperKey := strings.ToUpper(setupKey) upperKey := strings.ToUpper(setupKey)
@ -367,6 +368,15 @@ func (am *DefaultAccountManager) AddPeer(setupKey string, userID string, peer *P
return nil, status.Errorf(codes.InvalidArgument, "no setup key or user id provided") return nil, status.Errorf(codes.InvalidArgument, "no setup key or user id provided")
} }
unlock := am.Store.AcquireAccountLock(account.Id)
defer unlock()
// ensure that we consider modification happened meanwhile (because we were outside the account lock when we fetched the account)
account, err = am.Store.GetAccount(account.Id)
if err != nil {
return nil, err
}
var takenIps []net.IP var takenIps []net.IP
existingLabels := make(lookupMap) existingLabels := make(lookupMap)
for _, existingPeer := range account.Peers { for _, existingPeer := range account.Peers {
@ -433,8 +443,6 @@ func (am *DefaultAccountManager) AddPeer(setupKey string, userID string, peer *P
// UpdatePeerSSHKey updates peer's public SSH key // UpdatePeerSSHKey updates peer's public SSH key
func (am *DefaultAccountManager) UpdatePeerSSHKey(peerPubKey string, sshKey string) error { func (am *DefaultAccountManager) UpdatePeerSSHKey(peerPubKey string, sshKey string) error {
am.mux.Lock()
defer am.mux.Unlock()
if sshKey == "" { if sshKey == "" {
log.Debugf("empty SSH key provided for peer %s, skipping update", peerPubKey) log.Debugf("empty SSH key provided for peer %s, skipping update", peerPubKey)
@ -446,6 +454,15 @@ func (am *DefaultAccountManager) UpdatePeerSSHKey(peerPubKey string, sshKey stri
return err return err
} }
unlock := am.Store.AcquireAccountLock(account.Id)
defer unlock()
// ensure that we consider modification happened meanwhile (because we were outside the account lock when we fetched the account)
account, err = am.Store.GetAccount(account.Id)
if err != nil {
return err
}
peer, err := account.FindPeerByPubKey(peerPubKey) peer, err := account.FindPeerByPubKey(peerPubKey)
if err != nil { if err != nil {
return err return err
@ -470,14 +487,15 @@ func (am *DefaultAccountManager) UpdatePeerSSHKey(peerPubKey string, sshKey stri
// UpdatePeerMeta updates peer's system metadata // UpdatePeerMeta updates peer's system metadata
func (am *DefaultAccountManager) UpdatePeerMeta(peerPubKey string, meta PeerSystemMeta) error { func (am *DefaultAccountManager) UpdatePeerMeta(peerPubKey string, meta PeerSystemMeta) error {
am.mux.Lock()
defer am.mux.Unlock()
account, err := am.Store.GetAccountByPeerPubKey(peerPubKey) account, err := am.Store.GetAccountByPeerPubKey(peerPubKey)
if err != nil { if err != nil {
return err return err
} }
unlock := am.Store.AcquireAccountLock(account.Id)
defer unlock()
peer, err := account.FindPeerByPubKey(peerPubKey) peer, err := account.FindPeerByPubKey(peerPubKey)
if err != nil { if err != nil {
return err return err

View File

@ -61,8 +61,8 @@ type RouteUpdateOperation struct {
// GetRoute gets a route object from account and route IDs // GetRoute gets a route object from account and route IDs
func (am *DefaultAccountManager) GetRoute(accountID, routeID, userID string) (*route.Route, error) { func (am *DefaultAccountManager) GetRoute(accountID, routeID, userID string) (*route.Route, error) {
am.mux.Lock() unlock := am.Store.AcquireAccountLock(accountID)
defer am.mux.Unlock() defer unlock()
account, err := am.Store.GetAccount(accountID) account, err := am.Store.GetAccount(accountID)
if err != nil { if err != nil {
@ -116,8 +116,8 @@ func (am *DefaultAccountManager) checkPrefixPeerExists(accountID, peer string, p
// CreateRoute creates and saves a new route // CreateRoute creates and saves a new route
func (am *DefaultAccountManager) CreateRoute(accountID string, network, peer, description, netID string, masquerade bool, metric int, enabled bool) (*route.Route, error) { func (am *DefaultAccountManager) CreateRoute(accountID string, network, peer, description, netID string, masquerade bool, metric int, enabled bool) (*route.Route, error) {
am.mux.Lock() unlock := am.Store.AcquireAccountLock(accountID)
defer am.mux.Unlock() defer unlock()
account, err := am.Store.GetAccount(accountID) account, err := am.Store.GetAccount(accountID)
if err != nil { if err != nil {
@ -180,8 +180,8 @@ func (am *DefaultAccountManager) CreateRoute(accountID string, network, peer, de
// SaveRoute saves route // SaveRoute saves route
func (am *DefaultAccountManager) SaveRoute(accountID string, routeToSave *route.Route) error { func (am *DefaultAccountManager) SaveRoute(accountID string, routeToSave *route.Route) error {
am.mux.Lock() unlock := am.Store.AcquireAccountLock(accountID)
defer am.mux.Unlock() defer unlock()
if routeToSave == nil { if routeToSave == nil {
return status.Errorf(codes.InvalidArgument, "route provided is nil") return status.Errorf(codes.InvalidArgument, "route provided is nil")
@ -223,8 +223,8 @@ func (am *DefaultAccountManager) SaveRoute(accountID string, routeToSave *route.
// UpdateRoute updates existing route with set of operations // UpdateRoute updates existing route with set of operations
func (am *DefaultAccountManager) UpdateRoute(accountID, routeID string, operations []RouteUpdateOperation) (*route.Route, error) { func (am *DefaultAccountManager) UpdateRoute(accountID, routeID string, operations []RouteUpdateOperation) (*route.Route, error) {
am.mux.Lock() unlock := am.Store.AcquireAccountLock(accountID)
defer am.mux.Unlock() defer unlock()
account, err := am.Store.GetAccount(accountID) account, err := am.Store.GetAccount(accountID)
if err != nil { if err != nil {
@ -320,8 +320,8 @@ func (am *DefaultAccountManager) UpdateRoute(accountID, routeID string, operatio
// DeleteRoute deletes route with routeID // DeleteRoute deletes route with routeID
func (am *DefaultAccountManager) DeleteRoute(accountID, routeID string) error { func (am *DefaultAccountManager) DeleteRoute(accountID, routeID string) error {
am.mux.Lock() unlock := am.Store.AcquireAccountLock(accountID)
defer am.mux.Unlock() defer unlock()
account, err := am.Store.GetAccount(accountID) account, err := am.Store.GetAccount(accountID)
if err != nil { if err != nil {
@ -340,8 +340,8 @@ func (am *DefaultAccountManager) DeleteRoute(accountID, routeID string) error {
// ListRoutes returns a list of routes from account // ListRoutes returns a list of routes from account
func (am *DefaultAccountManager) ListRoutes(accountID, userID string) ([]*route.Route, error) { func (am *DefaultAccountManager) ListRoutes(accountID, userID string) ([]*route.Route, error) {
am.mux.Lock() unlock := am.Store.AcquireAccountLock(accountID)
defer am.mux.Unlock() defer unlock()
account, err := am.Store.GetAccount(accountID) account, err := am.Store.GetAccount(accountID)
if err != nil { if err != nil {

View File

@ -380,7 +380,7 @@ func TestSaveRoute(t *testing.T) {
return return
} }
account, err = am.GetAccountById(account.Id) account, err = am.Store.GetAccount(account.Id)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
@ -845,5 +845,5 @@ func initTestRouteAccount(t *testing.T, am *DefaultAccountManager) (*Account, er
return nil, err return nil, err
} }
return am.GetAccountById(accountID) return am.Store.GetAccount(account.Id)
} }

View File

@ -90,8 +90,8 @@ func (r *Rule) Copy() *Rule {
// GetRule of ACL from the store // GetRule of ACL from the store
func (am *DefaultAccountManager) GetRule(accountID, ruleID, userID string) (*Rule, error) { func (am *DefaultAccountManager) GetRule(accountID, ruleID, userID string) (*Rule, error) {
am.mux.Lock() unlock := am.Store.AcquireAccountLock(accountID)
defer am.mux.Unlock() defer unlock()
account, err := am.Store.GetAccount(accountID) account, err := am.Store.GetAccount(accountID)
if err != nil { if err != nil {
@ -117,8 +117,8 @@ func (am *DefaultAccountManager) GetRule(accountID, ruleID, userID string) (*Rul
// SaveRule of ACL in the store // SaveRule of ACL in the store
func (am *DefaultAccountManager) SaveRule(accountID string, rule *Rule) error { func (am *DefaultAccountManager) SaveRule(accountID string, rule *Rule) error {
am.mux.Lock() unlock := am.Store.AcquireAccountLock(accountID)
defer am.mux.Unlock() defer unlock()
account, err := am.Store.GetAccount(accountID) account, err := am.Store.GetAccount(accountID)
if err != nil { if err != nil {
@ -138,8 +138,8 @@ func (am *DefaultAccountManager) SaveRule(accountID string, rule *Rule) error {
// UpdateRule updates a rule using a list of operations // UpdateRule updates a rule using a list of operations
func (am *DefaultAccountManager) UpdateRule(accountID string, ruleID string, func (am *DefaultAccountManager) UpdateRule(accountID string, ruleID string,
operations []RuleUpdateOperation) (*Rule, error) { operations []RuleUpdateOperation) (*Rule, error) {
am.mux.Lock() unlock := am.Store.AcquireAccountLock(accountID)
defer am.mux.Unlock() defer unlock()
account, err := am.Store.GetAccount(accountID) account, err := am.Store.GetAccount(accountID)
if err != nil { if err != nil {
@ -212,8 +212,8 @@ func (am *DefaultAccountManager) UpdateRule(accountID string, ruleID string,
// DeleteRule of ACL from the store // DeleteRule of ACL from the store
func (am *DefaultAccountManager) DeleteRule(accountID, ruleID string) error { func (am *DefaultAccountManager) DeleteRule(accountID, ruleID string) error {
am.mux.Lock() unlock := am.Store.AcquireAccountLock(accountID)
defer am.mux.Unlock() defer unlock()
account, err := am.Store.GetAccount(accountID) account, err := am.Store.GetAccount(accountID)
if err != nil { if err != nil {
@ -232,8 +232,8 @@ func (am *DefaultAccountManager) DeleteRule(accountID, ruleID string) error {
// ListRules of ACL from the store // ListRules of ACL from the store
func (am *DefaultAccountManager) ListRules(accountID, userID string) ([]*Rule, error) { func (am *DefaultAccountManager) ListRules(accountID, userID string) ([]*Rule, error) {
am.mux.Lock() unlock := am.Store.AcquireAccountLock(accountID)
defer am.mux.Unlock() defer unlock()
account, err := am.Store.GetAccount(accountID) account, err := am.Store.GetAccount(accountID)
if err != nil { if err != nil {

View File

@ -173,8 +173,8 @@ func Hash(s string) uint32 {
// and adds it to the specified account. A list of autoGroups IDs can be empty. // and adds it to the specified account. A list of autoGroups IDs can be empty.
func (am *DefaultAccountManager) CreateSetupKey(accountID string, keyName string, keyType SetupKeyType, func (am *DefaultAccountManager) CreateSetupKey(accountID string, keyName string, keyType SetupKeyType,
expiresIn time.Duration, autoGroups []string) (*SetupKey, error) { expiresIn time.Duration, autoGroups []string) (*SetupKey, error) {
am.mux.Lock() unlock := am.Store.AcquireAccountLock(accountID)
defer am.mux.Unlock() defer unlock()
keyDuration := DefaultSetupKeyDuration keyDuration := DefaultSetupKeyDuration
if expiresIn != 0 { if expiresIn != 0 {
@ -208,8 +208,8 @@ func (am *DefaultAccountManager) CreateSetupKey(accountID string, keyName string
// (e.g. the key itself, creation date, ID, etc). // (e.g. the key itself, creation date, ID, etc).
// These properties are overwritten: Name, AutoGroups, Revoked. The rest is copied from the existing key. // These properties are overwritten: Name, AutoGroups, Revoked. The rest is copied from the existing key.
func (am *DefaultAccountManager) SaveSetupKey(accountID string, keyToSave *SetupKey) (*SetupKey, error) { func (am *DefaultAccountManager) SaveSetupKey(accountID string, keyToSave *SetupKey) (*SetupKey, error) {
am.mux.Lock() unlock := am.Store.AcquireAccountLock(accountID)
defer am.mux.Unlock() defer unlock()
if keyToSave == nil { if keyToSave == nil {
return nil, status.Errorf(codes.InvalidArgument, "provided setup key to update is nil") return nil, status.Errorf(codes.InvalidArgument, "provided setup key to update is nil")
@ -249,8 +249,8 @@ func (am *DefaultAccountManager) SaveSetupKey(accountID string, keyToSave *Setup
// ListSetupKeys returns a list of all setup keys of the account // ListSetupKeys returns a list of all setup keys of the account
func (am *DefaultAccountManager) ListSetupKeys(accountID, userID string) ([]*SetupKey, error) { func (am *DefaultAccountManager) ListSetupKeys(accountID, userID string) ([]*SetupKey, error) {
am.mux.Lock() unlock := am.Store.AcquireAccountLock(accountID)
defer am.mux.Unlock() defer unlock()
account, err := am.Store.GetAccount(accountID) account, err := am.Store.GetAccount(accountID)
if err != nil { if err != nil {
return nil, status.Errorf(codes.NotFound, "account not found") return nil, status.Errorf(codes.NotFound, "account not found")
@ -277,8 +277,8 @@ func (am *DefaultAccountManager) ListSetupKeys(accountID, userID string) ([]*Set
// GetSetupKey looks up a SetupKey by KeyID, returns NotFound error if not found. // GetSetupKey looks up a SetupKey by KeyID, returns NotFound error if not found.
func (am *DefaultAccountManager) GetSetupKey(accountID, userID, keyID string) (*SetupKey, error) { func (am *DefaultAccountManager) GetSetupKey(accountID, userID, keyID string) (*SetupKey, error) {
am.mux.Lock() unlock := am.Store.AcquireAccountLock(accountID)
defer am.mux.Unlock() defer unlock()
account, err := am.Store.GetAccount(accountID) account, err := am.Store.GetAccount(accountID)
if err != nil { if err != nil {

View File

@ -10,4 +10,8 @@ type Store interface {
SaveAccount(account *Account) error SaveAccount(account *Account) error
GetInstallationID() string GetInstallationID() string
SaveInstallationID(id string) error SaveInstallationID(id string) error
// AcquireAccountLock should attempt to acquire account lock and return a function that releases the lock
AcquireAccountLock(accountID string) func()
// AcquireGlobalLock should attempt to acquire a global lock and return a function that releases the lock
AcquireGlobalLock() func()
} }

View File

@ -119,8 +119,8 @@ func NewAdminUser(id string) *User {
// CreateUser creates a new user under the given account. Effectively this is a user invite. // CreateUser creates a new user under the given account. Effectively this is a user invite.
func (am *DefaultAccountManager) CreateUser(accountID string, invite *UserInfo) (*UserInfo, error) { func (am *DefaultAccountManager) CreateUser(accountID string, invite *UserInfo) (*UserInfo, error) {
am.mux.Lock() unlock := am.Store.AcquireAccountLock(accountID)
defer am.mux.Unlock() defer unlock()
if am.idpManager == nil { if am.idpManager == nil {
return nil, Errorf(PreconditionFailed, "IdP manager must be enabled to send user invites") return nil, Errorf(PreconditionFailed, "IdP manager must be enabled to send user invites")
@ -184,8 +184,8 @@ func (am *DefaultAccountManager) CreateUser(accountID string, invite *UserInfo)
// SaveUser saves updates a given user. If the user doesn't exit it will throw status.NotFound error. // SaveUser saves updates a given user. If the user doesn't exit it will throw status.NotFound error.
// Only User.AutoGroups field is allowed to be updated for now. // Only User.AutoGroups field is allowed to be updated for now.
func (am *DefaultAccountManager) SaveUser(accountID string, update *User) (*UserInfo, error) { func (am *DefaultAccountManager) SaveUser(accountID string, update *User) (*UserInfo, error) {
am.mux.Lock() unlock := am.Store.AcquireAccountLock(accountID)
defer am.mux.Unlock() defer unlock()
if update == nil { if update == nil {
return nil, status.Errorf(codes.InvalidArgument, "provided user update is nil") return nil, status.Errorf(codes.InvalidArgument, "provided user update is nil")
@ -234,16 +234,16 @@ func (am *DefaultAccountManager) SaveUser(accountID string, update *User) (*User
} }
// GetOrCreateAccountByUser returns an existing account for a given user id or creates a new one if doesn't exist // GetOrCreateAccountByUser returns an existing account for a given user id or creates a new one if doesn't exist
func (am *DefaultAccountManager) GetOrCreateAccountByUser(userId, domain string) (*Account, error) { func (am *DefaultAccountManager) GetOrCreateAccountByUser(userID, domain string) (*Account, error) {
am.mux.Lock() unlock := am.Store.AcquireGlobalLock()
defer am.mux.Unlock() defer unlock()
lowerDomain := strings.ToLower(domain) lowerDomain := strings.ToLower(domain)
account, err := am.Store.GetAccountByUser(userId) account, err := am.Store.GetAccountByUser(userID)
if err != nil { if err != nil {
if s, ok := status.FromError(err); ok && s.Code() == codes.NotFound { if s, ok := status.FromError(err); ok && s.Code() == codes.NotFound {
account, err = am.newAccount(userId, lowerDomain) account, err = am.newAccount(userID, lowerDomain)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -257,7 +257,7 @@ func (am *DefaultAccountManager) GetOrCreateAccountByUser(userId, domain string)
} }
} }
userObj := account.Users[userId] userObj := account.Users[userID]
if account.Domain != lowerDomain && userObj.Role == UserRoleAdmin { if account.Domain != lowerDomain && userObj.Role == UserRoleAdmin {
account.Domain = lowerDomain account.Domain = lowerDomain
@ -270,14 +270,6 @@ func (am *DefaultAccountManager) GetOrCreateAccountByUser(userId, domain string)
return account, nil return account, nil
} }
// GetAccountByUser returns an existing account for a given user id, NotFound if account couldn't be found
func (am *DefaultAccountManager) GetAccountByUser(userId string) (*Account, error) {
am.mux.Lock()
defer am.mux.Unlock()
return am.Store.GetAccountByUser(userId)
}
// IsUserAdmin flag for current user authenticated by JWT token // IsUserAdmin flag for current user authenticated by JWT token
func (am *DefaultAccountManager) IsUserAdmin(claims jwtclaims.AuthorizationClaims) (bool, error) { func (am *DefaultAccountManager) IsUserAdmin(claims jwtclaims.AuthorizationClaims) (bool, error) {
account, err := am.GetAccountFromToken(claims) account, err := am.GetAccountFromToken(claims)
@ -296,7 +288,7 @@ func (am *DefaultAccountManager) IsUserAdmin(claims jwtclaims.AuthorizationClaim
// GetUsersFromAccount performs a batched request for users from IDP by account ID apply filter on what data to return // GetUsersFromAccount performs a batched request for users from IDP by account ID apply filter on what data to return
// based on provided user role. // based on provided user role.
func (am *DefaultAccountManager) GetUsersFromAccount(accountID, userID string) ([]*UserInfo, error) { func (am *DefaultAccountManager) GetUsersFromAccount(accountID, userID string) ([]*UserInfo, error) {
account, err := am.GetAccountById(accountID) account, err := am.Store.GetAccount(accountID)
if err != nil { if err != nil {
return nil, err return nil, err
} }