diff --git a/management/client/client_test.go b/management/client/client_test.go index a93c253bc..3f95347aa 100644 --- a/management/client/client_test.go +++ b/management/client/client_test.go @@ -16,6 +16,8 @@ import ( "github.com/wiretrustee/wiretrustee/management/proto" mgmtProto "github.com/wiretrustee/wiretrustee/management/proto" mgmt "github.com/wiretrustee/wiretrustee/management/server" + "github.com/wiretrustee/wiretrustee/management/server/mock_server" + "github.com/wiretrustee/wiretrustee/util" "golang.zx2c4.com/wireguard/wgctrl/wgtypes" "google.golang.org/grpc" @@ -25,7 +27,7 @@ import ( var tested *GrpcClient var serverAddr string -var mgmtMockServer *mgmt.ManagementServiceServerMock +var mgmtMockServer *mock_server.ManagementServiceServerMock var serverKey wgtypes.Key const ValidKey = "A2C8E62B-38F5-4553-B31E-DD66C696CEBB" @@ -100,7 +102,7 @@ func startMockManagement(t *testing.T) (*grpc.Server, net.Listener) { t.Fatal(err) } - mgmtMockServer = &mgmt.ManagementServiceServerMock{ + mgmtMockServer = &mock_server.ManagementServiceServerMock{ GetServerKeyFunc: func(context.Context, *proto.Empty) (*proto.ServerKeyResponse, error) { response := &proto.ServerKeyResponse{ Key: serverKey.PublicKey().String(), diff --git a/management/server/http/handler/peers.go b/management/server/http/handler/peers.go index 3de4b894a..bdedd723d 100644 --- a/management/server/http/handler/peers.go +++ b/management/server/http/handler/peers.go @@ -15,6 +15,7 @@ import ( type Peers struct { accountManager server.AccountManager authAudience string + jwtExtractor JWTClaimsExtractor } //PeerResponse is a response sent to the client @@ -36,6 +37,7 @@ func NewPeers(accountManager server.AccountManager, authAudience string) *Peers return &Peers{ accountManager: accountManager, authAudience: authAudience, + jwtExtractor: *NewJWTClaimsExtractor(nil), } } @@ -55,6 +57,7 @@ func (h *Peers) updatePeer(accountId string, peer *server.Peer, w http.ResponseW } writeJSONObject(w, toPeerResponse(peer)) } + func (h *Peers) deletePeer(accountId string, peer *server.Peer, w http.ResponseWriter, r *http.Request) { _, err := h.accountManager.DeletePeer(accountId, peer.Key) if err != nil { @@ -66,7 +69,7 @@ func (h *Peers) deletePeer(accountId string, peer *server.Peer, w http.ResponseW } func (h *Peers) getPeerAccount(r *http.Request) (*server.Account, error) { - jwtClaims := extractClaimsFromRequestContext(r, h.authAudience) + jwtClaims := h.jwtExtractor.extractClaimsFromRequestContext(r, h.authAudience) account, err := h.accountManager.GetAccountByUserOrAccountId(jwtClaims.UserId, jwtClaims.AccountId, jwtClaims.Domain) if err != nil { diff --git a/management/server/http/handler/peers_test.go b/management/server/http/handler/peers_test.go new file mode 100644 index 000000000..8a67a38d8 --- /dev/null +++ b/management/server/http/handler/peers_test.go @@ -0,0 +1,107 @@ +package handler + +import ( + "encoding/json" + "io" + "net" + "net/http" + "net/http/httptest" + "testing" + + "github.com/magiconair/properties/assert" + "github.com/wiretrustee/wiretrustee/management/server" + "github.com/wiretrustee/wiretrustee/management/server/mock_server" +) + +func initTestMetaData(peer ...*server.Peer) *Peers { + return &Peers{ + accountManager: &mock_server.MockAccountManager{ + GetAccountByUserOrAccountIdFunc: func(userId, accountId, domain string) (*server.Account, error) { + return &server.Account{ + Id: accountId, + Domain: "hotmail.com", + Peers: map[string]*server.Peer{ + "test_peer": peer[0], + }, + }, nil + }, + }, + authAudience: "", + jwtExtractor: JWTClaimsExtractor{ + extractClaimsFromRequestContext: func(r *http.Request, authAudiance string) JWTClaims { + return JWTClaims{ + UserId: "test_user", + Domain: "hotmail.com", + AccountId: "test_id", + } + }, + }, + } +} + +// Tests the GetPeers endpoint reachable in the route /api/peers +// Use the metadata generated by initTestMetaData() to check for values +func TestGetPeers(t *testing.T) { + var tt = []struct { + name string + expectedStatus int + requestType string + requestPath string + requestBody io.Reader + }{ + {name: "GetPeersMetaData", requestType: http.MethodGet, requestPath: "/api/peers/", expectedStatus: http.StatusOK}, + } + + rr := httptest.NewRecorder() + peer := &server.Peer{ + Key: "key", + SetupKey: "setupkey", + IP: net.ParseIP("100.64.0.1"), + Status: &server.PeerStatus{}, + Name: "PeerName", + Meta: server.PeerSystemMeta{ + Hostname: "hostname", + GoOS: "GoOS", + Kernel: "kernel", + Core: "core", + Platform: "platform", + OS: "OS", + WtVersion: "development", + }, + } + + p := initTestMetaData(peer) + + for _, tc := range tt { + t.Run(tc.name, func(t *testing.T) { + req := httptest.NewRequest(tc.requestType, tc.requestPath, tc.requestBody) + + p.GetPeers(rr, req) + + res := rr.Result() + defer res.Body.Close() + + if status := rr.Code; status != tc.expectedStatus { + t.Fatalf("handler returned wrong status code: got %v want %v", + status, http.StatusOK) + } + + content, err := io.ReadAll(res.Body) + if err != nil { + t.Fatalf("I don't know what I expected; %v", err) + } + + respBody := []*PeerResponse{} + err = json.Unmarshal(content, &respBody) + if err != nil { + t.Fatalf("Sent content is not in correct json format; %v", err) + } + + got := respBody[0] + assert.Equal(t, got.Name, peer.Name) + assert.Equal(t, got.Version, peer.Meta.WtVersion) + assert.Equal(t, got.IP, peer.IP.String()) + assert.Equal(t, got.OS, "OS core") + }) + } +} diff --git a/management/server/http/handler/setupkeys.go b/management/server/http/handler/setupkeys.go index 4bbf6ba2a..48e24e072 100644 --- a/management/server/http/handler/setupkeys.go +++ b/management/server/http/handler/setupkeys.go @@ -122,7 +122,8 @@ func (h *SetupKeys) createKey(accountId string, w http.ResponseWriter, r *http.R } func (h *SetupKeys) getSetupKeyAccount(r *http.Request) (*server.Account, error) { - jwtClaims := extractClaimsFromRequestContext(r, h.authAudience) + extractor := NewJWTClaimsExtractor(nil) + jwtClaims := extractor.extractClaimsFromRequestContext(r, h.authAudience) account, err := h.accountManager.GetAccountByUserOrAccountId(jwtClaims.UserId, jwtClaims.AccountId, jwtClaims.Domain) if err != nil { diff --git a/management/server/http/handler/util.go b/management/server/http/handler/util.go index af0ad6b28..1c7b63b53 100644 --- a/management/server/http/handler/util.go +++ b/management/server/http/handler/util.go @@ -3,9 +3,10 @@ package handler import ( "encoding/json" "errors" - "github.com/golang-jwt/jwt" "net/http" "time" + + "github.com/golang-jwt/jwt" ) // JWTClaims stores information from JWTs @@ -15,6 +16,25 @@ type JWTClaims struct { Domain string } +type extractJWTClaims func(r *http.Request, authAudiance string) JWTClaims + +type JWTClaimsExtractor struct { + extractClaimsFromRequestContext extractJWTClaims +} + +// NewJWTClaimsExtractor returns an extractor, and if provided with a function with extractJWTClaims signature, +// then it will use that logic. Uses extractClaimsFromRequestContext by default +func NewJWTClaimsExtractor(e extractJWTClaims) *JWTClaimsExtractor { + var extractFunc extractJWTClaims + if extractFunc = e; extractFunc == nil { + extractFunc = extractClaimsFromRequestContext + } + + return &JWTClaimsExtractor{ + extractClaimsFromRequestContext: extractFunc, + } +} + // extractClaimsFromRequestContext extracts claims from the request context previously filled by the JWT token (after auth) func extractClaimsFromRequestContext(r *http.Request, authAudiance string) JWTClaims { token := r.Context().Value("user").(*jwt.Token) @@ -34,7 +54,7 @@ func extractClaimsFromRequestContext(r *http.Request, authAudiance string) JWTCl //writeJSONObject simply writes object to the HTTP reponse in JSON format func writeJSONObject(w http.ResponseWriter, obj interface{}) { - w.WriteHeader(200) + w.WriteHeader(http.StatusOK) w.Header().Set("Content-Type", "application/json; charset=UTF-8") err := json.NewEncoder(w).Encode(obj) if err != nil { diff --git a/management/server/idp/auth0_test.go b/management/server/idp/auth0_test.go index dc7ffb22e..a4b4b5dbe 100644 --- a/management/server/idp/auth0_test.go +++ b/management/server/idp/auth0_test.go @@ -3,13 +3,14 @@ package idp import ( "encoding/json" "fmt" - "github.com/golang-jwt/jwt" - "github.com/stretchr/testify/assert" "io/ioutil" "net/http" "strings" "testing" "time" + + "github.com/golang-jwt/jwt" + "github.com/stretchr/testify/assert" ) type mockHTTPClient struct { diff --git a/management/server/mock_server/account_mock.go b/management/server/mock_server/account_mock.go new file mode 100644 index 000000000..d3907e72d --- /dev/null +++ b/management/server/mock_server/account_mock.go @@ -0,0 +1,139 @@ +package mock_server + +import ( + "github.com/wiretrustee/wiretrustee/management/server" + "github.com/wiretrustee/wiretrustee/util" + "google.golang.org/grpc/codes" + "google.golang.org/grpc/status" +) + +type MockAccountManager struct { + GetOrCreateAccountByUserFunc func(userId, domain string) (*server.Account, error) + GetAccountByUserFunc func(userId string) (*server.Account, error) + AddSetupKeyFunc func(accountId string, keyName string, keyType server.SetupKeyType, expiresIn *util.Duration) (*server.SetupKey, error) + RevokeSetupKeyFunc func(accountId string, keyId string) (*server.SetupKey, error) + RenameSetupKeyFunc func(accountId string, keyId string, newName string) (*server.SetupKey, error) + GetAccountByIdFunc func(accountId string) (*server.Account, error) + GetAccountByUserOrAccountIdFunc func(userId, accountId, domain string) (*server.Account, error) + AccountExistsFunc func(accountId string) (*bool, error) + AddAccountFunc func(accountId, userId, domain string) (*server.Account, error) + GetPeerFunc func(peerKey string) (*server.Peer, error) + MarkPeerConnectedFunc func(peerKey string, connected bool) error + RenamePeerFunc func(accountId string, peerKey string, newName string) (*server.Peer, error) + DeletePeerFunc func(accountId string, peerKey string) (*server.Peer, error) + GetPeerByIPFunc func(accountId string, peerIP string) (*server.Peer, error) + GetNetworkMapFunc func(peerKey string) (*server.NetworkMap, error) + AddPeerFunc func(setupKey string, peer *server.Peer) (*server.Peer, error) +} + +func (am *MockAccountManager) GetOrCreateAccountByUser(userId, domain string) (*server.Account, error) { + if am.GetOrCreateAccountByUserFunc != nil { + return am.GetOrCreateAccountByUserFunc(userId, domain) + } + return nil, status.Errorf(codes.Unimplemented, "method GetOrCreateAccountByUser not implemented") +} + +func (am *MockAccountManager) GetAccountByUser(userId string) (*server.Account, error) { + if am.GetAccountByUserFunc != nil { + return am.GetAccountByUserFunc(userId) + } + return nil, status.Errorf(codes.Unimplemented, "method GetAccountByUser not implemented") +} + +func (am *MockAccountManager) AddSetupKey(accountId string, keyName string, keyType server.SetupKeyType, expiresIn *util.Duration) (*server.SetupKey, error) { + if am.AddSetupKeyFunc != nil { + return am.AddSetupKeyFunc(accountId, keyName, keyType, expiresIn) + } + return nil, status.Errorf(codes.Unimplemented, "method AddSetupKey not implemented") +} + +func (am *MockAccountManager) RevokeSetupKey(accountId string, keyId string) (*server.SetupKey, error) { + if am.RevokeSetupKeyFunc != nil { + return am.RevokeSetupKeyFunc(accountId, keyId) + } + return nil, status.Errorf(codes.Unimplemented, "method RevokeSetupKey not implemented") +} + +func (am *MockAccountManager) RenameSetupKey(accountId string, keyId string, newName string) (*server.SetupKey, error) { + if am.RenameSetupKeyFunc != nil { + return am.RenameSetupKeyFunc(accountId, keyId, newName) + } + return nil, status.Errorf(codes.Unimplemented, "method RenameSetupKey not implemented") +} + +func (am *MockAccountManager) GetAccountById(accountId string) (*server.Account, error) { + if am.GetAccountByIdFunc != nil { + return am.GetAccountByIdFunc(accountId) + } + return nil, status.Errorf(codes.Unimplemented, "method GetAccountById not implemented") +} + +func (am *MockAccountManager) GetAccountByUserOrAccountId(userId, accountId, domain string) (*server.Account, error) { + if am.GetAccountByUserOrAccountIdFunc != nil { + return am.GetAccountByUserOrAccountIdFunc(userId, accountId, domain) + } + return nil, status.Errorf(codes.Unimplemented, "method GetAccountByUserOrAccountId not implemented") +} + +func (am *MockAccountManager) AccountExists(accountId string) (*bool, error) { + if am.AccountExistsFunc != nil { + return am.AccountExistsFunc(accountId) + } + return nil, status.Errorf(codes.Unimplemented, "method AccountExists not implemented") +} + +func (am *MockAccountManager) AddAccount(accountId, userId, domain string) (*server.Account, error) { + if am.AddAccountFunc != nil { + return am.AddAccountFunc(accountId, userId, domain) + } + return nil, status.Errorf(codes.Unimplemented, "method AddAccount not implemented") +} + +func (am *MockAccountManager) GetPeer(peerKey string) (*server.Peer, error) { + if am.GetPeerFunc != nil { + return am.GetPeerFunc(peerKey) + } + return nil, status.Errorf(codes.Unimplemented, "method GetPeer not implemented") +} + +func (am *MockAccountManager) MarkPeerConnected(peerKey string, connected bool) error { + if am.MarkPeerConnectedFunc != nil { + return am.MarkPeerConnectedFunc(peerKey, connected) + } + return status.Errorf(codes.Unimplemented, "method MarkPeerConnected not implemented") +} + +func (am *MockAccountManager) RenamePeer(accountId string, peerKey string, newName string) (*server.Peer, error) { + if am.RenamePeerFunc != nil { + return am.RenamePeerFunc(accountId, peerKey, newName) + } + return nil, status.Errorf(codes.Unimplemented, "method RenamePeer not implemented") +} + +func (am *MockAccountManager) DeletePeer(accountId string, peerKey string) (*server.Peer, error) { + if am.DeletePeerFunc != nil { + return am.DeletePeerFunc(accountId, peerKey) + } + return nil, status.Errorf(codes.Unimplemented, "method DeletePeer not implemented") +} + +func (am *MockAccountManager) GetPeerByIP(accountId string, peerIP string) (*server.Peer, error) { + if am.GetPeerByIPFunc != nil { + return am.GetPeerByIPFunc(accountId, peerIP) + } + return nil, status.Errorf(codes.Unimplemented, "method GetPeerByIP not implemented") +} + +func (am *MockAccountManager) GetNetworkMap(peerKey string) (*server.NetworkMap, error) { + if am.GetNetworkMapFunc != nil { + return am.GetNetworkMapFunc(peerKey) + } + return nil, status.Errorf(codes.Unimplemented, "method GetNetworkMap not implemented") +} + +func (am *MockAccountManager) AddPeer(setupKey string, peer *server.Peer) (*server.Peer, error) { + if am.AddPeerFunc != nil { + return am.AddPeerFunc(setupKey, peer) + } + return nil, status.Errorf(codes.Unimplemented, "method AddPeer not implemented") +} diff --git a/management/server/management_server_mock.go b/management/server/mock_server/management_server_mock.go similarity index 98% rename from management/server/management_server_mock.go rename to management/server/mock_server/management_server_mock.go index 8ec077894..181aed188 100644 --- a/management/server/management_server_mock.go +++ b/management/server/mock_server/management_server_mock.go @@ -1,4 +1,4 @@ -package server +package mock_server import ( "context"