netbird/management/server/http/peers_handler_test.go
Yury Gargay 659110f0d5
Rework peer connection status based on the update channel existence (#1213)
With this change, we don't need to update all peers on startup. We will
check the existence of an update channel when returning a list or single peer on API.
Then after restarting of server consumers of API will see peer not
connected status till the creation of an updated channel which indicates
peer successful connection.
2023-10-11 18:11:45 +02:00

238 lines
6.4 KiB
Go

package http
import (
"bytes"
"encoding/json"
"fmt"
"io"
"net"
"net/http"
"net/http/httptest"
"testing"
"time"
"github.com/gorilla/mux"
"github.com/netbirdio/netbird/management/server/http/api"
"github.com/netbirdio/netbird/management/server/jwtclaims"
"github.com/magiconair/properties/assert"
"github.com/netbirdio/netbird/management/server"
"github.com/netbirdio/netbird/management/server/mock_server"
)
const testPeerID = "test_peer"
const noUpdateChannelTestPeerID = "no-update-channel"
func initTestMetaData(peers ...*server.Peer) *PeersHandler {
return &PeersHandler{
accountManager: &mock_server.MockAccountManager{
UpdatePeerFunc: func(accountID, userID string, update *server.Peer) (*server.Peer, error) {
var p *server.Peer
for _, peer := range peers {
if update.ID == peer.ID {
p = peer.Copy()
break
}
}
p.SSHEnabled = update.SSHEnabled
p.LoginExpirationEnabled = update.LoginExpirationEnabled
p.Name = update.Name
return p, nil
},
GetPeerFunc: func(accountID, peerID, userID string) (*server.Peer, error) {
var p *server.Peer
for _, peer := range peers {
if peerID == peer.ID {
p = peer.Copy()
break
}
}
return p, nil
},
GetPeersFunc: func(accountID, userID string) ([]*server.Peer, error) {
return peers, nil
},
GetAccountFromTokenFunc: func(claims jwtclaims.AuthorizationClaims) (*server.Account, *server.User, error) {
user := server.NewAdminUser("test_user")
return &server.Account{
Id: claims.AccountId,
Domain: "hotmail.com",
Peers: map[string]*server.Peer{
peers[0].ID: peers[0],
},
Users: map[string]*server.User{
"test_user": user,
},
Settings: &server.Settings{
PeerLoginExpirationEnabled: true,
PeerLoginExpiration: time.Hour,
},
}, user, nil
},
GetAllConnectedPeersFunc: func() (map[string]struct{}, error) {
statuses := make(map[string]struct{})
for _, peer := range peers {
if peer.ID == noUpdateChannelTestPeerID {
break
}
statuses[peer.ID] = struct{}{}
}
return statuses, nil
},
},
claimsExtractor: jwtclaims.NewClaimsExtractor(
jwtclaims.WithFromRequestContext(func(r *http.Request) jwtclaims.AuthorizationClaims {
return jwtclaims.AuthorizationClaims{
UserId: "test_user",
Domain: "hotmail.com",
AccountId: "test_id",
}
}),
),
}
}
// Tests the GetAllPeers endpoint reachable in the route /api/peers
// Use the metadata generated by initTestMetaData() to check for values
func TestGetPeers(t *testing.T) {
peer := &server.Peer{
ID: testPeerID,
Key: "key",
SetupKey: "setupkey",
IP: net.ParseIP("100.64.0.1"),
Status: &server.PeerStatus{Connected: true},
Name: "PeerName",
LoginExpirationEnabled: false,
Meta: server.PeerSystemMeta{
Hostname: "hostname",
GoOS: "GoOS",
Kernel: "kernel",
Core: "core",
Platform: "platform",
OS: "OS",
WtVersion: "development",
},
}
peer1 := peer.Copy()
peer1.ID = noUpdateChannelTestPeerID
expectedUpdatedPeer := peer.Copy()
expectedUpdatedPeer.LoginExpirationEnabled = true
expectedUpdatedPeer.SSHEnabled = true
expectedUpdatedPeer.Name = "New Name"
expectedPeer1 := peer1.Copy()
expectedPeer1.Status.Connected = false
tt := []struct {
name string
expectedStatus int
requestType string
requestPath string
requestBody io.Reader
expectedArray bool
expectedPeer *server.Peer
}{
{
name: "GetPeersMetaData",
requestType: http.MethodGet,
requestPath: "/api/peers/",
expectedStatus: http.StatusOK,
expectedArray: true,
expectedPeer: peer,
},
{
name: "GetPeer with update channel",
requestType: http.MethodGet,
requestPath: "/api/peers/" + testPeerID,
expectedStatus: http.StatusOK,
expectedArray: false,
expectedPeer: peer,
},
{
name: "GetPeer with no update channel",
requestType: http.MethodGet,
requestPath: "/api/peers/" + peer1.ID,
expectedStatus: http.StatusOK,
expectedArray: false,
expectedPeer: expectedPeer1,
},
{
name: "PutPeer",
requestType: http.MethodPut,
requestPath: "/api/peers/" + testPeerID,
expectedStatus: http.StatusOK,
expectedArray: false,
requestBody: bytes.NewBufferString("{\"login_expiration_enabled\":true,\"name\":\"New Name\",\"ssh_enabled\":true}"),
expectedPeer: expectedUpdatedPeer,
},
}
rr := httptest.NewRecorder()
p := initTestMetaData(peer, peer1)
for _, tc := range tt {
t.Run(tc.name, func(t *testing.T) {
recorder := httptest.NewRecorder()
req := httptest.NewRequest(tc.requestType, tc.requestPath, tc.requestBody)
router := mux.NewRouter()
router.HandleFunc("/api/peers/", p.GetAllPeers).Methods("GET")
router.HandleFunc("/api/peers/{peerId}", p.HandlePeer).Methods("GET")
router.HandleFunc("/api/peers/{peerId}", p.HandlePeer).Methods("PUT")
router.ServeHTTP(recorder, req)
res := recorder.Result()
defer res.Body.Close()
if status := rr.Code; status != tc.expectedStatus {
t.Fatalf("handler returned wrong status code: got %v want %v",
status, http.StatusOK)
}
content, err := io.ReadAll(res.Body)
if err != nil {
t.Fatalf("I don't know what I expected; %v", err)
}
var got *api.Peer
if tc.expectedArray {
respBody := []*api.Peer{}
err = json.Unmarshal(content, &respBody)
if err != nil {
t.Fatalf("Sent content is not in correct json format; %v", err)
}
// hardcode this check for now as we only have two peers in this suite
assert.Equal(t, len(respBody), 2)
assert.Equal(t, respBody[1].Connected, false)
got = respBody[0]
} else {
got = &api.Peer{}
err = json.Unmarshal(content, got)
if err != nil {
t.Fatalf("Sent content is not in correct json format; %v", err)
}
}
fmt.Println(got)
assert.Equal(t, got.Name, tc.expectedPeer.Name)
assert.Equal(t, got.Version, tc.expectedPeer.Meta.WtVersion)
assert.Equal(t, got.Ip, tc.expectedPeer.IP.String())
assert.Equal(t, got.Os, "OS core")
assert.Equal(t, got.LoginExpirationEnabled, tc.expectedPeer.LoginExpirationEnabled)
assert.Equal(t, got.SshEnabled, tc.expectedPeer.SSHEnabled)
assert.Equal(t, got.Connected, tc.expectedPeer.Status.Connected)
})
}
}