Merge branch 'main' into feat/multiple-profile

This commit is contained in:
Hakan Sariman 2025-06-22 08:44:52 +03:00
commit 69fa927360
34 changed files with 1485 additions and 444 deletions

View File

@ -4,12 +4,12 @@ import (
"github.com/netbirdio/netbird/client/internal" "github.com/netbirdio/netbird/client/internal"
) )
// Preferences export a subset of the internal config for gomobile // Preferences exports a subset of the internal config for gomobile
type Preferences struct { type Preferences struct {
configInput internal.ConfigInput configInput internal.ConfigInput
} }
// NewPreferences create new Preferences instance // NewPreferences creates a new Preferences instance
func NewPreferences(configPath string) *Preferences { func NewPreferences(configPath string) *Preferences {
ci := internal.ConfigInput{ ci := internal.ConfigInput{
ConfigPath: configPath, ConfigPath: configPath,
@ -17,7 +17,7 @@ func NewPreferences(configPath string) *Preferences {
return &Preferences{ci} return &Preferences{ci}
} }
// GetManagementURL read url from config file // GetManagementURL reads URL from config file
func (p *Preferences) GetManagementURL() (string, error) { func (p *Preferences) GetManagementURL() (string, error) {
if p.configInput.ManagementURL != "" { if p.configInput.ManagementURL != "" {
return p.configInput.ManagementURL, nil return p.configInput.ManagementURL, nil
@ -30,12 +30,12 @@ func (p *Preferences) GetManagementURL() (string, error) {
return cfg.ManagementURL.String(), err return cfg.ManagementURL.String(), err
} }
// SetManagementURL store the given url and wait for commit // SetManagementURL stores the given URL and waits for commit
func (p *Preferences) SetManagementURL(url string) { func (p *Preferences) SetManagementURL(url string) {
p.configInput.ManagementURL = url p.configInput.ManagementURL = url
} }
// GetAdminURL read url from config file // GetAdminURL reads URL from config file
func (p *Preferences) GetAdminURL() (string, error) { func (p *Preferences) GetAdminURL() (string, error) {
if p.configInput.AdminURL != "" { if p.configInput.AdminURL != "" {
return p.configInput.AdminURL, nil return p.configInput.AdminURL, nil
@ -48,12 +48,12 @@ func (p *Preferences) GetAdminURL() (string, error) {
return cfg.AdminURL.String(), err return cfg.AdminURL.String(), err
} }
// SetAdminURL store the given url and wait for commit // SetAdminURL stores the given URL and waits for commit
func (p *Preferences) SetAdminURL(url string) { func (p *Preferences) SetAdminURL(url string) {
p.configInput.AdminURL = url p.configInput.AdminURL = url
} }
// GetPreSharedKey read preshared key from config file // GetPreSharedKey reads pre-shared key from config file
func (p *Preferences) GetPreSharedKey() (string, error) { func (p *Preferences) GetPreSharedKey() (string, error) {
if p.configInput.PreSharedKey != nil { if p.configInput.PreSharedKey != nil {
return *p.configInput.PreSharedKey, nil return *p.configInput.PreSharedKey, nil
@ -66,17 +66,17 @@ func (p *Preferences) GetPreSharedKey() (string, error) {
return cfg.PreSharedKey, err return cfg.PreSharedKey, err
} }
// SetPreSharedKey store the given key and wait for commit // SetPreSharedKey stores the given key and waits for commit
func (p *Preferences) SetPreSharedKey(key string) { func (p *Preferences) SetPreSharedKey(key string) {
p.configInput.PreSharedKey = &key p.configInput.PreSharedKey = &key
} }
// SetRosenpassEnabled store if rosenpass is enabled // SetRosenpassEnabled stores whether Rosenpass is enabled
func (p *Preferences) SetRosenpassEnabled(enabled bool) { func (p *Preferences) SetRosenpassEnabled(enabled bool) {
p.configInput.RosenpassEnabled = &enabled p.configInput.RosenpassEnabled = &enabled
} }
// GetRosenpassEnabled read rosenpass enabled from config file // GetRosenpassEnabled reads Rosenpass enabled status from config file
func (p *Preferences) GetRosenpassEnabled() (bool, error) { func (p *Preferences) GetRosenpassEnabled() (bool, error) {
if p.configInput.RosenpassEnabled != nil { if p.configInput.RosenpassEnabled != nil {
return *p.configInput.RosenpassEnabled, nil return *p.configInput.RosenpassEnabled, nil
@ -89,12 +89,12 @@ func (p *Preferences) GetRosenpassEnabled() (bool, error) {
return cfg.RosenpassEnabled, err return cfg.RosenpassEnabled, err
} }
// SetRosenpassPermissive store the given permissive and wait for commit // SetRosenpassPermissive stores the given permissive setting and waits for commit
func (p *Preferences) SetRosenpassPermissive(permissive bool) { func (p *Preferences) SetRosenpassPermissive(permissive bool) {
p.configInput.RosenpassPermissive = &permissive p.configInput.RosenpassPermissive = &permissive
} }
// GetRosenpassPermissive read rosenpass permissive from config file // GetRosenpassPermissive reads Rosenpass permissive setting from config file
func (p *Preferences) GetRosenpassPermissive() (bool, error) { func (p *Preferences) GetRosenpassPermissive() (bool, error) {
if p.configInput.RosenpassPermissive != nil { if p.configInput.RosenpassPermissive != nil {
return *p.configInput.RosenpassPermissive, nil return *p.configInput.RosenpassPermissive, nil
@ -107,7 +107,119 @@ func (p *Preferences) GetRosenpassPermissive() (bool, error) {
return cfg.RosenpassPermissive, err return cfg.RosenpassPermissive, err
} }
// Commit write out the changes into config file // GetDisableClientRoutes reads disable client routes setting from config file
func (p *Preferences) GetDisableClientRoutes() (bool, error) {
if p.configInput.DisableClientRoutes != nil {
return *p.configInput.DisableClientRoutes, nil
}
cfg, err := internal.ReadConfig(p.configInput.ConfigPath)
if err != nil {
return false, err
}
return cfg.DisableClientRoutes, err
}
// SetDisableClientRoutes stores the given value and waits for commit
func (p *Preferences) SetDisableClientRoutes(disable bool) {
p.configInput.DisableClientRoutes = &disable
}
// GetDisableServerRoutes reads disable server routes setting from config file
func (p *Preferences) GetDisableServerRoutes() (bool, error) {
if p.configInput.DisableServerRoutes != nil {
return *p.configInput.DisableServerRoutes, nil
}
cfg, err := internal.ReadConfig(p.configInput.ConfigPath)
if err != nil {
return false, err
}
return cfg.DisableServerRoutes, err
}
// SetDisableServerRoutes stores the given value and waits for commit
func (p *Preferences) SetDisableServerRoutes(disable bool) {
p.configInput.DisableServerRoutes = &disable
}
// GetDisableDNS reads disable DNS setting from config file
func (p *Preferences) GetDisableDNS() (bool, error) {
if p.configInput.DisableDNS != nil {
return *p.configInput.DisableDNS, nil
}
cfg, err := internal.ReadConfig(p.configInput.ConfigPath)
if err != nil {
return false, err
}
return cfg.DisableDNS, err
}
// SetDisableDNS stores the given value and waits for commit
func (p *Preferences) SetDisableDNS(disable bool) {
p.configInput.DisableDNS = &disable
}
// GetDisableFirewall reads disable firewall setting from config file
func (p *Preferences) GetDisableFirewall() (bool, error) {
if p.configInput.DisableFirewall != nil {
return *p.configInput.DisableFirewall, nil
}
cfg, err := internal.ReadConfig(p.configInput.ConfigPath)
if err != nil {
return false, err
}
return cfg.DisableFirewall, err
}
// SetDisableFirewall stores the given value and waits for commit
func (p *Preferences) SetDisableFirewall(disable bool) {
p.configInput.DisableFirewall = &disable
}
// GetServerSSHAllowed reads server SSH allowed setting from config file
func (p *Preferences) GetServerSSHAllowed() (bool, error) {
if p.configInput.ServerSSHAllowed != nil {
return *p.configInput.ServerSSHAllowed, nil
}
cfg, err := internal.ReadConfig(p.configInput.ConfigPath)
if err != nil {
return false, err
}
if cfg.ServerSSHAllowed == nil {
// Default to false for security on Android
return false, nil
}
return *cfg.ServerSSHAllowed, err
}
// SetServerSSHAllowed stores the given value and waits for commit
func (p *Preferences) SetServerSSHAllowed(allowed bool) {
p.configInput.ServerSSHAllowed = &allowed
}
// GetBlockInbound reads block inbound setting from config file
func (p *Preferences) GetBlockInbound() (bool, error) {
if p.configInput.BlockInbound != nil {
return *p.configInput.BlockInbound, nil
}
cfg, err := internal.ReadConfig(p.configInput.ConfigPath)
if err != nil {
return false, err
}
return cfg.BlockInbound, err
}
// SetBlockInbound stores the given value and waits for commit
func (p *Preferences) SetBlockInbound(block bool) {
p.configInput.BlockInbound = &block
}
// Commit writes out the changes to the config file
func (p *Preferences) Commit() error { func (p *Preferences) Commit() error {
_, err := internal.UpdateOrCreateConfig(p.configInput) _, err := internal.UpdateOrCreateConfig(p.configInput)
return err return err

View File

@ -38,5 +38,5 @@ func init() {
upCmd.PersistentFlags().BoolVar(&blockInbound, blockInboundFlag, false, upCmd.PersistentFlags().BoolVar(&blockInbound, blockInboundFlag, false,
"Block inbound connections. If enabled, the client will not allow any inbound connections to the local machine nor routed networks.\n"+ "Block inbound connections. If enabled, the client will not allow any inbound connections to the local machine nor routed networks.\n"+
"This overrides any policies received from the management service.") "This overrides any policies received from the management service.")
} }

View File

@ -24,6 +24,7 @@ type WGTunDevice struct {
mtu int mtu int
iceBind *bind.ICEBind iceBind *bind.ICEBind
tunAdapter TunAdapter tunAdapter TunAdapter
disableDNS bool
name string name string
device *device.Device device *device.Device
@ -32,7 +33,7 @@ type WGTunDevice struct {
configurer WGConfigurer configurer WGConfigurer
} }
func NewTunDevice(address wgaddr.Address, port int, key string, mtu int, iceBind *bind.ICEBind, tunAdapter TunAdapter) *WGTunDevice { func NewTunDevice(address wgaddr.Address, port int, key string, mtu int, iceBind *bind.ICEBind, tunAdapter TunAdapter, disableDNS bool) *WGTunDevice {
return &WGTunDevice{ return &WGTunDevice{
address: address, address: address,
port: port, port: port,
@ -40,6 +41,7 @@ func NewTunDevice(address wgaddr.Address, port int, key string, mtu int, iceBind
mtu: mtu, mtu: mtu,
iceBind: iceBind, iceBind: iceBind,
tunAdapter: tunAdapter, tunAdapter: tunAdapter,
disableDNS: disableDNS,
} }
} }
@ -49,6 +51,13 @@ func (t *WGTunDevice) Create(routes []string, dns string, searchDomains []string
routesString := routesToString(routes) routesString := routesToString(routes)
searchDomainsToString := searchDomainsToString(searchDomains) searchDomainsToString := searchDomainsToString(searchDomains)
// Skip DNS configuration when DisableDNS is enabled
if t.disableDNS {
log.Info("DNS is disabled, skipping DNS and search domain configuration")
dns = ""
searchDomainsToString = ""
}
fd, err := t.tunAdapter.ConfigureInterface(t.address.String(), t.mtu, dns, searchDomainsToString, routesString) fd, err := t.tunAdapter.ConfigureInterface(t.address.String(), t.mtu, dns, searchDomainsToString, routesString)
if err != nil { if err != nil {
log.Errorf("failed to create Android interface: %s", err) log.Errorf("failed to create Android interface: %s", err)

View File

@ -43,6 +43,7 @@ type WGIFaceOpts struct {
MobileArgs *device.MobileIFaceArguments MobileArgs *device.MobileIFaceArguments
TransportNet transport.Net TransportNet transport.Net
FilterFn bind.FilterFn FilterFn bind.FilterFn
DisableDNS bool
} }
// WGIface represents an interface instance // WGIface represents an interface instance

View File

@ -18,7 +18,7 @@ func NewWGIFace(opts WGIFaceOpts) (*WGIface, error) {
wgIFace := &WGIface{ wgIFace := &WGIface{
userspaceBind: true, userspaceBind: true,
tun: device.NewTunDevice(wgAddress, opts.WGPort, opts.WGPrivKey, opts.MTU, iceBind, opts.MobileArgs.TunAdapter), tun: device.NewTunDevice(wgAddress, opts.WGPort, opts.WGPrivKey, opts.MTU, iceBind, opts.MobileArgs.TunAdapter, opts.DisableDNS),
wgProxyFactory: wgproxy.NewUSPFactory(iceBind), wgProxyFactory: wgproxy.NewUSPFactory(iceBind),
} }
return wgIFace, nil return wgIFace, nil

View File

@ -398,11 +398,15 @@ func (d *DefaultManager) squashAcceptRules(
// //
// We zeroed this to notify squash function that this protocol can't be squashed. // We zeroed this to notify squash function that this protocol can't be squashed.
addRuleToCalculationMap := func(i int, r *mgmProto.FirewallRule, protocols map[mgmProto.RuleProtocol]*protoMatch) { addRuleToCalculationMap := func(i int, r *mgmProto.FirewallRule, protocols map[mgmProto.RuleProtocol]*protoMatch) {
drop := r.Action == mgmProto.RuleAction_DROP || r.Port != "" hasPortRestrictions := r.Action == mgmProto.RuleAction_DROP ||
if drop { r.Port != "" || !portInfoEmpty(r.PortInfo)
if hasPortRestrictions {
// Don't squash rules with port restrictions
protocols[r.Protocol] = &protoMatch{ips: map[string]int{}} protocols[r.Protocol] = &protoMatch{ips: map[string]int{}}
return return
} }
if _, ok := protocols[r.Protocol]; !ok { if _, ok := protocols[r.Protocol]; !ok {
protocols[r.Protocol] = &protoMatch{ protocols[r.Protocol] = &protoMatch{
ips: map[string]int{}, ips: map[string]int{},

View File

@ -330,6 +330,434 @@ func TestDefaultManagerSquashRulesNoAffect(t *testing.T) {
assert.Equal(t, len(networkMap.FirewallRules), len(rules)) assert.Equal(t, len(networkMap.FirewallRules), len(rules))
} }
func TestDefaultManagerSquashRulesWithPortRestrictions(t *testing.T) {
tests := []struct {
name string
rules []*mgmProto.FirewallRule
expectedCount int
description string
}{
{
name: "should not squash rules with port ranges",
rules: []*mgmProto.FirewallRule{
{
PeerIP: "10.93.0.1",
Direction: mgmProto.RuleDirection_IN,
Action: mgmProto.RuleAction_ACCEPT,
Protocol: mgmProto.RuleProtocol_TCP,
PortInfo: &mgmProto.PortInfo{
PortSelection: &mgmProto.PortInfo_Range_{
Range: &mgmProto.PortInfo_Range{
Start: 8080,
End: 8090,
},
},
},
},
{
PeerIP: "10.93.0.2",
Direction: mgmProto.RuleDirection_IN,
Action: mgmProto.RuleAction_ACCEPT,
Protocol: mgmProto.RuleProtocol_TCP,
PortInfo: &mgmProto.PortInfo{
PortSelection: &mgmProto.PortInfo_Range_{
Range: &mgmProto.PortInfo_Range{
Start: 8080,
End: 8090,
},
},
},
},
{
PeerIP: "10.93.0.3",
Direction: mgmProto.RuleDirection_IN,
Action: mgmProto.RuleAction_ACCEPT,
Protocol: mgmProto.RuleProtocol_TCP,
PortInfo: &mgmProto.PortInfo{
PortSelection: &mgmProto.PortInfo_Range_{
Range: &mgmProto.PortInfo_Range{
Start: 8080,
End: 8090,
},
},
},
},
{
PeerIP: "10.93.0.4",
Direction: mgmProto.RuleDirection_IN,
Action: mgmProto.RuleAction_ACCEPT,
Protocol: mgmProto.RuleProtocol_TCP,
PortInfo: &mgmProto.PortInfo{
PortSelection: &mgmProto.PortInfo_Range_{
Range: &mgmProto.PortInfo_Range{
Start: 8080,
End: 8090,
},
},
},
},
},
expectedCount: 4,
description: "Rules with port ranges should not be squashed even if they cover all peers",
},
{
name: "should not squash rules with specific ports",
rules: []*mgmProto.FirewallRule{
{
PeerIP: "10.93.0.1",
Direction: mgmProto.RuleDirection_IN,
Action: mgmProto.RuleAction_ACCEPT,
Protocol: mgmProto.RuleProtocol_TCP,
PortInfo: &mgmProto.PortInfo{
PortSelection: &mgmProto.PortInfo_Port{
Port: 80,
},
},
},
{
PeerIP: "10.93.0.2",
Direction: mgmProto.RuleDirection_IN,
Action: mgmProto.RuleAction_ACCEPT,
Protocol: mgmProto.RuleProtocol_TCP,
PortInfo: &mgmProto.PortInfo{
PortSelection: &mgmProto.PortInfo_Port{
Port: 80,
},
},
},
{
PeerIP: "10.93.0.3",
Direction: mgmProto.RuleDirection_IN,
Action: mgmProto.RuleAction_ACCEPT,
Protocol: mgmProto.RuleProtocol_TCP,
PortInfo: &mgmProto.PortInfo{
PortSelection: &mgmProto.PortInfo_Port{
Port: 80,
},
},
},
{
PeerIP: "10.93.0.4",
Direction: mgmProto.RuleDirection_IN,
Action: mgmProto.RuleAction_ACCEPT,
Protocol: mgmProto.RuleProtocol_TCP,
PortInfo: &mgmProto.PortInfo{
PortSelection: &mgmProto.PortInfo_Port{
Port: 80,
},
},
},
},
expectedCount: 4,
description: "Rules with specific ports should not be squashed even if they cover all peers",
},
{
name: "should not squash rules with legacy port field",
rules: []*mgmProto.FirewallRule{
{
PeerIP: "10.93.0.1",
Direction: mgmProto.RuleDirection_IN,
Action: mgmProto.RuleAction_ACCEPT,
Protocol: mgmProto.RuleProtocol_TCP,
Port: "443",
},
{
PeerIP: "10.93.0.2",
Direction: mgmProto.RuleDirection_IN,
Action: mgmProto.RuleAction_ACCEPT,
Protocol: mgmProto.RuleProtocol_TCP,
Port: "443",
},
{
PeerIP: "10.93.0.3",
Direction: mgmProto.RuleDirection_IN,
Action: mgmProto.RuleAction_ACCEPT,
Protocol: mgmProto.RuleProtocol_TCP,
Port: "443",
},
{
PeerIP: "10.93.0.4",
Direction: mgmProto.RuleDirection_IN,
Action: mgmProto.RuleAction_ACCEPT,
Protocol: mgmProto.RuleProtocol_TCP,
Port: "443",
},
},
expectedCount: 4,
description: "Rules with legacy port field should not be squashed",
},
{
name: "should not squash rules with DROP action",
rules: []*mgmProto.FirewallRule{
{
PeerIP: "10.93.0.1",
Direction: mgmProto.RuleDirection_IN,
Action: mgmProto.RuleAction_DROP,
Protocol: mgmProto.RuleProtocol_TCP,
},
{
PeerIP: "10.93.0.2",
Direction: mgmProto.RuleDirection_IN,
Action: mgmProto.RuleAction_DROP,
Protocol: mgmProto.RuleProtocol_TCP,
},
{
PeerIP: "10.93.0.3",
Direction: mgmProto.RuleDirection_IN,
Action: mgmProto.RuleAction_DROP,
Protocol: mgmProto.RuleProtocol_TCP,
},
{
PeerIP: "10.93.0.4",
Direction: mgmProto.RuleDirection_IN,
Action: mgmProto.RuleAction_DROP,
Protocol: mgmProto.RuleProtocol_TCP,
},
},
expectedCount: 4,
description: "Rules with DROP action should not be squashed",
},
{
name: "should squash rules without port restrictions",
rules: []*mgmProto.FirewallRule{
{
PeerIP: "10.93.0.1",
Direction: mgmProto.RuleDirection_IN,
Action: mgmProto.RuleAction_ACCEPT,
Protocol: mgmProto.RuleProtocol_TCP,
},
{
PeerIP: "10.93.0.2",
Direction: mgmProto.RuleDirection_IN,
Action: mgmProto.RuleAction_ACCEPT,
Protocol: mgmProto.RuleProtocol_TCP,
},
{
PeerIP: "10.93.0.3",
Direction: mgmProto.RuleDirection_IN,
Action: mgmProto.RuleAction_ACCEPT,
Protocol: mgmProto.RuleProtocol_TCP,
},
{
PeerIP: "10.93.0.4",
Direction: mgmProto.RuleDirection_IN,
Action: mgmProto.RuleAction_ACCEPT,
Protocol: mgmProto.RuleProtocol_TCP,
},
},
expectedCount: 1,
description: "Rules without port restrictions should be squashed into a single 0.0.0.0 rule",
},
{
name: "mixed rules should not squash protocol with port restrictions",
rules: []*mgmProto.FirewallRule{
{
PeerIP: "10.93.0.1",
Direction: mgmProto.RuleDirection_IN,
Action: mgmProto.RuleAction_ACCEPT,
Protocol: mgmProto.RuleProtocol_TCP,
},
{
PeerIP: "10.93.0.2",
Direction: mgmProto.RuleDirection_IN,
Action: mgmProto.RuleAction_ACCEPT,
Protocol: mgmProto.RuleProtocol_TCP,
PortInfo: &mgmProto.PortInfo{
PortSelection: &mgmProto.PortInfo_Port{
Port: 80,
},
},
},
{
PeerIP: "10.93.0.3",
Direction: mgmProto.RuleDirection_IN,
Action: mgmProto.RuleAction_ACCEPT,
Protocol: mgmProto.RuleProtocol_TCP,
},
{
PeerIP: "10.93.0.4",
Direction: mgmProto.RuleDirection_IN,
Action: mgmProto.RuleAction_ACCEPT,
Protocol: mgmProto.RuleProtocol_TCP,
},
},
expectedCount: 4,
description: "TCP should not be squashed because one rule has port restrictions",
},
{
name: "should squash UDP but not TCP when TCP has port restrictions",
rules: []*mgmProto.FirewallRule{
// TCP rules with port restrictions - should NOT be squashed
{
PeerIP: "10.93.0.1",
Direction: mgmProto.RuleDirection_IN,
Action: mgmProto.RuleAction_ACCEPT,
Protocol: mgmProto.RuleProtocol_TCP,
Port: "443",
},
{
PeerIP: "10.93.0.2",
Direction: mgmProto.RuleDirection_IN,
Action: mgmProto.RuleAction_ACCEPT,
Protocol: mgmProto.RuleProtocol_TCP,
Port: "443",
},
{
PeerIP: "10.93.0.3",
Direction: mgmProto.RuleDirection_IN,
Action: mgmProto.RuleAction_ACCEPT,
Protocol: mgmProto.RuleProtocol_TCP,
Port: "443",
},
{
PeerIP: "10.93.0.4",
Direction: mgmProto.RuleDirection_IN,
Action: mgmProto.RuleAction_ACCEPT,
Protocol: mgmProto.RuleProtocol_TCP,
Port: "443",
},
// UDP rules without port restrictions - SHOULD be squashed
{
PeerIP: "10.93.0.1",
Direction: mgmProto.RuleDirection_IN,
Action: mgmProto.RuleAction_ACCEPT,
Protocol: mgmProto.RuleProtocol_UDP,
},
{
PeerIP: "10.93.0.2",
Direction: mgmProto.RuleDirection_IN,
Action: mgmProto.RuleAction_ACCEPT,
Protocol: mgmProto.RuleProtocol_UDP,
},
{
PeerIP: "10.93.0.3",
Direction: mgmProto.RuleDirection_IN,
Action: mgmProto.RuleAction_ACCEPT,
Protocol: mgmProto.RuleProtocol_UDP,
},
{
PeerIP: "10.93.0.4",
Direction: mgmProto.RuleDirection_IN,
Action: mgmProto.RuleAction_ACCEPT,
Protocol: mgmProto.RuleProtocol_UDP,
},
},
expectedCount: 5, // 4 TCP rules + 1 squashed UDP rule (0.0.0.0)
description: "UDP should be squashed to 0.0.0.0 rule, but TCP should remain as individual rules due to port restrictions",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
networkMap := &mgmProto.NetworkMap{
RemotePeers: []*mgmProto.RemotePeerConfig{
{AllowedIps: []string{"10.93.0.1"}},
{AllowedIps: []string{"10.93.0.2"}},
{AllowedIps: []string{"10.93.0.3"}},
{AllowedIps: []string{"10.93.0.4"}},
},
FirewallRules: tt.rules,
}
manager := &DefaultManager{}
rules, _ := manager.squashAcceptRules(networkMap)
assert.Equal(t, tt.expectedCount, len(rules), tt.description)
// For squashed rules, verify we get the expected 0.0.0.0 rule
if tt.expectedCount == 1 {
assert.Equal(t, "0.0.0.0", rules[0].PeerIP)
assert.Equal(t, mgmProto.RuleDirection_IN, rules[0].Direction)
assert.Equal(t, mgmProto.RuleAction_ACCEPT, rules[0].Action)
}
})
}
}
func TestPortInfoEmpty(t *testing.T) {
tests := []struct {
name string
portInfo *mgmProto.PortInfo
expected bool
}{
{
name: "nil PortInfo should be empty",
portInfo: nil,
expected: true,
},
{
name: "PortInfo with zero port should be empty",
portInfo: &mgmProto.PortInfo{
PortSelection: &mgmProto.PortInfo_Port{
Port: 0,
},
},
expected: true,
},
{
name: "PortInfo with valid port should not be empty",
portInfo: &mgmProto.PortInfo{
PortSelection: &mgmProto.PortInfo_Port{
Port: 80,
},
},
expected: false,
},
{
name: "PortInfo with nil range should be empty",
portInfo: &mgmProto.PortInfo{
PortSelection: &mgmProto.PortInfo_Range_{
Range: nil,
},
},
expected: true,
},
{
name: "PortInfo with zero start range should be empty",
portInfo: &mgmProto.PortInfo{
PortSelection: &mgmProto.PortInfo_Range_{
Range: &mgmProto.PortInfo_Range{
Start: 0,
End: 100,
},
},
},
expected: true,
},
{
name: "PortInfo with zero end range should be empty",
portInfo: &mgmProto.PortInfo{
PortSelection: &mgmProto.PortInfo_Range_{
Range: &mgmProto.PortInfo_Range{
Start: 80,
End: 0,
},
},
},
expected: true,
},
{
name: "PortInfo with valid range should not be empty",
portInfo: &mgmProto.PortInfo{
PortSelection: &mgmProto.PortInfo_Range_{
Range: &mgmProto.PortInfo_Range{
Start: 8080,
End: 8090,
},
},
},
expected: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := portInfoEmpty(tt.portInfo)
assert.Equal(t, tt.expected, result)
})
}
}
func TestDefaultManagerEnableSSHRules(t *testing.T) { func TestDefaultManagerEnableSSHRules(t *testing.T) {
networkMap := &mgmProto.NetworkMap{ networkMap := &mgmProto.NetworkMap{
PeerConfig: &mgmProto.PeerConfig{ PeerConfig: &mgmProto.PeerConfig{

View File

@ -223,6 +223,8 @@ func createNewConfig(input ConfigInput) (*Config, error) {
config := &Config{ config := &Config{
// defaults to false only for new (post 0.26) configurations // defaults to false only for new (post 0.26) configurations
ServerSSHAllowed: util.False(), ServerSSHAllowed: util.False(),
// default to disabling server routes on Android for security
DisableServerRoutes: runtime.GOOS == "android",
} }
if _, err := config.apply(input); err != nil { if _, err := config.apply(input); err != nil {
@ -416,9 +418,15 @@ func (config *Config) apply(input ConfigInput) (updated bool, err error) {
config.ServerSSHAllowed = input.ServerSSHAllowed config.ServerSSHAllowed = input.ServerSSHAllowed
updated = true updated = true
} else if config.ServerSSHAllowed == nil { } else if config.ServerSSHAllowed == nil {
// enables SSH for configs from old versions to preserve backwards compatibility if runtime.GOOS == "android" {
log.Infof("falling back to enabled SSH server for pre-existing configuration") // default to disabled SSH on Android for security
config.ServerSSHAllowed = util.True() log.Infof("setting SSH server to false by default on Android")
config.ServerSSHAllowed = util.False()
} else {
// enables SSH for configs from old versions to preserve backwards compatibility
log.Infof("falling back to enabled SSH server for pre-existing configuration")
config.ServerSSHAllowed = util.True()
}
updated = true updated = true
} }

View File

@ -175,7 +175,7 @@ func (e *ConnMgr) AddPeerConn(ctx context.Context, peerKey string, conn *peer.Co
PeerConnID: conn.ConnID(), PeerConnID: conn.ConnID(),
Log: conn.Log, Log: conn.Log,
} }
excluded, err := e.lazyConnMgr.AddPeer(lazyPeerCfg) excluded, err := e.lazyConnMgr.AddPeer(e.lazyCtx, lazyPeerCfg)
if err != nil { if err != nil {
conn.Log.Errorf("failed to add peer to lazyconn manager: %v", err) conn.Log.Errorf("failed to add peer to lazyconn manager: %v", err)
if err := conn.Open(ctx); err != nil { if err := conn.Open(ctx); err != nil {

View File

@ -1527,6 +1527,7 @@ func (e *Engine) newWgIface() (*iface.WGIface, error) {
MTU: iface.DefaultMTU, MTU: iface.DefaultMTU,
TransportNet: transportNet, TransportNet: transportNet,
FilterFn: e.addrViaRoutes, FilterFn: e.addrViaRoutes,
DisableDNS: e.config.DisableDNS,
} }
switch runtime.GOOS { switch runtime.GOOS {

View File

@ -68,3 +68,8 @@ func (i *Monitor) PauseTimer() {
func (i *Monitor) ResetTimer() { func (i *Monitor) ResetTimer() {
i.timer.Reset(i.inactivityThreshold) i.timer.Reset(i.inactivityThreshold)
} }
func (i *Monitor) ResetMonitor(ctx context.Context, timeoutChan chan peer.ConnID) {
i.Stop()
go i.Start(ctx, timeoutChan)
}

View File

@ -58,7 +58,7 @@ type Manager struct {
// Route HA group management // Route HA group management
peerToHAGroups map[string][]route.HAUniqueID // peer ID -> HA groups they belong to peerToHAGroups map[string][]route.HAUniqueID // peer ID -> HA groups they belong to
haGroupToPeers map[route.HAUniqueID][]string // HA group -> peer IDs in the group haGroupToPeers map[route.HAUniqueID][]string // HA group -> peer IDs in the group
routesMu sync.RWMutex // protects route mappings routesMu sync.RWMutex
onInactive chan peerid.ConnID onInactive chan peerid.ConnID
} }
@ -146,7 +146,7 @@ func (m *Manager) Start(ctx context.Context) {
case peerConnID := <-m.activityManager.OnActivityChan: case peerConnID := <-m.activityManager.OnActivityChan:
m.onPeerActivity(ctx, peerConnID) m.onPeerActivity(ctx, peerConnID)
case peerConnID := <-m.onInactive: case peerConnID := <-m.onInactive:
m.onPeerInactivityTimedOut(peerConnID) m.onPeerInactivityTimedOut(ctx, peerConnID)
} }
} }
} }
@ -197,7 +197,7 @@ func (m *Manager) ExcludePeer(ctx context.Context, peerConfigs []lazyconn.PeerCo
return added return added
} }
func (m *Manager) AddPeer(peerCfg lazyconn.PeerConfig) (bool, error) { func (m *Manager) AddPeer(ctx context.Context, peerCfg lazyconn.PeerConfig) (bool, error) {
m.managedPeersMu.Lock() m.managedPeersMu.Lock()
defer m.managedPeersMu.Unlock() defer m.managedPeersMu.Unlock()
@ -225,6 +225,13 @@ func (m *Manager) AddPeer(peerCfg lazyconn.PeerConfig) (bool, error) {
peerCfg: &peerCfg, peerCfg: &peerCfg,
expectedWatcher: watcherActivity, expectedWatcher: watcherActivity,
} }
// Check if this peer should be activated because its HA group peers are active
if group, ok := m.shouldActivateNewPeer(peerCfg.PublicKey); ok {
peerCfg.Log.Debugf("peer belongs to active HA group %s, will activate immediately", group)
m.activateNewPeerInActiveGroup(ctx, peerCfg)
}
return false, nil return false, nil
} }
@ -315,36 +322,38 @@ func (m *Manager) activateSinglePeer(ctx context.Context, cfg *lazyconn.PeerConf
// activateHAGroupPeers activates all peers in HA groups that the given peer belongs to // activateHAGroupPeers activates all peers in HA groups that the given peer belongs to
func (m *Manager) activateHAGroupPeers(ctx context.Context, triggerPeerID string) { func (m *Manager) activateHAGroupPeers(ctx context.Context, triggerPeerID string) {
var peersToActivate []string
m.routesMu.RLock() m.routesMu.RLock()
haGroups := m.peerToHAGroups[triggerPeerID] haGroups := m.peerToHAGroups[triggerPeerID]
m.routesMu.RUnlock()
if len(haGroups) == 0 { if len(haGroups) == 0 {
m.routesMu.RUnlock()
log.Debugf("peer %s is not part of any HA groups", triggerPeerID) log.Debugf("peer %s is not part of any HA groups", triggerPeerID)
return return
} }
activatedCount := 0
for _, haGroup := range haGroups { for _, haGroup := range haGroups {
m.routesMu.RLock()
peers := m.haGroupToPeers[haGroup] peers := m.haGroupToPeers[haGroup]
m.routesMu.RUnlock()
for _, peerID := range peers { for _, peerID := range peers {
if peerID == triggerPeerID { if peerID != triggerPeerID {
continue peersToActivate = append(peersToActivate, peerID)
} }
}
}
m.routesMu.RUnlock()
cfg, mp := m.getPeerForActivation(peerID) activatedCount := 0
if cfg == nil { for _, peerID := range peersToActivate {
continue cfg, mp := m.getPeerForActivation(peerID)
} if cfg == nil {
continue
}
if m.activateSinglePeer(ctx, cfg, mp) { if m.activateSinglePeer(ctx, cfg, mp) {
activatedCount++ activatedCount++
cfg.Log.Infof("activated peer as part of HA group %s (triggered by %s)", haGroup, triggerPeerID) cfg.Log.Infof("activated peer as part of HA group (triggered by %s)", triggerPeerID)
m.peerStore.PeerConnOpen(m.engineCtx, cfg.PublicKey) m.peerStore.PeerConnOpen(m.engineCtx, cfg.PublicKey)
}
} }
} }
@ -354,6 +363,51 @@ func (m *Manager) activateHAGroupPeers(ctx context.Context, triggerPeerID string
} }
} }
// shouldActivateNewPeer checks if a newly added peer should be activated
// because other peers in its HA groups are already active
func (m *Manager) shouldActivateNewPeer(peerID string) (route.HAUniqueID, bool) {
m.routesMu.RLock()
defer m.routesMu.RUnlock()
haGroups := m.peerToHAGroups[peerID]
if len(haGroups) == 0 {
return "", false
}
for _, haGroup := range haGroups {
peers := m.haGroupToPeers[haGroup]
for _, groupPeerID := range peers {
if groupPeerID == peerID {
continue
}
cfg, ok := m.managedPeers[groupPeerID]
if !ok {
continue
}
if mp, ok := m.managedPeersByConnID[cfg.PeerConnID]; ok && mp.expectedWatcher == watcherInactivity {
return haGroup, true
}
}
}
return "", false
}
// activateNewPeerInActiveGroup activates a newly added peer that should be active due to HA group
func (m *Manager) activateNewPeerInActiveGroup(ctx context.Context, peerCfg lazyconn.PeerConfig) {
mp, ok := m.managedPeersByConnID[peerCfg.PeerConnID]
if !ok {
return
}
if !m.activateSinglePeer(ctx, &peerCfg, mp) {
return
}
peerCfg.Log.Infof("activated newly added peer due to active HA group peers")
m.peerStore.PeerConnOpen(m.engineCtx, peerCfg.PublicKey)
}
func (m *Manager) addActivePeer(ctx context.Context, peerCfg lazyconn.PeerConfig) error { func (m *Manager) addActivePeer(ctx context.Context, peerCfg lazyconn.PeerConfig) error {
if _, ok := m.managedPeers[peerCfg.PublicKey]; ok { if _, ok := m.managedPeers[peerCfg.PublicKey]; ok {
peerCfg.Log.Warnf("peer already managed") peerCfg.Log.Warnf("peer already managed")
@ -415,6 +469,48 @@ func (m *Manager) close() {
log.Infof("lazy connection manager closed") log.Infof("lazy connection manager closed")
} }
// shouldDeferIdleForHA checks if peer should stay connected due to HA group requirements
func (m *Manager) shouldDeferIdleForHA(peerID string) bool {
m.routesMu.RLock()
defer m.routesMu.RUnlock()
haGroups := m.peerToHAGroups[peerID]
if len(haGroups) == 0 {
return false
}
for _, haGroup := range haGroups {
groupPeers := m.haGroupToPeers[haGroup]
for _, groupPeerID := range groupPeers {
if groupPeerID == peerID {
continue
}
cfg, ok := m.managedPeers[groupPeerID]
if !ok {
continue
}
groupMp, ok := m.managedPeersByConnID[cfg.PeerConnID]
if !ok {
continue
}
if groupMp.expectedWatcher != watcherInactivity {
continue
}
// Other member is still connected, defer idle
if peer, ok := m.peerStore.PeerConn(groupPeerID); ok && peer.IsConnected() {
return true
}
}
}
return false
}
func (m *Manager) onPeerActivity(ctx context.Context, peerConnID peerid.ConnID) { func (m *Manager) onPeerActivity(ctx context.Context, peerConnID peerid.ConnID) {
m.managedPeersMu.Lock() m.managedPeersMu.Lock()
defer m.managedPeersMu.Unlock() defer m.managedPeersMu.Unlock()
@ -441,7 +537,7 @@ func (m *Manager) onPeerActivity(ctx context.Context, peerConnID peerid.ConnID)
m.peerStore.PeerConnOpen(m.engineCtx, mp.peerCfg.PublicKey) m.peerStore.PeerConnOpen(m.engineCtx, mp.peerCfg.PublicKey)
} }
func (m *Manager) onPeerInactivityTimedOut(peerConnID peerid.ConnID) { func (m *Manager) onPeerInactivityTimedOut(ctx context.Context, peerConnID peerid.ConnID) {
m.managedPeersMu.Lock() m.managedPeersMu.Lock()
defer m.managedPeersMu.Unlock() defer m.managedPeersMu.Unlock()
@ -456,6 +552,17 @@ func (m *Manager) onPeerInactivityTimedOut(peerConnID peerid.ConnID) {
return return
} }
if m.shouldDeferIdleForHA(mp.peerCfg.PublicKey) {
iw, ok := m.inactivityMonitors[peerConnID]
if ok {
mp.peerCfg.Log.Debugf("resetting inactivity timer due to HA group requirements")
iw.ResetMonitor(ctx, m.onInactive)
} else {
mp.peerCfg.Log.Errorf("inactivity monitor not found for HA defer reset")
}
return
}
mp.peerCfg.Log.Infof("connection timed out") mp.peerCfg.Log.Infof("connection timed out")
// this is blocking operation, potentially can be optimized // this is blocking operation, potentially can be optimized
@ -489,7 +596,7 @@ func (m *Manager) onPeerConnected(peerConnID peerid.ConnID) {
iw, ok := m.inactivityMonitors[mp.peerCfg.PeerConnID] iw, ok := m.inactivityMonitors[mp.peerCfg.PeerConnID]
if !ok { if !ok {
mp.peerCfg.Log.Errorf("inactivity monitor not found for peer") mp.peerCfg.Log.Warnf("inactivity monitor not found for peer")
return return
} }

View File

@ -317,12 +317,12 @@ func (conn *Conn) WgConfig() WgConfig {
return conn.config.WgConfig return conn.config.WgConfig
} }
// IsConnected unit tests only // IsConnected returns true if the peer is connected
// refactor unit test to use status recorder use refactor status recorded to manage connection status in peer.Conn
func (conn *Conn) IsConnected() bool { func (conn *Conn) IsConnected() bool {
conn.mu.Lock() conn.mu.Lock()
defer conn.mu.Unlock() defer conn.mu.Unlock()
return conn.currentConnPriority != conntype.None
return conn.evalStatus() == StatusConnected
} }
func (conn *Conn) GetKey() string { func (conn *Conn) GetKey() string {

View File

@ -15,7 +15,7 @@ import (
// MockManager is the mock instance of a route manager // MockManager is the mock instance of a route manager
type MockManager struct { type MockManager struct {
ClassifyRoutesFunc func(routes []*route.Route) (map[route.ID]*route.Route, route.HAMap) ClassifyRoutesFunc func(routes []*route.Route) (map[route.ID]*route.Route, route.HAMap)
UpdateRoutesFunc func (updateSerial uint64, serverRoutes map[route.ID]*route.Route, clientRoutes route.HAMap, useNewDNSRoute bool) error UpdateRoutesFunc func(updateSerial uint64, serverRoutes map[route.ID]*route.Route, clientRoutes route.HAMap, useNewDNSRoute bool) error
TriggerSelectionFunc func(haMap route.HAMap) TriggerSelectionFunc func(haMap route.HAMap)
GetRouteSelectorFunc func() *routeselector.RouteSelector GetRouteSelectorFunc func() *routeselector.RouteSelector
GetClientRoutesFunc func() route.HAMap GetClientRoutesFunc func() route.HAMap

View File

@ -32,7 +32,6 @@ func (n *Notifier) SetListener(listener listener.NetworkChangeListener) {
func (n *Notifier) SetInitialClientRoutes(clientRoutes []*route.Route) { func (n *Notifier) SetInitialClientRoutes(clientRoutes []*route.Route) {
nets := make([]string, 0) nets := make([]string, 0)
for _, r := range clientRoutes { for _, r := range clientRoutes {
// filter out domain routes
if r.IsDynamic() { if r.IsDynamic() {
continue continue
} }
@ -46,30 +45,27 @@ func (n *Notifier) OnNewRoutes(idMap route.HAMap) {
if runtime.GOOS != "android" { if runtime.GOOS != "android" {
return return
} }
newNets := make([]string, 0)
var newNets []string
for _, routes := range idMap { for _, routes := range idMap {
for _, r := range routes { for _, r := range routes {
if r.IsDynamic() {
continue
}
newNets = append(newNets, r.Network.String()) newNets = append(newNets, r.Network.String())
} }
} }
sort.Strings(newNets) sort.Strings(newNets)
switch runtime.GOOS { if !n.hasDiff(n.initialRouteRanges, newNets) {
case "android": return
if !n.hasDiff(n.initialRouteRanges, newNets) {
return
}
default:
if !n.hasDiff(n.routeRanges, newNets) {
return
}
} }
n.routeRanges = newNets n.routeRanges = newNets
n.notify() n.notify()
} }
// OnNewPrefixes is called from iOS only
func (n *Notifier) OnNewPrefixes(prefixes []netip.Prefix) { func (n *Notifier) OnNewPrefixes(prefixes []netip.Prefix) {
newNets := make([]string, 0) newNets := make([]string, 0)
for _, prefix := range prefixes { for _, prefix := range prefixes {
@ -77,19 +73,11 @@ func (n *Notifier) OnNewPrefixes(prefixes []netip.Prefix) {
} }
sort.Strings(newNets) sort.Strings(newNets)
switch runtime.GOOS { if !n.hasDiff(n.routeRanges, newNets) {
case "android": return
if !n.hasDiff(n.initialRouteRanges, newNets) {
return
}
default:
if !n.hasDiff(n.routeRanges, newNets) {
return
}
} }
n.routeRanges = newNets n.routeRanges = newNets
n.notify() n.notify()
} }

View File

@ -59,16 +59,16 @@ type Info struct {
Environment Environment Environment Environment
Files []File // for posture checks Files []File // for posture checks
RosenpassEnabled bool RosenpassEnabled bool
RosenpassPermissive bool RosenpassPermissive bool
ServerSSHAllowed bool ServerSSHAllowed bool
DisableClientRoutes bool DisableClientRoutes bool
DisableServerRoutes bool DisableServerRoutes bool
DisableDNS bool DisableDNS bool
DisableFirewall bool DisableFirewall bool
BlockLANAccess bool BlockLANAccess bool
BlockInbound bool BlockInbound bool
LazyConnectionEnabled bool LazyConnectionEnabled bool
} }

View File

@ -1853,40 +1853,49 @@ func (am *DefaultAccountManager) GetOrCreateAccountByPrivateDomain(ctx context.C
} }
func (am *DefaultAccountManager) UpdateToPrimaryAccount(ctx context.Context, accountId string) (*types.Account, error) { func (am *DefaultAccountManager) UpdateToPrimaryAccount(ctx context.Context, accountId string) (*types.Account, error) {
account, err := am.Store.GetAccount(ctx, accountId) var account *types.Account
err := am.Store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
var err error
account, err = transaction.GetAccount(ctx, accountId)
if err != nil {
return err
}
if account.IsDomainPrimaryAccount {
return nil
}
existingPrimaryAccountID, err := transaction.GetAccountIDByPrivateDomain(ctx, store.LockingStrengthShare, account.Domain)
// error is not a not found error
if handleNotFound(err) != nil {
return err
}
// a primary account already exists for this private domain
if err == nil {
log.WithContext(ctx).WithFields(log.Fields{
"accountId": accountId,
"existingAccountId": existingPrimaryAccountID,
}).Errorf("cannot update account to primary, another account already exists as primary for the same domain")
return status.Errorf(status.Internal, "cannot update account to primary")
}
account.IsDomainPrimaryAccount = true
if err := transaction.SaveAccount(ctx, account); err != nil {
log.WithContext(ctx).WithFields(log.Fields{
"accountId": accountId,
}).Errorf("failed to update account to primary: %v", err)
return status.Errorf(status.Internal, "failed to update account to primary")
}
return nil
})
if err != nil { if err != nil {
return nil, err return nil, err
} }
if account.IsDomainPrimaryAccount {
return account, nil
}
existingPrimaryAccountID, err := am.Store.GetAccountIDByPrivateDomain(ctx, store.LockingStrengthShare, account.Domain)
// error is not a not found error
if handleNotFound(err) != nil {
return nil, err
}
// a primary account already exists for this private domain
if err == nil {
log.WithContext(ctx).WithFields(log.Fields{
"accountId": accountId,
"existingAccountId": existingPrimaryAccountID,
}).Errorf("cannot update account to primary, another account already exists as primary for the same domain")
return nil, status.Errorf(status.Internal, "cannot update account to primary")
}
account.IsDomainPrimaryAccount = true
if err := am.Store.SaveAccount(ctx, account); err != nil {
log.WithContext(ctx).WithFields(log.Fields{
"accountId": accountId,
}).Errorf("failed to update account to primary: %v", err)
return nil, status.Errorf(status.Internal, "failed to update account to primary")
}
return account, nil return account, nil
} }

View File

@ -664,15 +664,6 @@ func areGroupChangesAffectPeers(ctx context.Context, transaction store.Store, ac
return false, nil return false, nil
} }
func (am *DefaultAccountManager) anyGroupHasPeers(account *types.Account, groupIDs []string) bool {
for _, groupID := range groupIDs {
if group, exists := account.Groups[groupID]; exists && group.HasPeers() {
return true
}
}
return false
}
// anyGroupHasPeersOrResources checks if any of the given groups in the account have peers or resources. // anyGroupHasPeersOrResources checks if any of the given groups in the account have peers or resources.
func anyGroupHasPeersOrResources(ctx context.Context, transaction store.Store, accountID string, groupIDs []string) (bool, error) { func anyGroupHasPeersOrResources(ctx context.Context, transaction store.Store, accountID string, groupIDs []string) (bool, error) {
groups, err := transaction.GetGroupsByIDs(ctx, store.LockingStrengthShare, accountID, groupIDs) groups, err := transaction.GetGroupsByIDs(ctx, store.LockingStrengthShare, accountID, groupIDs)

View File

@ -426,6 +426,10 @@ components:
items: items:
type: string type: string
example: "stage-host-1" example: "stage-host-1"
ephemeral:
description: Indicates whether the peer is ephemeral or not
type: boolean
example: false
required: required:
- city_name - city_name
- connected - connected
@ -450,6 +454,7 @@ components:
- approval_required - approval_required
- serial_number - serial_number
- extra_dns_labels - extra_dns_labels
- ephemeral
AccessiblePeer: AccessiblePeer:
allOf: allOf:
- $ref: '#/components/schemas/PeerMinimum' - $ref: '#/components/schemas/PeerMinimum'

View File

@ -1016,6 +1016,9 @@ type Peer struct {
// DnsLabel Peer's DNS label is the parsed peer name for domain resolution. It is used to form an FQDN by appending the account's domain to the peer label. e.g. peer-dns-label.netbird.cloud // DnsLabel Peer's DNS label is the parsed peer name for domain resolution. It is used to form an FQDN by appending the account's domain to the peer label. e.g. peer-dns-label.netbird.cloud
DnsLabel string `json:"dns_label"` DnsLabel string `json:"dns_label"`
// Ephemeral Indicates whether the peer is ephemeral or not
Ephemeral bool `json:"ephemeral"`
// ExtraDnsLabels Extra DNS labels added to the peer // ExtraDnsLabels Extra DNS labels added to the peer
ExtraDnsLabels []string `json:"extra_dns_labels"` ExtraDnsLabels []string `json:"extra_dns_labels"`
@ -1097,6 +1100,9 @@ type PeerBatch struct {
// DnsLabel Peer's DNS label is the parsed peer name for domain resolution. It is used to form an FQDN by appending the account's domain to the peer label. e.g. peer-dns-label.netbird.cloud // DnsLabel Peer's DNS label is the parsed peer name for domain resolution. It is used to form an FQDN by appending the account's domain to the peer label. e.g. peer-dns-label.netbird.cloud
DnsLabel string `json:"dns_label"` DnsLabel string `json:"dns_label"`
// Ephemeral Indicates whether the peer is ephemeral or not
Ephemeral bool `json:"ephemeral"`
// ExtraDnsLabels Extra DNS labels added to the peer // ExtraDnsLabels Extra DNS labels added to the peer
ExtraDnsLabels []string `json:"extra_dns_labels"` ExtraDnsLabels []string `json:"extra_dns_labels"`

View File

@ -365,6 +365,7 @@ func toSinglePeerResponse(peer *nbpeer.Peer, groupsInfo []api.GroupMinimum, dnsD
CityName: peer.Location.CityName, CityName: peer.Location.CityName,
SerialNumber: peer.Meta.SystemSerialNumber, SerialNumber: peer.Meta.SystemSerialNumber,
InactivityExpirationEnabled: peer.InactivityExpirationEnabled, InactivityExpirationEnabled: peer.InactivityExpirationEnabled,
Ephemeral: peer.Ephemeral,
} }
} }

View File

@ -37,21 +37,23 @@ func (am *DefaultAccountManager) UpdateIntegratedValidatorGroups(ctx context.Con
unlock := am.Store.AcquireWriteLockByUID(ctx, accountID) unlock := am.Store.AcquireWriteLockByUID(ctx, accountID)
defer unlock() defer unlock()
a, err := am.Store.GetAccountByUser(ctx, userID) return am.Store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
if err != nil { a, err := transaction.GetAccountByUser(ctx, userID)
return err if err != nil {
} return err
}
var extra *types.ExtraSettings var extra *types.ExtraSettings
if a.Settings.Extra != nil { if a.Settings.Extra != nil {
extra = a.Settings.Extra extra = a.Settings.Extra
} else { } else {
extra = &types.ExtraSettings{} extra = &types.ExtraSettings{}
a.Settings.Extra = extra a.Settings.Extra = extra
} }
extra.IntegratedValidatorGroups = groups extra.IntegratedValidatorGroups = groups
return am.Store.SaveAccount(ctx, a) return transaction.SaveAccount(ctx, a)
})
} }
func (am *DefaultAccountManager) GroupValidation(ctx context.Context, accountID string, groupIDs []string) (bool, error) { func (am *DefaultAccountManager) GroupValidation(ctx context.Context, accountID string, groupIDs []string) (bool, error) {

View File

@ -92,7 +92,7 @@ func (am *DefaultAccountManager) getUserAccessiblePeers(ctx context.Context, acc
// fetch all the peers that have access to the user's peers // fetch all the peers that have access to the user's peers
for _, peer := range peers { for _, peer := range peers {
aclPeers, _ := account.GetPeerConnectionResources(ctx, peer.ID, approvedPeersMap) aclPeers, _ := account.GetPeerConnectionResources(ctx, peer, approvedPeersMap)
for _, p := range aclPeers { for _, p := range aclPeers {
peersMap[p.ID] = p peersMap[p.ID] = p
} }
@ -1149,7 +1149,7 @@ func (am *DefaultAccountManager) checkIfUserOwnsPeer(ctx context.Context, accoun
} }
for _, p := range userPeers { for _, p := range userPeers {
aclPeers, _ := account.GetPeerConnectionResources(ctx, p.ID, approvedPeersMap) aclPeers, _ := account.GetPeerConnectionResources(ctx, p, approvedPeersMap)
for _, aclPeer := range aclPeers { for _, aclPeer := range aclPeers {
if aclPeer.ID == peer.ID { if aclPeer.ID == peer.ID {
return peer, nil return peer, nil

View File

@ -27,6 +27,7 @@ func TestAccount_getPeersByPolicy(t *testing.T) {
ID: "peerB", ID: "peerB",
IP: net.ParseIP("100.65.80.39"), IP: net.ParseIP("100.65.80.39"),
Status: &nbpeer.PeerStatus{}, Status: &nbpeer.PeerStatus{},
Meta: nbpeer.PeerSystemMeta{WtVersion: "0.48.0"},
}, },
"peerC": { "peerC": {
ID: "peerC", ID: "peerC",
@ -63,6 +64,12 @@ func TestAccount_getPeersByPolicy(t *testing.T) {
IP: net.ParseIP("100.65.31.2"), IP: net.ParseIP("100.65.31.2"),
Status: &nbpeer.PeerStatus{}, Status: &nbpeer.PeerStatus{},
}, },
"peerK": {
ID: "peerK",
IP: net.ParseIP("100.32.80.1"),
Status: &nbpeer.PeerStatus{},
Meta: nbpeer.PeerSystemMeta{WtVersion: "0.30.0"},
},
}, },
Groups: map[string]*types.Group{ Groups: map[string]*types.Group{
"GroupAll": { "GroupAll": {
@ -111,6 +118,13 @@ func TestAccount_getPeersByPolicy(t *testing.T) {
"peerI", "peerI",
}, },
}, },
"GroupWorkflow": {
ID: "GroupWorkflow",
Name: "workflow",
Peers: []string{
"peerK",
},
},
}, },
Policies: []*types.Policy{ Policies: []*types.Policy{
{ {
@ -189,6 +203,39 @@ func TestAccount_getPeersByPolicy(t *testing.T) {
}, },
}, },
}, },
{
ID: "RuleWorkflow",
Name: "Workflow",
Description: "No description",
Enabled: true,
Rules: []*types.PolicyRule{
{
ID: "RuleWorkflow",
Name: "Workflow",
Description: "No description",
Bidirectional: true,
Enabled: true,
Protocol: types.PolicyRuleProtocolTCP,
Action: types.PolicyTrafficActionAccept,
PortRanges: []types.RulePortRange{
{
Start: 8088,
End: 8088,
},
{
Start: 9090,
End: 9095,
},
},
Sources: []string{
"GroupWorkflow",
},
Destinations: []string{
"GroupDMZ",
},
},
},
},
}, },
} }
@ -199,14 +246,14 @@ func TestAccount_getPeersByPolicy(t *testing.T) {
t.Run("check that all peers get map", func(t *testing.T) { t.Run("check that all peers get map", func(t *testing.T) {
for _, p := range account.Peers { for _, p := range account.Peers {
peers, firewallRules := account.GetPeerConnectionResources(context.Background(), p.ID, validatedPeers) peers, firewallRules := account.GetPeerConnectionResources(context.Background(), p, validatedPeers)
assert.GreaterOrEqual(t, len(peers), 2, "minimum number peers should present") assert.GreaterOrEqual(t, len(peers), 1, "minimum number peers should present")
assert.GreaterOrEqual(t, len(firewallRules), 2, "minimum number of firewall rules should present") assert.GreaterOrEqual(t, len(firewallRules), 1, "minimum number of firewall rules should present")
} }
}) })
t.Run("check first peer map details", func(t *testing.T) { t.Run("check first peer map details", func(t *testing.T) {
peers, firewallRules := account.GetPeerConnectionResources(context.Background(), "peerB", validatedPeers) peers, firewallRules := account.GetPeerConnectionResources(context.Background(), account.Peers["peerB"], validatedPeers)
assert.Len(t, peers, 8) assert.Len(t, peers, 8)
assert.Contains(t, peers, account.Peers["peerA"]) assert.Contains(t, peers, account.Peers["peerA"])
assert.Contains(t, peers, account.Peers["peerC"]) assert.Contains(t, peers, account.Peers["peerC"])
@ -364,6 +411,32 @@ func TestAccount_getPeersByPolicy(t *testing.T) {
assert.True(t, contains, "rule not found in expected rules %#v", rule) assert.True(t, contains, "rule not found in expected rules %#v", rule)
} }
}) })
t.Run("check port ranges support for older peers", func(t *testing.T) {
peers, firewallRules := account.GetPeerConnectionResources(context.Background(), account.Peers["peerK"], validatedPeers)
assert.Len(t, peers, 1)
assert.Contains(t, peers, account.Peers["peerI"])
expectedFirewallRules := []*types.FirewallRule{
{
PeerIP: "100.65.31.2",
Direction: types.FirewallRuleDirectionIN,
Action: "accept",
Protocol: "tcp",
Port: "8088",
PolicyID: "RuleWorkflow",
},
{
PeerIP: "100.65.31.2",
Direction: types.FirewallRuleDirectionOUT,
Action: "accept",
Protocol: "tcp",
Port: "8088",
PolicyID: "RuleWorkflow",
},
}
assert.ElementsMatch(t, firewallRules, expectedFirewallRules)
})
} }
func TestAccount_getPeersByPolicyDirect(t *testing.T) { func TestAccount_getPeersByPolicyDirect(t *testing.T) {
@ -466,10 +539,10 @@ func TestAccount_getPeersByPolicyDirect(t *testing.T) {
} }
t.Run("check first peer map", func(t *testing.T) { t.Run("check first peer map", func(t *testing.T) {
peers, firewallRules := account.GetPeerConnectionResources(context.Background(), "peerB", approvedPeers) peers, firewallRules := account.GetPeerConnectionResources(context.Background(), account.Peers["peerB"], approvedPeers)
assert.Contains(t, peers, account.Peers["peerC"]) assert.Contains(t, peers, account.Peers["peerC"])
epectedFirewallRules := []*types.FirewallRule{ expectedFirewallRules := []*types.FirewallRule{
{ {
PeerIP: "100.65.254.139", PeerIP: "100.65.254.139",
Direction: types.FirewallRuleDirectionIN, Direction: types.FirewallRuleDirectionIN,
@ -487,19 +560,19 @@ func TestAccount_getPeersByPolicyDirect(t *testing.T) {
PolicyID: "RuleSwarm", PolicyID: "RuleSwarm",
}, },
} }
assert.Len(t, firewallRules, len(epectedFirewallRules)) assert.Len(t, firewallRules, len(expectedFirewallRules))
slices.SortFunc(epectedFirewallRules, sortFunc()) slices.SortFunc(expectedFirewallRules, sortFunc())
slices.SortFunc(firewallRules, sortFunc()) slices.SortFunc(firewallRules, sortFunc())
for i := range firewallRules { for i := range firewallRules {
assert.Equal(t, epectedFirewallRules[i], firewallRules[i]) assert.Equal(t, expectedFirewallRules[i], firewallRules[i])
} }
}) })
t.Run("check second peer map", func(t *testing.T) { t.Run("check second peer map", func(t *testing.T) {
peers, firewallRules := account.GetPeerConnectionResources(context.Background(), "peerC", approvedPeers) peers, firewallRules := account.GetPeerConnectionResources(context.Background(), account.Peers["peerC"], approvedPeers)
assert.Contains(t, peers, account.Peers["peerB"]) assert.Contains(t, peers, account.Peers["peerB"])
epectedFirewallRules := []*types.FirewallRule{ expectedFirewallRules := []*types.FirewallRule{
{ {
PeerIP: "100.65.80.39", PeerIP: "100.65.80.39",
Direction: types.FirewallRuleDirectionIN, Direction: types.FirewallRuleDirectionIN,
@ -517,21 +590,21 @@ func TestAccount_getPeersByPolicyDirect(t *testing.T) {
PolicyID: "RuleSwarm", PolicyID: "RuleSwarm",
}, },
} }
assert.Len(t, firewallRules, len(epectedFirewallRules)) assert.Len(t, firewallRules, len(expectedFirewallRules))
slices.SortFunc(epectedFirewallRules, sortFunc()) slices.SortFunc(expectedFirewallRules, sortFunc())
slices.SortFunc(firewallRules, sortFunc()) slices.SortFunc(firewallRules, sortFunc())
for i := range firewallRules { for i := range firewallRules {
assert.Equal(t, epectedFirewallRules[i], firewallRules[i]) assert.Equal(t, expectedFirewallRules[i], firewallRules[i])
} }
}) })
account.Policies[1].Rules[0].Bidirectional = false account.Policies[1].Rules[0].Bidirectional = false
t.Run("check first peer map directional only", func(t *testing.T) { t.Run("check first peer map directional only", func(t *testing.T) {
peers, firewallRules := account.GetPeerConnectionResources(context.Background(), "peerB", approvedPeers) peers, firewallRules := account.GetPeerConnectionResources(context.Background(), account.Peers["peerB"], approvedPeers)
assert.Contains(t, peers, account.Peers["peerC"]) assert.Contains(t, peers, account.Peers["peerC"])
epectedFirewallRules := []*types.FirewallRule{ expectedFirewallRules := []*types.FirewallRule{
{ {
PeerIP: "100.65.254.139", PeerIP: "100.65.254.139",
Direction: types.FirewallRuleDirectionOUT, Direction: types.FirewallRuleDirectionOUT,
@ -541,19 +614,19 @@ func TestAccount_getPeersByPolicyDirect(t *testing.T) {
PolicyID: "RuleSwarm", PolicyID: "RuleSwarm",
}, },
} }
assert.Len(t, firewallRules, len(epectedFirewallRules)) assert.Len(t, firewallRules, len(expectedFirewallRules))
slices.SortFunc(epectedFirewallRules, sortFunc()) slices.SortFunc(expectedFirewallRules, sortFunc())
slices.SortFunc(firewallRules, sortFunc()) slices.SortFunc(firewallRules, sortFunc())
for i := range firewallRules { for i := range firewallRules {
assert.Equal(t, epectedFirewallRules[i], firewallRules[i]) assert.Equal(t, expectedFirewallRules[i], firewallRules[i])
} }
}) })
t.Run("check second peer map directional only", func(t *testing.T) { t.Run("check second peer map directional only", func(t *testing.T) {
peers, firewallRules := account.GetPeerConnectionResources(context.Background(), "peerC", approvedPeers) peers, firewallRules := account.GetPeerConnectionResources(context.Background(), account.Peers["peerC"], approvedPeers)
assert.Contains(t, peers, account.Peers["peerB"]) assert.Contains(t, peers, account.Peers["peerB"])
epectedFirewallRules := []*types.FirewallRule{ expectedFirewallRules := []*types.FirewallRule{
{ {
PeerIP: "100.65.80.39", PeerIP: "100.65.80.39",
Direction: types.FirewallRuleDirectionIN, Direction: types.FirewallRuleDirectionIN,
@ -563,11 +636,11 @@ func TestAccount_getPeersByPolicyDirect(t *testing.T) {
PolicyID: "RuleSwarm", PolicyID: "RuleSwarm",
}, },
} }
assert.Len(t, firewallRules, len(epectedFirewallRules)) assert.Len(t, firewallRules, len(expectedFirewallRules))
slices.SortFunc(epectedFirewallRules, sortFunc()) slices.SortFunc(expectedFirewallRules, sortFunc())
slices.SortFunc(firewallRules, sortFunc()) slices.SortFunc(firewallRules, sortFunc())
for i := range firewallRules { for i := range firewallRules {
assert.Equal(t, epectedFirewallRules[i], firewallRules[i]) assert.Equal(t, expectedFirewallRules[i], firewallRules[i])
} }
}) })
} }
@ -748,7 +821,7 @@ func TestAccount_getPeersByPolicyPostureChecks(t *testing.T) {
t.Run("verify peer's network map with default group peer list", func(t *testing.T) { t.Run("verify peer's network map with default group peer list", func(t *testing.T) {
// peerB doesn't fulfill the NB posture check but is included in the destination group Swarm, // peerB doesn't fulfill the NB posture check but is included in the destination group Swarm,
// will establish a connection with all source peers satisfying the NB posture check. // will establish a connection with all source peers satisfying the NB posture check.
peers, firewallRules := account.GetPeerConnectionResources(context.Background(), "peerB", approvedPeers) peers, firewallRules := account.GetPeerConnectionResources(context.Background(), account.Peers["peerB"], approvedPeers)
assert.Len(t, peers, 4) assert.Len(t, peers, 4)
assert.Len(t, firewallRules, 4) assert.Len(t, firewallRules, 4)
assert.Contains(t, peers, account.Peers["peerA"]) assert.Contains(t, peers, account.Peers["peerA"])
@ -758,7 +831,7 @@ func TestAccount_getPeersByPolicyPostureChecks(t *testing.T) {
// peerC satisfy the NB posture check, should establish connection to all destination group peer's // peerC satisfy the NB posture check, should establish connection to all destination group peer's
// We expect a single permissive firewall rule which all outgoing connections // We expect a single permissive firewall rule which all outgoing connections
peers, firewallRules = account.GetPeerConnectionResources(context.Background(), "peerC", approvedPeers) peers, firewallRules = account.GetPeerConnectionResources(context.Background(), account.Peers["peerC"], approvedPeers)
assert.Len(t, peers, len(account.Groups["GroupSwarm"].Peers)) assert.Len(t, peers, len(account.Groups["GroupSwarm"].Peers))
assert.Len(t, firewallRules, 1) assert.Len(t, firewallRules, 1)
expectedFirewallRules := []*types.FirewallRule{ expectedFirewallRules := []*types.FirewallRule{
@ -775,7 +848,7 @@ func TestAccount_getPeersByPolicyPostureChecks(t *testing.T) {
// peerE doesn't fulfill the NB posture check and exists in only destination group Swarm, // peerE doesn't fulfill the NB posture check and exists in only destination group Swarm,
// all source group peers satisfying the NB posture check should establish connection // all source group peers satisfying the NB posture check should establish connection
peers, firewallRules = account.GetPeerConnectionResources(context.Background(), "peerE", approvedPeers) peers, firewallRules = account.GetPeerConnectionResources(context.Background(), account.Peers["peerE"], approvedPeers)
assert.Len(t, peers, 4) assert.Len(t, peers, 4)
assert.Len(t, firewallRules, 4) assert.Len(t, firewallRules, 4)
assert.Contains(t, peers, account.Peers["peerA"]) assert.Contains(t, peers, account.Peers["peerA"])
@ -785,7 +858,7 @@ func TestAccount_getPeersByPolicyPostureChecks(t *testing.T) {
// peerI doesn't fulfill the OS version posture check and exists in only destination group Swarm, // peerI doesn't fulfill the OS version posture check and exists in only destination group Swarm,
// all source group peers satisfying the NB posture check should establish connection // all source group peers satisfying the NB posture check should establish connection
peers, firewallRules = account.GetPeerConnectionResources(context.Background(), "peerI", approvedPeers) peers, firewallRules = account.GetPeerConnectionResources(context.Background(), account.Peers["peerI"], approvedPeers)
assert.Len(t, peers, 4) assert.Len(t, peers, 4)
assert.Len(t, firewallRules, 4) assert.Len(t, firewallRules, 4)
assert.Contains(t, peers, account.Peers["peerA"]) assert.Contains(t, peers, account.Peers["peerA"])
@ -800,19 +873,19 @@ func TestAccount_getPeersByPolicyPostureChecks(t *testing.T) {
// peerB doesn't satisfy the NB posture check, and doesn't exist in destination group peer's // peerB doesn't satisfy the NB posture check, and doesn't exist in destination group peer's
// no connection should be established to any peer of destination group // no connection should be established to any peer of destination group
peers, firewallRules := account.GetPeerConnectionResources(context.Background(), "peerB", approvedPeers) peers, firewallRules := account.GetPeerConnectionResources(context.Background(), account.Peers["peerB"], approvedPeers)
assert.Len(t, peers, 0) assert.Len(t, peers, 0)
assert.Len(t, firewallRules, 0) assert.Len(t, firewallRules, 0)
// peerI doesn't satisfy the OS version posture check, and doesn't exist in destination group peer's // peerI doesn't satisfy the OS version posture check, and doesn't exist in destination group peer's
// no connection should be established to any peer of destination group // no connection should be established to any peer of destination group
peers, firewallRules = account.GetPeerConnectionResources(context.Background(), "peerI", approvedPeers) peers, firewallRules = account.GetPeerConnectionResources(context.Background(), account.Peers["peerI"], approvedPeers)
assert.Len(t, peers, 0) assert.Len(t, peers, 0)
assert.Len(t, firewallRules, 0) assert.Len(t, firewallRules, 0)
// peerC satisfy the NB posture check, should establish connection to all destination group peer's // peerC satisfy the NB posture check, should establish connection to all destination group peer's
// We expect a single permissive firewall rule which all outgoing connections // We expect a single permissive firewall rule which all outgoing connections
peers, firewallRules = account.GetPeerConnectionResources(context.Background(), "peerC", approvedPeers) peers, firewallRules = account.GetPeerConnectionResources(context.Background(), account.Peers["peerC"], approvedPeers)
assert.Len(t, peers, len(account.Groups["GroupSwarm"].Peers)) assert.Len(t, peers, len(account.Groups["GroupSwarm"].Peers))
assert.Len(t, firewallRules, len(account.Groups["GroupSwarm"].Peers)) assert.Len(t, firewallRules, len(account.Groups["GroupSwarm"].Peers))
@ -827,14 +900,14 @@ func TestAccount_getPeersByPolicyPostureChecks(t *testing.T) {
// peerE doesn't fulfill the NB posture check and exists in only destination group Swarm, // peerE doesn't fulfill the NB posture check and exists in only destination group Swarm,
// all source group peers satisfying the NB posture check should establish connection // all source group peers satisfying the NB posture check should establish connection
peers, firewallRules = account.GetPeerConnectionResources(context.Background(), "peerE", approvedPeers) peers, firewallRules = account.GetPeerConnectionResources(context.Background(), account.Peers["peerE"], approvedPeers)
assert.Len(t, peers, 3) assert.Len(t, peers, 3)
assert.Len(t, firewallRules, 3) assert.Len(t, firewallRules, 3)
assert.Contains(t, peers, account.Peers["peerA"]) assert.Contains(t, peers, account.Peers["peerA"])
assert.Contains(t, peers, account.Peers["peerC"]) assert.Contains(t, peers, account.Peers["peerC"])
assert.Contains(t, peers, account.Peers["peerD"]) assert.Contains(t, peers, account.Peers["peerD"])
peers, firewallRules = account.GetPeerConnectionResources(context.Background(), "peerA", approvedPeers) peers, firewallRules = account.GetPeerConnectionResources(context.Background(), account.Peers["peerA"], approvedPeers)
assert.Len(t, peers, 5) assert.Len(t, peers, 5)
// assert peers from Group Swarm // assert peers from Group Swarm
assert.Contains(t, peers, account.Peers["peerD"]) assert.Contains(t, peers, account.Peers["peerD"])

View File

@ -24,20 +24,12 @@ func sanitizeVersion(version string) string {
} }
func (n *NBVersionCheck) Check(ctx context.Context, peer nbpeer.Peer) (bool, error) { func (n *NBVersionCheck) Check(ctx context.Context, peer nbpeer.Peer) (bool, error) {
peerVersion := sanitizeVersion(peer.Meta.WtVersion) meetsMin, err := MeetsMinVersion(n.MinVersion, peer.Meta.WtVersion)
minVersion := sanitizeVersion(n.MinVersion)
peerNBVersion, err := version.NewVersion(peerVersion)
if err != nil { if err != nil {
return false, err return false, err
} }
constraints, err := version.NewConstraint(">= " + minVersion) if meetsMin {
if err != nil {
return false, err
}
if constraints.Check(peerNBVersion) {
return true, nil return true, nil
} }
@ -60,3 +52,21 @@ func (n *NBVersionCheck) Validate() error {
} }
return nil return nil
} }
// MeetsMinVersion checks if the peer's version meets or exceeds the minimum required version
func MeetsMinVersion(minVer, peerVer string) (bool, error) {
peerVer = sanitizeVersion(peerVer)
minVer = sanitizeVersion(minVer)
peerNBVer, err := version.NewVersion(peerVer)
if err != nil {
return false, err
}
constraints, err := version.NewConstraint(">= " + minVer)
if err != nil {
return false, err
}
return constraints.Check(peerNBVer), nil
}

View File

@ -139,3 +139,68 @@ func TestNBVersionCheck_Validate(t *testing.T) {
}) })
} }
} }
func TestMeetsMinVersion(t *testing.T) {
tests := []struct {
name string
minVer string
peerVer string
want bool
wantErr bool
}{
{
name: "Peer version greater than min version",
minVer: "0.26.0",
peerVer: "0.60.1",
want: true,
wantErr: false,
},
{
name: "Peer version equals min version",
minVer: "1.0.0",
peerVer: "1.0.0",
want: true,
wantErr: false,
},
{
name: "Peer version less than min version",
minVer: "1.0.0",
peerVer: "0.9.9",
want: false,
wantErr: false,
},
{
name: "Peer version with pre-release tag greater than min version",
minVer: "1.0.0",
peerVer: "1.0.1-alpha",
want: true,
wantErr: false,
},
{
name: "Invalid peer version format",
minVer: "1.0.0",
peerVer: "dev",
want: false,
wantErr: true,
},
{
name: "Invalid min version format",
minVer: "invalid.version",
peerVer: "1.0.0",
want: false,
wantErr: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got, err := MeetsMinVersion(tt.minVer, tt.peerVer)
if tt.wantErr {
assert.Error(t, err)
} else {
assert.NoError(t, err)
}
assert.Equal(t, tt.want, got)
})
}
}

View File

@ -4,19 +4,19 @@ import (
"context" "context"
"fmt" "fmt"
"net/netip" "net/netip"
"slices"
"unicode/utf8" "unicode/utf8"
"github.com/rs/xid" "github.com/rs/xid"
"github.com/netbirdio/netbird/management/server/permissions/modules"
"github.com/netbirdio/netbird/management/server/permissions/operations"
"github.com/netbirdio/netbird/management/server/store"
"github.com/netbirdio/netbird/management/server/types"
"github.com/netbirdio/netbird/management/domain" "github.com/netbirdio/netbird/management/domain"
"github.com/netbirdio/netbird/management/proto" "github.com/netbirdio/netbird/management/proto"
"github.com/netbirdio/netbird/management/server/activity" "github.com/netbirdio/netbird/management/server/activity"
"github.com/netbirdio/netbird/management/server/permissions/modules"
"github.com/netbirdio/netbird/management/server/permissions/operations"
"github.com/netbirdio/netbird/management/server/status" "github.com/netbirdio/netbird/management/server/status"
"github.com/netbirdio/netbird/management/server/store"
"github.com/netbirdio/netbird/management/server/types"
"github.com/netbirdio/netbird/route" "github.com/netbirdio/netbird/route"
) )
@ -30,13 +30,19 @@ func (am *DefaultAccountManager) GetRoute(ctx context.Context, accountID string,
return nil, status.NewPermissionDeniedError() return nil, status.NewPermissionDeniedError()
} }
return am.Store.GetRouteByID(ctx, store.LockingStrengthShare, string(routeID), accountID) return am.Store.GetRouteByID(ctx, store.LockingStrengthShare, accountID, string(routeID))
} }
// checkRoutePrefixOrDomainsExistForPeers checks if a route with a given prefix exists for a single peer or multiple peer groups. // checkRoutePrefixOrDomainsExistForPeers checks if a route with a given prefix exists for a single peer or multiple peer groups.
func (am *DefaultAccountManager) checkRoutePrefixOrDomainsExistForPeers(account *types.Account, peerID string, routeID route.ID, peerGroupIDs []string, prefix netip.Prefix, domains domain.List) error { func checkRoutePrefixOrDomainsExistForPeers(ctx context.Context, transaction store.Store, accountID string, checkRoute *route.Route, groupsMap map[string]*types.Group) error {
// routes can have both peer and peer_groups // routes can have both peer and peer_groups
routesWithPrefix := account.GetRoutesByPrefixOrDomains(prefix, domains) prefix := checkRoute.Network
domains := checkRoute.Domains
routesWithPrefix, err := getRoutesByPrefixOrDomains(ctx, transaction, accountID, prefix, domains)
if err != nil {
return err
}
// lets remember all the peers and the peer groups from routesWithPrefix // lets remember all the peers and the peer groups from routesWithPrefix
seenPeers := make(map[string]bool) seenPeers := make(map[string]bool)
@ -45,18 +51,24 @@ func (am *DefaultAccountManager) checkRoutePrefixOrDomainsExistForPeers(account
for _, prefixRoute := range routesWithPrefix { for _, prefixRoute := range routesWithPrefix {
// we skip route(s) with the same network ID as we want to allow updating of the existing route // we skip route(s) with the same network ID as we want to allow updating of the existing route
// when creating a new route routeID is newly generated so nothing will be skipped // when creating a new route routeID is newly generated so nothing will be skipped
if routeID == prefixRoute.ID { if checkRoute.ID == prefixRoute.ID {
continue continue
} }
if prefixRoute.Peer != "" { if prefixRoute.Peer != "" {
seenPeers[string(prefixRoute.ID)] = true seenPeers[string(prefixRoute.ID)] = true
} }
peerGroupsMap, err := transaction.GetGroupsByIDs(ctx, store.LockingStrengthShare, accountID, prefixRoute.PeerGroups)
if err != nil {
return err
}
for _, groupID := range prefixRoute.PeerGroups { for _, groupID := range prefixRoute.PeerGroups {
seenPeerGroups[groupID] = true seenPeerGroups[groupID] = true
group := account.GetGroup(groupID) group, ok := peerGroupsMap[groupID]
if group == nil { if !ok || group == nil {
return status.Errorf( return status.Errorf(
status.InvalidArgument, "failed to add route with %s - peer group %s doesn't exist", status.InvalidArgument, "failed to add route with %s - peer group %s doesn't exist",
getRouteDescriptor(prefix, domains), groupID, getRouteDescriptor(prefix, domains), groupID,
@ -69,12 +81,13 @@ func (am *DefaultAccountManager) checkRoutePrefixOrDomainsExistForPeers(account
} }
} }
if peerID != "" { if peerID := checkRoute.Peer; peerID != "" {
// check that peerID exists and is not in any route as single peer or part of the group // check that peerID exists and is not in any route as single peer or part of the group
peer := account.GetPeer(peerID) _, err = transaction.GetPeerByID(context.Background(), store.LockingStrengthShare, accountID, peerID)
if peer == nil { if err != nil {
return status.Errorf(status.InvalidArgument, "peer with ID %s not found", peerID) return status.Errorf(status.InvalidArgument, "peer with ID %s not found", peerID)
} }
if _, ok := seenPeers[peerID]; ok { if _, ok := seenPeers[peerID]; ok {
return status.Errorf(status.AlreadyExists, return status.Errorf(status.AlreadyExists,
"failed to add route with %s - peer %s already has this route", getRouteDescriptor(prefix, domains), peerID) "failed to add route with %s - peer %s already has this route", getRouteDescriptor(prefix, domains), peerID)
@ -82,9 +95,8 @@ func (am *DefaultAccountManager) checkRoutePrefixOrDomainsExistForPeers(account
} }
// check that peerGroupIDs are not in any route peerGroups list // check that peerGroupIDs are not in any route peerGroups list
for _, groupID := range peerGroupIDs { for _, groupID := range checkRoute.PeerGroups {
group := account.GetGroup(groupID) // we validated the group existence before entering this function, no need to check again. group := groupsMap[groupID] // we validated the group existence before entering this function, no need to check again.
if _, ok := seenPeerGroups[groupID]; ok { if _, ok := seenPeerGroups[groupID]; ok {
return status.Errorf( return status.Errorf(
status.AlreadyExists, "failed to add route with %s - peer group %s already has this route", status.AlreadyExists, "failed to add route with %s - peer group %s already has this route",
@ -92,12 +104,18 @@ func (am *DefaultAccountManager) checkRoutePrefixOrDomainsExistForPeers(account
} }
// check that the peers from peerGroupIDs groups are not the same peers we saw in routesWithPrefix // check that the peers from peerGroupIDs groups are not the same peers we saw in routesWithPrefix
peersMap, err := transaction.GetPeersByIDs(ctx, store.LockingStrengthShare, accountID, group.Peers)
if err != nil {
return err
}
for _, id := range group.Peers { for _, id := range group.Peers {
if _, ok := seenPeers[id]; ok { if _, ok := seenPeers[id]; ok {
peer := account.GetPeer(id) peer, ok := peersMap[id]
if peer == nil { if !ok || peer == nil {
return status.Errorf(status.InvalidArgument, "peer with ID %s not found", peerID) return status.Errorf(status.InvalidArgument, "peer with ID %s not found", id)
} }
return status.Errorf(status.AlreadyExists, return status.Errorf(status.AlreadyExists,
"failed to add route with %s - peer %s from the group %s already has this route", "failed to add route with %s - peer %s from the group %s already has this route",
getRouteDescriptor(prefix, domains), peer.Name, group.Name) getRouteDescriptor(prefix, domains), peer.Name, group.Name)
@ -128,97 +146,58 @@ func (am *DefaultAccountManager) CreateRoute(ctx context.Context, accountID stri
return nil, status.NewPermissionDeniedError() return nil, status.NewPermissionDeniedError()
} }
account, err := am.Store.GetAccount(ctx, accountID)
if err != nil {
return nil, err
}
if len(domains) > 0 && prefix.IsValid() { if len(domains) > 0 && prefix.IsValid() {
return nil, status.Errorf(status.InvalidArgument, "domains and network should not be provided at the same time") return nil, status.Errorf(status.InvalidArgument, "domains and network should not be provided at the same time")
} }
if len(domains) == 0 && !prefix.IsValid() { var newRoute *route.Route
return nil, status.Errorf(status.InvalidArgument, "invalid Prefix") var updateAccountPeers bool
}
if len(domains) > 0 { err = am.Store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
prefix = getPlaceholderIP() newRoute = &route.Route{
} ID: route.ID(xid.New().String()),
AccountID: accountID,
if peerID != "" && len(peerGroupIDs) != 0 { Network: prefix,
return nil, status.Errorf( Domains: domains,
status.InvalidArgument, KeepRoute: keepRoute,
"peer with ID %s and peers group %s should not be provided at the same time", NetID: netID,
peerID, peerGroupIDs) Description: description,
} Peer: peerID,
PeerGroups: peerGroupIDs,
var newRoute route.Route NetworkType: networkType,
newRoute.ID = route.ID(xid.New().String()) Masquerade: masquerade,
Metric: metric,
if len(peerGroupIDs) > 0 { Enabled: enabled,
err = validateGroups(peerGroupIDs, account.Groups) Groups: groups,
if err != nil { AccessControlGroups: accessControlGroupIDs,
return nil, err
} }
}
if len(accessControlGroupIDs) > 0 { if err = validateRoute(ctx, transaction, accountID, newRoute); err != nil {
err = validateGroups(accessControlGroupIDs, account.Groups) return err
if err != nil {
return nil, err
} }
}
err = am.checkRoutePrefixOrDomainsExistForPeers(account, peerID, newRoute.ID, peerGroupIDs, prefix, domains) updateAccountPeers, err = areRouteChangesAffectPeers(ctx, transaction, newRoute)
if err != nil {
return err
}
if err = transaction.IncrementNetworkSerial(ctx, store.LockingStrengthUpdate, accountID); err != nil {
return err
}
return transaction.SaveRoute(ctx, store.LockingStrengthUpdate, newRoute)
})
if err != nil { if err != nil {
return nil, err return nil, err
} }
if metric < route.MinMetric || metric > route.MaxMetric {
return nil, status.Errorf(status.InvalidArgument, "metric should be between %d and %d", route.MinMetric, route.MaxMetric)
}
if utf8.RuneCountInString(string(netID)) > route.MaxNetIDChar || netID == "" {
return nil, status.Errorf(status.InvalidArgument, "identifier should be between 1 and %d", route.MaxNetIDChar)
}
err = validateGroups(groups, account.Groups)
if err != nil {
return nil, err
}
newRoute.Peer = peerID
newRoute.PeerGroups = peerGroupIDs
newRoute.Network = prefix
newRoute.Domains = domains
newRoute.NetworkType = networkType
newRoute.Description = description
newRoute.NetID = netID
newRoute.Masquerade = masquerade
newRoute.Metric = metric
newRoute.Enabled = enabled
newRoute.Groups = groups
newRoute.KeepRoute = keepRoute
newRoute.AccessControlGroups = accessControlGroupIDs
if account.Routes == nil {
account.Routes = make(map[route.ID]*route.Route)
}
account.Routes[newRoute.ID] = &newRoute
account.Network.IncSerial()
if err = am.Store.SaveAccount(ctx, account); err != nil {
return nil, err
}
if am.isRouteChangeAffectPeers(account, &newRoute) {
am.UpdateAccountPeers(ctx, accountID)
}
am.StoreEvent(ctx, userID, string(newRoute.ID), accountID, activity.RouteCreated, newRoute.EventMeta()) am.StoreEvent(ctx, userID, string(newRoute.ID), accountID, activity.RouteCreated, newRoute.EventMeta())
return &newRoute, nil if updateAccountPeers {
am.UpdateAccountPeers(ctx, accountID)
}
return newRoute, nil
} }
// SaveRoute saves route // SaveRoute saves route
@ -226,6 +205,115 @@ func (am *DefaultAccountManager) SaveRoute(ctx context.Context, accountID, userI
unlock := am.Store.AcquireWriteLockByUID(ctx, accountID) unlock := am.Store.AcquireWriteLockByUID(ctx, accountID)
defer unlock() defer unlock()
allowed, err := am.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Routes, operations.Update)
if err != nil {
return status.NewPermissionValidationError(err)
}
if !allowed {
return status.NewPermissionDeniedError()
}
var oldRoute *route.Route
var oldRouteAffectsPeers bool
var newRouteAffectsPeers bool
err = am.Store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
if err = validateRoute(ctx, transaction, accountID, routeToSave); err != nil {
return err
}
oldRoute, err = transaction.GetRouteByID(ctx, store.LockingStrengthUpdate, accountID, string(routeToSave.ID))
if err != nil {
return err
}
oldRouteAffectsPeers, err = areRouteChangesAffectPeers(ctx, transaction, oldRoute)
if err != nil {
return err
}
newRouteAffectsPeers, err = areRouteChangesAffectPeers(ctx, transaction, routeToSave)
if err != nil {
return err
}
routeToSave.AccountID = accountID
if err = transaction.IncrementNetworkSerial(ctx, store.LockingStrengthUpdate, accountID); err != nil {
return err
}
return transaction.SaveRoute(ctx, store.LockingStrengthUpdate, routeToSave)
})
if err != nil {
return err
}
am.StoreEvent(ctx, userID, string(routeToSave.ID), accountID, activity.RouteUpdated, routeToSave.EventMeta())
if oldRouteAffectsPeers || newRouteAffectsPeers {
am.UpdateAccountPeers(ctx, accountID)
}
return nil
}
// DeleteRoute deletes route with routeID
func (am *DefaultAccountManager) DeleteRoute(ctx context.Context, accountID string, routeID route.ID, userID string) error {
unlock := am.Store.AcquireWriteLockByUID(ctx, accountID)
defer unlock()
allowed, err := am.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Routes, operations.Delete)
if err != nil {
return status.NewPermissionValidationError(err)
}
if !allowed {
return status.NewPermissionDeniedError()
}
var route *route.Route
var updateAccountPeers bool
err = am.Store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
route, err = transaction.GetRouteByID(ctx, store.LockingStrengthUpdate, accountID, string(routeID))
if err != nil {
return err
}
updateAccountPeers, err = areRouteChangesAffectPeers(ctx, transaction, route)
if err != nil {
return err
}
if err = transaction.IncrementNetworkSerial(ctx, store.LockingStrengthUpdate, accountID); err != nil {
return err
}
return transaction.DeleteRoute(ctx, store.LockingStrengthUpdate, accountID, string(routeID))
})
am.StoreEvent(ctx, userID, string(route.ID), accountID, activity.RouteRemoved, route.EventMeta())
if updateAccountPeers {
am.UpdateAccountPeers(ctx, accountID)
}
return nil
}
// ListRoutes returns a list of routes from account
func (am *DefaultAccountManager) ListRoutes(ctx context.Context, accountID, userID string) ([]*route.Route, error) {
allowed, err := am.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Routes, operations.Read)
if err != nil {
return nil, status.NewPermissionValidationError(err)
}
if !allowed {
return nil, status.NewPermissionDeniedError()
}
return am.Store.GetAccountRoutes(ctx, store.LockingStrengthShare, accountID)
}
func validateRoute(ctx context.Context, transaction store.Store, accountID string, routeToSave *route.Route) error {
if routeToSave == nil { if routeToSave == nil {
return status.Errorf(status.InvalidArgument, "route provided is nil") return status.Errorf(status.InvalidArgument, "route provided is nil")
} }
@ -238,19 +326,6 @@ func (am *DefaultAccountManager) SaveRoute(ctx context.Context, accountID, userI
return status.Errorf(status.InvalidArgument, "identifier should be between 1 and %d", route.MaxNetIDChar) return status.Errorf(status.InvalidArgument, "identifier should be between 1 and %d", route.MaxNetIDChar)
} }
allowed, err := am.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Routes, operations.Update)
if err != nil {
return status.NewPermissionValidationError(err)
}
if !allowed {
return status.NewPermissionDeniedError()
}
account, err := am.Store.GetAccount(ctx, accountID)
if err != nil {
return err
}
if len(routeToSave.Domains) > 0 && routeToSave.Network.IsValid() { if len(routeToSave.Domains) > 0 && routeToSave.Network.IsValid() {
return status.Errorf(status.InvalidArgument, "domains and network should not be provided at the same time") return status.Errorf(status.InvalidArgument, "domains and network should not be provided at the same time")
} }
@ -267,96 +342,39 @@ func (am *DefaultAccountManager) SaveRoute(ctx context.Context, accountID, userI
return status.Errorf(status.InvalidArgument, "peer with ID and peer groups should not be provided at the same time") return status.Errorf(status.InvalidArgument, "peer with ID and peer groups should not be provided at the same time")
} }
groupsMap, err := validateRouteGroups(ctx, transaction, accountID, routeToSave)
if err != nil {
return err
}
return checkRoutePrefixOrDomainsExistForPeers(ctx, transaction, accountID, routeToSave, groupsMap)
}
// validateRouteGroups validates the route groups and returns the validated groups map.
func validateRouteGroups(ctx context.Context, transaction store.Store, accountID string, routeToSave *route.Route) (map[string]*types.Group, error) {
groupsToValidate := slices.Concat(routeToSave.Groups, routeToSave.PeerGroups, routeToSave.AccessControlGroups)
groupsMap, err := transaction.GetGroupsByIDs(ctx, store.LockingStrengthShare, accountID, groupsToValidate)
if err != nil {
return nil, err
}
if len(routeToSave.PeerGroups) > 0 { if len(routeToSave.PeerGroups) > 0 {
err = validateGroups(routeToSave.PeerGroups, account.Groups) if err = validateGroups(routeToSave.PeerGroups, groupsMap); err != nil {
if err != nil { return nil, err
return err
} }
} }
if len(routeToSave.AccessControlGroups) > 0 { if len(routeToSave.AccessControlGroups) > 0 {
err = validateGroups(routeToSave.AccessControlGroups, account.Groups) if err = validateGroups(routeToSave.AccessControlGroups, groupsMap); err != nil {
if err != nil { return nil, err
return err
} }
} }
err = am.checkRoutePrefixOrDomainsExistForPeers(account, routeToSave.Peer, routeToSave.ID, routeToSave.Copy().PeerGroups, routeToSave.Network, routeToSave.Domains) if err = validateGroups(routeToSave.Groups, groupsMap); err != nil {
if err != nil { return nil, err
return err
} }
err = validateGroups(routeToSave.Groups, account.Groups) return groupsMap, nil
if err != nil {
return err
}
oldRoute := account.Routes[routeToSave.ID]
account.Routes[routeToSave.ID] = routeToSave
account.Network.IncSerial()
if err = am.Store.SaveAccount(ctx, account); err != nil {
return err
}
if am.isRouteChangeAffectPeers(account, oldRoute) || am.isRouteChangeAffectPeers(account, routeToSave) {
am.UpdateAccountPeers(ctx, accountID)
}
am.StoreEvent(ctx, userID, string(routeToSave.ID), accountID, activity.RouteUpdated, routeToSave.EventMeta())
return nil
}
// DeleteRoute deletes route with routeID
func (am *DefaultAccountManager) DeleteRoute(ctx context.Context, accountID string, routeID route.ID, userID string) error {
unlock := am.Store.AcquireWriteLockByUID(ctx, accountID)
defer unlock()
allowed, err := am.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Routes, operations.Delete)
if err != nil {
return status.NewPermissionValidationError(err)
}
if !allowed {
return status.NewPermissionDeniedError()
}
account, err := am.Store.GetAccount(ctx, accountID)
if err != nil {
return err
}
routy := account.Routes[routeID]
if routy == nil {
return status.Errorf(status.NotFound, "route with ID %s doesn't exist", routeID)
}
delete(account.Routes, routeID)
account.Network.IncSerial()
if err = am.Store.SaveAccount(ctx, account); err != nil {
return err
}
am.StoreEvent(ctx, userID, string(routy.ID), accountID, activity.RouteRemoved, routy.EventMeta())
if am.isRouteChangeAffectPeers(account, routy) {
am.UpdateAccountPeers(ctx, accountID)
}
return nil
}
// ListRoutes returns a list of routes from account
func (am *DefaultAccountManager) ListRoutes(ctx context.Context, accountID, userID string) ([]*route.Route, error) {
allowed, err := am.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Routes, operations.Read)
if err != nil {
return nil, status.NewPermissionValidationError(err)
}
if !allowed {
return nil, status.NewPermissionDeniedError()
}
return am.Store.GetAccountRoutes(ctx, store.LockingStrengthShare, accountID)
} }
func toProtocolRoute(route *route.Route) *proto.Route { func toProtocolRoute(route *route.Route) *proto.Route {
@ -455,8 +473,40 @@ func getProtoPortInfo(rule *types.RouteFirewallRule) *proto.PortInfo {
return &portInfo return &portInfo
} }
// isRouteChangeAffectPeers checks if a given route affects peers by determining // areRouteChangesAffectPeers checks if a given route affects peers by determining
// if it has a routing peer, distribution, or peer groups that include peers // if it has a routing peer, distribution, or peer groups that include peers.
func (am *DefaultAccountManager) isRouteChangeAffectPeers(account *types.Account, route *route.Route) bool { func areRouteChangesAffectPeers(ctx context.Context, transaction store.Store, route *route.Route) (bool, error) {
return am.anyGroupHasPeers(account, route.Groups) || am.anyGroupHasPeers(account, route.PeerGroups) || route.Peer != "" if route.Peer != "" {
return true, nil
}
hasPeers, err := anyGroupHasPeersOrResources(ctx, transaction, route.AccountID, route.Groups)
if err != nil {
return false, err
}
if hasPeers {
return true, nil
}
return anyGroupHasPeersOrResources(ctx, transaction, route.AccountID, route.PeerGroups)
}
// GetRoutesByPrefixOrDomains return list of routes by account and route prefix
func getRoutesByPrefixOrDomains(ctx context.Context, transaction store.Store, accountID string, prefix netip.Prefix, domains domain.List) ([]*route.Route, error) {
accountRoutes, err := transaction.GetAccountRoutes(ctx, store.LockingStrengthShare, accountID)
if err != nil {
return nil, err
}
routes := make([]*route.Route, 0)
for _, r := range accountRoutes {
dynamic := r.IsDynamic()
if dynamic && r.Domains.PunycodeString() == domains.PunycodeString() ||
!dynamic && r.Network.String() == prefix.String() {
routes = append(routes, r)
}
}
return routes, nil
} }

View File

@ -227,3 +227,7 @@ func NewUserRoleNotFoundError(role string) error {
func NewOperationNotFoundError(operation operations.Operation) error { func NewOperationNotFoundError(operation operations.Operation) error {
return Errorf(NotFound, "operation: %s not found", operation) return Errorf(NotFound, "operation: %s not found", operation)
} }
func NewRouteNotFoundError(routeID string) error {
return Errorf(NotFound, "route: %s not found", routeID)
}

View File

@ -23,8 +23,6 @@ import (
"gorm.io/gorm/clause" "gorm.io/gorm/clause"
"gorm.io/gorm/logger" "gorm.io/gorm/logger"
"github.com/netbirdio/netbird/management/server/util"
nbdns "github.com/netbirdio/netbird/dns" nbdns "github.com/netbirdio/netbird/dns"
resourceTypes "github.com/netbirdio/netbird/management/server/networks/resources/types" resourceTypes "github.com/netbirdio/netbird/management/server/networks/resources/types"
routerTypes "github.com/netbirdio/netbird/management/server/networks/routers/types" routerTypes "github.com/netbirdio/netbird/management/server/networks/routers/types"
@ -34,6 +32,7 @@ import (
"github.com/netbirdio/netbird/management/server/status" "github.com/netbirdio/netbird/management/server/status"
"github.com/netbirdio/netbird/management/server/telemetry" "github.com/netbirdio/netbird/management/server/telemetry"
"github.com/netbirdio/netbird/management/server/types" "github.com/netbirdio/netbird/management/server/types"
"github.com/netbirdio/netbird/management/server/util"
"github.com/netbirdio/netbird/route" "github.com/netbirdio/netbird/route"
) )
@ -1968,12 +1967,58 @@ func (s *SqlStore) DeletePostureChecks(ctx context.Context, lockStrength Locking
// GetAccountRoutes retrieves network routes for an account. // GetAccountRoutes retrieves network routes for an account.
func (s *SqlStore) GetAccountRoutes(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*route.Route, error) { func (s *SqlStore) GetAccountRoutes(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*route.Route, error) {
return getRecords[*route.Route](s.db, lockStrength, accountID) var routes []*route.Route
result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}).
Find(&routes, accountIDCondition, accountID)
if err := result.Error; err != nil {
log.WithContext(ctx).Errorf("failed to get routes from the store: %s", err)
return nil, status.Errorf(status.Internal, "failed to get routes from store")
}
return routes, nil
} }
// GetRouteByID retrieves a route by its ID and account ID. // GetRouteByID retrieves a route by its ID and account ID.
func (s *SqlStore) GetRouteByID(ctx context.Context, lockStrength LockingStrength, routeID string, accountID string) (*route.Route, error) { func (s *SqlStore) GetRouteByID(ctx context.Context, lockStrength LockingStrength, accountID string, routeID string) (*route.Route, error) {
return getRecordByID[route.Route](s.db, lockStrength, routeID, accountID) var route *route.Route
result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}).
First(&route, accountAndIDQueryCondition, accountID, routeID)
if err := result.Error; err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return nil, status.NewRouteNotFoundError(routeID)
}
log.WithContext(ctx).Errorf("failed to get route from the store: %s", err)
return nil, status.Errorf(status.Internal, "failed to get route from store")
}
return route, nil
}
// SaveRoute saves a route to the database.
func (s *SqlStore) SaveRoute(ctx context.Context, lockStrength LockingStrength, route *route.Route) error {
result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}).Save(route)
if err := result.Error; err != nil {
log.WithContext(ctx).Errorf("failed to save route to the store: %s", err)
return status.Errorf(status.Internal, "failed to save route to store")
}
return nil
}
// DeleteRoute deletes a route from the database.
func (s *SqlStore) DeleteRoute(ctx context.Context, lockStrength LockingStrength, accountID, routeID string) error {
result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}).
Delete(&route.Route{}, accountAndIDQueryCondition, accountID, routeID)
if err := result.Error; err != nil {
log.WithContext(ctx).Errorf("failed to delete route from the store: %s", err)
return status.Errorf(status.Internal, "failed to delete route from store")
}
if result.RowsAffected == 0 {
return status.NewRouteNotFoundError(routeID)
}
return nil
} }
// GetAccountSetupKeys retrieves setup keys for an account. // GetAccountSetupKeys retrieves setup keys for an account.
@ -2104,49 +2149,6 @@ func (s *SqlStore) DeleteNameServerGroup(ctx context.Context, lockStrength Locki
return nil return nil
} }
// getRecords retrieves records from the database based on the account ID.
func getRecords[T any](db *gorm.DB, lockStrength LockingStrength, accountID string) ([]T, error) {
tx := db
if lockStrength != LockingStrengthNone {
tx = tx.Clauses(clause.Locking{Strength: string(lockStrength)})
}
var record []T
result := tx.Find(&record, accountIDCondition, accountID)
if err := result.Error; err != nil {
parts := strings.Split(fmt.Sprintf("%T", record), ".")
recordType := parts[len(parts)-1]
return nil, status.Errorf(status.Internal, "failed to get account %ss from store: %v", recordType, err)
}
return record, nil
}
// getRecordByID retrieves a record by its ID and account ID from the database.
func getRecordByID[T any](db *gorm.DB, lockStrength LockingStrength, recordID, accountID string) (*T, error) {
tx := db
if lockStrength != LockingStrengthNone {
tx = tx.Clauses(clause.Locking{Strength: string(lockStrength)})
}
var record T
result := tx.Clauses(clause.Locking{Strength: string(lockStrength)}).
First(&record, accountAndIDQueryCondition, accountID, recordID)
if err := result.Error; err != nil {
parts := strings.Split(fmt.Sprintf("%T", record), ".")
recordType := parts[len(parts)-1]
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
return nil, status.Errorf(status.NotFound, "%s not found", recordType)
}
return nil, status.Errorf(status.Internal, "failed to get %s from store: %v", recordType, err)
}
return &record, nil
}
// SaveDNSSettings saves the DNS settings to the store. // SaveDNSSettings saves the DNS settings to the store.
func (s *SqlStore) SaveDNSSettings(ctx context.Context, lockStrength LockingStrength, accountID string, settings *types.DNSSettings) error { func (s *SqlStore) SaveDNSSettings(ctx context.Context, lockStrength LockingStrength, accountID string, settings *types.DNSSettings) error {
result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}).Model(&types.Account{}). result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}).Model(&types.Account{}).

View File

@ -19,21 +19,17 @@ import (
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
"github.com/netbirdio/netbird/management/server/util"
nbdns "github.com/netbirdio/netbird/dns" nbdns "github.com/netbirdio/netbird/dns"
resourceTypes "github.com/netbirdio/netbird/management/server/networks/resources/types" resourceTypes "github.com/netbirdio/netbird/management/server/networks/resources/types"
routerTypes "github.com/netbirdio/netbird/management/server/networks/routers/types" routerTypes "github.com/netbirdio/netbird/management/server/networks/routers/types"
networkTypes "github.com/netbirdio/netbird/management/server/networks/types" networkTypes "github.com/netbirdio/netbird/management/server/networks/types"
"github.com/netbirdio/netbird/management/server/posture"
"github.com/netbirdio/netbird/management/server/types"
route2 "github.com/netbirdio/netbird/route"
"github.com/netbirdio/netbird/management/server/status"
nbpeer "github.com/netbirdio/netbird/management/server/peer" nbpeer "github.com/netbirdio/netbird/management/server/peer"
"github.com/netbirdio/netbird/management/server/posture"
"github.com/netbirdio/netbird/management/server/status"
"github.com/netbirdio/netbird/management/server/types"
"github.com/netbirdio/netbird/management/server/util"
nbroute "github.com/netbirdio/netbird/route" nbroute "github.com/netbirdio/netbird/route"
route2 "github.com/netbirdio/netbird/route"
) )
func runTestForAllEngines(t *testing.T, testDataFile string, f func(t *testing.T, store Store)) { func runTestForAllEngines(t *testing.T, testDataFile string, f func(t *testing.T, store Store)) {
@ -3247,6 +3243,132 @@ func TestSqlStore_SaveGroups_LargeBatch(t *testing.T) {
require.NoError(t, err) require.NoError(t, err)
require.Equal(t, 8003, len(accountGroups)) require.Equal(t, 8003, len(accountGroups))
} }
func TestSqlStore_GetAccountRoutes(t *testing.T) {
store, cleanup, err := NewTestStoreFromSQL(context.Background(), "../testdata/extended-store.sql", t.TempDir())
t.Cleanup(cleanup)
require.NoError(t, err)
tests := []struct {
name string
accountID string
expectedCount int
}{
{
name: "retrieve routes by existing account ID",
accountID: "bf1c8084-ba50-4ce7-9439-34653001fc3b",
expectedCount: 1,
},
{
name: "non-existing account ID",
accountID: "nonexistent",
expectedCount: 0,
},
{
name: "empty account ID",
accountID: "",
expectedCount: 0,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
routes, err := store.GetAccountRoutes(context.Background(), LockingStrengthShare, tt.accountID)
require.NoError(t, err)
require.Len(t, routes, tt.expectedCount)
})
}
}
func TestSqlStore_GetRouteByID(t *testing.T) {
store, cleanup, err := NewTestStoreFromSQL(context.Background(), "../testdata/extended-store.sql", t.TempDir())
t.Cleanup(cleanup)
require.NoError(t, err)
accountID := "bf1c8084-ba50-4ce7-9439-34653001fc3b"
tests := []struct {
name string
routeID string
expectError bool
}{
{
name: "retrieve existing route",
routeID: "ct03t427qv97vmtmglog",
expectError: false,
},
{
name: "retrieve non-existing route",
routeID: "non-existing",
expectError: true,
},
{
name: "retrieve with empty route ID",
routeID: "",
expectError: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
route, err := store.GetRouteByID(context.Background(), LockingStrengthShare, accountID, tt.routeID)
if tt.expectError {
require.Error(t, err)
sErr, ok := status.FromError(err)
require.True(t, ok)
require.Equal(t, sErr.Type(), status.NotFound)
require.Nil(t, route)
} else {
require.NoError(t, err)
require.NotNil(t, route)
require.Equal(t, tt.routeID, string(route.ID))
}
})
}
}
func TestSqlStore_SaveRoute(t *testing.T) {
store, cleanup, err := NewTestStoreFromSQL(context.Background(), "../testdata/extended-store.sql", t.TempDir())
t.Cleanup(cleanup)
require.NoError(t, err)
accountID := "bf1c8084-ba50-4ce7-9439-34653001fc3b"
route := &route2.Route{
ID: "route-id",
AccountID: accountID,
Network: netip.MustParsePrefix("10.10.0.0/16"),
NetID: "netID",
PeerGroups: []string{"routeA"},
NetworkType: route2.IPv4Network,
Masquerade: true,
Metric: 9999,
Enabled: true,
Groups: []string{"groupA"},
AccessControlGroups: []string{},
}
err = store.SaveRoute(context.Background(), LockingStrengthUpdate, route)
require.NoError(t, err)
saveRoute, err := store.GetRouteByID(context.Background(), LockingStrengthShare, accountID, string(route.ID))
require.NoError(t, err)
require.Equal(t, route, saveRoute)
}
func TestSqlStore_DeleteRoute(t *testing.T) {
store, cleanup, err := NewTestStoreFromSQL(context.Background(), "../testdata/extended-store.sql", t.TempDir())
t.Cleanup(cleanup)
require.NoError(t, err)
accountID := "bf1c8084-ba50-4ce7-9439-34653001fc3b"
routeID := "ct03t427qv97vmtmglog"
err = store.DeleteRoute(context.Background(), LockingStrengthUpdate, accountID, routeID)
require.NoError(t, err)
route, err := store.GetRouteByID(context.Background(), LockingStrengthShare, accountID, routeID)
require.Error(t, err)
require.Nil(t, route)
}
func TestSqlStore_GetAccountMeta(t *testing.T) { func TestSqlStore_GetAccountMeta(t *testing.T) {
store, cleanup, err := NewTestStoreFromSQL(context.Background(), "../testdata/extended-store.sql", t.TempDir()) store, cleanup, err := NewTestStoreFromSQL(context.Background(), "../testdata/extended-store.sql", t.TempDir())

View File

@ -145,7 +145,9 @@ type Store interface {
DeleteSetupKey(ctx context.Context, lockStrength LockingStrength, accountID, keyID string) error DeleteSetupKey(ctx context.Context, lockStrength LockingStrength, accountID, keyID string) error
GetAccountRoutes(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*route.Route, error) GetAccountRoutes(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*route.Route, error)
GetRouteByID(ctx context.Context, lockStrength LockingStrength, routeID string, accountID string) (*route.Route, error) GetRouteByID(ctx context.Context, lockStrength LockingStrength, accountID, routeID string) (*route.Route, error)
SaveRoute(ctx context.Context, lockStrength LockingStrength, route *route.Route) error
DeleteRoute(ctx context.Context, lockStrength LockingStrength, accountID, routeID string) error
GetAccountNameServerGroups(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*dns.NameServerGroup, error) GetAccountNameServerGroups(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*dns.NameServerGroup, error)
GetNameServerGroupByID(ctx context.Context, lockStrength LockingStrength, nameServerGroupID string, accountID string) (*dns.NameServerGroup, error) GetNameServerGroupByID(ctx context.Context, lockStrength LockingStrength, nameServerGroupID string, accountID string) (*dns.NameServerGroup, error)

View File

@ -38,4 +38,5 @@ INSERT INTO "groups" VALUES('cfefqs706sqkneg59g2g','bf1c8084-ba50-4ce7-9439-3465
INSERT INTO posture_checks VALUES('csplshq7qv948l48f7t0','NetBird Version > 0.32.0','','bf1c8084-ba50-4ce7-9439-34653001fc3b','{"NBVersionCheck":{"MinVersion":"0.31.0"}}'); INSERT INTO posture_checks VALUES('csplshq7qv948l48f7t0','NetBird Version > 0.32.0','','bf1c8084-ba50-4ce7-9439-34653001fc3b','{"NBVersionCheck":{"MinVersion":"0.31.0"}}');
INSERT INTO posture_checks VALUES('cspnllq7qv95uq1r4k90','Allow Berlin and Deny local network 172.16.1.0/24','','bf1c8084-ba50-4ce7-9439-34653001fc3b','{"GeoLocationCheck":{"Locations":[{"CountryCode":"DE","CityName":"Berlin"}],"Action":"allow"},"PeerNetworkRangeCheck":{"Action":"deny","Ranges":["172.16.1.0/24"]}}'); INSERT INTO posture_checks VALUES('cspnllq7qv95uq1r4k90','Allow Berlin and Deny local network 172.16.1.0/24','','bf1c8084-ba50-4ce7-9439-34653001fc3b','{"GeoLocationCheck":{"Locations":[{"CountryCode":"DE","CityName":"Berlin"}],"Action":"allow"},"PeerNetworkRangeCheck":{"Action":"deny","Ranges":["172.16.1.0/24"]}}');
INSERT INTO name_server_groups VALUES('csqdelq7qv97ncu7d9t0','bf1c8084-ba50-4ce7-9439-34653001fc3b','Google DNS','Google DNS Servers','[{"IP":"8.8.8.8","NSType":1,"Port":53},{"IP":"8.8.4.4","NSType":1,"Port":53}]','["cfefqs706sqkneg59g2g"]',1,'[]',1,0); INSERT INTO name_server_groups VALUES('csqdelq7qv97ncu7d9t0','bf1c8084-ba50-4ce7-9439-34653001fc3b','Google DNS','Google DNS Servers','[{"IP":"8.8.8.8","NSType":1,"Port":53},{"IP":"8.8.4.4","NSType":1,"Port":53}]','["cfefqs706sqkneg59g2g"]',1,'[]',1,0);
INSERT INTO routes VALUES('ct03t427qv97vmtmglog','bf1c8084-ba50-4ce7-9439-34653001fc3b','"10.10.0.0/16"',NULL,0,'aws-eu-central-1-vpc','Production VPC in Frankfurt','ct03r5q7qv97vmtmglng',NULL,1,1,9999,1,'["cfefqs706sqkneg59g2g"]',NULL);
INSERT INTO installations VALUES(1,''); INSERT INTO installations VALUES(1,'');

View File

@ -36,6 +36,9 @@ const (
PublicCategory = "public" PublicCategory = "public"
PrivateCategory = "private" PrivateCategory = "private"
UnknownCategory = "unknown" UnknownCategory = "unknown"
// firewallRuleMinPortRangesVer defines the minimum peer version that supports port range rules.
firewallRuleMinPortRangesVer = "0.48.0"
) )
type LookupMap map[string]struct{} type LookupMap map[string]struct{}
@ -248,7 +251,7 @@ func (a *Account) GetPeerNetworkMap(
} }
} }
aclPeers, firewallRules := a.GetPeerConnectionResources(ctx, peerID, validatedPeersMap) aclPeers, firewallRules := a.GetPeerConnectionResources(ctx, peer, validatedPeersMap)
// exclude expired peers // exclude expired peers
var peersToConnect []*nbpeer.Peer var peersToConnect []*nbpeer.Peer
var expiredPeers []*nbpeer.Peer var expiredPeers []*nbpeer.Peer
@ -961,8 +964,9 @@ func (a *Account) UserGroupsRemoveFromPeers(userID string, groups ...string) map
// GetPeerConnectionResources for a given peer // GetPeerConnectionResources for a given peer
// //
// This function returns the list of peers and firewall rules that are applicable to a given peer. // This function returns the list of peers and firewall rules that are applicable to a given peer.
func (a *Account) GetPeerConnectionResources(ctx context.Context, peerID string, validatedPeersMap map[string]struct{}) ([]*nbpeer.Peer, []*FirewallRule) { func (a *Account) GetPeerConnectionResources(ctx context.Context, peer *nbpeer.Peer, validatedPeersMap map[string]struct{}) ([]*nbpeer.Peer, []*FirewallRule) {
generateResources, getAccumulatedResources := a.connResourcesGenerator(ctx) generateResources, getAccumulatedResources := a.connResourcesGenerator(ctx, peer)
for _, policy := range a.Policies { for _, policy := range a.Policies {
if !policy.Enabled { if !policy.Enabled {
continue continue
@ -973,8 +977,8 @@ func (a *Account) GetPeerConnectionResources(ctx context.Context, peerID string,
continue continue
} }
sourcePeers, peerInSources := a.getAllPeersFromGroups(ctx, rule.Sources, peerID, policy.SourcePostureChecks, validatedPeersMap) sourcePeers, peerInSources := a.getAllPeersFromGroups(ctx, rule.Sources, peer.ID, policy.SourcePostureChecks, validatedPeersMap)
destinationPeers, peerInDestinations := a.getAllPeersFromGroups(ctx, rule.Destinations, peerID, nil, validatedPeersMap) destinationPeers, peerInDestinations := a.getAllPeersFromGroups(ctx, rule.Destinations, peer.ID, nil, validatedPeersMap)
if rule.Bidirectional { if rule.Bidirectional {
if peerInSources { if peerInSources {
@ -1003,7 +1007,7 @@ func (a *Account) GetPeerConnectionResources(ctx context.Context, peerID string,
// The generator function is used to generate the list of peers and firewall rules that are applicable to a given peer. // The generator function is used to generate the list of peers and firewall rules that are applicable to a given peer.
// It safe to call the generator function multiple times for same peer and different rules no duplicates will be // It safe to call the generator function multiple times for same peer and different rules no duplicates will be
// generated. The accumulator function returns the result of all the generator calls. // generated. The accumulator function returns the result of all the generator calls.
func (a *Account) connResourcesGenerator(ctx context.Context) (func(*PolicyRule, []*nbpeer.Peer, int), func() ([]*nbpeer.Peer, []*FirewallRule)) { func (a *Account) connResourcesGenerator(ctx context.Context, targetPeer *nbpeer.Peer) (func(*PolicyRule, []*nbpeer.Peer, int), func() ([]*nbpeer.Peer, []*FirewallRule)) {
rulesExists := make(map[string]struct{}) rulesExists := make(map[string]struct{})
peersExists := make(map[string]struct{}) peersExists := make(map[string]struct{})
rules := make([]*FirewallRule, 0) rules := make([]*FirewallRule, 0)
@ -1051,17 +1055,7 @@ func (a *Account) connResourcesGenerator(ctx context.Context) (func(*PolicyRule,
continue continue
} }
for _, port := range rule.Ports { rules = append(rules, expandPortsAndRanges(fr, rule, targetPeer)...)
pr := fr // clone rule and add set new port
pr.Port = port
rules = append(rules, &pr)
}
for _, portRange := range rule.PortRanges {
pr := fr
pr.PortRange = portRange
rules = append(rules, &pr)
}
} }
}, func() ([]*nbpeer.Peer, []*FirewallRule) { }, func() ([]*nbpeer.Peer, []*FirewallRule) {
return peers, rules return peers, rules
@ -1590,3 +1584,45 @@ func (a *Account) AddAllGroup() error {
} }
return nil return nil
} }
// expandPortsAndRanges expands Ports and PortRanges of a rule into individual firewall rules
func expandPortsAndRanges(base FirewallRule, rule *PolicyRule, peer *nbpeer.Peer) []*FirewallRule {
var expanded []*FirewallRule
if len(rule.Ports) > 0 {
for _, port := range rule.Ports {
fr := base
fr.Port = port
expanded = append(expanded, &fr)
}
return expanded
}
supportPortRanges := peerSupportsPortRanges(peer.Meta.WtVersion)
for _, portRange := range rule.PortRanges {
fr := base
if supportPortRanges {
fr.PortRange = portRange
} else {
// Peer doesn't support port ranges, only allow single-port ranges
if portRange.Start != portRange.End {
continue
}
fr.Port = strconv.FormatUint(uint64(portRange.Start), 10)
}
expanded = append(expanded, &fr)
}
return expanded
}
// peerSupportsPortRanges checks if the peer version supports port ranges.
func peerSupportsPortRanges(peerVer string) bool {
if strings.Contains(peerVer, "dev") {
return true
}
meetMinVer, err := posture.MeetsMinVersion(firewallRuleMinPortRangesVer, peerVer)
return err == nil && meetMinVer
}

View File

@ -76,7 +76,6 @@ func generateRouteFirewallRules(ctx context.Context, route *nbroute.Route, rule
rules = append(rules, generateRulesWithPortRanges(baseRule, rule, rulesExists)...) rules = append(rules, generateRulesWithPortRanges(baseRule, rule, rulesExists)...)
} else { } else {
rules = append(rules, generateRulesWithPorts(ctx, baseRule, rule, rulesExists)...) rules = append(rules, generateRulesWithPorts(ctx, baseRule, rule, rulesExists)...)
} }
// TODO: generate IPv6 rules for dynamic routes // TODO: generate IPv6 rules for dynamic routes