Merge branch 'main' into linux-file-upstream

This commit is contained in:
Viktor Liu 2025-06-18 20:52:03 +02:00
commit 7f8cea53a2
60 changed files with 1948 additions and 573 deletions

View File

@ -149,6 +149,7 @@ nfpms:
dockers: dockers:
- image_templates: - image_templates:
- netbirdio/netbird:{{ .Version }}-amd64 - netbirdio/netbird:{{ .Version }}-amd64
- ghcr.io/netbirdio/netbird:{{ .Version }}-amd64
ids: ids:
- netbird - netbird
goarch: amd64 goarch: amd64
@ -164,6 +165,7 @@ dockers:
- "--label=maintainer=dev@netbird.io" - "--label=maintainer=dev@netbird.io"
- image_templates: - image_templates:
- netbirdio/netbird:{{ .Version }}-arm64v8 - netbirdio/netbird:{{ .Version }}-arm64v8
- ghcr.io/netbirdio/netbird:{{ .Version }}-arm64v8
ids: ids:
- netbird - netbird
goarch: arm64 goarch: arm64

View File

@ -59,6 +59,8 @@ type Client struct {
deviceName string deviceName string
uiVersion string uiVersion string
networkChangeListener listener.NetworkChangeListener networkChangeListener listener.NetworkChangeListener
connectClient *internal.ConnectClient
} }
// NewClient instantiate a new Client // NewClient instantiate a new Client
@ -106,8 +108,8 @@ func (c *Client) Run(urlOpener URLOpener, dns *DNSList, dnsReadyListener DnsRead
// todo do not throw error in case of cancelled context // todo do not throw error in case of cancelled context
ctx = internal.CtxInitState(ctx) ctx = internal.CtxInitState(ctx)
connectClient := internal.NewConnectClient(ctx, cfg, c.recorder) c.connectClient = internal.NewConnectClient(ctx, cfg, c.recorder)
return connectClient.RunOnAndroid(c.tunAdapter, c.iFaceDiscover, c.networkChangeListener, dns.items, dnsReadyListener) return c.connectClient.RunOnAndroid(c.tunAdapter, c.iFaceDiscover, c.networkChangeListener, dns.items, dnsReadyListener)
} }
// RunWithoutLogin we apply this type of run function when the backed has been started without UI (i.e. after reboot). // RunWithoutLogin we apply this type of run function when the backed has been started without UI (i.e. after reboot).
@ -132,8 +134,8 @@ func (c *Client) RunWithoutLogin(dns *DNSList, dnsReadyListener DnsReadyListener
// todo do not throw error in case of cancelled context // todo do not throw error in case of cancelled context
ctx = internal.CtxInitState(ctx) ctx = internal.CtxInitState(ctx)
connectClient := internal.NewConnectClient(ctx, cfg, c.recorder) c.connectClient = internal.NewConnectClient(ctx, cfg, c.recorder)
return connectClient.RunOnAndroid(c.tunAdapter, c.iFaceDiscover, c.networkChangeListener, dns.items, dnsReadyListener) return c.connectClient.RunOnAndroid(c.tunAdapter, c.iFaceDiscover, c.networkChangeListener, dns.items, dnsReadyListener)
} }
// Stop the internal client and free the resources // Stop the internal client and free the resources
@ -174,6 +176,53 @@ func (c *Client) PeersList() *PeerInfoArray {
return &PeerInfoArray{items: peerInfos} return &PeerInfoArray{items: peerInfos}
} }
func (c *Client) Networks() *NetworkArray {
if c.connectClient == nil {
log.Error("not connected")
return nil
}
engine := c.connectClient.Engine()
if engine == nil {
log.Error("could not get engine")
return nil
}
routeManager := engine.GetRouteManager()
if routeManager == nil {
log.Error("could not get route manager")
return nil
}
networkArray := &NetworkArray{
items: make([]Network, 0),
}
for id, routes := range routeManager.GetClientRoutesWithNetID() {
if len(routes) == 0 {
continue
}
if routes[0].IsDynamic() {
continue
}
peer, err := c.recorder.GetPeer(routes[0].Peer)
if err != nil {
log.Errorf("could not get peer info for %s: %v", routes[0].Peer, err)
continue
}
network := Network{
Name: string(id),
Network: routes[0].Network.String(),
Peer: peer.FQDN,
Status: peer.ConnStatus.String(),
}
networkArray.Add(network)
}
return networkArray
}
// OnUpdatedHostDNS update the DNS servers addresses for root zones // OnUpdatedHostDNS update the DNS servers addresses for root zones
func (c *Client) OnUpdatedHostDNS(list *DNSList) error { func (c *Client) OnUpdatedHostDNS(list *DNSList) error {
dnsServer, err := dns.GetServerDns() dnsServer, err := dns.GetServerDns()

View File

@ -0,0 +1,27 @@
//go:build android
package android
type Network struct {
Name string
Network string
Peer string
Status string
}
type NetworkArray struct {
items []Network
}
func (array *NetworkArray) Add(s Network) *NetworkArray {
array.items = append(array.items, s)
return array
}
func (array *NetworkArray) Get(i int) *Network {
return &array.items[i]
}
func (array *NetworkArray) Size() int {
return len(array.items)
}

View File

@ -7,30 +7,23 @@ type PeerInfo struct {
ConnStatus string // Todo replace to enum ConnStatus string // Todo replace to enum
} }
// PeerInfoCollection made for Java layer to get non default types as collection // PeerInfoArray is a wrapper of []PeerInfo
type PeerInfoCollection interface {
Add(s string) PeerInfoCollection
Get(i int) string
Size() int
}
// PeerInfoArray is the implementation of the PeerInfoCollection
type PeerInfoArray struct { type PeerInfoArray struct {
items []PeerInfo items []PeerInfo
} }
// Add new PeerInfo to the collection // Add new PeerInfo to the collection
func (array PeerInfoArray) Add(s PeerInfo) PeerInfoArray { func (array *PeerInfoArray) Add(s PeerInfo) *PeerInfoArray {
array.items = append(array.items, s) array.items = append(array.items, s)
return array return array
} }
// Get return an element of the collection // Get return an element of the collection
func (array PeerInfoArray) Get(i int) *PeerInfo { func (array *PeerInfoArray) Get(i int) *PeerInfo {
return &array.items[i] return &array.items[i]
} }
// Size return with the size of the collection // Size return with the size of the collection
func (array PeerInfoArray) Size() int { func (array *PeerInfoArray) Size() int {
return len(array.items) return len(array.items)
} }

View File

@ -4,12 +4,12 @@ import (
"github.com/netbirdio/netbird/client/internal" "github.com/netbirdio/netbird/client/internal"
) )
// Preferences export a subset of the internal config for gomobile // Preferences exports a subset of the internal config for gomobile
type Preferences struct { type Preferences struct {
configInput internal.ConfigInput configInput internal.ConfigInput
} }
// NewPreferences create new Preferences instance // NewPreferences creates a new Preferences instance
func NewPreferences(configPath string) *Preferences { func NewPreferences(configPath string) *Preferences {
ci := internal.ConfigInput{ ci := internal.ConfigInput{
ConfigPath: configPath, ConfigPath: configPath,
@ -17,7 +17,7 @@ func NewPreferences(configPath string) *Preferences {
return &Preferences{ci} return &Preferences{ci}
} }
// GetManagementURL read url from config file // GetManagementURL reads URL from config file
func (p *Preferences) GetManagementURL() (string, error) { func (p *Preferences) GetManagementURL() (string, error) {
if p.configInput.ManagementURL != "" { if p.configInput.ManagementURL != "" {
return p.configInput.ManagementURL, nil return p.configInput.ManagementURL, nil
@ -30,12 +30,12 @@ func (p *Preferences) GetManagementURL() (string, error) {
return cfg.ManagementURL.String(), err return cfg.ManagementURL.String(), err
} }
// SetManagementURL store the given url and wait for commit // SetManagementURL stores the given URL and waits for commit
func (p *Preferences) SetManagementURL(url string) { func (p *Preferences) SetManagementURL(url string) {
p.configInput.ManagementURL = url p.configInput.ManagementURL = url
} }
// GetAdminURL read url from config file // GetAdminURL reads URL from config file
func (p *Preferences) GetAdminURL() (string, error) { func (p *Preferences) GetAdminURL() (string, error) {
if p.configInput.AdminURL != "" { if p.configInput.AdminURL != "" {
return p.configInput.AdminURL, nil return p.configInput.AdminURL, nil
@ -48,12 +48,12 @@ func (p *Preferences) GetAdminURL() (string, error) {
return cfg.AdminURL.String(), err return cfg.AdminURL.String(), err
} }
// SetAdminURL store the given url and wait for commit // SetAdminURL stores the given URL and waits for commit
func (p *Preferences) SetAdminURL(url string) { func (p *Preferences) SetAdminURL(url string) {
p.configInput.AdminURL = url p.configInput.AdminURL = url
} }
// GetPreSharedKey read preshared key from config file // GetPreSharedKey reads pre-shared key from config file
func (p *Preferences) GetPreSharedKey() (string, error) { func (p *Preferences) GetPreSharedKey() (string, error) {
if p.configInput.PreSharedKey != nil { if p.configInput.PreSharedKey != nil {
return *p.configInput.PreSharedKey, nil return *p.configInput.PreSharedKey, nil
@ -66,17 +66,17 @@ func (p *Preferences) GetPreSharedKey() (string, error) {
return cfg.PreSharedKey, err return cfg.PreSharedKey, err
} }
// SetPreSharedKey store the given key and wait for commit // SetPreSharedKey stores the given key and waits for commit
func (p *Preferences) SetPreSharedKey(key string) { func (p *Preferences) SetPreSharedKey(key string) {
p.configInput.PreSharedKey = &key p.configInput.PreSharedKey = &key
} }
// SetRosenpassEnabled store if rosenpass is enabled // SetRosenpassEnabled stores whether Rosenpass is enabled
func (p *Preferences) SetRosenpassEnabled(enabled bool) { func (p *Preferences) SetRosenpassEnabled(enabled bool) {
p.configInput.RosenpassEnabled = &enabled p.configInput.RosenpassEnabled = &enabled
} }
// GetRosenpassEnabled read rosenpass enabled from config file // GetRosenpassEnabled reads Rosenpass enabled status from config file
func (p *Preferences) GetRosenpassEnabled() (bool, error) { func (p *Preferences) GetRosenpassEnabled() (bool, error) {
if p.configInput.RosenpassEnabled != nil { if p.configInput.RosenpassEnabled != nil {
return *p.configInput.RosenpassEnabled, nil return *p.configInput.RosenpassEnabled, nil
@ -89,12 +89,12 @@ func (p *Preferences) GetRosenpassEnabled() (bool, error) {
return cfg.RosenpassEnabled, err return cfg.RosenpassEnabled, err
} }
// SetRosenpassPermissive store the given permissive and wait for commit // SetRosenpassPermissive stores the given permissive setting and waits for commit
func (p *Preferences) SetRosenpassPermissive(permissive bool) { func (p *Preferences) SetRosenpassPermissive(permissive bool) {
p.configInput.RosenpassPermissive = &permissive p.configInput.RosenpassPermissive = &permissive
} }
// GetRosenpassPermissive read rosenpass permissive from config file // GetRosenpassPermissive reads Rosenpass permissive setting from config file
func (p *Preferences) GetRosenpassPermissive() (bool, error) { func (p *Preferences) GetRosenpassPermissive() (bool, error) {
if p.configInput.RosenpassPermissive != nil { if p.configInput.RosenpassPermissive != nil {
return *p.configInput.RosenpassPermissive, nil return *p.configInput.RosenpassPermissive, nil
@ -107,7 +107,119 @@ func (p *Preferences) GetRosenpassPermissive() (bool, error) {
return cfg.RosenpassPermissive, err return cfg.RosenpassPermissive, err
} }
// Commit write out the changes into config file // GetDisableClientRoutes reads disable client routes setting from config file
func (p *Preferences) GetDisableClientRoutes() (bool, error) {
if p.configInput.DisableClientRoutes != nil {
return *p.configInput.DisableClientRoutes, nil
}
cfg, err := internal.ReadConfig(p.configInput.ConfigPath)
if err != nil {
return false, err
}
return cfg.DisableClientRoutes, err
}
// SetDisableClientRoutes stores the given value and waits for commit
func (p *Preferences) SetDisableClientRoutes(disable bool) {
p.configInput.DisableClientRoutes = &disable
}
// GetDisableServerRoutes reads disable server routes setting from config file
func (p *Preferences) GetDisableServerRoutes() (bool, error) {
if p.configInput.DisableServerRoutes != nil {
return *p.configInput.DisableServerRoutes, nil
}
cfg, err := internal.ReadConfig(p.configInput.ConfigPath)
if err != nil {
return false, err
}
return cfg.DisableServerRoutes, err
}
// SetDisableServerRoutes stores the given value and waits for commit
func (p *Preferences) SetDisableServerRoutes(disable bool) {
p.configInput.DisableServerRoutes = &disable
}
// GetDisableDNS reads disable DNS setting from config file
func (p *Preferences) GetDisableDNS() (bool, error) {
if p.configInput.DisableDNS != nil {
return *p.configInput.DisableDNS, nil
}
cfg, err := internal.ReadConfig(p.configInput.ConfigPath)
if err != nil {
return false, err
}
return cfg.DisableDNS, err
}
// SetDisableDNS stores the given value and waits for commit
func (p *Preferences) SetDisableDNS(disable bool) {
p.configInput.DisableDNS = &disable
}
// GetDisableFirewall reads disable firewall setting from config file
func (p *Preferences) GetDisableFirewall() (bool, error) {
if p.configInput.DisableFirewall != nil {
return *p.configInput.DisableFirewall, nil
}
cfg, err := internal.ReadConfig(p.configInput.ConfigPath)
if err != nil {
return false, err
}
return cfg.DisableFirewall, err
}
// SetDisableFirewall stores the given value and waits for commit
func (p *Preferences) SetDisableFirewall(disable bool) {
p.configInput.DisableFirewall = &disable
}
// GetServerSSHAllowed reads server SSH allowed setting from config file
func (p *Preferences) GetServerSSHAllowed() (bool, error) {
if p.configInput.ServerSSHAllowed != nil {
return *p.configInput.ServerSSHAllowed, nil
}
cfg, err := internal.ReadConfig(p.configInput.ConfigPath)
if err != nil {
return false, err
}
if cfg.ServerSSHAllowed == nil {
// Default to false for security on Android
return false, nil
}
return *cfg.ServerSSHAllowed, err
}
// SetServerSSHAllowed stores the given value and waits for commit
func (p *Preferences) SetServerSSHAllowed(allowed bool) {
p.configInput.ServerSSHAllowed = &allowed
}
// GetBlockInbound reads block inbound setting from config file
func (p *Preferences) GetBlockInbound() (bool, error) {
if p.configInput.BlockInbound != nil {
return *p.configInput.BlockInbound, nil
}
cfg, err := internal.ReadConfig(p.configInput.ConfigPath)
if err != nil {
return false, err
}
return cfg.BlockInbound, err
}
// SetBlockInbound stores the given value and waits for commit
func (p *Preferences) SetBlockInbound(block bool) {
p.configInput.BlockInbound = &block
}
// Commit writes out the changes to the config file
func (p *Preferences) Commit() error { func (p *Preferences) Commit() error {
_, err := internal.UpdateOrCreateConfig(p.configInput) _, err := internal.UpdateOrCreateConfig(p.configInput)
return err return err

View File

@ -69,7 +69,10 @@ func statusFunc(cmd *cobra.Command, args []string) error {
return err return err
} }
if resp.GetStatus() == string(internal.StatusNeedsLogin) || resp.GetStatus() == string(internal.StatusLoginFailed) { status := resp.GetStatus()
if status == string(internal.StatusNeedsLogin) || status == string(internal.StatusLoginFailed) ||
status == string(internal.StatusSessionExpired) {
cmd.Printf("Daemon status: %s\n\n"+ cmd.Printf("Daemon status: %s\n\n"+
"Run UP command to log in with SSO (interactive login):\n\n"+ "Run UP command to log in with SSO (interactive login):\n\n"+
" netbird up \n\n"+ " netbird up \n\n"+

View File

@ -38,5 +38,5 @@ func init() {
upCmd.PersistentFlags().BoolVar(&blockInbound, blockInboundFlag, false, upCmd.PersistentFlags().BoolVar(&blockInbound, blockInboundFlag, false,
"Block inbound connections. If enabled, the client will not allow any inbound connections to the local machine nor routed networks.\n"+ "Block inbound connections. If enabled, the client will not allow any inbound connections to the local machine nor routed networks.\n"+
"This overrides any policies received from the management service.") "This overrides any policies received from the management service.")
} }

View File

@ -118,7 +118,7 @@ func tracePacket(cmd *cobra.Command, args []string) error {
} }
func printTrace(cmd *cobra.Command, src, dst, proto string, sport, dport uint16, resp *proto.TracePacketResponse) { func printTrace(cmd *cobra.Command, src, dst, proto string, sport, dport uint16, resp *proto.TracePacketResponse) {
cmd.Printf("Packet trace %s:%d -> %s:%d (%s)\n\n", src, sport, dst, dport, strings.ToUpper(proto)) cmd.Printf("Packet trace %s:%d %s:%d (%s)\n\n", src, sport, dst, dport, strings.ToUpper(proto))
for _, stage := range resp.Stages { for _, stage := range resp.Stages {
if stage.ForwardingDetails != nil { if stage.ForwardingDetails != nil {

View File

@ -62,5 +62,5 @@ type ConnKey struct {
} }
func (c ConnKey) String() string { func (c ConnKey) String() string {
return fmt.Sprintf("%s:%d -> %s:%d", c.SrcIP.Unmap(), c.SrcPort, c.DstIP.Unmap(), c.DstPort) return fmt.Sprintf("%s:%d %s:%d", c.SrcIP.Unmap(), c.SrcPort, c.DstIP.Unmap(), c.DstPort)
} }

View File

@ -3,6 +3,7 @@ package conntrack
import ( import (
"context" "context"
"fmt" "fmt"
"net"
"net/netip" "net/netip"
"sync" "sync"
"time" "time"
@ -19,6 +20,10 @@ const (
DefaultICMPTimeout = 30 * time.Second DefaultICMPTimeout = 30 * time.Second
// ICMPCleanupInterval is how often we check for stale ICMP connections // ICMPCleanupInterval is how often we check for stale ICMP connections
ICMPCleanupInterval = 15 * time.Second ICMPCleanupInterval = 15 * time.Second
// MaxICMPPayloadLength is the maximum length of ICMP payload we consider for original packet info,
// which includes the IP header (20 bytes) and transport header (8 bytes)
MaxICMPPayloadLength = 28
) )
// ICMPConnKey uniquely identifies an ICMP connection // ICMPConnKey uniquely identifies an ICMP connection
@ -29,7 +34,7 @@ type ICMPConnKey struct {
} }
func (i ICMPConnKey) String() string { func (i ICMPConnKey) String() string {
return fmt.Sprintf("%s -> %s (id %d)", i.SrcIP, i.DstIP, i.ID) return fmt.Sprintf("%s %s (id %d)", i.SrcIP, i.DstIP, i.ID)
} }
// ICMPConnTrack represents an ICMP connection state // ICMPConnTrack represents an ICMP connection state
@ -50,6 +55,72 @@ type ICMPTracker struct {
flowLogger nftypes.FlowLogger flowLogger nftypes.FlowLogger
} }
// ICMPInfo holds ICMP type, code, and payload for lazy string formatting in logs
type ICMPInfo struct {
TypeCode layers.ICMPv4TypeCode
PayloadData [MaxICMPPayloadLength]byte
// actual length of valid data
PayloadLen int
}
// String implements fmt.Stringer for lazy evaluation in log messages
func (info ICMPInfo) String() string {
if info.isErrorMessage() && info.PayloadLen >= MaxICMPPayloadLength {
if origInfo := info.parseOriginalPacket(); origInfo != "" {
return fmt.Sprintf("%s (original: %s)", info.TypeCode, origInfo)
}
}
return info.TypeCode.String()
}
// isErrorMessage returns true if this ICMP type carries original packet info
func (info ICMPInfo) isErrorMessage() bool {
typ := info.TypeCode.Type()
return typ == 3 || // Destination Unreachable
typ == 5 || // Redirect
typ == 11 || // Time Exceeded
typ == 12 // Parameter Problem
}
// parseOriginalPacket extracts info about the original packet from ICMP payload
func (info ICMPInfo) parseOriginalPacket() string {
if info.PayloadLen < MaxICMPPayloadLength {
return ""
}
// TODO: handle IPv6
if version := (info.PayloadData[0] >> 4) & 0xF; version != 4 {
return ""
}
protocol := info.PayloadData[9]
srcIP := net.IP(info.PayloadData[12:16])
dstIP := net.IP(info.PayloadData[16:20])
transportData := info.PayloadData[20:]
switch nftypes.Protocol(protocol) {
case nftypes.TCP:
srcPort := uint16(transportData[0])<<8 | uint16(transportData[1])
dstPort := uint16(transportData[2])<<8 | uint16(transportData[3])
return fmt.Sprintf("TCP %s:%d → %s:%d", srcIP, srcPort, dstIP, dstPort)
case nftypes.UDP:
srcPort := uint16(transportData[0])<<8 | uint16(transportData[1])
dstPort := uint16(transportData[2])<<8 | uint16(transportData[3])
return fmt.Sprintf("UDP %s:%d → %s:%d", srcIP, srcPort, dstIP, dstPort)
case nftypes.ICMP:
icmpType := transportData[0]
icmpCode := transportData[1]
return fmt.Sprintf("ICMP %s → %s (type %d code %d)", srcIP, dstIP, icmpType, icmpCode)
default:
return fmt.Sprintf("Proto %d %s → %s", protocol, srcIP, dstIP)
}
}
// NewICMPTracker creates a new ICMP connection tracker // NewICMPTracker creates a new ICMP connection tracker
func NewICMPTracker(timeout time.Duration, logger *nblog.Logger, flowLogger nftypes.FlowLogger) *ICMPTracker { func NewICMPTracker(timeout time.Duration, logger *nblog.Logger, flowLogger nftypes.FlowLogger) *ICMPTracker {
if timeout == 0 { if timeout == 0 {
@ -93,30 +164,64 @@ func (t *ICMPTracker) updateIfExists(srcIP netip.Addr, dstIP netip.Addr, id uint
} }
// TrackOutbound records an outbound ICMP connection // TrackOutbound records an outbound ICMP connection
func (t *ICMPTracker) TrackOutbound(srcIP netip.Addr, dstIP netip.Addr, id uint16, typecode layers.ICMPv4TypeCode, size int) { func (t *ICMPTracker) TrackOutbound(
srcIP netip.Addr,
dstIP netip.Addr,
id uint16,
typecode layers.ICMPv4TypeCode,
payload []byte,
size int,
) {
if _, exists := t.updateIfExists(dstIP, srcIP, id, nftypes.Egress, size); !exists { if _, exists := t.updateIfExists(dstIP, srcIP, id, nftypes.Egress, size); !exists {
// if (inverted direction) conn is not tracked, track this direction // if (inverted direction) conn is not tracked, track this direction
t.track(srcIP, dstIP, id, typecode, nftypes.Egress, nil, size) t.track(srcIP, dstIP, id, typecode, nftypes.Egress, nil, payload, size)
} }
} }
// TrackInbound records an inbound ICMP Echo Request // TrackInbound records an inbound ICMP Echo Request
func (t *ICMPTracker) TrackInbound(srcIP netip.Addr, dstIP netip.Addr, id uint16, typecode layers.ICMPv4TypeCode, ruleId []byte, size int) { func (t *ICMPTracker) TrackInbound(
t.track(srcIP, dstIP, id, typecode, nftypes.Ingress, ruleId, size) srcIP netip.Addr,
dstIP netip.Addr,
id uint16,
typecode layers.ICMPv4TypeCode,
ruleId []byte,
payload []byte,
size int,
) {
t.track(srcIP, dstIP, id, typecode, nftypes.Ingress, ruleId, payload, size)
} }
// track is the common implementation for tracking both inbound and outbound ICMP connections // track is the common implementation for tracking both inbound and outbound ICMP connections
func (t *ICMPTracker) track(srcIP netip.Addr, dstIP netip.Addr, id uint16, typecode layers.ICMPv4TypeCode, direction nftypes.Direction, ruleId []byte, size int) { func (t *ICMPTracker) track(
srcIP netip.Addr,
dstIP netip.Addr,
id uint16,
typecode layers.ICMPv4TypeCode,
direction nftypes.Direction,
ruleId []byte,
payload []byte,
size int,
) {
key, exists := t.updateIfExists(srcIP, dstIP, id, direction, size) key, exists := t.updateIfExists(srcIP, dstIP, id, direction, size)
if exists { if exists {
return return
} }
typ, code := typecode.Type(), typecode.Code() typ, code := typecode.Type(), typecode.Code()
icmpInfo := ICMPInfo{
TypeCode: typecode,
}
if len(payload) > 0 {
icmpInfo.PayloadLen = len(payload)
if icmpInfo.PayloadLen > MaxICMPPayloadLength {
icmpInfo.PayloadLen = MaxICMPPayloadLength
}
copy(icmpInfo.PayloadData[:], payload[:icmpInfo.PayloadLen])
}
// non echo requests don't need tracking // non echo requests don't need tracking
if typ != uint8(layers.ICMPv4TypeEchoRequest) { if typ != uint8(layers.ICMPv4TypeEchoRequest) {
t.logger.Trace("New %s ICMP connection %s type %d code %d", direction, key, typ, code) t.logger.Trace("New %s ICMP connection %s - %s", direction, key, icmpInfo)
t.sendStartEvent(direction, srcIP, dstIP, typ, code, ruleId, size) t.sendStartEvent(direction, srcIP, dstIP, typ, code, ruleId, size)
return return
} }
@ -138,7 +243,7 @@ func (t *ICMPTracker) track(srcIP netip.Addr, dstIP netip.Addr, id uint16, typec
t.connections[key] = conn t.connections[key] = conn
t.mutex.Unlock() t.mutex.Unlock()
t.logger.Trace("New %s ICMP connection %s type %d code %d", direction, key, typ, code) t.logger.Trace("New %s ICMP connection %s - %s", direction, key, icmpInfo)
t.sendEvent(nftypes.TypeStart, conn, ruleId) t.sendEvent(nftypes.TypeStart, conn, ruleId)
} }

View File

@ -15,7 +15,7 @@ func BenchmarkICMPTracker(b *testing.B) {
b.ResetTimer() b.ResetTimer()
for i := 0; i < b.N; i++ { for i := 0; i < b.N; i++ {
tracker.TrackOutbound(srcIP, dstIP, uint16(i%65535), 0, 0) tracker.TrackOutbound(srcIP, dstIP, uint16(i%65535), 0, []byte{}, 0)
} }
}) })
@ -28,7 +28,7 @@ func BenchmarkICMPTracker(b *testing.B) {
// Pre-populate some connections // Pre-populate some connections
for i := 0; i < 1000; i++ { for i := 0; i < 1000; i++ {
tracker.TrackOutbound(srcIP, dstIP, uint16(i), 0, 0) tracker.TrackOutbound(srcIP, dstIP, uint16(i), 0, []byte{}, 0)
} }
b.ResetTimer() b.ResetTimer()

View File

@ -86,5 +86,5 @@ type epID stack.TransportEndpointID
func (i epID) String() string { func (i epID) String() string {
// src and remote is swapped // src and remote is swapped
return fmt.Sprintf("%s:%d -> %s:%d", i.RemoteAddress, i.RemotePort, i.LocalAddress, i.LocalPort) return fmt.Sprintf("%s:%d %s:%d", i.RemoteAddress, i.RemotePort, i.LocalAddress, i.LocalPort)
} }

View File

@ -111,12 +111,12 @@ func (f *Forwarder) proxyTCP(id stack.TransportEndpointID, inConn *gonet.TCPConn
if errInToOut != nil { if errInToOut != nil {
if !isClosedError(errInToOut) { if !isClosedError(errInToOut) {
f.logger.Error("proxyTCP: copy error (in -> out) for %s: %v", epID(id), errInToOut) f.logger.Error("proxyTCP: copy error (in out) for %s: %v", epID(id), errInToOut)
} }
} }
if errOutToIn != nil { if errOutToIn != nil {
if !isClosedError(errOutToIn) { if !isClosedError(errOutToIn) {
f.logger.Error("proxyTCP: copy error (out -> in) for %s: %v", epID(id), errOutToIn) f.logger.Error("proxyTCP: copy error (out in) for %s: %v", epID(id), errOutToIn)
} }
} }

View File

@ -250,10 +250,10 @@ func (f *Forwarder) proxyUDP(ctx context.Context, pConn *udpPacketConn, id stack
wg.Wait() wg.Wait()
if outboundErr != nil && !isClosedError(outboundErr) { if outboundErr != nil && !isClosedError(outboundErr) {
f.logger.Error("proxyUDP: copy error (outbound->inbound) for %s: %v", epID(id), outboundErr) f.logger.Error("proxyUDP: copy error (outboundinbound) for %s: %v", epID(id), outboundErr)
} }
if inboundErr != nil && !isClosedError(inboundErr) { if inboundErr != nil && !isClosedError(inboundErr) {
f.logger.Error("proxyUDP: copy error (inbound->outbound) for %s: %v", epID(id), inboundErr) f.logger.Error("proxyUDP: copy error (inboundoutbound) for %s: %v", epID(id), inboundErr)
} }
var rxPackets, txPackets uint64 var rxPackets, txPackets uint64

View File

@ -671,7 +671,7 @@ func (m *Manager) trackOutbound(d *decoder, srcIP, dstIP netip.Addr, size int) {
flags := getTCPFlags(&d.tcp) flags := getTCPFlags(&d.tcp)
m.tcpTracker.TrackOutbound(srcIP, dstIP, uint16(d.tcp.SrcPort), uint16(d.tcp.DstPort), flags, size) m.tcpTracker.TrackOutbound(srcIP, dstIP, uint16(d.tcp.SrcPort), uint16(d.tcp.DstPort), flags, size)
case layers.LayerTypeICMPv4: case layers.LayerTypeICMPv4:
m.icmpTracker.TrackOutbound(srcIP, dstIP, d.icmp4.Id, d.icmp4.TypeCode, size) m.icmpTracker.TrackOutbound(srcIP, dstIP, d.icmp4.Id, d.icmp4.TypeCode, d.icmp4.Payload, size)
} }
} }
@ -684,7 +684,7 @@ func (m *Manager) trackInbound(d *decoder, srcIP, dstIP netip.Addr, ruleID []byt
flags := getTCPFlags(&d.tcp) flags := getTCPFlags(&d.tcp)
m.tcpTracker.TrackInbound(srcIP, dstIP, uint16(d.tcp.SrcPort), uint16(d.tcp.DstPort), flags, ruleID, size) m.tcpTracker.TrackInbound(srcIP, dstIP, uint16(d.tcp.SrcPort), uint16(d.tcp.DstPort), flags, ruleID, size)
case layers.LayerTypeICMPv4: case layers.LayerTypeICMPv4:
m.icmpTracker.TrackInbound(srcIP, dstIP, d.icmp4.Id, d.icmp4.TypeCode, ruleID, size) m.icmpTracker.TrackInbound(srcIP, dstIP, d.icmp4.Id, d.icmp4.TypeCode, ruleID, d.icmp4.Payload, size)
} }
} }

View File

@ -24,6 +24,7 @@ type WGTunDevice struct {
mtu int mtu int
iceBind *bind.ICEBind iceBind *bind.ICEBind
tunAdapter TunAdapter tunAdapter TunAdapter
disableDNS bool
name string name string
device *device.Device device *device.Device
@ -32,7 +33,7 @@ type WGTunDevice struct {
configurer WGConfigurer configurer WGConfigurer
} }
func NewTunDevice(address wgaddr.Address, port int, key string, mtu int, iceBind *bind.ICEBind, tunAdapter TunAdapter) *WGTunDevice { func NewTunDevice(address wgaddr.Address, port int, key string, mtu int, iceBind *bind.ICEBind, tunAdapter TunAdapter, disableDNS bool) *WGTunDevice {
return &WGTunDevice{ return &WGTunDevice{
address: address, address: address,
port: port, port: port,
@ -40,6 +41,7 @@ func NewTunDevice(address wgaddr.Address, port int, key string, mtu int, iceBind
mtu: mtu, mtu: mtu,
iceBind: iceBind, iceBind: iceBind,
tunAdapter: tunAdapter, tunAdapter: tunAdapter,
disableDNS: disableDNS,
} }
} }
@ -49,6 +51,13 @@ func (t *WGTunDevice) Create(routes []string, dns string, searchDomains []string
routesString := routesToString(routes) routesString := routesToString(routes)
searchDomainsToString := searchDomainsToString(searchDomains) searchDomainsToString := searchDomainsToString(searchDomains)
// Skip DNS configuration when DisableDNS is enabled
if t.disableDNS {
log.Info("DNS is disabled, skipping DNS and search domain configuration")
dns = ""
searchDomainsToString = ""
}
fd, err := t.tunAdapter.ConfigureInterface(t.address.String(), t.mtu, dns, searchDomainsToString, routesString) fd, err := t.tunAdapter.ConfigureInterface(t.address.String(), t.mtu, dns, searchDomainsToString, routesString)
if err != nil { if err != nil {
log.Errorf("failed to create Android interface: %s", err) log.Errorf("failed to create Android interface: %s", err)

View File

@ -43,6 +43,7 @@ type WGIFaceOpts struct {
MobileArgs *device.MobileIFaceArguments MobileArgs *device.MobileIFaceArguments
TransportNet transport.Net TransportNet transport.Net
FilterFn bind.FilterFn FilterFn bind.FilterFn
DisableDNS bool
} }
// WGIface represents an interface instance // WGIface represents an interface instance

View File

@ -18,7 +18,7 @@ func NewWGIFace(opts WGIFaceOpts) (*WGIface, error) {
wgIFace := &WGIface{ wgIFace := &WGIface{
userspaceBind: true, userspaceBind: true,
tun: device.NewTunDevice(wgAddress, opts.WGPort, opts.WGPrivKey, opts.MTU, iceBind, opts.MobileArgs.TunAdapter), tun: device.NewTunDevice(wgAddress, opts.WGPort, opts.WGPrivKey, opts.MTU, iceBind, opts.MobileArgs.TunAdapter, opts.DisableDNS),
wgProxyFactory: wgproxy.NewUSPFactory(iceBind), wgProxyFactory: wgproxy.NewUSPFactory(iceBind),
} }
return wgIFace, nil return wgIFace, nil

View File

@ -223,6 +223,8 @@ func createNewConfig(input ConfigInput) (*Config, error) {
config := &Config{ config := &Config{
// defaults to false only for new (post 0.26) configurations // defaults to false only for new (post 0.26) configurations
ServerSSHAllowed: util.False(), ServerSSHAllowed: util.False(),
// default to disabling server routes on Android for security
DisableServerRoutes: runtime.GOOS == "android",
} }
if _, err := config.apply(input); err != nil { if _, err := config.apply(input); err != nil {
@ -416,9 +418,15 @@ func (config *Config) apply(input ConfigInput) (updated bool, err error) {
config.ServerSSHAllowed = input.ServerSSHAllowed config.ServerSSHAllowed = input.ServerSSHAllowed
updated = true updated = true
} else if config.ServerSSHAllowed == nil { } else if config.ServerSSHAllowed == nil {
// enables SSH for configs from old versions to preserve backwards compatibility if runtime.GOOS == "android" {
log.Infof("falling back to enabled SSH server for pre-existing configuration") // default to disabled SSH on Android for security
config.ServerSSHAllowed = util.True() log.Infof("setting SSH server to false by default on Android")
config.ServerSSHAllowed = util.False()
} else {
// enables SSH for configs from old versions to preserve backwards compatibility
log.Infof("falling back to enabled SSH server for pre-existing configuration")
config.ServerSSHAllowed = util.True()
}
updated = true updated = true
} }

View File

@ -11,9 +11,10 @@ import (
) )
const ( const (
PriorityDNSRoute = 100 PriorityLocal = 100
PriorityMatchDomain = 50 PriorityDNSRoute = 75
PriorityDefault = 1 PriorityUpstream = 50
PriorityDefault = 1
PriorityFallback = -100 PriorityFallback = -100
) )

View File

@ -22,7 +22,7 @@ func TestHandlerChain_ServeDNS_Priorities(t *testing.T) {
// Setup handlers with different priorities // Setup handlers with different priorities
chain.AddHandler("example.com.", defaultHandler, nbdns.PriorityDefault) chain.AddHandler("example.com.", defaultHandler, nbdns.PriorityDefault)
chain.AddHandler("example.com.", matchDomainHandler, nbdns.PriorityMatchDomain) chain.AddHandler("example.com.", matchDomainHandler, nbdns.PriorityUpstream)
chain.AddHandler("example.com.", dnsRouteHandler, nbdns.PriorityDNSRoute) chain.AddHandler("example.com.", dnsRouteHandler, nbdns.PriorityDNSRoute)
// Create test request // Create test request
@ -200,7 +200,7 @@ func TestHandlerChain_ServeDNS_OverlappingDomains(t *testing.T) {
priority int priority int
}{ }{
{pattern: "*.example.com.", priority: nbdns.PriorityDefault}, {pattern: "*.example.com.", priority: nbdns.PriorityDefault},
{pattern: "*.example.com.", priority: nbdns.PriorityMatchDomain}, {pattern: "*.example.com.", priority: nbdns.PriorityUpstream},
{pattern: "*.example.com.", priority: nbdns.PriorityDNSRoute}, {pattern: "*.example.com.", priority: nbdns.PriorityDNSRoute},
}, },
queryDomain: "test.example.com.", queryDomain: "test.example.com.",
@ -214,7 +214,7 @@ func TestHandlerChain_ServeDNS_OverlappingDomains(t *testing.T) {
priority int priority int
}{ }{
{pattern: "*.example.com.", priority: nbdns.PriorityDefault}, {pattern: "*.example.com.", priority: nbdns.PriorityDefault},
{pattern: "test.example.com.", priority: nbdns.PriorityMatchDomain}, {pattern: "test.example.com.", priority: nbdns.PriorityUpstream},
{pattern: "*.test.example.com.", priority: nbdns.PriorityDNSRoute}, {pattern: "*.test.example.com.", priority: nbdns.PriorityDNSRoute},
}, },
queryDomain: "sub.test.example.com.", queryDomain: "sub.test.example.com.",
@ -281,7 +281,7 @@ func TestHandlerChain_ServeDNS_ChainContinuation(t *testing.T) {
// Add handlers in priority order // Add handlers in priority order
chain.AddHandler("example.com.", handler1, nbdns.PriorityDNSRoute) chain.AddHandler("example.com.", handler1, nbdns.PriorityDNSRoute)
chain.AddHandler("example.com.", handler2, nbdns.PriorityMatchDomain) chain.AddHandler("example.com.", handler2, nbdns.PriorityUpstream)
chain.AddHandler("example.com.", handler3, nbdns.PriorityDefault) chain.AddHandler("example.com.", handler3, nbdns.PriorityDefault)
// Create test request // Create test request
@ -344,13 +344,13 @@ func TestHandlerChain_PriorityDeregistration(t *testing.T) {
priority int priority int
}{ }{
{"add", "example.com.", nbdns.PriorityDNSRoute}, {"add", "example.com.", nbdns.PriorityDNSRoute},
{"add", "example.com.", nbdns.PriorityMatchDomain}, {"add", "example.com.", nbdns.PriorityUpstream},
{"remove", "example.com.", nbdns.PriorityDNSRoute}, {"remove", "example.com.", nbdns.PriorityDNSRoute},
}, },
query: "example.com.", query: "example.com.",
expectedCalls: map[int]bool{ expectedCalls: map[int]bool{
nbdns.PriorityDNSRoute: false, nbdns.PriorityDNSRoute: false,
nbdns.PriorityMatchDomain: true, nbdns.PriorityUpstream: true,
}, },
}, },
{ {
@ -361,13 +361,13 @@ func TestHandlerChain_PriorityDeregistration(t *testing.T) {
priority int priority int
}{ }{
{"add", "example.com.", nbdns.PriorityDNSRoute}, {"add", "example.com.", nbdns.PriorityDNSRoute},
{"add", "example.com.", nbdns.PriorityMatchDomain}, {"add", "example.com.", nbdns.PriorityUpstream},
{"remove", "example.com.", nbdns.PriorityMatchDomain}, {"remove", "example.com.", nbdns.PriorityUpstream},
}, },
query: "example.com.", query: "example.com.",
expectedCalls: map[int]bool{ expectedCalls: map[int]bool{
nbdns.PriorityDNSRoute: true, nbdns.PriorityDNSRoute: true,
nbdns.PriorityMatchDomain: false, nbdns.PriorityUpstream: false,
}, },
}, },
{ {
@ -378,16 +378,16 @@ func TestHandlerChain_PriorityDeregistration(t *testing.T) {
priority int priority int
}{ }{
{"add", "example.com.", nbdns.PriorityDNSRoute}, {"add", "example.com.", nbdns.PriorityDNSRoute},
{"add", "example.com.", nbdns.PriorityMatchDomain}, {"add", "example.com.", nbdns.PriorityUpstream},
{"add", "example.com.", nbdns.PriorityDefault}, {"add", "example.com.", nbdns.PriorityDefault},
{"remove", "example.com.", nbdns.PriorityDNSRoute}, {"remove", "example.com.", nbdns.PriorityDNSRoute},
{"remove", "example.com.", nbdns.PriorityMatchDomain}, {"remove", "example.com.", nbdns.PriorityUpstream},
}, },
query: "example.com.", query: "example.com.",
expectedCalls: map[int]bool{ expectedCalls: map[int]bool{
nbdns.PriorityDNSRoute: false, nbdns.PriorityDNSRoute: false,
nbdns.PriorityMatchDomain: false, nbdns.PriorityUpstream: false,
nbdns.PriorityDefault: true, nbdns.PriorityDefault: true,
}, },
}, },
} }
@ -454,7 +454,7 @@ func TestHandlerChain_MultiPriorityHandling(t *testing.T) {
// Add handlers in mixed order // Add handlers in mixed order
chain.AddHandler(testDomain, defaultHandler, nbdns.PriorityDefault) chain.AddHandler(testDomain, defaultHandler, nbdns.PriorityDefault)
chain.AddHandler(testDomain, routeHandler, nbdns.PriorityDNSRoute) chain.AddHandler(testDomain, routeHandler, nbdns.PriorityDNSRoute)
chain.AddHandler(testDomain, matchHandler, nbdns.PriorityMatchDomain) chain.AddHandler(testDomain, matchHandler, nbdns.PriorityUpstream)
// Test 1: Initial state // Test 1: Initial state
w1 := &nbdns.ResponseWriterChain{ResponseWriter: &test.MockResponseWriter{}} w1 := &nbdns.ResponseWriterChain{ResponseWriter: &test.MockResponseWriter{}}
@ -490,7 +490,7 @@ func TestHandlerChain_MultiPriorityHandling(t *testing.T) {
defaultHandler.Calls = nil defaultHandler.Calls = nil
// Test 3: Remove middle priority handler // Test 3: Remove middle priority handler
chain.RemoveHandler(testDomain, nbdns.PriorityMatchDomain) chain.RemoveHandler(testDomain, nbdns.PriorityUpstream)
w3 := &nbdns.ResponseWriterChain{ResponseWriter: &test.MockResponseWriter{}} w3 := &nbdns.ResponseWriterChain{ResponseWriter: &test.MockResponseWriter{}}
// Now lowest priority handler (defaultHandler) should be called // Now lowest priority handler (defaultHandler) should be called
@ -607,7 +607,7 @@ func TestHandlerChain_CaseSensitivity(t *testing.T) {
shouldMatch bool shouldMatch bool
}{ }{
{"EXAMPLE.COM.", nbdns.PriorityDefault, false, false}, {"EXAMPLE.COM.", nbdns.PriorityDefault, false, false},
{"example.com.", nbdns.PriorityMatchDomain, false, false}, {"example.com.", nbdns.PriorityUpstream, false, false},
{"Example.Com.", nbdns.PriorityDNSRoute, false, true}, {"Example.Com.", nbdns.PriorityDNSRoute, false, true},
}, },
query: "example.com.", query: "example.com.",
@ -702,8 +702,8 @@ func TestHandlerChain_DomainSpecificityOrdering(t *testing.T) {
priority int priority int
subdomain bool subdomain bool
}{ }{
{"add", "example.com.", nbdns.PriorityMatchDomain, true}, {"add", "example.com.", nbdns.PriorityUpstream, true},
{"add", "sub.example.com.", nbdns.PriorityMatchDomain, false}, {"add", "sub.example.com.", nbdns.PriorityUpstream, false},
}, },
query: "sub.example.com.", query: "sub.example.com.",
expectedMatch: "sub.example.com.", expectedMatch: "sub.example.com.",
@ -717,8 +717,8 @@ func TestHandlerChain_DomainSpecificityOrdering(t *testing.T) {
priority int priority int
subdomain bool subdomain bool
}{ }{
{"add", "example.com.", nbdns.PriorityMatchDomain, true}, {"add", "example.com.", nbdns.PriorityUpstream, true},
{"add", "sub.example.com.", nbdns.PriorityMatchDomain, true}, {"add", "sub.example.com.", nbdns.PriorityUpstream, true},
}, },
query: "sub.example.com.", query: "sub.example.com.",
expectedMatch: "sub.example.com.", expectedMatch: "sub.example.com.",
@ -732,10 +732,10 @@ func TestHandlerChain_DomainSpecificityOrdering(t *testing.T) {
priority int priority int
subdomain bool subdomain bool
}{ }{
{"add", "example.com.", nbdns.PriorityMatchDomain, true}, {"add", "example.com.", nbdns.PriorityUpstream, true},
{"add", "sub.example.com.", nbdns.PriorityMatchDomain, true}, {"add", "sub.example.com.", nbdns.PriorityUpstream, true},
{"add", "test.sub.example.com.", nbdns.PriorityMatchDomain, false}, {"add", "test.sub.example.com.", nbdns.PriorityUpstream, false},
{"remove", "test.sub.example.com.", nbdns.PriorityMatchDomain, false}, {"remove", "test.sub.example.com.", nbdns.PriorityUpstream, false},
}, },
query: "test.sub.example.com.", query: "test.sub.example.com.",
expectedMatch: "sub.example.com.", expectedMatch: "sub.example.com.",
@ -749,7 +749,7 @@ func TestHandlerChain_DomainSpecificityOrdering(t *testing.T) {
priority int priority int
subdomain bool subdomain bool
}{ }{
{"add", "sub.example.com.", nbdns.PriorityMatchDomain, false}, {"add", "sub.example.com.", nbdns.PriorityUpstream, false},
{"add", "example.com.", nbdns.PriorityDNSRoute, true}, {"add", "example.com.", nbdns.PriorityDNSRoute, true},
}, },
query: "sub.example.com.", query: "sub.example.com.",
@ -764,9 +764,9 @@ func TestHandlerChain_DomainSpecificityOrdering(t *testing.T) {
priority int priority int
subdomain bool subdomain bool
}{ }{
{"add", "example.com.", nbdns.PriorityMatchDomain, true}, {"add", "example.com.", nbdns.PriorityUpstream, true},
{"add", "other.example.com.", nbdns.PriorityMatchDomain, true}, {"add", "other.example.com.", nbdns.PriorityUpstream, true},
{"add", "sub.example.com.", nbdns.PriorityMatchDomain, false}, {"add", "sub.example.com.", nbdns.PriorityUpstream, false},
}, },
query: "sub.example.com.", query: "sub.example.com.",
expectedMatch: "sub.example.com.", expectedMatch: "sub.example.com.",

View File

@ -567,7 +567,7 @@ func (s *DefaultServer) buildLocalHandlerUpdate(customZones []nbdns.CustomZone)
muxUpdates = append(muxUpdates, handlerWrapper{ muxUpdates = append(muxUpdates, handlerWrapper{
domain: customZone.Domain, domain: customZone.Domain,
handler: s.localResolver, handler: s.localResolver,
priority: PriorityMatchDomain, priority: PriorityLocal,
}) })
for _, record := range customZone.Records { for _, record := range customZone.Records {
@ -606,7 +606,7 @@ func (s *DefaultServer) buildUpstreamHandlerUpdate(nameServerGroups []*nbdns.Nam
groupedNS := groupNSGroupsByDomain(nameServerGroups) groupedNS := groupNSGroupsByDomain(nameServerGroups)
for _, domainGroup := range groupedNS { for _, domainGroup := range groupedNS {
basePriority := PriorityMatchDomain basePriority := PriorityUpstream
if domainGroup.domain == nbdns.RootZone { if domainGroup.domain == nbdns.RootZone {
basePriority = PriorityDefault basePriority = PriorityDefault
} }
@ -683,9 +683,9 @@ func (s *DefaultServer) createHandlersForDomainGroup(domainGroup nsGroupsByDomai
} }
func (s *DefaultServer) leaksPriority(domainGroup nsGroupsByDomain, basePriority int, priority int) bool { func (s *DefaultServer) leaksPriority(domainGroup nsGroupsByDomain, basePriority int, priority int) bool {
if basePriority == PriorityMatchDomain && priority <= PriorityDefault { if basePriority == PriorityUpstream && priority <= PriorityDefault {
log.Warnf("too many handlers for domain=%s, would overlap with default priority tier (diff=%d). Skipping remaining handlers", log.Warnf("too many handlers for domain=%s, would overlap with default priority tier (diff=%d). Skipping remaining handlers",
domainGroup.domain, PriorityMatchDomain-PriorityDefault) domainGroup.domain, PriorityUpstream-PriorityDefault)
return true return true
} }
if basePriority == PriorityDefault && priority <= PriorityFallback { if basePriority == PriorityDefault && priority <= PriorityFallback {

View File

@ -164,12 +164,12 @@ func TestUpdateDNSServer(t *testing.T) {
generateDummyHandler("netbird.io", nameServers).ID(): handlerWrapper{ generateDummyHandler("netbird.io", nameServers).ID(): handlerWrapper{
domain: "netbird.io", domain: "netbird.io",
handler: dummyHandler, handler: dummyHandler,
priority: PriorityMatchDomain, priority: PriorityUpstream,
}, },
dummyHandler.ID(): handlerWrapper{ dummyHandler.ID(): handlerWrapper{
domain: "netbird.cloud", domain: "netbird.cloud",
handler: dummyHandler, handler: dummyHandler,
priority: PriorityMatchDomain, priority: PriorityLocal,
}, },
generateDummyHandler(".", nameServers).ID(): handlerWrapper{ generateDummyHandler(".", nameServers).ID(): handlerWrapper{
domain: nbdns.RootZone, domain: nbdns.RootZone,
@ -186,7 +186,7 @@ func TestUpdateDNSServer(t *testing.T) {
generateDummyHandler(zoneRecords[0].Name, nameServers).ID(): handlerWrapper{ generateDummyHandler(zoneRecords[0].Name, nameServers).ID(): handlerWrapper{
domain: "netbird.cloud", domain: "netbird.cloud",
handler: dummyHandler, handler: dummyHandler,
priority: PriorityMatchDomain, priority: PriorityUpstream,
}, },
}, },
initSerial: 0, initSerial: 0,
@ -210,12 +210,12 @@ func TestUpdateDNSServer(t *testing.T) {
generateDummyHandler("netbird.io", nameServers).ID(): handlerWrapper{ generateDummyHandler("netbird.io", nameServers).ID(): handlerWrapper{
domain: "netbird.io", domain: "netbird.io",
handler: dummyHandler, handler: dummyHandler,
priority: PriorityMatchDomain, priority: PriorityUpstream,
}, },
"local-resolver": handlerWrapper{ "local-resolver": handlerWrapper{
domain: "netbird.cloud", domain: "netbird.cloud",
handler: dummyHandler, handler: dummyHandler,
priority: PriorityMatchDomain, priority: PriorityLocal,
}, },
}, },
expectedLocalQs: []dns.Question{{Name: zoneRecords[0].Name, Qtype: 1, Qclass: 1}}, expectedLocalQs: []dns.Question{{Name: zoneRecords[0].Name, Qtype: 1, Qclass: 1}},
@ -305,7 +305,7 @@ func TestUpdateDNSServer(t *testing.T) {
generateDummyHandler(zoneRecords[0].Name, nameServers).ID(): handlerWrapper{ generateDummyHandler(zoneRecords[0].Name, nameServers).ID(): handlerWrapper{
domain: zoneRecords[0].Name, domain: zoneRecords[0].Name,
handler: dummyHandler, handler: dummyHandler,
priority: PriorityMatchDomain, priority: PriorityUpstream,
}, },
}, },
initSerial: 0, initSerial: 0,
@ -321,7 +321,7 @@ func TestUpdateDNSServer(t *testing.T) {
generateDummyHandler(zoneRecords[0].Name, nameServers).ID(): handlerWrapper{ generateDummyHandler(zoneRecords[0].Name, nameServers).ID(): handlerWrapper{
domain: zoneRecords[0].Name, domain: zoneRecords[0].Name,
handler: dummyHandler, handler: dummyHandler,
priority: PriorityMatchDomain, priority: PriorityUpstream,
}, },
}, },
initSerial: 0, initSerial: 0,
@ -495,7 +495,7 @@ func TestDNSFakeResolverHandleUpdates(t *testing.T) {
"id1": handlerWrapper{ "id1": handlerWrapper{
domain: zoneRecords[0].Name, domain: zoneRecords[0].Name,
handler: &local.Resolver{}, handler: &local.Resolver{},
priority: PriorityMatchDomain, priority: PriorityUpstream,
}, },
} }
//dnsServer.localResolver.RegisteredMap = local.RegistrationMap{local.BuildRecordKey("netbird.cloud", dns.ClassINET, dns.TypeA): struct{}{}} //dnsServer.localResolver.RegisteredMap = local.RegistrationMap{local.BuildRecordKey("netbird.cloud", dns.ClassINET, dns.TypeA): struct{}{}}
@ -978,7 +978,7 @@ func TestHandlerChain_DomainPriorities(t *testing.T) {
} }
chain.AddHandler("example.com.", dnsRouteHandler, PriorityDNSRoute) chain.AddHandler("example.com.", dnsRouteHandler, PriorityDNSRoute)
chain.AddHandler("example.com.", upstreamHandler, PriorityMatchDomain) chain.AddHandler("example.com.", upstreamHandler, PriorityUpstream)
testCases := []struct { testCases := []struct {
name string name string
@ -1059,14 +1059,14 @@ func TestDefaultServer_UpdateMux(t *testing.T) {
handler: &mockHandler{ handler: &mockHandler{
Id: "upstream-group1", Id: "upstream-group1",
}, },
priority: PriorityMatchDomain, priority: PriorityUpstream,
}, },
"upstream-group2": { "upstream-group2": {
domain: "example.com", domain: "example.com",
handler: &mockHandler{ handler: &mockHandler{
Id: "upstream-group2", Id: "upstream-group2",
}, },
priority: PriorityMatchDomain - 1, priority: PriorityUpstream - 1,
}, },
} }
@ -1093,21 +1093,21 @@ func TestDefaultServer_UpdateMux(t *testing.T) {
handler: &mockHandler{ handler: &mockHandler{
Id: "upstream-group1", Id: "upstream-group1",
}, },
priority: PriorityMatchDomain, priority: PriorityUpstream,
}, },
"upstream-group2": { "upstream-group2": {
domain: "example.com", domain: "example.com",
handler: &mockHandler{ handler: &mockHandler{
Id: "upstream-group2", Id: "upstream-group2",
}, },
priority: PriorityMatchDomain - 1, priority: PriorityUpstream - 1,
}, },
"upstream-other": { "upstream-other": {
domain: "other.com", domain: "other.com",
handler: &mockHandler{ handler: &mockHandler{
Id: "upstream-other", Id: "upstream-other",
}, },
priority: PriorityMatchDomain, priority: PriorityUpstream,
}, },
} }
@ -1128,7 +1128,7 @@ func TestDefaultServer_UpdateMux(t *testing.T) {
handler: &mockHandler{ handler: &mockHandler{
Id: "upstream-group2", Id: "upstream-group2",
}, },
priority: PriorityMatchDomain - 1, priority: PriorityUpstream - 1,
}, },
}, },
expectedHandlers: map[string]string{ expectedHandlers: map[string]string{
@ -1146,7 +1146,7 @@ func TestDefaultServer_UpdateMux(t *testing.T) {
handler: &mockHandler{ handler: &mockHandler{
Id: "upstream-group1", Id: "upstream-group1",
}, },
priority: PriorityMatchDomain, priority: PriorityUpstream,
}, },
}, },
expectedHandlers: map[string]string{ expectedHandlers: map[string]string{
@ -1164,7 +1164,7 @@ func TestDefaultServer_UpdateMux(t *testing.T) {
handler: &mockHandler{ handler: &mockHandler{
Id: "upstream-group3", Id: "upstream-group3",
}, },
priority: PriorityMatchDomain + 1, priority: PriorityUpstream + 1,
}, },
// Keep existing groups with their original priorities // Keep existing groups with their original priorities
{ {
@ -1172,14 +1172,14 @@ func TestDefaultServer_UpdateMux(t *testing.T) {
handler: &mockHandler{ handler: &mockHandler{
Id: "upstream-group1", Id: "upstream-group1",
}, },
priority: PriorityMatchDomain, priority: PriorityUpstream,
}, },
{ {
domain: "example.com", domain: "example.com",
handler: &mockHandler{ handler: &mockHandler{
Id: "upstream-group2", Id: "upstream-group2",
}, },
priority: PriorityMatchDomain - 1, priority: PriorityUpstream - 1,
}, },
}, },
expectedHandlers: map[string]string{ expectedHandlers: map[string]string{
@ -1199,14 +1199,14 @@ func TestDefaultServer_UpdateMux(t *testing.T) {
handler: &mockHandler{ handler: &mockHandler{
Id: "upstream-group1", Id: "upstream-group1",
}, },
priority: PriorityMatchDomain, priority: PriorityUpstream,
}, },
{ {
domain: "example.com", domain: "example.com",
handler: &mockHandler{ handler: &mockHandler{
Id: "upstream-group2", Id: "upstream-group2",
}, },
priority: PriorityMatchDomain - 1, priority: PriorityUpstream - 1,
}, },
// Add group3 with lowest priority // Add group3 with lowest priority
{ {
@ -1214,7 +1214,7 @@ func TestDefaultServer_UpdateMux(t *testing.T) {
handler: &mockHandler{ handler: &mockHandler{
Id: "upstream-group3", Id: "upstream-group3",
}, },
priority: PriorityMatchDomain - 2, priority: PriorityUpstream - 2,
}, },
}, },
expectedHandlers: map[string]string{ expectedHandlers: map[string]string{
@ -1335,14 +1335,14 @@ func TestDefaultServer_UpdateMux(t *testing.T) {
handler: &mockHandler{ handler: &mockHandler{
Id: "upstream-group1", Id: "upstream-group1",
}, },
priority: PriorityMatchDomain, priority: PriorityUpstream,
}, },
{ {
domain: "other.com", domain: "other.com",
handler: &mockHandler{ handler: &mockHandler{
Id: "upstream-other", Id: "upstream-other",
}, },
priority: PriorityMatchDomain, priority: PriorityUpstream,
}, },
}, },
expectedHandlers: map[string]string{ expectedHandlers: map[string]string{
@ -1360,28 +1360,28 @@ func TestDefaultServer_UpdateMux(t *testing.T) {
handler: &mockHandler{ handler: &mockHandler{
Id: "upstream-group1", Id: "upstream-group1",
}, },
priority: PriorityMatchDomain, priority: PriorityUpstream,
}, },
{ {
domain: "example.com", domain: "example.com",
handler: &mockHandler{ handler: &mockHandler{
Id: "upstream-group2", Id: "upstream-group2",
}, },
priority: PriorityMatchDomain - 1, priority: PriorityUpstream - 1,
}, },
{ {
domain: "other.com", domain: "other.com",
handler: &mockHandler{ handler: &mockHandler{
Id: "upstream-other", Id: "upstream-other",
}, },
priority: PriorityMatchDomain, priority: PriorityUpstream,
}, },
{ {
domain: "new.com", domain: "new.com",
handler: &mockHandler{ handler: &mockHandler{
Id: "upstream-new", Id: "upstream-new",
}, },
priority: PriorityMatchDomain, priority: PriorityUpstream,
}, },
}, },
expectedHandlers: map[string]string{ expectedHandlers: map[string]string{
@ -1791,14 +1791,14 @@ func TestExtraDomainsRefCounting(t *testing.T) {
// Register domains from different handlers with same domain // Register domains from different handlers with same domain
server.RegisterHandler(domain.List{"*.shared.example.com"}, &MockHandler{}, PriorityDNSRoute) server.RegisterHandler(domain.List{"*.shared.example.com"}, &MockHandler{}, PriorityDNSRoute)
server.RegisterHandler(domain.List{"shared.example.com."}, &MockHandler{}, PriorityMatchDomain) server.RegisterHandler(domain.List{"shared.example.com."}, &MockHandler{}, PriorityUpstream)
// Verify refcount is 2 // Verify refcount is 2
zoneKey := toZone("shared.example.com") zoneKey := toZone("shared.example.com")
assert.Equal(t, 2, server.extraDomains[zoneKey], "Refcount should be 2 after registering same domain twice") assert.Equal(t, 2, server.extraDomains[zoneKey], "Refcount should be 2 after registering same domain twice")
// Deregister one handler // Deregister one handler
server.DeregisterHandler(domain.List{"shared.example.com"}, PriorityMatchDomain) server.DeregisterHandler(domain.List{"shared.example.com"}, PriorityUpstream)
// Verify refcount is 1 // Verify refcount is 1
assert.Equal(t, 1, server.extraDomains[zoneKey], "Refcount should be 1 after deregistering one handler") assert.Equal(t, 1, server.extraDomains[zoneKey], "Refcount should be 1 after deregistering one handler")
@ -1925,7 +1925,7 @@ func TestDomainCaseHandling(t *testing.T) {
} }
server.RegisterHandler(domain.List{"MIXED.example.com"}, &MockHandler{}, PriorityDefault) server.RegisterHandler(domain.List{"MIXED.example.com"}, &MockHandler{}, PriorityDefault)
server.RegisterHandler(domain.List{"mixed.EXAMPLE.com"}, &MockHandler{}, PriorityMatchDomain) server.RegisterHandler(domain.List{"mixed.EXAMPLE.com"}, &MockHandler{}, PriorityUpstream)
assert.Equal(t, 1, len(server.extraDomains), "Case differences should be normalized") assert.Equal(t, 1, len(server.extraDomains), "Case differences should be normalized")
@ -1945,3 +1945,111 @@ func TestDomainCaseHandling(t *testing.T) {
assert.Contains(t, domains, "config.example.com.", "Mixed case domain should be normalized and pre.sent") assert.Contains(t, domains, "config.example.com.", "Mixed case domain should be normalized and pre.sent")
assert.Contains(t, domains, "mixed.example.com.", "Mixed case domain should be normalized and present") assert.Contains(t, domains, "mixed.example.com.", "Mixed case domain should be normalized and present")
} }
func TestLocalResolverPriorityInServer(t *testing.T) {
server := &DefaultServer{
ctx: context.Background(),
wgInterface: &mocWGIface{},
handlerChain: NewHandlerChain(),
localResolver: local.NewResolver(),
service: &mockService{},
extraDomains: make(map[domain.Domain]int),
}
config := nbdns.Config{
ServiceEnable: true,
CustomZones: []nbdns.CustomZone{
{
Domain: "local.example.com",
Records: []nbdns.SimpleRecord{
{
Name: "test.local.example.com",
Type: int(dns.TypeA),
Class: nbdns.DefaultClass,
TTL: 300,
RData: "192.168.1.100",
},
},
},
},
NameServerGroups: []*nbdns.NameServerGroup{
{
Domains: []string{"local.example.com"}, // Same domain as local records
NameServers: []nbdns.NameServer{
{
IP: netip.MustParseAddr("8.8.8.8"),
NSType: nbdns.UDPNameServerType,
Port: 53,
},
},
},
},
}
localMuxUpdates, _, err := server.buildLocalHandlerUpdate(config.CustomZones)
assert.NoError(t, err)
upstreamMuxUpdates, err := server.buildUpstreamHandlerUpdate(config.NameServerGroups)
assert.NoError(t, err)
// Verify that local handler has higher priority than upstream for same domain
var localPriority, upstreamPriority int
localFound, upstreamFound := false, false
for _, update := range localMuxUpdates {
if update.domain == "local.example.com" {
localPriority = update.priority
localFound = true
}
}
for _, update := range upstreamMuxUpdates {
if update.domain == "local.example.com" {
upstreamPriority = update.priority
upstreamFound = true
}
}
assert.True(t, localFound, "Local handler should be found")
assert.True(t, upstreamFound, "Upstream handler should be found")
assert.Greater(t, localPriority, upstreamPriority,
"Local handler priority (%d) should be higher than upstream priority (%d)",
localPriority, upstreamPriority)
assert.Equal(t, PriorityLocal, localPriority, "Local handler should use PriorityLocal")
assert.Equal(t, PriorityUpstream, upstreamPriority, "Upstream handler should use PriorityUpstream")
}
func TestLocalResolverPriorityConstants(t *testing.T) {
// Test that priority constants are ordered correctly
assert.Greater(t, PriorityLocal, PriorityDNSRoute, "Local priority should be higher than DNS route")
assert.Greater(t, PriorityLocal, PriorityUpstream, "Local priority should be higher than upstream")
assert.Greater(t, PriorityUpstream, PriorityDefault, "Upstream priority should be higher than default")
// Test that local resolver uses the correct priority
server := &DefaultServer{
localResolver: local.NewResolver(),
}
config := nbdns.Config{
CustomZones: []nbdns.CustomZone{
{
Domain: "local.example.com",
Records: []nbdns.SimpleRecord{
{
Name: "test.local.example.com",
Type: int(dns.TypeA),
Class: nbdns.DefaultClass,
TTL: 300,
RData: "192.168.1.100",
},
},
},
},
}
localMuxUpdates, _, err := server.buildLocalHandlerUpdate(config.CustomZones)
assert.NoError(t, err)
assert.Len(t, localMuxUpdates, 1)
assert.Equal(t, PriorityLocal, localMuxUpdates[0].priority, "Local handler should use PriorityLocal")
assert.Equal(t, "local.example.com", localMuxUpdates[0].domain)
}

View File

@ -2,6 +2,7 @@ package dns
import ( import (
"context" "context"
"crypto/rand"
"crypto/sha256" "crypto/sha256"
"encoding/hex" "encoding/hex"
"errors" "errors"
@ -103,19 +104,21 @@ func (u *upstreamResolverBase) Stop() {
// ServeDNS handles a DNS request // ServeDNS handles a DNS request
func (u *upstreamResolverBase) ServeDNS(w dns.ResponseWriter, r *dns.Msg) { func (u *upstreamResolverBase) ServeDNS(w dns.ResponseWriter, r *dns.Msg) {
requestID := GenerateRequestID()
logger := log.WithField("request_id", requestID)
var err error var err error
defer func() { defer func() {
u.checkUpstreamFails(err) u.checkUpstreamFails(err)
}() }()
log.Tracef("received upstream question: domain=%s type=%v class=%v", r.Question[0].Name, r.Question[0].Qtype, r.Question[0].Qclass) logger.Tracef("received upstream question: domain=%s type=%v class=%v", r.Question[0].Name, r.Question[0].Qtype, r.Question[0].Qclass)
if r.Extra == nil { if r.Extra == nil {
r.MsgHdr.AuthenticatedData = true r.MsgHdr.AuthenticatedData = true
} }
select { select {
case <-u.ctx.Done(): case <-u.ctx.Done():
log.Tracef("%s has been stopped", u) logger.Tracef("%s has been stopped", u)
return return
default: default:
} }
@ -132,35 +135,35 @@ func (u *upstreamResolverBase) ServeDNS(w dns.ResponseWriter, r *dns.Msg) {
if err != nil { if err != nil {
if errors.Is(err, context.DeadlineExceeded) || isTimeout(err) { if errors.Is(err, context.DeadlineExceeded) || isTimeout(err) {
log.Warnf("upstream %s timed out for question domain=%s", upstream, r.Question[0].Name) logger.Warnf("upstream %s timed out for question domain=%s", upstream, r.Question[0].Name)
continue continue
} }
log.Warnf("failed to query upstream %s for question domain=%s: %s", upstream, r.Question[0].Name, err) logger.Warnf("failed to query upstream %s for question domain=%s: %s", upstream, r.Question[0].Name, err)
continue continue
} }
if rm == nil || !rm.Response { if rm == nil || !rm.Response {
log.Warnf("no response from upstream %s for question domain=%s", upstream, r.Question[0].Name) logger.Warnf("no response from upstream %s for question domain=%s", upstream, r.Question[0].Name)
continue continue
} }
u.successCount.Add(1) u.successCount.Add(1)
log.Tracef("took %s to query the upstream %s for question domain=%s", t, upstream, r.Question[0].Name) logger.Tracef("took %s to query the upstream %s for question domain=%s", t, upstream, r.Question[0].Name)
if err = w.WriteMsg(rm); err != nil { if err = w.WriteMsg(rm); err != nil {
log.Errorf("failed to write DNS response for question domain=%s: %s", r.Question[0].Name, err) logger.Errorf("failed to write DNS response for question domain=%s: %s", r.Question[0].Name, err)
} }
// count the fails only if they happen sequentially // count the fails only if they happen sequentially
u.failsCount.Store(0) u.failsCount.Store(0)
return return
} }
u.failsCount.Add(1) u.failsCount.Add(1)
log.Errorf("all queries to the %s failed for question domain=%s", u, r.Question[0].Name) logger.Errorf("all queries to the %s failed for question domain=%s", u, r.Question[0].Name)
m := new(dns.Msg) m := new(dns.Msg)
m.SetRcode(r, dns.RcodeServerFailure) m.SetRcode(r, dns.RcodeServerFailure)
if err := w.WriteMsg(m); err != nil { if err := w.WriteMsg(m); err != nil {
log.Errorf("failed to write error response for %s for question domain=%s: %s", u, r.Question[0].Name, err) logger.Errorf("failed to write error response for %s for question domain=%s: %s", u, r.Question[0].Name, err)
} }
} }
@ -385,3 +388,13 @@ func ExchangeWithFallback(ctx context.Context, client *dns.Client, r *dns.Msg, u
return rm, t, nil return rm, t, nil
} }
func GenerateRequestID() string {
bytes := make([]byte, 4)
_, err := rand.Read(bytes)
if err != nil {
log.Errorf("failed to generate request ID: %v", err)
return ""
}
return hex.EncodeToString(bytes)
}

View File

@ -84,3 +84,10 @@ func (u *upstreamResolver) isLocalResolver(upstream string) bool {
} }
return false return false
} }
func GetClientPrivate(ip netip.Addr, interfaceName string, dialTimeout time.Duration) (*dns.Client, error) {
return &dns.Client{
Timeout: dialTimeout,
Net: "udp",
}, nil
}

View File

@ -36,3 +36,10 @@ func newUpstreamResolver(
func (u *upstreamResolver) exchange(ctx context.Context, upstream string, r *dns.Msg) (rm *dns.Msg, t time.Duration, err error) { func (u *upstreamResolver) exchange(ctx context.Context, upstream string, r *dns.Msg) (rm *dns.Msg, t time.Duration, err error) {
return ExchangeWithFallback(ctx, &dns.Client{}, r, upstream) return ExchangeWithFallback(ctx, &dns.Client{}, r, upstream)
} }
func GetClientPrivate(ip netip.Addr, interfaceName string, dialTimeout time.Duration) (*dns.Client, error) {
return &dns.Client{
Timeout: dialTimeout,
Net: "udp",
}, nil
}

View File

@ -18,14 +18,20 @@ import (
nberrors "github.com/netbirdio/netbird/client/errors" nberrors "github.com/netbirdio/netbird/client/errors"
firewall "github.com/netbirdio/netbird/client/firewall/manager" firewall "github.com/netbirdio/netbird/client/firewall/manager"
"github.com/netbirdio/netbird/client/internal/peer" "github.com/netbirdio/netbird/client/internal/peer"
nbdns "github.com/netbirdio/netbird/dns"
"github.com/netbirdio/netbird/management/domain"
"github.com/netbirdio/netbird/route" "github.com/netbirdio/netbird/route"
) )
const errResolveFailed = "failed to resolve query for domain=%s: %v" const errResolveFailed = "failed to resolve query for domain=%s: %v"
const upstreamTimeout = 15 * time.Second const upstreamTimeout = 15 * time.Second
type resolver interface {
LookupNetIP(ctx context.Context, network, host string) ([]netip.Addr, error)
}
type firewaller interface {
UpdateSet(set firewall.Set, prefixes []netip.Prefix) error
}
type DNSForwarder struct { type DNSForwarder struct {
listenAddress string listenAddress string
ttl uint32 ttl uint32
@ -38,16 +44,18 @@ type DNSForwarder struct {
mutex sync.RWMutex mutex sync.RWMutex
fwdEntries []*ForwarderEntry fwdEntries []*ForwarderEntry
firewall firewall.Manager firewall firewaller
resolver resolver
} }
func NewDNSForwarder(listenAddress string, ttl uint32, firewall firewall.Manager, statusRecorder *peer.Status) *DNSForwarder { func NewDNSForwarder(listenAddress string, ttl uint32, firewall firewaller, statusRecorder *peer.Status) *DNSForwarder {
log.Debugf("creating DNS forwarder with listen_address=%s ttl=%d", listenAddress, ttl) log.Debugf("creating DNS forwarder with listen_address=%s ttl=%d", listenAddress, ttl)
return &DNSForwarder{ return &DNSForwarder{
listenAddress: listenAddress, listenAddress: listenAddress,
ttl: ttl, ttl: ttl,
firewall: firewall, firewall: firewall,
statusRecorder: statusRecorder, statusRecorder: statusRecorder,
resolver: net.DefaultResolver,
} }
} }
@ -57,14 +65,17 @@ func (f *DNSForwarder) Listen(entries []*ForwarderEntry) error {
// UDP server // UDP server
mux := dns.NewServeMux() mux := dns.NewServeMux()
f.mux = mux f.mux = mux
mux.HandleFunc(".", f.handleDNSQueryUDP)
f.dnsServer = &dns.Server{ f.dnsServer = &dns.Server{
Addr: f.listenAddress, Addr: f.listenAddress,
Net: "udp", Net: "udp",
Handler: mux, Handler: mux,
} }
// TCP server // TCP server
tcpMux := dns.NewServeMux() tcpMux := dns.NewServeMux()
f.tcpMux = tcpMux f.tcpMux = tcpMux
tcpMux.HandleFunc(".", f.handleDNSQueryTCP)
f.tcpServer = &dns.Server{ f.tcpServer = &dns.Server{
Addr: f.listenAddress, Addr: f.listenAddress,
Net: "tcp", Net: "tcp",
@ -87,30 +98,13 @@ func (f *DNSForwarder) Listen(entries []*ForwarderEntry) error {
// return the first error we get (e.g. bind failure or shutdown) // return the first error we get (e.g. bind failure or shutdown)
return <-errCh return <-errCh
} }
func (f *DNSForwarder) UpdateDomains(entries []*ForwarderEntry) { func (f *DNSForwarder) UpdateDomains(entries []*ForwarderEntry) {
f.mutex.Lock() f.mutex.Lock()
defer f.mutex.Unlock() defer f.mutex.Unlock()
if f.mux == nil {
log.Debug("DNS mux is nil, skipping domain update")
f.fwdEntries = entries
return
}
oldDomains := filterDomains(f.fwdEntries)
for _, d := range oldDomains {
f.mux.HandleRemove(d.PunycodeString())
f.tcpMux.HandleRemove(d.PunycodeString())
}
newDomains := filterDomains(entries)
for _, d := range newDomains {
f.mux.HandleFunc(d.PunycodeString(), f.handleDNSQueryUDP)
f.tcpMux.HandleFunc(d.PunycodeString(), f.handleDNSQueryTCP)
}
f.fwdEntries = entries f.fwdEntries = entries
log.Debugf("Updated domains from %v to %v", oldDomains, newDomains) log.Debugf("Updated DNS forwarder with %d domains", len(entries))
} }
func (f *DNSForwarder) Close(ctx context.Context) error { func (f *DNSForwarder) Close(ctx context.Context) error {
@ -157,22 +151,31 @@ func (f *DNSForwarder) handleDNSQuery(w dns.ResponseWriter, query *dns.Msg) *dns
return nil return nil
} }
mostSpecificResId, matchingEntries := f.getMatchingEntries(strings.TrimSuffix(domain, "."))
// query doesn't match any configured domain
if mostSpecificResId == "" {
resp.Rcode = dns.RcodeRefused
if err := w.WriteMsg(resp); err != nil {
log.Errorf("failed to write DNS response: %v", err)
}
return nil
}
ctx, cancel := context.WithTimeout(context.Background(), upstreamTimeout) ctx, cancel := context.WithTimeout(context.Background(), upstreamTimeout)
defer cancel() defer cancel()
ips, err := net.DefaultResolver.LookupNetIP(ctx, network, domain) ips, err := f.resolver.LookupNetIP(ctx, network, domain)
if err != nil { if err != nil {
f.handleDNSError(w, query, resp, domain, err) f.handleDNSError(w, query, resp, domain, err)
return nil return nil
} }
f.updateInternalState(domain, ips) f.updateInternalState(ips, mostSpecificResId, matchingEntries)
f.addIPsToResponse(resp, domain, ips) f.addIPsToResponse(resp, domain, ips)
return resp return resp
} }
func (f *DNSForwarder) handleDNSQueryUDP(w dns.ResponseWriter, query *dns.Msg) { func (f *DNSForwarder) handleDNSQueryUDP(w dns.ResponseWriter, query *dns.Msg) {
resp := f.handleDNSQuery(w, query) resp := f.handleDNSQuery(w, query)
if resp == nil { if resp == nil {
return return
@ -206,9 +209,8 @@ func (f *DNSForwarder) handleDNSQueryTCP(w dns.ResponseWriter, query *dns.Msg) {
} }
} }
func (f *DNSForwarder) updateInternalState(domain string, ips []netip.Addr) { func (f *DNSForwarder) updateInternalState(ips []netip.Addr, mostSpecificResId route.ResID, matchingEntries []*ForwarderEntry) {
var prefixes []netip.Prefix var prefixes []netip.Prefix
mostSpecificResId, matchingEntries := f.getMatchingEntries(strings.TrimSuffix(domain, "."))
if mostSpecificResId != "" { if mostSpecificResId != "" {
for _, ip := range ips { for _, ip := range ips {
var prefix netip.Prefix var prefix netip.Prefix
@ -339,16 +341,3 @@ func (f *DNSForwarder) getMatchingEntries(domain string) (route.ResID, []*Forwar
return selectedResId, matches return selectedResId, matches
} }
// filterDomains returns a list of normalized domains
func filterDomains(entries []*ForwarderEntry) domain.List {
newDomains := make(domain.List, 0, len(entries))
for _, d := range entries {
if d.Domain == "" {
log.Warn("empty domain in DNS forwarder")
continue
}
newDomains = append(newDomains, domain.Domain(nbdns.NormalizeZone(d.Domain.PunycodeString())))
}
return newDomains
}

View File

@ -1,11 +1,21 @@
package dnsfwd package dnsfwd
import ( import (
"context"
"fmt"
"net/netip"
"strings"
"testing" "testing"
"time"
"github.com/miekg/dns"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/stretchr/testify/mock"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
firewall "github.com/netbirdio/netbird/client/firewall/manager"
"github.com/netbirdio/netbird/client/internal/dns/test"
"github.com/netbirdio/netbird/client/internal/peer"
"github.com/netbirdio/netbird/management/domain" "github.com/netbirdio/netbird/management/domain"
"github.com/netbirdio/netbird/route" "github.com/netbirdio/netbird/route"
) )
@ -13,7 +23,7 @@ import (
func Test_getMatchingEntries(t *testing.T) { func Test_getMatchingEntries(t *testing.T) {
testCases := []struct { testCases := []struct {
name string name string
storedMappings map[string]route.ResID // key: domain pattern, value: resId storedMappings map[string]route.ResID
queryDomain string queryDomain string
expectedResId route.ResID expectedResId route.ResID
}{ }{
@ -44,7 +54,7 @@ func Test_getMatchingEntries(t *testing.T) {
{ {
name: "Wildcard pattern does not match different domain", name: "Wildcard pattern does not match different domain",
storedMappings: map[string]route.ResID{"*.example.com": "res4"}, storedMappings: map[string]route.ResID{"*.example.com": "res4"},
queryDomain: "foo.notexample.com", queryDomain: "foo.example.org",
expectedResId: "", expectedResId: "",
}, },
{ {
@ -101,3 +111,619 @@ func Test_getMatchingEntries(t *testing.T) {
}) })
} }
} }
type MockFirewall struct {
mock.Mock
}
func (m *MockFirewall) UpdateSet(set firewall.Set, prefixes []netip.Prefix) error {
args := m.Called(set, prefixes)
return args.Error(0)
}
type MockResolver struct {
mock.Mock
}
func (m *MockResolver) LookupNetIP(ctx context.Context, network, host string) ([]netip.Addr, error) {
args := m.Called(ctx, network, host)
return args.Get(0).([]netip.Addr), args.Error(1)
}
func TestDNSForwarder_SubdomainAccessLogic(t *testing.T) {
tests := []struct {
name string
configuredDomain string
queryDomain string
shouldMatch bool
expectedResID route.ResID
description string
}{
{
name: "exact domain match should be allowed",
configuredDomain: "example.com",
queryDomain: "example.com",
shouldMatch: true,
expectedResID: "test-res-id",
description: "Direct match to configured domain should work",
},
{
name: "subdomain access should be restricted",
configuredDomain: "example.com",
queryDomain: "mail.example.com",
shouldMatch: false,
expectedResID: "",
description: "Subdomain should not be accessible unless explicitly configured",
},
{
name: "wildcard should allow subdomains",
configuredDomain: "*.example.com",
queryDomain: "mail.example.com",
shouldMatch: true,
expectedResID: "test-res-id",
description: "Wildcard domains should allow subdomain access",
},
{
name: "wildcard should allow base domain",
configuredDomain: "*.example.com",
queryDomain: "example.com",
shouldMatch: true,
expectedResID: "test-res-id",
description: "Wildcard should also match the base domain",
},
{
name: "deep subdomain should be restricted",
configuredDomain: "example.com",
queryDomain: "deep.mail.example.com",
shouldMatch: false,
expectedResID: "",
description: "Deep subdomains should not be accessible",
},
{
name: "wildcard allows deep subdomains",
configuredDomain: "*.example.com",
queryDomain: "deep.mail.example.com",
shouldMatch: true,
expectedResID: "test-res-id",
description: "Wildcard should allow deep subdomains",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
forwarder := &DNSForwarder{}
d, err := domain.FromString(tt.configuredDomain)
require.NoError(t, err)
entries := []*ForwarderEntry{
{
Domain: d,
ResID: "test-res-id",
},
}
forwarder.UpdateDomains(entries)
resID, matchingEntries := forwarder.getMatchingEntries(tt.queryDomain)
if tt.shouldMatch {
assert.Equal(t, tt.expectedResID, resID, "Expected matching ResID")
assert.NotEmpty(t, matchingEntries, "Expected matching entries")
t.Logf("✓ Domain %s correctly matches pattern %s", tt.queryDomain, tt.configuredDomain)
} else {
assert.Equal(t, tt.expectedResID, resID, "Expected no ResID match")
assert.Empty(t, matchingEntries, "Expected no matching entries")
t.Logf("✓ Domain %s correctly does NOT match pattern %s", tt.queryDomain, tt.configuredDomain)
}
})
}
}
func TestDNSForwarder_UnauthorizedDomainAccess(t *testing.T) {
if testing.Short() {
t.Skip("Skipping integration test in short mode")
}
tests := []struct {
name string
configuredDomain string
queryDomain string
shouldResolve bool
description string
}{
{
name: "configured exact domain resolves",
configuredDomain: "example.com",
queryDomain: "example.com",
shouldResolve: true,
description: "Exact match should resolve",
},
{
name: "unauthorized subdomain blocked",
configuredDomain: "example.com",
queryDomain: "mail.example.com",
shouldResolve: false,
description: "Subdomain should be blocked without wildcard",
},
{
name: "wildcard allows subdomain",
configuredDomain: "*.example.com",
queryDomain: "mail.example.com",
shouldResolve: true,
description: "Wildcard should allow subdomain",
},
{
name: "wildcard allows base domain",
configuredDomain: "*.example.com",
queryDomain: "example.com",
shouldResolve: true,
description: "Wildcard should allow base domain",
},
{
name: "unrelated domain blocked",
configuredDomain: "example.com",
queryDomain: "example.org",
shouldResolve: false,
description: "Unrelated domain should be blocked",
},
{
name: "deep subdomain blocked",
configuredDomain: "example.com",
queryDomain: "deep.mail.example.com",
shouldResolve: false,
description: "Deep subdomain should be blocked",
},
{
name: "wildcard allows deep subdomain",
configuredDomain: "*.example.com",
queryDomain: "deep.mail.example.com",
shouldResolve: true,
description: "Wildcard should allow deep subdomain",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
mockFirewall := &MockFirewall{}
mockResolver := &MockResolver{}
if tt.shouldResolve {
mockFirewall.On("UpdateSet", mock.AnythingOfType("manager.Set"), mock.AnythingOfType("[]netip.Prefix")).Return(nil)
// Mock successful DNS resolution
fakeIP := netip.MustParseAddr("1.2.3.4")
mockResolver.On("LookupNetIP", mock.Anything, "ip4", dns.Fqdn(tt.queryDomain)).Return([]netip.Addr{fakeIP}, nil)
}
forwarder := NewDNSForwarder("127.0.0.1:0", 300, mockFirewall, &peer.Status{})
forwarder.resolver = mockResolver
d, err := domain.FromString(tt.configuredDomain)
require.NoError(t, err)
entries := []*ForwarderEntry{
{
Domain: d,
ResID: "test-res-id",
Set: firewall.NewDomainSet([]domain.Domain{d}),
},
}
forwarder.UpdateDomains(entries)
query := &dns.Msg{}
query.SetQuestion(dns.Fqdn(tt.queryDomain), dns.TypeA)
mockWriter := &test.MockResponseWriter{}
resp := forwarder.handleDNSQuery(mockWriter, query)
if tt.shouldResolve {
require.NotNil(t, resp, "Expected response for authorized domain")
require.Equal(t, dns.RcodeSuccess, resp.Rcode, "Expected successful response")
assert.NotEmpty(t, resp.Answer, "Expected DNS answer records")
time.Sleep(10 * time.Millisecond)
mockFirewall.AssertExpectations(t)
mockResolver.AssertExpectations(t)
} else {
if resp != nil {
assert.True(t, len(resp.Answer) == 0 || resp.Rcode != dns.RcodeSuccess,
"Unauthorized domain should not return successful answers")
}
mockFirewall.AssertNotCalled(t, "UpdateSet")
mockResolver.AssertNotCalled(t, "LookupNetIP")
}
})
}
}
func TestDNSForwarder_FirewallSetUpdates(t *testing.T) {
tests := []struct {
name string
configuredDomains []string
query string
mockIP string
shouldResolve bool
expectedSetCount int // How many sets should be updated
description string
}{
{
name: "exact domain gets firewall update",
configuredDomains: []string{"example.com"},
query: "example.com",
mockIP: "1.1.1.1",
shouldResolve: true,
expectedSetCount: 1,
description: "Single exact match updates one set",
},
{
name: "wildcard domain gets firewall update",
configuredDomains: []string{"*.example.com"},
query: "mail.example.com",
mockIP: "1.1.1.2",
shouldResolve: true,
expectedSetCount: 1,
description: "Wildcard match updates one set",
},
{
name: "overlapping exact and wildcard both get updates",
configuredDomains: []string{"*.example.com", "mail.example.com"},
query: "mail.example.com",
mockIP: "1.1.1.3",
shouldResolve: true,
expectedSetCount: 2,
description: "Both exact and wildcard sets should be updated",
},
{
name: "unauthorized domain gets no firewall update",
configuredDomains: []string{"example.com"},
query: "mail.example.com",
mockIP: "1.1.1.4",
shouldResolve: false,
expectedSetCount: 0,
description: "No firewall update for unauthorized domains",
},
{
name: "multiple wildcards matching get all updated",
configuredDomains: []string{"*.example.com", "*.sub.example.com"},
query: "test.sub.example.com",
mockIP: "1.1.1.5",
shouldResolve: true,
expectedSetCount: 2,
description: "All matching wildcard sets should be updated",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
mockFirewall := &MockFirewall{}
mockResolver := &MockResolver{}
// Set up forwarder
forwarder := NewDNSForwarder("127.0.0.1:0", 300, mockFirewall, &peer.Status{})
forwarder.resolver = mockResolver
// Create entries and track sets
var entries []*ForwarderEntry
sets := make([]firewall.Set, 0)
for i, configDomain := range tt.configuredDomains {
d, err := domain.FromString(configDomain)
require.NoError(t, err)
set := firewall.NewDomainSet([]domain.Domain{d})
sets = append(sets, set)
entries = append(entries, &ForwarderEntry{
Domain: d,
ResID: route.ResID(fmt.Sprintf("res-%d", i)),
Set: set,
})
}
forwarder.UpdateDomains(entries)
// Set up mocks
if tt.shouldResolve {
fakeIP := netip.MustParseAddr(tt.mockIP)
mockResolver.On("LookupNetIP", mock.Anything, "ip4", dns.Fqdn(tt.query)).
Return([]netip.Addr{fakeIP}, nil).Once()
expectedPrefixes := []netip.Prefix{netip.PrefixFrom(fakeIP, 32)}
// Count how many sets should actually match
updateCount := 0
for i, entry := range entries {
domain := strings.ToLower(tt.query)
pattern := entry.Domain.PunycodeString()
matches := false
if strings.HasPrefix(pattern, "*.") {
baseDomain := strings.TrimPrefix(pattern, "*.")
if domain == baseDomain || strings.HasSuffix(domain, "."+baseDomain) {
matches = true
}
} else if domain == pattern {
matches = true
}
if matches {
mockFirewall.On("UpdateSet", sets[i], expectedPrefixes).Return(nil).Once()
updateCount++
}
}
assert.Equal(t, tt.expectedSetCount, updateCount,
"Expected %d sets to be updated, but mock expects %d",
tt.expectedSetCount, updateCount)
}
// Execute query
dnsQuery := &dns.Msg{}
dnsQuery.SetQuestion(dns.Fqdn(tt.query), dns.TypeA)
mockWriter := &test.MockResponseWriter{}
resp := forwarder.handleDNSQuery(mockWriter, dnsQuery)
// Verify response
if tt.shouldResolve {
require.NotNil(t, resp, "Expected response for authorized domain")
require.Equal(t, dns.RcodeSuccess, resp.Rcode)
require.NotEmpty(t, resp.Answer)
} else if resp != nil {
assert.True(t, resp.Rcode == dns.RcodeRefused || len(resp.Answer) == 0,
"Unauthorized domain should be refused or have no answers")
}
// Verify all mock expectations were met
mockFirewall.AssertExpectations(t)
mockResolver.AssertExpectations(t)
})
}
}
// Test to verify that multiple IPs for one domain result in all prefixes being sent together
func TestDNSForwarder_MultipleIPsInSingleUpdate(t *testing.T) {
mockFirewall := &MockFirewall{}
mockResolver := &MockResolver{}
forwarder := NewDNSForwarder("127.0.0.1:0", 300, mockFirewall, &peer.Status{})
forwarder.resolver = mockResolver
// Configure a single domain
d, err := domain.FromString("example.com")
require.NoError(t, err)
set := firewall.NewDomainSet([]domain.Domain{d})
entries := []*ForwarderEntry{{
Domain: d,
ResID: "test-res",
Set: set,
}}
forwarder.UpdateDomains(entries)
// Mock resolver returns multiple IPs
ips := []netip.Addr{
netip.MustParseAddr("1.1.1.1"),
netip.MustParseAddr("1.1.1.2"),
netip.MustParseAddr("1.1.1.3"),
}
mockResolver.On("LookupNetIP", mock.Anything, "ip4", "example.com.").
Return(ips, nil).Once()
// Expect ONE UpdateSet call with ALL prefixes
expectedPrefixes := []netip.Prefix{
netip.PrefixFrom(ips[0], 32),
netip.PrefixFrom(ips[1], 32),
netip.PrefixFrom(ips[2], 32),
}
mockFirewall.On("UpdateSet", set, expectedPrefixes).Return(nil).Once()
// Execute query
query := &dns.Msg{}
query.SetQuestion("example.com.", dns.TypeA)
mockWriter := &test.MockResponseWriter{}
resp := forwarder.handleDNSQuery(mockWriter, query)
// Verify response contains all IPs
require.NotNil(t, resp)
require.Equal(t, dns.RcodeSuccess, resp.Rcode)
require.Len(t, resp.Answer, 3, "Should have 3 answer records")
// Verify mocks
mockFirewall.AssertExpectations(t)
mockResolver.AssertExpectations(t)
}
func TestDNSForwarder_ResponseCodes(t *testing.T) {
tests := []struct {
name string
queryType uint16
queryDomain string
configured string
expectedCode int
description string
}{
{
name: "unauthorized domain returns REFUSED",
queryType: dns.TypeA,
queryDomain: "evil.com",
configured: "example.com",
expectedCode: dns.RcodeRefused,
description: "RFC compliant REFUSED for unauthorized queries",
},
{
name: "unsupported query type returns NOTIMP",
queryType: dns.TypeMX,
queryDomain: "example.com",
configured: "example.com",
expectedCode: dns.RcodeNotImplemented,
description: "RFC compliant NOTIMP for unsupported types",
},
{
name: "CNAME query returns NOTIMP",
queryType: dns.TypeCNAME,
queryDomain: "example.com",
configured: "example.com",
expectedCode: dns.RcodeNotImplemented,
description: "CNAME queries not supported",
},
{
name: "TXT query returns NOTIMP",
queryType: dns.TypeTXT,
queryDomain: "example.com",
configured: "example.com",
expectedCode: dns.RcodeNotImplemented,
description: "TXT queries not supported",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
forwarder := NewDNSForwarder("127.0.0.1:0", 300, nil, &peer.Status{})
d, err := domain.FromString(tt.configured)
require.NoError(t, err)
entries := []*ForwarderEntry{{Domain: d, ResID: "test-res"}}
forwarder.UpdateDomains(entries)
query := &dns.Msg{}
query.SetQuestion(dns.Fqdn(tt.queryDomain), tt.queryType)
// Capture the written response
var writtenResp *dns.Msg
mockWriter := &test.MockResponseWriter{
WriteMsgFunc: func(m *dns.Msg) error {
writtenResp = m
return nil
},
}
_ = forwarder.handleDNSQuery(mockWriter, query)
// Check the response written to the writer
require.NotNil(t, writtenResp, "Expected response to be written")
assert.Equal(t, tt.expectedCode, writtenResp.Rcode, tt.description)
})
}
}
func TestDNSForwarder_TCPTruncation(t *testing.T) {
// Test that large UDP responses are truncated with TC bit set
mockResolver := &MockResolver{}
forwarder := NewDNSForwarder("127.0.0.1:0", 300, nil, &peer.Status{})
forwarder.resolver = mockResolver
d, _ := domain.FromString("example.com")
entries := []*ForwarderEntry{{Domain: d, ResID: "test-res"}}
forwarder.UpdateDomains(entries)
// Mock many IPs to create a large response
var manyIPs []netip.Addr
for i := 0; i < 100; i++ {
manyIPs = append(manyIPs, netip.MustParseAddr(fmt.Sprintf("1.1.1.%d", i%256)))
}
mockResolver.On("LookupNetIP", mock.Anything, "ip4", "example.com.").Return(manyIPs, nil)
// Query without EDNS0
query := &dns.Msg{}
query.SetQuestion("example.com.", dns.TypeA)
var writtenResp *dns.Msg
mockWriter := &test.MockResponseWriter{
WriteMsgFunc: func(m *dns.Msg) error {
writtenResp = m
return nil
},
}
forwarder.handleDNSQueryUDP(mockWriter, query)
require.NotNil(t, writtenResp)
assert.True(t, writtenResp.Truncated, "Large response should be truncated")
assert.LessOrEqual(t, writtenResp.Len(), dns.MinMsgSize, "Response should fit in minimum UDP size")
}
func TestDNSForwarder_MultipleOverlappingPatterns(t *testing.T) {
// Test complex overlapping pattern scenarios
mockFirewall := &MockFirewall{}
mockResolver := &MockResolver{}
forwarder := NewDNSForwarder("127.0.0.1:0", 300, mockFirewall, &peer.Status{})
forwarder.resolver = mockResolver
// Set up complex overlapping patterns
patterns := []string{
"*.example.com", // Matches all subdomains
"*.mail.example.com", // More specific wildcard
"smtp.mail.example.com", // Exact match
"example.com", // Base domain
}
var entries []*ForwarderEntry
sets := make(map[string]firewall.Set)
for _, pattern := range patterns {
d, _ := domain.FromString(pattern)
set := firewall.NewDomainSet([]domain.Domain{d})
sets[pattern] = set
entries = append(entries, &ForwarderEntry{
Domain: d,
ResID: route.ResID("res-" + pattern),
Set: set,
})
}
forwarder.UpdateDomains(entries)
// Test smtp.mail.example.com - should match 3 patterns
fakeIP := netip.MustParseAddr("1.2.3.4")
mockResolver.On("LookupNetIP", mock.Anything, "ip4", "smtp.mail.example.com.").Return([]netip.Addr{fakeIP}, nil)
expectedPrefix := netip.PrefixFrom(fakeIP, 32)
// All three matching patterns should get firewall updates
mockFirewall.On("UpdateSet", sets["smtp.mail.example.com"], []netip.Prefix{expectedPrefix}).Return(nil)
mockFirewall.On("UpdateSet", sets["*.mail.example.com"], []netip.Prefix{expectedPrefix}).Return(nil)
mockFirewall.On("UpdateSet", sets["*.example.com"], []netip.Prefix{expectedPrefix}).Return(nil)
query := &dns.Msg{}
query.SetQuestion("smtp.mail.example.com.", dns.TypeA)
mockWriter := &test.MockResponseWriter{}
resp := forwarder.handleDNSQuery(mockWriter, query)
require.NotNil(t, resp)
assert.Equal(t, dns.RcodeSuccess, resp.Rcode)
// Verify all three sets were updated
mockFirewall.AssertExpectations(t)
// Verify the most specific ResID was selected
// (exact match should win over wildcards)
resID, matches := forwarder.getMatchingEntries("smtp.mail.example.com")
assert.Equal(t, route.ResID("res-smtp.mail.example.com"), resID)
assert.Len(t, matches, 3, "Should match 3 patterns")
}
func TestDNSForwarder_EmptyQuery(t *testing.T) {
// Test handling of malformed query with no questions
forwarder := NewDNSForwarder("127.0.0.1:0", 300, nil, &peer.Status{})
query := &dns.Msg{}
// Don't set any question
writeCalled := false
mockWriter := &test.MockResponseWriter{
WriteMsgFunc: func(m *dns.Msg) error {
writeCalled = true
return nil
},
}
resp := forwarder.handleDNSQuery(mockWriter, query)
assert.Nil(t, resp, "Should return nil for empty query")
assert.False(t, writeCalled, "Should not write response for empty query")
}

View File

@ -1527,6 +1527,7 @@ func (e *Engine) newWgIface() (*iface.WGIface, error) {
MTU: iface.DefaultMTU, MTU: iface.DefaultMTU,
TransportNet: transportNet, TransportNet: transportNet,
FilterFn: e.addrViaRoutes, FilterFn: e.addrViaRoutes,
DisableDNS: e.config.DisableDNS,
} }
switch runtime.GOOS { switch runtime.GOOS {

View File

@ -204,7 +204,7 @@ func (c *ConnTrack) handleEvent(event nfct.Event) {
eventStr = "Ended" eventStr = "Ended"
} }
log.Tracef("%s %s %s connection: %s:%d -> %s:%d", eventStr, direction, proto, srcIP, srcPort, dstIP, dstPort) log.Tracef("%s %s %s connection: %s:%d %s:%d", eventStr, direction, proto, srcIP, srcPort, dstIP, dstPort)
c.flowLogger.StoreEvent(nftypes.EventFields{ c.flowLogger.StoreEvent(nftypes.EventFields{
FlowID: flowID, FlowID: flowID,

View File

@ -575,13 +575,12 @@ func (d *Status) UpdatePeerFQDN(peerPubKey, fqdn string) error {
// FinishPeerListModifications this event invoke the notification // FinishPeerListModifications this event invoke the notification
func (d *Status) FinishPeerListModifications() { func (d *Status) FinishPeerListModifications() {
d.mux.Lock() d.mux.Lock()
defer d.mux.Unlock()
if !d.peerListChangedForNotification { if !d.peerListChangedForNotification {
d.mux.Unlock()
return return
} }
d.peerListChangedForNotification = false d.peerListChangedForNotification = false
d.mux.Unlock()
d.notifyPeerListChanged() d.notifyPeerListChanged()

View File

@ -4,7 +4,6 @@ import (
"context" "context"
"fmt" "fmt"
"reflect" "reflect"
"runtime"
"time" "time"
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
@ -23,7 +22,7 @@ import (
const ( const (
handlerTypeDynamic = iota handlerTypeDynamic = iota
handlerTypeDomain handlerTypeDnsInterceptor
handlerTypeStatic handlerTypeStatic
) )
@ -566,13 +565,14 @@ func HandlerFromRoute(
useNewDNSRoute bool, useNewDNSRoute bool,
) RouteHandler { ) RouteHandler {
switch handlerType(rt, useNewDNSRoute) { switch handlerType(rt, useNewDNSRoute) {
case handlerTypeDomain: case handlerTypeDnsInterceptor:
return dnsinterceptor.New( return dnsinterceptor.New(
rt, rt,
routeRefCounter, routeRefCounter,
allowedIPsRefCounter, allowedIPsRefCounter,
statusRecorder, statusRecorder,
dnsServer, dnsServer,
wgInterface,
peerStore, peerStore,
) )
case handlerTypeDynamic: case handlerTypeDynamic:
@ -596,8 +596,8 @@ func handlerType(rt *route.Route, useNewDNSRoute bool) int {
return handlerTypeStatic return handlerTypeStatic
} }
if useNewDNSRoute && runtime.GOOS != "ios" { if useNewDNSRoute {
return handlerTypeDomain return handlerTypeDnsInterceptor
} }
return handlerTypeDynamic return handlerTypeDynamic
} }

View File

@ -12,6 +12,7 @@ import (
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
nberrors "github.com/netbirdio/netbird/client/errors" nberrors "github.com/netbirdio/netbird/client/errors"
"github.com/netbirdio/netbird/client/iface/wgaddr"
nbdns "github.com/netbirdio/netbird/client/internal/dns" nbdns "github.com/netbirdio/netbird/client/internal/dns"
"github.com/netbirdio/netbird/client/internal/dnsfwd" "github.com/netbirdio/netbird/client/internal/dnsfwd"
"github.com/netbirdio/netbird/client/internal/peer" "github.com/netbirdio/netbird/client/internal/peer"
@ -23,6 +24,11 @@ import (
type domainMap map[domain.Domain][]netip.Prefix type domainMap map[domain.Domain][]netip.Prefix
type wgInterface interface {
Name() string
Address() wgaddr.Address
}
type DnsInterceptor struct { type DnsInterceptor struct {
mu sync.RWMutex mu sync.RWMutex
route *route.Route route *route.Route
@ -32,6 +38,7 @@ type DnsInterceptor struct {
dnsServer nbdns.Server dnsServer nbdns.Server
currentPeerKey string currentPeerKey string
interceptedDomains domainMap interceptedDomains domainMap
wgInterface wgInterface
peerStore *peerstore.Store peerStore *peerstore.Store
} }
@ -41,6 +48,7 @@ func New(
allowedIPsRefCounter *refcounter.AllowedIPsRefCounter, allowedIPsRefCounter *refcounter.AllowedIPsRefCounter,
statusRecorder *peer.Status, statusRecorder *peer.Status,
dnsServer nbdns.Server, dnsServer nbdns.Server,
wgInterface wgInterface,
peerStore *peerstore.Store, peerStore *peerstore.Store,
) *DnsInterceptor { ) *DnsInterceptor {
return &DnsInterceptor{ return &DnsInterceptor{
@ -49,6 +57,7 @@ func New(
allowedIPsRefcounter: allowedIPsRefCounter, allowedIPsRefcounter: allowedIPsRefCounter,
statusRecorder: statusRecorder, statusRecorder: statusRecorder,
dnsServer: dnsServer, dnsServer: dnsServer,
wgInterface: wgInterface,
interceptedDomains: make(domainMap), interceptedDomains: make(domainMap),
peerStore: peerStore, peerStore: peerStore,
} }
@ -135,15 +144,18 @@ func (d *DnsInterceptor) RemoveAllowedIPs() error {
// ServeDNS implements the dns.Handler interface // ServeDNS implements the dns.Handler interface
func (d *DnsInterceptor) ServeDNS(w dns.ResponseWriter, r *dns.Msg) { func (d *DnsInterceptor) ServeDNS(w dns.ResponseWriter, r *dns.Msg) {
requestID := nbdns.GenerateRequestID()
logger := log.WithField("request_id", requestID)
if len(r.Question) == 0 { if len(r.Question) == 0 {
return return
} }
log.Tracef("received DNS request for domain=%s type=%v class=%v", logger.Tracef("received DNS request for domain=%s type=%v class=%v",
r.Question[0].Name, r.Question[0].Qtype, r.Question[0].Qclass) r.Question[0].Name, r.Question[0].Qtype, r.Question[0].Qclass)
// pass if non A/AAAA query // pass if non A/AAAA query
if r.Question[0].Qtype != dns.TypeA && r.Question[0].Qtype != dns.TypeAAAA { if r.Question[0].Qtype != dns.TypeA && r.Question[0].Qtype != dns.TypeAAAA {
d.continueToNextHandler(w, r, "non A/AAAA query") d.continueToNextHandler(w, r, logger, "non A/AAAA query")
return return
} }
@ -152,29 +164,32 @@ func (d *DnsInterceptor) ServeDNS(w dns.ResponseWriter, r *dns.Msg) {
d.mu.RUnlock() d.mu.RUnlock()
if peerKey == "" { if peerKey == "" {
d.writeDNSError(w, r, "no current peer key") d.writeDNSError(w, r, logger, "no current peer key")
return return
} }
upstreamIP, err := d.getUpstreamIP(peerKey) upstreamIP, err := d.getUpstreamIP(peerKey)
if err != nil { if err != nil {
d.writeDNSError(w, r, fmt.Sprintf("get upstream IP: %v", err)) d.writeDNSError(w, r, logger, fmt.Sprintf("get upstream IP: %v", err))
return
}
client, err := nbdns.GetClientPrivate(d.wgInterface.Address().IP, d.wgInterface.Name(), nbdns.UpstreamTimeout)
if err != nil {
d.writeDNSError(w, r, logger, fmt.Sprintf("create DNS client: %v", err))
return return
} }
if r.Extra == nil { if r.Extra == nil {
r.MsgHdr.AuthenticatedData = true r.MsgHdr.AuthenticatedData = true
} }
client := &dns.Client{
Timeout: nbdns.UpstreamTimeout,
Net: "udp",
}
upstream := fmt.Sprintf("%s:%d", upstreamIP.String(), dnsfwd.ListenPort) upstream := fmt.Sprintf("%s:%d", upstreamIP.String(), dnsfwd.ListenPort)
reply, _, err := nbdns.ExchangeWithFallback(context.TODO(), client, r, upstream) reply, _, err := nbdns.ExchangeWithFallback(context.TODO(), client, r, upstream)
if err != nil { if err != nil {
log.Errorf("failed to exchange DNS request with %s (%s) for domain=%s: %v", upstreamIP.String(), peerKey, r.Question[0].Name, err) logger.Errorf("failed to exchange DNS request with %s (%s) for domain=%s: %v", upstreamIP.String(), peerKey, r.Question[0].Name, err)
if err := w.WriteMsg(&dns.Msg{MsgHdr: dns.MsgHdr{Rcode: dns.RcodeServerFailure, Id: r.Id}}); err != nil { if err := w.WriteMsg(&dns.Msg{MsgHdr: dns.MsgHdr{Rcode: dns.RcodeServerFailure, Id: r.Id}}); err != nil {
log.Errorf("failed writing DNS response: %v", err) logger.Errorf("failed writing DNS response: %v", err)
} }
return return
} }
@ -184,34 +199,34 @@ func (d *DnsInterceptor) ServeDNS(w dns.ResponseWriter, r *dns.Msg) {
answer = reply.Answer answer = reply.Answer
} }
log.Tracef("upstream %s (%s) DNS response for domain=%s answers=%v", upstreamIP.String(), peerKey, r.Question[0].Name, answer) logger.Tracef("upstream %s (%s) DNS response for domain=%s answers=%v", upstreamIP.String(), peerKey, r.Question[0].Name, answer)
reply.Id = r.Id reply.Id = r.Id
if err := d.writeMsg(w, reply); err != nil { if err := d.writeMsg(w, reply); err != nil {
log.Errorf("failed writing DNS response: %v", err) logger.Errorf("failed writing DNS response: %v", err)
} }
} }
func (d *DnsInterceptor) writeDNSError(w dns.ResponseWriter, r *dns.Msg, reason string) { func (d *DnsInterceptor) writeDNSError(w dns.ResponseWriter, r *dns.Msg, logger *log.Entry, reason string) {
log.Warnf("failed to query upstream for domain=%s: %s", r.Question[0].Name, reason) logger.Warnf("failed to query upstream for domain=%s: %s", r.Question[0].Name, reason)
resp := new(dns.Msg) resp := new(dns.Msg)
resp.SetRcode(r, dns.RcodeServerFailure) resp.SetRcode(r, dns.RcodeServerFailure)
if err := w.WriteMsg(resp); err != nil { if err := w.WriteMsg(resp); err != nil {
log.Errorf("failed to write DNS error response: %v", err) logger.Errorf("failed to write DNS error response: %v", err)
} }
} }
// continueToNextHandler signals the handler chain to try the next handler // continueToNextHandler signals the handler chain to try the next handler
func (d *DnsInterceptor) continueToNextHandler(w dns.ResponseWriter, r *dns.Msg, reason string) { func (d *DnsInterceptor) continueToNextHandler(w dns.ResponseWriter, r *dns.Msg, logger *log.Entry, reason string) {
log.Tracef("continuing to next handler for domain=%s reason=%s", r.Question[0].Name, reason) logger.Tracef("continuing to next handler for domain=%s reason=%s", r.Question[0].Name, reason)
resp := new(dns.Msg) resp := new(dns.Msg)
resp.SetRcode(r, dns.RcodeNameError) resp.SetRcode(r, dns.RcodeNameError)
// Set Zero bit to signal handler chain to continue // Set Zero bit to signal handler chain to continue
resp.MsgHdr.Zero = true resp.MsgHdr.Zero = true
if err := w.WriteMsg(resp); err != nil { if err := w.WriteMsg(resp); err != nil {
log.Errorf("failed writing DNS continue response: %v", err) logger.Errorf("failed writing DNS continue response: %v", err)
} }
} }

View File

@ -15,7 +15,7 @@ import (
// MockManager is the mock instance of a route manager // MockManager is the mock instance of a route manager
type MockManager struct { type MockManager struct {
ClassifyRoutesFunc func(routes []*route.Route) (map[route.ID]*route.Route, route.HAMap) ClassifyRoutesFunc func(routes []*route.Route) (map[route.ID]*route.Route, route.HAMap)
UpdateRoutesFunc func (updateSerial uint64, serverRoutes map[route.ID]*route.Route, clientRoutes route.HAMap, useNewDNSRoute bool) error UpdateRoutesFunc func(updateSerial uint64, serverRoutes map[route.ID]*route.Route, clientRoutes route.HAMap, useNewDNSRoute bool) error
TriggerSelectionFunc func(haMap route.HAMap) TriggerSelectionFunc func(haMap route.HAMap)
GetRouteSelectorFunc func() *routeselector.RouteSelector GetRouteSelectorFunc func() *routeselector.RouteSelector
GetClientRoutesFunc func() route.HAMap GetClientRoutesFunc func() route.HAMap

View File

@ -32,7 +32,6 @@ func (n *Notifier) SetListener(listener listener.NetworkChangeListener) {
func (n *Notifier) SetInitialClientRoutes(clientRoutes []*route.Route) { func (n *Notifier) SetInitialClientRoutes(clientRoutes []*route.Route) {
nets := make([]string, 0) nets := make([]string, 0)
for _, r := range clientRoutes { for _, r := range clientRoutes {
// filter out domain routes
if r.IsDynamic() { if r.IsDynamic() {
continue continue
} }
@ -46,30 +45,27 @@ func (n *Notifier) OnNewRoutes(idMap route.HAMap) {
if runtime.GOOS != "android" { if runtime.GOOS != "android" {
return return
} }
newNets := make([]string, 0)
var newNets []string
for _, routes := range idMap { for _, routes := range idMap {
for _, r := range routes { for _, r := range routes {
if r.IsDynamic() {
continue
}
newNets = append(newNets, r.Network.String()) newNets = append(newNets, r.Network.String())
} }
} }
sort.Strings(newNets) sort.Strings(newNets)
switch runtime.GOOS { if !n.hasDiff(n.initialRouteRanges, newNets) {
case "android": return
if !n.hasDiff(n.initialRouteRanges, newNets) {
return
}
default:
if !n.hasDiff(n.routeRanges, newNets) {
return
}
} }
n.routeRanges = newNets n.routeRanges = newNets
n.notify() n.notify()
} }
// OnNewPrefixes is called from iOS only
func (n *Notifier) OnNewPrefixes(prefixes []netip.Prefix) { func (n *Notifier) OnNewPrefixes(prefixes []netip.Prefix) {
newNets := make([]string, 0) newNets := make([]string, 0)
for _, prefix := range prefixes { for _, prefix := range prefixes {
@ -77,19 +73,11 @@ func (n *Notifier) OnNewPrefixes(prefixes []netip.Prefix) {
} }
sort.Strings(newNets) sort.Strings(newNets)
switch runtime.GOOS { if !n.hasDiff(n.routeRanges, newNets) {
case "android": return
if !n.hasDiff(n.initialRouteRanges, newNets) {
return
}
default:
if !n.hasDiff(n.routeRanges, newNets) {
return
}
} }
n.routeRanges = newNets n.routeRanges = newNets
n.notify() n.notify()
} }

View File

@ -10,10 +10,11 @@ type StatusType string
const ( const (
StatusIdle StatusType = "Idle" StatusIdle StatusType = "Idle"
StatusConnecting StatusType = "Connecting" StatusConnecting StatusType = "Connecting"
StatusConnected StatusType = "Connected" StatusConnected StatusType = "Connected"
StatusNeedsLogin StatusType = "NeedsLogin" StatusNeedsLogin StatusType = "NeedsLogin"
StatusLoginFailed StatusType = "LoginFailed" StatusLoginFailed StatusType = "LoginFailed"
StatusSessionExpired StatusType = "SessionExpired"
) )
// CtxInitState setup context state into the context tree. // CtxInitState setup context state into the context tree.

View File

@ -8,6 +8,7 @@ import (
"runtime" "runtime"
"strconv" "strconv"
"sync" "sync"
"sync/atomic"
"time" "time"
"github.com/cenkalti/backoff/v4" "github.com/cenkalti/backoff/v4"
@ -66,6 +67,7 @@ type Server struct {
lastProbe time.Time lastProbe time.Time
persistNetworkMap bool persistNetworkMap bool
isSessionActive atomic.Bool
} }
type oauthAuthFlow struct { type oauthAuthFlow struct {
@ -567,9 +569,6 @@ func (s *Server) WaitSSOLogin(callerCtx context.Context, msg *proto.WaitSSOLogin
tokenInfo, err := s.oauthAuthFlow.flow.WaitToken(waitCTX, flowInfo) tokenInfo, err := s.oauthAuthFlow.flow.WaitToken(waitCTX, flowInfo)
if err != nil { if err != nil {
if err == context.Canceled {
return nil, nil //nolint:nilnil
}
s.mutex.Lock() s.mutex.Lock()
s.oauthAuthFlow.expiresAt = time.Now() s.oauthAuthFlow.expiresAt = time.Now()
s.mutex.Unlock() s.mutex.Unlock()
@ -640,6 +639,7 @@ func (s *Server) Up(callerCtx context.Context, _ *proto.UpRequest) (*proto.UpRes
for { for {
select { select {
case <-runningChan: case <-runningChan:
s.isSessionActive.Store(true)
return &proto.UpResponse{}, nil return &proto.UpResponse{}, nil
case <-callerCtx.Done(): case <-callerCtx.Done():
log.Debug("context done, stopping the wait for engine to become ready") log.Debug("context done, stopping the wait for engine to become ready")
@ -668,6 +668,7 @@ func (s *Server) Down(ctx context.Context, _ *proto.DownRequest) (*proto.DownRes
log.Errorf("failed to shut down properly: %v", err) log.Errorf("failed to shut down properly: %v", err)
return nil, err return nil, err
} }
s.isSessionActive.Store(false)
state := internal.CtxGetState(s.rootCtx) state := internal.CtxGetState(s.rootCtx)
state.Set(internal.StatusIdle) state.Set(internal.StatusIdle)
@ -694,6 +695,12 @@ func (s *Server) Status(
return nil, err return nil, err
} }
if status == internal.StatusNeedsLogin && s.isSessionActive.Load() {
log.Debug("status requested while session is active, returning SessionExpired")
status = internal.StatusSessionExpired
s.isSessionActive.Store(false)
}
statusResponse := proto.StatusResponse{Status: string(status), DaemonVersion: version.NetbirdVersion()} statusResponse := proto.StatusResponse{Status: string(status), DaemonVersion: version.NetbirdVersion()}
s.statusRecorder.UpdateManagementAddress(s.config.ManagementURL.String()) s.statusRecorder.UpdateManagementAddress(s.config.ManagementURL.String())

View File

@ -59,16 +59,16 @@ type Info struct {
Environment Environment Environment Environment
Files []File // for posture checks Files []File // for posture checks
RosenpassEnabled bool RosenpassEnabled bool
RosenpassPermissive bool RosenpassPermissive bool
ServerSSHAllowed bool ServerSSHAllowed bool
DisableClientRoutes bool DisableClientRoutes bool
DisableServerRoutes bool DisableServerRoutes bool
DisableDNS bool DisableDNS bool
DisableFirewall bool DisableFirewall bool
BlockLANAccess bool BlockLANAccess bool
BlockInbound bool BlockInbound bool
LazyConnectionEnabled bool LazyConnectionEnabled bool
} }

View File

@ -20,7 +20,10 @@ import (
"fyne.io/fyne/v2" "fyne.io/fyne/v2"
"fyne.io/fyne/v2/app" "fyne.io/fyne/v2/app"
"fyne.io/fyne/v2/canvas"
"fyne.io/fyne/v2/container"
"fyne.io/fyne/v2/dialog" "fyne.io/fyne/v2/dialog"
"fyne.io/fyne/v2/layout"
"fyne.io/fyne/v2/theme" "fyne.io/fyne/v2/theme"
"fyne.io/fyne/v2/widget" "fyne.io/fyne/v2/widget"
"fyne.io/systray" "fyne.io/systray"
@ -51,7 +54,7 @@ const (
) )
func main() { func main() {
daemonAddr, showSettings, showNetworks, showDebug, errorMsg, saveLogsInFile := parseFlags() daemonAddr, showSettings, showNetworks, showLoginURL, showDebug, errorMsg, saveLogsInFile := parseFlags()
// Initialize file logging if needed. // Initialize file logging if needed.
var logFile string var logFile string
@ -77,13 +80,13 @@ func main() {
} }
// Create the service client (this also builds the settings or networks UI if requested). // Create the service client (this also builds the settings or networks UI if requested).
client := newServiceClient(daemonAddr, logFile, a, showSettings, showNetworks, showDebug) client := newServiceClient(daemonAddr, logFile, a, showSettings, showNetworks, showLoginURL, showDebug)
// Watch for theme/settings changes to update the icon. // Watch for theme/settings changes to update the icon.
go watchSettingsChanges(a, client) go watchSettingsChanges(a, client)
// Run in window mode if any UI flag was set. // Run in window mode if any UI flag was set.
if showSettings || showNetworks || showDebug { if showSettings || showNetworks || showDebug || showLoginURL {
a.Run() a.Run()
return return
} }
@ -104,7 +107,7 @@ func main() {
} }
// parseFlags reads and returns all needed command-line flags. // parseFlags reads and returns all needed command-line flags.
func parseFlags() (daemonAddr string, showSettings, showNetworks, showDebug bool, errorMsg string, saveLogsInFile bool) { func parseFlags() (daemonAddr string, showSettings, showNetworks, showLoginURL, showDebug bool, errorMsg string, saveLogsInFile bool) {
defaultDaemonAddr := "unix:///var/run/netbird.sock" defaultDaemonAddr := "unix:///var/run/netbird.sock"
if runtime.GOOS == "windows" { if runtime.GOOS == "windows" {
defaultDaemonAddr = "tcp://127.0.0.1:41731" defaultDaemonAddr = "tcp://127.0.0.1:41731"
@ -112,6 +115,7 @@ func parseFlags() (daemonAddr string, showSettings, showNetworks, showDebug bool
flag.StringVar(&daemonAddr, "daemon-addr", defaultDaemonAddr, "Daemon service address to serve CLI requests [unix|tcp]://[path|host:port]") flag.StringVar(&daemonAddr, "daemon-addr", defaultDaemonAddr, "Daemon service address to serve CLI requests [unix|tcp]://[path|host:port]")
flag.BoolVar(&showSettings, "settings", false, "run settings window") flag.BoolVar(&showSettings, "settings", false, "run settings window")
flag.BoolVar(&showNetworks, "networks", false, "run networks window") flag.BoolVar(&showNetworks, "networks", false, "run networks window")
flag.BoolVar(&showLoginURL, "login-url", false, "show login URL in a popup window")
flag.BoolVar(&showDebug, "debug", false, "run debug window") flag.BoolVar(&showDebug, "debug", false, "run debug window")
flag.StringVar(&errorMsg, "error-msg", "", "displays an error message window") flag.StringVar(&errorMsg, "error-msg", "", "displays an error message window")
flag.BoolVar(&saveLogsInFile, "use-log-file", false, fmt.Sprintf("save logs in a file: %s/netbird-ui-PID.log", os.TempDir())) flag.BoolVar(&saveLogsInFile, "use-log-file", false, fmt.Sprintf("save logs in a file: %s/netbird-ui-PID.log", os.TempDir()))
@ -253,6 +257,7 @@ type serviceClient struct {
exitNodeStates []exitNodeState exitNodeStates []exitNodeState
mExitNodeDeselectAll *systray.MenuItem mExitNodeDeselectAll *systray.MenuItem
logFile string logFile string
wLoginURL fyne.Window
} }
type menuHandler struct { type menuHandler struct {
@ -263,7 +268,7 @@ type menuHandler struct {
// newServiceClient instance constructor // newServiceClient instance constructor
// //
// This constructor also builds the UI elements for the settings window. // This constructor also builds the UI elements for the settings window.
func newServiceClient(addr string, logFile string, a fyne.App, showSettings bool, showNetworks bool, showDebug bool) *serviceClient { func newServiceClient(addr string, logFile string, a fyne.App, showSettings bool, showNetworks bool, showLoginURL bool, showDebug bool) *serviceClient {
ctx, cancel := context.WithCancel(context.Background()) ctx, cancel := context.WithCancel(context.Background())
s := &serviceClient{ s := &serviceClient{
ctx: ctx, ctx: ctx,
@ -286,6 +291,8 @@ func newServiceClient(addr string, logFile string, a fyne.App, showSettings bool
s.showSettingsUI() s.showSettingsUI()
case showNetworks: case showNetworks:
s.showNetworksUI() s.showNetworksUI()
case showLoginURL:
s.showLoginURL()
case showDebug: case showDebug:
s.showDebugUI() s.showDebugUI()
} }
@ -445,11 +452,11 @@ func (s *serviceClient) getSettingsForm() *widget.Form {
} }
} }
func (s *serviceClient) login() error { func (s *serviceClient) login(openURL bool) (*proto.LoginResponse, error) {
conn, err := s.getSrvClient(defaultFailTimeout) conn, err := s.getSrvClient(defaultFailTimeout)
if err != nil { if err != nil {
log.Errorf("get client: %v", err) log.Errorf("get client: %v", err)
return err return nil, err
} }
loginResp, err := conn.Login(s.ctx, &proto.LoginRequest{ loginResp, err := conn.Login(s.ctx, &proto.LoginRequest{
@ -457,24 +464,24 @@ func (s *serviceClient) login() error {
}) })
if err != nil { if err != nil {
log.Errorf("login to management URL with: %v", err) log.Errorf("login to management URL with: %v", err)
return err return nil, err
} }
if loginResp.NeedsSSOLogin { if loginResp.NeedsSSOLogin && openURL {
err = open.Run(loginResp.VerificationURIComplete) err = open.Run(loginResp.VerificationURIComplete)
if err != nil { if err != nil {
log.Errorf("opening the verification uri in the browser failed: %v", err) log.Errorf("opening the verification uri in the browser failed: %v", err)
return err return nil, err
} }
_, err = conn.WaitSSOLogin(s.ctx, &proto.WaitSSOLoginRequest{UserCode: loginResp.UserCode}) _, err = conn.WaitSSOLogin(s.ctx, &proto.WaitSSOLoginRequest{UserCode: loginResp.UserCode})
if err != nil { if err != nil {
log.Errorf("waiting sso login failed with: %v", err) log.Errorf("waiting sso login failed with: %v", err)
return err return nil, err
} }
} }
return nil return loginResp, nil
} }
func (s *serviceClient) menuUpClick() error { func (s *serviceClient) menuUpClick() error {
@ -486,7 +493,7 @@ func (s *serviceClient) menuUpClick() error {
return err return err
} }
err = s.login() _, err = s.login(true)
if err != nil { if err != nil {
log.Errorf("login failed with: %v", err) log.Errorf("login failed with: %v", err)
return err return err
@ -558,7 +565,7 @@ func (s *serviceClient) updateStatus() error {
defer s.updateIndicationLock.Unlock() defer s.updateIndicationLock.Unlock()
// notify the user when the session has expired // notify the user when the session has expired
if status.Status == string(internal.StatusNeedsLogin) { if status.Status == string(internal.StatusSessionExpired) {
s.onSessionExpire() s.onSessionExpire()
} }
@ -732,7 +739,6 @@ func (s *serviceClient) onTrayReady() {
go s.eventHandler.listen(s.ctx) go s.eventHandler.listen(s.ctx)
} }
func (s *serviceClient) attachOutput(cmd *exec.Cmd) *os.File { func (s *serviceClient) attachOutput(cmd *exec.Cmd) *os.File {
if s.logFile == "" { if s.logFile == "" {
// attach child's streams to parent's streams // attach child's streams to parent's streams
@ -871,17 +877,9 @@ func (s *serviceClient) onUpdateAvailable() {
// onSessionExpire sends a notification to the user when the session expires. // onSessionExpire sends a notification to the user when the session expires.
func (s *serviceClient) onSessionExpire() { func (s *serviceClient) onSessionExpire() {
s.sendNotification = true
if s.sendNotification { if s.sendNotification {
title := "Connection session expired" s.eventHandler.runSelfCommand("login-url", "true")
if runtime.GOOS == "darwin" {
title = "NetBird connection session expired"
}
s.app.SendNotification(
fyne.NewNotification(
title,
"Please re-authenticate to connect to the network",
),
)
s.sendNotification = false s.sendNotification = false
} }
} }
@ -955,9 +953,9 @@ func (s *serviceClient) updateConfig() error {
ServerSSHAllowed: &sshAllowed, ServerSSHAllowed: &sshAllowed,
RosenpassEnabled: &rosenpassEnabled, RosenpassEnabled: &rosenpassEnabled,
DisableAutoConnect: &disableAutoStart, DisableAutoConnect: &disableAutoStart,
DisableNotifications: &notificationsDisabled,
LazyConnectionEnabled: &lazyConnectionEnabled, LazyConnectionEnabled: &lazyConnectionEnabled,
BlockInbound: &blockInbound, BlockInbound: &blockInbound,
DisableNotifications: &notificationsDisabled,
} }
if err := s.restartClient(&loginRequest); err != nil { if err := s.restartClient(&loginRequest); err != nil {
@ -991,6 +989,87 @@ func (s *serviceClient) restartClient(loginRequest *proto.LoginRequest) error {
return nil return nil
} }
// showLoginURL creates a borderless window styled like a pop-up in the top-right corner using s.wLoginURL.
func (s *serviceClient) showLoginURL() {
resp, err := s.login(false)
if err != nil {
log.Errorf("failed to fetch login URL: %v", err)
return
}
verificationURL := resp.VerificationURIComplete
if verificationURL == "" {
verificationURL = resp.VerificationURI
}
if verificationURL == "" {
log.Error("no verification URL provided in the login response")
return
}
resIcon := fyne.NewStaticResource("netbird.png", iconAbout)
if s.wLoginURL == nil {
s.wLoginURL = s.app.NewWindow("NetBird Session Expired")
s.wLoginURL.Resize(fyne.NewSize(400, 200))
s.wLoginURL.SetIcon(resIcon)
}
// add a description label
label := widget.NewLabel("Your NetBird session has expired.\nPlease re-authenticate to continue using NetBird.")
btn := widget.NewButtonWithIcon("Re-authenticate", theme.ViewRefreshIcon(), func() {
conn, err := s.getSrvClient(defaultFailTimeout)
if err != nil {
log.Errorf("get client: %v", err)
return
}
if err := openURL(verificationURL); err != nil {
log.Errorf("failed to open login URL: %v", err)
return
}
_, err = conn.WaitSSOLogin(s.ctx, &proto.WaitSSOLoginRequest{UserCode: resp.UserCode})
if err != nil {
log.Errorf("Waiting sso login failed with: %v", err)
label.SetText("Waiting login failed, please create \na debug bundle in the settings and contact support.")
return
}
label.SetText("Re-authentication successful.\nReconnecting")
time.Sleep(300 * time.Millisecond)
_, err = conn.Up(s.ctx, &proto.UpRequest{})
if err != nil {
label.SetText("Reconnecting failed, please create \na debug bundle in the settings and contact support.")
log.Errorf("Reconnecting failed with: %v", err)
return
}
label.SetText("Connection successful.\nClosing this window.")
time.Sleep(time.Second)
s.wLoginURL.Close()
})
img := canvas.NewImageFromResource(resIcon)
img.FillMode = canvas.ImageFillContain
img.SetMinSize(fyne.NewSize(64, 64))
img.Resize(fyne.NewSize(64, 64))
// center the content vertically
content := container.NewVBox(
layout.NewSpacer(),
img,
label,
btn,
layout.NewSpacer(),
)
s.wLoginURL.SetContent(container.NewCenter(content))
s.wLoginURL.Show()
}
func openURL(url string) error { func openURL(url string) error {
var err error var err error
switch runtime.GOOS { switch runtime.GOOS {

View File

@ -12,6 +12,8 @@ import (
"fyne.io/fyne/v2" "fyne.io/fyne/v2"
"fyne.io/systray" "fyne.io/systray"
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/version"
) )
type eventHandler struct { type eventHandler struct {
@ -143,7 +145,7 @@ func (h *eventHandler) handleGitHubClick() {
} }
func (h *eventHandler) handleUpdateClick() { func (h *eventHandler) handleUpdateClick() {
if err := openURL("https://netbird.io/download"); err != nil { if err := openURL(version.DownloadUrl()); err != nil {
log.Errorf("failed to open download URL: %v", err) log.Errorf("failed to open download URL: %v", err)
} }
} }

View File

@ -358,8 +358,6 @@ func (s *serviceClient) updateExitNodes() {
} else { } else {
s.mExitNode.Disable() s.mExitNode.Disable()
} }
log.Debugf("Exit nodes updated: %d", len(s.mExitNodeItems))
} }
func (s *serviceClient) recreateExitNodeMenu(exitNodes []*proto.Network) { func (s *serviceClient) recreateExitNodeMenu(exitNodes []*proto.Network) {

2
go.mod
View File

@ -63,7 +63,7 @@ require (
github.com/miekg/dns v1.1.59 github.com/miekg/dns v1.1.59
github.com/mitchellh/hashstructure/v2 v2.0.2 github.com/mitchellh/hashstructure/v2 v2.0.2
github.com/nadoo/ipset v0.5.0 github.com/nadoo/ipset v0.5.0
github.com/netbirdio/management-integrations/integrations v0.0.0-20250529122842-6700aa91190c github.com/netbirdio/management-integrations/integrations v0.0.0-20250612164546-6bd7e2338d65
github.com/netbirdio/signal-dispatcher/dispatcher v0.0.0-20250514131221-a464fd5f30cb github.com/netbirdio/signal-dispatcher/dispatcher v0.0.0-20250514131221-a464fd5f30cb
github.com/okta/okta-sdk-golang/v2 v2.18.0 github.com/okta/okta-sdk-golang/v2 v2.18.0
github.com/oschwald/maxminddb-golang v1.12.0 github.com/oschwald/maxminddb-golang v1.12.0

4
go.sum
View File

@ -503,8 +503,8 @@ github.com/netbirdio/go-netroute v0.0.0-20240611143515-f59b0e1d3944 h1:TDtJKmM6S
github.com/netbirdio/go-netroute v0.0.0-20240611143515-f59b0e1d3944/go.mod h1:sHA6TRxjQ6RLbnI+3R4DZo2Eseg/iKiPRfNmcuNySVQ= github.com/netbirdio/go-netroute v0.0.0-20240611143515-f59b0e1d3944/go.mod h1:sHA6TRxjQ6RLbnI+3R4DZo2Eseg/iKiPRfNmcuNySVQ=
github.com/netbirdio/ice/v3 v3.0.0-20240315174635-e72a50fcb64e h1:PURA50S8u4mF6RrkYYCAvvPCixhqqEiEy3Ej6avh04c= github.com/netbirdio/ice/v3 v3.0.0-20240315174635-e72a50fcb64e h1:PURA50S8u4mF6RrkYYCAvvPCixhqqEiEy3Ej6avh04c=
github.com/netbirdio/ice/v3 v3.0.0-20240315174635-e72a50fcb64e/go.mod h1:YMLU7qbKfVjmEv7EoZPIVEI+kNYxWCdPK3VS0BU+U4Q= github.com/netbirdio/ice/v3 v3.0.0-20240315174635-e72a50fcb64e/go.mod h1:YMLU7qbKfVjmEv7EoZPIVEI+kNYxWCdPK3VS0BU+U4Q=
github.com/netbirdio/management-integrations/integrations v0.0.0-20250529122842-6700aa91190c h1:SdZxYjR9XXHLyRsTbS1EHBr6+RI15oie1K9Q8yvi3FY= github.com/netbirdio/management-integrations/integrations v0.0.0-20250612164546-6bd7e2338d65 h1:5OfYiLjpr4dbQYJI5ouZaylkVdi2KlErLFOwBeBo5Hw=
github.com/netbirdio/management-integrations/integrations v0.0.0-20250529122842-6700aa91190c/go.mod h1:Gi9raplYzCCyh07Olw/DVfCJTFgpr1WCXJ/Q+8TSA9Q= github.com/netbirdio/management-integrations/integrations v0.0.0-20250612164546-6bd7e2338d65/go.mod h1:Gi9raplYzCCyh07Olw/DVfCJTFgpr1WCXJ/Q+8TSA9Q=
github.com/netbirdio/service v0.0.0-20240911161631-f62744f42502 h1:3tHlFmhTdX9axERMVN63dqyFqnvuD+EMJHzM7mNGON8= github.com/netbirdio/service v0.0.0-20240911161631-f62744f42502 h1:3tHlFmhTdX9axERMVN63dqyFqnvuD+EMJHzM7mNGON8=
github.com/netbirdio/service v0.0.0-20240911161631-f62744f42502/go.mod h1:CIMRFEJVL+0DS1a3Nx06NaMn4Dz63Ng6O7dl0qH0zVM= github.com/netbirdio/service v0.0.0-20240911161631-f62744f42502/go.mod h1:CIMRFEJVL+0DS1a3Nx06NaMn4Dz63Ng6O7dl0qH0zVM=
github.com/netbirdio/signal-dispatcher/dispatcher v0.0.0-20250514131221-a464fd5f30cb h1:Cr6age+ePALqlSvtp7wc6lYY97XN7rkD1K4XEDmY+TU= github.com/netbirdio/signal-dispatcher/dispatcher v0.0.0-20250514131221-a464fd5f30cb h1:Cr6age+ePALqlSvtp7wc6lYY97XN7rkD1K4XEDmY+TU=

View File

@ -159,6 +159,12 @@ var (
if err != nil { if err != nil {
return err return err
} }
integrationMetrics, err := integrations.InitIntegrationMetrics(ctx, appMetrics)
if err != nil {
return err
}
store, err := store.NewStore(ctx, config.StoreConfig.Engine, config.Datadir, appMetrics, false) store, err := store.NewStore(ctx, config.StoreConfig.Engine, config.Datadir, appMetrics, false)
if err != nil { if err != nil {
return fmt.Errorf("failed creating Store: %s: %v", config.Datadir, err) return fmt.Errorf("failed creating Store: %s: %v", config.Datadir, err)
@ -176,7 +182,7 @@ var (
if disableSingleAccMode { if disableSingleAccMode {
mgmtSingleAccModeDomain = "" mgmtSingleAccModeDomain = ""
} }
eventStore, key, err := integrations.InitEventStore(ctx, config.Datadir, config.DataStoreEncryptionKey) eventStore, key, err := integrations.InitEventStore(ctx, config.Datadir, config.DataStoreEncryptionKey, integrationMetrics)
if err != nil { if err != nil {
return fmt.Errorf("failed to initialize database: %s", err) return fmt.Errorf("failed to initialize database: %s", err)
} }

View File

@ -24,6 +24,7 @@ import (
"golang.org/x/exp/maps" "golang.org/x/exp/maps"
nbdns "github.com/netbirdio/netbird/dns" nbdns "github.com/netbirdio/netbird/dns"
"github.com/netbirdio/netbird/formatter/hook"
"github.com/netbirdio/netbird/management/server/account" "github.com/netbirdio/netbird/management/server/account"
"github.com/netbirdio/netbird/management/server/activity" "github.com/netbirdio/netbird/management/server/activity"
nbcache "github.com/netbirdio/netbird/management/server/cache" nbcache "github.com/netbirdio/netbird/management/server/cache"
@ -409,14 +410,15 @@ func (am *DefaultAccountManager) handlePeerLoginExpirationSettings(ctx context.C
event = activity.AccountPeerLoginExpirationDisabled event = activity.AccountPeerLoginExpirationDisabled
am.peerLoginExpiry.Cancel(ctx, []string{accountID}) am.peerLoginExpiry.Cancel(ctx, []string{accountID})
} else { } else {
am.checkAndSchedulePeerLoginExpiration(ctx, accountID) am.schedulePeerLoginExpiration(ctx, accountID)
} }
am.StoreEvent(ctx, userID, accountID, accountID, event, nil) am.StoreEvent(ctx, userID, accountID, accountID, event, nil)
} }
if oldSettings.PeerLoginExpiration != newSettings.PeerLoginExpiration { if oldSettings.PeerLoginExpiration != newSettings.PeerLoginExpiration {
am.StoreEvent(ctx, userID, accountID, accountID, activity.AccountPeerLoginExpirationDurationUpdated, nil) am.StoreEvent(ctx, userID, accountID, accountID, activity.AccountPeerLoginExpirationDurationUpdated, nil)
am.checkAndSchedulePeerLoginExpiration(ctx, accountID) am.peerLoginExpiry.Cancel(ctx, []string{accountID})
am.schedulePeerLoginExpiration(ctx, accountID)
} }
} }
@ -454,6 +456,10 @@ func (am *DefaultAccountManager) handleInactivityExpirationSettings(ctx context.
func (am *DefaultAccountManager) peerLoginExpirationJob(ctx context.Context, accountID string) func() (time.Duration, bool) { func (am *DefaultAccountManager) peerLoginExpirationJob(ctx context.Context, accountID string) func() (time.Duration, bool) {
return func() (time.Duration, bool) { return func() (time.Duration, bool) {
//nolint
ctx := context.WithValue(ctx, nbcontext.AccountIDKey, accountID)
//nolint
ctx = context.WithValue(ctx, hook.ExecutionContextKey, fmt.Sprintf("%s-PEER-EXPIRATION", hook.SystemSource))
unlock := am.Store.AcquireWriteLockByUID(ctx, accountID) unlock := am.Store.AcquireWriteLockByUID(ctx, accountID)
defer unlock() defer unlock()
@ -478,8 +484,11 @@ func (am *DefaultAccountManager) peerLoginExpirationJob(ctx context.Context, acc
} }
} }
func (am *DefaultAccountManager) checkAndSchedulePeerLoginExpiration(ctx context.Context, accountID string) { func (am *DefaultAccountManager) schedulePeerLoginExpiration(ctx context.Context, accountID string) {
am.peerLoginExpiry.Cancel(ctx, []string{accountID}) if am.peerLoginExpiry.IsSchedulerRunning(accountID) {
log.WithContext(ctx).Tracef("peer login expiration job for account %s is already scheduled", accountID)
return
}
if nextRun, ok := am.getNextPeerExpiration(ctx, accountID); ok { if nextRun, ok := am.getNextPeerExpiration(ctx, accountID); ok {
go am.peerLoginExpiry.Schedule(ctx, nextRun, accountID, am.peerLoginExpirationJob(ctx, accountID)) go am.peerLoginExpiry.Schedule(ctx, nextRun, accountID, am.peerLoginExpirationJob(ctx, accountID))
} }

View File

@ -1862,11 +1862,8 @@ func TestDefaultAccountManager_MarkPeerConnected_PeerLoginExpiration(t *testing.
require.NoError(t, err, "expecting to update account settings successfully but got error") require.NoError(t, err, "expecting to update account settings successfully but got error")
wg := &sync.WaitGroup{} wg := &sync.WaitGroup{}
wg.Add(2) wg.Add(1)
manager.peerLoginExpiry = &MockScheduler{ manager.peerLoginExpiry = &MockScheduler{
CancelFunc: func(ctx context.Context, IDs []string) {
wg.Done()
},
ScheduleFunc: func(ctx context.Context, in time.Duration, ID string, job func() (nextRunIn time.Duration, reschedule bool)) { ScheduleFunc: func(ctx context.Context, in time.Duration, ID string, job func() (nextRunIn time.Duration, reschedule bool)) {
wg.Done() wg.Done()
}, },

View File

@ -664,15 +664,6 @@ func areGroupChangesAffectPeers(ctx context.Context, transaction store.Store, ac
return false, nil return false, nil
} }
func (am *DefaultAccountManager) anyGroupHasPeers(account *types.Account, groupIDs []string) bool {
for _, groupID := range groupIDs {
if group, exists := account.Groups[groupID]; exists && group.HasPeers() {
return true
}
}
return false
}
// anyGroupHasPeersOrResources checks if any of the given groups in the account have peers or resources. // anyGroupHasPeersOrResources checks if any of the given groups in the account have peers or resources.
func anyGroupHasPeersOrResources(ctx context.Context, transaction store.Store, accountID string, groupIDs []string) (bool, error) { func anyGroupHasPeersOrResources(ctx context.Context, transaction store.Store, accountID string, groupIDs []string) (bool, error) {
groups, err := transaction.GetGroupsByIDs(ctx, store.LockingStrengthShare, accountID, groupIDs) groups, err := transaction.GetGroupsByIDs(ctx, store.LockingStrengthShare, accountID, groupIDs)

View File

@ -133,7 +133,7 @@ func (am *DefaultAccountManager) MarkPeerConnected(ctx context.Context, peerPubK
} }
if peer.LoginExpirationEnabled && settings.PeerLoginExpirationEnabled { if peer.LoginExpirationEnabled && settings.PeerLoginExpirationEnabled {
am.checkAndSchedulePeerLoginExpiration(ctx, accountID) am.schedulePeerLoginExpiration(ctx, accountID)
} }
if peer.InactivityExpirationEnabled && settings.PeerInactivityExpirationEnabled { if peer.InactivityExpirationEnabled && settings.PeerInactivityExpirationEnabled {
@ -296,7 +296,8 @@ func (am *DefaultAccountManager) UpdatePeer(ctx context.Context, accountID, user
am.StoreEvent(ctx, userID, peer.IP.String(), accountID, event, peer.EventMeta(dnsDomain)) am.StoreEvent(ctx, userID, peer.IP.String(), accountID, event, peer.EventMeta(dnsDomain))
if peer.AddedWithSSOLogin() && peer.LoginExpirationEnabled && settings.PeerLoginExpirationEnabled { if peer.AddedWithSSOLogin() && peer.LoginExpirationEnabled && settings.PeerLoginExpirationEnabled {
am.checkAndSchedulePeerLoginExpiration(ctx, accountID) am.peerLoginExpiry.Cancel(ctx, []string{accountID})
am.schedulePeerLoginExpiration(ctx, accountID)
} }
} }

View File

@ -4,19 +4,19 @@ import (
"context" "context"
"fmt" "fmt"
"net/netip" "net/netip"
"slices"
"unicode/utf8" "unicode/utf8"
"github.com/rs/xid" "github.com/rs/xid"
"github.com/netbirdio/netbird/management/server/permissions/modules"
"github.com/netbirdio/netbird/management/server/permissions/operations"
"github.com/netbirdio/netbird/management/server/store"
"github.com/netbirdio/netbird/management/server/types"
"github.com/netbirdio/netbird/management/domain" "github.com/netbirdio/netbird/management/domain"
"github.com/netbirdio/netbird/management/proto" "github.com/netbirdio/netbird/management/proto"
"github.com/netbirdio/netbird/management/server/activity" "github.com/netbirdio/netbird/management/server/activity"
"github.com/netbirdio/netbird/management/server/permissions/modules"
"github.com/netbirdio/netbird/management/server/permissions/operations"
"github.com/netbirdio/netbird/management/server/status" "github.com/netbirdio/netbird/management/server/status"
"github.com/netbirdio/netbird/management/server/store"
"github.com/netbirdio/netbird/management/server/types"
"github.com/netbirdio/netbird/route" "github.com/netbirdio/netbird/route"
) )
@ -30,13 +30,19 @@ func (am *DefaultAccountManager) GetRoute(ctx context.Context, accountID string,
return nil, status.NewPermissionDeniedError() return nil, status.NewPermissionDeniedError()
} }
return am.Store.GetRouteByID(ctx, store.LockingStrengthShare, string(routeID), accountID) return am.Store.GetRouteByID(ctx, store.LockingStrengthShare, accountID, string(routeID))
} }
// checkRoutePrefixOrDomainsExistForPeers checks if a route with a given prefix exists for a single peer or multiple peer groups. // checkRoutePrefixOrDomainsExistForPeers checks if a route with a given prefix exists for a single peer or multiple peer groups.
func (am *DefaultAccountManager) checkRoutePrefixOrDomainsExistForPeers(account *types.Account, peerID string, routeID route.ID, peerGroupIDs []string, prefix netip.Prefix, domains domain.List) error { func checkRoutePrefixOrDomainsExistForPeers(ctx context.Context, transaction store.Store, accountID string, checkRoute *route.Route, groupsMap map[string]*types.Group) error {
// routes can have both peer and peer_groups // routes can have both peer and peer_groups
routesWithPrefix := account.GetRoutesByPrefixOrDomains(prefix, domains) prefix := checkRoute.Network
domains := checkRoute.Domains
routesWithPrefix, err := getRoutesByPrefixOrDomains(ctx, transaction, accountID, prefix, domains)
if err != nil {
return err
}
// lets remember all the peers and the peer groups from routesWithPrefix // lets remember all the peers and the peer groups from routesWithPrefix
seenPeers := make(map[string]bool) seenPeers := make(map[string]bool)
@ -45,18 +51,24 @@ func (am *DefaultAccountManager) checkRoutePrefixOrDomainsExistForPeers(account
for _, prefixRoute := range routesWithPrefix { for _, prefixRoute := range routesWithPrefix {
// we skip route(s) with the same network ID as we want to allow updating of the existing route // we skip route(s) with the same network ID as we want to allow updating of the existing route
// when creating a new route routeID is newly generated so nothing will be skipped // when creating a new route routeID is newly generated so nothing will be skipped
if routeID == prefixRoute.ID { if checkRoute.ID == prefixRoute.ID {
continue continue
} }
if prefixRoute.Peer != "" { if prefixRoute.Peer != "" {
seenPeers[string(prefixRoute.ID)] = true seenPeers[string(prefixRoute.ID)] = true
} }
peerGroupsMap, err := transaction.GetGroupsByIDs(ctx, store.LockingStrengthShare, accountID, prefixRoute.PeerGroups)
if err != nil {
return err
}
for _, groupID := range prefixRoute.PeerGroups { for _, groupID := range prefixRoute.PeerGroups {
seenPeerGroups[groupID] = true seenPeerGroups[groupID] = true
group := account.GetGroup(groupID) group, ok := peerGroupsMap[groupID]
if group == nil { if !ok || group == nil {
return status.Errorf( return status.Errorf(
status.InvalidArgument, "failed to add route with %s - peer group %s doesn't exist", status.InvalidArgument, "failed to add route with %s - peer group %s doesn't exist",
getRouteDescriptor(prefix, domains), groupID, getRouteDescriptor(prefix, domains), groupID,
@ -69,12 +81,13 @@ func (am *DefaultAccountManager) checkRoutePrefixOrDomainsExistForPeers(account
} }
} }
if peerID != "" { if peerID := checkRoute.Peer; peerID != "" {
// check that peerID exists and is not in any route as single peer or part of the group // check that peerID exists and is not in any route as single peer or part of the group
peer := account.GetPeer(peerID) _, err = transaction.GetPeerByID(context.Background(), store.LockingStrengthShare, accountID, peerID)
if peer == nil { if err != nil {
return status.Errorf(status.InvalidArgument, "peer with ID %s not found", peerID) return status.Errorf(status.InvalidArgument, "peer with ID %s not found", peerID)
} }
if _, ok := seenPeers[peerID]; ok { if _, ok := seenPeers[peerID]; ok {
return status.Errorf(status.AlreadyExists, return status.Errorf(status.AlreadyExists,
"failed to add route with %s - peer %s already has this route", getRouteDescriptor(prefix, domains), peerID) "failed to add route with %s - peer %s already has this route", getRouteDescriptor(prefix, domains), peerID)
@ -82,9 +95,8 @@ func (am *DefaultAccountManager) checkRoutePrefixOrDomainsExistForPeers(account
} }
// check that peerGroupIDs are not in any route peerGroups list // check that peerGroupIDs are not in any route peerGroups list
for _, groupID := range peerGroupIDs { for _, groupID := range checkRoute.PeerGroups {
group := account.GetGroup(groupID) // we validated the group existence before entering this function, no need to check again. group := groupsMap[groupID] // we validated the group existence before entering this function, no need to check again.
if _, ok := seenPeerGroups[groupID]; ok { if _, ok := seenPeerGroups[groupID]; ok {
return status.Errorf( return status.Errorf(
status.AlreadyExists, "failed to add route with %s - peer group %s already has this route", status.AlreadyExists, "failed to add route with %s - peer group %s already has this route",
@ -92,12 +104,18 @@ func (am *DefaultAccountManager) checkRoutePrefixOrDomainsExistForPeers(account
} }
// check that the peers from peerGroupIDs groups are not the same peers we saw in routesWithPrefix // check that the peers from peerGroupIDs groups are not the same peers we saw in routesWithPrefix
peersMap, err := transaction.GetPeersByIDs(ctx, store.LockingStrengthShare, accountID, group.Peers)
if err != nil {
return err
}
for _, id := range group.Peers { for _, id := range group.Peers {
if _, ok := seenPeers[id]; ok { if _, ok := seenPeers[id]; ok {
peer := account.GetPeer(id) peer, ok := peersMap[id]
if peer == nil { if !ok || peer == nil {
return status.Errorf(status.InvalidArgument, "peer with ID %s not found", peerID) return status.Errorf(status.InvalidArgument, "peer with ID %s not found", id)
} }
return status.Errorf(status.AlreadyExists, return status.Errorf(status.AlreadyExists,
"failed to add route with %s - peer %s from the group %s already has this route", "failed to add route with %s - peer %s from the group %s already has this route",
getRouteDescriptor(prefix, domains), peer.Name, group.Name) getRouteDescriptor(prefix, domains), peer.Name, group.Name)
@ -128,97 +146,58 @@ func (am *DefaultAccountManager) CreateRoute(ctx context.Context, accountID stri
return nil, status.NewPermissionDeniedError() return nil, status.NewPermissionDeniedError()
} }
account, err := am.Store.GetAccount(ctx, accountID)
if err != nil {
return nil, err
}
if len(domains) > 0 && prefix.IsValid() { if len(domains) > 0 && prefix.IsValid() {
return nil, status.Errorf(status.InvalidArgument, "domains and network should not be provided at the same time") return nil, status.Errorf(status.InvalidArgument, "domains and network should not be provided at the same time")
} }
if len(domains) == 0 && !prefix.IsValid() { var newRoute *route.Route
return nil, status.Errorf(status.InvalidArgument, "invalid Prefix") var updateAccountPeers bool
}
if len(domains) > 0 { err = am.Store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
prefix = getPlaceholderIP() newRoute = &route.Route{
} ID: route.ID(xid.New().String()),
AccountID: accountID,
if peerID != "" && len(peerGroupIDs) != 0 { Network: prefix,
return nil, status.Errorf( Domains: domains,
status.InvalidArgument, KeepRoute: keepRoute,
"peer with ID %s and peers group %s should not be provided at the same time", NetID: netID,
peerID, peerGroupIDs) Description: description,
} Peer: peerID,
PeerGroups: peerGroupIDs,
var newRoute route.Route NetworkType: networkType,
newRoute.ID = route.ID(xid.New().String()) Masquerade: masquerade,
Metric: metric,
if len(peerGroupIDs) > 0 { Enabled: enabled,
err = validateGroups(peerGroupIDs, account.Groups) Groups: groups,
if err != nil { AccessControlGroups: accessControlGroupIDs,
return nil, err
} }
}
if len(accessControlGroupIDs) > 0 { if err = validateRoute(ctx, transaction, accountID, newRoute); err != nil {
err = validateGroups(accessControlGroupIDs, account.Groups) return err
if err != nil {
return nil, err
} }
}
err = am.checkRoutePrefixOrDomainsExistForPeers(account, peerID, newRoute.ID, peerGroupIDs, prefix, domains) updateAccountPeers, err = areRouteChangesAffectPeers(ctx, transaction, newRoute)
if err != nil {
return err
}
if err = transaction.IncrementNetworkSerial(ctx, store.LockingStrengthUpdate, accountID); err != nil {
return err
}
return transaction.SaveRoute(ctx, store.LockingStrengthUpdate, newRoute)
})
if err != nil { if err != nil {
return nil, err return nil, err
} }
if metric < route.MinMetric || metric > route.MaxMetric {
return nil, status.Errorf(status.InvalidArgument, "metric should be between %d and %d", route.MinMetric, route.MaxMetric)
}
if utf8.RuneCountInString(string(netID)) > route.MaxNetIDChar || netID == "" {
return nil, status.Errorf(status.InvalidArgument, "identifier should be between 1 and %d", route.MaxNetIDChar)
}
err = validateGroups(groups, account.Groups)
if err != nil {
return nil, err
}
newRoute.Peer = peerID
newRoute.PeerGroups = peerGroupIDs
newRoute.Network = prefix
newRoute.Domains = domains
newRoute.NetworkType = networkType
newRoute.Description = description
newRoute.NetID = netID
newRoute.Masquerade = masquerade
newRoute.Metric = metric
newRoute.Enabled = enabled
newRoute.Groups = groups
newRoute.KeepRoute = keepRoute
newRoute.AccessControlGroups = accessControlGroupIDs
if account.Routes == nil {
account.Routes = make(map[route.ID]*route.Route)
}
account.Routes[newRoute.ID] = &newRoute
account.Network.IncSerial()
if err = am.Store.SaveAccount(ctx, account); err != nil {
return nil, err
}
if am.isRouteChangeAffectPeers(account, &newRoute) {
am.UpdateAccountPeers(ctx, accountID)
}
am.StoreEvent(ctx, userID, string(newRoute.ID), accountID, activity.RouteCreated, newRoute.EventMeta()) am.StoreEvent(ctx, userID, string(newRoute.ID), accountID, activity.RouteCreated, newRoute.EventMeta())
return &newRoute, nil if updateAccountPeers {
am.UpdateAccountPeers(ctx, accountID)
}
return newRoute, nil
} }
// SaveRoute saves route // SaveRoute saves route
@ -226,6 +205,115 @@ func (am *DefaultAccountManager) SaveRoute(ctx context.Context, accountID, userI
unlock := am.Store.AcquireWriteLockByUID(ctx, accountID) unlock := am.Store.AcquireWriteLockByUID(ctx, accountID)
defer unlock() defer unlock()
allowed, err := am.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Routes, operations.Update)
if err != nil {
return status.NewPermissionValidationError(err)
}
if !allowed {
return status.NewPermissionDeniedError()
}
var oldRoute *route.Route
var oldRouteAffectsPeers bool
var newRouteAffectsPeers bool
err = am.Store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
if err = validateRoute(ctx, transaction, accountID, routeToSave); err != nil {
return err
}
oldRoute, err = transaction.GetRouteByID(ctx, store.LockingStrengthUpdate, accountID, string(routeToSave.ID))
if err != nil {
return err
}
oldRouteAffectsPeers, err = areRouteChangesAffectPeers(ctx, transaction, oldRoute)
if err != nil {
return err
}
newRouteAffectsPeers, err = areRouteChangesAffectPeers(ctx, transaction, routeToSave)
if err != nil {
return err
}
routeToSave.AccountID = accountID
if err = transaction.IncrementNetworkSerial(ctx, store.LockingStrengthUpdate, accountID); err != nil {
return err
}
return transaction.SaveRoute(ctx, store.LockingStrengthUpdate, routeToSave)
})
if err != nil {
return err
}
am.StoreEvent(ctx, userID, string(routeToSave.ID), accountID, activity.RouteUpdated, routeToSave.EventMeta())
if oldRouteAffectsPeers || newRouteAffectsPeers {
am.UpdateAccountPeers(ctx, accountID)
}
return nil
}
// DeleteRoute deletes route with routeID
func (am *DefaultAccountManager) DeleteRoute(ctx context.Context, accountID string, routeID route.ID, userID string) error {
unlock := am.Store.AcquireWriteLockByUID(ctx, accountID)
defer unlock()
allowed, err := am.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Routes, operations.Delete)
if err != nil {
return status.NewPermissionValidationError(err)
}
if !allowed {
return status.NewPermissionDeniedError()
}
var route *route.Route
var updateAccountPeers bool
err = am.Store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
route, err = transaction.GetRouteByID(ctx, store.LockingStrengthUpdate, accountID, string(routeID))
if err != nil {
return err
}
updateAccountPeers, err = areRouteChangesAffectPeers(ctx, transaction, route)
if err != nil {
return err
}
if err = transaction.IncrementNetworkSerial(ctx, store.LockingStrengthUpdate, accountID); err != nil {
return err
}
return transaction.DeleteRoute(ctx, store.LockingStrengthUpdate, accountID, string(routeID))
})
am.StoreEvent(ctx, userID, string(route.ID), accountID, activity.RouteRemoved, route.EventMeta())
if updateAccountPeers {
am.UpdateAccountPeers(ctx, accountID)
}
return nil
}
// ListRoutes returns a list of routes from account
func (am *DefaultAccountManager) ListRoutes(ctx context.Context, accountID, userID string) ([]*route.Route, error) {
allowed, err := am.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Routes, operations.Read)
if err != nil {
return nil, status.NewPermissionValidationError(err)
}
if !allowed {
return nil, status.NewPermissionDeniedError()
}
return am.Store.GetAccountRoutes(ctx, store.LockingStrengthShare, accountID)
}
func validateRoute(ctx context.Context, transaction store.Store, accountID string, routeToSave *route.Route) error {
if routeToSave == nil { if routeToSave == nil {
return status.Errorf(status.InvalidArgument, "route provided is nil") return status.Errorf(status.InvalidArgument, "route provided is nil")
} }
@ -238,19 +326,6 @@ func (am *DefaultAccountManager) SaveRoute(ctx context.Context, accountID, userI
return status.Errorf(status.InvalidArgument, "identifier should be between 1 and %d", route.MaxNetIDChar) return status.Errorf(status.InvalidArgument, "identifier should be between 1 and %d", route.MaxNetIDChar)
} }
allowed, err := am.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Routes, operations.Update)
if err != nil {
return status.NewPermissionValidationError(err)
}
if !allowed {
return status.NewPermissionDeniedError()
}
account, err := am.Store.GetAccount(ctx, accountID)
if err != nil {
return err
}
if len(routeToSave.Domains) > 0 && routeToSave.Network.IsValid() { if len(routeToSave.Domains) > 0 && routeToSave.Network.IsValid() {
return status.Errorf(status.InvalidArgument, "domains and network should not be provided at the same time") return status.Errorf(status.InvalidArgument, "domains and network should not be provided at the same time")
} }
@ -267,96 +342,39 @@ func (am *DefaultAccountManager) SaveRoute(ctx context.Context, accountID, userI
return status.Errorf(status.InvalidArgument, "peer with ID and peer groups should not be provided at the same time") return status.Errorf(status.InvalidArgument, "peer with ID and peer groups should not be provided at the same time")
} }
groupsMap, err := validateRouteGroups(ctx, transaction, accountID, routeToSave)
if err != nil {
return err
}
return checkRoutePrefixOrDomainsExistForPeers(ctx, transaction, accountID, routeToSave, groupsMap)
}
// validateRouteGroups validates the route groups and returns the validated groups map.
func validateRouteGroups(ctx context.Context, transaction store.Store, accountID string, routeToSave *route.Route) (map[string]*types.Group, error) {
groupsToValidate := slices.Concat(routeToSave.Groups, routeToSave.PeerGroups, routeToSave.AccessControlGroups)
groupsMap, err := transaction.GetGroupsByIDs(ctx, store.LockingStrengthShare, accountID, groupsToValidate)
if err != nil {
return nil, err
}
if len(routeToSave.PeerGroups) > 0 { if len(routeToSave.PeerGroups) > 0 {
err = validateGroups(routeToSave.PeerGroups, account.Groups) if err = validateGroups(routeToSave.PeerGroups, groupsMap); err != nil {
if err != nil { return nil, err
return err
} }
} }
if len(routeToSave.AccessControlGroups) > 0 { if len(routeToSave.AccessControlGroups) > 0 {
err = validateGroups(routeToSave.AccessControlGroups, account.Groups) if err = validateGroups(routeToSave.AccessControlGroups, groupsMap); err != nil {
if err != nil { return nil, err
return err
} }
} }
err = am.checkRoutePrefixOrDomainsExistForPeers(account, routeToSave.Peer, routeToSave.ID, routeToSave.Copy().PeerGroups, routeToSave.Network, routeToSave.Domains) if err = validateGroups(routeToSave.Groups, groupsMap); err != nil {
if err != nil { return nil, err
return err
} }
err = validateGroups(routeToSave.Groups, account.Groups) return groupsMap, nil
if err != nil {
return err
}
oldRoute := account.Routes[routeToSave.ID]
account.Routes[routeToSave.ID] = routeToSave
account.Network.IncSerial()
if err = am.Store.SaveAccount(ctx, account); err != nil {
return err
}
if am.isRouteChangeAffectPeers(account, oldRoute) || am.isRouteChangeAffectPeers(account, routeToSave) {
am.UpdateAccountPeers(ctx, accountID)
}
am.StoreEvent(ctx, userID, string(routeToSave.ID), accountID, activity.RouteUpdated, routeToSave.EventMeta())
return nil
}
// DeleteRoute deletes route with routeID
func (am *DefaultAccountManager) DeleteRoute(ctx context.Context, accountID string, routeID route.ID, userID string) error {
unlock := am.Store.AcquireWriteLockByUID(ctx, accountID)
defer unlock()
allowed, err := am.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Routes, operations.Delete)
if err != nil {
return status.NewPermissionValidationError(err)
}
if !allowed {
return status.NewPermissionDeniedError()
}
account, err := am.Store.GetAccount(ctx, accountID)
if err != nil {
return err
}
routy := account.Routes[routeID]
if routy == nil {
return status.Errorf(status.NotFound, "route with ID %s doesn't exist", routeID)
}
delete(account.Routes, routeID)
account.Network.IncSerial()
if err = am.Store.SaveAccount(ctx, account); err != nil {
return err
}
am.StoreEvent(ctx, userID, string(routy.ID), accountID, activity.RouteRemoved, routy.EventMeta())
if am.isRouteChangeAffectPeers(account, routy) {
am.UpdateAccountPeers(ctx, accountID)
}
return nil
}
// ListRoutes returns a list of routes from account
func (am *DefaultAccountManager) ListRoutes(ctx context.Context, accountID, userID string) ([]*route.Route, error) {
allowed, err := am.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Routes, operations.Read)
if err != nil {
return nil, status.NewPermissionValidationError(err)
}
if !allowed {
return nil, status.NewPermissionDeniedError()
}
return am.Store.GetAccountRoutes(ctx, store.LockingStrengthShare, accountID)
} }
func toProtocolRoute(route *route.Route) *proto.Route { func toProtocolRoute(route *route.Route) *proto.Route {
@ -455,8 +473,40 @@ func getProtoPortInfo(rule *types.RouteFirewallRule) *proto.PortInfo {
return &portInfo return &portInfo
} }
// isRouteChangeAffectPeers checks if a given route affects peers by determining // areRouteChangesAffectPeers checks if a given route affects peers by determining
// if it has a routing peer, distribution, or peer groups that include peers // if it has a routing peer, distribution, or peer groups that include peers.
func (am *DefaultAccountManager) isRouteChangeAffectPeers(account *types.Account, route *route.Route) bool { func areRouteChangesAffectPeers(ctx context.Context, transaction store.Store, route *route.Route) (bool, error) {
return am.anyGroupHasPeers(account, route.Groups) || am.anyGroupHasPeers(account, route.PeerGroups) || route.Peer != "" if route.Peer != "" {
return true, nil
}
hasPeers, err := anyGroupHasPeersOrResources(ctx, transaction, route.AccountID, route.Groups)
if err != nil {
return false, err
}
if hasPeers {
return true, nil
}
return anyGroupHasPeersOrResources(ctx, transaction, route.AccountID, route.PeerGroups)
}
// GetRoutesByPrefixOrDomains return list of routes by account and route prefix
func getRoutesByPrefixOrDomains(ctx context.Context, transaction store.Store, accountID string, prefix netip.Prefix, domains domain.List) ([]*route.Route, error) {
accountRoutes, err := transaction.GetAccountRoutes(ctx, store.LockingStrengthShare, accountID)
if err != nil {
return nil, err
}
routes := make([]*route.Route, 0)
for _, r := range accountRoutes {
dynamic := r.IsDynamic()
if dynamic && r.Domains.PunycodeString() == domains.PunycodeString() ||
!dynamic && r.Network.String() == prefix.String() {
routes = append(routes, r)
}
}
return routes, nil
} }

View File

@ -12,6 +12,7 @@ import (
type Scheduler interface { type Scheduler interface {
Cancel(ctx context.Context, IDs []string) Cancel(ctx context.Context, IDs []string)
Schedule(ctx context.Context, in time.Duration, ID string, job func() (nextRunIn time.Duration, reschedule bool)) Schedule(ctx context.Context, in time.Duration, ID string, job func() (nextRunIn time.Duration, reschedule bool))
IsSchedulerRunning(ID string) bool
} }
// MockScheduler is a mock implementation of Scheduler // MockScheduler is a mock implementation of Scheduler
@ -26,7 +27,7 @@ func (mock *MockScheduler) Cancel(ctx context.Context, IDs []string) {
mock.CancelFunc(ctx, IDs) mock.CancelFunc(ctx, IDs)
return return
} }
log.WithContext(ctx).Errorf("MockScheduler doesn't have Cancel function defined ") log.WithContext(ctx).Warnf("MockScheduler doesn't have Cancel function defined ")
} }
// Schedule mocks the Schedule function of the Scheduler interface // Schedule mocks the Schedule function of the Scheduler interface
@ -35,7 +36,13 @@ func (mock *MockScheduler) Schedule(ctx context.Context, in time.Duration, ID st
mock.ScheduleFunc(ctx, in, ID, job) mock.ScheduleFunc(ctx, in, ID, job)
return return
} }
log.WithContext(ctx).Errorf("MockScheduler doesn't have Schedule function defined") log.WithContext(ctx).Warnf("MockScheduler doesn't have Schedule function defined")
}
func (mock *MockScheduler) IsSchedulerRunning(ID string) bool {
// MockScheduler does not implement IsSchedulerRunning, so we return false
log.Warnf("MockScheduler doesn't have IsSchedulerRunning function defined")
return false
} }
// DefaultScheduler is a generic structure that allows to schedule jobs (functions) to run in the future and cancel them. // DefaultScheduler is a generic structure that allows to schedule jobs (functions) to run in the future and cancel them.
@ -124,3 +131,11 @@ func (wm *DefaultScheduler) Schedule(ctx context.Context, in time.Duration, ID s
}() }()
} }
// IsSchedulerRunning checks if a job with the provided ID is scheduled to run
func (wm *DefaultScheduler) IsSchedulerRunning(ID string) bool {
wm.mu.Lock()
defer wm.mu.Unlock()
_, ok := wm.jobs[ID]
return ok
}

View File

@ -5,7 +5,6 @@ import (
"crypto/sha256" "crypto/sha256"
"encoding/base64" "encoding/base64"
"fmt" "fmt"
"strconv"
"strings" "strings"
"testing" "testing"
"time" "time"
@ -182,7 +181,7 @@ func TestDefaultAccountManager_CreateSetupKey(t *testing.T) {
} }
assertKey(t, key, tCase.expectedKeyName, false, tCase.expectedType, tCase.expectedUsedTimes, assertKey(t, key, tCase.expectedKeyName, false, tCase.expectedType, tCase.expectedUsedTimes,
tCase.expectedCreatedAt, tCase.expectedExpiresAt, strconv.Itoa(int(types.Hash(key.Key))), tCase.expectedCreatedAt, tCase.expectedExpiresAt, key.Id,
tCase.expectedUpdatedAt, tCase.expectedGroups, false) tCase.expectedUpdatedAt, tCase.expectedGroups, false)
// check the corresponding events that should have been generated // check the corresponding events that should have been generated
@ -258,10 +257,10 @@ func TestGenerateDefaultSetupKey(t *testing.T) {
expectedExpiresAt := time.Now().UTC().Add(24 * 30 * time.Hour) expectedExpiresAt := time.Now().UTC().Add(24 * 30 * time.Hour)
var expectedAutoGroups []string var expectedAutoGroups []string
key, plainKey := types.GenerateDefaultSetupKey() key, _ := types.GenerateDefaultSetupKey()
assertKey(t, key, expectedName, expectedRevoke, expectedType, expectedUsedTimes, expectedCreatedAt, assertKey(t, key, expectedName, expectedRevoke, expectedType, expectedUsedTimes, expectedCreatedAt,
expectedExpiresAt, strconv.Itoa(int(types.Hash(plainKey))), expectedUpdatedAt, expectedAutoGroups, true) expectedExpiresAt, key.Id, expectedUpdatedAt, expectedAutoGroups, true)
} }
@ -275,10 +274,10 @@ func TestGenerateSetupKey(t *testing.T) {
expectedUpdatedAt := time.Now().UTC() expectedUpdatedAt := time.Now().UTC()
var expectedAutoGroups []string var expectedAutoGroups []string
key, plain := types.GenerateSetupKey(expectedName, types.SetupKeyOneOff, time.Hour, []string{}, types.SetupKeyUnlimitedUsage, false, false) key, _ := types.GenerateSetupKey(expectedName, types.SetupKeyOneOff, time.Hour, []string{}, types.SetupKeyUnlimitedUsage, false, false)
assertKey(t, key, expectedName, expectedRevoke, expectedType, expectedUsedTimes, expectedCreatedAt, assertKey(t, key, expectedName, expectedRevoke, expectedType, expectedUsedTimes, expectedCreatedAt,
expectedExpiresAt, strconv.Itoa(int(types.Hash(plain))), expectedUpdatedAt, expectedAutoGroups, true) expectedExpiresAt, key.Id, expectedUpdatedAt, expectedAutoGroups, true)
} }

View File

@ -227,3 +227,7 @@ func NewUserRoleNotFoundError(role string) error {
func NewOperationNotFoundError(operation operations.Operation) error { func NewOperationNotFoundError(operation operations.Operation) error {
return Errorf(NotFound, "operation: %s not found", operation) return Errorf(NotFound, "operation: %s not found", operation)
} }
func NewRouteNotFoundError(routeID string) error {
return Errorf(NotFound, "route: %s not found", routeID)
}

View File

@ -23,8 +23,6 @@ import (
"gorm.io/gorm/clause" "gorm.io/gorm/clause"
"gorm.io/gorm/logger" "gorm.io/gorm/logger"
"github.com/netbirdio/netbird/management/server/util"
nbdns "github.com/netbirdio/netbird/dns" nbdns "github.com/netbirdio/netbird/dns"
resourceTypes "github.com/netbirdio/netbird/management/server/networks/resources/types" resourceTypes "github.com/netbirdio/netbird/management/server/networks/resources/types"
routerTypes "github.com/netbirdio/netbird/management/server/networks/routers/types" routerTypes "github.com/netbirdio/netbird/management/server/networks/routers/types"
@ -34,6 +32,7 @@ import (
"github.com/netbirdio/netbird/management/server/status" "github.com/netbirdio/netbird/management/server/status"
"github.com/netbirdio/netbird/management/server/telemetry" "github.com/netbirdio/netbird/management/server/telemetry"
"github.com/netbirdio/netbird/management/server/types" "github.com/netbirdio/netbird/management/server/types"
"github.com/netbirdio/netbird/management/server/util"
"github.com/netbirdio/netbird/route" "github.com/netbirdio/netbird/route"
) )
@ -1968,12 +1967,58 @@ func (s *SqlStore) DeletePostureChecks(ctx context.Context, lockStrength Locking
// GetAccountRoutes retrieves network routes for an account. // GetAccountRoutes retrieves network routes for an account.
func (s *SqlStore) GetAccountRoutes(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*route.Route, error) { func (s *SqlStore) GetAccountRoutes(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*route.Route, error) {
return getRecords[*route.Route](s.db, lockStrength, accountID) var routes []*route.Route
result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}).
Find(&routes, accountIDCondition, accountID)
if err := result.Error; err != nil {
log.WithContext(ctx).Errorf("failed to get routes from the store: %s", err)
return nil, status.Errorf(status.Internal, "failed to get routes from store")
}
return routes, nil
} }
// GetRouteByID retrieves a route by its ID and account ID. // GetRouteByID retrieves a route by its ID and account ID.
func (s *SqlStore) GetRouteByID(ctx context.Context, lockStrength LockingStrength, routeID string, accountID string) (*route.Route, error) { func (s *SqlStore) GetRouteByID(ctx context.Context, lockStrength LockingStrength, accountID string, routeID string) (*route.Route, error) {
return getRecordByID[route.Route](s.db, lockStrength, routeID, accountID) var route *route.Route
result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}).
First(&route, accountAndIDQueryCondition, accountID, routeID)
if err := result.Error; err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return nil, status.NewRouteNotFoundError(routeID)
}
log.WithContext(ctx).Errorf("failed to get route from the store: %s", err)
return nil, status.Errorf(status.Internal, "failed to get route from store")
}
return route, nil
}
// SaveRoute saves a route to the database.
func (s *SqlStore) SaveRoute(ctx context.Context, lockStrength LockingStrength, route *route.Route) error {
result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}).Save(route)
if err := result.Error; err != nil {
log.WithContext(ctx).Errorf("failed to save route to the store: %s", err)
return status.Errorf(status.Internal, "failed to save route to store")
}
return nil
}
// DeleteRoute deletes a route from the database.
func (s *SqlStore) DeleteRoute(ctx context.Context, lockStrength LockingStrength, accountID, routeID string) error {
result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}).
Delete(&route.Route{}, accountAndIDQueryCondition, accountID, routeID)
if err := result.Error; err != nil {
log.WithContext(ctx).Errorf("failed to delete route from the store: %s", err)
return status.Errorf(status.Internal, "failed to delete route from store")
}
if result.RowsAffected == 0 {
return status.NewRouteNotFoundError(routeID)
}
return nil
} }
// GetAccountSetupKeys retrieves setup keys for an account. // GetAccountSetupKeys retrieves setup keys for an account.
@ -2104,49 +2149,6 @@ func (s *SqlStore) DeleteNameServerGroup(ctx context.Context, lockStrength Locki
return nil return nil
} }
// getRecords retrieves records from the database based on the account ID.
func getRecords[T any](db *gorm.DB, lockStrength LockingStrength, accountID string) ([]T, error) {
tx := db
if lockStrength != LockingStrengthNone {
tx = tx.Clauses(clause.Locking{Strength: string(lockStrength)})
}
var record []T
result := tx.Find(&record, accountIDCondition, accountID)
if err := result.Error; err != nil {
parts := strings.Split(fmt.Sprintf("%T", record), ".")
recordType := parts[len(parts)-1]
return nil, status.Errorf(status.Internal, "failed to get account %ss from store: %v", recordType, err)
}
return record, nil
}
// getRecordByID retrieves a record by its ID and account ID from the database.
func getRecordByID[T any](db *gorm.DB, lockStrength LockingStrength, recordID, accountID string) (*T, error) {
tx := db
if lockStrength != LockingStrengthNone {
tx = tx.Clauses(clause.Locking{Strength: string(lockStrength)})
}
var record T
result := tx.Clauses(clause.Locking{Strength: string(lockStrength)}).
First(&record, accountAndIDQueryCondition, accountID, recordID)
if err := result.Error; err != nil {
parts := strings.Split(fmt.Sprintf("%T", record), ".")
recordType := parts[len(parts)-1]
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
return nil, status.Errorf(status.NotFound, "%s not found", recordType)
}
return nil, status.Errorf(status.Internal, "failed to get %s from store: %v", recordType, err)
}
return &record, nil
}
// SaveDNSSettings saves the DNS settings to the store. // SaveDNSSettings saves the DNS settings to the store.
func (s *SqlStore) SaveDNSSettings(ctx context.Context, lockStrength LockingStrength, accountID string, settings *types.DNSSettings) error { func (s *SqlStore) SaveDNSSettings(ctx context.Context, lockStrength LockingStrength, accountID string, settings *types.DNSSettings) error {
result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}).Model(&types.Account{}). result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}).Model(&types.Account{}).

View File

@ -19,21 +19,17 @@ import (
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
"github.com/netbirdio/netbird/management/server/util"
nbdns "github.com/netbirdio/netbird/dns" nbdns "github.com/netbirdio/netbird/dns"
resourceTypes "github.com/netbirdio/netbird/management/server/networks/resources/types" resourceTypes "github.com/netbirdio/netbird/management/server/networks/resources/types"
routerTypes "github.com/netbirdio/netbird/management/server/networks/routers/types" routerTypes "github.com/netbirdio/netbird/management/server/networks/routers/types"
networkTypes "github.com/netbirdio/netbird/management/server/networks/types" networkTypes "github.com/netbirdio/netbird/management/server/networks/types"
"github.com/netbirdio/netbird/management/server/posture"
"github.com/netbirdio/netbird/management/server/types"
route2 "github.com/netbirdio/netbird/route"
"github.com/netbirdio/netbird/management/server/status"
nbpeer "github.com/netbirdio/netbird/management/server/peer" nbpeer "github.com/netbirdio/netbird/management/server/peer"
"github.com/netbirdio/netbird/management/server/posture"
"github.com/netbirdio/netbird/management/server/status"
"github.com/netbirdio/netbird/management/server/types"
"github.com/netbirdio/netbird/management/server/util"
nbroute "github.com/netbirdio/netbird/route" nbroute "github.com/netbirdio/netbird/route"
route2 "github.com/netbirdio/netbird/route"
) )
func runTestForAllEngines(t *testing.T, testDataFile string, f func(t *testing.T, store Store)) { func runTestForAllEngines(t *testing.T, testDataFile string, f func(t *testing.T, store Store)) {
@ -3247,6 +3243,132 @@ func TestSqlStore_SaveGroups_LargeBatch(t *testing.T) {
require.NoError(t, err) require.NoError(t, err)
require.Equal(t, 8003, len(accountGroups)) require.Equal(t, 8003, len(accountGroups))
} }
func TestSqlStore_GetAccountRoutes(t *testing.T) {
store, cleanup, err := NewTestStoreFromSQL(context.Background(), "../testdata/extended-store.sql", t.TempDir())
t.Cleanup(cleanup)
require.NoError(t, err)
tests := []struct {
name string
accountID string
expectedCount int
}{
{
name: "retrieve routes by existing account ID",
accountID: "bf1c8084-ba50-4ce7-9439-34653001fc3b",
expectedCount: 1,
},
{
name: "non-existing account ID",
accountID: "nonexistent",
expectedCount: 0,
},
{
name: "empty account ID",
accountID: "",
expectedCount: 0,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
routes, err := store.GetAccountRoutes(context.Background(), LockingStrengthShare, tt.accountID)
require.NoError(t, err)
require.Len(t, routes, tt.expectedCount)
})
}
}
func TestSqlStore_GetRouteByID(t *testing.T) {
store, cleanup, err := NewTestStoreFromSQL(context.Background(), "../testdata/extended-store.sql", t.TempDir())
t.Cleanup(cleanup)
require.NoError(t, err)
accountID := "bf1c8084-ba50-4ce7-9439-34653001fc3b"
tests := []struct {
name string
routeID string
expectError bool
}{
{
name: "retrieve existing route",
routeID: "ct03t427qv97vmtmglog",
expectError: false,
},
{
name: "retrieve non-existing route",
routeID: "non-existing",
expectError: true,
},
{
name: "retrieve with empty route ID",
routeID: "",
expectError: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
route, err := store.GetRouteByID(context.Background(), LockingStrengthShare, accountID, tt.routeID)
if tt.expectError {
require.Error(t, err)
sErr, ok := status.FromError(err)
require.True(t, ok)
require.Equal(t, sErr.Type(), status.NotFound)
require.Nil(t, route)
} else {
require.NoError(t, err)
require.NotNil(t, route)
require.Equal(t, tt.routeID, string(route.ID))
}
})
}
}
func TestSqlStore_SaveRoute(t *testing.T) {
store, cleanup, err := NewTestStoreFromSQL(context.Background(), "../testdata/extended-store.sql", t.TempDir())
t.Cleanup(cleanup)
require.NoError(t, err)
accountID := "bf1c8084-ba50-4ce7-9439-34653001fc3b"
route := &route2.Route{
ID: "route-id",
AccountID: accountID,
Network: netip.MustParsePrefix("10.10.0.0/16"),
NetID: "netID",
PeerGroups: []string{"routeA"},
NetworkType: route2.IPv4Network,
Masquerade: true,
Metric: 9999,
Enabled: true,
Groups: []string{"groupA"},
AccessControlGroups: []string{},
}
err = store.SaveRoute(context.Background(), LockingStrengthUpdate, route)
require.NoError(t, err)
saveRoute, err := store.GetRouteByID(context.Background(), LockingStrengthShare, accountID, string(route.ID))
require.NoError(t, err)
require.Equal(t, route, saveRoute)
}
func TestSqlStore_DeleteRoute(t *testing.T) {
store, cleanup, err := NewTestStoreFromSQL(context.Background(), "../testdata/extended-store.sql", t.TempDir())
t.Cleanup(cleanup)
require.NoError(t, err)
accountID := "bf1c8084-ba50-4ce7-9439-34653001fc3b"
routeID := "ct03t427qv97vmtmglog"
err = store.DeleteRoute(context.Background(), LockingStrengthUpdate, accountID, routeID)
require.NoError(t, err)
route, err := store.GetRouteByID(context.Background(), LockingStrengthShare, accountID, routeID)
require.Error(t, err)
require.Nil(t, route)
}
func TestSqlStore_GetAccountMeta(t *testing.T) { func TestSqlStore_GetAccountMeta(t *testing.T) {
store, cleanup, err := NewTestStoreFromSQL(context.Background(), "../testdata/extended-store.sql", t.TempDir()) store, cleanup, err := NewTestStoreFromSQL(context.Background(), "../testdata/extended-store.sql", t.TempDir())

View File

@ -145,7 +145,9 @@ type Store interface {
DeleteSetupKey(ctx context.Context, lockStrength LockingStrength, accountID, keyID string) error DeleteSetupKey(ctx context.Context, lockStrength LockingStrength, accountID, keyID string) error
GetAccountRoutes(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*route.Route, error) GetAccountRoutes(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*route.Route, error)
GetRouteByID(ctx context.Context, lockStrength LockingStrength, routeID string, accountID string) (*route.Route, error) GetRouteByID(ctx context.Context, lockStrength LockingStrength, accountID, routeID string) (*route.Route, error)
SaveRoute(ctx context.Context, lockStrength LockingStrength, route *route.Route) error
DeleteRoute(ctx context.Context, lockStrength LockingStrength, accountID, routeID string) error
GetAccountNameServerGroups(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*dns.NameServerGroup, error) GetAccountNameServerGroups(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*dns.NameServerGroup, error)
GetNameServerGroupByID(ctx context.Context, lockStrength LockingStrength, nameServerGroupID string, accountID string) (*dns.NameServerGroup, error) GetNameServerGroupByID(ctx context.Context, lockStrength LockingStrength, nameServerGroupID string, accountID string) (*dns.NameServerGroup, error)

View File

@ -184,10 +184,10 @@ func (appMetrics *defaultAppMetrics) Expose(ctx context.Context, port int, endpo
} }
appMetrics.listener = listener appMetrics.listener = listener
go func() { go func() {
err := http.Serve(listener, rootRouter) if err := http.Serve(listener, rootRouter); err != nil && err != http.ErrServerClosed {
if err != nil { log.WithContext(ctx).Errorf("metrics server error: %v", err)
return
} }
log.WithContext(ctx).Info("metrics server stopped")
}() }()
log.WithContext(ctx).Infof("enabled application metrics and exposing on http://%s", listener.Addr().String()) log.WithContext(ctx).Infof("enabled application metrics and exposing on http://%s", listener.Addr().String())
@ -204,7 +204,7 @@ func (appMetrics *defaultAppMetrics) GetMeter() metric2.Meter {
func NewDefaultAppMetrics(ctx context.Context) (AppMetrics, error) { func NewDefaultAppMetrics(ctx context.Context) (AppMetrics, error) {
exporter, err := prometheus.New() exporter, err := prometheus.New()
if err != nil { if err != nil {
return nil, err return nil, fmt.Errorf("failed to create prometheus exporter: %w", err)
} }
provider := metric.NewMeterProvider(metric.WithReader(exporter)) provider := metric.NewMeterProvider(metric.WithReader(exporter))
@ -213,32 +213,32 @@ func NewDefaultAppMetrics(ctx context.Context) (AppMetrics, error) {
idpMetrics, err := NewIDPMetrics(ctx, meter) idpMetrics, err := NewIDPMetrics(ctx, meter)
if err != nil { if err != nil {
return nil, err return nil, fmt.Errorf("failed to initialize IDP metrics: %w", err)
} }
middleware, err := NewMetricsMiddleware(ctx, meter) middleware, err := NewMetricsMiddleware(ctx, meter)
if err != nil { if err != nil {
return nil, err return nil, fmt.Errorf("failed to initialize HTTP middleware metrics: %w", err)
} }
grpcMetrics, err := NewGRPCMetrics(ctx, meter) grpcMetrics, err := NewGRPCMetrics(ctx, meter)
if err != nil { if err != nil {
return nil, err return nil, fmt.Errorf("failed to initialize gRPC metrics: %w", err)
} }
storeMetrics, err := NewStoreMetrics(ctx, meter) storeMetrics, err := NewStoreMetrics(ctx, meter)
if err != nil { if err != nil {
return nil, err return nil, fmt.Errorf("failed to initialize store metrics: %w", err)
} }
updateChannelMetrics, err := NewUpdateChannelMetrics(ctx, meter) updateChannelMetrics, err := NewUpdateChannelMetrics(ctx, meter)
if err != nil { if err != nil {
return nil, err return nil, fmt.Errorf("failed to initialize update channel metrics: %w", err)
} }
accountManagerMetrics, err := NewAccountManagerMetrics(ctx, meter) accountManagerMetrics, err := NewAccountManagerMetrics(ctx, meter)
if err != nil { if err != nil {
return nil, err return nil, fmt.Errorf("failed to initialize account manager metrics: %w", err)
} }
return &defaultAppMetrics{ return &defaultAppMetrics{

View File

@ -38,4 +38,5 @@ INSERT INTO "groups" VALUES('cfefqs706sqkneg59g2g','bf1c8084-ba50-4ce7-9439-3465
INSERT INTO posture_checks VALUES('csplshq7qv948l48f7t0','NetBird Version > 0.32.0','','bf1c8084-ba50-4ce7-9439-34653001fc3b','{"NBVersionCheck":{"MinVersion":"0.31.0"}}'); INSERT INTO posture_checks VALUES('csplshq7qv948l48f7t0','NetBird Version > 0.32.0','','bf1c8084-ba50-4ce7-9439-34653001fc3b','{"NBVersionCheck":{"MinVersion":"0.31.0"}}');
INSERT INTO posture_checks VALUES('cspnllq7qv95uq1r4k90','Allow Berlin and Deny local network 172.16.1.0/24','','bf1c8084-ba50-4ce7-9439-34653001fc3b','{"GeoLocationCheck":{"Locations":[{"CountryCode":"DE","CityName":"Berlin"}],"Action":"allow"},"PeerNetworkRangeCheck":{"Action":"deny","Ranges":["172.16.1.0/24"]}}'); INSERT INTO posture_checks VALUES('cspnllq7qv95uq1r4k90','Allow Berlin and Deny local network 172.16.1.0/24','','bf1c8084-ba50-4ce7-9439-34653001fc3b','{"GeoLocationCheck":{"Locations":[{"CountryCode":"DE","CityName":"Berlin"}],"Action":"allow"},"PeerNetworkRangeCheck":{"Action":"deny","Ranges":["172.16.1.0/24"]}}');
INSERT INTO name_server_groups VALUES('csqdelq7qv97ncu7d9t0','bf1c8084-ba50-4ce7-9439-34653001fc3b','Google DNS','Google DNS Servers','[{"IP":"8.8.8.8","NSType":1,"Port":53},{"IP":"8.8.4.4","NSType":1,"Port":53}]','["cfefqs706sqkneg59g2g"]',1,'[]',1,0); INSERT INTO name_server_groups VALUES('csqdelq7qv97ncu7d9t0','bf1c8084-ba50-4ce7-9439-34653001fc3b','Google DNS','Google DNS Servers','[{"IP":"8.8.8.8","NSType":1,"Port":53},{"IP":"8.8.4.4","NSType":1,"Port":53}]','["cfefqs706sqkneg59g2g"]',1,'[]',1,0);
INSERT INTO routes VALUES('ct03t427qv97vmtmglog','bf1c8084-ba50-4ce7-9439-34653001fc3b','"10.10.0.0/16"',NULL,0,'aws-eu-central-1-vpc','Production VPC in Frankfurt','ct03r5q7qv97vmtmglng',NULL,1,1,9999,1,'["cfefqs706sqkneg59g2g"]',NULL);
INSERT INTO installations VALUES(1,''); INSERT INTO installations VALUES(1,'');

View File

@ -3,13 +3,12 @@ package types
import ( import (
"crypto/sha256" "crypto/sha256"
b64 "encoding/base64" b64 "encoding/base64"
"hash/fnv"
"strconv"
"strings" "strings"
"time" "time"
"unicode/utf8" "unicode/utf8"
"github.com/google/uuid" "github.com/google/uuid"
"github.com/rs/xid"
"github.com/netbirdio/netbird/management/server/util" "github.com/netbirdio/netbird/management/server/util"
) )
@ -170,7 +169,7 @@ func GenerateSetupKey(name string, t SetupKeyType, validFor time.Duration, autoG
encodedHashedKey := b64.StdEncoding.EncodeToString(hashedKey[:]) encodedHashedKey := b64.StdEncoding.EncodeToString(hashedKey[:])
return &SetupKey{ return &SetupKey{
Id: strconv.Itoa(int(Hash(key))), Id: xid.New().String(),
Key: encodedHashedKey, Key: encodedHashedKey,
KeySecret: HiddenKey(key, 4), KeySecret: HiddenKey(key, 4),
Name: name, Name: name,
@ -192,12 +191,3 @@ func GenerateDefaultSetupKey() (*SetupKey, string) {
return GenerateSetupKey(DefaultSetupKeyName, SetupKeyReusable, DefaultSetupKeyDuration, []string{}, return GenerateSetupKey(DefaultSetupKeyName, SetupKeyReusable, DefaultSetupKeyDuration, []string{},
SetupKeyUnlimitedUsage, false, false) SetupKeyUnlimitedUsage, false, false)
} }
func Hash(s string) uint32 {
h := fnv.New32a()
_, err := h.Write([]byte(s))
if err != nil {
panic(err)
}
return h.Sum32()
}

35
signal/cmd/env.go Normal file
View File

@ -0,0 +1,35 @@
package cmd
import (
"os"
"strings"
log "github.com/sirupsen/logrus"
"github.com/spf13/cobra"
"github.com/spf13/pflag"
)
// setFlagsFromEnvVars reads and updates flag values from environment variables with prefix NB_
func setFlagsFromEnvVars(cmd *cobra.Command) {
flags := cmd.PersistentFlags()
flags.VisitAll(func(f *pflag.Flag) {
newEnvVar := flagNameToEnvVar(f.Name, "NB_")
value, present := os.LookupEnv(newEnvVar)
if !present {
return
}
err := flags.Set(f.Name, value)
if err != nil {
log.Infof("unable to configure flag %s using variable %s, err: %v", f.Name, newEnvVar, err)
}
})
}
// flagNameToEnvVar converts flag name to environment var name adding a prefix,
// replacing dashes and making all uppercase (e.g. setup-keys is converted to NB_SETUP_KEYS according to the input prefix)
func flagNameToEnvVar(cmdFlag string, prefix string) string {
parsed := strings.ReplaceAll(cmdFlag, "-", "_")
upper := strings.ToUpper(parsed)
return prefix + upper
}

View File

@ -303,4 +303,5 @@ func init() {
runCmd.Flags().StringVar(&signalLetsencryptDomain, "letsencrypt-domain", "", "a domain to issue Let's Encrypt certificate for. Enables TLS using Let's Encrypt. Will fetch and renew certificate, and run the server with TLS") runCmd.Flags().StringVar(&signalLetsencryptDomain, "letsencrypt-domain", "", "a domain to issue Let's Encrypt certificate for. Enables TLS using Let's Encrypt. Will fetch and renew certificate, and run the server with TLS")
runCmd.Flags().StringVar(&signalCertFile, "cert-file", "", "Location of your SSL certificate. Can be used when you have an existing certificate and don't want a new certificate be generated automatically. If letsencrypt-domain is specified this property has no effect") runCmd.Flags().StringVar(&signalCertFile, "cert-file", "", "Location of your SSL certificate. Can be used when you have an existing certificate and don't want a new certificate be generated automatically. If letsencrypt-domain is specified this property has no effect")
runCmd.Flags().StringVar(&signalCertKey, "cert-key", "", "Location of your SSL certificate private key. Can be used when you have an existing certificate and don't want a new certificate be generated automatically. If letsencrypt-domain is specified this property has no effect") runCmd.Flags().StringVar(&signalCertKey, "cert-key", "", "Location of your SSL certificate private key. Can be used when you have an existing certificate and don't want a new certificate be generated automatically. If letsencrypt-domain is specified this property has no effect")
setFlagsFromEnvVars(runCmd)
} }