use dns.Client.Exchange

This commit is contained in:
Maycon Santos 2023-11-03 20:35:52 +01:00
parent 64084ca130
commit 65052e5cba

View File

@ -4,16 +4,14 @@ import (
"context"
"errors"
"fmt"
"math/rand"
"net"
"net/netip"
"runtime"
"sync"
"sync/atomic"
"syscall"
"time"
"github.com/cenkalti/backoff/v4"
"github.com/libp2p/go-netroute"
"github.com/miekg/dns"
log "github.com/sirupsen/logrus"
"golang.org/x/sys/unix"
@ -40,6 +38,9 @@ type upstreamResolver struct {
mutex sync.Mutex
reactivatePeriod time.Duration
upstreamTimeout time.Duration
lIP net.IP
lName string
iIndex int
deactivate func()
reactivate func()
@ -60,6 +61,7 @@ type upstreamResolver struct {
func getInterfaceIndex(interfaceName string) (int, error) {
iface, err := net.InterfaceByName(interfaceName)
if err != nil {
log.Errorf("unable to get interface by name error: %s", err)
return 0, err
}
@ -75,54 +77,52 @@ func newUpstreamResolver(parentCTX context.Context, interfaceName string, wgAddr
log.Errorf("error while parsing CIDR: %s", err)
}
index, err := getInterfaceIndex(interfaceName)
rand.Seed(time.Now().UnixNano())
port := rand.Intn(4001) + 1000
log.Debugf("UpstreamResolver interface name: %s, index: %d, ip: %s, port: %d", interfaceName, index, localIP, port)
log.Debugf("UpstreamResolver interface name: %s, index: %d, ip: %s", interfaceName, index, localIP)
if err != nil {
log.Debugf("unable to get interface index for %s: %s", interfaceName, err)
}
localIFaceIndex := index // Should be our interface index
// Create a custom dialer with the LocalAddr set to the desired IP
return &upstreamResolver{
ctx: ctx,
cancel: cancel,
upstreamTimeout: upstreamTimeout,
reactivatePeriod: reactivatePeriod,
failsTillDeact: failsTillDeact,
lIP: localIP,
iIndex: localIFaceIndex,
lName: interfaceName,
}
}
func (u *upstreamResolver) getClient() *dns.Client {
dialer := &net.Dialer{
LocalAddr: &net.UDPAddr{
IP: localIP,
Port: port, // Let the OS pick a free port
IP: u.lIP,
Port: 0, // Let the OS pick a free port
},
Timeout: upstreamTimeout,
Control: func(network, address string, c syscall.RawConn) error {
var operr error
fn := func(s uintptr) {
operr = syscall.SetsockoptInt(int(s), unix.IPPROTO_IP, unix.IP_BOUND_IF, localIFaceIndex)
operr = unix.SetsockoptInt(int(s), unix.IPPROTO_IP, unix.IP_BOUND_IF, u.iIndex)
}
if err := c.Control(fn); err != nil {
return err
}
if operr != nil {
log.Errorf("error while setting socket option: %s", operr)
}
return operr
},
}
// pktConn, err := dialer.Dial("udp", "100.127.136.151:10053")
// if err != nil {
// log.Errorf("error while dialing: %s", err)
//
// } else {
// pktConn.Write([]byte("hello"))
// pktConn.Close()
// }
// Create a new DNS client with the custom dialer
client := &dns.Client{
Dialer: dialer,
}
return &upstreamResolver{
ctx: ctx,
cancel: cancel,
upstreamClient: client,
upstreamTimeout: upstreamTimeout,
reactivatePeriod: reactivatePeriod,
failsTillDeact: failsTillDeact,
}
return client
}
func (u *upstreamResolver) stop() {
@ -134,7 +134,7 @@ func (u *upstreamResolver) stop() {
func (u *upstreamResolver) ServeDNS(w dns.ResponseWriter, r *dns.Msg) {
defer u.checkUpstreamFails()
log.WithField("question", r.Question[0]).Debug("received an upstream question")
//log.WithField("question", r.Question[0]).Debug("received an upstream question")
select {
case <-u.ctx.Done():
@ -143,23 +143,20 @@ func (u *upstreamResolver) ServeDNS(w dns.ResponseWriter, r *dns.Msg) {
}
for _, upstream := range u.upstreamServers {
log.Debugf("querying the upstream %s", upstream)
rr, errR := netroute.New()
if errR != nil {
log.Errorf("unable to create networute: %s", errR)
var (
err error
t time.Duration
rm *dns.Msg
)
upstreamExchangeClient := u.getClient()
if runtime.GOOS != "ios" {
ctx, cancel := context.WithTimeout(u.ctx, u.upstreamTimeout)
rm, t, err = upstreamExchangeClient.ExchangeContext(ctx, r, upstream)
cancel()
} else {
add := netip.MustParseAddrPort(upstream)
_, gateway, preferredSrc, errR := rr.Route(add.Addr().AsSlice())
if errR != nil {
log.Errorf("getting routes returned an error: %v", errR)
} else {
log.Infof("upstream %s gateway: %s, preferredSrc: %s", add.Addr(), gateway, preferredSrc)
}
log.Debugf("ios upstream resolver: %s", upstream)
rm, t, err = upstreamExchangeClient.Exchange(r, upstream)
}
ctx, cancel := context.WithTimeout(u.ctx, u.upstreamTimeout)
rm, t, err := u.upstreamClient.ExchangeContext(ctx, r, upstream)
cancel()
if err != nil {
if err == context.DeadlineExceeded || isTimeout(err) {
@ -169,7 +166,7 @@ func (u *upstreamResolver) ServeDNS(w dns.ResponseWriter, r *dns.Msg) {
}
u.failsCount.Add(1)
log.WithError(err).WithField("upstream", upstream).
Error("got an error while querying the upstream")
Error("got other error while querying the upstream")
return
}
@ -204,10 +201,11 @@ func (u *upstreamResolver) checkUpstreamFails() {
case <-u.ctx.Done():
return
default:
log.Warnf("upstream resolving is disabled for %v", reactivatePeriod)
u.deactivate()
u.disabled = true
go u.waitUntilResponse()
//todo test the deactivation logic, it seems to affect the client
//log.Warnf("upstream resolving is disabled for %v", reactivatePeriod)
//u.deactivate()
//u.disabled = true
//go u.waitUntilResponse()
}
}