Fix DNS resolution for routes on iOS (#2378)

This commit is contained in:
pascal-fischer 2024-08-02 18:43:00 +02:00 committed by GitHub
parent 727a4f0753
commit 501fd93e47
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
8 changed files with 124 additions and 37 deletions

View File

@ -94,7 +94,7 @@ func NewDefaultServer(
var dnsService service var dnsService service
if wgInterface.IsUserspaceBind() { if wgInterface.IsUserspaceBind() {
dnsService = newServiceViaMemory(wgInterface) dnsService = NewServiceViaMemory(wgInterface)
} else { } else {
dnsService = newServiceViaListener(wgInterface, addrPort) dnsService = newServiceViaListener(wgInterface, addrPort)
} }
@ -112,7 +112,7 @@ func NewDefaultServerPermanentUpstream(
statusRecorder *peer.Status, statusRecorder *peer.Status,
) *DefaultServer { ) *DefaultServer {
log.Debugf("host dns address list is: %v", hostsDnsList) log.Debugf("host dns address list is: %v", hostsDnsList)
ds := newDefaultServer(ctx, wgInterface, newServiceViaMemory(wgInterface), statusRecorder) ds := newDefaultServer(ctx, wgInterface, NewServiceViaMemory(wgInterface), statusRecorder)
ds.hostsDNSHolder.set(hostsDnsList) ds.hostsDNSHolder.set(hostsDnsList)
ds.permanent = true ds.permanent = true
ds.addHostRootZone() ds.addHostRootZone()
@ -130,7 +130,7 @@ func NewDefaultServerIos(
iosDnsManager IosDnsManager, iosDnsManager IosDnsManager,
statusRecorder *peer.Status, statusRecorder *peer.Status,
) *DefaultServer { ) *DefaultServer {
ds := newDefaultServer(ctx, wgInterface, newServiceViaMemory(wgInterface), statusRecorder) ds := newDefaultServer(ctx, wgInterface, NewServiceViaMemory(wgInterface), statusRecorder)
ds.iosDnsManager = iosDnsManager ds.iosDnsManager = iosDnsManager
return ds return ds
} }

View File

@ -534,7 +534,7 @@ func TestDNSServerStartStop(t *testing.T) {
func TestDNSServerUpstreamDeactivateCallback(t *testing.T) { func TestDNSServerUpstreamDeactivateCallback(t *testing.T) {
hostManager := &mockHostConfigurator{} hostManager := &mockHostConfigurator{}
server := DefaultServer{ server := DefaultServer{
service: newServiceViaMemory(&mocWGIface{}), service: NewServiceViaMemory(&mocWGIface{}),
localResolver: &localResolver{ localResolver: &localResolver{
registeredMap: make(registrationMap), registeredMap: make(registrationMap),
}, },

View File

@ -12,7 +12,7 @@ import (
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
) )
type serviceViaMemory struct { type ServiceViaMemory struct {
wgInterface WGIface wgInterface WGIface
dnsMux *dns.ServeMux dnsMux *dns.ServeMux
runtimeIP string runtimeIP string
@ -22,8 +22,8 @@ type serviceViaMemory struct {
listenerFlagLock sync.Mutex listenerFlagLock sync.Mutex
} }
func newServiceViaMemory(wgIface WGIface) *serviceViaMemory { func NewServiceViaMemory(wgIface WGIface) *ServiceViaMemory {
s := &serviceViaMemory{ s := &ServiceViaMemory{
wgInterface: wgIface, wgInterface: wgIface,
dnsMux: dns.NewServeMux(), dnsMux: dns.NewServeMux(),
@ -33,7 +33,7 @@ func newServiceViaMemory(wgIface WGIface) *serviceViaMemory {
return s return s
} }
func (s *serviceViaMemory) Listen() error { func (s *ServiceViaMemory) Listen() error {
s.listenerFlagLock.Lock() s.listenerFlagLock.Lock()
defer s.listenerFlagLock.Unlock() defer s.listenerFlagLock.Unlock()
@ -52,7 +52,7 @@ func (s *serviceViaMemory) Listen() error {
return nil return nil
} }
func (s *serviceViaMemory) Stop() { func (s *ServiceViaMemory) Stop() {
s.listenerFlagLock.Lock() s.listenerFlagLock.Lock()
defer s.listenerFlagLock.Unlock() defer s.listenerFlagLock.Unlock()
@ -67,23 +67,23 @@ func (s *serviceViaMemory) Stop() {
s.listenerIsRunning = false s.listenerIsRunning = false
} }
func (s *serviceViaMemory) RegisterMux(pattern string, handler dns.Handler) { func (s *ServiceViaMemory) RegisterMux(pattern string, handler dns.Handler) {
s.dnsMux.Handle(pattern, handler) s.dnsMux.Handle(pattern, handler)
} }
func (s *serviceViaMemory) DeregisterMux(pattern string) { func (s *ServiceViaMemory) DeregisterMux(pattern string) {
s.dnsMux.HandleRemove(pattern) s.dnsMux.HandleRemove(pattern)
} }
func (s *serviceViaMemory) RuntimePort() int { func (s *ServiceViaMemory) RuntimePort() int {
return s.runtimePort return s.runtimePort
} }
func (s *serviceViaMemory) RuntimeIP() string { func (s *ServiceViaMemory) RuntimeIP() string {
return s.runtimeIP return s.runtimeIP
} }
func (s *serviceViaMemory) filterDNSTraffic() (string, error) { func (s *ServiceViaMemory) filterDNSTraffic() (string, error) {
filter := s.wgInterface.GetFilter() filter := s.wgInterface.GetFilter()
if filter == nil { if filter == nil {
return "", fmt.Errorf("can't set DNS filter, filter not initialized") return "", fmt.Errorf("can't set DNS filter, filter not initialized")

View File

@ -4,6 +4,7 @@ package dns
import ( import (
"context" "context"
"fmt"
"net" "net"
"syscall" "syscall"
"time" "time"
@ -17,9 +18,9 @@ import (
type upstreamResolverIOS struct { type upstreamResolverIOS struct {
*upstreamResolverBase *upstreamResolverBase
lIP net.IP lIP net.IP
lNet *net.IPNet lNet *net.IPNet
iIndex int interfaceName string
} }
func newUpstreamResolver( func newUpstreamResolver(
@ -32,17 +33,11 @@ func newUpstreamResolver(
) (*upstreamResolverIOS, error) { ) (*upstreamResolverIOS, error) {
upstreamResolverBase := newUpstreamResolverBase(ctx, statusRecorder) upstreamResolverBase := newUpstreamResolverBase(ctx, statusRecorder)
index, err := getInterfaceIndex(interfaceName)
if err != nil {
log.Debugf("unable to get interface index for %s: %s", interfaceName, err)
return nil, err
}
ios := &upstreamResolverIOS{ ios := &upstreamResolverIOS{
upstreamResolverBase: upstreamResolverBase, upstreamResolverBase: upstreamResolverBase,
lIP: ip, lIP: ip,
lNet: net, lNet: net,
iIndex: index, interfaceName: interfaceName,
} }
ios.upstreamClient = ios ios.upstreamClient = ios
@ -53,7 +48,7 @@ func (u *upstreamResolverIOS) exchange(ctx context.Context, upstream string, r *
client := &dns.Client{} client := &dns.Client{}
upstreamHost, _, err := net.SplitHostPort(upstream) upstreamHost, _, err := net.SplitHostPort(upstream)
if err != nil { if err != nil {
log.Errorf("error while parsing upstream host: %s", err) return nil, 0, fmt.Errorf("error while parsing upstream host: %s", err)
} }
timeout := upstreamTimeout timeout := upstreamTimeout
@ -65,26 +60,35 @@ func (u *upstreamResolverIOS) exchange(ctx context.Context, upstream string, r *
upstreamIP := net.ParseIP(upstreamHost) upstreamIP := net.ParseIP(upstreamHost)
if u.lNet.Contains(upstreamIP) || net.IP.IsPrivate(upstreamIP) { if u.lNet.Contains(upstreamIP) || net.IP.IsPrivate(upstreamIP) {
log.Debugf("using private client to query upstream: %s", upstream) log.Debugf("using private client to query upstream: %s", upstream)
client = u.getClientPrivate(timeout) client, err = GetClientPrivate(u.lIP, u.interfaceName, timeout)
if err != nil {
return nil, 0, fmt.Errorf("error while creating private client: %s", err)
}
} }
// Cannot use client.ExchangeContext because it overwrites our Dialer // Cannot use client.ExchangeContext because it overwrites our Dialer
return client.Exchange(r, upstream) return client.Exchange(r, upstream)
} }
// getClientPrivate returns a new DNS client bound to the local IP address of the Netbird interface // GetClientPrivate returns a new DNS client bound to the local IP address of the Netbird interface
// This method is needed for iOS // This method is needed for iOS
func (u *upstreamResolverIOS) getClientPrivate(dialTimeout time.Duration) *dns.Client { func GetClientPrivate(ip net.IP, interfaceName string, dialTimeout time.Duration) (*dns.Client, error) {
index, err := getInterfaceIndex(interfaceName)
if err != nil {
log.Debugf("unable to get interface index for %s: %s", interfaceName, err)
return nil, err
}
dialer := &net.Dialer{ dialer := &net.Dialer{
LocalAddr: &net.UDPAddr{ LocalAddr: &net.UDPAddr{
IP: u.lIP, IP: ip,
Port: 0, // Let the OS pick a free port Port: 0, // Let the OS pick a free port
}, },
Timeout: dialTimeout, Timeout: dialTimeout,
Control: func(network, address string, c syscall.RawConn) error { Control: func(network, address string, c syscall.RawConn) error {
var operr error var operr error
fn := func(s uintptr) { fn := func(s uintptr) {
operr = unix.SetsockoptInt(int(s), unix.IPPROTO_IP, unix.IP_BOUND_IF, u.iIndex) operr = unix.SetsockoptInt(int(s), unix.IPPROTO_IP, unix.IP_BOUND_IF, index)
} }
if err := c.Control(fn); err != nil { if err := c.Control(fn); err != nil {
@ -101,7 +105,7 @@ func (u *upstreamResolverIOS) getClientPrivate(dialTimeout time.Duration) *dns.C
client := &dns.Client{ client := &dns.Client{
Dialer: dialer, Dialer: dialer,
} }
return client return client, nil
} }
func getInterfaceIndex(interfaceName string) (int, error) { func getInterfaceIndex(interfaceName string) (int, error) {

View File

@ -10,6 +10,7 @@ import (
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
nberrors "github.com/netbirdio/netbird/client/errors" nberrors "github.com/netbirdio/netbird/client/errors"
nbdns "github.com/netbirdio/netbird/client/internal/dns"
"github.com/netbirdio/netbird/client/internal/peer" "github.com/netbirdio/netbird/client/internal/peer"
"github.com/netbirdio/netbird/client/internal/routemanager/dynamic" "github.com/netbirdio/netbird/client/internal/routemanager/dynamic"
"github.com/netbirdio/netbird/client/internal/routemanager/refcounter" "github.com/netbirdio/netbird/client/internal/routemanager/refcounter"
@ -65,7 +66,7 @@ func newClientNetworkWatcher(ctx context.Context, dnsRouteInterval time.Duration
routePeersNotifiers: make(map[string]chan struct{}), routePeersNotifiers: make(map[string]chan struct{}),
routeUpdate: make(chan routesUpdate), routeUpdate: make(chan routesUpdate),
peerStateUpdate: make(chan struct{}), peerStateUpdate: make(chan struct{}),
handler: handlerFromRoute(rt, routeRefCounter, allowedIPsRefCounter, dnsRouteInterval, statusRecorder), handler: handlerFromRoute(rt, routeRefCounter, allowedIPsRefCounter, dnsRouteInterval, statusRecorder, wgInterface),
} }
return client return client
} }
@ -383,9 +384,10 @@ func (c *clientNetwork) peersStateAndUpdateWatcher() {
} }
} }
func handlerFromRoute(rt *route.Route, routeRefCounter *refcounter.RouteRefCounter, allowedIPsRefCounter *refcounter.AllowedIPsRefCounter, dnsRouterInteval time.Duration, statusRecorder *peer.Status) RouteHandler { func handlerFromRoute(rt *route.Route, routeRefCounter *refcounter.RouteRefCounter, allowedIPsRefCounter *refcounter.AllowedIPsRefCounter, dnsRouterInteval time.Duration, statusRecorder *peer.Status, wgInterface *iface.WGIface) RouteHandler {
if rt.IsDynamic() { if rt.IsDynamic() {
return dynamic.NewRoute(rt, routeRefCounter, allowedIPsRefCounter, dnsRouterInteval, statusRecorder) dns := nbdns.NewServiceViaMemory(wgInterface)
return dynamic.NewRoute(rt, routeRefCounter, allowedIPsRefCounter, dnsRouterInteval, statusRecorder, wgInterface, fmt.Sprintf("%s:%d", dns.RuntimeIP(), dns.RuntimePort()))
} }
return static.NewRoute(rt, routeRefCounter, allowedIPsRefCounter) return static.NewRoute(rt, routeRefCounter, allowedIPsRefCounter)
} }

View File

@ -16,6 +16,7 @@ import (
"github.com/netbirdio/netbird/client/internal/peer" "github.com/netbirdio/netbird/client/internal/peer"
"github.com/netbirdio/netbird/client/internal/routemanager/refcounter" "github.com/netbirdio/netbird/client/internal/routemanager/refcounter"
"github.com/netbirdio/netbird/client/internal/routemanager/util" "github.com/netbirdio/netbird/client/internal/routemanager/util"
"github.com/netbirdio/netbird/iface"
"github.com/netbirdio/netbird/management/domain" "github.com/netbirdio/netbird/management/domain"
"github.com/netbirdio/netbird/route" "github.com/netbirdio/netbird/route"
) )
@ -47,6 +48,8 @@ type Route struct {
currentPeerKey string currentPeerKey string
cancel context.CancelFunc cancel context.CancelFunc
statusRecorder *peer.Status statusRecorder *peer.Status
wgInterface *iface.WGIface
resolverAddr string
} }
func NewRoute( func NewRoute(
@ -55,6 +58,8 @@ func NewRoute(
allowedIPsRefCounter *refcounter.AllowedIPsRefCounter, allowedIPsRefCounter *refcounter.AllowedIPsRefCounter,
interval time.Duration, interval time.Duration,
statusRecorder *peer.Status, statusRecorder *peer.Status,
wgInterface *iface.WGIface,
resolverAddr string,
) *Route { ) *Route {
return &Route{ return &Route{
route: rt, route: rt,
@ -63,6 +68,8 @@ func NewRoute(
interval: interval, interval: interval,
dynamicDomains: domainMap{}, dynamicDomains: domainMap{},
statusRecorder: statusRecorder, statusRecorder: statusRecorder,
wgInterface: wgInterface,
resolverAddr: resolverAddr,
} }
} }
@ -228,11 +235,17 @@ func (r *Route) resolve(results chan resolveResult) {
wg.Add(1) wg.Add(1)
go func(domain domain.Domain) { go func(domain domain.Domain) {
defer wg.Done() defer wg.Done()
ips, err := net.LookupIP(string(domain))
ips, err := r.getIPsFromResolver(domain)
if err != nil { if err != nil {
results <- resolveResult{domain: domain, err: fmt.Errorf("resolve d %s: %w", domain.SafeString(), err)} log.Tracef("Failed to resolve domain %s with private resolver: %v", domain.SafeString(), err)
return ips, err = net.LookupIP(string(domain))
if err != nil {
results <- resolveResult{domain: domain, err: fmt.Errorf("resolve d %s: %w", domain.SafeString(), err)}
return
}
} }
for _, ip := range ips { for _, ip := range ips {
prefix, err := util.GetPrefixFromIP(ip) prefix, err := util.GetPrefixFromIP(ip)
if err != nil { if err != nil {

View File

@ -0,0 +1,13 @@
//go:build !ios
package dynamic
import (
"net"
"github.com/netbirdio/netbird/management/domain"
)
func (r *Route) getIPsFromResolver(domain domain.Domain) ([]net.IP, error) {
return net.LookupIP(string(domain))
}

View File

@ -0,0 +1,55 @@
//go:build ios
package dynamic
import (
"fmt"
"net"
"time"
"github.com/miekg/dns"
nbdns "github.com/netbirdio/netbird/client/internal/dns"
"github.com/netbirdio/netbird/management/domain"
)
const dialTimeout = 10 * time.Second
func (r *Route) getIPsFromResolver(domain domain.Domain) ([]net.IP, error) {
privateClient, err := nbdns.GetClientPrivate(r.wgInterface.Address().IP, r.wgInterface.Name(), dialTimeout)
if err != nil {
return nil, fmt.Errorf("error while creating private client: %s", err)
}
msg := new(dns.Msg)
msg.SetQuestion(dns.Fqdn(string(domain)), dns.TypeA)
startTime := time.Now()
response, _, err := privateClient.Exchange(msg, r.resolverAddr)
if err != nil {
return nil, fmt.Errorf("DNS query for %s failed after %s: %s ", domain.SafeString(), time.Since(startTime), err)
}
if response.Rcode != dns.RcodeSuccess {
return nil, fmt.Errorf("dns response code: %s", dns.RcodeToString[response.Rcode])
}
ips := make([]net.IP, 0)
for _, answ := range response.Answer {
if aRecord, ok := answ.(*dns.A); ok {
ips = append(ips, aRecord.A)
}
if aaaaRecord, ok := answ.(*dns.AAAA); ok {
ips = append(ips, aaaaRecord.AAAA)
}
}
if len(ips) == 0 {
return nil, fmt.Errorf("no A or AAAA records found for %s", domain.SafeString())
}
return ips, nil
}