diff --git a/.goreleaser.yaml b/.goreleaser.yaml index 693fbfe01..ca5eafa62 100644 --- a/.goreleaser.yaml +++ b/.goreleaser.yaml @@ -149,6 +149,7 @@ nfpms: dockers: - image_templates: - netbirdio/netbird:{{ .Version }}-amd64 + - ghcr.io/netbirdio/netbird:{{ .Version }}-amd64 ids: - netbird goarch: amd64 @@ -164,6 +165,7 @@ dockers: - "--label=maintainer=dev@netbird.io" - image_templates: - netbirdio/netbird:{{ .Version }}-arm64v8 + - ghcr.io/netbirdio/netbird:{{ .Version }}-arm64v8 ids: - netbird goarch: arm64 diff --git a/client/android/client.go b/client/android/client.go index 229bcd974..3b8a5bd0f 100644 --- a/client/android/client.go +++ b/client/android/client.go @@ -59,6 +59,8 @@ type Client struct { deviceName string uiVersion string networkChangeListener listener.NetworkChangeListener + + connectClient *internal.ConnectClient } // NewClient instantiate a new Client @@ -106,8 +108,8 @@ func (c *Client) Run(urlOpener URLOpener, dns *DNSList, dnsReadyListener DnsRead // todo do not throw error in case of cancelled context ctx = internal.CtxInitState(ctx) - connectClient := internal.NewConnectClient(ctx, cfg, c.recorder) - return connectClient.RunOnAndroid(c.tunAdapter, c.iFaceDiscover, c.networkChangeListener, dns.items, dnsReadyListener) + c.connectClient = internal.NewConnectClient(ctx, cfg, c.recorder) + return c.connectClient.RunOnAndroid(c.tunAdapter, c.iFaceDiscover, c.networkChangeListener, dns.items, dnsReadyListener) } // RunWithoutLogin we apply this type of run function when the backed has been started without UI (i.e. after reboot). @@ -132,8 +134,8 @@ func (c *Client) RunWithoutLogin(dns *DNSList, dnsReadyListener DnsReadyListener // todo do not throw error in case of cancelled context ctx = internal.CtxInitState(ctx) - connectClient := internal.NewConnectClient(ctx, cfg, c.recorder) - return connectClient.RunOnAndroid(c.tunAdapter, c.iFaceDiscover, c.networkChangeListener, dns.items, dnsReadyListener) + c.connectClient = internal.NewConnectClient(ctx, cfg, c.recorder) + return c.connectClient.RunOnAndroid(c.tunAdapter, c.iFaceDiscover, c.networkChangeListener, dns.items, dnsReadyListener) } // Stop the internal client and free the resources @@ -174,6 +176,53 @@ func (c *Client) PeersList() *PeerInfoArray { return &PeerInfoArray{items: peerInfos} } +func (c *Client) Networks() *NetworkArray { + if c.connectClient == nil { + log.Error("not connected") + return nil + } + + engine := c.connectClient.Engine() + if engine == nil { + log.Error("could not get engine") + return nil + } + + routeManager := engine.GetRouteManager() + if routeManager == nil { + log.Error("could not get route manager") + return nil + } + + networkArray := &NetworkArray{ + items: make([]Network, 0), + } + + for id, routes := range routeManager.GetClientRoutesWithNetID() { + if len(routes) == 0 { + continue + } + + if routes[0].IsDynamic() { + continue + } + + peer, err := c.recorder.GetPeer(routes[0].Peer) + if err != nil { + log.Errorf("could not get peer info for %s: %v", routes[0].Peer, err) + continue + } + network := Network{ + Name: string(id), + Network: routes[0].Network.String(), + Peer: peer.FQDN, + Status: peer.ConnStatus.String(), + } + networkArray.Add(network) + } + return networkArray +} + // OnUpdatedHostDNS update the DNS servers addresses for root zones func (c *Client) OnUpdatedHostDNS(list *DNSList) error { dnsServer, err := dns.GetServerDns() diff --git a/client/android/networks.go b/client/android/networks.go new file mode 100644 index 000000000..aa130420b --- /dev/null +++ b/client/android/networks.go @@ -0,0 +1,27 @@ +//go:build android + +package android + +type Network struct { + Name string + Network string + Peer string + Status string +} + +type NetworkArray struct { + items []Network +} + +func (array *NetworkArray) Add(s Network) *NetworkArray { + array.items = append(array.items, s) + return array +} + +func (array *NetworkArray) Get(i int) *Network { + return &array.items[i] +} + +func (array *NetworkArray) Size() int { + return len(array.items) +} diff --git a/client/android/peer_notifier.go b/client/android/peer_notifier.go index 9f6fcddd6..1f5564c72 100644 --- a/client/android/peer_notifier.go +++ b/client/android/peer_notifier.go @@ -7,30 +7,23 @@ type PeerInfo struct { ConnStatus string // Todo replace to enum } -// PeerInfoCollection made for Java layer to get non default types as collection -type PeerInfoCollection interface { - Add(s string) PeerInfoCollection - Get(i int) string - Size() int -} - -// PeerInfoArray is the implementation of the PeerInfoCollection +// PeerInfoArray is a wrapper of []PeerInfo type PeerInfoArray struct { items []PeerInfo } // Add new PeerInfo to the collection -func (array PeerInfoArray) Add(s PeerInfo) PeerInfoArray { +func (array *PeerInfoArray) Add(s PeerInfo) *PeerInfoArray { array.items = append(array.items, s) return array } // Get return an element of the collection -func (array PeerInfoArray) Get(i int) *PeerInfo { +func (array *PeerInfoArray) Get(i int) *PeerInfo { return &array.items[i] } // Size return with the size of the collection -func (array PeerInfoArray) Size() int { +func (array *PeerInfoArray) Size() int { return len(array.items) } diff --git a/client/android/preferences.go b/client/android/preferences.go index 2a8b197e7..2d5668d1c 100644 --- a/client/android/preferences.go +++ b/client/android/preferences.go @@ -4,12 +4,12 @@ import ( "github.com/netbirdio/netbird/client/internal" ) -// Preferences export a subset of the internal config for gomobile +// Preferences exports a subset of the internal config for gomobile type Preferences struct { configInput internal.ConfigInput } -// NewPreferences create new Preferences instance +// NewPreferences creates a new Preferences instance func NewPreferences(configPath string) *Preferences { ci := internal.ConfigInput{ ConfigPath: configPath, @@ -17,7 +17,7 @@ func NewPreferences(configPath string) *Preferences { return &Preferences{ci} } -// GetManagementURL read url from config file +// GetManagementURL reads URL from config file func (p *Preferences) GetManagementURL() (string, error) { if p.configInput.ManagementURL != "" { return p.configInput.ManagementURL, nil @@ -30,12 +30,12 @@ func (p *Preferences) GetManagementURL() (string, error) { return cfg.ManagementURL.String(), err } -// SetManagementURL store the given url and wait for commit +// SetManagementURL stores the given URL and waits for commit func (p *Preferences) SetManagementURL(url string) { p.configInput.ManagementURL = url } -// GetAdminURL read url from config file +// GetAdminURL reads URL from config file func (p *Preferences) GetAdminURL() (string, error) { if p.configInput.AdminURL != "" { return p.configInput.AdminURL, nil @@ -48,12 +48,12 @@ func (p *Preferences) GetAdminURL() (string, error) { return cfg.AdminURL.String(), err } -// SetAdminURL store the given url and wait for commit +// SetAdminURL stores the given URL and waits for commit func (p *Preferences) SetAdminURL(url string) { p.configInput.AdminURL = url } -// GetPreSharedKey read preshared key from config file +// GetPreSharedKey reads pre-shared key from config file func (p *Preferences) GetPreSharedKey() (string, error) { if p.configInput.PreSharedKey != nil { return *p.configInput.PreSharedKey, nil @@ -66,17 +66,17 @@ func (p *Preferences) GetPreSharedKey() (string, error) { return cfg.PreSharedKey, err } -// SetPreSharedKey store the given key and wait for commit +// SetPreSharedKey stores the given key and waits for commit func (p *Preferences) SetPreSharedKey(key string) { p.configInput.PreSharedKey = &key } -// SetRosenpassEnabled store if rosenpass is enabled +// SetRosenpassEnabled stores whether Rosenpass is enabled func (p *Preferences) SetRosenpassEnabled(enabled bool) { p.configInput.RosenpassEnabled = &enabled } -// GetRosenpassEnabled read rosenpass enabled from config file +// GetRosenpassEnabled reads Rosenpass enabled status from config file func (p *Preferences) GetRosenpassEnabled() (bool, error) { if p.configInput.RosenpassEnabled != nil { return *p.configInput.RosenpassEnabled, nil @@ -89,12 +89,12 @@ func (p *Preferences) GetRosenpassEnabled() (bool, error) { return cfg.RosenpassEnabled, err } -// SetRosenpassPermissive store the given permissive and wait for commit +// SetRosenpassPermissive stores the given permissive setting and waits for commit func (p *Preferences) SetRosenpassPermissive(permissive bool) { p.configInput.RosenpassPermissive = &permissive } -// GetRosenpassPermissive read rosenpass permissive from config file +// GetRosenpassPermissive reads Rosenpass permissive setting from config file func (p *Preferences) GetRosenpassPermissive() (bool, error) { if p.configInput.RosenpassPermissive != nil { return *p.configInput.RosenpassPermissive, nil @@ -107,7 +107,119 @@ func (p *Preferences) GetRosenpassPermissive() (bool, error) { return cfg.RosenpassPermissive, err } -// Commit write out the changes into config file +// GetDisableClientRoutes reads disable client routes setting from config file +func (p *Preferences) GetDisableClientRoutes() (bool, error) { + if p.configInput.DisableClientRoutes != nil { + return *p.configInput.DisableClientRoutes, nil + } + + cfg, err := internal.ReadConfig(p.configInput.ConfigPath) + if err != nil { + return false, err + } + return cfg.DisableClientRoutes, err +} + +// SetDisableClientRoutes stores the given value and waits for commit +func (p *Preferences) SetDisableClientRoutes(disable bool) { + p.configInput.DisableClientRoutes = &disable +} + +// GetDisableServerRoutes reads disable server routes setting from config file +func (p *Preferences) GetDisableServerRoutes() (bool, error) { + if p.configInput.DisableServerRoutes != nil { + return *p.configInput.DisableServerRoutes, nil + } + + cfg, err := internal.ReadConfig(p.configInput.ConfigPath) + if err != nil { + return false, err + } + return cfg.DisableServerRoutes, err +} + +// SetDisableServerRoutes stores the given value and waits for commit +func (p *Preferences) SetDisableServerRoutes(disable bool) { + p.configInput.DisableServerRoutes = &disable +} + +// GetDisableDNS reads disable DNS setting from config file +func (p *Preferences) GetDisableDNS() (bool, error) { + if p.configInput.DisableDNS != nil { + return *p.configInput.DisableDNS, nil + } + + cfg, err := internal.ReadConfig(p.configInput.ConfigPath) + if err != nil { + return false, err + } + return cfg.DisableDNS, err +} + +// SetDisableDNS stores the given value and waits for commit +func (p *Preferences) SetDisableDNS(disable bool) { + p.configInput.DisableDNS = &disable +} + +// GetDisableFirewall reads disable firewall setting from config file +func (p *Preferences) GetDisableFirewall() (bool, error) { + if p.configInput.DisableFirewall != nil { + return *p.configInput.DisableFirewall, nil + } + + cfg, err := internal.ReadConfig(p.configInput.ConfigPath) + if err != nil { + return false, err + } + return cfg.DisableFirewall, err +} + +// SetDisableFirewall stores the given value and waits for commit +func (p *Preferences) SetDisableFirewall(disable bool) { + p.configInput.DisableFirewall = &disable +} + +// GetServerSSHAllowed reads server SSH allowed setting from config file +func (p *Preferences) GetServerSSHAllowed() (bool, error) { + if p.configInput.ServerSSHAllowed != nil { + return *p.configInput.ServerSSHAllowed, nil + } + + cfg, err := internal.ReadConfig(p.configInput.ConfigPath) + if err != nil { + return false, err + } + if cfg.ServerSSHAllowed == nil { + // Default to false for security on Android + return false, nil + } + return *cfg.ServerSSHAllowed, err +} + +// SetServerSSHAllowed stores the given value and waits for commit +func (p *Preferences) SetServerSSHAllowed(allowed bool) { + p.configInput.ServerSSHAllowed = &allowed +} + +// GetBlockInbound reads block inbound setting from config file +func (p *Preferences) GetBlockInbound() (bool, error) { + if p.configInput.BlockInbound != nil { + return *p.configInput.BlockInbound, nil + } + + cfg, err := internal.ReadConfig(p.configInput.ConfigPath) + if err != nil { + return false, err + } + return cfg.BlockInbound, err +} + +// SetBlockInbound stores the given value and waits for commit +func (p *Preferences) SetBlockInbound(block bool) { + p.configInput.BlockInbound = &block +} + +// Commit writes out the changes to the config file func (p *Preferences) Commit() error { _, err := internal.UpdateOrCreateConfig(p.configInput) return err diff --git a/client/cmd/status.go b/client/cmd/status.go index e466f73ab..a85ee925e 100644 --- a/client/cmd/status.go +++ b/client/cmd/status.go @@ -69,7 +69,10 @@ func statusFunc(cmd *cobra.Command, args []string) error { return err } - if resp.GetStatus() == string(internal.StatusNeedsLogin) || resp.GetStatus() == string(internal.StatusLoginFailed) { + status := resp.GetStatus() + + if status == string(internal.StatusNeedsLogin) || status == string(internal.StatusLoginFailed) || + status == string(internal.StatusSessionExpired) { cmd.Printf("Daemon status: %s\n\n"+ "Run UP command to log in with SSO (interactive login):\n\n"+ " netbird up \n\n"+ diff --git a/client/cmd/system.go b/client/cmd/system.go index 83ce8d215..f63432401 100644 --- a/client/cmd/system.go +++ b/client/cmd/system.go @@ -38,5 +38,5 @@ func init() { upCmd.PersistentFlags().BoolVar(&blockInbound, blockInboundFlag, false, "Block inbound connections. If enabled, the client will not allow any inbound connections to the local machine nor routed networks.\n"+ - "This overrides any policies received from the management service.") + "This overrides any policies received from the management service.") } diff --git a/client/cmd/trace.go b/client/cmd/trace.go index abb73b646..655838260 100644 --- a/client/cmd/trace.go +++ b/client/cmd/trace.go @@ -118,7 +118,7 @@ func tracePacket(cmd *cobra.Command, args []string) error { } func printTrace(cmd *cobra.Command, src, dst, proto string, sport, dport uint16, resp *proto.TracePacketResponse) { - cmd.Printf("Packet trace %s:%d -> %s:%d (%s)\n\n", src, sport, dst, dport, strings.ToUpper(proto)) + cmd.Printf("Packet trace %s:%d → %s:%d (%s)\n\n", src, sport, dst, dport, strings.ToUpper(proto)) for _, stage := range resp.Stages { if stage.ForwardingDetails != nil { diff --git a/client/firewall/uspfilter/conntrack/common.go b/client/firewall/uspfilter/conntrack/common.go index 3de0bb3f4..bcf6d894b 100644 --- a/client/firewall/uspfilter/conntrack/common.go +++ b/client/firewall/uspfilter/conntrack/common.go @@ -62,5 +62,5 @@ type ConnKey struct { } func (c ConnKey) String() string { - return fmt.Sprintf("%s:%d -> %s:%d", c.SrcIP.Unmap(), c.SrcPort, c.DstIP.Unmap(), c.DstPort) + return fmt.Sprintf("%s:%d → %s:%d", c.SrcIP.Unmap(), c.SrcPort, c.DstIP.Unmap(), c.DstPort) } diff --git a/client/firewall/uspfilter/conntrack/icmp.go b/client/firewall/uspfilter/conntrack/icmp.go index c8ea159da..509c1549b 100644 --- a/client/firewall/uspfilter/conntrack/icmp.go +++ b/client/firewall/uspfilter/conntrack/icmp.go @@ -3,6 +3,7 @@ package conntrack import ( "context" "fmt" + "net" "net/netip" "sync" "time" @@ -19,6 +20,10 @@ const ( DefaultICMPTimeout = 30 * time.Second // ICMPCleanupInterval is how often we check for stale ICMP connections ICMPCleanupInterval = 15 * time.Second + + // MaxICMPPayloadLength is the maximum length of ICMP payload we consider for original packet info, + // which includes the IP header (20 bytes) and transport header (8 bytes) + MaxICMPPayloadLength = 28 ) // ICMPConnKey uniquely identifies an ICMP connection @@ -29,7 +34,7 @@ type ICMPConnKey struct { } func (i ICMPConnKey) String() string { - return fmt.Sprintf("%s -> %s (id %d)", i.SrcIP, i.DstIP, i.ID) + return fmt.Sprintf("%s → %s (id %d)", i.SrcIP, i.DstIP, i.ID) } // ICMPConnTrack represents an ICMP connection state @@ -50,6 +55,72 @@ type ICMPTracker struct { flowLogger nftypes.FlowLogger } +// ICMPInfo holds ICMP type, code, and payload for lazy string formatting in logs +type ICMPInfo struct { + TypeCode layers.ICMPv4TypeCode + PayloadData [MaxICMPPayloadLength]byte + // actual length of valid data + PayloadLen int +} + +// String implements fmt.Stringer for lazy evaluation in log messages +func (info ICMPInfo) String() string { + if info.isErrorMessage() && info.PayloadLen >= MaxICMPPayloadLength { + if origInfo := info.parseOriginalPacket(); origInfo != "" { + return fmt.Sprintf("%s (original: %s)", info.TypeCode, origInfo) + } + } + + return info.TypeCode.String() +} + +// isErrorMessage returns true if this ICMP type carries original packet info +func (info ICMPInfo) isErrorMessage() bool { + typ := info.TypeCode.Type() + return typ == 3 || // Destination Unreachable + typ == 5 || // Redirect + typ == 11 || // Time Exceeded + typ == 12 // Parameter Problem +} + +// parseOriginalPacket extracts info about the original packet from ICMP payload +func (info ICMPInfo) parseOriginalPacket() string { + if info.PayloadLen < MaxICMPPayloadLength { + return "" + } + + // TODO: handle IPv6 + if version := (info.PayloadData[0] >> 4) & 0xF; version != 4 { + return "" + } + + protocol := info.PayloadData[9] + srcIP := net.IP(info.PayloadData[12:16]) + dstIP := net.IP(info.PayloadData[16:20]) + + transportData := info.PayloadData[20:] + + switch nftypes.Protocol(protocol) { + case nftypes.TCP: + srcPort := uint16(transportData[0])<<8 | uint16(transportData[1]) + dstPort := uint16(transportData[2])<<8 | uint16(transportData[3]) + return fmt.Sprintf("TCP %s:%d → %s:%d", srcIP, srcPort, dstIP, dstPort) + + case nftypes.UDP: + srcPort := uint16(transportData[0])<<8 | uint16(transportData[1]) + dstPort := uint16(transportData[2])<<8 | uint16(transportData[3]) + return fmt.Sprintf("UDP %s:%d → %s:%d", srcIP, srcPort, dstIP, dstPort) + + case nftypes.ICMP: + icmpType := transportData[0] + icmpCode := transportData[1] + return fmt.Sprintf("ICMP %s → %s (type %d code %d)", srcIP, dstIP, icmpType, icmpCode) + + default: + return fmt.Sprintf("Proto %d %s → %s", protocol, srcIP, dstIP) + } +} + // NewICMPTracker creates a new ICMP connection tracker func NewICMPTracker(timeout time.Duration, logger *nblog.Logger, flowLogger nftypes.FlowLogger) *ICMPTracker { if timeout == 0 { @@ -93,30 +164,64 @@ func (t *ICMPTracker) updateIfExists(srcIP netip.Addr, dstIP netip.Addr, id uint } // TrackOutbound records an outbound ICMP connection -func (t *ICMPTracker) TrackOutbound(srcIP netip.Addr, dstIP netip.Addr, id uint16, typecode layers.ICMPv4TypeCode, size int) { +func (t *ICMPTracker) TrackOutbound( + srcIP netip.Addr, + dstIP netip.Addr, + id uint16, + typecode layers.ICMPv4TypeCode, + payload []byte, + size int, +) { if _, exists := t.updateIfExists(dstIP, srcIP, id, nftypes.Egress, size); !exists { // if (inverted direction) conn is not tracked, track this direction - t.track(srcIP, dstIP, id, typecode, nftypes.Egress, nil, size) + t.track(srcIP, dstIP, id, typecode, nftypes.Egress, nil, payload, size) } } // TrackInbound records an inbound ICMP Echo Request -func (t *ICMPTracker) TrackInbound(srcIP netip.Addr, dstIP netip.Addr, id uint16, typecode layers.ICMPv4TypeCode, ruleId []byte, size int) { - t.track(srcIP, dstIP, id, typecode, nftypes.Ingress, ruleId, size) +func (t *ICMPTracker) TrackInbound( + srcIP netip.Addr, + dstIP netip.Addr, + id uint16, + typecode layers.ICMPv4TypeCode, + ruleId []byte, + payload []byte, + size int, +) { + t.track(srcIP, dstIP, id, typecode, nftypes.Ingress, ruleId, payload, size) } // track is the common implementation for tracking both inbound and outbound ICMP connections -func (t *ICMPTracker) track(srcIP netip.Addr, dstIP netip.Addr, id uint16, typecode layers.ICMPv4TypeCode, direction nftypes.Direction, ruleId []byte, size int) { +func (t *ICMPTracker) track( + srcIP netip.Addr, + dstIP netip.Addr, + id uint16, + typecode layers.ICMPv4TypeCode, + direction nftypes.Direction, + ruleId []byte, + payload []byte, + size int, +) { key, exists := t.updateIfExists(srcIP, dstIP, id, direction, size) if exists { return } typ, code := typecode.Type(), typecode.Code() + icmpInfo := ICMPInfo{ + TypeCode: typecode, + } + if len(payload) > 0 { + icmpInfo.PayloadLen = len(payload) + if icmpInfo.PayloadLen > MaxICMPPayloadLength { + icmpInfo.PayloadLen = MaxICMPPayloadLength + } + copy(icmpInfo.PayloadData[:], payload[:icmpInfo.PayloadLen]) + } // non echo requests don't need tracking if typ != uint8(layers.ICMPv4TypeEchoRequest) { - t.logger.Trace("New %s ICMP connection %s type %d code %d", direction, key, typ, code) + t.logger.Trace("New %s ICMP connection %s - %s", direction, key, icmpInfo) t.sendStartEvent(direction, srcIP, dstIP, typ, code, ruleId, size) return } @@ -138,7 +243,7 @@ func (t *ICMPTracker) track(srcIP netip.Addr, dstIP netip.Addr, id uint16, typec t.connections[key] = conn t.mutex.Unlock() - t.logger.Trace("New %s ICMP connection %s type %d code %d", direction, key, typ, code) + t.logger.Trace("New %s ICMP connection %s - %s", direction, key, icmpInfo) t.sendEvent(nftypes.TypeStart, conn, ruleId) } diff --git a/client/firewall/uspfilter/conntrack/icmp_test.go b/client/firewall/uspfilter/conntrack/icmp_test.go index 5a7b36a36..b15b42cf0 100644 --- a/client/firewall/uspfilter/conntrack/icmp_test.go +++ b/client/firewall/uspfilter/conntrack/icmp_test.go @@ -15,7 +15,7 @@ func BenchmarkICMPTracker(b *testing.B) { b.ResetTimer() for i := 0; i < b.N; i++ { - tracker.TrackOutbound(srcIP, dstIP, uint16(i%65535), 0, 0) + tracker.TrackOutbound(srcIP, dstIP, uint16(i%65535), 0, []byte{}, 0) } }) @@ -28,7 +28,7 @@ func BenchmarkICMPTracker(b *testing.B) { // Pre-populate some connections for i := 0; i < 1000; i++ { - tracker.TrackOutbound(srcIP, dstIP, uint16(i), 0, 0) + tracker.TrackOutbound(srcIP, dstIP, uint16(i), 0, []byte{}, 0) } b.ResetTimer() diff --git a/client/firewall/uspfilter/forwarder/endpoint.go b/client/firewall/uspfilter/forwarder/endpoint.go index 3720eedfa..e18c083b9 100644 --- a/client/firewall/uspfilter/forwarder/endpoint.go +++ b/client/firewall/uspfilter/forwarder/endpoint.go @@ -86,5 +86,5 @@ type epID stack.TransportEndpointID func (i epID) String() string { // src and remote is swapped - return fmt.Sprintf("%s:%d -> %s:%d", i.RemoteAddress, i.RemotePort, i.LocalAddress, i.LocalPort) + return fmt.Sprintf("%s:%d → %s:%d", i.RemoteAddress, i.RemotePort, i.LocalAddress, i.LocalPort) } diff --git a/client/firewall/uspfilter/forwarder/tcp.go b/client/firewall/uspfilter/forwarder/tcp.go index 64e54e293..aa42f811b 100644 --- a/client/firewall/uspfilter/forwarder/tcp.go +++ b/client/firewall/uspfilter/forwarder/tcp.go @@ -111,12 +111,12 @@ func (f *Forwarder) proxyTCP(id stack.TransportEndpointID, inConn *gonet.TCPConn if errInToOut != nil { if !isClosedError(errInToOut) { - f.logger.Error("proxyTCP: copy error (in -> out) for %s: %v", epID(id), errInToOut) + f.logger.Error("proxyTCP: copy error (in → out) for %s: %v", epID(id), errInToOut) } } if errOutToIn != nil { if !isClosedError(errOutToIn) { - f.logger.Error("proxyTCP: copy error (out -> in) for %s: %v", epID(id), errOutToIn) + f.logger.Error("proxyTCP: copy error (out → in) for %s: %v", epID(id), errOutToIn) } } diff --git a/client/firewall/uspfilter/forwarder/udp.go b/client/firewall/uspfilter/forwarder/udp.go index f237a313d..3a761d06b 100644 --- a/client/firewall/uspfilter/forwarder/udp.go +++ b/client/firewall/uspfilter/forwarder/udp.go @@ -250,10 +250,10 @@ func (f *Forwarder) proxyUDP(ctx context.Context, pConn *udpPacketConn, id stack wg.Wait() if outboundErr != nil && !isClosedError(outboundErr) { - f.logger.Error("proxyUDP: copy error (outbound->inbound) for %s: %v", epID(id), outboundErr) + f.logger.Error("proxyUDP: copy error (outbound→inbound) for %s: %v", epID(id), outboundErr) } if inboundErr != nil && !isClosedError(inboundErr) { - f.logger.Error("proxyUDP: copy error (inbound->outbound) for %s: %v", epID(id), inboundErr) + f.logger.Error("proxyUDP: copy error (inbound→outbound) for %s: %v", epID(id), inboundErr) } var rxPackets, txPackets uint64 diff --git a/client/firewall/uspfilter/uspfilter.go b/client/firewall/uspfilter/uspfilter.go index c216bc302..dcff92c61 100644 --- a/client/firewall/uspfilter/uspfilter.go +++ b/client/firewall/uspfilter/uspfilter.go @@ -671,7 +671,7 @@ func (m *Manager) trackOutbound(d *decoder, srcIP, dstIP netip.Addr, size int) { flags := getTCPFlags(&d.tcp) m.tcpTracker.TrackOutbound(srcIP, dstIP, uint16(d.tcp.SrcPort), uint16(d.tcp.DstPort), flags, size) case layers.LayerTypeICMPv4: - m.icmpTracker.TrackOutbound(srcIP, dstIP, d.icmp4.Id, d.icmp4.TypeCode, size) + m.icmpTracker.TrackOutbound(srcIP, dstIP, d.icmp4.Id, d.icmp4.TypeCode, d.icmp4.Payload, size) } } @@ -684,7 +684,7 @@ func (m *Manager) trackInbound(d *decoder, srcIP, dstIP netip.Addr, ruleID []byt flags := getTCPFlags(&d.tcp) m.tcpTracker.TrackInbound(srcIP, dstIP, uint16(d.tcp.SrcPort), uint16(d.tcp.DstPort), flags, ruleID, size) case layers.LayerTypeICMPv4: - m.icmpTracker.TrackInbound(srcIP, dstIP, d.icmp4.Id, d.icmp4.TypeCode, ruleID, size) + m.icmpTracker.TrackInbound(srcIP, dstIP, d.icmp4.Id, d.icmp4.TypeCode, ruleID, d.icmp4.Payload, size) } } diff --git a/client/iface/device/device_android.go b/client/iface/device/device_android.go index ab3e611e1..ae9e29bd1 100644 --- a/client/iface/device/device_android.go +++ b/client/iface/device/device_android.go @@ -24,6 +24,7 @@ type WGTunDevice struct { mtu int iceBind *bind.ICEBind tunAdapter TunAdapter + disableDNS bool name string device *device.Device @@ -32,7 +33,7 @@ type WGTunDevice struct { configurer WGConfigurer } -func NewTunDevice(address wgaddr.Address, port int, key string, mtu int, iceBind *bind.ICEBind, tunAdapter TunAdapter) *WGTunDevice { +func NewTunDevice(address wgaddr.Address, port int, key string, mtu int, iceBind *bind.ICEBind, tunAdapter TunAdapter, disableDNS bool) *WGTunDevice { return &WGTunDevice{ address: address, port: port, @@ -40,6 +41,7 @@ func NewTunDevice(address wgaddr.Address, port int, key string, mtu int, iceBind mtu: mtu, iceBind: iceBind, tunAdapter: tunAdapter, + disableDNS: disableDNS, } } @@ -49,6 +51,13 @@ func (t *WGTunDevice) Create(routes []string, dns string, searchDomains []string routesString := routesToString(routes) searchDomainsToString := searchDomainsToString(searchDomains) + // Skip DNS configuration when DisableDNS is enabled + if t.disableDNS { + log.Info("DNS is disabled, skipping DNS and search domain configuration") + dns = "" + searchDomainsToString = "" + } + fd, err := t.tunAdapter.ConfigureInterface(t.address.String(), t.mtu, dns, searchDomainsToString, routesString) if err != nil { log.Errorf("failed to create Android interface: %s", err) diff --git a/client/iface/iface.go b/client/iface/iface.go index 7d609f4cd..006dfe4e7 100644 --- a/client/iface/iface.go +++ b/client/iface/iface.go @@ -43,6 +43,7 @@ type WGIFaceOpts struct { MobileArgs *device.MobileIFaceArguments TransportNet transport.Net FilterFn bind.FilterFn + DisableDNS bool } // WGIface represents an interface instance diff --git a/client/iface/iface_new_android.go b/client/iface/iface_new_android.go index 35046b887..c8babea32 100644 --- a/client/iface/iface_new_android.go +++ b/client/iface/iface_new_android.go @@ -18,7 +18,7 @@ func NewWGIFace(opts WGIFaceOpts) (*WGIface, error) { wgIFace := &WGIface{ userspaceBind: true, - tun: device.NewTunDevice(wgAddress, opts.WGPort, opts.WGPrivKey, opts.MTU, iceBind, opts.MobileArgs.TunAdapter), + tun: device.NewTunDevice(wgAddress, opts.WGPort, opts.WGPrivKey, opts.MTU, iceBind, opts.MobileArgs.TunAdapter, opts.DisableDNS), wgProxyFactory: wgproxy.NewUSPFactory(iceBind), } return wgIFace, nil diff --git a/client/internal/config.go b/client/internal/config.go index 45a7620e1..37ee1e1bf 100644 --- a/client/internal/config.go +++ b/client/internal/config.go @@ -223,6 +223,8 @@ func createNewConfig(input ConfigInput) (*Config, error) { config := &Config{ // defaults to false only for new (post 0.26) configurations ServerSSHAllowed: util.False(), + // default to disabling server routes on Android for security + DisableServerRoutes: runtime.GOOS == "android", } if _, err := config.apply(input); err != nil { @@ -416,9 +418,15 @@ func (config *Config) apply(input ConfigInput) (updated bool, err error) { config.ServerSSHAllowed = input.ServerSSHAllowed updated = true } else if config.ServerSSHAllowed == nil { - // enables SSH for configs from old versions to preserve backwards compatibility - log.Infof("falling back to enabled SSH server for pre-existing configuration") - config.ServerSSHAllowed = util.True() + if runtime.GOOS == "android" { + // default to disabled SSH on Android for security + log.Infof("setting SSH server to false by default on Android") + config.ServerSSHAllowed = util.False() + } else { + // enables SSH for configs from old versions to preserve backwards compatibility + log.Infof("falling back to enabled SSH server for pre-existing configuration") + config.ServerSSHAllowed = util.True() + } updated = true } diff --git a/client/internal/dns/handler_chain.go b/client/internal/dns/handler_chain.go index 758b14e46..36da8fb78 100644 --- a/client/internal/dns/handler_chain.go +++ b/client/internal/dns/handler_chain.go @@ -11,9 +11,10 @@ import ( ) const ( - PriorityDNSRoute = 100 - PriorityMatchDomain = 50 - PriorityDefault = 1 + PriorityLocal = 100 + PriorityDNSRoute = 75 + PriorityUpstream = 50 + PriorityDefault = 1 PriorityFallback = -100 ) diff --git a/client/internal/dns/handler_chain_test.go b/client/internal/dns/handler_chain_test.go index 5f03e0758..72c0004d5 100644 --- a/client/internal/dns/handler_chain_test.go +++ b/client/internal/dns/handler_chain_test.go @@ -22,7 +22,7 @@ func TestHandlerChain_ServeDNS_Priorities(t *testing.T) { // Setup handlers with different priorities chain.AddHandler("example.com.", defaultHandler, nbdns.PriorityDefault) - chain.AddHandler("example.com.", matchDomainHandler, nbdns.PriorityMatchDomain) + chain.AddHandler("example.com.", matchDomainHandler, nbdns.PriorityUpstream) chain.AddHandler("example.com.", dnsRouteHandler, nbdns.PriorityDNSRoute) // Create test request @@ -200,7 +200,7 @@ func TestHandlerChain_ServeDNS_OverlappingDomains(t *testing.T) { priority int }{ {pattern: "*.example.com.", priority: nbdns.PriorityDefault}, - {pattern: "*.example.com.", priority: nbdns.PriorityMatchDomain}, + {pattern: "*.example.com.", priority: nbdns.PriorityUpstream}, {pattern: "*.example.com.", priority: nbdns.PriorityDNSRoute}, }, queryDomain: "test.example.com.", @@ -214,7 +214,7 @@ func TestHandlerChain_ServeDNS_OverlappingDomains(t *testing.T) { priority int }{ {pattern: "*.example.com.", priority: nbdns.PriorityDefault}, - {pattern: "test.example.com.", priority: nbdns.PriorityMatchDomain}, + {pattern: "test.example.com.", priority: nbdns.PriorityUpstream}, {pattern: "*.test.example.com.", priority: nbdns.PriorityDNSRoute}, }, queryDomain: "sub.test.example.com.", @@ -281,7 +281,7 @@ func TestHandlerChain_ServeDNS_ChainContinuation(t *testing.T) { // Add handlers in priority order chain.AddHandler("example.com.", handler1, nbdns.PriorityDNSRoute) - chain.AddHandler("example.com.", handler2, nbdns.PriorityMatchDomain) + chain.AddHandler("example.com.", handler2, nbdns.PriorityUpstream) chain.AddHandler("example.com.", handler3, nbdns.PriorityDefault) // Create test request @@ -344,13 +344,13 @@ func TestHandlerChain_PriorityDeregistration(t *testing.T) { priority int }{ {"add", "example.com.", nbdns.PriorityDNSRoute}, - {"add", "example.com.", nbdns.PriorityMatchDomain}, + {"add", "example.com.", nbdns.PriorityUpstream}, {"remove", "example.com.", nbdns.PriorityDNSRoute}, }, query: "example.com.", expectedCalls: map[int]bool{ - nbdns.PriorityDNSRoute: false, - nbdns.PriorityMatchDomain: true, + nbdns.PriorityDNSRoute: false, + nbdns.PriorityUpstream: true, }, }, { @@ -361,13 +361,13 @@ func TestHandlerChain_PriorityDeregistration(t *testing.T) { priority int }{ {"add", "example.com.", nbdns.PriorityDNSRoute}, - {"add", "example.com.", nbdns.PriorityMatchDomain}, - {"remove", "example.com.", nbdns.PriorityMatchDomain}, + {"add", "example.com.", nbdns.PriorityUpstream}, + {"remove", "example.com.", nbdns.PriorityUpstream}, }, query: "example.com.", expectedCalls: map[int]bool{ - nbdns.PriorityDNSRoute: true, - nbdns.PriorityMatchDomain: false, + nbdns.PriorityDNSRoute: true, + nbdns.PriorityUpstream: false, }, }, { @@ -378,16 +378,16 @@ func TestHandlerChain_PriorityDeregistration(t *testing.T) { priority int }{ {"add", "example.com.", nbdns.PriorityDNSRoute}, - {"add", "example.com.", nbdns.PriorityMatchDomain}, + {"add", "example.com.", nbdns.PriorityUpstream}, {"add", "example.com.", nbdns.PriorityDefault}, {"remove", "example.com.", nbdns.PriorityDNSRoute}, - {"remove", "example.com.", nbdns.PriorityMatchDomain}, + {"remove", "example.com.", nbdns.PriorityUpstream}, }, query: "example.com.", expectedCalls: map[int]bool{ - nbdns.PriorityDNSRoute: false, - nbdns.PriorityMatchDomain: false, - nbdns.PriorityDefault: true, + nbdns.PriorityDNSRoute: false, + nbdns.PriorityUpstream: false, + nbdns.PriorityDefault: true, }, }, } @@ -454,7 +454,7 @@ func TestHandlerChain_MultiPriorityHandling(t *testing.T) { // Add handlers in mixed order chain.AddHandler(testDomain, defaultHandler, nbdns.PriorityDefault) chain.AddHandler(testDomain, routeHandler, nbdns.PriorityDNSRoute) - chain.AddHandler(testDomain, matchHandler, nbdns.PriorityMatchDomain) + chain.AddHandler(testDomain, matchHandler, nbdns.PriorityUpstream) // Test 1: Initial state w1 := &nbdns.ResponseWriterChain{ResponseWriter: &test.MockResponseWriter{}} @@ -490,7 +490,7 @@ func TestHandlerChain_MultiPriorityHandling(t *testing.T) { defaultHandler.Calls = nil // Test 3: Remove middle priority handler - chain.RemoveHandler(testDomain, nbdns.PriorityMatchDomain) + chain.RemoveHandler(testDomain, nbdns.PriorityUpstream) w3 := &nbdns.ResponseWriterChain{ResponseWriter: &test.MockResponseWriter{}} // Now lowest priority handler (defaultHandler) should be called @@ -607,7 +607,7 @@ func TestHandlerChain_CaseSensitivity(t *testing.T) { shouldMatch bool }{ {"EXAMPLE.COM.", nbdns.PriorityDefault, false, false}, - {"example.com.", nbdns.PriorityMatchDomain, false, false}, + {"example.com.", nbdns.PriorityUpstream, false, false}, {"Example.Com.", nbdns.PriorityDNSRoute, false, true}, }, query: "example.com.", @@ -702,8 +702,8 @@ func TestHandlerChain_DomainSpecificityOrdering(t *testing.T) { priority int subdomain bool }{ - {"add", "example.com.", nbdns.PriorityMatchDomain, true}, - {"add", "sub.example.com.", nbdns.PriorityMatchDomain, false}, + {"add", "example.com.", nbdns.PriorityUpstream, true}, + {"add", "sub.example.com.", nbdns.PriorityUpstream, false}, }, query: "sub.example.com.", expectedMatch: "sub.example.com.", @@ -717,8 +717,8 @@ func TestHandlerChain_DomainSpecificityOrdering(t *testing.T) { priority int subdomain bool }{ - {"add", "example.com.", nbdns.PriorityMatchDomain, true}, - {"add", "sub.example.com.", nbdns.PriorityMatchDomain, true}, + {"add", "example.com.", nbdns.PriorityUpstream, true}, + {"add", "sub.example.com.", nbdns.PriorityUpstream, true}, }, query: "sub.example.com.", expectedMatch: "sub.example.com.", @@ -732,10 +732,10 @@ func TestHandlerChain_DomainSpecificityOrdering(t *testing.T) { priority int subdomain bool }{ - {"add", "example.com.", nbdns.PriorityMatchDomain, true}, - {"add", "sub.example.com.", nbdns.PriorityMatchDomain, true}, - {"add", "test.sub.example.com.", nbdns.PriorityMatchDomain, false}, - {"remove", "test.sub.example.com.", nbdns.PriorityMatchDomain, false}, + {"add", "example.com.", nbdns.PriorityUpstream, true}, + {"add", "sub.example.com.", nbdns.PriorityUpstream, true}, + {"add", "test.sub.example.com.", nbdns.PriorityUpstream, false}, + {"remove", "test.sub.example.com.", nbdns.PriorityUpstream, false}, }, query: "test.sub.example.com.", expectedMatch: "sub.example.com.", @@ -749,7 +749,7 @@ func TestHandlerChain_DomainSpecificityOrdering(t *testing.T) { priority int subdomain bool }{ - {"add", "sub.example.com.", nbdns.PriorityMatchDomain, false}, + {"add", "sub.example.com.", nbdns.PriorityUpstream, false}, {"add", "example.com.", nbdns.PriorityDNSRoute, true}, }, query: "sub.example.com.", @@ -764,9 +764,9 @@ func TestHandlerChain_DomainSpecificityOrdering(t *testing.T) { priority int subdomain bool }{ - {"add", "example.com.", nbdns.PriorityMatchDomain, true}, - {"add", "other.example.com.", nbdns.PriorityMatchDomain, true}, - {"add", "sub.example.com.", nbdns.PriorityMatchDomain, false}, + {"add", "example.com.", nbdns.PriorityUpstream, true}, + {"add", "other.example.com.", nbdns.PriorityUpstream, true}, + {"add", "sub.example.com.", nbdns.PriorityUpstream, false}, }, query: "sub.example.com.", expectedMatch: "sub.example.com.", diff --git a/client/internal/dns/server.go b/client/internal/dns/server.go index 33977b4c8..f933c1de0 100644 --- a/client/internal/dns/server.go +++ b/client/internal/dns/server.go @@ -567,7 +567,7 @@ func (s *DefaultServer) buildLocalHandlerUpdate(customZones []nbdns.CustomZone) muxUpdates = append(muxUpdates, handlerWrapper{ domain: customZone.Domain, handler: s.localResolver, - priority: PriorityMatchDomain, + priority: PriorityLocal, }) for _, record := range customZone.Records { @@ -606,7 +606,7 @@ func (s *DefaultServer) buildUpstreamHandlerUpdate(nameServerGroups []*nbdns.Nam groupedNS := groupNSGroupsByDomain(nameServerGroups) for _, domainGroup := range groupedNS { - basePriority := PriorityMatchDomain + basePriority := PriorityUpstream if domainGroup.domain == nbdns.RootZone { basePriority = PriorityDefault } @@ -683,9 +683,9 @@ func (s *DefaultServer) createHandlersForDomainGroup(domainGroup nsGroupsByDomai } func (s *DefaultServer) leaksPriority(domainGroup nsGroupsByDomain, basePriority int, priority int) bool { - if basePriority == PriorityMatchDomain && priority <= PriorityDefault { + if basePriority == PriorityUpstream && priority <= PriorityDefault { log.Warnf("too many handlers for domain=%s, would overlap with default priority tier (diff=%d). Skipping remaining handlers", - domainGroup.domain, PriorityMatchDomain-PriorityDefault) + domainGroup.domain, PriorityUpstream-PriorityDefault) return true } if basePriority == PriorityDefault && priority <= PriorityFallback { diff --git a/client/internal/dns/server_test.go b/client/internal/dns/server_test.go index 73f95ca4e..0fd245a59 100644 --- a/client/internal/dns/server_test.go +++ b/client/internal/dns/server_test.go @@ -164,12 +164,12 @@ func TestUpdateDNSServer(t *testing.T) { generateDummyHandler("netbird.io", nameServers).ID(): handlerWrapper{ domain: "netbird.io", handler: dummyHandler, - priority: PriorityMatchDomain, + priority: PriorityUpstream, }, dummyHandler.ID(): handlerWrapper{ domain: "netbird.cloud", handler: dummyHandler, - priority: PriorityMatchDomain, + priority: PriorityLocal, }, generateDummyHandler(".", nameServers).ID(): handlerWrapper{ domain: nbdns.RootZone, @@ -186,7 +186,7 @@ func TestUpdateDNSServer(t *testing.T) { generateDummyHandler(zoneRecords[0].Name, nameServers).ID(): handlerWrapper{ domain: "netbird.cloud", handler: dummyHandler, - priority: PriorityMatchDomain, + priority: PriorityUpstream, }, }, initSerial: 0, @@ -210,12 +210,12 @@ func TestUpdateDNSServer(t *testing.T) { generateDummyHandler("netbird.io", nameServers).ID(): handlerWrapper{ domain: "netbird.io", handler: dummyHandler, - priority: PriorityMatchDomain, + priority: PriorityUpstream, }, "local-resolver": handlerWrapper{ domain: "netbird.cloud", handler: dummyHandler, - priority: PriorityMatchDomain, + priority: PriorityLocal, }, }, expectedLocalQs: []dns.Question{{Name: zoneRecords[0].Name, Qtype: 1, Qclass: 1}}, @@ -305,7 +305,7 @@ func TestUpdateDNSServer(t *testing.T) { generateDummyHandler(zoneRecords[0].Name, nameServers).ID(): handlerWrapper{ domain: zoneRecords[0].Name, handler: dummyHandler, - priority: PriorityMatchDomain, + priority: PriorityUpstream, }, }, initSerial: 0, @@ -321,7 +321,7 @@ func TestUpdateDNSServer(t *testing.T) { generateDummyHandler(zoneRecords[0].Name, nameServers).ID(): handlerWrapper{ domain: zoneRecords[0].Name, handler: dummyHandler, - priority: PriorityMatchDomain, + priority: PriorityUpstream, }, }, initSerial: 0, @@ -495,7 +495,7 @@ func TestDNSFakeResolverHandleUpdates(t *testing.T) { "id1": handlerWrapper{ domain: zoneRecords[0].Name, handler: &local.Resolver{}, - priority: PriorityMatchDomain, + priority: PriorityUpstream, }, } //dnsServer.localResolver.RegisteredMap = local.RegistrationMap{local.BuildRecordKey("netbird.cloud", dns.ClassINET, dns.TypeA): struct{}{}} @@ -978,7 +978,7 @@ func TestHandlerChain_DomainPriorities(t *testing.T) { } chain.AddHandler("example.com.", dnsRouteHandler, PriorityDNSRoute) - chain.AddHandler("example.com.", upstreamHandler, PriorityMatchDomain) + chain.AddHandler("example.com.", upstreamHandler, PriorityUpstream) testCases := []struct { name string @@ -1059,14 +1059,14 @@ func TestDefaultServer_UpdateMux(t *testing.T) { handler: &mockHandler{ Id: "upstream-group1", }, - priority: PriorityMatchDomain, + priority: PriorityUpstream, }, "upstream-group2": { domain: "example.com", handler: &mockHandler{ Id: "upstream-group2", }, - priority: PriorityMatchDomain - 1, + priority: PriorityUpstream - 1, }, } @@ -1093,21 +1093,21 @@ func TestDefaultServer_UpdateMux(t *testing.T) { handler: &mockHandler{ Id: "upstream-group1", }, - priority: PriorityMatchDomain, + priority: PriorityUpstream, }, "upstream-group2": { domain: "example.com", handler: &mockHandler{ Id: "upstream-group2", }, - priority: PriorityMatchDomain - 1, + priority: PriorityUpstream - 1, }, "upstream-other": { domain: "other.com", handler: &mockHandler{ Id: "upstream-other", }, - priority: PriorityMatchDomain, + priority: PriorityUpstream, }, } @@ -1128,7 +1128,7 @@ func TestDefaultServer_UpdateMux(t *testing.T) { handler: &mockHandler{ Id: "upstream-group2", }, - priority: PriorityMatchDomain - 1, + priority: PriorityUpstream - 1, }, }, expectedHandlers: map[string]string{ @@ -1146,7 +1146,7 @@ func TestDefaultServer_UpdateMux(t *testing.T) { handler: &mockHandler{ Id: "upstream-group1", }, - priority: PriorityMatchDomain, + priority: PriorityUpstream, }, }, expectedHandlers: map[string]string{ @@ -1164,7 +1164,7 @@ func TestDefaultServer_UpdateMux(t *testing.T) { handler: &mockHandler{ Id: "upstream-group3", }, - priority: PriorityMatchDomain + 1, + priority: PriorityUpstream + 1, }, // Keep existing groups with their original priorities { @@ -1172,14 +1172,14 @@ func TestDefaultServer_UpdateMux(t *testing.T) { handler: &mockHandler{ Id: "upstream-group1", }, - priority: PriorityMatchDomain, + priority: PriorityUpstream, }, { domain: "example.com", handler: &mockHandler{ Id: "upstream-group2", }, - priority: PriorityMatchDomain - 1, + priority: PriorityUpstream - 1, }, }, expectedHandlers: map[string]string{ @@ -1199,14 +1199,14 @@ func TestDefaultServer_UpdateMux(t *testing.T) { handler: &mockHandler{ Id: "upstream-group1", }, - priority: PriorityMatchDomain, + priority: PriorityUpstream, }, { domain: "example.com", handler: &mockHandler{ Id: "upstream-group2", }, - priority: PriorityMatchDomain - 1, + priority: PriorityUpstream - 1, }, // Add group3 with lowest priority { @@ -1214,7 +1214,7 @@ func TestDefaultServer_UpdateMux(t *testing.T) { handler: &mockHandler{ Id: "upstream-group3", }, - priority: PriorityMatchDomain - 2, + priority: PriorityUpstream - 2, }, }, expectedHandlers: map[string]string{ @@ -1335,14 +1335,14 @@ func TestDefaultServer_UpdateMux(t *testing.T) { handler: &mockHandler{ Id: "upstream-group1", }, - priority: PriorityMatchDomain, + priority: PriorityUpstream, }, { domain: "other.com", handler: &mockHandler{ Id: "upstream-other", }, - priority: PriorityMatchDomain, + priority: PriorityUpstream, }, }, expectedHandlers: map[string]string{ @@ -1360,28 +1360,28 @@ func TestDefaultServer_UpdateMux(t *testing.T) { handler: &mockHandler{ Id: "upstream-group1", }, - priority: PriorityMatchDomain, + priority: PriorityUpstream, }, { domain: "example.com", handler: &mockHandler{ Id: "upstream-group2", }, - priority: PriorityMatchDomain - 1, + priority: PriorityUpstream - 1, }, { domain: "other.com", handler: &mockHandler{ Id: "upstream-other", }, - priority: PriorityMatchDomain, + priority: PriorityUpstream, }, { domain: "new.com", handler: &mockHandler{ Id: "upstream-new", }, - priority: PriorityMatchDomain, + priority: PriorityUpstream, }, }, expectedHandlers: map[string]string{ @@ -1791,14 +1791,14 @@ func TestExtraDomainsRefCounting(t *testing.T) { // Register domains from different handlers with same domain server.RegisterHandler(domain.List{"*.shared.example.com"}, &MockHandler{}, PriorityDNSRoute) - server.RegisterHandler(domain.List{"shared.example.com."}, &MockHandler{}, PriorityMatchDomain) + server.RegisterHandler(domain.List{"shared.example.com."}, &MockHandler{}, PriorityUpstream) // Verify refcount is 2 zoneKey := toZone("shared.example.com") assert.Equal(t, 2, server.extraDomains[zoneKey], "Refcount should be 2 after registering same domain twice") // Deregister one handler - server.DeregisterHandler(domain.List{"shared.example.com"}, PriorityMatchDomain) + server.DeregisterHandler(domain.List{"shared.example.com"}, PriorityUpstream) // Verify refcount is 1 assert.Equal(t, 1, server.extraDomains[zoneKey], "Refcount should be 1 after deregistering one handler") @@ -1925,7 +1925,7 @@ func TestDomainCaseHandling(t *testing.T) { } server.RegisterHandler(domain.List{"MIXED.example.com"}, &MockHandler{}, PriorityDefault) - server.RegisterHandler(domain.List{"mixed.EXAMPLE.com"}, &MockHandler{}, PriorityMatchDomain) + server.RegisterHandler(domain.List{"mixed.EXAMPLE.com"}, &MockHandler{}, PriorityUpstream) assert.Equal(t, 1, len(server.extraDomains), "Case differences should be normalized") @@ -1945,3 +1945,111 @@ func TestDomainCaseHandling(t *testing.T) { assert.Contains(t, domains, "config.example.com.", "Mixed case domain should be normalized and pre.sent") assert.Contains(t, domains, "mixed.example.com.", "Mixed case domain should be normalized and present") } + +func TestLocalResolverPriorityInServer(t *testing.T) { + server := &DefaultServer{ + ctx: context.Background(), + wgInterface: &mocWGIface{}, + handlerChain: NewHandlerChain(), + localResolver: local.NewResolver(), + service: &mockService{}, + extraDomains: make(map[domain.Domain]int), + } + + config := nbdns.Config{ + ServiceEnable: true, + CustomZones: []nbdns.CustomZone{ + { + Domain: "local.example.com", + Records: []nbdns.SimpleRecord{ + { + Name: "test.local.example.com", + Type: int(dns.TypeA), + Class: nbdns.DefaultClass, + TTL: 300, + RData: "192.168.1.100", + }, + }, + }, + }, + NameServerGroups: []*nbdns.NameServerGroup{ + { + Domains: []string{"local.example.com"}, // Same domain as local records + NameServers: []nbdns.NameServer{ + { + IP: netip.MustParseAddr("8.8.8.8"), + NSType: nbdns.UDPNameServerType, + Port: 53, + }, + }, + }, + }, + } + + localMuxUpdates, _, err := server.buildLocalHandlerUpdate(config.CustomZones) + assert.NoError(t, err) + + upstreamMuxUpdates, err := server.buildUpstreamHandlerUpdate(config.NameServerGroups) + assert.NoError(t, err) + + // Verify that local handler has higher priority than upstream for same domain + var localPriority, upstreamPriority int + localFound, upstreamFound := false, false + + for _, update := range localMuxUpdates { + if update.domain == "local.example.com" { + localPriority = update.priority + localFound = true + } + } + + for _, update := range upstreamMuxUpdates { + if update.domain == "local.example.com" { + upstreamPriority = update.priority + upstreamFound = true + } + } + + assert.True(t, localFound, "Local handler should be found") + assert.True(t, upstreamFound, "Upstream handler should be found") + assert.Greater(t, localPriority, upstreamPriority, + "Local handler priority (%d) should be higher than upstream priority (%d)", + localPriority, upstreamPriority) + assert.Equal(t, PriorityLocal, localPriority, "Local handler should use PriorityLocal") + assert.Equal(t, PriorityUpstream, upstreamPriority, "Upstream handler should use PriorityUpstream") +} + +func TestLocalResolverPriorityConstants(t *testing.T) { + // Test that priority constants are ordered correctly + assert.Greater(t, PriorityLocal, PriorityDNSRoute, "Local priority should be higher than DNS route") + assert.Greater(t, PriorityLocal, PriorityUpstream, "Local priority should be higher than upstream") + assert.Greater(t, PriorityUpstream, PriorityDefault, "Upstream priority should be higher than default") + + // Test that local resolver uses the correct priority + server := &DefaultServer{ + localResolver: local.NewResolver(), + } + + config := nbdns.Config{ + CustomZones: []nbdns.CustomZone{ + { + Domain: "local.example.com", + Records: []nbdns.SimpleRecord{ + { + Name: "test.local.example.com", + Type: int(dns.TypeA), + Class: nbdns.DefaultClass, + TTL: 300, + RData: "192.168.1.100", + }, + }, + }, + }, + } + + localMuxUpdates, _, err := server.buildLocalHandlerUpdate(config.CustomZones) + assert.NoError(t, err) + assert.Len(t, localMuxUpdates, 1) + assert.Equal(t, PriorityLocal, localMuxUpdates[0].priority, "Local handler should use PriorityLocal") + assert.Equal(t, "local.example.com", localMuxUpdates[0].domain) +} diff --git a/client/internal/dns/upstream.go b/client/internal/dns/upstream.go index 2fbfb3b91..c44d36599 100644 --- a/client/internal/dns/upstream.go +++ b/client/internal/dns/upstream.go @@ -2,6 +2,7 @@ package dns import ( "context" + "crypto/rand" "crypto/sha256" "encoding/hex" "errors" @@ -103,19 +104,21 @@ func (u *upstreamResolverBase) Stop() { // ServeDNS handles a DNS request func (u *upstreamResolverBase) ServeDNS(w dns.ResponseWriter, r *dns.Msg) { + requestID := GenerateRequestID() + logger := log.WithField("request_id", requestID) var err error defer func() { u.checkUpstreamFails(err) }() - log.Tracef("received upstream question: domain=%s type=%v class=%v", r.Question[0].Name, r.Question[0].Qtype, r.Question[0].Qclass) + logger.Tracef("received upstream question: domain=%s type=%v class=%v", r.Question[0].Name, r.Question[0].Qtype, r.Question[0].Qclass) if r.Extra == nil { r.MsgHdr.AuthenticatedData = true } select { case <-u.ctx.Done(): - log.Tracef("%s has been stopped", u) + logger.Tracef("%s has been stopped", u) return default: } @@ -132,35 +135,35 @@ func (u *upstreamResolverBase) ServeDNS(w dns.ResponseWriter, r *dns.Msg) { if err != nil { if errors.Is(err, context.DeadlineExceeded) || isTimeout(err) { - log.Warnf("upstream %s timed out for question domain=%s", upstream, r.Question[0].Name) + logger.Warnf("upstream %s timed out for question domain=%s", upstream, r.Question[0].Name) continue } - log.Warnf("failed to query upstream %s for question domain=%s: %s", upstream, r.Question[0].Name, err) + logger.Warnf("failed to query upstream %s for question domain=%s: %s", upstream, r.Question[0].Name, err) continue } if rm == nil || !rm.Response { - log.Warnf("no response from upstream %s for question domain=%s", upstream, r.Question[0].Name) + logger.Warnf("no response from upstream %s for question domain=%s", upstream, r.Question[0].Name) continue } u.successCount.Add(1) - log.Tracef("took %s to query the upstream %s for question domain=%s", t, upstream, r.Question[0].Name) + logger.Tracef("took %s to query the upstream %s for question domain=%s", t, upstream, r.Question[0].Name) if err = w.WriteMsg(rm); err != nil { - log.Errorf("failed to write DNS response for question domain=%s: %s", r.Question[0].Name, err) + logger.Errorf("failed to write DNS response for question domain=%s: %s", r.Question[0].Name, err) } // count the fails only if they happen sequentially u.failsCount.Store(0) return } u.failsCount.Add(1) - log.Errorf("all queries to the %s failed for question domain=%s", u, r.Question[0].Name) + logger.Errorf("all queries to the %s failed for question domain=%s", u, r.Question[0].Name) m := new(dns.Msg) m.SetRcode(r, dns.RcodeServerFailure) if err := w.WriteMsg(m); err != nil { - log.Errorf("failed to write error response for %s for question domain=%s: %s", u, r.Question[0].Name, err) + logger.Errorf("failed to write error response for %s for question domain=%s: %s", u, r.Question[0].Name, err) } } @@ -385,3 +388,13 @@ func ExchangeWithFallback(ctx context.Context, client *dns.Client, r *dns.Msg, u return rm, t, nil } + +func GenerateRequestID() string { + bytes := make([]byte, 4) + _, err := rand.Read(bytes) + if err != nil { + log.Errorf("failed to generate request ID: %v", err) + return "" + } + return hex.EncodeToString(bytes) +} diff --git a/client/internal/dns/upstream_android.go b/client/internal/dns/upstream_android.go index 52d2ba58b..e7db581b1 100644 --- a/client/internal/dns/upstream_android.go +++ b/client/internal/dns/upstream_android.go @@ -84,3 +84,10 @@ func (u *upstreamResolver) isLocalResolver(upstream string) bool { } return false } + +func GetClientPrivate(ip netip.Addr, interfaceName string, dialTimeout time.Duration) (*dns.Client, error) { + return &dns.Client{ + Timeout: dialTimeout, + Net: "udp", + }, nil +} diff --git a/client/internal/dns/upstream_general.go b/client/internal/dns/upstream_general.go index 1bc06a7c1..317588a27 100644 --- a/client/internal/dns/upstream_general.go +++ b/client/internal/dns/upstream_general.go @@ -36,3 +36,10 @@ func newUpstreamResolver( func (u *upstreamResolver) exchange(ctx context.Context, upstream string, r *dns.Msg) (rm *dns.Msg, t time.Duration, err error) { return ExchangeWithFallback(ctx, &dns.Client{}, r, upstream) } + +func GetClientPrivate(ip netip.Addr, interfaceName string, dialTimeout time.Duration) (*dns.Client, error) { + return &dns.Client{ + Timeout: dialTimeout, + Net: "udp", + }, nil +} diff --git a/client/internal/dnsfwd/forwarder.go b/client/internal/dnsfwd/forwarder.go index 45b479632..506c429cd 100644 --- a/client/internal/dnsfwd/forwarder.go +++ b/client/internal/dnsfwd/forwarder.go @@ -18,14 +18,20 @@ import ( nberrors "github.com/netbirdio/netbird/client/errors" firewall "github.com/netbirdio/netbird/client/firewall/manager" "github.com/netbirdio/netbird/client/internal/peer" - nbdns "github.com/netbirdio/netbird/dns" - "github.com/netbirdio/netbird/management/domain" "github.com/netbirdio/netbird/route" ) const errResolveFailed = "failed to resolve query for domain=%s: %v" const upstreamTimeout = 15 * time.Second +type resolver interface { + LookupNetIP(ctx context.Context, network, host string) ([]netip.Addr, error) +} + +type firewaller interface { + UpdateSet(set firewall.Set, prefixes []netip.Prefix) error +} + type DNSForwarder struct { listenAddress string ttl uint32 @@ -38,16 +44,18 @@ type DNSForwarder struct { mutex sync.RWMutex fwdEntries []*ForwarderEntry - firewall firewall.Manager + firewall firewaller + resolver resolver } -func NewDNSForwarder(listenAddress string, ttl uint32, firewall firewall.Manager, statusRecorder *peer.Status) *DNSForwarder { +func NewDNSForwarder(listenAddress string, ttl uint32, firewall firewaller, statusRecorder *peer.Status) *DNSForwarder { log.Debugf("creating DNS forwarder with listen_address=%s ttl=%d", listenAddress, ttl) return &DNSForwarder{ listenAddress: listenAddress, ttl: ttl, firewall: firewall, statusRecorder: statusRecorder, + resolver: net.DefaultResolver, } } @@ -57,14 +65,17 @@ func (f *DNSForwarder) Listen(entries []*ForwarderEntry) error { // UDP server mux := dns.NewServeMux() f.mux = mux + mux.HandleFunc(".", f.handleDNSQueryUDP) f.dnsServer = &dns.Server{ Addr: f.listenAddress, Net: "udp", Handler: mux, } + // TCP server tcpMux := dns.NewServeMux() f.tcpMux = tcpMux + tcpMux.HandleFunc(".", f.handleDNSQueryTCP) f.tcpServer = &dns.Server{ Addr: f.listenAddress, Net: "tcp", @@ -87,30 +98,13 @@ func (f *DNSForwarder) Listen(entries []*ForwarderEntry) error { // return the first error we get (e.g. bind failure or shutdown) return <-errCh } + func (f *DNSForwarder) UpdateDomains(entries []*ForwarderEntry) { f.mutex.Lock() defer f.mutex.Unlock() - if f.mux == nil { - log.Debug("DNS mux is nil, skipping domain update") - f.fwdEntries = entries - return - } - - oldDomains := filterDomains(f.fwdEntries) - for _, d := range oldDomains { - f.mux.HandleRemove(d.PunycodeString()) - f.tcpMux.HandleRemove(d.PunycodeString()) - } - - newDomains := filterDomains(entries) - for _, d := range newDomains { - f.mux.HandleFunc(d.PunycodeString(), f.handleDNSQueryUDP) - f.tcpMux.HandleFunc(d.PunycodeString(), f.handleDNSQueryTCP) - } - f.fwdEntries = entries - log.Debugf("Updated domains from %v to %v", oldDomains, newDomains) + log.Debugf("Updated DNS forwarder with %d domains", len(entries)) } func (f *DNSForwarder) Close(ctx context.Context) error { @@ -157,22 +151,31 @@ func (f *DNSForwarder) handleDNSQuery(w dns.ResponseWriter, query *dns.Msg) *dns return nil } + mostSpecificResId, matchingEntries := f.getMatchingEntries(strings.TrimSuffix(domain, ".")) + // query doesn't match any configured domain + if mostSpecificResId == "" { + resp.Rcode = dns.RcodeRefused + if err := w.WriteMsg(resp); err != nil { + log.Errorf("failed to write DNS response: %v", err) + } + return nil + } + ctx, cancel := context.WithTimeout(context.Background(), upstreamTimeout) defer cancel() - ips, err := net.DefaultResolver.LookupNetIP(ctx, network, domain) + ips, err := f.resolver.LookupNetIP(ctx, network, domain) if err != nil { f.handleDNSError(w, query, resp, domain, err) return nil } - f.updateInternalState(domain, ips) + f.updateInternalState(ips, mostSpecificResId, matchingEntries) f.addIPsToResponse(resp, domain, ips) return resp } func (f *DNSForwarder) handleDNSQueryUDP(w dns.ResponseWriter, query *dns.Msg) { - resp := f.handleDNSQuery(w, query) if resp == nil { return @@ -206,9 +209,8 @@ func (f *DNSForwarder) handleDNSQueryTCP(w dns.ResponseWriter, query *dns.Msg) { } } -func (f *DNSForwarder) updateInternalState(domain string, ips []netip.Addr) { +func (f *DNSForwarder) updateInternalState(ips []netip.Addr, mostSpecificResId route.ResID, matchingEntries []*ForwarderEntry) { var prefixes []netip.Prefix - mostSpecificResId, matchingEntries := f.getMatchingEntries(strings.TrimSuffix(domain, ".")) if mostSpecificResId != "" { for _, ip := range ips { var prefix netip.Prefix @@ -339,16 +341,3 @@ func (f *DNSForwarder) getMatchingEntries(domain string) (route.ResID, []*Forwar return selectedResId, matches } - -// filterDomains returns a list of normalized domains -func filterDomains(entries []*ForwarderEntry) domain.List { - newDomains := make(domain.List, 0, len(entries)) - for _, d := range entries { - if d.Domain == "" { - log.Warn("empty domain in DNS forwarder") - continue - } - newDomains = append(newDomains, domain.Domain(nbdns.NormalizeZone(d.Domain.PunycodeString()))) - } - return newDomains -} diff --git a/client/internal/dnsfwd/forwarder_test.go b/client/internal/dnsfwd/forwarder_test.go index f0829bbbd..d8228c733 100644 --- a/client/internal/dnsfwd/forwarder_test.go +++ b/client/internal/dnsfwd/forwarder_test.go @@ -1,11 +1,21 @@ package dnsfwd import ( + "context" + "fmt" + "net/netip" + "strings" "testing" + "time" + "github.com/miekg/dns" "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/mock" "github.com/stretchr/testify/require" + firewall "github.com/netbirdio/netbird/client/firewall/manager" + "github.com/netbirdio/netbird/client/internal/dns/test" + "github.com/netbirdio/netbird/client/internal/peer" "github.com/netbirdio/netbird/management/domain" "github.com/netbirdio/netbird/route" ) @@ -13,7 +23,7 @@ import ( func Test_getMatchingEntries(t *testing.T) { testCases := []struct { name string - storedMappings map[string]route.ResID // key: domain pattern, value: resId + storedMappings map[string]route.ResID queryDomain string expectedResId route.ResID }{ @@ -44,7 +54,7 @@ func Test_getMatchingEntries(t *testing.T) { { name: "Wildcard pattern does not match different domain", storedMappings: map[string]route.ResID{"*.example.com": "res4"}, - queryDomain: "foo.notexample.com", + queryDomain: "foo.example.org", expectedResId: "", }, { @@ -101,3 +111,619 @@ func Test_getMatchingEntries(t *testing.T) { }) } } + +type MockFirewall struct { + mock.Mock +} + +func (m *MockFirewall) UpdateSet(set firewall.Set, prefixes []netip.Prefix) error { + args := m.Called(set, prefixes) + return args.Error(0) +} + +type MockResolver struct { + mock.Mock +} + +func (m *MockResolver) LookupNetIP(ctx context.Context, network, host string) ([]netip.Addr, error) { + args := m.Called(ctx, network, host) + return args.Get(0).([]netip.Addr), args.Error(1) +} + +func TestDNSForwarder_SubdomainAccessLogic(t *testing.T) { + tests := []struct { + name string + configuredDomain string + queryDomain string + shouldMatch bool + expectedResID route.ResID + description string + }{ + { + name: "exact domain match should be allowed", + configuredDomain: "example.com", + queryDomain: "example.com", + shouldMatch: true, + expectedResID: "test-res-id", + description: "Direct match to configured domain should work", + }, + { + name: "subdomain access should be restricted", + configuredDomain: "example.com", + queryDomain: "mail.example.com", + shouldMatch: false, + expectedResID: "", + description: "Subdomain should not be accessible unless explicitly configured", + }, + { + name: "wildcard should allow subdomains", + configuredDomain: "*.example.com", + queryDomain: "mail.example.com", + shouldMatch: true, + expectedResID: "test-res-id", + description: "Wildcard domains should allow subdomain access", + }, + { + name: "wildcard should allow base domain", + configuredDomain: "*.example.com", + queryDomain: "example.com", + shouldMatch: true, + expectedResID: "test-res-id", + description: "Wildcard should also match the base domain", + }, + { + name: "deep subdomain should be restricted", + configuredDomain: "example.com", + queryDomain: "deep.mail.example.com", + shouldMatch: false, + expectedResID: "", + description: "Deep subdomains should not be accessible", + }, + { + name: "wildcard allows deep subdomains", + configuredDomain: "*.example.com", + queryDomain: "deep.mail.example.com", + shouldMatch: true, + expectedResID: "test-res-id", + description: "Wildcard should allow deep subdomains", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + forwarder := &DNSForwarder{} + + d, err := domain.FromString(tt.configuredDomain) + require.NoError(t, err) + + entries := []*ForwarderEntry{ + { + Domain: d, + ResID: "test-res-id", + }, + } + + forwarder.UpdateDomains(entries) + + resID, matchingEntries := forwarder.getMatchingEntries(tt.queryDomain) + + if tt.shouldMatch { + assert.Equal(t, tt.expectedResID, resID, "Expected matching ResID") + assert.NotEmpty(t, matchingEntries, "Expected matching entries") + t.Logf("✓ Domain %s correctly matches pattern %s", tt.queryDomain, tt.configuredDomain) + } else { + assert.Equal(t, tt.expectedResID, resID, "Expected no ResID match") + assert.Empty(t, matchingEntries, "Expected no matching entries") + t.Logf("✓ Domain %s correctly does NOT match pattern %s", tt.queryDomain, tt.configuredDomain) + } + }) + } +} + +func TestDNSForwarder_UnauthorizedDomainAccess(t *testing.T) { + if testing.Short() { + t.Skip("Skipping integration test in short mode") + } + + tests := []struct { + name string + configuredDomain string + queryDomain string + shouldResolve bool + description string + }{ + { + name: "configured exact domain resolves", + configuredDomain: "example.com", + queryDomain: "example.com", + shouldResolve: true, + description: "Exact match should resolve", + }, + { + name: "unauthorized subdomain blocked", + configuredDomain: "example.com", + queryDomain: "mail.example.com", + shouldResolve: false, + description: "Subdomain should be blocked without wildcard", + }, + { + name: "wildcard allows subdomain", + configuredDomain: "*.example.com", + queryDomain: "mail.example.com", + shouldResolve: true, + description: "Wildcard should allow subdomain", + }, + { + name: "wildcard allows base domain", + configuredDomain: "*.example.com", + queryDomain: "example.com", + shouldResolve: true, + description: "Wildcard should allow base domain", + }, + { + name: "unrelated domain blocked", + configuredDomain: "example.com", + queryDomain: "example.org", + shouldResolve: false, + description: "Unrelated domain should be blocked", + }, + { + name: "deep subdomain blocked", + configuredDomain: "example.com", + queryDomain: "deep.mail.example.com", + shouldResolve: false, + description: "Deep subdomain should be blocked", + }, + { + name: "wildcard allows deep subdomain", + configuredDomain: "*.example.com", + queryDomain: "deep.mail.example.com", + shouldResolve: true, + description: "Wildcard should allow deep subdomain", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + mockFirewall := &MockFirewall{} + mockResolver := &MockResolver{} + + if tt.shouldResolve { + mockFirewall.On("UpdateSet", mock.AnythingOfType("manager.Set"), mock.AnythingOfType("[]netip.Prefix")).Return(nil) + + // Mock successful DNS resolution + fakeIP := netip.MustParseAddr("1.2.3.4") + mockResolver.On("LookupNetIP", mock.Anything, "ip4", dns.Fqdn(tt.queryDomain)).Return([]netip.Addr{fakeIP}, nil) + } + + forwarder := NewDNSForwarder("127.0.0.1:0", 300, mockFirewall, &peer.Status{}) + forwarder.resolver = mockResolver + + d, err := domain.FromString(tt.configuredDomain) + require.NoError(t, err) + + entries := []*ForwarderEntry{ + { + Domain: d, + ResID: "test-res-id", + Set: firewall.NewDomainSet([]domain.Domain{d}), + }, + } + + forwarder.UpdateDomains(entries) + + query := &dns.Msg{} + query.SetQuestion(dns.Fqdn(tt.queryDomain), dns.TypeA) + + mockWriter := &test.MockResponseWriter{} + resp := forwarder.handleDNSQuery(mockWriter, query) + + if tt.shouldResolve { + require.NotNil(t, resp, "Expected response for authorized domain") + require.Equal(t, dns.RcodeSuccess, resp.Rcode, "Expected successful response") + assert.NotEmpty(t, resp.Answer, "Expected DNS answer records") + + time.Sleep(10 * time.Millisecond) + mockFirewall.AssertExpectations(t) + mockResolver.AssertExpectations(t) + } else { + if resp != nil { + assert.True(t, len(resp.Answer) == 0 || resp.Rcode != dns.RcodeSuccess, + "Unauthorized domain should not return successful answers") + } + mockFirewall.AssertNotCalled(t, "UpdateSet") + mockResolver.AssertNotCalled(t, "LookupNetIP") + } + }) + } +} + +func TestDNSForwarder_FirewallSetUpdates(t *testing.T) { + tests := []struct { + name string + configuredDomains []string + query string + mockIP string + shouldResolve bool + expectedSetCount int // How many sets should be updated + description string + }{ + { + name: "exact domain gets firewall update", + configuredDomains: []string{"example.com"}, + query: "example.com", + mockIP: "1.1.1.1", + shouldResolve: true, + expectedSetCount: 1, + description: "Single exact match updates one set", + }, + { + name: "wildcard domain gets firewall update", + configuredDomains: []string{"*.example.com"}, + query: "mail.example.com", + mockIP: "1.1.1.2", + shouldResolve: true, + expectedSetCount: 1, + description: "Wildcard match updates one set", + }, + { + name: "overlapping exact and wildcard both get updates", + configuredDomains: []string{"*.example.com", "mail.example.com"}, + query: "mail.example.com", + mockIP: "1.1.1.3", + shouldResolve: true, + expectedSetCount: 2, + description: "Both exact and wildcard sets should be updated", + }, + { + name: "unauthorized domain gets no firewall update", + configuredDomains: []string{"example.com"}, + query: "mail.example.com", + mockIP: "1.1.1.4", + shouldResolve: false, + expectedSetCount: 0, + description: "No firewall update for unauthorized domains", + }, + { + name: "multiple wildcards matching get all updated", + configuredDomains: []string{"*.example.com", "*.sub.example.com"}, + query: "test.sub.example.com", + mockIP: "1.1.1.5", + shouldResolve: true, + expectedSetCount: 2, + description: "All matching wildcard sets should be updated", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + mockFirewall := &MockFirewall{} + mockResolver := &MockResolver{} + + // Set up forwarder + forwarder := NewDNSForwarder("127.0.0.1:0", 300, mockFirewall, &peer.Status{}) + forwarder.resolver = mockResolver + + // Create entries and track sets + var entries []*ForwarderEntry + sets := make([]firewall.Set, 0) + + for i, configDomain := range tt.configuredDomains { + d, err := domain.FromString(configDomain) + require.NoError(t, err) + + set := firewall.NewDomainSet([]domain.Domain{d}) + sets = append(sets, set) + + entries = append(entries, &ForwarderEntry{ + Domain: d, + ResID: route.ResID(fmt.Sprintf("res-%d", i)), + Set: set, + }) + } + + forwarder.UpdateDomains(entries) + + // Set up mocks + if tt.shouldResolve { + fakeIP := netip.MustParseAddr(tt.mockIP) + mockResolver.On("LookupNetIP", mock.Anything, "ip4", dns.Fqdn(tt.query)). + Return([]netip.Addr{fakeIP}, nil).Once() + + expectedPrefixes := []netip.Prefix{netip.PrefixFrom(fakeIP, 32)} + + // Count how many sets should actually match + updateCount := 0 + for i, entry := range entries { + domain := strings.ToLower(tt.query) + pattern := entry.Domain.PunycodeString() + + matches := false + if strings.HasPrefix(pattern, "*.") { + baseDomain := strings.TrimPrefix(pattern, "*.") + if domain == baseDomain || strings.HasSuffix(domain, "."+baseDomain) { + matches = true + } + } else if domain == pattern { + matches = true + } + + if matches { + mockFirewall.On("UpdateSet", sets[i], expectedPrefixes).Return(nil).Once() + updateCount++ + } + } + + assert.Equal(t, tt.expectedSetCount, updateCount, + "Expected %d sets to be updated, but mock expects %d", + tt.expectedSetCount, updateCount) + } + + // Execute query + dnsQuery := &dns.Msg{} + dnsQuery.SetQuestion(dns.Fqdn(tt.query), dns.TypeA) + + mockWriter := &test.MockResponseWriter{} + resp := forwarder.handleDNSQuery(mockWriter, dnsQuery) + + // Verify response + if tt.shouldResolve { + require.NotNil(t, resp, "Expected response for authorized domain") + require.Equal(t, dns.RcodeSuccess, resp.Rcode) + require.NotEmpty(t, resp.Answer) + } else if resp != nil { + assert.True(t, resp.Rcode == dns.RcodeRefused || len(resp.Answer) == 0, + "Unauthorized domain should be refused or have no answers") + } + + // Verify all mock expectations were met + mockFirewall.AssertExpectations(t) + mockResolver.AssertExpectations(t) + }) + } +} + +// Test to verify that multiple IPs for one domain result in all prefixes being sent together +func TestDNSForwarder_MultipleIPsInSingleUpdate(t *testing.T) { + mockFirewall := &MockFirewall{} + mockResolver := &MockResolver{} + + forwarder := NewDNSForwarder("127.0.0.1:0", 300, mockFirewall, &peer.Status{}) + forwarder.resolver = mockResolver + + // Configure a single domain + d, err := domain.FromString("example.com") + require.NoError(t, err) + + set := firewall.NewDomainSet([]domain.Domain{d}) + entries := []*ForwarderEntry{{ + Domain: d, + ResID: "test-res", + Set: set, + }} + + forwarder.UpdateDomains(entries) + + // Mock resolver returns multiple IPs + ips := []netip.Addr{ + netip.MustParseAddr("1.1.1.1"), + netip.MustParseAddr("1.1.1.2"), + netip.MustParseAddr("1.1.1.3"), + } + mockResolver.On("LookupNetIP", mock.Anything, "ip4", "example.com."). + Return(ips, nil).Once() + + // Expect ONE UpdateSet call with ALL prefixes + expectedPrefixes := []netip.Prefix{ + netip.PrefixFrom(ips[0], 32), + netip.PrefixFrom(ips[1], 32), + netip.PrefixFrom(ips[2], 32), + } + mockFirewall.On("UpdateSet", set, expectedPrefixes).Return(nil).Once() + + // Execute query + query := &dns.Msg{} + query.SetQuestion("example.com.", dns.TypeA) + + mockWriter := &test.MockResponseWriter{} + resp := forwarder.handleDNSQuery(mockWriter, query) + + // Verify response contains all IPs + require.NotNil(t, resp) + require.Equal(t, dns.RcodeSuccess, resp.Rcode) + require.Len(t, resp.Answer, 3, "Should have 3 answer records") + + // Verify mocks + mockFirewall.AssertExpectations(t) + mockResolver.AssertExpectations(t) +} + +func TestDNSForwarder_ResponseCodes(t *testing.T) { + tests := []struct { + name string + queryType uint16 + queryDomain string + configured string + expectedCode int + description string + }{ + { + name: "unauthorized domain returns REFUSED", + queryType: dns.TypeA, + queryDomain: "evil.com", + configured: "example.com", + expectedCode: dns.RcodeRefused, + description: "RFC compliant REFUSED for unauthorized queries", + }, + { + name: "unsupported query type returns NOTIMP", + queryType: dns.TypeMX, + queryDomain: "example.com", + configured: "example.com", + expectedCode: dns.RcodeNotImplemented, + description: "RFC compliant NOTIMP for unsupported types", + }, + { + name: "CNAME query returns NOTIMP", + queryType: dns.TypeCNAME, + queryDomain: "example.com", + configured: "example.com", + expectedCode: dns.RcodeNotImplemented, + description: "CNAME queries not supported", + }, + { + name: "TXT query returns NOTIMP", + queryType: dns.TypeTXT, + queryDomain: "example.com", + configured: "example.com", + expectedCode: dns.RcodeNotImplemented, + description: "TXT queries not supported", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + forwarder := NewDNSForwarder("127.0.0.1:0", 300, nil, &peer.Status{}) + + d, err := domain.FromString(tt.configured) + require.NoError(t, err) + + entries := []*ForwarderEntry{{Domain: d, ResID: "test-res"}} + forwarder.UpdateDomains(entries) + + query := &dns.Msg{} + query.SetQuestion(dns.Fqdn(tt.queryDomain), tt.queryType) + + // Capture the written response + var writtenResp *dns.Msg + mockWriter := &test.MockResponseWriter{ + WriteMsgFunc: func(m *dns.Msg) error { + writtenResp = m + return nil + }, + } + + _ = forwarder.handleDNSQuery(mockWriter, query) + + // Check the response written to the writer + require.NotNil(t, writtenResp, "Expected response to be written") + assert.Equal(t, tt.expectedCode, writtenResp.Rcode, tt.description) + }) + } +} + +func TestDNSForwarder_TCPTruncation(t *testing.T) { + // Test that large UDP responses are truncated with TC bit set + mockResolver := &MockResolver{} + forwarder := NewDNSForwarder("127.0.0.1:0", 300, nil, &peer.Status{}) + forwarder.resolver = mockResolver + + d, _ := domain.FromString("example.com") + entries := []*ForwarderEntry{{Domain: d, ResID: "test-res"}} + forwarder.UpdateDomains(entries) + + // Mock many IPs to create a large response + var manyIPs []netip.Addr + for i := 0; i < 100; i++ { + manyIPs = append(manyIPs, netip.MustParseAddr(fmt.Sprintf("1.1.1.%d", i%256))) + } + mockResolver.On("LookupNetIP", mock.Anything, "ip4", "example.com.").Return(manyIPs, nil) + + // Query without EDNS0 + query := &dns.Msg{} + query.SetQuestion("example.com.", dns.TypeA) + + var writtenResp *dns.Msg + mockWriter := &test.MockResponseWriter{ + WriteMsgFunc: func(m *dns.Msg) error { + writtenResp = m + return nil + }, + } + forwarder.handleDNSQueryUDP(mockWriter, query) + + require.NotNil(t, writtenResp) + assert.True(t, writtenResp.Truncated, "Large response should be truncated") + assert.LessOrEqual(t, writtenResp.Len(), dns.MinMsgSize, "Response should fit in minimum UDP size") +} + +func TestDNSForwarder_MultipleOverlappingPatterns(t *testing.T) { + // Test complex overlapping pattern scenarios + mockFirewall := &MockFirewall{} + mockResolver := &MockResolver{} + + forwarder := NewDNSForwarder("127.0.0.1:0", 300, mockFirewall, &peer.Status{}) + forwarder.resolver = mockResolver + + // Set up complex overlapping patterns + patterns := []string{ + "*.example.com", // Matches all subdomains + "*.mail.example.com", // More specific wildcard + "smtp.mail.example.com", // Exact match + "example.com", // Base domain + } + + var entries []*ForwarderEntry + sets := make(map[string]firewall.Set) + + for _, pattern := range patterns { + d, _ := domain.FromString(pattern) + set := firewall.NewDomainSet([]domain.Domain{d}) + sets[pattern] = set + entries = append(entries, &ForwarderEntry{ + Domain: d, + ResID: route.ResID("res-" + pattern), + Set: set, + }) + } + + forwarder.UpdateDomains(entries) + + // Test smtp.mail.example.com - should match 3 patterns + fakeIP := netip.MustParseAddr("1.2.3.4") + mockResolver.On("LookupNetIP", mock.Anything, "ip4", "smtp.mail.example.com.").Return([]netip.Addr{fakeIP}, nil) + + expectedPrefix := netip.PrefixFrom(fakeIP, 32) + // All three matching patterns should get firewall updates + mockFirewall.On("UpdateSet", sets["smtp.mail.example.com"], []netip.Prefix{expectedPrefix}).Return(nil) + mockFirewall.On("UpdateSet", sets["*.mail.example.com"], []netip.Prefix{expectedPrefix}).Return(nil) + mockFirewall.On("UpdateSet", sets["*.example.com"], []netip.Prefix{expectedPrefix}).Return(nil) + + query := &dns.Msg{} + query.SetQuestion("smtp.mail.example.com.", dns.TypeA) + + mockWriter := &test.MockResponseWriter{} + resp := forwarder.handleDNSQuery(mockWriter, query) + + require.NotNil(t, resp) + assert.Equal(t, dns.RcodeSuccess, resp.Rcode) + + // Verify all three sets were updated + mockFirewall.AssertExpectations(t) + + // Verify the most specific ResID was selected + // (exact match should win over wildcards) + resID, matches := forwarder.getMatchingEntries("smtp.mail.example.com") + assert.Equal(t, route.ResID("res-smtp.mail.example.com"), resID) + assert.Len(t, matches, 3, "Should match 3 patterns") +} + +func TestDNSForwarder_EmptyQuery(t *testing.T) { + // Test handling of malformed query with no questions + forwarder := NewDNSForwarder("127.0.0.1:0", 300, nil, &peer.Status{}) + + query := &dns.Msg{} + // Don't set any question + + writeCalled := false + mockWriter := &test.MockResponseWriter{ + WriteMsgFunc: func(m *dns.Msg) error { + writeCalled = true + return nil + }, + } + resp := forwarder.handleDNSQuery(mockWriter, query) + + assert.Nil(t, resp, "Should return nil for empty query") + assert.False(t, writeCalled, "Should not write response for empty query") +} diff --git a/client/internal/engine.go b/client/internal/engine.go index 882cf5578..39637dd2f 100644 --- a/client/internal/engine.go +++ b/client/internal/engine.go @@ -1527,6 +1527,7 @@ func (e *Engine) newWgIface() (*iface.WGIface, error) { MTU: iface.DefaultMTU, TransportNet: transportNet, FilterFn: e.addrViaRoutes, + DisableDNS: e.config.DisableDNS, } switch runtime.GOOS { diff --git a/client/internal/netflow/conntrack/conntrack.go b/client/internal/netflow/conntrack/conntrack.go index d01adf135..dbb4747a5 100644 --- a/client/internal/netflow/conntrack/conntrack.go +++ b/client/internal/netflow/conntrack/conntrack.go @@ -204,7 +204,7 @@ func (c *ConnTrack) handleEvent(event nfct.Event) { eventStr = "Ended" } - log.Tracef("%s %s %s connection: %s:%d -> %s:%d", eventStr, direction, proto, srcIP, srcPort, dstIP, dstPort) + log.Tracef("%s %s %s connection: %s:%d → %s:%d", eventStr, direction, proto, srcIP, srcPort, dstIP, dstPort) c.flowLogger.StoreEvent(nftypes.EventFields{ FlowID: flowID, diff --git a/client/internal/peer/status.go b/client/internal/peer/status.go index 629afec9b..e290ef75f 100644 --- a/client/internal/peer/status.go +++ b/client/internal/peer/status.go @@ -575,13 +575,12 @@ func (d *Status) UpdatePeerFQDN(peerPubKey, fqdn string) error { // FinishPeerListModifications this event invoke the notification func (d *Status) FinishPeerListModifications() { d.mux.Lock() + defer d.mux.Unlock() if !d.peerListChangedForNotification { - d.mux.Unlock() return } d.peerListChangedForNotification = false - d.mux.Unlock() d.notifyPeerListChanged() diff --git a/client/internal/routemanager/client/client.go b/client/internal/routemanager/client/client.go index 11c0f5708..46bff96db 100644 --- a/client/internal/routemanager/client/client.go +++ b/client/internal/routemanager/client/client.go @@ -4,7 +4,6 @@ import ( "context" "fmt" "reflect" - "runtime" "time" log "github.com/sirupsen/logrus" @@ -23,7 +22,7 @@ import ( const ( handlerTypeDynamic = iota - handlerTypeDomain + handlerTypeDnsInterceptor handlerTypeStatic ) @@ -566,13 +565,14 @@ func HandlerFromRoute( useNewDNSRoute bool, ) RouteHandler { switch handlerType(rt, useNewDNSRoute) { - case handlerTypeDomain: + case handlerTypeDnsInterceptor: return dnsinterceptor.New( rt, routeRefCounter, allowedIPsRefCounter, statusRecorder, dnsServer, + wgInterface, peerStore, ) case handlerTypeDynamic: @@ -596,8 +596,8 @@ func handlerType(rt *route.Route, useNewDNSRoute bool) int { return handlerTypeStatic } - if useNewDNSRoute && runtime.GOOS != "ios" { - return handlerTypeDomain + if useNewDNSRoute { + return handlerTypeDnsInterceptor } return handlerTypeDynamic } diff --git a/client/internal/routemanager/dnsinterceptor/handler.go b/client/internal/routemanager/dnsinterceptor/handler.go index 78d5e3b30..66557e888 100644 --- a/client/internal/routemanager/dnsinterceptor/handler.go +++ b/client/internal/routemanager/dnsinterceptor/handler.go @@ -12,6 +12,7 @@ import ( log "github.com/sirupsen/logrus" nberrors "github.com/netbirdio/netbird/client/errors" + "github.com/netbirdio/netbird/client/iface/wgaddr" nbdns "github.com/netbirdio/netbird/client/internal/dns" "github.com/netbirdio/netbird/client/internal/dnsfwd" "github.com/netbirdio/netbird/client/internal/peer" @@ -23,6 +24,11 @@ import ( type domainMap map[domain.Domain][]netip.Prefix +type wgInterface interface { + Name() string + Address() wgaddr.Address +} + type DnsInterceptor struct { mu sync.RWMutex route *route.Route @@ -32,6 +38,7 @@ type DnsInterceptor struct { dnsServer nbdns.Server currentPeerKey string interceptedDomains domainMap + wgInterface wgInterface peerStore *peerstore.Store } @@ -41,6 +48,7 @@ func New( allowedIPsRefCounter *refcounter.AllowedIPsRefCounter, statusRecorder *peer.Status, dnsServer nbdns.Server, + wgInterface wgInterface, peerStore *peerstore.Store, ) *DnsInterceptor { return &DnsInterceptor{ @@ -49,6 +57,7 @@ func New( allowedIPsRefcounter: allowedIPsRefCounter, statusRecorder: statusRecorder, dnsServer: dnsServer, + wgInterface: wgInterface, interceptedDomains: make(domainMap), peerStore: peerStore, } @@ -135,15 +144,18 @@ func (d *DnsInterceptor) RemoveAllowedIPs() error { // ServeDNS implements the dns.Handler interface func (d *DnsInterceptor) ServeDNS(w dns.ResponseWriter, r *dns.Msg) { + requestID := nbdns.GenerateRequestID() + logger := log.WithField("request_id", requestID) + if len(r.Question) == 0 { return } - log.Tracef("received DNS request for domain=%s type=%v class=%v", + logger.Tracef("received DNS request for domain=%s type=%v class=%v", r.Question[0].Name, r.Question[0].Qtype, r.Question[0].Qclass) // pass if non A/AAAA query if r.Question[0].Qtype != dns.TypeA && r.Question[0].Qtype != dns.TypeAAAA { - d.continueToNextHandler(w, r, "non A/AAAA query") + d.continueToNextHandler(w, r, logger, "non A/AAAA query") return } @@ -152,29 +164,32 @@ func (d *DnsInterceptor) ServeDNS(w dns.ResponseWriter, r *dns.Msg) { d.mu.RUnlock() if peerKey == "" { - d.writeDNSError(w, r, "no current peer key") + d.writeDNSError(w, r, logger, "no current peer key") return } upstreamIP, err := d.getUpstreamIP(peerKey) if err != nil { - d.writeDNSError(w, r, fmt.Sprintf("get upstream IP: %v", err)) + d.writeDNSError(w, r, logger, fmt.Sprintf("get upstream IP: %v", err)) + return + } + + client, err := nbdns.GetClientPrivate(d.wgInterface.Address().IP, d.wgInterface.Name(), nbdns.UpstreamTimeout) + if err != nil { + d.writeDNSError(w, r, logger, fmt.Sprintf("create DNS client: %v", err)) return } if r.Extra == nil { r.MsgHdr.AuthenticatedData = true } - client := &dns.Client{ - Timeout: nbdns.UpstreamTimeout, - Net: "udp", - } + upstream := fmt.Sprintf("%s:%d", upstreamIP.String(), dnsfwd.ListenPort) reply, _, err := nbdns.ExchangeWithFallback(context.TODO(), client, r, upstream) if err != nil { - log.Errorf("failed to exchange DNS request with %s (%s) for domain=%s: %v", upstreamIP.String(), peerKey, r.Question[0].Name, err) + logger.Errorf("failed to exchange DNS request with %s (%s) for domain=%s: %v", upstreamIP.String(), peerKey, r.Question[0].Name, err) if err := w.WriteMsg(&dns.Msg{MsgHdr: dns.MsgHdr{Rcode: dns.RcodeServerFailure, Id: r.Id}}); err != nil { - log.Errorf("failed writing DNS response: %v", err) + logger.Errorf("failed writing DNS response: %v", err) } return } @@ -184,34 +199,34 @@ func (d *DnsInterceptor) ServeDNS(w dns.ResponseWriter, r *dns.Msg) { answer = reply.Answer } - log.Tracef("upstream %s (%s) DNS response for domain=%s answers=%v", upstreamIP.String(), peerKey, r.Question[0].Name, answer) + logger.Tracef("upstream %s (%s) DNS response for domain=%s answers=%v", upstreamIP.String(), peerKey, r.Question[0].Name, answer) reply.Id = r.Id if err := d.writeMsg(w, reply); err != nil { - log.Errorf("failed writing DNS response: %v", err) + logger.Errorf("failed writing DNS response: %v", err) } } -func (d *DnsInterceptor) writeDNSError(w dns.ResponseWriter, r *dns.Msg, reason string) { - log.Warnf("failed to query upstream for domain=%s: %s", r.Question[0].Name, reason) +func (d *DnsInterceptor) writeDNSError(w dns.ResponseWriter, r *dns.Msg, logger *log.Entry, reason string) { + logger.Warnf("failed to query upstream for domain=%s: %s", r.Question[0].Name, reason) resp := new(dns.Msg) resp.SetRcode(r, dns.RcodeServerFailure) if err := w.WriteMsg(resp); err != nil { - log.Errorf("failed to write DNS error response: %v", err) + logger.Errorf("failed to write DNS error response: %v", err) } } // continueToNextHandler signals the handler chain to try the next handler -func (d *DnsInterceptor) continueToNextHandler(w dns.ResponseWriter, r *dns.Msg, reason string) { - log.Tracef("continuing to next handler for domain=%s reason=%s", r.Question[0].Name, reason) +func (d *DnsInterceptor) continueToNextHandler(w dns.ResponseWriter, r *dns.Msg, logger *log.Entry, reason string) { + logger.Tracef("continuing to next handler for domain=%s reason=%s", r.Question[0].Name, reason) resp := new(dns.Msg) resp.SetRcode(r, dns.RcodeNameError) // Set Zero bit to signal handler chain to continue resp.MsgHdr.Zero = true if err := w.WriteMsg(resp); err != nil { - log.Errorf("failed writing DNS continue response: %v", err) + logger.Errorf("failed writing DNS continue response: %v", err) } } diff --git a/client/internal/routemanager/mock.go b/client/internal/routemanager/mock.go index 63bad689e..742294cdf 100644 --- a/client/internal/routemanager/mock.go +++ b/client/internal/routemanager/mock.go @@ -15,7 +15,7 @@ import ( // MockManager is the mock instance of a route manager type MockManager struct { ClassifyRoutesFunc func(routes []*route.Route) (map[route.ID]*route.Route, route.HAMap) - UpdateRoutesFunc func (updateSerial uint64, serverRoutes map[route.ID]*route.Route, clientRoutes route.HAMap, useNewDNSRoute bool) error + UpdateRoutesFunc func(updateSerial uint64, serverRoutes map[route.ID]*route.Route, clientRoutes route.HAMap, useNewDNSRoute bool) error TriggerSelectionFunc func(haMap route.HAMap) GetRouteSelectorFunc func() *routeselector.RouteSelector GetClientRoutesFunc func() route.HAMap diff --git a/client/internal/routemanager/notifier/notifier.go b/client/internal/routemanager/notifier/notifier.go index 25a3a71e0..3cc7c3308 100644 --- a/client/internal/routemanager/notifier/notifier.go +++ b/client/internal/routemanager/notifier/notifier.go @@ -32,7 +32,6 @@ func (n *Notifier) SetListener(listener listener.NetworkChangeListener) { func (n *Notifier) SetInitialClientRoutes(clientRoutes []*route.Route) { nets := make([]string, 0) for _, r := range clientRoutes { - // filter out domain routes if r.IsDynamic() { continue } @@ -46,30 +45,27 @@ func (n *Notifier) OnNewRoutes(idMap route.HAMap) { if runtime.GOOS != "android" { return } - newNets := make([]string, 0) + + var newNets []string for _, routes := range idMap { for _, r := range routes { + if r.IsDynamic() { + continue + } newNets = append(newNets, r.Network.String()) } } sort.Strings(newNets) - switch runtime.GOOS { - case "android": - if !n.hasDiff(n.initialRouteRanges, newNets) { - return - } - default: - if !n.hasDiff(n.routeRanges, newNets) { - return - } + if !n.hasDiff(n.initialRouteRanges, newNets) { + return } n.routeRanges = newNets - n.notify() } +// OnNewPrefixes is called from iOS only func (n *Notifier) OnNewPrefixes(prefixes []netip.Prefix) { newNets := make([]string, 0) for _, prefix := range prefixes { @@ -77,19 +73,11 @@ func (n *Notifier) OnNewPrefixes(prefixes []netip.Prefix) { } sort.Strings(newNets) - switch runtime.GOOS { - case "android": - if !n.hasDiff(n.initialRouteRanges, newNets) { - return - } - default: - if !n.hasDiff(n.routeRanges, newNets) { - return - } + if !n.hasDiff(n.routeRanges, newNets) { + return } n.routeRanges = newNets - n.notify() } diff --git a/client/internal/state.go b/client/internal/state.go index 4ae99d944..041cb73f8 100644 --- a/client/internal/state.go +++ b/client/internal/state.go @@ -10,10 +10,11 @@ type StatusType string const ( StatusIdle StatusType = "Idle" - StatusConnecting StatusType = "Connecting" - StatusConnected StatusType = "Connected" - StatusNeedsLogin StatusType = "NeedsLogin" - StatusLoginFailed StatusType = "LoginFailed" + StatusConnecting StatusType = "Connecting" + StatusConnected StatusType = "Connected" + StatusNeedsLogin StatusType = "NeedsLogin" + StatusLoginFailed StatusType = "LoginFailed" + StatusSessionExpired StatusType = "SessionExpired" ) // CtxInitState setup context state into the context tree. diff --git a/client/server/server.go b/client/server/server.go index 72837b59d..31a437c99 100644 --- a/client/server/server.go +++ b/client/server/server.go @@ -8,6 +8,7 @@ import ( "runtime" "strconv" "sync" + "sync/atomic" "time" "github.com/cenkalti/backoff/v4" @@ -66,6 +67,7 @@ type Server struct { lastProbe time.Time persistNetworkMap bool + isSessionActive atomic.Bool } type oauthAuthFlow struct { @@ -567,9 +569,6 @@ func (s *Server) WaitSSOLogin(callerCtx context.Context, msg *proto.WaitSSOLogin tokenInfo, err := s.oauthAuthFlow.flow.WaitToken(waitCTX, flowInfo) if err != nil { - if err == context.Canceled { - return nil, nil //nolint:nilnil - } s.mutex.Lock() s.oauthAuthFlow.expiresAt = time.Now() s.mutex.Unlock() @@ -640,6 +639,7 @@ func (s *Server) Up(callerCtx context.Context, _ *proto.UpRequest) (*proto.UpRes for { select { case <-runningChan: + s.isSessionActive.Store(true) return &proto.UpResponse{}, nil case <-callerCtx.Done(): log.Debug("context done, stopping the wait for engine to become ready") @@ -668,6 +668,7 @@ func (s *Server) Down(ctx context.Context, _ *proto.DownRequest) (*proto.DownRes log.Errorf("failed to shut down properly: %v", err) return nil, err } + s.isSessionActive.Store(false) state := internal.CtxGetState(s.rootCtx) state.Set(internal.StatusIdle) @@ -694,6 +695,12 @@ func (s *Server) Status( return nil, err } + if status == internal.StatusNeedsLogin && s.isSessionActive.Load() { + log.Debug("status requested while session is active, returning SessionExpired") + status = internal.StatusSessionExpired + s.isSessionActive.Store(false) + } + statusResponse := proto.StatusResponse{Status: string(status), DaemonVersion: version.NetbirdVersion()} s.statusRecorder.UpdateManagementAddress(s.config.ManagementURL.String()) diff --git a/client/system/info.go b/client/system/info.go index a0a5fe8b3..aff10ece3 100644 --- a/client/system/info.go +++ b/client/system/info.go @@ -59,16 +59,16 @@ type Info struct { Environment Environment Files []File // for posture checks - RosenpassEnabled bool - RosenpassPermissive bool - ServerSSHAllowed bool + RosenpassEnabled bool + RosenpassPermissive bool + ServerSSHAllowed bool - DisableClientRoutes bool - DisableServerRoutes bool - DisableDNS bool - DisableFirewall bool - BlockLANAccess bool - BlockInbound bool + DisableClientRoutes bool + DisableServerRoutes bool + DisableDNS bool + DisableFirewall bool + BlockLANAccess bool + BlockInbound bool LazyConnectionEnabled bool } diff --git a/client/ui/client_ui.go b/client/ui/client_ui.go index 554cfdc44..00a535dd6 100644 --- a/client/ui/client_ui.go +++ b/client/ui/client_ui.go @@ -20,7 +20,10 @@ import ( "fyne.io/fyne/v2" "fyne.io/fyne/v2/app" + "fyne.io/fyne/v2/canvas" + "fyne.io/fyne/v2/container" "fyne.io/fyne/v2/dialog" + "fyne.io/fyne/v2/layout" "fyne.io/fyne/v2/theme" "fyne.io/fyne/v2/widget" "fyne.io/systray" @@ -51,7 +54,7 @@ const ( ) func main() { - daemonAddr, showSettings, showNetworks, showDebug, errorMsg, saveLogsInFile := parseFlags() + daemonAddr, showSettings, showNetworks, showLoginURL, showDebug, errorMsg, saveLogsInFile := parseFlags() // Initialize file logging if needed. var logFile string @@ -77,13 +80,13 @@ func main() { } // Create the service client (this also builds the settings or networks UI if requested). - client := newServiceClient(daemonAddr, logFile, a, showSettings, showNetworks, showDebug) + client := newServiceClient(daemonAddr, logFile, a, showSettings, showNetworks, showLoginURL, showDebug) // Watch for theme/settings changes to update the icon. go watchSettingsChanges(a, client) // Run in window mode if any UI flag was set. - if showSettings || showNetworks || showDebug { + if showSettings || showNetworks || showDebug || showLoginURL { a.Run() return } @@ -104,7 +107,7 @@ func main() { } // parseFlags reads and returns all needed command-line flags. -func parseFlags() (daemonAddr string, showSettings, showNetworks, showDebug bool, errorMsg string, saveLogsInFile bool) { +func parseFlags() (daemonAddr string, showSettings, showNetworks, showLoginURL, showDebug bool, errorMsg string, saveLogsInFile bool) { defaultDaemonAddr := "unix:///var/run/netbird.sock" if runtime.GOOS == "windows" { defaultDaemonAddr = "tcp://127.0.0.1:41731" @@ -112,6 +115,7 @@ func parseFlags() (daemonAddr string, showSettings, showNetworks, showDebug bool flag.StringVar(&daemonAddr, "daemon-addr", defaultDaemonAddr, "Daemon service address to serve CLI requests [unix|tcp]://[path|host:port]") flag.BoolVar(&showSettings, "settings", false, "run settings window") flag.BoolVar(&showNetworks, "networks", false, "run networks window") + flag.BoolVar(&showLoginURL, "login-url", false, "show login URL in a popup window") flag.BoolVar(&showDebug, "debug", false, "run debug window") flag.StringVar(&errorMsg, "error-msg", "", "displays an error message window") flag.BoolVar(&saveLogsInFile, "use-log-file", false, fmt.Sprintf("save logs in a file: %s/netbird-ui-PID.log", os.TempDir())) @@ -253,6 +257,7 @@ type serviceClient struct { exitNodeStates []exitNodeState mExitNodeDeselectAll *systray.MenuItem logFile string + wLoginURL fyne.Window } type menuHandler struct { @@ -263,7 +268,7 @@ type menuHandler struct { // newServiceClient instance constructor // // This constructor also builds the UI elements for the settings window. -func newServiceClient(addr string, logFile string, a fyne.App, showSettings bool, showNetworks bool, showDebug bool) *serviceClient { +func newServiceClient(addr string, logFile string, a fyne.App, showSettings bool, showNetworks bool, showLoginURL bool, showDebug bool) *serviceClient { ctx, cancel := context.WithCancel(context.Background()) s := &serviceClient{ ctx: ctx, @@ -286,6 +291,8 @@ func newServiceClient(addr string, logFile string, a fyne.App, showSettings bool s.showSettingsUI() case showNetworks: s.showNetworksUI() + case showLoginURL: + s.showLoginURL() case showDebug: s.showDebugUI() } @@ -445,11 +452,11 @@ func (s *serviceClient) getSettingsForm() *widget.Form { } } -func (s *serviceClient) login() error { +func (s *serviceClient) login(openURL bool) (*proto.LoginResponse, error) { conn, err := s.getSrvClient(defaultFailTimeout) if err != nil { log.Errorf("get client: %v", err) - return err + return nil, err } loginResp, err := conn.Login(s.ctx, &proto.LoginRequest{ @@ -457,24 +464,24 @@ func (s *serviceClient) login() error { }) if err != nil { log.Errorf("login to management URL with: %v", err) - return err + return nil, err } - if loginResp.NeedsSSOLogin { + if loginResp.NeedsSSOLogin && openURL { err = open.Run(loginResp.VerificationURIComplete) if err != nil { log.Errorf("opening the verification uri in the browser failed: %v", err) - return err + return nil, err } _, err = conn.WaitSSOLogin(s.ctx, &proto.WaitSSOLoginRequest{UserCode: loginResp.UserCode}) if err != nil { log.Errorf("waiting sso login failed with: %v", err) - return err + return nil, err } } - return nil + return loginResp, nil } func (s *serviceClient) menuUpClick() error { @@ -486,7 +493,7 @@ func (s *serviceClient) menuUpClick() error { return err } - err = s.login() + _, err = s.login(true) if err != nil { log.Errorf("login failed with: %v", err) return err @@ -558,7 +565,7 @@ func (s *serviceClient) updateStatus() error { defer s.updateIndicationLock.Unlock() // notify the user when the session has expired - if status.Status == string(internal.StatusNeedsLogin) { + if status.Status == string(internal.StatusSessionExpired) { s.onSessionExpire() } @@ -732,7 +739,6 @@ func (s *serviceClient) onTrayReady() { go s.eventHandler.listen(s.ctx) } - func (s *serviceClient) attachOutput(cmd *exec.Cmd) *os.File { if s.logFile == "" { // attach child's streams to parent's streams @@ -871,17 +877,9 @@ func (s *serviceClient) onUpdateAvailable() { // onSessionExpire sends a notification to the user when the session expires. func (s *serviceClient) onSessionExpire() { + s.sendNotification = true if s.sendNotification { - title := "Connection session expired" - if runtime.GOOS == "darwin" { - title = "NetBird connection session expired" - } - s.app.SendNotification( - fyne.NewNotification( - title, - "Please re-authenticate to connect to the network", - ), - ) + s.eventHandler.runSelfCommand("login-url", "true") s.sendNotification = false } } @@ -955,9 +953,9 @@ func (s *serviceClient) updateConfig() error { ServerSSHAllowed: &sshAllowed, RosenpassEnabled: &rosenpassEnabled, DisableAutoConnect: &disableAutoStart, + DisableNotifications: ¬ificationsDisabled, LazyConnectionEnabled: &lazyConnectionEnabled, BlockInbound: &blockInbound, - DisableNotifications: ¬ificationsDisabled, } if err := s.restartClient(&loginRequest); err != nil { @@ -991,6 +989,87 @@ func (s *serviceClient) restartClient(loginRequest *proto.LoginRequest) error { return nil } +// showLoginURL creates a borderless window styled like a pop-up in the top-right corner using s.wLoginURL. +func (s *serviceClient) showLoginURL() { + + resp, err := s.login(false) + if err != nil { + log.Errorf("failed to fetch login URL: %v", err) + return + } + verificationURL := resp.VerificationURIComplete + if verificationURL == "" { + verificationURL = resp.VerificationURI + } + + if verificationURL == "" { + log.Error("no verification URL provided in the login response") + return + } + + resIcon := fyne.NewStaticResource("netbird.png", iconAbout) + + if s.wLoginURL == nil { + s.wLoginURL = s.app.NewWindow("NetBird Session Expired") + s.wLoginURL.Resize(fyne.NewSize(400, 200)) + s.wLoginURL.SetIcon(resIcon) + } + // add a description label + label := widget.NewLabel("Your NetBird session has expired.\nPlease re-authenticate to continue using NetBird.") + + btn := widget.NewButtonWithIcon("Re-authenticate", theme.ViewRefreshIcon(), func() { + + conn, err := s.getSrvClient(defaultFailTimeout) + if err != nil { + log.Errorf("get client: %v", err) + return + } + + if err := openURL(verificationURL); err != nil { + log.Errorf("failed to open login URL: %v", err) + return + } + + _, err = conn.WaitSSOLogin(s.ctx, &proto.WaitSSOLoginRequest{UserCode: resp.UserCode}) + if err != nil { + log.Errorf("Waiting sso login failed with: %v", err) + label.SetText("Waiting login failed, please create \na debug bundle in the settings and contact support.") + return + } + + label.SetText("Re-authentication successful.\nReconnecting") + time.Sleep(300 * time.Millisecond) + _, err = conn.Up(s.ctx, &proto.UpRequest{}) + if err != nil { + label.SetText("Reconnecting failed, please create \na debug bundle in the settings and contact support.") + log.Errorf("Reconnecting failed with: %v", err) + return + } + + label.SetText("Connection successful.\nClosing this window.") + time.Sleep(time.Second) + + s.wLoginURL.Close() + }) + + img := canvas.NewImageFromResource(resIcon) + img.FillMode = canvas.ImageFillContain + img.SetMinSize(fyne.NewSize(64, 64)) + img.Resize(fyne.NewSize(64, 64)) + + // center the content vertically + content := container.NewVBox( + layout.NewSpacer(), + img, + label, + btn, + layout.NewSpacer(), + ) + s.wLoginURL.SetContent(container.NewCenter(content)) + + s.wLoginURL.Show() +} + func openURL(url string) error { var err error switch runtime.GOOS { diff --git a/client/ui/event_handler.go b/client/ui/event_handler.go index f7072c6b8..5441f3481 100644 --- a/client/ui/event_handler.go +++ b/client/ui/event_handler.go @@ -12,6 +12,8 @@ import ( "fyne.io/fyne/v2" "fyne.io/systray" log "github.com/sirupsen/logrus" + + "github.com/netbirdio/netbird/version" ) type eventHandler struct { @@ -143,7 +145,7 @@ func (h *eventHandler) handleGitHubClick() { } func (h *eventHandler) handleUpdateClick() { - if err := openURL("https://netbird.io/download"); err != nil { + if err := openURL(version.DownloadUrl()); err != nil { log.Errorf("failed to open download URL: %v", err) } } diff --git a/client/ui/network.go b/client/ui/network.go index b3748a89d..fb73efd7b 100644 --- a/client/ui/network.go +++ b/client/ui/network.go @@ -358,8 +358,6 @@ func (s *serviceClient) updateExitNodes() { } else { s.mExitNode.Disable() } - - log.Debugf("Exit nodes updated: %d", len(s.mExitNodeItems)) } func (s *serviceClient) recreateExitNodeMenu(exitNodes []*proto.Network) { diff --git a/go.mod b/go.mod index 11dc88c43..a12058278 100644 --- a/go.mod +++ b/go.mod @@ -63,7 +63,7 @@ require ( github.com/miekg/dns v1.1.59 github.com/mitchellh/hashstructure/v2 v2.0.2 github.com/nadoo/ipset v0.5.0 - github.com/netbirdio/management-integrations/integrations v0.0.0-20250529122842-6700aa91190c + github.com/netbirdio/management-integrations/integrations v0.0.0-20250612164546-6bd7e2338d65 github.com/netbirdio/signal-dispatcher/dispatcher v0.0.0-20250514131221-a464fd5f30cb github.com/okta/okta-sdk-golang/v2 v2.18.0 github.com/oschwald/maxminddb-golang v1.12.0 diff --git a/go.sum b/go.sum index f887cee94..6ce503dd1 100644 --- a/go.sum +++ b/go.sum @@ -503,8 +503,8 @@ github.com/netbirdio/go-netroute v0.0.0-20240611143515-f59b0e1d3944 h1:TDtJKmM6S github.com/netbirdio/go-netroute v0.0.0-20240611143515-f59b0e1d3944/go.mod h1:sHA6TRxjQ6RLbnI+3R4DZo2Eseg/iKiPRfNmcuNySVQ= github.com/netbirdio/ice/v3 v3.0.0-20240315174635-e72a50fcb64e h1:PURA50S8u4mF6RrkYYCAvvPCixhqqEiEy3Ej6avh04c= github.com/netbirdio/ice/v3 v3.0.0-20240315174635-e72a50fcb64e/go.mod h1:YMLU7qbKfVjmEv7EoZPIVEI+kNYxWCdPK3VS0BU+U4Q= -github.com/netbirdio/management-integrations/integrations v0.0.0-20250529122842-6700aa91190c h1:SdZxYjR9XXHLyRsTbS1EHBr6+RI15oie1K9Q8yvi3FY= -github.com/netbirdio/management-integrations/integrations v0.0.0-20250529122842-6700aa91190c/go.mod h1:Gi9raplYzCCyh07Olw/DVfCJTFgpr1WCXJ/Q+8TSA9Q= +github.com/netbirdio/management-integrations/integrations v0.0.0-20250612164546-6bd7e2338d65 h1:5OfYiLjpr4dbQYJI5ouZaylkVdi2KlErLFOwBeBo5Hw= +github.com/netbirdio/management-integrations/integrations v0.0.0-20250612164546-6bd7e2338d65/go.mod h1:Gi9raplYzCCyh07Olw/DVfCJTFgpr1WCXJ/Q+8TSA9Q= github.com/netbirdio/service v0.0.0-20240911161631-f62744f42502 h1:3tHlFmhTdX9axERMVN63dqyFqnvuD+EMJHzM7mNGON8= github.com/netbirdio/service v0.0.0-20240911161631-f62744f42502/go.mod h1:CIMRFEJVL+0DS1a3Nx06NaMn4Dz63Ng6O7dl0qH0zVM= github.com/netbirdio/signal-dispatcher/dispatcher v0.0.0-20250514131221-a464fd5f30cb h1:Cr6age+ePALqlSvtp7wc6lYY97XN7rkD1K4XEDmY+TU= diff --git a/management/cmd/management.go b/management/cmd/management.go index 5fb07890f..bce09efdd 100644 --- a/management/cmd/management.go +++ b/management/cmd/management.go @@ -159,6 +159,12 @@ var ( if err != nil { return err } + + integrationMetrics, err := integrations.InitIntegrationMetrics(ctx, appMetrics) + if err != nil { + return err + } + store, err := store.NewStore(ctx, config.StoreConfig.Engine, config.Datadir, appMetrics, false) if err != nil { return fmt.Errorf("failed creating Store: %s: %v", config.Datadir, err) @@ -176,7 +182,7 @@ var ( if disableSingleAccMode { mgmtSingleAccModeDomain = "" } - eventStore, key, err := integrations.InitEventStore(ctx, config.Datadir, config.DataStoreEncryptionKey) + eventStore, key, err := integrations.InitEventStore(ctx, config.Datadir, config.DataStoreEncryptionKey, integrationMetrics) if err != nil { return fmt.Errorf("failed to initialize database: %s", err) } diff --git a/management/server/account.go b/management/server/account.go index 82f5ee4a3..daeaf6e55 100644 --- a/management/server/account.go +++ b/management/server/account.go @@ -24,6 +24,7 @@ import ( "golang.org/x/exp/maps" nbdns "github.com/netbirdio/netbird/dns" + "github.com/netbirdio/netbird/formatter/hook" "github.com/netbirdio/netbird/management/server/account" "github.com/netbirdio/netbird/management/server/activity" nbcache "github.com/netbirdio/netbird/management/server/cache" @@ -409,14 +410,15 @@ func (am *DefaultAccountManager) handlePeerLoginExpirationSettings(ctx context.C event = activity.AccountPeerLoginExpirationDisabled am.peerLoginExpiry.Cancel(ctx, []string{accountID}) } else { - am.checkAndSchedulePeerLoginExpiration(ctx, accountID) + am.schedulePeerLoginExpiration(ctx, accountID) } am.StoreEvent(ctx, userID, accountID, accountID, event, nil) } if oldSettings.PeerLoginExpiration != newSettings.PeerLoginExpiration { am.StoreEvent(ctx, userID, accountID, accountID, activity.AccountPeerLoginExpirationDurationUpdated, nil) - am.checkAndSchedulePeerLoginExpiration(ctx, accountID) + am.peerLoginExpiry.Cancel(ctx, []string{accountID}) + am.schedulePeerLoginExpiration(ctx, accountID) } } @@ -454,6 +456,10 @@ func (am *DefaultAccountManager) handleInactivityExpirationSettings(ctx context. func (am *DefaultAccountManager) peerLoginExpirationJob(ctx context.Context, accountID string) func() (time.Duration, bool) { return func() (time.Duration, bool) { + //nolint + ctx := context.WithValue(ctx, nbcontext.AccountIDKey, accountID) + //nolint + ctx = context.WithValue(ctx, hook.ExecutionContextKey, fmt.Sprintf("%s-PEER-EXPIRATION", hook.SystemSource)) unlock := am.Store.AcquireWriteLockByUID(ctx, accountID) defer unlock() @@ -478,8 +484,11 @@ func (am *DefaultAccountManager) peerLoginExpirationJob(ctx context.Context, acc } } -func (am *DefaultAccountManager) checkAndSchedulePeerLoginExpiration(ctx context.Context, accountID string) { - am.peerLoginExpiry.Cancel(ctx, []string{accountID}) +func (am *DefaultAccountManager) schedulePeerLoginExpiration(ctx context.Context, accountID string) { + if am.peerLoginExpiry.IsSchedulerRunning(accountID) { + log.WithContext(ctx).Tracef("peer login expiration job for account %s is already scheduled", accountID) + return + } if nextRun, ok := am.getNextPeerExpiration(ctx, accountID); ok { go am.peerLoginExpiry.Schedule(ctx, nextRun, accountID, am.peerLoginExpirationJob(ctx, accountID)) } diff --git a/management/server/account_test.go b/management/server/account_test.go index ba0191c03..c3b1f31a6 100644 --- a/management/server/account_test.go +++ b/management/server/account_test.go @@ -1862,11 +1862,8 @@ func TestDefaultAccountManager_MarkPeerConnected_PeerLoginExpiration(t *testing. require.NoError(t, err, "expecting to update account settings successfully but got error") wg := &sync.WaitGroup{} - wg.Add(2) + wg.Add(1) manager.peerLoginExpiry = &MockScheduler{ - CancelFunc: func(ctx context.Context, IDs []string) { - wg.Done() - }, ScheduleFunc: func(ctx context.Context, in time.Duration, ID string, job func() (nextRunIn time.Duration, reschedule bool)) { wg.Done() }, diff --git a/management/server/group.go b/management/server/group.go index c26a0cfc1..130a67145 100644 --- a/management/server/group.go +++ b/management/server/group.go @@ -664,15 +664,6 @@ func areGroupChangesAffectPeers(ctx context.Context, transaction store.Store, ac return false, nil } -func (am *DefaultAccountManager) anyGroupHasPeers(account *types.Account, groupIDs []string) bool { - for _, groupID := range groupIDs { - if group, exists := account.Groups[groupID]; exists && group.HasPeers() { - return true - } - } - return false -} - // anyGroupHasPeersOrResources checks if any of the given groups in the account have peers or resources. func anyGroupHasPeersOrResources(ctx context.Context, transaction store.Store, accountID string, groupIDs []string) (bool, error) { groups, err := transaction.GetGroupsByIDs(ctx, store.LockingStrengthShare, accountID, groupIDs) diff --git a/management/server/peer.go b/management/server/peer.go index 4a468a6cd..f2469e09b 100644 --- a/management/server/peer.go +++ b/management/server/peer.go @@ -133,7 +133,7 @@ func (am *DefaultAccountManager) MarkPeerConnected(ctx context.Context, peerPubK } if peer.LoginExpirationEnabled && settings.PeerLoginExpirationEnabled { - am.checkAndSchedulePeerLoginExpiration(ctx, accountID) + am.schedulePeerLoginExpiration(ctx, accountID) } if peer.InactivityExpirationEnabled && settings.PeerInactivityExpirationEnabled { @@ -296,7 +296,8 @@ func (am *DefaultAccountManager) UpdatePeer(ctx context.Context, accountID, user am.StoreEvent(ctx, userID, peer.IP.String(), accountID, event, peer.EventMeta(dnsDomain)) if peer.AddedWithSSOLogin() && peer.LoginExpirationEnabled && settings.PeerLoginExpirationEnabled { - am.checkAndSchedulePeerLoginExpiration(ctx, accountID) + am.peerLoginExpiry.Cancel(ctx, []string{accountID}) + am.schedulePeerLoginExpiration(ctx, accountID) } } diff --git a/management/server/route.go b/management/server/route.go index 02755a708..32ff39977 100644 --- a/management/server/route.go +++ b/management/server/route.go @@ -4,19 +4,19 @@ import ( "context" "fmt" "net/netip" + "slices" "unicode/utf8" "github.com/rs/xid" - "github.com/netbirdio/netbird/management/server/permissions/modules" - "github.com/netbirdio/netbird/management/server/permissions/operations" - "github.com/netbirdio/netbird/management/server/store" - "github.com/netbirdio/netbird/management/server/types" - "github.com/netbirdio/netbird/management/domain" "github.com/netbirdio/netbird/management/proto" "github.com/netbirdio/netbird/management/server/activity" + "github.com/netbirdio/netbird/management/server/permissions/modules" + "github.com/netbirdio/netbird/management/server/permissions/operations" "github.com/netbirdio/netbird/management/server/status" + "github.com/netbirdio/netbird/management/server/store" + "github.com/netbirdio/netbird/management/server/types" "github.com/netbirdio/netbird/route" ) @@ -30,13 +30,19 @@ func (am *DefaultAccountManager) GetRoute(ctx context.Context, accountID string, return nil, status.NewPermissionDeniedError() } - return am.Store.GetRouteByID(ctx, store.LockingStrengthShare, string(routeID), accountID) + return am.Store.GetRouteByID(ctx, store.LockingStrengthShare, accountID, string(routeID)) } // checkRoutePrefixOrDomainsExistForPeers checks if a route with a given prefix exists for a single peer or multiple peer groups. -func (am *DefaultAccountManager) checkRoutePrefixOrDomainsExistForPeers(account *types.Account, peerID string, routeID route.ID, peerGroupIDs []string, prefix netip.Prefix, domains domain.List) error { +func checkRoutePrefixOrDomainsExistForPeers(ctx context.Context, transaction store.Store, accountID string, checkRoute *route.Route, groupsMap map[string]*types.Group) error { // routes can have both peer and peer_groups - routesWithPrefix := account.GetRoutesByPrefixOrDomains(prefix, domains) + prefix := checkRoute.Network + domains := checkRoute.Domains + + routesWithPrefix, err := getRoutesByPrefixOrDomains(ctx, transaction, accountID, prefix, domains) + if err != nil { + return err + } // lets remember all the peers and the peer groups from routesWithPrefix seenPeers := make(map[string]bool) @@ -45,18 +51,24 @@ func (am *DefaultAccountManager) checkRoutePrefixOrDomainsExistForPeers(account for _, prefixRoute := range routesWithPrefix { // we skip route(s) with the same network ID as we want to allow updating of the existing route // when creating a new route routeID is newly generated so nothing will be skipped - if routeID == prefixRoute.ID { + if checkRoute.ID == prefixRoute.ID { continue } if prefixRoute.Peer != "" { seenPeers[string(prefixRoute.ID)] = true } + + peerGroupsMap, err := transaction.GetGroupsByIDs(ctx, store.LockingStrengthShare, accountID, prefixRoute.PeerGroups) + if err != nil { + return err + } + for _, groupID := range prefixRoute.PeerGroups { seenPeerGroups[groupID] = true - group := account.GetGroup(groupID) - if group == nil { + group, ok := peerGroupsMap[groupID] + if !ok || group == nil { return status.Errorf( status.InvalidArgument, "failed to add route with %s - peer group %s doesn't exist", getRouteDescriptor(prefix, domains), groupID, @@ -69,12 +81,13 @@ func (am *DefaultAccountManager) checkRoutePrefixOrDomainsExistForPeers(account } } - if peerID != "" { + if peerID := checkRoute.Peer; peerID != "" { // check that peerID exists and is not in any route as single peer or part of the group - peer := account.GetPeer(peerID) - if peer == nil { + _, err = transaction.GetPeerByID(context.Background(), store.LockingStrengthShare, accountID, peerID) + if err != nil { return status.Errorf(status.InvalidArgument, "peer with ID %s not found", peerID) } + if _, ok := seenPeers[peerID]; ok { return status.Errorf(status.AlreadyExists, "failed to add route with %s - peer %s already has this route", getRouteDescriptor(prefix, domains), peerID) @@ -82,9 +95,8 @@ func (am *DefaultAccountManager) checkRoutePrefixOrDomainsExistForPeers(account } // check that peerGroupIDs are not in any route peerGroups list - for _, groupID := range peerGroupIDs { - group := account.GetGroup(groupID) // we validated the group existence before entering this function, no need to check again. - + for _, groupID := range checkRoute.PeerGroups { + group := groupsMap[groupID] // we validated the group existence before entering this function, no need to check again. if _, ok := seenPeerGroups[groupID]; ok { return status.Errorf( status.AlreadyExists, "failed to add route with %s - peer group %s already has this route", @@ -92,12 +104,18 @@ func (am *DefaultAccountManager) checkRoutePrefixOrDomainsExistForPeers(account } // check that the peers from peerGroupIDs groups are not the same peers we saw in routesWithPrefix + peersMap, err := transaction.GetPeersByIDs(ctx, store.LockingStrengthShare, accountID, group.Peers) + if err != nil { + return err + } + for _, id := range group.Peers { if _, ok := seenPeers[id]; ok { - peer := account.GetPeer(id) - if peer == nil { - return status.Errorf(status.InvalidArgument, "peer with ID %s not found", peerID) + peer, ok := peersMap[id] + if !ok || peer == nil { + return status.Errorf(status.InvalidArgument, "peer with ID %s not found", id) } + return status.Errorf(status.AlreadyExists, "failed to add route with %s - peer %s from the group %s already has this route", getRouteDescriptor(prefix, domains), peer.Name, group.Name) @@ -128,97 +146,58 @@ func (am *DefaultAccountManager) CreateRoute(ctx context.Context, accountID stri return nil, status.NewPermissionDeniedError() } - account, err := am.Store.GetAccount(ctx, accountID) - if err != nil { - return nil, err - } - if len(domains) > 0 && prefix.IsValid() { return nil, status.Errorf(status.InvalidArgument, "domains and network should not be provided at the same time") } - if len(domains) == 0 && !prefix.IsValid() { - return nil, status.Errorf(status.InvalidArgument, "invalid Prefix") - } + var newRoute *route.Route + var updateAccountPeers bool - if len(domains) > 0 { - prefix = getPlaceholderIP() - } - - if peerID != "" && len(peerGroupIDs) != 0 { - return nil, status.Errorf( - status.InvalidArgument, - "peer with ID %s and peers group %s should not be provided at the same time", - peerID, peerGroupIDs) - } - - var newRoute route.Route - newRoute.ID = route.ID(xid.New().String()) - - if len(peerGroupIDs) > 0 { - err = validateGroups(peerGroupIDs, account.Groups) - if err != nil { - return nil, err + err = am.Store.ExecuteInTransaction(ctx, func(transaction store.Store) error { + newRoute = &route.Route{ + ID: route.ID(xid.New().String()), + AccountID: accountID, + Network: prefix, + Domains: domains, + KeepRoute: keepRoute, + NetID: netID, + Description: description, + Peer: peerID, + PeerGroups: peerGroupIDs, + NetworkType: networkType, + Masquerade: masquerade, + Metric: metric, + Enabled: enabled, + Groups: groups, + AccessControlGroups: accessControlGroupIDs, } - } - if len(accessControlGroupIDs) > 0 { - err = validateGroups(accessControlGroupIDs, account.Groups) - if err != nil { - return nil, err + if err = validateRoute(ctx, transaction, accountID, newRoute); err != nil { + return err } - } - err = am.checkRoutePrefixOrDomainsExistForPeers(account, peerID, newRoute.ID, peerGroupIDs, prefix, domains) + updateAccountPeers, err = areRouteChangesAffectPeers(ctx, transaction, newRoute) + if err != nil { + return err + } + + if err = transaction.IncrementNetworkSerial(ctx, store.LockingStrengthUpdate, accountID); err != nil { + return err + } + + return transaction.SaveRoute(ctx, store.LockingStrengthUpdate, newRoute) + }) if err != nil { return nil, err } - if metric < route.MinMetric || metric > route.MaxMetric { - return nil, status.Errorf(status.InvalidArgument, "metric should be between %d and %d", route.MinMetric, route.MaxMetric) - } - - if utf8.RuneCountInString(string(netID)) > route.MaxNetIDChar || netID == "" { - return nil, status.Errorf(status.InvalidArgument, "identifier should be between 1 and %d", route.MaxNetIDChar) - } - - err = validateGroups(groups, account.Groups) - if err != nil { - return nil, err - } - - newRoute.Peer = peerID - newRoute.PeerGroups = peerGroupIDs - newRoute.Network = prefix - newRoute.Domains = domains - newRoute.NetworkType = networkType - newRoute.Description = description - newRoute.NetID = netID - newRoute.Masquerade = masquerade - newRoute.Metric = metric - newRoute.Enabled = enabled - newRoute.Groups = groups - newRoute.KeepRoute = keepRoute - newRoute.AccessControlGroups = accessControlGroupIDs - - if account.Routes == nil { - account.Routes = make(map[route.ID]*route.Route) - } - - account.Routes[newRoute.ID] = &newRoute - - account.Network.IncSerial() - if err = am.Store.SaveAccount(ctx, account); err != nil { - return nil, err - } - - if am.isRouteChangeAffectPeers(account, &newRoute) { - am.UpdateAccountPeers(ctx, accountID) - } - am.StoreEvent(ctx, userID, string(newRoute.ID), accountID, activity.RouteCreated, newRoute.EventMeta()) - return &newRoute, nil + if updateAccountPeers { + am.UpdateAccountPeers(ctx, accountID) + } + + return newRoute, nil } // SaveRoute saves route @@ -226,6 +205,115 @@ func (am *DefaultAccountManager) SaveRoute(ctx context.Context, accountID, userI unlock := am.Store.AcquireWriteLockByUID(ctx, accountID) defer unlock() + allowed, err := am.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Routes, operations.Update) + if err != nil { + return status.NewPermissionValidationError(err) + } + if !allowed { + return status.NewPermissionDeniedError() + } + + var oldRoute *route.Route + var oldRouteAffectsPeers bool + var newRouteAffectsPeers bool + + err = am.Store.ExecuteInTransaction(ctx, func(transaction store.Store) error { + if err = validateRoute(ctx, transaction, accountID, routeToSave); err != nil { + return err + } + + oldRoute, err = transaction.GetRouteByID(ctx, store.LockingStrengthUpdate, accountID, string(routeToSave.ID)) + if err != nil { + return err + } + + oldRouteAffectsPeers, err = areRouteChangesAffectPeers(ctx, transaction, oldRoute) + if err != nil { + return err + } + + newRouteAffectsPeers, err = areRouteChangesAffectPeers(ctx, transaction, routeToSave) + if err != nil { + return err + } + routeToSave.AccountID = accountID + + if err = transaction.IncrementNetworkSerial(ctx, store.LockingStrengthUpdate, accountID); err != nil { + return err + } + + return transaction.SaveRoute(ctx, store.LockingStrengthUpdate, routeToSave) + }) + if err != nil { + return err + } + + am.StoreEvent(ctx, userID, string(routeToSave.ID), accountID, activity.RouteUpdated, routeToSave.EventMeta()) + + if oldRouteAffectsPeers || newRouteAffectsPeers { + am.UpdateAccountPeers(ctx, accountID) + } + + return nil +} + +// DeleteRoute deletes route with routeID +func (am *DefaultAccountManager) DeleteRoute(ctx context.Context, accountID string, routeID route.ID, userID string) error { + unlock := am.Store.AcquireWriteLockByUID(ctx, accountID) + defer unlock() + + allowed, err := am.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Routes, operations.Delete) + if err != nil { + return status.NewPermissionValidationError(err) + } + if !allowed { + return status.NewPermissionDeniedError() + } + + var route *route.Route + var updateAccountPeers bool + + err = am.Store.ExecuteInTransaction(ctx, func(transaction store.Store) error { + route, err = transaction.GetRouteByID(ctx, store.LockingStrengthUpdate, accountID, string(routeID)) + if err != nil { + return err + } + + updateAccountPeers, err = areRouteChangesAffectPeers(ctx, transaction, route) + if err != nil { + return err + } + + if err = transaction.IncrementNetworkSerial(ctx, store.LockingStrengthUpdate, accountID); err != nil { + return err + } + + return transaction.DeleteRoute(ctx, store.LockingStrengthUpdate, accountID, string(routeID)) + }) + + am.StoreEvent(ctx, userID, string(route.ID), accountID, activity.RouteRemoved, route.EventMeta()) + + if updateAccountPeers { + am.UpdateAccountPeers(ctx, accountID) + } + + return nil +} + +// ListRoutes returns a list of routes from account +func (am *DefaultAccountManager) ListRoutes(ctx context.Context, accountID, userID string) ([]*route.Route, error) { + allowed, err := am.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Routes, operations.Read) + if err != nil { + return nil, status.NewPermissionValidationError(err) + } + if !allowed { + return nil, status.NewPermissionDeniedError() + } + + return am.Store.GetAccountRoutes(ctx, store.LockingStrengthShare, accountID) +} + +func validateRoute(ctx context.Context, transaction store.Store, accountID string, routeToSave *route.Route) error { if routeToSave == nil { return status.Errorf(status.InvalidArgument, "route provided is nil") } @@ -238,19 +326,6 @@ func (am *DefaultAccountManager) SaveRoute(ctx context.Context, accountID, userI return status.Errorf(status.InvalidArgument, "identifier should be between 1 and %d", route.MaxNetIDChar) } - allowed, err := am.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Routes, operations.Update) - if err != nil { - return status.NewPermissionValidationError(err) - } - if !allowed { - return status.NewPermissionDeniedError() - } - - account, err := am.Store.GetAccount(ctx, accountID) - if err != nil { - return err - } - if len(routeToSave.Domains) > 0 && routeToSave.Network.IsValid() { return status.Errorf(status.InvalidArgument, "domains and network should not be provided at the same time") } @@ -267,96 +342,39 @@ func (am *DefaultAccountManager) SaveRoute(ctx context.Context, accountID, userI return status.Errorf(status.InvalidArgument, "peer with ID and peer groups should not be provided at the same time") } + groupsMap, err := validateRouteGroups(ctx, transaction, accountID, routeToSave) + if err != nil { + return err + } + + return checkRoutePrefixOrDomainsExistForPeers(ctx, transaction, accountID, routeToSave, groupsMap) +} + +// validateRouteGroups validates the route groups and returns the validated groups map. +func validateRouteGroups(ctx context.Context, transaction store.Store, accountID string, routeToSave *route.Route) (map[string]*types.Group, error) { + groupsToValidate := slices.Concat(routeToSave.Groups, routeToSave.PeerGroups, routeToSave.AccessControlGroups) + groupsMap, err := transaction.GetGroupsByIDs(ctx, store.LockingStrengthShare, accountID, groupsToValidate) + if err != nil { + return nil, err + } + if len(routeToSave.PeerGroups) > 0 { - err = validateGroups(routeToSave.PeerGroups, account.Groups) - if err != nil { - return err + if err = validateGroups(routeToSave.PeerGroups, groupsMap); err != nil { + return nil, err } } if len(routeToSave.AccessControlGroups) > 0 { - err = validateGroups(routeToSave.AccessControlGroups, account.Groups) - if err != nil { - return err + if err = validateGroups(routeToSave.AccessControlGroups, groupsMap); err != nil { + return nil, err } } - err = am.checkRoutePrefixOrDomainsExistForPeers(account, routeToSave.Peer, routeToSave.ID, routeToSave.Copy().PeerGroups, routeToSave.Network, routeToSave.Domains) - if err != nil { - return err + if err = validateGroups(routeToSave.Groups, groupsMap); err != nil { + return nil, err } - err = validateGroups(routeToSave.Groups, account.Groups) - if err != nil { - return err - } - - oldRoute := account.Routes[routeToSave.ID] - account.Routes[routeToSave.ID] = routeToSave - - account.Network.IncSerial() - if err = am.Store.SaveAccount(ctx, account); err != nil { - return err - } - - if am.isRouteChangeAffectPeers(account, oldRoute) || am.isRouteChangeAffectPeers(account, routeToSave) { - am.UpdateAccountPeers(ctx, accountID) - } - - am.StoreEvent(ctx, userID, string(routeToSave.ID), accountID, activity.RouteUpdated, routeToSave.EventMeta()) - - return nil -} - -// DeleteRoute deletes route with routeID -func (am *DefaultAccountManager) DeleteRoute(ctx context.Context, accountID string, routeID route.ID, userID string) error { - unlock := am.Store.AcquireWriteLockByUID(ctx, accountID) - defer unlock() - - allowed, err := am.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Routes, operations.Delete) - if err != nil { - return status.NewPermissionValidationError(err) - } - if !allowed { - return status.NewPermissionDeniedError() - } - - account, err := am.Store.GetAccount(ctx, accountID) - if err != nil { - return err - } - - routy := account.Routes[routeID] - if routy == nil { - return status.Errorf(status.NotFound, "route with ID %s doesn't exist", routeID) - } - delete(account.Routes, routeID) - - account.Network.IncSerial() - if err = am.Store.SaveAccount(ctx, account); err != nil { - return err - } - - am.StoreEvent(ctx, userID, string(routy.ID), accountID, activity.RouteRemoved, routy.EventMeta()) - - if am.isRouteChangeAffectPeers(account, routy) { - am.UpdateAccountPeers(ctx, accountID) - } - - return nil -} - -// ListRoutes returns a list of routes from account -func (am *DefaultAccountManager) ListRoutes(ctx context.Context, accountID, userID string) ([]*route.Route, error) { - allowed, err := am.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Routes, operations.Read) - if err != nil { - return nil, status.NewPermissionValidationError(err) - } - if !allowed { - return nil, status.NewPermissionDeniedError() - } - - return am.Store.GetAccountRoutes(ctx, store.LockingStrengthShare, accountID) + return groupsMap, nil } func toProtocolRoute(route *route.Route) *proto.Route { @@ -455,8 +473,40 @@ func getProtoPortInfo(rule *types.RouteFirewallRule) *proto.PortInfo { return &portInfo } -// isRouteChangeAffectPeers checks if a given route affects peers by determining -// if it has a routing peer, distribution, or peer groups that include peers -func (am *DefaultAccountManager) isRouteChangeAffectPeers(account *types.Account, route *route.Route) bool { - return am.anyGroupHasPeers(account, route.Groups) || am.anyGroupHasPeers(account, route.PeerGroups) || route.Peer != "" +// areRouteChangesAffectPeers checks if a given route affects peers by determining +// if it has a routing peer, distribution, or peer groups that include peers. +func areRouteChangesAffectPeers(ctx context.Context, transaction store.Store, route *route.Route) (bool, error) { + if route.Peer != "" { + return true, nil + } + + hasPeers, err := anyGroupHasPeersOrResources(ctx, transaction, route.AccountID, route.Groups) + if err != nil { + return false, err + } + + if hasPeers { + return true, nil + } + + return anyGroupHasPeersOrResources(ctx, transaction, route.AccountID, route.PeerGroups) +} + +// GetRoutesByPrefixOrDomains return list of routes by account and route prefix +func getRoutesByPrefixOrDomains(ctx context.Context, transaction store.Store, accountID string, prefix netip.Prefix, domains domain.List) ([]*route.Route, error) { + accountRoutes, err := transaction.GetAccountRoutes(ctx, store.LockingStrengthShare, accountID) + if err != nil { + return nil, err + } + + routes := make([]*route.Route, 0) + for _, r := range accountRoutes { + dynamic := r.IsDynamic() + if dynamic && r.Domains.PunycodeString() == domains.PunycodeString() || + !dynamic && r.Network.String() == prefix.String() { + routes = append(routes, r) + } + } + + return routes, nil } diff --git a/management/server/scheduler.go b/management/server/scheduler.go index 147b50fc6..df73c9a1d 100644 --- a/management/server/scheduler.go +++ b/management/server/scheduler.go @@ -12,6 +12,7 @@ import ( type Scheduler interface { Cancel(ctx context.Context, IDs []string) Schedule(ctx context.Context, in time.Duration, ID string, job func() (nextRunIn time.Duration, reschedule bool)) + IsSchedulerRunning(ID string) bool } // MockScheduler is a mock implementation of Scheduler @@ -26,7 +27,7 @@ func (mock *MockScheduler) Cancel(ctx context.Context, IDs []string) { mock.CancelFunc(ctx, IDs) return } - log.WithContext(ctx).Errorf("MockScheduler doesn't have Cancel function defined ") + log.WithContext(ctx).Warnf("MockScheduler doesn't have Cancel function defined ") } // Schedule mocks the Schedule function of the Scheduler interface @@ -35,7 +36,13 @@ func (mock *MockScheduler) Schedule(ctx context.Context, in time.Duration, ID st mock.ScheduleFunc(ctx, in, ID, job) return } - log.WithContext(ctx).Errorf("MockScheduler doesn't have Schedule function defined") + log.WithContext(ctx).Warnf("MockScheduler doesn't have Schedule function defined") +} + +func (mock *MockScheduler) IsSchedulerRunning(ID string) bool { + // MockScheduler does not implement IsSchedulerRunning, so we return false + log.Warnf("MockScheduler doesn't have IsSchedulerRunning function defined") + return false } // DefaultScheduler is a generic structure that allows to schedule jobs (functions) to run in the future and cancel them. @@ -124,3 +131,11 @@ func (wm *DefaultScheduler) Schedule(ctx context.Context, in time.Duration, ID s }() } + +// IsSchedulerRunning checks if a job with the provided ID is scheduled to run +func (wm *DefaultScheduler) IsSchedulerRunning(ID string) bool { + wm.mu.Lock() + defer wm.mu.Unlock() + _, ok := wm.jobs[ID] + return ok +} diff --git a/management/server/setupkey_test.go b/management/server/setupkey_test.go index a561de40d..cecf55200 100644 --- a/management/server/setupkey_test.go +++ b/management/server/setupkey_test.go @@ -5,7 +5,6 @@ import ( "crypto/sha256" "encoding/base64" "fmt" - "strconv" "strings" "testing" "time" @@ -182,7 +181,7 @@ func TestDefaultAccountManager_CreateSetupKey(t *testing.T) { } assertKey(t, key, tCase.expectedKeyName, false, tCase.expectedType, tCase.expectedUsedTimes, - tCase.expectedCreatedAt, tCase.expectedExpiresAt, strconv.Itoa(int(types.Hash(key.Key))), + tCase.expectedCreatedAt, tCase.expectedExpiresAt, key.Id, tCase.expectedUpdatedAt, tCase.expectedGroups, false) // check the corresponding events that should have been generated @@ -258,10 +257,10 @@ func TestGenerateDefaultSetupKey(t *testing.T) { expectedExpiresAt := time.Now().UTC().Add(24 * 30 * time.Hour) var expectedAutoGroups []string - key, plainKey := types.GenerateDefaultSetupKey() + key, _ := types.GenerateDefaultSetupKey() assertKey(t, key, expectedName, expectedRevoke, expectedType, expectedUsedTimes, expectedCreatedAt, - expectedExpiresAt, strconv.Itoa(int(types.Hash(plainKey))), expectedUpdatedAt, expectedAutoGroups, true) + expectedExpiresAt, key.Id, expectedUpdatedAt, expectedAutoGroups, true) } @@ -275,10 +274,10 @@ func TestGenerateSetupKey(t *testing.T) { expectedUpdatedAt := time.Now().UTC() var expectedAutoGroups []string - key, plain := types.GenerateSetupKey(expectedName, types.SetupKeyOneOff, time.Hour, []string{}, types.SetupKeyUnlimitedUsage, false, false) + key, _ := types.GenerateSetupKey(expectedName, types.SetupKeyOneOff, time.Hour, []string{}, types.SetupKeyUnlimitedUsage, false, false) assertKey(t, key, expectedName, expectedRevoke, expectedType, expectedUsedTimes, expectedCreatedAt, - expectedExpiresAt, strconv.Itoa(int(types.Hash(plain))), expectedUpdatedAt, expectedAutoGroups, true) + expectedExpiresAt, key.Id, expectedUpdatedAt, expectedAutoGroups, true) } diff --git a/management/server/status/error.go b/management/server/status/error.go index 8fbe0bad9..5a6f6d1a7 100644 --- a/management/server/status/error.go +++ b/management/server/status/error.go @@ -227,3 +227,7 @@ func NewUserRoleNotFoundError(role string) error { func NewOperationNotFoundError(operation operations.Operation) error { return Errorf(NotFound, "operation: %s not found", operation) } + +func NewRouteNotFoundError(routeID string) error { + return Errorf(NotFound, "route: %s not found", routeID) +} diff --git a/management/server/store/sql_store.go b/management/server/store/sql_store.go index d81890775..a6c4d56bf 100644 --- a/management/server/store/sql_store.go +++ b/management/server/store/sql_store.go @@ -23,8 +23,6 @@ import ( "gorm.io/gorm/clause" "gorm.io/gorm/logger" - "github.com/netbirdio/netbird/management/server/util" - nbdns "github.com/netbirdio/netbird/dns" resourceTypes "github.com/netbirdio/netbird/management/server/networks/resources/types" routerTypes "github.com/netbirdio/netbird/management/server/networks/routers/types" @@ -34,6 +32,7 @@ import ( "github.com/netbirdio/netbird/management/server/status" "github.com/netbirdio/netbird/management/server/telemetry" "github.com/netbirdio/netbird/management/server/types" + "github.com/netbirdio/netbird/management/server/util" "github.com/netbirdio/netbird/route" ) @@ -1968,12 +1967,58 @@ func (s *SqlStore) DeletePostureChecks(ctx context.Context, lockStrength Locking // GetAccountRoutes retrieves network routes for an account. func (s *SqlStore) GetAccountRoutes(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*route.Route, error) { - return getRecords[*route.Route](s.db, lockStrength, accountID) + var routes []*route.Route + result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}). + Find(&routes, accountIDCondition, accountID) + if err := result.Error; err != nil { + log.WithContext(ctx).Errorf("failed to get routes from the store: %s", err) + return nil, status.Errorf(status.Internal, "failed to get routes from store") + } + + return routes, nil } // GetRouteByID retrieves a route by its ID and account ID. -func (s *SqlStore) GetRouteByID(ctx context.Context, lockStrength LockingStrength, routeID string, accountID string) (*route.Route, error) { - return getRecordByID[route.Route](s.db, lockStrength, routeID, accountID) +func (s *SqlStore) GetRouteByID(ctx context.Context, lockStrength LockingStrength, accountID string, routeID string) (*route.Route, error) { + var route *route.Route + result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}). + First(&route, accountAndIDQueryCondition, accountID, routeID) + if err := result.Error; err != nil { + if errors.Is(err, gorm.ErrRecordNotFound) { + return nil, status.NewRouteNotFoundError(routeID) + } + log.WithContext(ctx).Errorf("failed to get route from the store: %s", err) + return nil, status.Errorf(status.Internal, "failed to get route from store") + } + + return route, nil +} + +// SaveRoute saves a route to the database. +func (s *SqlStore) SaveRoute(ctx context.Context, lockStrength LockingStrength, route *route.Route) error { + result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}).Save(route) + if err := result.Error; err != nil { + log.WithContext(ctx).Errorf("failed to save route to the store: %s", err) + return status.Errorf(status.Internal, "failed to save route to store") + } + + return nil +} + +// DeleteRoute deletes a route from the database. +func (s *SqlStore) DeleteRoute(ctx context.Context, lockStrength LockingStrength, accountID, routeID string) error { + result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}). + Delete(&route.Route{}, accountAndIDQueryCondition, accountID, routeID) + if err := result.Error; err != nil { + log.WithContext(ctx).Errorf("failed to delete route from the store: %s", err) + return status.Errorf(status.Internal, "failed to delete route from store") + } + + if result.RowsAffected == 0 { + return status.NewRouteNotFoundError(routeID) + } + + return nil } // GetAccountSetupKeys retrieves setup keys for an account. @@ -2104,49 +2149,6 @@ func (s *SqlStore) DeleteNameServerGroup(ctx context.Context, lockStrength Locki return nil } -// getRecords retrieves records from the database based on the account ID. -func getRecords[T any](db *gorm.DB, lockStrength LockingStrength, accountID string) ([]T, error) { - tx := db - if lockStrength != LockingStrengthNone { - tx = tx.Clauses(clause.Locking{Strength: string(lockStrength)}) - } - - var record []T - - result := tx.Find(&record, accountIDCondition, accountID) - if err := result.Error; err != nil { - parts := strings.Split(fmt.Sprintf("%T", record), ".") - recordType := parts[len(parts)-1] - - return nil, status.Errorf(status.Internal, "failed to get account %ss from store: %v", recordType, err) - } - - return record, nil -} - -// getRecordByID retrieves a record by its ID and account ID from the database. -func getRecordByID[T any](db *gorm.DB, lockStrength LockingStrength, recordID, accountID string) (*T, error) { - tx := db - if lockStrength != LockingStrengthNone { - tx = tx.Clauses(clause.Locking{Strength: string(lockStrength)}) - } - - var record T - - result := tx.Clauses(clause.Locking{Strength: string(lockStrength)}). - First(&record, accountAndIDQueryCondition, accountID, recordID) - if err := result.Error; err != nil { - parts := strings.Split(fmt.Sprintf("%T", record), ".") - recordType := parts[len(parts)-1] - - if errors.Is(result.Error, gorm.ErrRecordNotFound) { - return nil, status.Errorf(status.NotFound, "%s not found", recordType) - } - return nil, status.Errorf(status.Internal, "failed to get %s from store: %v", recordType, err) - } - return &record, nil -} - // SaveDNSSettings saves the DNS settings to the store. func (s *SqlStore) SaveDNSSettings(ctx context.Context, lockStrength LockingStrength, accountID string, settings *types.DNSSettings) error { result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}).Model(&types.Account{}). diff --git a/management/server/store/sql_store_test.go b/management/server/store/sql_store_test.go index 2c1f5f8e6..fab9048e5 100644 --- a/management/server/store/sql_store_test.go +++ b/management/server/store/sql_store_test.go @@ -19,21 +19,17 @@ import ( "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" - "github.com/netbirdio/netbird/management/server/util" - nbdns "github.com/netbirdio/netbird/dns" resourceTypes "github.com/netbirdio/netbird/management/server/networks/resources/types" routerTypes "github.com/netbirdio/netbird/management/server/networks/routers/types" networkTypes "github.com/netbirdio/netbird/management/server/networks/types" - "github.com/netbirdio/netbird/management/server/posture" - "github.com/netbirdio/netbird/management/server/types" - - route2 "github.com/netbirdio/netbird/route" - - "github.com/netbirdio/netbird/management/server/status" - nbpeer "github.com/netbirdio/netbird/management/server/peer" + "github.com/netbirdio/netbird/management/server/posture" + "github.com/netbirdio/netbird/management/server/status" + "github.com/netbirdio/netbird/management/server/types" + "github.com/netbirdio/netbird/management/server/util" nbroute "github.com/netbirdio/netbird/route" + route2 "github.com/netbirdio/netbird/route" ) func runTestForAllEngines(t *testing.T, testDataFile string, f func(t *testing.T, store Store)) { @@ -3247,6 +3243,132 @@ func TestSqlStore_SaveGroups_LargeBatch(t *testing.T) { require.NoError(t, err) require.Equal(t, 8003, len(accountGroups)) } +func TestSqlStore_GetAccountRoutes(t *testing.T) { + store, cleanup, err := NewTestStoreFromSQL(context.Background(), "../testdata/extended-store.sql", t.TempDir()) + t.Cleanup(cleanup) + require.NoError(t, err) + + tests := []struct { + name string + accountID string + expectedCount int + }{ + { + name: "retrieve routes by existing account ID", + accountID: "bf1c8084-ba50-4ce7-9439-34653001fc3b", + expectedCount: 1, + }, + { + name: "non-existing account ID", + accountID: "nonexistent", + expectedCount: 0, + }, + { + name: "empty account ID", + accountID: "", + expectedCount: 0, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + routes, err := store.GetAccountRoutes(context.Background(), LockingStrengthShare, tt.accountID) + require.NoError(t, err) + require.Len(t, routes, tt.expectedCount) + }) + } +} + +func TestSqlStore_GetRouteByID(t *testing.T) { + store, cleanup, err := NewTestStoreFromSQL(context.Background(), "../testdata/extended-store.sql", t.TempDir()) + t.Cleanup(cleanup) + require.NoError(t, err) + + accountID := "bf1c8084-ba50-4ce7-9439-34653001fc3b" + tests := []struct { + name string + routeID string + expectError bool + }{ + { + name: "retrieve existing route", + routeID: "ct03t427qv97vmtmglog", + expectError: false, + }, + { + name: "retrieve non-existing route", + routeID: "non-existing", + expectError: true, + }, + { + name: "retrieve with empty route ID", + routeID: "", + expectError: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + route, err := store.GetRouteByID(context.Background(), LockingStrengthShare, accountID, tt.routeID) + if tt.expectError { + require.Error(t, err) + sErr, ok := status.FromError(err) + require.True(t, ok) + require.Equal(t, sErr.Type(), status.NotFound) + require.Nil(t, route) + } else { + require.NoError(t, err) + require.NotNil(t, route) + require.Equal(t, tt.routeID, string(route.ID)) + } + }) + } +} + +func TestSqlStore_SaveRoute(t *testing.T) { + store, cleanup, err := NewTestStoreFromSQL(context.Background(), "../testdata/extended-store.sql", t.TempDir()) + t.Cleanup(cleanup) + require.NoError(t, err) + + accountID := "bf1c8084-ba50-4ce7-9439-34653001fc3b" + + route := &route2.Route{ + ID: "route-id", + AccountID: accountID, + Network: netip.MustParsePrefix("10.10.0.0/16"), + NetID: "netID", + PeerGroups: []string{"routeA"}, + NetworkType: route2.IPv4Network, + Masquerade: true, + Metric: 9999, + Enabled: true, + Groups: []string{"groupA"}, + AccessControlGroups: []string{}, + } + err = store.SaveRoute(context.Background(), LockingStrengthUpdate, route) + require.NoError(t, err) + + saveRoute, err := store.GetRouteByID(context.Background(), LockingStrengthShare, accountID, string(route.ID)) + require.NoError(t, err) + require.Equal(t, route, saveRoute) + +} + +func TestSqlStore_DeleteRoute(t *testing.T) { + store, cleanup, err := NewTestStoreFromSQL(context.Background(), "../testdata/extended-store.sql", t.TempDir()) + t.Cleanup(cleanup) + require.NoError(t, err) + + accountID := "bf1c8084-ba50-4ce7-9439-34653001fc3b" + routeID := "ct03t427qv97vmtmglog" + + err = store.DeleteRoute(context.Background(), LockingStrengthUpdate, accountID, routeID) + require.NoError(t, err) + + route, err := store.GetRouteByID(context.Background(), LockingStrengthShare, accountID, routeID) + require.Error(t, err) + require.Nil(t, route) +} func TestSqlStore_GetAccountMeta(t *testing.T) { store, cleanup, err := NewTestStoreFromSQL(context.Background(), "../testdata/extended-store.sql", t.TempDir()) diff --git a/management/server/store/store.go b/management/server/store/store.go index c7b103454..d41379b1c 100644 --- a/management/server/store/store.go +++ b/management/server/store/store.go @@ -145,7 +145,9 @@ type Store interface { DeleteSetupKey(ctx context.Context, lockStrength LockingStrength, accountID, keyID string) error GetAccountRoutes(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*route.Route, error) - GetRouteByID(ctx context.Context, lockStrength LockingStrength, routeID string, accountID string) (*route.Route, error) + GetRouteByID(ctx context.Context, lockStrength LockingStrength, accountID, routeID string) (*route.Route, error) + SaveRoute(ctx context.Context, lockStrength LockingStrength, route *route.Route) error + DeleteRoute(ctx context.Context, lockStrength LockingStrength, accountID, routeID string) error GetAccountNameServerGroups(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*dns.NameServerGroup, error) GetNameServerGroupByID(ctx context.Context, lockStrength LockingStrength, nameServerGroupID string, accountID string) (*dns.NameServerGroup, error) diff --git a/management/server/telemetry/app_metrics.go b/management/server/telemetry/app_metrics.go index 09deb8127..988f91779 100644 --- a/management/server/telemetry/app_metrics.go +++ b/management/server/telemetry/app_metrics.go @@ -184,10 +184,10 @@ func (appMetrics *defaultAppMetrics) Expose(ctx context.Context, port int, endpo } appMetrics.listener = listener go func() { - err := http.Serve(listener, rootRouter) - if err != nil { - return + if err := http.Serve(listener, rootRouter); err != nil && err != http.ErrServerClosed { + log.WithContext(ctx).Errorf("metrics server error: %v", err) } + log.WithContext(ctx).Info("metrics server stopped") }() log.WithContext(ctx).Infof("enabled application metrics and exposing on http://%s", listener.Addr().String()) @@ -204,7 +204,7 @@ func (appMetrics *defaultAppMetrics) GetMeter() metric2.Meter { func NewDefaultAppMetrics(ctx context.Context) (AppMetrics, error) { exporter, err := prometheus.New() if err != nil { - return nil, err + return nil, fmt.Errorf("failed to create prometheus exporter: %w", err) } provider := metric.NewMeterProvider(metric.WithReader(exporter)) @@ -213,32 +213,32 @@ func NewDefaultAppMetrics(ctx context.Context) (AppMetrics, error) { idpMetrics, err := NewIDPMetrics(ctx, meter) if err != nil { - return nil, err + return nil, fmt.Errorf("failed to initialize IDP metrics: %w", err) } middleware, err := NewMetricsMiddleware(ctx, meter) if err != nil { - return nil, err + return nil, fmt.Errorf("failed to initialize HTTP middleware metrics: %w", err) } grpcMetrics, err := NewGRPCMetrics(ctx, meter) if err != nil { - return nil, err + return nil, fmt.Errorf("failed to initialize gRPC metrics: %w", err) } storeMetrics, err := NewStoreMetrics(ctx, meter) if err != nil { - return nil, err + return nil, fmt.Errorf("failed to initialize store metrics: %w", err) } updateChannelMetrics, err := NewUpdateChannelMetrics(ctx, meter) if err != nil { - return nil, err + return nil, fmt.Errorf("failed to initialize update channel metrics: %w", err) } accountManagerMetrics, err := NewAccountManagerMetrics(ctx, meter) if err != nil { - return nil, err + return nil, fmt.Errorf("failed to initialize account manager metrics: %w", err) } return &defaultAppMetrics{ diff --git a/management/server/testdata/extended-store.sql b/management/server/testdata/extended-store.sql index 324bf6293..0393d1ade 100644 --- a/management/server/testdata/extended-store.sql +++ b/management/server/testdata/extended-store.sql @@ -38,4 +38,5 @@ INSERT INTO "groups" VALUES('cfefqs706sqkneg59g2g','bf1c8084-ba50-4ce7-9439-3465 INSERT INTO posture_checks VALUES('csplshq7qv948l48f7t0','NetBird Version > 0.32.0','','bf1c8084-ba50-4ce7-9439-34653001fc3b','{"NBVersionCheck":{"MinVersion":"0.31.0"}}'); INSERT INTO posture_checks VALUES('cspnllq7qv95uq1r4k90','Allow Berlin and Deny local network 172.16.1.0/24','','bf1c8084-ba50-4ce7-9439-34653001fc3b','{"GeoLocationCheck":{"Locations":[{"CountryCode":"DE","CityName":"Berlin"}],"Action":"allow"},"PeerNetworkRangeCheck":{"Action":"deny","Ranges":["172.16.1.0/24"]}}'); INSERT INTO name_server_groups VALUES('csqdelq7qv97ncu7d9t0','bf1c8084-ba50-4ce7-9439-34653001fc3b','Google DNS','Google DNS Servers','[{"IP":"8.8.8.8","NSType":1,"Port":53},{"IP":"8.8.4.4","NSType":1,"Port":53}]','["cfefqs706sqkneg59g2g"]',1,'[]',1,0); +INSERT INTO routes VALUES('ct03t427qv97vmtmglog','bf1c8084-ba50-4ce7-9439-34653001fc3b','"10.10.0.0/16"',NULL,0,'aws-eu-central-1-vpc','Production VPC in Frankfurt','ct03r5q7qv97vmtmglng',NULL,1,1,9999,1,'["cfefqs706sqkneg59g2g"]',NULL); INSERT INTO installations VALUES(1,''); diff --git a/management/server/types/setupkey.go b/management/server/types/setupkey.go index ab8e46bea..69b381ae5 100644 --- a/management/server/types/setupkey.go +++ b/management/server/types/setupkey.go @@ -3,13 +3,12 @@ package types import ( "crypto/sha256" b64 "encoding/base64" - "hash/fnv" - "strconv" "strings" "time" "unicode/utf8" "github.com/google/uuid" + "github.com/rs/xid" "github.com/netbirdio/netbird/management/server/util" ) @@ -170,7 +169,7 @@ func GenerateSetupKey(name string, t SetupKeyType, validFor time.Duration, autoG encodedHashedKey := b64.StdEncoding.EncodeToString(hashedKey[:]) return &SetupKey{ - Id: strconv.Itoa(int(Hash(key))), + Id: xid.New().String(), Key: encodedHashedKey, KeySecret: HiddenKey(key, 4), Name: name, @@ -192,12 +191,3 @@ func GenerateDefaultSetupKey() (*SetupKey, string) { return GenerateSetupKey(DefaultSetupKeyName, SetupKeyReusable, DefaultSetupKeyDuration, []string{}, SetupKeyUnlimitedUsage, false, false) } - -func Hash(s string) uint32 { - h := fnv.New32a() - _, err := h.Write([]byte(s)) - if err != nil { - panic(err) - } - return h.Sum32() -} diff --git a/signal/cmd/env.go b/signal/cmd/env.go new file mode 100644 index 000000000..3c15ebe1f --- /dev/null +++ b/signal/cmd/env.go @@ -0,0 +1,35 @@ +package cmd + +import ( + "os" + "strings" + + log "github.com/sirupsen/logrus" + "github.com/spf13/cobra" + "github.com/spf13/pflag" +) + +// setFlagsFromEnvVars reads and updates flag values from environment variables with prefix NB_ +func setFlagsFromEnvVars(cmd *cobra.Command) { + flags := cmd.PersistentFlags() + flags.VisitAll(func(f *pflag.Flag) { + newEnvVar := flagNameToEnvVar(f.Name, "NB_") + value, present := os.LookupEnv(newEnvVar) + if !present { + return + } + + err := flags.Set(f.Name, value) + if err != nil { + log.Infof("unable to configure flag %s using variable %s, err: %v", f.Name, newEnvVar, err) + } + }) +} + +// flagNameToEnvVar converts flag name to environment var name adding a prefix, +// replacing dashes and making all uppercase (e.g. setup-keys is converted to NB_SETUP_KEYS according to the input prefix) +func flagNameToEnvVar(cmdFlag string, prefix string) string { + parsed := strings.ReplaceAll(cmdFlag, "-", "_") + upper := strings.ToUpper(parsed) + return prefix + upper +} diff --git a/signal/cmd/run.go b/signal/cmd/run.go index 3a671a848..39bc8331f 100644 --- a/signal/cmd/run.go +++ b/signal/cmd/run.go @@ -303,4 +303,5 @@ func init() { runCmd.Flags().StringVar(&signalLetsencryptDomain, "letsencrypt-domain", "", "a domain to issue Let's Encrypt certificate for. Enables TLS using Let's Encrypt. Will fetch and renew certificate, and run the server with TLS") runCmd.Flags().StringVar(&signalCertFile, "cert-file", "", "Location of your SSL certificate. Can be used when you have an existing certificate and don't want a new certificate be generated automatically. If letsencrypt-domain is specified this property has no effect") runCmd.Flags().StringVar(&signalCertKey, "cert-key", "", "Location of your SSL certificate private key. Can be used when you have an existing certificate and don't want a new certificate be generated automatically. If letsencrypt-domain is specified this property has no effect") + setFlagsFromEnvVars(runCmd) }