Add logger

This commit is contained in:
Viktor Liu 2024-12-30 15:18:21 +01:00
parent fad82ee65c
commit d2616544fe
17 changed files with 436 additions and 62 deletions

View File

@ -3,6 +3,11 @@
package uspfilter package uspfilter
import ( import (
"context"
"time"
log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/client/firewall/uspfilter/conntrack" "github.com/netbirdio/netbird/client/firewall/uspfilter/conntrack"
"github.com/netbirdio/netbird/client/internal/statemanager" "github.com/netbirdio/netbird/client/internal/statemanager"
) )
@ -17,23 +22,31 @@ func (m *Manager) Reset(stateManager *statemanager.Manager) error {
if m.udpTracker != nil { if m.udpTracker != nil {
m.udpTracker.Close() m.udpTracker.Close()
m.udpTracker = conntrack.NewUDPTracker(conntrack.DefaultUDPTimeout) m.udpTracker = conntrack.NewUDPTracker(conntrack.DefaultUDPTimeout, m.logger)
} }
if m.icmpTracker != nil { if m.icmpTracker != nil {
m.icmpTracker.Close() m.icmpTracker.Close()
m.icmpTracker = conntrack.NewICMPTracker(conntrack.DefaultICMPTimeout) m.icmpTracker = conntrack.NewICMPTracker(conntrack.DefaultICMPTimeout, m.logger)
} }
if m.tcpTracker != nil { if m.tcpTracker != nil {
m.tcpTracker.Close() m.tcpTracker.Close()
m.tcpTracker = conntrack.NewTCPTracker(conntrack.DefaultTCPTimeout) m.tcpTracker = conntrack.NewTCPTracker(conntrack.DefaultTCPTimeout, m.logger)
} }
if m.forwarder != nil { if m.forwarder != nil {
m.forwarder.Stop() m.forwarder.Stop()
} }
if m.logger != nil {
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
defer cancel()
if err := m.logger.Stop(ctx); err != nil {
log.Errorf("failed to shutdown logger: %v", err)
}
}
if m.nativeFirewall != nil { if m.nativeFirewall != nil {
return m.nativeFirewall.Reset(stateManager) return m.nativeFirewall.Reset(stateManager)
} }

View File

@ -1,9 +1,11 @@
package uspfilter package uspfilter
import ( import (
"context"
"fmt" "fmt"
"os/exec" "os/exec"
"syscall" "syscall"
"time"
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
@ -29,23 +31,31 @@ func (m *Manager) Reset(*statemanager.Manager) error {
if m.udpTracker != nil { if m.udpTracker != nil {
m.udpTracker.Close() m.udpTracker.Close()
m.udpTracker = conntrack.NewUDPTracker(conntrack.DefaultUDPTimeout) m.udpTracker = conntrack.NewUDPTracker(conntrack.DefaultUDPTimeout, m.logger)
} }
if m.icmpTracker != nil { if m.icmpTracker != nil {
m.icmpTracker.Close() m.icmpTracker.Close()
m.icmpTracker = conntrack.NewICMPTracker(conntrack.DefaultICMPTimeout) m.icmpTracker = conntrack.NewICMPTracker(conntrack.DefaultICMPTimeout, m.logger)
} }
if m.tcpTracker != nil { if m.tcpTracker != nil {
m.tcpTracker.Close() m.tcpTracker.Close()
m.tcpTracker = conntrack.NewTCPTracker(conntrack.DefaultTCPTimeout) m.tcpTracker = conntrack.NewTCPTracker(conntrack.DefaultTCPTimeout, m.logger)
} }
if m.forwarder != nil { if m.forwarder != nil {
m.forwarder.Stop() m.forwarder.Stop()
} }
if m.logger != nil {
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
defer cancel()
if err := m.logger.Stop(ctx); err != nil {
log.Errorf("failed to shutdown logger: %v", err)
}
}
if !isWindowsFirewallReachable() { if !isWindowsFirewallReachable() {
return nil return nil
} }

View File

@ -64,7 +64,7 @@ func BenchmarkAtomicOperations(b *testing.B) {
// Memory pressure tests // Memory pressure tests
func BenchmarkMemoryPressure(b *testing.B) { func BenchmarkMemoryPressure(b *testing.B) {
b.Run("TCPHighLoad", func(b *testing.B) { b.Run("TCPHighLoad", func(b *testing.B) {
tracker := NewTCPTracker(DefaultTCPTimeout) tracker := NewTCPTracker(DefaultTCPTimeout, nil)
defer tracker.Close() defer tracker.Close()
// Generate different IPs // Generate different IPs
@ -89,7 +89,7 @@ func BenchmarkMemoryPressure(b *testing.B) {
}) })
b.Run("UDPHighLoad", func(b *testing.B) { b.Run("UDPHighLoad", func(b *testing.B) {
tracker := NewUDPTracker(DefaultUDPTimeout) tracker := NewUDPTracker(DefaultUDPTimeout, nil)
defer tracker.Close() defer tracker.Close()
// Generate different IPs // Generate different IPs

View File

@ -6,6 +6,8 @@ import (
"time" "time"
"github.com/google/gopacket/layers" "github.com/google/gopacket/layers"
nblog "github.com/netbirdio/netbird/client/firewall/uspfilter/log"
) )
const ( const (
@ -33,6 +35,7 @@ type ICMPConnTrack struct {
// ICMPTracker manages ICMP connection states // ICMPTracker manages ICMP connection states
type ICMPTracker struct { type ICMPTracker struct {
logger *nblog.Logger
connections map[ICMPConnKey]*ICMPConnTrack connections map[ICMPConnKey]*ICMPConnTrack
timeout time.Duration timeout time.Duration
cleanupTicker *time.Ticker cleanupTicker *time.Ticker
@ -42,12 +45,13 @@ type ICMPTracker struct {
} }
// NewICMPTracker creates a new ICMP connection tracker // NewICMPTracker creates a new ICMP connection tracker
func NewICMPTracker(timeout time.Duration) *ICMPTracker { func NewICMPTracker(timeout time.Duration, logger *nblog.Logger) *ICMPTracker {
if timeout == 0 { if timeout == 0 {
timeout = DefaultICMPTimeout timeout = DefaultICMPTimeout
} }
tracker := &ICMPTracker{ tracker := &ICMPTracker{
logger: logger,
connections: make(map[ICMPConnKey]*ICMPConnTrack), connections: make(map[ICMPConnKey]*ICMPConnTrack),
timeout: timeout, timeout: timeout,
cleanupTicker: time.NewTicker(ICMPCleanupInterval), cleanupTicker: time.NewTicker(ICMPCleanupInterval),
@ -83,6 +87,8 @@ func (t *ICMPTracker) TrackOutbound(srcIP net.IP, dstIP net.IP, id uint16, seq u
conn.lastSeen.Store(now) conn.lastSeen.Store(now)
conn.established.Store(true) conn.established.Store(true)
t.connections[key] = conn t.connections[key] = conn
t.logger.Trace("New ICMP connection %v", key)
} }
t.mutex.Unlock() t.mutex.Unlock()
@ -141,6 +147,8 @@ func (t *ICMPTracker) cleanup() {
t.ipPool.Put(conn.SourceIP) t.ipPool.Put(conn.SourceIP)
t.ipPool.Put(conn.DestIP) t.ipPool.Put(conn.DestIP)
delete(t.connections, key) delete(t.connections, key)
t.logger.Debug("ICMPTracker: removed connection %v", key)
} }
} }
} }

View File

@ -7,7 +7,7 @@ import (
func BenchmarkICMPTracker(b *testing.B) { func BenchmarkICMPTracker(b *testing.B) {
b.Run("TrackOutbound", func(b *testing.B) { b.Run("TrackOutbound", func(b *testing.B) {
tracker := NewICMPTracker(DefaultICMPTimeout) tracker := NewICMPTracker(DefaultICMPTimeout, nil)
defer tracker.Close() defer tracker.Close()
srcIP := net.ParseIP("192.168.1.1") srcIP := net.ParseIP("192.168.1.1")
@ -20,7 +20,7 @@ func BenchmarkICMPTracker(b *testing.B) {
}) })
b.Run("IsValidInbound", func(b *testing.B) { b.Run("IsValidInbound", func(b *testing.B) {
tracker := NewICMPTracker(DefaultICMPTimeout) tracker := NewICMPTracker(DefaultICMPTimeout, nil)
defer tracker.Close() defer tracker.Close()
srcIP := net.ParseIP("192.168.1.1") srcIP := net.ParseIP("192.168.1.1")

View File

@ -6,6 +6,8 @@ import (
"net" "net"
"sync" "sync"
"time" "time"
nblog "github.com/netbirdio/netbird/client/firewall/uspfilter/log"
) )
const ( const (
@ -67,6 +69,7 @@ type TCPConnTrack struct {
// TCPTracker manages TCP connection states // TCPTracker manages TCP connection states
type TCPTracker struct { type TCPTracker struct {
logger *nblog.Logger
connections map[ConnKey]*TCPConnTrack connections map[ConnKey]*TCPConnTrack
mutex sync.RWMutex mutex sync.RWMutex
cleanupTicker *time.Ticker cleanupTicker *time.Ticker
@ -76,8 +79,9 @@ type TCPTracker struct {
} }
// NewTCPTracker creates a new TCP connection tracker // NewTCPTracker creates a new TCP connection tracker
func NewTCPTracker(timeout time.Duration) *TCPTracker { func NewTCPTracker(timeout time.Duration, logger *nblog.Logger) *TCPTracker {
tracker := &TCPTracker{ tracker := &TCPTracker{
logger: logger,
connections: make(map[ConnKey]*TCPConnTrack), connections: make(map[ConnKey]*TCPConnTrack),
cleanupTicker: time.NewTicker(TCPCleanupInterval), cleanupTicker: time.NewTicker(TCPCleanupInterval),
done: make(chan struct{}), done: make(chan struct{}),
@ -116,6 +120,8 @@ func (t *TCPTracker) TrackOutbound(srcIP net.IP, dstIP net.IP, srcPort uint16, d
conn.lastSeen.Store(now) conn.lastSeen.Store(now)
conn.established.Store(false) conn.established.Store(false)
t.connections[key] = conn t.connections[key] = conn
t.logger.Trace("New TCP connection: %s:%d -> %s:%d", srcIP, srcPort, dstIP, dstPort)
} }
t.mutex.Unlock() t.mutex.Unlock()
@ -318,6 +324,8 @@ func (t *TCPTracker) cleanup() {
t.ipPool.Put(conn.SourceIP) t.ipPool.Put(conn.SourceIP)
t.ipPool.Put(conn.DestIP) t.ipPool.Put(conn.DestIP)
delete(t.connections, key) delete(t.connections, key)
t.logger.Trace("Closed TCP connection: %s:%d -> %s:%d", conn.SourceIP, conn.SourcePort, conn.DestIP, conn.DestPort)
} }
} }
} }

View File

@ -9,7 +9,7 @@ import (
) )
func TestTCPStateMachine(t *testing.T) { func TestTCPStateMachine(t *testing.T) {
tracker := NewTCPTracker(DefaultTCPTimeout) tracker := NewTCPTracker(DefaultTCPTimeout, nil)
defer tracker.Close() defer tracker.Close()
srcIP := net.ParseIP("100.64.0.1") srcIP := net.ParseIP("100.64.0.1")
@ -154,7 +154,7 @@ func TestTCPStateMachine(t *testing.T) {
t.Run(tt.name, func(t *testing.T) { t.Run(tt.name, func(t *testing.T) {
t.Helper() t.Helper()
tracker = NewTCPTracker(DefaultTCPTimeout) tracker = NewTCPTracker(DefaultTCPTimeout, nil)
tt.test(t) tt.test(t)
}) })
} }
@ -162,7 +162,7 @@ func TestTCPStateMachine(t *testing.T) {
} }
func TestRSTHandling(t *testing.T) { func TestRSTHandling(t *testing.T) {
tracker := NewTCPTracker(DefaultTCPTimeout) tracker := NewTCPTracker(DefaultTCPTimeout, nil)
defer tracker.Close() defer tracker.Close()
srcIP := net.ParseIP("100.64.0.1") srcIP := net.ParseIP("100.64.0.1")
@ -233,7 +233,7 @@ func establishConnection(t *testing.T, tracker *TCPTracker, srcIP, dstIP net.IP,
func BenchmarkTCPTracker(b *testing.B) { func BenchmarkTCPTracker(b *testing.B) {
b.Run("TrackOutbound", func(b *testing.B) { b.Run("TrackOutbound", func(b *testing.B) {
tracker := NewTCPTracker(DefaultTCPTimeout) tracker := NewTCPTracker(DefaultTCPTimeout, nil)
defer tracker.Close() defer tracker.Close()
srcIP := net.ParseIP("192.168.1.1") srcIP := net.ParseIP("192.168.1.1")
@ -246,7 +246,7 @@ func BenchmarkTCPTracker(b *testing.B) {
}) })
b.Run("IsValidInbound", func(b *testing.B) { b.Run("IsValidInbound", func(b *testing.B) {
tracker := NewTCPTracker(DefaultTCPTimeout) tracker := NewTCPTracker(DefaultTCPTimeout, nil)
defer tracker.Close() defer tracker.Close()
srcIP := net.ParseIP("192.168.1.1") srcIP := net.ParseIP("192.168.1.1")
@ -264,7 +264,7 @@ func BenchmarkTCPTracker(b *testing.B) {
}) })
b.Run("ConcurrentAccess", func(b *testing.B) { b.Run("ConcurrentAccess", func(b *testing.B) {
tracker := NewTCPTracker(DefaultTCPTimeout) tracker := NewTCPTracker(DefaultTCPTimeout, nil)
defer tracker.Close() defer tracker.Close()
srcIP := net.ParseIP("192.168.1.1") srcIP := net.ParseIP("192.168.1.1")
@ -287,7 +287,7 @@ func BenchmarkTCPTracker(b *testing.B) {
// Benchmark connection cleanup // Benchmark connection cleanup
func BenchmarkCleanup(b *testing.B) { func BenchmarkCleanup(b *testing.B) {
b.Run("TCPCleanup", func(b *testing.B) { b.Run("TCPCleanup", func(b *testing.B) {
tracker := NewTCPTracker(100 * time.Millisecond) // Short timeout for testing tracker := NewTCPTracker(100*time.Millisecond, nil) // Short timeout for testing
defer tracker.Close() defer tracker.Close()
// Pre-populate with expired connections // Pre-populate with expired connections

View File

@ -4,6 +4,8 @@ import (
"net" "net"
"sync" "sync"
"time" "time"
nblog "github.com/netbirdio/netbird/client/firewall/uspfilter/log"
) )
const ( const (
@ -20,6 +22,7 @@ type UDPConnTrack struct {
// UDPTracker manages UDP connection states // UDPTracker manages UDP connection states
type UDPTracker struct { type UDPTracker struct {
logger *nblog.Logger
connections map[ConnKey]*UDPConnTrack connections map[ConnKey]*UDPConnTrack
timeout time.Duration timeout time.Duration
cleanupTicker *time.Ticker cleanupTicker *time.Ticker
@ -29,12 +32,13 @@ type UDPTracker struct {
} }
// NewUDPTracker creates a new UDP connection tracker // NewUDPTracker creates a new UDP connection tracker
func NewUDPTracker(timeout time.Duration) *UDPTracker { func NewUDPTracker(timeout time.Duration, logger *nblog.Logger) *UDPTracker {
if timeout == 0 { if timeout == 0 {
timeout = DefaultUDPTimeout timeout = DefaultUDPTimeout
} }
tracker := &UDPTracker{ tracker := &UDPTracker{
logger: logger,
connections: make(map[ConnKey]*UDPConnTrack), connections: make(map[ConnKey]*UDPConnTrack),
timeout: timeout, timeout: timeout,
cleanupTicker: time.NewTicker(UDPCleanupInterval), cleanupTicker: time.NewTicker(UDPCleanupInterval),
@ -70,6 +74,8 @@ func (t *UDPTracker) TrackOutbound(srcIP net.IP, dstIP net.IP, srcPort uint16, d
conn.lastSeen.Store(now) conn.lastSeen.Store(now)
conn.established.Store(true) conn.established.Store(true)
t.connections[key] = conn t.connections[key] = conn
t.logger.Trace("New UDP connection: %s", conn)
} }
t.mutex.Unlock() t.mutex.Unlock()
@ -120,6 +126,8 @@ func (t *UDPTracker) cleanup() {
t.ipPool.Put(conn.SourceIP) t.ipPool.Put(conn.SourceIP)
t.ipPool.Put(conn.DestIP) t.ipPool.Put(conn.DestIP)
delete(t.connections, key) delete(t.connections, key)
t.logger.Trace("UDP connection timed out: %s", conn)
} }
} }
} }

View File

@ -29,7 +29,7 @@ func TestNewUDPTracker(t *testing.T) {
for _, tt := range tests { for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) { t.Run(tt.name, func(t *testing.T) {
tracker := NewUDPTracker(tt.timeout) tracker := NewUDPTracker(tt.timeout, nil)
assert.NotNil(t, tracker) assert.NotNil(t, tracker)
assert.Equal(t, tt.wantTimeout, tracker.timeout) assert.Equal(t, tt.wantTimeout, tracker.timeout)
assert.NotNil(t, tracker.connections) assert.NotNil(t, tracker.connections)
@ -40,7 +40,7 @@ func TestNewUDPTracker(t *testing.T) {
} }
func TestUDPTracker_TrackOutbound(t *testing.T) { func TestUDPTracker_TrackOutbound(t *testing.T) {
tracker := NewUDPTracker(DefaultUDPTimeout) tracker := NewUDPTracker(DefaultUDPTimeout, nil)
defer tracker.Close() defer tracker.Close()
srcIP := net.ParseIP("192.168.1.2") srcIP := net.ParseIP("192.168.1.2")
@ -63,7 +63,7 @@ func TestUDPTracker_TrackOutbound(t *testing.T) {
} }
func TestUDPTracker_IsValidInbound(t *testing.T) { func TestUDPTracker_IsValidInbound(t *testing.T) {
tracker := NewUDPTracker(1 * time.Second) tracker := NewUDPTracker(1*time.Second, nil)
defer tracker.Close() defer tracker.Close()
srcIP := net.ParseIP("192.168.1.2") srcIP := net.ParseIP("192.168.1.2")
@ -211,7 +211,7 @@ func TestUDPTracker_Cleanup(t *testing.T) {
func BenchmarkUDPTracker(b *testing.B) { func BenchmarkUDPTracker(b *testing.B) {
b.Run("TrackOutbound", func(b *testing.B) { b.Run("TrackOutbound", func(b *testing.B) {
tracker := NewUDPTracker(DefaultUDPTimeout) tracker := NewUDPTracker(DefaultUDPTimeout, nil)
defer tracker.Close() defer tracker.Close()
srcIP := net.ParseIP("192.168.1.1") srcIP := net.ParseIP("192.168.1.1")
@ -224,7 +224,7 @@ func BenchmarkUDPTracker(b *testing.B) {
}) })
b.Run("IsValidInbound", func(b *testing.B) { b.Run("IsValidInbound", func(b *testing.B) {
tracker := NewUDPTracker(DefaultUDPTimeout) tracker := NewUDPTracker(DefaultUDPTimeout, nil)
defer tracker.Close() defer tracker.Close()
srcIP := net.ParseIP("192.168.1.1") srcIP := net.ParseIP("192.168.1.1")

View File

@ -1,15 +1,17 @@
package forwarder package forwarder
import ( import (
log "github.com/sirupsen/logrus"
wgdevice "golang.zx2c4.com/wireguard/device" wgdevice "golang.zx2c4.com/wireguard/device"
"gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/tcpip/header" "gvisor.dev/gvisor/pkg/tcpip/header"
"gvisor.dev/gvisor/pkg/tcpip/stack" "gvisor.dev/gvisor/pkg/tcpip/stack"
nblog "github.com/netbirdio/netbird/client/firewall/uspfilter/log"
) )
// endpoint implements stack.LinkEndpoint and handles integration with the wireguard device // endpoint implements stack.LinkEndpoint and handles integration with the wireguard device
type endpoint struct { type endpoint struct {
logger *nblog.Logger
dispatcher stack.NetworkDispatcher dispatcher stack.NetworkDispatcher
device *wgdevice.Device device *wgdevice.Device
mtu uint32 mtu uint32
@ -55,7 +57,7 @@ func (e *endpoint) WritePackets(pkts stack.PacketBufferList) (int, tcpip.Error)
// TODO: handle dest ip addresses outside our network // TODO: handle dest ip addresses outside our network
err := e.device.CreateOutboundPacket(data.AsSlice(), address.AsSlice()) err := e.device.CreateOutboundPacket(data.AsSlice(), address.AsSlice())
if err != nil { if err != nil {
log.Errorf("CreateOutboundPacket: %v", err) e.logger.Error("CreateOutboundPacket: %v", err)
continue continue
} }
written++ written++

View File

@ -14,6 +14,7 @@ import (
"gvisor.dev/gvisor/pkg/tcpip/transport/udp" "gvisor.dev/gvisor/pkg/tcpip/transport/udp"
"github.com/netbirdio/netbird/client/firewall/uspfilter/common" "github.com/netbirdio/netbird/client/firewall/uspfilter/common"
nblog "github.com/netbirdio/netbird/client/firewall/uspfilter/log"
) )
const ( const (
@ -22,6 +23,7 @@ const (
) )
type Forwarder struct { type Forwarder struct {
logger *nblog.Logger
stack *stack.Stack stack *stack.Stack
endpoint *endpoint endpoint *endpoint
udpForwarder *udpForwarder udpForwarder *udpForwarder
@ -29,8 +31,7 @@ type Forwarder struct {
cancel context.CancelFunc cancel context.CancelFunc
} }
func New(iface common.IFaceMapper) (*Forwarder, error) { func New(iface common.IFaceMapper, logger *nblog.Logger) (*Forwarder, error) {
s := stack.New(stack.Options{ s := stack.New(stack.Options{
NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol}, NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol},
TransportProtocols: []stack.TransportProtocolFactory{ TransportProtocols: []stack.TransportProtocolFactory{
@ -46,6 +47,7 @@ func New(iface common.IFaceMapper) (*Forwarder, error) {
} }
nicID := tcpip.NICID(1) nicID := tcpip.NICID(1)
endpoint := &endpoint{ endpoint := &endpoint{
logger: logger,
device: iface.GetWGDevice(), device: iface.GetWGDevice(),
mtu: uint32(mtu), mtu: uint32(mtu),
} }
@ -91,9 +93,10 @@ func New(iface common.IFaceMapper) (*Forwarder, error) {
ctx, cancel := context.WithCancel(context.Background()) ctx, cancel := context.WithCancel(context.Background())
f := &Forwarder{ f := &Forwarder{
logger: logger,
stack: s, stack: s,
endpoint: endpoint, endpoint: endpoint,
udpForwarder: newUDPForwarder(), udpForwarder: newUDPForwarder(logger),
ctx: ctx, ctx: ctx,
cancel: cancel, cancel: cancel,
} }

View File

@ -6,7 +6,6 @@ import (
"io" "io"
"net" "net"
log "github.com/sirupsen/logrus"
"gvisor.dev/gvisor/pkg/tcpip/adapters/gonet" "gvisor.dev/gvisor/pkg/tcpip/adapters/gonet"
"gvisor.dev/gvisor/pkg/tcpip/transport/tcp" "gvisor.dev/gvisor/pkg/tcpip/transport/tcp"
"gvisor.dev/gvisor/pkg/waiter" "gvisor.dev/gvisor/pkg/waiter"
@ -23,16 +22,19 @@ func (f *Forwarder) handleTCP(r *tcp.ForwarderRequest) {
outConn, err := (&net.Dialer{}).DialContext(f.ctx, "tcp", dialAddr) outConn, err := (&net.Dialer{}).DialContext(f.ctx, "tcp", dialAddr)
if err != nil { if err != nil {
r.Complete(true) r.Complete(true)
f.logger.Trace("forwarder: dial error for %v: %v", id, err)
return return
} }
f.logger.Trace("forwarder: established TCP connection to %v", id)
// Create wait queue for blocking syscalls // Create wait queue for blocking syscalls
wq := waiter.Queue{} wq := waiter.Queue{}
ep, err2 := r.CreateEndpoint(&wq) ep, err2 := r.CreateEndpoint(&wq)
if err2 != nil { if err2 != nil {
if err := outConn.Close(); err != nil { if err := outConn.Close(); err != nil {
log.Errorf("forwarder: outConn close error: %v", err) f.logger.Error("forwarder: outConn close error: %v", err)
} }
r.Complete(true) r.Complete(true)
return return
@ -49,10 +51,10 @@ func (f *Forwarder) handleTCP(r *tcp.ForwarderRequest) {
func (f *Forwarder) proxyTCP(inConn *gonet.TCPConn, outConn net.Conn) { func (f *Forwarder) proxyTCP(inConn *gonet.TCPConn, outConn net.Conn) {
defer func() { defer func() {
if err := inConn.Close(); err != nil { if err := inConn.Close(); err != nil {
log.Errorf("forwarder: inConn close error: %v", err) f.logger.Error("forwarder: inConn close error: %v", err)
} }
if err := outConn.Close(); err != nil { if err := outConn.Close(); err != nil {
log.Errorf("forwarder: outConn close error: %v", err) f.logger.Error("forwarder: outConn close error: %v", err)
} }
}() }()
@ -65,7 +67,7 @@ func (f *Forwarder) proxyTCP(inConn *gonet.TCPConn, outConn net.Conn) {
go func() { go func() {
n, err := io.Copy(outConn, inConn) n, err := io.Copy(outConn, inConn)
if err != nil && !isClosedError(err) { if err != nil && !isClosedError(err) {
log.Errorf("proxyTCP: inbound->outbound copy error after %d bytes: %v", n, err) f.logger.Error("inbound->outbound copy error after %d bytes: %v", n, err)
} }
errChan <- err errChan <- err
}() }()
@ -73,7 +75,7 @@ func (f *Forwarder) proxyTCP(inConn *gonet.TCPConn, outConn net.Conn) {
go func() { go func() {
n, err := io.Copy(inConn, outConn) n, err := io.Copy(inConn, outConn)
if err != nil && !isClosedError(err) { if err != nil && !isClosedError(err) {
log.Errorf("proxyTCP: outbound->inbound copy error after %d bytes: %v", n, err) f.logger.Error("outbound->inbound copy error after %d bytes: %v", n, err)
} }
errChan <- err errChan <- err
}() }()
@ -83,7 +85,7 @@ func (f *Forwarder) proxyTCP(inConn *gonet.TCPConn, outConn net.Conn) {
return return
case err := <-errChan: case err := <-errChan:
if err != nil && !isClosedError(err) { if err != nil && !isClosedError(err) {
log.Errorf("proxyTCP: copy error: %v", err) f.logger.Error("proxyTCP: copy error: %v", err)
} }
return return
} }

View File

@ -8,11 +8,12 @@ import (
"sync" "sync"
"time" "time"
log "github.com/sirupsen/logrus"
"gvisor.dev/gvisor/pkg/tcpip/adapters/gonet" "gvisor.dev/gvisor/pkg/tcpip/adapters/gonet"
"gvisor.dev/gvisor/pkg/tcpip/stack" "gvisor.dev/gvisor/pkg/tcpip/stack"
"gvisor.dev/gvisor/pkg/tcpip/transport/udp" "gvisor.dev/gvisor/pkg/tcpip/transport/udp"
"gvisor.dev/gvisor/pkg/waiter" "gvisor.dev/gvisor/pkg/waiter"
nblog "github.com/netbirdio/netbird/client/firewall/uspfilter/log"
) )
const ( const (
@ -29,15 +30,17 @@ type udpPacketConn struct {
type udpForwarder struct { type udpForwarder struct {
sync.RWMutex sync.RWMutex
logger *nblog.Logger
conns map[stack.TransportEndpointID]*udpPacketConn conns map[stack.TransportEndpointID]*udpPacketConn
bufPool sync.Pool bufPool sync.Pool
ctx context.Context ctx context.Context
cancel context.CancelFunc cancel context.CancelFunc
} }
func newUDPForwarder() *udpForwarder { func newUDPForwarder(logger *nblog.Logger) *udpForwarder {
ctx, cancel := context.WithCancel(context.Background()) ctx, cancel := context.WithCancel(context.Background())
f := &udpForwarder{ f := &udpForwarder{
logger: logger,
conns: make(map[stack.TransportEndpointID]*udpPacketConn), conns: make(map[stack.TransportEndpointID]*udpPacketConn),
ctx: ctx, ctx: ctx,
cancel: cancel, cancel: cancel,
@ -62,10 +65,10 @@ func (f *udpForwarder) Stop() {
for id, conn := range f.conns { for id, conn := range f.conns {
conn.cancel() conn.cancel()
if err := conn.conn.Close(); err != nil { if err := conn.conn.Close(); err != nil {
log.Errorf("forwarder: UDP conn close error for %v: %v", id, err) f.logger.Error("forwarder: UDP conn close error for %v: %v", id, err)
} }
if err := conn.outConn.Close(); err != nil { if err := conn.outConn.Close(); err != nil {
log.Errorf("forwarder: UDP outConn close error for %v: %v", id, err) f.logger.Error("forwarder: UDP outConn close error for %v: %v", id, err)
} }
delete(f.conns, id) delete(f.conns, id)
} }
@ -87,13 +90,13 @@ func (f *udpForwarder) cleanup() {
if now.Sub(conn.lastTime) > udpTimeout { if now.Sub(conn.lastTime) > udpTimeout {
conn.cancel() conn.cancel()
if err := conn.conn.Close(); err != nil { if err := conn.conn.Close(); err != nil {
log.Errorf("forwarder: UDP conn close error for %v: %v", id, err) f.logger.Error("forwarder: UDP conn close error for %v: %v", id, err)
} }
if err := conn.outConn.Close(); err != nil { if err := conn.outConn.Close(); err != nil {
log.Errorf("forwarder: UDP outConn close error for %v: %v", id, err) f.logger.Error("forwarder: UDP outConn close error for %v: %v", id, err)
} }
delete(f.conns, id) delete(f.conns, id)
log.Debugf("forwarder: cleaned up idle UDP connection %v", id) f.logger.Trace("forwarder: cleaned up idle UDP connection %v", id)
} }
} }
f.Unlock() f.Unlock()
@ -107,7 +110,7 @@ func (f *Forwarder) handleUDP(r *udp.ForwarderRequest) {
dstAddr := fmt.Sprintf("%s:%d", id.LocalAddress.String(), id.LocalPort) dstAddr := fmt.Sprintf("%s:%d", id.LocalAddress.String(), id.LocalPort)
if f.ctx.Err() != nil { if f.ctx.Err() != nil {
log.Debug("forwarder: context done, dropping UDP packet") f.logger.Trace("forwarder: context done, dropping UDP packet")
return return
} }
@ -116,7 +119,7 @@ func (f *Forwarder) handleUDP(r *udp.ForwarderRequest) {
ep, err := r.CreateEndpoint(&wq) ep, err := r.CreateEndpoint(&wq)
if err != nil { if err != nil {
log.Errorf("forwarder: failed to create UDP endpoint: %v", err) f.logger.Error("forwarder: failed to create UDP endpoint: %v", err)
return return
} }
@ -131,12 +134,16 @@ func (f *Forwarder) handleUDP(r *udp.ForwarderRequest) {
outConn, err := (&net.Dialer{}).DialContext(f.ctx, "udp", dstAddr) outConn, err := (&net.Dialer{}).DialContext(f.ctx, "udp", dstAddr)
if err != nil { if err != nil {
if err := inConn.Close(); err != nil { if err := inConn.Close(); err != nil {
log.Errorf("forwarder: UDP inConn close error for %v: %v", id, err) f.logger.Error("forwarder: UDP inConn close error for %v: %v", id, err)
} }
log.Errorf("forwarder: UDP dial error for %v: %v", id, err) f.logger.Debug("forwarder: UDP dial error for %v: %v", id, err)
// TODO: Send ICMP error message
return return
} }
f.logger.Trace("forwarder: established UDP connection to %v", id)
connCtx, connCancel := context.WithCancel(f.ctx) connCtx, connCancel := context.WithCancel(f.ctx)
pConn = &udpPacketConn{ pConn = &udpPacketConn{
conn: inConn, conn: inConn,
@ -154,10 +161,10 @@ func (f *Forwarder) proxyUDP(ctx context.Context, pConn *udpPacketConn, id stack
defer func() { defer func() {
pConn.cancel() pConn.cancel()
if err := pConn.conn.Close(); err != nil { if err := pConn.conn.Close(); err != nil {
log.Errorf("forwarder: UDP inConn close error for %v: %v", id, err) f.logger.Error("forwarder: UDP inConn close error for %v: %v", id, err)
} }
if err := pConn.outConn.Close(); err != nil { if err := pConn.outConn.Close(); err != nil {
log.Errorf("forwarder: UDP outConn close error for %v: %v", id, err) f.logger.Error("forwarder: UDP outConn close error for %v: %v", id, err)
} }
f.udpForwarder.Lock() f.udpForwarder.Lock()
@ -180,7 +187,7 @@ func (f *Forwarder) proxyUDP(ctx context.Context, pConn *udpPacketConn, id stack
return return
case err := <-errChan: case err := <-errChan:
if err != nil && !isClosedError(err) { if err != nil && !isClosedError(err) {
log.Errorf("forwader: UDP proxy error for %v: %v", id, err) f.logger.Error("proxyUDP: copy error: %v", err)
} }
return return
} }

View File

@ -0,0 +1,208 @@
// Package logger provides a high-performance, non-blocking logger for userspace networking
package log
import (
"context"
"fmt"
"io"
"sync"
"sync/atomic"
"time"
log "github.com/sirupsen/logrus"
)
const (
maxBatchSize = 1024 * 16 // 16KB max batch size
maxMessageSize = 1024 * 2 // 2KB per message
bufferSize = 1024 * 256 // 256KB ring buffer
defaultFlushInterval = 2 * time.Second
)
// Level represents log severity
type Level uint32
const (
LevelPanic Level = iota
LevelFatal
LevelError
LevelWarn
LevelInfo
LevelDebug
LevelTrace
)
var levelStrings = map[Level]string{
LevelPanic: "PANC",
LevelFatal: "FATL",
LevelError: "ERRO",
LevelWarn: "WARN",
LevelInfo: "INFO",
LevelDebug: "DEBG",
LevelTrace: "TRAC",
}
func FromLogrusLevel(level log.Level) Level {
switch level {
case log.TraceLevel:
return LevelTrace
case log.DebugLevel:
return LevelDebug
case log.InfoLevel:
return LevelInfo
case log.WarnLevel:
return LevelWarn
case log.ErrorLevel:
return LevelError
case log.FatalLevel:
return LevelFatal
case log.PanicLevel:
return LevelPanic
default:
return LevelInfo
}
}
// Logger is a high-performance, non-blocking logger
type Logger struct {
output io.Writer
level atomic.Uint32
buffer *ringBuffer
shutdown chan struct{}
wg sync.WaitGroup
// Reusable buffer pool for formatting messages
bufPool sync.Pool
}
func NewFromLogrus(logrusLogger *log.Logger) *Logger {
l := &Logger{
output: logrusLogger.Out,
buffer: newRingBuffer(bufferSize),
shutdown: make(chan struct{}),
bufPool: sync.Pool{
New: func() interface{} {
// Pre-allocate buffer for message formatting
b := make([]byte, 0, maxMessageSize)
return &b
},
},
}
l.level.Store(uint32(LevelInfo))
l.wg.Add(1)
go l.worker()
return l
}
func (l *Logger) SetLevel(level Level) {
l.level.Store(uint32(level))
}
func (l *Logger) formatMessage(buf *[]byte, level Level, format string, args ...interface{}) {
*buf = (*buf)[:0]
// Timestamp
*buf = time.Now().AppendFormat(*buf, "2006-01-02T15:04:05.000000-07:00")
*buf = append(*buf, ' ')
// Level
*buf = append(*buf, levelStrings[level]...)
*buf = append(*buf, ' ')
// Message
if len(args) > 0 {
*buf = append(*buf, fmt.Sprintf(format, args...)...)
} else {
*buf = append(*buf, format...)
}
*buf = append(*buf, '\n')
}
func (l *Logger) log(level Level, format string, args ...interface{}) {
bufp := l.bufPool.Get().(*[]byte)
l.formatMessage(bufp, level, format, args...)
if len(*bufp) > maxMessageSize {
*bufp = (*bufp)[:maxMessageSize]
}
l.buffer.Write(*bufp)
l.bufPool.Put(bufp)
}
func (l *Logger) Trace(format string, args ...interface{}) {
if l.level.Load() <= uint32(LevelTrace) {
l.log(LevelTrace, format, args...)
}
}
func (l *Logger) Debug(format string, args ...interface{}) {
if l.level.Load() <= uint32(LevelDebug) {
l.log(LevelDebug, format, args...)
}
}
func (l *Logger) Info(format string, args ...interface{}) {
if l.level.Load() <= uint32(LevelInfo) {
l.log(LevelInfo, format, args...)
}
}
func (l *Logger) Warn(format string, args ...interface{}) {
if l.level.Load() <= uint32(LevelWarn) {
l.log(LevelWarn, format, args...)
}
}
func (l *Logger) Error(format string, args ...interface{}) {
if l.level.Load() <= uint32(LevelError) {
l.log(LevelError, format, args...)
}
}
// worker periodically flushes the buffer
func (l *Logger) worker() {
defer l.wg.Done()
ticker := time.NewTicker(defaultFlushInterval)
defer ticker.Stop()
buf := make([]byte, 0, maxBatchSize)
for {
select {
case <-l.shutdown:
return
case <-ticker.C:
// Read accumulated messages
n, _ := l.buffer.Read(buf[:cap(buf)])
if n == 0 {
continue
}
// Write batch
l.output.Write(buf[:n])
}
}
}
// Stop gracefully shuts down the logger
func (l *Logger) Stop(ctx context.Context) error {
close(l.shutdown)
done := make(chan struct{})
go func() {
l.wg.Wait()
close(done)
}()
select {
case <-ctx.Done():
return ctx.Err()
case <-done:
return nil
}
}

View File

@ -0,0 +1,93 @@
package log
import "sync"
// ringBuffer is a simple ring buffer implementation
type ringBuffer struct {
buf []byte
size int
r, w int64 // Read and write positions
mu sync.Mutex
}
func newRingBuffer(size int) *ringBuffer {
return &ringBuffer{
buf: make([]byte, size),
size: size,
}
}
func (r *ringBuffer) Write(p []byte) (n int, err error) {
if len(p) == 0 {
return 0, nil
}
r.mu.Lock()
defer r.mu.Unlock()
if len(p) > r.size {
p = p[:r.size]
}
n = len(p)
// Write data, handling wrap-around
pos := int(r.w % int64(r.size))
writeLen := min(len(p), r.size-pos)
copy(r.buf[pos:], p[:writeLen])
// If we have more data and need to wrap around
if writeLen < len(p) {
copy(r.buf, p[writeLen:])
}
// Update write position
r.w += int64(n)
return n, nil
}
func (r *ringBuffer) Read(p []byte) (n int, err error) {
r.mu.Lock()
defer r.mu.Unlock()
if r.w == r.r {
return 0, nil
}
// Calculate available data accounting for wraparound
available := int(r.w - r.r)
if available < 0 {
available += r.size
}
available = min(available, r.size)
// Limit read to buffer size
toRead := min(available, len(p))
if toRead == 0 {
return 0, nil
}
// Read data, handling wrap-around
pos := int(r.r % int64(r.size))
readLen := min(toRead, r.size-pos)
n = copy(p, r.buf[pos:pos+readLen])
// If we need more data and need to wrap around
if readLen < toRead {
n += copy(p[readLen:toRead], r.buf[:toRead-readLen])
}
// Update read position
r.r += int64(n)
return n, nil
}
// min returns the smaller of two integers
func min(a, b int) int {
if a < b {
return a
}
return b
}

View File

@ -17,6 +17,7 @@ import (
"github.com/netbirdio/netbird/client/firewall/uspfilter/common" "github.com/netbirdio/netbird/client/firewall/uspfilter/common"
"github.com/netbirdio/netbird/client/firewall/uspfilter/conntrack" "github.com/netbirdio/netbird/client/firewall/uspfilter/conntrack"
"github.com/netbirdio/netbird/client/firewall/uspfilter/forwarder" "github.com/netbirdio/netbird/client/firewall/uspfilter/forwarder"
nblog "github.com/netbirdio/netbird/client/firewall/uspfilter/log"
"github.com/netbirdio/netbird/client/internal/statemanager" "github.com/netbirdio/netbird/client/internal/statemanager"
) )
@ -52,6 +53,7 @@ type Manager struct {
icmpTracker *conntrack.ICMPTracker icmpTracker *conntrack.ICMPTracker
tcpTracker *conntrack.TCPTracker tcpTracker *conntrack.TCPTracker
forwarder *forwarder.Forwarder forwarder *forwarder.Forwarder
logger *nblog.Logger
} }
// decoder for packages // decoder for packages
@ -106,15 +108,17 @@ func create(iface common.IFaceMapper) (*Manager, error) {
stateful: !disableConntrack, stateful: !disableConntrack,
// TODO: fix // TODO: fix
routingEnabled: true, routingEnabled: true,
// TODO: support chaning log level from logrus
logger: nblog.NewFromLogrus(log.StandardLogger()),
} }
// Only initialize trackers if stateful mode is enabled // Only initialize trackers if stateful mode is enabled
if disableConntrack { if disableConntrack {
log.Info("conntrack is disabled") log.Info("conntrack is disabled")
} else { } else {
m.udpTracker = conntrack.NewUDPTracker(conntrack.DefaultUDPTimeout) m.udpTracker = conntrack.NewUDPTracker(conntrack.DefaultUDPTimeout, m.logger)
m.icmpTracker = conntrack.NewICMPTracker(conntrack.DefaultICMPTimeout) m.icmpTracker = conntrack.NewICMPTracker(conntrack.DefaultICMPTimeout, m.logger)
m.tcpTracker = conntrack.NewTCPTracker(conntrack.DefaultTCPTimeout) m.tcpTracker = conntrack.NewTCPTracker(conntrack.DefaultTCPTimeout, m.logger)
} }
intf := iface.GetWGDevice() intf := iface.GetWGDevice()
@ -125,7 +129,7 @@ func create(iface common.IFaceMapper) (*Manager, error) {
m.routingEnabled = false m.routingEnabled = false
} else { } else {
var err error var err error
m.forwarder, err = forwarder.New(iface) m.forwarder, err = forwarder.New(iface, m.logger)
if err != nil { if err != nil {
log.Errorf("failed to create forwarder: %v", err) log.Errorf("failed to create forwarder: %v", err)
m.routingEnabled = false m.routingEnabled = false
@ -455,17 +459,16 @@ func (m *Manager) dropFilter(packetData []byte, rules map[string]RuleSet) bool {
defer m.decoders.Put(d) defer m.decoders.Put(d)
if !m.isValidPacket(d, packetData) { if !m.isValidPacket(d, packetData) {
log.Debugf("invalid packet: %v", d.decoded) m.logger.Trace("Invalid packet structure")
return true return true
} }
srcIP, dstIP := m.extractIPs(d) srcIP, dstIP := m.extractIPs(d)
if srcIP == nil { if srcIP == nil {
log.Errorf("unknown layer: %v", d.decoded[0]) m.logger.Error("Unknown network layer: %v", d.decoded[0])
return true return true
} }
// Check if this is local or routed traffic
isLocal := m.isLocalIP(dstIP) isLocal := m.isLocalIP(dstIP)
// For all inbound traffic, first check if it matches a tracked connection. // For all inbound traffic, first check if it matches a tracked connection.
@ -476,7 +479,12 @@ func (m *Manager) dropFilter(packetData []byte, rules map[string]RuleSet) bool {
// Handle local traffic - apply peer ACLs // Handle local traffic - apply peer ACLs
if isLocal { if isLocal {
return m.applyRules(srcIP, packetData, rules, d) drop := m.applyRules(srcIP, packetData, rules, d)
if drop {
m.logger.Trace("Dropping local packet: src=%s dst=%s rules=denied",
srcIP, dstIP)
}
return drop
} }
// Handle routed traffic // Handle routed traffic
@ -484,6 +492,8 @@ func (m *Manager) dropFilter(packetData []byte, rules map[string]RuleSet) bool {
// We might need to apply NAT // We might need to apply NAT
// Don't handle routing if not enabled // Don't handle routing if not enabled
if !m.routingEnabled { if !m.routingEnabled {
m.logger.Trace("Dropping routed packet (routing disabled): src=%s dst=%s",
srcIP, dstIP)
return true return true
} }
@ -493,13 +503,15 @@ func (m *Manager) dropFilter(packetData []byte, rules map[string]RuleSet) bool {
// Check route ACLs // Check route ACLs
if !m.checkRouteACLs(srcIP, dstIP, proto, srcPort, dstPort) { if !m.checkRouteACLs(srcIP, dstIP, proto, srcPort, dstPort) {
m.logger.Trace("Dropping routed packet (ACL denied): src=%s:%d dst=%s:%d proto=%v",
srcIP, srcPort, dstIP, dstPort, proto)
return true return true
} }
// Let forwarder handle the packet if it passed route ACLs // Let forwarder handle the packet if it passed route ACLs
err := m.forwarder.InjectIncomingPacket(packetData) err := m.forwarder.InjectIncomingPacket(packetData)
if err != nil { if err != nil {
log.Errorf("Failed to inject incoming packet: %v", err) m.logger.Error("Failed to inject incoming packet: %v", err)
} }
// Default: drop // Default: drop

View File

@ -400,7 +400,7 @@ func TestProcessOutgoingHooks(t *testing.T) {
Mask: net.CIDRMask(16, 32), Mask: net.CIDRMask(16, 32),
} }
manager.udpTracker.Close() manager.udpTracker.Close()
manager.udpTracker = conntrack.NewUDPTracker(100 * time.Millisecond) manager.udpTracker = conntrack.NewUDPTracker(100*time.Millisecond, nil)
defer func() { defer func() {
require.NoError(t, manager.Reset(nil)) require.NoError(t, manager.Reset(nil))
}() }()
@ -518,7 +518,7 @@ func TestStatefulFirewall_UDPTracking(t *testing.T) {
} }
manager.udpTracker.Close() // Close the existing tracker manager.udpTracker.Close() // Close the existing tracker
manager.udpTracker = conntrack.NewUDPTracker(200 * time.Millisecond) manager.udpTracker = conntrack.NewUDPTracker(200*time.Millisecond, nil)
manager.decoders = sync.Pool{ manager.decoders = sync.Pool{
New: func() any { New: func() any {
d := &decoder{ d := &decoder{