netbird/client/internal/ingressgw/manager_test.go
Zoltan Papp 26693e4ea8
Feature/port forwarding client ingress (#3242)
Client-side forward handling

Co-authored-by: Viktor Liu <17948409+lixmal@users.noreply.github.com>

---------

Co-authored-by: Viktor Liu <17948409+lixmal@users.noreply.github.com>
2025-01-29 16:04:33 +01:00

282 lines
6.5 KiB
Go

package ingressgw
import (
"fmt"
"net/netip"
"testing"
firewall "github.com/netbirdio/netbird/client/firewall/manager"
)
var (
_ firewall.Rule = (*MocFwRule)(nil)
_ DNATFirewall = &MockDNATFirewall{}
)
type MocFwRule struct {
id string
}
func (m *MocFwRule) ID() string {
return string(m.id)
}
type MockDNATFirewall struct {
throwError bool
}
func (m *MockDNATFirewall) AddDNATRule(fwdRule firewall.ForwardRule) (firewall.Rule, error) {
if m.throwError {
return nil, fmt.Errorf("moc error")
}
fwRule := &MocFwRule{
id: fwdRule.ID(),
}
return fwRule, nil
}
func (m *MockDNATFirewall) DeleteDNATRule(rule firewall.Rule) error {
if m.throwError {
return fmt.Errorf("moc error")
}
return nil
}
func (m *MockDNATFirewall) forceToThrowErrors() {
m.throwError = true
}
func TestManager_AddRule(t *testing.T) {
fw := &MockDNATFirewall{}
mgr := NewManager(fw)
port, _ := firewall.NewPort(8080)
updates := []firewall.ForwardRule{
{
Protocol: firewall.ProtocolTCP,
DestinationPort: *port,
TranslatedAddress: netip.MustParseAddr("172.16.254.1"),
TranslatedPort: *port,
},
{
Protocol: firewall.ProtocolUDP,
DestinationPort: *port,
TranslatedAddress: netip.MustParseAddr("172.16.254.1"),
TranslatedPort: *port,
}}
if err := mgr.Update(updates); err != nil {
t.Errorf("unexpected error: %v", err)
}
rules := mgr.Rules()
if len(rules) != len(updates) {
t.Errorf("unexpected rules count: %d", len(rules))
}
}
func TestManager_UpdateRule(t *testing.T) {
fw := &MockDNATFirewall{}
mgr := NewManager(fw)
port, _ := firewall.NewPort(8080)
ruleTCP := firewall.ForwardRule{
Protocol: firewall.ProtocolTCP,
DestinationPort: *port,
TranslatedAddress: netip.MustParseAddr("172.16.254.1"),
TranslatedPort: *port,
}
if err := mgr.Update([]firewall.ForwardRule{ruleTCP}); err != nil {
t.Errorf("unexpected error: %v", err)
}
ruleUDP := firewall.ForwardRule{
Protocol: firewall.ProtocolUDP,
DestinationPort: *port,
TranslatedAddress: netip.MustParseAddr("172.16.254.2"),
TranslatedPort: *port,
}
if err := mgr.Update([]firewall.ForwardRule{ruleUDP}); err != nil {
t.Errorf("unexpected error: %v", err)
}
rules := mgr.Rules()
if len(rules) != 1 {
t.Errorf("unexpected rules count: %d", len(rules))
}
if rules[0].TranslatedAddress.String() != ruleUDP.TranslatedAddress.String() {
t.Errorf("unexpected rule: %v", rules[0])
}
if rules[0].TranslatedPort.String() != ruleUDP.TranslatedPort.String() {
t.Errorf("unexpected rule: %v", rules[0])
}
if rules[0].DestinationPort.String() != ruleUDP.DestinationPort.String() {
t.Errorf("unexpected rule: %v", rules[0])
}
if rules[0].Protocol != ruleUDP.Protocol {
t.Errorf("unexpected rule: %v", rules[0])
}
}
func TestManager_ExtendRules(t *testing.T) {
fw := &MockDNATFirewall{}
mgr := NewManager(fw)
port, _ := firewall.NewPort(8080)
ruleTCP := firewall.ForwardRule{
Protocol: firewall.ProtocolTCP,
DestinationPort: *port,
TranslatedAddress: netip.MustParseAddr("172.16.254.1"),
TranslatedPort: *port,
}
ruleUDP := firewall.ForwardRule{
Protocol: firewall.ProtocolUDP,
DestinationPort: *port,
TranslatedAddress: netip.MustParseAddr("172.16.254.2"),
TranslatedPort: *port,
}
if err := mgr.Update([]firewall.ForwardRule{ruleTCP}); err != nil {
t.Errorf("unexpected error: %v", err)
}
if err := mgr.Update([]firewall.ForwardRule{ruleTCP, ruleUDP}); err != nil {
t.Errorf("unexpected error: %v", err)
}
rules := mgr.Rules()
if len(rules) != 2 {
t.Errorf("unexpected rules count: %d", len(rules))
}
}
func TestManager_UnderlingError(t *testing.T) {
fw := &MockDNATFirewall{}
mgr := NewManager(fw)
port, _ := firewall.NewPort(8080)
ruleTCP := firewall.ForwardRule{
Protocol: firewall.ProtocolTCP,
DestinationPort: *port,
TranslatedAddress: netip.MustParseAddr("172.16.254.1"),
TranslatedPort: *port,
}
ruleUDP := firewall.ForwardRule{
Protocol: firewall.ProtocolUDP,
DestinationPort: *port,
TranslatedAddress: netip.MustParseAddr("172.16.254.2"),
TranslatedPort: *port,
}
if err := mgr.Update([]firewall.ForwardRule{ruleTCP}); err != nil {
t.Errorf("unexpected error: %v", err)
}
fw.forceToThrowErrors()
if err := mgr.Update([]firewall.ForwardRule{ruleTCP, ruleUDP}); err == nil {
t.Errorf("expected error")
}
rules := mgr.Rules()
if len(rules) != 1 {
t.Errorf("unexpected rules count: %d", len(rules))
}
}
func TestManager_Cleanup(t *testing.T) {
fw := &MockDNATFirewall{}
mgr := NewManager(fw)
port, _ := firewall.NewPort(8080)
ruleTCP := firewall.ForwardRule{
Protocol: firewall.ProtocolTCP,
DestinationPort: *port,
TranslatedAddress: netip.MustParseAddr("172.16.254.1"),
TranslatedPort: *port,
}
if err := mgr.Update([]firewall.ForwardRule{ruleTCP}); err != nil {
t.Errorf("unexpected error: %v", err)
}
if err := mgr.Update([]firewall.ForwardRule{}); err != nil {
t.Errorf("unexpected error: %v", err)
}
rules := mgr.Rules()
if len(rules) != 0 {
t.Errorf("unexpected rules count: %d", len(rules))
}
}
func TestManager_DeleteBrokenRule(t *testing.T) {
fw := &MockDNATFirewall{}
// force to throw errors when Add DNAT Rule
fw.forceToThrowErrors()
mgr := NewManager(fw)
port, _ := firewall.NewPort(8080)
ruleTCP := firewall.ForwardRule{
Protocol: firewall.ProtocolTCP,
DestinationPort: *port,
TranslatedAddress: netip.MustParseAddr("172.16.254.1"),
TranslatedPort: *port,
}
if err := mgr.Update([]firewall.ForwardRule{ruleTCP}); err == nil {
t.Errorf("unexpected error: %v", err)
}
rules := mgr.Rules()
if len(rules) != 0 {
t.Errorf("unexpected rules count: %d", len(rules))
}
// simulate that to remove a broken rule
if err := mgr.Update([]firewall.ForwardRule{}); err != nil {
t.Errorf("unexpected error: %v", err)
}
if err := mgr.Close(); err != nil {
t.Errorf("unexpected error: %v", err)
}
}
func TestManager_Close(t *testing.T) {
fw := &MockDNATFirewall{}
mgr := NewManager(fw)
port, _ := firewall.NewPort(8080)
ruleTCP := firewall.ForwardRule{
Protocol: firewall.ProtocolTCP,
DestinationPort: *port,
TranslatedAddress: netip.MustParseAddr("172.16.254.1"),
TranslatedPort: *port,
}
if err := mgr.Update([]firewall.ForwardRule{ruleTCP}); err != nil {
t.Errorf("unexpected error: %v", err)
}
if err := mgr.Close(); err != nil {
t.Errorf("unexpected error: %v", err)
}
rules := mgr.Rules()
if len(rules) != 0 {
t.Errorf("unexpected rules count: %d", len(rules))
}
}