diff --git a/client/iface/wgproxy/bind/proxy.go b/client/iface/wgproxy/bind/proxy.go index 20804175e..c1cc8dd30 100644 --- a/client/iface/wgproxy/bind/proxy.go +++ b/client/iface/wgproxy/bind/proxy.go @@ -110,6 +110,10 @@ func (p *ProxyBind) CloseConn() error { } func (p *ProxyBind) close() error { + if p.remoteConn == nil { + return nil + } + p.closeMu.Lock() defer p.closeMu.Unlock() @@ -125,7 +129,7 @@ func (p *ProxyBind) close() error { p.pausedCond.L.Unlock() p.pausedCond.Signal() - p.bind.RemoveEndpoint(bind.EndpointToUDPAddr(*p.wgCurrentUsed)) + p.bind.RemoveEndpoint(bind.EndpointToUDPAddr(*p.wgRelayedEndpoint)) if rErr := p.remoteConn.Close(); rErr != nil && !errors.Is(rErr, net.ErrClosed) { return rErr diff --git a/client/iface/wgproxy/proxy_test.go b/client/iface/wgproxy/proxy_test.go index 7c4faa6b7..80ca57564 100644 --- a/client/iface/wgproxy/proxy_test.go +++ b/client/iface/wgproxy/proxy_test.go @@ -7,14 +7,18 @@ import ( "net" "testing" + "github.com/netbirdio/netbird/client/iface/bind" + bindproxy "github.com/netbirdio/netbird/client/iface/wgproxy/bind" "github.com/netbirdio/netbird/client/iface/wgproxy/ebpf" "github.com/netbirdio/netbird/client/iface/wgproxy/udp" + "github.com/netbirdio/netbird/util" ) func init() { _ = util.InitLog("debug", "console") } + func TestProxyRedirect(t *testing.T) { ebpfProxy := ebpf.NewWGEBPFProxy(51831) if err := ebpfProxy.Listen(); err != nil { @@ -28,9 +32,10 @@ func TestProxyRedirect(t *testing.T) { }() tests := []struct { - name string - proxy Proxy - wgPort int + name string + proxy Proxy + wgPort int + endpointAddr *net.UDPAddr }{ { name: "ebpf kernel proxy", @@ -45,12 +50,12 @@ func TestProxyRedirect(t *testing.T) { } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - redirectTraffic(t, tt.proxy, tt.wgPort) + redirectTraffic(t, tt.proxy, tt.wgPort, tt.endpointAddr) }) } } -func redirectTraffic(t *testing.T, proxy Proxy, wgPort int) { +func redirectTraffic(t *testing.T, proxy Proxy, wgPort int, endPointAddr *net.UDPAddr) { t.Helper() msgHelloFromRelay := []byte("hello from relay") @@ -82,7 +87,7 @@ func redirectTraffic(t *testing.T, proxy Proxy, wgPort int) { _ = relayedServer.Close() }() - if err := proxy.AddTurnConn(context.Background(), nil, relayedConn); err != nil { + if err := proxy.AddTurnConn(context.Background(), endPointAddr, relayedConn); err != nil { t.Errorf("error: %v", err) } defer func() { @@ -134,7 +139,7 @@ func redirectTraffic(t *testing.T, proxy Proxy, wgPort int) { } } -func TestProxyCloseByRemoteConnEBPF(t *testing.T) { +func TestProxyCloseByRemoteConn(t *testing.T) { ctx := context.Background() ebpfProxy := ebpf.NewWGEBPFProxy(51831) @@ -148,9 +153,15 @@ func TestProxyCloseByRemoteConnEBPF(t *testing.T) { } }() + iceBind := bind.NewICEBind(nil, nil) + endpointAddress := &net.UDPAddr{ + IP: net.IPv4(10, 0, 0, 1), + Port: 1234, + } tests := []struct { - name string - proxy Proxy + name string + proxy Proxy + endpointAddress *net.UDPAddr }{ { name: "ebpf proxy", @@ -160,6 +171,11 @@ func TestProxyCloseByRemoteConnEBPF(t *testing.T) { name: "udp proxy", proxy: udp.NewWGUDPProxy(51832), }, + { + name: "bind proxy", + proxy: bindproxy.NewProxyBind(iceBind), + endpointAddress: endpointAddress, + }, } relayedConn, _ := net.Dial("udp", "127.0.0.1:1234") @@ -168,7 +184,7 @@ func TestProxyCloseByRemoteConnEBPF(t *testing.T) { }() for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - err := tt.proxy.AddTurnConn(ctx, nil, relayedConn) + err := tt.proxy.AddTurnConn(ctx, endpointAddress, relayedConn) if err != nil { t.Errorf("error: %v", err) }