mirror of
https://github.com/netbirdio/netbird.git
synced 2024-11-25 17:43:38 +01:00
Fix DNS resolution for routes on iOS (#2378)
This commit is contained in:
parent
727a4f0753
commit
501fd93e47
@ -94,7 +94,7 @@ func NewDefaultServer(
|
|||||||
|
|
||||||
var dnsService service
|
var dnsService service
|
||||||
if wgInterface.IsUserspaceBind() {
|
if wgInterface.IsUserspaceBind() {
|
||||||
dnsService = newServiceViaMemory(wgInterface)
|
dnsService = NewServiceViaMemory(wgInterface)
|
||||||
} else {
|
} else {
|
||||||
dnsService = newServiceViaListener(wgInterface, addrPort)
|
dnsService = newServiceViaListener(wgInterface, addrPort)
|
||||||
}
|
}
|
||||||
@ -112,7 +112,7 @@ func NewDefaultServerPermanentUpstream(
|
|||||||
statusRecorder *peer.Status,
|
statusRecorder *peer.Status,
|
||||||
) *DefaultServer {
|
) *DefaultServer {
|
||||||
log.Debugf("host dns address list is: %v", hostsDnsList)
|
log.Debugf("host dns address list is: %v", hostsDnsList)
|
||||||
ds := newDefaultServer(ctx, wgInterface, newServiceViaMemory(wgInterface), statusRecorder)
|
ds := newDefaultServer(ctx, wgInterface, NewServiceViaMemory(wgInterface), statusRecorder)
|
||||||
ds.hostsDNSHolder.set(hostsDnsList)
|
ds.hostsDNSHolder.set(hostsDnsList)
|
||||||
ds.permanent = true
|
ds.permanent = true
|
||||||
ds.addHostRootZone()
|
ds.addHostRootZone()
|
||||||
@ -130,7 +130,7 @@ func NewDefaultServerIos(
|
|||||||
iosDnsManager IosDnsManager,
|
iosDnsManager IosDnsManager,
|
||||||
statusRecorder *peer.Status,
|
statusRecorder *peer.Status,
|
||||||
) *DefaultServer {
|
) *DefaultServer {
|
||||||
ds := newDefaultServer(ctx, wgInterface, newServiceViaMemory(wgInterface), statusRecorder)
|
ds := newDefaultServer(ctx, wgInterface, NewServiceViaMemory(wgInterface), statusRecorder)
|
||||||
ds.iosDnsManager = iosDnsManager
|
ds.iosDnsManager = iosDnsManager
|
||||||
return ds
|
return ds
|
||||||
}
|
}
|
||||||
|
@ -534,7 +534,7 @@ func TestDNSServerStartStop(t *testing.T) {
|
|||||||
func TestDNSServerUpstreamDeactivateCallback(t *testing.T) {
|
func TestDNSServerUpstreamDeactivateCallback(t *testing.T) {
|
||||||
hostManager := &mockHostConfigurator{}
|
hostManager := &mockHostConfigurator{}
|
||||||
server := DefaultServer{
|
server := DefaultServer{
|
||||||
service: newServiceViaMemory(&mocWGIface{}),
|
service: NewServiceViaMemory(&mocWGIface{}),
|
||||||
localResolver: &localResolver{
|
localResolver: &localResolver{
|
||||||
registeredMap: make(registrationMap),
|
registeredMap: make(registrationMap),
|
||||||
},
|
},
|
||||||
|
@ -12,7 +12,7 @@ import (
|
|||||||
log "github.com/sirupsen/logrus"
|
log "github.com/sirupsen/logrus"
|
||||||
)
|
)
|
||||||
|
|
||||||
type serviceViaMemory struct {
|
type ServiceViaMemory struct {
|
||||||
wgInterface WGIface
|
wgInterface WGIface
|
||||||
dnsMux *dns.ServeMux
|
dnsMux *dns.ServeMux
|
||||||
runtimeIP string
|
runtimeIP string
|
||||||
@ -22,8 +22,8 @@ type serviceViaMemory struct {
|
|||||||
listenerFlagLock sync.Mutex
|
listenerFlagLock sync.Mutex
|
||||||
}
|
}
|
||||||
|
|
||||||
func newServiceViaMemory(wgIface WGIface) *serviceViaMemory {
|
func NewServiceViaMemory(wgIface WGIface) *ServiceViaMemory {
|
||||||
s := &serviceViaMemory{
|
s := &ServiceViaMemory{
|
||||||
wgInterface: wgIface,
|
wgInterface: wgIface,
|
||||||
dnsMux: dns.NewServeMux(),
|
dnsMux: dns.NewServeMux(),
|
||||||
|
|
||||||
@ -33,7 +33,7 @@ func newServiceViaMemory(wgIface WGIface) *serviceViaMemory {
|
|||||||
return s
|
return s
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *serviceViaMemory) Listen() error {
|
func (s *ServiceViaMemory) Listen() error {
|
||||||
s.listenerFlagLock.Lock()
|
s.listenerFlagLock.Lock()
|
||||||
defer s.listenerFlagLock.Unlock()
|
defer s.listenerFlagLock.Unlock()
|
||||||
|
|
||||||
@ -52,7 +52,7 @@ func (s *serviceViaMemory) Listen() error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *serviceViaMemory) Stop() {
|
func (s *ServiceViaMemory) Stop() {
|
||||||
s.listenerFlagLock.Lock()
|
s.listenerFlagLock.Lock()
|
||||||
defer s.listenerFlagLock.Unlock()
|
defer s.listenerFlagLock.Unlock()
|
||||||
|
|
||||||
@ -67,23 +67,23 @@ func (s *serviceViaMemory) Stop() {
|
|||||||
s.listenerIsRunning = false
|
s.listenerIsRunning = false
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *serviceViaMemory) RegisterMux(pattern string, handler dns.Handler) {
|
func (s *ServiceViaMemory) RegisterMux(pattern string, handler dns.Handler) {
|
||||||
s.dnsMux.Handle(pattern, handler)
|
s.dnsMux.Handle(pattern, handler)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *serviceViaMemory) DeregisterMux(pattern string) {
|
func (s *ServiceViaMemory) DeregisterMux(pattern string) {
|
||||||
s.dnsMux.HandleRemove(pattern)
|
s.dnsMux.HandleRemove(pattern)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *serviceViaMemory) RuntimePort() int {
|
func (s *ServiceViaMemory) RuntimePort() int {
|
||||||
return s.runtimePort
|
return s.runtimePort
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *serviceViaMemory) RuntimeIP() string {
|
func (s *ServiceViaMemory) RuntimeIP() string {
|
||||||
return s.runtimeIP
|
return s.runtimeIP
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *serviceViaMemory) filterDNSTraffic() (string, error) {
|
func (s *ServiceViaMemory) filterDNSTraffic() (string, error) {
|
||||||
filter := s.wgInterface.GetFilter()
|
filter := s.wgInterface.GetFilter()
|
||||||
if filter == nil {
|
if filter == nil {
|
||||||
return "", fmt.Errorf("can't set DNS filter, filter not initialized")
|
return "", fmt.Errorf("can't set DNS filter, filter not initialized")
|
||||||
|
@ -4,6 +4,7 @@ package dns
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
|
"fmt"
|
||||||
"net"
|
"net"
|
||||||
"syscall"
|
"syscall"
|
||||||
"time"
|
"time"
|
||||||
@ -17,9 +18,9 @@ import (
|
|||||||
|
|
||||||
type upstreamResolverIOS struct {
|
type upstreamResolverIOS struct {
|
||||||
*upstreamResolverBase
|
*upstreamResolverBase
|
||||||
lIP net.IP
|
lIP net.IP
|
||||||
lNet *net.IPNet
|
lNet *net.IPNet
|
||||||
iIndex int
|
interfaceName string
|
||||||
}
|
}
|
||||||
|
|
||||||
func newUpstreamResolver(
|
func newUpstreamResolver(
|
||||||
@ -32,17 +33,11 @@ func newUpstreamResolver(
|
|||||||
) (*upstreamResolverIOS, error) {
|
) (*upstreamResolverIOS, error) {
|
||||||
upstreamResolverBase := newUpstreamResolverBase(ctx, statusRecorder)
|
upstreamResolverBase := newUpstreamResolverBase(ctx, statusRecorder)
|
||||||
|
|
||||||
index, err := getInterfaceIndex(interfaceName)
|
|
||||||
if err != nil {
|
|
||||||
log.Debugf("unable to get interface index for %s: %s", interfaceName, err)
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
ios := &upstreamResolverIOS{
|
ios := &upstreamResolverIOS{
|
||||||
upstreamResolverBase: upstreamResolverBase,
|
upstreamResolverBase: upstreamResolverBase,
|
||||||
lIP: ip,
|
lIP: ip,
|
||||||
lNet: net,
|
lNet: net,
|
||||||
iIndex: index,
|
interfaceName: interfaceName,
|
||||||
}
|
}
|
||||||
ios.upstreamClient = ios
|
ios.upstreamClient = ios
|
||||||
|
|
||||||
@ -53,7 +48,7 @@ func (u *upstreamResolverIOS) exchange(ctx context.Context, upstream string, r *
|
|||||||
client := &dns.Client{}
|
client := &dns.Client{}
|
||||||
upstreamHost, _, err := net.SplitHostPort(upstream)
|
upstreamHost, _, err := net.SplitHostPort(upstream)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Errorf("error while parsing upstream host: %s", err)
|
return nil, 0, fmt.Errorf("error while parsing upstream host: %s", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
timeout := upstreamTimeout
|
timeout := upstreamTimeout
|
||||||
@ -65,26 +60,35 @@ func (u *upstreamResolverIOS) exchange(ctx context.Context, upstream string, r *
|
|||||||
upstreamIP := net.ParseIP(upstreamHost)
|
upstreamIP := net.ParseIP(upstreamHost)
|
||||||
if u.lNet.Contains(upstreamIP) || net.IP.IsPrivate(upstreamIP) {
|
if u.lNet.Contains(upstreamIP) || net.IP.IsPrivate(upstreamIP) {
|
||||||
log.Debugf("using private client to query upstream: %s", upstream)
|
log.Debugf("using private client to query upstream: %s", upstream)
|
||||||
client = u.getClientPrivate(timeout)
|
client, err = GetClientPrivate(u.lIP, u.interfaceName, timeout)
|
||||||
|
if err != nil {
|
||||||
|
return nil, 0, fmt.Errorf("error while creating private client: %s", err)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Cannot use client.ExchangeContext because it overwrites our Dialer
|
// Cannot use client.ExchangeContext because it overwrites our Dialer
|
||||||
return client.Exchange(r, upstream)
|
return client.Exchange(r, upstream)
|
||||||
}
|
}
|
||||||
|
|
||||||
// getClientPrivate returns a new DNS client bound to the local IP address of the Netbird interface
|
// GetClientPrivate returns a new DNS client bound to the local IP address of the Netbird interface
|
||||||
// This method is needed for iOS
|
// This method is needed for iOS
|
||||||
func (u *upstreamResolverIOS) getClientPrivate(dialTimeout time.Duration) *dns.Client {
|
func GetClientPrivate(ip net.IP, interfaceName string, dialTimeout time.Duration) (*dns.Client, error) {
|
||||||
|
index, err := getInterfaceIndex(interfaceName)
|
||||||
|
if err != nil {
|
||||||
|
log.Debugf("unable to get interface index for %s: %s", interfaceName, err)
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
dialer := &net.Dialer{
|
dialer := &net.Dialer{
|
||||||
LocalAddr: &net.UDPAddr{
|
LocalAddr: &net.UDPAddr{
|
||||||
IP: u.lIP,
|
IP: ip,
|
||||||
Port: 0, // Let the OS pick a free port
|
Port: 0, // Let the OS pick a free port
|
||||||
},
|
},
|
||||||
Timeout: dialTimeout,
|
Timeout: dialTimeout,
|
||||||
Control: func(network, address string, c syscall.RawConn) error {
|
Control: func(network, address string, c syscall.RawConn) error {
|
||||||
var operr error
|
var operr error
|
||||||
fn := func(s uintptr) {
|
fn := func(s uintptr) {
|
||||||
operr = unix.SetsockoptInt(int(s), unix.IPPROTO_IP, unix.IP_BOUND_IF, u.iIndex)
|
operr = unix.SetsockoptInt(int(s), unix.IPPROTO_IP, unix.IP_BOUND_IF, index)
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := c.Control(fn); err != nil {
|
if err := c.Control(fn); err != nil {
|
||||||
@ -101,7 +105,7 @@ func (u *upstreamResolverIOS) getClientPrivate(dialTimeout time.Duration) *dns.C
|
|||||||
client := &dns.Client{
|
client := &dns.Client{
|
||||||
Dialer: dialer,
|
Dialer: dialer,
|
||||||
}
|
}
|
||||||
return client
|
return client, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func getInterfaceIndex(interfaceName string) (int, error) {
|
func getInterfaceIndex(interfaceName string) (int, error) {
|
||||||
|
@ -10,6 +10,7 @@ import (
|
|||||||
log "github.com/sirupsen/logrus"
|
log "github.com/sirupsen/logrus"
|
||||||
|
|
||||||
nberrors "github.com/netbirdio/netbird/client/errors"
|
nberrors "github.com/netbirdio/netbird/client/errors"
|
||||||
|
nbdns "github.com/netbirdio/netbird/client/internal/dns"
|
||||||
"github.com/netbirdio/netbird/client/internal/peer"
|
"github.com/netbirdio/netbird/client/internal/peer"
|
||||||
"github.com/netbirdio/netbird/client/internal/routemanager/dynamic"
|
"github.com/netbirdio/netbird/client/internal/routemanager/dynamic"
|
||||||
"github.com/netbirdio/netbird/client/internal/routemanager/refcounter"
|
"github.com/netbirdio/netbird/client/internal/routemanager/refcounter"
|
||||||
@ -65,7 +66,7 @@ func newClientNetworkWatcher(ctx context.Context, dnsRouteInterval time.Duration
|
|||||||
routePeersNotifiers: make(map[string]chan struct{}),
|
routePeersNotifiers: make(map[string]chan struct{}),
|
||||||
routeUpdate: make(chan routesUpdate),
|
routeUpdate: make(chan routesUpdate),
|
||||||
peerStateUpdate: make(chan struct{}),
|
peerStateUpdate: make(chan struct{}),
|
||||||
handler: handlerFromRoute(rt, routeRefCounter, allowedIPsRefCounter, dnsRouteInterval, statusRecorder),
|
handler: handlerFromRoute(rt, routeRefCounter, allowedIPsRefCounter, dnsRouteInterval, statusRecorder, wgInterface),
|
||||||
}
|
}
|
||||||
return client
|
return client
|
||||||
}
|
}
|
||||||
@ -383,9 +384,10 @@ func (c *clientNetwork) peersStateAndUpdateWatcher() {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func handlerFromRoute(rt *route.Route, routeRefCounter *refcounter.RouteRefCounter, allowedIPsRefCounter *refcounter.AllowedIPsRefCounter, dnsRouterInteval time.Duration, statusRecorder *peer.Status) RouteHandler {
|
func handlerFromRoute(rt *route.Route, routeRefCounter *refcounter.RouteRefCounter, allowedIPsRefCounter *refcounter.AllowedIPsRefCounter, dnsRouterInteval time.Duration, statusRecorder *peer.Status, wgInterface *iface.WGIface) RouteHandler {
|
||||||
if rt.IsDynamic() {
|
if rt.IsDynamic() {
|
||||||
return dynamic.NewRoute(rt, routeRefCounter, allowedIPsRefCounter, dnsRouterInteval, statusRecorder)
|
dns := nbdns.NewServiceViaMemory(wgInterface)
|
||||||
|
return dynamic.NewRoute(rt, routeRefCounter, allowedIPsRefCounter, dnsRouterInteval, statusRecorder, wgInterface, fmt.Sprintf("%s:%d", dns.RuntimeIP(), dns.RuntimePort()))
|
||||||
}
|
}
|
||||||
return static.NewRoute(rt, routeRefCounter, allowedIPsRefCounter)
|
return static.NewRoute(rt, routeRefCounter, allowedIPsRefCounter)
|
||||||
}
|
}
|
||||||
|
@ -16,6 +16,7 @@ import (
|
|||||||
"github.com/netbirdio/netbird/client/internal/peer"
|
"github.com/netbirdio/netbird/client/internal/peer"
|
||||||
"github.com/netbirdio/netbird/client/internal/routemanager/refcounter"
|
"github.com/netbirdio/netbird/client/internal/routemanager/refcounter"
|
||||||
"github.com/netbirdio/netbird/client/internal/routemanager/util"
|
"github.com/netbirdio/netbird/client/internal/routemanager/util"
|
||||||
|
"github.com/netbirdio/netbird/iface"
|
||||||
"github.com/netbirdio/netbird/management/domain"
|
"github.com/netbirdio/netbird/management/domain"
|
||||||
"github.com/netbirdio/netbird/route"
|
"github.com/netbirdio/netbird/route"
|
||||||
)
|
)
|
||||||
@ -47,6 +48,8 @@ type Route struct {
|
|||||||
currentPeerKey string
|
currentPeerKey string
|
||||||
cancel context.CancelFunc
|
cancel context.CancelFunc
|
||||||
statusRecorder *peer.Status
|
statusRecorder *peer.Status
|
||||||
|
wgInterface *iface.WGIface
|
||||||
|
resolverAddr string
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewRoute(
|
func NewRoute(
|
||||||
@ -55,6 +58,8 @@ func NewRoute(
|
|||||||
allowedIPsRefCounter *refcounter.AllowedIPsRefCounter,
|
allowedIPsRefCounter *refcounter.AllowedIPsRefCounter,
|
||||||
interval time.Duration,
|
interval time.Duration,
|
||||||
statusRecorder *peer.Status,
|
statusRecorder *peer.Status,
|
||||||
|
wgInterface *iface.WGIface,
|
||||||
|
resolverAddr string,
|
||||||
) *Route {
|
) *Route {
|
||||||
return &Route{
|
return &Route{
|
||||||
route: rt,
|
route: rt,
|
||||||
@ -63,6 +68,8 @@ func NewRoute(
|
|||||||
interval: interval,
|
interval: interval,
|
||||||
dynamicDomains: domainMap{},
|
dynamicDomains: domainMap{},
|
||||||
statusRecorder: statusRecorder,
|
statusRecorder: statusRecorder,
|
||||||
|
wgInterface: wgInterface,
|
||||||
|
resolverAddr: resolverAddr,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -228,11 +235,17 @@ func (r *Route) resolve(results chan resolveResult) {
|
|||||||
wg.Add(1)
|
wg.Add(1)
|
||||||
go func(domain domain.Domain) {
|
go func(domain domain.Domain) {
|
||||||
defer wg.Done()
|
defer wg.Done()
|
||||||
ips, err := net.LookupIP(string(domain))
|
|
||||||
|
ips, err := r.getIPsFromResolver(domain)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
results <- resolveResult{domain: domain, err: fmt.Errorf("resolve d %s: %w", domain.SafeString(), err)}
|
log.Tracef("Failed to resolve domain %s with private resolver: %v", domain.SafeString(), err)
|
||||||
return
|
ips, err = net.LookupIP(string(domain))
|
||||||
|
if err != nil {
|
||||||
|
results <- resolveResult{domain: domain, err: fmt.Errorf("resolve d %s: %w", domain.SafeString(), err)}
|
||||||
|
return
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
for _, ip := range ips {
|
for _, ip := range ips {
|
||||||
prefix, err := util.GetPrefixFromIP(ip)
|
prefix, err := util.GetPrefixFromIP(ip)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
13
client/internal/routemanager/dynamic/route_generic.go
Normal file
13
client/internal/routemanager/dynamic/route_generic.go
Normal file
@ -0,0 +1,13 @@
|
|||||||
|
//go:build !ios
|
||||||
|
|
||||||
|
package dynamic
|
||||||
|
|
||||||
|
import (
|
||||||
|
"net"
|
||||||
|
|
||||||
|
"github.com/netbirdio/netbird/management/domain"
|
||||||
|
)
|
||||||
|
|
||||||
|
func (r *Route) getIPsFromResolver(domain domain.Domain) ([]net.IP, error) {
|
||||||
|
return net.LookupIP(string(domain))
|
||||||
|
}
|
55
client/internal/routemanager/dynamic/route_ios.go
Normal file
55
client/internal/routemanager/dynamic/route_ios.go
Normal file
@ -0,0 +1,55 @@
|
|||||||
|
//go:build ios
|
||||||
|
|
||||||
|
package dynamic
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"net"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/miekg/dns"
|
||||||
|
|
||||||
|
nbdns "github.com/netbirdio/netbird/client/internal/dns"
|
||||||
|
|
||||||
|
"github.com/netbirdio/netbird/management/domain"
|
||||||
|
)
|
||||||
|
|
||||||
|
const dialTimeout = 10 * time.Second
|
||||||
|
|
||||||
|
func (r *Route) getIPsFromResolver(domain domain.Domain) ([]net.IP, error) {
|
||||||
|
privateClient, err := nbdns.GetClientPrivate(r.wgInterface.Address().IP, r.wgInterface.Name(), dialTimeout)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("error while creating private client: %s", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
msg := new(dns.Msg)
|
||||||
|
msg.SetQuestion(dns.Fqdn(string(domain)), dns.TypeA)
|
||||||
|
|
||||||
|
startTime := time.Now()
|
||||||
|
|
||||||
|
response, _, err := privateClient.Exchange(msg, r.resolverAddr)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("DNS query for %s failed after %s: %s ", domain.SafeString(), time.Since(startTime), err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if response.Rcode != dns.RcodeSuccess {
|
||||||
|
return nil, fmt.Errorf("dns response code: %s", dns.RcodeToString[response.Rcode])
|
||||||
|
}
|
||||||
|
|
||||||
|
ips := make([]net.IP, 0)
|
||||||
|
|
||||||
|
for _, answ := range response.Answer {
|
||||||
|
if aRecord, ok := answ.(*dns.A); ok {
|
||||||
|
ips = append(ips, aRecord.A)
|
||||||
|
}
|
||||||
|
if aaaaRecord, ok := answ.(*dns.AAAA); ok {
|
||||||
|
ips = append(ips, aaaaRecord.AAAA)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(ips) == 0 {
|
||||||
|
return nil, fmt.Errorf("no A or AAAA records found for %s", domain.SafeString())
|
||||||
|
}
|
||||||
|
|
||||||
|
return ips, nil
|
||||||
|
}
|
Loading…
Reference in New Issue
Block a user