Store domain information (#217)

* extract claim information from JWT

* get account function

* Store domain

* tests missing domain

* update existing account with domain

* add store domain tests
This commit is contained in:
Maycon Santos 2022-02-11 17:18:18 +01:00 committed by GitHub
parent 919f0aa3da
commit cd9a418df2
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
8 changed files with 136 additions and 54 deletions

View File

@ -23,6 +23,7 @@ type Account struct {
Id string
// User.Id it was created by
CreatedBy string
Domain string
SetupKeys map[string]*SetupKey
Network *Network
Peers map[string]*Peer
@ -30,9 +31,9 @@ type Account struct {
}
// NewAccount creates a new Account with a generated ID and generated default setup keys
func NewAccount(userId string) *Account {
func NewAccount(userId, domain string) *Account {
accountId := xid.New().String()
return newAccountWithId(accountId, userId)
return newAccountWithId(accountId, userId, domain)
}
func (a *Account) Copy() *Account {
@ -149,8 +150,8 @@ func (am *AccountManager) RenameSetupKey(accountId string, keyId string, newName
return keyCopy, nil
}
//GetAccount returns an existing account or error (NotFound) if doesn't exist
func (am *AccountManager) GetAccount(accountId string) (*Account, error) {
//GetAccountById returns an existing account using its ID or error (NotFound) if doesn't exist
func (am *AccountManager) GetAccountById(accountId string) (*Account, error) {
am.mux.Lock()
defer am.mux.Unlock()
@ -164,12 +165,12 @@ func (am *AccountManager) GetAccount(accountId string) (*Account, error) {
//GetAccountByUserOrAccountId look for an account by user or account Id, if no account is provided and
// user id doesn't have an account associated with it, one account is created
func (am *AccountManager) GetAccountByUserOrAccountId(userId, accountId string) (*Account, error) {
func (am *AccountManager) GetAccountByUserOrAccountId(userId, accountId, domain string) (*Account, error) {
if accountId != "" {
return am.GetAccount(accountId)
return am.GetAccountById(accountId)
} else if userId != "" {
account, err := am.GetOrCreateAccountByUser(userId)
account, err := am.GetOrCreateAccountByUser(userId, domain)
if err != nil {
return nil, status.Errorf(codes.NotFound, "account not found using user id: %s", userId)
}
@ -207,17 +208,17 @@ func (am *AccountManager) AccountExists(accountId string) (*bool, error) {
}
// AddAccount generates a new Account with a provided accountId and userId, saves to the Store
func (am *AccountManager) AddAccount(accountId string, userId string) (*Account, error) {
func (am *AccountManager) AddAccount(accountId, userId, domain string) (*Account, error) {
am.mux.Lock()
defer am.mux.Unlock()
return am.createAccount(accountId, userId)
return am.createAccount(accountId, userId, domain)
}
func (am *AccountManager) createAccount(accountId string, userId string) (*Account, error) {
account := newAccountWithId(accountId, userId)
func (am *AccountManager) createAccount(accountId, userId, domain string) (*Account, error) {
account := newAccountWithId(accountId, userId, domain)
err := am.Store.SaveAccount(account)
if err != nil {
@ -228,7 +229,7 @@ func (am *AccountManager) createAccount(accountId string, userId string) (*Accou
}
// newAccountWithId creates a new Account with a default SetupKey (doesn't store in a Store) and provided id
func newAccountWithId(accountId string, userId string) *Account {
func newAccountWithId(accountId, userId, domain string) *Account {
log.Debugf("creating new account")
@ -243,7 +244,15 @@ func newAccountWithId(accountId string, userId string) *Account {
log.Debugf("created new account %s with setup key %s", accountId, defaultKey.Key)
return &Account{Id: accountId, SetupKeys: setupKeys, Network: network, Peers: peers, Users: users, CreatedBy: userId}
return &Account{
Id: accountId,
SetupKeys: setupKeys,
Network: network,
Peers: peers,
Users: users,
CreatedBy: userId,
Domain: domain,
}
}
func getAccountSetupKeyById(acc *Account, keyId string) *SetupKey {

View File

@ -14,7 +14,7 @@ func TestAccountManager_GetOrCreateAccountByUser(t *testing.T) {
}
userId := "test_user"
account, err := manager.GetOrCreateAccountByUser(userId)
account, err := manager.GetOrCreateAccountByUser(userId, "")
if err != nil {
t.Fatal(err)
}
@ -32,6 +32,43 @@ func TestAccountManager_GetOrCreateAccountByUser(t *testing.T) {
}
}
func TestAccountManager_SetOrUpdateDomain(t *testing.T) {
manager, err := createManager(t)
if err != nil {
t.Fatal(err)
return
}
userId := "test_user"
domain := "hotmail.com"
account, err := manager.GetOrCreateAccountByUser(userId, domain)
if err != nil {
t.Fatal(err)
}
if account == nil {
t.Fatalf("expected to create an account for a user %s", userId)
}
if account.Domain != domain {
t.Errorf("setting account domain failed, expected %s, got %s", domain, account.Domain)
}
domain = "gmail.com"
account, err = manager.GetOrCreateAccountByUser(userId, domain)
if err != nil {
t.Fatalf("got the following error while retrieving existing acc: %v", err)
}
if account == nil {
t.Fatalf("expected to get an account for a user %s", userId)
}
if account.Domain != domain {
t.Errorf("updating domain. expected %s got %s", domain, account.Domain)
}
}
func TestAccountManager_AddAccount(t *testing.T) {
manager, err := createManager(t)
if err != nil {
@ -48,7 +85,7 @@ func TestAccountManager_AddAccount(t *testing.T) {
Mask: net.IPMask{255, 192, 0, 0},
}
account, err := manager.AddAccount(expectedId, userId)
account, err := manager.AddAccount(expectedId, userId, "")
if err != nil {
t.Fatal(err)
}
@ -79,7 +116,7 @@ func TestAccountManager_GetAccountByUserOrAccountId(t *testing.T) {
userId := "test_user"
account, err := manager.GetAccountByUserOrAccountId(userId, "")
account, err := manager.GetAccountByUserOrAccountId(userId, "", "")
if err != nil {
t.Fatal(err)
}
@ -89,12 +126,12 @@ func TestAccountManager_GetAccountByUserOrAccountId(t *testing.T) {
accountId := account.Id
_, err = manager.GetAccountByUserOrAccountId("", accountId)
_, err = manager.GetAccountByUserOrAccountId("", accountId, "")
if err != nil {
t.Errorf("expected to get existing account after creation using userid, no account was found for a account %s", accountId)
}
_, err = manager.GetAccountByUserOrAccountId("", "")
_, err = manager.GetAccountByUserOrAccountId("", "", "")
if err == nil {
t.Errorf("expected an error when user and account IDs are empty")
}
@ -109,7 +146,7 @@ func TestAccountManager_AccountExists(t *testing.T) {
expectedId := "test_account"
userId := "account_creator"
_, err = manager.AddAccount(expectedId, userId)
_, err = manager.AddAccount(expectedId, userId, "")
if err != nil {
t.Fatal(err)
}
@ -134,13 +171,13 @@ func TestAccountManager_GetAccount(t *testing.T) {
expectedId := "test_account"
userId := "account_creator"
account, err := manager.AddAccount(expectedId, userId)
account, err := manager.AddAccount(expectedId, userId, "")
if err != nil {
t.Fatal(err)
}
//AddAccount has been already tested so we can assume it is correct and compare results
getAccount, err := manager.GetAccount(expectedId)
getAccount, err := manager.GetAccountById(expectedId)
if err != nil {
t.Fatal(err)
return
@ -171,7 +208,7 @@ func TestAccountManager_AddPeer(t *testing.T) {
return
}
account, err := manager.AddAccount("test_account", "account_creator")
account, err := manager.AddAccount("test_account", "account_creator", "")
if err != nil {
t.Fatal(err)
}
@ -211,7 +248,7 @@ func TestAccountManager_AddPeer(t *testing.T) {
return
}
account, err = manager.GetAccount(account.Id)
account, err = manager.GetAccountById(account.Id)
if err != nil {
t.Fatal(err)
return
@ -238,7 +275,7 @@ func TestAccountManager_DeletePeer(t *testing.T) {
return
}
account, err := manager.AddAccount("test_account", "account_creator")
account, err := manager.AddAccount("test_account", "account_creator", "")
if err != nil {
t.Fatal(err)
}
@ -271,7 +308,7 @@ func TestAccountManager_DeletePeer(t *testing.T) {
return
}
account, err = manager.GetAccount(account.Id)
account, err = manager.GetAccountById(account.Id)
if err != nil {
t.Fatal(err)
return

View File

@ -32,7 +32,7 @@ func TestNewStore(t *testing.T) {
func TestSaveAccount(t *testing.T) {
store := newStore(t)
account := NewAccount("testuser")
account := NewAccount("testuser", "")
account.Users["testuser"] = NewAdminUser("testuser")
setupKey := GenerateDefaultSetupKey()
account.SetupKeys[setupKey.Key] = setupKey
@ -72,7 +72,7 @@ func TestSaveAccount(t *testing.T) {
func TestStore(t *testing.T) {
store := newStore(t)
account := NewAccount("testuser")
account := NewAccount("testuser", "")
account.Users["testuser"] = NewAdminUser("testuser")
account.Peers["testpeer"] = &Peer{
Key: "peerkey",

View File

@ -63,12 +63,21 @@ func (h *Peers) deletePeer(accountId string, peer *server.Peer, w http.ResponseW
writeJSONObject(w, "")
}
func (h *Peers) HandlePeer(w http.ResponseWriter, r *http.Request) {
userId, accountId := extractUserAndAccountIdFromRequestContext(r, h.authAudience)
//new user -> create a new account
account, err := h.accountManager.GetAccountByUserOrAccountId(userId, accountId)
func (h *Peers) getPeerAccount(r *http.Request) (*server.Account, error) {
jwtClaims := extractClaimsFromRequestContext(r, h.authAudience)
account, err := h.accountManager.GetAccountByUserOrAccountId(jwtClaims.UserId, jwtClaims.AccountId, jwtClaims.Domain)
if err != nil {
log.Errorf("failed getting account of a user %s: %v", userId, err)
return nil, fmt.Errorf("failed getting account of a user %s: %v", jwtClaims.UserId, err)
}
return account, nil
}
func (h *Peers) HandlePeer(w http.ResponseWriter, r *http.Request) {
account, err := h.getPeerAccount(r)
if err != nil {
log.Error(err)
http.Redirect(w, r, "/", http.StatusInternalServerError)
return
}
@ -105,11 +114,9 @@ func (h *Peers) HandlePeer(w http.ResponseWriter, r *http.Request) {
func (h *Peers) GetPeers(w http.ResponseWriter, r *http.Request) {
switch r.Method {
case http.MethodGet:
userId, accountId := extractUserAndAccountIdFromRequestContext(r, h.authAudience)
//new user -> create a new account
account, err := h.accountManager.GetAccountByUserOrAccountId(userId, accountId)
account, err := h.getPeerAccount(r)
if err != nil {
log.Errorf("failed getting account of a user %s: %v", userId, err)
log.Error(err)
http.Redirect(w, r, "/", http.StatusInternalServerError)
return
}

View File

@ -2,6 +2,7 @@ package handler
import (
"encoding/json"
"fmt"
"github.com/gorilla/mux"
log "github.com/sirupsen/logrus"
"github.com/wiretrustee/wiretrustee/management/server"
@ -78,7 +79,7 @@ func (h *SetupKeys) updateKey(accountId string, keyId string, w http.ResponseWri
}
func (h *SetupKeys) getKey(accountId string, keyId string, w http.ResponseWriter, r *http.Request) {
account, err := h.accountManager.GetAccount(accountId)
account, err := h.accountManager.GetAccountById(accountId)
if err != nil {
http.Error(w, "account doesn't exist", http.StatusInternalServerError)
return
@ -119,11 +120,21 @@ func (h *SetupKeys) createKey(accountId string, w http.ResponseWriter, r *http.R
writeSuccess(w, setupKey)
}
func (h *SetupKeys) HandleKey(w http.ResponseWriter, r *http.Request) {
userId, accountId := extractUserAndAccountIdFromRequestContext(r, h.authAudience)
account, err := h.accountManager.GetAccountByUserOrAccountId(userId, accountId)
func (h *SetupKeys) getSetupKeyAccount(r *http.Request) (*server.Account, error) {
jwtClaims := extractClaimsFromRequestContext(r, h.authAudience)
account, err := h.accountManager.GetAccountByUserOrAccountId(jwtClaims.UserId, jwtClaims.AccountId, jwtClaims.Domain)
if err != nil {
log.Errorf("failed getting account of a user %s: %v", userId, err)
return nil, fmt.Errorf("failed getting account of a user %s: %v", jwtClaims.UserId, err)
}
return account, nil
}
func (h *SetupKeys) HandleKey(w http.ResponseWriter, r *http.Request) {
account, err := h.getSetupKeyAccount(r)
if err != nil {
log.Error(err)
http.Redirect(w, r, "/", http.StatusInternalServerError)
return
}
@ -149,10 +160,9 @@ func (h *SetupKeys) HandleKey(w http.ResponseWriter, r *http.Request) {
func (h *SetupKeys) GetKeys(w http.ResponseWriter, r *http.Request) {
userId, accountId := extractUserAndAccountIdFromRequestContext(r, h.authAudience)
account, err := h.accountManager.GetAccountByUserOrAccountId(userId, accountId)
account, err := h.getSetupKeyAccount(r)
if err != nil {
log.Errorf("failed getting account of a user %s: %v", userId, err)
log.Error(err)
http.Redirect(w, r, "/", http.StatusInternalServerError)
return
}

View File

@ -8,17 +8,28 @@ import (
"time"
)
// extractUserAndAccountIdFromRequestContext extracts accountId from the request context previously filled by the JWT token (after auth)
func extractUserAndAccountIdFromRequestContext(r *http.Request, authAudiance string) (userId, accountId string) {
// JWTClaims stores information from JWTs
type JWTClaims struct {
UserId string
AccountId string
Domain string
}
// extractClaimsFromRequestContext extracts claims from the request context previously filled by the JWT token (after auth)
func extractClaimsFromRequestContext(r *http.Request, authAudiance string) JWTClaims {
token := r.Context().Value("user").(*jwt.Token)
claims := token.Claims.(jwt.MapClaims)
userId = claims["sub"].(string)
accountIdInt, ok := claims[authAudiance+"wt_account_id"]
jwtClaims := JWTClaims{}
jwtClaims.UserId = claims["sub"].(string)
accountIdClaim, ok := claims[authAudiance+"wt_account_id"]
if ok {
accountId = accountIdInt.(string)
jwtClaims.AccountId = accountIdClaim.(string)
}
return userId, accountId
domainClaim, ok := claims[authAudiance+"wt_user_domain"]
if ok {
jwtClaims.AccountId = domainClaim.(string)
}
return jwtClaims
}
//writeJSONObject simply writes object to the HTTP reponse in JSON format

View File

@ -14,7 +14,7 @@ func TestAccountManager_GetNetworkMap(t *testing.T) {
expectedId := "test_account"
userId := "account_creator"
account, err := manager.AddAccount(expectedId, userId)
account, err := manager.AddAccount(expectedId, userId, "")
if err != nil {
t.Fatal(err)
}

View File

@ -40,14 +40,14 @@ func NewAdminUser(id string) *User {
}
// GetOrCreateAccountByUser returns an existing account for a given user id or creates a new one if doesn't exist
func (am *AccountManager) GetOrCreateAccountByUser(userId string) (*Account, error) {
func (am *AccountManager) GetOrCreateAccountByUser(userId, domain string) (*Account, error) {
am.mux.Lock()
defer am.mux.Unlock()
account, err := am.Store.GetUserAccount(userId)
if err != nil {
if s, ok := status.FromError(err); ok && s.Code() == codes.NotFound {
account = NewAccount(userId)
account = NewAccount(userId, domain)
account.Users[userId] = NewAdminUser(userId)
err = am.Store.SaveAccount(account)
if err != nil {
@ -59,6 +59,14 @@ func (am *AccountManager) GetOrCreateAccountByUser(userId string) (*Account, err
}
}
if account.Domain != domain {
account.Domain = domain
err = am.Store.SaveAccount(account)
if err != nil {
return nil, status.Errorf(codes.Internal, "failed updating account with domain")
}
}
return account, nil
}