Add stop methods and improve udp implementation

This commit is contained in:
Viktor Liu 2024-12-30 13:34:51 +01:00
parent b43a8c56df
commit fad82ee65c
5 changed files with 205 additions and 83 deletions

View File

@ -30,6 +30,10 @@ func (m *Manager) Reset(stateManager *statemanager.Manager) error {
m.tcpTracker = conntrack.NewTCPTracker(conntrack.DefaultTCPTimeout) m.tcpTracker = conntrack.NewTCPTracker(conntrack.DefaultTCPTimeout)
} }
if m.forwarder != nil {
m.forwarder.Stop()
}
if m.nativeFirewall != nil { if m.nativeFirewall != nil {
return m.nativeFirewall.Reset(stateManager) return m.nativeFirewall.Reset(stateManager)
} }

View File

@ -42,6 +42,10 @@ func (m *Manager) Reset(*statemanager.Manager) error {
m.tcpTracker = conntrack.NewTCPTracker(conntrack.DefaultTCPTimeout) m.tcpTracker = conntrack.NewTCPTracker(conntrack.DefaultTCPTimeout)
} }
if m.forwarder != nil {
m.forwarder.Stop()
}
if !isWindowsFirewallReachable() { if !isWindowsFirewallReachable() {
return nil return nil
} }

View File

@ -1,6 +1,7 @@
package forwarder package forwarder
import ( import (
"context"
"fmt" "fmt"
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
@ -24,9 +25,12 @@ type Forwarder struct {
stack *stack.Stack stack *stack.Stack
endpoint *endpoint endpoint *endpoint
udpForwarder *udpForwarder udpForwarder *udpForwarder
ctx context.Context
cancel context.CancelFunc
} }
func New(iface common.IFaceMapper) (*Forwarder, error) { func New(iface common.IFaceMapper) (*Forwarder, error) {
s := stack.New(stack.Options{ s := stack.New(stack.Options{
NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol}, NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol},
TransportProtocols: []stack.TransportProtocolFactory{ TransportProtocols: []stack.TransportProtocolFactory{
@ -85,10 +89,13 @@ func New(iface common.IFaceMapper) (*Forwarder, error) {
}, },
}) })
ctx, cancel := context.WithCancel(context.Background())
f := &Forwarder{ f := &Forwarder{
stack: s, stack: s,
endpoint: endpoint, endpoint: endpoint,
udpForwarder: newUDPForwarder(), udpForwarder: newUDPForwarder(),
ctx: ctx,
cancel: cancel,
} }
// Set up TCP forwarder // Set up TCP forwarder
@ -118,3 +125,17 @@ func (f *Forwarder) InjectIncomingPacket(payload []byte) error {
} }
return nil return nil
} }
// Stop gracefully shuts down the forwarder
func (f *Forwarder) Stop() error {
f.cancel()
if f.udpForwarder != nil {
f.udpForwarder.Stop()
}
f.stack.Close()
f.stack.Wait()
return nil
}

View File

