mirror of
https://github.com/netbirdio/netbird.git
synced 2025-06-20 01:38:41 +02:00
Improve udp implementation
This commit is contained in:
parent
d2616544fe
commit
6a97d44d5d
@ -26,13 +26,11 @@ func (f *Forwarder) handleTCP(r *tcp.ForwarderRequest) {
|
||||
return
|
||||
}
|
||||
|
||||
f.logger.Trace("forwarder: established TCP connection to %v", id)
|
||||
|
||||
// Create wait queue for blocking syscalls
|
||||
wq := waiter.Queue{}
|
||||
|
||||
ep, err2 := r.CreateEndpoint(&wq)
|
||||
if err2 != nil {
|
||||
ep, epErr := r.CreateEndpoint(&wq)
|
||||
if epErr != nil {
|
||||
if err := outConn.Close(); err != nil {
|
||||
f.logger.Error("forwarder: outConn close error: %v", err)
|
||||
}
|
||||
@ -45,6 +43,8 @@ func (f *Forwarder) handleTCP(r *tcp.ForwarderRequest) {
|
||||
|
||||
inConn := gonet.NewTCPConn(&wq, ep)
|
||||
|
||||
f.logger.Trace("forwarder: established TCP connection to %v", id)
|
||||
|
||||
go f.proxyTCP(inConn, outConn)
|
||||
}
|
||||
|
||||
|
@ -6,8 +6,10 @@ import (
|
||||
"fmt"
|
||||
"net"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
"gvisor.dev/gvisor/pkg/tcpip/adapters/gonet"
|
||||
"gvisor.dev/gvisor/pkg/tcpip/stack"
|
||||
"gvisor.dev/gvisor/pkg/tcpip/transport/udp"
|
||||
@ -24,7 +26,7 @@ const (
|
||||
type udpPacketConn struct {
|
||||
conn *gonet.UDPConn
|
||||
outConn net.Conn
|
||||
lastTime time.Time
|
||||
lastSeen atomic.Int64
|
||||
cancel context.CancelFunc
|
||||
}
|
||||
|
||||
@ -84,22 +86,37 @@ func (f *udpForwarder) cleanup() {
|
||||
case <-f.ctx.Done():
|
||||
return
|
||||
case <-ticker.C:
|
||||
f.Lock()
|
||||
now := time.Now()
|
||||
var idleConns []struct {
|
||||
id stack.TransportEndpointID
|
||||
conn *udpPacketConn
|
||||
}
|
||||
|
||||
f.RLock()
|
||||
for id, conn := range f.conns {
|
||||
if now.Sub(conn.lastTime) > udpTimeout {
|
||||
conn.cancel()
|
||||
if err := conn.conn.Close(); err != nil {
|
||||
f.logger.Error("forwarder: UDP conn close error for %v: %v", id, err)
|
||||
}
|
||||
if err := conn.outConn.Close(); err != nil {
|
||||
f.logger.Error("forwarder: UDP outConn close error for %v: %v", id, err)
|
||||
}
|
||||
delete(f.conns, id)
|
||||
f.logger.Trace("forwarder: cleaned up idle UDP connection %v", id)
|
||||
if conn.getIdleDuration() > udpTimeout {
|
||||
idleConns = append(idleConns, struct {
|
||||
id stack.TransportEndpointID
|
||||
conn *udpPacketConn
|
||||
}{id, conn})
|
||||
}
|
||||
}
|
||||
f.Unlock()
|
||||
f.RUnlock()
|
||||
|
||||
for _, idle := range idleConns {
|
||||
idle.conn.cancel()
|
||||
if err := idle.conn.conn.Close(); err != nil {
|
||||
f.logger.Error("forwarder: UDP conn close error for %v: %v", idle.id, err)
|
||||
}
|
||||
if err := idle.conn.outConn.Close(); err != nil {
|
||||
f.logger.Error("forwarder: UDP outConn close error for %v: %v", idle.id, err)
|
||||
}
|
||||
|
||||
f.Lock()
|
||||
delete(f.conns, idle.id)
|
||||
f.Unlock()
|
||||
|
||||
f.logger.Trace("forwarder: cleaned up idle UDP connection %v", idle.id)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -114,47 +131,60 @@ func (f *Forwarder) handleUDP(r *udp.ForwarderRequest) {
|
||||
return
|
||||
}
|
||||
|
||||
f.udpForwarder.RLock()
|
||||
pConn, exists := f.udpForwarder.conns[id]
|
||||
f.udpForwarder.RUnlock()
|
||||
if exists {
|
||||
f.logger.Trace("forwarder: existing UDP connection for %v", id)
|
||||
return
|
||||
}
|
||||
|
||||
outConn, err := (&net.Dialer{}).DialContext(f.ctx, "udp", dstAddr)
|
||||
if err != nil {
|
||||
f.logger.Debug("forwarder: UDP dial error for %v: %v", id, err)
|
||||
// TODO: Send ICMP error message
|
||||
return
|
||||
}
|
||||
|
||||
// Create wait queue for blocking syscalls
|
||||
wq := waiter.Queue{}
|
||||
|
||||
ep, err := r.CreateEndpoint(&wq)
|
||||
if err != nil {
|
||||
f.logger.Error("forwarder: failed to create UDP endpoint: %v", err)
|
||||
ep, epErr := r.CreateEndpoint(&wq)
|
||||
if epErr != nil {
|
||||
f.logger.Error("forwarder: failed to create UDP endpoint: %v", epErr)
|
||||
if err := outConn.Close(); err != nil {
|
||||
f.logger.Debug("forwarder: UDP outConn close error for %v: %v", id, err)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
inConn := gonet.NewUDPConn(f.stack, &wq, ep)
|
||||
connCtx, connCancel := context.WithCancel(f.ctx)
|
||||
|
||||
// Try to get existing connection or create a new one
|
||||
f.udpForwarder.Lock()
|
||||
defer f.udpForwarder.Unlock()
|
||||
|
||||
pConn, exists := f.udpForwarder.conns[id]
|
||||
if !exists {
|
||||
outConn, err := (&net.Dialer{}).DialContext(f.ctx, "udp", dstAddr)
|
||||
if err != nil {
|
||||
if err := inConn.Close(); err != nil {
|
||||
f.logger.Error("forwarder: UDP inConn close error for %v: %v", id, err)
|
||||
}
|
||||
f.logger.Debug("forwarder: UDP dial error for %v: %v", id, err)
|
||||
|
||||
// TODO: Send ICMP error message
|
||||
return
|
||||
}
|
||||
|
||||
f.logger.Trace("forwarder: established UDP connection to %v", id)
|
||||
|
||||
connCtx, connCancel := context.WithCancel(f.ctx)
|
||||
pConn = &udpPacketConn{
|
||||
conn: inConn,
|
||||
outConn: outConn,
|
||||
lastTime: time.Now(),
|
||||
cancel: connCancel,
|
||||
}
|
||||
f.udpForwarder.conns[id] = pConn
|
||||
|
||||
go f.proxyUDP(connCtx, pConn, id)
|
||||
pConn = &udpPacketConn{
|
||||
conn: inConn,
|
||||
outConn: outConn,
|
||||
cancel: connCancel,
|
||||
}
|
||||
pConn.updateLastSeen()
|
||||
|
||||
f.udpForwarder.Lock()
|
||||
// Double-check no connection was created while we were setting up
|
||||
if _, exists := f.udpForwarder.conns[id]; exists {
|
||||
f.udpForwarder.Unlock()
|
||||
pConn.cancel()
|
||||
if err := inConn.Close(); err != nil {
|
||||
f.logger.Debug("forwarder: UDP inConn close error for %v: %v", id, err)
|
||||
}
|
||||
if err := outConn.Close(); err != nil {
|
||||
f.logger.Debug("forwarder: UDP outConn close error for %v: %v", id, err)
|
||||
}
|
||||
return
|
||||
}
|
||||
f.udpForwarder.conns[id] = pConn
|
||||
f.udpForwarder.Unlock()
|
||||
|
||||
f.logger.Trace("forwarder: established UDP connection to %v", id)
|
||||
go f.proxyUDP(connCtx, pConn, id)
|
||||
}
|
||||
|
||||
func (f *Forwarder) proxyUDP(ctx context.Context, pConn *udpPacketConn, id stack.TransportEndpointID) {
|
||||
@ -175,11 +205,11 @@ func (f *Forwarder) proxyUDP(ctx context.Context, pConn *udpPacketConn, id stack
|
||||
errChan := make(chan error, 2)
|
||||
|
||||
go func() {
|
||||
errChan <- f.copyUDP(ctx, pConn.conn, pConn.outConn, id, "outbound->inbound")
|
||||
errChan <- pConn.copy(ctx, pConn.conn, pConn.outConn, &f.udpForwarder.bufPool, "outbound->inbound")
|
||||
}()
|
||||
|
||||
go func() {
|
||||
errChan <- f.copyUDP(ctx, pConn.outConn, pConn.conn, id, "inbound->outbound")
|
||||
errChan <- pConn.copy(ctx, pConn.outConn, pConn.conn, &f.udpForwarder.bufPool, "inbound->outbound")
|
||||
}()
|
||||
|
||||
select {
|
||||
@ -193,9 +223,18 @@ func (f *Forwarder) proxyUDP(ctx context.Context, pConn *udpPacketConn, id stack
|
||||
}
|
||||
}
|
||||
|
||||
func (f *Forwarder) copyUDP(ctx context.Context, dst net.Conn, src net.Conn, id stack.TransportEndpointID, direction string) error {
|
||||
bufp := f.udpForwarder.bufPool.Get().(*[]byte)
|
||||
defer f.udpForwarder.bufPool.Put(bufp)
|
||||
func (c *udpPacketConn) updateLastSeen() {
|
||||
c.lastSeen.Store(time.Now().UnixNano())
|
||||
}
|
||||
|
||||
func (c *udpPacketConn) getIdleDuration() time.Duration {
|
||||
lastSeen := time.Unix(0, c.lastSeen.Load())
|
||||
return time.Since(lastSeen)
|
||||
}
|
||||
|
||||
func (c *udpPacketConn) copy(ctx context.Context, dst net.Conn, src net.Conn, bufPool *sync.Pool, direction string) error {
|
||||
bufp := bufPool.Get().(*[]byte)
|
||||
defer bufPool.Put(bufp)
|
||||
buffer := *bufp
|
||||
|
||||
if err := src.SetReadDeadline(time.Now().Add(udpTimeout)); err != nil {
|
||||
@ -223,11 +262,7 @@ func (f *Forwarder) copyUDP(ctx context.Context, dst net.Conn, src net.Conn, id
|
||||
return fmt.Errorf("write to %s: %w", direction, err)
|
||||
}
|
||||
|
||||
f.udpForwarder.Lock()
|
||||
if conn, ok := f.udpForwarder.conns[id]; ok {
|
||||
conn.lastTime = time.Now()
|
||||
}
|
||||
f.udpForwarder.Unlock()
|
||||
c.updateLastSeen()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
Loading…
x
Reference in New Issue
Block a user