mirror of
https://github.com/netbirdio/netbird.git
synced 2025-06-12 12:56:58 +02:00
[relay,client] Relay/fix/wg roaming (#2691)
If a peer connection switches from Relayed to ICE P2P, the Relayed proxy still consumes the data the other peer sends. Because the proxy is operating, the WireGuard switches back to the Relayed proxy automatically, thanks to the roaming feature. Extend the Proxy implementation with pause/resume functions. Before switching to the p2p connection, pause the WireGuard proxy operation to prevent unnecessary package sources. Consider waiting some milliseconds after the pause to be sure the WireGuard engine already processed all UDP msg in from the pipe.
This commit is contained in:
parent
b2379175fe
commit
0e95f16cdd
@ -82,8 +82,6 @@ type Conn struct {
|
|||||||
config ConnConfig
|
config ConnConfig
|
||||||
statusRecorder *Status
|
statusRecorder *Status
|
||||||
wgProxyFactory *wgproxy.Factory
|
wgProxyFactory *wgproxy.Factory
|
||||||
wgProxyICE wgproxy.Proxy
|
|
||||||
wgProxyRelay wgproxy.Proxy
|
|
||||||
signaler *Signaler
|
signaler *Signaler
|
||||||
iFaceDiscover stdnet.ExternalIFaceDiscover
|
iFaceDiscover stdnet.ExternalIFaceDiscover
|
||||||
relayManager *relayClient.Manager
|
relayManager *relayClient.Manager
|
||||||
@ -106,7 +104,8 @@ type Conn struct {
|
|||||||
beforeAddPeerHooks []nbnet.AddHookFunc
|
beforeAddPeerHooks []nbnet.AddHookFunc
|
||||||
afterRemovePeerHooks []nbnet.RemoveHookFunc
|
afterRemovePeerHooks []nbnet.RemoveHookFunc
|
||||||
|
|
||||||
endpointRelay *net.UDPAddr
|
wgProxyICE wgproxy.Proxy
|
||||||
|
wgProxyRelay wgproxy.Proxy
|
||||||
|
|
||||||
// for reconnection operations
|
// for reconnection operations
|
||||||
iCEDisconnected chan bool
|
iCEDisconnected chan bool
|
||||||
@ -257,8 +256,7 @@ func (conn *Conn) Close() {
|
|||||||
conn.wgProxyICE = nil
|
conn.wgProxyICE = nil
|
||||||
}
|
}
|
||||||
|
|
||||||
err := conn.config.WgConfig.WgInterface.RemovePeer(conn.config.WgConfig.RemoteKey)
|
if err := conn.removeWgPeer(); err != nil {
|
||||||
if err != nil {
|
|
||||||
conn.log.Errorf("failed to remove wg endpoint: %v", err)
|
conn.log.Errorf("failed to remove wg endpoint: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -430,54 +428,59 @@ func (conn *Conn) iCEConnectionIsReady(priority ConnPriority, iceConnInfo ICECon
|
|||||||
|
|
||||||
conn.log.Debugf("ICE connection is ready")
|
conn.log.Debugf("ICE connection is ready")
|
||||||
|
|
||||||
conn.statusICE.Set(StatusConnected)
|
|
||||||
|
|
||||||
defer conn.updateIceState(iceConnInfo)
|
|
||||||
|
|
||||||
if conn.currentConnPriority > priority {
|
if conn.currentConnPriority > priority {
|
||||||
|
conn.statusICE.Set(StatusConnected)
|
||||||
|
conn.updateIceState(iceConnInfo)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
conn.log.Infof("set ICE to active connection")
|
conn.log.Infof("set ICE to active connection")
|
||||||
|
|
||||||
endpoint, wgProxy, err := conn.getEndpointForICEConnInfo(iceConnInfo)
|
var (
|
||||||
if err != nil {
|
ep *net.UDPAddr
|
||||||
return
|
wgProxy wgproxy.Proxy
|
||||||
|
err error
|
||||||
|
)
|
||||||
|
if iceConnInfo.RelayedOnLocal {
|
||||||
|
wgProxy, err = conn.newProxy(iceConnInfo.RemoteConn)
|
||||||
|
if err != nil {
|
||||||
|
conn.log.Errorf("failed to add turn net.Conn to local proxy: %v", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
ep = wgProxy.EndpointAddr()
|
||||||
|
conn.wgProxyICE = wgProxy
|
||||||
|
} else {
|
||||||
|
directEp, err := net.ResolveUDPAddr("udp", iceConnInfo.RemoteConn.RemoteAddr().String())
|
||||||
|
if err != nil {
|
||||||
|
log.Errorf("failed to resolveUDPaddr")
|
||||||
|
conn.handleConfigurationFailure(err, nil)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
ep = directEp
|
||||||
}
|
}
|
||||||
|
|
||||||
endpointUdpAddr, _ := net.ResolveUDPAddr(endpoint.Network(), endpoint.String())
|
if err := conn.runBeforeAddPeerHooks(ep.IP); err != nil {
|
||||||
conn.log.Debugf("Conn resolved IP is %s for endopint %s", endpoint, endpointUdpAddr.IP)
|
conn.log.Errorf("Before add peer hook failed: %v", err)
|
||||||
|
|
||||||
conn.connIDICE = nbnet.GenerateConnID()
|
|
||||||
for _, hook := range conn.beforeAddPeerHooks {
|
|
||||||
if err := hook(conn.connIDICE, endpointUdpAddr.IP); err != nil {
|
|
||||||
conn.log.Errorf("Before add peer hook failed: %v", err)
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
conn.workerRelay.DisableWgWatcher()
|
conn.workerRelay.DisableWgWatcher()
|
||||||
|
|
||||||
err = conn.configureWGEndpoint(endpointUdpAddr)
|
if conn.wgProxyRelay != nil {
|
||||||
if err != nil {
|
conn.wgProxyRelay.Pause()
|
||||||
if wgProxy != nil {
|
}
|
||||||
if err := wgProxy.CloseConn(); err != nil {
|
|
||||||
conn.log.Warnf("Failed to close turn connection: %v", err)
|
if wgProxy != nil {
|
||||||
}
|
wgProxy.Work()
|
||||||
}
|
}
|
||||||
conn.log.Warnf("Failed to update wg peer configuration: %v", err)
|
|
||||||
|
if err = conn.configureWGEndpoint(ep); err != nil {
|
||||||
|
conn.handleConfigurationFailure(err, wgProxy)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
wgConfigWorkaround()
|
wgConfigWorkaround()
|
||||||
|
|
||||||
if conn.wgProxyICE != nil {
|
|
||||||
if err := conn.wgProxyICE.CloseConn(); err != nil {
|
|
||||||
conn.log.Warnf("failed to close deprecated wg proxy conn: %v", err)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
conn.wgProxyICE = wgProxy
|
|
||||||
|
|
||||||
conn.currentConnPriority = priority
|
conn.currentConnPriority = priority
|
||||||
|
conn.statusICE.Set(StatusConnected)
|
||||||
|
conn.updateIceState(iceConnInfo)
|
||||||
conn.doOnConnected(iceConnInfo.RosenpassPubKey, iceConnInfo.RosenpassAddr)
|
conn.doOnConnected(iceConnInfo.RosenpassPubKey, iceConnInfo.RosenpassAddr)
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -492,11 +495,18 @@ func (conn *Conn) onWorkerICEStateDisconnected(newState ConnStatus) {
|
|||||||
|
|
||||||
conn.log.Tracef("ICE connection state changed to %s", newState)
|
conn.log.Tracef("ICE connection state changed to %s", newState)
|
||||||
|
|
||||||
|
if conn.wgProxyICE != nil {
|
||||||
|
if err := conn.wgProxyICE.CloseConn(); err != nil {
|
||||||
|
conn.log.Warnf("failed to close deprecated wg proxy conn: %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// switch back to relay connection
|
// switch back to relay connection
|
||||||
if conn.endpointRelay != nil && conn.currentConnPriority != connPriorityRelay {
|
if conn.isReadyToUpgrade() {
|
||||||
conn.log.Debugf("ICE disconnected, set Relay to active connection")
|
conn.log.Debugf("ICE disconnected, set Relay to active connection")
|
||||||
err := conn.configureWGEndpoint(conn.endpointRelay)
|
conn.wgProxyRelay.Work()
|
||||||
if err != nil {
|
|
||||||
|
if err := conn.configureWGEndpoint(conn.wgProxyRelay.EndpointAddr()); err != nil {
|
||||||
conn.log.Errorf("failed to switch to relay conn: %v", err)
|
conn.log.Errorf("failed to switch to relay conn: %v", err)
|
||||||
}
|
}
|
||||||
conn.workerRelay.EnableWgWatcher(conn.ctx)
|
conn.workerRelay.EnableWgWatcher(conn.ctx)
|
||||||
@ -506,10 +516,7 @@ func (conn *Conn) onWorkerICEStateDisconnected(newState ConnStatus) {
|
|||||||
changed := conn.statusICE.Get() != newState && newState != StatusConnecting
|
changed := conn.statusICE.Get() != newState && newState != StatusConnecting
|
||||||
conn.statusICE.Set(newState)
|
conn.statusICE.Set(newState)
|
||||||
|
|
||||||
select {
|
conn.notifyReconnectLoopICEDisconnected(changed)
|
||||||
case conn.iCEDisconnected <- changed:
|
|
||||||
default:
|
|
||||||
}
|
|
||||||
|
|
||||||
peerState := State{
|
peerState := State{
|
||||||
PubKey: conn.config.Key,
|
PubKey: conn.config.Key,
|
||||||
@ -530,61 +537,48 @@ func (conn *Conn) relayConnectionIsReady(rci RelayConnInfo) {
|
|||||||
|
|
||||||
if conn.ctx.Err() != nil {
|
if conn.ctx.Err() != nil {
|
||||||
if err := rci.relayedConn.Close(); err != nil {
|
if err := rci.relayedConn.Close(); err != nil {
|
||||||
log.Warnf("failed to close unnecessary relayed connection: %v", err)
|
conn.log.Warnf("failed to close unnecessary relayed connection: %v", err)
|
||||||
}
|
}
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
conn.log.Debugf("Relay connection is ready to use")
|
conn.log.Debugf("Relay connection has been established, setup the WireGuard")
|
||||||
conn.statusRelay.Set(StatusConnected)
|
|
||||||
|
|
||||||
wgProxy := conn.wgProxyFactory.GetProxy()
|
wgProxy, err := conn.newProxy(rci.relayedConn)
|
||||||
endpoint, err := wgProxy.AddTurnConn(conn.ctx, rci.relayedConn)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
conn.log.Errorf("failed to add relayed net.Conn to local proxy: %v", err)
|
conn.log.Errorf("failed to add relayed net.Conn to local proxy: %v", err)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
conn.log.Infof("created new wgProxy for relay connection: %s", endpoint)
|
|
||||||
|
|
||||||
endpointUdpAddr, _ := net.ResolveUDPAddr(endpoint.Network(), endpoint.String())
|
conn.log.Infof("created new wgProxy for relay connection: %s", wgProxy.EndpointAddr().String())
|
||||||
conn.endpointRelay = endpointUdpAddr
|
|
||||||
conn.log.Debugf("conn resolved IP for %s: %s", endpoint, endpointUdpAddr.IP)
|
|
||||||
|
|
||||||
defer conn.updateRelayStatus(rci.relayedConn.RemoteAddr().String(), rci.rosenpassPubKey)
|
if conn.iceP2PIsActive() {
|
||||||
|
conn.log.Debugf("do not switch to relay because current priority is: %v", conn.currentConnPriority)
|
||||||
if conn.currentConnPriority > connPriorityRelay {
|
conn.wgProxyRelay = wgProxy
|
||||||
if conn.statusICE.Get() == StatusConnected {
|
conn.statusRelay.Set(StatusConnected)
|
||||||
log.Debugf("do not switch to relay because current priority is: %v", conn.currentConnPriority)
|
conn.updateRelayStatus(rci.relayedConn.RemoteAddr().String(), rci.rosenpassPubKey)
|
||||||
return
|
return
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
conn.connIDRelay = nbnet.GenerateConnID()
|
if err := conn.runBeforeAddPeerHooks(wgProxy.EndpointAddr().IP); err != nil {
|
||||||
for _, hook := range conn.beforeAddPeerHooks {
|
conn.log.Errorf("Before add peer hook failed: %v", err)
|
||||||
if err := hook(conn.connIDRelay, endpointUdpAddr.IP); err != nil {
|
|
||||||
conn.log.Errorf("Before add peer hook failed: %v", err)
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
err = conn.configureWGEndpoint(endpointUdpAddr)
|
wgProxy.Work()
|
||||||
if err != nil {
|
if err := conn.configureWGEndpoint(wgProxy.EndpointAddr()); err != nil {
|
||||||
if err := wgProxy.CloseConn(); err != nil {
|
if err := wgProxy.CloseConn(); err != nil {
|
||||||
conn.log.Warnf("Failed to close relay connection: %v", err)
|
conn.log.Warnf("Failed to close relay connection: %v", err)
|
||||||
}
|
}
|
||||||
conn.log.Errorf("Failed to update wg peer configuration: %v", err)
|
conn.log.Errorf("Failed to update WireGuard peer configuration: %v", err)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
conn.workerRelay.EnableWgWatcher(conn.ctx)
|
conn.workerRelay.EnableWgWatcher(conn.ctx)
|
||||||
|
|
||||||
wgConfigWorkaround()
|
wgConfigWorkaround()
|
||||||
|
|
||||||
if conn.wgProxyRelay != nil {
|
|
||||||
if err := conn.wgProxyRelay.CloseConn(); err != nil {
|
|
||||||
conn.log.Warnf("failed to close deprecated wg proxy conn: %v", err)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
conn.wgProxyRelay = wgProxy
|
|
||||||
conn.currentConnPriority = connPriorityRelay
|
conn.currentConnPriority = connPriorityRelay
|
||||||
|
conn.statusRelay.Set(StatusConnected)
|
||||||
|
conn.wgProxyRelay = wgProxy
|
||||||
|
conn.updateRelayStatus(rci.relayedConn.RemoteAddr().String(), rci.rosenpassPubKey)
|
||||||
conn.log.Infof("start to communicate with peer via relay")
|
conn.log.Infof("start to communicate with peer via relay")
|
||||||
conn.doOnConnected(rci.rosenpassPubKey, rci.rosenpassAddr)
|
conn.doOnConnected(rci.rosenpassPubKey, rci.rosenpassAddr)
|
||||||
}
|
}
|
||||||
@ -597,29 +591,23 @@ func (conn *Conn) onWorkerRelayStateDisconnected() {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
log.Debugf("relay connection is disconnected")
|
conn.log.Debugf("relay connection is disconnected")
|
||||||
|
|
||||||
if conn.currentConnPriority == connPriorityRelay {
|
if conn.currentConnPriority == connPriorityRelay {
|
||||||
log.Debugf("clean up WireGuard config")
|
conn.log.Debugf("clean up WireGuard config")
|
||||||
err := conn.config.WgConfig.WgInterface.RemovePeer(conn.config.WgConfig.RemoteKey)
|
if err := conn.removeWgPeer(); err != nil {
|
||||||
if err != nil {
|
|
||||||
conn.log.Errorf("failed to remove wg endpoint: %v", err)
|
conn.log.Errorf("failed to remove wg endpoint: %v", err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if conn.wgProxyRelay != nil {
|
if conn.wgProxyRelay != nil {
|
||||||
conn.endpointRelay = nil
|
|
||||||
_ = conn.wgProxyRelay.CloseConn()
|
_ = conn.wgProxyRelay.CloseConn()
|
||||||
conn.wgProxyRelay = nil
|
conn.wgProxyRelay = nil
|
||||||
}
|
}
|
||||||
|
|
||||||
changed := conn.statusRelay.Get() != StatusDisconnected
|
changed := conn.statusRelay.Get() != StatusDisconnected
|
||||||
conn.statusRelay.Set(StatusDisconnected)
|
conn.statusRelay.Set(StatusDisconnected)
|
||||||
|
conn.notifyReconnectLoopRelayDisconnected(changed)
|
||||||
select {
|
|
||||||
case conn.relayDisconnected <- changed:
|
|
||||||
default:
|
|
||||||
}
|
|
||||||
|
|
||||||
peerState := State{
|
peerState := State{
|
||||||
PubKey: conn.config.Key,
|
PubKey: conn.config.Key,
|
||||||
@ -627,9 +615,7 @@ func (conn *Conn) onWorkerRelayStateDisconnected() {
|
|||||||
Relayed: conn.isRelayed(),
|
Relayed: conn.isRelayed(),
|
||||||
ConnStatusUpdate: time.Now(),
|
ConnStatusUpdate: time.Now(),
|
||||||
}
|
}
|
||||||
|
if err := conn.statusRecorder.UpdatePeerRelayedStateToDisconnected(peerState); err != nil {
|
||||||
err := conn.statusRecorder.UpdatePeerRelayedStateToDisconnected(peerState)
|
|
||||||
if err != nil {
|
|
||||||
conn.log.Warnf("unable to save peer's state to Relay disconnected, got error: %v", err)
|
conn.log.Warnf("unable to save peer's state to Relay disconnected, got error: %v", err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -765,6 +751,16 @@ func (conn *Conn) isConnected() bool {
|
|||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (conn *Conn) runBeforeAddPeerHooks(ip net.IP) error {
|
||||||
|
conn.connIDICE = nbnet.GenerateConnID()
|
||||||
|
for _, hook := range conn.beforeAddPeerHooks {
|
||||||
|
if err := hook(conn.connIDICE, ip); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
func (conn *Conn) freeUpConnID() {
|
func (conn *Conn) freeUpConnID() {
|
||||||
if conn.connIDRelay != "" {
|
if conn.connIDRelay != "" {
|
||||||
for _, hook := range conn.afterRemovePeerHooks {
|
for _, hook := range conn.afterRemovePeerHooks {
|
||||||
@ -785,21 +781,52 @@ func (conn *Conn) freeUpConnID() {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (conn *Conn) getEndpointForICEConnInfo(iceConnInfo ICEConnInfo) (net.Addr, wgproxy.Proxy, error) {
|
func (conn *Conn) newProxy(remoteConn net.Conn) (wgproxy.Proxy, error) {
|
||||||
if !iceConnInfo.RelayedOnLocal {
|
conn.log.Debugf("setup proxied WireGuard connection")
|
||||||
return iceConnInfo.RemoteConn.RemoteAddr(), nil, nil
|
|
||||||
}
|
|
||||||
conn.log.Debugf("setup ice turn connection")
|
|
||||||
wgProxy := conn.wgProxyFactory.GetProxy()
|
wgProxy := conn.wgProxyFactory.GetProxy()
|
||||||
ep, err := wgProxy.AddTurnConn(conn.ctx, iceConnInfo.RemoteConn)
|
if err := wgProxy.AddTurnConn(conn.ctx, remoteConn); err != nil {
|
||||||
if err != nil {
|
|
||||||
conn.log.Errorf("failed to add turn net.Conn to local proxy: %v", err)
|
conn.log.Errorf("failed to add turn net.Conn to local proxy: %v", err)
|
||||||
if errClose := wgProxy.CloseConn(); errClose != nil {
|
return nil, err
|
||||||
conn.log.Warnf("failed to close turn proxy connection: %v", errClose)
|
}
|
||||||
}
|
return wgProxy, nil
|
||||||
return nil, nil, err
|
}
|
||||||
|
|
||||||
|
func (conn *Conn) isReadyToUpgrade() bool {
|
||||||
|
return conn.wgProxyRelay != nil && conn.currentConnPriority != connPriorityRelay
|
||||||
|
}
|
||||||
|
|
||||||
|
func (conn *Conn) iceP2PIsActive() bool {
|
||||||
|
return conn.currentConnPriority == connPriorityICEP2P && conn.statusICE.Get() == StatusConnected
|
||||||
|
}
|
||||||
|
|
||||||
|
func (conn *Conn) removeWgPeer() error {
|
||||||
|
return conn.config.WgConfig.WgInterface.RemovePeer(conn.config.WgConfig.RemoteKey)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (conn *Conn) notifyReconnectLoopRelayDisconnected(changed bool) {
|
||||||
|
select {
|
||||||
|
case conn.relayDisconnected <- changed:
|
||||||
|
default:
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (conn *Conn) notifyReconnectLoopICEDisconnected(changed bool) {
|
||||||
|
select {
|
||||||
|
case conn.iCEDisconnected <- changed:
|
||||||
|
default:
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (conn *Conn) handleConfigurationFailure(err error, wgProxy wgproxy.Proxy) {
|
||||||
|
conn.log.Warnf("Failed to update wg peer configuration: %v", err)
|
||||||
|
if wgProxy != nil {
|
||||||
|
if ierr := wgProxy.CloseConn(); ierr != nil {
|
||||||
|
conn.log.Warnf("Failed to close wg proxy: %v", ierr)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if conn.wgProxyRelay != nil {
|
||||||
|
conn.wgProxyRelay.Work()
|
||||||
}
|
}
|
||||||
return ep, wgProxy, nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func isRosenpassEnabled(remoteRosenpassPubKey []byte) bool {
|
func isRosenpassEnabled(remoteRosenpassPubKey []byte) bool {
|
||||||
|
@ -5,7 +5,6 @@ package ebpf
|
|||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"fmt"
|
"fmt"
|
||||||
"io"
|
|
||||||
"net"
|
"net"
|
||||||
"os"
|
"os"
|
||||||
"sync"
|
"sync"
|
||||||
@ -94,13 +93,12 @@ func (p *WGEBPFProxy) Listen() error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// AddTurnConn add new turn connection for the proxy
|
// AddTurnConn add new turn connection for the proxy
|
||||||
func (p *WGEBPFProxy) AddTurnConn(ctx context.Context, turnConn net.Conn) (net.Addr, error) {
|
func (p *WGEBPFProxy) AddTurnConn(turnConn net.Conn) (*net.UDPAddr, error) {
|
||||||
wgEndpointPort, err := p.storeTurnConn(turnConn)
|
wgEndpointPort, err := p.storeTurnConn(turnConn)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
go p.proxyToLocal(ctx, wgEndpointPort, turnConn)
|
|
||||||
log.Infof("turn conn added to wg proxy store: %s, endpoint port: :%d", turnConn.RemoteAddr(), wgEndpointPort)
|
log.Infof("turn conn added to wg proxy store: %s, endpoint port: :%d", turnConn.RemoteAddr(), wgEndpointPort)
|
||||||
|
|
||||||
wgEndpoint := &net.UDPAddr{
|
wgEndpoint := &net.UDPAddr{
|
||||||
@ -137,35 +135,6 @@ func (p *WGEBPFProxy) Free() error {
|
|||||||
return nberrors.FormatErrorOrNil(result)
|
return nberrors.FormatErrorOrNil(result)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (p *WGEBPFProxy) proxyToLocal(ctx context.Context, endpointPort uint16, remoteConn net.Conn) {
|
|
||||||
defer p.removeTurnConn(endpointPort)
|
|
||||||
|
|
||||||
var (
|
|
||||||
err error
|
|
||||||
n int
|
|
||||||
)
|
|
||||||
buf := make([]byte, 1500)
|
|
||||||
for ctx.Err() == nil {
|
|
||||||
n, err = remoteConn.Read(buf)
|
|
||||||
if err != nil {
|
|
||||||
if ctx.Err() != nil {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
if err != io.EOF {
|
|
||||||
log.Errorf("failed to read from turn conn (endpoint: :%d): %s", endpointPort, err)
|
|
||||||
}
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
if err := p.sendPkg(buf[:n], endpointPort); err != nil {
|
|
||||||
if ctx.Err() != nil || p.ctx.Err() != nil {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
log.Errorf("failed to write out turn pkg to local conn: %v", err)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// proxyToRemote read messages from local WireGuard interface and forward it to remote conn
|
// proxyToRemote read messages from local WireGuard interface and forward it to remote conn
|
||||||
// From this go routine has only one instance.
|
// From this go routine has only one instance.
|
||||||
func (p *WGEBPFProxy) proxyToRemote() {
|
func (p *WGEBPFProxy) proxyToRemote() {
|
||||||
@ -280,7 +249,7 @@ func (p *WGEBPFProxy) prepareSenderRawSocket() (net.PacketConn, error) {
|
|||||||
return packetConn, nil
|
return packetConn, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (p *WGEBPFProxy) sendPkg(data []byte, port uint16) error {
|
func (p *WGEBPFProxy) sendPkg(data []byte, port int) error {
|
||||||
localhost := net.ParseIP("127.0.0.1")
|
localhost := net.ParseIP("127.0.0.1")
|
||||||
|
|
||||||
payload := gopacket.Payload(data)
|
payload := gopacket.Payload(data)
|
||||||
|
@ -4,8 +4,13 @@ package ebpf
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
|
"io"
|
||||||
"net"
|
"net"
|
||||||
|
"sync"
|
||||||
|
|
||||||
|
log "github.com/sirupsen/logrus"
|
||||||
)
|
)
|
||||||
|
|
||||||
// ProxyWrapper help to keep the remoteConn instance for net.Conn.Close function call
|
// ProxyWrapper help to keep the remoteConn instance for net.Conn.Close function call
|
||||||
@ -13,20 +18,55 @@ type ProxyWrapper struct {
|
|||||||
WgeBPFProxy *WGEBPFProxy
|
WgeBPFProxy *WGEBPFProxy
|
||||||
|
|
||||||
remoteConn net.Conn
|
remoteConn net.Conn
|
||||||
cancel context.CancelFunc // with thic cancel function, we stop remoteToLocal thread
|
ctx context.Context
|
||||||
|
cancel context.CancelFunc
|
||||||
|
|
||||||
|
wgEndpointAddr *net.UDPAddr
|
||||||
|
|
||||||
|
pausedMu sync.Mutex
|
||||||
|
paused bool
|
||||||
|
isStarted bool
|
||||||
}
|
}
|
||||||
|
|
||||||
func (e *ProxyWrapper) AddTurnConn(ctx context.Context, remoteConn net.Conn) (net.Addr, error) {
|
func (p *ProxyWrapper) AddTurnConn(ctx context.Context, remoteConn net.Conn) error {
|
||||||
ctxConn, cancel := context.WithCancel(ctx)
|
addr, err := p.WgeBPFProxy.AddTurnConn(remoteConn)
|
||||||
addr, err := e.WgeBPFProxy.AddTurnConn(ctxConn, remoteConn)
|
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
cancel()
|
return fmt.Errorf("add turn conn: %w", err)
|
||||||
return nil, fmt.Errorf("add turn conn: %w", err)
|
|
||||||
}
|
}
|
||||||
e.remoteConn = remoteConn
|
p.remoteConn = remoteConn
|
||||||
e.cancel = cancel
|
p.ctx, p.cancel = context.WithCancel(ctx)
|
||||||
return addr, err
|
p.wgEndpointAddr = addr
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *ProxyWrapper) EndpointAddr() *net.UDPAddr {
|
||||||
|
return p.wgEndpointAddr
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *ProxyWrapper) Work() {
|
||||||
|
if p.remoteConn == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
p.pausedMu.Lock()
|
||||||
|
p.paused = false
|
||||||
|
p.pausedMu.Unlock()
|
||||||
|
|
||||||
|
if !p.isStarted {
|
||||||
|
p.isStarted = true
|
||||||
|
go p.proxyToLocal(p.ctx)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *ProxyWrapper) Pause() {
|
||||||
|
if p.remoteConn == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
log.Tracef("pause proxy reading from: %s", p.remoteConn.RemoteAddr())
|
||||||
|
p.pausedMu.Lock()
|
||||||
|
p.paused = true
|
||||||
|
p.pausedMu.Unlock()
|
||||||
}
|
}
|
||||||
|
|
||||||
// CloseConn close the remoteConn and automatically remove the conn instance from the map
|
// CloseConn close the remoteConn and automatically remove the conn instance from the map
|
||||||
@ -42,3 +82,45 @@ func (e *ProxyWrapper) CloseConn() error {
|
|||||||
}
|
}
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (p *ProxyWrapper) proxyToLocal(ctx context.Context) {
|
||||||
|
defer p.WgeBPFProxy.removeTurnConn(uint16(p.wgEndpointAddr.Port))
|
||||||
|
|
||||||
|
buf := make([]byte, 1500)
|
||||||
|
for {
|
||||||
|
n, err := p.readFromRemote(ctx, buf)
|
||||||
|
if err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
p.pausedMu.Lock()
|
||||||
|
if p.paused {
|
||||||
|
p.pausedMu.Unlock()
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
err = p.WgeBPFProxy.sendPkg(buf[:n], p.wgEndpointAddr.Port)
|
||||||
|
p.pausedMu.Unlock()
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
if ctx.Err() != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
log.Errorf("failed to write out turn pkg to local conn: %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *ProxyWrapper) readFromRemote(ctx context.Context, buf []byte) (int, error) {
|
||||||
|
n, err := p.remoteConn.Read(buf)
|
||||||
|
if err != nil {
|
||||||
|
if ctx.Err() != nil {
|
||||||
|
return 0, ctx.Err()
|
||||||
|
}
|
||||||
|
if !errors.Is(err, io.EOF) {
|
||||||
|
log.Errorf("failed to read from turn conn (endpoint: :%d): %s", p.wgEndpointAddr.Port, err)
|
||||||
|
}
|
||||||
|
return 0, err
|
||||||
|
}
|
||||||
|
return n, nil
|
||||||
|
}
|
||||||
|
@ -7,6 +7,9 @@ import (
|
|||||||
|
|
||||||
// Proxy is a transfer layer between the relayed connection and the WireGuard
|
// Proxy is a transfer layer between the relayed connection and the WireGuard
|
||||||
type Proxy interface {
|
type Proxy interface {
|
||||||
AddTurnConn(ctx context.Context, turnConn net.Conn) (net.Addr, error)
|
AddTurnConn(ctx context.Context, turnConn net.Conn) error
|
||||||
|
EndpointAddr() *net.UDPAddr
|
||||||
|
Work()
|
||||||
|
Pause()
|
||||||
CloseConn() error
|
CloseConn() error
|
||||||
}
|
}
|
||||||
|
@ -114,7 +114,7 @@ func TestProxyCloseByRemoteConn(t *testing.T) {
|
|||||||
for _, tt := range tests {
|
for _, tt := range tests {
|
||||||
t.Run(tt.name, func(t *testing.T) {
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
relayedConn := newMockConn()
|
relayedConn := newMockConn()
|
||||||
_, err := tt.proxy.AddTurnConn(ctx, relayedConn)
|
err := tt.proxy.AddTurnConn(ctx, relayedConn)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Errorf("error: %v", err)
|
t.Errorf("error: %v", err)
|
||||||
}
|
}
|
||||||
|
@ -15,13 +15,17 @@ import (
|
|||||||
// WGUserSpaceProxy proxies
|
// WGUserSpaceProxy proxies
|
||||||
type WGUserSpaceProxy struct {
|
type WGUserSpaceProxy struct {
|
||||||
localWGListenPort int
|
localWGListenPort int
|
||||||
ctx context.Context
|
|
||||||
cancel context.CancelFunc
|
|
||||||
|
|
||||||
remoteConn net.Conn
|
remoteConn net.Conn
|
||||||
localConn net.Conn
|
localConn net.Conn
|
||||||
|
ctx context.Context
|
||||||
|
cancel context.CancelFunc
|
||||||
closeMu sync.Mutex
|
closeMu sync.Mutex
|
||||||
closed bool
|
closed bool
|
||||||
|
|
||||||
|
pausedMu sync.Mutex
|
||||||
|
paused bool
|
||||||
|
isStarted bool
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewWGUserSpaceProxy instantiate a user space WireGuard proxy. This is not a thread safe implementation
|
// NewWGUserSpaceProxy instantiate a user space WireGuard proxy. This is not a thread safe implementation
|
||||||
@ -33,24 +37,60 @@ func NewWGUserSpaceProxy(wgPort int) *WGUserSpaceProxy {
|
|||||||
return p
|
return p
|
||||||
}
|
}
|
||||||
|
|
||||||
// AddTurnConn start the proxy with the given remote conn
|
// AddTurnConn
|
||||||
func (p *WGUserSpaceProxy) AddTurnConn(ctx context.Context, remoteConn net.Conn) (net.Addr, error) {
|
// The provided Context must be non-nil. If the context expires before
|
||||||
p.ctx, p.cancel = context.WithCancel(ctx)
|
// the connection is complete, an error is returned. Once successfully
|
||||||
|
// connected, any expiration of the context will not affect the
|
||||||
p.remoteConn = remoteConn
|
// connection.
|
||||||
|
func (p *WGUserSpaceProxy) AddTurnConn(ctx context.Context, remoteConn net.Conn) error {
|
||||||
var err error
|
|
||||||
dialer := net.Dialer{}
|
dialer := net.Dialer{}
|
||||||
p.localConn, err = dialer.DialContext(p.ctx, "udp", fmt.Sprintf(":%d", p.localWGListenPort))
|
localConn, err := dialer.DialContext(ctx, "udp", fmt.Sprintf(":%d", p.localWGListenPort))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Errorf("failed dialing to local Wireguard port %s", err)
|
log.Errorf("failed dialing to local Wireguard port %s", err)
|
||||||
return nil, err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
go p.proxyToRemote()
|
p.ctx, p.cancel = context.WithCancel(ctx)
|
||||||
go p.proxyToLocal()
|
p.localConn = localConn
|
||||||
|
p.remoteConn = remoteConn
|
||||||
|
|
||||||
return p.localConn.LocalAddr(), err
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *WGUserSpaceProxy) EndpointAddr() *net.UDPAddr {
|
||||||
|
if p.localConn == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
endpointUdpAddr, _ := net.ResolveUDPAddr(p.localConn.LocalAddr().Network(), p.localConn.LocalAddr().String())
|
||||||
|
return endpointUdpAddr
|
||||||
|
}
|
||||||
|
|
||||||
|
// Work starts the proxy or resumes it if it was paused
|
||||||
|
func (p *WGUserSpaceProxy) Work() {
|
||||||
|
if p.remoteConn == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
p.pausedMu.Lock()
|
||||||
|
p.paused = false
|
||||||
|
p.pausedMu.Unlock()
|
||||||
|
|
||||||
|
if !p.isStarted {
|
||||||
|
p.isStarted = true
|
||||||
|
go p.proxyToRemote(p.ctx)
|
||||||
|
go p.proxyToLocal(p.ctx)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Pause pauses the proxy from receiving data from the remote peer
|
||||||
|
func (p *WGUserSpaceProxy) Pause() {
|
||||||
|
if p.remoteConn == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
p.pausedMu.Lock()
|
||||||
|
p.paused = true
|
||||||
|
p.pausedMu.Unlock()
|
||||||
}
|
}
|
||||||
|
|
||||||
// CloseConn close the localConn
|
// CloseConn close the localConn
|
||||||
@ -85,7 +125,7 @@ func (p *WGUserSpaceProxy) close() error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// proxyToRemote proxies from Wireguard to the RemoteKey
|
// proxyToRemote proxies from Wireguard to the RemoteKey
|
||||||
func (p *WGUserSpaceProxy) proxyToRemote() {
|
func (p *WGUserSpaceProxy) proxyToRemote(ctx context.Context) {
|
||||||
defer func() {
|
defer func() {
|
||||||
if err := p.close(); err != nil {
|
if err := p.close(); err != nil {
|
||||||
log.Warnf("error in proxy to remote loop: %s", err)
|
log.Warnf("error in proxy to remote loop: %s", err)
|
||||||
@ -93,10 +133,10 @@ func (p *WGUserSpaceProxy) proxyToRemote() {
|
|||||||
}()
|
}()
|
||||||
|
|
||||||
buf := make([]byte, 1500)
|
buf := make([]byte, 1500)
|
||||||
for p.ctx.Err() == nil {
|
for ctx.Err() == nil {
|
||||||
n, err := p.localConn.Read(buf)
|
n, err := p.localConn.Read(buf)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
if p.ctx.Err() != nil {
|
if ctx.Err() != nil {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
log.Debugf("failed to read from wg interface conn: %s", err)
|
log.Debugf("failed to read from wg interface conn: %s", err)
|
||||||
@ -105,7 +145,7 @@ func (p *WGUserSpaceProxy) proxyToRemote() {
|
|||||||
|
|
||||||
_, err = p.remoteConn.Write(buf[:n])
|
_, err = p.remoteConn.Write(buf[:n])
|
||||||
if err != nil {
|
if err != nil {
|
||||||
if p.ctx.Err() != nil {
|
if ctx.Err() != nil {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -116,7 +156,8 @@ func (p *WGUserSpaceProxy) proxyToRemote() {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// proxyToLocal proxies from the Remote peer to local WireGuard
|
// proxyToLocal proxies from the Remote peer to local WireGuard
|
||||||
func (p *WGUserSpaceProxy) proxyToLocal() {
|
// if the proxy is paused it will drain the remote conn and drop the packets
|
||||||
|
func (p *WGUserSpaceProxy) proxyToLocal(ctx context.Context) {
|
||||||
defer func() {
|
defer func() {
|
||||||
if err := p.close(); err != nil {
|
if err := p.close(); err != nil {
|
||||||
log.Warnf("error in proxy to local loop: %s", err)
|
log.Warnf("error in proxy to local loop: %s", err)
|
||||||
@ -124,19 +165,27 @@ func (p *WGUserSpaceProxy) proxyToLocal() {
|
|||||||
}()
|
}()
|
||||||
|
|
||||||
buf := make([]byte, 1500)
|
buf := make([]byte, 1500)
|
||||||
for p.ctx.Err() == nil {
|
for {
|
||||||
n, err := p.remoteConn.Read(buf)
|
n, err := p.remoteConn.Read(buf)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
if p.ctx.Err() != nil {
|
if ctx.Err() != nil {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
log.Errorf("failed to read from remote conn: %s, %s", p.remoteConn.RemoteAddr(), err)
|
log.Errorf("failed to read from remote conn: %s, %s", p.remoteConn.RemoteAddr(), err)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
p.pausedMu.Lock()
|
||||||
|
if p.paused {
|
||||||
|
p.pausedMu.Unlock()
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
_, err = p.localConn.Write(buf[:n])
|
_, err = p.localConn.Write(buf[:n])
|
||||||
|
p.pausedMu.Unlock()
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
if p.ctx.Err() != nil {
|
if ctx.Err() != nil {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
log.Debugf("failed to write to wg interface conn: %s", err)
|
log.Debugf("failed to write to wg interface conn: %s", err)
|
||||||
|
Loading…
x
Reference in New Issue
Block a user