@ -1,10 +1,10 @@
package forwarder package forwarder
import ( import (
"context"
"fmt" "fmt"
"io" "io"
"net" "net"
"sync"
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
"gvisor.dev/gvisor/pkg/tcpip/adapters/gonet" "gvisor.dev/gvisor/pkg/tcpip/adapters/gonet"
@ -20,9 +20,7 @@ func (f *Forwarder) handleTCP(r *tcp.ForwarderRequest) {
dstPort := id.LocalPort dstPort := id.LocalPort
dialAddr := fmt.Sprintf("%s:%d", dstAddr.String(), dstPort) dialAddr := fmt.Sprintf("%s:%d", dstAddr.String(), dstPort)
// Dial the destination first outConn, err := (&net.Dialer{}).DialContext(f.ctx, "tcp", dialAddr)
dialer := net.Dialer{}
outConn, err := dialer.Dial("tcp", dialAddr)
if err != nil { if err != nil {
r.Complete(true) r.Complete(true)
return return
@ -40,8 +38,7 @@ func (f *Forwarder) handleTCP(r *tcp.ForwarderRequest) {
return return
} }
// Now that we've successfully connected to the destination, // Complete the handshake
// we can complete the incoming connection
r.Complete(false) r.Complete(false)
inConn := gonet.NewTCPConn(&wq, ep) inConn := gonet.NewTCPConn(&wq, ep)
@ -59,24 +56,35 @@ func (f *Forwarder) proxyTCP(inConn *gonet.TCPConn, outConn net.Conn) {
} }
}() }()
var wg sync.WaitGroup // Create context for managing the proxy goroutines
wg.Add(2) ctx, cancel := context.WithCancel(f.ctx)
defer cancel()
errChan := make(chan error, 2)
go func() { go func() {
defer wg.Done() n, err := io.Copy(outConn, inConn)
_, err := io.Copy(outConn, inConn) if err != nil && !isClosedError(err) {
if err != nil { log.Errorf("proxyTCP: inbound->outbound copy error after %d bytes: %v", n, err)
log.Errorf("proxyTCP: copy error: %v", err)
} }
errChan <- err
}() }()
go func() { go func() {
defer wg.Done() n, err := io.Copy(inConn, outConn)
_, err := io.Copy(inConn, outConn) if err != nil && !isClosedError(err) {
if err != nil { log.Errorf("proxyTCP: outbound->inbound copy error after %d bytes: %v", n, err)
log.Errorf("proxyTCP: copy error: %v", err)
} }
errChan <- err
}() }()
wg.Wait() select {
case <-ctx.Done():
return
case err := <-errChan:
if err != nil && !isClosedError(err) {
log.Errorf("proxyTCP: copy error: %v", err)
}
return
}
} }

View File

