mirror of
https://github.com/netbirdio/netbird.git
synced 2025-01-10 16:08:14 +01:00
02a3feddb8
* Add mysql store support * Add support to disable activity events recording
257 lines
8.6 KiB
Go
257 lines
8.6 KiB
Go
package users
|
|
|
|
import (
|
|
"bytes"
|
|
"context"
|
|
"encoding/json"
|
|
"io"
|
|
"net/http"
|
|
"net/http/httptest"
|
|
"testing"
|
|
"time"
|
|
|
|
"github.com/google/go-cmp/cmp"
|
|
"github.com/gorilla/mux"
|
|
"github.com/netbirdio/netbird/management/server/util"
|
|
"github.com/stretchr/testify/assert"
|
|
|
|
"github.com/netbirdio/netbird/management/server/http/api"
|
|
"github.com/netbirdio/netbird/management/server/jwtclaims"
|
|
"github.com/netbirdio/netbird/management/server/mock_server"
|
|
"github.com/netbirdio/netbird/management/server/status"
|
|
"github.com/netbirdio/netbird/management/server/types"
|
|
)
|
|
|
|
const (
|
|
existingAccountID = "existingAccountID"
|
|
notFoundAccountID = "notFoundAccountID"
|
|
existingUserID = "existingUserID"
|
|
notFoundUserID = "notFoundUserID"
|
|
existingTokenID = "existingTokenID"
|
|
notFoundTokenID = "notFoundTokenID"
|
|
testDomain = "hotmail.com"
|
|
)
|
|
|
|
var testAccount = &types.Account{
|
|
Id: existingAccountID,
|
|
Domain: testDomain,
|
|
Users: map[string]*types.User{
|
|
existingUserID: {
|
|
Id: existingUserID,
|
|
PATs: map[string]*types.PersonalAccessToken{
|
|
existingTokenID: {
|
|
ID: existingTokenID,
|
|
Name: "My first token",
|
|
HashedToken: "someHash",
|
|
ExpirationDate: util.ToPtr(time.Now().UTC().AddDate(0, 0, 7)),
|
|
CreatedBy: existingUserID,
|
|
CreatedAt: time.Now().UTC(),
|
|
LastUsed: util.ToPtr(time.Now().UTC()),
|
|
},
|
|
"token2": {
|
|
ID: "token2",
|
|
Name: "My second token",
|
|
HashedToken: "someOtherHash",
|
|
ExpirationDate: util.ToPtr(time.Now().UTC().AddDate(0, 0, 7)),
|
|
CreatedBy: existingUserID,
|
|
CreatedAt: time.Now().UTC(),
|
|
LastUsed: util.ToPtr(time.Now().UTC()),
|
|
},
|
|
},
|
|
},
|
|
},
|
|
}
|
|
|
|
func initPATTestData() *patHandler {
|
|
return &patHandler{
|
|
accountManager: &mock_server.MockAccountManager{
|
|
CreatePATFunc: func(_ context.Context, accountID string, initiatorUserID string, targetUserID string, tokenName string, expiresIn int) (*types.PersonalAccessTokenGenerated, error) {
|
|
if accountID != existingAccountID {
|
|
return nil, status.Errorf(status.NotFound, "account with ID %s not found", accountID)
|
|
}
|
|
if targetUserID != existingUserID {
|
|
return nil, status.Errorf(status.NotFound, "user with ID %s not found", targetUserID)
|
|
}
|
|
return &types.PersonalAccessTokenGenerated{
|
|
PlainToken: "nbp_z1pvsg2wP3EzmEou4S679KyTNhov632eyrXe",
|
|
PersonalAccessToken: types.PersonalAccessToken{},
|
|
}, nil
|
|
},
|
|
|
|
GetAccountIDFromTokenFunc: func(_ context.Context, claims jwtclaims.AuthorizationClaims) (string, string, error) {
|
|
return claims.AccountId, claims.UserId, nil
|
|
},
|
|
DeletePATFunc: func(_ context.Context, accountID string, initiatorUserID string, targetUserID string, tokenID string) error {
|
|
if accountID != existingAccountID {
|
|
return status.Errorf(status.NotFound, "account with ID %s not found", accountID)
|
|
}
|
|
if targetUserID != existingUserID {
|
|
return status.Errorf(status.NotFound, "user with ID %s not found", targetUserID)
|
|
}
|
|
if tokenID != existingTokenID {
|
|
return status.Errorf(status.NotFound, "token with ID %s not found", tokenID)
|
|
}
|
|
return nil
|
|
},
|
|
GetPATFunc: func(_ context.Context, accountID string, initiatorUserID string, targetUserID string, tokenID string) (*types.PersonalAccessToken, error) {
|
|
if accountID != existingAccountID {
|
|
return nil, status.Errorf(status.NotFound, "account with ID %s not found", accountID)
|
|
}
|
|
if targetUserID != existingUserID {
|
|
return nil, status.Errorf(status.NotFound, "user with ID %s not found", targetUserID)
|
|
}
|
|
if tokenID != existingTokenID {
|
|
return nil, status.Errorf(status.NotFound, "token with ID %s not found", tokenID)
|
|
}
|
|
return testAccount.Users[existingUserID].PATs[existingTokenID], nil
|
|
},
|
|
GetAllPATsFunc: func(_ context.Context, accountID string, initiatorUserID string, targetUserID string) ([]*types.PersonalAccessToken, error) {
|
|
if accountID != existingAccountID {
|
|
return nil, status.Errorf(status.NotFound, "account with ID %s not found", accountID)
|
|
}
|
|
if targetUserID != existingUserID {
|
|
return nil, status.Errorf(status.NotFound, "user with ID %s not found", targetUserID)
|
|
}
|
|
return []*types.PersonalAccessToken{testAccount.Users[existingUserID].PATs[existingTokenID], testAccount.Users[existingUserID].PATs["token2"]}, nil
|
|
},
|
|
},
|
|
claimsExtractor: jwtclaims.NewClaimsExtractor(
|
|
jwtclaims.WithFromRequestContext(func(r *http.Request) jwtclaims.AuthorizationClaims {
|
|
return jwtclaims.AuthorizationClaims{
|
|
UserId: existingUserID,
|
|
Domain: testDomain,
|
|
AccountId: existingAccountID,
|
|
}
|
|
}),
|
|
),
|
|
}
|
|
}
|
|
|
|
func TestTokenHandlers(t *testing.T) {
|
|
tt := []struct {
|
|
name string
|
|
expectedStatus int
|
|
expectedBody bool
|
|
requestType string
|
|
requestPath string
|
|
requestBody io.Reader
|
|
}{
|
|
{
|
|
name: "Get All Tokens",
|
|
requestType: http.MethodGet,
|
|
requestPath: "/api/users/" + existingUserID + "/tokens",
|
|
expectedStatus: http.StatusOK,
|
|
expectedBody: true,
|
|
},
|
|
{
|
|
name: "Get Existing Token",
|
|
requestType: http.MethodGet,
|
|
requestPath: "/api/users/" + existingUserID + "/tokens/" + existingTokenID,
|
|
expectedStatus: http.StatusOK,
|
|
expectedBody: true,
|
|
},
|
|
{
|
|
name: "Get Not Existing Token",
|
|
requestType: http.MethodGet,
|
|
requestPath: "/api/users/" + existingUserID + "/tokens/" + notFoundTokenID,
|
|
expectedStatus: http.StatusNotFound,
|
|
},
|
|
{
|
|
name: "Delete Existing Token",
|
|
requestType: http.MethodDelete,
|
|
requestPath: "/api/users/" + existingUserID + "/tokens/" + existingTokenID,
|
|
expectedStatus: http.StatusOK,
|
|
},
|
|
{
|
|
name: "Delete Not Existing Token",
|
|
requestType: http.MethodDelete,
|
|
requestPath: "/api/users/" + existingUserID + "/tokens/" + notFoundTokenID,
|
|
expectedStatus: http.StatusNotFound,
|
|
},
|
|
{
|
|
name: "POST OK",
|
|
requestType: http.MethodPost,
|
|
requestPath: "/api/users/" + existingUserID + "/tokens",
|
|
requestBody: bytes.NewBuffer(
|
|
[]byte("{\"name\":\"name\",\"expires_in\":7}")),
|
|
expectedStatus: http.StatusOK,
|
|
expectedBody: true,
|
|
},
|
|
}
|
|
|
|
p := initPATTestData()
|
|
|
|
for _, tc := range tt {
|
|
t.Run(tc.name, func(t *testing.T) {
|
|
recorder := httptest.NewRecorder()
|
|
req := httptest.NewRequest(tc.requestType, tc.requestPath, tc.requestBody)
|
|
|
|
router := mux.NewRouter()
|
|
router.HandleFunc("/api/users/{userId}/tokens", p.getAllTokens).Methods("GET")
|
|
router.HandleFunc("/api/users/{userId}/tokens/{tokenId}", p.getToken).Methods("GET")
|
|
router.HandleFunc("/api/users/{userId}/tokens", p.createToken).Methods("POST")
|
|
router.HandleFunc("/api/users/{userId}/tokens/{tokenId}", p.deleteToken).Methods("DELETE")
|
|
router.ServeHTTP(recorder, req)
|
|
|
|
res := recorder.Result()
|
|
defer res.Body.Close()
|
|
|
|
content, err := io.ReadAll(res.Body)
|
|
if err != nil {
|
|
t.Fatalf("I don't know what I expected; %v", err)
|
|
}
|
|
|
|
if status := recorder.Code; status != tc.expectedStatus {
|
|
t.Errorf("handler returned wrong status code: got %v want %v, content: %s",
|
|
status, tc.expectedStatus, string(content))
|
|
return
|
|
}
|
|
|
|
if !tc.expectedBody {
|
|
return
|
|
}
|
|
|
|
switch tc.name {
|
|
case "POST OK":
|
|
got := &api.PersonalAccessTokenGenerated{}
|
|
if err = json.Unmarshal(content, &got); err != nil {
|
|
t.Fatalf("Sent content is not in correct json format; %v", err)
|
|
}
|
|
assert.NotEmpty(t, got.PlainToken)
|
|
assert.Equal(t, types.PATLength, len(got.PlainToken))
|
|
case "Get All Tokens":
|
|
expectedTokens := []api.PersonalAccessToken{
|
|
toTokenResponse(*testAccount.Users[existingUserID].PATs[existingTokenID]),
|
|
toTokenResponse(*testAccount.Users[existingUserID].PATs["token2"]),
|
|
}
|
|
|
|
var got []api.PersonalAccessToken
|
|
if err = json.Unmarshal(content, &got); err != nil {
|
|
t.Fatalf("Sent content is not in correct json format; %v", err)
|
|
}
|
|
assert.True(t, cmp.Equal(got, expectedTokens))
|
|
case "Get Existing Token":
|
|
expectedToken := toTokenResponse(*testAccount.Users[existingUserID].PATs[existingTokenID])
|
|
got := &api.PersonalAccessToken{}
|
|
if err = json.Unmarshal(content, &got); err != nil {
|
|
t.Fatalf("Sent content is not in correct json format; %v", err)
|
|
}
|
|
|
|
assert.True(t, cmp.Equal(*got, expectedToken))
|
|
}
|
|
|
|
})
|
|
}
|
|
}
|
|
|
|
func toTokenResponse(serverToken types.PersonalAccessToken) api.PersonalAccessToken {
|
|
return api.PersonalAccessToken{
|
|
Id: serverToken.ID,
|
|
Name: serverToken.Name,
|
|
CreatedAt: serverToken.CreatedAt,
|
|
LastUsed: serverToken.LastUsed,
|
|
CreatedBy: serverToken.CreatedBy,
|
|
ExpirationDate: serverToken.GetExpirationDate(),
|
|
}
|
|
}
|