trying to bind the DNS resolver dialer to an interface

This commit is contained in:
Pascal Fischer 2023-11-03 14:26:07 +01:00
parent 79f60b86c4
commit 64084ca130
6 changed files with 109 additions and 15 deletions

View File

@ -42,9 +42,10 @@ func RunClientMobile(ctx context.Context, config *Config, statusRecorder *peer.S
return runClient(ctx, config, statusRecorder, mobileDependency) return runClient(ctx, config, statusRecorder, mobileDependency)
} }
func RunClientiOS(ctx context.Context, config *Config, statusRecorder *peer.Status, fileDescriptor int32, routeListener routemanager.RouteListener, dnsManager dns.IosDnsManager) error { func RunClientiOS(ctx context.Context, config *Config, statusRecorder *peer.Status, fileDescriptor int32, routeListener routemanager.RouteListener, dnsManager dns.IosDnsManager, interfaceName string) error {
mobileDependency := MobileDependency{ mobileDependency := MobileDependency{
FileDescriptor: fileDescriptor, FileDescriptor: fileDescriptor,
InterfaceName: interfaceName,
RouteListener: routeListener, RouteListener: routeListener,
DnsManager: dnsManager, DnsManager: dnsManager,
} }

View File

@ -53,6 +53,9 @@ type DefaultServer struct {
permanent bool permanent bool
hostsDnsList []string hostsDnsList []string
hostsDnsListLock sync.Mutex hostsDnsListLock sync.Mutex
interfaceName string
wgAddr string
} }
type handlerWithStop interface { type handlerWithStop interface {
@ -66,7 +69,7 @@ type muxUpdate struct {
} }
// NewDefaultServer returns a new dns server // NewDefaultServer returns a new dns server
func NewDefaultServer(ctx context.Context, wgInterface WGIface, customAddress string) (*DefaultServer, error) { func NewDefaultServer(ctx context.Context, wgInterface WGIface, customAddress string, interfaceName string, wgAddr string) (*DefaultServer, error) {
var addrPort *netip.AddrPort var addrPort *netip.AddrPort
if customAddress != "" { if customAddress != "" {
parsedAddrPort, err := netip.ParseAddrPort(customAddress) parsedAddrPort, err := netip.ParseAddrPort(customAddress)
@ -83,13 +86,13 @@ func NewDefaultServer(ctx context.Context, wgInterface WGIface, customAddress st
dnsService = newServiceViaListener(wgInterface, addrPort) dnsService = newServiceViaListener(wgInterface, addrPort)
} }
return newDefaultServer(ctx, wgInterface, dnsService), nil return newDefaultServer(ctx, wgInterface, dnsService, interfaceName, wgAddr), nil
} }
// NewDefaultServerPermanentUpstream returns a new dns server. It optimized for mobile systems // NewDefaultServerPermanentUpstream returns a new dns server. It optimized for mobile systems
func NewDefaultServerPermanentUpstream(ctx context.Context, wgInterface WGIface, hostsDnsList []string) *DefaultServer { func NewDefaultServerPermanentUpstream(ctx context.Context, wgInterface WGIface, hostsDnsList []string) *DefaultServer {
log.Debugf("host dns address list is: %v", hostsDnsList) log.Debugf("host dns address list is: %v", hostsDnsList)
ds := newDefaultServer(ctx, wgInterface, newServiceViaMemory(wgInterface)) ds := newDefaultServer(ctx, wgInterface, newServiceViaMemory(wgInterface), "", "")
ds.permanent = true ds.permanent = true
ds.hostsDnsList = hostsDnsList ds.hostsDnsList = hostsDnsList
ds.addHostRootZone() ds.addHostRootZone()
@ -97,7 +100,7 @@ func NewDefaultServerPermanentUpstream(ctx context.Context, wgInterface WGIface,
return ds return ds
} }
func newDefaultServer(ctx context.Context, wgInterface WGIface, dnsService service) *DefaultServer { func newDefaultServer(ctx context.Context, wgInterface WGIface, dnsService service, interfaceName string, wgAddr string) *DefaultServer {
ctx, stop := context.WithCancel(ctx) ctx, stop := context.WithCancel(ctx)
defaultServer := &DefaultServer{ defaultServer := &DefaultServer{
ctx: ctx, ctx: ctx,
@ -108,6 +111,8 @@ func newDefaultServer(ctx context.Context, wgInterface WGIface, dnsService servi
registeredMap: make(registrationMap), registeredMap: make(registrationMap),
}, },
wgInterface: wgInterface, wgInterface: wgInterface,
interfaceName: interfaceName,
wgAddr: wgAddr,
} }
return defaultServer return defaultServer
@ -295,7 +300,7 @@ func (s *DefaultServer) buildUpstreamHandlerUpdate(nameServerGroups []*nbdns.Nam
continue continue
} }
handler := newUpstreamResolver(s.ctx) handler := newUpstreamResolver(s.ctx, s.interfaceName, s.wgAddr)
for _, ns := range nsGroup.NameServers { for _, ns := range nsGroup.NameServers {
if ns.NSType != nbdns.UDPNameServerType { if ns.NSType != nbdns.UDPNameServerType {
log.Warnf("skiping nameserver %s with type %s, this peer supports only %s", log.Warnf("skiping nameserver %s with type %s, this peer supports only %s",
@ -468,7 +473,7 @@ func (s *DefaultServer) upstreamCallbacks(
} }
func (s *DefaultServer) addHostRootZone() { func (s *DefaultServer) addHostRootZone() {
handler := newUpstreamResolver(s.ctx) handler := newUpstreamResolver(s.ctx, s.interfaceName, s.wgAddr)
handler.upstreamServers = make([]string, len(s.hostsDnsList)) handler.upstreamServers = make([]string, len(s.hostsDnsList))
for n, ua := range s.hostsDnsList { for n, ua := range s.hostsDnsList {
handler.upstreamServers[n] = fmt.Sprintf("%s:53", ua) handler.upstreamServers[n] = fmt.Sprintf("%s:53", ua)

View File

@ -4,14 +4,19 @@ import (
"context" "context"
"errors" "errors"
"fmt" "fmt"
"math/rand"
"net" "net"
"net/netip"
"sync" "sync"
"sync/atomic" "sync/atomic"
"syscall"
"time" "time"
"github.com/cenkalti/backoff/v4" "github.com/cenkalti/backoff/v4"
"github.com/libp2p/go-netroute"
"github.com/miekg/dns" "github.com/miekg/dns"
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
"golang.org/x/sys/unix"
) )
const ( const (
@ -40,12 +45,80 @@ type upstreamResolver struct {
reactivate func() reactivate func()
} }
func newUpstreamResolver(parentCTX context.Context) *upstreamResolver { // func newUpstreamResolver(parentCTX context.Context) *upstreamResolver {
// ctx, cancel := context.WithCancel(parentCTX)
// return &upstreamResolver{
// ctx: ctx,
// cancel: cancel,
// upstreamClient: &dns.Client{},
// upstreamTimeout: upstreamTimeout,
// reactivatePeriod: reactivatePeriod,
// failsTillDeact: failsTillDeact,
// }
// }
func getInterfaceIndex(interfaceName string) (int, error) {
iface, err := net.InterfaceByName(interfaceName)
if err != nil {
return 0, err
}
return iface.Index, nil
}
func newUpstreamResolver(parentCTX context.Context, interfaceName string, wgAddr string) *upstreamResolver {
ctx, cancel := context.WithCancel(parentCTX) ctx, cancel := context.WithCancel(parentCTX)
// Specify the local IP address you want to bind to
localIP, _, err := net.ParseCIDR(wgAddr) // Should be our interface IP
if err != nil {
log.Errorf("error while parsing CIDR: %s", err)
}
index, err := getInterfaceIndex(interfaceName)
rand.Seed(time.Now().UnixNano())
port := rand.Intn(4001) + 1000
log.Debugf("UpstreamResolver interface name: %s, index: %d, ip: %s, port: %d", interfaceName, index, localIP, port)
if err != nil {
log.Debugf("unable to get interface index for %s: %s", interfaceName, err)
}
localIFaceIndex := index // Should be our interface index
// Create a custom dialer with the LocalAddr set to the desired IP
dialer := &net.Dialer{
LocalAddr: &net.UDPAddr{
IP: localIP,
Port: port, // Let the OS pick a free port
},
Control: func(network, address string, c syscall.RawConn) error {
var operr error
fn := func(s uintptr) {
operr = syscall.SetsockoptInt(int(s), unix.IPPROTO_IP, unix.IP_BOUND_IF, localIFaceIndex)
}
if err := c.Control(fn); err != nil {
return err
}
return operr
},
}
// pktConn, err := dialer.Dial("udp", "100.127.136.151:10053")
// if err != nil {
// log.Errorf("error while dialing: %s", err)
//
// } else {
// pktConn.Write([]byte("hello"))
// pktConn.Close()
// }
// Create a new DNS client with the custom dialer
client := &dns.Client{
Dialer: dialer,
}
return &upstreamResolver{ return &upstreamResolver{
ctx: ctx, ctx: ctx,
cancel: cancel, cancel: cancel,
upstreamClient: &dns.Client{}, upstreamClient: client,
upstreamTimeout: upstreamTimeout, upstreamTimeout: upstreamTimeout,
reactivatePeriod: reactivatePeriod, reactivatePeriod: reactivatePeriod,
failsTillDeact: failsTillDeact, failsTillDeact: failsTillDeact,
@ -61,7 +134,7 @@ func (u *upstreamResolver) stop() {
func (u *upstreamResolver) ServeDNS(w dns.ResponseWriter, r *dns.Msg) { func (u *upstreamResolver) ServeDNS(w dns.ResponseWriter, r *dns.Msg) {
defer u.checkUpstreamFails() defer u.checkUpstreamFails()
log.WithField("question", r.Question[0]).Trace("received an upstream question") log.WithField("question", r.Question[0]).Debug("received an upstream question")
select { select {
case <-u.ctx.Done(): case <-u.ctx.Done():
@ -70,6 +143,19 @@ func (u *upstreamResolver) ServeDNS(w dns.ResponseWriter, r *dns.Msg) {
} }
for _, upstream := range u.upstreamServers { for _, upstream := range u.upstreamServers {
log.Debugf("querying the upstream %s", upstream)
rr, errR := netroute.New()
if errR != nil {
log.Errorf("unable to create networute: %s", errR)
} else {
add := netip.MustParseAddrPort(upstream)
_, gateway, preferredSrc, errR := rr.Route(add.Addr().AsSlice())
if errR != nil {
log.Errorf("getting routes returned an error: %v", errR)
} else {
log.Infof("upstream %s gateway: %s, preferredSrc: %s", add.Addr(), gateway, preferredSrc)
}
}
ctx, cancel := context.WithTimeout(u.ctx, u.upstreamTimeout) ctx, cancel := context.WithTimeout(u.ctx, u.upstreamTimeout)
rm, t, err := u.upstreamClient.ExchangeContext(ctx, r, upstream) rm, t, err := u.upstreamClient.ExchangeContext(ctx, r, upstream)
@ -87,7 +173,7 @@ func (u *upstreamResolver) ServeDNS(w dns.ResponseWriter, r *dns.Msg) {
return return
} }
log.Tracef("took %s to query the upstream %s", t, upstream) log.Debugf("took %s to query the upstream %s", t, upstream)
err = w.WriteMsg(rm) err = w.WriteMsg(rm)
if err != nil { if err != nil {

View File

@ -206,7 +206,7 @@ func (e *Engine) Start() error {
} else { } else {
// todo fix custom address // todo fix custom address
if e.dnsServer == nil { if e.dnsServer == nil {
e.dnsServer, err = dns.NewDefaultServer(e.ctx, e.wgInterface, e.config.CustomDNSAddress) e.dnsServer, err = dns.NewDefaultServer(e.ctx, e.wgInterface, e.config.CustomDNSAddress, e.mobileDep.InterfaceName, wgAddr)
if err != nil { if err != nil {
e.close() e.close()
return err return err

View File

@ -16,4 +16,5 @@ type MobileDependency struct {
DnsReadyListener dns.ReadyListener DnsReadyListener dns.ReadyListener
DnsManager dns.IosDnsManager DnsManager dns.IosDnsManager
FileDescriptor int32 FileDescriptor int32
InterfaceName string
} }

View File

@ -67,8 +67,9 @@ func NewClient(cfgFile, deviceName string, osVersion string, osName string, rout
} }
// Run start the internal client. It is a blocker function // Run start the internal client. It is a blocker function
func (c *Client) Run(fd int32) error { func (c *Client) Run(fd int32, interfaceName string) error {
log.Infof("Starting NetBird client") log.Infof("Starting NetBird client")
log.Debugf("Tunnel uses interface: %s", interfaceName)
cfg, err := internal.UpdateOrCreateConfig(internal.ConfigInput{ cfg, err := internal.UpdateOrCreateConfig(internal.ConfigInput{
ConfigPath: c.cfgFile, ConfigPath: c.cfgFile,
}) })
@ -97,7 +98,7 @@ func (c *Client) Run(fd int32) error {
// todo do not throw error in case of cancelled context // todo do not throw error in case of cancelled context
ctx = internal.CtxInitState(ctx) ctx = internal.CtxInitState(ctx)
c.onHostDnsFn = func([]string) {} c.onHostDnsFn = func([]string) {}
return internal.RunClientiOS(ctx, cfg, c.recorder, fd, c.routeListener, c.dnsManager) return internal.RunClientiOS(ctx, cfg, c.recorder, fd, c.routeListener, c.dnsManager, interfaceName)
} }
// Stop the internal client and free the resources // Stop the internal client and free the resources