mirror of
https://github.com/netbirdio/netbird.git
synced 2025-06-21 10:18:50 +02:00
Add stop methods and improve udp implementation
This commit is contained in:
parent
b43a8c56df
commit
fad82ee65c
@ -30,6 +30,10 @@ func (m *Manager) Reset(stateManager *statemanager.Manager) error {
|
||||
m.tcpTracker = conntrack.NewTCPTracker(conntrack.DefaultTCPTimeout)
|
||||
}
|
||||
|
||||
if m.forwarder != nil {
|
||||
m.forwarder.Stop()
|
||||
}
|
||||
|
||||
if m.nativeFirewall != nil {
|
||||
return m.nativeFirewall.Reset(stateManager)
|
||||
}
|
||||
|
@ -42,6 +42,10 @@ func (m *Manager) Reset(*statemanager.Manager) error {
|
||||
m.tcpTracker = conntrack.NewTCPTracker(conntrack.DefaultTCPTimeout)
|
||||
}
|
||||
|
||||
if m.forwarder != nil {
|
||||
m.forwarder.Stop()
|
||||
}
|
||||
|
||||
if !isWindowsFirewallReachable() {
|
||||
return nil
|
||||
}
|
||||
|
@ -1,6 +1,7 @@
|
||||
package forwarder
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
@ -24,9 +25,12 @@ type Forwarder struct {
|
||||
stack *stack.Stack
|
||||
endpoint *endpoint
|
||||
udpForwarder *udpForwarder
|
||||
ctx context.Context
|
||||
cancel context.CancelFunc
|
||||
}
|
||||
|
||||
func New(iface common.IFaceMapper) (*Forwarder, error) {
|
||||
|
||||
s := stack.New(stack.Options{
|
||||
NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol},
|
||||
TransportProtocols: []stack.TransportProtocolFactory{
|
||||
@ -85,10 +89,13 @@ func New(iface common.IFaceMapper) (*Forwarder, error) {
|
||||
},
|
||||
})
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
f := &Forwarder{
|
||||
stack: s,
|
||||
endpoint: endpoint,
|
||||
udpForwarder: newUDPForwarder(),
|
||||
ctx: ctx,
|
||||
cancel: cancel,
|
||||
}
|
||||
|
||||
// Set up TCP forwarder
|
||||
@ -118,3 +125,17 @@ func (f *Forwarder) InjectIncomingPacket(payload []byte) error {
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// Stop gracefully shuts down the forwarder
|
||||
func (f *Forwarder) Stop() error {
|
||||
f.cancel()
|
||||
|
||||
if f.udpForwarder != nil {
|
||||
f.udpForwarder.Stop()
|
||||
}
|
||||
|
||||
f.stack.Close()
|
||||
f.stack.Wait()
|
||||
|
||||
return nil
|
||||
}
|
||||
|
@ -1,10 +1,10 @@
|
||||
package forwarder
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"io"
|
||||
"net"
|
||||
"sync"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
"gvisor.dev/gvisor/pkg/tcpip/adapters/gonet"
|
||||
@ -20,9 +20,7 @@ func (f *Forwarder) handleTCP(r *tcp.ForwarderRequest) {
|
||||
dstPort := id.LocalPort
|
||||
dialAddr := fmt.Sprintf("%s:%d", dstAddr.String(), dstPort)
|
||||
|
||||
// Dial the destination first
|
||||
dialer := net.Dialer{}
|
||||
outConn, err := dialer.Dial("tcp", dialAddr)
|
||||
outConn, err := (&net.Dialer{}).DialContext(f.ctx, "tcp", dialAddr)
|
||||
if err != nil {
|
||||
r.Complete(true)
|
||||
return
|
||||
@ -40,8 +38,7 @@ func (f *Forwarder) handleTCP(r *tcp.ForwarderRequest) {
|
||||
return
|
||||
}
|
||||
|
||||
// Now that we've successfully connected to the destination,
|
||||
// we can complete the incoming connection
|
||||
// Complete the handshake
|
||||
r.Complete(false)
|
||||
|
||||
inConn := gonet.NewTCPConn(&wq, ep)
|
||||
@ -59,24 +56,35 @@ func (f *Forwarder) proxyTCP(inConn *gonet.TCPConn, outConn net.Conn) {
|
||||
}
|
||||
}()
|
||||
|
||||
var wg sync.WaitGroup
|
||||
wg.Add(2)
|
||||
// Create context for managing the proxy goroutines
|
||||
ctx, cancel := context.WithCancel(f.ctx)
|
||||
defer cancel()
|
||||
|
||||
errChan := make(chan error, 2)
|
||||
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
_, err := io.Copy(outConn, inConn)
|
||||
if err != nil {
|
||||
log.Errorf("proxyTCP: copy error: %v", err)
|
||||
n, err := io.Copy(outConn, inConn)
|
||||
if err != nil && !isClosedError(err) {
|
||||
log.Errorf("proxyTCP: inbound->outbound copy error after %d bytes: %v", n, err)
|
||||
}
|
||||
errChan <- err
|
||||
}()
|
||||
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
_, err := io.Copy(inConn, outConn)
|
||||
if err != nil {
|
||||
log.Errorf("proxyTCP: copy error: %v", err)
|
||||
n, err := io.Copy(inConn, outConn)
|
||||
if err != nil && !isClosedError(err) {
|
||||
log.Errorf("proxyTCP: outbound->inbound copy error after %d bytes: %v", n, err)
|
||||
}
|
||||
errChan <- err
|
||||
}()
|
||||
|
||||
wg.Wait()
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return
|
||||
case err := <-errChan:
|
||||
if err != nil && !isClosedError(err) {
|
||||
log.Errorf("proxyTCP: copy error: %v", err)
|
||||
}
|
||||
return
|
||||
}
|
||||
}
|
||||
|
@ -1,6 +1,8 @@
|
||||
package forwarder
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net"
|
||||
"sync"
|
||||
@ -8,63 +10,113 @@ import (
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
"gvisor.dev/gvisor/pkg/tcpip/adapters/gonet"
|
||||
"gvisor.dev/gvisor/pkg/tcpip/stack"
|
||||
"gvisor.dev/gvisor/pkg/tcpip/transport/udp"
|
||||
"gvisor.dev/gvisor/pkg/waiter"
|
||||
)
|
||||
|
||||
const (
|
||||
udpTimeout = 60 * time.Second
|
||||
maxPacketSize = 65535
|
||||
)
|
||||
|
||||
type udpPacketConn struct {
|
||||
conn *gonet.UDPConn
|
||||
outConn net.Conn
|
||||
lastTime time.Time
|
||||
cancel context.CancelFunc
|
||||
}
|
||||
|
||||
type udpForwarder struct {
|
||||
sync.RWMutex
|
||||
conns map[string]*udpPacketConn
|
||||
conns map[stack.TransportEndpointID]*udpPacketConn
|
||||
bufPool sync.Pool
|
||||
ctx context.Context
|
||||
cancel context.CancelFunc
|
||||
}
|
||||
|
||||
func newUDPForwarder() *udpForwarder {
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
f := &udpForwarder{
|
||||
conns: make(map[string]*udpPacketConn),
|
||||
conns: make(map[stack.TransportEndpointID]*udpPacketConn),
|
||||
ctx: ctx,
|
||||
cancel: cancel,
|
||||
bufPool: sync.Pool{
|
||||
New: func() any {
|
||||
b := make([]byte, maxPacketSize)
|
||||
return &b
|
||||
},
|
||||
},
|
||||
}
|
||||
go f.cleanup()
|
||||
return f
|
||||
}
|
||||
|
||||
// Stop stops the UDP forwarder and all active connections
|
||||
func (f *udpForwarder) Stop() {
|
||||
f.cancel()
|
||||
|
||||
f.Lock()
|
||||
defer f.Unlock()
|
||||
|
||||
for id, conn := range f.conns {
|
||||
conn.cancel()
|
||||
if err := conn.conn.Close(); err != nil {
|
||||
log.Errorf("forwarder: UDP conn close error for %v: %v", id, err)
|
||||
}
|
||||
if err := conn.outConn.Close(); err != nil {
|
||||
log.Errorf("forwarder: UDP outConn close error for %v: %v", id, err)
|
||||
}
|
||||
delete(f.conns, id)
|
||||
}
|
||||
}
|
||||
|
||||
// cleanup periodically removes idle UDP connections
|
||||
func (f *udpForwarder) cleanup() {
|
||||
ticker := time.NewTicker(time.Minute)
|
||||
defer ticker.Stop()
|
||||
|
||||
for range ticker.C {
|
||||
for {
|
||||
select {
|
||||
case <-f.ctx.Done():
|
||||
return
|
||||
case <-ticker.C:
|
||||
f.Lock()
|
||||
now := time.Now()
|
||||
for addr, conn := range f.conns {
|
||||
for id, conn := range f.conns {
|
||||
if now.Sub(conn.lastTime) > udpTimeout {
|
||||
conn.conn.Close()
|
||||
conn.outConn.Close()
|
||||
delete(f.conns, addr)
|
||||
conn.cancel()
|
||||
if err := conn.conn.Close(); err != nil {
|
||||
log.Errorf("forwarder: UDP conn close error for %v: %v", id, err)
|
||||
}
|
||||
if err := conn.outConn.Close(); err != nil {
|
||||
log.Errorf("forwarder: UDP outConn close error for %v: %v", id, err)
|
||||
}
|
||||
delete(f.conns, id)
|
||||
log.Debugf("forwarder: cleaned up idle UDP connection %v", id)
|
||||
}
|
||||
}
|
||||
f.Unlock()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// handleUDP is called by the UDP forwarder for new packets
|
||||
func (f *Forwarder) handleUDP(r *udp.ForwarderRequest) {
|
||||
id := r.ID()
|
||||
dstAddr := fmt.Sprintf("%s:%d", id.LocalAddress.String(), id.LocalPort)
|
||||
|
||||
if f.ctx.Err() != nil {
|
||||
log.Debug("forwarder: context done, dropping UDP packet")
|
||||
return
|
||||
}
|
||||
|
||||
// Create wait queue for blocking syscalls
|
||||
wq := waiter.Queue{}
|
||||
|
||||
ep, err := r.CreateEndpoint(&wq)
|
||||
if err != nil {
|
||||
log.Errorf("Create UDP endpoint error: %v", err)
|
||||
log.Errorf("forwarder: failed to create UDP endpoint: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
@ -72,82 +124,115 @@ func (f *Forwarder) handleUDP(r *udp.ForwarderRequest) {
|
||||
|
||||
// Try to get existing connection or create a new one
|
||||
f.udpForwarder.Lock()
|
||||
pConn, exists := f.udpForwarder.conns[dstAddr]
|
||||
defer f.udpForwarder.Unlock()
|
||||
|
||||
pConn, exists := f.udpForwarder.conns[id]
|
||||
if !exists {
|
||||
outConn, err := net.Dial("udp", dstAddr)
|
||||
outConn, err := (&net.Dialer{}).DialContext(f.ctx, "udp", dstAddr)
|
||||
if err != nil {
|
||||
f.udpForwarder.Unlock()
|
||||
if err := inConn.Close(); err != nil {
|
||||
log.Errorf("forwader: UDP inConn close error: %v", err)
|
||||
log.Errorf("forwarder: UDP inConn close error for %v: %v", id, err)
|
||||
}
|
||||
log.Errorf("forwarder> UDP dial error: %v", err)
|
||||
log.Errorf("forwarder: UDP dial error for %v: %v", id, err)
|
||||
return
|
||||
}
|
||||
|
||||
connCtx, connCancel := context.WithCancel(f.ctx)
|
||||
pConn = &udpPacketConn{
|
||||
conn: inConn,
|
||||
outConn: outConn,
|
||||
lastTime: time.Now(),
|
||||
cancel: connCancel,
|
||||
}
|
||||
f.udpForwarder.conns[dstAddr] = pConn
|
||||
f.udpForwarder.conns[id] = pConn
|
||||
|
||||
go f.proxyUDP(pConn, dstAddr)
|
||||
go f.proxyUDP(connCtx, pConn, id)
|
||||
}
|
||||
f.udpForwarder.Unlock()
|
||||
}
|
||||
|
||||
func (f *Forwarder) proxyUDP(pConn *udpPacketConn, dstAddr string) {
|
||||
func (f *Forwarder) proxyUDP(ctx context.Context, pConn *udpPacketConn, id stack.TransportEndpointID) {
|
||||
defer func() {
|
||||
pConn.cancel()
|
||||
if err := pConn.conn.Close(); err != nil {
|
||||
log.Errorf("forwarder: inConn close error: %v", err)
|
||||
log.Errorf("forwarder: UDP inConn close error for %v: %v", id, err)
|
||||
}
|
||||
if err := pConn.outConn.Close(); err != nil {
|
||||
log.Errorf("forwarder: outConn close error: %v", err)
|
||||
log.Errorf("forwarder: UDP outConn close error for %v: %v", id, err)
|
||||
}
|
||||
}()
|
||||
|
||||
var wg sync.WaitGroup
|
||||
wg.Add(2)
|
||||
|
||||
// Handle outbound to inbound traffic
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
f.copyUDP(pConn.conn, pConn.outConn, dstAddr, "outbound->inbound")
|
||||
}()
|
||||
|
||||
// Handle inbound to outbound traffic
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
f.copyUDP(pConn.outConn, pConn.conn, dstAddr, "inbound->outbound")
|
||||
}()
|
||||
|
||||
wg.Wait()
|
||||
|
||||
// Clean up the connection from the map
|
||||
f.udpForwarder.Lock()
|
||||
delete(f.udpForwarder.conns, dstAddr)
|
||||
delete(f.udpForwarder.conns, id)
|
||||
f.udpForwarder.Unlock()
|
||||
}()
|
||||
|
||||
errChan := make(chan error, 2)
|
||||
|
||||
go func() {
|
||||
errChan <- f.copyUDP(ctx, pConn.conn, pConn.outConn, id, "outbound->inbound")
|
||||
}()
|
||||
|
||||
go func() {
|
||||
errChan <- f.copyUDP(ctx, pConn.outConn, pConn.conn, id, "inbound->outbound")
|
||||
}()
|
||||
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return
|
||||
case err := <-errChan:
|
||||
if err != nil && !isClosedError(err) {
|
||||
log.Errorf("forwader: UDP proxy error for %v: %v", id, err)
|
||||
}
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
func (f *Forwarder) copyUDP(ctx context.Context, dst net.Conn, src net.Conn, id stack.TransportEndpointID, direction string) error {
|
||||
bufp := f.udpForwarder.bufPool.Get().(*[]byte)
|
||||
defer f.udpForwarder.bufPool.Put(bufp)
|
||||
buffer := *bufp
|
||||
|
||||
if err := src.SetReadDeadline(time.Now().Add(udpTimeout)); err != nil {
|
||||
return fmt.Errorf("set read deadline: %w", err)
|
||||
}
|
||||
if err := src.SetWriteDeadline(time.Now().Add(udpTimeout)); err != nil {
|
||||
return fmt.Errorf("set write deadline: %w", err)
|
||||
}
|
||||
|
||||
func (f *Forwarder) copyUDP(dst net.Conn, src net.Conn, dstAddr, direction string) {
|
||||
buffer := make([]byte, 65535)
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return ctx.Err()
|
||||
default:
|
||||
n, err := src.Read(buffer)
|
||||
if err != nil {
|
||||
log.Errorf("UDP %s read error: %v", direction, err)
|
||||
return
|
||||
if isTimeout(err) {
|
||||
continue
|
||||
}
|
||||
return fmt.Errorf("read from %s: %w", direction, err)
|
||||
}
|
||||
|
||||
_, err = dst.Write(buffer[:n])
|
||||
if err != nil {
|
||||
log.Errorf("UDP %s write error: %v", direction, err)
|
||||
continue
|
||||
return fmt.Errorf("write to %s: %w", direction, err)
|
||||
}
|
||||
|
||||
f.udpForwarder.Lock()
|
||||
if conn, ok := f.udpForwarder.conns[dstAddr]; ok {
|
||||
if conn, ok := f.udpForwarder.conns[id]; ok {
|
||||
conn.lastTime = time.Now()
|
||||
}
|
||||
f.udpForwarder.Unlock()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func isClosedError(err error) bool {
|
||||
return errors.Is(err, net.ErrClosed) || errors.Is(err, context.Canceled)
|
||||
}
|
||||
|
||||
func isTimeout(err error) bool {
|
||||
var netErr net.Error
|
||||
if errors.As(err, &netErr) {
|
||||
return netErr.Timeout()
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
Loading…
x
Reference in New Issue
Block a user