Make sure the iOS dialer does not get overwritten (#1585)

* Make sure our iOS dialer does not get overwritten

* set dial timeout for both clients on ios

---------

Co-authored-by: Pascal Fischer <pascal@netbird.io>
This commit is contained in:
Viktor Liu 2024-02-16 14:37:47 +01:00 committed by GitHub
parent cf87f1e702
commit 0afd738509
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -46,24 +46,32 @@ func (u *upstreamResolverIOS) exchange(ctx context.Context, upstream string, r *
if err != nil { if err != nil {
log.Errorf("error while parsing upstream host: %s", err) log.Errorf("error while parsing upstream host: %s", err)
} }
timeout := upstreamTimeout
if deadline, ok := ctx.Deadline(); ok {
timeout = time.Until(deadline)
}
client.DialTimeout = timeout
upstreamIP := net.ParseIP(upstreamHost) upstreamIP := net.ParseIP(upstreamHost)
if u.lNet.Contains(upstreamIP) || net.IP.IsPrivate(upstreamIP) { if u.lNet.Contains(upstreamIP) || net.IP.IsPrivate(upstreamIP) {
log.Debugf("using private client to query upstream: %s", upstream) log.Debugf("using private client to query upstream: %s", upstream)
client = u.getClientPrivate() client = u.getClientPrivate(timeout)
} }
return client.ExchangeContext(ctx, r, upstream) // Cannot use client.ExchangeContext because it overwrites our Dialer
return client.Exchange(r, upstream)
} }
// getClientPrivate returns a new DNS client bound to the local IP address of the Netbird interface // getClientPrivate returns a new DNS client bound to the local IP address of the Netbird interface
// This method is needed for iOS // This method is needed for iOS
func (u *upstreamResolverIOS) getClientPrivate() *dns.Client { func (u *upstreamResolverIOS) getClientPrivate(dialTimeout time.Duration) *dns.Client {
dialer := &net.Dialer{ dialer := &net.Dialer{
LocalAddr: &net.UDPAddr{ LocalAddr: &net.UDPAddr{
IP: u.lIP, IP: u.lIP,
Port: 0, // Let the OS pick a free port Port: 0, // Let the OS pick a free port
}, },
Timeout: upstreamTimeout, Timeout: dialTimeout,
Control: func(network, address string, c syscall.RawConn) error { Control: func(network, address string, c syscall.RawConn) error {
var operr error var operr error
fn := func(s uintptr) { fn := func(s uintptr) {