diff --git a/management/client/rest/users.go b/management/client/rest/users.go index 372bcee45..31ffad051 100644 --- a/management/client/rest/users.go +++ b/management/client/rest/users.go @@ -80,3 +80,16 @@ func (a *UsersAPI) ResendInvitation(ctx context.Context, userID string) error { return nil } + +// Current gets the current user info +// See more: https://docs.netbird.io/api/resources/users#retrieve-current-user +func (a *UsersAPI) Current(ctx context.Context) (*api.User, error) { + resp, err := a.c.newRequest(ctx, "GET", "/api/users/current", nil) + if err != nil { + return nil, err + } + defer resp.Body.Close() + + ret, err := parseResponse[api.User](resp) + return &ret, err +} diff --git a/management/client/rest/users_test.go b/management/client/rest/users_test.go index 2ff8a0327..f68c5f083 100644 --- a/management/client/rest/users_test.go +++ b/management/client/rest/users_test.go @@ -196,8 +196,42 @@ func TestUsers_ResendInvitation_Err(t *testing.T) { }) } +func TestUsers_Current_200(t *testing.T) { + withMockClient(func(c *rest.Client, mux *http.ServeMux) { + mux.HandleFunc("/api/users/current", func(w http.ResponseWriter, r *http.Request) { + retBytes, _ := json.Marshal(testUser) + _, err := w.Write(retBytes) + require.NoError(t, err) + }) + ret, err := c.Users.Current(context.Background()) + require.NoError(t, err) + assert.Equal(t, testUser, *ret) + }) +} + +func TestUsers_Current_Err(t *testing.T) { + withMockClient(func(c *rest.Client, mux *http.ServeMux) { + mux.HandleFunc("/api/users/current", func(w http.ResponseWriter, r *http.Request) { + retBytes, _ := json.Marshal(util.ErrorResponse{Message: "No", Code: 400}) + w.WriteHeader(400) + _, err := w.Write(retBytes) + require.NoError(t, err) + }) + ret, err := c.Users.Current(context.Background()) + assert.Error(t, err) + assert.Equal(t, "No", err.Error()) + assert.Empty(t, ret) + }) +} + func TestUsers_Integration(t *testing.T) { withBlackBoxServer(t, func(c *rest.Client) { + // rest client PAT is owner's + current, err := c.Users.Current(context.Background()) + require.NoError(t, err) + assert.Equal(t, "a23efe53-63fb-11ec-90d6-0242ac120003", current.Id) + assert.Equal(t, "owner", current.Role) + user, err := c.Users.Create(context.Background(), api.UserCreateRequest{ AutoGroups: []string{}, Email: ptr("test@example.com"), diff --git a/management/server/account/manager.go b/management/server/account/manager.go index 807d05067..62ca6e97b 100644 --- a/management/server/account/manager.go +++ b/management/server/account/manager.go @@ -114,4 +114,5 @@ type Manager interface { CreateAccountByPrivateDomain(ctx context.Context, initiatorId, domain string) (*types.Account, error) UpdateToPrimaryAccount(ctx context.Context, accountId string) (*types.Account, error) GetOwnerInfo(ctx context.Context, accountId string) (*types.UserInfo, error) + GetCurrentUserInfo(ctx context.Context, accountID, userID string) (*types.UserInfo, error) } diff --git a/management/server/http/api/openapi.yml b/management/server/http/api/openapi.yml index 82971541d..c699e9eef 100644 --- a/management/server/http/api/openapi.yml +++ b/management/server/http/api/openapi.yml @@ -2397,6 +2397,29 @@ paths: "$ref": "#/components/responses/forbidden" '500': "$ref": "#/components/responses/internal_error" + /api/users/current: + get: + summary: Retrieve current user + description: Get information about the current user + tags: [ Users ] + security: + - BearerAuth: [ ] + - TokenAuth: [ ] + responses: + '200': + description: A User object + content: + application/json: + schema: + $ref: '#/components/schemas/User' + '400': + "$ref": "#/components/responses/bad_request" + '401': + "$ref": "#/components/responses/requires_authentication" + '403': + "$ref": "#/components/responses/forbidden" + '500': + "$ref": "#/components/responses/internal_error" /api/peers: get: summary: List all Peers diff --git a/management/server/http/api/types.gen.go b/management/server/http/api/types.gen.go index 751311333..9bdb3e4ac 100644 --- a/management/server/http/api/types.gen.go +++ b/management/server/http/api/types.gen.go @@ -230,7 +230,7 @@ type Account struct { // AccountExtraSettings defines model for AccountExtraSettings. type AccountExtraSettings struct { - // NetworkTrafficLogsEnabled Enables or disables network traffic logs. If enabled, all network traffic logs from peers will be stored. + // NetworkTrafficLogsEnabled Enables or disables network traffic logging. If enabled, all network traffic events from peers will be stored. NetworkTrafficLogsEnabled bool `json:"network_traffic_logs_enabled"` // NetworkTrafficPacketCounterEnabled Enables or disables network traffic packet counter. If enabled, network packets and their size will be counted and reported. (This can have an slight impact on performance) diff --git a/management/server/http/handlers/users/users_handler.go b/management/server/http/handlers/users/users_handler.go index 19f56c464..c69c6b944 100644 --- a/management/server/http/handlers/users/users_handler.go +++ b/management/server/http/handlers/users/users_handler.go @@ -25,6 +25,7 @@ type handler struct { func AddEndpoints(accountManager account.Manager, router *mux.Router) { userHandler := newHandler(accountManager) router.HandleFunc("/users", userHandler.getAllUsers).Methods("GET", "OPTIONS") + router.HandleFunc("/users/current", userHandler.getCurrentUser).Methods("GET", "OPTIONS") router.HandleFunc("/users/{userId}", userHandler.updateUser).Methods("PUT", "OPTIONS") router.HandleFunc("/users/{userId}", userHandler.deleteUser).Methods("DELETE", "OPTIONS") router.HandleFunc("/users", userHandler.createUser).Methods("POST", "OPTIONS") @@ -259,6 +260,29 @@ func (h *handler) inviteUser(w http.ResponseWriter, r *http.Request) { util.WriteJSONObject(r.Context(), w, util.EmptyObject{}) } +func (h *handler) getCurrentUser(w http.ResponseWriter, r *http.Request) { + if r.Method != http.MethodGet { + util.WriteErrorResponse("wrong HTTP method", http.StatusMethodNotAllowed, w) + return + } + ctx := r.Context() + userAuth, err := nbcontext.GetUserAuthFromContext(ctx) + if err != nil { + util.WriteError(r.Context(), err, w) + return + } + + accountID, userID := userAuth.AccountId, userAuth.UserId + + user, err := h.accountManager.GetCurrentUserInfo(ctx, accountID, userID) + if err != nil { + util.WriteError(r.Context(), err, w) + return + } + + util.WriteJSONObject(r.Context(), w, toUserResponse(user, userID)) +} + func toUserResponse(user *types.UserInfo, currenUserID string) *api.User { autoGroups := user.AutoGroups if autoGroups == nil { diff --git a/management/server/http/handlers/users/users_handler_test.go b/management/server/http/handlers/users/users_handler_test.go index a6a904a4c..604954819 100644 --- a/management/server/http/handlers/users/users_handler_test.go +++ b/management/server/http/handlers/users/users_handler_test.go @@ -9,6 +9,7 @@ import ( "net/http" "net/http/httptest" "testing" + "time" "github.com/gorilla/mux" "github.com/stretchr/testify/assert" @@ -123,6 +124,64 @@ func initUsersTestData() *handler { return nil }, + GetCurrentUserInfoFunc: func(ctx context.Context, accountID, userID string) (*types.UserInfo, error) { + switch userID { + case "not-found": + return nil, status.NewUserNotFoundError("not-found") + case "not-of-account": + return nil, status.NewUserNotPartOfAccountError() + case "blocked-user": + return nil, status.NewUserBlockedError() + case "service-user": + return nil, status.NewPermissionDeniedError() + case "owner": + return &types.UserInfo{ + ID: "owner", + Name: "", + Role: "owner", + Status: "active", + IsServiceUser: false, + IsBlocked: false, + NonDeletable: false, + Issued: "api", + Permissions: types.UserPermissions{ + DashboardView: "full", + }, + }, nil + case "regular-user": + return &types.UserInfo{ + ID: "regular-user", + Name: "", + Role: "user", + Status: "active", + IsServiceUser: false, + IsBlocked: false, + NonDeletable: false, + Issued: "api", + Permissions: types.UserPermissions{ + DashboardView: "limited", + }, + }, nil + + case "admin-user": + return &types.UserInfo{ + ID: "admin-user", + Name: "", + Role: "admin", + Status: "active", + IsServiceUser: false, + IsBlocked: false, + NonDeletable: false, + LastLogin: time.Time{}, + Issued: "api", + Permissions: types.UserPermissions{ + DashboardView: "full", + }, + }, nil + } + + return nil, fmt.Errorf("user id %s not handled", userID) + }, }, } } @@ -481,3 +540,73 @@ func TestDeleteUser(t *testing.T) { }) } } + +func TestCurrentUser(t *testing.T) { + tt := []struct { + name string + expectedStatus int + requestAuth nbcontext.UserAuth + }{ + { + name: "without auth", + expectedStatus: http.StatusInternalServerError, + }, + { + name: "user not found", + requestAuth: nbcontext.UserAuth{UserId: "not-found"}, + expectedStatus: http.StatusNotFound, + }, + { + name: "not of account", + requestAuth: nbcontext.UserAuth{UserId: "not-of-account"}, + expectedStatus: http.StatusForbidden, + }, + { + name: "blocked user", + requestAuth: nbcontext.UserAuth{UserId: "blocked-user"}, + expectedStatus: http.StatusForbidden, + }, + { + name: "service user", + requestAuth: nbcontext.UserAuth{UserId: "service-user"}, + expectedStatus: http.StatusForbidden, + }, + { + name: "owner", + requestAuth: nbcontext.UserAuth{UserId: "owner"}, + expectedStatus: http.StatusOK, + }, + { + name: "regular user", + requestAuth: nbcontext.UserAuth{UserId: "regular-user"}, + expectedStatus: http.StatusOK, + }, + { + name: "admin user", + requestAuth: nbcontext.UserAuth{UserId: "admin-user"}, + expectedStatus: http.StatusOK, + }, + } + + userHandler := initUsersTestData() + for _, tc := range tt { + t.Run(tc.name, func(t *testing.T) { + req := httptest.NewRequest(http.MethodGet, "/api/users/current", nil) + if tc.requestAuth.UserId != "" { + req = nbcontext.SetUserAuthInRequest(req, tc.requestAuth) + } + + rr := httptest.NewRecorder() + + userHandler.getCurrentUser(rr, req) + + res := rr.Result() + defer res.Body.Close() + + if status := rr.Code; status != tc.expectedStatus { + t.Fatalf("handler returned wrong status code: got %v want %v", + status, tc.expectedStatus) + } + }) + } +} diff --git a/management/server/mock_server/account_mock.go b/management/server/mock_server/account_mock.go index 008a7059f..8865c1e96 100644 --- a/management/server/mock_server/account_mock.go +++ b/management/server/mock_server/account_mock.go @@ -115,6 +115,7 @@ type MockAccountManager struct { CreateAccountByPrivateDomainFunc func(ctx context.Context, initiatorId, domain string) (*types.Account, error) UpdateToPrimaryAccountFunc func(ctx context.Context, accountId string) (*types.Account, error) GetOwnerInfoFunc func(ctx context.Context, accountID string) (*types.UserInfo, error) + GetCurrentUserInfoFunc func(ctx context.Context, accountID, userID string) (*types.UserInfo, error) } func (am *MockAccountManager) UpdateAccountPeers(ctx context.Context, accountID string) { @@ -871,3 +872,10 @@ func (am *MockAccountManager) GetOwnerInfo(ctx context.Context, accountId string } return nil, status.Errorf(codes.Unimplemented, "method GetOwnerInfo is not implemented") } + +func (am *MockAccountManager) GetCurrentUserInfo(ctx context.Context, accountID, userID string) (*types.UserInfo, error) { + if am.GetCurrentUserInfoFunc != nil { + return am.GetCurrentUserInfoFunc(ctx, accountID, userID) + } + return nil, status.Errorf(codes.Unimplemented, "method GetCurrentUserInfo is not implemented") +} diff --git a/management/server/user.go b/management/server/user.go index 3dee3f014..731958909 100644 --- a/management/server/user.go +++ b/management/server/user.go @@ -824,32 +824,33 @@ func (am *DefaultAccountManager) GetUsersFromAccount(ctx context.Context, accoun if err != nil { return nil, status.NewPermissionValidationError(err) } + user, err := am.Store.GetUserByUserID(ctx, store.LockingStrengthShare, initiatorUserID) if err != nil { return nil, fmt.Errorf("failed to get user: %w", err) } - accountUsers := []*types.User{user} - if allowed { + accountUsers := []*types.User{} + switch { + case allowed: accountUsers, err = am.Store.GetAccountUsers(ctx, store.LockingStrengthShare, accountID) if err != nil { return nil, err } + case user.AccountID == accountID: + accountUsers = append(accountUsers, user) + default: + return map[string]*types.UserInfo{}, nil } return am.BuildUserInfosForAccount(ctx, accountID, initiatorUserID, accountUsers) } // BuildUserInfosForAccount builds user info for the given account. -func (am *DefaultAccountManager) BuildUserInfosForAccount(ctx context.Context, accountID, initiatorUserID string, accountUsers []*types.User) (map[string]*types.UserInfo, error) { +func (am *DefaultAccountManager) BuildUserInfosForAccount(ctx context.Context, accountID, _ string, accountUsers []*types.User) (map[string]*types.UserInfo, error) { var queriedUsers []*idp.UserData var err error - initiatorUser, err := am.Store.GetUserByUserID(ctx, store.LockingStrengthShare, initiatorUserID) - if err != nil { - return nil, err - } - if !isNil(am.idpManager) { users := make(map[string]userLoggedInOnce, len(accountUsers)) usersFromIntegration := make([]*idp.UserData, 0) @@ -888,11 +889,6 @@ func (am *DefaultAccountManager) BuildUserInfosForAccount(ctx context.Context, a // in case of self-hosted, or IDP doesn't return anything, we will return the locally stored userInfo if len(queriedUsers) == 0 { for _, accountUser := range accountUsers { - if initiatorUser.IsRegularUser() && initiatorUser.Id != accountUser.Id { - // if user is not an admin then show only current user and do not show other users - continue - } - info, err := accountUser.ToUserInfo(nil, settings) if err != nil { return nil, err @@ -904,11 +900,6 @@ func (am *DefaultAccountManager) BuildUserInfosForAccount(ctx context.Context, a } for _, localUser := range accountUsers { - if initiatorUser.IsRegularUser() && initiatorUser.Id != localUser.Id { - // if user is not an admin then show only current user and do not show other users - continue - } - var info *types.UserInfo if queriedUser, contains := findUserInIDPUserdata(localUser.Id, queriedUsers); contains { info, err = localUser.ToUserInfo(queriedUser, settings) @@ -1241,3 +1232,30 @@ func validateUserInvite(invite *types.UserInfo) error { return nil } + +// GetCurrentUserInfo retrieves the account's current user info +func (am *DefaultAccountManager) GetCurrentUserInfo(ctx context.Context, accountID, userID string) (*types.UserInfo, error) { + user, err := am.Store.GetUserByUserID(ctx, store.LockingStrengthShare, userID) + if err != nil { + return nil, err + } + + if user.IsBlocked() { + return nil, status.NewUserBlockedError() + } + + if user.IsServiceUser { + return nil, status.NewPermissionDeniedError() + } + + if err := am.permissionsManager.ValidateAccountAccess(ctx, accountID, user, false); err != nil { + return nil, err + } + + userInfo, err := am.getUserInfo(ctx, user, accountID) + if err != nil { + return nil, err + } + + return userInfo, nil +} diff --git a/management/server/user_test.go b/management/server/user_test.go index c5da4ec88..098c8a31e 100644 --- a/management/server/user_test.go +++ b/management/server/user_test.go @@ -13,6 +13,7 @@ import ( nbcache "github.com/netbirdio/netbird/management/server/cache" nbcontext "github.com/netbirdio/netbird/management/server/context" "github.com/netbirdio/netbird/management/server/permissions" + "github.com/netbirdio/netbird/management/server/status" "github.com/netbirdio/netbird/management/server/util" nbpeer "github.com/netbirdio/netbird/management/server/peer" @@ -1607,3 +1608,175 @@ func TestSaveOrAddUser_PreventAccountSwitch(t *testing.T) { assert.Equal(t, account1.Users[targetId].AccountID, user.AccountID) assert.Equal(t, account1.Users[targetId].AutoGroups, user.AutoGroups) } + +func TestDefaultAccountManager_GetCurrentUserInfo(t *testing.T) { + store, cleanup, err := store.NewTestStoreFromSQL(context.Background(), "", t.TempDir()) + if err != nil { + t.Fatalf("Error when creating store: %s", err) + } + t.Cleanup(cleanup) + + account1 := newAccountWithId(context.Background(), "account1", "account1Owner", "") + account1.Settings.RegularUsersViewBlocked = false + account1.Users["blocked-user"] = &types.User{ + Id: "blocked-user", + AccountID: account1.Id, + Blocked: true, + } + account1.Users["service-user"] = &types.User{ + Id: "service-user", + IsServiceUser: true, + ServiceUserName: "service-user", + } + account1.Users["regular-user"] = &types.User{ + Id: "regular-user", + Role: types.UserRoleUser, + } + account1.Users["admin-user"] = &types.User{ + Id: "admin-user", + Role: types.UserRoleAdmin, + } + require.NoError(t, store.SaveAccount(context.Background(), account1)) + + account2 := newAccountWithId(context.Background(), "account2", "account2Owner", "") + account2.Users["settings-blocked-user"] = &types.User{ + Id: "settings-blocked-user", + Role: types.UserRoleUser, + } + require.NoError(t, store.SaveAccount(context.Background(), account2)) + + permissionsManager := permissions.NewManager(store) + am := DefaultAccountManager{ + Store: store, + eventStore: &activity.InMemoryEventStore{}, + permissionsManager: permissionsManager, + } + + tt := []struct { + name string + accountId string + userId string + expectedErr error + expectedResult *types.UserInfo + }{ + { + name: "not found", + accountId: account1.Id, + userId: "not-found", + expectedErr: status.NewUserNotFoundError("not-found"), + }, + { + name: "not part of account", + accountId: account1.Id, + userId: "account2Owner", + expectedErr: status.NewUserNotPartOfAccountError(), + }, + { + name: "blocked", + accountId: account1.Id, + userId: "blocked-user", + expectedErr: status.NewUserBlockedError(), + }, + { + name: "service user", + accountId: account1.Id, + userId: "service-user", + expectedErr: status.NewPermissionDeniedError(), + }, + { + name: "owner user", + accountId: account1.Id, + userId: "account1Owner", + expectedResult: &types.UserInfo{ + ID: "account1Owner", + Name: "", + Role: "owner", + AutoGroups: []string{}, + Status: "active", + IsServiceUser: false, + IsBlocked: false, + NonDeletable: false, + LastLogin: time.Time{}, + Issued: "api", + IntegrationReference: integration_reference.IntegrationReference{}, + Permissions: types.UserPermissions{ + DashboardView: "full", + }, + }, + }, + { + name: "regular user", + accountId: account1.Id, + userId: "regular-user", + expectedResult: &types.UserInfo{ + ID: "regular-user", + Name: "", + Role: "user", + Status: "active", + IsServiceUser: false, + IsBlocked: false, + NonDeletable: false, + LastLogin: time.Time{}, + Issued: "api", + IntegrationReference: integration_reference.IntegrationReference{}, + Permissions: types.UserPermissions{ + DashboardView: "limited", + }, + }, + }, + { + name: "admin user", + accountId: account1.Id, + userId: "admin-user", + expectedResult: &types.UserInfo{ + ID: "admin-user", + Name: "", + Role: "admin", + Status: "active", + IsServiceUser: false, + IsBlocked: false, + NonDeletable: false, + LastLogin: time.Time{}, + Issued: "api", + IntegrationReference: integration_reference.IntegrationReference{}, + Permissions: types.UserPermissions{ + DashboardView: "full", + }, + }, + }, + { + name: "settings blocked regular user", + accountId: account2.Id, + userId: "settings-blocked-user", + expectedResult: &types.UserInfo{ + ID: "settings-blocked-user", + Name: "", + Role: "user", + Status: "active", + IsServiceUser: false, + IsBlocked: false, + NonDeletable: false, + LastLogin: time.Time{}, + Issued: "api", + IntegrationReference: integration_reference.IntegrationReference{}, + Permissions: types.UserPermissions{ + DashboardView: "blocked", + }, + }, + }, + } + + for _, tc := range tt { + t.Run(tc.name, func(t *testing.T) { + result, err := am.GetCurrentUserInfo(context.Background(), tc.accountId, tc.userId) + + if tc.expectedErr != nil { + assert.Equal(t, err, tc.expectedErr) + return + } + + require.NoError(t, err) + assert.EqualValues(t, tc.expectedResult, result) + }) + } +}