mirror of
https://github.com/netbirdio/netbird.git
synced 2024-11-25 17:43:38 +01:00
27ed88f918
Instead of GetAllConnectedPeers that need to traverse the whole connections map in order to find one channel there.
248 lines
6.7 KiB
Go
248 lines
6.7 KiB
Go
package http
|
|
|
|
import (
|
|
"bytes"
|
|
"encoding/json"
|
|
"io"
|
|
"net"
|
|
"net/http"
|
|
"net/http/httptest"
|
|
"testing"
|
|
"time"
|
|
|
|
"github.com/gorilla/mux"
|
|
|
|
"github.com/netbirdio/netbird/management/server/http/api"
|
|
nbpeer "github.com/netbirdio/netbird/management/server/peer"
|
|
|
|
"github.com/netbirdio/netbird/management/server/jwtclaims"
|
|
|
|
"github.com/magiconair/properties/assert"
|
|
|
|
"github.com/netbirdio/netbird/management/server"
|
|
"github.com/netbirdio/netbird/management/server/mock_server"
|
|
)
|
|
|
|
const testPeerID = "test_peer"
|
|
const noUpdateChannelTestPeerID = "no-update-channel"
|
|
|
|
func initTestMetaData(peers ...*nbpeer.Peer) *PeersHandler {
|
|
return &PeersHandler{
|
|
accountManager: &mock_server.MockAccountManager{
|
|
UpdatePeerFunc: func(accountID, userID string, update *nbpeer.Peer) (*nbpeer.Peer, error) {
|
|
var p *nbpeer.Peer
|
|
for _, peer := range peers {
|
|
if update.ID == peer.ID {
|
|
p = peer.Copy()
|
|
break
|
|
}
|
|
}
|
|
p.SSHEnabled = update.SSHEnabled
|
|
p.LoginExpirationEnabled = update.LoginExpirationEnabled
|
|
p.Name = update.Name
|
|
return p, nil
|
|
},
|
|
GetPeerFunc: func(accountID, peerID, userID string) (*nbpeer.Peer, error) {
|
|
var p *nbpeer.Peer
|
|
for _, peer := range peers {
|
|
if peerID == peer.ID {
|
|
p = peer.Copy()
|
|
break
|
|
}
|
|
}
|
|
return p, nil
|
|
},
|
|
GetPeersFunc: func(accountID, userID string) ([]*nbpeer.Peer, error) {
|
|
return peers, nil
|
|
},
|
|
GetAccountFromTokenFunc: func(claims jwtclaims.AuthorizationClaims) (*server.Account, *server.User, error) {
|
|
user := server.NewAdminUser("test_user")
|
|
return &server.Account{
|
|
Id: claims.AccountId,
|
|
Domain: "hotmail.com",
|
|
Peers: map[string]*nbpeer.Peer{
|
|
peers[0].ID: peers[0],
|
|
peers[1].ID: peers[1],
|
|
},
|
|
Users: map[string]*server.User{
|
|
"test_user": user,
|
|
},
|
|
Settings: &server.Settings{
|
|
PeerLoginExpirationEnabled: true,
|
|
PeerLoginExpiration: time.Hour,
|
|
},
|
|
Network: &server.Network{
|
|
Identifier: "ciclqisab2ss43jdn8q0",
|
|
Net: net.IPNet{
|
|
IP: net.ParseIP("100.67.0.0"),
|
|
Mask: net.IPv4Mask(255, 255, 0, 0),
|
|
},
|
|
Serial: 51,
|
|
},
|
|
}, user, nil
|
|
},
|
|
HasConnectedChannelFunc: func(peerID string) bool {
|
|
statuses := make(map[string]struct{})
|
|
for _, peer := range peers {
|
|
if peer.ID == noUpdateChannelTestPeerID {
|
|
break
|
|
}
|
|
statuses[peer.ID] = struct{}{}
|
|
}
|
|
_, ok := statuses[peerID]
|
|
return ok
|
|
},
|
|
},
|
|
claimsExtractor: jwtclaims.NewClaimsExtractor(
|
|
jwtclaims.WithFromRequestContext(func(r *http.Request) jwtclaims.AuthorizationClaims {
|
|
return jwtclaims.AuthorizationClaims{
|
|
UserId: "test_user",
|
|
Domain: "hotmail.com",
|
|
AccountId: "test_id",
|
|
}
|
|
}),
|
|
),
|
|
}
|
|
}
|
|
|
|
// Tests the GetAllPeers endpoint reachable in the route /api/peers
|
|
// Use the metadata generated by initTestMetaData() to check for values
|
|
func TestGetPeers(t *testing.T) {
|
|
|
|
peer := &nbpeer.Peer{
|
|
ID: testPeerID,
|
|
Key: "key",
|
|
SetupKey: "setupkey",
|
|
IP: net.ParseIP("100.64.0.1"),
|
|
Status: &nbpeer.PeerStatus{Connected: true},
|
|
Name: "PeerName",
|
|
LoginExpirationEnabled: false,
|
|
Meta: nbpeer.PeerSystemMeta{
|
|
Hostname: "hostname",
|
|
GoOS: "GoOS",
|
|
Kernel: "kernel",
|
|
Core: "core",
|
|
Platform: "platform",
|
|
OS: "OS",
|
|
WtVersion: "development",
|
|
},
|
|
}
|
|
|
|
peer1 := peer.Copy()
|
|
peer1.ID = noUpdateChannelTestPeerID
|
|
|
|
expectedUpdatedPeer := peer.Copy()
|
|
expectedUpdatedPeer.LoginExpirationEnabled = true
|
|
expectedUpdatedPeer.SSHEnabled = true
|
|
expectedUpdatedPeer.Name = "New Name"
|
|
|
|
expectedPeer1 := peer1.Copy()
|
|
expectedPeer1.Status.Connected = false
|
|
|
|
tt := []struct {
|
|
name string
|
|
expectedStatus int
|
|
requestType string
|
|
requestPath string
|
|
requestBody io.Reader
|
|
expectedArray bool
|
|
expectedPeer *nbpeer.Peer
|
|
}{
|
|
{
|
|
name: "GetPeersMetaData",
|
|
requestType: http.MethodGet,
|
|
requestPath: "/api/peers/",
|
|
expectedStatus: http.StatusOK,
|
|
expectedArray: true,
|
|
expectedPeer: peer,
|
|
},
|
|
{
|
|
name: "GetPeer with update channel",
|
|
requestType: http.MethodGet,
|
|
requestPath: "/api/peers/" + testPeerID,
|
|
expectedStatus: http.StatusOK,
|
|
expectedArray: false,
|
|
expectedPeer: peer,
|
|
},
|
|
{
|
|
name: "GetPeer with no update channel",
|
|
requestType: http.MethodGet,
|
|
requestPath: "/api/peers/" + peer1.ID,
|
|
expectedStatus: http.StatusOK,
|
|
expectedArray: false,
|
|
expectedPeer: expectedPeer1,
|
|
},
|
|
{
|
|
name: "PutPeer",
|
|
requestType: http.MethodPut,
|
|
requestPath: "/api/peers/" + testPeerID,
|
|
expectedStatus: http.StatusOK,
|
|
expectedArray: false,
|
|
requestBody: bytes.NewBufferString("{\"login_expiration_enabled\":true,\"name\":\"New Name\",\"ssh_enabled\":true}"),
|
|
expectedPeer: expectedUpdatedPeer,
|
|
},
|
|
}
|
|
|
|
rr := httptest.NewRecorder()
|
|
|
|
p := initTestMetaData(peer, peer1)
|
|
|
|
for _, tc := range tt {
|
|
t.Run(tc.name, func(t *testing.T) {
|
|
|
|
recorder := httptest.NewRecorder()
|
|
req := httptest.NewRequest(tc.requestType, tc.requestPath, tc.requestBody)
|
|
|
|
router := mux.NewRouter()
|
|
router.HandleFunc("/api/peers/", p.GetAllPeers).Methods("GET")
|
|
router.HandleFunc("/api/peers/{peerId}", p.HandlePeer).Methods("GET")
|
|
router.HandleFunc("/api/peers/{peerId}", p.HandlePeer).Methods("PUT")
|
|
router.ServeHTTP(recorder, req)
|
|
|
|
res := recorder.Result()
|
|
defer res.Body.Close()
|
|
|
|
if status := rr.Code; status != tc.expectedStatus {
|
|
t.Fatalf("handler returned wrong status code: got %v want %v",
|
|
status, http.StatusOK)
|
|
}
|
|
|
|
content, err := io.ReadAll(res.Body)
|
|
if err != nil {
|
|
t.Fatalf("I don't know what I expected; %v", err)
|
|
}
|
|
|
|
var got *api.Peer
|
|
if tc.expectedArray {
|
|
respBody := []*api.Peer{}
|
|
err = json.Unmarshal(content, &respBody)
|
|
if err != nil {
|
|
t.Fatalf("Sent content is not in correct json format; %v", err)
|
|
}
|
|
|
|
// hardcode this check for now as we only have two peers in this suite
|
|
assert.Equal(t, len(respBody), 2)
|
|
assert.Equal(t, respBody[1].Connected, false)
|
|
|
|
got = respBody[0]
|
|
} else {
|
|
got = &api.Peer{}
|
|
err = json.Unmarshal(content, got)
|
|
if err != nil {
|
|
t.Fatalf("Sent content is not in correct json format; %v", err)
|
|
}
|
|
}
|
|
|
|
t.Log(got)
|
|
|
|
assert.Equal(t, got.Name, tc.expectedPeer.Name)
|
|
assert.Equal(t, got.Version, tc.expectedPeer.Meta.WtVersion)
|
|
assert.Equal(t, got.Ip, tc.expectedPeer.IP.String())
|
|
assert.Equal(t, got.Os, "OS core")
|
|
assert.Equal(t, got.LoginExpirationEnabled, tc.expectedPeer.LoginExpirationEnabled)
|
|
assert.Equal(t, got.SshEnabled, tc.expectedPeer.SSHEnabled)
|
|
assert.Equal(t, got.Connected, tc.expectedPeer.Status.Connected)
|
|
})
|
|
}
|
|
}
|