mirror of
https://github.com/netbirdio/netbird.git
synced 2025-01-11 16:38:27 +01:00
Block user through HTTP API (#846)
The new functionality allows blocking a user in the Management service. Blocked users lose access to the Dashboard, aren't able to modify the network map, and all of their connected devices disconnect and are set to the "login expired" state. Technically all above was achieved with the updated PUT /api/users endpoint, that was extended with the is_blocked field.
This commit is contained in:
parent
9f758b2015
commit
e3d2b6a408
@ -49,16 +49,16 @@ type AccountManager interface {
|
||||
CreateSetupKey(accountID string, keyName string, keyType SetupKeyType, expiresIn time.Duration,
|
||||
autoGroups []string, usageLimit int, userID string) (*SetupKey, error)
|
||||
SaveSetupKey(accountID string, key *SetupKey, userID string) (*SetupKey, error)
|
||||
CreateUser(accountID, executingUserID string, key *UserInfo) (*UserInfo, error)
|
||||
DeleteUser(accountID, executingUserID string, targetUserID string) error
|
||||
CreateUser(accountID, initiatorUserID string, key *UserInfo) (*UserInfo, error)
|
||||
DeleteUser(accountID, initiatorUserID string, targetUserID string) error
|
||||
ListSetupKeys(accountID, userID string) ([]*SetupKey, error)
|
||||
SaveUser(accountID, userID string, update *User) (*UserInfo, error)
|
||||
SaveUser(accountID, initiatorUserID string, update *User) (*UserInfo, error)
|
||||
GetSetupKey(accountID, userID, keyID string) (*SetupKey, error)
|
||||
GetAccountByUserOrAccountID(userID, accountID, domain string) (*Account, error)
|
||||
GetAccountFromToken(claims jwtclaims.AuthorizationClaims) (*Account, *User, error)
|
||||
GetAccountFromPAT(pat string) (*Account, *User, *PersonalAccessToken, error)
|
||||
MarkPATUsed(tokenID string) error
|
||||
IsUserAdmin(claims jwtclaims.AuthorizationClaims) (bool, error)
|
||||
GetUser(claims jwtclaims.AuthorizationClaims) (*User, error)
|
||||
AccountExists(accountId string) (*bool, error)
|
||||
GetPeerByKey(peerKey string) (*Peer, error)
|
||||
GetPeers(accountID, userID string) ([]*Peer, error)
|
||||
@ -69,10 +69,10 @@ type AccountManager interface {
|
||||
GetNetworkMap(peerID string) (*NetworkMap, error)
|
||||
GetPeerNetwork(peerID string) (*Network, error)
|
||||
AddPeer(setupKey, userID string, peer *Peer) (*Peer, *NetworkMap, error)
|
||||
CreatePAT(accountID string, executingUserID string, targetUserID string, tokenName string, expiresIn int) (*PersonalAccessTokenGenerated, error)
|
||||
DeletePAT(accountID string, executingUserID string, targetUserID string, tokenID string) error
|
||||
GetPAT(accountID string, executingUserID string, targetUserID string, tokenID string) (*PersonalAccessToken, error)
|
||||
GetAllPATs(accountID string, executingUserID string, targetUserID string) ([]*PersonalAccessToken, error)
|
||||
CreatePAT(accountID string, initiatorUserID string, targetUserID string, tokenName string, expiresIn int) (*PersonalAccessTokenGenerated, error)
|
||||
DeletePAT(accountID string, initiatorUserID string, targetUserID string, tokenID string) error
|
||||
GetPAT(accountID string, initiatorUserID string, targetUserID string, tokenID string) (*PersonalAccessToken, error)
|
||||
GetAllPATs(accountID string, initiatorUserID string, targetUserID string) ([]*PersonalAccessToken, error)
|
||||
UpdatePeerSSHKey(peerID string, sshKey string) error
|
||||
GetUsersFromAccount(accountID, userID string) ([]*UserInfo, error)
|
||||
GetGroup(accountId, groupID string) (*Group, error)
|
||||
@ -179,6 +179,7 @@ type UserInfo struct {
|
||||
AutoGroups []string `json:"auto_groups"`
|
||||
Status string `json:"-"`
|
||||
IsServiceUser bool `json:"is_service_user"`
|
||||
IsBlocked bool `json:"is_blocked"`
|
||||
}
|
||||
|
||||
// getRoutesToSync returns the enabled routes for the peer ID and the routes
|
||||
|
@ -91,6 +91,10 @@ const (
|
||||
ServiceUserCreated
|
||||
// ServiceUserDeleted indicates that a user deleted a service user
|
||||
ServiceUserDeleted
|
||||
// UserBlocked indicates that a user blocked another user
|
||||
UserBlocked
|
||||
// UserUnblocked indicates that a user unblocked another user
|
||||
UserUnblocked
|
||||
)
|
||||
|
||||
const (
|
||||
@ -184,6 +188,10 @@ const (
|
||||
ServiceUserCreatedMessage string = "Service user created"
|
||||
// ServiceUserDeletedMessage is a human-readable text message of the ServiceUserDeleted activity
|
||||
ServiceUserDeletedMessage string = "Service user deleted"
|
||||
// UserBlockedMessage is a human-readable text message of the UserBlocked activity
|
||||
UserBlockedMessage string = "User blocked"
|
||||
// UserUnblockedMessage is a human-readable text message of the UserUnblocked activity
|
||||
UserUnblockedMessage string = "User unblocked"
|
||||
)
|
||||
|
||||
// Activity that triggered an Event
|
||||
@ -282,6 +290,10 @@ func (a Activity) Message() string {
|
||||
return ServiceUserCreatedMessage
|
||||
case ServiceUserDeleted:
|
||||
return ServiceUserDeletedMessage
|
||||
case UserBlocked:
|
||||
return UserBlockedMessage
|
||||
case UserUnblocked:
|
||||
return UserUnblockedMessage
|
||||
default:
|
||||
return "UNKNOWN_ACTIVITY"
|
||||
}
|
||||
@ -300,6 +312,10 @@ func (a Activity) StringCode() string {
|
||||
return "user.join"
|
||||
case UserInvited:
|
||||
return "user.invite"
|
||||
case UserBlocked:
|
||||
return "user.block"
|
||||
case UserUnblocked:
|
||||
return "user.unblock"
|
||||
case AccountCreated:
|
||||
return "account.create"
|
||||
case RuleAdded:
|
||||
|
@ -65,7 +65,7 @@ components:
|
||||
status:
|
||||
description: User's status
|
||||
type: string
|
||||
enum: [ "active","invited","disabled" ]
|
||||
enum: [ "active","invited","blocked" ]
|
||||
auto_groups:
|
||||
description: Groups to auto-assign to peers registered by this user
|
||||
type: array
|
||||
@ -79,6 +79,9 @@ components:
|
||||
description: Is true if this user is a service user
|
||||
type: boolean
|
||||
readOnly: true
|
||||
is_blocked:
|
||||
description: Is true if this user is blocked. Blocked users can't use the system
|
||||
type: boolean
|
||||
required:
|
||||
- id
|
||||
- email
|
||||
@ -86,6 +89,7 @@ components:
|
||||
- role
|
||||
- auto_groups
|
||||
- status
|
||||
- is_blocked
|
||||
UserRequest:
|
||||
type: object
|
||||
properties:
|
||||
@ -97,9 +101,13 @@ components:
|
||||
type: array
|
||||
items:
|
||||
type: string
|
||||
is_blocked:
|
||||
description: If set to true then user is blocked and can't use the system
|
||||
type: boolean
|
||||
required:
|
||||
- role
|
||||
- auto_groups
|
||||
- is_blocked
|
||||
UserCreateRequest:
|
||||
type: object
|
||||
properties:
|
||||
@ -645,7 +653,7 @@ components:
|
||||
description: The string code of the activity that occurred during the event
|
||||
type: string
|
||||
enum: [ "user.peer.delete", "user.join", "user.invite", "user.peer.add", "user.group.add", "user.group.delete",
|
||||
"user.role.update",
|
||||
"user.role.update", "user.block", "user.unblock",
|
||||
"setupkey.peer.add", "setupkey.add", "setupkey.update", "setupkey.revoke", "setupkey.overuse",
|
||||
"setupkey.group.delete", "setupkey.group.add",
|
||||
"rule.add", "rule.delete", "rule.update",
|
||||
|
@ -46,6 +46,7 @@ const (
|
||||
EventActivityCodeSetupkeyPeerAdd EventActivityCode = "setupkey.peer.add"
|
||||
EventActivityCodeSetupkeyRevoke EventActivityCode = "setupkey.revoke"
|
||||
EventActivityCodeSetupkeyUpdate EventActivityCode = "setupkey.update"
|
||||
EventActivityCodeUserBlock EventActivityCode = "user.block"
|
||||
EventActivityCodeUserGroupAdd EventActivityCode = "user.group.add"
|
||||
EventActivityCodeUserGroupDelete EventActivityCode = "user.group.delete"
|
||||
EventActivityCodeUserInvite EventActivityCode = "user.invite"
|
||||
@ -53,6 +54,7 @@ const (
|
||||
EventActivityCodeUserPeerAdd EventActivityCode = "user.peer.add"
|
||||
EventActivityCodeUserPeerDelete EventActivityCode = "user.peer.delete"
|
||||
EventActivityCodeUserRoleUpdate EventActivityCode = "user.role.update"
|
||||
EventActivityCodeUserUnblock EventActivityCode = "user.unblock"
|
||||
)
|
||||
|
||||
// Defines values for NameserverNsType.
|
||||
@ -68,9 +70,9 @@ const (
|
||||
|
||||
// Defines values for UserStatus.
|
||||
const (
|
||||
UserStatusActive UserStatus = "active"
|
||||
UserStatusDisabled UserStatus = "disabled"
|
||||
UserStatusInvited UserStatus = "invited"
|
||||
UserStatusActive UserStatus = "active"
|
||||
UserStatusBlocked UserStatus = "blocked"
|
||||
UserStatusInvited UserStatus = "invited"
|
||||
)
|
||||
|
||||
// Account defines model for Account.
|
||||
@ -552,6 +554,9 @@ type User struct {
|
||||
// Id User ID
|
||||
Id string `json:"id"`
|
||||
|
||||
// IsBlocked Is true if this user is blocked. Blocked users can't use the system
|
||||
IsBlocked bool `json:"is_blocked"`
|
||||
|
||||
// IsCurrent Is true if authenticated user is the same as this user
|
||||
IsCurrent *bool `json:"is_current,omitempty"`
|
||||
|
||||
@ -594,6 +599,9 @@ type UserRequest struct {
|
||||
// AutoGroups Groups to auto-assign to peers registered by this user
|
||||
AutoGroups []string `json:"auto_groups"`
|
||||
|
||||
// IsBlocked If set to true then user is blocked and can't use the system
|
||||
IsBlocked bool `json:"is_blocked"`
|
||||
|
||||
// Role User's NetBird account role
|
||||
Role string `json:"role"`
|
||||
}
|
||||
|
@ -43,7 +43,7 @@ func APIHandler(accountManager s.AccountManager, jwtValidator jwtclaims.JWTValid
|
||||
acMiddleware := middleware.NewAccessControl(
|
||||
authCfg.Audience,
|
||||
authCfg.UserIDClaim,
|
||||
accountManager.IsUserAdmin)
|
||||
accountManager.GetUser)
|
||||
|
||||
rootRouter := mux.NewRouter()
|
||||
metricsMiddleware := appMetrics.HTTPMiddleware()
|
||||
|
@ -6,28 +6,30 @@ import (
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
"github.com/netbirdio/netbird/management/server"
|
||||
"github.com/netbirdio/netbird/management/server/http/util"
|
||||
"github.com/netbirdio/netbird/management/server/status"
|
||||
|
||||
"github.com/netbirdio/netbird/management/server/jwtclaims"
|
||||
)
|
||||
|
||||
type IsUserAdminFunc func(claims jwtclaims.AuthorizationClaims) (bool, error)
|
||||
// GetUser function defines a function to fetch user from Account by jwtclaims.AuthorizationClaims
|
||||
type GetUser func(claims jwtclaims.AuthorizationClaims) (*server.User, error)
|
||||
|
||||
// AccessControl middleware to restrict to make POST/PUT/DELETE requests by admin only
|
||||
type AccessControl struct {
|
||||
isUserAdmin IsUserAdminFunc
|
||||
claimsExtract jwtclaims.ClaimsExtractor
|
||||
getUser GetUser
|
||||
}
|
||||
|
||||
// NewAccessControl instance constructor
|
||||
func NewAccessControl(audience, userIDClaim string, isUserAdmin IsUserAdminFunc) *AccessControl {
|
||||
func NewAccessControl(audience, userIDClaim string, getUser GetUser) *AccessControl {
|
||||
return &AccessControl{
|
||||
isUserAdmin: isUserAdmin,
|
||||
claimsExtract: *jwtclaims.NewClaimsExtractor(
|
||||
jwtclaims.WithAudience(audience),
|
||||
jwtclaims.WithUserIDClaim(userIDClaim),
|
||||
),
|
||||
getUser: getUser,
|
||||
}
|
||||
}
|
||||
|
||||
@ -37,23 +39,29 @@ func (a *AccessControl) Handler(h http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
claims := a.claimsExtract.FromRequestContext(r)
|
||||
|
||||
ok, err := a.isUserAdmin(claims)
|
||||
user, err := a.getUser(claims)
|
||||
if err != nil {
|
||||
util.WriteError(status.Errorf(status.Unauthorized, "invalid JWT"), w)
|
||||
return
|
||||
}
|
||||
if !ok {
|
||||
|
||||
if user.IsBlocked() {
|
||||
util.WriteError(status.Errorf(status.PermissionDenied, "the user has no access to the API or is blocked"), w)
|
||||
return
|
||||
}
|
||||
|
||||
if !user.IsAdmin() {
|
||||
switch r.Method {
|
||||
case http.MethodDelete, http.MethodPost, http.MethodPatch, http.MethodPut:
|
||||
|
||||
ok, err := regexp.MatchString(`^.*/api/users/.*/tokens.*$`, r.URL.Path)
|
||||
if err != nil {
|
||||
log.Debugf("Regex failed")
|
||||
log.Debugf("regex failed")
|
||||
util.WriteError(status.Errorf(status.Internal, ""), w)
|
||||
return
|
||||
}
|
||||
if ok {
|
||||
log.Debugf("Valid Path")
|
||||
log.Debugf("valid Path")
|
||||
h.ServeHTTP(w, r)
|
||||
return
|
||||
}
|
||||
|
@ -63,7 +63,7 @@ var testAccount = &server.Account{
|
||||
func initPATTestData() *PATHandler {
|
||||
return &PATHandler{
|
||||
accountManager: &mock_server.MockAccountManager{
|
||||
CreatePATFunc: func(accountID string, executingUserID string, targetUserID string, tokenName string, expiresIn int) (*server.PersonalAccessTokenGenerated, error) {
|
||||
CreatePATFunc: func(accountID string, initiatorUserID string, targetUserID string, tokenName string, expiresIn int) (*server.PersonalAccessTokenGenerated, error) {
|
||||
if accountID != existingAccountID {
|
||||
return nil, status.Errorf(status.NotFound, "account with ID %s not found", accountID)
|
||||
}
|
||||
@ -79,7 +79,7 @@ func initPATTestData() *PATHandler {
|
||||
GetAccountFromTokenFunc: func(_ jwtclaims.AuthorizationClaims) (*server.Account, *server.User, error) {
|
||||
return testAccount, testAccount.Users[existingUserID], nil
|
||||
},
|
||||
DeletePATFunc: func(accountID string, executingUserID string, targetUserID string, tokenID string) error {
|
||||
DeletePATFunc: func(accountID string, initiatorUserID string, targetUserID string, tokenID string) error {
|
||||
if accountID != existingAccountID {
|
||||
return status.Errorf(status.NotFound, "account with ID %s not found", accountID)
|
||||
}
|
||||
@ -91,7 +91,7 @@ func initPATTestData() *PATHandler {
|
||||
}
|
||||
return nil
|
||||
},
|
||||
GetPATFunc: func(accountID string, executingUserID string, targetUserID string, tokenID string) (*server.PersonalAccessToken, error) {
|
||||
GetPATFunc: func(accountID string, initiatorUserID string, targetUserID string, tokenID string) (*server.PersonalAccessToken, error) {
|
||||
if accountID != existingAccountID {
|
||||
return nil, status.Errorf(status.NotFound, "account with ID %s not found", accountID)
|
||||
}
|
||||
@ -103,7 +103,7 @@ func initPATTestData() *PATHandler {
|
||||
}
|
||||
return testAccount.Users[existingUserID].PATs[existingTokenID], nil
|
||||
},
|
||||
GetAllPATsFunc: func(accountID string, executingUserID string, targetUserID string) ([]*server.PersonalAccessToken, error) {
|
||||
GetAllPATsFunc: func(accountID string, initiatorUserID string, targetUserID string) ([]*server.PersonalAccessToken, error) {
|
||||
if accountID != existingAccountID {
|
||||
return nil, status.Errorf(status.NotFound, "account with ID %s not found", accountID)
|
||||
}
|
||||
|
@ -61,6 +61,11 @@ func (h *UsersHandler) UpdateUser(w http.ResponseWriter, r *http.Request) {
|
||||
return
|
||||
}
|
||||
|
||||
if req.AutoGroups == nil {
|
||||
util.WriteErrorResponse("auto_groups field can't be absent", http.StatusBadRequest, w)
|
||||
return
|
||||
}
|
||||
|
||||
userRole := server.StrRoleToUserRole(req.Role)
|
||||
if userRole == server.UserRoleUnknown {
|
||||
util.WriteError(status.Errorf(status.InvalidArgument, "invalid user role"), w)
|
||||
@ -71,7 +76,9 @@ func (h *UsersHandler) UpdateUser(w http.ResponseWriter, r *http.Request) {
|
||||
Id: userID,
|
||||
Role: userRole,
|
||||
AutoGroups: req.AutoGroups,
|
||||
Blocked: req.IsBlocked,
|
||||
})
|
||||
|
||||
if err != nil {
|
||||
util.WriteError(err, w)
|
||||
return
|
||||
@ -214,7 +221,11 @@ func toUserResponse(user *server.UserInfo, currenUserID string) *api.User {
|
||||
case "invited":
|
||||
userStatus = api.UserStatusInvited
|
||||
default:
|
||||
userStatus = api.UserStatusDisabled
|
||||
userStatus = api.UserStatusBlocked
|
||||
}
|
||||
|
||||
if user.IsBlocked {
|
||||
userStatus = api.UserStatusBlocked
|
||||
}
|
||||
|
||||
isCurrent := user.ID == currenUserID
|
||||
@ -227,5 +238,6 @@ func toUserResponse(user *server.UserInfo, currenUserID string) *api.User {
|
||||
Status: userStatus,
|
||||
IsCurrent: &isCurrent,
|
||||
IsServiceUser: &user.IsServiceUser,
|
||||
IsBlocked: user.IsBlocked,
|
||||
}
|
||||
}
|
||||
|
@ -3,6 +3,7 @@ package http
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
@ -31,16 +32,19 @@ var usersTestAccount = &server.Account{
|
||||
Id: existingUserID,
|
||||
Role: "admin",
|
||||
IsServiceUser: false,
|
||||
AutoGroups: []string{"group_1"},
|
||||
},
|
||||
regularUserID: {
|
||||
Id: regularUserID,
|
||||
Role: "user",
|
||||
IsServiceUser: false,
|
||||
AutoGroups: []string{"group_1"},
|
||||
},
|
||||
serviceUserID: {
|
||||
Id: serviceUserID,
|
||||
Role: "user",
|
||||
IsServiceUser: true,
|
||||
AutoGroups: []string{"group_1"},
|
||||
},
|
||||
},
|
||||
}
|
||||
@ -70,7 +74,7 @@ func initUsersTestData() *UsersHandler {
|
||||
}
|
||||
return key, nil
|
||||
},
|
||||
DeleteUserFunc: func(accountID string, executingUserID string, targetUserID string) error {
|
||||
DeleteUserFunc: func(accountID string, initiatorUserID string, targetUserID string) error {
|
||||
if targetUserID == notFoundUserID {
|
||||
return status.Errorf(status.NotFound, "user with ID %s does not exists", targetUserID)
|
||||
}
|
||||
@ -79,6 +83,21 @@ func initUsersTestData() *UsersHandler {
|
||||
}
|
||||
return nil
|
||||
},
|
||||
SaveUserFunc: func(accountID, userID string, update *server.User) (*server.UserInfo, error) {
|
||||
if update.Id == notFoundUserID {
|
||||
return nil, status.Errorf(status.NotFound, "user with ID %s does not exists", update.Id)
|
||||
}
|
||||
|
||||
if userID != existingUserID {
|
||||
return nil, status.Errorf(status.NotFound, "user with ID %s does not exists", userID)
|
||||
}
|
||||
|
||||
info, err := update.Copy().ToUserInfo(nil)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return info, nil
|
||||
},
|
||||
},
|
||||
claimsExtractor: jwtclaims.NewClaimsExtractor(
|
||||
jwtclaims.WithFromRequestContext(func(r *http.Request) jwtclaims.AuthorizationClaims {
|
||||
@ -145,6 +164,122 @@ func TestGetUsers(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestUpdateUser(t *testing.T) {
|
||||
tt := []struct {
|
||||
name string
|
||||
expectedStatusCode int
|
||||
requestType string
|
||||
requestPath string
|
||||
requestBody io.Reader
|
||||
expectedUserID string
|
||||
expectedRole string
|
||||
expectedStatus string
|
||||
expectedBlocked bool
|
||||
expectedIsServiceUser bool
|
||||
expectedGroups []string
|
||||
}{
|
||||
{
|
||||
name: "Update_Block_User",
|
||||
requestType: http.MethodPut,
|
||||
requestPath: "/api/users/" + regularUserID,
|
||||
expectedStatusCode: http.StatusOK,
|
||||
expectedUserID: regularUserID,
|
||||
expectedBlocked: true,
|
||||
expectedRole: "user",
|
||||
expectedStatus: "blocked",
|
||||
expectedGroups: []string{"group_1"},
|
||||
requestBody: bytes.NewBufferString("{\"role\":\"user\",\"auto_groups\":[\"group_1\"],\"is_service_user\":false, \"is_blocked\": true}"),
|
||||
},
|
||||
{
|
||||
name: "Update_Change_Role_To_Admin",
|
||||
requestType: http.MethodPut,
|
||||
requestPath: "/api/users/" + regularUserID,
|
||||
expectedStatusCode: http.StatusOK,
|
||||
expectedUserID: regularUserID,
|
||||
expectedBlocked: false,
|
||||
expectedRole: "admin",
|
||||
expectedStatus: "blocked",
|
||||
expectedGroups: []string{"group_1"},
|
||||
requestBody: bytes.NewBufferString("{\"role\":\"admin\",\"auto_groups\":[\"group_1\"],\"is_service_user\":false, \"is_blocked\": false}"),
|
||||
},
|
||||
{
|
||||
name: "Update_Groups",
|
||||
requestType: http.MethodPut,
|
||||
requestPath: "/api/users/" + regularUserID,
|
||||
expectedStatusCode: http.StatusOK,
|
||||
expectedUserID: regularUserID,
|
||||
expectedBlocked: false,
|
||||
expectedRole: "admin",
|
||||
expectedStatus: "blocked",
|
||||
expectedGroups: []string{"group_2", "group_3"},
|
||||
requestBody: bytes.NewBufferString("{\"role\":\"admin\",\"auto_groups\":[\"group_3\", \"group_2\"],\"is_service_user\":false, \"is_blocked\": false}"),
|
||||
},
|
||||
{
|
||||
name: "Should_Fail_Because_AutoGroups_Is_Absent",
|
||||
requestType: http.MethodPut,
|
||||
requestPath: "/api/users/" + regularUserID,
|
||||
expectedStatusCode: http.StatusBadRequest,
|
||||
expectedUserID: regularUserID,
|
||||
expectedBlocked: false,
|
||||
expectedRole: "admin",
|
||||
expectedStatus: "blocked",
|
||||
expectedGroups: []string{"group_2", "group_3"},
|
||||
requestBody: bytes.NewBufferString("{\"role\":\"admin\",\"is_service_user\":false, \"is_blocked\": false}"),
|
||||
},
|
||||
}
|
||||
|
||||
userHandler := initUsersTestData()
|
||||
|
||||
for _, tc := range tt {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
recorder := httptest.NewRecorder()
|
||||
req := httptest.NewRequest(tc.requestType, tc.requestPath, tc.requestBody)
|
||||
|
||||
router := mux.NewRouter()
|
||||
router.HandleFunc("/api/users/{userId}", userHandler.UpdateUser).Methods("PUT")
|
||||
router.ServeHTTP(recorder, req)
|
||||
|
||||
res := recorder.Result()
|
||||
defer res.Body.Close()
|
||||
|
||||
if status := recorder.Code; status != tc.expectedStatusCode {
|
||||
t.Fatalf("handler returned wrong status code: got %v want %v",
|
||||
status, http.StatusOK)
|
||||
}
|
||||
|
||||
if tc.expectedStatusCode == 200 {
|
||||
|
||||
content, err := io.ReadAll(res.Body)
|
||||
if err != nil {
|
||||
t.Fatalf("I don't know what I expected; %v", err)
|
||||
}
|
||||
|
||||
respBody := &api.User{}
|
||||
err = json.Unmarshal(content, &respBody)
|
||||
if err != nil {
|
||||
t.Fatalf("response content is not in correct json format; %v", err)
|
||||
}
|
||||
|
||||
assert.Equal(t, tc.expectedUserID, respBody.Id)
|
||||
assert.Equal(t, tc.expectedRole, respBody.Role)
|
||||
assert.Equal(t, tc.expectedIsServiceUser, *respBody.IsServiceUser)
|
||||
assert.Equal(t, tc.expectedBlocked, respBody.IsBlocked)
|
||||
assert.Len(t, respBody.AutoGroups, len(tc.expectedGroups))
|
||||
|
||||
for _, expectedGroup := range tc.expectedGroups {
|
||||
exists := false
|
||||
for _, actualGroup := range respBody.AutoGroups {
|
||||
if expectedGroup == actualGroup {
|
||||
exists = true
|
||||
}
|
||||
}
|
||||
assert.True(t, exists, fmt.Sprintf("group %s not found in the response", expectedGroup))
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestCreateUser(t *testing.T) {
|
||||
name := "name"
|
||||
email := "email"
|
||||
|
@ -19,7 +19,7 @@ type MockAccountManager struct {
|
||||
expiresIn time.Duration, autoGroups []string, usageLimit int, userID string) (*server.SetupKey, error)
|
||||
GetSetupKeyFunc func(accountID, userID, keyID string) (*server.SetupKey, error)
|
||||
GetAccountByUserOrAccountIdFunc func(userId, accountId, domain string) (*server.Account, error)
|
||||
IsUserAdminFunc func(claims jwtclaims.AuthorizationClaims) (bool, error)
|
||||
GetUserFunc func(claims jwtclaims.AuthorizationClaims) (*server.User, error)
|
||||
AccountExistsFunc func(accountId string) (*bool, error)
|
||||
GetPeerByKeyFunc func(peerKey string) (*server.Peer, error)
|
||||
GetPeersFunc func(accountID, userID string) ([]*server.Peer, error)
|
||||
@ -60,11 +60,11 @@ type MockAccountManager struct {
|
||||
SaveSetupKeyFunc func(accountID string, key *server.SetupKey, userID string) (*server.SetupKey, error)
|
||||
ListSetupKeysFunc func(accountID, userID string) ([]*server.SetupKey, error)
|
||||
SaveUserFunc func(accountID, userID string, user *server.User) (*server.UserInfo, error)
|
||||
DeleteUserFunc func(accountID string, executingUserID string, targetUserID string) error
|
||||
CreatePATFunc func(accountID string, executingUserID string, targetUserId string, tokenName string, expiresIn int) (*server.PersonalAccessTokenGenerated, error)
|
||||
DeletePATFunc func(accountID string, executingUserID string, targetUserId string, tokenID string) error
|
||||
GetPATFunc func(accountID string, executingUserID string, targetUserId string, tokenID string) (*server.PersonalAccessToken, error)
|
||||
GetAllPATsFunc func(accountID string, executingUserID string, targetUserId string) ([]*server.PersonalAccessToken, error)
|
||||
DeleteUserFunc func(accountID string, initiatorUserID string, targetUserID string) error
|
||||
CreatePATFunc func(accountID string, initiatorUserID string, targetUserId string, tokenName string, expiresIn int) (*server.PersonalAccessTokenGenerated, error)
|
||||
DeletePATFunc func(accountID string, initiatorUserID string, targetUserId string, tokenID string) error
|
||||
GetPATFunc func(accountID string, initiatorUserID string, targetUserId string, tokenID string) (*server.PersonalAccessToken, error)
|
||||
GetAllPATsFunc func(accountID string, initiatorUserID string, targetUserId string) ([]*server.PersonalAccessToken, error)
|
||||
GetNameServerGroupFunc func(accountID, nsGroupID string) (*nbdns.NameServerGroup, error)
|
||||
CreateNameServerGroupFunc func(accountID string, name, description string, nameServerList []nbdns.NameServer, groups []string, primary bool, domains []string, enabled bool, userID string) (*nbdns.NameServerGroup, error)
|
||||
SaveNameServerGroupFunc func(accountID, userID string, nsGroupToSave *nbdns.NameServerGroup) error
|
||||
@ -190,33 +190,33 @@ func (am *MockAccountManager) MarkPATUsed(pat string) error {
|
||||
}
|
||||
|
||||
// CreatePAT mock implementation of GetPAT from server.AccountManager interface
|
||||
func (am *MockAccountManager) CreatePAT(accountID string, executingUserID string, targetUserID string, name string, expiresIn int) (*server.PersonalAccessTokenGenerated, error) {
|
||||
func (am *MockAccountManager) CreatePAT(accountID string, initiatorUserID string, targetUserID string, name string, expiresIn int) (*server.PersonalAccessTokenGenerated, error) {
|
||||
if am.CreatePATFunc != nil {
|
||||
return am.CreatePATFunc(accountID, executingUserID, targetUserID, name, expiresIn)
|
||||
return am.CreatePATFunc(accountID, initiatorUserID, targetUserID, name, expiresIn)
|
||||
}
|
||||
return nil, status.Errorf(codes.Unimplemented, "method CreatePAT is not implemented")
|
||||
}
|
||||
|
||||
// DeletePAT mock implementation of DeletePAT from server.AccountManager interface
|
||||
func (am *MockAccountManager) DeletePAT(accountID string, executingUserID string, targetUserID string, tokenID string) error {
|
||||
func (am *MockAccountManager) DeletePAT(accountID string, initiatorUserID string, targetUserID string, tokenID string) error {
|
||||
if am.DeletePATFunc != nil {
|
||||
return am.DeletePATFunc(accountID, executingUserID, targetUserID, tokenID)
|
||||
return am.DeletePATFunc(accountID, initiatorUserID, targetUserID, tokenID)
|
||||
}
|
||||
return status.Errorf(codes.Unimplemented, "method DeletePAT is not implemented")
|
||||
}
|
||||
|
||||
// GetPAT mock implementation of GetPAT from server.AccountManager interface
|
||||
func (am *MockAccountManager) GetPAT(accountID string, executingUserID string, targetUserID string, tokenID string) (*server.PersonalAccessToken, error) {
|
||||
func (am *MockAccountManager) GetPAT(accountID string, initiatorUserID string, targetUserID string, tokenID string) (*server.PersonalAccessToken, error) {
|
||||
if am.GetPATFunc != nil {
|
||||
return am.GetPATFunc(accountID, executingUserID, targetUserID, tokenID)
|
||||
return am.GetPATFunc(accountID, initiatorUserID, targetUserID, tokenID)
|
||||
}
|
||||
return nil, status.Errorf(codes.Unimplemented, "method GetPAT is not implemented")
|
||||
}
|
||||
|
||||
// GetAllPATs mock implementation of GetAllPATs from server.AccountManager interface
|
||||
func (am *MockAccountManager) GetAllPATs(accountID string, executingUserID string, targetUserID string) ([]*server.PersonalAccessToken, error) {
|
||||
func (am *MockAccountManager) GetAllPATs(accountID string, initiatorUserID string, targetUserID string) ([]*server.PersonalAccessToken, error) {
|
||||
if am.GetAllPATsFunc != nil {
|
||||
return am.GetAllPATsFunc(accountID, executingUserID, targetUserID)
|
||||
return am.GetAllPATsFunc(accountID, initiatorUserID, targetUserID)
|
||||
}
|
||||
return nil, status.Errorf(codes.Unimplemented, "method GetAllPATs is not implemented")
|
||||
}
|
||||
@ -385,12 +385,12 @@ func (am *MockAccountManager) UpdatePeerMeta(peerID string, meta server.PeerSyst
|
||||
return status.Errorf(codes.Unimplemented, "method UpdatePeerMetaFunc is not implemented")
|
||||
}
|
||||
|
||||
// IsUserAdmin mock implementation of IsUserAdmin from server.AccountManager interface
|
||||
func (am *MockAccountManager) IsUserAdmin(claims jwtclaims.AuthorizationClaims) (bool, error) {
|
||||
if am.IsUserAdminFunc != nil {
|
||||
return am.IsUserAdminFunc(claims)
|
||||
// GetUser mock implementation of GetUser from server.AccountManager interface
|
||||
func (am *MockAccountManager) GetUser(claims jwtclaims.AuthorizationClaims) (*server.User, error) {
|
||||
if am.GetUserFunc != nil {
|
||||
return am.GetUserFunc(claims)
|
||||
}
|
||||
return false, status.Errorf(codes.Unimplemented, "method IsUserAdmin is not implemented")
|
||||
return nil, status.Errorf(codes.Unimplemented, "method IsUserGetUserAdmin is not implemented")
|
||||
}
|
||||
|
||||
// UpdatePeerSSHKey mocks UpdatePeerSSHKey function of the account manager
|
||||
@ -493,9 +493,9 @@ func (am *MockAccountManager) SaveUser(accountID, userID string, user *server.Us
|
||||
}
|
||||
|
||||
// DeleteUser mocks DeleteUser of the AccountManager interface
|
||||
func (am *MockAccountManager) DeleteUser(accountID string, executingUserID string, targetUserID string) error {
|
||||
func (am *MockAccountManager) DeleteUser(accountID string, initiatorUserID string, targetUserID string) error {
|
||||
if am.DeleteUserFunc != nil {
|
||||
return am.DeleteUserFunc(accountID, executingUserID, targetUserID)
|
||||
return am.DeleteUserFunc(accountID, initiatorUserID, targetUserID)
|
||||
}
|
||||
return status.Errorf(codes.Unimplemented, "method DeleteUser is not implemented")
|
||||
}
|
||||
|
@ -605,6 +605,11 @@ func (am *DefaultAccountManager) SyncPeer(sync PeerSync) (*Peer, *NetworkMap, er
|
||||
return nil, nil, status.Errorf(status.Unauthenticated, "peer is not registered")
|
||||
}
|
||||
|
||||
err = checkIfPeerOwnerIsBlocked(peer, account)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
if peerLoginExpired(peer, account) {
|
||||
return nil, nil, status.Errorf(status.PermissionDenied, "peer login has expired, please log in once more")
|
||||
}
|
||||
@ -644,6 +649,11 @@ func (am *DefaultAccountManager) LoginPeer(login PeerLogin) (*Peer, *NetworkMap,
|
||||
return nil, nil, status.Errorf(status.Unauthenticated, "peer is not registered")
|
||||
}
|
||||
|
||||
err = checkIfPeerOwnerIsBlocked(peer, account)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
updateRemotePeers := false
|
||||
if peerLoginExpired(peer, account) {
|
||||
err = checkAuth(login.UserID, peer)
|
||||
@ -676,6 +686,19 @@ func (am *DefaultAccountManager) LoginPeer(login PeerLogin) (*Peer, *NetworkMap,
|
||||
return peer, account.GetPeerNetworkMap(peer.ID, am.dnsDomain), nil
|
||||
}
|
||||
|
||||
func checkIfPeerOwnerIsBlocked(peer *Peer, account *Account) error {
|
||||
if peer.AddedWithSSOLogin() {
|
||||
user, err := account.FindUser(peer.UserID)
|
||||
if err != nil {
|
||||
return status.Errorf(status.PermissionDenied, "user doesn't exist")
|
||||
}
|
||||
if user.IsBlocked() {
|
||||
return status.Errorf(status.PermissionDenied, "user is blocked")
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func checkAuth(loginUserID string, peer *Peer) error {
|
||||
if loginUserID == "" {
|
||||
// absence of a user ID indicates that JWT wasn't provided.
|
||||
|
@ -51,15 +51,22 @@ type User struct {
|
||||
// AutoGroups is a list of Group IDs to auto-assign to peers registered by this user
|
||||
AutoGroups []string
|
||||
PATs map[string]*PersonalAccessToken
|
||||
// Blocked indicates whether the user is blocked. Blocked users can't use the system.
|
||||
Blocked bool
|
||||
}
|
||||
|
||||
// IsAdmin returns true if user is an admin, false otherwise
|
||||
// IsBlocked returns true if the user is blocked, false otherwise
|
||||
func (u *User) IsBlocked() bool {
|
||||
return u.Blocked
|
||||
}
|
||||
|
||||
// IsAdmin returns true if the user is an admin, false otherwise
|
||||
func (u *User) IsAdmin() bool {
|
||||
return u.Role == UserRoleAdmin
|
||||
}
|
||||
|
||||
// toUserInfo converts a User object to a UserInfo object.
|
||||
func (u *User) toUserInfo(userData *idp.UserData) (*UserInfo, error) {
|
||||
// ToUserInfo converts a User object to a UserInfo object.
|
||||
func (u *User) ToUserInfo(userData *idp.UserData) (*UserInfo, error) {
|
||||
autoGroups := u.AutoGroups
|
||||
if autoGroups == nil {
|
||||
autoGroups = []string{}
|
||||
@ -74,6 +81,7 @@ func (u *User) toUserInfo(userData *idp.UserData) (*UserInfo, error) {
|
||||
AutoGroups: u.AutoGroups,
|
||||
Status: string(UserStatusActive),
|
||||
IsServiceUser: u.IsServiceUser,
|
||||
IsBlocked: u.Blocked,
|
||||
}, nil
|
||||
}
|
||||
if userData.ID != u.Id {
|
||||
@ -93,6 +101,7 @@ func (u *User) toUserInfo(userData *idp.UserData) (*UserInfo, error) {
|
||||
AutoGroups: autoGroups,
|
||||
Status: string(userStatus),
|
||||
IsServiceUser: u.IsServiceUser,
|
||||
IsBlocked: u.Blocked,
|
||||
}, nil
|
||||
}
|
||||
|
||||
@ -113,6 +122,7 @@ func (u *User) Copy() *User {
|
||||
IsServiceUser: u.IsServiceUser,
|
||||
ServiceUserName: u.ServiceUserName,
|
||||
PATs: pats,
|
||||
Blocked: u.Blocked,
|
||||
}
|
||||
}
|
||||
|
||||
@ -138,7 +148,7 @@ func NewAdminUser(id string) *User {
|
||||
}
|
||||
|
||||
// createServiceUser creates a new service user under the given account.
|
||||
func (am *DefaultAccountManager) createServiceUser(accountID string, executingUserID string, role UserRole, serviceUserName string, autoGroups []string) (*UserInfo, error) {
|
||||
func (am *DefaultAccountManager) createServiceUser(accountID string, initiatorUserID string, role UserRole, serviceUserName string, autoGroups []string) (*UserInfo, error) {
|
||||
unlock := am.Store.AcquireAccountLock(accountID)
|
||||
defer unlock()
|
||||
|
||||
@ -147,7 +157,7 @@ func (am *DefaultAccountManager) createServiceUser(accountID string, executingUs
|
||||
return nil, status.Errorf(status.NotFound, "account %s doesn't exist", accountID)
|
||||
}
|
||||
|
||||
executingUser := account.Users[executingUserID]
|
||||
executingUser := account.Users[initiatorUserID]
|
||||
if executingUser == nil {
|
||||
return nil, status.Errorf(status.NotFound, "user not found")
|
||||
}
|
||||
@ -166,7 +176,7 @@ func (am *DefaultAccountManager) createServiceUser(accountID string, executingUs
|
||||
}
|
||||
|
||||
meta := map[string]any{"name": newUser.ServiceUserName}
|
||||
am.storeEvent(executingUserID, newUser.Id, accountID, activity.ServiceUserCreated, meta)
|
||||
am.storeEvent(initiatorUserID, newUser.Id, accountID, activity.ServiceUserCreated, meta)
|
||||
|
||||
return &UserInfo{
|
||||
ID: newUser.Id,
|
||||
@ -212,7 +222,7 @@ func (am *DefaultAccountManager) inviteNewUser(accountID, userID string, invite
|
||||
}
|
||||
|
||||
if user != nil {
|
||||
return nil, status.Errorf(status.UserAlreadyExists, "user has an existing account")
|
||||
return nil, status.Errorf(status.UserAlreadyExists, "can't invite a user with an existing NetBird account")
|
||||
}
|
||||
|
||||
users, err := am.idpManager.GetUserByEmail(invite.Email)
|
||||
@ -221,7 +231,7 @@ func (am *DefaultAccountManager) inviteNewUser(accountID, userID string, invite
|
||||
}
|
||||
|
||||
if len(users) > 0 {
|
||||
return nil, status.Errorf(status.UserAlreadyExists, "user has an existing account")
|
||||
return nil, status.Errorf(status.UserAlreadyExists, "can't invite a user with an existing NetBird account")
|
||||
}
|
||||
|
||||
idpUser, err := am.idpManager.CreateUser(invite.Email, invite.Name, accountID)
|
||||
@ -249,12 +259,27 @@ func (am *DefaultAccountManager) inviteNewUser(accountID, userID string, invite
|
||||
|
||||
am.storeEvent(userID, newUser.Id, accountID, activity.UserInvited, nil)
|
||||
|
||||
return newUser.toUserInfo(idpUser)
|
||||
return newUser.ToUserInfo(idpUser)
|
||||
|
||||
}
|
||||
|
||||
// GetUser looks up a user by provided authorization claims.
|
||||
// It will also create an account if didn't exist for this user before.
|
||||
func (am *DefaultAccountManager) GetUser(claims jwtclaims.AuthorizationClaims) (*User, error) {
|
||||
account, _, err := am.GetAccountFromToken(claims)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to get account with token claims %v", err)
|
||||
}
|
||||
|
||||
user, ok := account.Users[claims.UserId]
|
||||
if !ok {
|
||||
return nil, status.Errorf(status.NotFound, "user not found")
|
||||
}
|
||||
return user, nil
|
||||
}
|
||||
|
||||
// DeleteUser deletes a user from the given account.
|
||||
func (am *DefaultAccountManager) DeleteUser(accountID, executingUserID string, targetUserID string) error {
|
||||
func (am *DefaultAccountManager) DeleteUser(accountID, initiatorUserID string, targetUserID string) error {
|
||||
unlock := am.Store.AcquireAccountLock(accountID)
|
||||
defer unlock()
|
||||
|
||||
@ -268,7 +293,7 @@ func (am *DefaultAccountManager) DeleteUser(accountID, executingUserID string, t
|
||||
return status.Errorf(status.NotFound, "user not found")
|
||||
}
|
||||
|
||||
executingUser := account.Users[executingUserID]
|
||||
executingUser := account.Users[initiatorUserID]
|
||||
if executingUser == nil {
|
||||
return status.Errorf(status.NotFound, "user not found")
|
||||
}
|
||||
@ -281,7 +306,7 @@ func (am *DefaultAccountManager) DeleteUser(accountID, executingUserID string, t
|
||||
}
|
||||
|
||||
meta := map[string]any{"name": targetUser.ServiceUserName}
|
||||
am.storeEvent(executingUserID, targetUserID, accountID, activity.ServiceUserDeleted, meta)
|
||||
am.storeEvent(initiatorUserID, targetUserID, accountID, activity.ServiceUserDeleted, meta)
|
||||
|
||||
delete(account.Users, targetUserID)
|
||||
|
||||
@ -294,7 +319,7 @@ func (am *DefaultAccountManager) DeleteUser(accountID, executingUserID string, t
|
||||
}
|
||||
|
||||
// CreatePAT creates a new PAT for the given user
|
||||
func (am *DefaultAccountManager) CreatePAT(accountID string, executingUserID string, targetUserID string, tokenName string, expiresIn int) (*PersonalAccessTokenGenerated, error) {
|
||||
func (am *DefaultAccountManager) CreatePAT(accountID string, initiatorUserID string, targetUserID string, tokenName string, expiresIn int) (*PersonalAccessTokenGenerated, error) {
|
||||
unlock := am.Store.AcquireAccountLock(accountID)
|
||||
defer unlock()
|
||||
|
||||
@ -316,12 +341,12 @@ func (am *DefaultAccountManager) CreatePAT(accountID string, executingUserID str
|
||||
return nil, status.Errorf(status.NotFound, "targetUser not found")
|
||||
}
|
||||
|
||||
executingUser := account.Users[executingUserID]
|
||||
executingUser := account.Users[initiatorUserID]
|
||||
if targetUser == nil {
|
||||
return nil, status.Errorf(status.NotFound, "user not found")
|
||||
}
|
||||
|
||||
if !(executingUserID == targetUserID || (executingUser.IsAdmin() && targetUser.IsServiceUser)) {
|
||||
if !(initiatorUserID == targetUserID || (executingUser.IsAdmin() && targetUser.IsServiceUser)) {
|
||||
return nil, status.Errorf(status.PermissionDenied, "no permission to create PAT for this user")
|
||||
}
|
||||
|
||||
@ -338,13 +363,13 @@ func (am *DefaultAccountManager) CreatePAT(accountID string, executingUserID str
|
||||
}
|
||||
|
||||
meta := map[string]any{"name": pat.Name, "is_service_user": targetUser.IsServiceUser, "user_name": targetUser.ServiceUserName}
|
||||
am.storeEvent(executingUserID, targetUserID, accountID, activity.PersonalAccessTokenCreated, meta)
|
||||
am.storeEvent(initiatorUserID, targetUserID, accountID, activity.PersonalAccessTokenCreated, meta)
|
||||
|
||||
return pat, nil
|
||||
}
|
||||
|
||||
// DeletePAT deletes a specific PAT from a user
|
||||
func (am *DefaultAccountManager) DeletePAT(accountID string, executingUserID string, targetUserID string, tokenID string) error {
|
||||
func (am *DefaultAccountManager) DeletePAT(accountID string, initiatorUserID string, targetUserID string, tokenID string) error {
|
||||
unlock := am.Store.AcquireAccountLock(accountID)
|
||||
defer unlock()
|
||||
|
||||
@ -358,12 +383,12 @@ func (am *DefaultAccountManager) DeletePAT(accountID string, executingUserID str
|
||||
return status.Errorf(status.NotFound, "user not found")
|
||||
}
|
||||
|
||||
executingUser := account.Users[executingUserID]
|
||||
executingUser := account.Users[initiatorUserID]
|
||||
if targetUser == nil {
|
||||
return status.Errorf(status.NotFound, "user not found")
|
||||
}
|
||||
|
||||
if !(executingUserID == targetUserID || (executingUser.IsAdmin() && targetUser.IsServiceUser)) {
|
||||
if !(initiatorUserID == targetUserID || (executingUser.IsAdmin() && targetUser.IsServiceUser)) {
|
||||
return status.Errorf(status.PermissionDenied, "no permission to delete PAT for this user")
|
||||
}
|
||||
|
||||
@ -382,7 +407,7 @@ func (am *DefaultAccountManager) DeletePAT(accountID string, executingUserID str
|
||||
}
|
||||
|
||||
meta := map[string]any{"name": pat.Name, "is_service_user": targetUser.IsServiceUser, "user_name": targetUser.ServiceUserName}
|
||||
am.storeEvent(executingUserID, targetUserID, accountID, activity.PersonalAccessTokenDeleted, meta)
|
||||
am.storeEvent(initiatorUserID, targetUserID, accountID, activity.PersonalAccessTokenDeleted, meta)
|
||||
|
||||
delete(targetUser.PATs, tokenID)
|
||||
|
||||
@ -394,7 +419,7 @@ func (am *DefaultAccountManager) DeletePAT(accountID string, executingUserID str
|
||||
}
|
||||
|
||||
// GetPAT returns a specific PAT from a user
|
||||
func (am *DefaultAccountManager) GetPAT(accountID string, executingUserID string, targetUserID string, tokenID string) (*PersonalAccessToken, error) {
|
||||
func (am *DefaultAccountManager) GetPAT(accountID string, initiatorUserID string, targetUserID string, tokenID string) (*PersonalAccessToken, error) {
|
||||
unlock := am.Store.AcquireAccountLock(accountID)
|
||||
defer unlock()
|
||||
|
||||
@ -408,12 +433,12 @@ func (am *DefaultAccountManager) GetPAT(accountID string, executingUserID string
|
||||
return nil, status.Errorf(status.NotFound, "user not found")
|
||||
}
|
||||
|
||||
executingUser := account.Users[executingUserID]
|
||||
executingUser := account.Users[initiatorUserID]
|
||||
if targetUser == nil {
|
||||
return nil, status.Errorf(status.NotFound, "user not found")
|
||||
}
|
||||
|
||||
if !(executingUserID == targetUserID || (executingUser.IsAdmin() && targetUser.IsServiceUser)) {
|
||||
if !(initiatorUserID == targetUserID || (executingUser.IsAdmin() && targetUser.IsServiceUser)) {
|
||||
return nil, status.Errorf(status.PermissionDenied, "no permission to get PAT for this userser")
|
||||
}
|
||||
|
||||
@ -426,7 +451,7 @@ func (am *DefaultAccountManager) GetPAT(accountID string, executingUserID string
|
||||
}
|
||||
|
||||
// GetAllPATs returns all PATs for a user
|
||||
func (am *DefaultAccountManager) GetAllPATs(accountID string, executingUserID string, targetUserID string) ([]*PersonalAccessToken, error) {
|
||||
func (am *DefaultAccountManager) GetAllPATs(accountID string, initiatorUserID string, targetUserID string) ([]*PersonalAccessToken, error) {
|
||||
unlock := am.Store.AcquireAccountLock(accountID)
|
||||
defer unlock()
|
||||
|
||||
@ -440,12 +465,12 @@ func (am *DefaultAccountManager) GetAllPATs(accountID string, executingUserID st
|
||||
return nil, status.Errorf(status.NotFound, "user not found")
|
||||
}
|
||||
|
||||
executingUser := account.Users[executingUserID]
|
||||
executingUser := account.Users[initiatorUserID]
|
||||
if targetUser == nil {
|
||||
return nil, status.Errorf(status.NotFound, "user not found")
|
||||
}
|
||||
|
||||
if !(executingUserID == targetUserID || (executingUser.IsAdmin() && targetUser.IsServiceUser)) {
|
||||
if !(initiatorUserID == targetUserID || (executingUser.IsAdmin() && targetUser.IsServiceUser)) {
|
||||
return nil, status.Errorf(status.PermissionDenied, "no permission to get PAT for this user")
|
||||
}
|
||||
|
||||
@ -457,9 +482,9 @@ func (am *DefaultAccountManager) GetAllPATs(accountID string, executingUserID st
|
||||
return pats, nil
|
||||
}
|
||||
|
||||
// SaveUser saves updates a given user. If the user doesn't exit it will throw status.NotFound error.
|
||||
// Only User.AutoGroups field is allowed to be updated for now.
|
||||
func (am *DefaultAccountManager) SaveUser(accountID, userID string, update *User) (*UserInfo, error) {
|
||||
// SaveUser saves updates to the given user. If the user doesn't exit it will throw status.NotFound error.
|
||||
// Only User.AutoGroups, User.Role, and User.Blocked fields are allowed to be updated for now.
|
||||
func (am *DefaultAccountManager) SaveUser(accountID, initiatorUserID string, update *User) (*UserInfo, error) {
|
||||
unlock := am.Store.AcquireAccountLock(accountID)
|
||||
defer unlock()
|
||||
|
||||
@ -472,56 +497,102 @@ func (am *DefaultAccountManager) SaveUser(accountID, userID string, update *User
|
||||
return nil, err
|
||||
}
|
||||
|
||||
initiatorUser, err := account.FindUser(initiatorUserID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if !initiatorUser.IsAdmin() || initiatorUser.IsBlocked() {
|
||||
return nil, status.Errorf(status.PermissionDenied, "only admins are authorized to perform user update operations")
|
||||
}
|
||||
|
||||
oldUser := account.Users[update.Id]
|
||||
if oldUser == nil {
|
||||
return nil, status.Errorf(status.NotFound, "user to update doesn't exist")
|
||||
}
|
||||
|
||||
if initiatorUser.IsAdmin() && initiatorUserID == update.Id && oldUser.Blocked != update.Blocked {
|
||||
return nil, status.Errorf(status.PermissionDenied, "admins can't block or unblock themselves")
|
||||
}
|
||||
|
||||
if initiatorUser.IsAdmin() && initiatorUserID == update.Id && update.Role != UserRoleAdmin {
|
||||
return nil, status.Errorf(status.PermissionDenied, "admins can't change their role")
|
||||
}
|
||||
|
||||
// only auto groups, revoked status, and name can be updated for now
|
||||
newUser := oldUser.Copy()
|
||||
newUser.Role = update.Role
|
||||
newUser.Blocked = update.Blocked
|
||||
|
||||
for _, newGroupID := range update.AutoGroups {
|
||||
if _, ok := account.Groups[newGroupID]; !ok {
|
||||
return nil, status.Errorf(status.InvalidArgument, "provided group ID %s in the user %s update doesn't exist",
|
||||
newGroupID, update.Id)
|
||||
}
|
||||
}
|
||||
|
||||
oldUser := account.Users[update.Id]
|
||||
if oldUser == nil {
|
||||
return nil, status.Errorf(status.NotFound, "update not found")
|
||||
}
|
||||
|
||||
// only auto groups, revoked status, and name can be updated for now
|
||||
newUser := oldUser.Copy()
|
||||
newUser.AutoGroups = update.AutoGroups
|
||||
newUser.Role = update.Role
|
||||
|
||||
account.Users[newUser.Id] = newUser
|
||||
|
||||
if !oldUser.IsBlocked() && update.IsBlocked() {
|
||||
// expire peers that belong to the user who's getting blocked
|
||||
blockedPeers, err := account.FindUserPeers(update.Id)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
var peerIDs []string
|
||||
for _, peer := range blockedPeers {
|
||||
peerIDs = append(peerIDs, peer.ID)
|
||||
peer.MarkLoginExpired(true)
|
||||
account.UpdatePeer(peer)
|
||||
err = am.Store.SavePeerStatus(account.Id, peer.ID, *peer.Status)
|
||||
if err != nil {
|
||||
log.Errorf("failed saving peer status while expiring peer %s", peer.ID)
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
am.peersUpdateManager.CloseChannels(peerIDs)
|
||||
err = am.updateAccountPeers(account)
|
||||
if err != nil {
|
||||
log.Errorf("failed updating account peers while expiring peers of a blocked user %s", accountID)
|
||||
return nil, err
|
||||
|
||||
}
|
||||
}
|
||||
|
||||
if err = am.Store.SaveAccount(account); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
defer func() {
|
||||
// store activity logs
|
||||
if oldUser.Role != newUser.Role {
|
||||
am.storeEvent(userID, oldUser.Id, accountID, activity.UserRoleUpdated, map[string]any{"role": newUser.Role})
|
||||
am.storeEvent(initiatorUserID, oldUser.Id, accountID, activity.UserRoleUpdated, map[string]any{"role": newUser.Role})
|
||||
}
|
||||
|
||||
removedGroups := difference(oldUser.AutoGroups, update.AutoGroups)
|
||||
addedGroups := difference(newUser.AutoGroups, oldUser.AutoGroups)
|
||||
for _, g := range removedGroups {
|
||||
group := account.GetGroup(g)
|
||||
if group != nil {
|
||||
am.storeEvent(userID, oldUser.Id, accountID, activity.GroupRemovedFromUser,
|
||||
map[string]any{"group": group.Name, "group_id": group.ID, "is_service_user": newUser.IsServiceUser, "user_name": newUser.ServiceUserName})
|
||||
} else {
|
||||
log.Errorf("group %s not found while saving user activity event of account %s", g, account.Id)
|
||||
if update.AutoGroups != nil {
|
||||
removedGroups := difference(oldUser.AutoGroups, update.AutoGroups)
|
||||
addedGroups := difference(newUser.AutoGroups, oldUser.AutoGroups)
|
||||
for _, g := range removedGroups {
|
||||
group := account.GetGroup(g)
|
||||
if group != nil {
|
||||
am.storeEvent(initiatorUserID, oldUser.Id, accountID, activity.GroupRemovedFromUser,
|
||||
map[string]any{"group": group.Name, "group_id": group.ID, "is_service_user": newUser.IsServiceUser, "user_name": newUser.ServiceUserName})
|
||||
} else {
|
||||
log.Errorf("group %s not found while saving user activity event of account %s", g, account.Id)
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
for _, g := range addedGroups {
|
||||
group := account.GetGroup(g)
|
||||
if group != nil {
|
||||
am.storeEvent(userID, oldUser.Id, accountID, activity.GroupAddedToUser,
|
||||
map[string]any{"group": group.Name, "group_id": group.ID, "is_service_user": newUser.IsServiceUser, "user_name": newUser.ServiceUserName})
|
||||
} else {
|
||||
log.Errorf("group %s not found while saving user activity event of account %s", g, account.Id)
|
||||
for _, g := range addedGroups {
|
||||
group := account.GetGroup(g)
|
||||
if group != nil {
|
||||
am.storeEvent(initiatorUserID, oldUser.Id, accountID, activity.GroupAddedToUser,
|
||||
map[string]any{"group": group.Name, "group_id": group.ID, "is_service_user": newUser.IsServiceUser, "user_name": newUser.ServiceUserName})
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
}()
|
||||
|
||||
if !isNil(am.idpManager) && !newUser.IsServiceUser {
|
||||
@ -532,9 +603,9 @@ func (am *DefaultAccountManager) SaveUser(accountID, userID string, update *User
|
||||
if userData == nil {
|
||||
return nil, status.Errorf(status.NotFound, "user %s not found in the IdP", newUser.Id)
|
||||
}
|
||||
return newUser.toUserInfo(userData)
|
||||
return newUser.ToUserInfo(userData)
|
||||
}
|
||||
return newUser.toUserInfo(nil)
|
||||
return newUser.ToUserInfo(nil)
|
||||
}
|
||||
|
||||
// GetOrCreateAccountByUser returns an existing account for a given user id or creates a new one if doesn't exist
|
||||
@ -574,21 +645,6 @@ func (am *DefaultAccountManager) GetOrCreateAccountByUser(userID, domain string)
|
||||
return account, nil
|
||||
}
|
||||
|
||||
// IsUserAdmin looks up a user by his ID and returns true if he is an admin
|
||||
func (am *DefaultAccountManager) IsUserAdmin(claims jwtclaims.AuthorizationClaims) (bool, error) {
|
||||
account, _, err := am.GetAccountFromToken(claims)
|
||||
if err != nil {
|
||||
return false, fmt.Errorf("get account: %v", err)
|
||||
}
|
||||
|
||||
user, ok := account.Users[claims.UserId]
|
||||
if !ok {
|
||||
return false, status.Errorf(status.NotFound, "user not found")
|
||||
}
|
||||
|
||||
return user.Role == UserRoleAdmin, nil
|
||||
}
|
||||
|
||||
// GetUsersFromAccount performs a batched request for users from IDP by account ID apply filter on what data to return
|
||||
// based on provided user role.
|
||||
func (am *DefaultAccountManager) GetUsersFromAccount(accountID, userID string) ([]*UserInfo, error) {
|
||||
@ -625,7 +681,7 @@ func (am *DefaultAccountManager) GetUsersFromAccount(accountID, userID string) (
|
||||
// if user is not an admin then show only current user and do not show other users
|
||||
continue
|
||||
}
|
||||
info, err := accountUser.toUserInfo(nil)
|
||||
info, err := accountUser.ToUserInfo(nil)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@ -642,7 +698,7 @@ func (am *DefaultAccountManager) GetUsersFromAccount(accountID, userID string) (
|
||||
|
||||
var info *UserInfo
|
||||
if queriedUser, contains := findUserInIDPUserdata(localUser.Id, queriedUsers); contains {
|
||||
info, err = localUser.toUserInfo(queriedUser)
|
||||
info, err = localUser.ToUserInfo(queriedUser)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
@ -8,6 +8,7 @@ import (
|
||||
|
||||
"github.com/google/go-cmp/cmp"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/netbirdio/netbird/management/server/activity"
|
||||
"github.com/netbirdio/netbird/management/server/jwtclaims"
|
||||
@ -265,6 +266,7 @@ func TestUser_Copy(t *testing.T) {
|
||||
LastUsed: time.Now(),
|
||||
},
|
||||
},
|
||||
Blocked: false,
|
||||
}
|
||||
|
||||
err := validateStruct(user)
|
||||
@ -288,7 +290,7 @@ func validateStruct(s interface{}) (err error) {
|
||||
field := structVal.Field(i)
|
||||
fieldName := structType.Field(i).Name
|
||||
|
||||
isSet := field.IsValid() && !field.IsZero()
|
||||
isSet := field.IsValid() && (!field.IsZero() || field.Type().String() == "bool")
|
||||
|
||||
if !isSet {
|
||||
err = fmt.Errorf("%v%s in not set; ", err, fieldName)
|
||||
@ -440,7 +442,7 @@ func TestUser_DeleteUser_regularUser(t *testing.T) {
|
||||
assert.Errorf(t, err, "Regular users can not be deleted (yet)")
|
||||
}
|
||||
|
||||
func TestUser_IsUserAdmin_ForAdmin(t *testing.T) {
|
||||
func TestDefaultAccountManager_GetUser(t *testing.T) {
|
||||
store := newStore(t)
|
||||
account := newAccountWithId(mockAccountID, mockUserID, "")
|
||||
|
||||
@ -458,42 +460,23 @@ func TestUser_IsUserAdmin_ForAdmin(t *testing.T) {
|
||||
UserId: mockUserID,
|
||||
}
|
||||
|
||||
ok, err := am.IsUserAdmin(claims)
|
||||
user, err := am.GetUser(claims)
|
||||
if err != nil {
|
||||
t.Fatalf("Error when checking user role: %s", err)
|
||||
}
|
||||
|
||||
assert.True(t, ok)
|
||||
assert.Equal(t, mockUserID, user.Id)
|
||||
assert.True(t, user.IsAdmin())
|
||||
assert.False(t, user.IsBlocked())
|
||||
}
|
||||
|
||||
func TestUser_IsUserAdmin_ForUser(t *testing.T) {
|
||||
store := newStore(t)
|
||||
account := newAccountWithId(mockAccountID, mockUserID, "")
|
||||
account.Users[mockUserID] = &User{
|
||||
Id: mockUserID,
|
||||
Role: "user",
|
||||
}
|
||||
func TestUser_IsAdmin(t *testing.T) {
|
||||
|
||||
err := store.SaveAccount(account)
|
||||
if err != nil {
|
||||
t.Fatalf("Error when saving account: %s", err)
|
||||
}
|
||||
user := NewAdminUser(mockUserID)
|
||||
assert.True(t, user.IsAdmin())
|
||||
|
||||
am := DefaultAccountManager{
|
||||
Store: store,
|
||||
eventStore: &activity.InMemoryEventStore{},
|
||||
}
|
||||
|
||||
claims := jwtclaims.AuthorizationClaims{
|
||||
UserId: mockUserID,
|
||||
}
|
||||
|
||||
ok, err := am.IsUserAdmin(claims)
|
||||
if err != nil {
|
||||
t.Fatalf("Error when checking user role: %s", err)
|
||||
}
|
||||
|
||||
assert.False(t, ok)
|
||||
user = NewRegularUser(mockUserID)
|
||||
assert.False(t, user.IsAdmin())
|
||||
}
|
||||
|
||||
func TestUser_GetUsersFromAccount_ForAdmin(t *testing.T) {
|
||||
@ -550,3 +533,103 @@ func TestUser_GetUsersFromAccount_ForUser(t *testing.T) {
|
||||
assert.Equal(t, 1, len(users))
|
||||
assert.Equal(t, mockServiceUserID, users[0].ID)
|
||||
}
|
||||
|
||||
func TestDefaultAccountManager_SaveUser(t *testing.T) {
|
||||
manager, err := createManager(t)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
return
|
||||
}
|
||||
|
||||
regularUserID := "regularUser"
|
||||
|
||||
tt := []struct {
|
||||
name string
|
||||
adminInitiator bool
|
||||
update *User
|
||||
expectedErr bool
|
||||
}{
|
||||
{
|
||||
name: "Should_Fail_To_Update_Admin_Role",
|
||||
expectedErr: true,
|
||||
adminInitiator: true,
|
||||
update: &User{
|
||||
Id: userID,
|
||||
Role: UserRoleUser,
|
||||
Blocked: false,
|
||||
},
|
||||
}, {
|
||||
name: "Should_Fail_When_Admin_Blocks_Themselves",
|
||||
expectedErr: true,
|
||||
adminInitiator: true,
|
||||
update: &User{
|
||||
Id: userID,
|
||||
Role: UserRoleAdmin,
|
||||
Blocked: true,
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "Should_Fail_To_Update_Non_Existing_User",
|
||||
expectedErr: true,
|
||||
adminInitiator: true,
|
||||
update: &User{
|
||||
Id: userID,
|
||||
Role: UserRoleAdmin,
|
||||
Blocked: true,
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "Should_Fail_To_Update_When_Initiator_Is_Not_An_Admin",
|
||||
expectedErr: true,
|
||||
adminInitiator: false,
|
||||
update: &User{
|
||||
Id: userID,
|
||||
Role: UserRoleAdmin,
|
||||
Blocked: true,
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "Should_Update_User",
|
||||
expectedErr: false,
|
||||
adminInitiator: true,
|
||||
update: &User{
|
||||
Id: regularUserID,
|
||||
Role: UserRoleAdmin,
|
||||
Blocked: true,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range tt {
|
||||
|
||||
// create an account and an admin user
|
||||
account, err := manager.GetOrCreateAccountByUser(userID, "netbird.io")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
// create a regular user
|
||||
account.Users[regularUserID] = NewRegularUser(regularUserID)
|
||||
err = manager.Store.SaveAccount(account)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
initiatorID := userID
|
||||
if !tc.adminInitiator {
|
||||
initiatorID = regularUserID
|
||||
}
|
||||
|
||||
updated, err := manager.SaveUser(account.Id, initiatorID, tc.update)
|
||||
if tc.expectedErr {
|
||||
require.Errorf(t, err, "expecting SaveUser to throw an error")
|
||||
} else {
|
||||
require.NoError(t, err, "expecting SaveUser not to throw an error")
|
||||
assert.NotNil(t, updated)
|
||||
|
||||
assert.Equal(t, string(tc.update.Role), updated.Role)
|
||||
assert.Equal(t, tc.update.IsBlocked(), updated.IsBlocked)
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user