Improve udp implementation

This commit is contained in:
Viktor Liu 2024-12-30 20:50:20 +01:00
parent d2616544fe
commit 6a97d44d5d
2 changed files with 96 additions and 61 deletions

View File

@ -26,13 +26,11 @@ func (f *Forwarder) handleTCP(r *tcp.ForwarderRequest) {
return return
} }
f.logger.Trace("forwarder: established TCP connection to %v", id)
// Create wait queue for blocking syscalls // Create wait queue for blocking syscalls
wq := waiter.Queue{} wq := waiter.Queue{}
ep, err2 := r.CreateEndpoint(&wq) ep, epErr := r.CreateEndpoint(&wq)
if err2 != nil { if epErr != nil {
if err := outConn.Close(); err != nil { if err := outConn.Close(); err != nil {
f.logger.Error("forwarder: outConn close error: %v", err) f.logger.Error("forwarder: outConn close error: %v", err)
} }
@ -45,6 +43,8 @@ func (f *Forwarder) handleTCP(r *tcp.ForwarderRequest) {
inConn := gonet.NewTCPConn(&wq, ep) inConn := gonet.NewTCPConn(&wq, ep)
f.logger.Trace("forwarder: established TCP connection to %v", id)
go f.proxyTCP(inConn, outConn) go f.proxyTCP(inConn, outConn)
} }

View File

@ -6,8 +6,10 @@ import (
"fmt" "fmt"
"net" "net"
"sync" "sync"
"sync/atomic"
"time" "time"
log "github.com/sirupsen/logrus"
"gvisor.dev/gvisor/pkg/tcpip/adapters/gonet" "gvisor.dev/gvisor/pkg/tcpip/adapters/gonet"
"gvisor.dev/gvisor/pkg/tcpip/stack" "gvisor.dev/gvisor/pkg/tcpip/stack"
"gvisor.dev/gvisor/pkg/tcpip/transport/udp" "gvisor.dev/gvisor/pkg/tcpip/transport/udp"
@ -24,7 +26,7 @@ const (
type udpPacketConn struct { type udpPacketConn struct {
conn *gonet.UDPConn conn *gonet.UDPConn
outConn net.Conn outConn net.Conn
lastTime time.Time lastSeen atomic.Int64
cancel context.CancelFunc cancel context.CancelFunc
} }
@ -84,22 +86,37 @@ func (f *udpForwarder) cleanup() {
case <-f.ctx.Done(): case <-f.ctx.Done():
return return
case <-ticker.C: case <-ticker.C:
f.Lock() var idleConns []struct {
now := time.Now() id stack.TransportEndpointID
conn *udpPacketConn
}
f.RLock()
for id, conn := range f.conns { for id, conn := range f.conns {
if now.Sub(conn.lastTime) > udpTimeout { if conn.getIdleDuration() > udpTimeout {
conn.cancel() idleConns = append(idleConns, struct {
if err := conn.conn.Close(); err != nil { id stack.TransportEndpointID
f.logger.Error("forwarder: UDP conn close error for %v: %v", id, err) conn *udpPacketConn
} }{id, conn})
if err := conn.outConn.Close(); err != nil {
f.logger.Error("forwarder: UDP outConn close error for %v: %v", id, err)
}
delete(f.conns, id)
f.logger.Trace("forwarder: cleaned up idle UDP connection %v", id)
} }
} }
f.Unlock() f.RUnlock()
for _, idle := range idleConns {
idle.conn.cancel()
if err := idle.conn.conn.Close(); err != nil {
f.logger.Error("forwarder: UDP conn close error for %v: %v", idle.id, err)
}
if err := idle.conn.outConn.Close(); err != nil {
f.logger.Error("forwarder: UDP outConn close error for %v: %v", idle.id, err)
}
f.Lock()
delete(f.conns, idle.id)
f.Unlock()
f.logger.Trace("forwarder: cleaned up idle UDP connection %v", idle.id)
}
} }
} }
} }
@ -114,47 +131,60 @@ func (f *Forwarder) handleUDP(r *udp.ForwarderRequest) {
return return
} }
f.udpForwarder.RLock()
pConn, exists := f.udpForwarder.conns[id]
f.udpForwarder.RUnlock()
if exists {
f.logger.Trace("forwarder: existing UDP connection for %v", id)
return
}
outConn, err := (&net.Dialer{}).DialContext(f.ctx, "udp", dstAddr)
if err != nil {
f.logger.Debug("forwarder: UDP dial error for %v: %v", id, err)
// TODO: Send ICMP error message
return
}
// Create wait queue for blocking syscalls // Create wait queue for blocking syscalls
wq := waiter.Queue{} wq := waiter.Queue{}
ep, epErr := r.CreateEndpoint(&wq)
ep, err := r.CreateEndpoint(&wq) if epErr != nil {
if err != nil { f.logger.Error("forwarder: failed to create UDP endpoint: %v", epErr)
f.logger.Error("forwarder: failed to create UDP endpoint: %v", err) if err := outConn.Close(); err != nil {
f.logger.Debug("forwarder: UDP outConn close error for %v: %v", id, err)
}
return return
} }
inConn := gonet.NewUDPConn(f.stack, &wq, ep) inConn := gonet.NewUDPConn(f.stack, &wq, ep)
connCtx, connCancel := context.WithCancel(f.ctx)
// Try to get existing connection or create a new one pConn = &udpPacketConn{
f.udpForwarder.Lock() conn: inConn,
defer f.udpForwarder.Unlock() outConn: outConn,
cancel: connCancel,
pConn, exists := f.udpForwarder.conns[id]
if !exists {
outConn, err := (&net.Dialer{}).DialContext(f.ctx, "udp", dstAddr)
if err != nil {
if err := inConn.Close(); err != nil {
f.logger.Error("forwarder: UDP inConn close error for %v: %v", id, err)
}
f.logger.Debug("forwarder: UDP dial error for %v: %v", id, err)
// TODO: Send ICMP error message
return
}
f.logger.Trace("forwarder: established UDP connection to %v", id)
connCtx, connCancel := context.WithCancel(f.ctx)
pConn = &udpPacketConn{
conn: inConn,
outConn: outConn,
lastTime: time.Now(),
cancel: connCancel,
}
f.udpForwarder.conns[id] = pConn
go f.proxyUDP(connCtx, pConn, id)
} }
pConn.updateLastSeen()
f.udpForwarder.Lock()
// Double-check no connection was created while we were setting up
if _, exists := f.udpForwarder.conns[id]; exists {
f.udpForwarder.Unlock()
pConn.cancel()
if err := inConn.Close(); err != nil {
f.logger.Debug("forwarder: UDP inConn close error for %v: %v", id, err)
}
if err := outConn.Close(); err != nil {
f.logger.Debug("forwarder: UDP outConn close error for %v: %v", id, err)
}
return
}
f.udpForwarder.conns[id] = pConn
f.udpForwarder.Unlock()
f.logger.Trace("forwarder: established UDP connection to %v", id)
go f.proxyUDP(connCtx, pConn, id)
} }
func (f *Forwarder) proxyUDP(ctx context.Context, pConn *udpPacketConn, id stack.TransportEndpointID) { func (f *Forwarder) proxyUDP(ctx context.Context, pConn *udpPacketConn, id stack.TransportEndpointID) {
@ -175,11 +205,11 @@ func (f *Forwarder) proxyUDP(ctx context.Context, pConn *udpPacketConn, id stack
errChan := make(chan error, 2) errChan := make(chan error, 2)
go func() { go func() {
errChan <- f.copyUDP(ctx, pConn.conn, pConn.outConn, id, "outbound->inbound") errChan <- pConn.copy(ctx, pConn.conn, pConn.outConn, &f.udpForwarder.bufPool, "outbound->inbound")
}() }()
go func() { go func() {
errChan <- f.copyUDP(ctx, pConn.outConn, pConn.conn, id, "inbound->outbound") errChan <- pConn.copy(ctx, pConn.outConn, pConn.conn, &f.udpForwarder.bufPool, "inbound->outbound")
}() }()
select { select {
@ -193,9 +223,18 @@ func (f *Forwarder) proxyUDP(ctx context.Context, pConn *udpPacketConn, id stack
} }
} }
func (f *Forwarder) copyUDP(ctx context.Context, dst net.Conn, src net.Conn, id stack.TransportEndpointID, direction string) error { func (c *udpPacketConn) updateLastSeen() {
bufp := f.udpForwarder.bufPool.Get().(*[]byte) c.lastSeen.Store(time.Now().UnixNano())
defer f.udpForwarder.bufPool.Put(bufp) }
func (c *udpPacketConn) getIdleDuration() time.Duration {
lastSeen := time.Unix(0, c.lastSeen.Load())
return time.Since(lastSeen)
}
func (c *udpPacketConn) copy(ctx context.Context, dst net.Conn, src net.Conn, bufPool *sync.Pool, direction string) error {
bufp := bufPool.Get().(*[]byte)
defer bufPool.Put(bufp)
buffer := *bufp buffer := *bufp
if err := src.SetReadDeadline(time.Now().Add(udpTimeout)); err != nil { if err := src.SetReadDeadline(time.Now().Add(udpTimeout)); err != nil {
@ -223,11 +262,7 @@ func (f *Forwarder) copyUDP(ctx context.Context, dst net.Conn, src net.Conn, id
return fmt.Errorf("write to %s: %w", direction, err) return fmt.Errorf("write to %s: %w", direction, err)
} }
f.udpForwarder.Lock() c.updateLastSeen()
if conn, ok := f.udpForwarder.conns[id]; ok {
conn.lastTime = time.Now()
}
f.udpForwarder.Unlock()
} }
} }
} }