mirror of
https://github.com/netbirdio/netbird.git
synced 2025-06-20 09:47:49 +02:00
Merge branch 'main' into fix/login-filter
This commit is contained in:
commit
7890cb4f32
@ -149,6 +149,7 @@ nfpms:
|
|||||||
dockers:
|
dockers:
|
||||||
- image_templates:
|
- image_templates:
|
||||||
- netbirdio/netbird:{{ .Version }}-amd64
|
- netbirdio/netbird:{{ .Version }}-amd64
|
||||||
|
- ghcr.io/netbirdio/netbird:{{ .Version }}-amd64
|
||||||
ids:
|
ids:
|
||||||
- netbird
|
- netbird
|
||||||
goarch: amd64
|
goarch: amd64
|
||||||
@ -164,6 +165,7 @@ dockers:
|
|||||||
- "--label=maintainer=dev@netbird.io"
|
- "--label=maintainer=dev@netbird.io"
|
||||||
- image_templates:
|
- image_templates:
|
||||||
- netbirdio/netbird:{{ .Version }}-arm64v8
|
- netbirdio/netbird:{{ .Version }}-arm64v8
|
||||||
|
- ghcr.io/netbirdio/netbird:{{ .Version }}-arm64v8
|
||||||
ids:
|
ids:
|
||||||
- netbird
|
- netbird
|
||||||
goarch: arm64
|
goarch: arm64
|
||||||
|
@ -4,12 +4,12 @@ import (
|
|||||||
"github.com/netbirdio/netbird/client/internal"
|
"github.com/netbirdio/netbird/client/internal"
|
||||||
)
|
)
|
||||||
|
|
||||||
// Preferences export a subset of the internal config for gomobile
|
// Preferences exports a subset of the internal config for gomobile
|
||||||
type Preferences struct {
|
type Preferences struct {
|
||||||
configInput internal.ConfigInput
|
configInput internal.ConfigInput
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewPreferences create new Preferences instance
|
// NewPreferences creates a new Preferences instance
|
||||||
func NewPreferences(configPath string) *Preferences {
|
func NewPreferences(configPath string) *Preferences {
|
||||||
ci := internal.ConfigInput{
|
ci := internal.ConfigInput{
|
||||||
ConfigPath: configPath,
|
ConfigPath: configPath,
|
||||||
@ -17,7 +17,7 @@ func NewPreferences(configPath string) *Preferences {
|
|||||||
return &Preferences{ci}
|
return &Preferences{ci}
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetManagementURL read url from config file
|
// GetManagementURL reads URL from config file
|
||||||
func (p *Preferences) GetManagementURL() (string, error) {
|
func (p *Preferences) GetManagementURL() (string, error) {
|
||||||
if p.configInput.ManagementURL != "" {
|
if p.configInput.ManagementURL != "" {
|
||||||
return p.configInput.ManagementURL, nil
|
return p.configInput.ManagementURL, nil
|
||||||
@ -30,12 +30,12 @@ func (p *Preferences) GetManagementURL() (string, error) {
|
|||||||
return cfg.ManagementURL.String(), err
|
return cfg.ManagementURL.String(), err
|
||||||
}
|
}
|
||||||
|
|
||||||
// SetManagementURL store the given url and wait for commit
|
// SetManagementURL stores the given URL and waits for commit
|
||||||
func (p *Preferences) SetManagementURL(url string) {
|
func (p *Preferences) SetManagementURL(url string) {
|
||||||
p.configInput.ManagementURL = url
|
p.configInput.ManagementURL = url
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetAdminURL read url from config file
|
// GetAdminURL reads URL from config file
|
||||||
func (p *Preferences) GetAdminURL() (string, error) {
|
func (p *Preferences) GetAdminURL() (string, error) {
|
||||||
if p.configInput.AdminURL != "" {
|
if p.configInput.AdminURL != "" {
|
||||||
return p.configInput.AdminURL, nil
|
return p.configInput.AdminURL, nil
|
||||||
@ -48,12 +48,12 @@ func (p *Preferences) GetAdminURL() (string, error) {
|
|||||||
return cfg.AdminURL.String(), err
|
return cfg.AdminURL.String(), err
|
||||||
}
|
}
|
||||||
|
|
||||||
// SetAdminURL store the given url and wait for commit
|
// SetAdminURL stores the given URL and waits for commit
|
||||||
func (p *Preferences) SetAdminURL(url string) {
|
func (p *Preferences) SetAdminURL(url string) {
|
||||||
p.configInput.AdminURL = url
|
p.configInput.AdminURL = url
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetPreSharedKey read preshared key from config file
|
// GetPreSharedKey reads pre-shared key from config file
|
||||||
func (p *Preferences) GetPreSharedKey() (string, error) {
|
func (p *Preferences) GetPreSharedKey() (string, error) {
|
||||||
if p.configInput.PreSharedKey != nil {
|
if p.configInput.PreSharedKey != nil {
|
||||||
return *p.configInput.PreSharedKey, nil
|
return *p.configInput.PreSharedKey, nil
|
||||||
@ -66,17 +66,17 @@ func (p *Preferences) GetPreSharedKey() (string, error) {
|
|||||||
return cfg.PreSharedKey, err
|
return cfg.PreSharedKey, err
|
||||||
}
|
}
|
||||||
|
|
||||||
// SetPreSharedKey store the given key and wait for commit
|
// SetPreSharedKey stores the given key and waits for commit
|
||||||
func (p *Preferences) SetPreSharedKey(key string) {
|
func (p *Preferences) SetPreSharedKey(key string) {
|
||||||
p.configInput.PreSharedKey = &key
|
p.configInput.PreSharedKey = &key
|
||||||
}
|
}
|
||||||
|
|
||||||
// SetRosenpassEnabled store if rosenpass is enabled
|
// SetRosenpassEnabled stores whether Rosenpass is enabled
|
||||||
func (p *Preferences) SetRosenpassEnabled(enabled bool) {
|
func (p *Preferences) SetRosenpassEnabled(enabled bool) {
|
||||||
p.configInput.RosenpassEnabled = &enabled
|
p.configInput.RosenpassEnabled = &enabled
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetRosenpassEnabled read rosenpass enabled from config file
|
// GetRosenpassEnabled reads Rosenpass enabled status from config file
|
||||||
func (p *Preferences) GetRosenpassEnabled() (bool, error) {
|
func (p *Preferences) GetRosenpassEnabled() (bool, error) {
|
||||||
if p.configInput.RosenpassEnabled != nil {
|
if p.configInput.RosenpassEnabled != nil {
|
||||||
return *p.configInput.RosenpassEnabled, nil
|
return *p.configInput.RosenpassEnabled, nil
|
||||||
@ -89,12 +89,12 @@ func (p *Preferences) GetRosenpassEnabled() (bool, error) {
|
|||||||
return cfg.RosenpassEnabled, err
|
return cfg.RosenpassEnabled, err
|
||||||
}
|
}
|
||||||
|
|
||||||
// SetRosenpassPermissive store the given permissive and wait for commit
|
// SetRosenpassPermissive stores the given permissive setting and waits for commit
|
||||||
func (p *Preferences) SetRosenpassPermissive(permissive bool) {
|
func (p *Preferences) SetRosenpassPermissive(permissive bool) {
|
||||||
p.configInput.RosenpassPermissive = &permissive
|
p.configInput.RosenpassPermissive = &permissive
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetRosenpassPermissive read rosenpass permissive from config file
|
// GetRosenpassPermissive reads Rosenpass permissive setting from config file
|
||||||
func (p *Preferences) GetRosenpassPermissive() (bool, error) {
|
func (p *Preferences) GetRosenpassPermissive() (bool, error) {
|
||||||
if p.configInput.RosenpassPermissive != nil {
|
if p.configInput.RosenpassPermissive != nil {
|
||||||
return *p.configInput.RosenpassPermissive, nil
|
return *p.configInput.RosenpassPermissive, nil
|
||||||
@ -107,7 +107,119 @@ func (p *Preferences) GetRosenpassPermissive() (bool, error) {
|
|||||||
return cfg.RosenpassPermissive, err
|
return cfg.RosenpassPermissive, err
|
||||||
}
|
}
|
||||||
|
|
||||||
// Commit write out the changes into config file
|
// GetDisableClientRoutes reads disable client routes setting from config file
|
||||||
|
func (p *Preferences) GetDisableClientRoutes() (bool, error) {
|
||||||
|
if p.configInput.DisableClientRoutes != nil {
|
||||||
|
return *p.configInput.DisableClientRoutes, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
cfg, err := internal.ReadConfig(p.configInput.ConfigPath)
|
||||||
|
if err != nil {
|
||||||
|
return false, err
|
||||||
|
}
|
||||||
|
return cfg.DisableClientRoutes, err
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetDisableClientRoutes stores the given value and waits for commit
|
||||||
|
func (p *Preferences) SetDisableClientRoutes(disable bool) {
|
||||||
|
p.configInput.DisableClientRoutes = &disable
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetDisableServerRoutes reads disable server routes setting from config file
|
||||||
|
func (p *Preferences) GetDisableServerRoutes() (bool, error) {
|
||||||
|
if p.configInput.DisableServerRoutes != nil {
|
||||||
|
return *p.configInput.DisableServerRoutes, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
cfg, err := internal.ReadConfig(p.configInput.ConfigPath)
|
||||||
|
if err != nil {
|
||||||
|
return false, err
|
||||||
|
}
|
||||||
|
return cfg.DisableServerRoutes, err
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetDisableServerRoutes stores the given value and waits for commit
|
||||||
|
func (p *Preferences) SetDisableServerRoutes(disable bool) {
|
||||||
|
p.configInput.DisableServerRoutes = &disable
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetDisableDNS reads disable DNS setting from config file
|
||||||
|
func (p *Preferences) GetDisableDNS() (bool, error) {
|
||||||
|
if p.configInput.DisableDNS != nil {
|
||||||
|
return *p.configInput.DisableDNS, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
cfg, err := internal.ReadConfig(p.configInput.ConfigPath)
|
||||||
|
if err != nil {
|
||||||
|
return false, err
|
||||||
|
}
|
||||||
|
return cfg.DisableDNS, err
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetDisableDNS stores the given value and waits for commit
|
||||||
|
func (p *Preferences) SetDisableDNS(disable bool) {
|
||||||
|
p.configInput.DisableDNS = &disable
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetDisableFirewall reads disable firewall setting from config file
|
||||||
|
func (p *Preferences) GetDisableFirewall() (bool, error) {
|
||||||
|
if p.configInput.DisableFirewall != nil {
|
||||||
|
return *p.configInput.DisableFirewall, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
cfg, err := internal.ReadConfig(p.configInput.ConfigPath)
|
||||||
|
if err != nil {
|
||||||
|
return false, err
|
||||||
|
}
|
||||||
|
return cfg.DisableFirewall, err
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetDisableFirewall stores the given value and waits for commit
|
||||||
|
func (p *Preferences) SetDisableFirewall(disable bool) {
|
||||||
|
p.configInput.DisableFirewall = &disable
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetServerSSHAllowed reads server SSH allowed setting from config file
|
||||||
|
func (p *Preferences) GetServerSSHAllowed() (bool, error) {
|
||||||
|
if p.configInput.ServerSSHAllowed != nil {
|
||||||
|
return *p.configInput.ServerSSHAllowed, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
cfg, err := internal.ReadConfig(p.configInput.ConfigPath)
|
||||||
|
if err != nil {
|
||||||
|
return false, err
|
||||||
|
}
|
||||||
|
if cfg.ServerSSHAllowed == nil {
|
||||||
|
// Default to false for security on Android
|
||||||
|
return false, nil
|
||||||
|
}
|
||||||
|
return *cfg.ServerSSHAllowed, err
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetServerSSHAllowed stores the given value and waits for commit
|
||||||
|
func (p *Preferences) SetServerSSHAllowed(allowed bool) {
|
||||||
|
p.configInput.ServerSSHAllowed = &allowed
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetBlockInbound reads block inbound setting from config file
|
||||||
|
func (p *Preferences) GetBlockInbound() (bool, error) {
|
||||||
|
if p.configInput.BlockInbound != nil {
|
||||||
|
return *p.configInput.BlockInbound, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
cfg, err := internal.ReadConfig(p.configInput.ConfigPath)
|
||||||
|
if err != nil {
|
||||||
|
return false, err
|
||||||
|
}
|
||||||
|
return cfg.BlockInbound, err
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetBlockInbound stores the given value and waits for commit
|
||||||
|
func (p *Preferences) SetBlockInbound(block bool) {
|
||||||
|
p.configInput.BlockInbound = &block
|
||||||
|
}
|
||||||
|
|
||||||
|
// Commit writes out the changes to the config file
|
||||||
func (p *Preferences) Commit() error {
|
func (p *Preferences) Commit() error {
|
||||||
_, err := internal.UpdateOrCreateConfig(p.configInput)
|
_, err := internal.UpdateOrCreateConfig(p.configInput)
|
||||||
return err
|
return err
|
||||||
|
@ -118,7 +118,7 @@ func tracePacket(cmd *cobra.Command, args []string) error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func printTrace(cmd *cobra.Command, src, dst, proto string, sport, dport uint16, resp *proto.TracePacketResponse) {
|
func printTrace(cmd *cobra.Command, src, dst, proto string, sport, dport uint16, resp *proto.TracePacketResponse) {
|
||||||
cmd.Printf("Packet trace %s:%d -> %s:%d (%s)\n\n", src, sport, dst, dport, strings.ToUpper(proto))
|
cmd.Printf("Packet trace %s:%d → %s:%d (%s)\n\n", src, sport, dst, dport, strings.ToUpper(proto))
|
||||||
|
|
||||||
for _, stage := range resp.Stages {
|
for _, stage := range resp.Stages {
|
||||||
if stage.ForwardingDetails != nil {
|
if stage.ForwardingDetails != nil {
|
||||||
|
@ -62,5 +62,5 @@ type ConnKey struct {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (c ConnKey) String() string {
|
func (c ConnKey) String() string {
|
||||||
return fmt.Sprintf("%s:%d -> %s:%d", c.SrcIP.Unmap(), c.SrcPort, c.DstIP.Unmap(), c.DstPort)
|
return fmt.Sprintf("%s:%d → %s:%d", c.SrcIP.Unmap(), c.SrcPort, c.DstIP.Unmap(), c.DstPort)
|
||||||
}
|
}
|
||||||
|
@ -3,6 +3,7 @@ package conntrack
|
|||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"fmt"
|
"fmt"
|
||||||
|
"net"
|
||||||
"net/netip"
|
"net/netip"
|
||||||
"sync"
|
"sync"
|
||||||
"time"
|
"time"
|
||||||
@ -19,6 +20,10 @@ const (
|
|||||||
DefaultICMPTimeout = 30 * time.Second
|
DefaultICMPTimeout = 30 * time.Second
|
||||||
// ICMPCleanupInterval is how often we check for stale ICMP connections
|
// ICMPCleanupInterval is how often we check for stale ICMP connections
|
||||||
ICMPCleanupInterval = 15 * time.Second
|
ICMPCleanupInterval = 15 * time.Second
|
||||||
|
|
||||||
|
// MaxICMPPayloadLength is the maximum length of ICMP payload we consider for original packet info,
|
||||||
|
// which includes the IP header (20 bytes) and transport header (8 bytes)
|
||||||
|
MaxICMPPayloadLength = 28
|
||||||
)
|
)
|
||||||
|
|
||||||
// ICMPConnKey uniquely identifies an ICMP connection
|
// ICMPConnKey uniquely identifies an ICMP connection
|
||||||
@ -29,7 +34,7 @@ type ICMPConnKey struct {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (i ICMPConnKey) String() string {
|
func (i ICMPConnKey) String() string {
|
||||||
return fmt.Sprintf("%s -> %s (id %d)", i.SrcIP, i.DstIP, i.ID)
|
return fmt.Sprintf("%s → %s (id %d)", i.SrcIP, i.DstIP, i.ID)
|
||||||
}
|
}
|
||||||
|
|
||||||
// ICMPConnTrack represents an ICMP connection state
|
// ICMPConnTrack represents an ICMP connection state
|
||||||
@ -50,6 +55,72 @@ type ICMPTracker struct {
|
|||||||
flowLogger nftypes.FlowLogger
|
flowLogger nftypes.FlowLogger
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// ICMPInfo holds ICMP type, code, and payload for lazy string formatting in logs
|
||||||
|
type ICMPInfo struct {
|
||||||
|
TypeCode layers.ICMPv4TypeCode
|
||||||
|
PayloadData [MaxICMPPayloadLength]byte
|
||||||
|
// actual length of valid data
|
||||||
|
PayloadLen int
|
||||||
|
}
|
||||||
|
|
||||||
|
// String implements fmt.Stringer for lazy evaluation in log messages
|
||||||
|
func (info ICMPInfo) String() string {
|
||||||
|
if info.isErrorMessage() && info.PayloadLen >= MaxICMPPayloadLength {
|
||||||
|
if origInfo := info.parseOriginalPacket(); origInfo != "" {
|
||||||
|
return fmt.Sprintf("%s (original: %s)", info.TypeCode, origInfo)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return info.TypeCode.String()
|
||||||
|
}
|
||||||
|
|
||||||
|
// isErrorMessage returns true if this ICMP type carries original packet info
|
||||||
|
func (info ICMPInfo) isErrorMessage() bool {
|
||||||
|
typ := info.TypeCode.Type()
|
||||||
|
return typ == 3 || // Destination Unreachable
|
||||||
|
typ == 5 || // Redirect
|
||||||
|
typ == 11 || // Time Exceeded
|
||||||
|
typ == 12 // Parameter Problem
|
||||||
|
}
|
||||||
|
|
||||||
|
// parseOriginalPacket extracts info about the original packet from ICMP payload
|
||||||
|
func (info ICMPInfo) parseOriginalPacket() string {
|
||||||
|
if info.PayloadLen < MaxICMPPayloadLength {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
|
||||||
|
// TODO: handle IPv6
|
||||||
|
if version := (info.PayloadData[0] >> 4) & 0xF; version != 4 {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
|
||||||
|
protocol := info.PayloadData[9]
|
||||||
|
srcIP := net.IP(info.PayloadData[12:16])
|
||||||
|
dstIP := net.IP(info.PayloadData[16:20])
|
||||||
|
|
||||||
|
transportData := info.PayloadData[20:]
|
||||||
|
|
||||||
|
switch nftypes.Protocol(protocol) {
|
||||||
|
case nftypes.TCP:
|
||||||
|
srcPort := uint16(transportData[0])<<8 | uint16(transportData[1])
|
||||||
|
dstPort := uint16(transportData[2])<<8 | uint16(transportData[3])
|
||||||
|
return fmt.Sprintf("TCP %s:%d → %s:%d", srcIP, srcPort, dstIP, dstPort)
|
||||||
|
|
||||||
|
case nftypes.UDP:
|
||||||
|
srcPort := uint16(transportData[0])<<8 | uint16(transportData[1])
|
||||||
|
dstPort := uint16(transportData[2])<<8 | uint16(transportData[3])
|
||||||
|
return fmt.Sprintf("UDP %s:%d → %s:%d", srcIP, srcPort, dstIP, dstPort)
|
||||||
|
|
||||||
|
case nftypes.ICMP:
|
||||||
|
icmpType := transportData[0]
|
||||||
|
icmpCode := transportData[1]
|
||||||
|
return fmt.Sprintf("ICMP %s → %s (type %d code %d)", srcIP, dstIP, icmpType, icmpCode)
|
||||||
|
|
||||||
|
default:
|
||||||
|
return fmt.Sprintf("Proto %d %s → %s", protocol, srcIP, dstIP)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// NewICMPTracker creates a new ICMP connection tracker
|
// NewICMPTracker creates a new ICMP connection tracker
|
||||||
func NewICMPTracker(timeout time.Duration, logger *nblog.Logger, flowLogger nftypes.FlowLogger) *ICMPTracker {
|
func NewICMPTracker(timeout time.Duration, logger *nblog.Logger, flowLogger nftypes.FlowLogger) *ICMPTracker {
|
||||||
if timeout == 0 {
|
if timeout == 0 {
|
||||||
@ -93,30 +164,64 @@ func (t *ICMPTracker) updateIfExists(srcIP netip.Addr, dstIP netip.Addr, id uint
|
|||||||
}
|
}
|
||||||
|
|
||||||
// TrackOutbound records an outbound ICMP connection
|
// TrackOutbound records an outbound ICMP connection
|
||||||
func (t *ICMPTracker) TrackOutbound(srcIP netip.Addr, dstIP netip.Addr, id uint16, typecode layers.ICMPv4TypeCode, size int) {
|
func (t *ICMPTracker) TrackOutbound(
|
||||||
|
srcIP netip.Addr,
|
||||||
|
dstIP netip.Addr,
|
||||||
|
id uint16,
|
||||||
|
typecode layers.ICMPv4TypeCode,
|
||||||
|
payload []byte,
|
||||||
|
size int,
|
||||||
|
) {
|
||||||
if _, exists := t.updateIfExists(dstIP, srcIP, id, nftypes.Egress, size); !exists {
|
if _, exists := t.updateIfExists(dstIP, srcIP, id, nftypes.Egress, size); !exists {
|
||||||
// if (inverted direction) conn is not tracked, track this direction
|
// if (inverted direction) conn is not tracked, track this direction
|
||||||
t.track(srcIP, dstIP, id, typecode, nftypes.Egress, nil, size)
|
t.track(srcIP, dstIP, id, typecode, nftypes.Egress, nil, payload, size)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// TrackInbound records an inbound ICMP Echo Request
|
// TrackInbound records an inbound ICMP Echo Request
|
||||||
func (t *ICMPTracker) TrackInbound(srcIP netip.Addr, dstIP netip.Addr, id uint16, typecode layers.ICMPv4TypeCode, ruleId []byte, size int) {
|
func (t *ICMPTracker) TrackInbound(
|
||||||
t.track(srcIP, dstIP, id, typecode, nftypes.Ingress, ruleId, size)
|
srcIP netip.Addr,
|
||||||
|
dstIP netip.Addr,
|
||||||
|
id uint16,
|
||||||
|
typecode layers.ICMPv4TypeCode,
|
||||||
|
ruleId []byte,
|
||||||
|
payload []byte,
|
||||||
|
size int,
|
||||||
|
) {
|
||||||
|
t.track(srcIP, dstIP, id, typecode, nftypes.Ingress, ruleId, payload, size)
|
||||||
}
|
}
|
||||||
|
|
||||||
// track is the common implementation for tracking both inbound and outbound ICMP connections
|
// track is the common implementation for tracking both inbound and outbound ICMP connections
|
||||||
func (t *ICMPTracker) track(srcIP netip.Addr, dstIP netip.Addr, id uint16, typecode layers.ICMPv4TypeCode, direction nftypes.Direction, ruleId []byte, size int) {
|
func (t *ICMPTracker) track(
|
||||||
|
srcIP netip.Addr,
|
||||||
|
dstIP netip.Addr,
|
||||||
|
id uint16,
|
||||||
|
typecode layers.ICMPv4TypeCode,
|
||||||
|
direction nftypes.Direction,
|
||||||
|
ruleId []byte,
|
||||||
|
payload []byte,
|
||||||
|
size int,
|
||||||
|
) {
|
||||||
key, exists := t.updateIfExists(srcIP, dstIP, id, direction, size)
|
key, exists := t.updateIfExists(srcIP, dstIP, id, direction, size)
|
||||||
if exists {
|
if exists {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
typ, code := typecode.Type(), typecode.Code()
|
typ, code := typecode.Type(), typecode.Code()
|
||||||
|
icmpInfo := ICMPInfo{
|
||||||
|
TypeCode: typecode,
|
||||||
|
}
|
||||||
|
if len(payload) > 0 {
|
||||||
|
icmpInfo.PayloadLen = len(payload)
|
||||||
|
if icmpInfo.PayloadLen > MaxICMPPayloadLength {
|
||||||
|
icmpInfo.PayloadLen = MaxICMPPayloadLength
|
||||||
|
}
|
||||||
|
copy(icmpInfo.PayloadData[:], payload[:icmpInfo.PayloadLen])
|
||||||
|
}
|
||||||
|
|
||||||
// non echo requests don't need tracking
|
// non echo requests don't need tracking
|
||||||
if typ != uint8(layers.ICMPv4TypeEchoRequest) {
|
if typ != uint8(layers.ICMPv4TypeEchoRequest) {
|
||||||
t.logger.Trace("New %s ICMP connection %s type %d code %d", direction, key, typ, code)
|
t.logger.Trace("New %s ICMP connection %s - %s", direction, key, icmpInfo)
|
||||||
t.sendStartEvent(direction, srcIP, dstIP, typ, code, ruleId, size)
|
t.sendStartEvent(direction, srcIP, dstIP, typ, code, ruleId, size)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
@ -138,7 +243,7 @@ func (t *ICMPTracker) track(srcIP netip.Addr, dstIP netip.Addr, id uint16, typec
|
|||||||
t.connections[key] = conn
|
t.connections[key] = conn
|
||||||
t.mutex.Unlock()
|
t.mutex.Unlock()
|
||||||
|
|
||||||
t.logger.Trace("New %s ICMP connection %s type %d code %d", direction, key, typ, code)
|
t.logger.Trace("New %s ICMP connection %s - %s", direction, key, icmpInfo)
|
||||||
t.sendEvent(nftypes.TypeStart, conn, ruleId)
|
t.sendEvent(nftypes.TypeStart, conn, ruleId)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -15,7 +15,7 @@ func BenchmarkICMPTracker(b *testing.B) {
|
|||||||
|
|
||||||
b.ResetTimer()
|
b.ResetTimer()
|
||||||
for i := 0; i < b.N; i++ {
|
for i := 0; i < b.N; i++ {
|
||||||
tracker.TrackOutbound(srcIP, dstIP, uint16(i%65535), 0, 0)
|
tracker.TrackOutbound(srcIP, dstIP, uint16(i%65535), 0, []byte{}, 0)
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
|
|
||||||
@ -28,7 +28,7 @@ func BenchmarkICMPTracker(b *testing.B) {
|
|||||||
|
|
||||||
// Pre-populate some connections
|
// Pre-populate some connections
|
||||||
for i := 0; i < 1000; i++ {
|
for i := 0; i < 1000; i++ {
|
||||||
tracker.TrackOutbound(srcIP, dstIP, uint16(i), 0, 0)
|
tracker.TrackOutbound(srcIP, dstIP, uint16(i), 0, []byte{}, 0)
|
||||||
}
|
}
|
||||||
|
|
||||||
b.ResetTimer()
|
b.ResetTimer()
|
||||||
|
@ -86,5 +86,5 @@ type epID stack.TransportEndpointID
|
|||||||
|
|
||||||
func (i epID) String() string {
|
func (i epID) String() string {
|
||||||
// src and remote is swapped
|
// src and remote is swapped
|
||||||
return fmt.Sprintf("%s:%d -> %s:%d", i.RemoteAddress, i.RemotePort, i.LocalAddress, i.LocalPort)
|
return fmt.Sprintf("%s:%d → %s:%d", i.RemoteAddress, i.RemotePort, i.LocalAddress, i.LocalPort)
|
||||||
}
|
}
|
||||||
|
@ -111,12 +111,12 @@ func (f *Forwarder) proxyTCP(id stack.TransportEndpointID, inConn *gonet.TCPConn
|
|||||||
|
|
||||||
if errInToOut != nil {
|
if errInToOut != nil {
|
||||||
if !isClosedError(errInToOut) {
|
if !isClosedError(errInToOut) {
|
||||||
f.logger.Error("proxyTCP: copy error (in -> out) for %s: %v", epID(id), errInToOut)
|
f.logger.Error("proxyTCP: copy error (in → out) for %s: %v", epID(id), errInToOut)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
if errOutToIn != nil {
|
if errOutToIn != nil {
|
||||||
if !isClosedError(errOutToIn) {
|
if !isClosedError(errOutToIn) {
|
||||||
f.logger.Error("proxyTCP: copy error (out -> in) for %s: %v", epID(id), errOutToIn)
|
f.logger.Error("proxyTCP: copy error (out → in) for %s: %v", epID(id), errOutToIn)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -250,10 +250,10 @@ func (f *Forwarder) proxyUDP(ctx context.Context, pConn *udpPacketConn, id stack
|
|||||||
wg.Wait()
|
wg.Wait()
|
||||||
|
|
||||||
if outboundErr != nil && !isClosedError(outboundErr) {
|
if outboundErr != nil && !isClosedError(outboundErr) {
|
||||||
f.logger.Error("proxyUDP: copy error (outbound->inbound) for %s: %v", epID(id), outboundErr)
|
f.logger.Error("proxyUDP: copy error (outbound→inbound) for %s: %v", epID(id), outboundErr)
|
||||||
}
|
}
|
||||||
if inboundErr != nil && !isClosedError(inboundErr) {
|
if inboundErr != nil && !isClosedError(inboundErr) {
|
||||||
f.logger.Error("proxyUDP: copy error (inbound->outbound) for %s: %v", epID(id), inboundErr)
|
f.logger.Error("proxyUDP: copy error (inbound→outbound) for %s: %v", epID(id), inboundErr)
|
||||||
}
|
}
|
||||||
|
|
||||||
var rxPackets, txPackets uint64
|
var rxPackets, txPackets uint64
|
||||||
|
@ -671,7 +671,7 @@ func (m *Manager) trackOutbound(d *decoder, srcIP, dstIP netip.Addr, size int) {
|
|||||||
flags := getTCPFlags(&d.tcp)
|
flags := getTCPFlags(&d.tcp)
|
||||||
m.tcpTracker.TrackOutbound(srcIP, dstIP, uint16(d.tcp.SrcPort), uint16(d.tcp.DstPort), flags, size)
|
m.tcpTracker.TrackOutbound(srcIP, dstIP, uint16(d.tcp.SrcPort), uint16(d.tcp.DstPort), flags, size)
|
||||||
case layers.LayerTypeICMPv4:
|
case layers.LayerTypeICMPv4:
|
||||||
m.icmpTracker.TrackOutbound(srcIP, dstIP, d.icmp4.Id, d.icmp4.TypeCode, size)
|
m.icmpTracker.TrackOutbound(srcIP, dstIP, d.icmp4.Id, d.icmp4.TypeCode, d.icmp4.Payload, size)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -684,7 +684,7 @@ func (m *Manager) trackInbound(d *decoder, srcIP, dstIP netip.Addr, ruleID []byt
|
|||||||
flags := getTCPFlags(&d.tcp)
|
flags := getTCPFlags(&d.tcp)
|
||||||
m.tcpTracker.TrackInbound(srcIP, dstIP, uint16(d.tcp.SrcPort), uint16(d.tcp.DstPort), flags, ruleID, size)
|
m.tcpTracker.TrackInbound(srcIP, dstIP, uint16(d.tcp.SrcPort), uint16(d.tcp.DstPort), flags, ruleID, size)
|
||||||
case layers.LayerTypeICMPv4:
|
case layers.LayerTypeICMPv4:
|
||||||
m.icmpTracker.TrackInbound(srcIP, dstIP, d.icmp4.Id, d.icmp4.TypeCode, ruleID, size)
|
m.icmpTracker.TrackInbound(srcIP, dstIP, d.icmp4.Id, d.icmp4.TypeCode, ruleID, d.icmp4.Payload, size)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -24,6 +24,7 @@ type WGTunDevice struct {
|
|||||||
mtu int
|
mtu int
|
||||||
iceBind *bind.ICEBind
|
iceBind *bind.ICEBind
|
||||||
tunAdapter TunAdapter
|
tunAdapter TunAdapter
|
||||||
|
disableDNS bool
|
||||||
|
|
||||||
name string
|
name string
|
||||||
device *device.Device
|
device *device.Device
|
||||||
@ -32,7 +33,7 @@ type WGTunDevice struct {
|
|||||||
configurer WGConfigurer
|
configurer WGConfigurer
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewTunDevice(address wgaddr.Address, port int, key string, mtu int, iceBind *bind.ICEBind, tunAdapter TunAdapter) *WGTunDevice {
|
func NewTunDevice(address wgaddr.Address, port int, key string, mtu int, iceBind *bind.ICEBind, tunAdapter TunAdapter, disableDNS bool) *WGTunDevice {
|
||||||
return &WGTunDevice{
|
return &WGTunDevice{
|
||||||
address: address,
|
address: address,
|
||||||
port: port,
|
port: port,
|
||||||
@ -40,6 +41,7 @@ func NewTunDevice(address wgaddr.Address, port int, key string, mtu int, iceBind
|
|||||||
mtu: mtu,
|
mtu: mtu,
|
||||||
iceBind: iceBind,
|
iceBind: iceBind,
|
||||||
tunAdapter: tunAdapter,
|
tunAdapter: tunAdapter,
|
||||||
|
disableDNS: disableDNS,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -49,6 +51,13 @@ func (t *WGTunDevice) Create(routes []string, dns string, searchDomains []string
|
|||||||
routesString := routesToString(routes)
|
routesString := routesToString(routes)
|
||||||
searchDomainsToString := searchDomainsToString(searchDomains)
|
searchDomainsToString := searchDomainsToString(searchDomains)
|
||||||
|
|
||||||
|
// Skip DNS configuration when DisableDNS is enabled
|
||||||
|
if t.disableDNS {
|
||||||
|
log.Info("DNS is disabled, skipping DNS and search domain configuration")
|
||||||
|
dns = ""
|
||||||
|
searchDomainsToString = ""
|
||||||
|
}
|
||||||
|
|
||||||
fd, err := t.tunAdapter.ConfigureInterface(t.address.String(), t.mtu, dns, searchDomainsToString, routesString)
|
fd, err := t.tunAdapter.ConfigureInterface(t.address.String(), t.mtu, dns, searchDomainsToString, routesString)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Errorf("failed to create Android interface: %s", err)
|
log.Errorf("failed to create Android interface: %s", err)
|
||||||
|
@ -43,6 +43,7 @@ type WGIFaceOpts struct {
|
|||||||
MobileArgs *device.MobileIFaceArguments
|
MobileArgs *device.MobileIFaceArguments
|
||||||
TransportNet transport.Net
|
TransportNet transport.Net
|
||||||
FilterFn bind.FilterFn
|
FilterFn bind.FilterFn
|
||||||
|
DisableDNS bool
|
||||||
}
|
}
|
||||||
|
|
||||||
// WGIface represents an interface instance
|
// WGIface represents an interface instance
|
||||||
|
@ -18,7 +18,7 @@ func NewWGIFace(opts WGIFaceOpts) (*WGIface, error) {
|
|||||||
|
|
||||||
wgIFace := &WGIface{
|
wgIFace := &WGIface{
|
||||||
userspaceBind: true,
|
userspaceBind: true,
|
||||||
tun: device.NewTunDevice(wgAddress, opts.WGPort, opts.WGPrivKey, opts.MTU, iceBind, opts.MobileArgs.TunAdapter),
|
tun: device.NewTunDevice(wgAddress, opts.WGPort, opts.WGPrivKey, opts.MTU, iceBind, opts.MobileArgs.TunAdapter, opts.DisableDNS),
|
||||||
wgProxyFactory: wgproxy.NewUSPFactory(iceBind),
|
wgProxyFactory: wgproxy.NewUSPFactory(iceBind),
|
||||||
}
|
}
|
||||||
return wgIFace, nil
|
return wgIFace, nil
|
||||||
|
@ -398,11 +398,15 @@ func (d *DefaultManager) squashAcceptRules(
|
|||||||
//
|
//
|
||||||
// We zeroed this to notify squash function that this protocol can't be squashed.
|
// We zeroed this to notify squash function that this protocol can't be squashed.
|
||||||
addRuleToCalculationMap := func(i int, r *mgmProto.FirewallRule, protocols map[mgmProto.RuleProtocol]*protoMatch) {
|
addRuleToCalculationMap := func(i int, r *mgmProto.FirewallRule, protocols map[mgmProto.RuleProtocol]*protoMatch) {
|
||||||
drop := r.Action == mgmProto.RuleAction_DROP || r.Port != ""
|
hasPortRestrictions := r.Action == mgmProto.RuleAction_DROP ||
|
||||||
if drop {
|
r.Port != "" || !portInfoEmpty(r.PortInfo)
|
||||||
|
|
||||||
|
if hasPortRestrictions {
|
||||||
|
// Don't squash rules with port restrictions
|
||||||
protocols[r.Protocol] = &protoMatch{ips: map[string]int{}}
|
protocols[r.Protocol] = &protoMatch{ips: map[string]int{}}
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
if _, ok := protocols[r.Protocol]; !ok {
|
if _, ok := protocols[r.Protocol]; !ok {
|
||||||
protocols[r.Protocol] = &protoMatch{
|
protocols[r.Protocol] = &protoMatch{
|
||||||
ips: map[string]int{},
|
ips: map[string]int{},
|
||||||
|
@ -330,6 +330,434 @@ func TestDefaultManagerSquashRulesNoAffect(t *testing.T) {
|
|||||||
assert.Equal(t, len(networkMap.FirewallRules), len(rules))
|
assert.Equal(t, len(networkMap.FirewallRules), len(rules))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestDefaultManagerSquashRulesWithPortRestrictions(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
rules []*mgmProto.FirewallRule
|
||||||
|
expectedCount int
|
||||||
|
description string
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "should not squash rules with port ranges",
|
||||||
|
rules: []*mgmProto.FirewallRule{
|
||||||
|
{
|
||||||
|
PeerIP: "10.93.0.1",
|
||||||
|
Direction: mgmProto.RuleDirection_IN,
|
||||||
|
Action: mgmProto.RuleAction_ACCEPT,
|
||||||
|
Protocol: mgmProto.RuleProtocol_TCP,
|
||||||
|
PortInfo: &mgmProto.PortInfo{
|
||||||
|
PortSelection: &mgmProto.PortInfo_Range_{
|
||||||
|
Range: &mgmProto.PortInfo_Range{
|
||||||
|
Start: 8080,
|
||||||
|
End: 8090,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
PeerIP: "10.93.0.2",
|
||||||
|
Direction: mgmProto.RuleDirection_IN,
|
||||||
|
Action: mgmProto.RuleAction_ACCEPT,
|
||||||
|
Protocol: mgmProto.RuleProtocol_TCP,
|
||||||
|
PortInfo: &mgmProto.PortInfo{
|
||||||
|
PortSelection: &mgmProto.PortInfo_Range_{
|
||||||
|
Range: &mgmProto.PortInfo_Range{
|
||||||
|
Start: 8080,
|
||||||
|
End: 8090,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
PeerIP: "10.93.0.3",
|
||||||
|
Direction: mgmProto.RuleDirection_IN,
|
||||||
|
Action: mgmProto.RuleAction_ACCEPT,
|
||||||
|
Protocol: mgmProto.RuleProtocol_TCP,
|
||||||
|
PortInfo: &mgmProto.PortInfo{
|
||||||
|
PortSelection: &mgmProto.PortInfo_Range_{
|
||||||
|
Range: &mgmProto.PortInfo_Range{
|
||||||
|
Start: 8080,
|
||||||
|
End: 8090,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
PeerIP: "10.93.0.4",
|
||||||
|
Direction: mgmProto.RuleDirection_IN,
|
||||||
|
Action: mgmProto.RuleAction_ACCEPT,
|
||||||
|
Protocol: mgmProto.RuleProtocol_TCP,
|
||||||
|
PortInfo: &mgmProto.PortInfo{
|
||||||
|
PortSelection: &mgmProto.PortInfo_Range_{
|
||||||
|
Range: &mgmProto.PortInfo_Range{
|
||||||
|
Start: 8080,
|
||||||
|
End: 8090,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
expectedCount: 4,
|
||||||
|
description: "Rules with port ranges should not be squashed even if they cover all peers",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "should not squash rules with specific ports",
|
||||||
|
rules: []*mgmProto.FirewallRule{
|
||||||
|
{
|
||||||
|
PeerIP: "10.93.0.1",
|
||||||
|
Direction: mgmProto.RuleDirection_IN,
|
||||||
|
Action: mgmProto.RuleAction_ACCEPT,
|
||||||
|
Protocol: mgmProto.RuleProtocol_TCP,
|
||||||
|
PortInfo: &mgmProto.PortInfo{
|
||||||
|
PortSelection: &mgmProto.PortInfo_Port{
|
||||||
|
Port: 80,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
PeerIP: "10.93.0.2",
|
||||||
|
Direction: mgmProto.RuleDirection_IN,
|
||||||
|
Action: mgmProto.RuleAction_ACCEPT,
|
||||||
|
Protocol: mgmProto.RuleProtocol_TCP,
|
||||||
|
PortInfo: &mgmProto.PortInfo{
|
||||||
|
PortSelection: &mgmProto.PortInfo_Port{
|
||||||
|
Port: 80,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
PeerIP: "10.93.0.3",
|
||||||
|
Direction: mgmProto.RuleDirection_IN,
|
||||||
|
Action: mgmProto.RuleAction_ACCEPT,
|
||||||
|
Protocol: mgmProto.RuleProtocol_TCP,
|
||||||
|
PortInfo: &mgmProto.PortInfo{
|
||||||
|
PortSelection: &mgmProto.PortInfo_Port{
|
||||||
|
Port: 80,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
PeerIP: "10.93.0.4",
|
||||||
|
Direction: mgmProto.RuleDirection_IN,
|
||||||
|
Action: mgmProto.RuleAction_ACCEPT,
|
||||||
|
Protocol: mgmProto.RuleProtocol_TCP,
|
||||||
|
PortInfo: &mgmProto.PortInfo{
|
||||||
|
PortSelection: &mgmProto.PortInfo_Port{
|
||||||
|
Port: 80,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
expectedCount: 4,
|
||||||
|
description: "Rules with specific ports should not be squashed even if they cover all peers",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "should not squash rules with legacy port field",
|
||||||
|
rules: []*mgmProto.FirewallRule{
|
||||||
|
{
|
||||||
|
PeerIP: "10.93.0.1",
|
||||||
|
Direction: mgmProto.RuleDirection_IN,
|
||||||
|
Action: mgmProto.RuleAction_ACCEPT,
|
||||||
|
Protocol: mgmProto.RuleProtocol_TCP,
|
||||||
|
Port: "443",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
PeerIP: "10.93.0.2",
|
||||||
|
Direction: mgmProto.RuleDirection_IN,
|
||||||
|
Action: mgmProto.RuleAction_ACCEPT,
|
||||||
|
Protocol: mgmProto.RuleProtocol_TCP,
|
||||||
|
Port: "443",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
PeerIP: "10.93.0.3",
|
||||||
|
Direction: mgmProto.RuleDirection_IN,
|
||||||
|
Action: mgmProto.RuleAction_ACCEPT,
|
||||||
|
Protocol: mgmProto.RuleProtocol_TCP,
|
||||||
|
Port: "443",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
PeerIP: "10.93.0.4",
|
||||||
|
Direction: mgmProto.RuleDirection_IN,
|
||||||
|
Action: mgmProto.RuleAction_ACCEPT,
|
||||||
|
Protocol: mgmProto.RuleProtocol_TCP,
|
||||||
|
Port: "443",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
expectedCount: 4,
|
||||||
|
description: "Rules with legacy port field should not be squashed",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "should not squash rules with DROP action",
|
||||||
|
rules: []*mgmProto.FirewallRule{
|
||||||
|
{
|
||||||
|
PeerIP: "10.93.0.1",
|
||||||
|
Direction: mgmProto.RuleDirection_IN,
|
||||||
|
Action: mgmProto.RuleAction_DROP,
|
||||||
|
Protocol: mgmProto.RuleProtocol_TCP,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
PeerIP: "10.93.0.2",
|
||||||
|
Direction: mgmProto.RuleDirection_IN,
|
||||||
|
Action: mgmProto.RuleAction_DROP,
|
||||||
|
Protocol: mgmProto.RuleProtocol_TCP,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
PeerIP: "10.93.0.3",
|
||||||
|
Direction: mgmProto.RuleDirection_IN,
|
||||||
|
Action: mgmProto.RuleAction_DROP,
|
||||||
|
Protocol: mgmProto.RuleProtocol_TCP,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
PeerIP: "10.93.0.4",
|
||||||
|
Direction: mgmProto.RuleDirection_IN,
|
||||||
|
Action: mgmProto.RuleAction_DROP,
|
||||||
|
Protocol: mgmProto.RuleProtocol_TCP,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
expectedCount: 4,
|
||||||
|
description: "Rules with DROP action should not be squashed",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "should squash rules without port restrictions",
|
||||||
|
rules: []*mgmProto.FirewallRule{
|
||||||
|
{
|
||||||
|
PeerIP: "10.93.0.1",
|
||||||
|
Direction: mgmProto.RuleDirection_IN,
|
||||||
|
Action: mgmProto.RuleAction_ACCEPT,
|
||||||
|
Protocol: mgmProto.RuleProtocol_TCP,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
PeerIP: "10.93.0.2",
|
||||||
|
Direction: mgmProto.RuleDirection_IN,
|
||||||
|
Action: mgmProto.RuleAction_ACCEPT,
|
||||||
|
Protocol: mgmProto.RuleProtocol_TCP,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
PeerIP: "10.93.0.3",
|
||||||
|
Direction: mgmProto.RuleDirection_IN,
|
||||||
|
Action: mgmProto.RuleAction_ACCEPT,
|
||||||
|
Protocol: mgmProto.RuleProtocol_TCP,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
PeerIP: "10.93.0.4",
|
||||||
|
Direction: mgmProto.RuleDirection_IN,
|
||||||
|
Action: mgmProto.RuleAction_ACCEPT,
|
||||||
|
Protocol: mgmProto.RuleProtocol_TCP,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
expectedCount: 1,
|
||||||
|
description: "Rules without port restrictions should be squashed into a single 0.0.0.0 rule",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "mixed rules should not squash protocol with port restrictions",
|
||||||
|
rules: []*mgmProto.FirewallRule{
|
||||||
|
{
|
||||||
|
PeerIP: "10.93.0.1",
|
||||||
|
Direction: mgmProto.RuleDirection_IN,
|
||||||
|
Action: mgmProto.RuleAction_ACCEPT,
|
||||||
|
Protocol: mgmProto.RuleProtocol_TCP,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
PeerIP: "10.93.0.2",
|
||||||
|
Direction: mgmProto.RuleDirection_IN,
|
||||||
|
Action: mgmProto.RuleAction_ACCEPT,
|
||||||
|
Protocol: mgmProto.RuleProtocol_TCP,
|
||||||
|
PortInfo: &mgmProto.PortInfo{
|
||||||
|
PortSelection: &mgmProto.PortInfo_Port{
|
||||||
|
Port: 80,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
PeerIP: "10.93.0.3",
|
||||||
|
Direction: mgmProto.RuleDirection_IN,
|
||||||
|
Action: mgmProto.RuleAction_ACCEPT,
|
||||||
|
Protocol: mgmProto.RuleProtocol_TCP,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
PeerIP: "10.93.0.4",
|
||||||
|
Direction: mgmProto.RuleDirection_IN,
|
||||||
|
Action: mgmProto.RuleAction_ACCEPT,
|
||||||
|
Protocol: mgmProto.RuleProtocol_TCP,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
expectedCount: 4,
|
||||||
|
description: "TCP should not be squashed because one rule has port restrictions",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "should squash UDP but not TCP when TCP has port restrictions",
|
||||||
|
rules: []*mgmProto.FirewallRule{
|
||||||
|
// TCP rules with port restrictions - should NOT be squashed
|
||||||
|
{
|
||||||
|
PeerIP: "10.93.0.1",
|
||||||
|
Direction: mgmProto.RuleDirection_IN,
|
||||||
|
Action: mgmProto.RuleAction_ACCEPT,
|
||||||
|
Protocol: mgmProto.RuleProtocol_TCP,
|
||||||
|
Port: "443",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
PeerIP: "10.93.0.2",
|
||||||
|
Direction: mgmProto.RuleDirection_IN,
|
||||||
|
Action: mgmProto.RuleAction_ACCEPT,
|
||||||
|
Protocol: mgmProto.RuleProtocol_TCP,
|
||||||
|
Port: "443",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
PeerIP: "10.93.0.3",
|
||||||
|
Direction: mgmProto.RuleDirection_IN,
|
||||||
|
Action: mgmProto.RuleAction_ACCEPT,
|
||||||
|
Protocol: mgmProto.RuleProtocol_TCP,
|
||||||
|
Port: "443",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
PeerIP: "10.93.0.4",
|
||||||
|
Direction: mgmProto.RuleDirection_IN,
|
||||||
|
Action: mgmProto.RuleAction_ACCEPT,
|
||||||
|
Protocol: mgmProto.RuleProtocol_TCP,
|
||||||
|
Port: "443",
|
||||||
|
},
|
||||||
|
// UDP rules without port restrictions - SHOULD be squashed
|
||||||
|
{
|
||||||
|
PeerIP: "10.93.0.1",
|
||||||
|
Direction: mgmProto.RuleDirection_IN,
|
||||||
|
Action: mgmProto.RuleAction_ACCEPT,
|
||||||
|
Protocol: mgmProto.RuleProtocol_UDP,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
PeerIP: "10.93.0.2",
|
||||||
|
Direction: mgmProto.RuleDirection_IN,
|
||||||
|
Action: mgmProto.RuleAction_ACCEPT,
|
||||||
|
Protocol: mgmProto.RuleProtocol_UDP,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
PeerIP: "10.93.0.3",
|
||||||
|
Direction: mgmProto.RuleDirection_IN,
|
||||||
|
Action: mgmProto.RuleAction_ACCEPT,
|
||||||
|
Protocol: mgmProto.RuleProtocol_UDP,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
PeerIP: "10.93.0.4",
|
||||||
|
Direction: mgmProto.RuleDirection_IN,
|
||||||
|
Action: mgmProto.RuleAction_ACCEPT,
|
||||||
|
Protocol: mgmProto.RuleProtocol_UDP,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
expectedCount: 5, // 4 TCP rules + 1 squashed UDP rule (0.0.0.0)
|
||||||
|
description: "UDP should be squashed to 0.0.0.0 rule, but TCP should remain as individual rules due to port restrictions",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
networkMap := &mgmProto.NetworkMap{
|
||||||
|
RemotePeers: []*mgmProto.RemotePeerConfig{
|
||||||
|
{AllowedIps: []string{"10.93.0.1"}},
|
||||||
|
{AllowedIps: []string{"10.93.0.2"}},
|
||||||
|
{AllowedIps: []string{"10.93.0.3"}},
|
||||||
|
{AllowedIps: []string{"10.93.0.4"}},
|
||||||
|
},
|
||||||
|
FirewallRules: tt.rules,
|
||||||
|
}
|
||||||
|
|
||||||
|
manager := &DefaultManager{}
|
||||||
|
rules, _ := manager.squashAcceptRules(networkMap)
|
||||||
|
|
||||||
|
assert.Equal(t, tt.expectedCount, len(rules), tt.description)
|
||||||
|
|
||||||
|
// For squashed rules, verify we get the expected 0.0.0.0 rule
|
||||||
|
if tt.expectedCount == 1 {
|
||||||
|
assert.Equal(t, "0.0.0.0", rules[0].PeerIP)
|
||||||
|
assert.Equal(t, mgmProto.RuleDirection_IN, rules[0].Direction)
|
||||||
|
assert.Equal(t, mgmProto.RuleAction_ACCEPT, rules[0].Action)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestPortInfoEmpty(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
portInfo *mgmProto.PortInfo
|
||||||
|
expected bool
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "nil PortInfo should be empty",
|
||||||
|
portInfo: nil,
|
||||||
|
expected: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "PortInfo with zero port should be empty",
|
||||||
|
portInfo: &mgmProto.PortInfo{
|
||||||
|
PortSelection: &mgmProto.PortInfo_Port{
|
||||||
|
Port: 0,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
expected: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "PortInfo with valid port should not be empty",
|
||||||
|
portInfo: &mgmProto.PortInfo{
|
||||||
|
PortSelection: &mgmProto.PortInfo_Port{
|
||||||
|
Port: 80,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
expected: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "PortInfo with nil range should be empty",
|
||||||
|
portInfo: &mgmProto.PortInfo{
|
||||||
|
PortSelection: &mgmProto.PortInfo_Range_{
|
||||||
|
Range: nil,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
expected: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "PortInfo with zero start range should be empty",
|
||||||
|
portInfo: &mgmProto.PortInfo{
|
||||||
|
PortSelection: &mgmProto.PortInfo_Range_{
|
||||||
|
Range: &mgmProto.PortInfo_Range{
|
||||||
|
Start: 0,
|
||||||
|
End: 100,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
expected: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "PortInfo with zero end range should be empty",
|
||||||
|
portInfo: &mgmProto.PortInfo{
|
||||||
|
PortSelection: &mgmProto.PortInfo_Range_{
|
||||||
|
Range: &mgmProto.PortInfo_Range{
|
||||||
|
Start: 80,
|
||||||
|
End: 0,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
expected: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "PortInfo with valid range should not be empty",
|
||||||
|
portInfo: &mgmProto.PortInfo{
|
||||||
|
PortSelection: &mgmProto.PortInfo_Range_{
|
||||||
|
Range: &mgmProto.PortInfo_Range{
|
||||||
|
Start: 8080,
|
||||||
|
End: 8090,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
expected: false,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
result := portInfoEmpty(tt.portInfo)
|
||||||
|
assert.Equal(t, tt.expected, result)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func TestDefaultManagerEnableSSHRules(t *testing.T) {
|
func TestDefaultManagerEnableSSHRules(t *testing.T) {
|
||||||
networkMap := &mgmProto.NetworkMap{
|
networkMap := &mgmProto.NetworkMap{
|
||||||
PeerConfig: &mgmProto.PeerConfig{
|
PeerConfig: &mgmProto.PeerConfig{
|
||||||
|
@ -223,6 +223,8 @@ func createNewConfig(input ConfigInput) (*Config, error) {
|
|||||||
config := &Config{
|
config := &Config{
|
||||||
// defaults to false only for new (post 0.26) configurations
|
// defaults to false only for new (post 0.26) configurations
|
||||||
ServerSSHAllowed: util.False(),
|
ServerSSHAllowed: util.False(),
|
||||||
|
// default to disabling server routes on Android for security
|
||||||
|
DisableServerRoutes: runtime.GOOS == "android",
|
||||||
}
|
}
|
||||||
|
|
||||||
if _, err := config.apply(input); err != nil {
|
if _, err := config.apply(input); err != nil {
|
||||||
@ -416,9 +418,15 @@ func (config *Config) apply(input ConfigInput) (updated bool, err error) {
|
|||||||
config.ServerSSHAllowed = input.ServerSSHAllowed
|
config.ServerSSHAllowed = input.ServerSSHAllowed
|
||||||
updated = true
|
updated = true
|
||||||
} else if config.ServerSSHAllowed == nil {
|
} else if config.ServerSSHAllowed == nil {
|
||||||
|
if runtime.GOOS == "android" {
|
||||||
|
// default to disabled SSH on Android for security
|
||||||
|
log.Infof("setting SSH server to false by default on Android")
|
||||||
|
config.ServerSSHAllowed = util.False()
|
||||||
|
} else {
|
||||||
// enables SSH for configs from old versions to preserve backwards compatibility
|
// enables SSH for configs from old versions to preserve backwards compatibility
|
||||||
log.Infof("falling back to enabled SSH server for pre-existing configuration")
|
log.Infof("falling back to enabled SSH server for pre-existing configuration")
|
||||||
config.ServerSSHAllowed = util.True()
|
config.ServerSSHAllowed = util.True()
|
||||||
|
}
|
||||||
updated = true
|
updated = true
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -11,8 +11,9 @@ import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
const (
|
const (
|
||||||
PriorityDNSRoute = 100
|
PriorityLocal = 100
|
||||||
PriorityMatchDomain = 50
|
PriorityDNSRoute = 75
|
||||||
|
PriorityUpstream = 50
|
||||||
PriorityDefault = 1
|
PriorityDefault = 1
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -22,7 +22,7 @@ func TestHandlerChain_ServeDNS_Priorities(t *testing.T) {
|
|||||||
|
|
||||||
// Setup handlers with different priorities
|
// Setup handlers with different priorities
|
||||||
chain.AddHandler("example.com.", defaultHandler, nbdns.PriorityDefault)
|
chain.AddHandler("example.com.", defaultHandler, nbdns.PriorityDefault)
|
||||||
chain.AddHandler("example.com.", matchDomainHandler, nbdns.PriorityMatchDomain)
|
chain.AddHandler("example.com.", matchDomainHandler, nbdns.PriorityUpstream)
|
||||||
chain.AddHandler("example.com.", dnsRouteHandler, nbdns.PriorityDNSRoute)
|
chain.AddHandler("example.com.", dnsRouteHandler, nbdns.PriorityDNSRoute)
|
||||||
|
|
||||||
// Create test request
|
// Create test request
|
||||||
@ -200,7 +200,7 @@ func TestHandlerChain_ServeDNS_OverlappingDomains(t *testing.T) {
|
|||||||
priority int
|
priority int
|
||||||
}{
|
}{
|
||||||
{pattern: "*.example.com.", priority: nbdns.PriorityDefault},
|
{pattern: "*.example.com.", priority: nbdns.PriorityDefault},
|
||||||
{pattern: "*.example.com.", priority: nbdns.PriorityMatchDomain},
|
{pattern: "*.example.com.", priority: nbdns.PriorityUpstream},
|
||||||
{pattern: "*.example.com.", priority: nbdns.PriorityDNSRoute},
|
{pattern: "*.example.com.", priority: nbdns.PriorityDNSRoute},
|
||||||
},
|
},
|
||||||
queryDomain: "test.example.com.",
|
queryDomain: "test.example.com.",
|
||||||
@ -214,7 +214,7 @@ func TestHandlerChain_ServeDNS_OverlappingDomains(t *testing.T) {
|
|||||||
priority int
|
priority int
|
||||||
}{
|
}{
|
||||||
{pattern: "*.example.com.", priority: nbdns.PriorityDefault},
|
{pattern: "*.example.com.", priority: nbdns.PriorityDefault},
|
||||||
{pattern: "test.example.com.", priority: nbdns.PriorityMatchDomain},
|
{pattern: "test.example.com.", priority: nbdns.PriorityUpstream},
|
||||||
{pattern: "*.test.example.com.", priority: nbdns.PriorityDNSRoute},
|
{pattern: "*.test.example.com.", priority: nbdns.PriorityDNSRoute},
|
||||||
},
|
},
|
||||||
queryDomain: "sub.test.example.com.",
|
queryDomain: "sub.test.example.com.",
|
||||||
@ -281,7 +281,7 @@ func TestHandlerChain_ServeDNS_ChainContinuation(t *testing.T) {
|
|||||||
|
|
||||||
// Add handlers in priority order
|
// Add handlers in priority order
|
||||||
chain.AddHandler("example.com.", handler1, nbdns.PriorityDNSRoute)
|
chain.AddHandler("example.com.", handler1, nbdns.PriorityDNSRoute)
|
||||||
chain.AddHandler("example.com.", handler2, nbdns.PriorityMatchDomain)
|
chain.AddHandler("example.com.", handler2, nbdns.PriorityUpstream)
|
||||||
chain.AddHandler("example.com.", handler3, nbdns.PriorityDefault)
|
chain.AddHandler("example.com.", handler3, nbdns.PriorityDefault)
|
||||||
|
|
||||||
// Create test request
|
// Create test request
|
||||||
@ -344,13 +344,13 @@ func TestHandlerChain_PriorityDeregistration(t *testing.T) {
|
|||||||
priority int
|
priority int
|
||||||
}{
|
}{
|
||||||
{"add", "example.com.", nbdns.PriorityDNSRoute},
|
{"add", "example.com.", nbdns.PriorityDNSRoute},
|
||||||
{"add", "example.com.", nbdns.PriorityMatchDomain},
|
{"add", "example.com.", nbdns.PriorityUpstream},
|
||||||
{"remove", "example.com.", nbdns.PriorityDNSRoute},
|
{"remove", "example.com.", nbdns.PriorityDNSRoute},
|
||||||
},
|
},
|
||||||
query: "example.com.",
|
query: "example.com.",
|
||||||
expectedCalls: map[int]bool{
|
expectedCalls: map[int]bool{
|
||||||
nbdns.PriorityDNSRoute: false,
|
nbdns.PriorityDNSRoute: false,
|
||||||
nbdns.PriorityMatchDomain: true,
|
nbdns.PriorityUpstream: true,
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
@ -361,13 +361,13 @@ func TestHandlerChain_PriorityDeregistration(t *testing.T) {
|
|||||||
priority int
|
priority int
|
||||||
}{
|
}{
|
||||||
{"add", "example.com.", nbdns.PriorityDNSRoute},
|
{"add", "example.com.", nbdns.PriorityDNSRoute},
|
||||||
{"add", "example.com.", nbdns.PriorityMatchDomain},
|
{"add", "example.com.", nbdns.PriorityUpstream},
|
||||||
{"remove", "example.com.", nbdns.PriorityMatchDomain},
|
{"remove", "example.com.", nbdns.PriorityUpstream},
|
||||||
},
|
},
|
||||||
query: "example.com.",
|
query: "example.com.",
|
||||||
expectedCalls: map[int]bool{
|
expectedCalls: map[int]bool{
|
||||||
nbdns.PriorityDNSRoute: true,
|
nbdns.PriorityDNSRoute: true,
|
||||||
nbdns.PriorityMatchDomain: false,
|
nbdns.PriorityUpstream: false,
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
@ -378,15 +378,15 @@ func TestHandlerChain_PriorityDeregistration(t *testing.T) {
|
|||||||
priority int
|
priority int
|
||||||
}{
|
}{
|
||||||
{"add", "example.com.", nbdns.PriorityDNSRoute},
|
{"add", "example.com.", nbdns.PriorityDNSRoute},
|
||||||
{"add", "example.com.", nbdns.PriorityMatchDomain},
|
{"add", "example.com.", nbdns.PriorityUpstream},
|
||||||
{"add", "example.com.", nbdns.PriorityDefault},
|
{"add", "example.com.", nbdns.PriorityDefault},
|
||||||
{"remove", "example.com.", nbdns.PriorityDNSRoute},
|
{"remove", "example.com.", nbdns.PriorityDNSRoute},
|
||||||
{"remove", "example.com.", nbdns.PriorityMatchDomain},
|
{"remove", "example.com.", nbdns.PriorityUpstream},
|
||||||
},
|
},
|
||||||
query: "example.com.",
|
query: "example.com.",
|
||||||
expectedCalls: map[int]bool{
|
expectedCalls: map[int]bool{
|
||||||
nbdns.PriorityDNSRoute: false,
|
nbdns.PriorityDNSRoute: false,
|
||||||
nbdns.PriorityMatchDomain: false,
|
nbdns.PriorityUpstream: false,
|
||||||
nbdns.PriorityDefault: true,
|
nbdns.PriorityDefault: true,
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
@ -454,7 +454,7 @@ func TestHandlerChain_MultiPriorityHandling(t *testing.T) {
|
|||||||
// Add handlers in mixed order
|
// Add handlers in mixed order
|
||||||
chain.AddHandler(testDomain, defaultHandler, nbdns.PriorityDefault)
|
chain.AddHandler(testDomain, defaultHandler, nbdns.PriorityDefault)
|
||||||
chain.AddHandler(testDomain, routeHandler, nbdns.PriorityDNSRoute)
|
chain.AddHandler(testDomain, routeHandler, nbdns.PriorityDNSRoute)
|
||||||
chain.AddHandler(testDomain, matchHandler, nbdns.PriorityMatchDomain)
|
chain.AddHandler(testDomain, matchHandler, nbdns.PriorityUpstream)
|
||||||
|
|
||||||
// Test 1: Initial state
|
// Test 1: Initial state
|
||||||
w1 := &nbdns.ResponseWriterChain{ResponseWriter: &test.MockResponseWriter{}}
|
w1 := &nbdns.ResponseWriterChain{ResponseWriter: &test.MockResponseWriter{}}
|
||||||
@ -490,7 +490,7 @@ func TestHandlerChain_MultiPriorityHandling(t *testing.T) {
|
|||||||
defaultHandler.Calls = nil
|
defaultHandler.Calls = nil
|
||||||
|
|
||||||
// Test 3: Remove middle priority handler
|
// Test 3: Remove middle priority handler
|
||||||
chain.RemoveHandler(testDomain, nbdns.PriorityMatchDomain)
|
chain.RemoveHandler(testDomain, nbdns.PriorityUpstream)
|
||||||
|
|
||||||
w3 := &nbdns.ResponseWriterChain{ResponseWriter: &test.MockResponseWriter{}}
|
w3 := &nbdns.ResponseWriterChain{ResponseWriter: &test.MockResponseWriter{}}
|
||||||
// Now lowest priority handler (defaultHandler) should be called
|
// Now lowest priority handler (defaultHandler) should be called
|
||||||
@ -607,7 +607,7 @@ func TestHandlerChain_CaseSensitivity(t *testing.T) {
|
|||||||
shouldMatch bool
|
shouldMatch bool
|
||||||
}{
|
}{
|
||||||
{"EXAMPLE.COM.", nbdns.PriorityDefault, false, false},
|
{"EXAMPLE.COM.", nbdns.PriorityDefault, false, false},
|
||||||
{"example.com.", nbdns.PriorityMatchDomain, false, false},
|
{"example.com.", nbdns.PriorityUpstream, false, false},
|
||||||
{"Example.Com.", nbdns.PriorityDNSRoute, false, true},
|
{"Example.Com.", nbdns.PriorityDNSRoute, false, true},
|
||||||
},
|
},
|
||||||
query: "example.com.",
|
query: "example.com.",
|
||||||
@ -702,8 +702,8 @@ func TestHandlerChain_DomainSpecificityOrdering(t *testing.T) {
|
|||||||
priority int
|
priority int
|
||||||
subdomain bool
|
subdomain bool
|
||||||
}{
|
}{
|
||||||
{"add", "example.com.", nbdns.PriorityMatchDomain, true},
|
{"add", "example.com.", nbdns.PriorityUpstream, true},
|
||||||
{"add", "sub.example.com.", nbdns.PriorityMatchDomain, false},
|
{"add", "sub.example.com.", nbdns.PriorityUpstream, false},
|
||||||
},
|
},
|
||||||
query: "sub.example.com.",
|
query: "sub.example.com.",
|
||||||
expectedMatch: "sub.example.com.",
|
expectedMatch: "sub.example.com.",
|
||||||
@ -717,8 +717,8 @@ func TestHandlerChain_DomainSpecificityOrdering(t *testing.T) {
|
|||||||
priority int
|
priority int
|
||||||
subdomain bool
|
subdomain bool
|
||||||
}{
|
}{
|
||||||
{"add", "example.com.", nbdns.PriorityMatchDomain, true},
|
{"add", "example.com.", nbdns.PriorityUpstream, true},
|
||||||
{"add", "sub.example.com.", nbdns.PriorityMatchDomain, true},
|
{"add", "sub.example.com.", nbdns.PriorityUpstream, true},
|
||||||
},
|
},
|
||||||
query: "sub.example.com.",
|
query: "sub.example.com.",
|
||||||
expectedMatch: "sub.example.com.",
|
expectedMatch: "sub.example.com.",
|
||||||
@ -732,10 +732,10 @@ func TestHandlerChain_DomainSpecificityOrdering(t *testing.T) {
|
|||||||
priority int
|
priority int
|
||||||
subdomain bool
|
subdomain bool
|
||||||
}{
|
}{
|
||||||
{"add", "example.com.", nbdns.PriorityMatchDomain, true},
|
{"add", "example.com.", nbdns.PriorityUpstream, true},
|
||||||
{"add", "sub.example.com.", nbdns.PriorityMatchDomain, true},
|
{"add", "sub.example.com.", nbdns.PriorityUpstream, true},
|
||||||
{"add", "test.sub.example.com.", nbdns.PriorityMatchDomain, false},
|
{"add", "test.sub.example.com.", nbdns.PriorityUpstream, false},
|
||||||
{"remove", "test.sub.example.com.", nbdns.PriorityMatchDomain, false},
|
{"remove", "test.sub.example.com.", nbdns.PriorityUpstream, false},
|
||||||
},
|
},
|
||||||
query: "test.sub.example.com.",
|
query: "test.sub.example.com.",
|
||||||
expectedMatch: "sub.example.com.",
|
expectedMatch: "sub.example.com.",
|
||||||
@ -749,7 +749,7 @@ func TestHandlerChain_DomainSpecificityOrdering(t *testing.T) {
|
|||||||
priority int
|
priority int
|
||||||
subdomain bool
|
subdomain bool
|
||||||
}{
|
}{
|
||||||
{"add", "sub.example.com.", nbdns.PriorityMatchDomain, false},
|
{"add", "sub.example.com.", nbdns.PriorityUpstream, false},
|
||||||
{"add", "example.com.", nbdns.PriorityDNSRoute, true},
|
{"add", "example.com.", nbdns.PriorityDNSRoute, true},
|
||||||
},
|
},
|
||||||
query: "sub.example.com.",
|
query: "sub.example.com.",
|
||||||
@ -764,9 +764,9 @@ func TestHandlerChain_DomainSpecificityOrdering(t *testing.T) {
|
|||||||
priority int
|
priority int
|
||||||
subdomain bool
|
subdomain bool
|
||||||
}{
|
}{
|
||||||
{"add", "example.com.", nbdns.PriorityMatchDomain, true},
|
{"add", "example.com.", nbdns.PriorityUpstream, true},
|
||||||
{"add", "other.example.com.", nbdns.PriorityMatchDomain, true},
|
{"add", "other.example.com.", nbdns.PriorityUpstream, true},
|
||||||
{"add", "sub.example.com.", nbdns.PriorityMatchDomain, false},
|
{"add", "sub.example.com.", nbdns.PriorityUpstream, false},
|
||||||
},
|
},
|
||||||
query: "sub.example.com.",
|
query: "sub.example.com.",
|
||||||
expectedMatch: "sub.example.com.",
|
expectedMatch: "sub.example.com.",
|
||||||
|
@ -527,7 +527,7 @@ func (s *DefaultServer) buildLocalHandlerUpdate(customZones []nbdns.CustomZone)
|
|||||||
muxUpdates = append(muxUpdates, handlerWrapper{
|
muxUpdates = append(muxUpdates, handlerWrapper{
|
||||||
domain: customZone.Domain,
|
domain: customZone.Domain,
|
||||||
handler: s.localResolver,
|
handler: s.localResolver,
|
||||||
priority: PriorityMatchDomain,
|
priority: PriorityLocal,
|
||||||
})
|
})
|
||||||
|
|
||||||
for _, record := range customZone.Records {
|
for _, record := range customZone.Records {
|
||||||
@ -566,7 +566,7 @@ func (s *DefaultServer) buildUpstreamHandlerUpdate(nameServerGroups []*nbdns.Nam
|
|||||||
groupedNS := groupNSGroupsByDomain(nameServerGroups)
|
groupedNS := groupNSGroupsByDomain(nameServerGroups)
|
||||||
|
|
||||||
for _, domainGroup := range groupedNS {
|
for _, domainGroup := range groupedNS {
|
||||||
basePriority := PriorityMatchDomain
|
basePriority := PriorityUpstream
|
||||||
if domainGroup.domain == nbdns.RootZone {
|
if domainGroup.domain == nbdns.RootZone {
|
||||||
basePriority = PriorityDefault
|
basePriority = PriorityDefault
|
||||||
}
|
}
|
||||||
@ -588,10 +588,14 @@ func (s *DefaultServer) createHandlersForDomainGroup(domainGroup nsGroupsByDomai
|
|||||||
// Decrement priority by handler index (0, 1, 2, ...) to avoid conflicts
|
// Decrement priority by handler index (0, 1, 2, ...) to avoid conflicts
|
||||||
priority := basePriority - i
|
priority := basePriority - i
|
||||||
|
|
||||||
// Check if we're about to overlap with the next priority tier
|
// Check if we're about to overlap with the next priority tier.
|
||||||
if basePriority == PriorityMatchDomain && priority <= PriorityDefault {
|
// This boundary check ensures that the priority of upstream handlers does not conflict
|
||||||
|
// with the default priority tier. By decrementing the priority for each handler, we avoid
|
||||||
|
// overlaps, but if the calculated priority falls into the default tier, we skip the remaining
|
||||||
|
// handlers to maintain the integrity of the priority system.
|
||||||
|
if basePriority == PriorityUpstream && priority <= PriorityDefault {
|
||||||
log.Warnf("too many handlers for domain=%s, would overlap with default priority tier (diff=%d). Skipping remaining handlers",
|
log.Warnf("too many handlers for domain=%s, would overlap with default priority tier (diff=%d). Skipping remaining handlers",
|
||||||
domainGroup.domain, PriorityMatchDomain-PriorityDefault)
|
domainGroup.domain, PriorityUpstream-PriorityDefault)
|
||||||
break
|
break
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -164,12 +164,12 @@ func TestUpdateDNSServer(t *testing.T) {
|
|||||||
generateDummyHandler("netbird.io", nameServers).ID(): handlerWrapper{
|
generateDummyHandler("netbird.io", nameServers).ID(): handlerWrapper{
|
||||||
domain: "netbird.io",
|
domain: "netbird.io",
|
||||||
handler: dummyHandler,
|
handler: dummyHandler,
|
||||||
priority: PriorityMatchDomain,
|
priority: PriorityUpstream,
|
||||||
},
|
},
|
||||||
dummyHandler.ID(): handlerWrapper{
|
dummyHandler.ID(): handlerWrapper{
|
||||||
domain: "netbird.cloud",
|
domain: "netbird.cloud",
|
||||||
handler: dummyHandler,
|
handler: dummyHandler,
|
||||||
priority: PriorityMatchDomain,
|
priority: PriorityLocal,
|
||||||
},
|
},
|
||||||
generateDummyHandler(".", nameServers).ID(): handlerWrapper{
|
generateDummyHandler(".", nameServers).ID(): handlerWrapper{
|
||||||
domain: nbdns.RootZone,
|
domain: nbdns.RootZone,
|
||||||
@ -186,7 +186,7 @@ func TestUpdateDNSServer(t *testing.T) {
|
|||||||
generateDummyHandler(zoneRecords[0].Name, nameServers).ID(): handlerWrapper{
|
generateDummyHandler(zoneRecords[0].Name, nameServers).ID(): handlerWrapper{
|
||||||
domain: "netbird.cloud",
|
domain: "netbird.cloud",
|
||||||
handler: dummyHandler,
|
handler: dummyHandler,
|
||||||
priority: PriorityMatchDomain,
|
priority: PriorityUpstream,
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
initSerial: 0,
|
initSerial: 0,
|
||||||
@ -210,12 +210,12 @@ func TestUpdateDNSServer(t *testing.T) {
|
|||||||
generateDummyHandler("netbird.io", nameServers).ID(): handlerWrapper{
|
generateDummyHandler("netbird.io", nameServers).ID(): handlerWrapper{
|
||||||
domain: "netbird.io",
|
domain: "netbird.io",
|
||||||
handler: dummyHandler,
|
handler: dummyHandler,
|
||||||
priority: PriorityMatchDomain,
|
priority: PriorityUpstream,
|
||||||
},
|
},
|
||||||
"local-resolver": handlerWrapper{
|
"local-resolver": handlerWrapper{
|
||||||
domain: "netbird.cloud",
|
domain: "netbird.cloud",
|
||||||
handler: dummyHandler,
|
handler: dummyHandler,
|
||||||
priority: PriorityMatchDomain,
|
priority: PriorityLocal,
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
expectedLocalQs: []dns.Question{{Name: zoneRecords[0].Name, Qtype: 1, Qclass: 1}},
|
expectedLocalQs: []dns.Question{{Name: zoneRecords[0].Name, Qtype: 1, Qclass: 1}},
|
||||||
@ -305,7 +305,7 @@ func TestUpdateDNSServer(t *testing.T) {
|
|||||||
generateDummyHandler(zoneRecords[0].Name, nameServers).ID(): handlerWrapper{
|
generateDummyHandler(zoneRecords[0].Name, nameServers).ID(): handlerWrapper{
|
||||||
domain: zoneRecords[0].Name,
|
domain: zoneRecords[0].Name,
|
||||||
handler: dummyHandler,
|
handler: dummyHandler,
|
||||||
priority: PriorityMatchDomain,
|
priority: PriorityUpstream,
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
initSerial: 0,
|
initSerial: 0,
|
||||||
@ -321,7 +321,7 @@ func TestUpdateDNSServer(t *testing.T) {
|
|||||||
generateDummyHandler(zoneRecords[0].Name, nameServers).ID(): handlerWrapper{
|
generateDummyHandler(zoneRecords[0].Name, nameServers).ID(): handlerWrapper{
|
||||||
domain: zoneRecords[0].Name,
|
domain: zoneRecords[0].Name,
|
||||||
handler: dummyHandler,
|
handler: dummyHandler,
|
||||||
priority: PriorityMatchDomain,
|
priority: PriorityUpstream,
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
initSerial: 0,
|
initSerial: 0,
|
||||||
@ -495,7 +495,7 @@ func TestDNSFakeResolverHandleUpdates(t *testing.T) {
|
|||||||
"id1": handlerWrapper{
|
"id1": handlerWrapper{
|
||||||
domain: zoneRecords[0].Name,
|
domain: zoneRecords[0].Name,
|
||||||
handler: &local.Resolver{},
|
handler: &local.Resolver{},
|
||||||
priority: PriorityMatchDomain,
|
priority: PriorityUpstream,
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
//dnsServer.localResolver.RegisteredMap = local.RegistrationMap{local.BuildRecordKey("netbird.cloud", dns.ClassINET, dns.TypeA): struct{}{}}
|
//dnsServer.localResolver.RegisteredMap = local.RegistrationMap{local.BuildRecordKey("netbird.cloud", dns.ClassINET, dns.TypeA): struct{}{}}
|
||||||
@ -978,7 +978,7 @@ func TestHandlerChain_DomainPriorities(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
chain.AddHandler("example.com.", dnsRouteHandler, PriorityDNSRoute)
|
chain.AddHandler("example.com.", dnsRouteHandler, PriorityDNSRoute)
|
||||||
chain.AddHandler("example.com.", upstreamHandler, PriorityMatchDomain)
|
chain.AddHandler("example.com.", upstreamHandler, PriorityUpstream)
|
||||||
|
|
||||||
testCases := []struct {
|
testCases := []struct {
|
||||||
name string
|
name string
|
||||||
@ -1059,14 +1059,14 @@ func TestDefaultServer_UpdateMux(t *testing.T) {
|
|||||||
handler: &mockHandler{
|
handler: &mockHandler{
|
||||||
Id: "upstream-group1",
|
Id: "upstream-group1",
|
||||||
},
|
},
|
||||||
priority: PriorityMatchDomain,
|
priority: PriorityUpstream,
|
||||||
},
|
},
|
||||||
"upstream-group2": {
|
"upstream-group2": {
|
||||||
domain: "example.com",
|
domain: "example.com",
|
||||||
handler: &mockHandler{
|
handler: &mockHandler{
|
||||||
Id: "upstream-group2",
|
Id: "upstream-group2",
|
||||||
},
|
},
|
||||||
priority: PriorityMatchDomain - 1,
|
priority: PriorityUpstream - 1,
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -1093,21 +1093,21 @@ func TestDefaultServer_UpdateMux(t *testing.T) {
|
|||||||
handler: &mockHandler{
|
handler: &mockHandler{
|
||||||
Id: "upstream-group1",
|
Id: "upstream-group1",
|
||||||
},
|
},
|
||||||
priority: PriorityMatchDomain,
|
priority: PriorityUpstream,
|
||||||
},
|
},
|
||||||
"upstream-group2": {
|
"upstream-group2": {
|
||||||
domain: "example.com",
|
domain: "example.com",
|
||||||
handler: &mockHandler{
|
handler: &mockHandler{
|
||||||
Id: "upstream-group2",
|
Id: "upstream-group2",
|
||||||
},
|
},
|
||||||
priority: PriorityMatchDomain - 1,
|
priority: PriorityUpstream - 1,
|
||||||
},
|
},
|
||||||
"upstream-other": {
|
"upstream-other": {
|
||||||
domain: "other.com",
|
domain: "other.com",
|
||||||
handler: &mockHandler{
|
handler: &mockHandler{
|
||||||
Id: "upstream-other",
|
Id: "upstream-other",
|
||||||
},
|
},
|
||||||
priority: PriorityMatchDomain,
|
priority: PriorityUpstream,
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -1128,7 +1128,7 @@ func TestDefaultServer_UpdateMux(t *testing.T) {
|
|||||||
handler: &mockHandler{
|
handler: &mockHandler{
|
||||||
Id: "upstream-group2",
|
Id: "upstream-group2",
|
||||||
},
|
},
|
||||||
priority: PriorityMatchDomain - 1,
|
priority: PriorityUpstream - 1,
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
expectedHandlers: map[string]string{
|
expectedHandlers: map[string]string{
|
||||||
@ -1146,7 +1146,7 @@ func TestDefaultServer_UpdateMux(t *testing.T) {
|
|||||||
handler: &mockHandler{
|
handler: &mockHandler{
|
||||||
Id: "upstream-group1",
|
Id: "upstream-group1",
|
||||||
},
|
},
|
||||||
priority: PriorityMatchDomain,
|
priority: PriorityUpstream,
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
expectedHandlers: map[string]string{
|
expectedHandlers: map[string]string{
|
||||||
@ -1164,7 +1164,7 @@ func TestDefaultServer_UpdateMux(t *testing.T) {
|
|||||||
handler: &mockHandler{
|
handler: &mockHandler{
|
||||||
Id: "upstream-group3",
|
Id: "upstream-group3",
|
||||||
},
|
},
|
||||||
priority: PriorityMatchDomain + 1,
|
priority: PriorityUpstream + 1,
|
||||||
},
|
},
|
||||||
// Keep existing groups with their original priorities
|
// Keep existing groups with their original priorities
|
||||||
{
|
{
|
||||||
@ -1172,14 +1172,14 @@ func TestDefaultServer_UpdateMux(t *testing.T) {
|
|||||||
handler: &mockHandler{
|
handler: &mockHandler{
|
||||||
Id: "upstream-group1",
|
Id: "upstream-group1",
|
||||||
},
|
},
|
||||||
priority: PriorityMatchDomain,
|
priority: PriorityUpstream,
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
domain: "example.com",
|
domain: "example.com",
|
||||||
handler: &mockHandler{
|
handler: &mockHandler{
|
||||||
Id: "upstream-group2",
|
Id: "upstream-group2",
|
||||||
},
|
},
|
||||||
priority: PriorityMatchDomain - 1,
|
priority: PriorityUpstream - 1,
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
expectedHandlers: map[string]string{
|
expectedHandlers: map[string]string{
|
||||||
@ -1199,14 +1199,14 @@ func TestDefaultServer_UpdateMux(t *testing.T) {
|
|||||||
handler: &mockHandler{
|
handler: &mockHandler{
|
||||||
Id: "upstream-group1",
|
Id: "upstream-group1",
|
||||||
},
|
},
|
||||||
priority: PriorityMatchDomain,
|
priority: PriorityUpstream,
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
domain: "example.com",
|
domain: "example.com",
|
||||||
handler: &mockHandler{
|
handler: &mockHandler{
|
||||||
Id: "upstream-group2",
|
Id: "upstream-group2",
|
||||||
},
|
},
|
||||||
priority: PriorityMatchDomain - 1,
|
priority: PriorityUpstream - 1,
|
||||||
},
|
},
|
||||||
// Add group3 with lowest priority
|
// Add group3 with lowest priority
|
||||||
{
|
{
|
||||||
@ -1214,7 +1214,7 @@ func TestDefaultServer_UpdateMux(t *testing.T) {
|
|||||||
handler: &mockHandler{
|
handler: &mockHandler{
|
||||||
Id: "upstream-group3",
|
Id: "upstream-group3",
|
||||||
},
|
},
|
||||||
priority: PriorityMatchDomain - 2,
|
priority: PriorityUpstream - 2,
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
expectedHandlers: map[string]string{
|
expectedHandlers: map[string]string{
|
||||||
@ -1335,14 +1335,14 @@ func TestDefaultServer_UpdateMux(t *testing.T) {
|
|||||||
handler: &mockHandler{
|
handler: &mockHandler{
|
||||||
Id: "upstream-group1",
|
Id: "upstream-group1",
|
||||||
},
|
},
|
||||||
priority: PriorityMatchDomain,
|
priority: PriorityUpstream,
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
domain: "other.com",
|
domain: "other.com",
|
||||||
handler: &mockHandler{
|
handler: &mockHandler{
|
||||||
Id: "upstream-other",
|
Id: "upstream-other",
|
||||||
},
|
},
|
||||||
priority: PriorityMatchDomain,
|
priority: PriorityUpstream,
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
expectedHandlers: map[string]string{
|
expectedHandlers: map[string]string{
|
||||||
@ -1360,28 +1360,28 @@ func TestDefaultServer_UpdateMux(t *testing.T) {
|
|||||||
handler: &mockHandler{
|
handler: &mockHandler{
|
||||||
Id: "upstream-group1",
|
Id: "upstream-group1",
|
||||||
},
|
},
|
||||||
priority: PriorityMatchDomain,
|
priority: PriorityUpstream,
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
domain: "example.com",
|
domain: "example.com",
|
||||||
handler: &mockHandler{
|
handler: &mockHandler{
|
||||||
Id: "upstream-group2",
|
Id: "upstream-group2",
|
||||||
},
|
},
|
||||||
priority: PriorityMatchDomain - 1,
|
priority: PriorityUpstream - 1,
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
domain: "other.com",
|
domain: "other.com",
|
||||||
handler: &mockHandler{
|
handler: &mockHandler{
|
||||||
Id: "upstream-other",
|
Id: "upstream-other",
|
||||||
},
|
},
|
||||||
priority: PriorityMatchDomain,
|
priority: PriorityUpstream,
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
domain: "new.com",
|
domain: "new.com",
|
||||||
handler: &mockHandler{
|
handler: &mockHandler{
|
||||||
Id: "upstream-new",
|
Id: "upstream-new",
|
||||||
},
|
},
|
||||||
priority: PriorityMatchDomain,
|
priority: PriorityUpstream,
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
expectedHandlers: map[string]string{
|
expectedHandlers: map[string]string{
|
||||||
@ -1791,14 +1791,14 @@ func TestExtraDomainsRefCounting(t *testing.T) {
|
|||||||
|
|
||||||
// Register domains from different handlers with same domain
|
// Register domains from different handlers with same domain
|
||||||
server.RegisterHandler(domain.List{"*.shared.example.com"}, &MockHandler{}, PriorityDNSRoute)
|
server.RegisterHandler(domain.List{"*.shared.example.com"}, &MockHandler{}, PriorityDNSRoute)
|
||||||
server.RegisterHandler(domain.List{"shared.example.com."}, &MockHandler{}, PriorityMatchDomain)
|
server.RegisterHandler(domain.List{"shared.example.com."}, &MockHandler{}, PriorityUpstream)
|
||||||
|
|
||||||
// Verify refcount is 2
|
// Verify refcount is 2
|
||||||
zoneKey := toZone("shared.example.com")
|
zoneKey := toZone("shared.example.com")
|
||||||
assert.Equal(t, 2, server.extraDomains[zoneKey], "Refcount should be 2 after registering same domain twice")
|
assert.Equal(t, 2, server.extraDomains[zoneKey], "Refcount should be 2 after registering same domain twice")
|
||||||
|
|
||||||
// Deregister one handler
|
// Deregister one handler
|
||||||
server.DeregisterHandler(domain.List{"shared.example.com"}, PriorityMatchDomain)
|
server.DeregisterHandler(domain.List{"shared.example.com"}, PriorityUpstream)
|
||||||
|
|
||||||
// Verify refcount is 1
|
// Verify refcount is 1
|
||||||
assert.Equal(t, 1, server.extraDomains[zoneKey], "Refcount should be 1 after deregistering one handler")
|
assert.Equal(t, 1, server.extraDomains[zoneKey], "Refcount should be 1 after deregistering one handler")
|
||||||
@ -1925,7 +1925,7 @@ func TestDomainCaseHandling(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
server.RegisterHandler(domain.List{"MIXED.example.com"}, &MockHandler{}, PriorityDefault)
|
server.RegisterHandler(domain.List{"MIXED.example.com"}, &MockHandler{}, PriorityDefault)
|
||||||
server.RegisterHandler(domain.List{"mixed.EXAMPLE.com"}, &MockHandler{}, PriorityMatchDomain)
|
server.RegisterHandler(domain.List{"mixed.EXAMPLE.com"}, &MockHandler{}, PriorityUpstream)
|
||||||
|
|
||||||
assert.Equal(t, 1, len(server.extraDomains), "Case differences should be normalized")
|
assert.Equal(t, 1, len(server.extraDomains), "Case differences should be normalized")
|
||||||
|
|
||||||
@ -1945,3 +1945,111 @@ func TestDomainCaseHandling(t *testing.T) {
|
|||||||
assert.Contains(t, domains, "config.example.com.", "Mixed case domain should be normalized and pre.sent")
|
assert.Contains(t, domains, "config.example.com.", "Mixed case domain should be normalized and pre.sent")
|
||||||
assert.Contains(t, domains, "mixed.example.com.", "Mixed case domain should be normalized and present")
|
assert.Contains(t, domains, "mixed.example.com.", "Mixed case domain should be normalized and present")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestLocalResolverPriorityInServer(t *testing.T) {
|
||||||
|
server := &DefaultServer{
|
||||||
|
ctx: context.Background(),
|
||||||
|
wgInterface: &mocWGIface{},
|
||||||
|
handlerChain: NewHandlerChain(),
|
||||||
|
localResolver: local.NewResolver(),
|
||||||
|
service: &mockService{},
|
||||||
|
extraDomains: make(map[domain.Domain]int),
|
||||||
|
}
|
||||||
|
|
||||||
|
config := nbdns.Config{
|
||||||
|
ServiceEnable: true,
|
||||||
|
CustomZones: []nbdns.CustomZone{
|
||||||
|
{
|
||||||
|
Domain: "local.example.com",
|
||||||
|
Records: []nbdns.SimpleRecord{
|
||||||
|
{
|
||||||
|
Name: "test.local.example.com",
|
||||||
|
Type: int(dns.TypeA),
|
||||||
|
Class: nbdns.DefaultClass,
|
||||||
|
TTL: 300,
|
||||||
|
RData: "192.168.1.100",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
NameServerGroups: []*nbdns.NameServerGroup{
|
||||||
|
{
|
||||||
|
Domains: []string{"local.example.com"}, // Same domain as local records
|
||||||
|
NameServers: []nbdns.NameServer{
|
||||||
|
{
|
||||||
|
IP: netip.MustParseAddr("8.8.8.8"),
|
||||||
|
NSType: nbdns.UDPNameServerType,
|
||||||
|
Port: 53,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
localMuxUpdates, _, err := server.buildLocalHandlerUpdate(config.CustomZones)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
|
||||||
|
upstreamMuxUpdates, err := server.buildUpstreamHandlerUpdate(config.NameServerGroups)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
|
||||||
|
// Verify that local handler has higher priority than upstream for same domain
|
||||||
|
var localPriority, upstreamPriority int
|
||||||
|
localFound, upstreamFound := false, false
|
||||||
|
|
||||||
|
for _, update := range localMuxUpdates {
|
||||||
|
if update.domain == "local.example.com" {
|
||||||
|
localPriority = update.priority
|
||||||
|
localFound = true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, update := range upstreamMuxUpdates {
|
||||||
|
if update.domain == "local.example.com" {
|
||||||
|
upstreamPriority = update.priority
|
||||||
|
upstreamFound = true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
assert.True(t, localFound, "Local handler should be found")
|
||||||
|
assert.True(t, upstreamFound, "Upstream handler should be found")
|
||||||
|
assert.Greater(t, localPriority, upstreamPriority,
|
||||||
|
"Local handler priority (%d) should be higher than upstream priority (%d)",
|
||||||
|
localPriority, upstreamPriority)
|
||||||
|
assert.Equal(t, PriorityLocal, localPriority, "Local handler should use PriorityLocal")
|
||||||
|
assert.Equal(t, PriorityUpstream, upstreamPriority, "Upstream handler should use PriorityUpstream")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestLocalResolverPriorityConstants(t *testing.T) {
|
||||||
|
// Test that priority constants are ordered correctly
|
||||||
|
assert.Greater(t, PriorityLocal, PriorityDNSRoute, "Local priority should be higher than DNS route")
|
||||||
|
assert.Greater(t, PriorityLocal, PriorityUpstream, "Local priority should be higher than upstream")
|
||||||
|
assert.Greater(t, PriorityUpstream, PriorityDefault, "Upstream priority should be higher than default")
|
||||||
|
|
||||||
|
// Test that local resolver uses the correct priority
|
||||||
|
server := &DefaultServer{
|
||||||
|
localResolver: local.NewResolver(),
|
||||||
|
}
|
||||||
|
|
||||||
|
config := nbdns.Config{
|
||||||
|
CustomZones: []nbdns.CustomZone{
|
||||||
|
{
|
||||||
|
Domain: "local.example.com",
|
||||||
|
Records: []nbdns.SimpleRecord{
|
||||||
|
{
|
||||||
|
Name: "test.local.example.com",
|
||||||
|
Type: int(dns.TypeA),
|
||||||
|
Class: nbdns.DefaultClass,
|
||||||
|
TTL: 300,
|
||||||
|
RData: "192.168.1.100",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
localMuxUpdates, _, err := server.buildLocalHandlerUpdate(config.CustomZones)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
assert.Len(t, localMuxUpdates, 1)
|
||||||
|
assert.Equal(t, PriorityLocal, localMuxUpdates[0].priority, "Local handler should use PriorityLocal")
|
||||||
|
assert.Equal(t, "local.example.com", localMuxUpdates[0].domain)
|
||||||
|
}
|
||||||
|
@ -2,6 +2,7 @@ package dns
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
|
"crypto/rand"
|
||||||
"crypto/sha256"
|
"crypto/sha256"
|
||||||
"encoding/hex"
|
"encoding/hex"
|
||||||
"errors"
|
"errors"
|
||||||
@ -103,19 +104,21 @@ func (u *upstreamResolverBase) Stop() {
|
|||||||
|
|
||||||
// ServeDNS handles a DNS request
|
// ServeDNS handles a DNS request
|
||||||
func (u *upstreamResolverBase) ServeDNS(w dns.ResponseWriter, r *dns.Msg) {
|
func (u *upstreamResolverBase) ServeDNS(w dns.ResponseWriter, r *dns.Msg) {
|
||||||
|
requestID := GenerateRequestID()
|
||||||
|
logger := log.WithField("request_id", requestID)
|
||||||
var err error
|
var err error
|
||||||
defer func() {
|
defer func() {
|
||||||
u.checkUpstreamFails(err)
|
u.checkUpstreamFails(err)
|
||||||
}()
|
}()
|
||||||
|
|
||||||
log.Tracef("received upstream question: domain=%s type=%v class=%v", r.Question[0].Name, r.Question[0].Qtype, r.Question[0].Qclass)
|
logger.Tracef("received upstream question: domain=%s type=%v class=%v", r.Question[0].Name, r.Question[0].Qtype, r.Question[0].Qclass)
|
||||||
if r.Extra == nil {
|
if r.Extra == nil {
|
||||||
r.MsgHdr.AuthenticatedData = true
|
r.MsgHdr.AuthenticatedData = true
|
||||||
}
|
}
|
||||||
|
|
||||||
select {
|
select {
|
||||||
case <-u.ctx.Done():
|
case <-u.ctx.Done():
|
||||||
log.Tracef("%s has been stopped", u)
|
logger.Tracef("%s has been stopped", u)
|
||||||
return
|
return
|
||||||
default:
|
default:
|
||||||
}
|
}
|
||||||
@ -132,35 +135,35 @@ func (u *upstreamResolverBase) ServeDNS(w dns.ResponseWriter, r *dns.Msg) {
|
|||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
if errors.Is(err, context.DeadlineExceeded) || isTimeout(err) {
|
if errors.Is(err, context.DeadlineExceeded) || isTimeout(err) {
|
||||||
log.Warnf("upstream %s timed out for question domain=%s", upstream, r.Question[0].Name)
|
logger.Warnf("upstream %s timed out for question domain=%s", upstream, r.Question[0].Name)
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
log.Warnf("failed to query upstream %s for question domain=%s: %s", upstream, r.Question[0].Name, err)
|
logger.Warnf("failed to query upstream %s for question domain=%s: %s", upstream, r.Question[0].Name, err)
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
if rm == nil || !rm.Response {
|
if rm == nil || !rm.Response {
|
||||||
log.Warnf("no response from upstream %s for question domain=%s", upstream, r.Question[0].Name)
|
logger.Warnf("no response from upstream %s for question domain=%s", upstream, r.Question[0].Name)
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
u.successCount.Add(1)
|
u.successCount.Add(1)
|
||||||
log.Tracef("took %s to query the upstream %s for question domain=%s", t, upstream, r.Question[0].Name)
|
logger.Tracef("took %s to query the upstream %s for question domain=%s", t, upstream, r.Question[0].Name)
|
||||||
|
|
||||||
if err = w.WriteMsg(rm); err != nil {
|
if err = w.WriteMsg(rm); err != nil {
|
||||||
log.Errorf("failed to write DNS response for question domain=%s: %s", r.Question[0].Name, err)
|
logger.Errorf("failed to write DNS response for question domain=%s: %s", r.Question[0].Name, err)
|
||||||
}
|
}
|
||||||
// count the fails only if they happen sequentially
|
// count the fails only if they happen sequentially
|
||||||
u.failsCount.Store(0)
|
u.failsCount.Store(0)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
u.failsCount.Add(1)
|
u.failsCount.Add(1)
|
||||||
log.Errorf("all queries to the %s failed for question domain=%s", u, r.Question[0].Name)
|
logger.Errorf("all queries to the %s failed for question domain=%s", u, r.Question[0].Name)
|
||||||
|
|
||||||
m := new(dns.Msg)
|
m := new(dns.Msg)
|
||||||
m.SetRcode(r, dns.RcodeServerFailure)
|
m.SetRcode(r, dns.RcodeServerFailure)
|
||||||
if err := w.WriteMsg(m); err != nil {
|
if err := w.WriteMsg(m); err != nil {
|
||||||
log.Errorf("failed to write error response for %s for question domain=%s: %s", u, r.Question[0].Name, err)
|
logger.Errorf("failed to write error response for %s for question domain=%s: %s", u, r.Question[0].Name, err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -385,3 +388,13 @@ func ExchangeWithFallback(ctx context.Context, client *dns.Client, r *dns.Msg, u
|
|||||||
|
|
||||||
return rm, t, nil
|
return rm, t, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func GenerateRequestID() string {
|
||||||
|
bytes := make([]byte, 4)
|
||||||
|
_, err := rand.Read(bytes)
|
||||||
|
if err != nil {
|
||||||
|
log.Errorf("failed to generate request ID: %v", err)
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
return hex.EncodeToString(bytes)
|
||||||
|
}
|
||||||
|
@ -84,3 +84,10 @@ func (u *upstreamResolver) isLocalResolver(upstream string) bool {
|
|||||||
}
|
}
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func GetClientPrivate(ip netip.Addr, interfaceName string, dialTimeout time.Duration) (*dns.Client, error) {
|
||||||
|
return &dns.Client{
|
||||||
|
Timeout: dialTimeout,
|
||||||
|
Net: "udp",
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
@ -36,3 +36,10 @@ func newUpstreamResolver(
|
|||||||
func (u *upstreamResolver) exchange(ctx context.Context, upstream string, r *dns.Msg) (rm *dns.Msg, t time.Duration, err error) {
|
func (u *upstreamResolver) exchange(ctx context.Context, upstream string, r *dns.Msg) (rm *dns.Msg, t time.Duration, err error) {
|
||||||
return ExchangeWithFallback(ctx, &dns.Client{}, r, upstream)
|
return ExchangeWithFallback(ctx, &dns.Client{}, r, upstream)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func GetClientPrivate(ip netip.Addr, interfaceName string, dialTimeout time.Duration) (*dns.Client, error) {
|
||||||
|
return &dns.Client{
|
||||||
|
Timeout: dialTimeout,
|
||||||
|
Net: "udp",
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
@ -18,14 +18,20 @@ import (
|
|||||||
nberrors "github.com/netbirdio/netbird/client/errors"
|
nberrors "github.com/netbirdio/netbird/client/errors"
|
||||||
firewall "github.com/netbirdio/netbird/client/firewall/manager"
|
firewall "github.com/netbirdio/netbird/client/firewall/manager"
|
||||||
"github.com/netbirdio/netbird/client/internal/peer"
|
"github.com/netbirdio/netbird/client/internal/peer"
|
||||||
nbdns "github.com/netbirdio/netbird/dns"
|
|
||||||
"github.com/netbirdio/netbird/management/domain"
|
|
||||||
"github.com/netbirdio/netbird/route"
|
"github.com/netbirdio/netbird/route"
|
||||||
)
|
)
|
||||||
|
|
||||||
const errResolveFailed = "failed to resolve query for domain=%s: %v"
|
const errResolveFailed = "failed to resolve query for domain=%s: %v"
|
||||||
const upstreamTimeout = 15 * time.Second
|
const upstreamTimeout = 15 * time.Second
|
||||||
|
|
||||||
|
type resolver interface {
|
||||||
|
LookupNetIP(ctx context.Context, network, host string) ([]netip.Addr, error)
|
||||||
|
}
|
||||||
|
|
||||||
|
type firewaller interface {
|
||||||
|
UpdateSet(set firewall.Set, prefixes []netip.Prefix) error
|
||||||
|
}
|
||||||
|
|
||||||
type DNSForwarder struct {
|
type DNSForwarder struct {
|
||||||
listenAddress string
|
listenAddress string
|
||||||
ttl uint32
|
ttl uint32
|
||||||
@ -38,16 +44,18 @@ type DNSForwarder struct {
|
|||||||
|
|
||||||
mutex sync.RWMutex
|
mutex sync.RWMutex
|
||||||
fwdEntries []*ForwarderEntry
|
fwdEntries []*ForwarderEntry
|
||||||
firewall firewall.Manager
|
firewall firewaller
|
||||||
|
resolver resolver
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewDNSForwarder(listenAddress string, ttl uint32, firewall firewall.Manager, statusRecorder *peer.Status) *DNSForwarder {
|
func NewDNSForwarder(listenAddress string, ttl uint32, firewall firewaller, statusRecorder *peer.Status) *DNSForwarder {
|
||||||
log.Debugf("creating DNS forwarder with listen_address=%s ttl=%d", listenAddress, ttl)
|
log.Debugf("creating DNS forwarder with listen_address=%s ttl=%d", listenAddress, ttl)
|
||||||
return &DNSForwarder{
|
return &DNSForwarder{
|
||||||
listenAddress: listenAddress,
|
listenAddress: listenAddress,
|
||||||
ttl: ttl,
|
ttl: ttl,
|
||||||
firewall: firewall,
|
firewall: firewall,
|
||||||
statusRecorder: statusRecorder,
|
statusRecorder: statusRecorder,
|
||||||
|
resolver: net.DefaultResolver,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -57,14 +65,17 @@ func (f *DNSForwarder) Listen(entries []*ForwarderEntry) error {
|
|||||||
// UDP server
|
// UDP server
|
||||||
mux := dns.NewServeMux()
|
mux := dns.NewServeMux()
|
||||||
f.mux = mux
|
f.mux = mux
|
||||||
|
mux.HandleFunc(".", f.handleDNSQueryUDP)
|
||||||
f.dnsServer = &dns.Server{
|
f.dnsServer = &dns.Server{
|
||||||
Addr: f.listenAddress,
|
Addr: f.listenAddress,
|
||||||
Net: "udp",
|
Net: "udp",
|
||||||
Handler: mux,
|
Handler: mux,
|
||||||
}
|
}
|
||||||
|
|
||||||
// TCP server
|
// TCP server
|
||||||
tcpMux := dns.NewServeMux()
|
tcpMux := dns.NewServeMux()
|
||||||
f.tcpMux = tcpMux
|
f.tcpMux = tcpMux
|
||||||
|
tcpMux.HandleFunc(".", f.handleDNSQueryTCP)
|
||||||
f.tcpServer = &dns.Server{
|
f.tcpServer = &dns.Server{
|
||||||
Addr: f.listenAddress,
|
Addr: f.listenAddress,
|
||||||
Net: "tcp",
|
Net: "tcp",
|
||||||
@ -87,30 +98,13 @@ func (f *DNSForwarder) Listen(entries []*ForwarderEntry) error {
|
|||||||
// return the first error we get (e.g. bind failure or shutdown)
|
// return the first error we get (e.g. bind failure or shutdown)
|
||||||
return <-errCh
|
return <-errCh
|
||||||
}
|
}
|
||||||
|
|
||||||
func (f *DNSForwarder) UpdateDomains(entries []*ForwarderEntry) {
|
func (f *DNSForwarder) UpdateDomains(entries []*ForwarderEntry) {
|
||||||
f.mutex.Lock()
|
f.mutex.Lock()
|
||||||
defer f.mutex.Unlock()
|
defer f.mutex.Unlock()
|
||||||
|
|
||||||
if f.mux == nil {
|
|
||||||
log.Debug("DNS mux is nil, skipping domain update")
|
|
||||||
f.fwdEntries = entries
|
f.fwdEntries = entries
|
||||||
return
|
log.Debugf("Updated DNS forwarder with %d domains", len(entries))
|
||||||
}
|
|
||||||
|
|
||||||
oldDomains := filterDomains(f.fwdEntries)
|
|
||||||
for _, d := range oldDomains {
|
|
||||||
f.mux.HandleRemove(d.PunycodeString())
|
|
||||||
f.tcpMux.HandleRemove(d.PunycodeString())
|
|
||||||
}
|
|
||||||
|
|
||||||
newDomains := filterDomains(entries)
|
|
||||||
for _, d := range newDomains {
|
|
||||||
f.mux.HandleFunc(d.PunycodeString(), f.handleDNSQueryUDP)
|
|
||||||
f.tcpMux.HandleFunc(d.PunycodeString(), f.handleDNSQueryTCP)
|
|
||||||
}
|
|
||||||
|
|
||||||
f.fwdEntries = entries
|
|
||||||
log.Debugf("Updated domains from %v to %v", oldDomains, newDomains)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (f *DNSForwarder) Close(ctx context.Context) error {
|
func (f *DNSForwarder) Close(ctx context.Context) error {
|
||||||
@ -157,22 +151,31 @@ func (f *DNSForwarder) handleDNSQuery(w dns.ResponseWriter, query *dns.Msg) *dns
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
mostSpecificResId, matchingEntries := f.getMatchingEntries(strings.TrimSuffix(domain, "."))
|
||||||
|
// query doesn't match any configured domain
|
||||||
|
if mostSpecificResId == "" {
|
||||||
|
resp.Rcode = dns.RcodeRefused
|
||||||
|
if err := w.WriteMsg(resp); err != nil {
|
||||||
|
log.Errorf("failed to write DNS response: %v", err)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
ctx, cancel := context.WithTimeout(context.Background(), upstreamTimeout)
|
ctx, cancel := context.WithTimeout(context.Background(), upstreamTimeout)
|
||||||
defer cancel()
|
defer cancel()
|
||||||
ips, err := net.DefaultResolver.LookupNetIP(ctx, network, domain)
|
ips, err := f.resolver.LookupNetIP(ctx, network, domain)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
f.handleDNSError(w, query, resp, domain, err)
|
f.handleDNSError(w, query, resp, domain, err)
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
f.updateInternalState(domain, ips)
|
f.updateInternalState(ips, mostSpecificResId, matchingEntries)
|
||||||
f.addIPsToResponse(resp, domain, ips)
|
f.addIPsToResponse(resp, domain, ips)
|
||||||
|
|
||||||
return resp
|
return resp
|
||||||
}
|
}
|
||||||
|
|
||||||
func (f *DNSForwarder) handleDNSQueryUDP(w dns.ResponseWriter, query *dns.Msg) {
|
func (f *DNSForwarder) handleDNSQueryUDP(w dns.ResponseWriter, query *dns.Msg) {
|
||||||
|
|
||||||
resp := f.handleDNSQuery(w, query)
|
resp := f.handleDNSQuery(w, query)
|
||||||
if resp == nil {
|
if resp == nil {
|
||||||
return
|
return
|
||||||
@ -206,9 +209,8 @@ func (f *DNSForwarder) handleDNSQueryTCP(w dns.ResponseWriter, query *dns.Msg) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (f *DNSForwarder) updateInternalState(domain string, ips []netip.Addr) {
|
func (f *DNSForwarder) updateInternalState(ips []netip.Addr, mostSpecificResId route.ResID, matchingEntries []*ForwarderEntry) {
|
||||||
var prefixes []netip.Prefix
|
var prefixes []netip.Prefix
|
||||||
mostSpecificResId, matchingEntries := f.getMatchingEntries(strings.TrimSuffix(domain, "."))
|
|
||||||
if mostSpecificResId != "" {
|
if mostSpecificResId != "" {
|
||||||
for _, ip := range ips {
|
for _, ip := range ips {
|
||||||
var prefix netip.Prefix
|
var prefix netip.Prefix
|
||||||
@ -339,16 +341,3 @@ func (f *DNSForwarder) getMatchingEntries(domain string) (route.ResID, []*Forwar
|
|||||||
|
|
||||||
return selectedResId, matches
|
return selectedResId, matches
|
||||||
}
|
}
|
||||||
|
|
||||||
// filterDomains returns a list of normalized domains
|
|
||||||
func filterDomains(entries []*ForwarderEntry) domain.List {
|
|
||||||
newDomains := make(domain.List, 0, len(entries))
|
|
||||||
for _, d := range entries {
|
|
||||||
if d.Domain == "" {
|
|
||||||
log.Warn("empty domain in DNS forwarder")
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
newDomains = append(newDomains, domain.Domain(nbdns.NormalizeZone(d.Domain.PunycodeString())))
|
|
||||||
}
|
|
||||||
return newDomains
|
|
||||||
}
|
|
||||||
|
@ -1,11 +1,21 @@
|
|||||||
package dnsfwd
|
package dnsfwd
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"context"
|
||||||
|
"fmt"
|
||||||
|
"net/netip"
|
||||||
|
"strings"
|
||||||
"testing"
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/miekg/dns"
|
||||||
"github.com/stretchr/testify/assert"
|
"github.com/stretchr/testify/assert"
|
||||||
|
"github.com/stretchr/testify/mock"
|
||||||
"github.com/stretchr/testify/require"
|
"github.com/stretchr/testify/require"
|
||||||
|
|
||||||
|
firewall "github.com/netbirdio/netbird/client/firewall/manager"
|
||||||
|
"github.com/netbirdio/netbird/client/internal/dns/test"
|
||||||
|
"github.com/netbirdio/netbird/client/internal/peer"
|
||||||
"github.com/netbirdio/netbird/management/domain"
|
"github.com/netbirdio/netbird/management/domain"
|
||||||
"github.com/netbirdio/netbird/route"
|
"github.com/netbirdio/netbird/route"
|
||||||
)
|
)
|
||||||
@ -13,7 +23,7 @@ import (
|
|||||||
func Test_getMatchingEntries(t *testing.T) {
|
func Test_getMatchingEntries(t *testing.T) {
|
||||||
testCases := []struct {
|
testCases := []struct {
|
||||||
name string
|
name string
|
||||||
storedMappings map[string]route.ResID // key: domain pattern, value: resId
|
storedMappings map[string]route.ResID
|
||||||
queryDomain string
|
queryDomain string
|
||||||
expectedResId route.ResID
|
expectedResId route.ResID
|
||||||
}{
|
}{
|
||||||
@ -44,7 +54,7 @@ func Test_getMatchingEntries(t *testing.T) {
|
|||||||
{
|
{
|
||||||
name: "Wildcard pattern does not match different domain",
|
name: "Wildcard pattern does not match different domain",
|
||||||
storedMappings: map[string]route.ResID{"*.example.com": "res4"},
|
storedMappings: map[string]route.ResID{"*.example.com": "res4"},
|
||||||
queryDomain: "foo.notexample.com",
|
queryDomain: "foo.example.org",
|
||||||
expectedResId: "",
|
expectedResId: "",
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
@ -101,3 +111,619 @@ func Test_getMatchingEntries(t *testing.T) {
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type MockFirewall struct {
|
||||||
|
mock.Mock
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *MockFirewall) UpdateSet(set firewall.Set, prefixes []netip.Prefix) error {
|
||||||
|
args := m.Called(set, prefixes)
|
||||||
|
return args.Error(0)
|
||||||
|
}
|
||||||
|
|
||||||
|
type MockResolver struct {
|
||||||
|
mock.Mock
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *MockResolver) LookupNetIP(ctx context.Context, network, host string) ([]netip.Addr, error) {
|
||||||
|
args := m.Called(ctx, network, host)
|
||||||
|
return args.Get(0).([]netip.Addr), args.Error(1)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestDNSForwarder_SubdomainAccessLogic(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
configuredDomain string
|
||||||
|
queryDomain string
|
||||||
|
shouldMatch bool
|
||||||
|
expectedResID route.ResID
|
||||||
|
description string
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "exact domain match should be allowed",
|
||||||
|
configuredDomain: "example.com",
|
||||||
|
queryDomain: "example.com",
|
||||||
|
shouldMatch: true,
|
||||||
|
expectedResID: "test-res-id",
|
||||||
|
description: "Direct match to configured domain should work",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "subdomain access should be restricted",
|
||||||
|
configuredDomain: "example.com",
|
||||||
|
queryDomain: "mail.example.com",
|
||||||
|
shouldMatch: false,
|
||||||
|
expectedResID: "",
|
||||||
|
description: "Subdomain should not be accessible unless explicitly configured",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "wildcard should allow subdomains",
|
||||||
|
configuredDomain: "*.example.com",
|
||||||
|
queryDomain: "mail.example.com",
|
||||||
|
shouldMatch: true,
|
||||||
|
expectedResID: "test-res-id",
|
||||||
|
description: "Wildcard domains should allow subdomain access",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "wildcard should allow base domain",
|
||||||
|
configuredDomain: "*.example.com",
|
||||||
|
queryDomain: "example.com",
|
||||||
|
shouldMatch: true,
|
||||||
|
expectedResID: "test-res-id",
|
||||||
|
description: "Wildcard should also match the base domain",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "deep subdomain should be restricted",
|
||||||
|
configuredDomain: "example.com",
|
||||||
|
queryDomain: "deep.mail.example.com",
|
||||||
|
shouldMatch: false,
|
||||||
|
expectedResID: "",
|
||||||
|
description: "Deep subdomains should not be accessible",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "wildcard allows deep subdomains",
|
||||||
|
configuredDomain: "*.example.com",
|
||||||
|
queryDomain: "deep.mail.example.com",
|
||||||
|
shouldMatch: true,
|
||||||
|
expectedResID: "test-res-id",
|
||||||
|
description: "Wildcard should allow deep subdomains",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
forwarder := &DNSForwarder{}
|
||||||
|
|
||||||
|
d, err := domain.FromString(tt.configuredDomain)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
entries := []*ForwarderEntry{
|
||||||
|
{
|
||||||
|
Domain: d,
|
||||||
|
ResID: "test-res-id",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
forwarder.UpdateDomains(entries)
|
||||||
|
|
||||||
|
resID, matchingEntries := forwarder.getMatchingEntries(tt.queryDomain)
|
||||||
|
|
||||||
|
if tt.shouldMatch {
|
||||||
|
assert.Equal(t, tt.expectedResID, resID, "Expected matching ResID")
|
||||||
|
assert.NotEmpty(t, matchingEntries, "Expected matching entries")
|
||||||
|
t.Logf("✓ Domain %s correctly matches pattern %s", tt.queryDomain, tt.configuredDomain)
|
||||||
|
} else {
|
||||||
|
assert.Equal(t, tt.expectedResID, resID, "Expected no ResID match")
|
||||||
|
assert.Empty(t, matchingEntries, "Expected no matching entries")
|
||||||
|
t.Logf("✓ Domain %s correctly does NOT match pattern %s", tt.queryDomain, tt.configuredDomain)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestDNSForwarder_UnauthorizedDomainAccess(t *testing.T) {
|
||||||
|
if testing.Short() {
|
||||||
|
t.Skip("Skipping integration test in short mode")
|
||||||
|
}
|
||||||
|
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
configuredDomain string
|
||||||
|
queryDomain string
|
||||||
|
shouldResolve bool
|
||||||
|
description string
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "configured exact domain resolves",
|
||||||
|
configuredDomain: "example.com",
|
||||||
|
queryDomain: "example.com",
|
||||||
|
shouldResolve: true,
|
||||||
|
description: "Exact match should resolve",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "unauthorized subdomain blocked",
|
||||||
|
configuredDomain: "example.com",
|
||||||
|
queryDomain: "mail.example.com",
|
||||||
|
shouldResolve: false,
|
||||||
|
description: "Subdomain should be blocked without wildcard",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "wildcard allows subdomain",
|
||||||
|
configuredDomain: "*.example.com",
|
||||||
|
queryDomain: "mail.example.com",
|
||||||
|
shouldResolve: true,
|
||||||
|
description: "Wildcard should allow subdomain",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "wildcard allows base domain",
|
||||||
|
configuredDomain: "*.example.com",
|
||||||
|
queryDomain: "example.com",
|
||||||
|
shouldResolve: true,
|
||||||
|
description: "Wildcard should allow base domain",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "unrelated domain blocked",
|
||||||
|
configuredDomain: "example.com",
|
||||||
|
queryDomain: "example.org",
|
||||||
|
shouldResolve: false,
|
||||||
|
description: "Unrelated domain should be blocked",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "deep subdomain blocked",
|
||||||
|
configuredDomain: "example.com",
|
||||||
|
queryDomain: "deep.mail.example.com",
|
||||||
|
shouldResolve: false,
|
||||||
|
description: "Deep subdomain should be blocked",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "wildcard allows deep subdomain",
|
||||||
|
configuredDomain: "*.example.com",
|
||||||
|
queryDomain: "deep.mail.example.com",
|
||||||
|
shouldResolve: true,
|
||||||
|
description: "Wildcard should allow deep subdomain",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
mockFirewall := &MockFirewall{}
|
||||||
|
mockResolver := &MockResolver{}
|
||||||
|
|
||||||
|
if tt.shouldResolve {
|
||||||
|
mockFirewall.On("UpdateSet", mock.AnythingOfType("manager.Set"), mock.AnythingOfType("[]netip.Prefix")).Return(nil)
|
||||||
|
|
||||||
|
// Mock successful DNS resolution
|
||||||
|
fakeIP := netip.MustParseAddr("1.2.3.4")
|
||||||
|
mockResolver.On("LookupNetIP", mock.Anything, "ip4", dns.Fqdn(tt.queryDomain)).Return([]netip.Addr{fakeIP}, nil)
|
||||||
|
}
|
||||||
|
|
||||||
|
forwarder := NewDNSForwarder("127.0.0.1:0", 300, mockFirewall, &peer.Status{})
|
||||||
|
forwarder.resolver = mockResolver
|
||||||
|
|
||||||
|
d, err := domain.FromString(tt.configuredDomain)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
entries := []*ForwarderEntry{
|
||||||
|
{
|
||||||
|
Domain: d,
|
||||||
|
ResID: "test-res-id",
|
||||||
|
Set: firewall.NewDomainSet([]domain.Domain{d}),
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
forwarder.UpdateDomains(entries)
|
||||||
|
|
||||||
|
query := &dns.Msg{}
|
||||||
|
query.SetQuestion(dns.Fqdn(tt.queryDomain), dns.TypeA)
|
||||||
|
|
||||||
|
mockWriter := &test.MockResponseWriter{}
|
||||||
|
resp := forwarder.handleDNSQuery(mockWriter, query)
|
||||||
|
|
||||||
|
if tt.shouldResolve {
|
||||||
|
require.NotNil(t, resp, "Expected response for authorized domain")
|
||||||
|
require.Equal(t, dns.RcodeSuccess, resp.Rcode, "Expected successful response")
|
||||||
|
assert.NotEmpty(t, resp.Answer, "Expected DNS answer records")
|
||||||
|
|
||||||
|
time.Sleep(10 * time.Millisecond)
|
||||||
|
mockFirewall.AssertExpectations(t)
|
||||||
|
mockResolver.AssertExpectations(t)
|
||||||
|
} else {
|
||||||
|
if resp != nil {
|
||||||
|
assert.True(t, len(resp.Answer) == 0 || resp.Rcode != dns.RcodeSuccess,
|
||||||
|
"Unauthorized domain should not return successful answers")
|
||||||
|
}
|
||||||
|
mockFirewall.AssertNotCalled(t, "UpdateSet")
|
||||||
|
mockResolver.AssertNotCalled(t, "LookupNetIP")
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestDNSForwarder_FirewallSetUpdates(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
configuredDomains []string
|
||||||
|
query string
|
||||||
|
mockIP string
|
||||||
|
shouldResolve bool
|
||||||
|
expectedSetCount int // How many sets should be updated
|
||||||
|
description string
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "exact domain gets firewall update",
|
||||||
|
configuredDomains: []string{"example.com"},
|
||||||
|
query: "example.com",
|
||||||
|
mockIP: "1.1.1.1",
|
||||||
|
shouldResolve: true,
|
||||||
|
expectedSetCount: 1,
|
||||||
|
description: "Single exact match updates one set",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "wildcard domain gets firewall update",
|
||||||
|
configuredDomains: []string{"*.example.com"},
|
||||||
|
query: "mail.example.com",
|
||||||
|
mockIP: "1.1.1.2",
|
||||||
|
shouldResolve: true,
|
||||||
|
expectedSetCount: 1,
|
||||||
|
description: "Wildcard match updates one set",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "overlapping exact and wildcard both get updates",
|
||||||
|
configuredDomains: []string{"*.example.com", "mail.example.com"},
|
||||||
|
query: "mail.example.com",
|
||||||
|
mockIP: "1.1.1.3",
|
||||||
|
shouldResolve: true,
|
||||||
|
expectedSetCount: 2,
|
||||||
|
description: "Both exact and wildcard sets should be updated",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "unauthorized domain gets no firewall update",
|
||||||
|
configuredDomains: []string{"example.com"},
|
||||||
|
query: "mail.example.com",
|
||||||
|
mockIP: "1.1.1.4",
|
||||||
|
shouldResolve: false,
|
||||||
|
expectedSetCount: 0,
|
||||||
|
description: "No firewall update for unauthorized domains",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "multiple wildcards matching get all updated",
|
||||||
|
configuredDomains: []string{"*.example.com", "*.sub.example.com"},
|
||||||
|
query: "test.sub.example.com",
|
||||||
|
mockIP: "1.1.1.5",
|
||||||
|
shouldResolve: true,
|
||||||
|
expectedSetCount: 2,
|
||||||
|
description: "All matching wildcard sets should be updated",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
mockFirewall := &MockFirewall{}
|
||||||
|
mockResolver := &MockResolver{}
|
||||||
|
|
||||||
|
// Set up forwarder
|
||||||
|
forwarder := NewDNSForwarder("127.0.0.1:0", 300, mockFirewall, &peer.Status{})
|
||||||
|
forwarder.resolver = mockResolver
|
||||||
|
|
||||||
|
// Create entries and track sets
|
||||||
|
var entries []*ForwarderEntry
|
||||||
|
sets := make([]firewall.Set, 0)
|
||||||
|
|
||||||
|
for i, configDomain := range tt.configuredDomains {
|
||||||
|
d, err := domain.FromString(configDomain)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
set := firewall.NewDomainSet([]domain.Domain{d})
|
||||||
|
sets = append(sets, set)
|
||||||
|
|
||||||
|
entries = append(entries, &ForwarderEntry{
|
||||||
|
Domain: d,
|
||||||
|
ResID: route.ResID(fmt.Sprintf("res-%d", i)),
|
||||||
|
Set: set,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
forwarder.UpdateDomains(entries)
|
||||||
|
|
||||||
|
// Set up mocks
|
||||||
|
if tt.shouldResolve {
|
||||||
|
fakeIP := netip.MustParseAddr(tt.mockIP)
|
||||||
|
mockResolver.On("LookupNetIP", mock.Anything, "ip4", dns.Fqdn(tt.query)).
|
||||||
|
Return([]netip.Addr{fakeIP}, nil).Once()
|
||||||
|
|
||||||
|
expectedPrefixes := []netip.Prefix{netip.PrefixFrom(fakeIP, 32)}
|
||||||
|
|
||||||
|
// Count how many sets should actually match
|
||||||
|
updateCount := 0
|
||||||
|
for i, entry := range entries {
|
||||||
|
domain := strings.ToLower(tt.query)
|
||||||
|
pattern := entry.Domain.PunycodeString()
|
||||||
|
|
||||||
|
matches := false
|
||||||
|
if strings.HasPrefix(pattern, "*.") {
|
||||||
|
baseDomain := strings.TrimPrefix(pattern, "*.")
|
||||||
|
if domain == baseDomain || strings.HasSuffix(domain, "."+baseDomain) {
|
||||||
|
matches = true
|
||||||
|
}
|
||||||
|
} else if domain == pattern {
|
||||||
|
matches = true
|
||||||
|
}
|
||||||
|
|
||||||
|
if matches {
|
||||||
|
mockFirewall.On("UpdateSet", sets[i], expectedPrefixes).Return(nil).Once()
|
||||||
|
updateCount++
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
assert.Equal(t, tt.expectedSetCount, updateCount,
|
||||||
|
"Expected %d sets to be updated, but mock expects %d",
|
||||||
|
tt.expectedSetCount, updateCount)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Execute query
|
||||||
|
dnsQuery := &dns.Msg{}
|
||||||
|
dnsQuery.SetQuestion(dns.Fqdn(tt.query), dns.TypeA)
|
||||||
|
|
||||||
|
mockWriter := &test.MockResponseWriter{}
|
||||||
|
resp := forwarder.handleDNSQuery(mockWriter, dnsQuery)
|
||||||
|
|
||||||
|
// Verify response
|
||||||
|
if tt.shouldResolve {
|
||||||
|
require.NotNil(t, resp, "Expected response for authorized domain")
|
||||||
|
require.Equal(t, dns.RcodeSuccess, resp.Rcode)
|
||||||
|
require.NotEmpty(t, resp.Answer)
|
||||||
|
} else if resp != nil {
|
||||||
|
assert.True(t, resp.Rcode == dns.RcodeRefused || len(resp.Answer) == 0,
|
||||||
|
"Unauthorized domain should be refused or have no answers")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Verify all mock expectations were met
|
||||||
|
mockFirewall.AssertExpectations(t)
|
||||||
|
mockResolver.AssertExpectations(t)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Test to verify that multiple IPs for one domain result in all prefixes being sent together
|
||||||
|
func TestDNSForwarder_MultipleIPsInSingleUpdate(t *testing.T) {
|
||||||
|
mockFirewall := &MockFirewall{}
|
||||||
|
mockResolver := &MockResolver{}
|
||||||
|
|
||||||
|
forwarder := NewDNSForwarder("127.0.0.1:0", 300, mockFirewall, &peer.Status{})
|
||||||
|
forwarder.resolver = mockResolver
|
||||||
|
|
||||||
|
// Configure a single domain
|
||||||
|
d, err := domain.FromString("example.com")
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
set := firewall.NewDomainSet([]domain.Domain{d})
|
||||||
|
entries := []*ForwarderEntry{{
|
||||||
|
Domain: d,
|
||||||
|
ResID: "test-res",
|
||||||
|
Set: set,
|
||||||
|
}}
|
||||||
|
|
||||||
|
forwarder.UpdateDomains(entries)
|
||||||
|
|
||||||
|
// Mock resolver returns multiple IPs
|
||||||
|
ips := []netip.Addr{
|
||||||
|
netip.MustParseAddr("1.1.1.1"),
|
||||||
|
netip.MustParseAddr("1.1.1.2"),
|
||||||
|
netip.MustParseAddr("1.1.1.3"),
|
||||||
|
}
|
||||||
|
mockResolver.On("LookupNetIP", mock.Anything, "ip4", "example.com.").
|
||||||
|
Return(ips, nil).Once()
|
||||||
|
|
||||||
|
// Expect ONE UpdateSet call with ALL prefixes
|
||||||
|
expectedPrefixes := []netip.Prefix{
|
||||||
|
netip.PrefixFrom(ips[0], 32),
|
||||||
|
netip.PrefixFrom(ips[1], 32),
|
||||||
|
netip.PrefixFrom(ips[2], 32),
|
||||||
|
}
|
||||||
|
mockFirewall.On("UpdateSet", set, expectedPrefixes).Return(nil).Once()
|
||||||
|
|
||||||
|
// Execute query
|
||||||
|
query := &dns.Msg{}
|
||||||
|
query.SetQuestion("example.com.", dns.TypeA)
|
||||||
|
|
||||||
|
mockWriter := &test.MockResponseWriter{}
|
||||||
|
resp := forwarder.handleDNSQuery(mockWriter, query)
|
||||||
|
|
||||||
|
// Verify response contains all IPs
|
||||||
|
require.NotNil(t, resp)
|
||||||
|
require.Equal(t, dns.RcodeSuccess, resp.Rcode)
|
||||||
|
require.Len(t, resp.Answer, 3, "Should have 3 answer records")
|
||||||
|
|
||||||
|
// Verify mocks
|
||||||
|
mockFirewall.AssertExpectations(t)
|
||||||
|
mockResolver.AssertExpectations(t)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestDNSForwarder_ResponseCodes(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
queryType uint16
|
||||||
|
queryDomain string
|
||||||
|
configured string
|
||||||
|
expectedCode int
|
||||||
|
description string
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "unauthorized domain returns REFUSED",
|
||||||
|
queryType: dns.TypeA,
|
||||||
|
queryDomain: "evil.com",
|
||||||
|
configured: "example.com",
|
||||||
|
expectedCode: dns.RcodeRefused,
|
||||||
|
description: "RFC compliant REFUSED for unauthorized queries",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "unsupported query type returns NOTIMP",
|
||||||
|
queryType: dns.TypeMX,
|
||||||
|
queryDomain: "example.com",
|
||||||
|
configured: "example.com",
|
||||||
|
expectedCode: dns.RcodeNotImplemented,
|
||||||
|
description: "RFC compliant NOTIMP for unsupported types",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "CNAME query returns NOTIMP",
|
||||||
|
queryType: dns.TypeCNAME,
|
||||||
|
queryDomain: "example.com",
|
||||||
|
configured: "example.com",
|
||||||
|
expectedCode: dns.RcodeNotImplemented,
|
||||||
|
description: "CNAME queries not supported",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "TXT query returns NOTIMP",
|
||||||
|
queryType: dns.TypeTXT,
|
||||||
|
queryDomain: "example.com",
|
||||||
|
configured: "example.com",
|
||||||
|
expectedCode: dns.RcodeNotImplemented,
|
||||||
|
description: "TXT queries not supported",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
forwarder := NewDNSForwarder("127.0.0.1:0", 300, nil, &peer.Status{})
|
||||||
|
|
||||||
|
d, err := domain.FromString(tt.configured)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
entries := []*ForwarderEntry{{Domain: d, ResID: "test-res"}}
|
||||||
|
forwarder.UpdateDomains(entries)
|
||||||
|
|
||||||
|
query := &dns.Msg{}
|
||||||
|
query.SetQuestion(dns.Fqdn(tt.queryDomain), tt.queryType)
|
||||||
|
|
||||||
|
// Capture the written response
|
||||||
|
var writtenResp *dns.Msg
|
||||||
|
mockWriter := &test.MockResponseWriter{
|
||||||
|
WriteMsgFunc: func(m *dns.Msg) error {
|
||||||
|
writtenResp = m
|
||||||
|
return nil
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
_ = forwarder.handleDNSQuery(mockWriter, query)
|
||||||
|
|
||||||
|
// Check the response written to the writer
|
||||||
|
require.NotNil(t, writtenResp, "Expected response to be written")
|
||||||
|
assert.Equal(t, tt.expectedCode, writtenResp.Rcode, tt.description)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestDNSForwarder_TCPTruncation(t *testing.T) {
|
||||||
|
// Test that large UDP responses are truncated with TC bit set
|
||||||
|
mockResolver := &MockResolver{}
|
||||||
|
forwarder := NewDNSForwarder("127.0.0.1:0", 300, nil, &peer.Status{})
|
||||||
|
forwarder.resolver = mockResolver
|
||||||
|
|
||||||
|
d, _ := domain.FromString("example.com")
|
||||||
|
entries := []*ForwarderEntry{{Domain: d, ResID: "test-res"}}
|
||||||
|
forwarder.UpdateDomains(entries)
|
||||||
|
|
||||||
|
// Mock many IPs to create a large response
|
||||||
|
var manyIPs []netip.Addr
|
||||||
|
for i := 0; i < 100; i++ {
|
||||||
|
manyIPs = append(manyIPs, netip.MustParseAddr(fmt.Sprintf("1.1.1.%d", i%256)))
|
||||||
|
}
|
||||||
|
mockResolver.On("LookupNetIP", mock.Anything, "ip4", "example.com.").Return(manyIPs, nil)
|
||||||
|
|
||||||
|
// Query without EDNS0
|
||||||
|
query := &dns.Msg{}
|
||||||
|
query.SetQuestion("example.com.", dns.TypeA)
|
||||||
|
|
||||||
|
var writtenResp *dns.Msg
|
||||||
|
mockWriter := &test.MockResponseWriter{
|
||||||
|
WriteMsgFunc: func(m *dns.Msg) error {
|
||||||
|
writtenResp = m
|
||||||
|
return nil
|
||||||
|
},
|
||||||
|
}
|
||||||
|
forwarder.handleDNSQueryUDP(mockWriter, query)
|
||||||
|
|
||||||
|
require.NotNil(t, writtenResp)
|
||||||
|
assert.True(t, writtenResp.Truncated, "Large response should be truncated")
|
||||||
|
assert.LessOrEqual(t, writtenResp.Len(), dns.MinMsgSize, "Response should fit in minimum UDP size")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestDNSForwarder_MultipleOverlappingPatterns(t *testing.T) {
|
||||||
|
// Test complex overlapping pattern scenarios
|
||||||
|
mockFirewall := &MockFirewall{}
|
||||||
|
mockResolver := &MockResolver{}
|
||||||
|
|
||||||
|
forwarder := NewDNSForwarder("127.0.0.1:0", 300, mockFirewall, &peer.Status{})
|
||||||
|
forwarder.resolver = mockResolver
|
||||||
|
|
||||||
|
// Set up complex overlapping patterns
|
||||||
|
patterns := []string{
|
||||||
|
"*.example.com", // Matches all subdomains
|
||||||
|
"*.mail.example.com", // More specific wildcard
|
||||||
|
"smtp.mail.example.com", // Exact match
|
||||||
|
"example.com", // Base domain
|
||||||
|
}
|
||||||
|
|
||||||
|
var entries []*ForwarderEntry
|
||||||
|
sets := make(map[string]firewall.Set)
|
||||||
|
|
||||||
|
for _, pattern := range patterns {
|
||||||
|
d, _ := domain.FromString(pattern)
|
||||||
|
set := firewall.NewDomainSet([]domain.Domain{d})
|
||||||
|
sets[pattern] = set
|
||||||
|
entries = append(entries, &ForwarderEntry{
|
||||||
|
Domain: d,
|
||||||
|
ResID: route.ResID("res-" + pattern),
|
||||||
|
Set: set,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
forwarder.UpdateDomains(entries)
|
||||||
|
|
||||||
|
// Test smtp.mail.example.com - should match 3 patterns
|
||||||
|
fakeIP := netip.MustParseAddr("1.2.3.4")
|
||||||
|
mockResolver.On("LookupNetIP", mock.Anything, "ip4", "smtp.mail.example.com.").Return([]netip.Addr{fakeIP}, nil)
|
||||||
|
|
||||||
|
expectedPrefix := netip.PrefixFrom(fakeIP, 32)
|
||||||
|
// All three matching patterns should get firewall updates
|
||||||
|
mockFirewall.On("UpdateSet", sets["smtp.mail.example.com"], []netip.Prefix{expectedPrefix}).Return(nil)
|
||||||
|
mockFirewall.On("UpdateSet", sets["*.mail.example.com"], []netip.Prefix{expectedPrefix}).Return(nil)
|
||||||
|
mockFirewall.On("UpdateSet", sets["*.example.com"], []netip.Prefix{expectedPrefix}).Return(nil)
|
||||||
|
|
||||||
|
query := &dns.Msg{}
|
||||||
|
query.SetQuestion("smtp.mail.example.com.", dns.TypeA)
|
||||||
|
|
||||||
|
mockWriter := &test.MockResponseWriter{}
|
||||||
|
resp := forwarder.handleDNSQuery(mockWriter, query)
|
||||||
|
|
||||||
|
require.NotNil(t, resp)
|
||||||
|
assert.Equal(t, dns.RcodeSuccess, resp.Rcode)
|
||||||
|
|
||||||
|
// Verify all three sets were updated
|
||||||
|
mockFirewall.AssertExpectations(t)
|
||||||
|
|
||||||
|
// Verify the most specific ResID was selected
|
||||||
|
// (exact match should win over wildcards)
|
||||||
|
resID, matches := forwarder.getMatchingEntries("smtp.mail.example.com")
|
||||||
|
assert.Equal(t, route.ResID("res-smtp.mail.example.com"), resID)
|
||||||
|
assert.Len(t, matches, 3, "Should match 3 patterns")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestDNSForwarder_EmptyQuery(t *testing.T) {
|
||||||
|
// Test handling of malformed query with no questions
|
||||||
|
forwarder := NewDNSForwarder("127.0.0.1:0", 300, nil, &peer.Status{})
|
||||||
|
|
||||||
|
query := &dns.Msg{}
|
||||||
|
// Don't set any question
|
||||||
|
|
||||||
|
writeCalled := false
|
||||||
|
mockWriter := &test.MockResponseWriter{
|
||||||
|
WriteMsgFunc: func(m *dns.Msg) error {
|
||||||
|
writeCalled = true
|
||||||
|
return nil
|
||||||
|
},
|
||||||
|
}
|
||||||
|
resp := forwarder.handleDNSQuery(mockWriter, query)
|
||||||
|
|
||||||
|
assert.Nil(t, resp, "Should return nil for empty query")
|
||||||
|
assert.False(t, writeCalled, "Should not write response for empty query")
|
||||||
|
}
|
||||||
|
@ -1527,6 +1527,7 @@ func (e *Engine) newWgIface() (*iface.WGIface, error) {
|
|||||||
MTU: iface.DefaultMTU,
|
MTU: iface.DefaultMTU,
|
||||||
TransportNet: transportNet,
|
TransportNet: transportNet,
|
||||||
FilterFn: e.addrViaRoutes,
|
FilterFn: e.addrViaRoutes,
|
||||||
|
DisableDNS: e.config.DisableDNS,
|
||||||
}
|
}
|
||||||
|
|
||||||
switch runtime.GOOS {
|
switch runtime.GOOS {
|
||||||
|
@ -204,7 +204,7 @@ func (c *ConnTrack) handleEvent(event nfct.Event) {
|
|||||||
eventStr = "Ended"
|
eventStr = "Ended"
|
||||||
}
|
}
|
||||||
|
|
||||||
log.Tracef("%s %s %s connection: %s:%d -> %s:%d", eventStr, direction, proto, srcIP, srcPort, dstIP, dstPort)
|
log.Tracef("%s %s %s connection: %s:%d → %s:%d", eventStr, direction, proto, srcIP, srcPort, dstIP, dstPort)
|
||||||
|
|
||||||
c.flowLogger.StoreEvent(nftypes.EventFields{
|
c.flowLogger.StoreEvent(nftypes.EventFields{
|
||||||
FlowID: flowID,
|
FlowID: flowID,
|
||||||
|
@ -575,13 +575,12 @@ func (d *Status) UpdatePeerFQDN(peerPubKey, fqdn string) error {
|
|||||||
// FinishPeerListModifications this event invoke the notification
|
// FinishPeerListModifications this event invoke the notification
|
||||||
func (d *Status) FinishPeerListModifications() {
|
func (d *Status) FinishPeerListModifications() {
|
||||||
d.mux.Lock()
|
d.mux.Lock()
|
||||||
|
defer d.mux.Unlock()
|
||||||
|
|
||||||
if !d.peerListChangedForNotification {
|
if !d.peerListChangedForNotification {
|
||||||
d.mux.Unlock()
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
d.peerListChangedForNotification = false
|
d.peerListChangedForNotification = false
|
||||||
d.mux.Unlock()
|
|
||||||
|
|
||||||
d.notifyPeerListChanged()
|
d.notifyPeerListChanged()
|
||||||
|
|
||||||
|
@ -4,7 +4,6 @@ import (
|
|||||||
"context"
|
"context"
|
||||||
"fmt"
|
"fmt"
|
||||||
"reflect"
|
"reflect"
|
||||||
"runtime"
|
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
log "github.com/sirupsen/logrus"
|
log "github.com/sirupsen/logrus"
|
||||||
@ -23,7 +22,7 @@ import (
|
|||||||
|
|
||||||
const (
|
const (
|
||||||
handlerTypeDynamic = iota
|
handlerTypeDynamic = iota
|
||||||
handlerTypeDomain
|
handlerTypeDnsInterceptor
|
||||||
handlerTypeStatic
|
handlerTypeStatic
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -566,13 +565,14 @@ func HandlerFromRoute(
|
|||||||
useNewDNSRoute bool,
|
useNewDNSRoute bool,
|
||||||
) RouteHandler {
|
) RouteHandler {
|
||||||
switch handlerType(rt, useNewDNSRoute) {
|
switch handlerType(rt, useNewDNSRoute) {
|
||||||
case handlerTypeDomain:
|
case handlerTypeDnsInterceptor:
|
||||||
return dnsinterceptor.New(
|
return dnsinterceptor.New(
|
||||||
rt,
|
rt,
|
||||||
routeRefCounter,
|
routeRefCounter,
|
||||||
allowedIPsRefCounter,
|
allowedIPsRefCounter,
|
||||||
statusRecorder,
|
statusRecorder,
|
||||||
dnsServer,
|
dnsServer,
|
||||||
|
wgInterface,
|
||||||
peerStore,
|
peerStore,
|
||||||
)
|
)
|
||||||
case handlerTypeDynamic:
|
case handlerTypeDynamic:
|
||||||
@ -596,8 +596,8 @@ func handlerType(rt *route.Route, useNewDNSRoute bool) int {
|
|||||||
return handlerTypeStatic
|
return handlerTypeStatic
|
||||||
}
|
}
|
||||||
|
|
||||||
if useNewDNSRoute && runtime.GOOS != "ios" {
|
if useNewDNSRoute {
|
||||||
return handlerTypeDomain
|
return handlerTypeDnsInterceptor
|
||||||
}
|
}
|
||||||
return handlerTypeDynamic
|
return handlerTypeDynamic
|
||||||
}
|
}
|
||||||
|
@ -12,6 +12,7 @@ import (
|
|||||||
log "github.com/sirupsen/logrus"
|
log "github.com/sirupsen/logrus"
|
||||||
|
|
||||||
nberrors "github.com/netbirdio/netbird/client/errors"
|
nberrors "github.com/netbirdio/netbird/client/errors"
|
||||||
|
"github.com/netbirdio/netbird/client/iface/wgaddr"
|
||||||
nbdns "github.com/netbirdio/netbird/client/internal/dns"
|
nbdns "github.com/netbirdio/netbird/client/internal/dns"
|
||||||
"github.com/netbirdio/netbird/client/internal/dnsfwd"
|
"github.com/netbirdio/netbird/client/internal/dnsfwd"
|
||||||
"github.com/netbirdio/netbird/client/internal/peer"
|
"github.com/netbirdio/netbird/client/internal/peer"
|
||||||
@ -23,6 +24,11 @@ import (
|
|||||||
|
|
||||||
type domainMap map[domain.Domain][]netip.Prefix
|
type domainMap map[domain.Domain][]netip.Prefix
|
||||||
|
|
||||||
|
type wgInterface interface {
|
||||||
|
Name() string
|
||||||
|
Address() wgaddr.Address
|
||||||
|
}
|
||||||
|
|
||||||
type DnsInterceptor struct {
|
type DnsInterceptor struct {
|
||||||
mu sync.RWMutex
|
mu sync.RWMutex
|
||||||
route *route.Route
|
route *route.Route
|
||||||
@ -32,6 +38,7 @@ type DnsInterceptor struct {
|
|||||||
dnsServer nbdns.Server
|
dnsServer nbdns.Server
|
||||||
currentPeerKey string
|
currentPeerKey string
|
||||||
interceptedDomains domainMap
|
interceptedDomains domainMap
|
||||||
|
wgInterface wgInterface
|
||||||
peerStore *peerstore.Store
|
peerStore *peerstore.Store
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -41,6 +48,7 @@ func New(
|
|||||||
allowedIPsRefCounter *refcounter.AllowedIPsRefCounter,
|
allowedIPsRefCounter *refcounter.AllowedIPsRefCounter,
|
||||||
statusRecorder *peer.Status,
|
statusRecorder *peer.Status,
|
||||||
dnsServer nbdns.Server,
|
dnsServer nbdns.Server,
|
||||||
|
wgInterface wgInterface,
|
||||||
peerStore *peerstore.Store,
|
peerStore *peerstore.Store,
|
||||||
) *DnsInterceptor {
|
) *DnsInterceptor {
|
||||||
return &DnsInterceptor{
|
return &DnsInterceptor{
|
||||||
@ -49,6 +57,7 @@ func New(
|
|||||||
allowedIPsRefcounter: allowedIPsRefCounter,
|
allowedIPsRefcounter: allowedIPsRefCounter,
|
||||||
statusRecorder: statusRecorder,
|
statusRecorder: statusRecorder,
|
||||||
dnsServer: dnsServer,
|
dnsServer: dnsServer,
|
||||||
|
wgInterface: wgInterface,
|
||||||
interceptedDomains: make(domainMap),
|
interceptedDomains: make(domainMap),
|
||||||
peerStore: peerStore,
|
peerStore: peerStore,
|
||||||
}
|
}
|
||||||
@ -135,15 +144,18 @@ func (d *DnsInterceptor) RemoveAllowedIPs() error {
|
|||||||
|
|
||||||
// ServeDNS implements the dns.Handler interface
|
// ServeDNS implements the dns.Handler interface
|
||||||
func (d *DnsInterceptor) ServeDNS(w dns.ResponseWriter, r *dns.Msg) {
|
func (d *DnsInterceptor) ServeDNS(w dns.ResponseWriter, r *dns.Msg) {
|
||||||
|
requestID := nbdns.GenerateRequestID()
|
||||||
|
logger := log.WithField("request_id", requestID)
|
||||||
|
|
||||||
if len(r.Question) == 0 {
|
if len(r.Question) == 0 {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
log.Tracef("received DNS request for domain=%s type=%v class=%v",
|
logger.Tracef("received DNS request for domain=%s type=%v class=%v",
|
||||||
r.Question[0].Name, r.Question[0].Qtype, r.Question[0].Qclass)
|
r.Question[0].Name, r.Question[0].Qtype, r.Question[0].Qclass)
|
||||||
|
|
||||||
// pass if non A/AAAA query
|
// pass if non A/AAAA query
|
||||||
if r.Question[0].Qtype != dns.TypeA && r.Question[0].Qtype != dns.TypeAAAA {
|
if r.Question[0].Qtype != dns.TypeA && r.Question[0].Qtype != dns.TypeAAAA {
|
||||||
d.continueToNextHandler(w, r, "non A/AAAA query")
|
d.continueToNextHandler(w, r, logger, "non A/AAAA query")
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -152,29 +164,32 @@ func (d *DnsInterceptor) ServeDNS(w dns.ResponseWriter, r *dns.Msg) {
|
|||||||
d.mu.RUnlock()
|
d.mu.RUnlock()
|
||||||
|
|
||||||
if peerKey == "" {
|
if peerKey == "" {
|
||||||
d.writeDNSError(w, r, "no current peer key")
|
d.writeDNSError(w, r, logger, "no current peer key")
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
upstreamIP, err := d.getUpstreamIP(peerKey)
|
upstreamIP, err := d.getUpstreamIP(peerKey)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
d.writeDNSError(w, r, fmt.Sprintf("get upstream IP: %v", err))
|
d.writeDNSError(w, r, logger, fmt.Sprintf("get upstream IP: %v", err))
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
client, err := nbdns.GetClientPrivate(d.wgInterface.Address().IP, d.wgInterface.Name(), nbdns.UpstreamTimeout)
|
||||||
|
if err != nil {
|
||||||
|
d.writeDNSError(w, r, logger, fmt.Sprintf("create DNS client: %v", err))
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
if r.Extra == nil {
|
if r.Extra == nil {
|
||||||
r.MsgHdr.AuthenticatedData = true
|
r.MsgHdr.AuthenticatedData = true
|
||||||
}
|
}
|
||||||
client := &dns.Client{
|
|
||||||
Timeout: nbdns.UpstreamTimeout,
|
|
||||||
Net: "udp",
|
|
||||||
}
|
|
||||||
upstream := fmt.Sprintf("%s:%d", upstreamIP.String(), dnsfwd.ListenPort)
|
upstream := fmt.Sprintf("%s:%d", upstreamIP.String(), dnsfwd.ListenPort)
|
||||||
reply, _, err := nbdns.ExchangeWithFallback(context.TODO(), client, r, upstream)
|
reply, _, err := nbdns.ExchangeWithFallback(context.TODO(), client, r, upstream)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Errorf("failed to exchange DNS request with %s (%s) for domain=%s: %v", upstreamIP.String(), peerKey, r.Question[0].Name, err)
|
logger.Errorf("failed to exchange DNS request with %s (%s) for domain=%s: %v", upstreamIP.String(), peerKey, r.Question[0].Name, err)
|
||||||
if err := w.WriteMsg(&dns.Msg{MsgHdr: dns.MsgHdr{Rcode: dns.RcodeServerFailure, Id: r.Id}}); err != nil {
|
if err := w.WriteMsg(&dns.Msg{MsgHdr: dns.MsgHdr{Rcode: dns.RcodeServerFailure, Id: r.Id}}); err != nil {
|
||||||
log.Errorf("failed writing DNS response: %v", err)
|
logger.Errorf("failed writing DNS response: %v", err)
|
||||||
}
|
}
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
@ -184,34 +199,34 @@ func (d *DnsInterceptor) ServeDNS(w dns.ResponseWriter, r *dns.Msg) {
|
|||||||
answer = reply.Answer
|
answer = reply.Answer
|
||||||
}
|
}
|
||||||
|
|
||||||
log.Tracef("upstream %s (%s) DNS response for domain=%s answers=%v", upstreamIP.String(), peerKey, r.Question[0].Name, answer)
|
logger.Tracef("upstream %s (%s) DNS response for domain=%s answers=%v", upstreamIP.String(), peerKey, r.Question[0].Name, answer)
|
||||||
|
|
||||||
reply.Id = r.Id
|
reply.Id = r.Id
|
||||||
if err := d.writeMsg(w, reply); err != nil {
|
if err := d.writeMsg(w, reply); err != nil {
|
||||||
log.Errorf("failed writing DNS response: %v", err)
|
logger.Errorf("failed writing DNS response: %v", err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (d *DnsInterceptor) writeDNSError(w dns.ResponseWriter, r *dns.Msg, reason string) {
|
func (d *DnsInterceptor) writeDNSError(w dns.ResponseWriter, r *dns.Msg, logger *log.Entry, reason string) {
|
||||||
log.Warnf("failed to query upstream for domain=%s: %s", r.Question[0].Name, reason)
|
logger.Warnf("failed to query upstream for domain=%s: %s", r.Question[0].Name, reason)
|
||||||
|
|
||||||
resp := new(dns.Msg)
|
resp := new(dns.Msg)
|
||||||
resp.SetRcode(r, dns.RcodeServerFailure)
|
resp.SetRcode(r, dns.RcodeServerFailure)
|
||||||
if err := w.WriteMsg(resp); err != nil {
|
if err := w.WriteMsg(resp); err != nil {
|
||||||
log.Errorf("failed to write DNS error response: %v", err)
|
logger.Errorf("failed to write DNS error response: %v", err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// continueToNextHandler signals the handler chain to try the next handler
|
// continueToNextHandler signals the handler chain to try the next handler
|
||||||
func (d *DnsInterceptor) continueToNextHandler(w dns.ResponseWriter, r *dns.Msg, reason string) {
|
func (d *DnsInterceptor) continueToNextHandler(w dns.ResponseWriter, r *dns.Msg, logger *log.Entry, reason string) {
|
||||||
log.Tracef("continuing to next handler for domain=%s reason=%s", r.Question[0].Name, reason)
|
logger.Tracef("continuing to next handler for domain=%s reason=%s", r.Question[0].Name, reason)
|
||||||
|
|
||||||
resp := new(dns.Msg)
|
resp := new(dns.Msg)
|
||||||
resp.SetRcode(r, dns.RcodeNameError)
|
resp.SetRcode(r, dns.RcodeNameError)
|
||||||
// Set Zero bit to signal handler chain to continue
|
// Set Zero bit to signal handler chain to continue
|
||||||
resp.MsgHdr.Zero = true
|
resp.MsgHdr.Zero = true
|
||||||
if err := w.WriteMsg(resp); err != nil {
|
if err := w.WriteMsg(resp); err != nil {
|
||||||
log.Errorf("failed writing DNS continue response: %v", err)
|
logger.Errorf("failed writing DNS continue response: %v", err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -15,7 +15,7 @@ import (
|
|||||||
// MockManager is the mock instance of a route manager
|
// MockManager is the mock instance of a route manager
|
||||||
type MockManager struct {
|
type MockManager struct {
|
||||||
ClassifyRoutesFunc func(routes []*route.Route) (map[route.ID]*route.Route, route.HAMap)
|
ClassifyRoutesFunc func(routes []*route.Route) (map[route.ID]*route.Route, route.HAMap)
|
||||||
UpdateRoutesFunc func (updateSerial uint64, serverRoutes map[route.ID]*route.Route, clientRoutes route.HAMap, useNewDNSRoute bool) error
|
UpdateRoutesFunc func(updateSerial uint64, serverRoutes map[route.ID]*route.Route, clientRoutes route.HAMap, useNewDNSRoute bool) error
|
||||||
TriggerSelectionFunc func(haMap route.HAMap)
|
TriggerSelectionFunc func(haMap route.HAMap)
|
||||||
GetRouteSelectorFunc func() *routeselector.RouteSelector
|
GetRouteSelectorFunc func() *routeselector.RouteSelector
|
||||||
GetClientRoutesFunc func() route.HAMap
|
GetClientRoutesFunc func() route.HAMap
|
||||||
|
@ -32,7 +32,6 @@ func (n *Notifier) SetListener(listener listener.NetworkChangeListener) {
|
|||||||
func (n *Notifier) SetInitialClientRoutes(clientRoutes []*route.Route) {
|
func (n *Notifier) SetInitialClientRoutes(clientRoutes []*route.Route) {
|
||||||
nets := make([]string, 0)
|
nets := make([]string, 0)
|
||||||
for _, r := range clientRoutes {
|
for _, r := range clientRoutes {
|
||||||
// filter out domain routes
|
|
||||||
if r.IsDynamic() {
|
if r.IsDynamic() {
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
@ -46,30 +45,27 @@ func (n *Notifier) OnNewRoutes(idMap route.HAMap) {
|
|||||||
if runtime.GOOS != "android" {
|
if runtime.GOOS != "android" {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
newNets := make([]string, 0)
|
|
||||||
|
var newNets []string
|
||||||
for _, routes := range idMap {
|
for _, routes := range idMap {
|
||||||
for _, r := range routes {
|
for _, r := range routes {
|
||||||
|
if r.IsDynamic() {
|
||||||
|
continue
|
||||||
|
}
|
||||||
newNets = append(newNets, r.Network.String())
|
newNets = append(newNets, r.Network.String())
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
sort.Strings(newNets)
|
sort.Strings(newNets)
|
||||||
switch runtime.GOOS {
|
|
||||||
case "android":
|
|
||||||
if !n.hasDiff(n.initialRouteRanges, newNets) {
|
if !n.hasDiff(n.initialRouteRanges, newNets) {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
default:
|
|
||||||
if !n.hasDiff(n.routeRanges, newNets) {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
n.routeRanges = newNets
|
n.routeRanges = newNets
|
||||||
|
|
||||||
n.notify()
|
n.notify()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// OnNewPrefixes is called from iOS only
|
||||||
func (n *Notifier) OnNewPrefixes(prefixes []netip.Prefix) {
|
func (n *Notifier) OnNewPrefixes(prefixes []netip.Prefix) {
|
||||||
newNets := make([]string, 0)
|
newNets := make([]string, 0)
|
||||||
for _, prefix := range prefixes {
|
for _, prefix := range prefixes {
|
||||||
@ -77,19 +73,11 @@ func (n *Notifier) OnNewPrefixes(prefixes []netip.Prefix) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
sort.Strings(newNets)
|
sort.Strings(newNets)
|
||||||
switch runtime.GOOS {
|
|
||||||
case "android":
|
|
||||||
if !n.hasDiff(n.initialRouteRanges, newNets) {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
default:
|
|
||||||
if !n.hasDiff(n.routeRanges, newNets) {
|
if !n.hasDiff(n.routeRanges, newNets) {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
}
|
|
||||||
|
|
||||||
n.routeRanges = newNets
|
n.routeRanges = newNets
|
||||||
|
|
||||||
n.notify()
|
n.notify()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -572,7 +572,7 @@ func (s *serviceClient) updateStatus() error {
|
|||||||
var systrayIconState bool
|
var systrayIconState bool
|
||||||
|
|
||||||
switch {
|
switch {
|
||||||
case status.Status == string(internal.StatusConnected) && !s.mUp.Disabled():
|
case status.Status == string(internal.StatusConnected):
|
||||||
s.connected = true
|
s.connected = true
|
||||||
s.sendNotification = true
|
s.sendNotification = true
|
||||||
if s.isUpdateIconActive {
|
if s.isUpdateIconActive {
|
||||||
|
@ -12,6 +12,8 @@ import (
|
|||||||
"fyne.io/fyne/v2"
|
"fyne.io/fyne/v2"
|
||||||
"fyne.io/systray"
|
"fyne.io/systray"
|
||||||
log "github.com/sirupsen/logrus"
|
log "github.com/sirupsen/logrus"
|
||||||
|
|
||||||
|
"github.com/netbirdio/netbird/version"
|
||||||
)
|
)
|
||||||
|
|
||||||
type eventHandler struct {
|
type eventHandler struct {
|
||||||
@ -143,7 +145,7 @@ func (h *eventHandler) handleGitHubClick() {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (h *eventHandler) handleUpdateClick() {
|
func (h *eventHandler) handleUpdateClick() {
|
||||||
if err := openURL("https://netbird.io/download"); err != nil {
|
if err := openURL(version.DownloadUrl()); err != nil {
|
||||||
log.Errorf("failed to open download URL: %v", err)
|
log.Errorf("failed to open download URL: %v", err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -24,6 +24,7 @@ import (
|
|||||||
"golang.org/x/exp/maps"
|
"golang.org/x/exp/maps"
|
||||||
|
|
||||||
nbdns "github.com/netbirdio/netbird/dns"
|
nbdns "github.com/netbirdio/netbird/dns"
|
||||||
|
"github.com/netbirdio/netbird/formatter/hook"
|
||||||
"github.com/netbirdio/netbird/management/server/account"
|
"github.com/netbirdio/netbird/management/server/account"
|
||||||
"github.com/netbirdio/netbird/management/server/activity"
|
"github.com/netbirdio/netbird/management/server/activity"
|
||||||
nbcache "github.com/netbirdio/netbird/management/server/cache"
|
nbcache "github.com/netbirdio/netbird/management/server/cache"
|
||||||
@ -412,14 +413,15 @@ func (am *DefaultAccountManager) handlePeerLoginExpirationSettings(ctx context.C
|
|||||||
event = activity.AccountPeerLoginExpirationDisabled
|
event = activity.AccountPeerLoginExpirationDisabled
|
||||||
am.peerLoginExpiry.Cancel(ctx, []string{accountID})
|
am.peerLoginExpiry.Cancel(ctx, []string{accountID})
|
||||||
} else {
|
} else {
|
||||||
am.checkAndSchedulePeerLoginExpiration(ctx, accountID)
|
am.schedulePeerLoginExpiration(ctx, accountID)
|
||||||
}
|
}
|
||||||
am.StoreEvent(ctx, userID, accountID, accountID, event, nil)
|
am.StoreEvent(ctx, userID, accountID, accountID, event, nil)
|
||||||
}
|
}
|
||||||
|
|
||||||
if oldSettings.PeerLoginExpiration != newSettings.PeerLoginExpiration {
|
if oldSettings.PeerLoginExpiration != newSettings.PeerLoginExpiration {
|
||||||
am.StoreEvent(ctx, userID, accountID, accountID, activity.AccountPeerLoginExpirationDurationUpdated, nil)
|
am.StoreEvent(ctx, userID, accountID, accountID, activity.AccountPeerLoginExpirationDurationUpdated, nil)
|
||||||
am.checkAndSchedulePeerLoginExpiration(ctx, accountID)
|
am.peerLoginExpiry.Cancel(ctx, []string{accountID})
|
||||||
|
am.schedulePeerLoginExpiration(ctx, accountID)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -457,6 +459,10 @@ func (am *DefaultAccountManager) handleInactivityExpirationSettings(ctx context.
|
|||||||
|
|
||||||
func (am *DefaultAccountManager) peerLoginExpirationJob(ctx context.Context, accountID string) func() (time.Duration, bool) {
|
func (am *DefaultAccountManager) peerLoginExpirationJob(ctx context.Context, accountID string) func() (time.Duration, bool) {
|
||||||
return func() (time.Duration, bool) {
|
return func() (time.Duration, bool) {
|
||||||
|
//nolint
|
||||||
|
ctx := context.WithValue(ctx, nbcontext.AccountIDKey, accountID)
|
||||||
|
//nolint
|
||||||
|
ctx = context.WithValue(ctx, hook.ExecutionContextKey, fmt.Sprintf("%s-PEER-EXPIRATION", hook.SystemSource))
|
||||||
unlock := am.Store.AcquireWriteLockByUID(ctx, accountID)
|
unlock := am.Store.AcquireWriteLockByUID(ctx, accountID)
|
||||||
defer unlock()
|
defer unlock()
|
||||||
|
|
||||||
@ -481,8 +487,11 @@ func (am *DefaultAccountManager) peerLoginExpirationJob(ctx context.Context, acc
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (am *DefaultAccountManager) checkAndSchedulePeerLoginExpiration(ctx context.Context, accountID string) {
|
func (am *DefaultAccountManager) schedulePeerLoginExpiration(ctx context.Context, accountID string) {
|
||||||
am.peerLoginExpiry.Cancel(ctx, []string{accountID})
|
if am.peerLoginExpiry.IsSchedulerRunning(accountID) {
|
||||||
|
log.WithContext(ctx).Tracef("peer login expiration job for account %s is already scheduled", accountID)
|
||||||
|
return
|
||||||
|
}
|
||||||
if nextRun, ok := am.getNextPeerExpiration(ctx, accountID); ok {
|
if nextRun, ok := am.getNextPeerExpiration(ctx, accountID); ok {
|
||||||
go am.peerLoginExpiry.Schedule(ctx, nextRun, accountID, am.peerLoginExpirationJob(ctx, accountID))
|
go am.peerLoginExpiry.Schedule(ctx, nextRun, accountID, am.peerLoginExpirationJob(ctx, accountID))
|
||||||
}
|
}
|
||||||
|
@ -1862,11 +1862,8 @@ func TestDefaultAccountManager_MarkPeerConnected_PeerLoginExpiration(t *testing.
|
|||||||
require.NoError(t, err, "expecting to update account settings successfully but got error")
|
require.NoError(t, err, "expecting to update account settings successfully but got error")
|
||||||
|
|
||||||
wg := &sync.WaitGroup{}
|
wg := &sync.WaitGroup{}
|
||||||
wg.Add(2)
|
wg.Add(1)
|
||||||
manager.peerLoginExpiry = &MockScheduler{
|
manager.peerLoginExpiry = &MockScheduler{
|
||||||
CancelFunc: func(ctx context.Context, IDs []string) {
|
|
||||||
wg.Done()
|
|
||||||
},
|
|
||||||
ScheduleFunc: func(ctx context.Context, in time.Duration, ID string, job func() (nextRunIn time.Duration, reschedule bool)) {
|
ScheduleFunc: func(ctx context.Context, in time.Duration, ID string, job func() (nextRunIn time.Duration, reschedule bool)) {
|
||||||
wg.Done()
|
wg.Done()
|
||||||
},
|
},
|
||||||
|
@ -664,15 +664,6 @@ func areGroupChangesAffectPeers(ctx context.Context, transaction store.Store, ac
|
|||||||
return false, nil
|
return false, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (am *DefaultAccountManager) anyGroupHasPeers(account *types.Account, groupIDs []string) bool {
|
|
||||||
for _, groupID := range groupIDs {
|
|
||||||
if group, exists := account.Groups[groupID]; exists && group.HasPeers() {
|
|
||||||
return true
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
|
|
||||||
// anyGroupHasPeersOrResources checks if any of the given groups in the account have peers or resources.
|
// anyGroupHasPeersOrResources checks if any of the given groups in the account have peers or resources.
|
||||||
func anyGroupHasPeersOrResources(ctx context.Context, transaction store.Store, accountID string, groupIDs []string) (bool, error) {
|
func anyGroupHasPeersOrResources(ctx context.Context, transaction store.Store, accountID string, groupIDs []string) (bool, error) {
|
||||||
groups, err := transaction.GetGroupsByIDs(ctx, store.LockingStrengthShare, accountID, groupIDs)
|
groups, err := transaction.GetGroupsByIDs(ctx, store.LockingStrengthShare, accountID, groupIDs)
|
||||||
|
@ -92,7 +92,7 @@ func (am *DefaultAccountManager) getUserAccessiblePeers(ctx context.Context, acc
|
|||||||
|
|
||||||
// fetch all the peers that have access to the user's peers
|
// fetch all the peers that have access to the user's peers
|
||||||
for _, peer := range peers {
|
for _, peer := range peers {
|
||||||
aclPeers, _ := account.GetPeerConnectionResources(ctx, peer.ID, approvedPeersMap)
|
aclPeers, _ := account.GetPeerConnectionResources(ctx, peer, approvedPeersMap)
|
||||||
for _, p := range aclPeers {
|
for _, p := range aclPeers {
|
||||||
peersMap[p.ID] = p
|
peersMap[p.ID] = p
|
||||||
}
|
}
|
||||||
@ -133,7 +133,7 @@ func (am *DefaultAccountManager) MarkPeerConnected(ctx context.Context, peerPubK
|
|||||||
}
|
}
|
||||||
|
|
||||||
if peer.LoginExpirationEnabled && settings.PeerLoginExpirationEnabled {
|
if peer.LoginExpirationEnabled && settings.PeerLoginExpirationEnabled {
|
||||||
am.checkAndSchedulePeerLoginExpiration(ctx, accountID)
|
am.schedulePeerLoginExpiration(ctx, accountID)
|
||||||
}
|
}
|
||||||
|
|
||||||
if peer.InactivityExpirationEnabled && settings.PeerInactivityExpirationEnabled {
|
if peer.InactivityExpirationEnabled && settings.PeerInactivityExpirationEnabled {
|
||||||
@ -296,7 +296,8 @@ func (am *DefaultAccountManager) UpdatePeer(ctx context.Context, accountID, user
|
|||||||
am.StoreEvent(ctx, userID, peer.IP.String(), accountID, event, peer.EventMeta(dnsDomain))
|
am.StoreEvent(ctx, userID, peer.IP.String(), accountID, event, peer.EventMeta(dnsDomain))
|
||||||
|
|
||||||
if peer.AddedWithSSOLogin() && peer.LoginExpirationEnabled && settings.PeerLoginExpirationEnabled {
|
if peer.AddedWithSSOLogin() && peer.LoginExpirationEnabled && settings.PeerLoginExpirationEnabled {
|
||||||
am.checkAndSchedulePeerLoginExpiration(ctx, accountID)
|
am.peerLoginExpiry.Cancel(ctx, []string{accountID})
|
||||||
|
am.schedulePeerLoginExpiration(ctx, accountID)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -1148,7 +1149,7 @@ func (am *DefaultAccountManager) checkIfUserOwnsPeer(ctx context.Context, accoun
|
|||||||
}
|
}
|
||||||
|
|
||||||
for _, p := range userPeers {
|
for _, p := range userPeers {
|
||||||
aclPeers, _ := account.GetPeerConnectionResources(ctx, p.ID, approvedPeersMap)
|
aclPeers, _ := account.GetPeerConnectionResources(ctx, p, approvedPeersMap)
|
||||||
for _, aclPeer := range aclPeers {
|
for _, aclPeer := range aclPeers {
|
||||||
if aclPeer.ID == peer.ID {
|
if aclPeer.ID == peer.ID {
|
||||||
return peer, nil
|
return peer, nil
|
||||||
|
@ -27,6 +27,7 @@ func TestAccount_getPeersByPolicy(t *testing.T) {
|
|||||||
ID: "peerB",
|
ID: "peerB",
|
||||||
IP: net.ParseIP("100.65.80.39"),
|
IP: net.ParseIP("100.65.80.39"),
|
||||||
Status: &nbpeer.PeerStatus{},
|
Status: &nbpeer.PeerStatus{},
|
||||||
|
Meta: nbpeer.PeerSystemMeta{WtVersion: "0.48.0"},
|
||||||
},
|
},
|
||||||
"peerC": {
|
"peerC": {
|
||||||
ID: "peerC",
|
ID: "peerC",
|
||||||
@ -63,6 +64,12 @@ func TestAccount_getPeersByPolicy(t *testing.T) {
|
|||||||
IP: net.ParseIP("100.65.31.2"),
|
IP: net.ParseIP("100.65.31.2"),
|
||||||
Status: &nbpeer.PeerStatus{},
|
Status: &nbpeer.PeerStatus{},
|
||||||
},
|
},
|
||||||
|
"peerK": {
|
||||||
|
ID: "peerK",
|
||||||
|
IP: net.ParseIP("100.32.80.1"),
|
||||||
|
Status: &nbpeer.PeerStatus{},
|
||||||
|
Meta: nbpeer.PeerSystemMeta{WtVersion: "0.30.0"},
|
||||||
|
},
|
||||||
},
|
},
|
||||||
Groups: map[string]*types.Group{
|
Groups: map[string]*types.Group{
|
||||||
"GroupAll": {
|
"GroupAll": {
|
||||||
@ -111,6 +118,13 @@ func TestAccount_getPeersByPolicy(t *testing.T) {
|
|||||||
"peerI",
|
"peerI",
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
|
"GroupWorkflow": {
|
||||||
|
ID: "GroupWorkflow",
|
||||||
|
Name: "workflow",
|
||||||
|
Peers: []string{
|
||||||
|
"peerK",
|
||||||
|
},
|
||||||
|
},
|
||||||
},
|
},
|
||||||
Policies: []*types.Policy{
|
Policies: []*types.Policy{
|
||||||
{
|
{
|
||||||
@ -189,6 +203,39 @@ func TestAccount_getPeersByPolicy(t *testing.T) {
|
|||||||
},
|
},
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
|
{
|
||||||
|
ID: "RuleWorkflow",
|
||||||
|
Name: "Workflow",
|
||||||
|
Description: "No description",
|
||||||
|
Enabled: true,
|
||||||
|
Rules: []*types.PolicyRule{
|
||||||
|
{
|
||||||
|
ID: "RuleWorkflow",
|
||||||
|
Name: "Workflow",
|
||||||
|
Description: "No description",
|
||||||
|
Bidirectional: true,
|
||||||
|
Enabled: true,
|
||||||
|
Protocol: types.PolicyRuleProtocolTCP,
|
||||||
|
Action: types.PolicyTrafficActionAccept,
|
||||||
|
PortRanges: []types.RulePortRange{
|
||||||
|
{
|
||||||
|
Start: 8088,
|
||||||
|
End: 8088,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
Start: 9090,
|
||||||
|
End: 9095,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
Sources: []string{
|
||||||
|
"GroupWorkflow",
|
||||||
|
},
|
||||||
|
Destinations: []string{
|
||||||
|
"GroupDMZ",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -199,14 +246,14 @@ func TestAccount_getPeersByPolicy(t *testing.T) {
|
|||||||
|
|
||||||
t.Run("check that all peers get map", func(t *testing.T) {
|
t.Run("check that all peers get map", func(t *testing.T) {
|
||||||
for _, p := range account.Peers {
|
for _, p := range account.Peers {
|
||||||
peers, firewallRules := account.GetPeerConnectionResources(context.Background(), p.ID, validatedPeers)
|
peers, firewallRules := account.GetPeerConnectionResources(context.Background(), p, validatedPeers)
|
||||||
assert.GreaterOrEqual(t, len(peers), 2, "minimum number peers should present")
|
assert.GreaterOrEqual(t, len(peers), 1, "minimum number peers should present")
|
||||||
assert.GreaterOrEqual(t, len(firewallRules), 2, "minimum number of firewall rules should present")
|
assert.GreaterOrEqual(t, len(firewallRules), 1, "minimum number of firewall rules should present")
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
|
|
||||||
t.Run("check first peer map details", func(t *testing.T) {
|
t.Run("check first peer map details", func(t *testing.T) {
|
||||||
peers, firewallRules := account.GetPeerConnectionResources(context.Background(), "peerB", validatedPeers)
|
peers, firewallRules := account.GetPeerConnectionResources(context.Background(), account.Peers["peerB"], validatedPeers)
|
||||||
assert.Len(t, peers, 8)
|
assert.Len(t, peers, 8)
|
||||||
assert.Contains(t, peers, account.Peers["peerA"])
|
assert.Contains(t, peers, account.Peers["peerA"])
|
||||||
assert.Contains(t, peers, account.Peers["peerC"])
|
assert.Contains(t, peers, account.Peers["peerC"])
|
||||||
@ -364,6 +411,32 @@ func TestAccount_getPeersByPolicy(t *testing.T) {
|
|||||||
assert.True(t, contains, "rule not found in expected rules %#v", rule)
|
assert.True(t, contains, "rule not found in expected rules %#v", rule)
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
|
|
||||||
|
t.Run("check port ranges support for older peers", func(t *testing.T) {
|
||||||
|
peers, firewallRules := account.GetPeerConnectionResources(context.Background(), account.Peers["peerK"], validatedPeers)
|
||||||
|
assert.Len(t, peers, 1)
|
||||||
|
assert.Contains(t, peers, account.Peers["peerI"])
|
||||||
|
|
||||||
|
expectedFirewallRules := []*types.FirewallRule{
|
||||||
|
{
|
||||||
|
PeerIP: "100.65.31.2",
|
||||||
|
Direction: types.FirewallRuleDirectionIN,
|
||||||
|
Action: "accept",
|
||||||
|
Protocol: "tcp",
|
||||||
|
Port: "8088",
|
||||||
|
PolicyID: "RuleWorkflow",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
PeerIP: "100.65.31.2",
|
||||||
|
Direction: types.FirewallRuleDirectionOUT,
|
||||||
|
Action: "accept",
|
||||||
|
Protocol: "tcp",
|
||||||
|
Port: "8088",
|
||||||
|
PolicyID: "RuleWorkflow",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
assert.ElementsMatch(t, firewallRules, expectedFirewallRules)
|
||||||
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestAccount_getPeersByPolicyDirect(t *testing.T) {
|
func TestAccount_getPeersByPolicyDirect(t *testing.T) {
|
||||||
@ -466,10 +539,10 @@ func TestAccount_getPeersByPolicyDirect(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
t.Run("check first peer map", func(t *testing.T) {
|
t.Run("check first peer map", func(t *testing.T) {
|
||||||
peers, firewallRules := account.GetPeerConnectionResources(context.Background(), "peerB", approvedPeers)
|
peers, firewallRules := account.GetPeerConnectionResources(context.Background(), account.Peers["peerB"], approvedPeers)
|
||||||
assert.Contains(t, peers, account.Peers["peerC"])
|
assert.Contains(t, peers, account.Peers["peerC"])
|
||||||
|
|
||||||
epectedFirewallRules := []*types.FirewallRule{
|
expectedFirewallRules := []*types.FirewallRule{
|
||||||
{
|
{
|
||||||
PeerIP: "100.65.254.139",
|
PeerIP: "100.65.254.139",
|
||||||
Direction: types.FirewallRuleDirectionIN,
|
Direction: types.FirewallRuleDirectionIN,
|
||||||
@ -487,19 +560,19 @@ func TestAccount_getPeersByPolicyDirect(t *testing.T) {
|
|||||||
PolicyID: "RuleSwarm",
|
PolicyID: "RuleSwarm",
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
assert.Len(t, firewallRules, len(epectedFirewallRules))
|
assert.Len(t, firewallRules, len(expectedFirewallRules))
|
||||||
slices.SortFunc(epectedFirewallRules, sortFunc())
|
slices.SortFunc(expectedFirewallRules, sortFunc())
|
||||||
slices.SortFunc(firewallRules, sortFunc())
|
slices.SortFunc(firewallRules, sortFunc())
|
||||||
for i := range firewallRules {
|
for i := range firewallRules {
|
||||||
assert.Equal(t, epectedFirewallRules[i], firewallRules[i])
|
assert.Equal(t, expectedFirewallRules[i], firewallRules[i])
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
|
|
||||||
t.Run("check second peer map", func(t *testing.T) {
|
t.Run("check second peer map", func(t *testing.T) {
|
||||||
peers, firewallRules := account.GetPeerConnectionResources(context.Background(), "peerC", approvedPeers)
|
peers, firewallRules := account.GetPeerConnectionResources(context.Background(), account.Peers["peerC"], approvedPeers)
|
||||||
assert.Contains(t, peers, account.Peers["peerB"])
|
assert.Contains(t, peers, account.Peers["peerB"])
|
||||||
|
|
||||||
epectedFirewallRules := []*types.FirewallRule{
|
expectedFirewallRules := []*types.FirewallRule{
|
||||||
{
|
{
|
||||||
PeerIP: "100.65.80.39",
|
PeerIP: "100.65.80.39",
|
||||||
Direction: types.FirewallRuleDirectionIN,
|
Direction: types.FirewallRuleDirectionIN,
|
||||||
@ -517,21 +590,21 @@ func TestAccount_getPeersByPolicyDirect(t *testing.T) {
|
|||||||
PolicyID: "RuleSwarm",
|
PolicyID: "RuleSwarm",
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
assert.Len(t, firewallRules, len(epectedFirewallRules))
|
assert.Len(t, firewallRules, len(expectedFirewallRules))
|
||||||
slices.SortFunc(epectedFirewallRules, sortFunc())
|
slices.SortFunc(expectedFirewallRules, sortFunc())
|
||||||
slices.SortFunc(firewallRules, sortFunc())
|
slices.SortFunc(firewallRules, sortFunc())
|
||||||
for i := range firewallRules {
|
for i := range firewallRules {
|
||||||
assert.Equal(t, epectedFirewallRules[i], firewallRules[i])
|
assert.Equal(t, expectedFirewallRules[i], firewallRules[i])
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
|
|
||||||
account.Policies[1].Rules[0].Bidirectional = false
|
account.Policies[1].Rules[0].Bidirectional = false
|
||||||
|
|
||||||
t.Run("check first peer map directional only", func(t *testing.T) {
|
t.Run("check first peer map directional only", func(t *testing.T) {
|
||||||
peers, firewallRules := account.GetPeerConnectionResources(context.Background(), "peerB", approvedPeers)
|
peers, firewallRules := account.GetPeerConnectionResources(context.Background(), account.Peers["peerB"], approvedPeers)
|
||||||
assert.Contains(t, peers, account.Peers["peerC"])
|
assert.Contains(t, peers, account.Peers["peerC"])
|
||||||
|
|
||||||
epectedFirewallRules := []*types.FirewallRule{
|
expectedFirewallRules := []*types.FirewallRule{
|
||||||
{
|
{
|
||||||
PeerIP: "100.65.254.139",
|
PeerIP: "100.65.254.139",
|
||||||
Direction: types.FirewallRuleDirectionOUT,
|
Direction: types.FirewallRuleDirectionOUT,
|
||||||
@ -541,19 +614,19 @@ func TestAccount_getPeersByPolicyDirect(t *testing.T) {
|
|||||||
PolicyID: "RuleSwarm",
|
PolicyID: "RuleSwarm",
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
assert.Len(t, firewallRules, len(epectedFirewallRules))
|
assert.Len(t, firewallRules, len(expectedFirewallRules))
|
||||||
slices.SortFunc(epectedFirewallRules, sortFunc())
|
slices.SortFunc(expectedFirewallRules, sortFunc())
|
||||||
slices.SortFunc(firewallRules, sortFunc())
|
slices.SortFunc(firewallRules, sortFunc())
|
||||||
for i := range firewallRules {
|
for i := range firewallRules {
|
||||||
assert.Equal(t, epectedFirewallRules[i], firewallRules[i])
|
assert.Equal(t, expectedFirewallRules[i], firewallRules[i])
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
|
|
||||||
t.Run("check second peer map directional only", func(t *testing.T) {
|
t.Run("check second peer map directional only", func(t *testing.T) {
|
||||||
peers, firewallRules := account.GetPeerConnectionResources(context.Background(), "peerC", approvedPeers)
|
peers, firewallRules := account.GetPeerConnectionResources(context.Background(), account.Peers["peerC"], approvedPeers)
|
||||||
assert.Contains(t, peers, account.Peers["peerB"])
|
assert.Contains(t, peers, account.Peers["peerB"])
|
||||||
|
|
||||||
epectedFirewallRules := []*types.FirewallRule{
|
expectedFirewallRules := []*types.FirewallRule{
|
||||||
{
|
{
|
||||||
PeerIP: "100.65.80.39",
|
PeerIP: "100.65.80.39",
|
||||||
Direction: types.FirewallRuleDirectionIN,
|
Direction: types.FirewallRuleDirectionIN,
|
||||||
@ -563,11 +636,11 @@ func TestAccount_getPeersByPolicyDirect(t *testing.T) {
|
|||||||
PolicyID: "RuleSwarm",
|
PolicyID: "RuleSwarm",
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
assert.Len(t, firewallRules, len(epectedFirewallRules))
|
assert.Len(t, firewallRules, len(expectedFirewallRules))
|
||||||
slices.SortFunc(epectedFirewallRules, sortFunc())
|
slices.SortFunc(expectedFirewallRules, sortFunc())
|
||||||
slices.SortFunc(firewallRules, sortFunc())
|
slices.SortFunc(firewallRules, sortFunc())
|
||||||
for i := range firewallRules {
|
for i := range firewallRules {
|
||||||
assert.Equal(t, epectedFirewallRules[i], firewallRules[i])
|
assert.Equal(t, expectedFirewallRules[i], firewallRules[i])
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
@ -748,7 +821,7 @@ func TestAccount_getPeersByPolicyPostureChecks(t *testing.T) {
|
|||||||
t.Run("verify peer's network map with default group peer list", func(t *testing.T) {
|
t.Run("verify peer's network map with default group peer list", func(t *testing.T) {
|
||||||
// peerB doesn't fulfill the NB posture check but is included in the destination group Swarm,
|
// peerB doesn't fulfill the NB posture check but is included in the destination group Swarm,
|
||||||
// will establish a connection with all source peers satisfying the NB posture check.
|
// will establish a connection with all source peers satisfying the NB posture check.
|
||||||
peers, firewallRules := account.GetPeerConnectionResources(context.Background(), "peerB", approvedPeers)
|
peers, firewallRules := account.GetPeerConnectionResources(context.Background(), account.Peers["peerB"], approvedPeers)
|
||||||
assert.Len(t, peers, 4)
|
assert.Len(t, peers, 4)
|
||||||
assert.Len(t, firewallRules, 4)
|
assert.Len(t, firewallRules, 4)
|
||||||
assert.Contains(t, peers, account.Peers["peerA"])
|
assert.Contains(t, peers, account.Peers["peerA"])
|
||||||
@ -758,7 +831,7 @@ func TestAccount_getPeersByPolicyPostureChecks(t *testing.T) {
|
|||||||
|
|
||||||
// peerC satisfy the NB posture check, should establish connection to all destination group peer's
|
// peerC satisfy the NB posture check, should establish connection to all destination group peer's
|
||||||
// We expect a single permissive firewall rule which all outgoing connections
|
// We expect a single permissive firewall rule which all outgoing connections
|
||||||
peers, firewallRules = account.GetPeerConnectionResources(context.Background(), "peerC", approvedPeers)
|
peers, firewallRules = account.GetPeerConnectionResources(context.Background(), account.Peers["peerC"], approvedPeers)
|
||||||
assert.Len(t, peers, len(account.Groups["GroupSwarm"].Peers))
|
assert.Len(t, peers, len(account.Groups["GroupSwarm"].Peers))
|
||||||
assert.Len(t, firewallRules, 1)
|
assert.Len(t, firewallRules, 1)
|
||||||
expectedFirewallRules := []*types.FirewallRule{
|
expectedFirewallRules := []*types.FirewallRule{
|
||||||
@ -775,7 +848,7 @@ func TestAccount_getPeersByPolicyPostureChecks(t *testing.T) {
|
|||||||
|
|
||||||
// peerE doesn't fulfill the NB posture check and exists in only destination group Swarm,
|
// peerE doesn't fulfill the NB posture check and exists in only destination group Swarm,
|
||||||
// all source group peers satisfying the NB posture check should establish connection
|
// all source group peers satisfying the NB posture check should establish connection
|
||||||
peers, firewallRules = account.GetPeerConnectionResources(context.Background(), "peerE", approvedPeers)
|
peers, firewallRules = account.GetPeerConnectionResources(context.Background(), account.Peers["peerE"], approvedPeers)
|
||||||
assert.Len(t, peers, 4)
|
assert.Len(t, peers, 4)
|
||||||
assert.Len(t, firewallRules, 4)
|
assert.Len(t, firewallRules, 4)
|
||||||
assert.Contains(t, peers, account.Peers["peerA"])
|
assert.Contains(t, peers, account.Peers["peerA"])
|
||||||
@ -785,7 +858,7 @@ func TestAccount_getPeersByPolicyPostureChecks(t *testing.T) {
|
|||||||
|
|
||||||
// peerI doesn't fulfill the OS version posture check and exists in only destination group Swarm,
|
// peerI doesn't fulfill the OS version posture check and exists in only destination group Swarm,
|
||||||
// all source group peers satisfying the NB posture check should establish connection
|
// all source group peers satisfying the NB posture check should establish connection
|
||||||
peers, firewallRules = account.GetPeerConnectionResources(context.Background(), "peerI", approvedPeers)
|
peers, firewallRules = account.GetPeerConnectionResources(context.Background(), account.Peers["peerI"], approvedPeers)
|
||||||
assert.Len(t, peers, 4)
|
assert.Len(t, peers, 4)
|
||||||
assert.Len(t, firewallRules, 4)
|
assert.Len(t, firewallRules, 4)
|
||||||
assert.Contains(t, peers, account.Peers["peerA"])
|
assert.Contains(t, peers, account.Peers["peerA"])
|
||||||
@ -800,19 +873,19 @@ func TestAccount_getPeersByPolicyPostureChecks(t *testing.T) {
|
|||||||
|
|
||||||
// peerB doesn't satisfy the NB posture check, and doesn't exist in destination group peer's
|
// peerB doesn't satisfy the NB posture check, and doesn't exist in destination group peer's
|
||||||
// no connection should be established to any peer of destination group
|
// no connection should be established to any peer of destination group
|
||||||
peers, firewallRules := account.GetPeerConnectionResources(context.Background(), "peerB", approvedPeers)
|
peers, firewallRules := account.GetPeerConnectionResources(context.Background(), account.Peers["peerB"], approvedPeers)
|
||||||
assert.Len(t, peers, 0)
|
assert.Len(t, peers, 0)
|
||||||
assert.Len(t, firewallRules, 0)
|
assert.Len(t, firewallRules, 0)
|
||||||
|
|
||||||
// peerI doesn't satisfy the OS version posture check, and doesn't exist in destination group peer's
|
// peerI doesn't satisfy the OS version posture check, and doesn't exist in destination group peer's
|
||||||
// no connection should be established to any peer of destination group
|
// no connection should be established to any peer of destination group
|
||||||
peers, firewallRules = account.GetPeerConnectionResources(context.Background(), "peerI", approvedPeers)
|
peers, firewallRules = account.GetPeerConnectionResources(context.Background(), account.Peers["peerI"], approvedPeers)
|
||||||
assert.Len(t, peers, 0)
|
assert.Len(t, peers, 0)
|
||||||
assert.Len(t, firewallRules, 0)
|
assert.Len(t, firewallRules, 0)
|
||||||
|
|
||||||
// peerC satisfy the NB posture check, should establish connection to all destination group peer's
|
// peerC satisfy the NB posture check, should establish connection to all destination group peer's
|
||||||
// We expect a single permissive firewall rule which all outgoing connections
|
// We expect a single permissive firewall rule which all outgoing connections
|
||||||
peers, firewallRules = account.GetPeerConnectionResources(context.Background(), "peerC", approvedPeers)
|
peers, firewallRules = account.GetPeerConnectionResources(context.Background(), account.Peers["peerC"], approvedPeers)
|
||||||
assert.Len(t, peers, len(account.Groups["GroupSwarm"].Peers))
|
assert.Len(t, peers, len(account.Groups["GroupSwarm"].Peers))
|
||||||
assert.Len(t, firewallRules, len(account.Groups["GroupSwarm"].Peers))
|
assert.Len(t, firewallRules, len(account.Groups["GroupSwarm"].Peers))
|
||||||
|
|
||||||
@ -827,14 +900,14 @@ func TestAccount_getPeersByPolicyPostureChecks(t *testing.T) {
|
|||||||
|
|
||||||
// peerE doesn't fulfill the NB posture check and exists in only destination group Swarm,
|
// peerE doesn't fulfill the NB posture check and exists in only destination group Swarm,
|
||||||
// all source group peers satisfying the NB posture check should establish connection
|
// all source group peers satisfying the NB posture check should establish connection
|
||||||
peers, firewallRules = account.GetPeerConnectionResources(context.Background(), "peerE", approvedPeers)
|
peers, firewallRules = account.GetPeerConnectionResources(context.Background(), account.Peers["peerE"], approvedPeers)
|
||||||
assert.Len(t, peers, 3)
|
assert.Len(t, peers, 3)
|
||||||
assert.Len(t, firewallRules, 3)
|
assert.Len(t, firewallRules, 3)
|
||||||
assert.Contains(t, peers, account.Peers["peerA"])
|
assert.Contains(t, peers, account.Peers["peerA"])
|
||||||
assert.Contains(t, peers, account.Peers["peerC"])
|
assert.Contains(t, peers, account.Peers["peerC"])
|
||||||
assert.Contains(t, peers, account.Peers["peerD"])
|
assert.Contains(t, peers, account.Peers["peerD"])
|
||||||
|
|
||||||
peers, firewallRules = account.GetPeerConnectionResources(context.Background(), "peerA", approvedPeers)
|
peers, firewallRules = account.GetPeerConnectionResources(context.Background(), account.Peers["peerA"], approvedPeers)
|
||||||
assert.Len(t, peers, 5)
|
assert.Len(t, peers, 5)
|
||||||
// assert peers from Group Swarm
|
// assert peers from Group Swarm
|
||||||
assert.Contains(t, peers, account.Peers["peerD"])
|
assert.Contains(t, peers, account.Peers["peerD"])
|
||||||
|
@ -24,20 +24,12 @@ func sanitizeVersion(version string) string {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (n *NBVersionCheck) Check(ctx context.Context, peer nbpeer.Peer) (bool, error) {
|
func (n *NBVersionCheck) Check(ctx context.Context, peer nbpeer.Peer) (bool, error) {
|
||||||
peerVersion := sanitizeVersion(peer.Meta.WtVersion)
|
meetsMin, err := MeetsMinVersion(n.MinVersion, peer.Meta.WtVersion)
|
||||||
minVersion := sanitizeVersion(n.MinVersion)
|
|
||||||
|
|
||||||
peerNBVersion, err := version.NewVersion(peerVersion)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return false, err
|
return false, err
|
||||||
}
|
}
|
||||||
|
|
||||||
constraints, err := version.NewConstraint(">= " + minVersion)
|
if meetsMin {
|
||||||
if err != nil {
|
|
||||||
return false, err
|
|
||||||
}
|
|
||||||
|
|
||||||
if constraints.Check(peerNBVersion) {
|
|
||||||
return true, nil
|
return true, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -60,3 +52,21 @@ func (n *NBVersionCheck) Validate() error {
|
|||||||
}
|
}
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// MeetsMinVersion checks if the peer's version meets or exceeds the minimum required version
|
||||||
|
func MeetsMinVersion(minVer, peerVer string) (bool, error) {
|
||||||
|
peerVer = sanitizeVersion(peerVer)
|
||||||
|
minVer = sanitizeVersion(minVer)
|
||||||
|
|
||||||
|
peerNBVer, err := version.NewVersion(peerVer)
|
||||||
|
if err != nil {
|
||||||
|
return false, err
|
||||||
|
}
|
||||||
|
|
||||||
|
constraints, err := version.NewConstraint(">= " + minVer)
|
||||||
|
if err != nil {
|
||||||
|
return false, err
|
||||||
|
}
|
||||||
|
|
||||||
|
return constraints.Check(peerNBVer), nil
|
||||||
|
}
|
||||||
|
@ -139,3 +139,68 @@ func TestNBVersionCheck_Validate(t *testing.T) {
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestMeetsMinVersion(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
minVer string
|
||||||
|
peerVer string
|
||||||
|
want bool
|
||||||
|
wantErr bool
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "Peer version greater than min version",
|
||||||
|
minVer: "0.26.0",
|
||||||
|
peerVer: "0.60.1",
|
||||||
|
want: true,
|
||||||
|
wantErr: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Peer version equals min version",
|
||||||
|
minVer: "1.0.0",
|
||||||
|
peerVer: "1.0.0",
|
||||||
|
want: true,
|
||||||
|
wantErr: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Peer version less than min version",
|
||||||
|
minVer: "1.0.0",
|
||||||
|
peerVer: "0.9.9",
|
||||||
|
want: false,
|
||||||
|
wantErr: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Peer version with pre-release tag greater than min version",
|
||||||
|
minVer: "1.0.0",
|
||||||
|
peerVer: "1.0.1-alpha",
|
||||||
|
want: true,
|
||||||
|
wantErr: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Invalid peer version format",
|
||||||
|
minVer: "1.0.0",
|
||||||
|
peerVer: "dev",
|
||||||
|
want: false,
|
||||||
|
wantErr: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Invalid min version format",
|
||||||
|
minVer: "invalid.version",
|
||||||
|
peerVer: "1.0.0",
|
||||||
|
want: false,
|
||||||
|
wantErr: true,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
got, err := MeetsMinVersion(tt.minVer, tt.peerVer)
|
||||||
|
if tt.wantErr {
|
||||||
|
assert.Error(t, err)
|
||||||
|
} else {
|
||||||
|
assert.NoError(t, err)
|
||||||
|
}
|
||||||
|
assert.Equal(t, tt.want, got)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
@ -4,19 +4,19 @@ import (
|
|||||||
"context"
|
"context"
|
||||||
"fmt"
|
"fmt"
|
||||||
"net/netip"
|
"net/netip"
|
||||||
|
"slices"
|
||||||
"unicode/utf8"
|
"unicode/utf8"
|
||||||
|
|
||||||
"github.com/rs/xid"
|
"github.com/rs/xid"
|
||||||
|
|
||||||
"github.com/netbirdio/netbird/management/server/permissions/modules"
|
|
||||||
"github.com/netbirdio/netbird/management/server/permissions/operations"
|
|
||||||
"github.com/netbirdio/netbird/management/server/store"
|
|
||||||
"github.com/netbirdio/netbird/management/server/types"
|
|
||||||
|
|
||||||
"github.com/netbirdio/netbird/management/domain"
|
"github.com/netbirdio/netbird/management/domain"
|
||||||
"github.com/netbirdio/netbird/management/proto"
|
"github.com/netbirdio/netbird/management/proto"
|
||||||
"github.com/netbirdio/netbird/management/server/activity"
|
"github.com/netbirdio/netbird/management/server/activity"
|
||||||
|
"github.com/netbirdio/netbird/management/server/permissions/modules"
|
||||||
|
"github.com/netbirdio/netbird/management/server/permissions/operations"
|
||||||
"github.com/netbirdio/netbird/management/server/status"
|
"github.com/netbirdio/netbird/management/server/status"
|
||||||
|
"github.com/netbirdio/netbird/management/server/store"
|
||||||
|
"github.com/netbirdio/netbird/management/server/types"
|
||||||
"github.com/netbirdio/netbird/route"
|
"github.com/netbirdio/netbird/route"
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -30,13 +30,19 @@ func (am *DefaultAccountManager) GetRoute(ctx context.Context, accountID string,
|
|||||||
return nil, status.NewPermissionDeniedError()
|
return nil, status.NewPermissionDeniedError()
|
||||||
}
|
}
|
||||||
|
|
||||||
return am.Store.GetRouteByID(ctx, store.LockingStrengthShare, string(routeID), accountID)
|
return am.Store.GetRouteByID(ctx, store.LockingStrengthShare, accountID, string(routeID))
|
||||||
}
|
}
|
||||||
|
|
||||||
// checkRoutePrefixOrDomainsExistForPeers checks if a route with a given prefix exists for a single peer or multiple peer groups.
|
// checkRoutePrefixOrDomainsExistForPeers checks if a route with a given prefix exists for a single peer or multiple peer groups.
|
||||||
func (am *DefaultAccountManager) checkRoutePrefixOrDomainsExistForPeers(account *types.Account, peerID string, routeID route.ID, peerGroupIDs []string, prefix netip.Prefix, domains domain.List) error {
|
func checkRoutePrefixOrDomainsExistForPeers(ctx context.Context, transaction store.Store, accountID string, checkRoute *route.Route, groupsMap map[string]*types.Group) error {
|
||||||
// routes can have both peer and peer_groups
|
// routes can have both peer and peer_groups
|
||||||
routesWithPrefix := account.GetRoutesByPrefixOrDomains(prefix, domains)
|
prefix := checkRoute.Network
|
||||||
|
domains := checkRoute.Domains
|
||||||
|
|
||||||
|
routesWithPrefix, err := getRoutesByPrefixOrDomains(ctx, transaction, accountID, prefix, domains)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
// lets remember all the peers and the peer groups from routesWithPrefix
|
// lets remember all the peers and the peer groups from routesWithPrefix
|
||||||
seenPeers := make(map[string]bool)
|
seenPeers := make(map[string]bool)
|
||||||
@ -45,18 +51,24 @@ func (am *DefaultAccountManager) checkRoutePrefixOrDomainsExistForPeers(account
|
|||||||
for _, prefixRoute := range routesWithPrefix {
|
for _, prefixRoute := range routesWithPrefix {
|
||||||
// we skip route(s) with the same network ID as we want to allow updating of the existing route
|
// we skip route(s) with the same network ID as we want to allow updating of the existing route
|
||||||
// when creating a new route routeID is newly generated so nothing will be skipped
|
// when creating a new route routeID is newly generated so nothing will be skipped
|
||||||
if routeID == prefixRoute.ID {
|
if checkRoute.ID == prefixRoute.ID {
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
if prefixRoute.Peer != "" {
|
if prefixRoute.Peer != "" {
|
||||||
seenPeers[string(prefixRoute.ID)] = true
|
seenPeers[string(prefixRoute.ID)] = true
|
||||||
}
|
}
|
||||||
|
|
||||||
|
peerGroupsMap, err := transaction.GetGroupsByIDs(ctx, store.LockingStrengthShare, accountID, prefixRoute.PeerGroups)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
for _, groupID := range prefixRoute.PeerGroups {
|
for _, groupID := range prefixRoute.PeerGroups {
|
||||||
seenPeerGroups[groupID] = true
|
seenPeerGroups[groupID] = true
|
||||||
|
|
||||||
group := account.GetGroup(groupID)
|
group, ok := peerGroupsMap[groupID]
|
||||||
if group == nil {
|
if !ok || group == nil {
|
||||||
return status.Errorf(
|
return status.Errorf(
|
||||||
status.InvalidArgument, "failed to add route with %s - peer group %s doesn't exist",
|
status.InvalidArgument, "failed to add route with %s - peer group %s doesn't exist",
|
||||||
getRouteDescriptor(prefix, domains), groupID,
|
getRouteDescriptor(prefix, domains), groupID,
|
||||||
@ -69,12 +81,13 @@ func (am *DefaultAccountManager) checkRoutePrefixOrDomainsExistForPeers(account
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if peerID != "" {
|
if peerID := checkRoute.Peer; peerID != "" {
|
||||||
// check that peerID exists and is not in any route as single peer or part of the group
|
// check that peerID exists and is not in any route as single peer or part of the group
|
||||||
peer := account.GetPeer(peerID)
|
_, err = transaction.GetPeerByID(context.Background(), store.LockingStrengthShare, accountID, peerID)
|
||||||
if peer == nil {
|
if err != nil {
|
||||||
return status.Errorf(status.InvalidArgument, "peer with ID %s not found", peerID)
|
return status.Errorf(status.InvalidArgument, "peer with ID %s not found", peerID)
|
||||||
}
|
}
|
||||||
|
|
||||||
if _, ok := seenPeers[peerID]; ok {
|
if _, ok := seenPeers[peerID]; ok {
|
||||||
return status.Errorf(status.AlreadyExists,
|
return status.Errorf(status.AlreadyExists,
|
||||||
"failed to add route with %s - peer %s already has this route", getRouteDescriptor(prefix, domains), peerID)
|
"failed to add route with %s - peer %s already has this route", getRouteDescriptor(prefix, domains), peerID)
|
||||||
@ -82,9 +95,8 @@ func (am *DefaultAccountManager) checkRoutePrefixOrDomainsExistForPeers(account
|
|||||||
}
|
}
|
||||||
|
|
||||||
// check that peerGroupIDs are not in any route peerGroups list
|
// check that peerGroupIDs are not in any route peerGroups list
|
||||||
for _, groupID := range peerGroupIDs {
|
for _, groupID := range checkRoute.PeerGroups {
|
||||||
group := account.GetGroup(groupID) // we validated the group existence before entering this function, no need to check again.
|
group := groupsMap[groupID] // we validated the group existence before entering this function, no need to check again.
|
||||||
|
|
||||||
if _, ok := seenPeerGroups[groupID]; ok {
|
if _, ok := seenPeerGroups[groupID]; ok {
|
||||||
return status.Errorf(
|
return status.Errorf(
|
||||||
status.AlreadyExists, "failed to add route with %s - peer group %s already has this route",
|
status.AlreadyExists, "failed to add route with %s - peer group %s already has this route",
|
||||||
@ -92,12 +104,18 @@ func (am *DefaultAccountManager) checkRoutePrefixOrDomainsExistForPeers(account
|
|||||||
}
|
}
|
||||||
|
|
||||||
// check that the peers from peerGroupIDs groups are not the same peers we saw in routesWithPrefix
|
// check that the peers from peerGroupIDs groups are not the same peers we saw in routesWithPrefix
|
||||||
|
peersMap, err := transaction.GetPeersByIDs(ctx, store.LockingStrengthShare, accountID, group.Peers)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
for _, id := range group.Peers {
|
for _, id := range group.Peers {
|
||||||
if _, ok := seenPeers[id]; ok {
|
if _, ok := seenPeers[id]; ok {
|
||||||
peer := account.GetPeer(id)
|
peer, ok := peersMap[id]
|
||||||
if peer == nil {
|
if !ok || peer == nil {
|
||||||
return status.Errorf(status.InvalidArgument, "peer with ID %s not found", peerID)
|
return status.Errorf(status.InvalidArgument, "peer with ID %s not found", id)
|
||||||
}
|
}
|
||||||
|
|
||||||
return status.Errorf(status.AlreadyExists,
|
return status.Errorf(status.AlreadyExists,
|
||||||
"failed to add route with %s - peer %s from the group %s already has this route",
|
"failed to add route with %s - peer %s from the group %s already has this route",
|
||||||
getRouteDescriptor(prefix, domains), peer.Name, group.Name)
|
getRouteDescriptor(prefix, domains), peer.Name, group.Name)
|
||||||
@ -128,97 +146,58 @@ func (am *DefaultAccountManager) CreateRoute(ctx context.Context, accountID stri
|
|||||||
return nil, status.NewPermissionDeniedError()
|
return nil, status.NewPermissionDeniedError()
|
||||||
}
|
}
|
||||||
|
|
||||||
account, err := am.Store.GetAccount(ctx, accountID)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
if len(domains) > 0 && prefix.IsValid() {
|
if len(domains) > 0 && prefix.IsValid() {
|
||||||
return nil, status.Errorf(status.InvalidArgument, "domains and network should not be provided at the same time")
|
return nil, status.Errorf(status.InvalidArgument, "domains and network should not be provided at the same time")
|
||||||
}
|
}
|
||||||
|
|
||||||
if len(domains) == 0 && !prefix.IsValid() {
|
var newRoute *route.Route
|
||||||
return nil, status.Errorf(status.InvalidArgument, "invalid Prefix")
|
var updateAccountPeers bool
|
||||||
|
|
||||||
|
err = am.Store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
|
||||||
|
newRoute = &route.Route{
|
||||||
|
ID: route.ID(xid.New().String()),
|
||||||
|
AccountID: accountID,
|
||||||
|
Network: prefix,
|
||||||
|
Domains: domains,
|
||||||
|
KeepRoute: keepRoute,
|
||||||
|
NetID: netID,
|
||||||
|
Description: description,
|
||||||
|
Peer: peerID,
|
||||||
|
PeerGroups: peerGroupIDs,
|
||||||
|
NetworkType: networkType,
|
||||||
|
Masquerade: masquerade,
|
||||||
|
Metric: metric,
|
||||||
|
Enabled: enabled,
|
||||||
|
Groups: groups,
|
||||||
|
AccessControlGroups: accessControlGroupIDs,
|
||||||
}
|
}
|
||||||
|
|
||||||
if len(domains) > 0 {
|
if err = validateRoute(ctx, transaction, accountID, newRoute); err != nil {
|
||||||
prefix = getPlaceholderIP()
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
if peerID != "" && len(peerGroupIDs) != 0 {
|
updateAccountPeers, err = areRouteChangesAffectPeers(ctx, transaction, newRoute)
|
||||||
return nil, status.Errorf(
|
if err != nil {
|
||||||
status.InvalidArgument,
|
return err
|
||||||
"peer with ID %s and peers group %s should not be provided at the same time",
|
|
||||||
peerID, peerGroupIDs)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
var newRoute route.Route
|
if err = transaction.IncrementNetworkSerial(ctx, store.LockingStrengthUpdate, accountID); err != nil {
|
||||||
newRoute.ID = route.ID(xid.New().String())
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
if len(peerGroupIDs) > 0 {
|
return transaction.SaveRoute(ctx, store.LockingStrengthUpdate, newRoute)
|
||||||
err = validateGroups(peerGroupIDs, account.Groups)
|
})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
}
|
|
||||||
|
|
||||||
if len(accessControlGroupIDs) > 0 {
|
|
||||||
err = validateGroups(accessControlGroupIDs, account.Groups)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
err = am.checkRoutePrefixOrDomainsExistForPeers(account, peerID, newRoute.ID, peerGroupIDs, prefix, domains)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
if metric < route.MinMetric || metric > route.MaxMetric {
|
|
||||||
return nil, status.Errorf(status.InvalidArgument, "metric should be between %d and %d", route.MinMetric, route.MaxMetric)
|
|
||||||
}
|
|
||||||
|
|
||||||
if utf8.RuneCountInString(string(netID)) > route.MaxNetIDChar || netID == "" {
|
|
||||||
return nil, status.Errorf(status.InvalidArgument, "identifier should be between 1 and %d", route.MaxNetIDChar)
|
|
||||||
}
|
|
||||||
|
|
||||||
err = validateGroups(groups, account.Groups)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
newRoute.Peer = peerID
|
|
||||||
newRoute.PeerGroups = peerGroupIDs
|
|
||||||
newRoute.Network = prefix
|
|
||||||
newRoute.Domains = domains
|
|
||||||
newRoute.NetworkType = networkType
|
|
||||||
newRoute.Description = description
|
|
||||||
newRoute.NetID = netID
|
|
||||||
newRoute.Masquerade = masquerade
|
|
||||||
newRoute.Metric = metric
|
|
||||||
newRoute.Enabled = enabled
|
|
||||||
newRoute.Groups = groups
|
|
||||||
newRoute.KeepRoute = keepRoute
|
|
||||||
newRoute.AccessControlGroups = accessControlGroupIDs
|
|
||||||
|
|
||||||
if account.Routes == nil {
|
|
||||||
account.Routes = make(map[route.ID]*route.Route)
|
|
||||||
}
|
|
||||||
|
|
||||||
account.Routes[newRoute.ID] = &newRoute
|
|
||||||
|
|
||||||
account.Network.IncSerial()
|
|
||||||
if err = am.Store.SaveAccount(ctx, account); err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
if am.isRouteChangeAffectPeers(account, &newRoute) {
|
|
||||||
am.UpdateAccountPeers(ctx, accountID)
|
|
||||||
}
|
|
||||||
|
|
||||||
am.StoreEvent(ctx, userID, string(newRoute.ID), accountID, activity.RouteCreated, newRoute.EventMeta())
|
am.StoreEvent(ctx, userID, string(newRoute.ID), accountID, activity.RouteCreated, newRoute.EventMeta())
|
||||||
|
|
||||||
return &newRoute, nil
|
if updateAccountPeers {
|
||||||
|
am.UpdateAccountPeers(ctx, accountID)
|
||||||
|
}
|
||||||
|
|
||||||
|
return newRoute, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// SaveRoute saves route
|
// SaveRoute saves route
|
||||||
@ -226,6 +205,115 @@ func (am *DefaultAccountManager) SaveRoute(ctx context.Context, accountID, userI
|
|||||||
unlock := am.Store.AcquireWriteLockByUID(ctx, accountID)
|
unlock := am.Store.AcquireWriteLockByUID(ctx, accountID)
|
||||||
defer unlock()
|
defer unlock()
|
||||||
|
|
||||||
|
allowed, err := am.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Routes, operations.Update)
|
||||||
|
if err != nil {
|
||||||
|
return status.NewPermissionValidationError(err)
|
||||||
|
}
|
||||||
|
if !allowed {
|
||||||
|
return status.NewPermissionDeniedError()
|
||||||
|
}
|
||||||
|
|
||||||
|
var oldRoute *route.Route
|
||||||
|
var oldRouteAffectsPeers bool
|
||||||
|
var newRouteAffectsPeers bool
|
||||||
|
|
||||||
|
err = am.Store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
|
||||||
|
if err = validateRoute(ctx, transaction, accountID, routeToSave); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
oldRoute, err = transaction.GetRouteByID(ctx, store.LockingStrengthUpdate, accountID, string(routeToSave.ID))
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
oldRouteAffectsPeers, err = areRouteChangesAffectPeers(ctx, transaction, oldRoute)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
newRouteAffectsPeers, err = areRouteChangesAffectPeers(ctx, transaction, routeToSave)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
routeToSave.AccountID = accountID
|
||||||
|
|
||||||
|
if err = transaction.IncrementNetworkSerial(ctx, store.LockingStrengthUpdate, accountID); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
return transaction.SaveRoute(ctx, store.LockingStrengthUpdate, routeToSave)
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
am.StoreEvent(ctx, userID, string(routeToSave.ID), accountID, activity.RouteUpdated, routeToSave.EventMeta())
|
||||||
|
|
||||||
|
if oldRouteAffectsPeers || newRouteAffectsPeers {
|
||||||
|
am.UpdateAccountPeers(ctx, accountID)
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// DeleteRoute deletes route with routeID
|
||||||
|
func (am *DefaultAccountManager) DeleteRoute(ctx context.Context, accountID string, routeID route.ID, userID string) error {
|
||||||
|
unlock := am.Store.AcquireWriteLockByUID(ctx, accountID)
|
||||||
|
defer unlock()
|
||||||
|
|
||||||
|
allowed, err := am.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Routes, operations.Delete)
|
||||||
|
if err != nil {
|
||||||
|
return status.NewPermissionValidationError(err)
|
||||||
|
}
|
||||||
|
if !allowed {
|
||||||
|
return status.NewPermissionDeniedError()
|
||||||
|
}
|
||||||
|
|
||||||
|
var route *route.Route
|
||||||
|
var updateAccountPeers bool
|
||||||
|
|
||||||
|
err = am.Store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
|
||||||
|
route, err = transaction.GetRouteByID(ctx, store.LockingStrengthUpdate, accountID, string(routeID))
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
updateAccountPeers, err = areRouteChangesAffectPeers(ctx, transaction, route)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
if err = transaction.IncrementNetworkSerial(ctx, store.LockingStrengthUpdate, accountID); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
return transaction.DeleteRoute(ctx, store.LockingStrengthUpdate, accountID, string(routeID))
|
||||||
|
})
|
||||||
|
|
||||||
|
am.StoreEvent(ctx, userID, string(route.ID), accountID, activity.RouteRemoved, route.EventMeta())
|
||||||
|
|
||||||
|
if updateAccountPeers {
|
||||||
|
am.UpdateAccountPeers(ctx, accountID)
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// ListRoutes returns a list of routes from account
|
||||||
|
func (am *DefaultAccountManager) ListRoutes(ctx context.Context, accountID, userID string) ([]*route.Route, error) {
|
||||||
|
allowed, err := am.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Routes, operations.Read)
|
||||||
|
if err != nil {
|
||||||
|
return nil, status.NewPermissionValidationError(err)
|
||||||
|
}
|
||||||
|
if !allowed {
|
||||||
|
return nil, status.NewPermissionDeniedError()
|
||||||
|
}
|
||||||
|
|
||||||
|
return am.Store.GetAccountRoutes(ctx, store.LockingStrengthShare, accountID)
|
||||||
|
}
|
||||||
|
|
||||||
|
func validateRoute(ctx context.Context, transaction store.Store, accountID string, routeToSave *route.Route) error {
|
||||||
if routeToSave == nil {
|
if routeToSave == nil {
|
||||||
return status.Errorf(status.InvalidArgument, "route provided is nil")
|
return status.Errorf(status.InvalidArgument, "route provided is nil")
|
||||||
}
|
}
|
||||||
@ -238,19 +326,6 @@ func (am *DefaultAccountManager) SaveRoute(ctx context.Context, accountID, userI
|
|||||||
return status.Errorf(status.InvalidArgument, "identifier should be between 1 and %d", route.MaxNetIDChar)
|
return status.Errorf(status.InvalidArgument, "identifier should be between 1 and %d", route.MaxNetIDChar)
|
||||||
}
|
}
|
||||||
|
|
||||||
allowed, err := am.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Routes, operations.Update)
|
|
||||||
if err != nil {
|
|
||||||
return status.NewPermissionValidationError(err)
|
|
||||||
}
|
|
||||||
if !allowed {
|
|
||||||
return status.NewPermissionDeniedError()
|
|
||||||
}
|
|
||||||
|
|
||||||
account, err := am.Store.GetAccount(ctx, accountID)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
if len(routeToSave.Domains) > 0 && routeToSave.Network.IsValid() {
|
if len(routeToSave.Domains) > 0 && routeToSave.Network.IsValid() {
|
||||||
return status.Errorf(status.InvalidArgument, "domains and network should not be provided at the same time")
|
return status.Errorf(status.InvalidArgument, "domains and network should not be provided at the same time")
|
||||||
}
|
}
|
||||||
@ -267,96 +342,39 @@ func (am *DefaultAccountManager) SaveRoute(ctx context.Context, accountID, userI
|
|||||||
return status.Errorf(status.InvalidArgument, "peer with ID and peer groups should not be provided at the same time")
|
return status.Errorf(status.InvalidArgument, "peer with ID and peer groups should not be provided at the same time")
|
||||||
}
|
}
|
||||||
|
|
||||||
if len(routeToSave.PeerGroups) > 0 {
|
groupsMap, err := validateRouteGroups(ctx, transaction, accountID, routeToSave)
|
||||||
err = validateGroups(routeToSave.PeerGroups, account.Groups)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
return checkRoutePrefixOrDomainsExistForPeers(ctx, transaction, accountID, routeToSave, groupsMap)
|
||||||
|
}
|
||||||
|
|
||||||
|
// validateRouteGroups validates the route groups and returns the validated groups map.
|
||||||
|
func validateRouteGroups(ctx context.Context, transaction store.Store, accountID string, routeToSave *route.Route) (map[string]*types.Group, error) {
|
||||||
|
groupsToValidate := slices.Concat(routeToSave.Groups, routeToSave.PeerGroups, routeToSave.AccessControlGroups)
|
||||||
|
groupsMap, err := transaction.GetGroupsByIDs(ctx, store.LockingStrengthShare, accountID, groupsToValidate)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(routeToSave.PeerGroups) > 0 {
|
||||||
|
if err = validateGroups(routeToSave.PeerGroups, groupsMap); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if len(routeToSave.AccessControlGroups) > 0 {
|
if len(routeToSave.AccessControlGroups) > 0 {
|
||||||
err = validateGroups(routeToSave.AccessControlGroups, account.Groups)
|
if err = validateGroups(routeToSave.AccessControlGroups, groupsMap); err != nil {
|
||||||
if err != nil {
|
return nil, err
|
||||||
return err
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
err = am.checkRoutePrefixOrDomainsExistForPeers(account, routeToSave.Peer, routeToSave.ID, routeToSave.Copy().PeerGroups, routeToSave.Network, routeToSave.Domains)
|
if err = validateGroups(routeToSave.Groups, groupsMap); err != nil {
|
||||||
if err != nil {
|
return nil, err
|
||||||
return err
|
|
||||||
}
|
}
|
||||||
|
|
||||||
err = validateGroups(routeToSave.Groups, account.Groups)
|
return groupsMap, nil
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
oldRoute := account.Routes[routeToSave.ID]
|
|
||||||
account.Routes[routeToSave.ID] = routeToSave
|
|
||||||
|
|
||||||
account.Network.IncSerial()
|
|
||||||
if err = am.Store.SaveAccount(ctx, account); err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
if am.isRouteChangeAffectPeers(account, oldRoute) || am.isRouteChangeAffectPeers(account, routeToSave) {
|
|
||||||
am.UpdateAccountPeers(ctx, accountID)
|
|
||||||
}
|
|
||||||
|
|
||||||
am.StoreEvent(ctx, userID, string(routeToSave.ID), accountID, activity.RouteUpdated, routeToSave.EventMeta())
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// DeleteRoute deletes route with routeID
|
|
||||||
func (am *DefaultAccountManager) DeleteRoute(ctx context.Context, accountID string, routeID route.ID, userID string) error {
|
|
||||||
unlock := am.Store.AcquireWriteLockByUID(ctx, accountID)
|
|
||||||
defer unlock()
|
|
||||||
|
|
||||||
allowed, err := am.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Routes, operations.Delete)
|
|
||||||
if err != nil {
|
|
||||||
return status.NewPermissionValidationError(err)
|
|
||||||
}
|
|
||||||
if !allowed {
|
|
||||||
return status.NewPermissionDeniedError()
|
|
||||||
}
|
|
||||||
|
|
||||||
account, err := am.Store.GetAccount(ctx, accountID)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
routy := account.Routes[routeID]
|
|
||||||
if routy == nil {
|
|
||||||
return status.Errorf(status.NotFound, "route with ID %s doesn't exist", routeID)
|
|
||||||
}
|
|
||||||
delete(account.Routes, routeID)
|
|
||||||
|
|
||||||
account.Network.IncSerial()
|
|
||||||
if err = am.Store.SaveAccount(ctx, account); err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
am.StoreEvent(ctx, userID, string(routy.ID), accountID, activity.RouteRemoved, routy.EventMeta())
|
|
||||||
|
|
||||||
if am.isRouteChangeAffectPeers(account, routy) {
|
|
||||||
am.UpdateAccountPeers(ctx, accountID)
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// ListRoutes returns a list of routes from account
|
|
||||||
func (am *DefaultAccountManager) ListRoutes(ctx context.Context, accountID, userID string) ([]*route.Route, error) {
|
|
||||||
allowed, err := am.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Routes, operations.Read)
|
|
||||||
if err != nil {
|
|
||||||
return nil, status.NewPermissionValidationError(err)
|
|
||||||
}
|
|
||||||
if !allowed {
|
|
||||||
return nil, status.NewPermissionDeniedError()
|
|
||||||
}
|
|
||||||
|
|
||||||
return am.Store.GetAccountRoutes(ctx, store.LockingStrengthShare, accountID)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func toProtocolRoute(route *route.Route) *proto.Route {
|
func toProtocolRoute(route *route.Route) *proto.Route {
|
||||||
@ -455,8 +473,40 @@ func getProtoPortInfo(rule *types.RouteFirewallRule) *proto.PortInfo {
|
|||||||
return &portInfo
|
return &portInfo
|
||||||
}
|
}
|
||||||
|
|
||||||
// isRouteChangeAffectPeers checks if a given route affects peers by determining
|
// areRouteChangesAffectPeers checks if a given route affects peers by determining
|
||||||
// if it has a routing peer, distribution, or peer groups that include peers
|
// if it has a routing peer, distribution, or peer groups that include peers.
|
||||||
func (am *DefaultAccountManager) isRouteChangeAffectPeers(account *types.Account, route *route.Route) bool {
|
func areRouteChangesAffectPeers(ctx context.Context, transaction store.Store, route *route.Route) (bool, error) {
|
||||||
return am.anyGroupHasPeers(account, route.Groups) || am.anyGroupHasPeers(account, route.PeerGroups) || route.Peer != ""
|
if route.Peer != "" {
|
||||||
|
return true, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
hasPeers, err := anyGroupHasPeersOrResources(ctx, transaction, route.AccountID, route.Groups)
|
||||||
|
if err != nil {
|
||||||
|
return false, err
|
||||||
|
}
|
||||||
|
|
||||||
|
if hasPeers {
|
||||||
|
return true, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
return anyGroupHasPeersOrResources(ctx, transaction, route.AccountID, route.PeerGroups)
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetRoutesByPrefixOrDomains return list of routes by account and route prefix
|
||||||
|
func getRoutesByPrefixOrDomains(ctx context.Context, transaction store.Store, accountID string, prefix netip.Prefix, domains domain.List) ([]*route.Route, error) {
|
||||||
|
accountRoutes, err := transaction.GetAccountRoutes(ctx, store.LockingStrengthShare, accountID)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
routes := make([]*route.Route, 0)
|
||||||
|
for _, r := range accountRoutes {
|
||||||
|
dynamic := r.IsDynamic()
|
||||||
|
if dynamic && r.Domains.PunycodeString() == domains.PunycodeString() ||
|
||||||
|
!dynamic && r.Network.String() == prefix.String() {
|
||||||
|
routes = append(routes, r)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return routes, nil
|
||||||
}
|
}
|
||||||
|
@ -12,6 +12,7 @@ import (
|
|||||||
type Scheduler interface {
|
type Scheduler interface {
|
||||||
Cancel(ctx context.Context, IDs []string)
|
Cancel(ctx context.Context, IDs []string)
|
||||||
Schedule(ctx context.Context, in time.Duration, ID string, job func() (nextRunIn time.Duration, reschedule bool))
|
Schedule(ctx context.Context, in time.Duration, ID string, job func() (nextRunIn time.Duration, reschedule bool))
|
||||||
|
IsSchedulerRunning(ID string) bool
|
||||||
}
|
}
|
||||||
|
|
||||||
// MockScheduler is a mock implementation of Scheduler
|
// MockScheduler is a mock implementation of Scheduler
|
||||||
@ -26,7 +27,7 @@ func (mock *MockScheduler) Cancel(ctx context.Context, IDs []string) {
|
|||||||
mock.CancelFunc(ctx, IDs)
|
mock.CancelFunc(ctx, IDs)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
log.WithContext(ctx).Errorf("MockScheduler doesn't have Cancel function defined ")
|
log.WithContext(ctx).Warnf("MockScheduler doesn't have Cancel function defined ")
|
||||||
}
|
}
|
||||||
|
|
||||||
// Schedule mocks the Schedule function of the Scheduler interface
|
// Schedule mocks the Schedule function of the Scheduler interface
|
||||||
@ -35,7 +36,13 @@ func (mock *MockScheduler) Schedule(ctx context.Context, in time.Duration, ID st
|
|||||||
mock.ScheduleFunc(ctx, in, ID, job)
|
mock.ScheduleFunc(ctx, in, ID, job)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
log.WithContext(ctx).Errorf("MockScheduler doesn't have Schedule function defined")
|
log.WithContext(ctx).Warnf("MockScheduler doesn't have Schedule function defined")
|
||||||
|
}
|
||||||
|
|
||||||
|
func (mock *MockScheduler) IsSchedulerRunning(ID string) bool {
|
||||||
|
// MockScheduler does not implement IsSchedulerRunning, so we return false
|
||||||
|
log.Warnf("MockScheduler doesn't have IsSchedulerRunning function defined")
|
||||||
|
return false
|
||||||
}
|
}
|
||||||
|
|
||||||
// DefaultScheduler is a generic structure that allows to schedule jobs (functions) to run in the future and cancel them.
|
// DefaultScheduler is a generic structure that allows to schedule jobs (functions) to run in the future and cancel them.
|
||||||
@ -124,3 +131,11 @@ func (wm *DefaultScheduler) Schedule(ctx context.Context, in time.Duration, ID s
|
|||||||
|
|
||||||
}()
|
}()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// IsSchedulerRunning checks if a job with the provided ID is scheduled to run
|
||||||
|
func (wm *DefaultScheduler) IsSchedulerRunning(ID string) bool {
|
||||||
|
wm.mu.Lock()
|
||||||
|
defer wm.mu.Unlock()
|
||||||
|
_, ok := wm.jobs[ID]
|
||||||
|
return ok
|
||||||
|
}
|
||||||
|
@ -230,3 +230,7 @@ func NewUserRoleNotFoundError(role string) error {
|
|||||||
func NewOperationNotFoundError(operation operations.Operation) error {
|
func NewOperationNotFoundError(operation operations.Operation) error {
|
||||||
return Errorf(NotFound, "operation: %s not found", operation)
|
return Errorf(NotFound, "operation: %s not found", operation)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func NewRouteNotFoundError(routeID string) error {
|
||||||
|
return Errorf(NotFound, "route: %s not found", routeID)
|
||||||
|
}
|
||||||
|
@ -23,8 +23,6 @@ import (
|
|||||||
"gorm.io/gorm/clause"
|
"gorm.io/gorm/clause"
|
||||||
"gorm.io/gorm/logger"
|
"gorm.io/gorm/logger"
|
||||||
|
|
||||||
"github.com/netbirdio/netbird/management/server/util"
|
|
||||||
|
|
||||||
nbdns "github.com/netbirdio/netbird/dns"
|
nbdns "github.com/netbirdio/netbird/dns"
|
||||||
resourceTypes "github.com/netbirdio/netbird/management/server/networks/resources/types"
|
resourceTypes "github.com/netbirdio/netbird/management/server/networks/resources/types"
|
||||||
routerTypes "github.com/netbirdio/netbird/management/server/networks/routers/types"
|
routerTypes "github.com/netbirdio/netbird/management/server/networks/routers/types"
|
||||||
@ -34,6 +32,7 @@ import (
|
|||||||
"github.com/netbirdio/netbird/management/server/status"
|
"github.com/netbirdio/netbird/management/server/status"
|
||||||
"github.com/netbirdio/netbird/management/server/telemetry"
|
"github.com/netbirdio/netbird/management/server/telemetry"
|
||||||
"github.com/netbirdio/netbird/management/server/types"
|
"github.com/netbirdio/netbird/management/server/types"
|
||||||
|
"github.com/netbirdio/netbird/management/server/util"
|
||||||
"github.com/netbirdio/netbird/route"
|
"github.com/netbirdio/netbird/route"
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -1968,12 +1967,58 @@ func (s *SqlStore) DeletePostureChecks(ctx context.Context, lockStrength Locking
|
|||||||
|
|
||||||
// GetAccountRoutes retrieves network routes for an account.
|
// GetAccountRoutes retrieves network routes for an account.
|
||||||
func (s *SqlStore) GetAccountRoutes(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*route.Route, error) {
|
func (s *SqlStore) GetAccountRoutes(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*route.Route, error) {
|
||||||
return getRecords[*route.Route](s.db, lockStrength, accountID)
|
var routes []*route.Route
|
||||||
|
result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}).
|
||||||
|
Find(&routes, accountIDCondition, accountID)
|
||||||
|
if err := result.Error; err != nil {
|
||||||
|
log.WithContext(ctx).Errorf("failed to get routes from the store: %s", err)
|
||||||
|
return nil, status.Errorf(status.Internal, "failed to get routes from store")
|
||||||
|
}
|
||||||
|
|
||||||
|
return routes, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetRouteByID retrieves a route by its ID and account ID.
|
// GetRouteByID retrieves a route by its ID and account ID.
|
||||||
func (s *SqlStore) GetRouteByID(ctx context.Context, lockStrength LockingStrength, routeID string, accountID string) (*route.Route, error) {
|
func (s *SqlStore) GetRouteByID(ctx context.Context, lockStrength LockingStrength, accountID string, routeID string) (*route.Route, error) {
|
||||||
return getRecordByID[route.Route](s.db, lockStrength, routeID, accountID)
|
var route *route.Route
|
||||||
|
result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}).
|
||||||
|
First(&route, accountAndIDQueryCondition, accountID, routeID)
|
||||||
|
if err := result.Error; err != nil {
|
||||||
|
if errors.Is(err, gorm.ErrRecordNotFound) {
|
||||||
|
return nil, status.NewRouteNotFoundError(routeID)
|
||||||
|
}
|
||||||
|
log.WithContext(ctx).Errorf("failed to get route from the store: %s", err)
|
||||||
|
return nil, status.Errorf(status.Internal, "failed to get route from store")
|
||||||
|
}
|
||||||
|
|
||||||
|
return route, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// SaveRoute saves a route to the database.
|
||||||
|
func (s *SqlStore) SaveRoute(ctx context.Context, lockStrength LockingStrength, route *route.Route) error {
|
||||||
|
result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}).Save(route)
|
||||||
|
if err := result.Error; err != nil {
|
||||||
|
log.WithContext(ctx).Errorf("failed to save route to the store: %s", err)
|
||||||
|
return status.Errorf(status.Internal, "failed to save route to store")
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// DeleteRoute deletes a route from the database.
|
||||||
|
func (s *SqlStore) DeleteRoute(ctx context.Context, lockStrength LockingStrength, accountID, routeID string) error {
|
||||||
|
result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}).
|
||||||
|
Delete(&route.Route{}, accountAndIDQueryCondition, accountID, routeID)
|
||||||
|
if err := result.Error; err != nil {
|
||||||
|
log.WithContext(ctx).Errorf("failed to delete route from the store: %s", err)
|
||||||
|
return status.Errorf(status.Internal, "failed to delete route from store")
|
||||||
|
}
|
||||||
|
|
||||||
|
if result.RowsAffected == 0 {
|
||||||
|
return status.NewRouteNotFoundError(routeID)
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetAccountSetupKeys retrieves setup keys for an account.
|
// GetAccountSetupKeys retrieves setup keys for an account.
|
||||||
@ -2104,49 +2149,6 @@ func (s *SqlStore) DeleteNameServerGroup(ctx context.Context, lockStrength Locki
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// getRecords retrieves records from the database based on the account ID.
|
|
||||||
func getRecords[T any](db *gorm.DB, lockStrength LockingStrength, accountID string) ([]T, error) {
|
|
||||||
tx := db
|
|
||||||
if lockStrength != LockingStrengthNone {
|
|
||||||
tx = tx.Clauses(clause.Locking{Strength: string(lockStrength)})
|
|
||||||
}
|
|
||||||
|
|
||||||
var record []T
|
|
||||||
|
|
||||||
result := tx.Find(&record, accountIDCondition, accountID)
|
|
||||||
if err := result.Error; err != nil {
|
|
||||||
parts := strings.Split(fmt.Sprintf("%T", record), ".")
|
|
||||||
recordType := parts[len(parts)-1]
|
|
||||||
|
|
||||||
return nil, status.Errorf(status.Internal, "failed to get account %ss from store: %v", recordType, err)
|
|
||||||
}
|
|
||||||
|
|
||||||
return record, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// getRecordByID retrieves a record by its ID and account ID from the database.
|
|
||||||
func getRecordByID[T any](db *gorm.DB, lockStrength LockingStrength, recordID, accountID string) (*T, error) {
|
|
||||||
tx := db
|
|
||||||
if lockStrength != LockingStrengthNone {
|
|
||||||
tx = tx.Clauses(clause.Locking{Strength: string(lockStrength)})
|
|
||||||
}
|
|
||||||
|
|
||||||
var record T
|
|
||||||
|
|
||||||
result := tx.Clauses(clause.Locking{Strength: string(lockStrength)}).
|
|
||||||
First(&record, accountAndIDQueryCondition, accountID, recordID)
|
|
||||||
if err := result.Error; err != nil {
|
|
||||||
parts := strings.Split(fmt.Sprintf("%T", record), ".")
|
|
||||||
recordType := parts[len(parts)-1]
|
|
||||||
|
|
||||||
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
|
|
||||||
return nil, status.Errorf(status.NotFound, "%s not found", recordType)
|
|
||||||
}
|
|
||||||
return nil, status.Errorf(status.Internal, "failed to get %s from store: %v", recordType, err)
|
|
||||||
}
|
|
||||||
return &record, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// SaveDNSSettings saves the DNS settings to the store.
|
// SaveDNSSettings saves the DNS settings to the store.
|
||||||
func (s *SqlStore) SaveDNSSettings(ctx context.Context, lockStrength LockingStrength, accountID string, settings *types.DNSSettings) error {
|
func (s *SqlStore) SaveDNSSettings(ctx context.Context, lockStrength LockingStrength, accountID string, settings *types.DNSSettings) error {
|
||||||
result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}).Model(&types.Account{}).
|
result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}).Model(&types.Account{}).
|
||||||
|
@ -19,21 +19,17 @@ import (
|
|||||||
"github.com/stretchr/testify/assert"
|
"github.com/stretchr/testify/assert"
|
||||||
"github.com/stretchr/testify/require"
|
"github.com/stretchr/testify/require"
|
||||||
|
|
||||||
"github.com/netbirdio/netbird/management/server/util"
|
|
||||||
|
|
||||||
nbdns "github.com/netbirdio/netbird/dns"
|
nbdns "github.com/netbirdio/netbird/dns"
|
||||||
resourceTypes "github.com/netbirdio/netbird/management/server/networks/resources/types"
|
resourceTypes "github.com/netbirdio/netbird/management/server/networks/resources/types"
|
||||||
routerTypes "github.com/netbirdio/netbird/management/server/networks/routers/types"
|
routerTypes "github.com/netbirdio/netbird/management/server/networks/routers/types"
|
||||||
networkTypes "github.com/netbirdio/netbird/management/server/networks/types"
|
networkTypes "github.com/netbirdio/netbird/management/server/networks/types"
|
||||||
"github.com/netbirdio/netbird/management/server/posture"
|
|
||||||
"github.com/netbirdio/netbird/management/server/types"
|
|
||||||
|
|
||||||
route2 "github.com/netbirdio/netbird/route"
|
|
||||||
|
|
||||||
"github.com/netbirdio/netbird/management/server/status"
|
|
||||||
|
|
||||||
nbpeer "github.com/netbirdio/netbird/management/server/peer"
|
nbpeer "github.com/netbirdio/netbird/management/server/peer"
|
||||||
|
"github.com/netbirdio/netbird/management/server/posture"
|
||||||
|
"github.com/netbirdio/netbird/management/server/status"
|
||||||
|
"github.com/netbirdio/netbird/management/server/types"
|
||||||
|
"github.com/netbirdio/netbird/management/server/util"
|
||||||
nbroute "github.com/netbirdio/netbird/route"
|
nbroute "github.com/netbirdio/netbird/route"
|
||||||
|
route2 "github.com/netbirdio/netbird/route"
|
||||||
)
|
)
|
||||||
|
|
||||||
func runTestForAllEngines(t *testing.T, testDataFile string, f func(t *testing.T, store Store)) {
|
func runTestForAllEngines(t *testing.T, testDataFile string, f func(t *testing.T, store Store)) {
|
||||||
@ -3247,6 +3243,132 @@ func TestSqlStore_SaveGroups_LargeBatch(t *testing.T) {
|
|||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
require.Equal(t, 8003, len(accountGroups))
|
require.Equal(t, 8003, len(accountGroups))
|
||||||
}
|
}
|
||||||
|
func TestSqlStore_GetAccountRoutes(t *testing.T) {
|
||||||
|
store, cleanup, err := NewTestStoreFromSQL(context.Background(), "../testdata/extended-store.sql", t.TempDir())
|
||||||
|
t.Cleanup(cleanup)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
accountID string
|
||||||
|
expectedCount int
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "retrieve routes by existing account ID",
|
||||||
|
accountID: "bf1c8084-ba50-4ce7-9439-34653001fc3b",
|
||||||
|
expectedCount: 1,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "non-existing account ID",
|
||||||
|
accountID: "nonexistent",
|
||||||
|
expectedCount: 0,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "empty account ID",
|
||||||
|
accountID: "",
|
||||||
|
expectedCount: 0,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
routes, err := store.GetAccountRoutes(context.Background(), LockingStrengthShare, tt.accountID)
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.Len(t, routes, tt.expectedCount)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSqlStore_GetRouteByID(t *testing.T) {
|
||||||
|
store, cleanup, err := NewTestStoreFromSQL(context.Background(), "../testdata/extended-store.sql", t.TempDir())
|
||||||
|
t.Cleanup(cleanup)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
accountID := "bf1c8084-ba50-4ce7-9439-34653001fc3b"
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
routeID string
|
||||||
|
expectError bool
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "retrieve existing route",
|
||||||
|
routeID: "ct03t427qv97vmtmglog",
|
||||||
|
expectError: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "retrieve non-existing route",
|
||||||
|
routeID: "non-existing",
|
||||||
|
expectError: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "retrieve with empty route ID",
|
||||||
|
routeID: "",
|
||||||
|
expectError: true,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
route, err := store.GetRouteByID(context.Background(), LockingStrengthShare, accountID, tt.routeID)
|
||||||
|
if tt.expectError {
|
||||||
|
require.Error(t, err)
|
||||||
|
sErr, ok := status.FromError(err)
|
||||||
|
require.True(t, ok)
|
||||||
|
require.Equal(t, sErr.Type(), status.NotFound)
|
||||||
|
require.Nil(t, route)
|
||||||
|
} else {
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.NotNil(t, route)
|
||||||
|
require.Equal(t, tt.routeID, string(route.ID))
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSqlStore_SaveRoute(t *testing.T) {
|
||||||
|
store, cleanup, err := NewTestStoreFromSQL(context.Background(), "../testdata/extended-store.sql", t.TempDir())
|
||||||
|
t.Cleanup(cleanup)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
accountID := "bf1c8084-ba50-4ce7-9439-34653001fc3b"
|
||||||
|
|
||||||
|
route := &route2.Route{
|
||||||
|
ID: "route-id",
|
||||||
|
AccountID: accountID,
|
||||||
|
Network: netip.MustParsePrefix("10.10.0.0/16"),
|
||||||
|
NetID: "netID",
|
||||||
|
PeerGroups: []string{"routeA"},
|
||||||
|
NetworkType: route2.IPv4Network,
|
||||||
|
Masquerade: true,
|
||||||
|
Metric: 9999,
|
||||||
|
Enabled: true,
|
||||||
|
Groups: []string{"groupA"},
|
||||||
|
AccessControlGroups: []string{},
|
||||||
|
}
|
||||||
|
err = store.SaveRoute(context.Background(), LockingStrengthUpdate, route)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
saveRoute, err := store.GetRouteByID(context.Background(), LockingStrengthShare, accountID, string(route.ID))
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.Equal(t, route, saveRoute)
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSqlStore_DeleteRoute(t *testing.T) {
|
||||||
|
store, cleanup, err := NewTestStoreFromSQL(context.Background(), "../testdata/extended-store.sql", t.TempDir())
|
||||||
|
t.Cleanup(cleanup)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
accountID := "bf1c8084-ba50-4ce7-9439-34653001fc3b"
|
||||||
|
routeID := "ct03t427qv97vmtmglog"
|
||||||
|
|
||||||
|
err = store.DeleteRoute(context.Background(), LockingStrengthUpdate, accountID, routeID)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
route, err := store.GetRouteByID(context.Background(), LockingStrengthShare, accountID, routeID)
|
||||||
|
require.Error(t, err)
|
||||||
|
require.Nil(t, route)
|
||||||
|
}
|
||||||
|
|
||||||
func TestSqlStore_GetAccountMeta(t *testing.T) {
|
func TestSqlStore_GetAccountMeta(t *testing.T) {
|
||||||
store, cleanup, err := NewTestStoreFromSQL(context.Background(), "../testdata/extended-store.sql", t.TempDir())
|
store, cleanup, err := NewTestStoreFromSQL(context.Background(), "../testdata/extended-store.sql", t.TempDir())
|
||||||
|
@ -145,7 +145,9 @@ type Store interface {
|
|||||||
DeleteSetupKey(ctx context.Context, lockStrength LockingStrength, accountID, keyID string) error
|
DeleteSetupKey(ctx context.Context, lockStrength LockingStrength, accountID, keyID string) error
|
||||||
|
|
||||||
GetAccountRoutes(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*route.Route, error)
|
GetAccountRoutes(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*route.Route, error)
|
||||||
GetRouteByID(ctx context.Context, lockStrength LockingStrength, routeID string, accountID string) (*route.Route, error)
|
GetRouteByID(ctx context.Context, lockStrength LockingStrength, accountID, routeID string) (*route.Route, error)
|
||||||
|
SaveRoute(ctx context.Context, lockStrength LockingStrength, route *route.Route) error
|
||||||
|
DeleteRoute(ctx context.Context, lockStrength LockingStrength, accountID, routeID string) error
|
||||||
|
|
||||||
GetAccountNameServerGroups(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*dns.NameServerGroup, error)
|
GetAccountNameServerGroups(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*dns.NameServerGroup, error)
|
||||||
GetNameServerGroupByID(ctx context.Context, lockStrength LockingStrength, nameServerGroupID string, accountID string) (*dns.NameServerGroup, error)
|
GetNameServerGroupByID(ctx context.Context, lockStrength LockingStrength, nameServerGroupID string, accountID string) (*dns.NameServerGroup, error)
|
||||||
|
@ -38,4 +38,5 @@ INSERT INTO "groups" VALUES('cfefqs706sqkneg59g2g','bf1c8084-ba50-4ce7-9439-3465
|
|||||||
INSERT INTO posture_checks VALUES('csplshq7qv948l48f7t0','NetBird Version > 0.32.0','','bf1c8084-ba50-4ce7-9439-34653001fc3b','{"NBVersionCheck":{"MinVersion":"0.31.0"}}');
|
INSERT INTO posture_checks VALUES('csplshq7qv948l48f7t0','NetBird Version > 0.32.0','','bf1c8084-ba50-4ce7-9439-34653001fc3b','{"NBVersionCheck":{"MinVersion":"0.31.0"}}');
|
||||||
INSERT INTO posture_checks VALUES('cspnllq7qv95uq1r4k90','Allow Berlin and Deny local network 172.16.1.0/24','','bf1c8084-ba50-4ce7-9439-34653001fc3b','{"GeoLocationCheck":{"Locations":[{"CountryCode":"DE","CityName":"Berlin"}],"Action":"allow"},"PeerNetworkRangeCheck":{"Action":"deny","Ranges":["172.16.1.0/24"]}}');
|
INSERT INTO posture_checks VALUES('cspnllq7qv95uq1r4k90','Allow Berlin and Deny local network 172.16.1.0/24','','bf1c8084-ba50-4ce7-9439-34653001fc3b','{"GeoLocationCheck":{"Locations":[{"CountryCode":"DE","CityName":"Berlin"}],"Action":"allow"},"PeerNetworkRangeCheck":{"Action":"deny","Ranges":["172.16.1.0/24"]}}');
|
||||||
INSERT INTO name_server_groups VALUES('csqdelq7qv97ncu7d9t0','bf1c8084-ba50-4ce7-9439-34653001fc3b','Google DNS','Google DNS Servers','[{"IP":"8.8.8.8","NSType":1,"Port":53},{"IP":"8.8.4.4","NSType":1,"Port":53}]','["cfefqs706sqkneg59g2g"]',1,'[]',1,0);
|
INSERT INTO name_server_groups VALUES('csqdelq7qv97ncu7d9t0','bf1c8084-ba50-4ce7-9439-34653001fc3b','Google DNS','Google DNS Servers','[{"IP":"8.8.8.8","NSType":1,"Port":53},{"IP":"8.8.4.4","NSType":1,"Port":53}]','["cfefqs706sqkneg59g2g"]',1,'[]',1,0);
|
||||||
|
INSERT INTO routes VALUES('ct03t427qv97vmtmglog','bf1c8084-ba50-4ce7-9439-34653001fc3b','"10.10.0.0/16"',NULL,0,'aws-eu-central-1-vpc','Production VPC in Frankfurt','ct03r5q7qv97vmtmglng',NULL,1,1,9999,1,'["cfefqs706sqkneg59g2g"]',NULL);
|
||||||
INSERT INTO installations VALUES(1,'');
|
INSERT INTO installations VALUES(1,'');
|
||||||
|
@ -36,6 +36,9 @@ const (
|
|||||||
PublicCategory = "public"
|
PublicCategory = "public"
|
||||||
PrivateCategory = "private"
|
PrivateCategory = "private"
|
||||||
UnknownCategory = "unknown"
|
UnknownCategory = "unknown"
|
||||||
|
|
||||||
|
// firewallRuleMinPortRangesVer defines the minimum peer version that supports port range rules.
|
||||||
|
firewallRuleMinPortRangesVer = "0.48.0"
|
||||||
)
|
)
|
||||||
|
|
||||||
type LookupMap map[string]struct{}
|
type LookupMap map[string]struct{}
|
||||||
@ -248,7 +251,7 @@ func (a *Account) GetPeerNetworkMap(
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
aclPeers, firewallRules := a.GetPeerConnectionResources(ctx, peerID, validatedPeersMap)
|
aclPeers, firewallRules := a.GetPeerConnectionResources(ctx, peer, validatedPeersMap)
|
||||||
// exclude expired peers
|
// exclude expired peers
|
||||||
var peersToConnect []*nbpeer.Peer
|
var peersToConnect []*nbpeer.Peer
|
||||||
var expiredPeers []*nbpeer.Peer
|
var expiredPeers []*nbpeer.Peer
|
||||||
@ -961,8 +964,9 @@ func (a *Account) UserGroupsRemoveFromPeers(userID string, groups ...string) map
|
|||||||
// GetPeerConnectionResources for a given peer
|
// GetPeerConnectionResources for a given peer
|
||||||
//
|
//
|
||||||
// This function returns the list of peers and firewall rules that are applicable to a given peer.
|
// This function returns the list of peers and firewall rules that are applicable to a given peer.
|
||||||
func (a *Account) GetPeerConnectionResources(ctx context.Context, peerID string, validatedPeersMap map[string]struct{}) ([]*nbpeer.Peer, []*FirewallRule) {
|
func (a *Account) GetPeerConnectionResources(ctx context.Context, peer *nbpeer.Peer, validatedPeersMap map[string]struct{}) ([]*nbpeer.Peer, []*FirewallRule) {
|
||||||
generateResources, getAccumulatedResources := a.connResourcesGenerator(ctx)
|
generateResources, getAccumulatedResources := a.connResourcesGenerator(ctx, peer)
|
||||||
|
|
||||||
for _, policy := range a.Policies {
|
for _, policy := range a.Policies {
|
||||||
if !policy.Enabled {
|
if !policy.Enabled {
|
||||||
continue
|
continue
|
||||||
@ -973,8 +977,8 @@ func (a *Account) GetPeerConnectionResources(ctx context.Context, peerID string,
|
|||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
sourcePeers, peerInSources := a.getAllPeersFromGroups(ctx, rule.Sources, peerID, policy.SourcePostureChecks, validatedPeersMap)
|
sourcePeers, peerInSources := a.getAllPeersFromGroups(ctx, rule.Sources, peer.ID, policy.SourcePostureChecks, validatedPeersMap)
|
||||||
destinationPeers, peerInDestinations := a.getAllPeersFromGroups(ctx, rule.Destinations, peerID, nil, validatedPeersMap)
|
destinationPeers, peerInDestinations := a.getAllPeersFromGroups(ctx, rule.Destinations, peer.ID, nil, validatedPeersMap)
|
||||||
|
|
||||||
if rule.Bidirectional {
|
if rule.Bidirectional {
|
||||||
if peerInSources {
|
if peerInSources {
|
||||||
@ -1003,7 +1007,7 @@ func (a *Account) GetPeerConnectionResources(ctx context.Context, peerID string,
|
|||||||
// The generator function is used to generate the list of peers and firewall rules that are applicable to a given peer.
|
// The generator function is used to generate the list of peers and firewall rules that are applicable to a given peer.
|
||||||
// It safe to call the generator function multiple times for same peer and different rules no duplicates will be
|
// It safe to call the generator function multiple times for same peer and different rules no duplicates will be
|
||||||
// generated. The accumulator function returns the result of all the generator calls.
|
// generated. The accumulator function returns the result of all the generator calls.
|
||||||
func (a *Account) connResourcesGenerator(ctx context.Context) (func(*PolicyRule, []*nbpeer.Peer, int), func() ([]*nbpeer.Peer, []*FirewallRule)) {
|
func (a *Account) connResourcesGenerator(ctx context.Context, targetPeer *nbpeer.Peer) (func(*PolicyRule, []*nbpeer.Peer, int), func() ([]*nbpeer.Peer, []*FirewallRule)) {
|
||||||
rulesExists := make(map[string]struct{})
|
rulesExists := make(map[string]struct{})
|
||||||
peersExists := make(map[string]struct{})
|
peersExists := make(map[string]struct{})
|
||||||
rules := make([]*FirewallRule, 0)
|
rules := make([]*FirewallRule, 0)
|
||||||
@ -1051,17 +1055,7 @@ func (a *Account) connResourcesGenerator(ctx context.Context) (func(*PolicyRule,
|
|||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
for _, port := range rule.Ports {
|
rules = append(rules, expandPortsAndRanges(fr, rule, targetPeer)...)
|
||||||
pr := fr // clone rule and add set new port
|
|
||||||
pr.Port = port
|
|
||||||
rules = append(rules, &pr)
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, portRange := range rule.PortRanges {
|
|
||||||
pr := fr
|
|
||||||
pr.PortRange = portRange
|
|
||||||
rules = append(rules, &pr)
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
}, func() ([]*nbpeer.Peer, []*FirewallRule) {
|
}, func() ([]*nbpeer.Peer, []*FirewallRule) {
|
||||||
return peers, rules
|
return peers, rules
|
||||||
@ -1590,3 +1584,45 @@ func (a *Account) AddAllGroup() error {
|
|||||||
}
|
}
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// expandPortsAndRanges expands Ports and PortRanges of a rule into individual firewall rules
|
||||||
|
func expandPortsAndRanges(base FirewallRule, rule *PolicyRule, peer *nbpeer.Peer) []*FirewallRule {
|
||||||
|
var expanded []*FirewallRule
|
||||||
|
|
||||||
|
if len(rule.Ports) > 0 {
|
||||||
|
for _, port := range rule.Ports {
|
||||||
|
fr := base
|
||||||
|
fr.Port = port
|
||||||
|
expanded = append(expanded, &fr)
|
||||||
|
}
|
||||||
|
return expanded
|
||||||
|
}
|
||||||
|
|
||||||
|
supportPortRanges := peerSupportsPortRanges(peer.Meta.WtVersion)
|
||||||
|
for _, portRange := range rule.PortRanges {
|
||||||
|
fr := base
|
||||||
|
|
||||||
|
if supportPortRanges {
|
||||||
|
fr.PortRange = portRange
|
||||||
|
} else {
|
||||||
|
// Peer doesn't support port ranges, only allow single-port ranges
|
||||||
|
if portRange.Start != portRange.End {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
fr.Port = strconv.FormatUint(uint64(portRange.Start), 10)
|
||||||
|
}
|
||||||
|
expanded = append(expanded, &fr)
|
||||||
|
}
|
||||||
|
|
||||||
|
return expanded
|
||||||
|
}
|
||||||
|
|
||||||
|
// peerSupportsPortRanges checks if the peer version supports port ranges.
|
||||||
|
func peerSupportsPortRanges(peerVer string) bool {
|
||||||
|
if strings.Contains(peerVer, "dev") {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
meetMinVer, err := posture.MeetsMinVersion(firewallRuleMinPortRangesVer, peerVer)
|
||||||
|
return err == nil && meetMinVer
|
||||||
|
}
|
||||||
|
@ -76,7 +76,6 @@ func generateRouteFirewallRules(ctx context.Context, route *nbroute.Route, rule
|
|||||||
rules = append(rules, generateRulesWithPortRanges(baseRule, rule, rulesExists)...)
|
rules = append(rules, generateRulesWithPortRanges(baseRule, rule, rulesExists)...)
|
||||||
} else {
|
} else {
|
||||||
rules = append(rules, generateRulesWithPorts(ctx, baseRule, rule, rulesExists)...)
|
rules = append(rules, generateRulesWithPorts(ctx, baseRule, rule, rulesExists)...)
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// TODO: generate IPv6 rules for dynamic routes
|
// TODO: generate IPv6 rules for dynamic routes
|
||||||
|
Loading…
x
Reference in New Issue
Block a user