diff --git a/client/iface/bind/ice_bind.go b/client/iface/bind/ice_bind.go index 41f415af7..6897f04a1 100644 --- a/client/iface/bind/ice_bind.go +++ b/client/iface/bind/ice_bind.go @@ -5,7 +5,6 @@ import ( "net" "net/netip" "runtime" - "strings" "sync" "github.com/pion/stun/v2" @@ -108,35 +107,17 @@ func (s *ICEBind) GetICEMux() (*UniversalUDPMuxDefault, error) { return s.udpMux, nil } -func (b *ICEBind) SetEndpoint(peerAddress *net.UDPAddr, conn net.Conn) (*net.UDPAddr, error) { - fakeUDPAddr, err := fakeAddress(peerAddress) - if err != nil { - return nil, err - } - - // force IPv4 - fakeAddr, ok := netip.AddrFromSlice(fakeUDPAddr.IP.To4()) - if !ok { - return nil, fmt.Errorf("failed to convert IP to netip.Addr") - } - +func (b *ICEBind) SetEndpoint(fakeIP netip.Addr, conn net.Conn) { b.endpointsMu.Lock() - b.endpoints[fakeAddr] = conn + b.endpoints[fakeIP] = conn b.endpointsMu.Unlock() - - return fakeUDPAddr, nil } -func (b *ICEBind) RemoveEndpoint(fakeUDPAddr *net.UDPAddr) { - fakeAddr, ok := netip.AddrFromSlice(fakeUDPAddr.IP.To4()) - if !ok { - log.Warnf("failed to convert IP to netip.Addr") - return - } - +func (b *ICEBind) RemoveEndpoint(fakeIP netip.Addr) { b.endpointsMu.Lock() defer b.endpointsMu.Unlock() - delete(b.endpoints, fakeAddr) + + delete(b.endpoints, fakeIP) } func (b *ICEBind) Send(bufs [][]byte, ep wgConn.Endpoint) error { @@ -275,21 +256,6 @@ func (c *ICEBind) receiveRelayed(buffs [][]byte, sizes []int, eps []wgConn.Endpo } } -// fakeAddress returns a fake address that is used to as an identifier for the peer. -// The fake address is in the format of 127.1.x.x where x.x is the last two octets of the peer address. -func fakeAddress(peerAddress *net.UDPAddr) (*net.UDPAddr, error) { - octets := strings.Split(peerAddress.IP.String(), ".") - if len(octets) != 4 { - return nil, fmt.Errorf("invalid IP format") - } - - newAddr := &net.UDPAddr{ - IP: net.ParseIP(fmt.Sprintf("127.1.%s.%s", octets[2], octets[3])), - Port: peerAddress.Port, - } - return newAddr, nil -} - func getMessages(msgsPool *sync.Pool) *[]ipv6.Message { return msgsPool.Get().(*[]ipv6.Message) } diff --git a/client/iface/wgproxy/bind/proxy.go b/client/iface/wgproxy/bind/proxy.go index 8a2e65382..614787e17 100644 --- a/client/iface/wgproxy/bind/proxy.go +++ b/client/iface/wgproxy/bind/proxy.go @@ -6,6 +6,7 @@ import ( "fmt" "net" "net/netip" + "strings" "sync" log "github.com/sirupsen/logrus" @@ -16,13 +17,13 @@ import ( type ProxyBind struct { Bind *bind.ICEBind - wgAddr *net.UDPAddr - wgEndpoint *bind.Endpoint - remoteConn net.Conn - ctx context.Context - cancel context.CancelFunc - closeMu sync.Mutex - closed bool + fakeNetIP *netip.AddrPort + wgBindEndpoint *bind.Endpoint + remoteConn net.Conn + ctx context.Context + cancel context.CancelFunc + closeMu sync.Mutex + closed bool pausedMu sync.Mutex paused bool @@ -33,20 +34,24 @@ type ProxyBind struct { // endpoint is the NetBird address of the remote peer. The SetEndpoint return with the address what will be used in the // WireGuard configuration. func (p *ProxyBind) AddTurnConn(ctx context.Context, nbAddr *net.UDPAddr, remoteConn net.Conn) error { - addr, err := p.Bind.SetEndpoint(nbAddr, remoteConn) + fakeNetIP, err := fakeAddress(nbAddr) if err != nil { return err } - p.wgAddr = addr - p.wgEndpoint = addrToEndpoint(addr) + p.fakeNetIP = fakeNetIP + p.wgBindEndpoint = &bind.Endpoint{AddrPort: *fakeNetIP} p.remoteConn = remoteConn p.ctx, p.cancel = context.WithCancel(ctx) - return err + return nil } func (p *ProxyBind) EndpointAddr() *net.UDPAddr { - return p.wgAddr + return &net.UDPAddr{ + IP: p.fakeNetIP.Addr().AsSlice(), + Port: int(p.fakeNetIP.Port()), + Zone: p.fakeNetIP.Addr().Zone(), + } } func (p *ProxyBind) Work() { @@ -54,6 +59,8 @@ func (p *ProxyBind) Work() { return } + p.Bind.SetEndpoint(p.fakeNetIP.Addr(), p.remoteConn) + p.pausedMu.Lock() p.paused = false p.pausedMu.Unlock() @@ -93,7 +100,7 @@ func (p *ProxyBind) close() error { p.cancel() - p.Bind.RemoveEndpoint(p.wgAddr) + p.Bind.RemoveEndpoint(p.fakeNetIP.Addr()) if rErr := p.remoteConn.Close(); rErr != nil && !errors.Is(rErr, net.ErrClosed) { return rErr @@ -126,7 +133,7 @@ func (p *ProxyBind) proxyToLocal(ctx context.Context) { } msg := bind.RecvMessage{ - Endpoint: p.wgEndpoint, + Endpoint: p.wgBindEndpoint, Buffer: buf[:n], } p.Bind.RecvChan <- msg @@ -134,8 +141,19 @@ func (p *ProxyBind) proxyToLocal(ctx context.Context) { } } -func addrToEndpoint(addr *net.UDPAddr) *bind.Endpoint { - ip, _ := netip.AddrFromSlice(addr.IP.To4()) - addrPort := netip.AddrPortFrom(ip, uint16(addr.Port)) - return &bind.Endpoint{AddrPort: addrPort} +// fakeAddress returns a fake address that is used to as an identifier for the peer. +// The fake address is in the format of 127.1.x.x where x.x is the last two octets of the peer address. +func fakeAddress(peerAddress *net.UDPAddr) (*netip.AddrPort, error) { + octets := strings.Split(peerAddress.IP.String(), ".") + if len(octets) != 4 { + return nil, fmt.Errorf("invalid IP format") + } + + fakeIP, err := netip.ParseAddr(fmt.Sprintf("127.1.%s.%s", octets[2], octets[3])) + if err != nil { + return nil, fmt.Errorf("failed to parse new IP: %w", err) + } + + netipAddr := netip.AddrPortFrom(fakeIP, uint16(peerAddress.Port)) + return &netipAddr, nil } diff --git a/client/internal/peer/conn.go b/client/internal/peer/conn.go index 9b4d1a554..b91cfe33c 100644 --- a/client/internal/peer/conn.go +++ b/client/internal/peer/conn.go @@ -442,8 +442,8 @@ func (conn *Conn) onRelayConnectionIsReady(rci RelayConnInfo) { conn.log.Infof("created new wgProxy for relay connection: %s", wgProxy.EndpointAddr().String()) - if conn.iceP2PIsActive() { - conn.log.Debugf("do not switch to relay because current priority is: %s", conn.currentConnPriority.String()) + if conn.isICEActive() { + conn.log.Infof("do not switch to relay because current priority is: %s", conn.currentConnPriority.String()) conn.setRelayedProxy(wgProxy) conn.statusRelay.Set(StatusConnected) conn.updateRelayStatus(rci.relayedConn.RemoteAddr().String(), rci.rosenpassPubKey) @@ -711,8 +711,8 @@ func (conn *Conn) isReadyToUpgrade() bool { return conn.wgProxyRelay != nil && conn.currentConnPriority != connPriorityRelay } -func (conn *Conn) iceP2PIsActive() bool { - return conn.currentConnPriority == connPriorityICEP2P && conn.statusICE.Get() == StatusConnected +func (conn *Conn) isICEActive() bool { + return (conn.currentConnPriority == connPriorityICEP2P || conn.currentConnPriority == connPriorityICETurn) && conn.statusICE.Get() == StatusConnected } func (conn *Conn) removeWgPeer() error {