From d2616544fe3d2532bf374d4a20c2d9e585fdd87a Mon Sep 17 00:00:00 2001 From: Viktor Liu Date: Mon, 30 Dec 2024 15:18:21 +0100 Subject: [PATCH] Add logger --- client/firewall/uspfilter/allow_netbird.go | 19 +- .../uspfilter/allow_netbird_windows.go | 16 +- .../uspfilter/conntrack/common_test.go | 4 +- client/firewall/uspfilter/conntrack/icmp.go | 10 +- .../firewall/uspfilter/conntrack/icmp_test.go | 4 +- client/firewall/uspfilter/conntrack/tcp.go | 10 +- .../firewall/uspfilter/conntrack/tcp_test.go | 14 +- client/firewall/uspfilter/conntrack/udp.go | 10 +- .../firewall/uspfilter/conntrack/udp_test.go | 10 +- .../firewall/uspfilter/forwarder/endpoint.go | 6 +- .../firewall/uspfilter/forwarder/forwarder.go | 9 +- client/firewall/uspfilter/forwarder/tcp.go | 16 +- client/firewall/uspfilter/forwarder/udp.go | 35 +-- client/firewall/uspfilter/log/log.go | 208 ++++++++++++++++++ client/firewall/uspfilter/log/ringbuffer.go | 93 ++++++++ client/firewall/uspfilter/uspfilter.go | 30 ++- client/firewall/uspfilter/uspfilter_test.go | 4 +- 17 files changed, 436 insertions(+), 62 deletions(-) create mode 100644 client/firewall/uspfilter/log/log.go create mode 100644 client/firewall/uspfilter/log/ringbuffer.go diff --git a/client/firewall/uspfilter/allow_netbird.go b/client/firewall/uspfilter/allow_netbird.go index 297095090..03f23f5e6 100644 --- a/client/firewall/uspfilter/allow_netbird.go +++ b/client/firewall/uspfilter/allow_netbird.go @@ -3,6 +3,11 @@ package uspfilter import ( + "context" + "time" + + log "github.com/sirupsen/logrus" + "github.com/netbirdio/netbird/client/firewall/uspfilter/conntrack" "github.com/netbirdio/netbird/client/internal/statemanager" ) @@ -17,23 +22,31 @@ func (m *Manager) Reset(stateManager *statemanager.Manager) error { if m.udpTracker != nil { m.udpTracker.Close() - m.udpTracker = conntrack.NewUDPTracker(conntrack.DefaultUDPTimeout) + m.udpTracker = conntrack.NewUDPTracker(conntrack.DefaultUDPTimeout, m.logger) } if m.icmpTracker != nil { m.icmpTracker.Close() - m.icmpTracker = conntrack.NewICMPTracker(conntrack.DefaultICMPTimeout) + m.icmpTracker = conntrack.NewICMPTracker(conntrack.DefaultICMPTimeout, m.logger) } if m.tcpTracker != nil { m.tcpTracker.Close() - m.tcpTracker = conntrack.NewTCPTracker(conntrack.DefaultTCPTimeout) + m.tcpTracker = conntrack.NewTCPTracker(conntrack.DefaultTCPTimeout, m.logger) } if m.forwarder != nil { m.forwarder.Stop() } + if m.logger != nil { + ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second) + defer cancel() + if err := m.logger.Stop(ctx); err != nil { + log.Errorf("failed to shutdown logger: %v", err) + } + } + if m.nativeFirewall != nil { return m.nativeFirewall.Reset(stateManager) } diff --git a/client/firewall/uspfilter/allow_netbird_windows.go b/client/firewall/uspfilter/allow_netbird_windows.go index 42bf0896e..379585978 100644 --- a/client/firewall/uspfilter/allow_netbird_windows.go +++ b/client/firewall/uspfilter/allow_netbird_windows.go @@ -1,9 +1,11 @@ package uspfilter import ( + "context" "fmt" "os/exec" "syscall" + "time" log "github.com/sirupsen/logrus" @@ -29,23 +31,31 @@ func (m *Manager) Reset(*statemanager.Manager) error { if m.udpTracker != nil { m.udpTracker.Close() - m.udpTracker = conntrack.NewUDPTracker(conntrack.DefaultUDPTimeout) + m.udpTracker = conntrack.NewUDPTracker(conntrack.DefaultUDPTimeout, m.logger) } if m.icmpTracker != nil { m.icmpTracker.Close() - m.icmpTracker = conntrack.NewICMPTracker(conntrack.DefaultICMPTimeout) + m.icmpTracker = conntrack.NewICMPTracker(conntrack.DefaultICMPTimeout, m.logger) } if m.tcpTracker != nil { m.tcpTracker.Close() - m.tcpTracker = conntrack.NewTCPTracker(conntrack.DefaultTCPTimeout) + m.tcpTracker = conntrack.NewTCPTracker(conntrack.DefaultTCPTimeout, m.logger) } if m.forwarder != nil { m.forwarder.Stop() } + if m.logger != nil { + ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second) + defer cancel() + if err := m.logger.Stop(ctx); err != nil { + log.Errorf("failed to shutdown logger: %v", err) + } + } + if !isWindowsFirewallReachable() { return nil } diff --git a/client/firewall/uspfilter/conntrack/common_test.go b/client/firewall/uspfilter/conntrack/common_test.go index 72d006def..b885470a3 100644 --- a/client/firewall/uspfilter/conntrack/common_test.go +++ b/client/firewall/uspfilter/conntrack/common_test.go @@ -64,7 +64,7 @@ func BenchmarkAtomicOperations(b *testing.B) { // Memory pressure tests func BenchmarkMemoryPressure(b *testing.B) { b.Run("TCPHighLoad", func(b *testing.B) { - tracker := NewTCPTracker(DefaultTCPTimeout) + tracker := NewTCPTracker(DefaultTCPTimeout, nil) defer tracker.Close() // Generate different IPs @@ -89,7 +89,7 @@ func BenchmarkMemoryPressure(b *testing.B) { }) b.Run("UDPHighLoad", func(b *testing.B) { - tracker := NewUDPTracker(DefaultUDPTimeout) + tracker := NewUDPTracker(DefaultUDPTimeout, nil) defer tracker.Close() // Generate different IPs diff --git a/client/firewall/uspfilter/conntrack/icmp.go b/client/firewall/uspfilter/conntrack/icmp.go index e0a971678..277a4b26e 100644 --- a/client/firewall/uspfilter/conntrack/icmp.go +++ b/client/firewall/uspfilter/conntrack/icmp.go @@ -6,6 +6,8 @@ import ( "time" "github.com/google/gopacket/layers" + + nblog "github.com/netbirdio/netbird/client/firewall/uspfilter/log" ) const ( @@ -33,6 +35,7 @@ type ICMPConnTrack struct { // ICMPTracker manages ICMP connection states type ICMPTracker struct { + logger *nblog.Logger connections map[ICMPConnKey]*ICMPConnTrack timeout time.Duration cleanupTicker *time.Ticker @@ -42,12 +45,13 @@ type ICMPTracker struct { } // NewICMPTracker creates a new ICMP connection tracker -func NewICMPTracker(timeout time.Duration) *ICMPTracker { +func NewICMPTracker(timeout time.Duration, logger *nblog.Logger) *ICMPTracker { if timeout == 0 { timeout = DefaultICMPTimeout } tracker := &ICMPTracker{ + logger: logger, connections: make(map[ICMPConnKey]*ICMPConnTrack), timeout: timeout, cleanupTicker: time.NewTicker(ICMPCleanupInterval), @@ -83,6 +87,8 @@ func (t *ICMPTracker) TrackOutbound(srcIP net.IP, dstIP net.IP, id uint16, seq u conn.lastSeen.Store(now) conn.established.Store(true) t.connections[key] = conn + + t.logger.Trace("New ICMP connection %v", key) } t.mutex.Unlock() @@ -141,6 +147,8 @@ func (t *ICMPTracker) cleanup() { t.ipPool.Put(conn.SourceIP) t.ipPool.Put(conn.DestIP) delete(t.connections, key) + + t.logger.Debug("ICMPTracker: removed connection %v", key) } } } diff --git a/client/firewall/uspfilter/conntrack/icmp_test.go b/client/firewall/uspfilter/conntrack/icmp_test.go index 21176e719..e653416f9 100644 --- a/client/firewall/uspfilter/conntrack/icmp_test.go +++ b/client/firewall/uspfilter/conntrack/icmp_test.go @@ -7,7 +7,7 @@ import ( func BenchmarkICMPTracker(b *testing.B) { b.Run("TrackOutbound", func(b *testing.B) { - tracker := NewICMPTracker(DefaultICMPTimeout) + tracker := NewICMPTracker(DefaultICMPTimeout, nil) defer tracker.Close() srcIP := net.ParseIP("192.168.1.1") @@ -20,7 +20,7 @@ func BenchmarkICMPTracker(b *testing.B) { }) b.Run("IsValidInbound", func(b *testing.B) { - tracker := NewICMPTracker(DefaultICMPTimeout) + tracker := NewICMPTracker(DefaultICMPTimeout, nil) defer tracker.Close() srcIP := net.ParseIP("192.168.1.1") diff --git a/client/firewall/uspfilter/conntrack/tcp.go b/client/firewall/uspfilter/conntrack/tcp.go index a7968dc73..a42208b61 100644 --- a/client/firewall/uspfilter/conntrack/tcp.go +++ b/client/firewall/uspfilter/conntrack/tcp.go @@ -6,6 +6,8 @@ import ( "net" "sync" "time" + + nblog "github.com/netbirdio/netbird/client/firewall/uspfilter/log" ) const ( @@ -67,6 +69,7 @@ type TCPConnTrack struct { // TCPTracker manages TCP connection states type TCPTracker struct { + logger *nblog.Logger connections map[ConnKey]*TCPConnTrack mutex sync.RWMutex cleanupTicker *time.Ticker @@ -76,8 +79,9 @@ type TCPTracker struct { } // NewTCPTracker creates a new TCP connection tracker -func NewTCPTracker(timeout time.Duration) *TCPTracker { +func NewTCPTracker(timeout time.Duration, logger *nblog.Logger) *TCPTracker { tracker := &TCPTracker{ + logger: logger, connections: make(map[ConnKey]*TCPConnTrack), cleanupTicker: time.NewTicker(TCPCleanupInterval), done: make(chan struct{}), @@ -116,6 +120,8 @@ func (t *TCPTracker) TrackOutbound(srcIP net.IP, dstIP net.IP, srcPort uint16, d conn.lastSeen.Store(now) conn.established.Store(false) t.connections[key] = conn + + t.logger.Trace("New TCP connection: %s:%d -> %s:%d", srcIP, srcPort, dstIP, dstPort) } t.mutex.Unlock() @@ -318,6 +324,8 @@ func (t *TCPTracker) cleanup() { t.ipPool.Put(conn.SourceIP) t.ipPool.Put(conn.DestIP) delete(t.connections, key) + + t.logger.Trace("Closed TCP connection: %s:%d -> %s:%d", conn.SourceIP, conn.SourcePort, conn.DestIP, conn.DestPort) } } } diff --git a/client/firewall/uspfilter/conntrack/tcp_test.go b/client/firewall/uspfilter/conntrack/tcp_test.go index 6c8f82423..c44e7dfa7 100644 --- a/client/firewall/uspfilter/conntrack/tcp_test.go +++ b/client/firewall/uspfilter/conntrack/tcp_test.go @@ -9,7 +9,7 @@ import ( ) func TestTCPStateMachine(t *testing.T) { - tracker := NewTCPTracker(DefaultTCPTimeout) + tracker := NewTCPTracker(DefaultTCPTimeout, nil) defer tracker.Close() srcIP := net.ParseIP("100.64.0.1") @@ -154,7 +154,7 @@ func TestTCPStateMachine(t *testing.T) { t.Run(tt.name, func(t *testing.T) { t.Helper() - tracker = NewTCPTracker(DefaultTCPTimeout) + tracker = NewTCPTracker(DefaultTCPTimeout, nil) tt.test(t) }) } @@ -162,7 +162,7 @@ func TestTCPStateMachine(t *testing.T) { } func TestRSTHandling(t *testing.T) { - tracker := NewTCPTracker(DefaultTCPTimeout) + tracker := NewTCPTracker(DefaultTCPTimeout, nil) defer tracker.Close() srcIP := net.ParseIP("100.64.0.1") @@ -233,7 +233,7 @@ func establishConnection(t *testing.T, tracker *TCPTracker, srcIP, dstIP net.IP, func BenchmarkTCPTracker(b *testing.B) { b.Run("TrackOutbound", func(b *testing.B) { - tracker := NewTCPTracker(DefaultTCPTimeout) + tracker := NewTCPTracker(DefaultTCPTimeout, nil) defer tracker.Close() srcIP := net.ParseIP("192.168.1.1") @@ -246,7 +246,7 @@ func BenchmarkTCPTracker(b *testing.B) { }) b.Run("IsValidInbound", func(b *testing.B) { - tracker := NewTCPTracker(DefaultTCPTimeout) + tracker := NewTCPTracker(DefaultTCPTimeout, nil) defer tracker.Close() srcIP := net.ParseIP("192.168.1.1") @@ -264,7 +264,7 @@ func BenchmarkTCPTracker(b *testing.B) { }) b.Run("ConcurrentAccess", func(b *testing.B) { - tracker := NewTCPTracker(DefaultTCPTimeout) + tracker := NewTCPTracker(DefaultTCPTimeout, nil) defer tracker.Close() srcIP := net.ParseIP("192.168.1.1") @@ -287,7 +287,7 @@ func BenchmarkTCPTracker(b *testing.B) { // Benchmark connection cleanup func BenchmarkCleanup(b *testing.B) { b.Run("TCPCleanup", func(b *testing.B) { - tracker := NewTCPTracker(100 * time.Millisecond) // Short timeout for testing + tracker := NewTCPTracker(100*time.Millisecond, nil) // Short timeout for testing defer tracker.Close() // Pre-populate with expired connections diff --git a/client/firewall/uspfilter/conntrack/udp.go b/client/firewall/uspfilter/conntrack/udp.go index a969a4e84..630006349 100644 --- a/client/firewall/uspfilter/conntrack/udp.go +++ b/client/firewall/uspfilter/conntrack/udp.go @@ -4,6 +4,8 @@ import ( "net" "sync" "time" + + nblog "github.com/netbirdio/netbird/client/firewall/uspfilter/log" ) const ( @@ -20,6 +22,7 @@ type UDPConnTrack struct { // UDPTracker manages UDP connection states type UDPTracker struct { + logger *nblog.Logger connections map[ConnKey]*UDPConnTrack timeout time.Duration cleanupTicker *time.Ticker @@ -29,12 +32,13 @@ type UDPTracker struct { } // NewUDPTracker creates a new UDP connection tracker -func NewUDPTracker(timeout time.Duration) *UDPTracker { +func NewUDPTracker(timeout time.Duration, logger *nblog.Logger) *UDPTracker { if timeout == 0 { timeout = DefaultUDPTimeout } tracker := &UDPTracker{ + logger: logger, connections: make(map[ConnKey]*UDPConnTrack), timeout: timeout, cleanupTicker: time.NewTicker(UDPCleanupInterval), @@ -70,6 +74,8 @@ func (t *UDPTracker) TrackOutbound(srcIP net.IP, dstIP net.IP, srcPort uint16, d conn.lastSeen.Store(now) conn.established.Store(true) t.connections[key] = conn + + t.logger.Trace("New UDP connection: %s", conn) } t.mutex.Unlock() @@ -120,6 +126,8 @@ func (t *UDPTracker) cleanup() { t.ipPool.Put(conn.SourceIP) t.ipPool.Put(conn.DestIP) delete(t.connections, key) + + t.logger.Trace("UDP connection timed out: %s", conn) } } } diff --git a/client/firewall/uspfilter/conntrack/udp_test.go b/client/firewall/uspfilter/conntrack/udp_test.go index 671721890..4e42c484f 100644 --- a/client/firewall/uspfilter/conntrack/udp_test.go +++ b/client/firewall/uspfilter/conntrack/udp_test.go @@ -29,7 +29,7 @@ func TestNewUDPTracker(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - tracker := NewUDPTracker(tt.timeout) + tracker := NewUDPTracker(tt.timeout, nil) assert.NotNil(t, tracker) assert.Equal(t, tt.wantTimeout, tracker.timeout) assert.NotNil(t, tracker.connections) @@ -40,7 +40,7 @@ func TestNewUDPTracker(t *testing.T) { } func TestUDPTracker_TrackOutbound(t *testing.T) { - tracker := NewUDPTracker(DefaultUDPTimeout) + tracker := NewUDPTracker(DefaultUDPTimeout, nil) defer tracker.Close() srcIP := net.ParseIP("192.168.1.2") @@ -63,7 +63,7 @@ func TestUDPTracker_TrackOutbound(t *testing.T) { } func TestUDPTracker_IsValidInbound(t *testing.T) { - tracker := NewUDPTracker(1 * time.Second) + tracker := NewUDPTracker(1*time.Second, nil) defer tracker.Close() srcIP := net.ParseIP("192.168.1.2") @@ -211,7 +211,7 @@ func TestUDPTracker_Cleanup(t *testing.T) { func BenchmarkUDPTracker(b *testing.B) { b.Run("TrackOutbound", func(b *testing.B) { - tracker := NewUDPTracker(DefaultUDPTimeout) + tracker := NewUDPTracker(DefaultUDPTimeout, nil) defer tracker.Close() srcIP := net.ParseIP("192.168.1.1") @@ -224,7 +224,7 @@ func BenchmarkUDPTracker(b *testing.B) { }) b.Run("IsValidInbound", func(b *testing.B) { - tracker := NewUDPTracker(DefaultUDPTimeout) + tracker := NewUDPTracker(DefaultUDPTimeout, nil) defer tracker.Close() srcIP := net.ParseIP("192.168.1.1") diff --git a/client/firewall/uspfilter/forwarder/endpoint.go b/client/firewall/uspfilter/forwarder/endpoint.go index 9f22fe3a2..c234ca241 100644 --- a/client/firewall/uspfilter/forwarder/endpoint.go +++ b/client/firewall/uspfilter/forwarder/endpoint.go @@ -1,15 +1,17 @@ package forwarder import ( - log "github.com/sirupsen/logrus" wgdevice "golang.zx2c4.com/wireguard/device" "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/header" "gvisor.dev/gvisor/pkg/tcpip/stack" + + nblog "github.com/netbirdio/netbird/client/firewall/uspfilter/log" ) // endpoint implements stack.LinkEndpoint and handles integration with the wireguard device type endpoint struct { + logger *nblog.Logger dispatcher stack.NetworkDispatcher device *wgdevice.Device mtu uint32 @@ -55,7 +57,7 @@ func (e *endpoint) WritePackets(pkts stack.PacketBufferList) (int, tcpip.Error) // TODO: handle dest ip addresses outside our network err := e.device.CreateOutboundPacket(data.AsSlice(), address.AsSlice()) if err != nil { - log.Errorf("CreateOutboundPacket: %v", err) + e.logger.Error("CreateOutboundPacket: %v", err) continue } written++ diff --git a/client/firewall/uspfilter/forwarder/forwarder.go b/client/firewall/uspfilter/forwarder/forwarder.go index 815c7da09..f39200658 100644 --- a/client/firewall/uspfilter/forwarder/forwarder.go +++ b/client/firewall/uspfilter/forwarder/forwarder.go @@ -14,6 +14,7 @@ import ( "gvisor.dev/gvisor/pkg/tcpip/transport/udp" "github.com/netbirdio/netbird/client/firewall/uspfilter/common" + nblog "github.com/netbirdio/netbird/client/firewall/uspfilter/log" ) const ( @@ -22,6 +23,7 @@ const ( ) type Forwarder struct { + logger *nblog.Logger stack *stack.Stack endpoint *endpoint udpForwarder *udpForwarder @@ -29,8 +31,7 @@ type Forwarder struct { cancel context.CancelFunc } -func New(iface common.IFaceMapper) (*Forwarder, error) { - +func New(iface common.IFaceMapper, logger *nblog.Logger) (*Forwarder, error) { s := stack.New(stack.Options{ NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol}, TransportProtocols: []stack.TransportProtocolFactory{ @@ -46,6 +47,7 @@ func New(iface common.IFaceMapper) (*Forwarder, error) { } nicID := tcpip.NICID(1) endpoint := &endpoint{ + logger: logger, device: iface.GetWGDevice(), mtu: uint32(mtu), } @@ -91,9 +93,10 @@ func New(iface common.IFaceMapper) (*Forwarder, error) { ctx, cancel := context.WithCancel(context.Background()) f := &Forwarder{ + logger: logger, stack: s, endpoint: endpoint, - udpForwarder: newUDPForwarder(), + udpForwarder: newUDPForwarder(logger), ctx: ctx, cancel: cancel, } diff --git a/client/firewall/uspfilter/forwarder/tcp.go b/client/firewall/uspfilter/forwarder/tcp.go index 90967b6f5..25503cb6d 100644 --- a/client/firewall/uspfilter/forwarder/tcp.go +++ b/client/firewall/uspfilter/forwarder/tcp.go @@ -6,7 +6,6 @@ import ( "io" "net" - log "github.com/sirupsen/logrus" "gvisor.dev/gvisor/pkg/tcpip/adapters/gonet" "gvisor.dev/gvisor/pkg/tcpip/transport/tcp" "gvisor.dev/gvisor/pkg/waiter" @@ -23,16 +22,19 @@ func (f *Forwarder) handleTCP(r *tcp.ForwarderRequest) { outConn, err := (&net.Dialer{}).DialContext(f.ctx, "tcp", dialAddr) if err != nil { r.Complete(true) + f.logger.Trace("forwarder: dial error for %v: %v", id, err) return } + f.logger.Trace("forwarder: established TCP connection to %v", id) + // Create wait queue for blocking syscalls wq := waiter.Queue{} ep, err2 := r.CreateEndpoint(&wq) if err2 != nil { if err := outConn.Close(); err != nil { - log.Errorf("forwarder: outConn close error: %v", err) + f.logger.Error("forwarder: outConn close error: %v", err) } r.Complete(true) return @@ -49,10 +51,10 @@ func (f *Forwarder) handleTCP(r *tcp.ForwarderRequest) { func (f *Forwarder) proxyTCP(inConn *gonet.TCPConn, outConn net.Conn) { defer func() { if err := inConn.Close(); err != nil { - log.Errorf("forwarder: inConn close error: %v", err) + f.logger.Error("forwarder: inConn close error: %v", err) } if err := outConn.Close(); err != nil { - log.Errorf("forwarder: outConn close error: %v", err) + f.logger.Error("forwarder: outConn close error: %v", err) } }() @@ -65,7 +67,7 @@ func (f *Forwarder) proxyTCP(inConn *gonet.TCPConn, outConn net.Conn) { go func() { n, err := io.Copy(outConn, inConn) if err != nil && !isClosedError(err) { - log.Errorf("proxyTCP: inbound->outbound copy error after %d bytes: %v", n, err) + f.logger.Error("inbound->outbound copy error after %d bytes: %v", n, err) } errChan <- err }() @@ -73,7 +75,7 @@ func (f *Forwarder) proxyTCP(inConn *gonet.TCPConn, outConn net.Conn) { go func() { n, err := io.Copy(inConn, outConn) if err != nil && !isClosedError(err) { - log.Errorf("proxyTCP: outbound->inbound copy error after %d bytes: %v", n, err) + f.logger.Error("outbound->inbound copy error after %d bytes: %v", n, err) } errChan <- err }() @@ -83,7 +85,7 @@ func (f *Forwarder) proxyTCP(inConn *gonet.TCPConn, outConn net.Conn) { return case err := <-errChan: if err != nil && !isClosedError(err) { - log.Errorf("proxyTCP: copy error: %v", err) + f.logger.Error("proxyTCP: copy error: %v", err) } return } diff --git a/client/firewall/uspfilter/forwarder/udp.go b/client/firewall/uspfilter/forwarder/udp.go index 7d201024e..bb43a8346 100644 --- a/client/firewall/uspfilter/forwarder/udp.go +++ b/client/firewall/uspfilter/forwarder/udp.go @@ -8,11 +8,12 @@ import ( "sync" "time" - log "github.com/sirupsen/logrus" "gvisor.dev/gvisor/pkg/tcpip/adapters/gonet" "gvisor.dev/gvisor/pkg/tcpip/stack" "gvisor.dev/gvisor/pkg/tcpip/transport/udp" "gvisor.dev/gvisor/pkg/waiter" + + nblog "github.com/netbirdio/netbird/client/firewall/uspfilter/log" ) const ( @@ -29,15 +30,17 @@ type udpPacketConn struct { type udpForwarder struct { sync.RWMutex + logger *nblog.Logger conns map[stack.TransportEndpointID]*udpPacketConn bufPool sync.Pool ctx context.Context cancel context.CancelFunc } -func newUDPForwarder() *udpForwarder { +func newUDPForwarder(logger *nblog.Logger) *udpForwarder { ctx, cancel := context.WithCancel(context.Background()) f := &udpForwarder{ + logger: logger, conns: make(map[stack.TransportEndpointID]*udpPacketConn), ctx: ctx, cancel: cancel, @@ -62,10 +65,10 @@ func (f *udpForwarder) Stop() { for id, conn := range f.conns { conn.cancel() if err := conn.conn.Close(); err != nil { - log.Errorf("forwarder: UDP conn close error for %v: %v", id, err) + f.logger.Error("forwarder: UDP conn close error for %v: %v", id, err) } if err := conn.outConn.Close(); err != nil { - log.Errorf("forwarder: UDP outConn close error for %v: %v", id, err) + f.logger.Error("forwarder: UDP outConn close error for %v: %v", id, err) } delete(f.conns, id) } @@ -87,13 +90,13 @@ func (f *udpForwarder) cleanup() { if now.Sub(conn.lastTime) > udpTimeout { conn.cancel() if err := conn.conn.Close(); err != nil { - log.Errorf("forwarder: UDP conn close error for %v: %v", id, err) + f.logger.Error("forwarder: UDP conn close error for %v: %v", id, err) } if err := conn.outConn.Close(); err != nil { - log.Errorf("forwarder: UDP outConn close error for %v: %v", id, err) + f.logger.Error("forwarder: UDP outConn close error for %v: %v", id, err) } delete(f.conns, id) - log.Debugf("forwarder: cleaned up idle UDP connection %v", id) + f.logger.Trace("forwarder: cleaned up idle UDP connection %v", id) } } f.Unlock() @@ -107,7 +110,7 @@ func (f *Forwarder) handleUDP(r *udp.ForwarderRequest) { dstAddr := fmt.Sprintf("%s:%d", id.LocalAddress.String(), id.LocalPort) if f.ctx.Err() != nil { - log.Debug("forwarder: context done, dropping UDP packet") + f.logger.Trace("forwarder: context done, dropping UDP packet") return } @@ -116,7 +119,7 @@ func (f *Forwarder) handleUDP(r *udp.ForwarderRequest) { ep, err := r.CreateEndpoint(&wq) if err != nil { - log.Errorf("forwarder: failed to create UDP endpoint: %v", err) + f.logger.Error("forwarder: failed to create UDP endpoint: %v", err) return } @@ -131,12 +134,16 @@ func (f *Forwarder) handleUDP(r *udp.ForwarderRequest) { outConn, err := (&net.Dialer{}).DialContext(f.ctx, "udp", dstAddr) if err != nil { if err := inConn.Close(); err != nil { - log.Errorf("forwarder: UDP inConn close error for %v: %v", id, err) + f.logger.Error("forwarder: UDP inConn close error for %v: %v", id, err) } - log.Errorf("forwarder: UDP dial error for %v: %v", id, err) + f.logger.Debug("forwarder: UDP dial error for %v: %v", id, err) + + // TODO: Send ICMP error message return } + f.logger.Trace("forwarder: established UDP connection to %v", id) + connCtx, connCancel := context.WithCancel(f.ctx) pConn = &udpPacketConn{ conn: inConn, @@ -154,10 +161,10 @@ func (f *Forwarder) proxyUDP(ctx context.Context, pConn *udpPacketConn, id stack defer func() { pConn.cancel() if err := pConn.conn.Close(); err != nil { - log.Errorf("forwarder: UDP inConn close error for %v: %v", id, err) + f.logger.Error("forwarder: UDP inConn close error for %v: %v", id, err) } if err := pConn.outConn.Close(); err != nil { - log.Errorf("forwarder: UDP outConn close error for %v: %v", id, err) + f.logger.Error("forwarder: UDP outConn close error for %v: %v", id, err) } f.udpForwarder.Lock() @@ -180,7 +187,7 @@ func (f *Forwarder) proxyUDP(ctx context.Context, pConn *udpPacketConn, id stack return case err := <-errChan: if err != nil && !isClosedError(err) { - log.Errorf("forwader: UDP proxy error for %v: %v", id, err) + f.logger.Error("proxyUDP: copy error: %v", err) } return } diff --git a/client/firewall/uspfilter/log/log.go b/client/firewall/uspfilter/log/log.go new file mode 100644 index 000000000..2e9a4d4b7 --- /dev/null +++ b/client/firewall/uspfilter/log/log.go @@ -0,0 +1,208 @@ +// Package logger provides a high-performance, non-blocking logger for userspace networking +package log + +import ( + "context" + "fmt" + "io" + "sync" + "sync/atomic" + "time" + + log "github.com/sirupsen/logrus" +) + +const ( + maxBatchSize = 1024 * 16 // 16KB max batch size + maxMessageSize = 1024 * 2 // 2KB per message + bufferSize = 1024 * 256 // 256KB ring buffer + defaultFlushInterval = 2 * time.Second +) + +// Level represents log severity +type Level uint32 + +const ( + LevelPanic Level = iota + LevelFatal + LevelError + LevelWarn + LevelInfo + LevelDebug + LevelTrace +) + +var levelStrings = map[Level]string{ + LevelPanic: "PANC", + LevelFatal: "FATL", + LevelError: "ERRO", + LevelWarn: "WARN", + LevelInfo: "INFO", + LevelDebug: "DEBG", + LevelTrace: "TRAC", +} + +func FromLogrusLevel(level log.Level) Level { + switch level { + case log.TraceLevel: + return LevelTrace + case log.DebugLevel: + return LevelDebug + case log.InfoLevel: + return LevelInfo + case log.WarnLevel: + return LevelWarn + case log.ErrorLevel: + return LevelError + case log.FatalLevel: + return LevelFatal + case log.PanicLevel: + return LevelPanic + default: + return LevelInfo + } +} + +// Logger is a high-performance, non-blocking logger +type Logger struct { + output io.Writer + level atomic.Uint32 + buffer *ringBuffer + shutdown chan struct{} + wg sync.WaitGroup + + // Reusable buffer pool for formatting messages + bufPool sync.Pool +} + +func NewFromLogrus(logrusLogger *log.Logger) *Logger { + l := &Logger{ + output: logrusLogger.Out, + buffer: newRingBuffer(bufferSize), + shutdown: make(chan struct{}), + bufPool: sync.Pool{ + New: func() interface{} { + // Pre-allocate buffer for message formatting + b := make([]byte, 0, maxMessageSize) + return &b + }, + }, + } + l.level.Store(uint32(LevelInfo)) + + l.wg.Add(1) + go l.worker() + + return l +} + +func (l *Logger) SetLevel(level Level) { + l.level.Store(uint32(level)) +} + +func (l *Logger) formatMessage(buf *[]byte, level Level, format string, args ...interface{}) { + *buf = (*buf)[:0] + + // Timestamp + *buf = time.Now().AppendFormat(*buf, "2006-01-02T15:04:05.000000-07:00") + *buf = append(*buf, ' ') + + // Level + *buf = append(*buf, levelStrings[level]...) + *buf = append(*buf, ' ') + + // Message + if len(args) > 0 { + *buf = append(*buf, fmt.Sprintf(format, args...)...) + } else { + *buf = append(*buf, format...) + } + + *buf = append(*buf, '\n') +} + +func (l *Logger) log(level Level, format string, args ...interface{}) { + bufp := l.bufPool.Get().(*[]byte) + l.formatMessage(bufp, level, format, args...) + + if len(*bufp) > maxMessageSize { + *bufp = (*bufp)[:maxMessageSize] + } + l.buffer.Write(*bufp) + + l.bufPool.Put(bufp) +} + +func (l *Logger) Trace(format string, args ...interface{}) { + if l.level.Load() <= uint32(LevelTrace) { + l.log(LevelTrace, format, args...) + } +} + +func (l *Logger) Debug(format string, args ...interface{}) { + if l.level.Load() <= uint32(LevelDebug) { + l.log(LevelDebug, format, args...) + } +} + +func (l *Logger) Info(format string, args ...interface{}) { + if l.level.Load() <= uint32(LevelInfo) { + l.log(LevelInfo, format, args...) + } +} + +func (l *Logger) Warn(format string, args ...interface{}) { + if l.level.Load() <= uint32(LevelWarn) { + l.log(LevelWarn, format, args...) + } +} + +func (l *Logger) Error(format string, args ...interface{}) { + if l.level.Load() <= uint32(LevelError) { + l.log(LevelError, format, args...) + } +} + +// worker periodically flushes the buffer +func (l *Logger) worker() { + defer l.wg.Done() + + ticker := time.NewTicker(defaultFlushInterval) + defer ticker.Stop() + + buf := make([]byte, 0, maxBatchSize) + + for { + select { + case <-l.shutdown: + return + case <-ticker.C: + // Read accumulated messages + n, _ := l.buffer.Read(buf[:cap(buf)]) + if n == 0 { + continue + } + + // Write batch + l.output.Write(buf[:n]) + } + } +} + +// Stop gracefully shuts down the logger +func (l *Logger) Stop(ctx context.Context) error { + close(l.shutdown) + + done := make(chan struct{}) + go func() { + l.wg.Wait() + close(done) + }() + + select { + case <-ctx.Done(): + return ctx.Err() + case <-done: + return nil + } +} diff --git a/client/firewall/uspfilter/log/ringbuffer.go b/client/firewall/uspfilter/log/ringbuffer.go new file mode 100644 index 000000000..48ebe84ae --- /dev/null +++ b/client/firewall/uspfilter/log/ringbuffer.go @@ -0,0 +1,93 @@ +package log + +import "sync" + +// ringBuffer is a simple ring buffer implementation +type ringBuffer struct { + buf []byte + size int + r, w int64 // Read and write positions + mu sync.Mutex +} + +func newRingBuffer(size int) *ringBuffer { + return &ringBuffer{ + buf: make([]byte, size), + size: size, + } +} + +func (r *ringBuffer) Write(p []byte) (n int, err error) { + if len(p) == 0 { + return 0, nil + } + + r.mu.Lock() + defer r.mu.Unlock() + + if len(p) > r.size { + p = p[:r.size] + } + + n = len(p) + + // Write data, handling wrap-around + pos := int(r.w % int64(r.size)) + writeLen := min(len(p), r.size-pos) + copy(r.buf[pos:], p[:writeLen]) + + // If we have more data and need to wrap around + if writeLen < len(p) { + copy(r.buf, p[writeLen:]) + } + + // Update write position + r.w += int64(n) + + return n, nil +} + +func (r *ringBuffer) Read(p []byte) (n int, err error) { + r.mu.Lock() + defer r.mu.Unlock() + + if r.w == r.r { + return 0, nil + } + + // Calculate available data accounting for wraparound + available := int(r.w - r.r) + if available < 0 { + available += r.size + } + available = min(available, r.size) + + // Limit read to buffer size + toRead := min(available, len(p)) + if toRead == 0 { + return 0, nil + } + + // Read data, handling wrap-around + pos := int(r.r % int64(r.size)) + readLen := min(toRead, r.size-pos) + n = copy(p, r.buf[pos:pos+readLen]) + + // If we need more data and need to wrap around + if readLen < toRead { + n += copy(p[readLen:toRead], r.buf[:toRead-readLen]) + } + + // Update read position + r.r += int64(n) + + return n, nil +} + +// min returns the smaller of two integers +func min(a, b int) int { + if a < b { + return a + } + return b +} diff --git a/client/firewall/uspfilter/uspfilter.go b/client/firewall/uspfilter/uspfilter.go index feed1887b..55e2063ec 100644 --- a/client/firewall/uspfilter/uspfilter.go +++ b/client/firewall/uspfilter/uspfilter.go @@ -17,6 +17,7 @@ import ( "github.com/netbirdio/netbird/client/firewall/uspfilter/common" "github.com/netbirdio/netbird/client/firewall/uspfilter/conntrack" "github.com/netbirdio/netbird/client/firewall/uspfilter/forwarder" + nblog "github.com/netbirdio/netbird/client/firewall/uspfilter/log" "github.com/netbirdio/netbird/client/internal/statemanager" ) @@ -52,6 +53,7 @@ type Manager struct { icmpTracker *conntrack.ICMPTracker tcpTracker *conntrack.TCPTracker forwarder *forwarder.Forwarder + logger *nblog.Logger } // decoder for packages @@ -106,15 +108,17 @@ func create(iface common.IFaceMapper) (*Manager, error) { stateful: !disableConntrack, // TODO: fix routingEnabled: true, + // TODO: support chaning log level from logrus + logger: nblog.NewFromLogrus(log.StandardLogger()), } // Only initialize trackers if stateful mode is enabled if disableConntrack { log.Info("conntrack is disabled") } else { - m.udpTracker = conntrack.NewUDPTracker(conntrack.DefaultUDPTimeout) - m.icmpTracker = conntrack.NewICMPTracker(conntrack.DefaultICMPTimeout) - m.tcpTracker = conntrack.NewTCPTracker(conntrack.DefaultTCPTimeout) + m.udpTracker = conntrack.NewUDPTracker(conntrack.DefaultUDPTimeout, m.logger) + m.icmpTracker = conntrack.NewICMPTracker(conntrack.DefaultICMPTimeout, m.logger) + m.tcpTracker = conntrack.NewTCPTracker(conntrack.DefaultTCPTimeout, m.logger) } intf := iface.GetWGDevice() @@ -125,7 +129,7 @@ func create(iface common.IFaceMapper) (*Manager, error) { m.routingEnabled = false } else { var err error - m.forwarder, err = forwarder.New(iface) + m.forwarder, err = forwarder.New(iface, m.logger) if err != nil { log.Errorf("failed to create forwarder: %v", err) m.routingEnabled = false @@ -455,17 +459,16 @@ func (m *Manager) dropFilter(packetData []byte, rules map[string]RuleSet) bool { defer m.decoders.Put(d) if !m.isValidPacket(d, packetData) { - log.Debugf("invalid packet: %v", d.decoded) + m.logger.Trace("Invalid packet structure") return true } srcIP, dstIP := m.extractIPs(d) if srcIP == nil { - log.Errorf("unknown layer: %v", d.decoded[0]) + m.logger.Error("Unknown network layer: %v", d.decoded[0]) return true } - // Check if this is local or routed traffic isLocal := m.isLocalIP(dstIP) // For all inbound traffic, first check if it matches a tracked connection. @@ -476,7 +479,12 @@ func (m *Manager) dropFilter(packetData []byte, rules map[string]RuleSet) bool { // Handle local traffic - apply peer ACLs if isLocal { - return m.applyRules(srcIP, packetData, rules, d) + drop := m.applyRules(srcIP, packetData, rules, d) + if drop { + m.logger.Trace("Dropping local packet: src=%s dst=%s rules=denied", + srcIP, dstIP) + } + return drop } // Handle routed traffic @@ -484,6 +492,8 @@ func (m *Manager) dropFilter(packetData []byte, rules map[string]RuleSet) bool { // We might need to apply NAT // Don't handle routing if not enabled if !m.routingEnabled { + m.logger.Trace("Dropping routed packet (routing disabled): src=%s dst=%s", + srcIP, dstIP) return true } @@ -493,13 +503,15 @@ func (m *Manager) dropFilter(packetData []byte, rules map[string]RuleSet) bool { // Check route ACLs if !m.checkRouteACLs(srcIP, dstIP, proto, srcPort, dstPort) { + m.logger.Trace("Dropping routed packet (ACL denied): src=%s:%d dst=%s:%d proto=%v", + srcIP, srcPort, dstIP, dstPort, proto) return true } // Let forwarder handle the packet if it passed route ACLs err := m.forwarder.InjectIncomingPacket(packetData) if err != nil { - log.Errorf("Failed to inject incoming packet: %v", err) + m.logger.Error("Failed to inject incoming packet: %v", err) } // Default: drop diff --git a/client/firewall/uspfilter/uspfilter_test.go b/client/firewall/uspfilter/uspfilter_test.go index 443d82607..2d85116d4 100644 --- a/client/firewall/uspfilter/uspfilter_test.go +++ b/client/firewall/uspfilter/uspfilter_test.go @@ -400,7 +400,7 @@ func TestProcessOutgoingHooks(t *testing.T) { Mask: net.CIDRMask(16, 32), } manager.udpTracker.Close() - manager.udpTracker = conntrack.NewUDPTracker(100 * time.Millisecond) + manager.udpTracker = conntrack.NewUDPTracker(100*time.Millisecond, nil) defer func() { require.NoError(t, manager.Reset(nil)) }() @@ -518,7 +518,7 @@ func TestStatefulFirewall_UDPTracking(t *testing.T) { } manager.udpTracker.Close() // Close the existing tracker - manager.udpTracker = conntrack.NewUDPTracker(200 * time.Millisecond) + manager.udpTracker = conntrack.NewUDPTracker(200*time.Millisecond, nil) manager.decoders = sync.Pool{ New: func() any { d := &decoder{