From 9a0354b681fbb8e73be265252b99eb09242d0ab3 Mon Sep 17 00:00:00 2001 From: Viktor Liu <17948409+lixmal@users.noreply.github.com> Date: Fri, 21 Feb 2025 19:44:50 +0100 Subject: [PATCH] [client] Update local interface addresses when gathering candidates (#3324) --- client/iface/bind/udp_mux.go | 102 +++++++++++++++++------------ client/internal/dns/server_test.go | 4 +- client/internal/stdnet/filter.go | 1 - client/internal/stdnet/stdnet.go | 39 ++++++++++- 4 files changed, 98 insertions(+), 48 deletions(-) diff --git a/client/iface/bind/udp_mux.go b/client/iface/bind/udp_mux.go index 00a91f0ec..4c827de95 100644 --- a/client/iface/bind/udp_mux.go +++ b/client/iface/bind/udp_mux.go @@ -4,6 +4,7 @@ import ( "fmt" "io" "net" + "slices" "strings" "sync" @@ -152,46 +153,7 @@ func NewUDPMuxDefault(params UDPMuxParams) *UDPMuxDefault { params.Logger = logging.NewDefaultLoggerFactory().NewLogger("ice") } - var localAddrsForUnspecified []net.Addr - if addr, ok := params.UDPConn.LocalAddr().(*net.UDPAddr); !ok { - params.Logger.Errorf("LocalAddr is not a net.UDPAddr, got %T", params.UDPConn.LocalAddr()) - } else if ok && addr.IP.IsUnspecified() { - // For unspecified addresses, the correct behavior is to return errListenUnspecified, but - // it will break the applications that are already using unspecified UDP connection - // with UDPMuxDefault, so print a warn log and create a local address list for mux. - params.Logger.Warn("UDPMuxDefault should not listening on unspecified address, use NewMultiUDPMuxFromPort instead") - var networks []ice.NetworkType - switch { - - case addr.IP.To16() != nil: - networks = []ice.NetworkType{ice.NetworkTypeUDP4, ice.NetworkTypeUDP6} - - case addr.IP.To4() != nil: - networks = []ice.NetworkType{ice.NetworkTypeUDP4} - - default: - params.Logger.Errorf("LocalAddr expected IPV4 or IPV6, got %T", params.UDPConn.LocalAddr()) - } - if len(networks) > 0 { - if params.Net == nil { - var err error - if params.Net, err = stdnet.NewNet(); err != nil { - params.Logger.Errorf("failed to get create network: %v", err) - } - } - - ips, err := localInterfaces(params.Net, params.InterfaceFilter, nil, networks, true) - if err == nil { - for _, ip := range ips { - localAddrsForUnspecified = append(localAddrsForUnspecified, &net.UDPAddr{IP: ip, Port: addr.Port}) - } - } else { - params.Logger.Errorf("failed to get local interfaces for unspecified addr: %v", err) - } - } - } - - return &UDPMuxDefault{ + mux := &UDPMuxDefault{ addressMap: map[string][]*udpMuxedConn{}, params: params, connsIPv4: make(map[string]*udpMuxedConn), @@ -203,8 +165,55 @@ func NewUDPMuxDefault(params UDPMuxParams) *UDPMuxDefault { return newBufferHolder(receiveMTU + maxAddrSize) }, }, - localAddrsForUnspecified: localAddrsForUnspecified, } + + mux.updateLocalAddresses() + return mux +} + +func (m *UDPMuxDefault) updateLocalAddresses() { + var localAddrsForUnspecified []net.Addr + if addr, ok := m.params.UDPConn.LocalAddr().(*net.UDPAddr); !ok { + m.params.Logger.Errorf("LocalAddr is not a net.UDPAddr, got %T", m.params.UDPConn.LocalAddr()) + } else if ok && addr.IP.IsUnspecified() { + // For unspecified addresses, the correct behavior is to return errListenUnspecified, but + // it will break the applications that are already using unspecified UDP connection + // with UDPMuxDefault, so print a warn log and create a local address list for mux. + m.params.Logger.Warn("UDPMuxDefault should not listening on unspecified address, use NewMultiUDPMuxFromPort instead") + var networks []ice.NetworkType + switch { + + case addr.IP.To16() != nil: + networks = []ice.NetworkType{ice.NetworkTypeUDP4, ice.NetworkTypeUDP6} + + case addr.IP.To4() != nil: + networks = []ice.NetworkType{ice.NetworkTypeUDP4} + + default: + m.params.Logger.Errorf("LocalAddr expected IPV4 or IPV6, got %T", m.params.UDPConn.LocalAddr()) + } + if len(networks) > 0 { + if m.params.Net == nil { + var err error + if m.params.Net, err = stdnet.NewNet(); err != nil { + m.params.Logger.Errorf("failed to get create network: %v", err) + } + } + + ips, err := localInterfaces(m.params.Net, m.params.InterfaceFilter, nil, networks, true) + if err == nil { + for _, ip := range ips { + localAddrsForUnspecified = append(localAddrsForUnspecified, &net.UDPAddr{IP: ip, Port: addr.Port}) + } + } else { + m.params.Logger.Errorf("failed to get local interfaces for unspecified addr: %v", err) + } + } + } + + m.mu.Lock() + m.localAddrsForUnspecified = localAddrsForUnspecified + m.mu.Unlock() } // LocalAddr returns the listening address of this UDPMuxDefault @@ -214,8 +223,12 @@ func (m *UDPMuxDefault) LocalAddr() net.Addr { // GetListenAddresses returns the list of addresses that this mux is listening on func (m *UDPMuxDefault) GetListenAddresses() []net.Addr { + m.updateLocalAddresses() + + m.mu.Lock() + defer m.mu.Unlock() if len(m.localAddrsForUnspecified) > 0 { - return m.localAddrsForUnspecified + return slices.Clone(m.localAddrsForUnspecified) } return []net.Addr{m.LocalAddr()} @@ -225,7 +238,10 @@ func (m *UDPMuxDefault) GetListenAddresses() []net.Addr { // creates the connection if an existing one can't be found func (m *UDPMuxDefault) GetConn(ufrag string, addr net.Addr) (net.PacketConn, error) { // don't check addr for mux using unspecified address - if len(m.localAddrsForUnspecified) == 0 && m.params.UDPConn.LocalAddr().String() != addr.String() { + m.mu.Lock() + lenLocalAddrs := len(m.localAddrsForUnspecified) + m.mu.Unlock() + if lenLocalAddrs == 0 && m.params.UDPConn.LocalAddr().String() != addr.String() { return nil, fmt.Errorf("invalid address %s", addr.String()) } diff --git a/client/internal/dns/server_test.go b/client/internal/dns/server_test.go index 84779256f..94b87124b 100644 --- a/client/internal/dns/server_test.go +++ b/client/internal/dns/server_test.go @@ -413,7 +413,7 @@ func TestDNSFakeResolverHandleUpdates(t *testing.T) { defer t.Setenv("NB_WG_KERNEL_DISABLED", ov) t.Setenv("NB_WG_KERNEL_DISABLED", "true") - newNet, err := stdnet.NewNet(nil) + newNet, err := stdnet.NewNet([]string{"utun2301"}) if err != nil { t.Errorf("create stdnet: %v", err) return @@ -887,7 +887,7 @@ func createWgInterfaceWithBind(t *testing.T) (*iface.WGIface, error) { defer t.Setenv("NB_WG_KERNEL_DISABLED", ov) t.Setenv("NB_WG_KERNEL_DISABLED", "true") - newNet, err := stdnet.NewNet(nil) + newNet, err := stdnet.NewNet([]string{"utun2301"}) if err != nil { t.Fatalf("create stdnet: %v", err) return nil, err diff --git a/client/internal/stdnet/filter.go b/client/internal/stdnet/filter.go index c04250b2d..e45714001 100644 --- a/client/internal/stdnet/filter.go +++ b/client/internal/stdnet/filter.go @@ -21,7 +21,6 @@ func InterfaceFilter(disallowList []string) func(string) bool { for _, s := range disallowList { if strings.HasPrefix(iFace, s) && runtime.GOOS != "ios" { - log.Tracef("ignoring interface %s - it is not allowed", iFace) return false } } diff --git a/client/internal/stdnet/stdnet.go b/client/internal/stdnet/stdnet.go index 2e87475a5..aa9fdd045 100644 --- a/client/internal/stdnet/stdnet.go +++ b/client/internal/stdnet/stdnet.go @@ -5,11 +5,16 @@ package stdnet import ( "fmt" + "slices" + "sync" + "time" "github.com/pion/transport/v3" "github.com/pion/transport/v3/stdnet" ) +const updateInterval = 30 * time.Second + // Net is an implementation of the net.Net interface // based on functions of the standard net package. type Net struct { @@ -18,6 +23,10 @@ type Net struct { iFaceDiscover iFaceDiscover // interfaceFilter should return true if the given interfaceName is allowed interfaceFilter func(interfaceName string) bool + lastUpdate time.Time + + // mu is shared between interfaces and lastUpdate + mu sync.Mutex } // NewNetWithDiscover creates a new StdNet instance. @@ -43,18 +52,40 @@ func NewNet(disallowList []string) (*Net, error) { // The interfaces are discovered by an external iFaceDiscover function or by a default discoverer if the external one // wasn't specified. func (n *Net) UpdateInterfaces() (err error) { + n.mu.Lock() + defer n.mu.Unlock() + + return n.updateInterfaces() +} + +func (n *Net) updateInterfaces() (err error) { allIfaces, err := n.iFaceDiscover.iFaces() if err != nil { return err } + n.interfaces = n.filterInterfaces(allIfaces) + + n.lastUpdate = time.Now() + return nil } // Interfaces returns a slice of interfaces which are available on the // system func (n *Net) Interfaces() ([]*transport.Interface, error) { - return n.interfaces, nil + n.mu.Lock() + defer n.mu.Unlock() + + if time.Since(n.lastUpdate) < updateInterval { + return slices.Clone(n.interfaces), nil + } + + if err := n.updateInterfaces(); err != nil { + return nil, fmt.Errorf("update interfaces: %w", err) + } + + return slices.Clone(n.interfaces), nil } // InterfaceByIndex returns the interface specified by index. @@ -63,6 +94,8 @@ func (n *Net) Interfaces() ([]*transport.Interface, error) { // sharing the logical data link; for more precision use // InterfaceByName. func (n *Net) InterfaceByIndex(index int) (*transport.Interface, error) { + n.mu.Lock() + defer n.mu.Unlock() for _, ifc := range n.interfaces { if ifc.Index == index { return ifc, nil @@ -74,6 +107,8 @@ func (n *Net) InterfaceByIndex(index int) (*transport.Interface, error) { // InterfaceByName returns the interface specified by name. func (n *Net) InterfaceByName(name string) (*transport.Interface, error) { + n.mu.Lock() + defer n.mu.Unlock() for _, ifc := range n.interfaces { if ifc.Name == name { return ifc, nil @@ -87,7 +122,7 @@ func (n *Net) filterInterfaces(interfaces []*transport.Interface) []*transport.I if n.interfaceFilter == nil { return interfaces } - result := []*transport.Interface{} + var result []*transport.Interface for _, iface := range interfaces { if n.interfaceFilter(iface.Name) { result = append(result, iface)