[client] Update local interface addresses when gathering candidates (#3324)

This commit is contained in:
Viktor Liu 2025-02-21 19:44:50 +01:00 committed by GitHub
parent 73101c8977
commit 9a0354b681
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 98 additions and 48 deletions

View File

@ -4,6 +4,7 @@ import (
"fmt" "fmt"
"io" "io"
"net" "net"
"slices"
"strings" "strings"
"sync" "sync"
@ -152,46 +153,7 @@ func NewUDPMuxDefault(params UDPMuxParams) *UDPMuxDefault {
params.Logger = logging.NewDefaultLoggerFactory().NewLogger("ice") params.Logger = logging.NewDefaultLoggerFactory().NewLogger("ice")
} }
var localAddrsForUnspecified []net.Addr mux := &UDPMuxDefault{
if addr, ok := params.UDPConn.LocalAddr().(*net.UDPAddr); !ok {
params.Logger.Errorf("LocalAddr is not a net.UDPAddr, got %T", params.UDPConn.LocalAddr())
} else if ok && addr.IP.IsUnspecified() {
// For unspecified addresses, the correct behavior is to return errListenUnspecified, but
// it will break the applications that are already using unspecified UDP connection
// with UDPMuxDefault, so print a warn log and create a local address list for mux.
params.Logger.Warn("UDPMuxDefault should not listening on unspecified address, use NewMultiUDPMuxFromPort instead")
var networks []ice.NetworkType
switch {
case addr.IP.To16() != nil:
networks = []ice.NetworkType{ice.NetworkTypeUDP4, ice.NetworkTypeUDP6}
case addr.IP.To4() != nil:
networks = []ice.NetworkType{ice.NetworkTypeUDP4}
default:
params.Logger.Errorf("LocalAddr expected IPV4 or IPV6, got %T", params.UDPConn.LocalAddr())
}
if len(networks) > 0 {
if params.Net == nil {
var err error
if params.Net, err = stdnet.NewNet(); err != nil {
params.Logger.Errorf("failed to get create network: %v", err)
}
}
ips, err := localInterfaces(params.Net, params.InterfaceFilter, nil, networks, true)
if err == nil {
for _, ip := range ips {
localAddrsForUnspecified = append(localAddrsForUnspecified, &net.UDPAddr{IP: ip, Port: addr.Port})
}
} else {
params.Logger.Errorf("failed to get local interfaces for unspecified addr: %v", err)
}
}
}
return &UDPMuxDefault{
addressMap: map[string][]*udpMuxedConn{}, addressMap: map[string][]*udpMuxedConn{},
params: params, params: params,
connsIPv4: make(map[string]*udpMuxedConn), connsIPv4: make(map[string]*udpMuxedConn),
@ -203,8 +165,55 @@ func NewUDPMuxDefault(params UDPMuxParams) *UDPMuxDefault {
return newBufferHolder(receiveMTU + maxAddrSize) return newBufferHolder(receiveMTU + maxAddrSize)
}, },
}, },
localAddrsForUnspecified: localAddrsForUnspecified,
} }
mux.updateLocalAddresses()
return mux
}
func (m *UDPMuxDefault) updateLocalAddresses() {
var localAddrsForUnspecified []net.Addr
if addr, ok := m.params.UDPConn.LocalAddr().(*net.UDPAddr); !ok {
m.params.Logger.Errorf("LocalAddr is not a net.UDPAddr, got %T", m.params.UDPConn.LocalAddr())
} else if ok && addr.IP.IsUnspecified() {
// For unspecified addresses, the correct behavior is to return errListenUnspecified, but
// it will break the applications that are already using unspecified UDP connection
// with UDPMuxDefault, so print a warn log and create a local address list for mux.
m.params.Logger.Warn("UDPMuxDefault should not listening on unspecified address, use NewMultiUDPMuxFromPort instead")
var networks []ice.NetworkType
switch {
case addr.IP.To16() != nil:
networks = []ice.NetworkType{ice.NetworkTypeUDP4, ice.NetworkTypeUDP6}
case addr.IP.To4() != nil:
networks = []ice.NetworkType{ice.NetworkTypeUDP4}
default:
m.params.Logger.Errorf("LocalAddr expected IPV4 or IPV6, got %T", m.params.UDPConn.LocalAddr())
}
if len(networks) > 0 {
if m.params.Net == nil {
var err error
if m.params.Net, err = stdnet.NewNet(); err != nil {
m.params.Logger.Errorf("failed to get create network: %v", err)
}
}
ips, err := localInterfaces(m.params.Net, m.params.InterfaceFilter, nil, networks, true)
if err == nil {
for _, ip := range ips {
localAddrsForUnspecified = append(localAddrsForUnspecified, &net.UDPAddr{IP: ip, Port: addr.Port})
}
} else {
m.params.Logger.Errorf("failed to get local interfaces for unspecified addr: %v", err)
}
}
}
m.mu.Lock()
m.localAddrsForUnspecified = localAddrsForUnspecified
m.mu.Unlock()
} }
// LocalAddr returns the listening address of this UDPMuxDefault // LocalAddr returns the listening address of this UDPMuxDefault
@ -214,8 +223,12 @@ func (m *UDPMuxDefault) LocalAddr() net.Addr {
// GetListenAddresses returns the list of addresses that this mux is listening on // GetListenAddresses returns the list of addresses that this mux is listening on
func (m *UDPMuxDefault) GetListenAddresses() []net.Addr { func (m *UDPMuxDefault) GetListenAddresses() []net.Addr {
m.updateLocalAddresses()
m.mu.Lock()
defer m.mu.Unlock()
if len(m.localAddrsForUnspecified) > 0 { if len(m.localAddrsForUnspecified) > 0 {
return m.localAddrsForUnspecified return slices.Clone(m.localAddrsForUnspecified)
} }
return []net.Addr{m.LocalAddr()} return []net.Addr{m.LocalAddr()}
@ -225,7 +238,10 @@ func (m *UDPMuxDefault) GetListenAddresses() []net.Addr {
// creates the connection if an existing one can't be found // creates the connection if an existing one can't be found
func (m *UDPMuxDefault) GetConn(ufrag string, addr net.Addr) (net.PacketConn, error) { func (m *UDPMuxDefault) GetConn(ufrag string, addr net.Addr) (net.PacketConn, error) {
// don't check addr for mux using unspecified address // don't check addr for mux using unspecified address
if len(m.localAddrsForUnspecified) == 0 && m.params.UDPConn.LocalAddr().String() != addr.String() { m.mu.Lock()
lenLocalAddrs := len(m.localAddrsForUnspecified)
m.mu.Unlock()
if lenLocalAddrs == 0 && m.params.UDPConn.LocalAddr().String() != addr.String() {
return nil, fmt.Errorf("invalid address %s", addr.String()) return nil, fmt.Errorf("invalid address %s", addr.String())
} }

View File

@ -413,7 +413,7 @@ func TestDNSFakeResolverHandleUpdates(t *testing.T) {
defer t.Setenv("NB_WG_KERNEL_DISABLED", ov) defer t.Setenv("NB_WG_KERNEL_DISABLED", ov)
t.Setenv("NB_WG_KERNEL_DISABLED", "true") t.Setenv("NB_WG_KERNEL_DISABLED", "true")
newNet, err := stdnet.NewNet(nil) newNet, err := stdnet.NewNet([]string{"utun2301"})
if err != nil { if err != nil {
t.Errorf("create stdnet: %v", err) t.Errorf("create stdnet: %v", err)
return return
@ -887,7 +887,7 @@ func createWgInterfaceWithBind(t *testing.T) (*iface.WGIface, error) {
defer t.Setenv("NB_WG_KERNEL_DISABLED", ov) defer t.Setenv("NB_WG_KERNEL_DISABLED", ov)
t.Setenv("NB_WG_KERNEL_DISABLED", "true") t.Setenv("NB_WG_KERNEL_DISABLED", "true")
newNet, err := stdnet.NewNet(nil) newNet, err := stdnet.NewNet([]string{"utun2301"})
if err != nil { if err != nil {
t.Fatalf("create stdnet: %v", err) t.Fatalf("create stdnet: %v", err)
return nil, err return nil, err

View File

@ -21,7 +21,6 @@ func InterfaceFilter(disallowList []string) func(string) bool {
for _, s := range disallowList { for _, s := range disallowList {
if strings.HasPrefix(iFace, s) && runtime.GOOS != "ios" { if strings.HasPrefix(iFace, s) && runtime.GOOS != "ios" {
log.Tracef("ignoring interface %s - it is not allowed", iFace)
return false return false
} }
} }

View File

@ -5,11 +5,16 @@ package stdnet
import ( import (
"fmt" "fmt"
"slices"
"sync"
"time"
"github.com/pion/transport/v3" "github.com/pion/transport/v3"
"github.com/pion/transport/v3/stdnet" "github.com/pion/transport/v3/stdnet"
) )
const updateInterval = 30 * time.Second
// Net is an implementation of the net.Net interface // Net is an implementation of the net.Net interface
// based on functions of the standard net package. // based on functions of the standard net package.
type Net struct { type Net struct {
@ -18,6 +23,10 @@ type Net struct {
iFaceDiscover iFaceDiscover iFaceDiscover iFaceDiscover
// interfaceFilter should return true if the given interfaceName is allowed // interfaceFilter should return true if the given interfaceName is allowed
interfaceFilter func(interfaceName string) bool interfaceFilter func(interfaceName string) bool
lastUpdate time.Time
// mu is shared between interfaces and lastUpdate
mu sync.Mutex
} }
// NewNetWithDiscover creates a new StdNet instance. // NewNetWithDiscover creates a new StdNet instance.
@ -43,18 +52,40 @@ func NewNet(disallowList []string) (*Net, error) {
// The interfaces are discovered by an external iFaceDiscover function or by a default discoverer if the external one // The interfaces are discovered by an external iFaceDiscover function or by a default discoverer if the external one
// wasn't specified. // wasn't specified.
func (n *Net) UpdateInterfaces() (err error) { func (n *Net) UpdateInterfaces() (err error) {
n.mu.Lock()
defer n.mu.Unlock()
return n.updateInterfaces()
}
func (n *Net) updateInterfaces() (err error) {
allIfaces, err := n.iFaceDiscover.iFaces() allIfaces, err := n.iFaceDiscover.iFaces()
if err != nil { if err != nil {
return err return err
} }
n.interfaces = n.filterInterfaces(allIfaces) n.interfaces = n.filterInterfaces(allIfaces)
n.lastUpdate = time.Now()
return nil return nil
} }
// Interfaces returns a slice of interfaces which are available on the // Interfaces returns a slice of interfaces which are available on the
// system // system
func (n *Net) Interfaces() ([]*transport.Interface, error) { func (n *Net) Interfaces() ([]*transport.Interface, error) {
return n.interfaces, nil n.mu.Lock()
defer n.mu.Unlock()
if time.Since(n.lastUpdate) < updateInterval {
return slices.Clone(n.interfaces), nil
}
if err := n.updateInterfaces(); err != nil {
return nil, fmt.Errorf("update interfaces: %w", err)
}
return slices.Clone(n.interfaces), nil
} }
// InterfaceByIndex returns the interface specified by index. // InterfaceByIndex returns the interface specified by index.
@ -63,6 +94,8 @@ func (n *Net) Interfaces() ([]*transport.Interface, error) {
// sharing the logical data link; for more precision use // sharing the logical data link; for more precision use
// InterfaceByName. // InterfaceByName.
func (n *Net) InterfaceByIndex(index int) (*transport.Interface, error) { func (n *Net) InterfaceByIndex(index int) (*transport.Interface, error) {
n.mu.Lock()
defer n.mu.Unlock()
for _, ifc := range n.interfaces { for _, ifc := range n.interfaces {
if ifc.Index == index { if ifc.Index == index {
return ifc, nil return ifc, nil
@ -74,6 +107,8 @@ func (n *Net) InterfaceByIndex(index int) (*transport.Interface, error) {
// InterfaceByName returns the interface specified by name. // InterfaceByName returns the interface specified by name.
func (n *Net) InterfaceByName(name string) (*transport.Interface, error) { func (n *Net) InterfaceByName(name string) (*transport.Interface, error) {
n.mu.Lock()
defer n.mu.Unlock()
for _, ifc := range n.interfaces { for _, ifc := range n.interfaces {
if ifc.Name == name { if ifc.Name == name {
return ifc, nil return ifc, nil
@ -87,7 +122,7 @@ func (n *Net) filterInterfaces(interfaces []*transport.Interface) []*transport.I
if n.interfaceFilter == nil { if n.interfaceFilter == nil {
return interfaces return interfaces
} }
result := []*transport.Interface{} var result []*transport.Interface
for _, iface := range interfaces { for _, iface := range interfaces {
if n.interfaceFilter(iface.Name) { if n.interfaceFilter(iface.Name) {
result = append(result, iface) result = append(result, iface)