Block user through HTTP API (#846)

The new functionality allows blocking a user in the Management service.
Blocked users lose access to the Dashboard, aren't able to modify the network map,
and all of their connected devices disconnect and are set to the "login expired" state.

Technically all above was achieved with the updated PUT /api/users endpoint,
that was extended with the is_blocked field.
This commit is contained in:
Misha Bragin 2023-05-11 18:09:36 +02:00 committed by GitHub
parent 9f758b2015
commit e3d2b6a408
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
13 changed files with 505 additions and 155 deletions

View File

@ -49,16 +49,16 @@ type AccountManager interface {
CreateSetupKey(accountID string, keyName string, keyType SetupKeyType, expiresIn time.Duration,
autoGroups []string, usageLimit int, userID string) (*SetupKey, error)
SaveSetupKey(accountID string, key *SetupKey, userID string) (*SetupKey, error)
CreateUser(accountID, executingUserID string, key *UserInfo) (*UserInfo, error)
DeleteUser(accountID, executingUserID string, targetUserID string) error
CreateUser(accountID, initiatorUserID string, key *UserInfo) (*UserInfo, error)
DeleteUser(accountID, initiatorUserID string, targetUserID string) error
ListSetupKeys(accountID, userID string) ([]*SetupKey, error)
SaveUser(accountID, userID string, update *User) (*UserInfo, error)
SaveUser(accountID, initiatorUserID string, update *User) (*UserInfo, error)
GetSetupKey(accountID, userID, keyID string) (*SetupKey, error)
GetAccountByUserOrAccountID(userID, accountID, domain string) (*Account, error)
GetAccountFromToken(claims jwtclaims.AuthorizationClaims) (*Account, *User, error)
GetAccountFromPAT(pat string) (*Account, *User, *PersonalAccessToken, error)
MarkPATUsed(tokenID string) error
IsUserAdmin(claims jwtclaims.AuthorizationClaims) (bool, error)
GetUser(claims jwtclaims.AuthorizationClaims) (*User, error)
AccountExists(accountId string) (*bool, error)
GetPeerByKey(peerKey string) (*Peer, error)
GetPeers(accountID, userID string) ([]*Peer, error)
@ -69,10 +69,10 @@ type AccountManager interface {
GetNetworkMap(peerID string) (*NetworkMap, error)
GetPeerNetwork(peerID string) (*Network, error)
AddPeer(setupKey, userID string, peer *Peer) (*Peer, *NetworkMap, error)
CreatePAT(accountID string, executingUserID string, targetUserID string, tokenName string, expiresIn int) (*PersonalAccessTokenGenerated, error)
DeletePAT(accountID string, executingUserID string, targetUserID string, tokenID string) error
GetPAT(accountID string, executingUserID string, targetUserID string, tokenID string) (*PersonalAccessToken, error)
GetAllPATs(accountID string, executingUserID string, targetUserID string) ([]*PersonalAccessToken, error)
CreatePAT(accountID string, initiatorUserID string, targetUserID string, tokenName string, expiresIn int) (*PersonalAccessTokenGenerated, error)
DeletePAT(accountID string, initiatorUserID string, targetUserID string, tokenID string) error
GetPAT(accountID string, initiatorUserID string, targetUserID string, tokenID string) (*PersonalAccessToken, error)
GetAllPATs(accountID string, initiatorUserID string, targetUserID string) ([]*PersonalAccessToken, error)
UpdatePeerSSHKey(peerID string, sshKey string) error
GetUsersFromAccount(accountID, userID string) ([]*UserInfo, error)
GetGroup(accountId, groupID string) (*Group, error)
@ -179,6 +179,7 @@ type UserInfo struct {
AutoGroups []string `json:"auto_groups"`
Status string `json:"-"`
IsServiceUser bool `json:"is_service_user"`
IsBlocked bool `json:"is_blocked"`
}
// getRoutesToSync returns the enabled routes for the peer ID and the routes

View File

@ -91,6 +91,10 @@ const (
ServiceUserCreated
// ServiceUserDeleted indicates that a user deleted a service user
ServiceUserDeleted
// UserBlocked indicates that a user blocked another user
UserBlocked
// UserUnblocked indicates that a user unblocked another user
UserUnblocked
)
const (
@ -184,6 +188,10 @@ const (
ServiceUserCreatedMessage string = "Service user created"
// ServiceUserDeletedMessage is a human-readable text message of the ServiceUserDeleted activity
ServiceUserDeletedMessage string = "Service user deleted"
// UserBlockedMessage is a human-readable text message of the UserBlocked activity
UserBlockedMessage string = "User blocked"
// UserUnblockedMessage is a human-readable text message of the UserUnblocked activity
UserUnblockedMessage string = "User unblocked"
)
// Activity that triggered an Event
@ -282,6 +290,10 @@ func (a Activity) Message() string {
return ServiceUserCreatedMessage
case ServiceUserDeleted:
return ServiceUserDeletedMessage
case UserBlocked:
return UserBlockedMessage
case UserUnblocked:
return UserUnblockedMessage
default:
return "UNKNOWN_ACTIVITY"
}
@ -300,6 +312,10 @@ func (a Activity) StringCode() string {
return "user.join"
case UserInvited:
return "user.invite"
case UserBlocked:
return "user.block"
case UserUnblocked:
return "user.unblock"
case AccountCreated:
return "account.create"
case RuleAdded:

View File

@ -65,7 +65,7 @@ components:
status:
description: User's status
type: string
enum: [ "active","invited","disabled" ]
enum: [ "active","invited","blocked" ]
auto_groups:
description: Groups to auto-assign to peers registered by this user
type: array
@ -79,6 +79,9 @@ components:
description: Is true if this user is a service user
type: boolean
readOnly: true
is_blocked:
description: Is true if this user is blocked. Blocked users can't use the system
type: boolean
required:
- id
- email
@ -86,6 +89,7 @@ components:
- role
- auto_groups
- status
- is_blocked
UserRequest:
type: object
properties:
@ -97,9 +101,13 @@ components:
type: array
items:
type: string
is_blocked:
description: If set to true then user is blocked and can't use the system
type: boolean
required:
- role
- auto_groups
- is_blocked
UserCreateRequest:
type: object
properties:
@ -645,7 +653,7 @@ components:
description: The string code of the activity that occurred during the event
type: string
enum: [ "user.peer.delete", "user.join", "user.invite", "user.peer.add", "user.group.add", "user.group.delete",
"user.role.update",
"user.role.update", "user.block", "user.unblock",
"setupkey.peer.add", "setupkey.add", "setupkey.update", "setupkey.revoke", "setupkey.overuse",
"setupkey.group.delete", "setupkey.group.add",
"rule.add", "rule.delete", "rule.update",

View File

@ -46,6 +46,7 @@ const (
EventActivityCodeSetupkeyPeerAdd EventActivityCode = "setupkey.peer.add"
EventActivityCodeSetupkeyRevoke EventActivityCode = "setupkey.revoke"
EventActivityCodeSetupkeyUpdate EventActivityCode = "setupkey.update"
EventActivityCodeUserBlock EventActivityCode = "user.block"
EventActivityCodeUserGroupAdd EventActivityCode = "user.group.add"
EventActivityCodeUserGroupDelete EventActivityCode = "user.group.delete"
EventActivityCodeUserInvite EventActivityCode = "user.invite"
@ -53,6 +54,7 @@ const (
EventActivityCodeUserPeerAdd EventActivityCode = "user.peer.add"
EventActivityCodeUserPeerDelete EventActivityCode = "user.peer.delete"
EventActivityCodeUserRoleUpdate EventActivityCode = "user.role.update"
EventActivityCodeUserUnblock EventActivityCode = "user.unblock"
)
// Defines values for NameserverNsType.
@ -68,9 +70,9 @@ const (
// Defines values for UserStatus.
const (
UserStatusActive UserStatus = "active"
UserStatusDisabled UserStatus = "disabled"
UserStatusInvited UserStatus = "invited"
UserStatusActive UserStatus = "active"
UserStatusBlocked UserStatus = "blocked"
UserStatusInvited UserStatus = "invited"
)
// Account defines model for Account.
@ -552,6 +554,9 @@ type User struct {
// Id User ID
Id string `json:"id"`
// IsBlocked Is true if this user is blocked. Blocked users can't use the system
IsBlocked bool `json:"is_blocked"`
// IsCurrent Is true if authenticated user is the same as this user
IsCurrent *bool `json:"is_current,omitempty"`
@ -594,6 +599,9 @@ type UserRequest struct {
// AutoGroups Groups to auto-assign to peers registered by this user
AutoGroups []string `json:"auto_groups"`
// IsBlocked If set to true then user is blocked and can't use the system
IsBlocked bool `json:"is_blocked"`
// Role User's NetBird account role
Role string `json:"role"`
}

View File

@ -43,7 +43,7 @@ func APIHandler(accountManager s.AccountManager, jwtValidator jwtclaims.JWTValid
acMiddleware := middleware.NewAccessControl(
authCfg.Audience,
authCfg.UserIDClaim,
accountManager.IsUserAdmin)
accountManager.GetUser)
rootRouter := mux.NewRouter()
metricsMiddleware := appMetrics.HTTPMiddleware()

View File

@ -6,28 +6,30 @@ import (
log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/management/server"
"github.com/netbirdio/netbird/management/server/http/util"
"github.com/netbirdio/netbird/management/server/status"
"github.com/netbirdio/netbird/management/server/jwtclaims"
)
type IsUserAdminFunc func(claims jwtclaims.AuthorizationClaims) (bool, error)
// GetUser function defines a function to fetch user from Account by jwtclaims.AuthorizationClaims
type GetUser func(claims jwtclaims.AuthorizationClaims) (*server.User, error)
// AccessControl middleware to restrict to make POST/PUT/DELETE requests by admin only
type AccessControl struct {
isUserAdmin IsUserAdminFunc
claimsExtract jwtclaims.ClaimsExtractor
getUser GetUser
}
// NewAccessControl instance constructor
func NewAccessControl(audience, userIDClaim string, isUserAdmin IsUserAdminFunc) *AccessControl {
func NewAccessControl(audience, userIDClaim string, getUser GetUser) *AccessControl {
return &AccessControl{
isUserAdmin: isUserAdmin,
claimsExtract: *jwtclaims.NewClaimsExtractor(
jwtclaims.WithAudience(audience),
jwtclaims.WithUserIDClaim(userIDClaim),
),
getUser: getUser,
}
}
@ -37,23 +39,29 @@ func (a *AccessControl) Handler(h http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
claims := a.claimsExtract.FromRequestContext(r)
ok, err := a.isUserAdmin(claims)
user, err := a.getUser(claims)
if err != nil {
util.WriteError(status.Errorf(status.Unauthorized, "invalid JWT"), w)
return
}
if !ok {
if user.IsBlocked() {
util.WriteError(status.Errorf(status.PermissionDenied, "the user has no access to the API or is blocked"), w)
return
}
if !user.IsAdmin() {
switch r.Method {
case http.MethodDelete, http.MethodPost, http.MethodPatch, http.MethodPut:
ok, err := regexp.MatchString(`^.*/api/users/.*/tokens.*$`, r.URL.Path)
if err != nil {
log.Debugf("Regex failed")
log.Debugf("regex failed")
util.WriteError(status.Errorf(status.Internal, ""), w)
return
}
if ok {
log.Debugf("Valid Path")
log.Debugf("valid Path")
h.ServeHTTP(w, r)
return
}

View File

@ -63,7 +63,7 @@ var testAccount = &server.Account{
func initPATTestData() *PATHandler {
return &PATHandler{
accountManager: &mock_server.MockAccountManager{
CreatePATFunc: func(accountID string, executingUserID string, targetUserID string, tokenName string, expiresIn int) (*server.PersonalAccessTokenGenerated, error) {
CreatePATFunc: func(accountID string, initiatorUserID string, targetUserID string, tokenName string, expiresIn int) (*server.PersonalAccessTokenGenerated, error) {
if accountID != existingAccountID {
return nil, status.Errorf(status.NotFound, "account with ID %s not found", accountID)
}
@ -79,7 +79,7 @@ func initPATTestData() *PATHandler {
GetAccountFromTokenFunc: func(_ jwtclaims.AuthorizationClaims) (*server.Account, *server.User, error) {
return testAccount, testAccount.Users[existingUserID], nil
},
DeletePATFunc: func(accountID string, executingUserID string, targetUserID string, tokenID string) error {
DeletePATFunc: func(accountID string, initiatorUserID string, targetUserID string, tokenID string) error {
if accountID != existingAccountID {
return status.Errorf(status.NotFound, "account with ID %s not found", accountID)
}
@ -91,7 +91,7 @@ func initPATTestData() *PATHandler {
}
return nil
},
GetPATFunc: func(accountID string, executingUserID string, targetUserID string, tokenID string) (*server.PersonalAccessToken, error) {
GetPATFunc: func(accountID string, initiatorUserID string, targetUserID string, tokenID string) (*server.PersonalAccessToken, error) {
if accountID != existingAccountID {
return nil, status.Errorf(status.NotFound, "account with ID %s not found", accountID)
}
@ -103,7 +103,7 @@ func initPATTestData() *PATHandler {
}
return testAccount.Users[existingUserID].PATs[existingTokenID], nil
},
GetAllPATsFunc: func(accountID string, executingUserID string, targetUserID string) ([]*server.PersonalAccessToken, error) {
GetAllPATsFunc: func(accountID string, initiatorUserID string, targetUserID string) ([]*server.PersonalAccessToken, error) {
if accountID != existingAccountID {
return nil, status.Errorf(status.NotFound, "account with ID %s not found", accountID)
}

View File

@ -61,6 +61,11 @@ func (h *UsersHandler) UpdateUser(w http.ResponseWriter, r *http.Request) {
return
}
if req.AutoGroups == nil {
util.WriteErrorResponse("auto_groups field can't be absent", http.StatusBadRequest, w)
return
}
userRole := server.StrRoleToUserRole(req.Role)
if userRole == server.UserRoleUnknown {
util.WriteError(status.Errorf(status.InvalidArgument, "invalid user role"), w)
@ -71,7 +76,9 @@ func (h *UsersHandler) UpdateUser(w http.ResponseWriter, r *http.Request) {
Id: userID,
Role: userRole,
AutoGroups: req.AutoGroups,
Blocked: req.IsBlocked,
})
if err != nil {
util.WriteError(err, w)
return
@ -214,7 +221,11 @@ func toUserResponse(user *server.UserInfo, currenUserID string) *api.User {
case "invited":
userStatus = api.UserStatusInvited
default:
userStatus = api.UserStatusDisabled
userStatus = api.UserStatusBlocked
}
if user.IsBlocked {
userStatus = api.UserStatusBlocked
}
isCurrent := user.ID == currenUserID
@ -227,5 +238,6 @@ func toUserResponse(user *server.UserInfo, currenUserID string) *api.User {
Status: userStatus,
IsCurrent: &isCurrent,
IsServiceUser: &user.IsServiceUser,
IsBlocked: user.IsBlocked,
}
}

View File

@ -3,6 +3,7 @@ package http
import (
"bytes"
"encoding/json"
"fmt"
"io"
"net/http"
"net/http/httptest"
@ -31,16 +32,19 @@ var usersTestAccount = &server.Account{
Id: existingUserID,
Role: "admin",
IsServiceUser: false,
AutoGroups: []string{"group_1"},
},
regularUserID: {
Id: regularUserID,
Role: "user",
IsServiceUser: false,
AutoGroups: []string{"group_1"},
},
serviceUserID: {
Id: serviceUserID,
Role: "user",
IsServiceUser: true,
AutoGroups: []string{"group_1"},
},
},
}
@ -70,7 +74,7 @@ func initUsersTestData() *UsersHandler {
}
return key, nil
},
DeleteUserFunc: func(accountID string, executingUserID string, targetUserID string) error {
DeleteUserFunc: func(accountID string, initiatorUserID string, targetUserID string) error {
if targetUserID == notFoundUserID {
return status.Errorf(status.NotFound, "user with ID %s does not exists", targetUserID)
}
@ -79,6 +83,21 @@ func initUsersTestData() *UsersHandler {
}
return nil
},
SaveUserFunc: func(accountID, userID string, update *server.User) (*server.UserInfo, error) {
if update.Id == notFoundUserID {
return nil, status.Errorf(status.NotFound, "user with ID %s does not exists", update.Id)
}
if userID != existingUserID {
return nil, status.Errorf(status.NotFound, "user with ID %s does not exists", userID)
}
info, err := update.Copy().ToUserInfo(nil)
if err != nil {
return nil, err
}
return info, nil
},
},
claimsExtractor: jwtclaims.NewClaimsExtractor(
jwtclaims.WithFromRequestContext(func(r *http.Request) jwtclaims.AuthorizationClaims {
@ -145,6 +164,122 @@ func TestGetUsers(t *testing.T) {
}
}
func TestUpdateUser(t *testing.T) {
tt := []struct {
name string
expectedStatusCode int
requestType string
requestPath string
requestBody io.Reader
expectedUserID string
expectedRole string
expectedStatus string
expectedBlocked bool
expectedIsServiceUser bool
expectedGroups []string
}{
{
name: "Update_Block_User",
requestType: http.MethodPut,
requestPath: "/api/users/" + regularUserID,
expectedStatusCode: http.StatusOK,
expectedUserID: regularUserID,
expectedBlocked: true,
expectedRole: "user",
expectedStatus: "blocked",
expectedGroups: []string{"group_1"},
requestBody: bytes.NewBufferString("{\"role\":\"user\",\"auto_groups\":[\"group_1\"],\"is_service_user\":false, \"is_blocked\": true}"),
},
{
name: "Update_Change_Role_To_Admin",
requestType: http.MethodPut,
requestPath: "/api/users/" + regularUserID,
expectedStatusCode: http.StatusOK,
expectedUserID: regularUserID,
expectedBlocked: false,
expectedRole: "admin",
expectedStatus: "blocked",
expectedGroups: []string{"group_1"},
requestBody: bytes.NewBufferString("{\"role\":\"admin\",\"auto_groups\":[\"group_1\"],\"is_service_user\":false, \"is_blocked\": false}"),
},
{
name: "Update_Groups",
requestType: http.MethodPut,
requestPath: "/api/users/" + regularUserID,
expectedStatusCode: http.StatusOK,
expectedUserID: regularUserID,
expectedBlocked: false,
expectedRole: "admin",
expectedStatus: "blocked",
expectedGroups: []string{"group_2", "group_3"},
requestBody: bytes.NewBufferString("{\"role\":\"admin\",\"auto_groups\":[\"group_3\", \"group_2\"],\"is_service_user\":false, \"is_blocked\": false}"),
},
{
name: "Should_Fail_Because_AutoGroups_Is_Absent",
requestType: http.MethodPut,
requestPath: "/api/users/" + regularUserID,
expectedStatusCode: http.StatusBadRequest,
expectedUserID: regularUserID,
expectedBlocked: false,
expectedRole: "admin",
expectedStatus: "blocked",
expectedGroups: []string{"group_2", "group_3"},
requestBody: bytes.NewBufferString("{\"role\":\"admin\",\"is_service_user\":false, \"is_blocked\": false}"),
},
}
userHandler := initUsersTestData()
for _, tc := range tt {
t.Run(tc.name, func(t *testing.T) {
recorder := httptest.NewRecorder()
req := httptest.NewRequest(tc.requestType, tc.requestPath, tc.requestBody)
router := mux.NewRouter()
router.HandleFunc("/api/users/{userId}", userHandler.UpdateUser).Methods("PUT")
router.ServeHTTP(recorder, req)
res := recorder.Result()
defer res.Body.Close()
if status := recorder.Code; status != tc.expectedStatusCode {
t.Fatalf("handler returned wrong status code: got %v want %v",
status, http.StatusOK)
}
if tc.expectedStatusCode == 200 {
content, err := io.ReadAll(res.Body)
if err != nil {
t.Fatalf("I don't know what I expected; %v", err)
}
respBody := &api.User{}
err = json.Unmarshal(content, &respBody)
if err != nil {
t.Fatalf("response content is not in correct json format; %v", err)
}
assert.Equal(t, tc.expectedUserID, respBody.Id)
assert.Equal(t, tc.expectedRole, respBody.Role)
assert.Equal(t, tc.expectedIsServiceUser, *respBody.IsServiceUser)
assert.Equal(t, tc.expectedBlocked, respBody.IsBlocked)
assert.Len(t, respBody.AutoGroups, len(tc.expectedGroups))
for _, expectedGroup := range tc.expectedGroups {
exists := false
for _, actualGroup := range respBody.AutoGroups {
if expectedGroup == actualGroup {
exists = true
}
}
assert.True(t, exists, fmt.Sprintf("group %s not found in the response", expectedGroup))
}
}
})
}
}
func TestCreateUser(t *testing.T) {
name := "name"
email := "email"

View File

@ -19,7 +19,7 @@ type MockAccountManager struct {
expiresIn time.Duration, autoGroups []string, usageLimit int, userID string) (*server.SetupKey, error)
GetSetupKeyFunc func(accountID, userID, keyID string) (*server.SetupKey, error)
GetAccountByUserOrAccountIdFunc func(userId, accountId, domain string) (*server.Account, error)
IsUserAdminFunc func(claims jwtclaims.AuthorizationClaims) (bool, error)
GetUserFunc func(claims jwtclaims.AuthorizationClaims) (*server.User, error)
AccountExistsFunc func(accountId string) (*bool, error)
GetPeerByKeyFunc func(peerKey string) (*server.Peer, error)
GetPeersFunc func(accountID, userID string) ([]*server.Peer, error)
@ -60,11 +60,11 @@ type MockAccountManager struct {
SaveSetupKeyFunc func(accountID string, key *server.SetupKey, userID string) (*server.SetupKey, error)
ListSetupKeysFunc func(accountID, userID string) ([]*server.SetupKey, error)
SaveUserFunc func(accountID, userID string, user *server.User) (*server.UserInfo, error)
DeleteUserFunc func(accountID string, executingUserID string, targetUserID string) error
CreatePATFunc func(accountID string, executingUserID string, targetUserId string, tokenName string, expiresIn int) (*server.PersonalAccessTokenGenerated, error)
DeletePATFunc func(accountID string, executingUserID string, targetUserId string, tokenID string) error
GetPATFunc func(accountID string, executingUserID string, targetUserId string, tokenID string) (*server.PersonalAccessToken, error)
GetAllPATsFunc func(accountID string, executingUserID string, targetUserId string) ([]*server.PersonalAccessToken, error)
DeleteUserFunc func(accountID string, initiatorUserID string, targetUserID string) error
CreatePATFunc func(accountID string, initiatorUserID string, targetUserId string, tokenName string, expiresIn int) (*server.PersonalAccessTokenGenerated, error)
DeletePATFunc func(accountID string, initiatorUserID string, targetUserId string, tokenID string) error
GetPATFunc func(accountID string, initiatorUserID string, targetUserId string, tokenID string) (*server.PersonalAccessToken, error)
GetAllPATsFunc func(accountID string, initiatorUserID string, targetUserId string) ([]*server.PersonalAccessToken, error)
GetNameServerGroupFunc func(accountID, nsGroupID string) (*nbdns.NameServerGroup, error)
CreateNameServerGroupFunc func(accountID string, name, description string, nameServerList []nbdns.NameServer, groups []string, primary bool, domains []string, enabled bool, userID string) (*nbdns.NameServerGroup, error)
SaveNameServerGroupFunc func(accountID, userID string, nsGroupToSave *nbdns.NameServerGroup) error
@ -190,33 +190,33 @@ func (am *MockAccountManager) MarkPATUsed(pat string) error {
}
// CreatePAT mock implementation of GetPAT from server.AccountManager interface
func (am *MockAccountManager) CreatePAT(accountID string, executingUserID string, targetUserID string, name string, expiresIn int) (*server.PersonalAccessTokenGenerated, error) {
func (am *MockAccountManager) CreatePAT(accountID string, initiatorUserID string, targetUserID string, name string, expiresIn int) (*server.PersonalAccessTokenGenerated, error) {
if am.CreatePATFunc != nil {
return am.CreatePATFunc(accountID, executingUserID, targetUserID, name, expiresIn)
return am.CreatePATFunc(accountID, initiatorUserID, targetUserID, name, expiresIn)
}
return nil, status.Errorf(codes.Unimplemented, "method CreatePAT is not implemented")
}
// DeletePAT mock implementation of DeletePAT from server.AccountManager interface
func (am *MockAccountManager) DeletePAT(accountID string, executingUserID string, targetUserID string, tokenID string) error {
func (am *MockAccountManager) DeletePAT(accountID string, initiatorUserID string, targetUserID string, tokenID string) error {
if am.DeletePATFunc != nil {
return am.DeletePATFunc(accountID, executingUserID, targetUserID, tokenID)
return am.DeletePATFunc(accountID, initiatorUserID, targetUserID, tokenID)
}
return status.Errorf(codes.Unimplemented, "method DeletePAT is not implemented")
}
// GetPAT mock implementation of GetPAT from server.AccountManager interface
func (am *MockAccountManager) GetPAT(accountID string, executingUserID string, targetUserID string, tokenID string) (*server.PersonalAccessToken, error) {
func (am *MockAccountManager) GetPAT(accountID string, initiatorUserID string, targetUserID string, tokenID string) (*server.PersonalAccessToken, error) {
if am.GetPATFunc != nil {
return am.GetPATFunc(accountID, executingUserID, targetUserID, tokenID)
return am.GetPATFunc(accountID, initiatorUserID, targetUserID, tokenID)
}
return nil, status.Errorf(codes.Unimplemented, "method GetPAT is not implemented")
}
// GetAllPATs mock implementation of GetAllPATs from server.AccountManager interface
func (am *MockAccountManager) GetAllPATs(accountID string, executingUserID string, targetUserID string) ([]*server.PersonalAccessToken, error) {
func (am *MockAccountManager) GetAllPATs(accountID string, initiatorUserID string, targetUserID string) ([]*server.PersonalAccessToken, error) {
if am.GetAllPATsFunc != nil {
return am.GetAllPATsFunc(accountID, executingUserID, targetUserID)
return am.GetAllPATsFunc(accountID, initiatorUserID, targetUserID)
}
return nil, status.Errorf(codes.Unimplemented, "method GetAllPATs is not implemented")
}
@ -385,12 +385,12 @@ func (am *MockAccountManager) UpdatePeerMeta(peerID string, meta server.PeerSyst
return status.Errorf(codes.Unimplemented, "method UpdatePeerMetaFunc is not implemented")
}
// IsUserAdmin mock implementation of IsUserAdmin from server.AccountManager interface
func (am *MockAccountManager) IsUserAdmin(claims jwtclaims.AuthorizationClaims) (bool, error) {
if am.IsUserAdminFunc != nil {
return am.IsUserAdminFunc(claims)
// GetUser mock implementation of GetUser from server.AccountManager interface
func (am *MockAccountManager) GetUser(claims jwtclaims.AuthorizationClaims) (*server.User, error) {
if am.GetUserFunc != nil {
return am.GetUserFunc(claims)
}
return false, status.Errorf(codes.Unimplemented, "method IsUserAdmin is not implemented")
return nil, status.Errorf(codes.Unimplemented, "method IsUserGetUserAdmin is not implemented")
}
// UpdatePeerSSHKey mocks UpdatePeerSSHKey function of the account manager
@ -493,9 +493,9 @@ func (am *MockAccountManager) SaveUser(accountID, userID string, user *server.Us
}
// DeleteUser mocks DeleteUser of the AccountManager interface
func (am *MockAccountManager) DeleteUser(accountID string, executingUserID string, targetUserID string) error {
func (am *MockAccountManager) DeleteUser(accountID string, initiatorUserID string, targetUserID string) error {
if am.DeleteUserFunc != nil {
return am.DeleteUserFunc(accountID, executingUserID, targetUserID)
return am.DeleteUserFunc(accountID, initiatorUserID, targetUserID)
}
return status.Errorf(codes.Unimplemented, "method DeleteUser is not implemented")
}

View File

@ -605,6 +605,11 @@ func (am *DefaultAccountManager) SyncPeer(sync PeerSync) (*Peer, *NetworkMap, er
return nil, nil, status.Errorf(status.Unauthenticated, "peer is not registered")
}
err = checkIfPeerOwnerIsBlocked(peer, account)
if err != nil {
return nil, nil, err
}
if peerLoginExpired(peer, account) {
return nil, nil, status.Errorf(status.PermissionDenied, "peer login has expired, please log in once more")
}
@ -644,6 +649,11 @@ func (am *DefaultAccountManager) LoginPeer(login PeerLogin) (*Peer, *NetworkMap,
return nil, nil, status.Errorf(status.Unauthenticated, "peer is not registered")
}
err = checkIfPeerOwnerIsBlocked(peer, account)
if err != nil {
return nil, nil, err
}
updateRemotePeers := false
if peerLoginExpired(peer, account) {
err = checkAuth(login.UserID, peer)
@ -676,6 +686,19 @@ func (am *DefaultAccountManager) LoginPeer(login PeerLogin) (*Peer, *NetworkMap,
return peer, account.GetPeerNetworkMap(peer.ID, am.dnsDomain), nil
}
func checkIfPeerOwnerIsBlocked(peer *Peer, account *Account) error {
if peer.AddedWithSSOLogin() {
user, err := account.FindUser(peer.UserID)
if err != nil {
return status.Errorf(status.PermissionDenied, "user doesn't exist")
}
if user.IsBlocked() {
return status.Errorf(status.PermissionDenied, "user is blocked")
}
}
return nil
}
func checkAuth(loginUserID string, peer *Peer) error {
if loginUserID == "" {
// absence of a user ID indicates that JWT wasn't provided.

View File

@ -51,15 +51,22 @@ type User struct {
// AutoGroups is a list of Group IDs to auto-assign to peers registered by this user
AutoGroups []string
PATs map[string]*PersonalAccessToken
// Blocked indicates whether the user is blocked. Blocked users can't use the system.
Blocked bool
}
// IsAdmin returns true if user is an admin, false otherwise
// IsBlocked returns true if the user is blocked, false otherwise
func (u *User) IsBlocked() bool {
return u.Blocked
}
// IsAdmin returns true if the user is an admin, false otherwise
func (u *User) IsAdmin() bool {
return u.Role == UserRoleAdmin
}
// toUserInfo converts a User object to a UserInfo object.
func (u *User) toUserInfo(userData *idp.UserData) (*UserInfo, error) {
// ToUserInfo converts a User object to a UserInfo object.
func (u *User) ToUserInfo(userData *idp.UserData) (*UserInfo, error) {
autoGroups := u.AutoGroups
if autoGroups == nil {
autoGroups = []string{}
@ -74,6 +81,7 @@ func (u *User) toUserInfo(userData *idp.UserData) (*UserInfo, error) {
AutoGroups: u.AutoGroups,
Status: string(UserStatusActive),
IsServiceUser: u.IsServiceUser,
IsBlocked: u.Blocked,
}, nil
}
if userData.ID != u.Id {
@ -93,6 +101,7 @@ func (u *User) toUserInfo(userData *idp.UserData) (*UserInfo, error) {
AutoGroups: autoGroups,
Status: string(userStatus),
IsServiceUser: u.IsServiceUser,
IsBlocked: u.Blocked,
}, nil
}
@ -113,6 +122,7 @@ func (u *User) Copy() *User {
IsServiceUser: u.IsServiceUser,
ServiceUserName: u.ServiceUserName,
PATs: pats,
Blocked: u.Blocked,
}
}
@ -138,7 +148,7 @@ func NewAdminUser(id string) *User {
}
// createServiceUser creates a new service user under the given account.
func (am *DefaultAccountManager) createServiceUser(accountID string, executingUserID string, role UserRole, serviceUserName string, autoGroups []string) (*UserInfo, error) {
func (am *DefaultAccountManager) createServiceUser(accountID string, initiatorUserID string, role UserRole, serviceUserName string, autoGroups []string) (*UserInfo, error) {
unlock := am.Store.AcquireAccountLock(accountID)
defer unlock()
@ -147,7 +157,7 @@ func (am *DefaultAccountManager) createServiceUser(accountID string, executingUs
return nil, status.Errorf(status.NotFound, "account %s doesn't exist", accountID)
}
executingUser := account.Users[executingUserID]
executingUser := account.Users[initiatorUserID]
if executingUser == nil {
return nil, status.Errorf(status.NotFound, "user not found")
}
@ -166,7 +176,7 @@ func (am *DefaultAccountManager) createServiceUser(accountID string, executingUs
}
meta := map[string]any{"name": newUser.ServiceUserName}
am.storeEvent(executingUserID, newUser.Id, accountID, activity.ServiceUserCreated, meta)
am.storeEvent(initiatorUserID, newUser.Id, accountID, activity.ServiceUserCreated, meta)
return &UserInfo{
ID: newUser.Id,
@ -212,7 +222,7 @@ func (am *DefaultAccountManager) inviteNewUser(accountID, userID string, invite
}
if user != nil {
return nil, status.Errorf(status.UserAlreadyExists, "user has an existing account")
return nil, status.Errorf(status.UserAlreadyExists, "can't invite a user with an existing NetBird account")
}
users, err := am.idpManager.GetUserByEmail(invite.Email)
@ -221,7 +231,7 @@ func (am *DefaultAccountManager) inviteNewUser(accountID, userID string, invite
}
if len(users) > 0 {
return nil, status.Errorf(status.UserAlreadyExists, "user has an existing account")
return nil, status.Errorf(status.UserAlreadyExists, "can't invite a user with an existing NetBird account")
}
idpUser, err := am.idpManager.CreateUser(invite.Email, invite.Name, accountID)
@ -249,12 +259,27 @@ func (am *DefaultAccountManager) inviteNewUser(accountID, userID string, invite
am.storeEvent(userID, newUser.Id, accountID, activity.UserInvited, nil)
return newUser.toUserInfo(idpUser)
return newUser.ToUserInfo(idpUser)
}
// GetUser looks up a user by provided authorization claims.
// It will also create an account if didn't exist for this user before.
func (am *DefaultAccountManager) GetUser(claims jwtclaims.AuthorizationClaims) (*User, error) {
account, _, err := am.GetAccountFromToken(claims)
if err != nil {
return nil, fmt.Errorf("failed to get account with token claims %v", err)
}
user, ok := account.Users[claims.UserId]
if !ok {
return nil, status.Errorf(status.NotFound, "user not found")
}
return user, nil
}
// DeleteUser deletes a user from the given account.
func (am *DefaultAccountManager) DeleteUser(accountID, executingUserID string, targetUserID string) error {
func (am *DefaultAccountManager) DeleteUser(accountID, initiatorUserID string, targetUserID string) error {
unlock := am.Store.AcquireAccountLock(accountID)
defer unlock()
@ -268,7 +293,7 @@ func (am *DefaultAccountManager) DeleteUser(accountID, executingUserID string, t
return status.Errorf(status.NotFound, "user not found")
}
executingUser := account.Users[executingUserID]
executingUser := account.Users[initiatorUserID]
if executingUser == nil {
return status.Errorf(status.NotFound, "user not found")
}
@ -281,7 +306,7 @@ func (am *DefaultAccountManager) DeleteUser(accountID, executingUserID string, t
}
meta := map[string]any{"name": targetUser.ServiceUserName}
am.storeEvent(executingUserID, targetUserID, accountID, activity.ServiceUserDeleted, meta)
am.storeEvent(initiatorUserID, targetUserID, accountID, activity.ServiceUserDeleted, meta)
delete(account.Users, targetUserID)
@ -294,7 +319,7 @@ func (am *DefaultAccountManager) DeleteUser(accountID, executingUserID string, t
}
// CreatePAT creates a new PAT for the given user
func (am *DefaultAccountManager) CreatePAT(accountID string, executingUserID string, targetUserID string, tokenName string, expiresIn int) (*PersonalAccessTokenGenerated, error) {
func (am *DefaultAccountManager) CreatePAT(accountID string, initiatorUserID string, targetUserID string, tokenName string, expiresIn int) (*PersonalAccessTokenGenerated, error) {
unlock := am.Store.AcquireAccountLock(accountID)
defer unlock()
@ -316,12 +341,12 @@ func (am *DefaultAccountManager) CreatePAT(accountID string, executingUserID str
return nil, status.Errorf(status.NotFound, "targetUser not found")
}
executingUser := account.Users[executingUserID]
executingUser := account.Users[initiatorUserID]
if targetUser == nil {
return nil, status.Errorf(status.NotFound, "user not found")
}
if !(executingUserID == targetUserID || (executingUser.IsAdmin() && targetUser.IsServiceUser)) {
if !(initiatorUserID == targetUserID || (executingUser.IsAdmin() && targetUser.IsServiceUser)) {
return nil, status.Errorf(status.PermissionDenied, "no permission to create PAT for this user")
}
@ -338,13 +363,13 @@ func (am *DefaultAccountManager) CreatePAT(accountID string, executingUserID str
}
meta := map[string]any{"name": pat.Name, "is_service_user": targetUser.IsServiceUser, "user_name": targetUser.ServiceUserName}
am.storeEvent(executingUserID, targetUserID, accountID, activity.PersonalAccessTokenCreated, meta)
am.storeEvent(initiatorUserID, targetUserID, accountID, activity.PersonalAccessTokenCreated, meta)
return pat, nil
}
// DeletePAT deletes a specific PAT from a user
func (am *DefaultAccountManager) DeletePAT(accountID string, executingUserID string, targetUserID string, tokenID string) error {
func (am *DefaultAccountManager) DeletePAT(accountID string, initiatorUserID string, targetUserID string, tokenID string) error {
unlock := am.Store.AcquireAccountLock(accountID)
defer unlock()
@ -358,12 +383,12 @@ func (am *DefaultAccountManager) DeletePAT(accountID string, executingUserID str
return status.Errorf(status.NotFound, "user not found")
}
executingUser := account.Users[executingUserID]
executingUser := account.Users[initiatorUserID]
if targetUser == nil {
return status.Errorf(status.NotFound, "user not found")
}
if !(executingUserID == targetUserID || (executingUser.IsAdmin() && targetUser.IsServiceUser)) {
if !(initiatorUserID == targetUserID || (executingUser.IsAdmin() && targetUser.IsServiceUser)) {
return status.Errorf(status.PermissionDenied, "no permission to delete PAT for this user")
}
@ -382,7 +407,7 @@ func (am *DefaultAccountManager) DeletePAT(accountID string, executingUserID str
}
meta := map[string]any{"name": pat.Name, "is_service_user": targetUser.IsServiceUser, "user_name": targetUser.ServiceUserName}
am.storeEvent(executingUserID, targetUserID, accountID, activity.PersonalAccessTokenDeleted, meta)
am.storeEvent(initiatorUserID, targetUserID, accountID, activity.PersonalAccessTokenDeleted, meta)
delete(targetUser.PATs, tokenID)
@ -394,7 +419,7 @@ func (am *DefaultAccountManager) DeletePAT(accountID string, executingUserID str
}
// GetPAT returns a specific PAT from a user
func (am *DefaultAccountManager) GetPAT(accountID string, executingUserID string, targetUserID string, tokenID string) (*PersonalAccessToken, error) {
func (am *DefaultAccountManager) GetPAT(accountID string, initiatorUserID string, targetUserID string, tokenID string) (*PersonalAccessToken, error) {
unlock := am.Store.AcquireAccountLock(accountID)
defer unlock()
@ -408,12 +433,12 @@ func (am *DefaultAccountManager) GetPAT(accountID string, executingUserID string
return nil, status.Errorf(status.NotFound, "user not found")
}
executingUser := account.Users[executingUserID]
executingUser := account.Users[initiatorUserID]
if targetUser == nil {
return nil, status.Errorf(status.NotFound, "user not found")
}
if !(executingUserID == targetUserID || (executingUser.IsAdmin() && targetUser.IsServiceUser)) {
if !(initiatorUserID == targetUserID || (executingUser.IsAdmin() && targetUser.IsServiceUser)) {
return nil, status.Errorf(status.PermissionDenied, "no permission to get PAT for this userser")
}
@ -426,7 +451,7 @@ func (am *DefaultAccountManager) GetPAT(accountID string, executingUserID string
}
// GetAllPATs returns all PATs for a user
func (am *DefaultAccountManager) GetAllPATs(accountID string, executingUserID string, targetUserID string) ([]*PersonalAccessToken, error) {
func (am *DefaultAccountManager) GetAllPATs(accountID string, initiatorUserID string, targetUserID string) ([]*PersonalAccessToken, error) {
unlock := am.Store.AcquireAccountLock(accountID)
defer unlock()
@ -440,12 +465,12 @@ func (am *DefaultAccountManager) GetAllPATs(accountID string, executingUserID st
return nil, status.Errorf(status.NotFound, "user not found")
}
executingUser := account.Users[executingUserID]
executingUser := account.Users[initiatorUserID]
if targetUser == nil {
return nil, status.Errorf(status.NotFound, "user not found")
}
if !(executingUserID == targetUserID || (executingUser.IsAdmin() && targetUser.IsServiceUser)) {
if !(initiatorUserID == targetUserID || (executingUser.IsAdmin() && targetUser.IsServiceUser)) {
return nil, status.Errorf(status.PermissionDenied, "no permission to get PAT for this user")
}
@ -457,9 +482,9 @@ func (am *DefaultAccountManager) GetAllPATs(accountID string, executingUserID st
return pats, nil
}
// SaveUser saves updates a given user. If the user doesn't exit it will throw status.NotFound error.
// Only User.AutoGroups field is allowed to be updated for now.
func (am *DefaultAccountManager) SaveUser(accountID, userID string, update *User) (*UserInfo, error) {
// SaveUser saves updates to the given user. If the user doesn't exit it will throw status.NotFound error.
// Only User.AutoGroups, User.Role, and User.Blocked fields are allowed to be updated for now.
func (am *DefaultAccountManager) SaveUser(accountID, initiatorUserID string, update *User) (*UserInfo, error) {
unlock := am.Store.AcquireAccountLock(accountID)
defer unlock()
@ -472,56 +497,102 @@ func (am *DefaultAccountManager) SaveUser(accountID, userID string, update *User
return nil, err
}
initiatorUser, err := account.FindUser(initiatorUserID)
if err != nil {
return nil, err
}
if !initiatorUser.IsAdmin() || initiatorUser.IsBlocked() {
return nil, status.Errorf(status.PermissionDenied, "only admins are authorized to perform user update operations")
}
oldUser := account.Users[update.Id]
if oldUser == nil {
return nil, status.Errorf(status.NotFound, "user to update doesn't exist")
}
if initiatorUser.IsAdmin() && initiatorUserID == update.Id && oldUser.Blocked != update.Blocked {
return nil, status.Errorf(status.PermissionDenied, "admins can't block or unblock themselves")
}
if initiatorUser.IsAdmin() && initiatorUserID == update.Id && update.Role != UserRoleAdmin {
return nil, status.Errorf(status.PermissionDenied, "admins can't change their role")
}
// only auto groups, revoked status, and name can be updated for now
newUser := oldUser.Copy()
newUser.Role = update.Role
newUser.Blocked = update.Blocked
for _, newGroupID := range update.AutoGroups {
if _, ok := account.Groups[newGroupID]; !ok {
return nil, status.Errorf(status.InvalidArgument, "provided group ID %s in the user %s update doesn't exist",
newGroupID, update.Id)
}
}
oldUser := account.Users[update.Id]
if oldUser == nil {
return nil, status.Errorf(status.NotFound, "update not found")
}
// only auto groups, revoked status, and name can be updated for now
newUser := oldUser.Copy()
newUser.AutoGroups = update.AutoGroups
newUser.Role = update.Role
account.Users[newUser.Id] = newUser
if !oldUser.IsBlocked() && update.IsBlocked() {
// expire peers that belong to the user who's getting blocked
blockedPeers, err := account.FindUserPeers(update.Id)
if err != nil {
return nil, err
}
var peerIDs []string
for _, peer := range blockedPeers {
peerIDs = append(peerIDs, peer.ID)
peer.MarkLoginExpired(true)
account.UpdatePeer(peer)
err = am.Store.SavePeerStatus(account.Id, peer.ID, *peer.Status)
if err != nil {
log.Errorf("failed saving peer status while expiring peer %s", peer.ID)
return nil, err
}
}
am.peersUpdateManager.CloseChannels(peerIDs)
err = am.updateAccountPeers(account)
if err != nil {
log.Errorf("failed updating account peers while expiring peers of a blocked user %s", accountID)
return nil, err
}
}
if err = am.Store.SaveAccount(account); err != nil {
return nil, err
}
defer func() {
// store activity logs
if oldUser.Role != newUser.Role {
am.storeEvent(userID, oldUser.Id, accountID, activity.UserRoleUpdated, map[string]any{"role": newUser.Role})
am.storeEvent(initiatorUserID, oldUser.Id, accountID, activity.UserRoleUpdated, map[string]any{"role": newUser.Role})
}
removedGroups := difference(oldUser.AutoGroups, update.AutoGroups)
addedGroups := difference(newUser.AutoGroups, oldUser.AutoGroups)
for _, g := range removedGroups {
group := account.GetGroup(g)
if group != nil {
am.storeEvent(userID, oldUser.Id, accountID, activity.GroupRemovedFromUser,
map[string]any{"group": group.Name, "group_id": group.ID, "is_service_user": newUser.IsServiceUser, "user_name": newUser.ServiceUserName})
} else {
log.Errorf("group %s not found while saving user activity event of account %s", g, account.Id)
if update.AutoGroups != nil {
removedGroups := difference(oldUser.AutoGroups, update.AutoGroups)
addedGroups := difference(newUser.AutoGroups, oldUser.AutoGroups)
for _, g := range removedGroups {
group := account.GetGroup(g)
if group != nil {
am.storeEvent(initiatorUserID, oldUser.Id, accountID, activity.GroupRemovedFromUser,
map[string]any{"group": group.Name, "group_id": group.ID, "is_service_user": newUser.IsServiceUser, "user_name": newUser.ServiceUserName})
} else {
log.Errorf("group %s not found while saving user activity event of account %s", g, account.Id)
}
}
}
for _, g := range addedGroups {
group := account.GetGroup(g)
if group != nil {
am.storeEvent(userID, oldUser.Id, accountID, activity.GroupAddedToUser,
map[string]any{"group": group.Name, "group_id": group.ID, "is_service_user": newUser.IsServiceUser, "user_name": newUser.ServiceUserName})
} else {
log.Errorf("group %s not found while saving user activity event of account %s", g, account.Id)
for _, g := range addedGroups {
group := account.GetGroup(g)
if group != nil {
am.storeEvent(initiatorUserID, oldUser.Id, accountID, activity.GroupAddedToUser,
map[string]any{"group": group.Name, "group_id": group.ID, "is_service_user": newUser.IsServiceUser, "user_name": newUser.ServiceUserName})
}
}
}
}()
if !isNil(am.idpManager) && !newUser.IsServiceUser {
@ -532,9 +603,9 @@ func (am *DefaultAccountManager) SaveUser(accountID, userID string, update *User
if userData == nil {
return nil, status.Errorf(status.NotFound, "user %s not found in the IdP", newUser.Id)
}
return newUser.toUserInfo(userData)
return newUser.ToUserInfo(userData)
}
return newUser.toUserInfo(nil)
return newUser.ToUserInfo(nil)
}
// GetOrCreateAccountByUser returns an existing account for a given user id or creates a new one if doesn't exist
@ -574,21 +645,6 @@ func (am *DefaultAccountManager) GetOrCreateAccountByUser(userID, domain string)
return account, nil
}
// IsUserAdmin looks up a user by his ID and returns true if he is an admin
func (am *DefaultAccountManager) IsUserAdmin(claims jwtclaims.AuthorizationClaims) (bool, error) {
account, _, err := am.GetAccountFromToken(claims)
if err != nil {
return false, fmt.Errorf("get account: %v", err)
}
user, ok := account.Users[claims.UserId]
if !ok {
return false, status.Errorf(status.NotFound, "user not found")
}
return user.Role == UserRoleAdmin, nil
}
// GetUsersFromAccount performs a batched request for users from IDP by account ID apply filter on what data to return
// based on provided user role.
func (am *DefaultAccountManager) GetUsersFromAccount(accountID, userID string) ([]*UserInfo, error) {
@ -625,7 +681,7 @@ func (am *DefaultAccountManager) GetUsersFromAccount(accountID, userID string) (
// if user is not an admin then show only current user and do not show other users
continue
}
info, err := accountUser.toUserInfo(nil)
info, err := accountUser.ToUserInfo(nil)
if err != nil {
return nil, err
}
@ -642,7 +698,7 @@ func (am *DefaultAccountManager) GetUsersFromAccount(accountID, userID string) (
var info *UserInfo
if queriedUser, contains := findUserInIDPUserdata(localUser.Id, queriedUsers); contains {
info, err = localUser.toUserInfo(queriedUser)
info, err = localUser.ToUserInfo(queriedUser)
if err != nil {
return nil, err
}

View File

@ -8,6 +8,7 @@ import (
"github.com/google/go-cmp/cmp"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/netbirdio/netbird/management/server/activity"
"github.com/netbirdio/netbird/management/server/jwtclaims"
@ -265,6 +266,7 @@ func TestUser_Copy(t *testing.T) {
LastUsed: time.Now(),
},
},
Blocked: false,
}
err := validateStruct(user)
@ -288,7 +290,7 @@ func validateStruct(s interface{}) (err error) {
field := structVal.Field(i)
fieldName := structType.Field(i).Name
isSet := field.IsValid() && !field.IsZero()
isSet := field.IsValid() && (!field.IsZero() || field.Type().String() == "bool")
if !isSet {
err = fmt.Errorf("%v%s in not set; ", err, fieldName)
@ -440,7 +442,7 @@ func TestUser_DeleteUser_regularUser(t *testing.T) {
assert.Errorf(t, err, "Regular users can not be deleted (yet)")
}
func TestUser_IsUserAdmin_ForAdmin(t *testing.T) {
func TestDefaultAccountManager_GetUser(t *testing.T) {
store := newStore(t)
account := newAccountWithId(mockAccountID, mockUserID, "")
@ -458,42 +460,23 @@ func TestUser_IsUserAdmin_ForAdmin(t *testing.T) {
UserId: mockUserID,
}
ok, err := am.IsUserAdmin(claims)
user, err := am.GetUser(claims)
if err != nil {
t.Fatalf("Error when checking user role: %s", err)
}
assert.True(t, ok)
assert.Equal(t, mockUserID, user.Id)
assert.True(t, user.IsAdmin())
assert.False(t, user.IsBlocked())
}
func TestUser_IsUserAdmin_ForUser(t *testing.T) {
store := newStore(t)
account := newAccountWithId(mockAccountID, mockUserID, "")
account.Users[mockUserID] = &User{
Id: mockUserID,
Role: "user",
}
func TestUser_IsAdmin(t *testing.T) {
err := store.SaveAccount(account)
if err != nil {
t.Fatalf("Error when saving account: %s", err)
}
user := NewAdminUser(mockUserID)
assert.True(t, user.IsAdmin())
am := DefaultAccountManager{
Store: store,
eventStore: &activity.InMemoryEventStore{},
}
claims := jwtclaims.AuthorizationClaims{
UserId: mockUserID,
}
ok, err := am.IsUserAdmin(claims)
if err != nil {
t.Fatalf("Error when checking user role: %s", err)
}
assert.False(t, ok)
user = NewRegularUser(mockUserID)
assert.False(t, user.IsAdmin())
}
func TestUser_GetUsersFromAccount_ForAdmin(t *testing.T) {
@ -550,3 +533,103 @@ func TestUser_GetUsersFromAccount_ForUser(t *testing.T) {
assert.Equal(t, 1, len(users))
assert.Equal(t, mockServiceUserID, users[0].ID)
}
func TestDefaultAccountManager_SaveUser(t *testing.T) {
manager, err := createManager(t)
if err != nil {
t.Fatal(err)
return
}
regularUserID := "regularUser"
tt := []struct {
name string
adminInitiator bool
update *User
expectedErr bool
}{
{
name: "Should_Fail_To_Update_Admin_Role",
expectedErr: true,
adminInitiator: true,
update: &User{
Id: userID,
Role: UserRoleUser,
Blocked: false,
},
}, {
name: "Should_Fail_When_Admin_Blocks_Themselves",
expectedErr: true,
adminInitiator: true,
update: &User{
Id: userID,
Role: UserRoleAdmin,
Blocked: true,
},
},
{
name: "Should_Fail_To_Update_Non_Existing_User",
expectedErr: true,
adminInitiator: true,
update: &User{
Id: userID,
Role: UserRoleAdmin,
Blocked: true,
},
},
{
name: "Should_Fail_To_Update_When_Initiator_Is_Not_An_Admin",
expectedErr: true,
adminInitiator: false,
update: &User{
Id: userID,
Role: UserRoleAdmin,
Blocked: true,
},
},
{
name: "Should_Update_User",
expectedErr: false,
adminInitiator: true,
update: &User{
Id: regularUserID,
Role: UserRoleAdmin,
Blocked: true,
},
},
}
for _, tc := range tt {
// create an account and an admin user
account, err := manager.GetOrCreateAccountByUser(userID, "netbird.io")
if err != nil {
t.Fatal(err)
}
// create a regular user
account.Users[regularUserID] = NewRegularUser(regularUserID)
err = manager.Store.SaveAccount(account)
if err != nil {
t.Fatal(err)
}
initiatorID := userID
if !tc.adminInitiator {
initiatorID = regularUserID
}
updated, err := manager.SaveUser(account.Id, initiatorID, tc.update)
if tc.expectedErr {
require.Errorf(t, err, "expecting SaveUser to throw an error")
} else {
require.NoError(t, err, "expecting SaveUser not to throw an error")
assert.NotNil(t, updated)
assert.Equal(t, string(tc.update.Role), updated.Role)
assert.Equal(t, tc.update.IsBlocked(), updated.IsBlocked)
}
}
}