mirror of
https://github.com/netbirdio/netbird.git
synced 2024-11-22 08:03:30 +01:00
Store domain information (#217)
* extract claim information from JWT * get account function * Store domain * tests missing domain * update existing account with domain * add store domain tests
This commit is contained in:
parent
919f0aa3da
commit
cd9a418df2
@ -23,6 +23,7 @@ type Account struct {
|
||||
Id string
|
||||
// User.Id it was created by
|
||||
CreatedBy string
|
||||
Domain string
|
||||
SetupKeys map[string]*SetupKey
|
||||
Network *Network
|
||||
Peers map[string]*Peer
|
||||
@ -30,9 +31,9 @@ type Account struct {
|
||||
}
|
||||
|
||||
// NewAccount creates a new Account with a generated ID and generated default setup keys
|
||||
func NewAccount(userId string) *Account {
|
||||
func NewAccount(userId, domain string) *Account {
|
||||
accountId := xid.New().String()
|
||||
return newAccountWithId(accountId, userId)
|
||||
return newAccountWithId(accountId, userId, domain)
|
||||
}
|
||||
|
||||
func (a *Account) Copy() *Account {
|
||||
@ -149,8 +150,8 @@ func (am *AccountManager) RenameSetupKey(accountId string, keyId string, newName
|
||||
return keyCopy, nil
|
||||
}
|
||||
|
||||
//GetAccount returns an existing account or error (NotFound) if doesn't exist
|
||||
func (am *AccountManager) GetAccount(accountId string) (*Account, error) {
|
||||
//GetAccountById returns an existing account using its ID or error (NotFound) if doesn't exist
|
||||
func (am *AccountManager) GetAccountById(accountId string) (*Account, error) {
|
||||
am.mux.Lock()
|
||||
defer am.mux.Unlock()
|
||||
|
||||
@ -164,12 +165,12 @@ func (am *AccountManager) GetAccount(accountId string) (*Account, error) {
|
||||
|
||||
//GetAccountByUserOrAccountId look for an account by user or account Id, if no account is provided and
|
||||
// user id doesn't have an account associated with it, one account is created
|
||||
func (am *AccountManager) GetAccountByUserOrAccountId(userId, accountId string) (*Account, error) {
|
||||
func (am *AccountManager) GetAccountByUserOrAccountId(userId, accountId, domain string) (*Account, error) {
|
||||
|
||||
if accountId != "" {
|
||||
return am.GetAccount(accountId)
|
||||
return am.GetAccountById(accountId)
|
||||
} else if userId != "" {
|
||||
account, err := am.GetOrCreateAccountByUser(userId)
|
||||
account, err := am.GetOrCreateAccountByUser(userId, domain)
|
||||
if err != nil {
|
||||
return nil, status.Errorf(codes.NotFound, "account not found using user id: %s", userId)
|
||||
}
|
||||
@ -207,17 +208,17 @@ func (am *AccountManager) AccountExists(accountId string) (*bool, error) {
|
||||
}
|
||||
|
||||
// AddAccount generates a new Account with a provided accountId and userId, saves to the Store
|
||||
func (am *AccountManager) AddAccount(accountId string, userId string) (*Account, error) {
|
||||
func (am *AccountManager) AddAccount(accountId, userId, domain string) (*Account, error) {
|
||||
|
||||
am.mux.Lock()
|
||||
defer am.mux.Unlock()
|
||||
|
||||
return am.createAccount(accountId, userId)
|
||||
return am.createAccount(accountId, userId, domain)
|
||||
|
||||
}
|
||||
|
||||
func (am *AccountManager) createAccount(accountId string, userId string) (*Account, error) {
|
||||
account := newAccountWithId(accountId, userId)
|
||||
func (am *AccountManager) createAccount(accountId, userId, domain string) (*Account, error) {
|
||||
account := newAccountWithId(accountId, userId, domain)
|
||||
|
||||
err := am.Store.SaveAccount(account)
|
||||
if err != nil {
|
||||
@ -228,7 +229,7 @@ func (am *AccountManager) createAccount(accountId string, userId string) (*Accou
|
||||
}
|
||||
|
||||
// newAccountWithId creates a new Account with a default SetupKey (doesn't store in a Store) and provided id
|
||||
func newAccountWithId(accountId string, userId string) *Account {
|
||||
func newAccountWithId(accountId, userId, domain string) *Account {
|
||||
|
||||
log.Debugf("creating new account")
|
||||
|
||||
@ -243,7 +244,15 @@ func newAccountWithId(accountId string, userId string) *Account {
|
||||
|
||||
log.Debugf("created new account %s with setup key %s", accountId, defaultKey.Key)
|
||||
|
||||
return &Account{Id: accountId, SetupKeys: setupKeys, Network: network, Peers: peers, Users: users, CreatedBy: userId}
|
||||
return &Account{
|
||||
Id: accountId,
|
||||
SetupKeys: setupKeys,
|
||||
Network: network,
|
||||
Peers: peers,
|
||||
Users: users,
|
||||
CreatedBy: userId,
|
||||
Domain: domain,
|
||||
}
|
||||
}
|
||||
|
||||
func getAccountSetupKeyById(acc *Account, keyId string) *SetupKey {
|
||||
|
@ -14,7 +14,7 @@ func TestAccountManager_GetOrCreateAccountByUser(t *testing.T) {
|
||||
}
|
||||
|
||||
userId := "test_user"
|
||||
account, err := manager.GetOrCreateAccountByUser(userId)
|
||||
account, err := manager.GetOrCreateAccountByUser(userId, "")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
@ -32,6 +32,43 @@ func TestAccountManager_GetOrCreateAccountByUser(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestAccountManager_SetOrUpdateDomain(t *testing.T) {
|
||||
manager, err := createManager(t)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
return
|
||||
}
|
||||
|
||||
userId := "test_user"
|
||||
domain := "hotmail.com"
|
||||
account, err := manager.GetOrCreateAccountByUser(userId, domain)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if account == nil {
|
||||
t.Fatalf("expected to create an account for a user %s", userId)
|
||||
}
|
||||
|
||||
if account.Domain != domain {
|
||||
t.Errorf("setting account domain failed, expected %s, got %s", domain, account.Domain)
|
||||
}
|
||||
|
||||
domain = "gmail.com"
|
||||
|
||||
account, err = manager.GetOrCreateAccountByUser(userId, domain)
|
||||
if err != nil {
|
||||
t.Fatalf("got the following error while retrieving existing acc: %v", err)
|
||||
}
|
||||
|
||||
if account == nil {
|
||||
t.Fatalf("expected to get an account for a user %s", userId)
|
||||
}
|
||||
|
||||
if account.Domain != domain {
|
||||
t.Errorf("updating domain. expected %s got %s", domain, account.Domain)
|
||||
}
|
||||
}
|
||||
|
||||
func TestAccountManager_AddAccount(t *testing.T) {
|
||||
manager, err := createManager(t)
|
||||
if err != nil {
|
||||
@ -48,7 +85,7 @@ func TestAccountManager_AddAccount(t *testing.T) {
|
||||
Mask: net.IPMask{255, 192, 0, 0},
|
||||
}
|
||||
|
||||
account, err := manager.AddAccount(expectedId, userId)
|
||||
account, err := manager.AddAccount(expectedId, userId, "")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
@ -79,7 +116,7 @@ func TestAccountManager_GetAccountByUserOrAccountId(t *testing.T) {
|
||||
|
||||
userId := "test_user"
|
||||
|
||||
account, err := manager.GetAccountByUserOrAccountId(userId, "")
|
||||
account, err := manager.GetAccountByUserOrAccountId(userId, "", "")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
@ -89,12 +126,12 @@ func TestAccountManager_GetAccountByUserOrAccountId(t *testing.T) {
|
||||
|
||||
accountId := account.Id
|
||||
|
||||
_, err = manager.GetAccountByUserOrAccountId("", accountId)
|
||||
_, err = manager.GetAccountByUserOrAccountId("", accountId, "")
|
||||
if err != nil {
|
||||
t.Errorf("expected to get existing account after creation using userid, no account was found for a account %s", accountId)
|
||||
}
|
||||
|
||||
_, err = manager.GetAccountByUserOrAccountId("", "")
|
||||
_, err = manager.GetAccountByUserOrAccountId("", "", "")
|
||||
if err == nil {
|
||||
t.Errorf("expected an error when user and account IDs are empty")
|
||||
}
|
||||
@ -109,7 +146,7 @@ func TestAccountManager_AccountExists(t *testing.T) {
|
||||
|
||||
expectedId := "test_account"
|
||||
userId := "account_creator"
|
||||
_, err = manager.AddAccount(expectedId, userId)
|
||||
_, err = manager.AddAccount(expectedId, userId, "")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
@ -134,13 +171,13 @@ func TestAccountManager_GetAccount(t *testing.T) {
|
||||
|
||||
expectedId := "test_account"
|
||||
userId := "account_creator"
|
||||
account, err := manager.AddAccount(expectedId, userId)
|
||||
account, err := manager.AddAccount(expectedId, userId, "")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
//AddAccount has been already tested so we can assume it is correct and compare results
|
||||
getAccount, err := manager.GetAccount(expectedId)
|
||||
getAccount, err := manager.GetAccountById(expectedId)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
return
|
||||
@ -171,7 +208,7 @@ func TestAccountManager_AddPeer(t *testing.T) {
|
||||
return
|
||||
}
|
||||
|
||||
account, err := manager.AddAccount("test_account", "account_creator")
|
||||
account, err := manager.AddAccount("test_account", "account_creator", "")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
@ -211,7 +248,7 @@ func TestAccountManager_AddPeer(t *testing.T) {
|
||||
return
|
||||
}
|
||||
|
||||
account, err = manager.GetAccount(account.Id)
|
||||
account, err = manager.GetAccountById(account.Id)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
return
|
||||
@ -238,7 +275,7 @@ func TestAccountManager_DeletePeer(t *testing.T) {
|
||||
return
|
||||
}
|
||||
|
||||
account, err := manager.AddAccount("test_account", "account_creator")
|
||||
account, err := manager.AddAccount("test_account", "account_creator", "")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
@ -271,7 +308,7 @@ func TestAccountManager_DeletePeer(t *testing.T) {
|
||||
return
|
||||
}
|
||||
|
||||
account, err = manager.GetAccount(account.Id)
|
||||
account, err = manager.GetAccountById(account.Id)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
return
|
||||
|
@ -32,7 +32,7 @@ func TestNewStore(t *testing.T) {
|
||||
func TestSaveAccount(t *testing.T) {
|
||||
store := newStore(t)
|
||||
|
||||
account := NewAccount("testuser")
|
||||
account := NewAccount("testuser", "")
|
||||
account.Users["testuser"] = NewAdminUser("testuser")
|
||||
setupKey := GenerateDefaultSetupKey()
|
||||
account.SetupKeys[setupKey.Key] = setupKey
|
||||
@ -72,7 +72,7 @@ func TestSaveAccount(t *testing.T) {
|
||||
func TestStore(t *testing.T) {
|
||||
store := newStore(t)
|
||||
|
||||
account := NewAccount("testuser")
|
||||
account := NewAccount("testuser", "")
|
||||
account.Users["testuser"] = NewAdminUser("testuser")
|
||||
account.Peers["testpeer"] = &Peer{
|
||||
Key: "peerkey",
|
||||
|
@ -63,12 +63,21 @@ func (h *Peers) deletePeer(accountId string, peer *server.Peer, w http.ResponseW
|
||||
writeJSONObject(w, "")
|
||||
}
|
||||
|
||||
func (h *Peers) HandlePeer(w http.ResponseWriter, r *http.Request) {
|
||||
userId, accountId := extractUserAndAccountIdFromRequestContext(r, h.authAudience)
|
||||
//new user -> create a new account
|
||||
account, err := h.accountManager.GetAccountByUserOrAccountId(userId, accountId)
|
||||
func (h *Peers) getPeerAccount(r *http.Request) (*server.Account, error) {
|
||||
jwtClaims := extractClaimsFromRequestContext(r, h.authAudience)
|
||||
|
||||
account, err := h.accountManager.GetAccountByUserOrAccountId(jwtClaims.UserId, jwtClaims.AccountId, jwtClaims.Domain)
|
||||
if err != nil {
|
||||
log.Errorf("failed getting account of a user %s: %v", userId, err)
|
||||
return nil, fmt.Errorf("failed getting account of a user %s: %v", jwtClaims.UserId, err)
|
||||
}
|
||||
|
||||
return account, nil
|
||||
}
|
||||
|
||||
func (h *Peers) HandlePeer(w http.ResponseWriter, r *http.Request) {
|
||||
account, err := h.getPeerAccount(r)
|
||||
if err != nil {
|
||||
log.Error(err)
|
||||
http.Redirect(w, r, "/", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
@ -105,11 +114,9 @@ func (h *Peers) HandlePeer(w http.ResponseWriter, r *http.Request) {
|
||||
func (h *Peers) GetPeers(w http.ResponseWriter, r *http.Request) {
|
||||
switch r.Method {
|
||||
case http.MethodGet:
|
||||
userId, accountId := extractUserAndAccountIdFromRequestContext(r, h.authAudience)
|
||||
//new user -> create a new account
|
||||
account, err := h.accountManager.GetAccountByUserOrAccountId(userId, accountId)
|
||||
account, err := h.getPeerAccount(r)
|
||||
if err != nil {
|
||||
log.Errorf("failed getting account of a user %s: %v", userId, err)
|
||||
log.Error(err)
|
||||
http.Redirect(w, r, "/", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
@ -2,6 +2,7 @@ package handler
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"github.com/gorilla/mux"
|
||||
log "github.com/sirupsen/logrus"
|
||||
"github.com/wiretrustee/wiretrustee/management/server"
|
||||
@ -78,7 +79,7 @@ func (h *SetupKeys) updateKey(accountId string, keyId string, w http.ResponseWri
|
||||
}
|
||||
|
||||
func (h *SetupKeys) getKey(accountId string, keyId string, w http.ResponseWriter, r *http.Request) {
|
||||
account, err := h.accountManager.GetAccount(accountId)
|
||||
account, err := h.accountManager.GetAccountById(accountId)
|
||||
if err != nil {
|
||||
http.Error(w, "account doesn't exist", http.StatusInternalServerError)
|
||||
return
|
||||
@ -119,11 +120,21 @@ func (h *SetupKeys) createKey(accountId string, w http.ResponseWriter, r *http.R
|
||||
writeSuccess(w, setupKey)
|
||||
}
|
||||
|
||||
func (h *SetupKeys) HandleKey(w http.ResponseWriter, r *http.Request) {
|
||||
userId, accountId := extractUserAndAccountIdFromRequestContext(r, h.authAudience)
|
||||
account, err := h.accountManager.GetAccountByUserOrAccountId(userId, accountId)
|
||||
func (h *SetupKeys) getSetupKeyAccount(r *http.Request) (*server.Account, error) {
|
||||
jwtClaims := extractClaimsFromRequestContext(r, h.authAudience)
|
||||
|
||||
account, err := h.accountManager.GetAccountByUserOrAccountId(jwtClaims.UserId, jwtClaims.AccountId, jwtClaims.Domain)
|
||||
if err != nil {
|
||||
log.Errorf("failed getting account of a user %s: %v", userId, err)
|
||||
return nil, fmt.Errorf("failed getting account of a user %s: %v", jwtClaims.UserId, err)
|
||||
}
|
||||
|
||||
return account, nil
|
||||
}
|
||||
|
||||
func (h *SetupKeys) HandleKey(w http.ResponseWriter, r *http.Request) {
|
||||
account, err := h.getSetupKeyAccount(r)
|
||||
if err != nil {
|
||||
log.Error(err)
|
||||
http.Redirect(w, r, "/", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
@ -149,10 +160,9 @@ func (h *SetupKeys) HandleKey(w http.ResponseWriter, r *http.Request) {
|
||||
|
||||
func (h *SetupKeys) GetKeys(w http.ResponseWriter, r *http.Request) {
|
||||
|
||||
userId, accountId := extractUserAndAccountIdFromRequestContext(r, h.authAudience)
|
||||
account, err := h.accountManager.GetAccountByUserOrAccountId(userId, accountId)
|
||||
account, err := h.getSetupKeyAccount(r)
|
||||
if err != nil {
|
||||
log.Errorf("failed getting account of a user %s: %v", userId, err)
|
||||
log.Error(err)
|
||||
http.Redirect(w, r, "/", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
@ -8,17 +8,28 @@ import (
|
||||
"time"
|
||||
)
|
||||
|
||||
// extractUserAndAccountIdFromRequestContext extracts accountId from the request context previously filled by the JWT token (after auth)
|
||||
func extractUserAndAccountIdFromRequestContext(r *http.Request, authAudiance string) (userId, accountId string) {
|
||||
// JWTClaims stores information from JWTs
|
||||
type JWTClaims struct {
|
||||
UserId string
|
||||
AccountId string
|
||||
Domain string
|
||||
}
|
||||
|
||||
// extractClaimsFromRequestContext extracts claims from the request context previously filled by the JWT token (after auth)
|
||||
func extractClaimsFromRequestContext(r *http.Request, authAudiance string) JWTClaims {
|
||||
token := r.Context().Value("user").(*jwt.Token)
|
||||
claims := token.Claims.(jwt.MapClaims)
|
||||
|
||||
userId = claims["sub"].(string)
|
||||
accountIdInt, ok := claims[authAudiance+"wt_account_id"]
|
||||
jwtClaims := JWTClaims{}
|
||||
jwtClaims.UserId = claims["sub"].(string)
|
||||
accountIdClaim, ok := claims[authAudiance+"wt_account_id"]
|
||||
if ok {
|
||||
accountId = accountIdInt.(string)
|
||||
jwtClaims.AccountId = accountIdClaim.(string)
|
||||
}
|
||||
return userId, accountId
|
||||
domainClaim, ok := claims[authAudiance+"wt_user_domain"]
|
||||
if ok {
|
||||
jwtClaims.AccountId = domainClaim.(string)
|
||||
}
|
||||
return jwtClaims
|
||||
}
|
||||
|
||||
//writeJSONObject simply writes object to the HTTP reponse in JSON format
|
||||
|
@ -14,7 +14,7 @@ func TestAccountManager_GetNetworkMap(t *testing.T) {
|
||||
|
||||
expectedId := "test_account"
|
||||
userId := "account_creator"
|
||||
account, err := manager.AddAccount(expectedId, userId)
|
||||
account, err := manager.AddAccount(expectedId, userId, "")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
@ -40,14 +40,14 @@ func NewAdminUser(id string) *User {
|
||||
}
|
||||
|
||||
// GetOrCreateAccountByUser returns an existing account for a given user id or creates a new one if doesn't exist
|
||||
func (am *AccountManager) GetOrCreateAccountByUser(userId string) (*Account, error) {
|
||||
func (am *AccountManager) GetOrCreateAccountByUser(userId, domain string) (*Account, error) {
|
||||
am.mux.Lock()
|
||||
defer am.mux.Unlock()
|
||||
|
||||
account, err := am.Store.GetUserAccount(userId)
|
||||
if err != nil {
|
||||
if s, ok := status.FromError(err); ok && s.Code() == codes.NotFound {
|
||||
account = NewAccount(userId)
|
||||
account = NewAccount(userId, domain)
|
||||
account.Users[userId] = NewAdminUser(userId)
|
||||
err = am.Store.SaveAccount(account)
|
||||
if err != nil {
|
||||
@ -59,6 +59,14 @@ func (am *AccountManager) GetOrCreateAccountByUser(userId string) (*Account, err
|
||||
}
|
||||
}
|
||||
|
||||
if account.Domain != domain {
|
||||
account.Domain = domain
|
||||
err = am.Store.SaveAccount(account)
|
||||
if err != nil {
|
||||
return nil, status.Errorf(codes.Internal, "failed updating account with domain")
|
||||
}
|
||||
}
|
||||
|
||||
return account, nil
|
||||
}
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user