mirror of
https://github.com/netbirdio/netbird.git
synced 2025-03-13 06:08:48 +01:00
Implement redirect logic in UDP proxy
This commit is contained in:
parent
06a17f0eee
commit
4db73a13d7
@ -92,7 +92,7 @@ func (p *ProxyBind) Pause() {
|
||||
p.pausedCond.L.Unlock()
|
||||
}
|
||||
|
||||
func (p *ProxyBind) RedirectTo(endpoint *net.UDPAddr) {
|
||||
func (p *ProxyBind) RedirectAs(endpoint *net.UDPAddr) {
|
||||
p.pausedCond.L.Lock()
|
||||
p.paused = false
|
||||
|
||||
|
@ -81,7 +81,7 @@ func (p *ProxyWrapper) Pause() {
|
||||
p.pausedCond.L.Unlock()
|
||||
}
|
||||
|
||||
func (p *ProxyWrapper) RedirectTo(endpoint *net.UDPAddr) {
|
||||
func (p *ProxyWrapper) RedirectAs(endpoint *net.UDPAddr) {
|
||||
p.pausedCond.L.Lock()
|
||||
p.paused = false
|
||||
|
||||
|
@ -11,6 +11,11 @@ type Proxy interface {
|
||||
EndpointAddr() *net.UDPAddr // EndpointAddr returns the address of the WireGuard peer endpoint
|
||||
Work() // Work start or resume the proxy
|
||||
Pause() // Pause to forward the packages from remote connection to WireGuard. The opposite way still works.
|
||||
/*
|
||||
RedirectAs resume the forwarding the packages from relayed connection to WireGuard interface if it was paused
|
||||
and rewrite the src address to the endpoint address.
|
||||
With this logic can avoid the package loss from relayed connections.
|
||||
*/
|
||||
RedirectAs(endpoint *net.UDPAddr)
|
||||
CloseConn() error
|
||||
RedirectTo(endpoint *net.UDPAddr)
|
||||
}
|
||||
|
@ -20,13 +20,15 @@ type WGUDPProxy struct {
|
||||
|
||||
remoteConn net.Conn
|
||||
localConn net.Conn
|
||||
srcFakerConn *SrcFaker
|
||||
sendPkg func(data []byte) (int, error)
|
||||
ctx context.Context
|
||||
cancel context.CancelFunc
|
||||
closeMu sync.Mutex
|
||||
closed bool
|
||||
|
||||
pausedMu sync.Mutex
|
||||
paused bool
|
||||
pausedCond *sync.Cond
|
||||
isStarted bool
|
||||
}
|
||||
|
||||
@ -35,6 +37,7 @@ func NewWGUDPProxy(wgPort int) *WGUDPProxy {
|
||||
log.Debugf("Initializing new user space proxy with port %d", wgPort)
|
||||
p := &WGUDPProxy{
|
||||
localWGListenPort: wgPort,
|
||||
pausedCond: sync.NewCond(&sync.Mutex{}),
|
||||
}
|
||||
return p
|
||||
}
|
||||
@ -54,6 +57,7 @@ func (p *WGUDPProxy) AddTurnConn(ctx context.Context, endpoint *net.UDPAddr, rem
|
||||
|
||||
p.ctx, p.cancel = context.WithCancel(ctx)
|
||||
p.localConn = localConn
|
||||
p.sendPkg = p.localConn.Write
|
||||
p.remoteConn = remoteConn
|
||||
|
||||
return err
|
||||
@ -73,15 +77,17 @@ func (p *WGUDPProxy) Work() {
|
||||
return
|
||||
}
|
||||
|
||||
p.pausedMu.Lock()
|
||||
p.pausedCond.L.Lock()
|
||||
p.paused = false
|
||||
p.pausedMu.Unlock()
|
||||
p.sendPkg = p.localConn.Write
|
||||
|
||||
if !p.isStarted {
|
||||
p.isStarted = true
|
||||
go p.proxyToRemote(p.ctx)
|
||||
go p.proxyToLocal(p.ctx)
|
||||
}
|
||||
p.pausedCond.L.Unlock()
|
||||
p.pausedCond.Signal()
|
||||
}
|
||||
|
||||
// Pause pauses the proxy from receiving data from the remote peer
|
||||
@ -90,13 +96,33 @@ func (p *WGUDPProxy) Pause() {
|
||||
return
|
||||
}
|
||||
|
||||
p.pausedMu.Lock()
|
||||
p.pausedCond.L.Lock()
|
||||
p.paused = true
|
||||
p.pausedMu.Unlock()
|
||||
p.pausedCond.L.Unlock()
|
||||
}
|
||||
|
||||
func (p *WGUDPProxy) RedirectTo(endpoint *net.UDPAddr) {
|
||||
// todo implement me
|
||||
// RedirectAs start to use the fake sourced raw socket as package sender
|
||||
func (p *WGUDPProxy) RedirectAs(endpoint *net.UDPAddr) {
|
||||
p.pausedCond.L.Lock()
|
||||
defer func() {
|
||||
p.pausedCond.L.Unlock()
|
||||
p.pausedCond.Signal()
|
||||
}()
|
||||
|
||||
p.paused = false
|
||||
if p.srcFakerConn != nil {
|
||||
if err := p.srcFakerConn.Close(); err != nil {
|
||||
log.Errorf("failed to close src faker conn: %s", err)
|
||||
}
|
||||
p.srcFakerConn = nil
|
||||
}
|
||||
srcFakerConn, err := NewSrcFaker(p.localWGListenPort, endpoint)
|
||||
if err != nil {
|
||||
log.Errorf("failed to create src faker conn: %s", err)
|
||||
return
|
||||
}
|
||||
p.srcFakerConn = srcFakerConn
|
||||
p.sendPkg = p.srcFakerConn.SendPkg
|
||||
}
|
||||
|
||||
// CloseConn close the localConn
|
||||
@ -108,6 +134,8 @@ func (p *WGUDPProxy) CloseConn() error {
|
||||
}
|
||||
|
||||
func (p *WGUDPProxy) close() error {
|
||||
var result *multierror.Error
|
||||
|
||||
p.closeMu.Lock()
|
||||
defer p.closeMu.Unlock()
|
||||
|
||||
@ -115,11 +143,14 @@ func (p *WGUDPProxy) close() error {
|
||||
if p.closed {
|
||||
return nil
|
||||
}
|
||||
p.closed = true
|
||||
|
||||
p.cancel()
|
||||
|
||||
var result *multierror.Error
|
||||
p.pausedCond.L.Lock()
|
||||
p.paused = false
|
||||
p.pausedCond.L.Unlock()
|
||||
p.pausedCond.Signal()
|
||||
|
||||
if err := p.remoteConn.Close(); err != nil && !errors.Is(err, net.ErrClosed) {
|
||||
result = multierror.Append(result, fmt.Errorf("remote conn: %s", err))
|
||||
}
|
||||
@ -127,6 +158,11 @@ func (p *WGUDPProxy) close() error {
|
||||
if err := p.localConn.Close(); err != nil {
|
||||
result = multierror.Append(result, fmt.Errorf("local conn: %s", err))
|
||||
}
|
||||
|
||||
if err := p.srcFakerConn.Close(); err != nil {
|
||||
result = multierror.Append(result, fmt.Errorf("src faker raw conn: %s", err))
|
||||
}
|
||||
|
||||
return cerrors.FormatErrorOrNil(result)
|
||||
}
|
||||
|
||||
@ -179,14 +215,20 @@ func (p *WGUDPProxy) proxyToLocal(ctx context.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
p.pausedMu.Lock()
|
||||
for {
|
||||
p.pausedCond.L.Lock()
|
||||
if p.paused {
|
||||
p.pausedMu.Unlock()
|
||||
p.pausedCond.Wait()
|
||||
if !p.paused {
|
||||
break
|
||||
}
|
||||
p.pausedCond.L.Unlock()
|
||||
continue
|
||||
}
|
||||
|
||||
_, err = p.localConn.Write(buf[:n])
|
||||
p.pausedMu.Unlock()
|
||||
break
|
||||
}
|
||||
_, err = p.sendPkg(buf[:n])
|
||||
p.pausedCond.L.Unlock()
|
||||
|
||||
if err != nil {
|
||||
if ctx.Err() != nil {
|
||||
|
139
client/iface/wgproxy/udp/rawsocket.go
Normal file
139
client/iface/wgproxy/udp/rawsocket.go
Normal file
@ -0,0 +1,139 @@
|
||||
package udp
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net"
|
||||
"os"
|
||||
"syscall"
|
||||
|
||||
"github.com/google/gopacket"
|
||||
"github.com/google/gopacket/layers"
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
nbnet "github.com/netbirdio/netbird/util/net"
|
||||
)
|
||||
|
||||
var (
|
||||
serializeOpts = gopacket.SerializeOptions{
|
||||
ComputeChecksums: true,
|
||||
FixLengths: true,
|
||||
}
|
||||
|
||||
localHostNetIPAddr = &net.IPAddr{
|
||||
IP: net.ParseIP("127.0.0.1"),
|
||||
}
|
||||
)
|
||||
|
||||
type SrcFaker struct {
|
||||
srcAddr *net.UDPAddr
|
||||
|
||||
rawSocket net.PacketConn
|
||||
ipH gopacket.SerializableLayer
|
||||
udpH gopacket.SerializableLayer
|
||||
layerBuffer gopacket.SerializeBuffer
|
||||
}
|
||||
|
||||
func NewSrcFaker(dstPort int, srcAddr *net.UDPAddr) (*SrcFaker, error) {
|
||||
rawSocket, err := prepareSenderRawSocket()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
ipH, udpH, err := prepareHeaders(dstPort, srcAddr)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
f := &SrcFaker{
|
||||
srcAddr: srcAddr,
|
||||
rawSocket: rawSocket,
|
||||
ipH: ipH,
|
||||
udpH: udpH,
|
||||
layerBuffer: gopacket.NewSerializeBuffer(),
|
||||
}
|
||||
|
||||
return f, nil
|
||||
}
|
||||
|
||||
func (f *SrcFaker) Close() error {
|
||||
return f.rawSocket.Close()
|
||||
}
|
||||
|
||||
func (f *SrcFaker) SendPkg(data []byte) (int, error) {
|
||||
defer func() {
|
||||
if err := f.layerBuffer.Clear(); err != nil {
|
||||
log.Errorf("failed to clear layer buffer: %s", err)
|
||||
}
|
||||
}()
|
||||
|
||||
payload := gopacket.Payload(data)
|
||||
|
||||
err := gopacket.SerializeLayers(f.layerBuffer, serializeOpts, f.ipH, f.udpH, payload)
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("serialize layers: %w", err)
|
||||
}
|
||||
n, err := f.rawSocket.WriteTo(f.layerBuffer.Bytes(), localHostNetIPAddr)
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("write to raw conn: %w", err)
|
||||
}
|
||||
return n, nil
|
||||
}
|
||||
|
||||
func prepareHeaders(dstPort int, srcAddr *net.UDPAddr) (gopacket.SerializableLayer, gopacket.SerializableLayer, error) {
|
||||
ipH := &layers.IPv4{
|
||||
DstIP: net.ParseIP("127.0.0.1"),
|
||||
SrcIP: srcAddr.IP,
|
||||
Version: 4,
|
||||
TTL: 64,
|
||||
Protocol: layers.IPProtocolUDP,
|
||||
}
|
||||
udpH := &layers.UDP{
|
||||
SrcPort: layers.UDPPort(srcAddr.Port),
|
||||
DstPort: layers.UDPPort(dstPort), // dst is the localhost WireGuard port
|
||||
}
|
||||
|
||||
err := udpH.SetNetworkLayerForChecksum(ipH)
|
||||
if err != nil {
|
||||
return nil, nil, fmt.Errorf("set network layer for checksum: %w", err)
|
||||
}
|
||||
|
||||
return ipH, udpH, nil
|
||||
}
|
||||
|
||||
func prepareSenderRawSocket() (net.PacketConn, error) {
|
||||
// Create a raw socket.
|
||||
fd, err := syscall.Socket(syscall.AF_INET, syscall.SOCK_RAW, syscall.IPPROTO_RAW)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("creating raw socket failed: %w", err)
|
||||
}
|
||||
|
||||
// Set the IP_HDRINCL option on the socket to tell the kernel that headers are included in the packet.
|
||||
err = syscall.SetsockoptInt(fd, syscall.IPPROTO_IP, syscall.IP_HDRINCL, 1)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("setting IP_HDRINCL failed: %w", err)
|
||||
}
|
||||
|
||||
// Bind the socket to the "lo" interface.
|
||||
err = syscall.SetsockoptString(fd, syscall.SOL_SOCKET, syscall.SO_BINDTODEVICE, "lo")
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("binding to lo interface failed: %w", err)
|
||||
}
|
||||
|
||||
// Set the fwmark on the socket.
|
||||
err = nbnet.SetSocketOpt(fd)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("setting fwmark failed: %w", err)
|
||||
}
|
||||
|
||||
// Convert the file descriptor to a PacketConn.
|
||||
file := os.NewFile(uintptr(fd), fmt.Sprintf("fd %d", fd))
|
||||
if file == nil {
|
||||
return nil, fmt.Errorf("converting fd to file failed")
|
||||
}
|
||||
packetConn, err := net.FilePacketConn(file)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("converting file to packet conn failed: %w", err)
|
||||
}
|
||||
|
||||
return packetConn, nil
|
||||
}
|
Loading…
Reference in New Issue
Block a user