netbird/management/server/http/peers_handler_test.go

247 lines
6.6 KiB
Go
Raw Normal View History

package http
2022-02-22 18:18:05 +01:00
import (
"bytes"
2022-02-22 18:18:05 +01:00
"encoding/json"
"fmt"
2022-02-22 18:18:05 +01:00
"io"
"net"
"net/http"
"net/http/httptest"
"testing"
"time"
2022-02-22 18:18:05 +01:00
"github.com/gorilla/mux"
"github.com/netbirdio/netbird/management/server/http/api"
"github.com/netbirdio/netbird/management/server/jwtclaims"
2022-02-22 18:18:05 +01:00
"github.com/magiconair/properties/assert"
"github.com/netbirdio/netbird/management/server"
"github.com/netbirdio/netbird/management/server/mock_server"
2022-02-22 18:18:05 +01:00
)
2023-02-07 20:11:08 +01:00
const testPeerID = "test_peer"
const noUpdateChannelTestPeerID = "no-update-channel"
2023-02-07 20:11:08 +01:00
func initTestMetaData(peers ...*server.Peer) *PeersHandler {
return &PeersHandler{
2022-02-22 18:18:05 +01:00
accountManager: &mock_server.MockAccountManager{
UpdatePeerFunc: func(accountID, userID string, update *server.Peer) (*server.Peer, error) {
var p *server.Peer
for _, peer := range peers {
if update.ID == peer.ID {
p = peer.Copy()
break
}
}
p.SSHEnabled = update.SSHEnabled
p.LoginExpirationEnabled = update.LoginExpirationEnabled
p.Name = update.Name
return p, nil
},
2023-02-07 20:11:08 +01:00
GetPeerFunc: func(accountID, peerID, userID string) (*server.Peer, error) {
var p *server.Peer
for _, peer := range peers {
if peerID == peer.ID {
p = peer.Copy()
break
}
}
return p, nil
2023-02-07 20:11:08 +01:00
},
2022-11-05 10:24:50 +01:00
GetPeersFunc: func(accountID, userID string) ([]*server.Peer, error) {
return peers, nil
},
GetAccountFromTokenFunc: func(claims jwtclaims.AuthorizationClaims) (*server.Account, *server.User, error) {
user := server.NewAdminUser("test_user")
2022-02-22 18:18:05 +01:00
return &server.Account{
Id: claims.AccountId,
2022-02-22 18:18:05 +01:00
Domain: "hotmail.com",
Peers: map[string]*server.Peer{
2023-02-07 20:11:08 +01:00
peers[0].ID: peers[0],
2022-11-05 10:24:50 +01:00
},
Users: map[string]*server.User{
"test_user": user,
2022-02-22 18:18:05 +01:00
},
Settings: &server.Settings{
PeerLoginExpirationEnabled: true,
PeerLoginExpiration: time.Hour,
},
Network: &server.Network{
Identifier: "ciclqisab2ss43jdn8q0",
Net: net.IPNet{
IP: net.ParseIP("100.67.0.0"),
Mask: net.IPv4Mask(255, 255, 0, 0),
},
Serial: 51,
},
}, user, nil
2022-02-22 18:18:05 +01:00
},
GetAllConnectedPeersFunc: func() (map[string]struct{}, error) {
statuses := make(map[string]struct{})
for _, peer := range peers {
if peer.ID == noUpdateChannelTestPeerID {
break
}
statuses[peer.ID] = struct{}{}
}
return statuses, nil
},
2022-02-22 18:18:05 +01:00
},
claimsExtractor: jwtclaims.NewClaimsExtractor(
jwtclaims.WithFromRequestContext(func(r *http.Request) jwtclaims.AuthorizationClaims {
return jwtclaims.AuthorizationClaims{
2022-02-22 18:18:05 +01:00
UserId: "test_user",
Domain: "hotmail.com",
AccountId: "test_id",
}
}),
),
2022-02-22 18:18:05 +01:00
}
}
// Tests the GetAllPeers endpoint reachable in the route /api/peers
2022-02-22 18:18:05 +01:00
// Use the metadata generated by initTestMetaData() to check for values
func TestGetPeers(t *testing.T) {
peer := &server.Peer{
ID: testPeerID,
Key: "key",
SetupKey: "setupkey",
IP: net.ParseIP("100.64.0.1"),
Status: &server.PeerStatus{Connected: true},
Name: "PeerName",
LoginExpirationEnabled: false,
Meta: server.PeerSystemMeta{
Hostname: "hostname",
GoOS: "GoOS",
Kernel: "kernel",
Core: "core",
Platform: "platform",
OS: "OS",
WtVersion: "development",
},
}
peer1 := peer.Copy()
peer1.ID = noUpdateChannelTestPeerID
expectedUpdatedPeer := peer.Copy()
expectedUpdatedPeer.LoginExpirationEnabled = true
expectedUpdatedPeer.SSHEnabled = true
expectedUpdatedPeer.Name = "New Name"
expectedPeer1 := peer1.Copy()
expectedPeer1.Status.Connected = false
tt := []struct {
2022-02-22 18:18:05 +01:00
name string
expectedStatus int
requestType string
requestPath string
requestBody io.Reader
2023-02-07 20:11:08 +01:00
expectedArray bool
expectedPeer *server.Peer
2022-02-22 18:18:05 +01:00
}{
{
name: "GetPeersMetaData",
requestType: http.MethodGet,
requestPath: "/api/peers/",
expectedStatus: http.StatusOK,
2023-02-07 20:11:08 +01:00
expectedArray: true,
expectedPeer: peer,
2023-02-07 20:11:08 +01:00
},
{
name: "GetPeer with update channel",
2023-02-07 20:11:08 +01:00
requestType: http.MethodGet,
requestPath: "/api/peers/" + testPeerID,
expectedStatus: http.StatusOK,
expectedArray: false,
expectedPeer: peer,
},
{
name: "GetPeer with no update channel",
requestType: http.MethodGet,
requestPath: "/api/peers/" + peer1.ID,
expectedStatus: http.StatusOK,
expectedArray: false,
expectedPeer: expectedPeer1,
},
{
name: "PutPeer",
requestType: http.MethodPut,
requestPath: "/api/peers/" + testPeerID,
expectedStatus: http.StatusOK,
expectedArray: false,
requestBody: bytes.NewBufferString("{\"login_expiration_enabled\":true,\"name\":\"New Name\",\"ssh_enabled\":true}"),
expectedPeer: expectedUpdatedPeer,
},
2022-02-22 18:18:05 +01:00
}
rr := httptest.NewRecorder()
p := initTestMetaData(peer, peer1)
2022-02-22 18:18:05 +01:00
for _, tc := range tt {
t.Run(tc.name, func(t *testing.T) {
2023-02-07 20:11:08 +01:00
recorder := httptest.NewRecorder()
2022-02-22 18:18:05 +01:00
req := httptest.NewRequest(tc.requestType, tc.requestPath, tc.requestBody)
2023-02-07 20:11:08 +01:00
router := mux.NewRouter()
router.HandleFunc("/api/peers/", p.GetAllPeers).Methods("GET")
router.HandleFunc("/api/peers/{peerId}", p.HandlePeer).Methods("GET")
router.HandleFunc("/api/peers/{peerId}", p.HandlePeer).Methods("PUT")
2023-02-07 20:11:08 +01:00
router.ServeHTTP(recorder, req)
2022-02-22 18:18:05 +01:00
2023-02-07 20:11:08 +01:00
res := recorder.Result()
2022-02-22 18:18:05 +01:00
defer res.Body.Close()
if status := rr.Code; status != tc.expectedStatus {
t.Fatalf("handler returned wrong status code: got %v want %v",
status, http.StatusOK)
}
content, err := io.ReadAll(res.Body)
if err != nil {
t.Fatalf("I don't know what I expected; %v", err)
}
2023-02-07 20:11:08 +01:00
var got *api.Peer
if tc.expectedArray {
respBody := []*api.Peer{}
err = json.Unmarshal(content, &respBody)
if err != nil {
t.Fatalf("Sent content is not in correct json format; %v", err)
}
// hardcode this check for now as we only have two peers in this suite
assert.Equal(t, len(respBody), 2)
assert.Equal(t, respBody[1].Connected, false)
2023-02-07 20:11:08 +01:00
got = respBody[0]
} else {
got = &api.Peer{}
err = json.Unmarshal(content, got)
if err != nil {
t.Fatalf("Sent content is not in correct json format; %v", err)
}
2022-02-22 18:18:05 +01:00
}
fmt.Println(got)
assert.Equal(t, got.Name, tc.expectedPeer.Name)
assert.Equal(t, got.Version, tc.expectedPeer.Meta.WtVersion)
assert.Equal(t, got.Ip, tc.expectedPeer.IP.String())
assert.Equal(t, got.Os, "OS core")
assert.Equal(t, got.LoginExpirationEnabled, tc.expectedPeer.LoginExpirationEnabled)
assert.Equal(t, got.SshEnabled, tc.expectedPeer.SSHEnabled)
assert.Equal(t, got.Connected, tc.expectedPeer.Status.Connected)
2022-02-22 18:18:05 +01:00
})
}
}