From cd9a418df2883858a14dbd005f57c3a4d994ac2a Mon Sep 17 00:00:00 2001 From: Maycon Santos Date: Fri, 11 Feb 2022 17:18:18 +0100 Subject: [PATCH] Store domain information (#217) * extract claim information from JWT * get account function * Store domain * tests missing domain * update existing account with domain * add store domain tests --- management/server/account.go | 35 +++++++----- management/server/account_test.go | 61 +++++++++++++++++---- management/server/file_store_test.go | 4 +- management/server/http/handler/peers.go | 25 ++++++--- management/server/http/handler/setupkeys.go | 26 ++++++--- management/server/http/handler/util.go | 25 ++++++--- management/server/peer_test.go | 2 +- management/server/user.go | 12 +++- 8 files changed, 136 insertions(+), 54 deletions(-) diff --git a/management/server/account.go b/management/server/account.go index f082741bf..4550c6379 100644 --- a/management/server/account.go +++ b/management/server/account.go @@ -23,6 +23,7 @@ type Account struct { Id string // User.Id it was created by CreatedBy string + Domain string SetupKeys map[string]*SetupKey Network *Network Peers map[string]*Peer @@ -30,9 +31,9 @@ type Account struct { } // NewAccount creates a new Account with a generated ID and generated default setup keys -func NewAccount(userId string) *Account { +func NewAccount(userId, domain string) *Account { accountId := xid.New().String() - return newAccountWithId(accountId, userId) + return newAccountWithId(accountId, userId, domain) } func (a *Account) Copy() *Account { @@ -149,8 +150,8 @@ func (am *AccountManager) RenameSetupKey(accountId string, keyId string, newName return keyCopy, nil } -//GetAccount returns an existing account or error (NotFound) if doesn't exist -func (am *AccountManager) GetAccount(accountId string) (*Account, error) { +//GetAccountById returns an existing account using its ID or error (NotFound) if doesn't exist +func (am *AccountManager) GetAccountById(accountId string) (*Account, error) { am.mux.Lock() defer am.mux.Unlock() @@ -164,12 +165,12 @@ func (am *AccountManager) GetAccount(accountId string) (*Account, error) { //GetAccountByUserOrAccountId look for an account by user or account Id, if no account is provided and // user id doesn't have an account associated with it, one account is created -func (am *AccountManager) GetAccountByUserOrAccountId(userId, accountId string) (*Account, error) { +func (am *AccountManager) GetAccountByUserOrAccountId(userId, accountId, domain string) (*Account, error) { if accountId != "" { - return am.GetAccount(accountId) + return am.GetAccountById(accountId) } else if userId != "" { - account, err := am.GetOrCreateAccountByUser(userId) + account, err := am.GetOrCreateAccountByUser(userId, domain) if err != nil { return nil, status.Errorf(codes.NotFound, "account not found using user id: %s", userId) } @@ -207,17 +208,17 @@ func (am *AccountManager) AccountExists(accountId string) (*bool, error) { } // AddAccount generates a new Account with a provided accountId and userId, saves to the Store -func (am *AccountManager) AddAccount(accountId string, userId string) (*Account, error) { +func (am *AccountManager) AddAccount(accountId, userId, domain string) (*Account, error) { am.mux.Lock() defer am.mux.Unlock() - return am.createAccount(accountId, userId) + return am.createAccount(accountId, userId, domain) } -func (am *AccountManager) createAccount(accountId string, userId string) (*Account, error) { - account := newAccountWithId(accountId, userId) +func (am *AccountManager) createAccount(accountId, userId, domain string) (*Account, error) { + account := newAccountWithId(accountId, userId, domain) err := am.Store.SaveAccount(account) if err != nil { @@ -228,7 +229,7 @@ func (am *AccountManager) createAccount(accountId string, userId string) (*Accou } // newAccountWithId creates a new Account with a default SetupKey (doesn't store in a Store) and provided id -func newAccountWithId(accountId string, userId string) *Account { +func newAccountWithId(accountId, userId, domain string) *Account { log.Debugf("creating new account") @@ -243,7 +244,15 @@ func newAccountWithId(accountId string, userId string) *Account { log.Debugf("created new account %s with setup key %s", accountId, defaultKey.Key) - return &Account{Id: accountId, SetupKeys: setupKeys, Network: network, Peers: peers, Users: users, CreatedBy: userId} + return &Account{ + Id: accountId, + SetupKeys: setupKeys, + Network: network, + Peers: peers, + Users: users, + CreatedBy: userId, + Domain: domain, + } } func getAccountSetupKeyById(acc *Account, keyId string) *SetupKey { diff --git a/management/server/account_test.go b/management/server/account_test.go index 7f26ccb33..2ec6c8bb6 100644 --- a/management/server/account_test.go +++ b/management/server/account_test.go @@ -14,7 +14,7 @@ func TestAccountManager_GetOrCreateAccountByUser(t *testing.T) { } userId := "test_user" - account, err := manager.GetOrCreateAccountByUser(userId) + account, err := manager.GetOrCreateAccountByUser(userId, "") if err != nil { t.Fatal(err) } @@ -32,6 +32,43 @@ func TestAccountManager_GetOrCreateAccountByUser(t *testing.T) { } } +func TestAccountManager_SetOrUpdateDomain(t *testing.T) { + manager, err := createManager(t) + if err != nil { + t.Fatal(err) + return + } + + userId := "test_user" + domain := "hotmail.com" + account, err := manager.GetOrCreateAccountByUser(userId, domain) + if err != nil { + t.Fatal(err) + } + if account == nil { + t.Fatalf("expected to create an account for a user %s", userId) + } + + if account.Domain != domain { + t.Errorf("setting account domain failed, expected %s, got %s", domain, account.Domain) + } + + domain = "gmail.com" + + account, err = manager.GetOrCreateAccountByUser(userId, domain) + if err != nil { + t.Fatalf("got the following error while retrieving existing acc: %v", err) + } + + if account == nil { + t.Fatalf("expected to get an account for a user %s", userId) + } + + if account.Domain != domain { + t.Errorf("updating domain. expected %s got %s", domain, account.Domain) + } +} + func TestAccountManager_AddAccount(t *testing.T) { manager, err := createManager(t) if err != nil { @@ -48,7 +85,7 @@ func TestAccountManager_AddAccount(t *testing.T) { Mask: net.IPMask{255, 192, 0, 0}, } - account, err := manager.AddAccount(expectedId, userId) + account, err := manager.AddAccount(expectedId, userId, "") if err != nil { t.Fatal(err) } @@ -79,7 +116,7 @@ func TestAccountManager_GetAccountByUserOrAccountId(t *testing.T) { userId := "test_user" - account, err := manager.GetAccountByUserOrAccountId(userId, "") + account, err := manager.GetAccountByUserOrAccountId(userId, "", "") if err != nil { t.Fatal(err) } @@ -89,12 +126,12 @@ func TestAccountManager_GetAccountByUserOrAccountId(t *testing.T) { accountId := account.Id - _, err = manager.GetAccountByUserOrAccountId("", accountId) + _, err = manager.GetAccountByUserOrAccountId("", accountId, "") if err != nil { t.Errorf("expected to get existing account after creation using userid, no account was found for a account %s", accountId) } - _, err = manager.GetAccountByUserOrAccountId("", "") + _, err = manager.GetAccountByUserOrAccountId("", "", "") if err == nil { t.Errorf("expected an error when user and account IDs are empty") } @@ -109,7 +146,7 @@ func TestAccountManager_AccountExists(t *testing.T) { expectedId := "test_account" userId := "account_creator" - _, err = manager.AddAccount(expectedId, userId) + _, err = manager.AddAccount(expectedId, userId, "") if err != nil { t.Fatal(err) } @@ -134,13 +171,13 @@ func TestAccountManager_GetAccount(t *testing.T) { expectedId := "test_account" userId := "account_creator" - account, err := manager.AddAccount(expectedId, userId) + account, err := manager.AddAccount(expectedId, userId, "") if err != nil { t.Fatal(err) } //AddAccount has been already tested so we can assume it is correct and compare results - getAccount, err := manager.GetAccount(expectedId) + getAccount, err := manager.GetAccountById(expectedId) if err != nil { t.Fatal(err) return @@ -171,7 +208,7 @@ func TestAccountManager_AddPeer(t *testing.T) { return } - account, err := manager.AddAccount("test_account", "account_creator") + account, err := manager.AddAccount("test_account", "account_creator", "") if err != nil { t.Fatal(err) } @@ -211,7 +248,7 @@ func TestAccountManager_AddPeer(t *testing.T) { return } - account, err = manager.GetAccount(account.Id) + account, err = manager.GetAccountById(account.Id) if err != nil { t.Fatal(err) return @@ -238,7 +275,7 @@ func TestAccountManager_DeletePeer(t *testing.T) { return } - account, err := manager.AddAccount("test_account", "account_creator") + account, err := manager.AddAccount("test_account", "account_creator", "") if err != nil { t.Fatal(err) } @@ -271,7 +308,7 @@ func TestAccountManager_DeletePeer(t *testing.T) { return } - account, err = manager.GetAccount(account.Id) + account, err = manager.GetAccountById(account.Id) if err != nil { t.Fatal(err) return diff --git a/management/server/file_store_test.go b/management/server/file_store_test.go index 4d9b86cf3..0b83799c1 100644 --- a/management/server/file_store_test.go +++ b/management/server/file_store_test.go @@ -32,7 +32,7 @@ func TestNewStore(t *testing.T) { func TestSaveAccount(t *testing.T) { store := newStore(t) - account := NewAccount("testuser") + account := NewAccount("testuser", "") account.Users["testuser"] = NewAdminUser("testuser") setupKey := GenerateDefaultSetupKey() account.SetupKeys[setupKey.Key] = setupKey @@ -72,7 +72,7 @@ func TestSaveAccount(t *testing.T) { func TestStore(t *testing.T) { store := newStore(t) - account := NewAccount("testuser") + account := NewAccount("testuser", "") account.Users["testuser"] = NewAdminUser("testuser") account.Peers["testpeer"] = &Peer{ Key: "peerkey", diff --git a/management/server/http/handler/peers.go b/management/server/http/handler/peers.go index 3e3fe564d..c39dc64e3 100644 --- a/management/server/http/handler/peers.go +++ b/management/server/http/handler/peers.go @@ -63,12 +63,21 @@ func (h *Peers) deletePeer(accountId string, peer *server.Peer, w http.ResponseW writeJSONObject(w, "") } -func (h *Peers) HandlePeer(w http.ResponseWriter, r *http.Request) { - userId, accountId := extractUserAndAccountIdFromRequestContext(r, h.authAudience) - //new user -> create a new account - account, err := h.accountManager.GetAccountByUserOrAccountId(userId, accountId) +func (h *Peers) getPeerAccount(r *http.Request) (*server.Account, error) { + jwtClaims := extractClaimsFromRequestContext(r, h.authAudience) + + account, err := h.accountManager.GetAccountByUserOrAccountId(jwtClaims.UserId, jwtClaims.AccountId, jwtClaims.Domain) if err != nil { - log.Errorf("failed getting account of a user %s: %v", userId, err) + return nil, fmt.Errorf("failed getting account of a user %s: %v", jwtClaims.UserId, err) + } + + return account, nil +} + +func (h *Peers) HandlePeer(w http.ResponseWriter, r *http.Request) { + account, err := h.getPeerAccount(r) + if err != nil { + log.Error(err) http.Redirect(w, r, "/", http.StatusInternalServerError) return } @@ -105,11 +114,9 @@ func (h *Peers) HandlePeer(w http.ResponseWriter, r *http.Request) { func (h *Peers) GetPeers(w http.ResponseWriter, r *http.Request) { switch r.Method { case http.MethodGet: - userId, accountId := extractUserAndAccountIdFromRequestContext(r, h.authAudience) - //new user -> create a new account - account, err := h.accountManager.GetAccountByUserOrAccountId(userId, accountId) + account, err := h.getPeerAccount(r) if err != nil { - log.Errorf("failed getting account of a user %s: %v", userId, err) + log.Error(err) http.Redirect(w, r, "/", http.StatusInternalServerError) return } diff --git a/management/server/http/handler/setupkeys.go b/management/server/http/handler/setupkeys.go index 479410eba..23d6fdd76 100644 --- a/management/server/http/handler/setupkeys.go +++ b/management/server/http/handler/setupkeys.go @@ -2,6 +2,7 @@ package handler import ( "encoding/json" + "fmt" "github.com/gorilla/mux" log "github.com/sirupsen/logrus" "github.com/wiretrustee/wiretrustee/management/server" @@ -78,7 +79,7 @@ func (h *SetupKeys) updateKey(accountId string, keyId string, w http.ResponseWri } func (h *SetupKeys) getKey(accountId string, keyId string, w http.ResponseWriter, r *http.Request) { - account, err := h.accountManager.GetAccount(accountId) + account, err := h.accountManager.GetAccountById(accountId) if err != nil { http.Error(w, "account doesn't exist", http.StatusInternalServerError) return @@ -119,11 +120,21 @@ func (h *SetupKeys) createKey(accountId string, w http.ResponseWriter, r *http.R writeSuccess(w, setupKey) } -func (h *SetupKeys) HandleKey(w http.ResponseWriter, r *http.Request) { - userId, accountId := extractUserAndAccountIdFromRequestContext(r, h.authAudience) - account, err := h.accountManager.GetAccountByUserOrAccountId(userId, accountId) +func (h *SetupKeys) getSetupKeyAccount(r *http.Request) (*server.Account, error) { + jwtClaims := extractClaimsFromRequestContext(r, h.authAudience) + + account, err := h.accountManager.GetAccountByUserOrAccountId(jwtClaims.UserId, jwtClaims.AccountId, jwtClaims.Domain) if err != nil { - log.Errorf("failed getting account of a user %s: %v", userId, err) + return nil, fmt.Errorf("failed getting account of a user %s: %v", jwtClaims.UserId, err) + } + + return account, nil +} + +func (h *SetupKeys) HandleKey(w http.ResponseWriter, r *http.Request) { + account, err := h.getSetupKeyAccount(r) + if err != nil { + log.Error(err) http.Redirect(w, r, "/", http.StatusInternalServerError) return } @@ -149,10 +160,9 @@ func (h *SetupKeys) HandleKey(w http.ResponseWriter, r *http.Request) { func (h *SetupKeys) GetKeys(w http.ResponseWriter, r *http.Request) { - userId, accountId := extractUserAndAccountIdFromRequestContext(r, h.authAudience) - account, err := h.accountManager.GetAccountByUserOrAccountId(userId, accountId) + account, err := h.getSetupKeyAccount(r) if err != nil { - log.Errorf("failed getting account of a user %s: %v", userId, err) + log.Error(err) http.Redirect(w, r, "/", http.StatusInternalServerError) return } diff --git a/management/server/http/handler/util.go b/management/server/http/handler/util.go index 86b1993e1..248581e0e 100644 --- a/management/server/http/handler/util.go +++ b/management/server/http/handler/util.go @@ -8,17 +8,28 @@ import ( "time" ) -// extractUserAndAccountIdFromRequestContext extracts accountId from the request context previously filled by the JWT token (after auth) -func extractUserAndAccountIdFromRequestContext(r *http.Request, authAudiance string) (userId, accountId string) { +// JWTClaims stores information from JWTs +type JWTClaims struct { + UserId string + AccountId string + Domain string +} + +// extractClaimsFromRequestContext extracts claims from the request context previously filled by the JWT token (after auth) +func extractClaimsFromRequestContext(r *http.Request, authAudiance string) JWTClaims { token := r.Context().Value("user").(*jwt.Token) claims := token.Claims.(jwt.MapClaims) - - userId = claims["sub"].(string) - accountIdInt, ok := claims[authAudiance+"wt_account_id"] + jwtClaims := JWTClaims{} + jwtClaims.UserId = claims["sub"].(string) + accountIdClaim, ok := claims[authAudiance+"wt_account_id"] if ok { - accountId = accountIdInt.(string) + jwtClaims.AccountId = accountIdClaim.(string) } - return userId, accountId + domainClaim, ok := claims[authAudiance+"wt_user_domain"] + if ok { + jwtClaims.AccountId = domainClaim.(string) + } + return jwtClaims } //writeJSONObject simply writes object to the HTTP reponse in JSON format diff --git a/management/server/peer_test.go b/management/server/peer_test.go index e70b691b2..ff7c4deb9 100644 --- a/management/server/peer_test.go +++ b/management/server/peer_test.go @@ -14,7 +14,7 @@ func TestAccountManager_GetNetworkMap(t *testing.T) { expectedId := "test_account" userId := "account_creator" - account, err := manager.AddAccount(expectedId, userId) + account, err := manager.AddAccount(expectedId, userId, "") if err != nil { t.Fatal(err) } diff --git a/management/server/user.go b/management/server/user.go index 2134a944f..f16d36d15 100644 --- a/management/server/user.go +++ b/management/server/user.go @@ -40,14 +40,14 @@ func NewAdminUser(id string) *User { } // GetOrCreateAccountByUser returns an existing account for a given user id or creates a new one if doesn't exist -func (am *AccountManager) GetOrCreateAccountByUser(userId string) (*Account, error) { +func (am *AccountManager) GetOrCreateAccountByUser(userId, domain string) (*Account, error) { am.mux.Lock() defer am.mux.Unlock() account, err := am.Store.GetUserAccount(userId) if err != nil { if s, ok := status.FromError(err); ok && s.Code() == codes.NotFound { - account = NewAccount(userId) + account = NewAccount(userId, domain) account.Users[userId] = NewAdminUser(userId) err = am.Store.SaveAccount(account) if err != nil { @@ -59,6 +59,14 @@ func (am *AccountManager) GetOrCreateAccountByUser(userId string) (*Account, err } } + if account.Domain != domain { + account.Domain = domain + err = am.Store.SaveAccount(account) + if err != nil { + return nil, status.Errorf(codes.Internal, "failed updating account with domain") + } + } + return account, nil }