mirror of
https://github.com/netbirdio/netbird.git
synced 2025-06-21 10:18:50 +02:00
Add stop methods and improve udp implementation
This commit is contained in:
parent
b43a8c56df
commit
fad82ee65c
@ -30,6 +30,10 @@ func (m *Manager) Reset(stateManager *statemanager.Manager) error {
|
|||||||
m.tcpTracker = conntrack.NewTCPTracker(conntrack.DefaultTCPTimeout)
|
m.tcpTracker = conntrack.NewTCPTracker(conntrack.DefaultTCPTimeout)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if m.forwarder != nil {
|
||||||
|
m.forwarder.Stop()
|
||||||
|
}
|
||||||
|
|
||||||
if m.nativeFirewall != nil {
|
if m.nativeFirewall != nil {
|
||||||
return m.nativeFirewall.Reset(stateManager)
|
return m.nativeFirewall.Reset(stateManager)
|
||||||
}
|
}
|
||||||
|
@ -42,6 +42,10 @@ func (m *Manager) Reset(*statemanager.Manager) error {
|
|||||||
m.tcpTracker = conntrack.NewTCPTracker(conntrack.DefaultTCPTimeout)
|
m.tcpTracker = conntrack.NewTCPTracker(conntrack.DefaultTCPTimeout)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if m.forwarder != nil {
|
||||||
|
m.forwarder.Stop()
|
||||||
|
}
|
||||||
|
|
||||||
if !isWindowsFirewallReachable() {
|
if !isWindowsFirewallReachable() {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
@ -1,6 +1,7 @@
|
|||||||
package forwarder
|
package forwarder
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"context"
|
||||||
"fmt"
|
"fmt"
|
||||||
|
|
||||||
log "github.com/sirupsen/logrus"
|
log "github.com/sirupsen/logrus"
|
||||||
@ -24,9 +25,12 @@ type Forwarder struct {
|
|||||||
stack *stack.Stack
|
stack *stack.Stack
|
||||||
endpoint *endpoint
|
endpoint *endpoint
|
||||||
udpForwarder *udpForwarder
|
udpForwarder *udpForwarder
|
||||||
|
ctx context.Context
|
||||||
|
cancel context.CancelFunc
|
||||||
}
|
}
|
||||||
|
|
||||||
func New(iface common.IFaceMapper) (*Forwarder, error) {
|
func New(iface common.IFaceMapper) (*Forwarder, error) {
|
||||||
|
|
||||||
s := stack.New(stack.Options{
|
s := stack.New(stack.Options{
|
||||||
NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol},
|
NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol},
|
||||||
TransportProtocols: []stack.TransportProtocolFactory{
|
TransportProtocols: []stack.TransportProtocolFactory{
|
||||||
@ -85,10 +89,13 @@ func New(iface common.IFaceMapper) (*Forwarder, error) {
|
|||||||
},
|
},
|
||||||
})
|
})
|
||||||
|
|
||||||
|
ctx, cancel := context.WithCancel(context.Background())
|
||||||
f := &Forwarder{
|
f := &Forwarder{
|
||||||
stack: s,
|
stack: s,
|
||||||
endpoint: endpoint,
|
endpoint: endpoint,
|
||||||
udpForwarder: newUDPForwarder(),
|
udpForwarder: newUDPForwarder(),
|
||||||
|
ctx: ctx,
|
||||||
|
cancel: cancel,
|
||||||
}
|
}
|
||||||
|
|
||||||
// Set up TCP forwarder
|
// Set up TCP forwarder
|
||||||
@ -118,3 +125,17 @@ func (f *Forwarder) InjectIncomingPacket(payload []byte) error {
|
|||||||
}
|
}
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Stop gracefully shuts down the forwarder
|
||||||
|
func (f *Forwarder) Stop() error {
|
||||||
|
f.cancel()
|
||||||
|
|
||||||
|
if f.udpForwarder != nil {
|
||||||
|
f.udpForwarder.Stop()
|
||||||
|
}
|
||||||
|
|
||||||
|
f.stack.Close()
|
||||||
|
f.stack.Wait()
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
@ -1,10 +1,10 @@
|
|||||||
package forwarder
|
package forwarder
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"context"
|
||||||
"fmt"
|
"fmt"
|
||||||
"io"
|
"io"
|
||||||
"net"
|
"net"
|
||||||
"sync"
|
|
||||||
|
|
||||||
log "github.com/sirupsen/logrus"
|
log "github.com/sirupsen/logrus"
|
||||||
"gvisor.dev/gvisor/pkg/tcpip/adapters/gonet"
|
"gvisor.dev/gvisor/pkg/tcpip/adapters/gonet"
|
||||||
@ -20,9 +20,7 @@ func (f *Forwarder) handleTCP(r *tcp.ForwarderRequest) {
|
|||||||
dstPort := id.LocalPort
|
dstPort := id.LocalPort
|
||||||
dialAddr := fmt.Sprintf("%s:%d", dstAddr.String(), dstPort)
|
dialAddr := fmt.Sprintf("%s:%d", dstAddr.String(), dstPort)
|
||||||
|
|
||||||
// Dial the destination first
|
outConn, err := (&net.Dialer{}).DialContext(f.ctx, "tcp", dialAddr)
|
||||||
dialer := net.Dialer{}
|
|
||||||
outConn, err := dialer.Dial("tcp", dialAddr)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
r.Complete(true)
|
r.Complete(true)
|
||||||
return
|
return
|
||||||
@ -40,8 +38,7 @@ func (f *Forwarder) handleTCP(r *tcp.ForwarderRequest) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
// Now that we've successfully connected to the destination,
|
// Complete the handshake
|
||||||
// we can complete the incoming connection
|
|
||||||
r.Complete(false)
|
r.Complete(false)
|
||||||
|
|
||||||
inConn := gonet.NewTCPConn(&wq, ep)
|
inConn := gonet.NewTCPConn(&wq, ep)
|
||||||
@ -59,24 +56,35 @@ func (f *Forwarder) proxyTCP(inConn *gonet.TCPConn, outConn net.Conn) {
|
|||||||
}
|
}
|
||||||
}()
|
}()
|
||||||
|
|
||||||
var wg sync.WaitGroup
|
// Create context for managing the proxy goroutines
|
||||||
wg.Add(2)
|
ctx, cancel := context.WithCancel(f.ctx)
|
||||||
|
defer cancel()
|
||||||
|
|
||||||
|
errChan := make(chan error, 2)
|
||||||
|
|
||||||
go func() {
|
go func() {
|
||||||
defer wg.Done()
|
n, err := io.Copy(outConn, inConn)
|
||||||
_, err := io.Copy(outConn, inConn)
|
if err != nil && !isClosedError(err) {
|
||||||
if err != nil {
|
log.Errorf("proxyTCP: inbound->outbound copy error after %d bytes: %v", n, err)
|
||||||
log.Errorf("proxyTCP: copy error: %v", err)
|
|
||||||
}
|
}
|
||||||
|
errChan <- err
|
||||||
}()
|
}()
|
||||||
|
|
||||||
go func() {
|
go func() {
|
||||||
defer wg.Done()
|
n, err := io.Copy(inConn, outConn)
|
||||||
_, err := io.Copy(inConn, outConn)
|
if err != nil && !isClosedError(err) {
|
||||||
if err != nil {
|
log.Errorf("proxyTCP: outbound->inbound copy error after %d bytes: %v", n, err)
|
||||||
log.Errorf("proxyTCP: copy error: %v", err)
|
|
||||||
}
|
}
|
||||||
|
errChan <- err
|
||||||
}()
|
}()
|
||||||
|
|
||||||
wg.Wait()
|
select {
|
||||||
|
case <-ctx.Done():
|
||||||
|
return
|
||||||
|
case err := <-errChan:
|
||||||
|
if err != nil && !isClosedError(err) {
|
||||||
|
log.Errorf("proxyTCP: copy error: %v", err)
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
@ -1,6 +1,8 @@
|
|||||||
package forwarder
|
package forwarder
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"context"
|
||||||
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"net"
|
"net"
|
||||||
"sync"
|
"sync"
|
||||||
@ -8,49 +10,94 @@ import (
|
|||||||
|
|
||||||
log "github.com/sirupsen/logrus"
|
log "github.com/sirupsen/logrus"
|
||||||
"gvisor.dev/gvisor/pkg/tcpip/adapters/gonet"
|
"gvisor.dev/gvisor/pkg/tcpip/adapters/gonet"
|
||||||
|
"gvisor.dev/gvisor/pkg/tcpip/stack"
|
||||||
"gvisor.dev/gvisor/pkg/tcpip/transport/udp"
|
"gvisor.dev/gvisor/pkg/tcpip/transport/udp"
|
||||||
"gvisor.dev/gvisor/pkg/waiter"
|
"gvisor.dev/gvisor/pkg/waiter"
|
||||||
)
|
)
|
||||||
|
|
||||||
const (
|
const (
|
||||||
udpTimeout = 60 * time.Second
|
udpTimeout = 60 * time.Second
|
||||||
|
maxPacketSize = 65535
|
||||||
)
|
)
|
||||||
|
|
||||||
type udpPacketConn struct {
|
type udpPacketConn struct {
|
||||||
conn *gonet.UDPConn
|
conn *gonet.UDPConn
|
||||||
outConn net.Conn
|
outConn net.Conn
|
||||||
lastTime time.Time
|
lastTime time.Time
|
||||||
|
cancel context.CancelFunc
|
||||||
}
|
}
|
||||||
|
|
||||||
type udpForwarder struct {
|
type udpForwarder struct {
|
||||||
sync.RWMutex
|
sync.RWMutex
|
||||||
conns map[string]*udpPacketConn
|
conns map[stack.TransportEndpointID]*udpPacketConn
|
||||||
|
bufPool sync.Pool
|
||||||
|
ctx context.Context
|
||||||
|
cancel context.CancelFunc
|
||||||
}
|
}
|
||||||
|
|
||||||
func newUDPForwarder() *udpForwarder {
|
func newUDPForwarder() *udpForwarder {
|
||||||
|
ctx, cancel := context.WithCancel(context.Background())
|
||||||
f := &udpForwarder{
|
f := &udpForwarder{
|
||||||
conns: make(map[string]*udpPacketConn),
|
conns: make(map[stack.TransportEndpointID]*udpPacketConn),
|
||||||
|
ctx: ctx,
|
||||||
|
cancel: cancel,
|
||||||
|
bufPool: sync.Pool{
|
||||||
|
New: func() any {
|
||||||
|
b := make([]byte, maxPacketSize)
|
||||||
|
return &b
|
||||||
|
},
|
||||||
|
},
|
||||||
}
|
}
|
||||||
go f.cleanup()
|
go f.cleanup()
|
||||||
return f
|
return f
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Stop stops the UDP forwarder and all active connections
|
||||||
|
func (f *udpForwarder) Stop() {
|
||||||
|
f.cancel()
|
||||||
|
|
||||||
|
f.Lock()
|
||||||
|
defer f.Unlock()
|
||||||
|
|
||||||
|
for id, conn := range f.conns {
|
||||||
|
conn.cancel()
|
||||||
|
if err := conn.conn.Close(); err != nil {
|
||||||
|
log.Errorf("forwarder: UDP conn close error for %v: %v", id, err)
|
||||||
|
}
|
||||||
|
if err := conn.outConn.Close(); err != nil {
|
||||||
|
log.Errorf("forwarder: UDP outConn close error for %v: %v", id, err)
|
||||||
|
}
|
||||||
|
delete(f.conns, id)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// cleanup periodically removes idle UDP connections
|
// cleanup periodically removes idle UDP connections
|
||||||
func (f *udpForwarder) cleanup() {
|
func (f *udpForwarder) cleanup() {
|
||||||
ticker := time.NewTicker(time.Minute)
|
ticker := time.NewTicker(time.Minute)
|
||||||
defer ticker.Stop()
|
defer ticker.Stop()
|
||||||
|
|
||||||
for range ticker.C {
|
for {
|
||||||
f.Lock()
|
select {
|
||||||
now := time.Now()
|
case <-f.ctx.Done():
|
||||||
for addr, conn := range f.conns {
|
return
|
||||||
if now.Sub(conn.lastTime) > udpTimeout {
|
case <-ticker.C:
|
||||||
conn.conn.Close()
|
f.Lock()
|
||||||
conn.outConn.Close()
|
now := time.Now()
|
||||||
delete(f.conns, addr)
|
for id, conn := range f.conns {
|
||||||
|
if now.Sub(conn.lastTime) > udpTimeout {
|
||||||
|
conn.cancel()
|
||||||
|
if err := conn.conn.Close(); err != nil {
|
||||||
|
log.Errorf("forwarder: UDP conn close error for %v: %v", id, err)
|
||||||
|
}
|
||||||
|
if err := conn.outConn.Close(); err != nil {
|
||||||
|
log.Errorf("forwarder: UDP outConn close error for %v: %v", id, err)
|
||||||
|
}
|
||||||
|
delete(f.conns, id)
|
||||||
|
log.Debugf("forwarder: cleaned up idle UDP connection %v", id)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
f.Unlock()
|
||||||
}
|
}
|
||||||
f.Unlock()
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -59,12 +106,17 @@ func (f *Forwarder) handleUDP(r *udp.ForwarderRequest) {
|
|||||||
id := r.ID()
|
id := r.ID()
|
||||||
dstAddr := fmt.Sprintf("%s:%d", id.LocalAddress.String(), id.LocalPort)
|
dstAddr := fmt.Sprintf("%s:%d", id.LocalAddress.String(), id.LocalPort)
|
||||||
|
|
||||||
|
if f.ctx.Err() != nil {
|
||||||
|
log.Debug("forwarder: context done, dropping UDP packet")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
// Create wait queue for blocking syscalls
|
// Create wait queue for blocking syscalls
|
||||||
wq := waiter.Queue{}
|
wq := waiter.Queue{}
|
||||||
|
|
||||||
ep, err := r.CreateEndpoint(&wq)
|
ep, err := r.CreateEndpoint(&wq)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Errorf("Create UDP endpoint error: %v", err)
|
log.Errorf("forwarder: failed to create UDP endpoint: %v", err)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -72,82 +124,115 @@ func (f *Forwarder) handleUDP(r *udp.ForwarderRequest) {
|
|||||||
|
|
||||||
// Try to get existing connection or create a new one
|
// Try to get existing connection or create a new one
|
||||||
f.udpForwarder.Lock()
|
f.udpForwarder.Lock()
|
||||||
pConn, exists := f.udpForwarder.conns[dstAddr]
|
defer f.udpForwarder.Unlock()
|
||||||
|
|
||||||
|
pConn, exists := f.udpForwarder.conns[id]
|
||||||
if !exists {
|
if !exists {
|
||||||
outConn, err := net.Dial("udp", dstAddr)
|
outConn, err := (&net.Dialer{}).DialContext(f.ctx, "udp", dstAddr)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
f.udpForwarder.Unlock()
|
|
||||||
if err := inConn.Close(); err != nil {
|
if err := inConn.Close(); err != nil {
|
||||||
log.Errorf("forwader: UDP inConn close error: %v", err)
|
log.Errorf("forwarder: UDP inConn close error for %v: %v", id, err)
|
||||||
}
|
}
|
||||||
log.Errorf("forwarder> UDP dial error: %v", err)
|
log.Errorf("forwarder: UDP dial error for %v: %v", id, err)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
connCtx, connCancel := context.WithCancel(f.ctx)
|
||||||
pConn = &udpPacketConn{
|
pConn = &udpPacketConn{
|
||||||
conn: inConn,
|
conn: inConn,
|
||||||
outConn: outConn,
|
outConn: outConn,
|
||||||
lastTime: time.Now(),
|
lastTime: time.Now(),
|
||||||
|
cancel: connCancel,
|
||||||
}
|
}
|
||||||
f.udpForwarder.conns[dstAddr] = pConn
|
f.udpForwarder.conns[id] = pConn
|
||||||
|
|
||||||
go f.proxyUDP(pConn, dstAddr)
|
go f.proxyUDP(connCtx, pConn, id)
|
||||||
}
|
}
|
||||||
f.udpForwarder.Unlock()
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (f *Forwarder) proxyUDP(pConn *udpPacketConn, dstAddr string) {
|
func (f *Forwarder) proxyUDP(ctx context.Context, pConn *udpPacketConn, id stack.TransportEndpointID) {
|
||||||
defer func() {
|
defer func() {
|
||||||
|
pConn.cancel()
|
||||||
if err := pConn.conn.Close(); err != nil {
|
if err := pConn.conn.Close(); err != nil {
|
||||||
log.Errorf("forwarder: inConn close error: %v", err)
|
log.Errorf("forwarder: UDP inConn close error for %v: %v", id, err)
|
||||||
}
|
}
|
||||||
if err := pConn.outConn.Close(); err != nil {
|
if err := pConn.outConn.Close(); err != nil {
|
||||||
log.Errorf("forwarder: outConn close error: %v", err)
|
log.Errorf("forwarder: UDP outConn close error for %v: %v", id, err)
|
||||||
}
|
|
||||||
}()
|
|
||||||
|
|
||||||
var wg sync.WaitGroup
|
|
||||||
wg.Add(2)
|
|
||||||
|
|
||||||
// Handle outbound to inbound traffic
|
|
||||||
go func() {
|
|
||||||
defer wg.Done()
|
|
||||||
f.copyUDP(pConn.conn, pConn.outConn, dstAddr, "outbound->inbound")
|
|
||||||
}()
|
|
||||||
|
|
||||||
// Handle inbound to outbound traffic
|
|
||||||
go func() {
|
|
||||||
defer wg.Done()
|
|
||||||
f.copyUDP(pConn.outConn, pConn.conn, dstAddr, "inbound->outbound")
|
|
||||||
}()
|
|
||||||
|
|
||||||
wg.Wait()
|
|
||||||
|
|
||||||
// Clean up the connection from the map
|
|
||||||
f.udpForwarder.Lock()
|
|
||||||
delete(f.udpForwarder.conns, dstAddr)
|
|
||||||
f.udpForwarder.Unlock()
|
|
||||||
}
|
|
||||||
|
|
||||||
func (f *Forwarder) copyUDP(dst net.Conn, src net.Conn, dstAddr, direction string) {
|
|
||||||
buffer := make([]byte, 65535)
|
|
||||||
for {
|
|
||||||
n, err := src.Read(buffer)
|
|
||||||
if err != nil {
|
|
||||||
log.Errorf("UDP %s read error: %v", direction, err)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
_, err = dst.Write(buffer[:n])
|
|
||||||
if err != nil {
|
|
||||||
log.Errorf("UDP %s write error: %v", direction, err)
|
|
||||||
continue
|
|
||||||
}
|
}
|
||||||
|
|
||||||
f.udpForwarder.Lock()
|
f.udpForwarder.Lock()
|
||||||
if conn, ok := f.udpForwarder.conns[dstAddr]; ok {
|
delete(f.udpForwarder.conns, id)
|
||||||
conn.lastTime = time.Now()
|
|
||||||
}
|
|
||||||
f.udpForwarder.Unlock()
|
f.udpForwarder.Unlock()
|
||||||
|
}()
|
||||||
|
|
||||||
|
errChan := make(chan error, 2)
|
||||||
|
|
||||||
|
go func() {
|
||||||
|
errChan <- f.copyUDP(ctx, pConn.conn, pConn.outConn, id, "outbound->inbound")
|
||||||
|
}()
|
||||||
|
|
||||||
|
go func() {
|
||||||
|
errChan <- f.copyUDP(ctx, pConn.outConn, pConn.conn, id, "inbound->outbound")
|
||||||
|
}()
|
||||||
|
|
||||||
|
select {
|
||||||
|
case <-ctx.Done():
|
||||||
|
return
|
||||||
|
case err := <-errChan:
|
||||||
|
if err != nil && !isClosedError(err) {
|
||||||
|
log.Errorf("forwader: UDP proxy error for %v: %v", id, err)
|
||||||
|
}
|
||||||
|
return
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (f *Forwarder) copyUDP(ctx context.Context, dst net.Conn, src net.Conn, id stack.TransportEndpointID, direction string) error {
|
||||||
|
bufp := f.udpForwarder.bufPool.Get().(*[]byte)
|
||||||
|
defer f.udpForwarder.bufPool.Put(bufp)
|
||||||
|
buffer := *bufp
|
||||||
|
|
||||||
|
if err := src.SetReadDeadline(time.Now().Add(udpTimeout)); err != nil {
|
||||||
|
return fmt.Errorf("set read deadline: %w", err)
|
||||||
|
}
|
||||||
|
if err := src.SetWriteDeadline(time.Now().Add(udpTimeout)); err != nil {
|
||||||
|
return fmt.Errorf("set write deadline: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
for {
|
||||||
|
select {
|
||||||
|
case <-ctx.Done():
|
||||||
|
return ctx.Err()
|
||||||
|
default:
|
||||||
|
n, err := src.Read(buffer)
|
||||||
|
if err != nil {
|
||||||
|
if isTimeout(err) {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
return fmt.Errorf("read from %s: %w", direction, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
_, err = dst.Write(buffer[:n])
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("write to %s: %w", direction, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
f.udpForwarder.Lock()
|
||||||
|
if conn, ok := f.udpForwarder.conns[id]; ok {
|
||||||
|
conn.lastTime = time.Now()
|
||||||
|
}
|
||||||
|
f.udpForwarder.Unlock()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func isClosedError(err error) bool {
|
||||||
|
return errors.Is(err, net.ErrClosed) || errors.Is(err, context.Canceled)
|
||||||
|
}
|
||||||
|
|
||||||
|
func isTimeout(err error) bool {
|
||||||
|
var netErr net.Error
|
||||||
|
if errors.As(err, &netErr) {
|
||||||
|
return netErr.Timeout()
|
||||||
|
}
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
Loading…
x
Reference in New Issue
Block a user