netbird/management/server/http/nameservers_test.go

288 lines
9.4 KiB
Go
Raw Normal View History

package http
import (
"bytes"
"encoding/json"
nbdns "github.com/netbirdio/netbird/dns"
"github.com/netbirdio/netbird/management/server/http/api"
"github.com/stretchr/testify/assert"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/status"
"io"
"net/http"
"net/http/httptest"
"net/netip"
"testing"
"github.com/gorilla/mux"
"github.com/netbirdio/netbird/management/server"
"github.com/netbirdio/netbird/management/server/jwtclaims"
"github.com/netbirdio/netbird/management/server/mock_server"
)
const (
existingNSGroupID = "existingNSGroupID"
notFoundNSGroupID = "notFoundNSGroupID"
testNSGroupAccountID = "test_id"
)
var testingNSAccount = &server.Account{
Id: testNSGroupAccountID,
Domain: "hotmail.com",
}
var baseExistingNSGroup = &nbdns.NameServerGroup{
ID: existingNSGroupID,
Name: "super",
Description: "super",
NameServers: []nbdns.NameServer{
{
IP: netip.MustParseAddr("1.1.1.1"),
NSType: nbdns.UDPNameServerType,
Port: nbdns.DefaultDNSPort,
},
{
IP: netip.MustParseAddr("1.1.2.2"),
NSType: nbdns.UDPNameServerType,
Port: nbdns.DefaultDNSPort,
},
},
Groups: []string{"testing"},
Enabled: true,
}
func initNameserversTestData() *Nameservers {
return &Nameservers{
accountManager: &mock_server.MockAccountManager{
GetNameServerGroupFunc: func(accountID, nsGroupID string) (*nbdns.NameServerGroup, error) {
if nsGroupID == existingNSGroupID {
return baseExistingNSGroup.Copy(), nil
}
return nil, status.Errorf(codes.NotFound, "nameserver group with ID %s not found", nsGroupID)
},
CreateNameServerGroupFunc: func(accountID string, name, description string, nameServerList []nbdns.NameServer, groups []string, enabled bool) (*nbdns.NameServerGroup, error) {
return &nbdns.NameServerGroup{
ID: existingNSGroupID,
Name: name,
Description: description,
NameServers: nameServerList,
Groups: groups,
Enabled: enabled,
}, nil
},
DeleteNameServerGroupFunc: func(accountID, nsGroupID string) error {
return nil
},
SaveNameServerGroupFunc: func(accountID string, nsGroupToSave *nbdns.NameServerGroup) error {
if nsGroupToSave.ID == existingNSGroupID {
return nil
}
return status.Errorf(codes.NotFound, "nameserver group with ID %s was not found", nsGroupToSave.ID)
},
UpdateNameServerGroupFunc: func(accountID, nsGroupID string, operations []server.NameServerGroupUpdateOperation) (*nbdns.NameServerGroup, error) {
nsGroupToUpdate := baseExistingNSGroup.Copy()
if nsGroupID != nsGroupToUpdate.ID {
return nil, status.Errorf(codes.NotFound, "nameserver group ID %s no longer exists", nsGroupID)
}
for _, operation := range operations {
switch operation.Type {
case server.UpdateNameServerGroupName:
nsGroupToUpdate.Name = operation.Values[0]
case server.UpdateNameServerGroupDescription:
nsGroupToUpdate.Description = operation.Values[0]
case server.UpdateNameServerGroupNameServers:
var parsedNSList []nbdns.NameServer
for _, nsURL := range operation.Values {
parsed, err := nbdns.ParseNameServerURL(nsURL)
if err != nil {
return nil, err
}
parsedNSList = append(parsedNSList, parsed)
}
nsGroupToUpdate.NameServers = parsedNSList
}
}
return nsGroupToUpdate, nil
},
GetAccountFromTokenFunc: func(_ jwtclaims.AuthorizationClaims) (*server.Account, error) {
return testingNSAccount, nil
},
},
authAudience: "",
jwtExtractor: jwtclaims.ClaimsExtractor{
ExtractClaimsFromRequestContext: func(r *http.Request, authAudiance string) jwtclaims.AuthorizationClaims {
return jwtclaims.AuthorizationClaims{
UserId: "test_user",
Domain: "hotmail.com",
AccountId: testNSGroupAccountID,
}
},
},
}
}
func TestNameserversHandlers(t *testing.T) {
tt := []struct {
name string
expectedStatus int
expectedBody bool
expectedNSGroup *api.NameserverGroup
requestType string
requestPath string
requestBody io.Reader
}{
{
name: "Get Existing Nameserver Group",
requestType: http.MethodGet,
requestPath: "/api/dns/nameservers/" + existingNSGroupID,
expectedStatus: http.StatusOK,
expectedBody: true,
expectedNSGroup: toNameserverGroupResponse(baseExistingNSGroup),
},
{
name: "Get Not Existing Nameserver Group",
requestType: http.MethodGet,
requestPath: "/api/dns/nameservers/" + notFoundNSGroupID,
expectedStatus: http.StatusNotFound,
},
{
name: "POST OK",
requestType: http.MethodPost,
requestPath: "/api/dns/nameservers",
requestBody: bytes.NewBuffer(
[]byte("{\"name\":\"name\",\"Description\":\"Post\",\"nameservers\":[{\"ip\":\"1.1.1.1\",\"ns_type\":\"udp\",\"port\":53}],\"groups\":[\"group\"],\"enabled\":true}")),
expectedStatus: http.StatusOK,
expectedBody: true,
expectedNSGroup: &api.NameserverGroup{
Id: existingNSGroupID,
Name: "name",
Description: "Post",
Nameservers: []api.Nameserver{
{
Ip: "1.1.1.1",
NsType: "udp",
Port: 53,
},
},
Groups: []string{"group"},
Enabled: true,
},
},
{
name: "POST Invalid Nameserver",
requestType: http.MethodPost,
requestPath: "/api/dns/nameservers",
requestBody: bytes.NewBuffer(
[]byte("{\"name\":\"name\",\"Description\":\"Post\",\"nameservers\":[{\"ip\":\"1000\",\"ns_type\":\"udp\",\"port\":53}],\"groups\":[\"group\"],\"enabled\":true}")),
expectedStatus: http.StatusBadRequest,
expectedBody: false,
},
{
name: "PUT OK",
requestType: http.MethodPut,
requestPath: "/api/dns/nameservers/" + existingNSGroupID,
requestBody: bytes.NewBuffer(
[]byte("{\"name\":\"name\",\"Description\":\"Post\",\"nameservers\":[{\"ip\":\"1.1.1.1\",\"ns_type\":\"udp\",\"port\":53}],\"groups\":[\"group\"],\"enabled\":true}")),
expectedStatus: http.StatusOK,
expectedBody: true,
expectedNSGroup: &api.NameserverGroup{
Id: existingNSGroupID,
Name: "name",
Description: "Post",
Nameservers: []api.Nameserver{
{
Ip: "1.1.1.1",
NsType: "udp",
Port: 53,
},
},
Groups: []string{"group"},
Enabled: true,
},
},
{
name: "PUT Not Existing Nameserver Group",
requestType: http.MethodPut,
requestPath: "/api/dns/nameservers/" + notFoundNSGroupID,
requestBody: bytes.NewBuffer(
[]byte("{\"name\":\"name\",\"Description\":\"Post\",\"nameservers\":[{\"ip\":\"1.1.1.1\",\"ns_type\":\"udp\",\"port\":53}],\"groups\":[\"group\"],\"enabled\":true}")),
expectedStatus: http.StatusNotFound,
expectedBody: false,
},
{
name: "PUT Invalid Nameserver",
requestType: http.MethodPut,
requestPath: "/api/dns/nameservers/" + notFoundNSGroupID,
requestBody: bytes.NewBuffer(
[]byte("{\"name\":\"name\",\"Description\":\"Post\",\"nameservers\":[{\"ip\":\"100\",\"ns_type\":\"udp\",\"port\":53}],\"groups\":[\"group\"],\"enabled\":true}")),
expectedStatus: http.StatusBadRequest,
expectedBody: false,
},
{
name: "PATCH OK",
requestType: http.MethodPatch,
requestPath: "/api/dns/nameservers/" + existingNSGroupID,
requestBody: bytes.NewBufferString("[{\"op\":\"replace\",\"path\":\"description\",\"value\":[\"NewDesc\"]}]"),
expectedStatus: http.StatusOK,
expectedBody: true,
expectedNSGroup: &api.NameserverGroup{
Id: existingNSGroupID,
Name: baseExistingNSGroup.Name,
Description: "NewDesc",
Nameservers: toNameserverGroupResponse(baseExistingNSGroup).Nameservers,
Groups: baseExistingNSGroup.Groups,
Enabled: baseExistingNSGroup.Enabled,
},
},
{
name: "PATCH Invalid Nameserver Group OK",
requestType: http.MethodPatch,
requestPath: "/api/dns/nameservers/" + notFoundRouteID,
requestBody: bytes.NewBufferString("[{\"op\":\"replace\",\"path\":\"description\",\"value\":[\"NewDesc\"]}]"),
expectedStatus: http.StatusNotFound,
expectedBody: false,
},
}
p := initNameserversTestData()
for _, tc := range tt {
t.Run(tc.name, func(t *testing.T) {
recorder := httptest.NewRecorder()
req := httptest.NewRequest(tc.requestType, tc.requestPath, tc.requestBody)
router := mux.NewRouter()
router.HandleFunc("/api/dns/nameservers/{id}", p.GetNameserverGroupHandler).Methods("GET")
router.HandleFunc("/api/dns/nameservers", p.CreateNameserverGroupHandler).Methods("POST")
router.HandleFunc("/api/dns/nameservers/{id}", p.DeleteNameserverGroupHandler).Methods("DELETE")
router.HandleFunc("/api/dns/nameservers/{id}", p.UpdateNameserverGroupHandler).Methods("PUT")
router.HandleFunc("/api/dns/nameservers/{id}", p.PatchNameserverGroupHandler).Methods("PATCH")
router.ServeHTTP(recorder, req)
res := recorder.Result()
defer res.Body.Close()
content, err := io.ReadAll(res.Body)
if err != nil {
t.Fatalf("I don't know what I expected; %v", err)
}
if status := recorder.Code; status != tc.expectedStatus {
t.Errorf("handler returned wrong status code: got %v want %v, content: %s",
status, tc.expectedStatus, string(content))
return
}
if !tc.expectedBody {
return
}
got := &api.NameserverGroup{}
if err = json.Unmarshal(content, &got); err != nil {
t.Fatalf("Sent content is not in correct json format; %v", err)
}
assert.Equal(t, tc.expectedNSGroup, got)
})
}
}