mirror of
https://github.com/netbirdio/netbird.git
synced 2024-12-12 09:50:47 +01:00
use dns.Client.Exchange
This commit is contained in:
parent
64084ca130
commit
65052e5cba
@ -4,16 +4,14 @@ import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"math/rand"
|
||||
"net"
|
||||
"net/netip"
|
||||
"runtime"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"syscall"
|
||||
"time"
|
||||
|
||||
"github.com/cenkalti/backoff/v4"
|
||||
"github.com/libp2p/go-netroute"
|
||||
"github.com/miekg/dns"
|
||||
log "github.com/sirupsen/logrus"
|
||||
"golang.org/x/sys/unix"
|
||||
@ -40,6 +38,9 @@ type upstreamResolver struct {
|
||||
mutex sync.Mutex
|
||||
reactivatePeriod time.Duration
|
||||
upstreamTimeout time.Duration
|
||||
lIP net.IP
|
||||
lName string
|
||||
iIndex int
|
||||
|
||||
deactivate func()
|
||||
reactivate func()
|
||||
@ -60,6 +61,7 @@ type upstreamResolver struct {
|
||||
func getInterfaceIndex(interfaceName string) (int, error) {
|
||||
iface, err := net.InterfaceByName(interfaceName)
|
||||
if err != nil {
|
||||
log.Errorf("unable to get interface by name error: %s", err)
|
||||
return 0, err
|
||||
}
|
||||
|
||||
@ -75,54 +77,52 @@ func newUpstreamResolver(parentCTX context.Context, interfaceName string, wgAddr
|
||||
log.Errorf("error while parsing CIDR: %s", err)
|
||||
}
|
||||
index, err := getInterfaceIndex(interfaceName)
|
||||
rand.Seed(time.Now().UnixNano())
|
||||
port := rand.Intn(4001) + 1000
|
||||
log.Debugf("UpstreamResolver interface name: %s, index: %d, ip: %s, port: %d", interfaceName, index, localIP, port)
|
||||
log.Debugf("UpstreamResolver interface name: %s, index: %d, ip: %s", interfaceName, index, localIP)
|
||||
if err != nil {
|
||||
log.Debugf("unable to get interface index for %s: %s", interfaceName, err)
|
||||
}
|
||||
localIFaceIndex := index // Should be our interface index
|
||||
// Create a custom dialer with the LocalAddr set to the desired IP
|
||||
|
||||
return &upstreamResolver{
|
||||
ctx: ctx,
|
||||
cancel: cancel,
|
||||
upstreamTimeout: upstreamTimeout,
|
||||
reactivatePeriod: reactivatePeriod,
|
||||
failsTillDeact: failsTillDeact,
|
||||
lIP: localIP,
|
||||
iIndex: localIFaceIndex,
|
||||
lName: interfaceName,
|
||||
}
|
||||
}
|
||||
|
||||
func (u *upstreamResolver) getClient() *dns.Client {
|
||||
dialer := &net.Dialer{
|
||||
LocalAddr: &net.UDPAddr{
|
||||
IP: localIP,
|
||||
Port: port, // Let the OS pick a free port
|
||||
IP: u.lIP,
|
||||
Port: 0, // Let the OS pick a free port
|
||||
},
|
||||
Timeout: upstreamTimeout,
|
||||
Control: func(network, address string, c syscall.RawConn) error {
|
||||
var operr error
|
||||
fn := func(s uintptr) {
|
||||
operr = syscall.SetsockoptInt(int(s), unix.IPPROTO_IP, unix.IP_BOUND_IF, localIFaceIndex)
|
||||
operr = unix.SetsockoptInt(int(s), unix.IPPROTO_IP, unix.IP_BOUND_IF, u.iIndex)
|
||||
}
|
||||
|
||||
if err := c.Control(fn); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if operr != nil {
|
||||
log.Errorf("error while setting socket option: %s", operr)
|
||||
}
|
||||
|
||||
return operr
|
||||
},
|
||||
}
|
||||
// pktConn, err := dialer.Dial("udp", "100.127.136.151:10053")
|
||||
// if err != nil {
|
||||
// log.Errorf("error while dialing: %s", err)
|
||||
//
|
||||
// } else {
|
||||
// pktConn.Write([]byte("hello"))
|
||||
// pktConn.Close()
|
||||
// }
|
||||
|
||||
// Create a new DNS client with the custom dialer
|
||||
client := &dns.Client{
|
||||
Dialer: dialer,
|
||||
}
|
||||
|
||||
return &upstreamResolver{
|
||||
ctx: ctx,
|
||||
cancel: cancel,
|
||||
upstreamClient: client,
|
||||
upstreamTimeout: upstreamTimeout,
|
||||
reactivatePeriod: reactivatePeriod,
|
||||
failsTillDeact: failsTillDeact,
|
||||
}
|
||||
return client
|
||||
}
|
||||
|
||||
func (u *upstreamResolver) stop() {
|
||||
@ -134,7 +134,7 @@ func (u *upstreamResolver) stop() {
|
||||
func (u *upstreamResolver) ServeDNS(w dns.ResponseWriter, r *dns.Msg) {
|
||||
defer u.checkUpstreamFails()
|
||||
|
||||
log.WithField("question", r.Question[0]).Debug("received an upstream question")
|
||||
//log.WithField("question", r.Question[0]).Debug("received an upstream question")
|
||||
|
||||
select {
|
||||
case <-u.ctx.Done():
|
||||
@ -143,23 +143,20 @@ func (u *upstreamResolver) ServeDNS(w dns.ResponseWriter, r *dns.Msg) {
|
||||
}
|
||||
|
||||
for _, upstream := range u.upstreamServers {
|
||||
log.Debugf("querying the upstream %s", upstream)
|
||||
rr, errR := netroute.New()
|
||||
if errR != nil {
|
||||
log.Errorf("unable to create networute: %s", errR)
|
||||
var (
|
||||
err error
|
||||
t time.Duration
|
||||
rm *dns.Msg
|
||||
)
|
||||
upstreamExchangeClient := u.getClient()
|
||||
if runtime.GOOS != "ios" {
|
||||
ctx, cancel := context.WithTimeout(u.ctx, u.upstreamTimeout)
|
||||
rm, t, err = upstreamExchangeClient.ExchangeContext(ctx, r, upstream)
|
||||
cancel()
|
||||
} else {
|
||||
add := netip.MustParseAddrPort(upstream)
|
||||
_, gateway, preferredSrc, errR := rr.Route(add.Addr().AsSlice())
|
||||
if errR != nil {
|
||||
log.Errorf("getting routes returned an error: %v", errR)
|
||||
} else {
|
||||
log.Infof("upstream %s gateway: %s, preferredSrc: %s", add.Addr(), gateway, preferredSrc)
|
||||
}
|
||||
log.Debugf("ios upstream resolver: %s", upstream)
|
||||
rm, t, err = upstreamExchangeClient.Exchange(r, upstream)
|
||||
}
|
||||
ctx, cancel := context.WithTimeout(u.ctx, u.upstreamTimeout)
|
||||
rm, t, err := u.upstreamClient.ExchangeContext(ctx, r, upstream)
|
||||
|
||||
cancel()
|
||||
|
||||
if err != nil {
|
||||
if err == context.DeadlineExceeded || isTimeout(err) {
|
||||
@ -169,7 +166,7 @@ func (u *upstreamResolver) ServeDNS(w dns.ResponseWriter, r *dns.Msg) {
|
||||
}
|
||||
u.failsCount.Add(1)
|
||||
log.WithError(err).WithField("upstream", upstream).
|
||||
Error("got an error while querying the upstream")
|
||||
Error("got other error while querying the upstream")
|
||||
return
|
||||
}
|
||||
|
||||
@ -204,10 +201,11 @@ func (u *upstreamResolver) checkUpstreamFails() {
|
||||
case <-u.ctx.Done():
|
||||
return
|
||||
default:
|
||||
log.Warnf("upstream resolving is disabled for %v", reactivatePeriod)
|
||||
u.deactivate()
|
||||
u.disabled = true
|
||||
go u.waitUntilResponse()
|
||||
//todo test the deactivation logic, it seems to affect the client
|
||||
//log.Warnf("upstream resolving is disabled for %v", reactivatePeriod)
|
||||
//u.deactivate()
|
||||
//u.disabled = true
|
||||
//go u.waitUntilResponse()
|
||||
}
|
||||
}
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user