diff --git a/client/internal/dns/upstream.go b/client/internal/dns/upstream.go index 69d1a58d6..d94bbe592 100644 --- a/client/internal/dns/upstream.go +++ b/client/internal/dns/upstream.go @@ -4,16 +4,14 @@ import ( "context" "errors" "fmt" - "math/rand" "net" - "net/netip" + "runtime" "sync" "sync/atomic" "syscall" "time" "github.com/cenkalti/backoff/v4" - "github.com/libp2p/go-netroute" "github.com/miekg/dns" log "github.com/sirupsen/logrus" "golang.org/x/sys/unix" @@ -40,6 +38,9 @@ type upstreamResolver struct { mutex sync.Mutex reactivatePeriod time.Duration upstreamTimeout time.Duration + lIP net.IP + lName string + iIndex int deactivate func() reactivate func() @@ -60,6 +61,7 @@ type upstreamResolver struct { func getInterfaceIndex(interfaceName string) (int, error) { iface, err := net.InterfaceByName(interfaceName) if err != nil { + log.Errorf("unable to get interface by name error: %s", err) return 0, err } @@ -75,54 +77,52 @@ func newUpstreamResolver(parentCTX context.Context, interfaceName string, wgAddr log.Errorf("error while parsing CIDR: %s", err) } index, err := getInterfaceIndex(interfaceName) - rand.Seed(time.Now().UnixNano()) - port := rand.Intn(4001) + 1000 - log.Debugf("UpstreamResolver interface name: %s, index: %d, ip: %s, port: %d", interfaceName, index, localIP, port) + log.Debugf("UpstreamResolver interface name: %s, index: %d, ip: %s", interfaceName, index, localIP) if err != nil { log.Debugf("unable to get interface index for %s: %s", interfaceName, err) } localIFaceIndex := index // Should be our interface index - // Create a custom dialer with the LocalAddr set to the desired IP + + return &upstreamResolver{ + ctx: ctx, + cancel: cancel, + upstreamTimeout: upstreamTimeout, + reactivatePeriod: reactivatePeriod, + failsTillDeact: failsTillDeact, + lIP: localIP, + iIndex: localIFaceIndex, + lName: interfaceName, + } +} + +func (u *upstreamResolver) getClient() *dns.Client { dialer := &net.Dialer{ LocalAddr: &net.UDPAddr{ - IP: localIP, - Port: port, // Let the OS pick a free port + IP: u.lIP, + Port: 0, // Let the OS pick a free port }, + Timeout: upstreamTimeout, Control: func(network, address string, c syscall.RawConn) error { var operr error fn := func(s uintptr) { - operr = syscall.SetsockoptInt(int(s), unix.IPPROTO_IP, unix.IP_BOUND_IF, localIFaceIndex) + operr = unix.SetsockoptInt(int(s), unix.IPPROTO_IP, unix.IP_BOUND_IF, u.iIndex) } if err := c.Control(fn); err != nil { return err } + if operr != nil { + log.Errorf("error while setting socket option: %s", operr) + } + return operr }, } - // pktConn, err := dialer.Dial("udp", "100.127.136.151:10053") - // if err != nil { - // log.Errorf("error while dialing: %s", err) - // - // } else { - // pktConn.Write([]byte("hello")) - // pktConn.Close() - // } - - // Create a new DNS client with the custom dialer client := &dns.Client{ Dialer: dialer, } - - return &upstreamResolver{ - ctx: ctx, - cancel: cancel, - upstreamClient: client, - upstreamTimeout: upstreamTimeout, - reactivatePeriod: reactivatePeriod, - failsTillDeact: failsTillDeact, - } + return client } func (u *upstreamResolver) stop() { @@ -134,7 +134,7 @@ func (u *upstreamResolver) stop() { func (u *upstreamResolver) ServeDNS(w dns.ResponseWriter, r *dns.Msg) { defer u.checkUpstreamFails() - log.WithField("question", r.Question[0]).Debug("received an upstream question") + //log.WithField("question", r.Question[0]).Debug("received an upstream question") select { case <-u.ctx.Done(): @@ -143,23 +143,20 @@ func (u *upstreamResolver) ServeDNS(w dns.ResponseWriter, r *dns.Msg) { } for _, upstream := range u.upstreamServers { - log.Debugf("querying the upstream %s", upstream) - rr, errR := netroute.New() - if errR != nil { - log.Errorf("unable to create networute: %s", errR) + var ( + err error + t time.Duration + rm *dns.Msg + ) + upstreamExchangeClient := u.getClient() + if runtime.GOOS != "ios" { + ctx, cancel := context.WithTimeout(u.ctx, u.upstreamTimeout) + rm, t, err = upstreamExchangeClient.ExchangeContext(ctx, r, upstream) + cancel() } else { - add := netip.MustParseAddrPort(upstream) - _, gateway, preferredSrc, errR := rr.Route(add.Addr().AsSlice()) - if errR != nil { - log.Errorf("getting routes returned an error: %v", errR) - } else { - log.Infof("upstream %s gateway: %s, preferredSrc: %s", add.Addr(), gateway, preferredSrc) - } + log.Debugf("ios upstream resolver: %s", upstream) + rm, t, err = upstreamExchangeClient.Exchange(r, upstream) } - ctx, cancel := context.WithTimeout(u.ctx, u.upstreamTimeout) - rm, t, err := u.upstreamClient.ExchangeContext(ctx, r, upstream) - - cancel() if err != nil { if err == context.DeadlineExceeded || isTimeout(err) { @@ -169,7 +166,7 @@ func (u *upstreamResolver) ServeDNS(w dns.ResponseWriter, r *dns.Msg) { } u.failsCount.Add(1) log.WithError(err).WithField("upstream", upstream). - Error("got an error while querying the upstream") + Error("got other error while querying the upstream") return } @@ -204,10 +201,11 @@ func (u *upstreamResolver) checkUpstreamFails() { case <-u.ctx.Done(): return default: - log.Warnf("upstream resolving is disabled for %v", reactivatePeriod) - u.deactivate() - u.disabled = true - go u.waitUntilResponse() + //todo test the deactivation logic, it seems to affect the client + //log.Warnf("upstream resolving is disabled for %v", reactivatePeriod) + //u.deactivate() + //u.disabled = true + //go u.waitUntilResponse() } }