diff --git a/client/android/preferences.go b/client/android/preferences.go index 2a8b197e7..2d5668d1c 100644 --- a/client/android/preferences.go +++ b/client/android/preferences.go @@ -4,12 +4,12 @@ import ( "github.com/netbirdio/netbird/client/internal" ) -// Preferences export a subset of the internal config for gomobile +// Preferences exports a subset of the internal config for gomobile type Preferences struct { configInput internal.ConfigInput } -// NewPreferences create new Preferences instance +// NewPreferences creates a new Preferences instance func NewPreferences(configPath string) *Preferences { ci := internal.ConfigInput{ ConfigPath: configPath, @@ -17,7 +17,7 @@ func NewPreferences(configPath string) *Preferences { return &Preferences{ci} } -// GetManagementURL read url from config file +// GetManagementURL reads URL from config file func (p *Preferences) GetManagementURL() (string, error) { if p.configInput.ManagementURL != "" { return p.configInput.ManagementURL, nil @@ -30,12 +30,12 @@ func (p *Preferences) GetManagementURL() (string, error) { return cfg.ManagementURL.String(), err } -// SetManagementURL store the given url and wait for commit +// SetManagementURL stores the given URL and waits for commit func (p *Preferences) SetManagementURL(url string) { p.configInput.ManagementURL = url } -// GetAdminURL read url from config file +// GetAdminURL reads URL from config file func (p *Preferences) GetAdminURL() (string, error) { if p.configInput.AdminURL != "" { return p.configInput.AdminURL, nil @@ -48,12 +48,12 @@ func (p *Preferences) GetAdminURL() (string, error) { return cfg.AdminURL.String(), err } -// SetAdminURL store the given url and wait for commit +// SetAdminURL stores the given URL and waits for commit func (p *Preferences) SetAdminURL(url string) { p.configInput.AdminURL = url } -// GetPreSharedKey read preshared key from config file +// GetPreSharedKey reads pre-shared key from config file func (p *Preferences) GetPreSharedKey() (string, error) { if p.configInput.PreSharedKey != nil { return *p.configInput.PreSharedKey, nil @@ -66,17 +66,17 @@ func (p *Preferences) GetPreSharedKey() (string, error) { return cfg.PreSharedKey, err } -// SetPreSharedKey store the given key and wait for commit +// SetPreSharedKey stores the given key and waits for commit func (p *Preferences) SetPreSharedKey(key string) { p.configInput.PreSharedKey = &key } -// SetRosenpassEnabled store if rosenpass is enabled +// SetRosenpassEnabled stores whether Rosenpass is enabled func (p *Preferences) SetRosenpassEnabled(enabled bool) { p.configInput.RosenpassEnabled = &enabled } -// GetRosenpassEnabled read rosenpass enabled from config file +// GetRosenpassEnabled reads Rosenpass enabled status from config file func (p *Preferences) GetRosenpassEnabled() (bool, error) { if p.configInput.RosenpassEnabled != nil { return *p.configInput.RosenpassEnabled, nil @@ -89,12 +89,12 @@ func (p *Preferences) GetRosenpassEnabled() (bool, error) { return cfg.RosenpassEnabled, err } -// SetRosenpassPermissive store the given permissive and wait for commit +// SetRosenpassPermissive stores the given permissive setting and waits for commit func (p *Preferences) SetRosenpassPermissive(permissive bool) { p.configInput.RosenpassPermissive = &permissive } -// GetRosenpassPermissive read rosenpass permissive from config file +// GetRosenpassPermissive reads Rosenpass permissive setting from config file func (p *Preferences) GetRosenpassPermissive() (bool, error) { if p.configInput.RosenpassPermissive != nil { return *p.configInput.RosenpassPermissive, nil @@ -107,7 +107,119 @@ func (p *Preferences) GetRosenpassPermissive() (bool, error) { return cfg.RosenpassPermissive, err } -// Commit write out the changes into config file +// GetDisableClientRoutes reads disable client routes setting from config file +func (p *Preferences) GetDisableClientRoutes() (bool, error) { + if p.configInput.DisableClientRoutes != nil { + return *p.configInput.DisableClientRoutes, nil + } + + cfg, err := internal.ReadConfig(p.configInput.ConfigPath) + if err != nil { + return false, err + } + return cfg.DisableClientRoutes, err +} + +// SetDisableClientRoutes stores the given value and waits for commit +func (p *Preferences) SetDisableClientRoutes(disable bool) { + p.configInput.DisableClientRoutes = &disable +} + +// GetDisableServerRoutes reads disable server routes setting from config file +func (p *Preferences) GetDisableServerRoutes() (bool, error) { + if p.configInput.DisableServerRoutes != nil { + return *p.configInput.DisableServerRoutes, nil + } + + cfg, err := internal.ReadConfig(p.configInput.ConfigPath) + if err != nil { + return false, err + } + return cfg.DisableServerRoutes, err +} + +// SetDisableServerRoutes stores the given value and waits for commit +func (p *Preferences) SetDisableServerRoutes(disable bool) { + p.configInput.DisableServerRoutes = &disable +} + +// GetDisableDNS reads disable DNS setting from config file +func (p *Preferences) GetDisableDNS() (bool, error) { + if p.configInput.DisableDNS != nil { + return *p.configInput.DisableDNS, nil + } + + cfg, err := internal.ReadConfig(p.configInput.ConfigPath) + if err != nil { + return false, err + } + return cfg.DisableDNS, err +} + +// SetDisableDNS stores the given value and waits for commit +func (p *Preferences) SetDisableDNS(disable bool) { + p.configInput.DisableDNS = &disable +} + +// GetDisableFirewall reads disable firewall setting from config file +func (p *Preferences) GetDisableFirewall() (bool, error) { + if p.configInput.DisableFirewall != nil { + return *p.configInput.DisableFirewall, nil + } + + cfg, err := internal.ReadConfig(p.configInput.ConfigPath) + if err != nil { + return false, err + } + return cfg.DisableFirewall, err +} + +// SetDisableFirewall stores the given value and waits for commit +func (p *Preferences) SetDisableFirewall(disable bool) { + p.configInput.DisableFirewall = &disable +} + +// GetServerSSHAllowed reads server SSH allowed setting from config file +func (p *Preferences) GetServerSSHAllowed() (bool, error) { + if p.configInput.ServerSSHAllowed != nil { + return *p.configInput.ServerSSHAllowed, nil + } + + cfg, err := internal.ReadConfig(p.configInput.ConfigPath) + if err != nil { + return false, err + } + if cfg.ServerSSHAllowed == nil { + // Default to false for security on Android + return false, nil + } + return *cfg.ServerSSHAllowed, err +} + +// SetServerSSHAllowed stores the given value and waits for commit +func (p *Preferences) SetServerSSHAllowed(allowed bool) { + p.configInput.ServerSSHAllowed = &allowed +} + +// GetBlockInbound reads block inbound setting from config file +func (p *Preferences) GetBlockInbound() (bool, error) { + if p.configInput.BlockInbound != nil { + return *p.configInput.BlockInbound, nil + } + + cfg, err := internal.ReadConfig(p.configInput.ConfigPath) + if err != nil { + return false, err + } + return cfg.BlockInbound, err +} + +// SetBlockInbound stores the given value and waits for commit +func (p *Preferences) SetBlockInbound(block bool) { + p.configInput.BlockInbound = &block +} + +// Commit writes out the changes to the config file func (p *Preferences) Commit() error { _, err := internal.UpdateOrCreateConfig(p.configInput) return err diff --git a/client/cmd/system.go b/client/cmd/system.go index 83ce8d215..f63432401 100644 --- a/client/cmd/system.go +++ b/client/cmd/system.go @@ -38,5 +38,5 @@ func init() { upCmd.PersistentFlags().BoolVar(&blockInbound, blockInboundFlag, false, "Block inbound connections. If enabled, the client will not allow any inbound connections to the local machine nor routed networks.\n"+ - "This overrides any policies received from the management service.") + "This overrides any policies received from the management service.") } diff --git a/client/iface/device/device_android.go b/client/iface/device/device_android.go index ab3e611e1..ae9e29bd1 100644 --- a/client/iface/device/device_android.go +++ b/client/iface/device/device_android.go @@ -24,6 +24,7 @@ type WGTunDevice struct { mtu int iceBind *bind.ICEBind tunAdapter TunAdapter + disableDNS bool name string device *device.Device @@ -32,7 +33,7 @@ type WGTunDevice struct { configurer WGConfigurer } -func NewTunDevice(address wgaddr.Address, port int, key string, mtu int, iceBind *bind.ICEBind, tunAdapter TunAdapter) *WGTunDevice { +func NewTunDevice(address wgaddr.Address, port int, key string, mtu int, iceBind *bind.ICEBind, tunAdapter TunAdapter, disableDNS bool) *WGTunDevice { return &WGTunDevice{ address: address, port: port, @@ -40,6 +41,7 @@ func NewTunDevice(address wgaddr.Address, port int, key string, mtu int, iceBind mtu: mtu, iceBind: iceBind, tunAdapter: tunAdapter, + disableDNS: disableDNS, } } @@ -49,6 +51,13 @@ func (t *WGTunDevice) Create(routes []string, dns string, searchDomains []string routesString := routesToString(routes) searchDomainsToString := searchDomainsToString(searchDomains) + // Skip DNS configuration when DisableDNS is enabled + if t.disableDNS { + log.Info("DNS is disabled, skipping DNS and search domain configuration") + dns = "" + searchDomainsToString = "" + } + fd, err := t.tunAdapter.ConfigureInterface(t.address.String(), t.mtu, dns, searchDomainsToString, routesString) if err != nil { log.Errorf("failed to create Android interface: %s", err) diff --git a/client/iface/iface.go b/client/iface/iface.go index 7d609f4cd..006dfe4e7 100644 --- a/client/iface/iface.go +++ b/client/iface/iface.go @@ -43,6 +43,7 @@ type WGIFaceOpts struct { MobileArgs *device.MobileIFaceArguments TransportNet transport.Net FilterFn bind.FilterFn + DisableDNS bool } // WGIface represents an interface instance diff --git a/client/iface/iface_new_android.go b/client/iface/iface_new_android.go index 35046b887..c8babea32 100644 --- a/client/iface/iface_new_android.go +++ b/client/iface/iface_new_android.go @@ -18,7 +18,7 @@ func NewWGIFace(opts WGIFaceOpts) (*WGIface, error) { wgIFace := &WGIface{ userspaceBind: true, - tun: device.NewTunDevice(wgAddress, opts.WGPort, opts.WGPrivKey, opts.MTU, iceBind, opts.MobileArgs.TunAdapter), + tun: device.NewTunDevice(wgAddress, opts.WGPort, opts.WGPrivKey, opts.MTU, iceBind, opts.MobileArgs.TunAdapter, opts.DisableDNS), wgProxyFactory: wgproxy.NewUSPFactory(iceBind), } return wgIFace, nil diff --git a/client/internal/acl/manager.go b/client/internal/acl/manager.go index c8bc9123b..32dc7fbb8 100644 --- a/client/internal/acl/manager.go +++ b/client/internal/acl/manager.go @@ -398,11 +398,15 @@ func (d *DefaultManager) squashAcceptRules( // // We zeroed this to notify squash function that this protocol can't be squashed. addRuleToCalculationMap := func(i int, r *mgmProto.FirewallRule, protocols map[mgmProto.RuleProtocol]*protoMatch) { - drop := r.Action == mgmProto.RuleAction_DROP || r.Port != "" - if drop { + hasPortRestrictions := r.Action == mgmProto.RuleAction_DROP || + r.Port != "" || !portInfoEmpty(r.PortInfo) + + if hasPortRestrictions { + // Don't squash rules with port restrictions protocols[r.Protocol] = &protoMatch{ips: map[string]int{}} return } + if _, ok := protocols[r.Protocol]; !ok { protocols[r.Protocol] = &protoMatch{ ips: map[string]int{}, diff --git a/client/internal/acl/manager_test.go b/client/internal/acl/manager_test.go index 16620033e..b378de8c8 100644 --- a/client/internal/acl/manager_test.go +++ b/client/internal/acl/manager_test.go @@ -330,6 +330,434 @@ func TestDefaultManagerSquashRulesNoAffect(t *testing.T) { assert.Equal(t, len(networkMap.FirewallRules), len(rules)) } +func TestDefaultManagerSquashRulesWithPortRestrictions(t *testing.T) { + tests := []struct { + name string + rules []*mgmProto.FirewallRule + expectedCount int + description string + }{ + { + name: "should not squash rules with port ranges", + rules: []*mgmProto.FirewallRule{ + { + PeerIP: "10.93.0.1", + Direction: mgmProto.RuleDirection_IN, + Action: mgmProto.RuleAction_ACCEPT, + Protocol: mgmProto.RuleProtocol_TCP, + PortInfo: &mgmProto.PortInfo{ + PortSelection: &mgmProto.PortInfo_Range_{ + Range: &mgmProto.PortInfo_Range{ + Start: 8080, + End: 8090, + }, + }, + }, + }, + { + PeerIP: "10.93.0.2", + Direction: mgmProto.RuleDirection_IN, + Action: mgmProto.RuleAction_ACCEPT, + Protocol: mgmProto.RuleProtocol_TCP, + PortInfo: &mgmProto.PortInfo{ + PortSelection: &mgmProto.PortInfo_Range_{ + Range: &mgmProto.PortInfo_Range{ + Start: 8080, + End: 8090, + }, + }, + }, + }, + { + PeerIP: "10.93.0.3", + Direction: mgmProto.RuleDirection_IN, + Action: mgmProto.RuleAction_ACCEPT, + Protocol: mgmProto.RuleProtocol_TCP, + PortInfo: &mgmProto.PortInfo{ + PortSelection: &mgmProto.PortInfo_Range_{ + Range: &mgmProto.PortInfo_Range{ + Start: 8080, + End: 8090, + }, + }, + }, + }, + { + PeerIP: "10.93.0.4", + Direction: mgmProto.RuleDirection_IN, + Action: mgmProto.RuleAction_ACCEPT, + Protocol: mgmProto.RuleProtocol_TCP, + PortInfo: &mgmProto.PortInfo{ + PortSelection: &mgmProto.PortInfo_Range_{ + Range: &mgmProto.PortInfo_Range{ + Start: 8080, + End: 8090, + }, + }, + }, + }, + }, + expectedCount: 4, + description: "Rules with port ranges should not be squashed even if they cover all peers", + }, + { + name: "should not squash rules with specific ports", + rules: []*mgmProto.FirewallRule{ + { + PeerIP: "10.93.0.1", + Direction: mgmProto.RuleDirection_IN, + Action: mgmProto.RuleAction_ACCEPT, + Protocol: mgmProto.RuleProtocol_TCP, + PortInfo: &mgmProto.PortInfo{ + PortSelection: &mgmProto.PortInfo_Port{ + Port: 80, + }, + }, + }, + { + PeerIP: "10.93.0.2", + Direction: mgmProto.RuleDirection_IN, + Action: mgmProto.RuleAction_ACCEPT, + Protocol: mgmProto.RuleProtocol_TCP, + PortInfo: &mgmProto.PortInfo{ + PortSelection: &mgmProto.PortInfo_Port{ + Port: 80, + }, + }, + }, + { + PeerIP: "10.93.0.3", + Direction: mgmProto.RuleDirection_IN, + Action: mgmProto.RuleAction_ACCEPT, + Protocol: mgmProto.RuleProtocol_TCP, + PortInfo: &mgmProto.PortInfo{ + PortSelection: &mgmProto.PortInfo_Port{ + Port: 80, + }, + }, + }, + { + PeerIP: "10.93.0.4", + Direction: mgmProto.RuleDirection_IN, + Action: mgmProto.RuleAction_ACCEPT, + Protocol: mgmProto.RuleProtocol_TCP, + PortInfo: &mgmProto.PortInfo{ + PortSelection: &mgmProto.PortInfo_Port{ + Port: 80, + }, + }, + }, + }, + expectedCount: 4, + description: "Rules with specific ports should not be squashed even if they cover all peers", + }, + { + name: "should not squash rules with legacy port field", + rules: []*mgmProto.FirewallRule{ + { + PeerIP: "10.93.0.1", + Direction: mgmProto.RuleDirection_IN, + Action: mgmProto.RuleAction_ACCEPT, + Protocol: mgmProto.RuleProtocol_TCP, + Port: "443", + }, + { + PeerIP: "10.93.0.2", + Direction: mgmProto.RuleDirection_IN, + Action: mgmProto.RuleAction_ACCEPT, + Protocol: mgmProto.RuleProtocol_TCP, + Port: "443", + }, + { + PeerIP: "10.93.0.3", + Direction: mgmProto.RuleDirection_IN, + Action: mgmProto.RuleAction_ACCEPT, + Protocol: mgmProto.RuleProtocol_TCP, + Port: "443", + }, + { + PeerIP: "10.93.0.4", + Direction: mgmProto.RuleDirection_IN, + Action: mgmProto.RuleAction_ACCEPT, + Protocol: mgmProto.RuleProtocol_TCP, + Port: "443", + }, + }, + expectedCount: 4, + description: "Rules with legacy port field should not be squashed", + }, + { + name: "should not squash rules with DROP action", + rules: []*mgmProto.FirewallRule{ + { + PeerIP: "10.93.0.1", + Direction: mgmProto.RuleDirection_IN, + Action: mgmProto.RuleAction_DROP, + Protocol: mgmProto.RuleProtocol_TCP, + }, + { + PeerIP: "10.93.0.2", + Direction: mgmProto.RuleDirection_IN, + Action: mgmProto.RuleAction_DROP, + Protocol: mgmProto.RuleProtocol_TCP, + }, + { + PeerIP: "10.93.0.3", + Direction: mgmProto.RuleDirection_IN, + Action: mgmProto.RuleAction_DROP, + Protocol: mgmProto.RuleProtocol_TCP, + }, + { + PeerIP: "10.93.0.4", + Direction: mgmProto.RuleDirection_IN, + Action: mgmProto.RuleAction_DROP, + Protocol: mgmProto.RuleProtocol_TCP, + }, + }, + expectedCount: 4, + description: "Rules with DROP action should not be squashed", + }, + { + name: "should squash rules without port restrictions", + rules: []*mgmProto.FirewallRule{ + { + PeerIP: "10.93.0.1", + Direction: mgmProto.RuleDirection_IN, + Action: mgmProto.RuleAction_ACCEPT, + Protocol: mgmProto.RuleProtocol_TCP, + }, + { + PeerIP: "10.93.0.2", + Direction: mgmProto.RuleDirection_IN, + Action: mgmProto.RuleAction_ACCEPT, + Protocol: mgmProto.RuleProtocol_TCP, + }, + { + PeerIP: "10.93.0.3", + Direction: mgmProto.RuleDirection_IN, + Action: mgmProto.RuleAction_ACCEPT, + Protocol: mgmProto.RuleProtocol_TCP, + }, + { + PeerIP: "10.93.0.4", + Direction: mgmProto.RuleDirection_IN, + Action: mgmProto.RuleAction_ACCEPT, + Protocol: mgmProto.RuleProtocol_TCP, + }, + }, + expectedCount: 1, + description: "Rules without port restrictions should be squashed into a single 0.0.0.0 rule", + }, + { + name: "mixed rules should not squash protocol with port restrictions", + rules: []*mgmProto.FirewallRule{ + { + PeerIP: "10.93.0.1", + Direction: mgmProto.RuleDirection_IN, + Action: mgmProto.RuleAction_ACCEPT, + Protocol: mgmProto.RuleProtocol_TCP, + }, + { + PeerIP: "10.93.0.2", + Direction: mgmProto.RuleDirection_IN, + Action: mgmProto.RuleAction_ACCEPT, + Protocol: mgmProto.RuleProtocol_TCP, + PortInfo: &mgmProto.PortInfo{ + PortSelection: &mgmProto.PortInfo_Port{ + Port: 80, + }, + }, + }, + { + PeerIP: "10.93.0.3", + Direction: mgmProto.RuleDirection_IN, + Action: mgmProto.RuleAction_ACCEPT, + Protocol: mgmProto.RuleProtocol_TCP, + }, + { + PeerIP: "10.93.0.4", + Direction: mgmProto.RuleDirection_IN, + Action: mgmProto.RuleAction_ACCEPT, + Protocol: mgmProto.RuleProtocol_TCP, + }, + }, + expectedCount: 4, + description: "TCP should not be squashed because one rule has port restrictions", + }, + { + name: "should squash UDP but not TCP when TCP has port restrictions", + rules: []*mgmProto.FirewallRule{ + // TCP rules with port restrictions - should NOT be squashed + { + PeerIP: "10.93.0.1", + Direction: mgmProto.RuleDirection_IN, + Action: mgmProto.RuleAction_ACCEPT, + Protocol: mgmProto.RuleProtocol_TCP, + Port: "443", + }, + { + PeerIP: "10.93.0.2", + Direction: mgmProto.RuleDirection_IN, + Action: mgmProto.RuleAction_ACCEPT, + Protocol: mgmProto.RuleProtocol_TCP, + Port: "443", + }, + { + PeerIP: "10.93.0.3", + Direction: mgmProto.RuleDirection_IN, + Action: mgmProto.RuleAction_ACCEPT, + Protocol: mgmProto.RuleProtocol_TCP, + Port: "443", + }, + { + PeerIP: "10.93.0.4", + Direction: mgmProto.RuleDirection_IN, + Action: mgmProto.RuleAction_ACCEPT, + Protocol: mgmProto.RuleProtocol_TCP, + Port: "443", + }, + // UDP rules without port restrictions - SHOULD be squashed + { + PeerIP: "10.93.0.1", + Direction: mgmProto.RuleDirection_IN, + Action: mgmProto.RuleAction_ACCEPT, + Protocol: mgmProto.RuleProtocol_UDP, + }, + { + PeerIP: "10.93.0.2", + Direction: mgmProto.RuleDirection_IN, + Action: mgmProto.RuleAction_ACCEPT, + Protocol: mgmProto.RuleProtocol_UDP, + }, + { + PeerIP: "10.93.0.3", + Direction: mgmProto.RuleDirection_IN, + Action: mgmProto.RuleAction_ACCEPT, + Protocol: mgmProto.RuleProtocol_UDP, + }, + { + PeerIP: "10.93.0.4", + Direction: mgmProto.RuleDirection_IN, + Action: mgmProto.RuleAction_ACCEPT, + Protocol: mgmProto.RuleProtocol_UDP, + }, + }, + expectedCount: 5, // 4 TCP rules + 1 squashed UDP rule (0.0.0.0) + description: "UDP should be squashed to 0.0.0.0 rule, but TCP should remain as individual rules due to port restrictions", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + networkMap := &mgmProto.NetworkMap{ + RemotePeers: []*mgmProto.RemotePeerConfig{ + {AllowedIps: []string{"10.93.0.1"}}, + {AllowedIps: []string{"10.93.0.2"}}, + {AllowedIps: []string{"10.93.0.3"}}, + {AllowedIps: []string{"10.93.0.4"}}, + }, + FirewallRules: tt.rules, + } + + manager := &DefaultManager{} + rules, _ := manager.squashAcceptRules(networkMap) + + assert.Equal(t, tt.expectedCount, len(rules), tt.description) + + // For squashed rules, verify we get the expected 0.0.0.0 rule + if tt.expectedCount == 1 { + assert.Equal(t, "0.0.0.0", rules[0].PeerIP) + assert.Equal(t, mgmProto.RuleDirection_IN, rules[0].Direction) + assert.Equal(t, mgmProto.RuleAction_ACCEPT, rules[0].Action) + } + }) + } +} + +func TestPortInfoEmpty(t *testing.T) { + tests := []struct { + name string + portInfo *mgmProto.PortInfo + expected bool + }{ + { + name: "nil PortInfo should be empty", + portInfo: nil, + expected: true, + }, + { + name: "PortInfo with zero port should be empty", + portInfo: &mgmProto.PortInfo{ + PortSelection: &mgmProto.PortInfo_Port{ + Port: 0, + }, + }, + expected: true, + }, + { + name: "PortInfo with valid port should not be empty", + portInfo: &mgmProto.PortInfo{ + PortSelection: &mgmProto.PortInfo_Port{ + Port: 80, + }, + }, + expected: false, + }, + { + name: "PortInfo with nil range should be empty", + portInfo: &mgmProto.PortInfo{ + PortSelection: &mgmProto.PortInfo_Range_{ + Range: nil, + }, + }, + expected: true, + }, + { + name: "PortInfo with zero start range should be empty", + portInfo: &mgmProto.PortInfo{ + PortSelection: &mgmProto.PortInfo_Range_{ + Range: &mgmProto.PortInfo_Range{ + Start: 0, + End: 100, + }, + }, + }, + expected: true, + }, + { + name: "PortInfo with zero end range should be empty", + portInfo: &mgmProto.PortInfo{ + PortSelection: &mgmProto.PortInfo_Range_{ + Range: &mgmProto.PortInfo_Range{ + Start: 80, + End: 0, + }, + }, + }, + expected: true, + }, + { + name: "PortInfo with valid range should not be empty", + portInfo: &mgmProto.PortInfo{ + PortSelection: &mgmProto.PortInfo_Range_{ + Range: &mgmProto.PortInfo_Range{ + Start: 8080, + End: 8090, + }, + }, + }, + expected: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := portInfoEmpty(tt.portInfo) + assert.Equal(t, tt.expected, result) + }) + } +} + func TestDefaultManagerEnableSSHRules(t *testing.T) { networkMap := &mgmProto.NetworkMap{ PeerConfig: &mgmProto.PeerConfig{ diff --git a/client/internal/config.go b/client/internal/config.go index 45a7620e1..37ee1e1bf 100644 --- a/client/internal/config.go +++ b/client/internal/config.go @@ -223,6 +223,8 @@ func createNewConfig(input ConfigInput) (*Config, error) { config := &Config{ // defaults to false only for new (post 0.26) configurations ServerSSHAllowed: util.False(), + // default to disabling server routes on Android for security + DisableServerRoutes: runtime.GOOS == "android", } if _, err := config.apply(input); err != nil { @@ -416,9 +418,15 @@ func (config *Config) apply(input ConfigInput) (updated bool, err error) { config.ServerSSHAllowed = input.ServerSSHAllowed updated = true } else if config.ServerSSHAllowed == nil { - // enables SSH for configs from old versions to preserve backwards compatibility - log.Infof("falling back to enabled SSH server for pre-existing configuration") - config.ServerSSHAllowed = util.True() + if runtime.GOOS == "android" { + // default to disabled SSH on Android for security + log.Infof("setting SSH server to false by default on Android") + config.ServerSSHAllowed = util.False() + } else { + // enables SSH for configs from old versions to preserve backwards compatibility + log.Infof("falling back to enabled SSH server for pre-existing configuration") + config.ServerSSHAllowed = util.True() + } updated = true } diff --git a/client/internal/conn_mgr.go b/client/internal/conn_mgr.go index aac312dc3..c630d3052 100644 --- a/client/internal/conn_mgr.go +++ b/client/internal/conn_mgr.go @@ -175,7 +175,7 @@ func (e *ConnMgr) AddPeerConn(ctx context.Context, peerKey string, conn *peer.Co PeerConnID: conn.ConnID(), Log: conn.Log, } - excluded, err := e.lazyConnMgr.AddPeer(lazyPeerCfg) + excluded, err := e.lazyConnMgr.AddPeer(e.lazyCtx, lazyPeerCfg) if err != nil { conn.Log.Errorf("failed to add peer to lazyconn manager: %v", err) if err := conn.Open(ctx); err != nil { diff --git a/client/internal/engine.go b/client/internal/engine.go index 253ecb2a6..4ea6fbd94 100644 --- a/client/internal/engine.go +++ b/client/internal/engine.go @@ -1527,6 +1527,7 @@ func (e *Engine) newWgIface() (*iface.WGIface, error) { MTU: iface.DefaultMTU, TransportNet: transportNet, FilterFn: e.addrViaRoutes, + DisableDNS: e.config.DisableDNS, } switch runtime.GOOS { diff --git a/client/internal/lazyconn/inactivity/inactivity.go b/client/internal/lazyconn/inactivity/inactivity.go index a30c1846d..9b7c8511b 100644 --- a/client/internal/lazyconn/inactivity/inactivity.go +++ b/client/internal/lazyconn/inactivity/inactivity.go @@ -68,3 +68,8 @@ func (i *Monitor) PauseTimer() { func (i *Monitor) ResetTimer() { i.timer.Reset(i.inactivityThreshold) } + +func (i *Monitor) ResetMonitor(ctx context.Context, timeoutChan chan peer.ConnID) { + i.Stop() + go i.Start(ctx, timeoutChan) +} diff --git a/client/internal/lazyconn/manager/manager.go b/client/internal/lazyconn/manager/manager.go index 718bdbddf..74ede50a7 100644 --- a/client/internal/lazyconn/manager/manager.go +++ b/client/internal/lazyconn/manager/manager.go @@ -58,7 +58,7 @@ type Manager struct { // Route HA group management peerToHAGroups map[string][]route.HAUniqueID // peer ID -> HA groups they belong to haGroupToPeers map[route.HAUniqueID][]string // HA group -> peer IDs in the group - routesMu sync.RWMutex // protects route mappings + routesMu sync.RWMutex onInactive chan peerid.ConnID } @@ -146,7 +146,7 @@ func (m *Manager) Start(ctx context.Context) { case peerConnID := <-m.activityManager.OnActivityChan: m.onPeerActivity(ctx, peerConnID) case peerConnID := <-m.onInactive: - m.onPeerInactivityTimedOut(peerConnID) + m.onPeerInactivityTimedOut(ctx, peerConnID) } } } @@ -197,7 +197,7 @@ func (m *Manager) ExcludePeer(ctx context.Context, peerConfigs []lazyconn.PeerCo return added } -func (m *Manager) AddPeer(peerCfg lazyconn.PeerConfig) (bool, error) { +func (m *Manager) AddPeer(ctx context.Context, peerCfg lazyconn.PeerConfig) (bool, error) { m.managedPeersMu.Lock() defer m.managedPeersMu.Unlock() @@ -225,6 +225,13 @@ func (m *Manager) AddPeer(peerCfg lazyconn.PeerConfig) (bool, error) { peerCfg: &peerCfg, expectedWatcher: watcherActivity, } + + // Check if this peer should be activated because its HA group peers are active + if group, ok := m.shouldActivateNewPeer(peerCfg.PublicKey); ok { + peerCfg.Log.Debugf("peer belongs to active HA group %s, will activate immediately", group) + m.activateNewPeerInActiveGroup(ctx, peerCfg) + } + return false, nil } @@ -315,36 +322,38 @@ func (m *Manager) activateSinglePeer(ctx context.Context, cfg *lazyconn.PeerConf // activateHAGroupPeers activates all peers in HA groups that the given peer belongs to func (m *Manager) activateHAGroupPeers(ctx context.Context, triggerPeerID string) { + var peersToActivate []string + m.routesMu.RLock() haGroups := m.peerToHAGroups[triggerPeerID] - m.routesMu.RUnlock() if len(haGroups) == 0 { + m.routesMu.RUnlock() log.Debugf("peer %s is not part of any HA groups", triggerPeerID) return } - activatedCount := 0 for _, haGroup := range haGroups { - m.routesMu.RLock() peers := m.haGroupToPeers[haGroup] - m.routesMu.RUnlock() - for _, peerID := range peers { - if peerID == triggerPeerID { - continue + if peerID != triggerPeerID { + peersToActivate = append(peersToActivate, peerID) } + } + } + m.routesMu.RUnlock() - cfg, mp := m.getPeerForActivation(peerID) - if cfg == nil { - continue - } + activatedCount := 0 + for _, peerID := range peersToActivate { + cfg, mp := m.getPeerForActivation(peerID) + if cfg == nil { + continue + } - if m.activateSinglePeer(ctx, cfg, mp) { - activatedCount++ - cfg.Log.Infof("activated peer as part of HA group %s (triggered by %s)", haGroup, triggerPeerID) - m.peerStore.PeerConnOpen(m.engineCtx, cfg.PublicKey) - } + if m.activateSinglePeer(ctx, cfg, mp) { + activatedCount++ + cfg.Log.Infof("activated peer as part of HA group (triggered by %s)", triggerPeerID) + m.peerStore.PeerConnOpen(m.engineCtx, cfg.PublicKey) } } @@ -354,6 +363,51 @@ func (m *Manager) activateHAGroupPeers(ctx context.Context, triggerPeerID string } } +// shouldActivateNewPeer checks if a newly added peer should be activated +// because other peers in its HA groups are already active +func (m *Manager) shouldActivateNewPeer(peerID string) (route.HAUniqueID, bool) { + m.routesMu.RLock() + defer m.routesMu.RUnlock() + + haGroups := m.peerToHAGroups[peerID] + if len(haGroups) == 0 { + return "", false + } + + for _, haGroup := range haGroups { + peers := m.haGroupToPeers[haGroup] + for _, groupPeerID := range peers { + if groupPeerID == peerID { + continue + } + + cfg, ok := m.managedPeers[groupPeerID] + if !ok { + continue + } + if mp, ok := m.managedPeersByConnID[cfg.PeerConnID]; ok && mp.expectedWatcher == watcherInactivity { + return haGroup, true + } + } + } + return "", false +} + +// activateNewPeerInActiveGroup activates a newly added peer that should be active due to HA group +func (m *Manager) activateNewPeerInActiveGroup(ctx context.Context, peerCfg lazyconn.PeerConfig) { + mp, ok := m.managedPeersByConnID[peerCfg.PeerConnID] + if !ok { + return + } + + if !m.activateSinglePeer(ctx, &peerCfg, mp) { + return + } + + peerCfg.Log.Infof("activated newly added peer due to active HA group peers") + m.peerStore.PeerConnOpen(m.engineCtx, peerCfg.PublicKey) +} + func (m *Manager) addActivePeer(ctx context.Context, peerCfg lazyconn.PeerConfig) error { if _, ok := m.managedPeers[peerCfg.PublicKey]; ok { peerCfg.Log.Warnf("peer already managed") @@ -415,6 +469,48 @@ func (m *Manager) close() { log.Infof("lazy connection manager closed") } +// shouldDeferIdleForHA checks if peer should stay connected due to HA group requirements +func (m *Manager) shouldDeferIdleForHA(peerID string) bool { + m.routesMu.RLock() + defer m.routesMu.RUnlock() + + haGroups := m.peerToHAGroups[peerID] + if len(haGroups) == 0 { + return false + } + + for _, haGroup := range haGroups { + groupPeers := m.haGroupToPeers[haGroup] + + for _, groupPeerID := range groupPeers { + if groupPeerID == peerID { + continue + } + + cfg, ok := m.managedPeers[groupPeerID] + if !ok { + continue + } + + groupMp, ok := m.managedPeersByConnID[cfg.PeerConnID] + if !ok { + continue + } + + if groupMp.expectedWatcher != watcherInactivity { + continue + } + + // Other member is still connected, defer idle + if peer, ok := m.peerStore.PeerConn(groupPeerID); ok && peer.IsConnected() { + return true + } + } + } + + return false +} + func (m *Manager) onPeerActivity(ctx context.Context, peerConnID peerid.ConnID) { m.managedPeersMu.Lock() defer m.managedPeersMu.Unlock() @@ -441,7 +537,7 @@ func (m *Manager) onPeerActivity(ctx context.Context, peerConnID peerid.ConnID) m.peerStore.PeerConnOpen(m.engineCtx, mp.peerCfg.PublicKey) } -func (m *Manager) onPeerInactivityTimedOut(peerConnID peerid.ConnID) { +func (m *Manager) onPeerInactivityTimedOut(ctx context.Context, peerConnID peerid.ConnID) { m.managedPeersMu.Lock() defer m.managedPeersMu.Unlock() @@ -456,6 +552,17 @@ func (m *Manager) onPeerInactivityTimedOut(peerConnID peerid.ConnID) { return } + if m.shouldDeferIdleForHA(mp.peerCfg.PublicKey) { + iw, ok := m.inactivityMonitors[peerConnID] + if ok { + mp.peerCfg.Log.Debugf("resetting inactivity timer due to HA group requirements") + iw.ResetMonitor(ctx, m.onInactive) + } else { + mp.peerCfg.Log.Errorf("inactivity monitor not found for HA defer reset") + } + return + } + mp.peerCfg.Log.Infof("connection timed out") // this is blocking operation, potentially can be optimized @@ -489,7 +596,7 @@ func (m *Manager) onPeerConnected(peerConnID peerid.ConnID) { iw, ok := m.inactivityMonitors[mp.peerCfg.PeerConnID] if !ok { - mp.peerCfg.Log.Errorf("inactivity monitor not found for peer") + mp.peerCfg.Log.Warnf("inactivity monitor not found for peer") return } diff --git a/client/internal/peer/conn.go b/client/internal/peer/conn.go index b33023873..c3f44cc7f 100644 --- a/client/internal/peer/conn.go +++ b/client/internal/peer/conn.go @@ -317,12 +317,12 @@ func (conn *Conn) WgConfig() WgConfig { return conn.config.WgConfig } -// IsConnected unit tests only -// refactor unit test to use status recorder use refactor status recorded to manage connection status in peer.Conn +// IsConnected returns true if the peer is connected func (conn *Conn) IsConnected() bool { conn.mu.Lock() defer conn.mu.Unlock() - return conn.currentConnPriority != conntype.None + + return conn.evalStatus() == StatusConnected } func (conn *Conn) GetKey() string { diff --git a/client/internal/routemanager/mock.go b/client/internal/routemanager/mock.go index 63bad689e..742294cdf 100644 --- a/client/internal/routemanager/mock.go +++ b/client/internal/routemanager/mock.go @@ -15,7 +15,7 @@ import ( // MockManager is the mock instance of a route manager type MockManager struct { ClassifyRoutesFunc func(routes []*route.Route) (map[route.ID]*route.Route, route.HAMap) - UpdateRoutesFunc func (updateSerial uint64, serverRoutes map[route.ID]*route.Route, clientRoutes route.HAMap, useNewDNSRoute bool) error + UpdateRoutesFunc func(updateSerial uint64, serverRoutes map[route.ID]*route.Route, clientRoutes route.HAMap, useNewDNSRoute bool) error TriggerSelectionFunc func(haMap route.HAMap) GetRouteSelectorFunc func() *routeselector.RouteSelector GetClientRoutesFunc func() route.HAMap diff --git a/client/internal/routemanager/notifier/notifier.go b/client/internal/routemanager/notifier/notifier.go index 25a3a71e0..3cc7c3308 100644 --- a/client/internal/routemanager/notifier/notifier.go +++ b/client/internal/routemanager/notifier/notifier.go @@ -32,7 +32,6 @@ func (n *Notifier) SetListener(listener listener.NetworkChangeListener) { func (n *Notifier) SetInitialClientRoutes(clientRoutes []*route.Route) { nets := make([]string, 0) for _, r := range clientRoutes { - // filter out domain routes if r.IsDynamic() { continue } @@ -46,30 +45,27 @@ func (n *Notifier) OnNewRoutes(idMap route.HAMap) { if runtime.GOOS != "android" { return } - newNets := make([]string, 0) + + var newNets []string for _, routes := range idMap { for _, r := range routes { + if r.IsDynamic() { + continue + } newNets = append(newNets, r.Network.String()) } } sort.Strings(newNets) - switch runtime.GOOS { - case "android": - if !n.hasDiff(n.initialRouteRanges, newNets) { - return - } - default: - if !n.hasDiff(n.routeRanges, newNets) { - return - } + if !n.hasDiff(n.initialRouteRanges, newNets) { + return } n.routeRanges = newNets - n.notify() } +// OnNewPrefixes is called from iOS only func (n *Notifier) OnNewPrefixes(prefixes []netip.Prefix) { newNets := make([]string, 0) for _, prefix := range prefixes { @@ -77,19 +73,11 @@ func (n *Notifier) OnNewPrefixes(prefixes []netip.Prefix) { } sort.Strings(newNets) - switch runtime.GOOS { - case "android": - if !n.hasDiff(n.initialRouteRanges, newNets) { - return - } - default: - if !n.hasDiff(n.routeRanges, newNets) { - return - } + if !n.hasDiff(n.routeRanges, newNets) { + return } n.routeRanges = newNets - n.notify() } diff --git a/client/system/info.go b/client/system/info.go index a0a5fe8b3..aff10ece3 100644 --- a/client/system/info.go +++ b/client/system/info.go @@ -59,16 +59,16 @@ type Info struct { Environment Environment Files []File // for posture checks - RosenpassEnabled bool - RosenpassPermissive bool - ServerSSHAllowed bool + RosenpassEnabled bool + RosenpassPermissive bool + ServerSSHAllowed bool - DisableClientRoutes bool - DisableServerRoutes bool - DisableDNS bool - DisableFirewall bool - BlockLANAccess bool - BlockInbound bool + DisableClientRoutes bool + DisableServerRoutes bool + DisableDNS bool + DisableFirewall bool + BlockLANAccess bool + BlockInbound bool LazyConnectionEnabled bool } diff --git a/management/server/account.go b/management/server/account.go index daeaf6e55..b376f6f5e 100644 --- a/management/server/account.go +++ b/management/server/account.go @@ -1853,40 +1853,49 @@ func (am *DefaultAccountManager) GetOrCreateAccountByPrivateDomain(ctx context.C } func (am *DefaultAccountManager) UpdateToPrimaryAccount(ctx context.Context, accountId string) (*types.Account, error) { - account, err := am.Store.GetAccount(ctx, accountId) + var account *types.Account + err := am.Store.ExecuteInTransaction(ctx, func(transaction store.Store) error { + var err error + account, err = transaction.GetAccount(ctx, accountId) + if err != nil { + return err + } + + if account.IsDomainPrimaryAccount { + return nil + } + + existingPrimaryAccountID, err := transaction.GetAccountIDByPrivateDomain(ctx, store.LockingStrengthShare, account.Domain) + + // error is not a not found error + if handleNotFound(err) != nil { + return err + } + + // a primary account already exists for this private domain + if err == nil { + log.WithContext(ctx).WithFields(log.Fields{ + "accountId": accountId, + "existingAccountId": existingPrimaryAccountID, + }).Errorf("cannot update account to primary, another account already exists as primary for the same domain") + return status.Errorf(status.Internal, "cannot update account to primary") + } + + account.IsDomainPrimaryAccount = true + + if err := transaction.SaveAccount(ctx, account); err != nil { + log.WithContext(ctx).WithFields(log.Fields{ + "accountId": accountId, + }).Errorf("failed to update account to primary: %v", err) + return status.Errorf(status.Internal, "failed to update account to primary") + } + + return nil + }) if err != nil { return nil, err } - if account.IsDomainPrimaryAccount { - return account, nil - } - - existingPrimaryAccountID, err := am.Store.GetAccountIDByPrivateDomain(ctx, store.LockingStrengthShare, account.Domain) - - // error is not a not found error - if handleNotFound(err) != nil { - return nil, err - } - - // a primary account already exists for this private domain - if err == nil { - log.WithContext(ctx).WithFields(log.Fields{ - "accountId": accountId, - "existingAccountId": existingPrimaryAccountID, - }).Errorf("cannot update account to primary, another account already exists as primary for the same domain") - return nil, status.Errorf(status.Internal, "cannot update account to primary") - } - - account.IsDomainPrimaryAccount = true - - if err := am.Store.SaveAccount(ctx, account); err != nil { - log.WithContext(ctx).WithFields(log.Fields{ - "accountId": accountId, - }).Errorf("failed to update account to primary: %v", err) - return nil, status.Errorf(status.Internal, "failed to update account to primary") - } - return account, nil } diff --git a/management/server/group.go b/management/server/group.go index c26a0cfc1..130a67145 100644 --- a/management/server/group.go +++ b/management/server/group.go @@ -664,15 +664,6 @@ func areGroupChangesAffectPeers(ctx context.Context, transaction store.Store, ac return false, nil } -func (am *DefaultAccountManager) anyGroupHasPeers(account *types.Account, groupIDs []string) bool { - for _, groupID := range groupIDs { - if group, exists := account.Groups[groupID]; exists && group.HasPeers() { - return true - } - } - return false -} - // anyGroupHasPeersOrResources checks if any of the given groups in the account have peers or resources. func anyGroupHasPeersOrResources(ctx context.Context, transaction store.Store, accountID string, groupIDs []string) (bool, error) { groups, err := transaction.GetGroupsByIDs(ctx, store.LockingStrengthShare, accountID, groupIDs) diff --git a/management/server/http/api/openapi.yml b/management/server/http/api/openapi.yml index 58134d375..1c5ca9b04 100644 --- a/management/server/http/api/openapi.yml +++ b/management/server/http/api/openapi.yml @@ -426,6 +426,10 @@ components: items: type: string example: "stage-host-1" + ephemeral: + description: Indicates whether the peer is ephemeral or not + type: boolean + example: false required: - city_name - connected @@ -450,6 +454,7 @@ components: - approval_required - serial_number - extra_dns_labels + - ephemeral AccessiblePeer: allOf: - $ref: '#/components/schemas/PeerMinimum' diff --git a/management/server/http/api/types.gen.go b/management/server/http/api/types.gen.go index 0a09d7ca2..d27fd2a57 100644 --- a/management/server/http/api/types.gen.go +++ b/management/server/http/api/types.gen.go @@ -1016,6 +1016,9 @@ type Peer struct { // DnsLabel Peer's DNS label is the parsed peer name for domain resolution. It is used to form an FQDN by appending the account's domain to the peer label. e.g. peer-dns-label.netbird.cloud DnsLabel string `json:"dns_label"` + // Ephemeral Indicates whether the peer is ephemeral or not + Ephemeral bool `json:"ephemeral"` + // ExtraDnsLabels Extra DNS labels added to the peer ExtraDnsLabels []string `json:"extra_dns_labels"` @@ -1097,6 +1100,9 @@ type PeerBatch struct { // DnsLabel Peer's DNS label is the parsed peer name for domain resolution. It is used to form an FQDN by appending the account's domain to the peer label. e.g. peer-dns-label.netbird.cloud DnsLabel string `json:"dns_label"` + // Ephemeral Indicates whether the peer is ephemeral or not + Ephemeral bool `json:"ephemeral"` + // ExtraDnsLabels Extra DNS labels added to the peer ExtraDnsLabels []string `json:"extra_dns_labels"` diff --git a/management/server/http/handlers/peers/peers_handler.go b/management/server/http/handlers/peers/peers_handler.go index 58ea06ea3..8c20ed65f 100644 --- a/management/server/http/handlers/peers/peers_handler.go +++ b/management/server/http/handlers/peers/peers_handler.go @@ -365,6 +365,7 @@ func toSinglePeerResponse(peer *nbpeer.Peer, groupsInfo []api.GroupMinimum, dnsD CityName: peer.Location.CityName, SerialNumber: peer.Meta.SystemSerialNumber, InactivityExpirationEnabled: peer.InactivityExpirationEnabled, + Ephemeral: peer.Ephemeral, } } diff --git a/management/server/integrated_validator.go b/management/server/integrated_validator.go index ef77bf10c..edb89466c 100644 --- a/management/server/integrated_validator.go +++ b/management/server/integrated_validator.go @@ -37,21 +37,23 @@ func (am *DefaultAccountManager) UpdateIntegratedValidatorGroups(ctx context.Con unlock := am.Store.AcquireWriteLockByUID(ctx, accountID) defer unlock() - a, err := am.Store.GetAccountByUser(ctx, userID) - if err != nil { - return err - } + return am.Store.ExecuteInTransaction(ctx, func(transaction store.Store) error { + a, err := transaction.GetAccountByUser(ctx, userID) + if err != nil { + return err + } - var extra *types.ExtraSettings + var extra *types.ExtraSettings - if a.Settings.Extra != nil { - extra = a.Settings.Extra - } else { - extra = &types.ExtraSettings{} - a.Settings.Extra = extra - } - extra.IntegratedValidatorGroups = groups - return am.Store.SaveAccount(ctx, a) + if a.Settings.Extra != nil { + extra = a.Settings.Extra + } else { + extra = &types.ExtraSettings{} + a.Settings.Extra = extra + } + extra.IntegratedValidatorGroups = groups + return transaction.SaveAccount(ctx, a) + }) } func (am *DefaultAccountManager) GroupValidation(ctx context.Context, accountID string, groupIDs []string) (bool, error) { diff --git a/management/server/peer.go b/management/server/peer.go index f2469e09b..1a1289721 100644 --- a/management/server/peer.go +++ b/management/server/peer.go @@ -92,7 +92,7 @@ func (am *DefaultAccountManager) getUserAccessiblePeers(ctx context.Context, acc // fetch all the peers that have access to the user's peers for _, peer := range peers { - aclPeers, _ := account.GetPeerConnectionResources(ctx, peer.ID, approvedPeersMap) + aclPeers, _ := account.GetPeerConnectionResources(ctx, peer, approvedPeersMap) for _, p := range aclPeers { peersMap[p.ID] = p } @@ -1149,7 +1149,7 @@ func (am *DefaultAccountManager) checkIfUserOwnsPeer(ctx context.Context, accoun } for _, p := range userPeers { - aclPeers, _ := account.GetPeerConnectionResources(ctx, p.ID, approvedPeersMap) + aclPeers, _ := account.GetPeerConnectionResources(ctx, p, approvedPeersMap) for _, aclPeer := range aclPeers { if aclPeer.ID == peer.ID { return peer, nil diff --git a/management/server/policy_test.go b/management/server/policy_test.go index 679ec3b86..4352f3cff 100644 --- a/management/server/policy_test.go +++ b/management/server/policy_test.go @@ -27,6 +27,7 @@ func TestAccount_getPeersByPolicy(t *testing.T) { ID: "peerB", IP: net.ParseIP("100.65.80.39"), Status: &nbpeer.PeerStatus{}, + Meta: nbpeer.PeerSystemMeta{WtVersion: "0.48.0"}, }, "peerC": { ID: "peerC", @@ -63,6 +64,12 @@ func TestAccount_getPeersByPolicy(t *testing.T) { IP: net.ParseIP("100.65.31.2"), Status: &nbpeer.PeerStatus{}, }, + "peerK": { + ID: "peerK", + IP: net.ParseIP("100.32.80.1"), + Status: &nbpeer.PeerStatus{}, + Meta: nbpeer.PeerSystemMeta{WtVersion: "0.30.0"}, + }, }, Groups: map[string]*types.Group{ "GroupAll": { @@ -111,6 +118,13 @@ func TestAccount_getPeersByPolicy(t *testing.T) { "peerI", }, }, + "GroupWorkflow": { + ID: "GroupWorkflow", + Name: "workflow", + Peers: []string{ + "peerK", + }, + }, }, Policies: []*types.Policy{ { @@ -189,6 +203,39 @@ func TestAccount_getPeersByPolicy(t *testing.T) { }, }, }, + { + ID: "RuleWorkflow", + Name: "Workflow", + Description: "No description", + Enabled: true, + Rules: []*types.PolicyRule{ + { + ID: "RuleWorkflow", + Name: "Workflow", + Description: "No description", + Bidirectional: true, + Enabled: true, + Protocol: types.PolicyRuleProtocolTCP, + Action: types.PolicyTrafficActionAccept, + PortRanges: []types.RulePortRange{ + { + Start: 8088, + End: 8088, + }, + { + Start: 9090, + End: 9095, + }, + }, + Sources: []string{ + "GroupWorkflow", + }, + Destinations: []string{ + "GroupDMZ", + }, + }, + }, + }, }, } @@ -199,14 +246,14 @@ func TestAccount_getPeersByPolicy(t *testing.T) { t.Run("check that all peers get map", func(t *testing.T) { for _, p := range account.Peers { - peers, firewallRules := account.GetPeerConnectionResources(context.Background(), p.ID, validatedPeers) - assert.GreaterOrEqual(t, len(peers), 2, "minimum number peers should present") - assert.GreaterOrEqual(t, len(firewallRules), 2, "minimum number of firewall rules should present") + peers, firewallRules := account.GetPeerConnectionResources(context.Background(), p, validatedPeers) + assert.GreaterOrEqual(t, len(peers), 1, "minimum number peers should present") + assert.GreaterOrEqual(t, len(firewallRules), 1, "minimum number of firewall rules should present") } }) t.Run("check first peer map details", func(t *testing.T) { - peers, firewallRules := account.GetPeerConnectionResources(context.Background(), "peerB", validatedPeers) + peers, firewallRules := account.GetPeerConnectionResources(context.Background(), account.Peers["peerB"], validatedPeers) assert.Len(t, peers, 8) assert.Contains(t, peers, account.Peers["peerA"]) assert.Contains(t, peers, account.Peers["peerC"]) @@ -364,6 +411,32 @@ func TestAccount_getPeersByPolicy(t *testing.T) { assert.True(t, contains, "rule not found in expected rules %#v", rule) } }) + + t.Run("check port ranges support for older peers", func(t *testing.T) { + peers, firewallRules := account.GetPeerConnectionResources(context.Background(), account.Peers["peerK"], validatedPeers) + assert.Len(t, peers, 1) + assert.Contains(t, peers, account.Peers["peerI"]) + + expectedFirewallRules := []*types.FirewallRule{ + { + PeerIP: "100.65.31.2", + Direction: types.FirewallRuleDirectionIN, + Action: "accept", + Protocol: "tcp", + Port: "8088", + PolicyID: "RuleWorkflow", + }, + { + PeerIP: "100.65.31.2", + Direction: types.FirewallRuleDirectionOUT, + Action: "accept", + Protocol: "tcp", + Port: "8088", + PolicyID: "RuleWorkflow", + }, + } + assert.ElementsMatch(t, firewallRules, expectedFirewallRules) + }) } func TestAccount_getPeersByPolicyDirect(t *testing.T) { @@ -466,10 +539,10 @@ func TestAccount_getPeersByPolicyDirect(t *testing.T) { } t.Run("check first peer map", func(t *testing.T) { - peers, firewallRules := account.GetPeerConnectionResources(context.Background(), "peerB", approvedPeers) + peers, firewallRules := account.GetPeerConnectionResources(context.Background(), account.Peers["peerB"], approvedPeers) assert.Contains(t, peers, account.Peers["peerC"]) - epectedFirewallRules := []*types.FirewallRule{ + expectedFirewallRules := []*types.FirewallRule{ { PeerIP: "100.65.254.139", Direction: types.FirewallRuleDirectionIN, @@ -487,19 +560,19 @@ func TestAccount_getPeersByPolicyDirect(t *testing.T) { PolicyID: "RuleSwarm", }, } - assert.Len(t, firewallRules, len(epectedFirewallRules)) - slices.SortFunc(epectedFirewallRules, sortFunc()) + assert.Len(t, firewallRules, len(expectedFirewallRules)) + slices.SortFunc(expectedFirewallRules, sortFunc()) slices.SortFunc(firewallRules, sortFunc()) for i := range firewallRules { - assert.Equal(t, epectedFirewallRules[i], firewallRules[i]) + assert.Equal(t, expectedFirewallRules[i], firewallRules[i]) } }) t.Run("check second peer map", func(t *testing.T) { - peers, firewallRules := account.GetPeerConnectionResources(context.Background(), "peerC", approvedPeers) + peers, firewallRules := account.GetPeerConnectionResources(context.Background(), account.Peers["peerC"], approvedPeers) assert.Contains(t, peers, account.Peers["peerB"]) - epectedFirewallRules := []*types.FirewallRule{ + expectedFirewallRules := []*types.FirewallRule{ { PeerIP: "100.65.80.39", Direction: types.FirewallRuleDirectionIN, @@ -517,21 +590,21 @@ func TestAccount_getPeersByPolicyDirect(t *testing.T) { PolicyID: "RuleSwarm", }, } - assert.Len(t, firewallRules, len(epectedFirewallRules)) - slices.SortFunc(epectedFirewallRules, sortFunc()) + assert.Len(t, firewallRules, len(expectedFirewallRules)) + slices.SortFunc(expectedFirewallRules, sortFunc()) slices.SortFunc(firewallRules, sortFunc()) for i := range firewallRules { - assert.Equal(t, epectedFirewallRules[i], firewallRules[i]) + assert.Equal(t, expectedFirewallRules[i], firewallRules[i]) } }) account.Policies[1].Rules[0].Bidirectional = false t.Run("check first peer map directional only", func(t *testing.T) { - peers, firewallRules := account.GetPeerConnectionResources(context.Background(), "peerB", approvedPeers) + peers, firewallRules := account.GetPeerConnectionResources(context.Background(), account.Peers["peerB"], approvedPeers) assert.Contains(t, peers, account.Peers["peerC"]) - epectedFirewallRules := []*types.FirewallRule{ + expectedFirewallRules := []*types.FirewallRule{ { PeerIP: "100.65.254.139", Direction: types.FirewallRuleDirectionOUT, @@ -541,19 +614,19 @@ func TestAccount_getPeersByPolicyDirect(t *testing.T) { PolicyID: "RuleSwarm", }, } - assert.Len(t, firewallRules, len(epectedFirewallRules)) - slices.SortFunc(epectedFirewallRules, sortFunc()) + assert.Len(t, firewallRules, len(expectedFirewallRules)) + slices.SortFunc(expectedFirewallRules, sortFunc()) slices.SortFunc(firewallRules, sortFunc()) for i := range firewallRules { - assert.Equal(t, epectedFirewallRules[i], firewallRules[i]) + assert.Equal(t, expectedFirewallRules[i], firewallRules[i]) } }) t.Run("check second peer map directional only", func(t *testing.T) { - peers, firewallRules := account.GetPeerConnectionResources(context.Background(), "peerC", approvedPeers) + peers, firewallRules := account.GetPeerConnectionResources(context.Background(), account.Peers["peerC"], approvedPeers) assert.Contains(t, peers, account.Peers["peerB"]) - epectedFirewallRules := []*types.FirewallRule{ + expectedFirewallRules := []*types.FirewallRule{ { PeerIP: "100.65.80.39", Direction: types.FirewallRuleDirectionIN, @@ -563,11 +636,11 @@ func TestAccount_getPeersByPolicyDirect(t *testing.T) { PolicyID: "RuleSwarm", }, } - assert.Len(t, firewallRules, len(epectedFirewallRules)) - slices.SortFunc(epectedFirewallRules, sortFunc()) + assert.Len(t, firewallRules, len(expectedFirewallRules)) + slices.SortFunc(expectedFirewallRules, sortFunc()) slices.SortFunc(firewallRules, sortFunc()) for i := range firewallRules { - assert.Equal(t, epectedFirewallRules[i], firewallRules[i]) + assert.Equal(t, expectedFirewallRules[i], firewallRules[i]) } }) } @@ -748,7 +821,7 @@ func TestAccount_getPeersByPolicyPostureChecks(t *testing.T) { t.Run("verify peer's network map with default group peer list", func(t *testing.T) { // peerB doesn't fulfill the NB posture check but is included in the destination group Swarm, // will establish a connection with all source peers satisfying the NB posture check. - peers, firewallRules := account.GetPeerConnectionResources(context.Background(), "peerB", approvedPeers) + peers, firewallRules := account.GetPeerConnectionResources(context.Background(), account.Peers["peerB"], approvedPeers) assert.Len(t, peers, 4) assert.Len(t, firewallRules, 4) assert.Contains(t, peers, account.Peers["peerA"]) @@ -758,7 +831,7 @@ func TestAccount_getPeersByPolicyPostureChecks(t *testing.T) { // peerC satisfy the NB posture check, should establish connection to all destination group peer's // We expect a single permissive firewall rule which all outgoing connections - peers, firewallRules = account.GetPeerConnectionResources(context.Background(), "peerC", approvedPeers) + peers, firewallRules = account.GetPeerConnectionResources(context.Background(), account.Peers["peerC"], approvedPeers) assert.Len(t, peers, len(account.Groups["GroupSwarm"].Peers)) assert.Len(t, firewallRules, 1) expectedFirewallRules := []*types.FirewallRule{ @@ -775,7 +848,7 @@ func TestAccount_getPeersByPolicyPostureChecks(t *testing.T) { // peerE doesn't fulfill the NB posture check and exists in only destination group Swarm, // all source group peers satisfying the NB posture check should establish connection - peers, firewallRules = account.GetPeerConnectionResources(context.Background(), "peerE", approvedPeers) + peers, firewallRules = account.GetPeerConnectionResources(context.Background(), account.Peers["peerE"], approvedPeers) assert.Len(t, peers, 4) assert.Len(t, firewallRules, 4) assert.Contains(t, peers, account.Peers["peerA"]) @@ -785,7 +858,7 @@ func TestAccount_getPeersByPolicyPostureChecks(t *testing.T) { // peerI doesn't fulfill the OS version posture check and exists in only destination group Swarm, // all source group peers satisfying the NB posture check should establish connection - peers, firewallRules = account.GetPeerConnectionResources(context.Background(), "peerI", approvedPeers) + peers, firewallRules = account.GetPeerConnectionResources(context.Background(), account.Peers["peerI"], approvedPeers) assert.Len(t, peers, 4) assert.Len(t, firewallRules, 4) assert.Contains(t, peers, account.Peers["peerA"]) @@ -800,19 +873,19 @@ func TestAccount_getPeersByPolicyPostureChecks(t *testing.T) { // peerB doesn't satisfy the NB posture check, and doesn't exist in destination group peer's // no connection should be established to any peer of destination group - peers, firewallRules := account.GetPeerConnectionResources(context.Background(), "peerB", approvedPeers) + peers, firewallRules := account.GetPeerConnectionResources(context.Background(), account.Peers["peerB"], approvedPeers) assert.Len(t, peers, 0) assert.Len(t, firewallRules, 0) // peerI doesn't satisfy the OS version posture check, and doesn't exist in destination group peer's // no connection should be established to any peer of destination group - peers, firewallRules = account.GetPeerConnectionResources(context.Background(), "peerI", approvedPeers) + peers, firewallRules = account.GetPeerConnectionResources(context.Background(), account.Peers["peerI"], approvedPeers) assert.Len(t, peers, 0) assert.Len(t, firewallRules, 0) // peerC satisfy the NB posture check, should establish connection to all destination group peer's // We expect a single permissive firewall rule which all outgoing connections - peers, firewallRules = account.GetPeerConnectionResources(context.Background(), "peerC", approvedPeers) + peers, firewallRules = account.GetPeerConnectionResources(context.Background(), account.Peers["peerC"], approvedPeers) assert.Len(t, peers, len(account.Groups["GroupSwarm"].Peers)) assert.Len(t, firewallRules, len(account.Groups["GroupSwarm"].Peers)) @@ -827,14 +900,14 @@ func TestAccount_getPeersByPolicyPostureChecks(t *testing.T) { // peerE doesn't fulfill the NB posture check and exists in only destination group Swarm, // all source group peers satisfying the NB posture check should establish connection - peers, firewallRules = account.GetPeerConnectionResources(context.Background(), "peerE", approvedPeers) + peers, firewallRules = account.GetPeerConnectionResources(context.Background(), account.Peers["peerE"], approvedPeers) assert.Len(t, peers, 3) assert.Len(t, firewallRules, 3) assert.Contains(t, peers, account.Peers["peerA"]) assert.Contains(t, peers, account.Peers["peerC"]) assert.Contains(t, peers, account.Peers["peerD"]) - peers, firewallRules = account.GetPeerConnectionResources(context.Background(), "peerA", approvedPeers) + peers, firewallRules = account.GetPeerConnectionResources(context.Background(), account.Peers["peerA"], approvedPeers) assert.Len(t, peers, 5) // assert peers from Group Swarm assert.Contains(t, peers, account.Peers["peerD"]) diff --git a/management/server/posture/nb_version.go b/management/server/posture/nb_version.go index e98e8e795..33bf01ad1 100644 --- a/management/server/posture/nb_version.go +++ b/management/server/posture/nb_version.go @@ -24,20 +24,12 @@ func sanitizeVersion(version string) string { } func (n *NBVersionCheck) Check(ctx context.Context, peer nbpeer.Peer) (bool, error) { - peerVersion := sanitizeVersion(peer.Meta.WtVersion) - minVersion := sanitizeVersion(n.MinVersion) - - peerNBVersion, err := version.NewVersion(peerVersion) + meetsMin, err := MeetsMinVersion(n.MinVersion, peer.Meta.WtVersion) if err != nil { return false, err } - constraints, err := version.NewConstraint(">= " + minVersion) - if err != nil { - return false, err - } - - if constraints.Check(peerNBVersion) { + if meetsMin { return true, nil } @@ -60,3 +52,21 @@ func (n *NBVersionCheck) Validate() error { } return nil } + +// MeetsMinVersion checks if the peer's version meets or exceeds the minimum required version +func MeetsMinVersion(minVer, peerVer string) (bool, error) { + peerVer = sanitizeVersion(peerVer) + minVer = sanitizeVersion(minVer) + + peerNBVer, err := version.NewVersion(peerVer) + if err != nil { + return false, err + } + + constraints, err := version.NewConstraint(">= " + minVer) + if err != nil { + return false, err + } + + return constraints.Check(peerNBVer), nil +} diff --git a/management/server/posture/nb_version_test.go b/management/server/posture/nb_version_test.go index 1bf485453..d3478afc2 100644 --- a/management/server/posture/nb_version_test.go +++ b/management/server/posture/nb_version_test.go @@ -139,3 +139,68 @@ func TestNBVersionCheck_Validate(t *testing.T) { }) } } + +func TestMeetsMinVersion(t *testing.T) { + tests := []struct { + name string + minVer string + peerVer string + want bool + wantErr bool + }{ + { + name: "Peer version greater than min version", + minVer: "0.26.0", + peerVer: "0.60.1", + want: true, + wantErr: false, + }, + { + name: "Peer version equals min version", + minVer: "1.0.0", + peerVer: "1.0.0", + want: true, + wantErr: false, + }, + { + name: "Peer version less than min version", + minVer: "1.0.0", + peerVer: "0.9.9", + want: false, + wantErr: false, + }, + { + name: "Peer version with pre-release tag greater than min version", + minVer: "1.0.0", + peerVer: "1.0.1-alpha", + want: true, + wantErr: false, + }, + { + name: "Invalid peer version format", + minVer: "1.0.0", + peerVer: "dev", + want: false, + wantErr: true, + }, + { + name: "Invalid min version format", + minVer: "invalid.version", + peerVer: "1.0.0", + want: false, + wantErr: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got, err := MeetsMinVersion(tt.minVer, tt.peerVer) + if tt.wantErr { + assert.Error(t, err) + } else { + assert.NoError(t, err) + } + assert.Equal(t, tt.want, got) + }) + } +} diff --git a/management/server/route.go b/management/server/route.go index 02755a708..32ff39977 100644 --- a/management/server/route.go +++ b/management/server/route.go @@ -4,19 +4,19 @@ import ( "context" "fmt" "net/netip" + "slices" "unicode/utf8" "github.com/rs/xid" - "github.com/netbirdio/netbird/management/server/permissions/modules" - "github.com/netbirdio/netbird/management/server/permissions/operations" - "github.com/netbirdio/netbird/management/server/store" - "github.com/netbirdio/netbird/management/server/types" - "github.com/netbirdio/netbird/management/domain" "github.com/netbirdio/netbird/management/proto" "github.com/netbirdio/netbird/management/server/activity" + "github.com/netbirdio/netbird/management/server/permissions/modules" + "github.com/netbirdio/netbird/management/server/permissions/operations" "github.com/netbirdio/netbird/management/server/status" + "github.com/netbirdio/netbird/management/server/store" + "github.com/netbirdio/netbird/management/server/types" "github.com/netbirdio/netbird/route" ) @@ -30,13 +30,19 @@ func (am *DefaultAccountManager) GetRoute(ctx context.Context, accountID string, return nil, status.NewPermissionDeniedError() } - return am.Store.GetRouteByID(ctx, store.LockingStrengthShare, string(routeID), accountID) + return am.Store.GetRouteByID(ctx, store.LockingStrengthShare, accountID, string(routeID)) } // checkRoutePrefixOrDomainsExistForPeers checks if a route with a given prefix exists for a single peer or multiple peer groups. -func (am *DefaultAccountManager) checkRoutePrefixOrDomainsExistForPeers(account *types.Account, peerID string, routeID route.ID, peerGroupIDs []string, prefix netip.Prefix, domains domain.List) error { +func checkRoutePrefixOrDomainsExistForPeers(ctx context.Context, transaction store.Store, accountID string, checkRoute *route.Route, groupsMap map[string]*types.Group) error { // routes can have both peer and peer_groups - routesWithPrefix := account.GetRoutesByPrefixOrDomains(prefix, domains) + prefix := checkRoute.Network + domains := checkRoute.Domains + + routesWithPrefix, err := getRoutesByPrefixOrDomains(ctx, transaction, accountID, prefix, domains) + if err != nil { + return err + } // lets remember all the peers and the peer groups from routesWithPrefix seenPeers := make(map[string]bool) @@ -45,18 +51,24 @@ func (am *DefaultAccountManager) checkRoutePrefixOrDomainsExistForPeers(account for _, prefixRoute := range routesWithPrefix { // we skip route(s) with the same network ID as we want to allow updating of the existing route // when creating a new route routeID is newly generated so nothing will be skipped - if routeID == prefixRoute.ID { + if checkRoute.ID == prefixRoute.ID { continue } if prefixRoute.Peer != "" { seenPeers[string(prefixRoute.ID)] = true } + + peerGroupsMap, err := transaction.GetGroupsByIDs(ctx, store.LockingStrengthShare, accountID, prefixRoute.PeerGroups) + if err != nil { + return err + } + for _, groupID := range prefixRoute.PeerGroups { seenPeerGroups[groupID] = true - group := account.GetGroup(groupID) - if group == nil { + group, ok := peerGroupsMap[groupID] + if !ok || group == nil { return status.Errorf( status.InvalidArgument, "failed to add route with %s - peer group %s doesn't exist", getRouteDescriptor(prefix, domains), groupID, @@ -69,12 +81,13 @@ func (am *DefaultAccountManager) checkRoutePrefixOrDomainsExistForPeers(account } } - if peerID != "" { + if peerID := checkRoute.Peer; peerID != "" { // check that peerID exists and is not in any route as single peer or part of the group - peer := account.GetPeer(peerID) - if peer == nil { + _, err = transaction.GetPeerByID(context.Background(), store.LockingStrengthShare, accountID, peerID) + if err != nil { return status.Errorf(status.InvalidArgument, "peer with ID %s not found", peerID) } + if _, ok := seenPeers[peerID]; ok { return status.Errorf(status.AlreadyExists, "failed to add route with %s - peer %s already has this route", getRouteDescriptor(prefix, domains), peerID) @@ -82,9 +95,8 @@ func (am *DefaultAccountManager) checkRoutePrefixOrDomainsExistForPeers(account } // check that peerGroupIDs are not in any route peerGroups list - for _, groupID := range peerGroupIDs { - group := account.GetGroup(groupID) // we validated the group existence before entering this function, no need to check again. - + for _, groupID := range checkRoute.PeerGroups { + group := groupsMap[groupID] // we validated the group existence before entering this function, no need to check again. if _, ok := seenPeerGroups[groupID]; ok { return status.Errorf( status.AlreadyExists, "failed to add route with %s - peer group %s already has this route", @@ -92,12 +104,18 @@ func (am *DefaultAccountManager) checkRoutePrefixOrDomainsExistForPeers(account } // check that the peers from peerGroupIDs groups are not the same peers we saw in routesWithPrefix + peersMap, err := transaction.GetPeersByIDs(ctx, store.LockingStrengthShare, accountID, group.Peers) + if err != nil { + return err + } + for _, id := range group.Peers { if _, ok := seenPeers[id]; ok { - peer := account.GetPeer(id) - if peer == nil { - return status.Errorf(status.InvalidArgument, "peer with ID %s not found", peerID) + peer, ok := peersMap[id] + if !ok || peer == nil { + return status.Errorf(status.InvalidArgument, "peer with ID %s not found", id) } + return status.Errorf(status.AlreadyExists, "failed to add route with %s - peer %s from the group %s already has this route", getRouteDescriptor(prefix, domains), peer.Name, group.Name) @@ -128,97 +146,58 @@ func (am *DefaultAccountManager) CreateRoute(ctx context.Context, accountID stri return nil, status.NewPermissionDeniedError() } - account, err := am.Store.GetAccount(ctx, accountID) - if err != nil { - return nil, err - } - if len(domains) > 0 && prefix.IsValid() { return nil, status.Errorf(status.InvalidArgument, "domains and network should not be provided at the same time") } - if len(domains) == 0 && !prefix.IsValid() { - return nil, status.Errorf(status.InvalidArgument, "invalid Prefix") - } + var newRoute *route.Route + var updateAccountPeers bool - if len(domains) > 0 { - prefix = getPlaceholderIP() - } - - if peerID != "" && len(peerGroupIDs) != 0 { - return nil, status.Errorf( - status.InvalidArgument, - "peer with ID %s and peers group %s should not be provided at the same time", - peerID, peerGroupIDs) - } - - var newRoute route.Route - newRoute.ID = route.ID(xid.New().String()) - - if len(peerGroupIDs) > 0 { - err = validateGroups(peerGroupIDs, account.Groups) - if err != nil { - return nil, err + err = am.Store.ExecuteInTransaction(ctx, func(transaction store.Store) error { + newRoute = &route.Route{ + ID: route.ID(xid.New().String()), + AccountID: accountID, + Network: prefix, + Domains: domains, + KeepRoute: keepRoute, + NetID: netID, + Description: description, + Peer: peerID, + PeerGroups: peerGroupIDs, + NetworkType: networkType, + Masquerade: masquerade, + Metric: metric, + Enabled: enabled, + Groups: groups, + AccessControlGroups: accessControlGroupIDs, } - } - if len(accessControlGroupIDs) > 0 { - err = validateGroups(accessControlGroupIDs, account.Groups) - if err != nil { - return nil, err + if err = validateRoute(ctx, transaction, accountID, newRoute); err != nil { + return err } - } - err = am.checkRoutePrefixOrDomainsExistForPeers(account, peerID, newRoute.ID, peerGroupIDs, prefix, domains) + updateAccountPeers, err = areRouteChangesAffectPeers(ctx, transaction, newRoute) + if err != nil { + return err + } + + if err = transaction.IncrementNetworkSerial(ctx, store.LockingStrengthUpdate, accountID); err != nil { + return err + } + + return transaction.SaveRoute(ctx, store.LockingStrengthUpdate, newRoute) + }) if err != nil { return nil, err } - if metric < route.MinMetric || metric > route.MaxMetric { - return nil, status.Errorf(status.InvalidArgument, "metric should be between %d and %d", route.MinMetric, route.MaxMetric) - } - - if utf8.RuneCountInString(string(netID)) > route.MaxNetIDChar || netID == "" { - return nil, status.Errorf(status.InvalidArgument, "identifier should be between 1 and %d", route.MaxNetIDChar) - } - - err = validateGroups(groups, account.Groups) - if err != nil { - return nil, err - } - - newRoute.Peer = peerID - newRoute.PeerGroups = peerGroupIDs - newRoute.Network = prefix - newRoute.Domains = domains - newRoute.NetworkType = networkType - newRoute.Description = description - newRoute.NetID = netID - newRoute.Masquerade = masquerade - newRoute.Metric = metric - newRoute.Enabled = enabled - newRoute.Groups = groups - newRoute.KeepRoute = keepRoute - newRoute.AccessControlGroups = accessControlGroupIDs - - if account.Routes == nil { - account.Routes = make(map[route.ID]*route.Route) - } - - account.Routes[newRoute.ID] = &newRoute - - account.Network.IncSerial() - if err = am.Store.SaveAccount(ctx, account); err != nil { - return nil, err - } - - if am.isRouteChangeAffectPeers(account, &newRoute) { - am.UpdateAccountPeers(ctx, accountID) - } - am.StoreEvent(ctx, userID, string(newRoute.ID), accountID, activity.RouteCreated, newRoute.EventMeta()) - return &newRoute, nil + if updateAccountPeers { + am.UpdateAccountPeers(ctx, accountID) + } + + return newRoute, nil } // SaveRoute saves route @@ -226,6 +205,115 @@ func (am *DefaultAccountManager) SaveRoute(ctx context.Context, accountID, userI unlock := am.Store.AcquireWriteLockByUID(ctx, accountID) defer unlock() + allowed, err := am.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Routes, operations.Update) + if err != nil { + return status.NewPermissionValidationError(err) + } + if !allowed { + return status.NewPermissionDeniedError() + } + + var oldRoute *route.Route + var oldRouteAffectsPeers bool + var newRouteAffectsPeers bool + + err = am.Store.ExecuteInTransaction(ctx, func(transaction store.Store) error { + if err = validateRoute(ctx, transaction, accountID, routeToSave); err != nil { + return err + } + + oldRoute, err = transaction.GetRouteByID(ctx, store.LockingStrengthUpdate, accountID, string(routeToSave.ID)) + if err != nil { + return err + } + + oldRouteAffectsPeers, err = areRouteChangesAffectPeers(ctx, transaction, oldRoute) + if err != nil { + return err + } + + newRouteAffectsPeers, err = areRouteChangesAffectPeers(ctx, transaction, routeToSave) + if err != nil { + return err + } + routeToSave.AccountID = accountID + + if err = transaction.IncrementNetworkSerial(ctx, store.LockingStrengthUpdate, accountID); err != nil { + return err + } + + return transaction.SaveRoute(ctx, store.LockingStrengthUpdate, routeToSave) + }) + if err != nil { + return err + } + + am.StoreEvent(ctx, userID, string(routeToSave.ID), accountID, activity.RouteUpdated, routeToSave.EventMeta()) + + if oldRouteAffectsPeers || newRouteAffectsPeers { + am.UpdateAccountPeers(ctx, accountID) + } + + return nil +} + +// DeleteRoute deletes route with routeID +func (am *DefaultAccountManager) DeleteRoute(ctx context.Context, accountID string, routeID route.ID, userID string) error { + unlock := am.Store.AcquireWriteLockByUID(ctx, accountID) + defer unlock() + + allowed, err := am.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Routes, operations.Delete) + if err != nil { + return status.NewPermissionValidationError(err) + } + if !allowed { + return status.NewPermissionDeniedError() + } + + var route *route.Route + var updateAccountPeers bool + + err = am.Store.ExecuteInTransaction(ctx, func(transaction store.Store) error { + route, err = transaction.GetRouteByID(ctx, store.LockingStrengthUpdate, accountID, string(routeID)) + if err != nil { + return err + } + + updateAccountPeers, err = areRouteChangesAffectPeers(ctx, transaction, route) + if err != nil { + return err + } + + if err = transaction.IncrementNetworkSerial(ctx, store.LockingStrengthUpdate, accountID); err != nil { + return err + } + + return transaction.DeleteRoute(ctx, store.LockingStrengthUpdate, accountID, string(routeID)) + }) + + am.StoreEvent(ctx, userID, string(route.ID), accountID, activity.RouteRemoved, route.EventMeta()) + + if updateAccountPeers { + am.UpdateAccountPeers(ctx, accountID) + } + + return nil +} + +// ListRoutes returns a list of routes from account +func (am *DefaultAccountManager) ListRoutes(ctx context.Context, accountID, userID string) ([]*route.Route, error) { + allowed, err := am.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Routes, operations.Read) + if err != nil { + return nil, status.NewPermissionValidationError(err) + } + if !allowed { + return nil, status.NewPermissionDeniedError() + } + + return am.Store.GetAccountRoutes(ctx, store.LockingStrengthShare, accountID) +} + +func validateRoute(ctx context.Context, transaction store.Store, accountID string, routeToSave *route.Route) error { if routeToSave == nil { return status.Errorf(status.InvalidArgument, "route provided is nil") } @@ -238,19 +326,6 @@ func (am *DefaultAccountManager) SaveRoute(ctx context.Context, accountID, userI return status.Errorf(status.InvalidArgument, "identifier should be between 1 and %d", route.MaxNetIDChar) } - allowed, err := am.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Routes, operations.Update) - if err != nil { - return status.NewPermissionValidationError(err) - } - if !allowed { - return status.NewPermissionDeniedError() - } - - account, err := am.Store.GetAccount(ctx, accountID) - if err != nil { - return err - } - if len(routeToSave.Domains) > 0 && routeToSave.Network.IsValid() { return status.Errorf(status.InvalidArgument, "domains and network should not be provided at the same time") } @@ -267,96 +342,39 @@ func (am *DefaultAccountManager) SaveRoute(ctx context.Context, accountID, userI return status.Errorf(status.InvalidArgument, "peer with ID and peer groups should not be provided at the same time") } + groupsMap, err := validateRouteGroups(ctx, transaction, accountID, routeToSave) + if err != nil { + return err + } + + return checkRoutePrefixOrDomainsExistForPeers(ctx, transaction, accountID, routeToSave, groupsMap) +} + +// validateRouteGroups validates the route groups and returns the validated groups map. +func validateRouteGroups(ctx context.Context, transaction store.Store, accountID string, routeToSave *route.Route) (map[string]*types.Group, error) { + groupsToValidate := slices.Concat(routeToSave.Groups, routeToSave.PeerGroups, routeToSave.AccessControlGroups) + groupsMap, err := transaction.GetGroupsByIDs(ctx, store.LockingStrengthShare, accountID, groupsToValidate) + if err != nil { + return nil, err + } + if len(routeToSave.PeerGroups) > 0 { - err = validateGroups(routeToSave.PeerGroups, account.Groups) - if err != nil { - return err + if err = validateGroups(routeToSave.PeerGroups, groupsMap); err != nil { + return nil, err } } if len(routeToSave.AccessControlGroups) > 0 { - err = validateGroups(routeToSave.AccessControlGroups, account.Groups) - if err != nil { - return err + if err = validateGroups(routeToSave.AccessControlGroups, groupsMap); err != nil { + return nil, err } } - err = am.checkRoutePrefixOrDomainsExistForPeers(account, routeToSave.Peer, routeToSave.ID, routeToSave.Copy().PeerGroups, routeToSave.Network, routeToSave.Domains) - if err != nil { - return err + if err = validateGroups(routeToSave.Groups, groupsMap); err != nil { + return nil, err } - err = validateGroups(routeToSave.Groups, account.Groups) - if err != nil { - return err - } - - oldRoute := account.Routes[routeToSave.ID] - account.Routes[routeToSave.ID] = routeToSave - - account.Network.IncSerial() - if err = am.Store.SaveAccount(ctx, account); err != nil { - return err - } - - if am.isRouteChangeAffectPeers(account, oldRoute) || am.isRouteChangeAffectPeers(account, routeToSave) { - am.UpdateAccountPeers(ctx, accountID) - } - - am.StoreEvent(ctx, userID, string(routeToSave.ID), accountID, activity.RouteUpdated, routeToSave.EventMeta()) - - return nil -} - -// DeleteRoute deletes route with routeID -func (am *DefaultAccountManager) DeleteRoute(ctx context.Context, accountID string, routeID route.ID, userID string) error { - unlock := am.Store.AcquireWriteLockByUID(ctx, accountID) - defer unlock() - - allowed, err := am.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Routes, operations.Delete) - if err != nil { - return status.NewPermissionValidationError(err) - } - if !allowed { - return status.NewPermissionDeniedError() - } - - account, err := am.Store.GetAccount(ctx, accountID) - if err != nil { - return err - } - - routy := account.Routes[routeID] - if routy == nil { - return status.Errorf(status.NotFound, "route with ID %s doesn't exist", routeID) - } - delete(account.Routes, routeID) - - account.Network.IncSerial() - if err = am.Store.SaveAccount(ctx, account); err != nil { - return err - } - - am.StoreEvent(ctx, userID, string(routy.ID), accountID, activity.RouteRemoved, routy.EventMeta()) - - if am.isRouteChangeAffectPeers(account, routy) { - am.UpdateAccountPeers(ctx, accountID) - } - - return nil -} - -// ListRoutes returns a list of routes from account -func (am *DefaultAccountManager) ListRoutes(ctx context.Context, accountID, userID string) ([]*route.Route, error) { - allowed, err := am.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Routes, operations.Read) - if err != nil { - return nil, status.NewPermissionValidationError(err) - } - if !allowed { - return nil, status.NewPermissionDeniedError() - } - - return am.Store.GetAccountRoutes(ctx, store.LockingStrengthShare, accountID) + return groupsMap, nil } func toProtocolRoute(route *route.Route) *proto.Route { @@ -455,8 +473,40 @@ func getProtoPortInfo(rule *types.RouteFirewallRule) *proto.PortInfo { return &portInfo } -// isRouteChangeAffectPeers checks if a given route affects peers by determining -// if it has a routing peer, distribution, or peer groups that include peers -func (am *DefaultAccountManager) isRouteChangeAffectPeers(account *types.Account, route *route.Route) bool { - return am.anyGroupHasPeers(account, route.Groups) || am.anyGroupHasPeers(account, route.PeerGroups) || route.Peer != "" +// areRouteChangesAffectPeers checks if a given route affects peers by determining +// if it has a routing peer, distribution, or peer groups that include peers. +func areRouteChangesAffectPeers(ctx context.Context, transaction store.Store, route *route.Route) (bool, error) { + if route.Peer != "" { + return true, nil + } + + hasPeers, err := anyGroupHasPeersOrResources(ctx, transaction, route.AccountID, route.Groups) + if err != nil { + return false, err + } + + if hasPeers { + return true, nil + } + + return anyGroupHasPeersOrResources(ctx, transaction, route.AccountID, route.PeerGroups) +} + +// GetRoutesByPrefixOrDomains return list of routes by account and route prefix +func getRoutesByPrefixOrDomains(ctx context.Context, transaction store.Store, accountID string, prefix netip.Prefix, domains domain.List) ([]*route.Route, error) { + accountRoutes, err := transaction.GetAccountRoutes(ctx, store.LockingStrengthShare, accountID) + if err != nil { + return nil, err + } + + routes := make([]*route.Route, 0) + for _, r := range accountRoutes { + dynamic := r.IsDynamic() + if dynamic && r.Domains.PunycodeString() == domains.PunycodeString() || + !dynamic && r.Network.String() == prefix.String() { + routes = append(routes, r) + } + } + + return routes, nil } diff --git a/management/server/status/error.go b/management/server/status/error.go index 8fbe0bad9..5a6f6d1a7 100644 --- a/management/server/status/error.go +++ b/management/server/status/error.go @@ -227,3 +227,7 @@ func NewUserRoleNotFoundError(role string) error { func NewOperationNotFoundError(operation operations.Operation) error { return Errorf(NotFound, "operation: %s not found", operation) } + +func NewRouteNotFoundError(routeID string) error { + return Errorf(NotFound, "route: %s not found", routeID) +} diff --git a/management/server/store/sql_store.go b/management/server/store/sql_store.go index d81890775..a6c4d56bf 100644 --- a/management/server/store/sql_store.go +++ b/management/server/store/sql_store.go @@ -23,8 +23,6 @@ import ( "gorm.io/gorm/clause" "gorm.io/gorm/logger" - "github.com/netbirdio/netbird/management/server/util" - nbdns "github.com/netbirdio/netbird/dns" resourceTypes "github.com/netbirdio/netbird/management/server/networks/resources/types" routerTypes "github.com/netbirdio/netbird/management/server/networks/routers/types" @@ -34,6 +32,7 @@ import ( "github.com/netbirdio/netbird/management/server/status" "github.com/netbirdio/netbird/management/server/telemetry" "github.com/netbirdio/netbird/management/server/types" + "github.com/netbirdio/netbird/management/server/util" "github.com/netbirdio/netbird/route" ) @@ -1968,12 +1967,58 @@ func (s *SqlStore) DeletePostureChecks(ctx context.Context, lockStrength Locking // GetAccountRoutes retrieves network routes for an account. func (s *SqlStore) GetAccountRoutes(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*route.Route, error) { - return getRecords[*route.Route](s.db, lockStrength, accountID) + var routes []*route.Route + result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}). + Find(&routes, accountIDCondition, accountID) + if err := result.Error; err != nil { + log.WithContext(ctx).Errorf("failed to get routes from the store: %s", err) + return nil, status.Errorf(status.Internal, "failed to get routes from store") + } + + return routes, nil } // GetRouteByID retrieves a route by its ID and account ID. -func (s *SqlStore) GetRouteByID(ctx context.Context, lockStrength LockingStrength, routeID string, accountID string) (*route.Route, error) { - return getRecordByID[route.Route](s.db, lockStrength, routeID, accountID) +func (s *SqlStore) GetRouteByID(ctx context.Context, lockStrength LockingStrength, accountID string, routeID string) (*route.Route, error) { + var route *route.Route + result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}). + First(&route, accountAndIDQueryCondition, accountID, routeID) + if err := result.Error; err != nil { + if errors.Is(err, gorm.ErrRecordNotFound) { + return nil, status.NewRouteNotFoundError(routeID) + } + log.WithContext(ctx).Errorf("failed to get route from the store: %s", err) + return nil, status.Errorf(status.Internal, "failed to get route from store") + } + + return route, nil +} + +// SaveRoute saves a route to the database. +func (s *SqlStore) SaveRoute(ctx context.Context, lockStrength LockingStrength, route *route.Route) error { + result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}).Save(route) + if err := result.Error; err != nil { + log.WithContext(ctx).Errorf("failed to save route to the store: %s", err) + return status.Errorf(status.Internal, "failed to save route to store") + } + + return nil +} + +// DeleteRoute deletes a route from the database. +func (s *SqlStore) DeleteRoute(ctx context.Context, lockStrength LockingStrength, accountID, routeID string) error { + result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}). + Delete(&route.Route{}, accountAndIDQueryCondition, accountID, routeID) + if err := result.Error; err != nil { + log.WithContext(ctx).Errorf("failed to delete route from the store: %s", err) + return status.Errorf(status.Internal, "failed to delete route from store") + } + + if result.RowsAffected == 0 { + return status.NewRouteNotFoundError(routeID) + } + + return nil } // GetAccountSetupKeys retrieves setup keys for an account. @@ -2104,49 +2149,6 @@ func (s *SqlStore) DeleteNameServerGroup(ctx context.Context, lockStrength Locki return nil } -// getRecords retrieves records from the database based on the account ID. -func getRecords[T any](db *gorm.DB, lockStrength LockingStrength, accountID string) ([]T, error) { - tx := db - if lockStrength != LockingStrengthNone { - tx = tx.Clauses(clause.Locking{Strength: string(lockStrength)}) - } - - var record []T - - result := tx.Find(&record, accountIDCondition, accountID) - if err := result.Error; err != nil { - parts := strings.Split(fmt.Sprintf("%T", record), ".") - recordType := parts[len(parts)-1] - - return nil, status.Errorf(status.Internal, "failed to get account %ss from store: %v", recordType, err) - } - - return record, nil -} - -// getRecordByID retrieves a record by its ID and account ID from the database. -func getRecordByID[T any](db *gorm.DB, lockStrength LockingStrength, recordID, accountID string) (*T, error) { - tx := db - if lockStrength != LockingStrengthNone { - tx = tx.Clauses(clause.Locking{Strength: string(lockStrength)}) - } - - var record T - - result := tx.Clauses(clause.Locking{Strength: string(lockStrength)}). - First(&record, accountAndIDQueryCondition, accountID, recordID) - if err := result.Error; err != nil { - parts := strings.Split(fmt.Sprintf("%T", record), ".") - recordType := parts[len(parts)-1] - - if errors.Is(result.Error, gorm.ErrRecordNotFound) { - return nil, status.Errorf(status.NotFound, "%s not found", recordType) - } - return nil, status.Errorf(status.Internal, "failed to get %s from store: %v", recordType, err) - } - return &record, nil -} - // SaveDNSSettings saves the DNS settings to the store. func (s *SqlStore) SaveDNSSettings(ctx context.Context, lockStrength LockingStrength, accountID string, settings *types.DNSSettings) error { result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}).Model(&types.Account{}). diff --git a/management/server/store/sql_store_test.go b/management/server/store/sql_store_test.go index 2c1f5f8e6..fab9048e5 100644 --- a/management/server/store/sql_store_test.go +++ b/management/server/store/sql_store_test.go @@ -19,21 +19,17 @@ import ( "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" - "github.com/netbirdio/netbird/management/server/util" - nbdns "github.com/netbirdio/netbird/dns" resourceTypes "github.com/netbirdio/netbird/management/server/networks/resources/types" routerTypes "github.com/netbirdio/netbird/management/server/networks/routers/types" networkTypes "github.com/netbirdio/netbird/management/server/networks/types" - "github.com/netbirdio/netbird/management/server/posture" - "github.com/netbirdio/netbird/management/server/types" - - route2 "github.com/netbirdio/netbird/route" - - "github.com/netbirdio/netbird/management/server/status" - nbpeer "github.com/netbirdio/netbird/management/server/peer" + "github.com/netbirdio/netbird/management/server/posture" + "github.com/netbirdio/netbird/management/server/status" + "github.com/netbirdio/netbird/management/server/types" + "github.com/netbirdio/netbird/management/server/util" nbroute "github.com/netbirdio/netbird/route" + route2 "github.com/netbirdio/netbird/route" ) func runTestForAllEngines(t *testing.T, testDataFile string, f func(t *testing.T, store Store)) { @@ -3247,6 +3243,132 @@ func TestSqlStore_SaveGroups_LargeBatch(t *testing.T) { require.NoError(t, err) require.Equal(t, 8003, len(accountGroups)) } +func TestSqlStore_GetAccountRoutes(t *testing.T) { + store, cleanup, err := NewTestStoreFromSQL(context.Background(), "../testdata/extended-store.sql", t.TempDir()) + t.Cleanup(cleanup) + require.NoError(t, err) + + tests := []struct { + name string + accountID string + expectedCount int + }{ + { + name: "retrieve routes by existing account ID", + accountID: "bf1c8084-ba50-4ce7-9439-34653001fc3b", + expectedCount: 1, + }, + { + name: "non-existing account ID", + accountID: "nonexistent", + expectedCount: 0, + }, + { + name: "empty account ID", + accountID: "", + expectedCount: 0, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + routes, err := store.GetAccountRoutes(context.Background(), LockingStrengthShare, tt.accountID) + require.NoError(t, err) + require.Len(t, routes, tt.expectedCount) + }) + } +} + +func TestSqlStore_GetRouteByID(t *testing.T) { + store, cleanup, err := NewTestStoreFromSQL(context.Background(), "../testdata/extended-store.sql", t.TempDir()) + t.Cleanup(cleanup) + require.NoError(t, err) + + accountID := "bf1c8084-ba50-4ce7-9439-34653001fc3b" + tests := []struct { + name string + routeID string + expectError bool + }{ + { + name: "retrieve existing route", + routeID: "ct03t427qv97vmtmglog", + expectError: false, + }, + { + name: "retrieve non-existing route", + routeID: "non-existing", + expectError: true, + }, + { + name: "retrieve with empty route ID", + routeID: "", + expectError: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + route, err := store.GetRouteByID(context.Background(), LockingStrengthShare, accountID, tt.routeID) + if tt.expectError { + require.Error(t, err) + sErr, ok := status.FromError(err) + require.True(t, ok) + require.Equal(t, sErr.Type(), status.NotFound) + require.Nil(t, route) + } else { + require.NoError(t, err) + require.NotNil(t, route) + require.Equal(t, tt.routeID, string(route.ID)) + } + }) + } +} + +func TestSqlStore_SaveRoute(t *testing.T) { + store, cleanup, err := NewTestStoreFromSQL(context.Background(), "../testdata/extended-store.sql", t.TempDir()) + t.Cleanup(cleanup) + require.NoError(t, err) + + accountID := "bf1c8084-ba50-4ce7-9439-34653001fc3b" + + route := &route2.Route{ + ID: "route-id", + AccountID: accountID, + Network: netip.MustParsePrefix("10.10.0.0/16"), + NetID: "netID", + PeerGroups: []string{"routeA"}, + NetworkType: route2.IPv4Network, + Masquerade: true, + Metric: 9999, + Enabled: true, + Groups: []string{"groupA"}, + AccessControlGroups: []string{}, + } + err = store.SaveRoute(context.Background(), LockingStrengthUpdate, route) + require.NoError(t, err) + + saveRoute, err := store.GetRouteByID(context.Background(), LockingStrengthShare, accountID, string(route.ID)) + require.NoError(t, err) + require.Equal(t, route, saveRoute) + +} + +func TestSqlStore_DeleteRoute(t *testing.T) { + store, cleanup, err := NewTestStoreFromSQL(context.Background(), "../testdata/extended-store.sql", t.TempDir()) + t.Cleanup(cleanup) + require.NoError(t, err) + + accountID := "bf1c8084-ba50-4ce7-9439-34653001fc3b" + routeID := "ct03t427qv97vmtmglog" + + err = store.DeleteRoute(context.Background(), LockingStrengthUpdate, accountID, routeID) + require.NoError(t, err) + + route, err := store.GetRouteByID(context.Background(), LockingStrengthShare, accountID, routeID) + require.Error(t, err) + require.Nil(t, route) +} func TestSqlStore_GetAccountMeta(t *testing.T) { store, cleanup, err := NewTestStoreFromSQL(context.Background(), "../testdata/extended-store.sql", t.TempDir()) diff --git a/management/server/store/store.go b/management/server/store/store.go index c7b103454..d41379b1c 100644 --- a/management/server/store/store.go +++ b/management/server/store/store.go @@ -145,7 +145,9 @@ type Store interface { DeleteSetupKey(ctx context.Context, lockStrength LockingStrength, accountID, keyID string) error GetAccountRoutes(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*route.Route, error) - GetRouteByID(ctx context.Context, lockStrength LockingStrength, routeID string, accountID string) (*route.Route, error) + GetRouteByID(ctx context.Context, lockStrength LockingStrength, accountID, routeID string) (*route.Route, error) + SaveRoute(ctx context.Context, lockStrength LockingStrength, route *route.Route) error + DeleteRoute(ctx context.Context, lockStrength LockingStrength, accountID, routeID string) error GetAccountNameServerGroups(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*dns.NameServerGroup, error) GetNameServerGroupByID(ctx context.Context, lockStrength LockingStrength, nameServerGroupID string, accountID string) (*dns.NameServerGroup, error) diff --git a/management/server/testdata/extended-store.sql b/management/server/testdata/extended-store.sql index 324bf6293..0393d1ade 100644 --- a/management/server/testdata/extended-store.sql +++ b/management/server/testdata/extended-store.sql @@ -38,4 +38,5 @@ INSERT INTO "groups" VALUES('cfefqs706sqkneg59g2g','bf1c8084-ba50-4ce7-9439-3465 INSERT INTO posture_checks VALUES('csplshq7qv948l48f7t0','NetBird Version > 0.32.0','','bf1c8084-ba50-4ce7-9439-34653001fc3b','{"NBVersionCheck":{"MinVersion":"0.31.0"}}'); INSERT INTO posture_checks VALUES('cspnllq7qv95uq1r4k90','Allow Berlin and Deny local network 172.16.1.0/24','','bf1c8084-ba50-4ce7-9439-34653001fc3b','{"GeoLocationCheck":{"Locations":[{"CountryCode":"DE","CityName":"Berlin"}],"Action":"allow"},"PeerNetworkRangeCheck":{"Action":"deny","Ranges":["172.16.1.0/24"]}}'); INSERT INTO name_server_groups VALUES('csqdelq7qv97ncu7d9t0','bf1c8084-ba50-4ce7-9439-34653001fc3b','Google DNS','Google DNS Servers','[{"IP":"8.8.8.8","NSType":1,"Port":53},{"IP":"8.8.4.4","NSType":1,"Port":53}]','["cfefqs706sqkneg59g2g"]',1,'[]',1,0); +INSERT INTO routes VALUES('ct03t427qv97vmtmglog','bf1c8084-ba50-4ce7-9439-34653001fc3b','"10.10.0.0/16"',NULL,0,'aws-eu-central-1-vpc','Production VPC in Frankfurt','ct03r5q7qv97vmtmglng',NULL,1,1,9999,1,'["cfefqs706sqkneg59g2g"]',NULL); INSERT INTO installations VALUES(1,''); diff --git a/management/server/types/account.go b/management/server/types/account.go index da230f0b2..090ba76e4 100644 --- a/management/server/types/account.go +++ b/management/server/types/account.go @@ -36,6 +36,9 @@ const ( PublicCategory = "public" PrivateCategory = "private" UnknownCategory = "unknown" + + // firewallRuleMinPortRangesVer defines the minimum peer version that supports port range rules. + firewallRuleMinPortRangesVer = "0.48.0" ) type LookupMap map[string]struct{} @@ -248,7 +251,7 @@ func (a *Account) GetPeerNetworkMap( } } - aclPeers, firewallRules := a.GetPeerConnectionResources(ctx, peerID, validatedPeersMap) + aclPeers, firewallRules := a.GetPeerConnectionResources(ctx, peer, validatedPeersMap) // exclude expired peers var peersToConnect []*nbpeer.Peer var expiredPeers []*nbpeer.Peer @@ -961,8 +964,9 @@ func (a *Account) UserGroupsRemoveFromPeers(userID string, groups ...string) map // GetPeerConnectionResources for a given peer // // This function returns the list of peers and firewall rules that are applicable to a given peer. -func (a *Account) GetPeerConnectionResources(ctx context.Context, peerID string, validatedPeersMap map[string]struct{}) ([]*nbpeer.Peer, []*FirewallRule) { - generateResources, getAccumulatedResources := a.connResourcesGenerator(ctx) +func (a *Account) GetPeerConnectionResources(ctx context.Context, peer *nbpeer.Peer, validatedPeersMap map[string]struct{}) ([]*nbpeer.Peer, []*FirewallRule) { + generateResources, getAccumulatedResources := a.connResourcesGenerator(ctx, peer) + for _, policy := range a.Policies { if !policy.Enabled { continue @@ -973,8 +977,8 @@ func (a *Account) GetPeerConnectionResources(ctx context.Context, peerID string, continue } - sourcePeers, peerInSources := a.getAllPeersFromGroups(ctx, rule.Sources, peerID, policy.SourcePostureChecks, validatedPeersMap) - destinationPeers, peerInDestinations := a.getAllPeersFromGroups(ctx, rule.Destinations, peerID, nil, validatedPeersMap) + sourcePeers, peerInSources := a.getAllPeersFromGroups(ctx, rule.Sources, peer.ID, policy.SourcePostureChecks, validatedPeersMap) + destinationPeers, peerInDestinations := a.getAllPeersFromGroups(ctx, rule.Destinations, peer.ID, nil, validatedPeersMap) if rule.Bidirectional { if peerInSources { @@ -1003,7 +1007,7 @@ func (a *Account) GetPeerConnectionResources(ctx context.Context, peerID string, // The generator function is used to generate the list of peers and firewall rules that are applicable to a given peer. // It safe to call the generator function multiple times for same peer and different rules no duplicates will be // generated. The accumulator function returns the result of all the generator calls. -func (a *Account) connResourcesGenerator(ctx context.Context) (func(*PolicyRule, []*nbpeer.Peer, int), func() ([]*nbpeer.Peer, []*FirewallRule)) { +func (a *Account) connResourcesGenerator(ctx context.Context, targetPeer *nbpeer.Peer) (func(*PolicyRule, []*nbpeer.Peer, int), func() ([]*nbpeer.Peer, []*FirewallRule)) { rulesExists := make(map[string]struct{}) peersExists := make(map[string]struct{}) rules := make([]*FirewallRule, 0) @@ -1051,17 +1055,7 @@ func (a *Account) connResourcesGenerator(ctx context.Context) (func(*PolicyRule, continue } - for _, port := range rule.Ports { - pr := fr // clone rule and add set new port - pr.Port = port - rules = append(rules, &pr) - } - - for _, portRange := range rule.PortRanges { - pr := fr - pr.PortRange = portRange - rules = append(rules, &pr) - } + rules = append(rules, expandPortsAndRanges(fr, rule, targetPeer)...) } }, func() ([]*nbpeer.Peer, []*FirewallRule) { return peers, rules @@ -1590,3 +1584,45 @@ func (a *Account) AddAllGroup() error { } return nil } + +// expandPortsAndRanges expands Ports and PortRanges of a rule into individual firewall rules +func expandPortsAndRanges(base FirewallRule, rule *PolicyRule, peer *nbpeer.Peer) []*FirewallRule { + var expanded []*FirewallRule + + if len(rule.Ports) > 0 { + for _, port := range rule.Ports { + fr := base + fr.Port = port + expanded = append(expanded, &fr) + } + return expanded + } + + supportPortRanges := peerSupportsPortRanges(peer.Meta.WtVersion) + for _, portRange := range rule.PortRanges { + fr := base + + if supportPortRanges { + fr.PortRange = portRange + } else { + // Peer doesn't support port ranges, only allow single-port ranges + if portRange.Start != portRange.End { + continue + } + fr.Port = strconv.FormatUint(uint64(portRange.Start), 10) + } + expanded = append(expanded, &fr) + } + + return expanded +} + +// peerSupportsPortRanges checks if the peer version supports port ranges. +func peerSupportsPortRanges(peerVer string) bool { + if strings.Contains(peerVer, "dev") { + return true + } + + meetMinVer, err := posture.MeetsMinVersion(firewallRuleMinPortRangesVer, peerVer) + return err == nil && meetMinVer +} diff --git a/management/server/types/firewall_rule.go b/management/server/types/firewall_rule.go index ef54abea2..19222a607 100644 --- a/management/server/types/firewall_rule.go +++ b/management/server/types/firewall_rule.go @@ -76,7 +76,6 @@ func generateRouteFirewallRules(ctx context.Context, route *nbroute.Route, rule rules = append(rules, generateRulesWithPortRanges(baseRule, rule, rulesExists)...) } else { rules = append(rules, generateRulesWithPorts(ctx, baseRule, rule, rulesExists)...) - } // TODO: generate IPv6 rules for dynamic routes