netbird/management/server/http/handlers/dns/dns_settings_handler_test.go
Viktor Liu ddc365f7a0
[client, management] Add new network concept (#3047)
---------

Co-authored-by: Pascal Fischer <32096965+pascal-fischer@users.noreply.github.com>
Co-authored-by: bcmmbaga <bethuelmbaga12@gmail.com>
Co-authored-by: Maycon Santos <mlsmaycon@gmail.com>
Co-authored-by: Zoltan Papp <zoltan.pmail@gmail.com>
2024-12-20 11:30:28 +01:00

153 lines
4.3 KiB
Go

package dns
import (
"bytes"
"context"
"encoding/json"
"io"
"net/http"
"net/http/httptest"
"testing"
"github.com/stretchr/testify/assert"
"github.com/netbirdio/netbird/management/server/http/api"
"github.com/netbirdio/netbird/management/server/status"
"github.com/netbirdio/netbird/management/server/types"
"github.com/gorilla/mux"
"github.com/netbirdio/netbird/management/server/jwtclaims"
"github.com/netbirdio/netbird/management/server/mock_server"
)
const (
testDNSSettingsAccountID = "test_id"
testDNSSettingsExistingGroup = "test_group"
testDNSSettingsUserID = "test_user"
)
var baseExistingDNSSettings = types.DNSSettings{
DisabledManagementGroups: []string{testDNSSettingsExistingGroup},
}
var testingDNSSettingsAccount = &types.Account{
Id: testDNSSettingsAccountID,
Domain: "hotmail.com",
Users: map[string]*types.User{
testDNSSettingsUserID: types.NewAdminUser("test_user"),
},
DNSSettings: baseExistingDNSSettings,
}
func initDNSSettingsTestData() *dnsSettingsHandler {
return &dnsSettingsHandler{
accountManager: &mock_server.MockAccountManager{
GetDNSSettingsFunc: func(ctx context.Context, accountID string, userID string) (*types.DNSSettings, error) {
return &testingDNSSettingsAccount.DNSSettings, nil
},
SaveDNSSettingsFunc: func(ctx context.Context, accountID string, userID string, dnsSettingsToSave *types.DNSSettings) error {
if dnsSettingsToSave != nil {
return nil
}
return status.Errorf(status.InvalidArgument, "the dns settings provided are nil")
},
GetAccountIDFromTokenFunc: func(ctx context.Context, _ jwtclaims.AuthorizationClaims) (string, string, error) {
return testingDNSSettingsAccount.Id, testingDNSSettingsAccount.Users[testDNSSettingsUserID].Id, nil
},
},
claimsExtractor: jwtclaims.NewClaimsExtractor(
jwtclaims.WithFromRequestContext(func(r *http.Request) jwtclaims.AuthorizationClaims {
return jwtclaims.AuthorizationClaims{
UserId: "test_user",
Domain: "hotmail.com",
AccountId: testDNSSettingsAccountID,
}
}),
),
}
}
func TestDNSSettingsHandlers(t *testing.T) {
tt := []struct {
name string
expectedStatus int
expectedBody bool
expectedDNSSettings *api.DNSSettings
requestType string
requestPath string
requestBody io.Reader
}{
{
name: "Get DNS Settings",
requestType: http.MethodGet,
requestPath: "/api/dns/settings",
expectedStatus: http.StatusOK,
expectedBody: true,
expectedDNSSettings: &api.DNSSettings{
DisabledManagementGroups: baseExistingDNSSettings.DisabledManagementGroups,
},
},
{
name: "Update DNS Settings",
requestType: http.MethodPut,
requestPath: "/api/dns/settings",
requestBody: bytes.NewBuffer(
[]byte("{\"disabled_management_groups\":[\"group1\",\"group2\"]}")),
expectedStatus: http.StatusOK,
expectedBody: true,
expectedDNSSettings: &api.DNSSettings{
DisabledManagementGroups: []string{"group1", "group2"},
},
},
{
name: "Update DNS Settings Empty Body",
requestType: http.MethodPut,
requestPath: "/api/dns/settings",
requestBody: bytes.NewBuffer(
[]byte("{}")),
expectedStatus: http.StatusOK,
expectedBody: true,
expectedDNSSettings: &api.DNSSettings{},
},
}
p := initDNSSettingsTestData()
for _, tc := range tt {
t.Run(tc.name, func(t *testing.T) {
recorder := httptest.NewRecorder()
req := httptest.NewRequest(tc.requestType, tc.requestPath, tc.requestBody)
router := mux.NewRouter()
router.HandleFunc("/api/dns/settings", p.getDNSSettings).Methods("GET")
router.HandleFunc("/api/dns/settings", p.updateDNSSettings).Methods("PUT")
router.ServeHTTP(recorder, req)
res := recorder.Result()
defer res.Body.Close()
content, err := io.ReadAll(res.Body)
if err != nil {
t.Fatalf("I don't know what I expected; %v", err)
}
if status := recorder.Code; status != tc.expectedStatus {
t.Errorf("handler returned wrong status code: got %v want %v, content: %s",
status, tc.expectedStatus, string(content))
return
}
if !tc.expectedBody {
return
}
got := &api.DNSSettings{}
if err = json.Unmarshal(content, &got); err != nil {
t.Fatalf("Sent content is not in correct json format; %v", err)
}
assert.Equal(t, tc.expectedDNSSettings, got)
})
}
}