Add stop methods and improve udp implementation

This commit is contained in:
Viktor Liu 2024-12-30 13:34:51 +01:00
parent b43a8c56df
commit fad82ee65c
5 changed files with 205 additions and 83 deletions

View File

@ -30,6 +30,10 @@ func (m *Manager) Reset(stateManager *statemanager.Manager) error {
m.tcpTracker = conntrack.NewTCPTracker(conntrack.DefaultTCPTimeout)
}
if m.forwarder != nil {
m.forwarder.Stop()
}
if m.nativeFirewall != nil {
return m.nativeFirewall.Reset(stateManager)
}

View File

@ -42,6 +42,10 @@ func (m *Manager) Reset(*statemanager.Manager) error {
m.tcpTracker = conntrack.NewTCPTracker(conntrack.DefaultTCPTimeout)
}
if m.forwarder != nil {
m.forwarder.Stop()
}
if !isWindowsFirewallReachable() {
return nil
}

View File

@ -1,6 +1,7 @@
package forwarder
import (
"context"
"fmt"
log "github.com/sirupsen/logrus"
@ -24,9 +25,12 @@ type Forwarder struct {
stack *stack.Stack
endpoint *endpoint
udpForwarder *udpForwarder
ctx context.Context
cancel context.CancelFunc
}
func New(iface common.IFaceMapper) (*Forwarder, error) {
s := stack.New(stack.Options{
NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol},
TransportProtocols: []stack.TransportProtocolFactory{
@ -85,10 +89,13 @@ func New(iface common.IFaceMapper) (*Forwarder, error) {
},
})
ctx, cancel := context.WithCancel(context.Background())
f := &Forwarder{
stack: s,
endpoint: endpoint,
udpForwarder: newUDPForwarder(),
ctx: ctx,
cancel: cancel,
}
// Set up TCP forwarder
@ -118,3 +125,17 @@ func (f *Forwarder) InjectIncomingPacket(payload []byte) error {
}
return nil
}
// Stop gracefully shuts down the forwarder
func (f *Forwarder) Stop() error {
f.cancel()
if f.udpForwarder != nil {
f.udpForwarder.Stop()
}
f.stack.Close()
f.stack.Wait()
return nil
}

View File

@ -1,10 +1,10 @@
package forwarder
import (
"context"
"fmt"
"io"
"net"
"sync"
log "github.com/sirupsen/logrus"
"gvisor.dev/gvisor/pkg/tcpip/adapters/gonet"
@ -20,9 +20,7 @@ func (f *Forwarder) handleTCP(r *tcp.ForwarderRequest) {
dstPort := id.LocalPort
dialAddr := fmt.Sprintf("%s:%d", dstAddr.String(), dstPort)
// Dial the destination first
dialer := net.Dialer{}
outConn, err := dialer.Dial("tcp", dialAddr)
outConn, err := (&net.Dialer{}).DialContext(f.ctx, "tcp", dialAddr)
if err != nil {
r.Complete(true)
return
@ -40,8 +38,7 @@ func (f *Forwarder) handleTCP(r *tcp.ForwarderRequest) {
return
}
// Now that we've successfully connected to the destination,
// we can complete the incoming connection
// Complete the handshake
r.Complete(false)
inConn := gonet.NewTCPConn(&wq, ep)
@ -59,24 +56,35 @@ func (f *Forwarder) proxyTCP(inConn *gonet.TCPConn, outConn net.Conn) {
}
}()
var wg sync.WaitGroup
wg.Add(2)
// Create context for managing the proxy goroutines
ctx, cancel := context.WithCancel(f.ctx)
defer cancel()
errChan := make(chan error, 2)
go func() {
defer wg.Done()
_, err := io.Copy(outConn, inConn)
if err != nil {
log.Errorf("proxyTCP: copy error: %v", err)
n, err := io.Copy(outConn, inConn)
if err != nil && !isClosedError(err) {
log.Errorf("proxyTCP: inbound->outbound copy error after %d bytes: %v", n, err)
}
errChan <- err
}()
go func() {
defer wg.Done()
_, err := io.Copy(inConn, outConn)
if err != nil {
log.Errorf("proxyTCP: copy error: %v", err)
n, err := io.Copy(inConn, outConn)
if err != nil && !isClosedError(err) {
log.Errorf("proxyTCP: outbound->inbound copy error after %d bytes: %v", n, err)
}
errChan <- err
}()
wg.Wait()
select {
case <-ctx.Done():
return
case err := <-errChan:
if err != nil && !isClosedError(err) {
log.Errorf("proxyTCP: copy error: %v", err)
}
return
}
}

View File

@ -1,6 +1,8 @@
package forwarder
import (
"context"
"errors"
"fmt"
"net"
"sync"
@ -8,50 +10,95 @@ import (
log "github.com/sirupsen/logrus"
"gvisor.dev/gvisor/pkg/tcpip/adapters/gonet"
"gvisor.dev/gvisor/pkg/tcpip/stack"
"gvisor.dev/gvisor/pkg/tcpip/transport/udp"
"gvisor.dev/gvisor/pkg/waiter"
)
const (
udpTimeout = 60 * time.Second
maxPacketSize = 65535
)
type udpPacketConn struct {
conn *gonet.UDPConn
outConn net.Conn
lastTime time.Time
cancel context.CancelFunc
}
type udpForwarder struct {
sync.RWMutex
conns map[string]*udpPacketConn
conns map[stack.TransportEndpointID]*udpPacketConn
bufPool sync.Pool
ctx context.Context
cancel context.CancelFunc
}
func newUDPForwarder() *udpForwarder {
ctx, cancel := context.WithCancel(context.Background())
f := &udpForwarder{
conns: make(map[string]*udpPacketConn),
conns: make(map[stack.TransportEndpointID]*udpPacketConn),
ctx: ctx,
cancel: cancel,
bufPool: sync.Pool{
New: func() any {
b := make([]byte, maxPacketSize)
return &b
},
},
}
go f.cleanup()
return f
}
// Stop stops the UDP forwarder and all active connections
func (f *udpForwarder) Stop() {
f.cancel()
f.Lock()
defer f.Unlock()
for id, conn := range f.conns {
conn.cancel()
if err := conn.conn.Close(); err != nil {
log.Errorf("forwarder: UDP conn close error for %v: %v", id, err)
}
if err := conn.outConn.Close(); err != nil {
log.Errorf("forwarder: UDP outConn close error for %v: %v", id, err)
}
delete(f.conns, id)
}
}
// cleanup periodically removes idle UDP connections
func (f *udpForwarder) cleanup() {
ticker := time.NewTicker(time.Minute)
defer ticker.Stop()
for range ticker.C {
for {
select {
case <-f.ctx.Done():
return
case <-ticker.C:
f.Lock()
now := time.Now()
for addr, conn := range f.conns {
for id, conn := range f.conns {
if now.Sub(conn.lastTime) > udpTimeout {
conn.conn.Close()
conn.outConn.Close()
delete(f.conns, addr)
conn.cancel()
if err := conn.conn.Close(); err != nil {
log.Errorf("forwarder: UDP conn close error for %v: %v", id, err)
}
if err := conn.outConn.Close(); err != nil {
log.Errorf("forwarder: UDP outConn close error for %v: %v", id, err)
}
delete(f.conns, id)
log.Debugf("forwarder: cleaned up idle UDP connection %v", id)
}
}
f.Unlock()
}
}
}
// handleUDP is called by the UDP forwarder for new packets
@ -59,12 +106,17 @@ func (f *Forwarder) handleUDP(r *udp.ForwarderRequest) {
id := r.ID()
dstAddr := fmt.Sprintf("%s:%d", id.LocalAddress.String(), id.LocalPort)
if f.ctx.Err() != nil {
log.Debug("forwarder: context done, dropping UDP packet")
return
}
// Create wait queue for blocking syscalls
wq := waiter.Queue{}
ep, err := r.CreateEndpoint(&wq)
if err != nil {
log.Errorf("Create UDP endpoint error: %v", err)
log.Errorf("forwarder: failed to create UDP endpoint: %v", err)
return
}
@ -72,82 +124,115 @@ func (f *Forwarder) handleUDP(r *udp.ForwarderRequest) {
// Try to get existing connection or create a new one
f.udpForwarder.Lock()
pConn, exists := f.udpForwarder.conns[dstAddr]
defer f.udpForwarder.Unlock()
pConn, exists := f.udpForwarder.conns[id]
if !exists {
outConn, err := net.Dial("udp", dstAddr)
outConn, err := (&net.Dialer{}).DialContext(f.ctx, "udp", dstAddr)
if err != nil {
f.udpForwarder.Unlock()
if err := inConn.Close(); err != nil {
log.Errorf("forwader: UDP inConn close error: %v", err)
log.Errorf("forwarder: UDP inConn close error for %v: %v", id, err)
}
log.Errorf("forwarder> UDP dial error: %v", err)
log.Errorf("forwarder: UDP dial error for %v: %v", id, err)
return
}
connCtx, connCancel := context.WithCancel(f.ctx)
pConn = &udpPacketConn{
conn: inConn,
outConn: outConn,
lastTime: time.Now(),
cancel: connCancel,
}
f.udpForwarder.conns[dstAddr] = pConn
f.udpForwarder.conns[id] = pConn
go f.proxyUDP(pConn, dstAddr)
go f.proxyUDP(connCtx, pConn, id)
}
f.udpForwarder.Unlock()
}
func (f *Forwarder) proxyUDP(pConn *udpPacketConn, dstAddr string) {
func (f *Forwarder) proxyUDP(ctx context.Context, pConn *udpPacketConn, id stack.TransportEndpointID) {
defer func() {
pConn.cancel()
if err := pConn.conn.Close(); err != nil {
log.Errorf("forwarder: inConn close error: %v", err)
log.Errorf("forwarder: UDP inConn close error for %v: %v", id, err)
}
if err := pConn.outConn.Close(); err != nil {
log.Errorf("forwarder: outConn close error: %v", err)
log.Errorf("forwarder: UDP outConn close error for %v: %v", id, err)
}
}()
var wg sync.WaitGroup
wg.Add(2)
// Handle outbound to inbound traffic
go func() {
defer wg.Done()
f.copyUDP(pConn.conn, pConn.outConn, dstAddr, "outbound->inbound")
}()
// Handle inbound to outbound traffic
go func() {
defer wg.Done()
f.copyUDP(pConn.outConn, pConn.conn, dstAddr, "inbound->outbound")
}()
wg.Wait()
// Clean up the connection from the map
f.udpForwarder.Lock()
delete(f.udpForwarder.conns, dstAddr)
delete(f.udpForwarder.conns, id)
f.udpForwarder.Unlock()
}()
errChan := make(chan error, 2)
go func() {
errChan <- f.copyUDP(ctx, pConn.conn, pConn.outConn, id, "outbound->inbound")
}()
go func() {
errChan <- f.copyUDP(ctx, pConn.outConn, pConn.conn, id, "inbound->outbound")
}()
select {
case <-ctx.Done():
return
case err := <-errChan:
if err != nil && !isClosedError(err) {
log.Errorf("forwader: UDP proxy error for %v: %v", id, err)
}
return
}
}
func (f *Forwarder) copyUDP(dst net.Conn, src net.Conn, dstAddr, direction string) {
buffer := make([]byte, 65535)
func (f *Forwarder) copyUDP(ctx context.Context, dst net.Conn, src net.Conn, id stack.TransportEndpointID, direction string) error {
bufp := f.udpForwarder.bufPool.Get().(*[]byte)
defer f.udpForwarder.bufPool.Put(bufp)
buffer := *bufp
if err := src.SetReadDeadline(time.Now().Add(udpTimeout)); err != nil {
return fmt.Errorf("set read deadline: %w", err)
}
if err := src.SetWriteDeadline(time.Now().Add(udpTimeout)); err != nil {
return fmt.Errorf("set write deadline: %w", err)
}
for {
select {
case <-ctx.Done():
return ctx.Err()
default:
n, err := src.Read(buffer)
if err != nil {
log.Errorf("UDP %s read error: %v", direction, err)
return
if isTimeout(err) {
continue
}
return fmt.Errorf("read from %s: %w", direction, err)
}
_, err = dst.Write(buffer[:n])
if err != nil {
log.Errorf("UDP %s write error: %v", direction, err)
continue
return fmt.Errorf("write to %s: %w", direction, err)
}
f.udpForwarder.Lock()
if conn, ok := f.udpForwarder.conns[dstAddr]; ok {
if conn, ok := f.udpForwarder.conns[id]; ok {
conn.lastTime = time.Now()
}
f.udpForwarder.Unlock()
}
}
}
func isClosedError(err error) bool {
return errors.Is(err, net.ErrClosed) || errors.Is(err, context.Canceled)
}
func isTimeout(err error) bool {
var netErr net.Error
if errors.As(err, &netErr) {
return netErr.Timeout()
}
return false
}