bcmmbaga e326cc7412
Merge branch 'main' into routes-get-account-refactoring
# Conflicts:
#	.github/workflows/golang-test-linux.yml
#	go.mod
#	go.sum
#	management/server/account.go
#	management/server/account_test.go
#	management/server/http/handler.go
#	management/server/http/handlers/peers/peers_handler.go
#	management/server/http/middleware/auth_middleware.go
#	management/server/http/middleware/auth_middleware_test.go
#	management/server/http/testing/benchmarks/peers_handler_benchmark_test.go
#	management/server/http/testing/benchmarks/users_handler_benchmark_test.go
#	management/server/integrated_validator.go
#	management/server/mock_server/account_mock.go
#	management/server/peer.go
#	management/server/status/error.go
#	management/server/store/sql_store.go
#	management/server/store/sql_store_test.go
#	management/server/user.go
2025-02-24 13:39:38 +00:00

441 lines
12 KiB

package peers
import (
nbcontext ""
nbpeer ""
const (
testPeerID = "test_peer"
noUpdateChannelTestPeerID = "no-update-channel"
adminUser = "admin_user"
regularUser = "regular_user"
serviceUser = "service_user"
func initTestMetaData(peers ...*nbpeer.Peer) *Handler {
peersMap := make(map[string]*nbpeer.Peer)
for _, peer := range peers {
peersMap[peer.ID] = peer.Copy()
policy := &types.Policy{
ID: "policy",
AccountID: "test_id",
Name: "policy",
Enabled: true,
Rules: []*types.PolicyRule{
ID: "rule",
Name: "rule",
Enabled: true,
Action: "accept",
Destinations: []string{"group1"},
Sources: []string{"group1"},
Bidirectional: true,
Protocol: "all",
Ports: []string{"80"},
srvUser := types.NewRegularUser(serviceUser)
srvUser.IsServiceUser = true
account := &types.Account{
Id: "test_id",
Domain: "",
Peers: peersMap,
Users: map[string]*types.User{
adminUser: types.NewAdminUser(adminUser),
regularUser: types.NewRegularUser(regularUser),
serviceUser: srvUser,
Groups: map[string]*types.Group{
"group1": {
ID: "group1",
AccountID: "test_id",
Name: "group1",
Issued: "api",
Peers: maps.Keys(peersMap),
Settings: &types.Settings{
PeerLoginExpirationEnabled: true,
PeerLoginExpiration: time.Hour,
Policies: []*types.Policy{policy},
Network: &types.Network{
Identifier: "ciclqisab2ss43jdn8q0",
Net: net.IPNet{
IP: net.ParseIP(""),
Mask: net.IPv4Mask(255, 255, 0, 0),
Serial: 51,
return &Handler{
accountManager: &mock_server.MockAccountManager{
UpdatePeerFunc: func(_ context.Context, accountID, userID string, update *nbpeer.Peer) (*nbpeer.Peer, error) {
var p *nbpeer.Peer
for _, peer := range peers {
if update.ID == peer.ID {
p = peer.Copy()
p.SSHEnabled = update.SSHEnabled
p.LoginExpirationEnabled = update.LoginExpirationEnabled
p.Name = update.Name
return p, nil
GetPeerFunc: func(_ context.Context, accountID, peerID, userID string) (*nbpeer.Peer, error) {
var p *nbpeer.Peer
for _, peer := range peers {
if peerID == peer.ID {
p = peer.Copy()
return p, nil
GetPeersFunc: func(_ context.Context, accountID, userID string) ([]*nbpeer.Peer, error) {
return peers, nil
GetPeerGroupsFunc: func(ctx context.Context, accountID, peerID string) ([]*types.Group, error) {
peersID := make([]string, len(peers))
for _, peer := range peers {
peersID = append(peersID, peer.ID)
return []*types.Group{
ID: "group1",
AccountID: accountID,
Name: "group1",
Issued: "api",
Peers: peersID,
}, nil
GetDNSDomainFunc: func() string {
return "netbird.selfhosted"
GetAccountFunc: func(ctx context.Context, accountID string) (*types.Account, error) {
return account, nil
GetAccountFunc: func(ctx context.Context, accountID string) (*types.Account, error) {
return account, nil
GetAccountByIDFunc: func(ctx context.Context, accountID string, userID string) (*types.Account, error) {
return account, nil
HasConnectedChannelFunc: func(peerID string) bool {
statuses := make(map[string]struct{})
for _, peer := range peers {
if peer.ID == noUpdateChannelTestPeerID {
statuses[peer.ID] = struct{}{}
_, ok := statuses[peerID]
return ok
// Tests the GetAllPeers endpoint reachable in the route /api/peers
// Use the metadata generated by initTestMetaData() to check for values
func TestGetPeers(t *testing.T) {
peer := &nbpeer.Peer{
ID: testPeerID,
Key: "key",
IP: net.ParseIP(""),
Status: &nbpeer.PeerStatus{Connected: true},
Name: "PeerName",
LoginExpirationEnabled: false,
Meta: nbpeer.PeerSystemMeta{
Hostname: "hostname",
GoOS: "GoOS",
Kernel: "kernel",
Core: "core",
Platform: "platform",
OS: "OS",
WtVersion: "development",
SystemSerialNumber: "C02XJ0J0JGH7",
peer1 := peer.Copy()
peer1.ID = noUpdateChannelTestPeerID
expectedUpdatedPeer := peer.Copy()
expectedUpdatedPeer.LoginExpirationEnabled = true
expectedUpdatedPeer.SSHEnabled = true
expectedUpdatedPeer.Name = "New Name"
expectedPeer1 := peer1.Copy()
expectedPeer1.Status.Connected = false
tt := []struct {
name string
expectedStatus int
requestType string
requestPath string
requestBody io.Reader
expectedArray bool
expectedPeer *nbpeer.Peer
name: "GetPeersMetaData",
requestType: http.MethodGet,
requestPath: "/api/peers/",
expectedStatus: http.StatusOK,
expectedArray: true,
expectedPeer: peer,
name: "GetPeer with update channel",
requestType: http.MethodGet,
requestPath: "/api/peers/" + testPeerID,
expectedStatus: http.StatusOK,
expectedArray: false,
expectedPeer: peer,
name: "GetPeer with no update channel",
requestType: http.MethodGet,
requestPath: "/api/peers/" + peer1.ID,
expectedStatus: http.StatusOK,
expectedArray: false,
expectedPeer: expectedPeer1,
name: "PutPeer",
requestType: http.MethodPut,
requestPath: "/api/peers/" + testPeerID,
expectedStatus: http.StatusOK,
expectedArray: false,
requestBody: bytes.NewBufferString("{\"login_expiration_enabled\":true,\"name\":\"New Name\",\"ssh_enabled\":true}"),
expectedPeer: expectedUpdatedPeer,
rr := httptest.NewRecorder()
p := initTestMetaData(peer, peer1)
for _, tc := range tt {
t.Run(, func(t *testing.T) {
recorder := httptest.NewRecorder()
req := httptest.NewRequest(tc.requestType, tc.requestPath, tc.requestBody)
req = nbcontext.SetUserAuthInRequest(req, nbcontext.UserAuth{
UserId: "admin_user",
Domain: "",
AccountId: "test_id",
router := mux.NewRouter()
router.HandleFunc("/api/peers/", p.GetAllPeers).Methods("GET")
router.HandleFunc("/api/peers/{peerId}", p.HandlePeer).Methods("GET")
router.HandleFunc("/api/peers/{peerId}", p.HandlePeer).Methods("PUT")
router.ServeHTTP(recorder, req)
res := recorder.Result()
defer res.Body.Close()
if status := rr.Code; status != tc.expectedStatus {
t.Fatalf("handler returned wrong status code: got %v want %v",
status, http.StatusOK)
content, err := io.ReadAll(res.Body)
if err != nil {
t.Fatalf("I don't know what I expected; %v", err)
var got *api.Peer
if tc.expectedArray {
respBody := []*api.Peer{}
err = json.Unmarshal(content, &respBody)
if err != nil {
t.Fatalf("Sent content is not in correct json format; %v", err)
// hardcode this check for now as we only have two peers in this suite
assert.Equal(t, len(respBody), 2)
for _, peer := range respBody {
if peer.Id == testPeerID {
got = peer
} else {
assert.Equal(t, peer.Connected, false)
} else {
got = &api.Peer{}
err = json.Unmarshal(content, got)
if err != nil {
t.Fatalf("Sent content is not in correct json format; %v", err)
assert.Equal(t, got.Name, tc.expectedPeer.Name)
assert.Equal(t, got.Version, tc.expectedPeer.Meta.WtVersion)
assert.Equal(t, got.Ip, tc.expectedPeer.IP.String())
assert.Equal(t, got.Os, "OS core")
assert.Equal(t, got.LoginExpirationEnabled, tc.expectedPeer.LoginExpirationEnabled)
assert.Equal(t, got.SshEnabled, tc.expectedPeer.SSHEnabled)
assert.Equal(t, got.Connected, tc.expectedPeer.Status.Connected)
assert.Equal(t, got.SerialNumber, tc.expectedPeer.Meta.SystemSerialNumber)
func TestGetAccessiblePeers(t *testing.T) {
peer1 := &nbpeer.Peer{
ID: "peer1",
Key: "key1",
IP: net.ParseIP(""),
Status: &nbpeer.PeerStatus{Connected: true},
Name: "peer1",
LoginExpirationEnabled: false,
UserID: regularUser,
peer2 := &nbpeer.Peer{
ID: "peer2",
Key: "key2",
IP: net.ParseIP(""),
Status: &nbpeer.PeerStatus{Connected: true},
Name: "peer2",
LoginExpirationEnabled: false,
UserID: adminUser,
peer3 := &nbpeer.Peer{
ID: "peer3",
Key: "key3",
IP: net.ParseIP(""),
Status: &nbpeer.PeerStatus{Connected: true},
Name: "peer3",
LoginExpirationEnabled: false,
UserID: regularUser,
p := initTestMetaData(peer1, peer2, peer3)
tt := []struct {
name string
peerID string
callerUserID string
expectedStatus int
expectedPeers []string
name: "non admin user can access owned peer",
peerID: "peer1",
callerUserID: regularUser,
expectedStatus: http.StatusOK,
expectedPeers: []string{"peer2", "peer3"},
name: "non admin user can't access unowned peer",
peerID: "peer2",
callerUserID: regularUser,
expectedStatus: http.StatusOK,
expectedPeers: []string{},
name: "admin user can access owned peer",
peerID: "peer2",
callerUserID: adminUser,
expectedStatus: http.StatusOK,
expectedPeers: []string{"peer1", "peer3"},
name: "admin user can access unowned peer",
peerID: "peer3",
callerUserID: adminUser,
expectedStatus: http.StatusOK,
expectedPeers: []string{"peer1", "peer2"},
name: "service user can access unowned peer",
peerID: "peer3",
callerUserID: serviceUser,
expectedStatus: http.StatusOK,
expectedPeers: []string{"peer1", "peer2"},
for _, tc := range tt {
t.Run(, func(t *testing.T) {
recorder := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodGet, fmt.Sprintf("/api/peers/%s/accessible-peers", tc.peerID), nil)
req = nbcontext.SetUserAuthInRequest(req, nbcontext.UserAuth{
UserId: tc.callerUserID,
Domain: "",
AccountId: "test_id",
router := mux.NewRouter()
router.HandleFunc("/api/peers/{peerId}/accessible-peers", p.GetAccessiblePeers).Methods("GET")
router.ServeHTTP(recorder, req)
res := recorder.Result()
if res.StatusCode != tc.expectedStatus {
t.Fatalf("handler returned wrong status code: got %v want %v", res.StatusCode, tc.expectedStatus)
body, err := io.ReadAll(res.Body)
if err != nil {
t.Fatalf("failed to read response body: %v", err)
defer res.Body.Close()
var accessiblePeers []api.AccessiblePeer
err = json.Unmarshal(body, &accessiblePeers)
if err != nil {
t.Fatalf("failed to unmarshal response: %v", err)
peerIDs := make([]string, len(accessiblePeers))
for i, peer := range accessiblePeers {
peerIDs[i] = peer.Id
assert.ElementsMatch(t, peerIDs, tc.expectedPeers)