netbird/management/server/http/rules_test.go

309 lines
8.4 KiB
Go
Raw Normal View History

package http
import (
"bytes"
"encoding/json"
"fmt"
"github.com/netbirdio/netbird/management/server/http/api"
"io"
"net/http"
"net/http/httptest"
"strings"
"testing"
"github.com/gorilla/mux"
"github.com/netbirdio/netbird/management/server/jwtclaims"
"github.com/magiconair/properties/assert"
"github.com/netbirdio/netbird/management/server"
"github.com/netbirdio/netbird/management/server/mock_server"
)
func initRulesTestData(rules ...*server.Rule) *Rules {
return &Rules{
accountManager: &mock_server.MockAccountManager{
SaveRuleFunc: func(_ string, rule *server.Rule) error {
if !strings.HasPrefix(rule.ID, "id-") {
rule.ID = "id-was-set"
}
return nil
},
GetRuleFunc: func(_, ruleID string) (*server.Rule, error) {
if ruleID != "idoftherule" {
return nil, fmt.Errorf("not found")
}
return &server.Rule{
ID: "idoftherule",
Name: "Rule",
Source: []string{"idofsrcrule"},
Destination: []string{"idofdestrule"},
Flow: server.TrafficFlowBidirect,
}, nil
},
UpdateRuleFunc: func(_ string, ruleID string, operations []server.RuleUpdateOperation) (*server.Rule, error) {
var rule server.Rule
rule.ID = ruleID
for _, operation := range operations {
switch operation.Type {
case server.UpdateRuleName:
rule.Name = operation.Values[0]
case server.UpdateRuleDescription:
rule.Description = operation.Values[0]
case server.UpdateRuleFlow:
if server.TrafficFlowBidirectString == operation.Values[0] {
rule.Flow = server.TrafficFlowBidirect
} else {
rule.Flow = 100
}
case server.UpdateSourceGroups, server.InsertGroupsToSource:
rule.Source = operation.Values
case server.UpdateDestinationGroups, server.InsertGroupsToDestination:
rule.Destination = operation.Values
case server.RemoveGroupsFromSource, server.RemoveGroupsFromDestination:
default:
return nil, fmt.Errorf("no operation")
}
}
return &rule, nil
},
GetAccountFromTokenFunc: func(claims jwtclaims.AuthorizationClaims) (*server.Account, error) {
return &server.Account{
Id: claims.AccountId,
Domain: "hotmail.com",
Rules: map[string]*server.Rule{"id-existed": &server.Rule{ID: "id-existed"}},
Groups: map[string]*server.Group{
"F": {ID: "F"},
"G": {ID: "G"},
},
}, nil
},
},
authAudience: "",
jwtExtractor: jwtclaims.ClaimsExtractor{
ExtractClaimsFromRequestContext: func(r *http.Request, authAudiance string) jwtclaims.AuthorizationClaims {
return jwtclaims.AuthorizationClaims{
UserId: "test_user",
Domain: "hotmail.com",
AccountId: "test_id",
}
},
},
}
}
func TestRulesGetRule(t *testing.T) {
tt := []struct {
name string
expectedStatus int
expectedBody bool
requestType string
requestPath string
requestBody io.Reader
}{
{
name: "GetRule OK",
expectedBody: true,
requestType: http.MethodGet,
requestPath: "/api/rules/idoftherule",
expectedStatus: http.StatusOK,
},
{
name: "GetRule not found",
requestType: http.MethodGet,
requestPath: "/api/rules/notexists",
expectedStatus: http.StatusNotFound,
},
}
rule := &server.Rule{
ID: "idoftherule",
Name: "Rule",
}
p := initRulesTestData(rule)
for _, tc := range tt {
t.Run(tc.name, func(t *testing.T) {
recorder := httptest.NewRecorder()
req := httptest.NewRequest(tc.requestType, tc.requestPath, tc.requestBody)
router := mux.NewRouter()
router.HandleFunc("/api/rules/{id}", p.GetRuleHandler).Methods("GET")
router.ServeHTTP(recorder, req)
res := recorder.Result()
defer res.Body.Close()
if status := recorder.Code; status != tc.expectedStatus {
t.Errorf("handler returned wrong status code: got %v want %v",
status, tc.expectedStatus)
return
}
if !tc.expectedBody {
return
}
content, err := io.ReadAll(res.Body)
if err != nil {
t.Fatalf("I don't know what I expected; %v", err)
}
var got api.Rule
if err = json.Unmarshal(content, &got); err != nil {
t.Fatalf("Sent content is not in correct json format; %v", err)
}
assert.Equal(t, got.Id, rule.ID)
assert.Equal(t, got.Name, rule.Name)
})
}
}
func TestRulesWriteRule(t *testing.T) {
tt := []struct {
name string
expectedStatus int
expectedBody bool
expectedRule *api.Rule
requestType string
requestPath string
requestBody io.Reader
}{
{
name: "WriteRule POST OK",
requestType: http.MethodPost,
requestPath: "/api/rules",
requestBody: bytes.NewBuffer(
[]byte(`{"Name":"Default POSTed Rule","Flow":"bidirect"}`)),
expectedStatus: http.StatusOK,
expectedBody: true,
expectedRule: &api.Rule{
Id: "id-was-set",
Name: "Default POSTed Rule",
Flow: server.TrafficFlowBidirectString,
},
},
{
name: "WriteRule POST Invalid Name",
requestType: http.MethodPost,
requestPath: "/api/rules",
requestBody: bytes.NewBuffer(
[]byte(`{"Name":"","Flow":"bidirect"}`)),
expectedStatus: http.StatusUnprocessableEntity,
expectedBody: false,
},
{
name: "WriteRule PUT OK",
requestType: http.MethodPut,
requestPath: "/api/rules/id-existed",
requestBody: bytes.NewBuffer(
[]byte(`{"Name":"Default POSTed Rule","Flow":"bidirect"}`)),
expectedStatus: http.StatusOK,
expectedBody: true,
expectedRule: &api.Rule{
Id: "id-existed",
Name: "Default POSTed Rule",
Flow: server.TrafficFlowBidirectString,
},
},
{
name: "WriteRule PUT Invalid Name",
requestType: http.MethodPut,
requestPath: "/api/rules/id-existed",
requestBody: bytes.NewBuffer(
[]byte(`{"Name":"","Flow":"bidirect"}`)),
expectedStatus: http.StatusUnprocessableEntity,
},
{
name: "Write Rule PATCH Name OK",
requestType: http.MethodPatch,
requestPath: "/api/rules/id-existed",
requestBody: bytes.NewBuffer(
[]byte(`[{"op":"replace","path":"name","value":["Default POSTed Rule"]}]`)),
expectedStatus: http.StatusOK,
expectedBody: true,
expectedRule: &api.Rule{
Id: "id-existed",
Name: "Default POSTed Rule",
Flow: server.TrafficFlowBidirectString,
},
},
{
name: "Write Rule PATCH Invalid Name OP",
requestType: http.MethodPatch,
requestPath: "/api/rules/id-existed",
requestBody: bytes.NewBuffer(
[]byte(`[{"op":"insert","path":"name","value":[""]}]`)),
expectedStatus: http.StatusBadRequest,
expectedBody: false,
},
{
name: "Write Rule PATCH Invalid Name",
requestType: http.MethodPatch,
requestPath: "/api/rules/id-existed",
requestBody: bytes.NewBuffer(
[]byte(`[{"op":"replace","path":"name","value":[]}]`)),
expectedStatus: http.StatusUnprocessableEntity,
expectedBody: false,
},
{
name: "Write Rule PATCH Sources OK",
requestType: http.MethodPatch,
requestPath: "/api/rules/id-existed",
requestBody: bytes.NewBuffer(
[]byte(`[{"op":"replace","path":"sources","value":["G","F"]}]`)),
expectedStatus: http.StatusOK,
expectedBody: true,
expectedRule: &api.Rule{
Id: "id-existed",
Flow: server.TrafficFlowBidirectString,
Sources: []api.GroupMinimum{
{Id: "G"},
{Id: "F"}},
},
},
}
p := initRulesTestData()
for _, tc := range tt {
t.Run(tc.name, func(t *testing.T) {
recorder := httptest.NewRecorder()
req := httptest.NewRequest(tc.requestType, tc.requestPath, tc.requestBody)
router := mux.NewRouter()
router.HandleFunc("/api/rules", p.CreateRuleHandler).Methods("POST")
router.HandleFunc("/api/rules/{id}", p.UpdateRuleHandler).Methods("PUT")
router.HandleFunc("/api/rules/{id}", p.PatchRuleHandler).Methods("PATCH")
router.ServeHTTP(recorder, req)
res := recorder.Result()
defer res.Body.Close()
content, err := io.ReadAll(res.Body)
if err != nil {
t.Fatalf("I don't know what I expected; %v", err)
}
if status := recorder.Code; status != tc.expectedStatus {
t.Errorf("handler returned wrong status code: got %v want %v, content: %s",
status, tc.expectedStatus, string(content))
return
}
if !tc.expectedBody {
return
}
got := &api.Rule{}
if err = json.Unmarshal(content, &got); err != nil {
t.Fatalf("Sent content is not in correct json format; %v", err)
}
assert.Equal(t, got, tc.expectedRule)
})
}
}