diff --git a/client/internal/dnsfwd/forwarder.go b/client/internal/dnsfwd/forwarder.go index bc153479c..097daa9e2 100644 --- a/client/internal/dnsfwd/forwarder.go +++ b/client/internal/dnsfwd/forwarder.go @@ -5,11 +5,14 @@ import ( "errors" "net" "net/netip" + "strings" + "sync" "time" "github.com/miekg/dns" log "github.com/sirupsen/logrus" + "github.com/netbirdio/netbird/client/internal/peer" nbdns "github.com/netbirdio/netbird/dns" ) @@ -17,23 +20,27 @@ const errResolveFailed = "failed to resolve query for domain=%s: %v" const upstreamTimeout = 15 * time.Second type DNSForwarder struct { - listenAddress string - ttl uint32 - domains []string + listenAddress string + ttl uint32 + domains []string + statusRecorder *peer.Status dnsServer *dns.Server mux *dns.ServeMux + + resId sync.Map } -func NewDNSForwarder(listenAddress string, ttl uint32) *DNSForwarder { +func NewDNSForwarder(listenAddress string, ttl uint32, statusRecorder *peer.Status) *DNSForwarder { log.Debugf("creating DNS forwarder with listen_address=%s ttl=%d", listenAddress, ttl) return &DNSForwarder{ - listenAddress: listenAddress, - ttl: ttl, + listenAddress: listenAddress, + ttl: ttl, + statusRecorder: statusRecorder, } } -func (f *DNSForwarder) Listen(domains []string) error { +func (f *DNSForwarder) Listen(domains []string, resIds map[string]string) error { log.Infof("listen DNS forwarder on address=%s", f.listenAddress) mux := dns.NewServeMux() @@ -45,22 +52,31 @@ func (f *DNSForwarder) Listen(domains []string) error { f.dnsServer = dnsServer f.mux = mux - f.UpdateDomains(domains) + f.UpdateDomains(domains, resIds) return dnsServer.ListenAndServe() } -func (f *DNSForwarder) UpdateDomains(domains []string) { +func (f *DNSForwarder) UpdateDomains(domains []string, resIds map[string]string) { log.Debugf("Updating domains from %v to %v", f.domains, domains) for _, d := range f.domains { f.mux.HandleRemove(d) + f.statusRecorder.RemoveResolvedIPLookupEntry(d) } + f.resId.Clear() newDomains := filterDomains(domains) for _, d := range newDomains { f.mux.HandleFunc(d, f.handleDNSQuery) } + + for domain, resId := range resIds { + if domain != "" { + f.resId.Store(domain, resId) + } + } + f.domains = newDomains } @@ -106,6 +122,21 @@ func (f *DNSForwarder) handleDNSQuery(w dns.ResponseWriter, query *dns.Msg) { return } + resId, ok := f.resId.Load(strings.TrimSuffix(domain, ".")) + if ok { + for _, ip := range ips { + var ipWithSuffix string + if ip.Is4() { + ipWithSuffix = ip.String() + "/32" + log.Tracef("resolved domain=%s to IPv4=%s", domain, ipWithSuffix) + } else { + ipWithSuffix = ip.String() + "/128" + log.Tracef("resolved domain=%s to IPv6=%s", domain, ipWithSuffix) + } + f.statusRecorder.AddResolvedIPLookupEntry(ipWithSuffix, resId.(string)) + } + } + f.addIPsToResponse(resp, domain, ips) if err := w.WriteMsg(resp); err != nil { diff --git a/client/internal/dnsfwd/manager.go b/client/internal/dnsfwd/manager.go index 8dae06aec..a51ae7abb 100644 --- a/client/internal/dnsfwd/manager.go +++ b/client/internal/dnsfwd/manager.go @@ -10,6 +10,7 @@ import ( nberrors "github.com/netbirdio/netbird/client/errors" firewall "github.com/netbirdio/netbird/client/firewall/manager" + "github.com/netbirdio/netbird/client/internal/peer" ) const ( @@ -19,19 +20,21 @@ const ( ) type Manager struct { - firewall firewall.Manager + firewall firewall.Manager + statusRecorder *peer.Status fwRules []firewall.Rule dnsForwarder *DNSForwarder } -func NewManager(fw firewall.Manager) *Manager { +func NewManager(fw firewall.Manager, statusRecorder *peer.Status) *Manager { return &Manager{ - firewall: fw, + firewall: fw, + statusRecorder: statusRecorder, } } -func (m *Manager) Start(domains []string) error { +func (m *Manager) Start(domains []string, resIds map[string]string) error { log.Infof("starting DNS forwarder") if m.dnsForwarder != nil { return nil @@ -41,9 +44,9 @@ func (m *Manager) Start(domains []string) error { return err } - m.dnsForwarder = NewDNSForwarder(fmt.Sprintf(":%d", ListenPort), dnsTTL) + m.dnsForwarder = NewDNSForwarder(fmt.Sprintf(":%d", ListenPort), dnsTTL, m.statusRecorder) go func() { - if err := m.dnsForwarder.Listen(domains); err != nil { + if err := m.dnsForwarder.Listen(domains, resIds); err != nil { // todo handle close error if it is exists log.Errorf("failed to start DNS forwarder, err: %v", err) } @@ -52,12 +55,12 @@ func (m *Manager) Start(domains []string) error { return nil } -func (m *Manager) UpdateDomains(domains []string) { +func (m *Manager) UpdateDomains(domains []string, resIds map[string]string) { if m.dnsForwarder == nil { return } - m.dnsForwarder.UpdateDomains(domains) + m.dnsForwarder.UpdateDomains(domains, resIds) } func (m *Manager) Stop(ctx context.Context) error { diff --git a/client/internal/engine.go b/client/internal/engine.go index 260e807a0..74a07927c 100644 --- a/client/internal/engine.go +++ b/client/internal/engine.go @@ -962,8 +962,8 @@ func (e *Engine) updateNetworkMap(networkMap *mgmProto.NetworkMap) error { // DNS forwarder dnsRouteFeatureFlag := toDNSFeatureFlag(networkMap) - dnsRouteDomains := toRouteDomains(e.config.WgPrivateKey.PublicKey().String(), networkMap.GetRoutes()) - e.updateDNSForwarder(dnsRouteFeatureFlag, dnsRouteDomains) + dnsRouteDomains, resourceIds := toRouteDomains(e.config.WgPrivateKey.PublicKey().String(), networkMap.GetRoutes()) + e.updateDNSForwarder(dnsRouteFeatureFlag, dnsRouteDomains, resourceIds) routes := toRoutes(networkMap.GetRoutes()) if err := e.routeManager.UpdateRoutes(serial, routes, dnsRouteFeatureFlag); err != nil { @@ -1079,21 +1079,29 @@ func toRoutes(protoRoutes []*mgmProto.Route) []*route.Route { return routes } -func toRouteDomains(myPubKey string, protoRoutes []*mgmProto.Route) []string { +func toRouteDomains(myPubKey string, protoRoutes []*mgmProto.Route) ([]string, map[string]string) { if protoRoutes == nil { protoRoutes = []*mgmProto.Route{} } var dnsRoutes []string + resIds := make(map[string]string) for _, protoRoute := range protoRoutes { if len(protoRoute.Domains) == 0 { continue } if protoRoute.Peer == myPubKey { dnsRoutes = append(dnsRoutes, protoRoute.Domains...) + // resource ID is the first part of the ID + resId := strings.Split(protoRoute.ID, ":") + for _, domain := range protoRoute.Domains { + if len(resId) > 0 { + resIds[domain] = resId[0] + } + } } } - return dnsRoutes + return dnsRoutes, resIds } func toDNSConfig(protoDNSConfig *mgmProto.DNSConfig, network *net.IPNet) nbdns.Config { @@ -1760,7 +1768,7 @@ func (e *Engine) GetWgAddr() net.IP { } // updateDNSForwarder start or stop the DNS forwarder based on the domains and the feature flag -func (e *Engine) updateDNSForwarder(enabled bool, domains []string) { +func (e *Engine) updateDNSForwarder(enabled bool, domains []string, resIds map[string]string) { if !enabled { if e.dnsForwardMgr == nil { return @@ -1774,15 +1782,15 @@ func (e *Engine) updateDNSForwarder(enabled bool, domains []string) { if len(domains) > 0 { log.Infof("enable domain router service for domains: %v", domains) if e.dnsForwardMgr == nil { - e.dnsForwardMgr = dnsfwd.NewManager(e.firewall) + e.dnsForwardMgr = dnsfwd.NewManager(e.firewall, e.statusRecorder) - if err := e.dnsForwardMgr.Start(domains); err != nil { + if err := e.dnsForwardMgr.Start(domains, resIds); err != nil { log.Errorf("failed to start DNS forward: %v", err) e.dnsForwardMgr = nil } } else { log.Infof("update domain router service for domains: %v", domains) - e.dnsForwardMgr.UpdateDomains(domains) + e.dnsForwardMgr.UpdateDomains(domains, resIds) } } else if e.dnsForwardMgr != nil { log.Infof("disable domain router service") diff --git a/client/internal/netflow/logger/logger.go b/client/internal/netflow/logger/logger.go index 43dc975fd..a3bd091b6 100644 --- a/client/internal/netflow/logger/logger.go +++ b/client/internal/netflow/logger/logger.go @@ -86,18 +86,18 @@ func (l *Logger) startReceiver() { Timestamp: time.Now().UTC(), } - var isExitNode bool - if event.Direction == types.Ingress { - if !l.wgIfaceIPNet.Contains(net.IP(event.SourceIP.AsSlice())) { - event.SourceResourceID, isExitNode = l.statusRecorder.CheckRoutes(event.SourceIP) - } - } else if event.Direction == types.Egress { - if !l.wgIfaceIPNet.Contains(net.IP(event.DestIP.AsSlice())) { - event.DestResourceID, isExitNode = l.statusRecorder.CheckRoutes(event.DestIP) - } + var isSrcExitNode bool + var isDestExitNode bool + + if !l.wgIfaceIPNet.Contains(net.IP(event.SourceIP.AsSlice())) { + event.SourceResourceID, isSrcExitNode = l.statusRecorder.CheckRoutes(event.SourceIP) } - if l.shouldStore(eventFields, isExitNode) { + if !l.wgIfaceIPNet.Contains(net.IP(event.DestIP.AsSlice())) { + event.DestResourceID, isDestExitNode = l.statusRecorder.CheckRoutes(event.DestIP) + } + + if l.shouldStore(eventFields, isSrcExitNode || isDestExitNode) { l.Store.StoreEvent(&event) } } diff --git a/client/internal/peer/route.go b/client/internal/peer/route.go index ff9aafcb2..c3567dcc9 100644 --- a/client/internal/peer/route.go +++ b/client/internal/peer/route.go @@ -2,37 +2,89 @@ package peer import ( "net/netip" + "sort" "sync" log "github.com/sirupsen/logrus" ) +// routeEntry holds the route prefix and the corresponding resource ID. +type routeEntry struct { + prefix netip.Prefix + resourceID string +} + type routeIDLookup struct { - localMap sync.Map - remoteMap sync.Map + localRoutes []routeEntry + localLock sync.RWMutex + + remoteRoutes []routeEntry + remoteLock sync.RWMutex + resolvedIPs sync.Map } func (r *routeIDLookup) AddLocalRouteID(resourceID string, route netip.Prefix) { - _, exists := r.localMap.LoadOrStore(route, resourceID) - if exists { - log.Tracef("resourceID %s already exists in local map", resourceID) + r.localLock.Lock() + defer r.localLock.Unlock() + + // update the resource id if the route already exists. + for i, entry := range r.localRoutes { + if entry.prefix == route { + r.localRoutes[i].resourceID = resourceID + log.Tracef("resourceID for route %v updated to %s in local routes", route, resourceID) + return + } } + + // append and sort descending by prefix bits (more specific first) + r.localRoutes = append(r.localRoutes, routeEntry{prefix: route, resourceID: resourceID}) + sort.Slice(r.localRoutes, func(i, j int) bool { + return r.localRoutes[i].prefix.Bits() > r.localRoutes[j].prefix.Bits() + }) } func (r *routeIDLookup) RemoveLocalRouteID(route netip.Prefix) { - r.localMap.Delete(route) -} + r.localLock.Lock() + defer r.localLock.Unlock() -func (r *routeIDLookup) AddRemoteRouteID(resourceID string, route netip.Prefix) { - _, exists := r.remoteMap.LoadOrStore(route, resourceID) - if exists { - log.Tracef("resourceID %s already exists in remote map", resourceID) + for i, entry := range r.localRoutes { + if entry.prefix == route { + r.localRoutes = append(r.localRoutes[:i], r.localRoutes[i+1:]...) + return + } } } +func (r *routeIDLookup) AddRemoteRouteID(resourceID string, route netip.Prefix) { + r.remoteLock.Lock() + defer r.remoteLock.Unlock() + + for i, entry := range r.remoteRoutes { + if entry.prefix == route { + r.remoteRoutes[i].resourceID = resourceID + log.Tracef("resourceID for route %v updated to %s in remote routes", route, resourceID) + return + } + } + + // append and sort descending by prefix bits. + r.remoteRoutes = append(r.remoteRoutes, routeEntry{prefix: route, resourceID: resourceID}) + sort.Slice(r.remoteRoutes, func(i, j int) bool { + return r.remoteRoutes[i].prefix.Bits() > r.remoteRoutes[j].prefix.Bits() + }) +} + func (r *routeIDLookup) RemoveRemoteRouteID(route netip.Prefix) { - r.remoteMap.Delete(route) + r.remoteLock.Lock() + defer r.remoteLock.Unlock() + + for i, entry := range r.remoteRoutes { + if entry.prefix == route { + r.remoteRoutes = append(r.remoteRoutes[:i], r.remoteRoutes[i+1:]...) + return + } + } } func (r *routeIDLookup) AddResolvedIP(resourceID string, route netip.Prefix) { @@ -44,37 +96,35 @@ func (r *routeIDLookup) RemoveResolvedIP(route netip.Prefix) { } // Lookup returns the resource ID for the given IP address -// and a bool indicating if the IP is an exit node +// and a bool indicating if the IP is an exit node. func (r *routeIDLookup) Lookup(ip netip.Addr) (string, bool) { - var isExitNode bool - - resId, ok := r.resolvedIPs.Load(ip) - if ok { - return resId.(string), false + if res, ok := r.resolvedIPs.Load(ip); ok { + return res.(string), false } var resourceID string - r.localMap.Range(func(key, value interface{}) bool { - pref := key.(netip.Prefix) - if pref.Contains(ip) { - resourceID = value.(string) - isExitNode = pref.Bits() == 0 - return false + var isExitNode bool + r.localLock.RLock() + for _, entry := range r.localRoutes { + if entry.prefix.Contains(ip) { + resourceID = entry.resourceID + isExitNode = (entry.prefix.Bits() == 0) + break } - return true - }) + } + r.localLock.RUnlock() if resourceID == "" { - r.remoteMap.Range(func(key, value interface{}) bool { - pref := key.(netip.Prefix) - if pref.Contains(ip) { - resourceID = value.(string) - isExitNode = pref.Bits() == 0 - return false + r.remoteLock.RLock() + for _, entry := range r.remoteRoutes { + if entry.prefix.Contains(ip) { + resourceID = entry.resourceID + isExitNode = (entry.prefix.Bits() == 0) + break } - return true - }) + } + r.remoteLock.RUnlock() } return resourceID, isExitNode diff --git a/client/internal/peer/status.go b/client/internal/peer/status.go index dc96118e3..9b3fc744d 100644 --- a/client/internal/peer/status.go +++ b/client/internal/peer/status.go @@ -610,6 +610,28 @@ func (d *Status) RemoveLocalPeerStateRoute(route string) { delete(d.localPeer.Routes, route) } +// AddResolvedIPLookupEntry adds a resolved IP lookup entry +func (d *Status) AddResolvedIPLookupEntry(route, resourceId string) { + d.mux.Lock() + defer d.mux.Unlock() + + pref, err := netip.ParsePrefix(route) + if err == nil { + d.routeIDLookup.AddResolvedIP(resourceId, pref) + } +} + +// RemoveResolvedIPLookupEntry removes a resolved IP lookup entry +func (d *Status) RemoveResolvedIPLookupEntry(route string) { + d.mux.Lock() + defer d.mux.Unlock() + + pref, err := netip.ParsePrefix(route) + if err == nil { + d.routeIDLookup.RemoveResolvedIP(pref) + } +} + // CleanLocalPeerStateRoutes cleans all routes from the local peer state func (d *Status) CleanLocalPeerStateRoutes() { d.mux.Lock()