Store domain information (#217)

* extract claim information from JWT

* get account function

* Store domain

* tests missing domain

* update existing account with domain

* add store domain tests
This commit is contained in:
Maycon Santos 2022-02-11 17:18:18 +01:00 committed by GitHub
parent 919f0aa3da
commit cd9a418df2
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
8 changed files with 136 additions and 54 deletions

View File

@ -23,6 +23,7 @@ type Account struct {
Id string Id string
// User.Id it was created by // User.Id it was created by
CreatedBy string CreatedBy string
Domain string
SetupKeys map[string]*SetupKey SetupKeys map[string]*SetupKey
Network *Network Network *Network
Peers map[string]*Peer Peers map[string]*Peer
@ -30,9 +31,9 @@ type Account struct {
} }
// NewAccount creates a new Account with a generated ID and generated default setup keys // NewAccount creates a new Account with a generated ID and generated default setup keys
func NewAccount(userId string) *Account { func NewAccount(userId, domain string) *Account {
accountId := xid.New().String() accountId := xid.New().String()
return newAccountWithId(accountId, userId) return newAccountWithId(accountId, userId, domain)
} }
func (a *Account) Copy() *Account { func (a *Account) Copy() *Account {
@ -149,8 +150,8 @@ func (am *AccountManager) RenameSetupKey(accountId string, keyId string, newName
return keyCopy, nil return keyCopy, nil
} }
//GetAccount returns an existing account or error (NotFound) if doesn't exist //GetAccountById returns an existing account using its ID or error (NotFound) if doesn't exist
func (am *AccountManager) GetAccount(accountId string) (*Account, error) { func (am *AccountManager) GetAccountById(accountId string) (*Account, error) {
am.mux.Lock() am.mux.Lock()
defer am.mux.Unlock() defer am.mux.Unlock()
@ -164,12 +165,12 @@ func (am *AccountManager) GetAccount(accountId string) (*Account, error) {
//GetAccountByUserOrAccountId look for an account by user or account Id, if no account is provided and //GetAccountByUserOrAccountId look for an account by user or account Id, if no account is provided and
// user id doesn't have an account associated with it, one account is created // user id doesn't have an account associated with it, one account is created
func (am *AccountManager) GetAccountByUserOrAccountId(userId, accountId string) (*Account, error) { func (am *AccountManager) GetAccountByUserOrAccountId(userId, accountId, domain string) (*Account, error) {
if accountId != "" { if accountId != "" {
return am.GetAccount(accountId) return am.GetAccountById(accountId)
} else if userId != "" { } else if userId != "" {
account, err := am.GetOrCreateAccountByUser(userId) account, err := am.GetOrCreateAccountByUser(userId, domain)
if err != nil { if err != nil {
return nil, status.Errorf(codes.NotFound, "account not found using user id: %s", userId) return nil, status.Errorf(codes.NotFound, "account not found using user id: %s", userId)
} }
@ -207,17 +208,17 @@ func (am *AccountManager) AccountExists(accountId string) (*bool, error) {
} }
// AddAccount generates a new Account with a provided accountId and userId, saves to the Store // AddAccount generates a new Account with a provided accountId and userId, saves to the Store
func (am *AccountManager) AddAccount(accountId string, userId string) (*Account, error) { func (am *AccountManager) AddAccount(accountId, userId, domain string) (*Account, error) {
am.mux.Lock() am.mux.Lock()
defer am.mux.Unlock() defer am.mux.Unlock()
return am.createAccount(accountId, userId) return am.createAccount(accountId, userId, domain)
} }
func (am *AccountManager) createAccount(accountId string, userId string) (*Account, error) { func (am *AccountManager) createAccount(accountId, userId, domain string) (*Account, error) {
account := newAccountWithId(accountId, userId) account := newAccountWithId(accountId, userId, domain)
err := am.Store.SaveAccount(account) err := am.Store.SaveAccount(account)
if err != nil { if err != nil {
@ -228,7 +229,7 @@ func (am *AccountManager) createAccount(accountId string, userId string) (*Accou
} }
// newAccountWithId creates a new Account with a default SetupKey (doesn't store in a Store) and provided id // newAccountWithId creates a new Account with a default SetupKey (doesn't store in a Store) and provided id
func newAccountWithId(accountId string, userId string) *Account { func newAccountWithId(accountId, userId, domain string) *Account {
log.Debugf("creating new account") log.Debugf("creating new account")
@ -243,7 +244,15 @@ func newAccountWithId(accountId string, userId string) *Account {
log.Debugf("created new account %s with setup key %s", accountId, defaultKey.Key) log.Debugf("created new account %s with setup key %s", accountId, defaultKey.Key)
return &Account{Id: accountId, SetupKeys: setupKeys, Network: network, Peers: peers, Users: users, CreatedBy: userId} return &Account{
Id: accountId,
SetupKeys: setupKeys,
Network: network,
Peers: peers,
Users: users,
CreatedBy: userId,
Domain: domain,
}
} }
func getAccountSetupKeyById(acc *Account, keyId string) *SetupKey { func getAccountSetupKeyById(acc *Account, keyId string) *SetupKey {

View File

@ -14,7 +14,7 @@ func TestAccountManager_GetOrCreateAccountByUser(t *testing.T) {
} }
userId := "test_user" userId := "test_user"
account, err := manager.GetOrCreateAccountByUser(userId) account, err := manager.GetOrCreateAccountByUser(userId, "")
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
@ -32,6 +32,43 @@ func TestAccountManager_GetOrCreateAccountByUser(t *testing.T) {
} }
} }
func TestAccountManager_SetOrUpdateDomain(t *testing.T) {
manager, err := createManager(t)
if err != nil {
t.Fatal(err)
return
}
userId := "test_user"
domain := "hotmail.com"
account, err := manager.GetOrCreateAccountByUser(userId, domain)
if err != nil {
t.Fatal(err)
}
if account == nil {
t.Fatalf("expected to create an account for a user %s", userId)
}
if account.Domain != domain {
t.Errorf("setting account domain failed, expected %s, got %s", domain, account.Domain)
}
domain = "gmail.com"
account, err = manager.GetOrCreateAccountByUser(userId, domain)
if err != nil {
t.Fatalf("got the following error while retrieving existing acc: %v", err)
}
if account == nil {
t.Fatalf("expected to get an account for a user %s", userId)
}
if account.Domain != domain {
t.Errorf("updating domain. expected %s got %s", domain, account.Domain)
}
}
func TestAccountManager_AddAccount(t *testing.T) { func TestAccountManager_AddAccount(t *testing.T) {
manager, err := createManager(t) manager, err := createManager(t)
if err != nil { if err != nil {
@ -48,7 +85,7 @@ func TestAccountManager_AddAccount(t *testing.T) {
Mask: net.IPMask{255, 192, 0, 0}, Mask: net.IPMask{255, 192, 0, 0},
} }
account, err := manager.AddAccount(expectedId, userId) account, err := manager.AddAccount(expectedId, userId, "")
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
@ -79,7 +116,7 @@ func TestAccountManager_GetAccountByUserOrAccountId(t *testing.T) {
userId := "test_user" userId := "test_user"
account, err := manager.GetAccountByUserOrAccountId(userId, "") account, err := manager.GetAccountByUserOrAccountId(userId, "", "")
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
@ -89,12 +126,12 @@ func TestAccountManager_GetAccountByUserOrAccountId(t *testing.T) {
accountId := account.Id accountId := account.Id
_, err = manager.GetAccountByUserOrAccountId("", accountId) _, err = manager.GetAccountByUserOrAccountId("", accountId, "")
if err != nil { if err != nil {
t.Errorf("expected to get existing account after creation using userid, no account was found for a account %s", accountId) t.Errorf("expected to get existing account after creation using userid, no account was found for a account %s", accountId)
} }
_, err = manager.GetAccountByUserOrAccountId("", "") _, err = manager.GetAccountByUserOrAccountId("", "", "")
if err == nil { if err == nil {
t.Errorf("expected an error when user and account IDs are empty") t.Errorf("expected an error when user and account IDs are empty")
} }
@ -109,7 +146,7 @@ func TestAccountManager_AccountExists(t *testing.T) {
expectedId := "test_account" expectedId := "test_account"
userId := "account_creator" userId := "account_creator"
_, err = manager.AddAccount(expectedId, userId) _, err = manager.AddAccount(expectedId, userId, "")
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
@ -134,13 +171,13 @@ func TestAccountManager_GetAccount(t *testing.T) {
expectedId := "test_account" expectedId := "test_account"
userId := "account_creator" userId := "account_creator"
account, err := manager.AddAccount(expectedId, userId) account, err := manager.AddAccount(expectedId, userId, "")
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
//AddAccount has been already tested so we can assume it is correct and compare results //AddAccount has been already tested so we can assume it is correct and compare results
getAccount, err := manager.GetAccount(expectedId) getAccount, err := manager.GetAccountById(expectedId)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
return return
@ -171,7 +208,7 @@ func TestAccountManager_AddPeer(t *testing.T) {
return return
} }
account, err := manager.AddAccount("test_account", "account_creator") account, err := manager.AddAccount("test_account", "account_creator", "")
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
@ -211,7 +248,7 @@ func TestAccountManager_AddPeer(t *testing.T) {
return return
} }
account, err = manager.GetAccount(account.Id) account, err = manager.GetAccountById(account.Id)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
return return
@ -238,7 +275,7 @@ func TestAccountManager_DeletePeer(t *testing.T) {
return return
} }
account, err := manager.AddAccount("test_account", "account_creator") account, err := manager.AddAccount("test_account", "account_creator", "")
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
@ -271,7 +308,7 @@ func TestAccountManager_DeletePeer(t *testing.T) {
return return
} }
account, err = manager.GetAccount(account.Id) account, err = manager.GetAccountById(account.Id)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
return return

View File

@ -32,7 +32,7 @@ func TestNewStore(t *testing.T) {
func TestSaveAccount(t *testing.T) { func TestSaveAccount(t *testing.T) {
store := newStore(t) store := newStore(t)
account := NewAccount("testuser") account := NewAccount("testuser", "")
account.Users["testuser"] = NewAdminUser("testuser") account.Users["testuser"] = NewAdminUser("testuser")
setupKey := GenerateDefaultSetupKey() setupKey := GenerateDefaultSetupKey()
account.SetupKeys[setupKey.Key] = setupKey account.SetupKeys[setupKey.Key] = setupKey
@ -72,7 +72,7 @@ func TestSaveAccount(t *testing.T) {
func TestStore(t *testing.T) { func TestStore(t *testing.T) {
store := newStore(t) store := newStore(t)
account := NewAccount("testuser") account := NewAccount("testuser", "")
account.Users["testuser"] = NewAdminUser("testuser") account.Users["testuser"] = NewAdminUser("testuser")
account.Peers["testpeer"] = &Peer{ account.Peers["testpeer"] = &Peer{
Key: "peerkey", Key: "peerkey",

View File

@ -63,12 +63,21 @@ func (h *Peers) deletePeer(accountId string, peer *server.Peer, w http.ResponseW
writeJSONObject(w, "") writeJSONObject(w, "")
} }
func (h *Peers) HandlePeer(w http.ResponseWriter, r *http.Request) { func (h *Peers) getPeerAccount(r *http.Request) (*server.Account, error) {
userId, accountId := extractUserAndAccountIdFromRequestContext(r, h.authAudience) jwtClaims := extractClaimsFromRequestContext(r, h.authAudience)
//new user -> create a new account
account, err := h.accountManager.GetAccountByUserOrAccountId(userId, accountId) account, err := h.accountManager.GetAccountByUserOrAccountId(jwtClaims.UserId, jwtClaims.AccountId, jwtClaims.Domain)
if err != nil { if err != nil {
log.Errorf("failed getting account of a user %s: %v", userId, err) return nil, fmt.Errorf("failed getting account of a user %s: %v", jwtClaims.UserId, err)
}
return account, nil
}
func (h *Peers) HandlePeer(w http.ResponseWriter, r *http.Request) {
account, err := h.getPeerAccount(r)
if err != nil {
log.Error(err)
http.Redirect(w, r, "/", http.StatusInternalServerError) http.Redirect(w, r, "/", http.StatusInternalServerError)
return return
} }
@ -105,11 +114,9 @@ func (h *Peers) HandlePeer(w http.ResponseWriter, r *http.Request) {
func (h *Peers) GetPeers(w http.ResponseWriter, r *http.Request) { func (h *Peers) GetPeers(w http.ResponseWriter, r *http.Request) {
switch r.Method { switch r.Method {
case http.MethodGet: case http.MethodGet:
userId, accountId := extractUserAndAccountIdFromRequestContext(r, h.authAudience) account, err := h.getPeerAccount(r)
//new user -> create a new account
account, err := h.accountManager.GetAccountByUserOrAccountId(userId, accountId)
if err != nil { if err != nil {
log.Errorf("failed getting account of a user %s: %v", userId, err) log.Error(err)
http.Redirect(w, r, "/", http.StatusInternalServerError) http.Redirect(w, r, "/", http.StatusInternalServerError)
return return
} }

View File

@ -2,6 +2,7 @@ package handler
import ( import (
"encoding/json" "encoding/json"
"fmt"
"github.com/gorilla/mux" "github.com/gorilla/mux"
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
"github.com/wiretrustee/wiretrustee/management/server" "github.com/wiretrustee/wiretrustee/management/server"
@ -78,7 +79,7 @@ func (h *SetupKeys) updateKey(accountId string, keyId string, w http.ResponseWri
} }
func (h *SetupKeys) getKey(accountId string, keyId string, w http.ResponseWriter, r *http.Request) { func (h *SetupKeys) getKey(accountId string, keyId string, w http.ResponseWriter, r *http.Request) {
account, err := h.accountManager.GetAccount(accountId) account, err := h.accountManager.GetAccountById(accountId)
if err != nil { if err != nil {
http.Error(w, "account doesn't exist", http.StatusInternalServerError) http.Error(w, "account doesn't exist", http.StatusInternalServerError)
return return
@ -119,11 +120,21 @@ func (h *SetupKeys) createKey(accountId string, w http.ResponseWriter, r *http.R
writeSuccess(w, setupKey) writeSuccess(w, setupKey)
} }
func (h *SetupKeys) HandleKey(w http.ResponseWriter, r *http.Request) { func (h *SetupKeys) getSetupKeyAccount(r *http.Request) (*server.Account, error) {
userId, accountId := extractUserAndAccountIdFromRequestContext(r, h.authAudience) jwtClaims := extractClaimsFromRequestContext(r, h.authAudience)
account, err := h.accountManager.GetAccountByUserOrAccountId(userId, accountId)
account, err := h.accountManager.GetAccountByUserOrAccountId(jwtClaims.UserId, jwtClaims.AccountId, jwtClaims.Domain)
if err != nil { if err != nil {
log.Errorf("failed getting account of a user %s: %v", userId, err) return nil, fmt.Errorf("failed getting account of a user %s: %v", jwtClaims.UserId, err)
}
return account, nil
}
func (h *SetupKeys) HandleKey(w http.ResponseWriter, r *http.Request) {
account, err := h.getSetupKeyAccount(r)
if err != nil {
log.Error(err)
http.Redirect(w, r, "/", http.StatusInternalServerError) http.Redirect(w, r, "/", http.StatusInternalServerError)
return return
} }
@ -149,10 +160,9 @@ func (h *SetupKeys) HandleKey(w http.ResponseWriter, r *http.Request) {
func (h *SetupKeys) GetKeys(w http.ResponseWriter, r *http.Request) { func (h *SetupKeys) GetKeys(w http.ResponseWriter, r *http.Request) {
userId, accountId := extractUserAndAccountIdFromRequestContext(r, h.authAudience) account, err := h.getSetupKeyAccount(r)
account, err := h.accountManager.GetAccountByUserOrAccountId(userId, accountId)
if err != nil { if err != nil {
log.Errorf("failed getting account of a user %s: %v", userId, err) log.Error(err)
http.Redirect(w, r, "/", http.StatusInternalServerError) http.Redirect(w, r, "/", http.StatusInternalServerError)
return return
} }

View File

@ -8,17 +8,28 @@ import (
"time" "time"
) )
// extractUserAndAccountIdFromRequestContext extracts accountId from the request context previously filled by the JWT token (after auth) // JWTClaims stores information from JWTs
func extractUserAndAccountIdFromRequestContext(r *http.Request, authAudiance string) (userId, accountId string) { type JWTClaims struct {
UserId string
AccountId string
Domain string
}
// extractClaimsFromRequestContext extracts claims from the request context previously filled by the JWT token (after auth)
func extractClaimsFromRequestContext(r *http.Request, authAudiance string) JWTClaims {
token := r.Context().Value("user").(*jwt.Token) token := r.Context().Value("user").(*jwt.Token)
claims := token.Claims.(jwt.MapClaims) claims := token.Claims.(jwt.MapClaims)
jwtClaims := JWTClaims{}
userId = claims["sub"].(string) jwtClaims.UserId = claims["sub"].(string)
accountIdInt, ok := claims[authAudiance+"wt_account_id"] accountIdClaim, ok := claims[authAudiance+"wt_account_id"]
if ok { if ok {
accountId = accountIdInt.(string) jwtClaims.AccountId = accountIdClaim.(string)
} }
return userId, accountId domainClaim, ok := claims[authAudiance+"wt_user_domain"]
if ok {
jwtClaims.AccountId = domainClaim.(string)
}
return jwtClaims
} }
//writeJSONObject simply writes object to the HTTP reponse in JSON format //writeJSONObject simply writes object to the HTTP reponse in JSON format

View File

@ -14,7 +14,7 @@ func TestAccountManager_GetNetworkMap(t *testing.T) {
expectedId := "test_account" expectedId := "test_account"
userId := "account_creator" userId := "account_creator"
account, err := manager.AddAccount(expectedId, userId) account, err := manager.AddAccount(expectedId, userId, "")
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }

View File

@ -40,14 +40,14 @@ func NewAdminUser(id string) *User {
} }
// GetOrCreateAccountByUser returns an existing account for a given user id or creates a new one if doesn't exist // GetOrCreateAccountByUser returns an existing account for a given user id or creates a new one if doesn't exist
func (am *AccountManager) GetOrCreateAccountByUser(userId string) (*Account, error) { func (am *AccountManager) GetOrCreateAccountByUser(userId, domain string) (*Account, error) {
am.mux.Lock() am.mux.Lock()
defer am.mux.Unlock() defer am.mux.Unlock()
account, err := am.Store.GetUserAccount(userId) account, err := am.Store.GetUserAccount(userId)
if err != nil { if err != nil {
if s, ok := status.FromError(err); ok && s.Code() == codes.NotFound { if s, ok := status.FromError(err); ok && s.Code() == codes.NotFound {
account = NewAccount(userId) account = NewAccount(userId, domain)
account.Users[userId] = NewAdminUser(userId) account.Users[userId] = NewAdminUser(userId)
err = am.Store.SaveAccount(account) err = am.Store.SaveAccount(account)
if err != nil { if err != nil {
@ -59,6 +59,14 @@ func (am *AccountManager) GetOrCreateAccountByUser(userId string) (*Account, err
} }
} }
if account.Domain != domain {
account.Domain = domain
err = am.Store.SaveAccount(account)
if err != nil {
return nil, status.Errorf(codes.Internal, "failed updating account with domain")
}
}
return account, nil return account, nil
} }