From 0b8387bd2c4ce62e690e023176438a7dcdf5199f Mon Sep 17 00:00:00 2001 From: Maycon Santos Date: Tue, 1 Mar 2022 15:22:18 +0100 Subject: [PATCH] Group users of same private domain (#243) * Added Domain Category field and fix store tests * Add GetAccountByDomain method * Add Domain Category to authorization claims * Initial GetAccountWithAuthorizationClaims test cases * Renamed Private Domain map and index it on saving account * New Go build tags * Added NewRegularUser function * Updated restore to account for primary domain account Also, added another test case * Added grouping user of private domains Also added auxiliary methods for update metadata and domain attributes * Update http handles get account method and tests * Fix lint and document another case * Removed unnecessary log * Move use cases to method and add flow comments * Split the new user and existing logic from GetAccountWithAuthorizationClaims * Review: minor corrections Co-authored-by: braginini --- iface/ifacename.go | 1 + iface/ifacename_darwin.go | 1 + management/server/account.go | 163 ++++++++++++++++-- management/server/account_test.go | 150 ++++++++++++++++ management/server/file_store.go | 48 ++++-- management/server/file_store_test.go | 56 +++--- management/server/http/handler/peers.go | 2 +- management/server/http/handler/peers_test.go | 4 +- management/server/http/handler/setupkeys.go | 2 +- management/server/jwtclaims/claims.go | 7 +- management/server/jwtclaims/extractor.go | 13 +- management/server/jwtclaims/extractor_test.go | 24 ++- management/server/mock_server/account_mock.go | 41 +++-- management/server/store.go | 1 + management/server/testdata/store.json | 3 + management/server/user.go | 14 +- 16 files changed, 452 insertions(+), 78 deletions(-) diff --git a/iface/ifacename.go b/iface/ifacename.go index 3fcb0e60e..05d0299d3 100644 --- a/iface/ifacename.go +++ b/iface/ifacename.go @@ -1,3 +1,4 @@ +//go:build linux || windows // +build linux windows package iface diff --git a/iface/ifacename_darwin.go b/iface/ifacename_darwin.go index 5f25ad29e..c80f790f5 100644 --- a/iface/ifacename_darwin.go +++ b/iface/ifacename_darwin.go @@ -1,3 +1,4 @@ +//go:build darwin // +build darwin package iface diff --git a/management/server/account.go b/management/server/account.go index c597a7c46..fc5c89eb5 100644 --- a/management/server/account.go +++ b/management/server/account.go @@ -4,12 +4,20 @@ import ( "github.com/rs/xid" log "github.com/sirupsen/logrus" "github.com/wiretrustee/wiretrustee/management/server/idp" + "github.com/wiretrustee/wiretrustee/management/server/jwtclaims" "github.com/wiretrustee/wiretrustee/util" "google.golang.org/grpc/codes" "google.golang.org/grpc/status" + "strings" "sync" ) +const ( + PublicCategory = "public" + PrivateCategory = "private" + UnknownCategory = "unknown" +) + type AccountManager interface { GetOrCreateAccountByUser(userId, domain string) (*Account, error) GetAccountByUser(userId string) (*Account, error) @@ -18,6 +26,7 @@ type AccountManager interface { RenameSetupKey(accountId string, keyId string, newName string) (*SetupKey, error) GetAccountById(accountId string) (*Account, error) GetAccountByUserOrAccountId(userId, accountId, domain string) (*Account, error) + GetAccountWithAuthorizationClaims(claims jwtclaims.AuthorizationClaims) (*Account, error) AccountExists(accountId string) (*bool, error) AddAccount(accountId, userId, domain string) (*Account, error) GetPeer(peerKey string) (*Peer, error) @@ -41,12 +50,14 @@ type DefaultAccountManager struct { type Account struct { Id string // User.Id it was created by - CreatedBy string - Domain string - SetupKeys map[string]*SetupKey - Network *Network - Peers map[string]*Peer - Users map[string]*User + CreatedBy string + Domain string + DomainCategory string + IsDomainPrimaryAccount bool + SetupKeys map[string]*SetupKey + Network *Network + Peers map[string]*Peer + Users map[string]*User } // NewAccount creates a new Account with a generated ID and generated default setup keys @@ -193,12 +204,9 @@ func (am *DefaultAccountManager) GetAccountByUserOrAccountId(userId, accountId, if err != nil { return nil, status.Errorf(codes.NotFound, "account not found using user id: %s", userId) } - // update idp manager app metadata - if am.idpManager != nil { - err = am.idpManager.UpdateUserAppMetadata(userId, idp.AppMetadata{WTAccountId: account.Id}) - if err != nil { - return nil, status.Errorf(codes.Internal, "updating user's app metadata failed with: %v", err) - } + err = am.updateIDPMetadata(userId, account.Id) + if err != nil { + return nil, err } return account, nil } @@ -206,6 +214,137 @@ func (am *DefaultAccountManager) GetAccountByUserOrAccountId(userId, accountId, return nil, status.Errorf(codes.NotFound, "no valid user or account Id provided") } +// updateIDPMetadata update user's app metadata in idp manager +func (am *DefaultAccountManager) updateIDPMetadata(userId, accountID string) error { + if am.idpManager != nil { + err := am.idpManager.UpdateUserAppMetadata(userId, idp.AppMetadata{WTAccountId: accountID}) + if err != nil { + return status.Errorf(codes.Internal, "updating user's app metadata failed with: %v", err) + } + } + return nil +} + +// updateAccountDomainAttributes updates the account domain attributes and then, saves the account +func (am *DefaultAccountManager) updateAccountDomainAttributes(account *Account, claims jwtclaims.AuthorizationClaims, primaryDomain bool) error { + account.IsDomainPrimaryAccount = primaryDomain + account.Domain = strings.ToLower(claims.Domain) + account.DomainCategory = claims.DomainCategory + err := am.Store.SaveAccount(account) + if err != nil { + return status.Errorf(codes.Internal, "failed saving updated account") + } + return nil +} + +// handleExistingUserAccount handles existing User accounts and update its domain attributes. +// +// +// If there is no primary domain account yet, we set the account as primary for the domain. Otherwise, +// we compare the account's ID with the domain account ID, and if they don't match, we set the account as +// non-primary account for the domain. We don't merge accounts at this stage, because of cases when a domain +// was previously unclassified or classified as public so N users that logged int that time, has they own account +// and peers that shouldn't be lost. +func (am *DefaultAccountManager) handleExistingUserAccount(existingAcc *Account, domainAcc *Account, claims jwtclaims.AuthorizationClaims) error { + var err error + + if domainAcc == nil || existingAcc.Id != domainAcc.Id { + err = am.updateAccountDomainAttributes(existingAcc, claims, false) + if err != nil { + return err + } + } + + // we should register the account ID to this user's metadata in our IDP manager + err = am.updateIDPMetadata(claims.UserId, existingAcc.Id) + if err != nil { + return err + } + + return nil +} + +// handleNewUserAccount validates if there is an existing primary account for the domain, if so it adds the new user to that account, +// otherwise it will create a new account and make it primary account for the domain. +func (am *DefaultAccountManager) handleNewUserAccount(domainAcc *Account, claims jwtclaims.AuthorizationClaims) (*Account, error) { + var ( + account *Account + primaryAccount bool + ) + lowerDomain := strings.ToLower(claims.Domain) + // if domain already has a primary account, add regular user + if domainAcc != nil { + account = domainAcc + account.Users[claims.UserId] = NewRegularUser(claims.UserId) + primaryAccount = false + } else { + account = NewAccount(claims.UserId, lowerDomain) + account.Users[claims.UserId] = NewAdminUser(claims.UserId) + primaryAccount = true + } + + err := am.updateAccountDomainAttributes(account, claims, primaryAccount) + if err != nil { + return nil, err + } + + err = am.updateIDPMetadata(claims.UserId, account.Id) + if err != nil { + return nil, err + } + + return account, nil +} + +// GetAccountWithAuthorizationClaims retrievs an account using JWT Claims. +// if domain is of the PrivateCategory category, it will evaluate +// if account is new, existing or if there is another account with the same domain +// +// Use cases: +// +// New user + New account + New domain -> create account, user role = admin (if private domain, index domain) +// +// New user + New account + Existing Private Domain -> add user to the existing account, user role = regular (not admin) +// +// New user + New account + Existing Public Domain -> create account, user role = admin +// +// Existing user + Existing account + Existing Domain -> Nothing changes (if private, index domain) +// +// Existing user + Existing account + Existing Indexed Domain -> Nothing changes +// +// Existing user + Existing account + Existing domain reclassified Domain as private -> Nothing changes (index domain) +func (am *DefaultAccountManager) GetAccountWithAuthorizationClaims(claims jwtclaims.AuthorizationClaims) (*Account, error) { + // if Account ID is part of the claims + // it means that we've already classified the domain and user has an account + if claims.DomainCategory != PrivateCategory || claims.AccountId != "" { + return am.GetAccountByUserOrAccountId(claims.UserId, claims.AccountId, claims.Domain) + } + + am.mux.Lock() + defer am.mux.Unlock() + + // We checked if the domain has a primary account already + domainAccount, err := am.Store.GetAccountByPrivateDomain(claims.Domain) + accStatus, _ := status.FromError(err) + if accStatus.Code() != codes.OK && accStatus.Code() != codes.NotFound { + return nil, err + } + + account, err := am.Store.GetUserAccount(claims.UserId) + if err == nil { + err = am.handleExistingUserAccount(account, domainAccount, claims) + if err != nil { + return nil, err + } + return account, nil + } else if s, ok := status.FromError(err); ok && s.Code() == codes.NotFound { + return am.handleNewUserAccount(domainAccount, claims) + } else { + // other error + return nil, err + } +} + //AccountExists checks whether account exists (returns true) or not (returns false) func (am *DefaultAccountManager) AccountExists(accountId string) (*bool, error) { am.mux.Lock() diff --git a/management/server/account_test.go b/management/server/account_test.go index d2f5b3535..0e6d96ead 100644 --- a/management/server/account_test.go +++ b/management/server/account_test.go @@ -1,6 +1,8 @@ package server import ( + "github.com/stretchr/testify/require" + "github.com/wiretrustee/wiretrustee/management/server/jwtclaims" "golang.zx2c4.com/wireguard/wgctrl/wgtypes" "net" "testing" @@ -32,6 +34,154 @@ func TestAccountManager_GetOrCreateAccountByUser(t *testing.T) { } } +func TestDefaultAccountManager_GetAccountWithAuthorizationClaims(t *testing.T) { + + type initUserParams jwtclaims.AuthorizationClaims + + type test struct { + name string + inputClaims jwtclaims.AuthorizationClaims + inputInitUserParams initUserParams + inputUpdateAttrs bool + testingFunc require.ComparisonAssertionFunc + expectedMSG string + expectedUserRole UserRole + } + + var ( + publicDomain = "public.com" + privateDomain = "private.com" + unknownDomain = "unknown.com" + ) + + defaultInitAccount := initUserParams{ + Domain: publicDomain, + UserId: "defaultUser", + } + + testCase1 := test{ + name: "New User With Public Domain", + inputClaims: jwtclaims.AuthorizationClaims{ + Domain: publicDomain, + UserId: "pub-domain-user", + DomainCategory: PublicCategory, + }, + inputInitUserParams: defaultInitAccount, + testingFunc: require.NotEqual, + expectedMSG: "account IDs shouldn't match", + expectedUserRole: UserRoleAdmin, + } + + initUnknown := defaultInitAccount + initUnknown.DomainCategory = UnknownCategory + initUnknown.Domain = unknownDomain + + testCase2 := test{ + name: "New User With Unknown Domain", + inputClaims: jwtclaims.AuthorizationClaims{ + Domain: unknownDomain, + UserId: "unknown-domain-user", + DomainCategory: UnknownCategory, + }, + inputInitUserParams: initUnknown, + testingFunc: require.NotEqual, + expectedMSG: "account IDs shouldn't match", + expectedUserRole: UserRoleAdmin, + } + + testCase3 := test{ + name: "New User With Private Domain", + inputClaims: jwtclaims.AuthorizationClaims{ + Domain: privateDomain, + UserId: "pvt-domain-user", + DomainCategory: PrivateCategory, + }, + inputInitUserParams: defaultInitAccount, + testingFunc: require.NotEqual, + expectedMSG: "account IDs shouldn't match", + expectedUserRole: UserRoleAdmin, + } + + privateInitAccount := defaultInitAccount + privateInitAccount.Domain = privateDomain + privateInitAccount.DomainCategory = PrivateCategory + + testCase4 := test{ + name: "New Regular User With Existing Private Domain", + inputClaims: jwtclaims.AuthorizationClaims{ + Domain: privateDomain, + UserId: "pvt-domain-user", + DomainCategory: PrivateCategory, + }, + inputUpdateAttrs: true, + inputInitUserParams: privateInitAccount, + testingFunc: require.Equal, + expectedMSG: "account IDs should match", + expectedUserRole: UserRoleUser, + } + + testCase5 := test{ + name: "Existing User With Existing Reclassified Private Domain", + inputClaims: jwtclaims.AuthorizationClaims{ + Domain: defaultInitAccount.Domain, + UserId: defaultInitAccount.UserId, + DomainCategory: PrivateCategory, + }, + inputInitUserParams: defaultInitAccount, + testingFunc: require.Equal, + expectedMSG: "account IDs should match", + expectedUserRole: UserRoleAdmin, + } + + for _, testCase := range []test{testCase1, testCase2, testCase3, testCase4, testCase5} { + t.Run(testCase.name, func(t *testing.T) { + + manager, err := createManager(t) + require.NoError(t, err, "unable to create account manager") + + initAccount, err := manager.GetAccountByUserOrAccountId(testCase.inputInitUserParams.UserId, testCase.inputInitUserParams.AccountId, testCase.inputInitUserParams.Domain) + require.NoError(t, err, "create init user failed") + + if testCase.inputUpdateAttrs { + err = manager.updateAccountDomainAttributes(initAccount, jwtclaims.AuthorizationClaims{UserId: testCase.inputInitUserParams.UserId, Domain: testCase.inputInitUserParams.Domain, DomainCategory: testCase.inputInitUserParams.DomainCategory}, true) + require.NoError(t, err, "update init user failed") + } + + account, err := manager.GetAccountWithAuthorizationClaims(testCase.inputClaims) + require.NoError(t, err, "support function failed") + + testCase.testingFunc(t, initAccount.Id, account.Id, testCase.expectedMSG) + + require.EqualValues(t, testCase.expectedUserRole, account.Users[testCase.inputClaims.UserId].Role, "user role should match") + }) + } +} +func TestAccountManager_PrivateAccount(t *testing.T) { + manager, err := createManager(t) + if err != nil { + t.Fatal(err) + return + } + + userId := "test_user" + account, err := manager.GetOrCreateAccountByUser(userId, "") + if err != nil { + t.Fatal(err) + } + if account == nil { + t.Fatalf("expected to create an account for a user %s", userId) + } + + account, err = manager.GetAccountByUser(userId) + if err != nil { + t.Errorf("expected to get existing account after creation, no account was found for a user %s", userId) + } + + if account != nil && account.Users[userId] == nil { + t.Fatalf("expected to create an account for a user %s but no user was found after creation udner the account %s", userId, account.Id) + } +} + func TestAccountManager_SetOrUpdateDomain(t *testing.T) { manager, err := createManager(t) if err != nil { diff --git a/management/server/file_store.go b/management/server/file_store.go index 5fc8b9e54..a6833b04f 100644 --- a/management/server/file_store.go +++ b/management/server/file_store.go @@ -17,10 +17,11 @@ const storeFileName = "store.json" // FileStore represents an account storage backed by a file persisted to disk type FileStore struct { - Accounts map[string]*Account - SetupKeyId2AccountId map[string]string `json:"-"` - PeerKeyId2AccountId map[string]string `json:"-"` - UserId2AccountId map[string]string `json:"-"` + Accounts map[string]*Account + SetupKeyId2AccountId map[string]string `json:"-"` + PeerKeyId2AccountId map[string]string `json:"-"` + UserId2AccountId map[string]string `json:"-"` + PrivateDomain2AccountId map[string]string `json:"-"` // mutex to synchronise Store read/write operations mux sync.Mutex `json:"-"` @@ -42,12 +43,13 @@ func restore(file string) (*FileStore, error) { if _, err := os.Stat(file); os.IsNotExist(err) { // create a new FileStore if previously didn't exist (e.g. first run) s := &FileStore{ - Accounts: make(map[string]*Account), - mux: sync.Mutex{}, - SetupKeyId2AccountId: make(map[string]string), - PeerKeyId2AccountId: make(map[string]string), - UserId2AccountId: make(map[string]string), - storeFile: file, + Accounts: make(map[string]*Account), + mux: sync.Mutex{}, + SetupKeyId2AccountId: make(map[string]string), + PeerKeyId2AccountId: make(map[string]string), + UserId2AccountId: make(map[string]string), + PrivateDomain2AccountId: make(map[string]string), + storeFile: file, } err = s.persist(file) @@ -68,6 +70,7 @@ func restore(file string) (*FileStore, error) { store.SetupKeyId2AccountId = make(map[string]string) store.PeerKeyId2AccountId = make(map[string]string) store.UserId2AccountId = make(map[string]string) + store.PrivateDomain2AccountId = make(map[string]string) for accountId, account := range store.Accounts { for setupKeyId := range account.SetupKeys { store.SetupKeyId2AccountId[strings.ToUpper(setupKeyId)] = accountId @@ -78,6 +81,12 @@ func restore(file string) (*FileStore, error) { for _, user := range account.Users { store.UserId2AccountId[user.Id] = accountId } + for _, user := range account.Users { + store.UserId2AccountId[user.Id] = accountId + } + if account.Domain != "" && account.DomainCategory == PrivateCategory && account.IsDomainPrimaryAccount { + store.PrivateDomain2AccountId[account.Domain] = accountId + } } return store, nil @@ -178,6 +187,10 @@ func (s *FileStore) SaveAccount(account *Account) error { s.UserId2AccountId[user.Id] = account.Id } + if account.DomainCategory == PrivateCategory && account.IsDomainPrimaryAccount { + s.PrivateDomain2AccountId[account.Domain] = account.Id + } + err := s.persist(s.storeFile) if err != nil { return err @@ -186,6 +199,21 @@ func (s *FileStore) SaveAccount(account *Account) error { return nil } +func (s *FileStore) GetAccountByPrivateDomain(domain string) (*Account, error) { + + accountId, accountIdFound := s.PrivateDomain2AccountId[strings.ToLower(domain)] + if !accountIdFound { + return nil, status.Errorf(codes.NotFound, "provided domain is not registered or is not private") + } + + account, err := s.GetAccount(accountId) + if err != nil { + return nil, err + } + + return account, nil +} + func (s *FileStore) GetAccountBySetupKey(setupKey string) (*Account, error) { accountId, accountIdFound := s.SetupKeyId2AccountId[strings.ToUpper(setupKey)] diff --git a/management/server/file_store_test.go b/management/server/file_store_test.go index 0b83799c1..bf05f9b8b 100644 --- a/management/server/file_store_test.go +++ b/management/server/file_store_test.go @@ -1,6 +1,7 @@ package server import ( + "github.com/stretchr/testify/require" "github.com/wiretrustee/wiretrustee/util" "net" "path/filepath" @@ -131,34 +132,45 @@ func TestRestore(t *testing.T) { } account := store.Accounts["bf1c8084-ba50-4ce7-9439-34653001fc3b"] - if account == nil { - t.Errorf("failed to restore a FileStore file - missing account bf1c8084-ba50-4ce7-9439-34653001fc3b") + + require.NotNil(t, account, "failed to restore a FileStore file - missing account bf1c8084-ba50-4ce7-9439-34653001fc3b") + + require.NotNil(t, account.Users["edafee4e-63fb-11ec-90d6-0242ac120003"], "failed to restore a FileStore file - missing Account User edafee4e-63fb-11ec-90d6-0242ac120003") + + require.NotNil(t, account.Users["f4f6d672-63fb-11ec-90d6-0242ac120003"], "failed to restore a FileStore file - missing Account User f4f6d672-63fb-11ec-90d6-0242ac120003") + + require.NotNil(t, account.Network, "failed to restore a FileStore file - missing Account Network") + + require.NotNil(t, account.SetupKeys["A2C8E62B-38F5-4553-B31E-DD66C696CEBB"], "failed to restore a FileStore file - missing Account SetupKey A2C8E62B-38F5-4553-B31E-DD66C696CEBB") + + require.Len(t, store.UserId2AccountId, 2, "failed to restore a FileStore wrong UserId2AccountId mapping length") + + require.Len(t, store.SetupKeyId2AccountId, 1, "failed to restore a FileStore wrong SetupKeyId2AccountId mapping length") + + require.Len(t, store.PrivateDomain2AccountId, 1, "failed to restore a FileStore wrong PrivateDomain2AccountId mapping length") +} + +func TestGetAccountByPrivateDomain(t *testing.T) { + storeDir := t.TempDir() + + err := util.CopyFileContents("testdata/store.json", filepath.Join(storeDir, "store.json")) + if err != nil { + t.Fatal(err) } - if account != nil && account.Users["edafee4e-63fb-11ec-90d6-0242ac120003"] == nil { - t.Errorf("failed to restore a FileStore file - missing Account User edafee4e-63fb-11ec-90d6-0242ac120003") + store, err := NewStore(storeDir) + if err != nil { + return } - if account != nil && account.Users["f4f6d672-63fb-11ec-90d6-0242ac120003"] == nil { - t.Errorf("failed to restore a FileStore file - missing Account User f4f6d672-63fb-11ec-90d6-0242ac120003") - } + existingDomain := "test.com" - if account != nil && account.Network == nil { - t.Errorf("failed to restore a FileStore file - missing Account Network") - } - - if account != nil && account.SetupKeys["A2C8E62B-38F5-4553-B31E-DD66C696CEBB"] == nil { - t.Errorf("failed to restore a FileStore file - missing Account SetupKey A2C8E62B-38F5-4553-B31E-DD66C696CEBB") - } - - if len(store.UserId2AccountId) != 2 { - t.Errorf("failed to restore a FileStore wrong UserId2AccountId mapping") - } - - if len(store.SetupKeyId2AccountId) != 1 { - t.Errorf("failed to restore a FileStore wrong SetupKeyId2AccountId mapping") - } + account, err := store.GetAccountByPrivateDomain(existingDomain) + require.NoError(t, err, "should found account") + require.Equal(t, existingDomain, account.Domain, "domains should match") + _, err = store.GetAccountByPrivateDomain("missing-domain.com") + require.Error(t, err, "should return error on domain lookup") } func newStore(t *testing.T) *FileStore { diff --git a/management/server/http/handler/peers.go b/management/server/http/handler/peers.go index 9212a4677..8da5d2151 100644 --- a/management/server/http/handler/peers.go +++ b/management/server/http/handler/peers.go @@ -72,7 +72,7 @@ func (h *Peers) deletePeer(accountId string, peer *server.Peer, w http.ResponseW func (h *Peers) getPeerAccount(r *http.Request) (*server.Account, error) { jwtClaims := h.jwtExtractor.ExtractClaimsFromRequestContext(r, h.authAudience) - account, err := h.accountManager.GetAccountByUserOrAccountId(jwtClaims.UserId, jwtClaims.AccountId, jwtClaims.Domain) + account, err := h.accountManager.GetAccountWithAuthorizationClaims(jwtClaims) if err != nil { return nil, fmt.Errorf("failed getting account of a user %s: %v", jwtClaims.UserId, err) } diff --git a/management/server/http/handler/peers_test.go b/management/server/http/handler/peers_test.go index bb4aa4a03..ac04d4f53 100644 --- a/management/server/http/handler/peers_test.go +++ b/management/server/http/handler/peers_test.go @@ -17,9 +17,9 @@ import ( func initTestMetaData(peer ...*server.Peer) *Peers { return &Peers{ accountManager: &mock_server.MockAccountManager{ - GetAccountByUserOrAccountIdFunc: func(userId, accountId, domain string) (*server.Account, error) { + GetAccountWithAuthorizationClaimsFunc: func(claims jwtclaims.AuthorizationClaims) (*server.Account, error) { return &server.Account{ - Id: accountId, + Id: claims.AccountId, Domain: "hotmail.com", Peers: map[string]*server.Peer{ "test_peer": peer[0], diff --git a/management/server/http/handler/setupkeys.go b/management/server/http/handler/setupkeys.go index ee3dcd628..06b411f4d 100644 --- a/management/server/http/handler/setupkeys.go +++ b/management/server/http/handler/setupkeys.go @@ -126,7 +126,7 @@ func (h *SetupKeys) getSetupKeyAccount(r *http.Request) (*server.Account, error) extractor := jwtclaims.NewClaimsExtractor(nil) jwtClaims := extractor.ExtractClaimsFromRequestContext(r, h.authAudience) - account, err := h.accountManager.GetAccountByUserOrAccountId(jwtClaims.UserId, jwtClaims.AccountId, jwtClaims.Domain) + account, err := h.accountManager.GetAccountWithAuthorizationClaims(jwtClaims) if err != nil { return nil, fmt.Errorf("failed getting account of a user %s: %v", jwtClaims.UserId, err) } diff --git a/management/server/jwtclaims/claims.go b/management/server/jwtclaims/claims.go index 277f3c20d..2d7dc499a 100644 --- a/management/server/jwtclaims/claims.go +++ b/management/server/jwtclaims/claims.go @@ -2,7 +2,8 @@ package jwtclaims // AuthorizationClaims stores authorization information from JWTs type AuthorizationClaims struct { - UserId string - AccountId string - Domain string + UserId string + AccountId string + Domain string + DomainCategory string } diff --git a/management/server/jwtclaims/extractor.go b/management/server/jwtclaims/extractor.go index f6f609d12..aa37e18a9 100644 --- a/management/server/jwtclaims/extractor.go +++ b/management/server/jwtclaims/extractor.go @@ -6,10 +6,11 @@ import ( ) const ( - TokenUserProperty = "user" - AccountIDSuffix = "wt_account_id" - DomainIDSuffix = "wt_account_domain" - UserIDClaim = "sub" + TokenUserProperty = "user" + AccountIDSuffix = "wt_account_id" + DomainIDSuffix = "wt_account_domain" + DomainCategorySuffix = "wt_account_domain_category" + UserIDClaim = "sub" ) // Extract function type @@ -47,5 +48,9 @@ func ExtractClaimsFromRequestContext(r *http.Request, authAudiance string) Autho if ok { jwtClaims.Domain = domainClaim.(string) } + domainCategoryClaim, ok := claims[authAudiance+DomainCategorySuffix] + if ok { + jwtClaims.DomainCategory = domainCategoryClaim.(string) + } return jwtClaims } diff --git a/management/server/jwtclaims/extractor_test.go b/management/server/jwtclaims/extractor_test.go index 7859d187a..9f4d7c7d3 100644 --- a/management/server/jwtclaims/extractor_test.go +++ b/management/server/jwtclaims/extractor_test.go @@ -19,6 +19,9 @@ func newTestRequestWithJWT(t *testing.T, claims AuthorizationClaims, audiance st if claims.Domain != "" { claimMaps[audiance+DomainIDSuffix] = claims.Domain } + if claims.DomainCategory != "" { + claimMaps[audiance+DomainCategorySuffix] = claims.DomainCategory + } token := jwt.NewWithClaims(jwt.SigningMethodHS256, claimMaps) r, err := http.NewRequest(http.MethodGet, "http://localhost", nil) require.NoError(t, err, "creating testing request failed") @@ -41,9 +44,10 @@ func TestExtractClaimsFromRequestContext(t *testing.T) { name: "All Claim Fields", inputAudiance: "https://login/", inputAuthorizationClaims: AuthorizationClaims{ - UserId: "test", - Domain: "test.com", - AccountId: "testAcc", + UserId: "test", + Domain: "test.com", + AccountId: "testAcc", + DomainCategory: "public", }, testingFunc: require.EqualValues, expectedMSG: "extracted claims should match input claims", @@ -72,6 +76,18 @@ func TestExtractClaimsFromRequestContext(t *testing.T) { } testCase4 := test{ + name: "Category Is Empty", + inputAudiance: "https://login/", + inputAuthorizationClaims: AuthorizationClaims{ + UserId: "test", + Domain: "test.com", + AccountId: "testAcc", + }, + testingFunc: require.EqualValues, + expectedMSG: "extracted claims should match input claims", + } + + testCase5 := test{ name: "Only User ID Is set", inputAudiance: "https://login/", inputAuthorizationClaims: AuthorizationClaims{ @@ -81,7 +97,7 @@ func TestExtractClaimsFromRequestContext(t *testing.T) { expectedMSG: "extracted claims should match input claims", } - for _, testCase := range []test{testCase1, testCase2, testCase3, testCase4} { + for _, testCase := range []test{testCase1, testCase2, testCase3, testCase4, testCase5} { t.Run(testCase.name, func(t *testing.T) { request := newTestRequestWithJWT(t, testCase.inputAuthorizationClaims, testCase.inputAudiance) diff --git a/management/server/mock_server/account_mock.go b/management/server/mock_server/account_mock.go index d3907e72d..b0c30561a 100644 --- a/management/server/mock_server/account_mock.go +++ b/management/server/mock_server/account_mock.go @@ -2,28 +2,30 @@ package mock_server import ( "github.com/wiretrustee/wiretrustee/management/server" + "github.com/wiretrustee/wiretrustee/management/server/jwtclaims" "github.com/wiretrustee/wiretrustee/util" "google.golang.org/grpc/codes" "google.golang.org/grpc/status" ) type MockAccountManager struct { - GetOrCreateAccountByUserFunc func(userId, domain string) (*server.Account, error) - GetAccountByUserFunc func(userId string) (*server.Account, error) - AddSetupKeyFunc func(accountId string, keyName string, keyType server.SetupKeyType, expiresIn *util.Duration) (*server.SetupKey, error) - RevokeSetupKeyFunc func(accountId string, keyId string) (*server.SetupKey, error) - RenameSetupKeyFunc func(accountId string, keyId string, newName string) (*server.SetupKey, error) - GetAccountByIdFunc func(accountId string) (*server.Account, error) - GetAccountByUserOrAccountIdFunc func(userId, accountId, domain string) (*server.Account, error) - AccountExistsFunc func(accountId string) (*bool, error) - AddAccountFunc func(accountId, userId, domain string) (*server.Account, error) - GetPeerFunc func(peerKey string) (*server.Peer, error) - MarkPeerConnectedFunc func(peerKey string, connected bool) error - RenamePeerFunc func(accountId string, peerKey string, newName string) (*server.Peer, error) - DeletePeerFunc func(accountId string, peerKey string) (*server.Peer, error) - GetPeerByIPFunc func(accountId string, peerIP string) (*server.Peer, error) - GetNetworkMapFunc func(peerKey string) (*server.NetworkMap, error) - AddPeerFunc func(setupKey string, peer *server.Peer) (*server.Peer, error) + GetOrCreateAccountByUserFunc func(userId, domain string) (*server.Account, error) + GetAccountByUserFunc func(userId string) (*server.Account, error) + AddSetupKeyFunc func(accountId string, keyName string, keyType server.SetupKeyType, expiresIn *util.Duration) (*server.SetupKey, error) + RevokeSetupKeyFunc func(accountId string, keyId string) (*server.SetupKey, error) + RenameSetupKeyFunc func(accountId string, keyId string, newName string) (*server.SetupKey, error) + GetAccountByIdFunc func(accountId string) (*server.Account, error) + GetAccountByUserOrAccountIdFunc func(userId, accountId, domain string) (*server.Account, error) + GetAccountWithAuthorizationClaimsFunc func(claims jwtclaims.AuthorizationClaims) (*server.Account, error) + AccountExistsFunc func(accountId string) (*bool, error) + AddAccountFunc func(accountId, userId, domain string) (*server.Account, error) + GetPeerFunc func(peerKey string) (*server.Peer, error) + MarkPeerConnectedFunc func(peerKey string, connected bool) error + RenamePeerFunc func(accountId string, peerKey string, newName string) (*server.Peer, error) + DeletePeerFunc func(accountId string, peerKey string) (*server.Peer, error) + GetPeerByIPFunc func(accountId string, peerIP string) (*server.Peer, error) + GetNetworkMapFunc func(peerKey string) (*server.NetworkMap, error) + AddPeerFunc func(setupKey string, peer *server.Peer) (*server.Peer, error) } func (am *MockAccountManager) GetOrCreateAccountByUser(userId, domain string) (*server.Account, error) { @@ -75,6 +77,13 @@ func (am *MockAccountManager) GetAccountByUserOrAccountId(userId, accountId, dom return nil, status.Errorf(codes.Unimplemented, "method GetAccountByUserOrAccountId not implemented") } +func (am *MockAccountManager) GetAccountWithAuthorizationClaims(claims jwtclaims.AuthorizationClaims) (*server.Account, error) { + if am.GetAccountWithAuthorizationClaimsFunc != nil { + return am.GetAccountWithAuthorizationClaimsFunc(claims) + } + return nil, status.Errorf(codes.Unimplemented, "method GetAccountWithAuthorizationClaims not implemented") +} + func (am *MockAccountManager) AccountExists(accountId string) (*bool, error) { if am.AccountExistsFunc != nil { return am.AccountExistsFunc(accountId) diff --git a/management/server/store.go b/management/server/store.go index 5581de9d2..58fbc45b0 100644 --- a/management/server/store.go +++ b/management/server/store.go @@ -9,5 +9,6 @@ type Store interface { GetAccountPeers(accountId string) ([]*Peer, error) GetPeerAccount(peerKey string) (*Account, error) GetAccountBySetupKey(setupKey string) (*Account, error) + GetAccountByPrivateDomain(domain string) (*Account, error) SaveAccount(account *Account) error } diff --git a/management/server/testdata/store.json b/management/server/testdata/store.json index d2c4743b0..1bd7311a1 100644 --- a/management/server/testdata/store.json +++ b/management/server/testdata/store.json @@ -2,6 +2,9 @@ "Accounts": { "bf1c8084-ba50-4ce7-9439-34653001fc3b": { "Id": "bf1c8084-ba50-4ce7-9439-34653001fc3b", + "Domain": "test.com", + "DomainCategory": "private", + "IsDomainPrimaryAccount": true, "SetupKeys": { "A2C8E62B-38F5-4553-B31E-DD66C696CEBB": { "Key": "A2C8E62B-38F5-4553-B31E-DD66C696CEBB", diff --git a/management/server/user.go b/management/server/user.go index 9c6297805..c7eb72d8e 100644 --- a/management/server/user.go +++ b/management/server/user.go @@ -3,6 +3,7 @@ package server import ( "google.golang.org/grpc/codes" "google.golang.org/grpc/status" + "strings" ) const ( @@ -34,6 +35,11 @@ func NewUser(id string, role UserRole) *User { } } +// NewRegularUser creates a new user with role UserRoleAdmin +func NewRegularUser(id string) *User { + return NewUser(id, UserRoleUser) +} + // NewAdminUser creates a new user with role UserRoleAdmin func NewAdminUser(id string) *User { return NewUser(id, UserRoleAdmin) @@ -44,10 +50,12 @@ func (am *DefaultAccountManager) GetOrCreateAccountByUser(userId, domain string) am.mux.Lock() defer am.mux.Unlock() + lowerDomain := strings.ToLower(domain) + account, err := am.Store.GetUserAccount(userId) if err != nil { if s, ok := status.FromError(err); ok && s.Code() == codes.NotFound { - account = NewAccount(userId, domain) + account = NewAccount(userId, lowerDomain) account.Users[userId] = NewAdminUser(userId) err = am.Store.SaveAccount(account) if err != nil { @@ -59,8 +67,8 @@ func (am *DefaultAccountManager) GetOrCreateAccountByUser(userId, domain string) } } - if account.Domain != domain { - account.Domain = domain + if account.Domain != lowerDomain { + account.Domain = lowerDomain err = am.Store.SaveAccount(account) if err != nil { return nil, status.Errorf(codes.Internal, "failed updating account with domain")