mirror of
https://github.com/netbirdio/netbird.git
synced 2024-11-21 23:53:14 +01:00
Group users of same private domain (#243)
* Added Domain Category field and fix store tests * Add GetAccountByDomain method * Add Domain Category to authorization claims * Initial GetAccountWithAuthorizationClaims test cases * Renamed Private Domain map and index it on saving account * New Go build tags * Added NewRegularUser function * Updated restore to account for primary domain account Also, added another test case * Added grouping user of private domains Also added auxiliary methods for update metadata and domain attributes * Update http handles get account method and tests * Fix lint and document another case * Removed unnecessary log * Move use cases to method and add flow comments * Split the new user and existing logic from GetAccountWithAuthorizationClaims * Review: minor corrections Co-authored-by: braginini <bangvalo@gmail.com>
This commit is contained in:
parent
5d4c2643a3
commit
0b8387bd2c
@ -1,3 +1,4 @@
|
||||
//go:build linux || windows
|
||||
// +build linux windows
|
||||
|
||||
package iface
|
||||
|
@ -1,3 +1,4 @@
|
||||
//go:build darwin
|
||||
// +build darwin
|
||||
|
||||
package iface
|
||||
|
@ -4,12 +4,20 @@ import (
|
||||
"github.com/rs/xid"
|
||||
log "github.com/sirupsen/logrus"
|
||||
"github.com/wiretrustee/wiretrustee/management/server/idp"
|
||||
"github.com/wiretrustee/wiretrustee/management/server/jwtclaims"
|
||||
"github.com/wiretrustee/wiretrustee/util"
|
||||
"google.golang.org/grpc/codes"
|
||||
"google.golang.org/grpc/status"
|
||||
"strings"
|
||||
"sync"
|
||||
)
|
||||
|
||||
const (
|
||||
PublicCategory = "public"
|
||||
PrivateCategory = "private"
|
||||
UnknownCategory = "unknown"
|
||||
)
|
||||
|
||||
type AccountManager interface {
|
||||
GetOrCreateAccountByUser(userId, domain string) (*Account, error)
|
||||
GetAccountByUser(userId string) (*Account, error)
|
||||
@ -18,6 +26,7 @@ type AccountManager interface {
|
||||
RenameSetupKey(accountId string, keyId string, newName string) (*SetupKey, error)
|
||||
GetAccountById(accountId string) (*Account, error)
|
||||
GetAccountByUserOrAccountId(userId, accountId, domain string) (*Account, error)
|
||||
GetAccountWithAuthorizationClaims(claims jwtclaims.AuthorizationClaims) (*Account, error)
|
||||
AccountExists(accountId string) (*bool, error)
|
||||
AddAccount(accountId, userId, domain string) (*Account, error)
|
||||
GetPeer(peerKey string) (*Peer, error)
|
||||
@ -41,12 +50,14 @@ type DefaultAccountManager struct {
|
||||
type Account struct {
|
||||
Id string
|
||||
// User.Id it was created by
|
||||
CreatedBy string
|
||||
Domain string
|
||||
SetupKeys map[string]*SetupKey
|
||||
Network *Network
|
||||
Peers map[string]*Peer
|
||||
Users map[string]*User
|
||||
CreatedBy string
|
||||
Domain string
|
||||
DomainCategory string
|
||||
IsDomainPrimaryAccount bool
|
||||
SetupKeys map[string]*SetupKey
|
||||
Network *Network
|
||||
Peers map[string]*Peer
|
||||
Users map[string]*User
|
||||
}
|
||||
|
||||
// NewAccount creates a new Account with a generated ID and generated default setup keys
|
||||
@ -193,12 +204,9 @@ func (am *DefaultAccountManager) GetAccountByUserOrAccountId(userId, accountId,
|
||||
if err != nil {
|
||||
return nil, status.Errorf(codes.NotFound, "account not found using user id: %s", userId)
|
||||
}
|
||||
// update idp manager app metadata
|
||||
if am.idpManager != nil {
|
||||
err = am.idpManager.UpdateUserAppMetadata(userId, idp.AppMetadata{WTAccountId: account.Id})
|
||||
if err != nil {
|
||||
return nil, status.Errorf(codes.Internal, "updating user's app metadata failed with: %v", err)
|
||||
}
|
||||
err = am.updateIDPMetadata(userId, account.Id)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return account, nil
|
||||
}
|
||||
@ -206,6 +214,137 @@ func (am *DefaultAccountManager) GetAccountByUserOrAccountId(userId, accountId,
|
||||
return nil, status.Errorf(codes.NotFound, "no valid user or account Id provided")
|
||||
}
|
||||
|
||||
// updateIDPMetadata update user's app metadata in idp manager
|
||||
func (am *DefaultAccountManager) updateIDPMetadata(userId, accountID string) error {
|
||||
if am.idpManager != nil {
|
||||
err := am.idpManager.UpdateUserAppMetadata(userId, idp.AppMetadata{WTAccountId: accountID})
|
||||
if err != nil {
|
||||
return status.Errorf(codes.Internal, "updating user's app metadata failed with: %v", err)
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// updateAccountDomainAttributes updates the account domain attributes and then, saves the account
|
||||
func (am *DefaultAccountManager) updateAccountDomainAttributes(account *Account, claims jwtclaims.AuthorizationClaims, primaryDomain bool) error {
|
||||
account.IsDomainPrimaryAccount = primaryDomain
|
||||
account.Domain = strings.ToLower(claims.Domain)
|
||||
account.DomainCategory = claims.DomainCategory
|
||||
err := am.Store.SaveAccount(account)
|
||||
if err != nil {
|
||||
return status.Errorf(codes.Internal, "failed saving updated account")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// handleExistingUserAccount handles existing User accounts and update its domain attributes.
|
||||
//
|
||||
//
|
||||
// If there is no primary domain account yet, we set the account as primary for the domain. Otherwise,
|
||||
// we compare the account's ID with the domain account ID, and if they don't match, we set the account as
|
||||
// non-primary account for the domain. We don't merge accounts at this stage, because of cases when a domain
|
||||
// was previously unclassified or classified as public so N users that logged int that time, has they own account
|
||||
// and peers that shouldn't be lost.
|
||||
func (am *DefaultAccountManager) handleExistingUserAccount(existingAcc *Account, domainAcc *Account, claims jwtclaims.AuthorizationClaims) error {
|
||||
var err error
|
||||
|
||||
if domainAcc == nil || existingAcc.Id != domainAcc.Id {
|
||||
err = am.updateAccountDomainAttributes(existingAcc, claims, false)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
// we should register the account ID to this user's metadata in our IDP manager
|
||||
err = am.updateIDPMetadata(claims.UserId, existingAcc.Id)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// handleNewUserAccount validates if there is an existing primary account for the domain, if so it adds the new user to that account,
|
||||
// otherwise it will create a new account and make it primary account for the domain.
|
||||
func (am *DefaultAccountManager) handleNewUserAccount(domainAcc *Account, claims jwtclaims.AuthorizationClaims) (*Account, error) {
|
||||
var (
|
||||
account *Account
|
||||
primaryAccount bool
|
||||
)
|
||||
lowerDomain := strings.ToLower(claims.Domain)
|
||||
// if domain already has a primary account, add regular user
|
||||
if domainAcc != nil {
|
||||
account = domainAcc
|
||||
account.Users[claims.UserId] = NewRegularUser(claims.UserId)
|
||||
primaryAccount = false
|
||||
} else {
|
||||
account = NewAccount(claims.UserId, lowerDomain)
|
||||
account.Users[claims.UserId] = NewAdminUser(claims.UserId)
|
||||
primaryAccount = true
|
||||
}
|
||||
|
||||
err := am.updateAccountDomainAttributes(account, claims, primaryAccount)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
err = am.updateIDPMetadata(claims.UserId, account.Id)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return account, nil
|
||||
}
|
||||
|
||||
// GetAccountWithAuthorizationClaims retrievs an account using JWT Claims.
|
||||
// if domain is of the PrivateCategory category, it will evaluate
|
||||
// if account is new, existing or if there is another account with the same domain
|
||||
//
|
||||
// Use cases:
|
||||
//
|
||||
// New user + New account + New domain -> create account, user role = admin (if private domain, index domain)
|
||||
//
|
||||
// New user + New account + Existing Private Domain -> add user to the existing account, user role = regular (not admin)
|
||||
//
|
||||
// New user + New account + Existing Public Domain -> create account, user role = admin
|
||||
//
|
||||
// Existing user + Existing account + Existing Domain -> Nothing changes (if private, index domain)
|
||||
//
|
||||
// Existing user + Existing account + Existing Indexed Domain -> Nothing changes
|
||||
//
|
||||
// Existing user + Existing account + Existing domain reclassified Domain as private -> Nothing changes (index domain)
|
||||
func (am *DefaultAccountManager) GetAccountWithAuthorizationClaims(claims jwtclaims.AuthorizationClaims) (*Account, error) {
|
||||
// if Account ID is part of the claims
|
||||
// it means that we've already classified the domain and user has an account
|
||||
if claims.DomainCategory != PrivateCategory || claims.AccountId != "" {
|
||||
return am.GetAccountByUserOrAccountId(claims.UserId, claims.AccountId, claims.Domain)
|
||||
}
|
||||
|
||||
am.mux.Lock()
|
||||
defer am.mux.Unlock()
|
||||
|
||||
// We checked if the domain has a primary account already
|
||||
domainAccount, err := am.Store.GetAccountByPrivateDomain(claims.Domain)
|
||||
accStatus, _ := status.FromError(err)
|
||||
if accStatus.Code() != codes.OK && accStatus.Code() != codes.NotFound {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
account, err := am.Store.GetUserAccount(claims.UserId)
|
||||
if err == nil {
|
||||
err = am.handleExistingUserAccount(account, domainAccount, claims)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return account, nil
|
||||
} else if s, ok := status.FromError(err); ok && s.Code() == codes.NotFound {
|
||||
return am.handleNewUserAccount(domainAccount, claims)
|
||||
} else {
|
||||
// other error
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
//AccountExists checks whether account exists (returns true) or not (returns false)
|
||||
func (am *DefaultAccountManager) AccountExists(accountId string) (*bool, error) {
|
||||
am.mux.Lock()
|
||||
|
@ -1,6 +1,8 @@
|
||||
package server
|
||||
|
||||
import (
|
||||
"github.com/stretchr/testify/require"
|
||||
"github.com/wiretrustee/wiretrustee/management/server/jwtclaims"
|
||||
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
|
||||
"net"
|
||||
"testing"
|
||||
@ -32,6 +34,154 @@ func TestAccountManager_GetOrCreateAccountByUser(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestDefaultAccountManager_GetAccountWithAuthorizationClaims(t *testing.T) {
|
||||
|
||||
type initUserParams jwtclaims.AuthorizationClaims
|
||||
|
||||
type test struct {
|
||||
name string
|
||||
inputClaims jwtclaims.AuthorizationClaims
|
||||
inputInitUserParams initUserParams
|
||||
inputUpdateAttrs bool
|
||||
testingFunc require.ComparisonAssertionFunc
|
||||
expectedMSG string
|
||||
expectedUserRole UserRole
|
||||
}
|
||||
|
||||
var (
|
||||
publicDomain = "public.com"
|
||||
privateDomain = "private.com"
|
||||
unknownDomain = "unknown.com"
|
||||
)
|
||||
|
||||
defaultInitAccount := initUserParams{
|
||||
Domain: publicDomain,
|
||||
UserId: "defaultUser",
|
||||
}
|
||||
|
||||
testCase1 := test{
|
||||
name: "New User With Public Domain",
|
||||
inputClaims: jwtclaims.AuthorizationClaims{
|
||||
Domain: publicDomain,
|
||||
UserId: "pub-domain-user",
|
||||
DomainCategory: PublicCategory,
|
||||
},
|
||||
inputInitUserParams: defaultInitAccount,
|
||||
testingFunc: require.NotEqual,
|
||||
expectedMSG: "account IDs shouldn't match",
|
||||
expectedUserRole: UserRoleAdmin,
|
||||
}
|
||||
|
||||
initUnknown := defaultInitAccount
|
||||
initUnknown.DomainCategory = UnknownCategory
|
||||
initUnknown.Domain = unknownDomain
|
||||
|
||||
testCase2 := test{
|
||||
name: "New User With Unknown Domain",
|
||||
inputClaims: jwtclaims.AuthorizationClaims{
|
||||
Domain: unknownDomain,
|
||||
UserId: "unknown-domain-user",
|
||||
DomainCategory: UnknownCategory,
|
||||
},
|
||||
inputInitUserParams: initUnknown,
|
||||
testingFunc: require.NotEqual,
|
||||
expectedMSG: "account IDs shouldn't match",
|
||||
expectedUserRole: UserRoleAdmin,
|
||||
}
|
||||
|
||||
testCase3 := test{
|
||||
name: "New User With Private Domain",
|
||||
inputClaims: jwtclaims.AuthorizationClaims{
|
||||
Domain: privateDomain,
|
||||
UserId: "pvt-domain-user",
|
||||
DomainCategory: PrivateCategory,
|
||||
},
|
||||
inputInitUserParams: defaultInitAccount,
|
||||
testingFunc: require.NotEqual,
|
||||
expectedMSG: "account IDs shouldn't match",
|
||||
expectedUserRole: UserRoleAdmin,
|
||||
}
|
||||
|
||||
privateInitAccount := defaultInitAccount
|
||||
privateInitAccount.Domain = privateDomain
|
||||
privateInitAccount.DomainCategory = PrivateCategory
|
||||
|
||||
testCase4 := test{
|
||||
name: "New Regular User With Existing Private Domain",
|
||||
inputClaims: jwtclaims.AuthorizationClaims{
|
||||
Domain: privateDomain,
|
||||
UserId: "pvt-domain-user",
|
||||
DomainCategory: PrivateCategory,
|
||||
},
|
||||
inputUpdateAttrs: true,
|
||||
inputInitUserParams: privateInitAccount,
|
||||
testingFunc: require.Equal,
|
||||
expectedMSG: "account IDs should match",
|
||||
expectedUserRole: UserRoleUser,
|
||||
}
|
||||
|
||||
testCase5 := test{
|
||||
name: "Existing User With Existing Reclassified Private Domain",
|
||||
inputClaims: jwtclaims.AuthorizationClaims{
|
||||
Domain: defaultInitAccount.Domain,
|
||||
UserId: defaultInitAccount.UserId,
|
||||
DomainCategory: PrivateCategory,
|
||||
},
|
||||
inputInitUserParams: defaultInitAccount,
|
||||
testingFunc: require.Equal,
|
||||
expectedMSG: "account IDs should match",
|
||||
expectedUserRole: UserRoleAdmin,
|
||||
}
|
||||
|
||||
for _, testCase := range []test{testCase1, testCase2, testCase3, testCase4, testCase5} {
|
||||
t.Run(testCase.name, func(t *testing.T) {
|
||||
|
||||
manager, err := createManager(t)
|
||||
require.NoError(t, err, "unable to create account manager")
|
||||
|
||||
initAccount, err := manager.GetAccountByUserOrAccountId(testCase.inputInitUserParams.UserId, testCase.inputInitUserParams.AccountId, testCase.inputInitUserParams.Domain)
|
||||
require.NoError(t, err, "create init user failed")
|
||||
|
||||
if testCase.inputUpdateAttrs {
|
||||
err = manager.updateAccountDomainAttributes(initAccount, jwtclaims.AuthorizationClaims{UserId: testCase.inputInitUserParams.UserId, Domain: testCase.inputInitUserParams.Domain, DomainCategory: testCase.inputInitUserParams.DomainCategory}, true)
|
||||
require.NoError(t, err, "update init user failed")
|
||||
}
|
||||
|
||||
account, err := manager.GetAccountWithAuthorizationClaims(testCase.inputClaims)
|
||||
require.NoError(t, err, "support function failed")
|
||||
|
||||
testCase.testingFunc(t, initAccount.Id, account.Id, testCase.expectedMSG)
|
||||
|
||||
require.EqualValues(t, testCase.expectedUserRole, account.Users[testCase.inputClaims.UserId].Role, "user role should match")
|
||||
})
|
||||
}
|
||||
}
|
||||
func TestAccountManager_PrivateAccount(t *testing.T) {
|
||||
manager, err := createManager(t)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
return
|
||||
}
|
||||
|
||||
userId := "test_user"
|
||||
account, err := manager.GetOrCreateAccountByUser(userId, "")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if account == nil {
|
||||
t.Fatalf("expected to create an account for a user %s", userId)
|
||||
}
|
||||
|
||||
account, err = manager.GetAccountByUser(userId)
|
||||
if err != nil {
|
||||
t.Errorf("expected to get existing account after creation, no account was found for a user %s", userId)
|
||||
}
|
||||
|
||||
if account != nil && account.Users[userId] == nil {
|
||||
t.Fatalf("expected to create an account for a user %s but no user was found after creation udner the account %s", userId, account.Id)
|
||||
}
|
||||
}
|
||||
|
||||
func TestAccountManager_SetOrUpdateDomain(t *testing.T) {
|
||||
manager, err := createManager(t)
|
||||
if err != nil {
|
||||
|
@ -17,10 +17,11 @@ const storeFileName = "store.json"
|
||||
|
||||
// FileStore represents an account storage backed by a file persisted to disk
|
||||
type FileStore struct {
|
||||
Accounts map[string]*Account
|
||||
SetupKeyId2AccountId map[string]string `json:"-"`
|
||||
PeerKeyId2AccountId map[string]string `json:"-"`
|
||||
UserId2AccountId map[string]string `json:"-"`
|
||||
Accounts map[string]*Account
|
||||
SetupKeyId2AccountId map[string]string `json:"-"`
|
||||
PeerKeyId2AccountId map[string]string `json:"-"`
|
||||
UserId2AccountId map[string]string `json:"-"`
|
||||
PrivateDomain2AccountId map[string]string `json:"-"`
|
||||
|
||||
// mutex to synchronise Store read/write operations
|
||||
mux sync.Mutex `json:"-"`
|
||||
@ -42,12 +43,13 @@ func restore(file string) (*FileStore, error) {
|
||||
if _, err := os.Stat(file); os.IsNotExist(err) {
|
||||
// create a new FileStore if previously didn't exist (e.g. first run)
|
||||
s := &FileStore{
|
||||
Accounts: make(map[string]*Account),
|
||||
mux: sync.Mutex{},
|
||||
SetupKeyId2AccountId: make(map[string]string),
|
||||
PeerKeyId2AccountId: make(map[string]string),
|
||||
UserId2AccountId: make(map[string]string),
|
||||
storeFile: file,
|
||||
Accounts: make(map[string]*Account),
|
||||
mux: sync.Mutex{},
|
||||
SetupKeyId2AccountId: make(map[string]string),
|
||||
PeerKeyId2AccountId: make(map[string]string),
|
||||
UserId2AccountId: make(map[string]string),
|
||||
PrivateDomain2AccountId: make(map[string]string),
|
||||
storeFile: file,
|
||||
}
|
||||
|
||||
err = s.persist(file)
|
||||
@ -68,6 +70,7 @@ func restore(file string) (*FileStore, error) {
|
||||
store.SetupKeyId2AccountId = make(map[string]string)
|
||||
store.PeerKeyId2AccountId = make(map[string]string)
|
||||
store.UserId2AccountId = make(map[string]string)
|
||||
store.PrivateDomain2AccountId = make(map[string]string)
|
||||
for accountId, account := range store.Accounts {
|
||||
for setupKeyId := range account.SetupKeys {
|
||||
store.SetupKeyId2AccountId[strings.ToUpper(setupKeyId)] = accountId
|
||||
@ -78,6 +81,12 @@ func restore(file string) (*FileStore, error) {
|
||||
for _, user := range account.Users {
|
||||
store.UserId2AccountId[user.Id] = accountId
|
||||
}
|
||||
for _, user := range account.Users {
|
||||
store.UserId2AccountId[user.Id] = accountId
|
||||
}
|
||||
if account.Domain != "" && account.DomainCategory == PrivateCategory && account.IsDomainPrimaryAccount {
|
||||
store.PrivateDomain2AccountId[account.Domain] = accountId
|
||||
}
|
||||
}
|
||||
|
||||
return store, nil
|
||||
@ -178,6 +187,10 @@ func (s *FileStore) SaveAccount(account *Account) error {
|
||||
s.UserId2AccountId[user.Id] = account.Id
|
||||
}
|
||||
|
||||
if account.DomainCategory == PrivateCategory && account.IsDomainPrimaryAccount {
|
||||
s.PrivateDomain2AccountId[account.Domain] = account.Id
|
||||
}
|
||||
|
||||
err := s.persist(s.storeFile)
|
||||
if err != nil {
|
||||
return err
|
||||
@ -186,6 +199,21 @@ func (s *FileStore) SaveAccount(account *Account) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *FileStore) GetAccountByPrivateDomain(domain string) (*Account, error) {
|
||||
|
||||
accountId, accountIdFound := s.PrivateDomain2AccountId[strings.ToLower(domain)]
|
||||
if !accountIdFound {
|
||||
return nil, status.Errorf(codes.NotFound, "provided domain is not registered or is not private")
|
||||
}
|
||||
|
||||
account, err := s.GetAccount(accountId)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return account, nil
|
||||
}
|
||||
|
||||
func (s *FileStore) GetAccountBySetupKey(setupKey string) (*Account, error) {
|
||||
|
||||
accountId, accountIdFound := s.SetupKeyId2AccountId[strings.ToUpper(setupKey)]
|
||||
|
@ -1,6 +1,7 @@
|
||||
package server
|
||||
|
||||
import (
|
||||
"github.com/stretchr/testify/require"
|
||||
"github.com/wiretrustee/wiretrustee/util"
|
||||
"net"
|
||||
"path/filepath"
|
||||
@ -131,34 +132,45 @@ func TestRestore(t *testing.T) {
|
||||
}
|
||||
|
||||
account := store.Accounts["bf1c8084-ba50-4ce7-9439-34653001fc3b"]
|
||||
if account == nil {
|
||||
t.Errorf("failed to restore a FileStore file - missing account bf1c8084-ba50-4ce7-9439-34653001fc3b")
|
||||
|
||||
require.NotNil(t, account, "failed to restore a FileStore file - missing account bf1c8084-ba50-4ce7-9439-34653001fc3b")
|
||||
|
||||
require.NotNil(t, account.Users["edafee4e-63fb-11ec-90d6-0242ac120003"], "failed to restore a FileStore file - missing Account User edafee4e-63fb-11ec-90d6-0242ac120003")
|
||||
|
||||
require.NotNil(t, account.Users["f4f6d672-63fb-11ec-90d6-0242ac120003"], "failed to restore a FileStore file - missing Account User f4f6d672-63fb-11ec-90d6-0242ac120003")
|
||||
|
||||
require.NotNil(t, account.Network, "failed to restore a FileStore file - missing Account Network")
|
||||
|
||||
require.NotNil(t, account.SetupKeys["A2C8E62B-38F5-4553-B31E-DD66C696CEBB"], "failed to restore a FileStore file - missing Account SetupKey A2C8E62B-38F5-4553-B31E-DD66C696CEBB")
|
||||
|
||||
require.Len(t, store.UserId2AccountId, 2, "failed to restore a FileStore wrong UserId2AccountId mapping length")
|
||||
|
||||
require.Len(t, store.SetupKeyId2AccountId, 1, "failed to restore a FileStore wrong SetupKeyId2AccountId mapping length")
|
||||
|
||||
require.Len(t, store.PrivateDomain2AccountId, 1, "failed to restore a FileStore wrong PrivateDomain2AccountId mapping length")
|
||||
}
|
||||
|
||||
func TestGetAccountByPrivateDomain(t *testing.T) {
|
||||
storeDir := t.TempDir()
|
||||
|
||||
err := util.CopyFileContents("testdata/store.json", filepath.Join(storeDir, "store.json"))
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if account != nil && account.Users["edafee4e-63fb-11ec-90d6-0242ac120003"] == nil {
|
||||
t.Errorf("failed to restore a FileStore file - missing Account User edafee4e-63fb-11ec-90d6-0242ac120003")
|
||||
store, err := NewStore(storeDir)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
if account != nil && account.Users["f4f6d672-63fb-11ec-90d6-0242ac120003"] == nil {
|
||||
t.Errorf("failed to restore a FileStore file - missing Account User f4f6d672-63fb-11ec-90d6-0242ac120003")
|
||||
}
|
||||
existingDomain := "test.com"
|
||||
|
||||
if account != nil && account.Network == nil {
|
||||
t.Errorf("failed to restore a FileStore file - missing Account Network")
|
||||
}
|
||||
|
||||
if account != nil && account.SetupKeys["A2C8E62B-38F5-4553-B31E-DD66C696CEBB"] == nil {
|
||||
t.Errorf("failed to restore a FileStore file - missing Account SetupKey A2C8E62B-38F5-4553-B31E-DD66C696CEBB")
|
||||
}
|
||||
|
||||
if len(store.UserId2AccountId) != 2 {
|
||||
t.Errorf("failed to restore a FileStore wrong UserId2AccountId mapping")
|
||||
}
|
||||
|
||||
if len(store.SetupKeyId2AccountId) != 1 {
|
||||
t.Errorf("failed to restore a FileStore wrong SetupKeyId2AccountId mapping")
|
||||
}
|
||||
account, err := store.GetAccountByPrivateDomain(existingDomain)
|
||||
require.NoError(t, err, "should found account")
|
||||
require.Equal(t, existingDomain, account.Domain, "domains should match")
|
||||
|
||||
_, err = store.GetAccountByPrivateDomain("missing-domain.com")
|
||||
require.Error(t, err, "should return error on domain lookup")
|
||||
}
|
||||
|
||||
func newStore(t *testing.T) *FileStore {
|
||||
|
@ -72,7 +72,7 @@ func (h *Peers) deletePeer(accountId string, peer *server.Peer, w http.ResponseW
|
||||
func (h *Peers) getPeerAccount(r *http.Request) (*server.Account, error) {
|
||||
jwtClaims := h.jwtExtractor.ExtractClaimsFromRequestContext(r, h.authAudience)
|
||||
|
||||
account, err := h.accountManager.GetAccountByUserOrAccountId(jwtClaims.UserId, jwtClaims.AccountId, jwtClaims.Domain)
|
||||
account, err := h.accountManager.GetAccountWithAuthorizationClaims(jwtClaims)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed getting account of a user %s: %v", jwtClaims.UserId, err)
|
||||
}
|
||||
|
@ -17,9 +17,9 @@ import (
|
||||
func initTestMetaData(peer ...*server.Peer) *Peers {
|
||||
return &Peers{
|
||||
accountManager: &mock_server.MockAccountManager{
|
||||
GetAccountByUserOrAccountIdFunc: func(userId, accountId, domain string) (*server.Account, error) {
|
||||
GetAccountWithAuthorizationClaimsFunc: func(claims jwtclaims.AuthorizationClaims) (*server.Account, error) {
|
||||
return &server.Account{
|
||||
Id: accountId,
|
||||
Id: claims.AccountId,
|
||||
Domain: "hotmail.com",
|
||||
Peers: map[string]*server.Peer{
|
||||
"test_peer": peer[0],
|
||||
|
@ -126,7 +126,7 @@ func (h *SetupKeys) getSetupKeyAccount(r *http.Request) (*server.Account, error)
|
||||
extractor := jwtclaims.NewClaimsExtractor(nil)
|
||||
jwtClaims := extractor.ExtractClaimsFromRequestContext(r, h.authAudience)
|
||||
|
||||
account, err := h.accountManager.GetAccountByUserOrAccountId(jwtClaims.UserId, jwtClaims.AccountId, jwtClaims.Domain)
|
||||
account, err := h.accountManager.GetAccountWithAuthorizationClaims(jwtClaims)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed getting account of a user %s: %v", jwtClaims.UserId, err)
|
||||
}
|
||||
|
@ -2,7 +2,8 @@ package jwtclaims
|
||||
|
||||
// AuthorizationClaims stores authorization information from JWTs
|
||||
type AuthorizationClaims struct {
|
||||
UserId string
|
||||
AccountId string
|
||||
Domain string
|
||||
UserId string
|
||||
AccountId string
|
||||
Domain string
|
||||
DomainCategory string
|
||||
}
|
||||
|
@ -6,10 +6,11 @@ import (
|
||||
)
|
||||
|
||||
const (
|
||||
TokenUserProperty = "user"
|
||||
AccountIDSuffix = "wt_account_id"
|
||||
DomainIDSuffix = "wt_account_domain"
|
||||
UserIDClaim = "sub"
|
||||
TokenUserProperty = "user"
|
||||
AccountIDSuffix = "wt_account_id"
|
||||
DomainIDSuffix = "wt_account_domain"
|
||||
DomainCategorySuffix = "wt_account_domain_category"
|
||||
UserIDClaim = "sub"
|
||||
)
|
||||
|
||||
// Extract function type
|
||||
@ -47,5 +48,9 @@ func ExtractClaimsFromRequestContext(r *http.Request, authAudiance string) Autho
|
||||
if ok {
|
||||
jwtClaims.Domain = domainClaim.(string)
|
||||
}
|
||||
domainCategoryClaim, ok := claims[authAudiance+DomainCategorySuffix]
|
||||
if ok {
|
||||
jwtClaims.DomainCategory = domainCategoryClaim.(string)
|
||||
}
|
||||
return jwtClaims
|
||||
}
|
||||
|
@ -19,6 +19,9 @@ func newTestRequestWithJWT(t *testing.T, claims AuthorizationClaims, audiance st
|
||||
if claims.Domain != "" {
|
||||
claimMaps[audiance+DomainIDSuffix] = claims.Domain
|
||||
}
|
||||
if claims.DomainCategory != "" {
|
||||
claimMaps[audiance+DomainCategorySuffix] = claims.DomainCategory
|
||||
}
|
||||
token := jwt.NewWithClaims(jwt.SigningMethodHS256, claimMaps)
|
||||
r, err := http.NewRequest(http.MethodGet, "http://localhost", nil)
|
||||
require.NoError(t, err, "creating testing request failed")
|
||||
@ -41,9 +44,10 @@ func TestExtractClaimsFromRequestContext(t *testing.T) {
|
||||
name: "All Claim Fields",
|
||||
inputAudiance: "https://login/",
|
||||
inputAuthorizationClaims: AuthorizationClaims{
|
||||
UserId: "test",
|
||||
Domain: "test.com",
|
||||
AccountId: "testAcc",
|
||||
UserId: "test",
|
||||
Domain: "test.com",
|
||||
AccountId: "testAcc",
|
||||
DomainCategory: "public",
|
||||
},
|
||||
testingFunc: require.EqualValues,
|
||||
expectedMSG: "extracted claims should match input claims",
|
||||
@ -72,6 +76,18 @@ func TestExtractClaimsFromRequestContext(t *testing.T) {
|
||||
}
|
||||
|
||||
testCase4 := test{
|
||||
name: "Category Is Empty",
|
||||
inputAudiance: "https://login/",
|
||||
inputAuthorizationClaims: AuthorizationClaims{
|
||||
UserId: "test",
|
||||
Domain: "test.com",
|
||||
AccountId: "testAcc",
|
||||
},
|
||||
testingFunc: require.EqualValues,
|
||||
expectedMSG: "extracted claims should match input claims",
|
||||
}
|
||||
|
||||
testCase5 := test{
|
||||
name: "Only User ID Is set",
|
||||
inputAudiance: "https://login/",
|
||||
inputAuthorizationClaims: AuthorizationClaims{
|
||||
@ -81,7 +97,7 @@ func TestExtractClaimsFromRequestContext(t *testing.T) {
|
||||
expectedMSG: "extracted claims should match input claims",
|
||||
}
|
||||
|
||||
for _, testCase := range []test{testCase1, testCase2, testCase3, testCase4} {
|
||||
for _, testCase := range []test{testCase1, testCase2, testCase3, testCase4, testCase5} {
|
||||
t.Run(testCase.name, func(t *testing.T) {
|
||||
|
||||
request := newTestRequestWithJWT(t, testCase.inputAuthorizationClaims, testCase.inputAudiance)
|
||||
|
@ -2,28 +2,30 @@ package mock_server
|
||||
|
||||
import (
|
||||
"github.com/wiretrustee/wiretrustee/management/server"
|
||||
"github.com/wiretrustee/wiretrustee/management/server/jwtclaims"
|
||||
"github.com/wiretrustee/wiretrustee/util"
|
||||
"google.golang.org/grpc/codes"
|
||||
"google.golang.org/grpc/status"
|
||||
)
|
||||
|
||||
type MockAccountManager struct {
|
||||
GetOrCreateAccountByUserFunc func(userId, domain string) (*server.Account, error)
|
||||
GetAccountByUserFunc func(userId string) (*server.Account, error)
|
||||
AddSetupKeyFunc func(accountId string, keyName string, keyType server.SetupKeyType, expiresIn *util.Duration) (*server.SetupKey, error)
|
||||
RevokeSetupKeyFunc func(accountId string, keyId string) (*server.SetupKey, error)
|
||||
RenameSetupKeyFunc func(accountId string, keyId string, newName string) (*server.SetupKey, error)
|
||||
GetAccountByIdFunc func(accountId string) (*server.Account, error)
|
||||
GetAccountByUserOrAccountIdFunc func(userId, accountId, domain string) (*server.Account, error)
|
||||
AccountExistsFunc func(accountId string) (*bool, error)
|
||||
AddAccountFunc func(accountId, userId, domain string) (*server.Account, error)
|
||||
GetPeerFunc func(peerKey string) (*server.Peer, error)
|
||||
MarkPeerConnectedFunc func(peerKey string, connected bool) error
|
||||
RenamePeerFunc func(accountId string, peerKey string, newName string) (*server.Peer, error)
|
||||
DeletePeerFunc func(accountId string, peerKey string) (*server.Peer, error)
|
||||
GetPeerByIPFunc func(accountId string, peerIP string) (*server.Peer, error)
|
||||
GetNetworkMapFunc func(peerKey string) (*server.NetworkMap, error)
|
||||
AddPeerFunc func(setupKey string, peer *server.Peer) (*server.Peer, error)
|
||||
GetOrCreateAccountByUserFunc func(userId, domain string) (*server.Account, error)
|
||||
GetAccountByUserFunc func(userId string) (*server.Account, error)
|
||||
AddSetupKeyFunc func(accountId string, keyName string, keyType server.SetupKeyType, expiresIn *util.Duration) (*server.SetupKey, error)
|
||||
RevokeSetupKeyFunc func(accountId string, keyId string) (*server.SetupKey, error)
|
||||
RenameSetupKeyFunc func(accountId string, keyId string, newName string) (*server.SetupKey, error)
|
||||
GetAccountByIdFunc func(accountId string) (*server.Account, error)
|
||||
GetAccountByUserOrAccountIdFunc func(userId, accountId, domain string) (*server.Account, error)
|
||||
GetAccountWithAuthorizationClaimsFunc func(claims jwtclaims.AuthorizationClaims) (*server.Account, error)
|
||||
AccountExistsFunc func(accountId string) (*bool, error)
|
||||
AddAccountFunc func(accountId, userId, domain string) (*server.Account, error)
|
||||
GetPeerFunc func(peerKey string) (*server.Peer, error)
|
||||
MarkPeerConnectedFunc func(peerKey string, connected bool) error
|
||||
RenamePeerFunc func(accountId string, peerKey string, newName string) (*server.Peer, error)
|
||||
DeletePeerFunc func(accountId string, peerKey string) (*server.Peer, error)
|
||||
GetPeerByIPFunc func(accountId string, peerIP string) (*server.Peer, error)
|
||||
GetNetworkMapFunc func(peerKey string) (*server.NetworkMap, error)
|
||||
AddPeerFunc func(setupKey string, peer *server.Peer) (*server.Peer, error)
|
||||
}
|
||||
|
||||
func (am *MockAccountManager) GetOrCreateAccountByUser(userId, domain string) (*server.Account, error) {
|
||||
@ -75,6 +77,13 @@ func (am *MockAccountManager) GetAccountByUserOrAccountId(userId, accountId, dom
|
||||
return nil, status.Errorf(codes.Unimplemented, "method GetAccountByUserOrAccountId not implemented")
|
||||
}
|
||||
|
||||
func (am *MockAccountManager) GetAccountWithAuthorizationClaims(claims jwtclaims.AuthorizationClaims) (*server.Account, error) {
|
||||
if am.GetAccountWithAuthorizationClaimsFunc != nil {
|
||||
return am.GetAccountWithAuthorizationClaimsFunc(claims)
|
||||
}
|
||||
return nil, status.Errorf(codes.Unimplemented, "method GetAccountWithAuthorizationClaims not implemented")
|
||||
}
|
||||
|
||||
func (am *MockAccountManager) AccountExists(accountId string) (*bool, error) {
|
||||
if am.AccountExistsFunc != nil {
|
||||
return am.AccountExistsFunc(accountId)
|
||||
|
@ -9,5 +9,6 @@ type Store interface {
|
||||
GetAccountPeers(accountId string) ([]*Peer, error)
|
||||
GetPeerAccount(peerKey string) (*Account, error)
|
||||
GetAccountBySetupKey(setupKey string) (*Account, error)
|
||||
GetAccountByPrivateDomain(domain string) (*Account, error)
|
||||
SaveAccount(account *Account) error
|
||||
}
|
||||
|
3
management/server/testdata/store.json
vendored
3
management/server/testdata/store.json
vendored
@ -2,6 +2,9 @@
|
||||
"Accounts": {
|
||||
"bf1c8084-ba50-4ce7-9439-34653001fc3b": {
|
||||
"Id": "bf1c8084-ba50-4ce7-9439-34653001fc3b",
|
||||
"Domain": "test.com",
|
||||
"DomainCategory": "private",
|
||||
"IsDomainPrimaryAccount": true,
|
||||
"SetupKeys": {
|
||||
"A2C8E62B-38F5-4553-B31E-DD66C696CEBB": {
|
||||
"Key": "A2C8E62B-38F5-4553-B31E-DD66C696CEBB",
|
||||
|
@ -3,6 +3,7 @@ package server
|
||||
import (
|
||||
"google.golang.org/grpc/codes"
|
||||
"google.golang.org/grpc/status"
|
||||
"strings"
|
||||
)
|
||||
|
||||
const (
|
||||
@ -34,6 +35,11 @@ func NewUser(id string, role UserRole) *User {
|
||||
}
|
||||
}
|
||||
|
||||
// NewRegularUser creates a new user with role UserRoleAdmin
|
||||
func NewRegularUser(id string) *User {
|
||||
return NewUser(id, UserRoleUser)
|
||||
}
|
||||
|
||||
// NewAdminUser creates a new user with role UserRoleAdmin
|
||||
func NewAdminUser(id string) *User {
|
||||
return NewUser(id, UserRoleAdmin)
|
||||
@ -44,10 +50,12 @@ func (am *DefaultAccountManager) GetOrCreateAccountByUser(userId, domain string)
|
||||
am.mux.Lock()
|
||||
defer am.mux.Unlock()
|
||||
|
||||
lowerDomain := strings.ToLower(domain)
|
||||
|
||||
account, err := am.Store.GetUserAccount(userId)
|
||||
if err != nil {
|
||||
if s, ok := status.FromError(err); ok && s.Code() == codes.NotFound {
|
||||
account = NewAccount(userId, domain)
|
||||
account = NewAccount(userId, lowerDomain)
|
||||
account.Users[userId] = NewAdminUser(userId)
|
||||
err = am.Store.SaveAccount(account)
|
||||
if err != nil {
|
||||
@ -59,8 +67,8 @@ func (am *DefaultAccountManager) GetOrCreateAccountByUser(userId, domain string)
|
||||
}
|
||||
}
|
||||
|
||||
if account.Domain != domain {
|
||||
account.Domain = domain
|
||||
if account.Domain != lowerDomain {
|
||||
account.Domain = lowerDomain
|
||||
err = am.Store.SaveAccount(account)
|
||||
if err != nil {
|
||||
return nil, status.Errorf(codes.Internal, "failed updating account with domain")
|
||||
|
Loading…
Reference in New Issue
Block a user