mirror of
https://github.com/netbirdio/netbird.git
synced 2024-11-26 01:53:42 +01:00
Store domain information (#217)
* extract claim information from JWT * get account function * Store domain * tests missing domain * update existing account with domain * add store domain tests
This commit is contained in:
parent
919f0aa3da
commit
cd9a418df2
@ -23,6 +23,7 @@ type Account struct {
|
|||||||
Id string
|
Id string
|
||||||
// User.Id it was created by
|
// User.Id it was created by
|
||||||
CreatedBy string
|
CreatedBy string
|
||||||
|
Domain string
|
||||||
SetupKeys map[string]*SetupKey
|
SetupKeys map[string]*SetupKey
|
||||||
Network *Network
|
Network *Network
|
||||||
Peers map[string]*Peer
|
Peers map[string]*Peer
|
||||||
@ -30,9 +31,9 @@ type Account struct {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// NewAccount creates a new Account with a generated ID and generated default setup keys
|
// NewAccount creates a new Account with a generated ID and generated default setup keys
|
||||||
func NewAccount(userId string) *Account {
|
func NewAccount(userId, domain string) *Account {
|
||||||
accountId := xid.New().String()
|
accountId := xid.New().String()
|
||||||
return newAccountWithId(accountId, userId)
|
return newAccountWithId(accountId, userId, domain)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (a *Account) Copy() *Account {
|
func (a *Account) Copy() *Account {
|
||||||
@ -149,8 +150,8 @@ func (am *AccountManager) RenameSetupKey(accountId string, keyId string, newName
|
|||||||
return keyCopy, nil
|
return keyCopy, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
//GetAccount returns an existing account or error (NotFound) if doesn't exist
|
//GetAccountById returns an existing account using its ID or error (NotFound) if doesn't exist
|
||||||
func (am *AccountManager) GetAccount(accountId string) (*Account, error) {
|
func (am *AccountManager) GetAccountById(accountId string) (*Account, error) {
|
||||||
am.mux.Lock()
|
am.mux.Lock()
|
||||||
defer am.mux.Unlock()
|
defer am.mux.Unlock()
|
||||||
|
|
||||||
@ -164,12 +165,12 @@ func (am *AccountManager) GetAccount(accountId string) (*Account, error) {
|
|||||||
|
|
||||||
//GetAccountByUserOrAccountId look for an account by user or account Id, if no account is provided and
|
//GetAccountByUserOrAccountId look for an account by user or account Id, if no account is provided and
|
||||||
// user id doesn't have an account associated with it, one account is created
|
// user id doesn't have an account associated with it, one account is created
|
||||||
func (am *AccountManager) GetAccountByUserOrAccountId(userId, accountId string) (*Account, error) {
|
func (am *AccountManager) GetAccountByUserOrAccountId(userId, accountId, domain string) (*Account, error) {
|
||||||
|
|
||||||
if accountId != "" {
|
if accountId != "" {
|
||||||
return am.GetAccount(accountId)
|
return am.GetAccountById(accountId)
|
||||||
} else if userId != "" {
|
} else if userId != "" {
|
||||||
account, err := am.GetOrCreateAccountByUser(userId)
|
account, err := am.GetOrCreateAccountByUser(userId, domain)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, status.Errorf(codes.NotFound, "account not found using user id: %s", userId)
|
return nil, status.Errorf(codes.NotFound, "account not found using user id: %s", userId)
|
||||||
}
|
}
|
||||||
@ -207,17 +208,17 @@ func (am *AccountManager) AccountExists(accountId string) (*bool, error) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// AddAccount generates a new Account with a provided accountId and userId, saves to the Store
|
// AddAccount generates a new Account with a provided accountId and userId, saves to the Store
|
||||||
func (am *AccountManager) AddAccount(accountId string, userId string) (*Account, error) {
|
func (am *AccountManager) AddAccount(accountId, userId, domain string) (*Account, error) {
|
||||||
|
|
||||||
am.mux.Lock()
|
am.mux.Lock()
|
||||||
defer am.mux.Unlock()
|
defer am.mux.Unlock()
|
||||||
|
|
||||||
return am.createAccount(accountId, userId)
|
return am.createAccount(accountId, userId, domain)
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (am *AccountManager) createAccount(accountId string, userId string) (*Account, error) {
|
func (am *AccountManager) createAccount(accountId, userId, domain string) (*Account, error) {
|
||||||
account := newAccountWithId(accountId, userId)
|
account := newAccountWithId(accountId, userId, domain)
|
||||||
|
|
||||||
err := am.Store.SaveAccount(account)
|
err := am.Store.SaveAccount(account)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@ -228,7 +229,7 @@ func (am *AccountManager) createAccount(accountId string, userId string) (*Accou
|
|||||||
}
|
}
|
||||||
|
|
||||||
// newAccountWithId creates a new Account with a default SetupKey (doesn't store in a Store) and provided id
|
// newAccountWithId creates a new Account with a default SetupKey (doesn't store in a Store) and provided id
|
||||||
func newAccountWithId(accountId string, userId string) *Account {
|
func newAccountWithId(accountId, userId, domain string) *Account {
|
||||||
|
|
||||||
log.Debugf("creating new account")
|
log.Debugf("creating new account")
|
||||||
|
|
||||||
@ -243,7 +244,15 @@ func newAccountWithId(accountId string, userId string) *Account {
|
|||||||
|
|
||||||
log.Debugf("created new account %s with setup key %s", accountId, defaultKey.Key)
|
log.Debugf("created new account %s with setup key %s", accountId, defaultKey.Key)
|
||||||
|
|
||||||
return &Account{Id: accountId, SetupKeys: setupKeys, Network: network, Peers: peers, Users: users, CreatedBy: userId}
|
return &Account{
|
||||||
|
Id: accountId,
|
||||||
|
SetupKeys: setupKeys,
|
||||||
|
Network: network,
|
||||||
|
Peers: peers,
|
||||||
|
Users: users,
|
||||||
|
CreatedBy: userId,
|
||||||
|
Domain: domain,
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func getAccountSetupKeyById(acc *Account, keyId string) *SetupKey {
|
func getAccountSetupKeyById(acc *Account, keyId string) *SetupKey {
|
||||||
|
@ -14,7 +14,7 @@ func TestAccountManager_GetOrCreateAccountByUser(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
userId := "test_user"
|
userId := "test_user"
|
||||||
account, err := manager.GetOrCreateAccountByUser(userId)
|
account, err := manager.GetOrCreateAccountByUser(userId, "")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
@ -32,6 +32,43 @@ func TestAccountManager_GetOrCreateAccountByUser(t *testing.T) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestAccountManager_SetOrUpdateDomain(t *testing.T) {
|
||||||
|
manager, err := createManager(t)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
userId := "test_user"
|
||||||
|
domain := "hotmail.com"
|
||||||
|
account, err := manager.GetOrCreateAccountByUser(userId, domain)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
if account == nil {
|
||||||
|
t.Fatalf("expected to create an account for a user %s", userId)
|
||||||
|
}
|
||||||
|
|
||||||
|
if account.Domain != domain {
|
||||||
|
t.Errorf("setting account domain failed, expected %s, got %s", domain, account.Domain)
|
||||||
|
}
|
||||||
|
|
||||||
|
domain = "gmail.com"
|
||||||
|
|
||||||
|
account, err = manager.GetOrCreateAccountByUser(userId, domain)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("got the following error while retrieving existing acc: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if account == nil {
|
||||||
|
t.Fatalf("expected to get an account for a user %s", userId)
|
||||||
|
}
|
||||||
|
|
||||||
|
if account.Domain != domain {
|
||||||
|
t.Errorf("updating domain. expected %s got %s", domain, account.Domain)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func TestAccountManager_AddAccount(t *testing.T) {
|
func TestAccountManager_AddAccount(t *testing.T) {
|
||||||
manager, err := createManager(t)
|
manager, err := createManager(t)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@ -48,7 +85,7 @@ func TestAccountManager_AddAccount(t *testing.T) {
|
|||||||
Mask: net.IPMask{255, 192, 0, 0},
|
Mask: net.IPMask{255, 192, 0, 0},
|
||||||
}
|
}
|
||||||
|
|
||||||
account, err := manager.AddAccount(expectedId, userId)
|
account, err := manager.AddAccount(expectedId, userId, "")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
@ -79,7 +116,7 @@ func TestAccountManager_GetAccountByUserOrAccountId(t *testing.T) {
|
|||||||
|
|
||||||
userId := "test_user"
|
userId := "test_user"
|
||||||
|
|
||||||
account, err := manager.GetAccountByUserOrAccountId(userId, "")
|
account, err := manager.GetAccountByUserOrAccountId(userId, "", "")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
@ -89,12 +126,12 @@ func TestAccountManager_GetAccountByUserOrAccountId(t *testing.T) {
|
|||||||
|
|
||||||
accountId := account.Id
|
accountId := account.Id
|
||||||
|
|
||||||
_, err = manager.GetAccountByUserOrAccountId("", accountId)
|
_, err = manager.GetAccountByUserOrAccountId("", accountId, "")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Errorf("expected to get existing account after creation using userid, no account was found for a account %s", accountId)
|
t.Errorf("expected to get existing account after creation using userid, no account was found for a account %s", accountId)
|
||||||
}
|
}
|
||||||
|
|
||||||
_, err = manager.GetAccountByUserOrAccountId("", "")
|
_, err = manager.GetAccountByUserOrAccountId("", "", "")
|
||||||
if err == nil {
|
if err == nil {
|
||||||
t.Errorf("expected an error when user and account IDs are empty")
|
t.Errorf("expected an error when user and account IDs are empty")
|
||||||
}
|
}
|
||||||
@ -109,7 +146,7 @@ func TestAccountManager_AccountExists(t *testing.T) {
|
|||||||
|
|
||||||
expectedId := "test_account"
|
expectedId := "test_account"
|
||||||
userId := "account_creator"
|
userId := "account_creator"
|
||||||
_, err = manager.AddAccount(expectedId, userId)
|
_, err = manager.AddAccount(expectedId, userId, "")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
@ -134,13 +171,13 @@ func TestAccountManager_GetAccount(t *testing.T) {
|
|||||||
|
|
||||||
expectedId := "test_account"
|
expectedId := "test_account"
|
||||||
userId := "account_creator"
|
userId := "account_creator"
|
||||||
account, err := manager.AddAccount(expectedId, userId)
|
account, err := manager.AddAccount(expectedId, userId, "")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
|
|
||||||
//AddAccount has been already tested so we can assume it is correct and compare results
|
//AddAccount has been already tested so we can assume it is correct and compare results
|
||||||
getAccount, err := manager.GetAccount(expectedId)
|
getAccount, err := manager.GetAccountById(expectedId)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
return
|
return
|
||||||
@ -171,7 +208,7 @@ func TestAccountManager_AddPeer(t *testing.T) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
account, err := manager.AddAccount("test_account", "account_creator")
|
account, err := manager.AddAccount("test_account", "account_creator", "")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
@ -211,7 +248,7 @@ func TestAccountManager_AddPeer(t *testing.T) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
account, err = manager.GetAccount(account.Id)
|
account, err = manager.GetAccountById(account.Id)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
return
|
return
|
||||||
@ -238,7 +275,7 @@ func TestAccountManager_DeletePeer(t *testing.T) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
account, err := manager.AddAccount("test_account", "account_creator")
|
account, err := manager.AddAccount("test_account", "account_creator", "")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
@ -271,7 +308,7 @@ func TestAccountManager_DeletePeer(t *testing.T) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
account, err = manager.GetAccount(account.Id)
|
account, err = manager.GetAccountById(account.Id)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
return
|
return
|
||||||
|
@ -32,7 +32,7 @@ func TestNewStore(t *testing.T) {
|
|||||||
func TestSaveAccount(t *testing.T) {
|
func TestSaveAccount(t *testing.T) {
|
||||||
store := newStore(t)
|
store := newStore(t)
|
||||||
|
|
||||||
account := NewAccount("testuser")
|
account := NewAccount("testuser", "")
|
||||||
account.Users["testuser"] = NewAdminUser("testuser")
|
account.Users["testuser"] = NewAdminUser("testuser")
|
||||||
setupKey := GenerateDefaultSetupKey()
|
setupKey := GenerateDefaultSetupKey()
|
||||||
account.SetupKeys[setupKey.Key] = setupKey
|
account.SetupKeys[setupKey.Key] = setupKey
|
||||||
@ -72,7 +72,7 @@ func TestSaveAccount(t *testing.T) {
|
|||||||
func TestStore(t *testing.T) {
|
func TestStore(t *testing.T) {
|
||||||
store := newStore(t)
|
store := newStore(t)
|
||||||
|
|
||||||
account := NewAccount("testuser")
|
account := NewAccount("testuser", "")
|
||||||
account.Users["testuser"] = NewAdminUser("testuser")
|
account.Users["testuser"] = NewAdminUser("testuser")
|
||||||
account.Peers["testpeer"] = &Peer{
|
account.Peers["testpeer"] = &Peer{
|
||||||
Key: "peerkey",
|
Key: "peerkey",
|
||||||
|
@ -63,12 +63,21 @@ func (h *Peers) deletePeer(accountId string, peer *server.Peer, w http.ResponseW
|
|||||||
writeJSONObject(w, "")
|
writeJSONObject(w, "")
|
||||||
}
|
}
|
||||||
|
|
||||||
func (h *Peers) HandlePeer(w http.ResponseWriter, r *http.Request) {
|
func (h *Peers) getPeerAccount(r *http.Request) (*server.Account, error) {
|
||||||
userId, accountId := extractUserAndAccountIdFromRequestContext(r, h.authAudience)
|
jwtClaims := extractClaimsFromRequestContext(r, h.authAudience)
|
||||||
//new user -> create a new account
|
|
||||||
account, err := h.accountManager.GetAccountByUserOrAccountId(userId, accountId)
|
account, err := h.accountManager.GetAccountByUserOrAccountId(jwtClaims.UserId, jwtClaims.AccountId, jwtClaims.Domain)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Errorf("failed getting account of a user %s: %v", userId, err)
|
return nil, fmt.Errorf("failed getting account of a user %s: %v", jwtClaims.UserId, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return account, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (h *Peers) HandlePeer(w http.ResponseWriter, r *http.Request) {
|
||||||
|
account, err := h.getPeerAccount(r)
|
||||||
|
if err != nil {
|
||||||
|
log.Error(err)
|
||||||
http.Redirect(w, r, "/", http.StatusInternalServerError)
|
http.Redirect(w, r, "/", http.StatusInternalServerError)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
@ -105,11 +114,9 @@ func (h *Peers) HandlePeer(w http.ResponseWriter, r *http.Request) {
|
|||||||
func (h *Peers) GetPeers(w http.ResponseWriter, r *http.Request) {
|
func (h *Peers) GetPeers(w http.ResponseWriter, r *http.Request) {
|
||||||
switch r.Method {
|
switch r.Method {
|
||||||
case http.MethodGet:
|
case http.MethodGet:
|
||||||
userId, accountId := extractUserAndAccountIdFromRequestContext(r, h.authAudience)
|
account, err := h.getPeerAccount(r)
|
||||||
//new user -> create a new account
|
|
||||||
account, err := h.accountManager.GetAccountByUserOrAccountId(userId, accountId)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Errorf("failed getting account of a user %s: %v", userId, err)
|
log.Error(err)
|
||||||
http.Redirect(w, r, "/", http.StatusInternalServerError)
|
http.Redirect(w, r, "/", http.StatusInternalServerError)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
@ -2,6 +2,7 @@ package handler
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
|
"fmt"
|
||||||
"github.com/gorilla/mux"
|
"github.com/gorilla/mux"
|
||||||
log "github.com/sirupsen/logrus"
|
log "github.com/sirupsen/logrus"
|
||||||
"github.com/wiretrustee/wiretrustee/management/server"
|
"github.com/wiretrustee/wiretrustee/management/server"
|
||||||
@ -78,7 +79,7 @@ func (h *SetupKeys) updateKey(accountId string, keyId string, w http.ResponseWri
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (h *SetupKeys) getKey(accountId string, keyId string, w http.ResponseWriter, r *http.Request) {
|
func (h *SetupKeys) getKey(accountId string, keyId string, w http.ResponseWriter, r *http.Request) {
|
||||||
account, err := h.accountManager.GetAccount(accountId)
|
account, err := h.accountManager.GetAccountById(accountId)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
http.Error(w, "account doesn't exist", http.StatusInternalServerError)
|
http.Error(w, "account doesn't exist", http.StatusInternalServerError)
|
||||||
return
|
return
|
||||||
@ -119,11 +120,21 @@ func (h *SetupKeys) createKey(accountId string, w http.ResponseWriter, r *http.R
|
|||||||
writeSuccess(w, setupKey)
|
writeSuccess(w, setupKey)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (h *SetupKeys) HandleKey(w http.ResponseWriter, r *http.Request) {
|
func (h *SetupKeys) getSetupKeyAccount(r *http.Request) (*server.Account, error) {
|
||||||
userId, accountId := extractUserAndAccountIdFromRequestContext(r, h.authAudience)
|
jwtClaims := extractClaimsFromRequestContext(r, h.authAudience)
|
||||||
account, err := h.accountManager.GetAccountByUserOrAccountId(userId, accountId)
|
|
||||||
|
account, err := h.accountManager.GetAccountByUserOrAccountId(jwtClaims.UserId, jwtClaims.AccountId, jwtClaims.Domain)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Errorf("failed getting account of a user %s: %v", userId, err)
|
return nil, fmt.Errorf("failed getting account of a user %s: %v", jwtClaims.UserId, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return account, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (h *SetupKeys) HandleKey(w http.ResponseWriter, r *http.Request) {
|
||||||
|
account, err := h.getSetupKeyAccount(r)
|
||||||
|
if err != nil {
|
||||||
|
log.Error(err)
|
||||||
http.Redirect(w, r, "/", http.StatusInternalServerError)
|
http.Redirect(w, r, "/", http.StatusInternalServerError)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
@ -149,10 +160,9 @@ func (h *SetupKeys) HandleKey(w http.ResponseWriter, r *http.Request) {
|
|||||||
|
|
||||||
func (h *SetupKeys) GetKeys(w http.ResponseWriter, r *http.Request) {
|
func (h *SetupKeys) GetKeys(w http.ResponseWriter, r *http.Request) {
|
||||||
|
|
||||||
userId, accountId := extractUserAndAccountIdFromRequestContext(r, h.authAudience)
|
account, err := h.getSetupKeyAccount(r)
|
||||||
account, err := h.accountManager.GetAccountByUserOrAccountId(userId, accountId)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Errorf("failed getting account of a user %s: %v", userId, err)
|
log.Error(err)
|
||||||
http.Redirect(w, r, "/", http.StatusInternalServerError)
|
http.Redirect(w, r, "/", http.StatusInternalServerError)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
@ -8,17 +8,28 @@ import (
|
|||||||
"time"
|
"time"
|
||||||
)
|
)
|
||||||
|
|
||||||
// extractUserAndAccountIdFromRequestContext extracts accountId from the request context previously filled by the JWT token (after auth)
|
// JWTClaims stores information from JWTs
|
||||||
func extractUserAndAccountIdFromRequestContext(r *http.Request, authAudiance string) (userId, accountId string) {
|
type JWTClaims struct {
|
||||||
|
UserId string
|
||||||
|
AccountId string
|
||||||
|
Domain string
|
||||||
|
}
|
||||||
|
|
||||||
|
// extractClaimsFromRequestContext extracts claims from the request context previously filled by the JWT token (after auth)
|
||||||
|
func extractClaimsFromRequestContext(r *http.Request, authAudiance string) JWTClaims {
|
||||||
token := r.Context().Value("user").(*jwt.Token)
|
token := r.Context().Value("user").(*jwt.Token)
|
||||||
claims := token.Claims.(jwt.MapClaims)
|
claims := token.Claims.(jwt.MapClaims)
|
||||||
|
jwtClaims := JWTClaims{}
|
||||||
userId = claims["sub"].(string)
|
jwtClaims.UserId = claims["sub"].(string)
|
||||||
accountIdInt, ok := claims[authAudiance+"wt_account_id"]
|
accountIdClaim, ok := claims[authAudiance+"wt_account_id"]
|
||||||
if ok {
|
if ok {
|
||||||
accountId = accountIdInt.(string)
|
jwtClaims.AccountId = accountIdClaim.(string)
|
||||||
}
|
}
|
||||||
return userId, accountId
|
domainClaim, ok := claims[authAudiance+"wt_user_domain"]
|
||||||
|
if ok {
|
||||||
|
jwtClaims.AccountId = domainClaim.(string)
|
||||||
|
}
|
||||||
|
return jwtClaims
|
||||||
}
|
}
|
||||||
|
|
||||||
//writeJSONObject simply writes object to the HTTP reponse in JSON format
|
//writeJSONObject simply writes object to the HTTP reponse in JSON format
|
||||||
|
@ -14,7 +14,7 @@ func TestAccountManager_GetNetworkMap(t *testing.T) {
|
|||||||
|
|
||||||
expectedId := "test_account"
|
expectedId := "test_account"
|
||||||
userId := "account_creator"
|
userId := "account_creator"
|
||||||
account, err := manager.AddAccount(expectedId, userId)
|
account, err := manager.AddAccount(expectedId, userId, "")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
|
@ -40,14 +40,14 @@ func NewAdminUser(id string) *User {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// GetOrCreateAccountByUser returns an existing account for a given user id or creates a new one if doesn't exist
|
// GetOrCreateAccountByUser returns an existing account for a given user id or creates a new one if doesn't exist
|
||||||
func (am *AccountManager) GetOrCreateAccountByUser(userId string) (*Account, error) {
|
func (am *AccountManager) GetOrCreateAccountByUser(userId, domain string) (*Account, error) {
|
||||||
am.mux.Lock()
|
am.mux.Lock()
|
||||||
defer am.mux.Unlock()
|
defer am.mux.Unlock()
|
||||||
|
|
||||||
account, err := am.Store.GetUserAccount(userId)
|
account, err := am.Store.GetUserAccount(userId)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
if s, ok := status.FromError(err); ok && s.Code() == codes.NotFound {
|
if s, ok := status.FromError(err); ok && s.Code() == codes.NotFound {
|
||||||
account = NewAccount(userId)
|
account = NewAccount(userId, domain)
|
||||||
account.Users[userId] = NewAdminUser(userId)
|
account.Users[userId] = NewAdminUser(userId)
|
||||||
err = am.Store.SaveAccount(account)
|
err = am.Store.SaveAccount(account)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@ -59,6 +59,14 @@ func (am *AccountManager) GetOrCreateAccountByUser(userId string) (*Account, err
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if account.Domain != domain {
|
||||||
|
account.Domain = domain
|
||||||
|
err = am.Store.SaveAccount(account)
|
||||||
|
if err != nil {
|
||||||
|
return nil, status.Errorf(codes.Internal, "failed updating account with domain")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
return account, nil
|
return account, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user