[client] Close the remote conn in proxy (#2626)

Port the conn close call to eBPF proxy
This commit is contained in:
Zoltan Papp 2024-09-25 18:50:10 +02:00 committed by GitHub
parent 1e4a0f77e2
commit 4ebf6e1c4c
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
16 changed files with 469 additions and 257 deletions

View File

@ -292,7 +292,7 @@ func (e *Engine) Start() error {
e.wgInterface = wgIface e.wgInterface = wgIface
userspace := e.wgInterface.IsUserspaceBind() userspace := e.wgInterface.IsUserspaceBind()
e.wgProxyFactory = wgproxy.NewFactory(e.ctx, userspace, e.config.WgPort) e.wgProxyFactory = wgproxy.NewFactory(userspace, e.config.WgPort)
if e.config.RosenpassEnabled { if e.config.RosenpassEnabled {
log.Infof("rosenpass is enabled") log.Infof("rosenpass is enabled")

View File

@ -527,8 +527,8 @@ func (conn *Conn) relayConnectionIsReady(rci RelayConnInfo) {
conn.log.Debugf("Relay connection is ready to use") conn.log.Debugf("Relay connection is ready to use")
conn.statusRelay.Set(StatusConnected) conn.statusRelay.Set(StatusConnected)
wgProxy := conn.wgProxyFactory.GetProxy(conn.ctx) wgProxy := conn.wgProxyFactory.GetProxy()
endpoint, err := wgProxy.AddTurnConn(rci.relayedConn) endpoint, err := wgProxy.AddTurnConn(conn.ctx, rci.relayedConn)
if err != nil { if err != nil {
conn.log.Errorf("failed to add relayed net.Conn to local proxy: %v", err) conn.log.Errorf("failed to add relayed net.Conn to local proxy: %v", err)
return return
@ -775,8 +775,8 @@ func (conn *Conn) getEndpointForICEConnInfo(iceConnInfo ICEConnInfo) (net.Addr,
return iceConnInfo.RemoteConn.RemoteAddr(), nil, nil return iceConnInfo.RemoteConn.RemoteAddr(), nil, nil
} }
conn.log.Debugf("setup ice turn connection") conn.log.Debugf("setup ice turn connection")
wgProxy := conn.wgProxyFactory.GetProxy(conn.ctx) wgProxy := conn.wgProxyFactory.GetProxy()
ep, err := wgProxy.AddTurnConn(iceConnInfo.RemoteConn) ep, err := wgProxy.AddTurnConn(conn.ctx, iceConnInfo.RemoteConn)
if err != nil { if err != nil {
conn.log.Errorf("failed to add turn net.Conn to local proxy: %v", err) conn.log.Errorf("failed to add turn net.Conn to local proxy: %v", err)
if errClose := wgProxy.CloseConn(); errClose != nil { if errClose := wgProxy.CloseConn(); errClose != nil {

View File

@ -44,7 +44,7 @@ func TestNewConn_interfaceFilter(t *testing.T) {
} }
func TestConn_GetKey(t *testing.T) { func TestConn_GetKey(t *testing.T) {
wgProxyFactory := wgproxy.NewFactory(context.Background(), false, connConf.LocalWgPort) wgProxyFactory := wgproxy.NewFactory(false, connConf.LocalWgPort)
defer func() { defer func() {
_ = wgProxyFactory.Free() _ = wgProxyFactory.Free()
}() }()
@ -59,7 +59,7 @@ func TestConn_GetKey(t *testing.T) {
} }
func TestConn_OnRemoteOffer(t *testing.T) { func TestConn_OnRemoteOffer(t *testing.T) {
wgProxyFactory := wgproxy.NewFactory(context.Background(), false, connConf.LocalWgPort) wgProxyFactory := wgproxy.NewFactory(false, connConf.LocalWgPort)
defer func() { defer func() {
_ = wgProxyFactory.Free() _ = wgProxyFactory.Free()
}() }()
@ -96,7 +96,7 @@ func TestConn_OnRemoteOffer(t *testing.T) {
} }
func TestConn_OnRemoteAnswer(t *testing.T) { func TestConn_OnRemoteAnswer(t *testing.T) {
wgProxyFactory := wgproxy.NewFactory(context.Background(), false, connConf.LocalWgPort) wgProxyFactory := wgproxy.NewFactory(false, connConf.LocalWgPort)
defer func() { defer func() {
_ = wgProxyFactory.Free() _ = wgProxyFactory.Free()
}() }()
@ -132,7 +132,7 @@ func TestConn_OnRemoteAnswer(t *testing.T) {
wg.Wait() wg.Wait()
} }
func TestConn_Status(t *testing.T) { func TestConn_Status(t *testing.T) {
wgProxyFactory := wgproxy.NewFactory(context.Background(), false, connConf.LocalWgPort) wgProxyFactory := wgproxy.NewFactory(false, connConf.LocalWgPort)
defer func() { defer func() {
_ = wgProxyFactory.Free() _ = wgProxyFactory.Free()
}() }()

View File

@ -1,4 +1,4 @@
package wgproxy package ebpf
import ( import (
"fmt" "fmt"

View File

@ -1,4 +1,4 @@
package wgproxy package ebpf
import ( import (
"fmt" "fmt"

View File

@ -1,6 +1,6 @@
//go:build linux && !android //go:build linux && !android
package wgproxy package ebpf
import ( import (
"context" "context"
@ -13,47 +13,49 @@ import (
"github.com/google/gopacket" "github.com/google/gopacket"
"github.com/google/gopacket/layers" "github.com/google/gopacket/layers"
"github.com/hashicorp/go-multierror"
"github.com/pion/transport/v3" "github.com/pion/transport/v3"
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
nberrors "github.com/netbirdio/netbird/client/errors"
"github.com/netbirdio/netbird/client/internal/ebpf" "github.com/netbirdio/netbird/client/internal/ebpf"
ebpfMgr "github.com/netbirdio/netbird/client/internal/ebpf/manager" ebpfMgr "github.com/netbirdio/netbird/client/internal/ebpf/manager"
nbnet "github.com/netbirdio/netbird/util/net" nbnet "github.com/netbirdio/netbird/util/net"
) )
const (
loopbackAddr = "127.0.0.1"
)
// WGEBPFProxy definition for proxy with EBPF support // WGEBPFProxy definition for proxy with EBPF support
type WGEBPFProxy struct { type WGEBPFProxy struct {
ebpfManager ebpfMgr.Manager
ctx context.Context
cancel context.CancelFunc
lastUsedPort uint16
localWGListenPort int localWGListenPort int
ebpfManager ebpfMgr.Manager
turnConnStore map[uint16]net.Conn turnConnStore map[uint16]net.Conn
turnConnMutex sync.Mutex turnConnMutex sync.Mutex
lastUsedPort uint16
rawConn net.PacketConn rawConn net.PacketConn
conn transport.UDPConn conn transport.UDPConn
ctx context.Context
ctxCancel context.CancelFunc
} }
// NewWGEBPFProxy create new WGEBPFProxy instance // NewWGEBPFProxy create new WGEBPFProxy instance
func NewWGEBPFProxy(ctx context.Context, wgPort int) *WGEBPFProxy { func NewWGEBPFProxy(wgPort int) *WGEBPFProxy {
log.Debugf("instantiate ebpf proxy") log.Debugf("instantiate ebpf proxy")
wgProxy := &WGEBPFProxy{ wgProxy := &WGEBPFProxy{
localWGListenPort: wgPort, localWGListenPort: wgPort,
ebpfManager: ebpf.GetEbpfManagerInstance(), ebpfManager: ebpf.GetEbpfManagerInstance(),
lastUsedPort: 0,
turnConnStore: make(map[uint16]net.Conn), turnConnStore: make(map[uint16]net.Conn),
} }
wgProxy.ctx, wgProxy.cancel = context.WithCancel(ctx)
return wgProxy return wgProxy
} }
// listen load ebpf program and listen the proxy // Listen load ebpf program and listen the proxy
func (p *WGEBPFProxy) listen() error { func (p *WGEBPFProxy) Listen() error {
pl := portLookup{} pl := portLookup{}
wgPorxyPort, err := pl.searchFreePort() wgPorxyPort, err := pl.searchFreePort()
if err != nil { if err != nil {
@ -72,9 +74,11 @@ func (p *WGEBPFProxy) listen() error {
addr := net.UDPAddr{ addr := net.UDPAddr{
Port: wgPorxyPort, Port: wgPorxyPort,
IP: net.ParseIP("127.0.0.1"), IP: net.ParseIP(loopbackAddr),
} }
p.ctx, p.ctxCancel = context.WithCancel(context.Background())
conn, err := nbnet.ListenUDP("udp", &addr) conn, err := nbnet.ListenUDP("udp", &addr)
if err != nil { if err != nil {
cErr := p.Free() cErr := p.Free()
@ -91,106 +95,110 @@ func (p *WGEBPFProxy) listen() error {
} }
// AddTurnConn add new turn connection for the proxy // AddTurnConn add new turn connection for the proxy
func (p *WGEBPFProxy) AddTurnConn(turnConn net.Conn) (net.Addr, error) { func (p *WGEBPFProxy) AddTurnConn(ctx context.Context, turnConn net.Conn) (net.Addr, error) {
wgEndpointPort, err := p.storeTurnConn(turnConn) wgEndpointPort, err := p.storeTurnConn(turnConn)
if err != nil { if err != nil {
return nil, err return nil, err
} }
go p.proxyToLocal(wgEndpointPort, turnConn) go p.proxyToLocal(ctx, wgEndpointPort, turnConn)
log.Infof("turn conn added to wg proxy store: %s, endpoint port: :%d", turnConn.RemoteAddr(), wgEndpointPort) log.Infof("turn conn added to wg proxy store: %s, endpoint port: :%d", turnConn.RemoteAddr(), wgEndpointPort)
wgEndpoint := &net.UDPAddr{ wgEndpoint := &net.UDPAddr{
IP: net.ParseIP("127.0.0.1"), IP: net.ParseIP(loopbackAddr),
Port: int(wgEndpointPort), Port: int(wgEndpointPort),
} }
return wgEndpoint, nil return wgEndpoint, nil
} }
// CloseConn doing nothing because this type of proxy implementation does not store the connection // Free resources except the remoteConns will be keep open.
func (p *WGEBPFProxy) CloseConn() error {
return nil
}
// Free resources
func (p *WGEBPFProxy) Free() error { func (p *WGEBPFProxy) Free() error {
log.Debugf("free up ebpf wg proxy") log.Debugf("free up ebpf wg proxy")
var err1, err2, err3 error if p.ctx != nil && p.ctx.Err() != nil {
if p.conn != nil { //nolint
err1 = p.conn.Close() return nil
} }
err2 = p.ebpfManager.FreeWGProxy() p.ctxCancel()
if p.rawConn != nil {
err3 = p.rawConn.Close() var result *multierror.Error
if err := p.conn.Close(); err != nil {
result = multierror.Append(result, err)
} }
if err1 != nil { if err := p.ebpfManager.FreeWGProxy(); err != nil {
return err1 result = multierror.Append(result, err)
} }
if err2 != nil { if err := p.rawConn.Close(); err != nil {
return err2 result = multierror.Append(result, err)
} }
return nberrors.FormatErrorOrNil(result)
return err3
} }
func (p *WGEBPFProxy) proxyToLocal(endpointPort uint16, remoteConn net.Conn) { func (p *WGEBPFProxy) proxyToLocal(ctx context.Context, endpointPort uint16, remoteConn net.Conn) {
defer p.removeTurnConn(endpointPort)
var (
err error
n int
)
buf := make([]byte, 1500) buf := make([]byte, 1500)
var err error for ctx.Err() == nil {
defer func() {
p.removeTurnConn(endpointPort)
}()
for {
select {
case <-p.ctx.Done():
return
default:
var n int
n, err = remoteConn.Read(buf) n, err = remoteConn.Read(buf)
if err != nil { if err != nil {
if ctx.Err() != nil {
return
}
if err != io.EOF { if err != io.EOF {
log.Errorf("failed to read from turn conn (endpoint: :%d): %s", endpointPort, err) log.Errorf("failed to read from turn conn (endpoint: :%d): %s", endpointPort, err)
} }
return return
} }
err = p.sendPkg(buf[:n], endpointPort)
if err != nil { if err := p.sendPkg(buf[:n], endpointPort); err != nil {
log.Errorf("failed to write out turn pkg to local conn: %v", err) if ctx.Err() != nil || p.ctx.Err() != nil {
return
} }
log.Errorf("failed to write out turn pkg to local conn: %v", err)
} }
} }
} }
// proxyToRemote read messages from local WireGuard interface and forward it to remote conn // proxyToRemote read messages from local WireGuard interface and forward it to remote conn
// From this go routine has only one instance.
func (p *WGEBPFProxy) proxyToRemote() { func (p *WGEBPFProxy) proxyToRemote() {
buf := make([]byte, 1500) buf := make([]byte, 1500)
for { for p.ctx.Err() == nil {
select { if err := p.readAndForwardPacket(buf); err != nil {
case <-p.ctx.Done(): if p.ctx.Err() != nil {
return return
default: }
log.Errorf("failed to proxy packet to remote conn: %s", err)
}
}
}
func (p *WGEBPFProxy) readAndForwardPacket(buf []byte) error {
n, addr, err := p.conn.ReadFromUDP(buf) n, addr, err := p.conn.ReadFromUDP(buf)
if err != nil { if err != nil {
log.Errorf("failed to read UDP pkg from WG: %s", err) return fmt.Errorf("failed to read UDP packet from WG: %w", err)
return
} }
p.turnConnMutex.Lock() p.turnConnMutex.Lock()
conn, ok := p.turnConnStore[uint16(addr.Port)] conn, ok := p.turnConnStore[uint16(addr.Port)]
p.turnConnMutex.Unlock() p.turnConnMutex.Unlock()
if !ok { if !ok {
if p.ctx.Err() == nil {
log.Debugf("turn conn not found by port because conn already has been closed: %d", addr.Port) log.Debugf("turn conn not found by port because conn already has been closed: %d", addr.Port)
continue }
return nil
} }
_, err = conn.Write(buf[:n]) if _, err := conn.Write(buf[:n]); err != nil {
if err != nil { return fmt.Errorf("failed to forward local WG packet (%d) to remote turn conn: %w", addr.Port, err)
log.Debugf("failed to forward local wg pkg (%d) to remote turn conn: %s", addr.Port, err)
}
}
} }
return nil
} }
func (p *WGEBPFProxy) storeTurnConn(turnConn net.Conn) (uint16, error) { func (p *WGEBPFProxy) storeTurnConn(turnConn net.Conn) (uint16, error) {
@ -206,11 +214,14 @@ func (p *WGEBPFProxy) storeTurnConn(turnConn net.Conn) (uint16, error) {
} }
func (p *WGEBPFProxy) removeTurnConn(turnConnID uint16) { func (p *WGEBPFProxy) removeTurnConn(turnConnID uint16) {
log.Debugf("remove turn conn from store by port: %d", turnConnID)
p.turnConnMutex.Lock() p.turnConnMutex.Lock()
defer p.turnConnMutex.Unlock() defer p.turnConnMutex.Unlock()
delete(p.turnConnStore, turnConnID)
_, ok := p.turnConnStore[turnConnID]
if ok {
log.Debugf("remove turn conn from store by port: %d", turnConnID)
}
delete(p.turnConnStore, turnConnID)
} }
func (p *WGEBPFProxy) nextFreePort() (uint16, error) { func (p *WGEBPFProxy) nextFreePort() (uint16, error) {

View File

@ -1,14 +1,13 @@
//go:build linux && !android //go:build linux && !android
package wgproxy package ebpf
import ( import (
"context"
"testing" "testing"
) )
func TestWGEBPFProxy_connStore(t *testing.T) { func TestWGEBPFProxy_connStore(t *testing.T) {
wgProxy := NewWGEBPFProxy(context.Background(), 1) wgProxy := NewWGEBPFProxy(1)
p, _ := wgProxy.storeTurnConn(nil) p, _ := wgProxy.storeTurnConn(nil)
if p != 1 { if p != 1 {
@ -28,7 +27,7 @@ func TestWGEBPFProxy_connStore(t *testing.T) {
} }
func TestWGEBPFProxy_portCalculation_overflow(t *testing.T) { func TestWGEBPFProxy_portCalculation_overflow(t *testing.T) {
wgProxy := NewWGEBPFProxy(context.Background(), 1) wgProxy := NewWGEBPFProxy(1)
_, _ = wgProxy.storeTurnConn(nil) _, _ = wgProxy.storeTurnConn(nil)
wgProxy.lastUsedPort = 65535 wgProxy.lastUsedPort = 65535
@ -44,7 +43,7 @@ func TestWGEBPFProxy_portCalculation_overflow(t *testing.T) {
} }
func TestWGEBPFProxy_portCalculation_maxConn(t *testing.T) { func TestWGEBPFProxy_portCalculation_maxConn(t *testing.T) {
wgProxy := NewWGEBPFProxy(context.Background(), 1) wgProxy := NewWGEBPFProxy(1)
for i := 0; i < 65535; i++ { for i := 0; i < 65535; i++ {
_, _ = wgProxy.storeTurnConn(nil) _, _ = wgProxy.storeTurnConn(nil)

View File

@ -0,0 +1,44 @@
//go:build linux && !android
package ebpf
import (
"context"
"fmt"
"net"
)
// ProxyWrapper help to keep the remoteConn instance for net.Conn.Close function call
type ProxyWrapper struct {
WgeBPFProxy *WGEBPFProxy
remoteConn net.Conn
cancel context.CancelFunc // with thic cancel function, we stop remoteToLocal thread
}
func (e *ProxyWrapper) AddTurnConn(ctx context.Context, remoteConn net.Conn) (net.Addr, error) {
ctxConn, cancel := context.WithCancel(ctx)
addr, err := e.WgeBPFProxy.AddTurnConn(ctxConn, remoteConn)
if err != nil {
cancel()
return nil, fmt.Errorf("add turn conn: %w", err)
}
e.remoteConn = remoteConn
e.cancel = cancel
return addr, err
}
// CloseConn close the remoteConn and automatically remove the conn instance from the map
func (e *ProxyWrapper) CloseConn() error {
if e.cancel == nil {
return fmt.Errorf("proxy not started")
}
e.cancel()
if err := e.remoteConn.Close(); err != nil {
return fmt.Errorf("failed to close remote conn: %w", err)
}
return nil
}

View File

@ -1,22 +0,0 @@
package wgproxy
import "context"
type Factory struct {
wgPort int
ebpfProxy Proxy
}
func (w *Factory) GetProxy(ctx context.Context) Proxy {
if w.ebpfProxy != nil {
return w.ebpfProxy
}
return NewWGUserSpaceProxy(ctx, w.wgPort)
}
func (w *Factory) Free() error {
if w.ebpfProxy != nil {
return w.ebpfProxy.Free()
}
return nil
}

View File

@ -3,20 +3,26 @@
package wgproxy package wgproxy
import ( import (
"context"
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/client/internal/wgproxy/ebpf"
"github.com/netbirdio/netbird/client/internal/wgproxy/usp"
) )
func NewFactory(ctx context.Context, userspace bool, wgPort int) *Factory { type Factory struct {
wgPort int
ebpfProxy *ebpf.WGEBPFProxy
}
func NewFactory(userspace bool, wgPort int) *Factory {
f := &Factory{wgPort: wgPort} f := &Factory{wgPort: wgPort}
if userspace { if userspace {
return f return f
} }
ebpfProxy := NewWGEBPFProxy(ctx, wgPort) ebpfProxy := ebpf.NewWGEBPFProxy(wgPort)
err := ebpfProxy.listen() err := ebpfProxy.Listen()
if err != nil { if err != nil {
log.Warnf("failed to initialize ebpf proxy, fallback to user space proxy: %s", err) log.Warnf("failed to initialize ebpf proxy, fallback to user space proxy: %s", err)
return f return f
@ -25,3 +31,20 @@ func NewFactory(ctx context.Context, userspace bool, wgPort int) *Factory {
f.ebpfProxy = ebpfProxy f.ebpfProxy = ebpfProxy
return f return f
} }
func (w *Factory) GetProxy() Proxy {
if w.ebpfProxy != nil {
p := &ebpf.ProxyWrapper{
WgeBPFProxy: w.ebpfProxy,
}
return p
}
return usp.NewWGUserSpaceProxy(w.wgPort)
}
func (w *Factory) Free() error {
if w.ebpfProxy == nil {
return nil
}
return w.ebpfProxy.Free()
}

View File

@ -2,8 +2,20 @@
package wgproxy package wgproxy
import "context" import "github.com/netbirdio/netbird/client/internal/wgproxy/usp"
func NewFactory(ctx context.Context, _ bool, wgPort int) *Factory { type Factory struct {
wgPort int
}
func NewFactory(_ bool, wgPort int) *Factory {
return &Factory{wgPort: wgPort} return &Factory{wgPort: wgPort}
} }
func (w *Factory) GetProxy() Proxy {
return usp.NewWGUserSpaceProxy(w.wgPort)
}
func (w *Factory) Free() error {
return nil
}

View File

@ -1,12 +1,12 @@
package wgproxy package wgproxy
import ( import (
"context"
"net" "net"
) )
// Proxy is a transfer layer between the Turn connection and the WireGuard // Proxy is a transfer layer between the relayed connection and the WireGuard
type Proxy interface { type Proxy interface {
AddTurnConn(turnConn net.Conn) (net.Addr, error) AddTurnConn(ctx context.Context, turnConn net.Conn) (net.Addr, error)
CloseConn() error CloseConn() error
Free() error
} }

View File

@ -0,0 +1,128 @@
//go:build linux
package wgproxy
import (
"context"
"io"
"net"
"os"
"runtime"
"testing"
"time"
"github.com/netbirdio/netbird/client/internal/wgproxy/ebpf"
"github.com/netbirdio/netbird/client/internal/wgproxy/usp"
"github.com/netbirdio/netbird/util"
)
func TestMain(m *testing.M) {
_ = util.InitLog("trace", "console")
code := m.Run()
os.Exit(code)
}
type mocConn struct {
closeChan chan struct{}
closed bool
}
func newMockConn() *mocConn {
return &mocConn{
closeChan: make(chan struct{}),
}
}
func (m *mocConn) Read(b []byte) (n int, err error) {
<-m.closeChan
return 0, io.EOF
}
func (m *mocConn) Write(b []byte) (n int, err error) {
<-m.closeChan
return 0, io.EOF
}
func (m *mocConn) Close() error {
if m.closed == true {
return nil
}
m.closed = true
close(m.closeChan)
return nil
}
func (m *mocConn) LocalAddr() net.Addr {
panic("implement me")
}
func (m *mocConn) RemoteAddr() net.Addr {
return &net.UDPAddr{
IP: net.ParseIP("172.16.254.1"),
}
}
func (m *mocConn) SetDeadline(t time.Time) error {
panic("implement me")
}
func (m *mocConn) SetReadDeadline(t time.Time) error {
panic("implement me")
}
func (m *mocConn) SetWriteDeadline(t time.Time) error {
panic("implement me")
}
func TestProxyCloseByRemoteConn(t *testing.T) {
ctx := context.Background()
tests := []struct {
name string
proxy Proxy
}{
{
name: "userspace proxy",
proxy: usp.NewWGUserSpaceProxy(51830),
},
}
if runtime.GOOS == "linux" && os.Getenv("GITHUB_ACTIONS") != "true" {
ebpfProxy := ebpf.NewWGEBPFProxy(51831)
if err := ebpfProxy.Listen(); err != nil {
t.Fatalf("failed to initialize ebpf proxy: %s", err)
}
defer func() {
if err := ebpfProxy.Free(); err != nil {
t.Errorf("failed to free ebpf proxy: %s", err)
}
}()
proxyWrapper := &ebpf.ProxyWrapper{
WgeBPFProxy: ebpfProxy,
}
tests = append(tests, struct {
name string
proxy Proxy
}{
name: "ebpf proxy",
proxy: proxyWrapper,
})
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
relayedConn := newMockConn()
_, err := tt.proxy.AddTurnConn(ctx, relayedConn)
if err != nil {
t.Errorf("error: %v", err)
}
_ = relayedConn.Close()
if err := tt.proxy.CloseConn(); err != nil {
t.Errorf("error: %v", err)
}
})
}
}

View File

@ -1,129 +0,0 @@
package wgproxy
import (
"context"
"fmt"
"io"
"net"
log "github.com/sirupsen/logrus"
)
// WGUserSpaceProxy proxies
type WGUserSpaceProxy struct {
localWGListenPort int
ctx context.Context
cancel context.CancelFunc
remoteConn net.Conn
localConn net.Conn
}
// NewWGUserSpaceProxy instantiate a user space WireGuard proxy
func NewWGUserSpaceProxy(ctx context.Context, wgPort int) *WGUserSpaceProxy {
log.Debugf("Initializing new user space proxy with port %d", wgPort)
p := &WGUserSpaceProxy{
localWGListenPort: wgPort,
}
p.ctx, p.cancel = context.WithCancel(ctx)
return p
}
// AddTurnConn start the proxy with the given remote conn
func (p *WGUserSpaceProxy) AddTurnConn(remoteConn net.Conn) (net.Addr, error) {
p.remoteConn = remoteConn
var err error
dialer := &net.Dialer{}
p.localConn, err = dialer.DialContext(p.ctx, "udp", fmt.Sprintf(":%d", p.localWGListenPort))
if err != nil {
log.Errorf("failed dialing to local Wireguard port %s", err)
return nil, err
}
go p.proxyToRemote()
go p.proxyToLocal()
return p.localConn.LocalAddr(), err
}
// CloseConn close the localConn
func (p *WGUserSpaceProxy) CloseConn() error {
p.cancel()
if p.localConn == nil {
return nil
}
if p.remoteConn == nil {
return nil
}
if err := p.remoteConn.Close(); err != nil {
log.Warnf("failed to close remote conn: %s", err)
}
return p.localConn.Close()
}
// Free doing nothing because this implementation of proxy does not have global state
func (p *WGUserSpaceProxy) Free() error {
return nil
}
// proxyToRemote proxies everything from Wireguard to the RemoteKey peer
// blocks
func (p *WGUserSpaceProxy) proxyToRemote() {
defer log.Infof("exit from proxyToRemote: %s", p.localConn.LocalAddr())
buf := make([]byte, 1500)
for {
select {
case <-p.ctx.Done():
return
default:
n, err := p.localConn.Read(buf)
if err != nil {
log.Debugf("failed to read from wg interface conn: %s", err)
continue
}
_, err = p.remoteConn.Write(buf[:n])
if err != nil {
if err == io.EOF {
p.cancel()
} else {
log.Debugf("failed to write to remote conn: %s", err)
}
continue
}
}
}
}
// proxyToLocal proxies everything from the RemoteKey peer to local Wireguard
// blocks
func (p *WGUserSpaceProxy) proxyToLocal() {
defer p.cancel()
defer log.Infof("exit from proxyToLocal: %s", p.localConn.LocalAddr())
buf := make([]byte, 1500)
for {
select {
case <-p.ctx.Done():
return
default:
n, err := p.remoteConn.Read(buf)
if err != nil {
if err == io.EOF {
return
}
log.Errorf("failed to read from remote conn: %s", err)
continue
}
_, err = p.localConn.Write(buf[:n])
if err != nil {
log.Debugf("failed to write to wg interface conn: %s", err)
continue
}
}
}
}

View File

@ -0,0 +1,146 @@
package usp
import (
"context"
"fmt"
"net"
"sync"
"github.com/hashicorp/go-multierror"
log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/client/errors"
)
// WGUserSpaceProxy proxies
type WGUserSpaceProxy struct {
localWGListenPort int
ctx context.Context
cancel context.CancelFunc
remoteConn net.Conn
localConn net.Conn
closeMu sync.Mutex
closed bool
}
// NewWGUserSpaceProxy instantiate a user space WireGuard proxy. This is not a thread safe implementation
func NewWGUserSpaceProxy(wgPort int) *WGUserSpaceProxy {
log.Debugf("Initializing new user space proxy with port %d", wgPort)
p := &WGUserSpaceProxy{
localWGListenPort: wgPort,
}
return p
}
// AddTurnConn start the proxy with the given remote conn
func (p *WGUserSpaceProxy) AddTurnConn(ctx context.Context, remoteConn net.Conn) (net.Addr, error) {
p.ctx, p.cancel = context.WithCancel(ctx)
p.remoteConn = remoteConn
var err error
dialer := net.Dialer{}
p.localConn, err = dialer.DialContext(p.ctx, "udp", fmt.Sprintf(":%d", p.localWGListenPort))
if err != nil {
log.Errorf("failed dialing to local Wireguard port %s", err)
return nil, err
}
go p.proxyToRemote()
go p.proxyToLocal()
return p.localConn.LocalAddr(), err
}
// CloseConn close the localConn
func (p *WGUserSpaceProxy) CloseConn() error {
if p.cancel == nil {
return fmt.Errorf("proxy not started")
}
return p.close()
}
func (p *WGUserSpaceProxy) close() error {
p.closeMu.Lock()
defer p.closeMu.Unlock()
// prevent double close
if p.closed {
return nil
}
p.closed = true
p.cancel()
var result *multierror.Error
if err := p.remoteConn.Close(); err != nil {
result = multierror.Append(result, fmt.Errorf("remote conn: %s", err))
}
if err := p.localConn.Close(); err != nil {
result = multierror.Append(result, fmt.Errorf("local conn: %s", err))
}
return errors.FormatErrorOrNil(result)
}
// proxyToRemote proxies from Wireguard to the RemoteKey
func (p *WGUserSpaceProxy) proxyToRemote() {
defer func() {
if err := p.close(); err != nil {
log.Warnf("error in proxy to remote loop: %s", err)
}
}()
buf := make([]byte, 1500)
for p.ctx.Err() == nil {
n, err := p.localConn.Read(buf)
if err != nil {
if p.ctx.Err() != nil {
return
}
log.Debugf("failed to read from wg interface conn: %s", err)
return
}
_, err = p.remoteConn.Write(buf[:n])
if err != nil {
if p.ctx.Err() != nil {
return
}
log.Debugf("failed to write to remote conn: %s", err)
return
}
}
}
// proxyToLocal proxies from the Remote peer to local WireGuard
func (p *WGUserSpaceProxy) proxyToLocal() {
defer func() {
if err := p.close(); err != nil {
log.Warnf("error in proxy to local loop: %s", err)
}
}()
buf := make([]byte, 1500)
for p.ctx.Err() == nil {
n, err := p.remoteConn.Read(buf)
if err != nil {
if p.ctx.Err() != nil {
return
}
log.Errorf("failed to read from remote conn: %s, %s", p.remoteConn.RemoteAddr(), err)
return
}
_, err = p.localConn.Write(buf[:n])
if err != nil {
if p.ctx.Err() != nil {
return
}
log.Debugf("failed to write to wg interface conn: %s", err)
continue
}
}
}

View File

@ -13,7 +13,7 @@ func TestServerPicker_UnavailableServers(t *testing.T) {
PeerID: "test", PeerID: "test",
} }
ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second) ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
defer cancel() defer cancel()
go func() { go func() {