diff --git a/management/server/account.go b/management/server/account.go index 1852ae763..4221e5f59 100644 --- a/management/server/account.go +++ b/management/server/account.go @@ -44,7 +44,6 @@ type AccountManager interface { GetAccountWithAuthorizationClaims(claims jwtclaims.AuthorizationClaims) (*Account, error) IsUserAdmin(claims jwtclaims.AuthorizationClaims) (bool, error) AccountExists(accountId string) (*bool, error) - AddAccount(accountId, userId, domain string) (*Account, error) GetPeer(peerKey string) (*Peer, error) MarkPeerConnected(peerKey string, connected bool) error RenamePeer(accountId string, peerKey string, newName string) (*Peer, error) @@ -171,7 +170,7 @@ func BuildManager( for _, account := range store.GetAllAccounts() { _, err := account.GetGroupAll() if err != nil { - am.addAllGroup(account) + addAllGroup(account) if err := store.SaveAccount(account); err != nil { return nil, err } @@ -524,8 +523,6 @@ func (am *DefaultAccountManager) handleNewUserAccount( } } else { account = NewAccount(claims.UserId, lowerDomain) - am.addAllGroup(account) - account.Users[claims.UserId] = NewAdminUser(claims.UserId) err = am.updateAccountDomainAttributes(account, claims, true) if err != nil { return nil, err @@ -622,29 +619,8 @@ func (am *DefaultAccountManager) AccountExists(accountId string) (*bool, error) return &res, nil } -// AddAccount generates a new Account with a provided accountId and userId, saves to the Store -func (am *DefaultAccountManager) AddAccount(accountId, userId, domain string) (*Account, error) { - am.mux.Lock() - defer am.mux.Unlock() - - return am.createAccount(accountId, userId, domain) -} - -func (am *DefaultAccountManager) createAccount(accountId, userId, domain string) (*Account, error) { - account := newAccountWithId(accountId, userId, domain) - - am.addAllGroup(account) - - err := am.Store.SaveAccount(account) - if err != nil { - return nil, status.Errorf(codes.Internal, "failed creating account") - } - - return account, nil -} - // addAllGroup to account object if it doesn't exists -func (am *DefaultAccountManager) addAllGroup(account *Account) { +func addAllGroup(account *Account) { if len(account.Groups) == 0 { allGroup := &Group{ ID: xid.New().String(), @@ -677,10 +653,10 @@ func newAccountWithId(accountId, userId, domain string) *Account { network := NewNetwork() peers := make(map[string]*Peer) users := make(map[string]*User) - + users[userId] = NewAdminUser(userId) log.Debugf("created new account %s with setup key %s", accountId, defaultKey.Key) - return &Account{ + acc := &Account{ Id: accountId, SetupKeys: setupKeys, Network: network, @@ -689,6 +665,9 @@ func newAccountWithId(accountId, userId, domain string) *Account { CreatedBy: userId, Domain: domain, } + + addAllGroup(acc) + return acc } func getAccountSetupKeyById(acc *Account, keyId string) *SetupKey { diff --git a/management/server/account_test.go b/management/server/account_test.go index 3dab9b347..c7de2e838 100644 --- a/management/server/account_test.go +++ b/management/server/account_test.go @@ -11,6 +11,95 @@ import ( "golang.zx2c4.com/wireguard/wgctrl/wgtypes" ) +func verifyCanAddPeerToAccount(t *testing.T, manager AccountManager, account *Account, userID string) { + peer := &Peer{ + Key: "BhRPtynAAYRDy08+q4HTMsos8fs4plTP4NOSh7C1ry8=", + Name: "test-host@netbird.io", + Meta: PeerSystemMeta{ + Hostname: "test-host@netbird.io", + GoOS: "linux", + Kernel: "Linux", + Core: "21.04", + Platform: "x86_64", + OS: "Ubuntu", + WtVersion: "development", + UIVersion: "development", + }, + } + + var setupKey string + for _, key := range account.SetupKeys { + setupKey = key.Key + } + + _, err := manager.AddPeer(setupKey, userID, peer) + if err != nil { + t.Error("expected to add new peer successfully after creating new account, but failed", err) + } +} + +func verifyNewAccountHasDefaultFields(t *testing.T, account *Account, createdBy string, domain string, expectedUsers []string) { + if len(account.Peers) != 0 { + t.Errorf("expected account to have len(Peers) = %v, got %v", 0, len(account.Peers)) + } + + if len(account.SetupKeys) != 2 { + t.Errorf("expected account to have len(SetupKeys) = %v, got %v", 2, len(account.SetupKeys)) + } + + ipNet := net.IPNet{IP: net.ParseIP("100.64.0.0"), Mask: net.IPMask{255, 192, 0, 0}} + if !ipNet.Contains(account.Network.Net.IP) { + t.Errorf("expected account's Network to be a subnet of %v, got %v", ipNet.String(), account.Network.Net.String()) + } + + g, err := account.GetGroupAll() + if err != nil { + t.Fatal(err) + } + if g.Name != "All" { + t.Errorf("expecting account to have group ALL added by default") + } + if len(account.Users) != len(expectedUsers) { + t.Errorf("expecting account to have %d users, got %d", len(expectedUsers), len(account.Users)) + } + + if account.Users[createdBy] == nil { + t.Errorf("expecting account to have createdBy user %s in a user map ", createdBy) + } + + for _, expectedUserID := range expectedUsers { + if account.Users[expectedUserID] == nil { + t.Errorf("expecting account to have a user %s in a user map", expectedUserID) + } + } + + if account.CreatedBy != createdBy { + t.Errorf("expecting newly created account to be created by user %s, got %s", createdBy, account.CreatedBy) + } + + if account.Domain != domain { + t.Errorf("expecting newly created account to have domain %s, got %s", domain, account.Domain) + } + + if len(account.Rules) != 1 { + t.Errorf("expecting newly created account to have 1 rule, got %d", len(account.Rules)) + } + + for _, rule := range account.Rules { + if rule.Name != "Default" { + t.Errorf("expecting newly created account to have Default rule, got %s", rule.Name) + } + } +} + +func TestNewAccount(t *testing.T) { + + domain := "netbird.io" + userId := "account_creator" + account := NewAccount(userId, domain) + verifyNewAccountHasDefaultFields(t, account, userId, domain, []string{userId}) +} + func TestAccountManager_GetOrCreateAccountByUser(t *testing.T) { manager, err := createManager(t) if err != nil { @@ -51,6 +140,8 @@ func TestDefaultAccountManager_GetAccountWithAuthorizationClaims(t *testing.T) { expectedUserRole UserRole expectedDomainCategory string expectedPrimaryDomainStatus bool + expectedCreatedBy string + expectedUsers []string } var ( @@ -77,6 +168,8 @@ func TestDefaultAccountManager_GetAccountWithAuthorizationClaims(t *testing.T) { expectedUserRole: UserRoleAdmin, expectedDomainCategory: "", expectedPrimaryDomainStatus: false, + expectedCreatedBy: "pub-domain-user", + expectedUsers: []string{"pub-domain-user"}, } initUnknown := defaultInitAccount @@ -96,6 +189,8 @@ func TestDefaultAccountManager_GetAccountWithAuthorizationClaims(t *testing.T) { expectedUserRole: UserRoleAdmin, expectedDomainCategory: "", expectedPrimaryDomainStatus: false, + expectedCreatedBy: "unknown-domain-user", + expectedUsers: []string{"unknown-domain-user"}, } testCase3 := test{ @@ -111,6 +206,8 @@ func TestDefaultAccountManager_GetAccountWithAuthorizationClaims(t *testing.T) { expectedUserRole: UserRoleAdmin, expectedDomainCategory: PrivateCategory, expectedPrimaryDomainStatus: true, + expectedCreatedBy: "pvt-domain-user", + expectedUsers: []string{"pvt-domain-user"}, } privateInitAccount := defaultInitAccount @@ -121,7 +218,7 @@ func TestDefaultAccountManager_GetAccountWithAuthorizationClaims(t *testing.T) { name: "New Regular User With Existing Private Domain", inputClaims: jwtclaims.AuthorizationClaims{ Domain: privateDomain, - UserId: "pvt-domain-user", + UserId: "new-pvt-domain-user", DomainCategory: PrivateCategory, }, inputUpdateAttrs: true, @@ -131,6 +228,8 @@ func TestDefaultAccountManager_GetAccountWithAuthorizationClaims(t *testing.T) { expectedUserRole: UserRoleUser, expectedDomainCategory: PrivateCategory, expectedPrimaryDomainStatus: true, + expectedCreatedBy: defaultInitAccount.UserId, + expectedUsers: []string{defaultInitAccount.UserId, "new-pvt-domain-user"}, } testCase5 := test{ @@ -146,6 +245,8 @@ func TestDefaultAccountManager_GetAccountWithAuthorizationClaims(t *testing.T) { expectedUserRole: UserRoleAdmin, expectedDomainCategory: PrivateCategory, expectedPrimaryDomainStatus: true, + expectedCreatedBy: defaultInitAccount.UserId, + expectedUsers: []string{defaultInitAccount.UserId}, } testCase6 := test{ @@ -162,6 +263,8 @@ func TestDefaultAccountManager_GetAccountWithAuthorizationClaims(t *testing.T) { expectedUserRole: UserRoleAdmin, expectedDomainCategory: PrivateCategory, expectedPrimaryDomainStatus: true, + expectedCreatedBy: defaultInitAccount.UserId, + expectedUsers: []string{defaultInitAccount.UserId}, } for _, testCase := range []test{testCase1, testCase2, testCase3, testCase4, testCase5, testCase6} { t.Run(testCase.name, func(t *testing.T) { @@ -182,6 +285,8 @@ func TestDefaultAccountManager_GetAccountWithAuthorizationClaims(t *testing.T) { account, err := manager.GetAccountWithAuthorizationClaims(testCase.inputClaims) require.NoError(t, err, "support function failed") + verifyNewAccountHasDefaultFields(t, account, testCase.expectedCreatedBy, testCase.inputClaims.Domain, testCase.expectedUsers) + verifyCanAddPeerToAccount(t, manager, account, testCase.expectedCreatedBy) testCase.testingFunc(t, initAccount.Id, account.Id, testCase.expectedMSG) @@ -255,41 +360,6 @@ func TestAccountManager_SetOrUpdateDomain(t *testing.T) { } } -func TestAccountManager_AddAccount(t *testing.T) { - manager, err := createManager(t) - if err != nil { - t.Fatal(err) - return - } - - expectedId := "test_account" - userId := "account_creator" - expectedPeersSize := 0 - expectedSetupKeysSize := 2 - - account, err := manager.AddAccount(expectedId, userId, "") - if err != nil { - t.Fatal(err) - } - - if account.Id != expectedId { - t.Errorf("expected account to have Id = %s, got %s", expectedId, account.Id) - } - - if len(account.Peers) != expectedPeersSize { - t.Errorf("expected account to have len(Peers) = %v, got %v", expectedPeersSize, len(account.Peers)) - } - - if len(account.SetupKeys) != expectedSetupKeysSize { - t.Errorf("expected account to have len(SetupKeys) = %v, got %v", expectedSetupKeysSize, len(account.SetupKeys)) - } - - ipNet := net.IPNet{IP: net.ParseIP("100.64.0.0"), Mask: net.IPMask{255, 192, 0, 0}} - if !ipNet.Contains(account.Network.Net.IP) { - t.Errorf("expected account's Network to be a subnet of %v, got %v", ipNet.String(), account.Network.Net.String()) - } -} - func TestAccountManager_GetAccountByUserOrAccountId(t *testing.T) { manager, err := createManager(t) if err != nil { @@ -320,6 +390,15 @@ func TestAccountManager_GetAccountByUserOrAccountId(t *testing.T) { } } +func createAccount(am *DefaultAccountManager, accountID, userID, domain string) (*Account, error) { + account := newAccountWithId(accountID, userID, domain) + err := am.Store.SaveAccount(account) + if err != nil { + return nil, err + } + return account, nil +} + func TestAccountManager_AccountExists(t *testing.T) { manager, err := createManager(t) if err != nil { @@ -329,7 +408,7 @@ func TestAccountManager_AccountExists(t *testing.T) { expectedId := "test_account" userId := "account_creator" - _, err = manager.AddAccount(expectedId, userId, "") + _, err = createAccount(manager, expectedId, userId, "") if err != nil { t.Fatal(err) } @@ -353,7 +432,7 @@ func TestAccountManager_GetAccount(t *testing.T) { expectedId := "test_account" userId := "account_creator" - account, err := manager.AddAccount(expectedId, userId, "") + account, err := createAccount(manager, expectedId, userId, "") if err != nil { t.Fatal(err) } @@ -389,7 +468,7 @@ func TestAccountManager_AddPeer(t *testing.T) { return } - account, err := manager.AddAccount("test_account", "account_creator", "") + account, err := createAccount(manager, "test_account", "account_creator", "") if err != nil { t.Fatal(err) } @@ -521,7 +600,7 @@ func TestAccountManager_NetworkUpdates(t *testing.T) { return } - account, err := manager.AddAccount("test_account", "account_creator", "") + account, err := createAccount(manager, "test_account", "account_creator", "") if err != nil { t.Fatal(err) } @@ -704,7 +783,7 @@ func TestAccountManager_DeletePeer(t *testing.T) { return } - account, err := manager.AddAccount("test_account", "account_creator", "") + account, err := createAccount(manager, "test_account", "account_creator", "") if err != nil { t.Fatal(err) } @@ -757,7 +836,7 @@ func TestGetUsersFromAccount(t *testing.T) { users := map[string]*User{"1": {Id: "1", Role: "admin"}, "2": {Id: "2", Role: "user"}, "3": {Id: "3", Role: "user"}} accountId := "test_account_id" - account, err := manager.AddAccount(accountId, users["1"].Id, "") + account, err := createAccount(manager, accountId, users["1"].Id, "") if err != nil { t.Fatal(err) } @@ -788,7 +867,7 @@ func TestAccountManager_UpdatePeerMeta(t *testing.T) { return } - account, err := manager.AddAccount("test_account", "account_creator", "") + account, err := createAccount(manager, "test_account", "account_creator", "") if err != nil { t.Fatal(err) } diff --git a/management/server/file_store_test.go b/management/server/file_store_test.go index df6c7d40e..27b996949 100644 --- a/management/server/file_store_test.go +++ b/management/server/file_store_test.go @@ -34,7 +34,6 @@ func TestSaveAccount(t *testing.T) { store := newStore(t) account := NewAccount("testuser", "") - account.Users["testuser"] = NewAdminUser("testuser") setupKey := GenerateDefaultSetupKey() account.SetupKeys[setupKey.Key] = setupKey account.Peers["testpeer"] = &Peer{ @@ -74,7 +73,6 @@ func TestStore(t *testing.T) { store := newStore(t) account := NewAccount("testuser", "") - account.Users["testuser"] = NewAdminUser("testuser") account.Peers["testpeer"] = &Peer{ Key: "peerkey", SetupKey: "peerkeysetupkey", diff --git a/management/server/mock_server/account_mock.go b/management/server/mock_server/account_mock.go index 3a5aa5ec4..341f7a3f2 100644 --- a/management/server/mock_server/account_mock.go +++ b/management/server/mock_server/account_mock.go @@ -19,7 +19,6 @@ type MockAccountManager struct { GetAccountWithAuthorizationClaimsFunc func(claims jwtclaims.AuthorizationClaims) (*server.Account, error) IsUserAdminFunc func(claims jwtclaims.AuthorizationClaims) (bool, error) AccountExistsFunc func(accountId string) (*bool, error) - AddAccountFunc func(accountId, userId, domain string) (*server.Account, error) GetPeerFunc func(peerKey string) (*server.Peer, error) MarkPeerConnectedFunc func(peerKey string, connected bool) error RenamePeerFunc func(accountId string, peerKey string, newName string) (*server.Peer, error) @@ -139,15 +138,6 @@ func (am *MockAccountManager) AccountExists(accountId string) (*bool, error) { return nil, status.Errorf(codes.Unimplemented, "method AccountExists not implemented") } -func (am *MockAccountManager) AddAccount( - accountId, userId, domain string, -) (*server.Account, error) { - if am.AddAccountFunc != nil { - return am.AddAccountFunc(accountId, userId, domain) - } - return nil, status.Errorf(codes.Unimplemented, "method AddAccount not implemented") -} - func (am *MockAccountManager) GetPeer(peerKey string) (*server.Peer, error) { if am.GetPeerFunc != nil { return am.GetPeerFunc(peerKey) diff --git a/management/server/peer_test.go b/management/server/peer_test.go index 58210e61d..c48ac9975 100644 --- a/management/server/peer_test.go +++ b/management/server/peer_test.go @@ -16,7 +16,7 @@ func TestAccountManager_GetNetworkMap(t *testing.T) { expectedId := "test_account" userId := "account_creator" - account, err := manager.AddAccount(expectedId, userId, "") + account, err := createAccount(manager, expectedId, userId, "") if err != nil { t.Fatal(err) } @@ -89,7 +89,7 @@ func TestAccountManager_GetNetworkMapWithRule(t *testing.T) { expectedId := "test_account" userId := "account_creator" - account, err := manager.AddAccount(expectedId, userId, "") + account, err := createAccount(manager, expectedId, userId, "") if err != nil { t.Fatal(err) } diff --git a/management/server/user.go b/management/server/user.go index 0691c6db7..05c42af2e 100644 --- a/management/server/user.go +++ b/management/server/user.go @@ -60,8 +60,6 @@ func (am *DefaultAccountManager) GetOrCreateAccountByUser(userId, domain string) if err != nil { if s, ok := status.FromError(err); ok && s.Code() == codes.NotFound { account = NewAccount(userId, lowerDomain) - account.Users[userId] = NewAdminUser(userId) - am.addAllGroup(account) err = am.Store.SaveAccount(account) if err != nil { return nil, status.Errorf(codes.Internal, "failed creating account")