Improve udp implementation

This commit is contained in:
Viktor Liu 2024-12-30 20:50:20 +01:00
parent d2616544fe
commit 6a97d44d5d
2 changed files with 96 additions and 61 deletions

View File

@ -26,13 +26,11 @@ func (f *Forwarder) handleTCP(r *tcp.ForwarderRequest) {
return
}
f.logger.Trace("forwarder: established TCP connection to %v", id)
// Create wait queue for blocking syscalls
wq := waiter.Queue{}
ep, err2 := r.CreateEndpoint(&wq)
if err2 != nil {
ep, epErr := r.CreateEndpoint(&wq)
if epErr != nil {
if err := outConn.Close(); err != nil {
f.logger.Error("forwarder: outConn close error: %v", err)
}
@ -45,6 +43,8 @@ func (f *Forwarder) handleTCP(r *tcp.ForwarderRequest) {
inConn := gonet.NewTCPConn(&wq, ep)
f.logger.Trace("forwarder: established TCP connection to %v", id)
go f.proxyTCP(inConn, outConn)
}

View File

@ -6,8 +6,10 @@ import (
"fmt"
"net"
"sync"
"sync/atomic"
"time"
log "github.com/sirupsen/logrus"
"gvisor.dev/gvisor/pkg/tcpip/adapters/gonet"
"gvisor.dev/gvisor/pkg/tcpip/stack"
"gvisor.dev/gvisor/pkg/tcpip/transport/udp"
@ -24,7 +26,7 @@ const (
type udpPacketConn struct {
conn *gonet.UDPConn
outConn net.Conn
lastTime time.Time
lastSeen atomic.Int64
cancel context.CancelFunc
}
@ -84,22 +86,37 @@ func (f *udpForwarder) cleanup() {
case <-f.ctx.Done():
return
case <-ticker.C:
f.Lock()
now := time.Now()
var idleConns []struct {
id stack.TransportEndpointID
conn *udpPacketConn
}
f.RLock()
for id, conn := range f.conns {
if now.Sub(conn.lastTime) > udpTimeout {
conn.cancel()
if err := conn.conn.Close(); err != nil {
f.logger.Error("forwarder: UDP conn close error for %v: %v", id, err)
}
if err := conn.outConn.Close(); err != nil {
f.logger.Error("forwarder: UDP outConn close error for %v: %v", id, err)
}
delete(f.conns, id)
f.logger.Trace("forwarder: cleaned up idle UDP connection %v", id)
if conn.getIdleDuration() > udpTimeout {
idleConns = append(idleConns, struct {
id stack.TransportEndpointID
conn *udpPacketConn
}{id, conn})
}
}
f.Unlock()
f.RUnlock()
for _, idle := range idleConns {
idle.conn.cancel()
if err := idle.conn.conn.Close(); err != nil {
f.logger.Error("forwarder: UDP conn close error for %v: %v", idle.id, err)
}
if err := idle.conn.outConn.Close(); err != nil {
f.logger.Error("forwarder: UDP outConn close error for %v: %v", idle.id, err)
}
f.Lock()
delete(f.conns, idle.id)
f.Unlock()
f.logger.Trace("forwarder: cleaned up idle UDP connection %v", idle.id)
}
}
}
}
@ -114,47 +131,60 @@ func (f *Forwarder) handleUDP(r *udp.ForwarderRequest) {
return
}
f.udpForwarder.RLock()
pConn, exists := f.udpForwarder.conns[id]
f.udpForwarder.RUnlock()
if exists {
f.logger.Trace("forwarder: existing UDP connection for %v", id)
return
}
outConn, err := (&net.Dialer{}).DialContext(f.ctx, "udp", dstAddr)
if err != nil {
f.logger.Debug("forwarder: UDP dial error for %v: %v", id, err)
// TODO: Send ICMP error message
return
}
// Create wait queue for blocking syscalls
wq := waiter.Queue{}
ep, err := r.CreateEndpoint(&wq)
if err != nil {
f.logger.Error("forwarder: failed to create UDP endpoint: %v", err)
ep, epErr := r.CreateEndpoint(&wq)
if epErr != nil {
f.logger.Error("forwarder: failed to create UDP endpoint: %v", epErr)
if err := outConn.Close(); err != nil {
f.logger.Debug("forwarder: UDP outConn close error for %v: %v", id, err)
}
return
}
inConn := gonet.NewUDPConn(f.stack, &wq, ep)
connCtx, connCancel := context.WithCancel(f.ctx)
// Try to get existing connection or create a new one
f.udpForwarder.Lock()
defer f.udpForwarder.Unlock()
pConn, exists := f.udpForwarder.conns[id]
if !exists {
outConn, err := (&net.Dialer{}).DialContext(f.ctx, "udp", dstAddr)
if err != nil {
if err := inConn.Close(); err != nil {
f.logger.Error("forwarder: UDP inConn close error for %v: %v", id, err)
}
f.logger.Debug("forwarder: UDP dial error for %v: %v", id, err)
// TODO: Send ICMP error message
return
}
f.logger.Trace("forwarder: established UDP connection to %v", id)
connCtx, connCancel := context.WithCancel(f.ctx)
pConn = &udpPacketConn{
conn: inConn,
outConn: outConn,
lastTime: time.Now(),
cancel: connCancel,
}
f.udpForwarder.conns[id] = pConn
go f.proxyUDP(connCtx, pConn, id)
pConn = &udpPacketConn{
conn: inConn,
outConn: outConn,
cancel: connCancel,
}
pConn.updateLastSeen()
f.udpForwarder.Lock()
// Double-check no connection was created while we were setting up
if _, exists := f.udpForwarder.conns[id]; exists {
f.udpForwarder.Unlock()
pConn.cancel()
if err := inConn.Close(); err != nil {
f.logger.Debug("forwarder: UDP inConn close error for %v: %v", id, err)
}
if err := outConn.Close(); err != nil {
f.logger.Debug("forwarder: UDP outConn close error for %v: %v", id, err)
}
return
}
f.udpForwarder.conns[id] = pConn
f.udpForwarder.Unlock()
f.logger.Trace("forwarder: established UDP connection to %v", id)
go f.proxyUDP(connCtx, pConn, id)
}
func (f *Forwarder) proxyUDP(ctx context.Context, pConn *udpPacketConn, id stack.TransportEndpointID) {
@ -175,11 +205,11 @@ func (f *Forwarder) proxyUDP(ctx context.Context, pConn *udpPacketConn, id stack
errChan := make(chan error, 2)
go func() {
errChan <- f.copyUDP(ctx, pConn.conn, pConn.outConn, id, "outbound->inbound")
errChan <- pConn.copy(ctx, pConn.conn, pConn.outConn, &f.udpForwarder.bufPool, "outbound->inbound")
}()
go func() {
errChan <- f.copyUDP(ctx, pConn.outConn, pConn.conn, id, "inbound->outbound")
errChan <- pConn.copy(ctx, pConn.outConn, pConn.conn, &f.udpForwarder.bufPool, "inbound->outbound")
}()
select {
@ -193,9 +223,18 @@ func (f *Forwarder) proxyUDP(ctx context.Context, pConn *udpPacketConn, id stack
}
}
func (f *Forwarder) copyUDP(ctx context.Context, dst net.Conn, src net.Conn, id stack.TransportEndpointID, direction string) error {
bufp := f.udpForwarder.bufPool.Get().(*[]byte)
defer f.udpForwarder.bufPool.Put(bufp)
func (c *udpPacketConn) updateLastSeen() {
c.lastSeen.Store(time.Now().UnixNano())
}
func (c *udpPacketConn) getIdleDuration() time.Duration {
lastSeen := time.Unix(0, c.lastSeen.Load())
return time.Since(lastSeen)
}
func (c *udpPacketConn) copy(ctx context.Context, dst net.Conn, src net.Conn, bufPool *sync.Pool, direction string) error {
bufp := bufPool.Get().(*[]byte)
defer bufPool.Put(bufp)
buffer := *bufp
if err := src.SetReadDeadline(time.Now().Add(udpTimeout)); err != nil {
@ -223,11 +262,7 @@ func (f *Forwarder) copyUDP(ctx context.Context, dst net.Conn, src net.Conn, id
return fmt.Errorf("write to %s: %w", direction, err)
}
f.udpForwarder.Lock()
if conn, ok := f.udpForwarder.conns[id]; ok {
conn.lastTime = time.Now()
}
f.udpForwarder.Unlock()
c.updateLastSeen()
}
}
}