mirror of
https://github.com/netbirdio/netbird.git
synced 2025-06-20 09:47:49 +02:00
Improve udp implementation
This commit is contained in:
parent
d2616544fe
commit
6a97d44d5d
@ -26,13 +26,11 @@ func (f *Forwarder) handleTCP(r *tcp.ForwarderRequest) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
f.logger.Trace("forwarder: established TCP connection to %v", id)
|
|
||||||
|
|
||||||
// Create wait queue for blocking syscalls
|
// Create wait queue for blocking syscalls
|
||||||
wq := waiter.Queue{}
|
wq := waiter.Queue{}
|
||||||
|
|
||||||
ep, err2 := r.CreateEndpoint(&wq)
|
ep, epErr := r.CreateEndpoint(&wq)
|
||||||
if err2 != nil {
|
if epErr != nil {
|
||||||
if err := outConn.Close(); err != nil {
|
if err := outConn.Close(); err != nil {
|
||||||
f.logger.Error("forwarder: outConn close error: %v", err)
|
f.logger.Error("forwarder: outConn close error: %v", err)
|
||||||
}
|
}
|
||||||
@ -45,6 +43,8 @@ func (f *Forwarder) handleTCP(r *tcp.ForwarderRequest) {
|
|||||||
|
|
||||||
inConn := gonet.NewTCPConn(&wq, ep)
|
inConn := gonet.NewTCPConn(&wq, ep)
|
||||||
|
|
||||||
|
f.logger.Trace("forwarder: established TCP connection to %v", id)
|
||||||
|
|
||||||
go f.proxyTCP(inConn, outConn)
|
go f.proxyTCP(inConn, outConn)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -6,8 +6,10 @@ import (
|
|||||||
"fmt"
|
"fmt"
|
||||||
"net"
|
"net"
|
||||||
"sync"
|
"sync"
|
||||||
|
"sync/atomic"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
log "github.com/sirupsen/logrus"
|
||||||
"gvisor.dev/gvisor/pkg/tcpip/adapters/gonet"
|
"gvisor.dev/gvisor/pkg/tcpip/adapters/gonet"
|
||||||
"gvisor.dev/gvisor/pkg/tcpip/stack"
|
"gvisor.dev/gvisor/pkg/tcpip/stack"
|
||||||
"gvisor.dev/gvisor/pkg/tcpip/transport/udp"
|
"gvisor.dev/gvisor/pkg/tcpip/transport/udp"
|
||||||
@ -24,7 +26,7 @@ const (
|
|||||||
type udpPacketConn struct {
|
type udpPacketConn struct {
|
||||||
conn *gonet.UDPConn
|
conn *gonet.UDPConn
|
||||||
outConn net.Conn
|
outConn net.Conn
|
||||||
lastTime time.Time
|
lastSeen atomic.Int64
|
||||||
cancel context.CancelFunc
|
cancel context.CancelFunc
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -84,22 +86,37 @@ func (f *udpForwarder) cleanup() {
|
|||||||
case <-f.ctx.Done():
|
case <-f.ctx.Done():
|
||||||
return
|
return
|
||||||
case <-ticker.C:
|
case <-ticker.C:
|
||||||
f.Lock()
|
var idleConns []struct {
|
||||||
now := time.Now()
|
id stack.TransportEndpointID
|
||||||
|
conn *udpPacketConn
|
||||||
|
}
|
||||||
|
|
||||||
|
f.RLock()
|
||||||
for id, conn := range f.conns {
|
for id, conn := range f.conns {
|
||||||
if now.Sub(conn.lastTime) > udpTimeout {
|
if conn.getIdleDuration() > udpTimeout {
|
||||||
conn.cancel()
|
idleConns = append(idleConns, struct {
|
||||||
if err := conn.conn.Close(); err != nil {
|
id stack.TransportEndpointID
|
||||||
f.logger.Error("forwarder: UDP conn close error for %v: %v", id, err)
|
conn *udpPacketConn
|
||||||
}
|
}{id, conn})
|
||||||
if err := conn.outConn.Close(); err != nil {
|
|
||||||
f.logger.Error("forwarder: UDP outConn close error for %v: %v", id, err)
|
|
||||||
}
|
|
||||||
delete(f.conns, id)
|
|
||||||
f.logger.Trace("forwarder: cleaned up idle UDP connection %v", id)
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
f.Unlock()
|
f.RUnlock()
|
||||||
|
|
||||||
|
for _, idle := range idleConns {
|
||||||
|
idle.conn.cancel()
|
||||||
|
if err := idle.conn.conn.Close(); err != nil {
|
||||||
|
f.logger.Error("forwarder: UDP conn close error for %v: %v", idle.id, err)
|
||||||
|
}
|
||||||
|
if err := idle.conn.outConn.Close(); err != nil {
|
||||||
|
f.logger.Error("forwarder: UDP outConn close error for %v: %v", idle.id, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
f.Lock()
|
||||||
|
delete(f.conns, idle.id)
|
||||||
|
f.Unlock()
|
||||||
|
|
||||||
|
f.logger.Trace("forwarder: cleaned up idle UDP connection %v", idle.id)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -114,47 +131,60 @@ func (f *Forwarder) handleUDP(r *udp.ForwarderRequest) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
f.udpForwarder.RLock()
|
||||||
|
pConn, exists := f.udpForwarder.conns[id]
|
||||||
|
f.udpForwarder.RUnlock()
|
||||||
|
if exists {
|
||||||
|
f.logger.Trace("forwarder: existing UDP connection for %v", id)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
outConn, err := (&net.Dialer{}).DialContext(f.ctx, "udp", dstAddr)
|
||||||
|
if err != nil {
|
||||||
|
f.logger.Debug("forwarder: UDP dial error for %v: %v", id, err)
|
||||||
|
// TODO: Send ICMP error message
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
// Create wait queue for blocking syscalls
|
// Create wait queue for blocking syscalls
|
||||||
wq := waiter.Queue{}
|
wq := waiter.Queue{}
|
||||||
|
ep, epErr := r.CreateEndpoint(&wq)
|
||||||
ep, err := r.CreateEndpoint(&wq)
|
if epErr != nil {
|
||||||
if err != nil {
|
f.logger.Error("forwarder: failed to create UDP endpoint: %v", epErr)
|
||||||
f.logger.Error("forwarder: failed to create UDP endpoint: %v", err)
|
if err := outConn.Close(); err != nil {
|
||||||
|
f.logger.Debug("forwarder: UDP outConn close error for %v: %v", id, err)
|
||||||
|
}
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
inConn := gonet.NewUDPConn(f.stack, &wq, ep)
|
inConn := gonet.NewUDPConn(f.stack, &wq, ep)
|
||||||
|
connCtx, connCancel := context.WithCancel(f.ctx)
|
||||||
|
|
||||||
// Try to get existing connection or create a new one
|
pConn = &udpPacketConn{
|
||||||
f.udpForwarder.Lock()
|
conn: inConn,
|
||||||
defer f.udpForwarder.Unlock()
|
outConn: outConn,
|
||||||
|
cancel: connCancel,
|
||||||
pConn, exists := f.udpForwarder.conns[id]
|
|
||||||
if !exists {
|
|
||||||
outConn, err := (&net.Dialer{}).DialContext(f.ctx, "udp", dstAddr)
|
|
||||||
if err != nil {
|
|
||||||
if err := inConn.Close(); err != nil {
|
|
||||||
f.logger.Error("forwarder: UDP inConn close error for %v: %v", id, err)
|
|
||||||
}
|
|
||||||
f.logger.Debug("forwarder: UDP dial error for %v: %v", id, err)
|
|
||||||
|
|
||||||
// TODO: Send ICMP error message
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
f.logger.Trace("forwarder: established UDP connection to %v", id)
|
|
||||||
|
|
||||||
connCtx, connCancel := context.WithCancel(f.ctx)
|
|
||||||
pConn = &udpPacketConn{
|
|
||||||
conn: inConn,
|
|
||||||
outConn: outConn,
|
|
||||||
lastTime: time.Now(),
|
|
||||||
cancel: connCancel,
|
|
||||||
}
|
|
||||||
f.udpForwarder.conns[id] = pConn
|
|
||||||
|
|
||||||
go f.proxyUDP(connCtx, pConn, id)
|
|
||||||
}
|
}
|
||||||
|
pConn.updateLastSeen()
|
||||||
|
|
||||||
|
f.udpForwarder.Lock()
|
||||||
|
// Double-check no connection was created while we were setting up
|
||||||
|
if _, exists := f.udpForwarder.conns[id]; exists {
|
||||||
|
f.udpForwarder.Unlock()
|
||||||
|
pConn.cancel()
|
||||||
|
if err := inConn.Close(); err != nil {
|
||||||
|
f.logger.Debug("forwarder: UDP inConn close error for %v: %v", id, err)
|
||||||
|
}
|
||||||
|
if err := outConn.Close(); err != nil {
|
||||||
|
f.logger.Debug("forwarder: UDP outConn close error for %v: %v", id, err)
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
f.udpForwarder.conns[id] = pConn
|
||||||
|
f.udpForwarder.Unlock()
|
||||||
|
|
||||||
|
f.logger.Trace("forwarder: established UDP connection to %v", id)
|
||||||
|
go f.proxyUDP(connCtx, pConn, id)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (f *Forwarder) proxyUDP(ctx context.Context, pConn *udpPacketConn, id stack.TransportEndpointID) {
|
func (f *Forwarder) proxyUDP(ctx context.Context, pConn *udpPacketConn, id stack.TransportEndpointID) {
|
||||||
@ -175,11 +205,11 @@ func (f *Forwarder) proxyUDP(ctx context.Context, pConn *udpPacketConn, id stack
|
|||||||
errChan := make(chan error, 2)
|
errChan := make(chan error, 2)
|
||||||
|
|
||||||
go func() {
|
go func() {
|
||||||
errChan <- f.copyUDP(ctx, pConn.conn, pConn.outConn, id, "outbound->inbound")
|
errChan <- pConn.copy(ctx, pConn.conn, pConn.outConn, &f.udpForwarder.bufPool, "outbound->inbound")
|
||||||
}()
|
}()
|
||||||
|
|
||||||
go func() {
|
go func() {
|
||||||
errChan <- f.copyUDP(ctx, pConn.outConn, pConn.conn, id, "inbound->outbound")
|
errChan <- pConn.copy(ctx, pConn.outConn, pConn.conn, &f.udpForwarder.bufPool, "inbound->outbound")
|
||||||
}()
|
}()
|
||||||
|
|
||||||
select {
|
select {
|
||||||
@ -193,9 +223,18 @@ func (f *Forwarder) proxyUDP(ctx context.Context, pConn *udpPacketConn, id stack
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (f *Forwarder) copyUDP(ctx context.Context, dst net.Conn, src net.Conn, id stack.TransportEndpointID, direction string) error {
|
func (c *udpPacketConn) updateLastSeen() {
|
||||||
bufp := f.udpForwarder.bufPool.Get().(*[]byte)
|
c.lastSeen.Store(time.Now().UnixNano())
|
||||||
defer f.udpForwarder.bufPool.Put(bufp)
|
}
|
||||||
|
|
||||||
|
func (c *udpPacketConn) getIdleDuration() time.Duration {
|
||||||
|
lastSeen := time.Unix(0, c.lastSeen.Load())
|
||||||
|
return time.Since(lastSeen)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *udpPacketConn) copy(ctx context.Context, dst net.Conn, src net.Conn, bufPool *sync.Pool, direction string) error {
|
||||||
|
bufp := bufPool.Get().(*[]byte)
|
||||||
|
defer bufPool.Put(bufp)
|
||||||
buffer := *bufp
|
buffer := *bufp
|
||||||
|
|
||||||
if err := src.SetReadDeadline(time.Now().Add(udpTimeout)); err != nil {
|
if err := src.SetReadDeadline(time.Now().Add(udpTimeout)); err != nil {
|
||||||
@ -223,11 +262,7 @@ func (f *Forwarder) copyUDP(ctx context.Context, dst net.Conn, src net.Conn, id
|
|||||||
return fmt.Errorf("write to %s: %w", direction, err)
|
return fmt.Errorf("write to %s: %w", direction, err)
|
||||||
}
|
}
|
||||||
|
|
||||||
f.udpForwarder.Lock()
|
c.updateLastSeen()
|
||||||
if conn, ok := f.udpForwarder.conns[id]; ok {
|
|
||||||
conn.lastTime = time.Now()
|
|
||||||
}
|
|
||||||
f.udpForwarder.Unlock()
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
Loading…
x
Reference in New Issue
Block a user