mirror of
https://github.com/netbirdio/netbird.git
synced 2025-06-19 17:31:39 +02:00
Test mgmt http handler (#240)
This commit is contained in:
parent
41c6af6b6f
commit
5f5cbf7e20
@ -16,6 +16,8 @@ import (
|
||||
"github.com/wiretrustee/wiretrustee/management/proto"
|
||||
mgmtProto "github.com/wiretrustee/wiretrustee/management/proto"
|
||||
mgmt "github.com/wiretrustee/wiretrustee/management/server"
|
||||
"github.com/wiretrustee/wiretrustee/management/server/mock_server"
|
||||
|
||||
"github.com/wiretrustee/wiretrustee/util"
|
||||
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
|
||||
"google.golang.org/grpc"
|
||||
@ -25,7 +27,7 @@ import (
|
||||
|
||||
var tested *GrpcClient
|
||||
var serverAddr string
|
||||
var mgmtMockServer *mgmt.ManagementServiceServerMock
|
||||
var mgmtMockServer *mock_server.ManagementServiceServerMock
|
||||
var serverKey wgtypes.Key
|
||||
|
||||
const ValidKey = "A2C8E62B-38F5-4553-B31E-DD66C696CEBB"
|
||||
@ -100,7 +102,7 @@ func startMockManagement(t *testing.T) (*grpc.Server, net.Listener) {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
mgmtMockServer = &mgmt.ManagementServiceServerMock{
|
||||
mgmtMockServer = &mock_server.ManagementServiceServerMock{
|
||||
GetServerKeyFunc: func(context.Context, *proto.Empty) (*proto.ServerKeyResponse, error) {
|
||||
response := &proto.ServerKeyResponse{
|
||||
Key: serverKey.PublicKey().String(),
|
||||
|
@ -15,6 +15,7 @@ import (
|
||||
type Peers struct {
|
||||
accountManager server.AccountManager
|
||||
authAudience string
|
||||
jwtExtractor JWTClaimsExtractor
|
||||
}
|
||||
|
||||
//PeerResponse is a response sent to the client
|
||||
@ -36,6 +37,7 @@ func NewPeers(accountManager server.AccountManager, authAudience string) *Peers
|
||||
return &Peers{
|
||||
accountManager: accountManager,
|
||||
authAudience: authAudience,
|
||||
jwtExtractor: *NewJWTClaimsExtractor(nil),
|
||||
}
|
||||
}
|
||||
|
||||
@ -55,6 +57,7 @@ func (h *Peers) updatePeer(accountId string, peer *server.Peer, w http.ResponseW
|
||||
}
|
||||
writeJSONObject(w, toPeerResponse(peer))
|
||||
}
|
||||
|
||||
func (h *Peers) deletePeer(accountId string, peer *server.Peer, w http.ResponseWriter, r *http.Request) {
|
||||
_, err := h.accountManager.DeletePeer(accountId, peer.Key)
|
||||
if err != nil {
|
||||
@ -66,7 +69,7 @@ func (h *Peers) deletePeer(accountId string, peer *server.Peer, w http.ResponseW
|
||||
}
|
||||
|
||||
func (h *Peers) getPeerAccount(r *http.Request) (*server.Account, error) {
|
||||
jwtClaims := extractClaimsFromRequestContext(r, h.authAudience)
|
||||
jwtClaims := h.jwtExtractor.extractClaimsFromRequestContext(r, h.authAudience)
|
||||
|
||||
account, err := h.accountManager.GetAccountByUserOrAccountId(jwtClaims.UserId, jwtClaims.AccountId, jwtClaims.Domain)
|
||||
if err != nil {
|
||||
|
107
management/server/http/handler/peers_test.go
Normal file
107
management/server/http/handler/peers_test.go
Normal file
@ -0,0 +1,107 @@
|
||||
package handler
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"io"
|
||||
"net"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
|
||||
"github.com/magiconair/properties/assert"
|
||||
"github.com/wiretrustee/wiretrustee/management/server"
|
||||
"github.com/wiretrustee/wiretrustee/management/server/mock_server"
|
||||
)
|
||||
|
||||
func initTestMetaData(peer ...*server.Peer) *Peers {
|
||||
return &Peers{
|
||||
accountManager: &mock_server.MockAccountManager{
|
||||
GetAccountByUserOrAccountIdFunc: func(userId, accountId, domain string) (*server.Account, error) {
|
||||
return &server.Account{
|
||||
Id: accountId,
|
||||
Domain: "hotmail.com",
|
||||
Peers: map[string]*server.Peer{
|
||||
"test_peer": peer[0],
|
||||
},
|
||||
}, nil
|
||||
},
|
||||
},
|
||||
authAudience: "",
|
||||
jwtExtractor: JWTClaimsExtractor{
|
||||
extractClaimsFromRequestContext: func(r *http.Request, authAudiance string) JWTClaims {
|
||||
return JWTClaims{
|
||||
UserId: "test_user",
|
||||
Domain: "hotmail.com",
|
||||
AccountId: "test_id",
|
||||
}
|
||||
},
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
// Tests the GetPeers endpoint reachable in the route /api/peers
|
||||
// Use the metadata generated by initTestMetaData() to check for values
|
||||
func TestGetPeers(t *testing.T) {
|
||||
var tt = []struct {
|
||||
name string
|
||||
expectedStatus int
|
||||
requestType string
|
||||
requestPath string
|
||||
requestBody io.Reader
|
||||
}{
|
||||
{name: "GetPeersMetaData", requestType: http.MethodGet, requestPath: "/api/peers/", expectedStatus: http.StatusOK},
|
||||
}
|
||||
|
||||
rr := httptest.NewRecorder()
|
||||
peer := &server.Peer{
|
||||
Key: "key",
|
||||
SetupKey: "setupkey",
|
||||
IP: net.ParseIP("100.64.0.1"),
|
||||
Status: &server.PeerStatus{},
|
||||
Name: "PeerName",
|
||||
Meta: server.PeerSystemMeta{
|
||||
Hostname: "hostname",
|
||||
GoOS: "GoOS",
|
||||
Kernel: "kernel",
|
||||
Core: "core",
|
||||
Platform: "platform",
|
||||
OS: "OS",
|
||||
WtVersion: "development",
|
||||
},
|
||||
}
|
||||
|
||||
p := initTestMetaData(peer)
|
||||
|
||||
for _, tc := range tt {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
req := httptest.NewRequest(tc.requestType, tc.requestPath, tc.requestBody)
|
||||
|
||||
p.GetPeers(rr, req)
|
||||
|
||||
res := rr.Result()
|
||||
defer res.Body.Close()
|
||||
|
||||
if status := rr.Code; status != tc.expectedStatus {
|
||||
t.Fatalf("handler returned wrong status code: got %v want %v",
|
||||
status, http.StatusOK)
|
||||
}
|
||||
|
||||
content, err := io.ReadAll(res.Body)
|
||||
if err != nil {
|
||||
t.Fatalf("I don't know what I expected; %v", err)
|
||||
}
|
||||
|
||||
respBody := []*PeerResponse{}
|
||||
err = json.Unmarshal(content, &respBody)
|
||||
if err != nil {
|
||||
t.Fatalf("Sent content is not in correct json format; %v", err)
|
||||
}
|
||||
|
||||
got := respBody[0]
|
||||
assert.Equal(t, got.Name, peer.Name)
|
||||
assert.Equal(t, got.Version, peer.Meta.WtVersion)
|
||||
assert.Equal(t, got.IP, peer.IP.String())
|
||||
assert.Equal(t, got.OS, "OS core")
|
||||
})
|
||||
}
|
||||
}
|
@ -122,7 +122,8 @@ func (h *SetupKeys) createKey(accountId string, w http.ResponseWriter, r *http.R
|
||||
}
|
||||
|
||||
func (h *SetupKeys) getSetupKeyAccount(r *http.Request) (*server.Account, error) {
|
||||
jwtClaims := extractClaimsFromRequestContext(r, h.authAudience)
|
||||
extractor := NewJWTClaimsExtractor(nil)
|
||||
jwtClaims := extractor.extractClaimsFromRequestContext(r, h.authAudience)
|
||||
|
||||
account, err := h.accountManager.GetAccountByUserOrAccountId(jwtClaims.UserId, jwtClaims.AccountId, jwtClaims.Domain)
|
||||
if err != nil {
|
||||
|
@ -3,9 +3,10 @@ package handler
|
||||
import (
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"github.com/golang-jwt/jwt"
|
||||
"net/http"
|
||||
"time"
|
||||
|
||||
"github.com/golang-jwt/jwt"
|
||||
)
|
||||
|
||||
// JWTClaims stores information from JWTs
|
||||
@ -15,6 +16,25 @@ type JWTClaims struct {
|
||||
Domain string
|
||||
}
|
||||
|
||||
type extractJWTClaims func(r *http.Request, authAudiance string) JWTClaims
|
||||
|
||||
type JWTClaimsExtractor struct {
|
||||
extractClaimsFromRequestContext extractJWTClaims
|
||||
}
|
||||
|
||||
// NewJWTClaimsExtractor returns an extractor, and if provided with a function with extractJWTClaims signature,
|
||||
// then it will use that logic. Uses extractClaimsFromRequestContext by default
|
||||
func NewJWTClaimsExtractor(e extractJWTClaims) *JWTClaimsExtractor {
|
||||
var extractFunc extractJWTClaims
|
||||
if extractFunc = e; extractFunc == nil {
|
||||
extractFunc = extractClaimsFromRequestContext
|
||||
}
|
||||
|
||||
return &JWTClaimsExtractor{
|
||||
extractClaimsFromRequestContext: extractFunc,
|
||||
}
|
||||
}
|
||||
|
||||
// extractClaimsFromRequestContext extracts claims from the request context previously filled by the JWT token (after auth)
|
||||
func extractClaimsFromRequestContext(r *http.Request, authAudiance string) JWTClaims {
|
||||
token := r.Context().Value("user").(*jwt.Token)
|
||||
@ -34,7 +54,7 @@ func extractClaimsFromRequestContext(r *http.Request, authAudiance string) JWTCl
|
||||
|
||||
//writeJSONObject simply writes object to the HTTP reponse in JSON format
|
||||
func writeJSONObject(w http.ResponseWriter, obj interface{}) {
|
||||
w.WriteHeader(200)
|
||||
w.WriteHeader(http.StatusOK)
|
||||
w.Header().Set("Content-Type", "application/json; charset=UTF-8")
|
||||
err := json.NewEncoder(w).Encode(obj)
|
||||
if err != nil {
|
||||
|
@ -3,13 +3,14 @@ package idp
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"github.com/golang-jwt/jwt"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"io/ioutil"
|
||||
"net/http"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/golang-jwt/jwt"
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
type mockHTTPClient struct {
|
||||
|
139
management/server/mock_server/account_mock.go
Normal file
139
management/server/mock_server/account_mock.go
Normal file
@ -0,0 +1,139 @@
|
||||
package mock_server
|
||||
|
||||
import (
|
||||
"github.com/wiretrustee/wiretrustee/management/server"
|
||||
"github.com/wiretrustee/wiretrustee/util"
|
||||
"google.golang.org/grpc/codes"
|
||||
"google.golang.org/grpc/status"
|
||||
)
|
||||
|
||||
type MockAccountManager struct {
|
||||
GetOrCreateAccountByUserFunc func(userId, domain string) (*server.Account, error)
|
||||
GetAccountByUserFunc func(userId string) (*server.Account, error)
|
||||
AddSetupKeyFunc func(accountId string, keyName string, keyType server.SetupKeyType, expiresIn *util.Duration) (*server.SetupKey, error)
|
||||
RevokeSetupKeyFunc func(accountId string, keyId string) (*server.SetupKey, error)
|
||||
RenameSetupKeyFunc func(accountId string, keyId string, newName string) (*server.SetupKey, error)
|
||||
GetAccountByIdFunc func(accountId string) (*server.Account, error)
|
||||
GetAccountByUserOrAccountIdFunc func(userId, accountId, domain string) (*server.Account, error)
|
||||
AccountExistsFunc func(accountId string) (*bool, error)
|
||||
AddAccountFunc func(accountId, userId, domain string) (*server.Account, error)
|
||||
GetPeerFunc func(peerKey string) (*server.Peer, error)
|
||||
MarkPeerConnectedFunc func(peerKey string, connected bool) error
|
||||
RenamePeerFunc func(accountId string, peerKey string, newName string) (*server.Peer, error)
|
||||
DeletePeerFunc func(accountId string, peerKey string) (*server.Peer, error)
|
||||
GetPeerByIPFunc func(accountId string, peerIP string) (*server.Peer, error)
|
||||
GetNetworkMapFunc func(peerKey string) (*server.NetworkMap, error)
|
||||
AddPeerFunc func(setupKey string, peer *server.Peer) (*server.Peer, error)
|
||||
}
|
||||
|
||||
func (am *MockAccountManager) GetOrCreateAccountByUser(userId, domain string) (*server.Account, error) {
|
||||
if am.GetOrCreateAccountByUserFunc != nil {
|
||||
return am.GetOrCreateAccountByUserFunc(userId, domain)
|
||||
}
|
||||
return nil, status.Errorf(codes.Unimplemented, "method GetOrCreateAccountByUser not implemented")
|
||||
}
|
||||
|
||||
func (am *MockAccountManager) GetAccountByUser(userId string) (*server.Account, error) {
|
||||
if am.GetAccountByUserFunc != nil {
|
||||
return am.GetAccountByUserFunc(userId)
|
||||
}
|
||||
return nil, status.Errorf(codes.Unimplemented, "method GetAccountByUser not implemented")
|
||||
}
|
||||
|
||||
func (am *MockAccountManager) AddSetupKey(accountId string, keyName string, keyType server.SetupKeyType, expiresIn *util.Duration) (*server.SetupKey, error) {
|
||||
if am.AddSetupKeyFunc != nil {
|
||||
return am.AddSetupKeyFunc(accountId, keyName, keyType, expiresIn)
|
||||
}
|
||||
return nil, status.Errorf(codes.Unimplemented, "method AddSetupKey not implemented")
|
||||
}
|
||||
|
||||
func (am *MockAccountManager) RevokeSetupKey(accountId string, keyId string) (*server.SetupKey, error) {
|
||||
if am.RevokeSetupKeyFunc != nil {
|
||||
return am.RevokeSetupKeyFunc(accountId, keyId)
|
||||
}
|
||||
return nil, status.Errorf(codes.Unimplemented, "method RevokeSetupKey not implemented")
|
||||
}
|
||||
|
||||
func (am *MockAccountManager) RenameSetupKey(accountId string, keyId string, newName string) (*server.SetupKey, error) {
|
||||
if am.RenameSetupKeyFunc != nil {
|
||||
return am.RenameSetupKeyFunc(accountId, keyId, newName)
|
||||
}
|
||||
return nil, status.Errorf(codes.Unimplemented, "method RenameSetupKey not implemented")
|
||||
}
|
||||
|
||||
func (am *MockAccountManager) GetAccountById(accountId string) (*server.Account, error) {
|
||||
if am.GetAccountByIdFunc != nil {
|
||||
return am.GetAccountByIdFunc(accountId)
|
||||
}
|
||||
return nil, status.Errorf(codes.Unimplemented, "method GetAccountById not implemented")
|
||||
}
|
||||
|
||||
func (am *MockAccountManager) GetAccountByUserOrAccountId(userId, accountId, domain string) (*server.Account, error) {
|
||||
if am.GetAccountByUserOrAccountIdFunc != nil {
|
||||
return am.GetAccountByUserOrAccountIdFunc(userId, accountId, domain)
|
||||
}
|
||||
return nil, status.Errorf(codes.Unimplemented, "method GetAccountByUserOrAccountId not implemented")
|
||||
}
|
||||
|
||||
func (am *MockAccountManager) AccountExists(accountId string) (*bool, error) {
|
||||
if am.AccountExistsFunc != nil {
|
||||
return am.AccountExistsFunc(accountId)
|
||||
}
|
||||
return nil, status.Errorf(codes.Unimplemented, "method AccountExists not implemented")
|
||||
}
|
||||
|
||||
func (am *MockAccountManager) AddAccount(accountId, userId, domain string) (*server.Account, error) {
|
||||
if am.AddAccountFunc != nil {
|
||||
return am.AddAccountFunc(accountId, userId, domain)
|
||||
}
|
||||
return nil, status.Errorf(codes.Unimplemented, "method AddAccount not implemented")
|
||||
}
|
||||
|
||||
func (am *MockAccountManager) GetPeer(peerKey string) (*server.Peer, error) {
|
||||
if am.GetPeerFunc != nil {
|
||||
return am.GetPeerFunc(peerKey)
|
||||
}
|
||||
return nil, status.Errorf(codes.Unimplemented, "method GetPeer not implemented")
|
||||
}
|
||||
|
||||
func (am *MockAccountManager) MarkPeerConnected(peerKey string, connected bool) error {
|
||||
if am.MarkPeerConnectedFunc != nil {
|
||||
return am.MarkPeerConnectedFunc(peerKey, connected)
|
||||
}
|
||||
return status.Errorf(codes.Unimplemented, "method MarkPeerConnected not implemented")
|
||||
}
|
||||
|
||||
func (am *MockAccountManager) RenamePeer(accountId string, peerKey string, newName string) (*server.Peer, error) {
|
||||
if am.RenamePeerFunc != nil {
|
||||
return am.RenamePeerFunc(accountId, peerKey, newName)
|
||||
}
|
||||
return nil, status.Errorf(codes.Unimplemented, "method RenamePeer not implemented")
|
||||
}
|
||||
|
||||
func (am *MockAccountManager) DeletePeer(accountId string, peerKey string) (*server.Peer, error) {
|
||||
if am.DeletePeerFunc != nil {
|
||||
return am.DeletePeerFunc(accountId, peerKey)
|
||||
}
|
||||
return nil, status.Errorf(codes.Unimplemented, "method DeletePeer not implemented")
|
||||
}
|
||||
|
||||
func (am *MockAccountManager) GetPeerByIP(accountId string, peerIP string) (*server.Peer, error) {
|
||||
if am.GetPeerByIPFunc != nil {
|
||||
return am.GetPeerByIPFunc(accountId, peerIP)
|
||||
}
|
||||
return nil, status.Errorf(codes.Unimplemented, "method GetPeerByIP not implemented")
|
||||
}
|
||||
|
||||
func (am *MockAccountManager) GetNetworkMap(peerKey string) (*server.NetworkMap, error) {
|
||||
if am.GetNetworkMapFunc != nil {
|
||||
return am.GetNetworkMapFunc(peerKey)
|
||||
}
|
||||
return nil, status.Errorf(codes.Unimplemented, "method GetNetworkMap not implemented")
|
||||
}
|
||||
|
||||
func (am *MockAccountManager) AddPeer(setupKey string, peer *server.Peer) (*server.Peer, error) {
|
||||
if am.AddPeerFunc != nil {
|
||||
return am.AddPeerFunc(setupKey, peer)
|
||||
}
|
||||
return nil, status.Errorf(codes.Unimplemented, "method AddPeer not implemented")
|
||||
}
|
@ -1,4 +1,4 @@
|
||||
package server
|
||||
package mock_server
|
||||
|
||||
import (
|
||||
"context"
|
Loading…
x
Reference in New Issue
Block a user