@ -1,6 +1,8 @@
package forwarder package forwarder
import ( import (
"context"
"errors"
"fmt" "fmt"
"net" "net"
"sync" "sync"
@ -8,63 +10,113 @@ import (
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
"gvisor.dev/gvisor/pkg/tcpip/adapters/gonet" "gvisor.dev/gvisor/pkg/tcpip/adapters/gonet"
"gvisor.dev/gvisor/pkg/tcpip/stack"
"gvisor.dev/gvisor/pkg/tcpip/transport/udp" "gvisor.dev/gvisor/pkg/tcpip/transport/udp"
"gvisor.dev/gvisor/pkg/waiter" "gvisor.dev/gvisor/pkg/waiter"
) )
const ( const (
udpTimeout = 60 * time.Second udpTimeout = 60 * time.Second
maxPacketSize = 65535
) )
type udpPacketConn struct { type udpPacketConn struct {
conn *gonet.UDPConn conn *gonet.UDPConn
outConn net.Conn outConn net.Conn
lastTime time.Time lastTime time.Time
cancel context.CancelFunc
} }
type udpForwarder struct { type udpForwarder struct {
sync.RWMutex sync.RWMutex
conns map[string]*udpPacketConn conns map[stack.TransportEndpointID]*udpPacketConn
bufPool sync.Pool
ctx context.Context
cancel context.CancelFunc
} }
func newUDPForwarder() *udpForwarder { func newUDPForwarder() *udpForwarder {
ctx, cancel := context.WithCancel(context.Background())
f := &udpForwarder{ f := &udpForwarder{
conns: make(map[string]*udpPacketConn), conns: make(map[stack.TransportEndpointID]*udpPacketConn),
ctx: ctx,
cancel: cancel,
bufPool: sync.Pool{
New: func() any {
b := make([]byte, maxPacketSize)
return &b
},
},
} }
go f.cleanup() go f.cleanup()
return f return f
} }
// Stop stops the UDP forwarder and all active connections
func (f *udpForwarder) Stop() {
f.cancel()
f.Lock()
defer f.Unlock()
for id, conn := range f.conns {
conn.cancel()
if err := conn.conn.Close(); err != nil {
log.Errorf("forwarder: UDP conn close error for %v: %v", id, err)
}
if err := conn.outConn.Close(); err != nil {
log.Errorf("forwarder: UDP outConn close error for %v: %v", id, err)
}
delete(f.conns, id)
}
}
// cleanup periodically removes idle UDP connections // cleanup periodically removes idle UDP connections
func (f *udpForwarder) cleanup() { func (f *udpForwarder) cleanup() {
ticker := time.NewTicker(time.Minute) ticker := time.NewTicker(time.Minute)
defer ticker.Stop() defer ticker.Stop()
for range ticker.C { for {
select {
case <-f.ctx.Done():
return
case <-ticker.C:
f.Lock() f.Lock()
now := time.Now() now := time.Now()
for addr, conn := range f.conns { for id, conn := range f.conns {
if now.Sub(conn.lastTime) > udpTimeout { if now.Sub(conn.lastTime) > udpTimeout {
conn.conn.Close() conn.cancel()
conn.outConn.Close() if err := conn.conn.Close(); err != nil {
delete(f.conns, addr) log.Errorf("forwarder: UDP conn close error for %v: %v", id, err)
}
if err := conn.outConn.Close(); err != nil {
log.Errorf("forwarder: UDP outConn close error for %v: %v", id, err)
}
delete(f.conns, id)
log.Debugf("forwarder: cleaned up idle UDP connection %v", id)
} }
} }
f.Unlock() f.Unlock()
} }
} }
}
// handleUDP is called by the UDP forwarder for new packets // handleUDP is called by the UDP forwarder for new packets
func (f *Forwarder) handleUDP(r *udp.ForwarderRequest) { func (f *Forwarder) handleUDP(r *udp.ForwarderRequest) {
id := r.ID() id := r.ID()
dstAddr := fmt.Sprintf("%s:%d", id.LocalAddress.String(), id.LocalPort) dstAddr := fmt.Sprintf("%s:%d", id.LocalAddress.String(), id.LocalPort)
if f.ctx.Err() != nil {
log.Debug("forwarder: context done, dropping UDP packet")
return
}
// Create wait queue for blocking syscalls // Create wait queue for blocking syscalls
wq := waiter.Queue{} wq := waiter.Queue{}
ep, err := r.CreateEndpoint(&wq) ep, err := r.CreateEndpoint(&wq)
if err != nil { if err != nil {
log.Errorf("Create UDP endpoint error: %v", err) log.Errorf("forwarder: failed to create UDP endpoint: %v", err)
return return
} }
@ -72,82 +124,115 @@ func (f *Forwarder) handleUDP(r *udp.ForwarderRequest) {
// Try to get existing connection or create a new one // Try to get existing connection or create a new one
f.udpForwarder.Lock() f.udpForwarder.Lock()
pConn, exists := f.udpForwarder.conns[dstAddr] defer f.udpForwarder.Unlock()
pConn, exists := f.udpForwarder.conns[id]
if !exists { if !exists {
outConn, err := net.Dial("udp", dstAddr) outConn, err := (&net.Dialer{}).DialContext(f.ctx, "udp", dstAddr)
if err != nil { if err != nil {
f.udpForwarder.Unlock()
if err := inConn.Close(); err != nil { if err := inConn.Close(); err != nil {
log.Errorf("forwader: UDP inConn close error: %v", err) log.Errorf("forwarder: UDP inConn close error for %v: %v", id, err)
} }
log.Errorf("forwarder> UDP dial error: %v", err) log.Errorf("forwarder: UDP dial error for %v: %v", id, err)
return return
} }
connCtx, connCancel := context.WithCancel(f.ctx)
pConn = &udpPacketConn{ pConn = &udpPacketConn{
conn: inConn, conn: inConn,
outConn: outConn, outConn: outConn,
lastTime: time.Now(), lastTime: time.Now(),
cancel: connCancel,
} }
f.udpForwarder.conns[dstAddr] = pConn f.udpForwarder.conns[id] = pConn
go f.proxyUDP(pConn, dstAddr) go f.proxyUDP(connCtx, pConn, id)
} }
f.udpForwarder.Unlock()
} }
func (f *Forwarder) proxyUDP(pConn *udpPacketConn, dstAddr string) { func (f *Forwarder) proxyUDP(ctx context.Context, pConn *udpPacketConn, id stack.TransportEndpointID) {
defer func() { defer func() {
pConn.cancel()
if err := pConn.conn.Close(); err != nil { if err := pConn.conn.Close(); err != nil {
log.Errorf("forwarder: inConn close error: %v", err) log.Errorf("forwarder: UDP inConn close error for %v: %v", id, err)
} }
if err := pConn.outConn.Close(); err != nil { if err := pConn.outConn.Close(); err != nil {
log.Errorf("forwarder: outConn close error: %v", err) log.Errorf("forwarder: UDP outConn close error for %v: %v", id, err)
} }
}()
var wg sync.WaitGroup
wg.Add(2)
// Handle outbound to inbound traffic
go func() {
defer wg.Done()
f.copyUDP(pConn.conn, pConn.outConn, dstAddr, "outbound->inbound")
}()
// Handle inbound to outbound traffic
go func() {
defer wg.Done()
f.copyUDP(pConn.outConn, pConn.conn, dstAddr, "inbound->outbound")
}()
wg.Wait()
// Clean up the connection from the map
f.udpForwarder.Lock() f.udpForwarder.Lock()
delete(f.udpForwarder.conns, dstAddr) delete(f.udpForwarder.conns, id)
f.udpForwarder.Unlock() f.udpForwarder.Unlock()
}()
errChan := make(chan error, 2)
go func() {
errChan <- f.copyUDP(ctx, pConn.conn, pConn.outConn, id, "outbound->inbound")
}()
go func() {
errChan <- f.copyUDP(ctx, pConn.outConn, pConn.conn, id, "inbound->outbound")
}()
select {
case <-ctx.Done():
return
case err := <-errChan:
if err != nil && !isClosedError(err) {
log.Errorf("forwader: UDP proxy error for %v: %v", id, err)
}
return
}
}
func (f *Forwarder) copyUDP(ctx context.Context, dst net.Conn, src net.Conn, id stack.TransportEndpointID, direction string) error {
bufp := f.udpForwarder.bufPool.Get().(*[]byte)
defer f.udpForwarder.bufPool.Put(bufp)
buffer := *bufp
if err := src.SetReadDeadline(time.Now().Add(udpTimeout)); err != nil {
return fmt.Errorf("set read deadline: %w", err)
}
if err := src.SetWriteDeadline(time.Now().Add(udpTimeout)); err != nil {
return fmt.Errorf("set write deadline: %w", err)
} }
func (f *Forwarder) copyUDP(dst net.Conn, src net.Conn, dstAddr, direction string) {
buffer := make([]byte, 65535)
for { for {
select {
case <-ctx.Done():
return ctx.Err()
default:
n, err := src.Read(buffer) n, err := src.Read(buffer)
if err != nil { if err != nil {
log.Errorf("UDP %s read error: %v", direction, err) if isTimeout(err) {
return continue
}
return fmt.Errorf("read from %s: %w", direction, err)
} }
_, err = dst.Write(buffer[:n]) _, err = dst.Write(buffer[:n])
if err != nil { if err != nil {
log.Errorf("UDP %s write error: %v", direction, err) return fmt.Errorf("write to %s: %w", direction, err)
continue
} }
f.udpForwarder.Lock() f.udpForwarder.Lock()
if conn, ok := f.udpForwarder.conns[dstAddr]; ok { if conn, ok := f.udpForwarder.conns[id]; ok {
conn.lastTime = time.Now() conn.lastTime = time.Now()
} }
f.udpForwarder.Unlock() f.udpForwarder.Unlock()
} }
} }
}
func isClosedError(err error) bool {
return errors.Is(err, net.ErrClosed) || errors.Is(err, context.Canceled)
}
func isTimeout(err error) bool {
var netErr net.Error
if errors.As(err, &netErr) {
return netErr.Timeout()
}
return false
}