Add logger

This commit is contained in:
Viktor Liu 2024-12-30 15:18:21 +01:00
parent fad82ee65c
commit d2616544fe
17 changed files with 436 additions and 62 deletions

View File

@ -3,6 +3,11 @@
package uspfilter
import (
"context"
"time"
log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/client/firewall/uspfilter/conntrack"
"github.com/netbirdio/netbird/client/internal/statemanager"
)
@ -17,23 +22,31 @@ func (m *Manager) Reset(stateManager *statemanager.Manager) error {
if m.udpTracker != nil {
m.udpTracker.Close()
m.udpTracker = conntrack.NewUDPTracker(conntrack.DefaultUDPTimeout)
m.udpTracker = conntrack.NewUDPTracker(conntrack.DefaultUDPTimeout, m.logger)
}
if m.icmpTracker != nil {
m.icmpTracker.Close()
m.icmpTracker = conntrack.NewICMPTracker(conntrack.DefaultICMPTimeout)
m.icmpTracker = conntrack.NewICMPTracker(conntrack.DefaultICMPTimeout, m.logger)
}
if m.tcpTracker != nil {
m.tcpTracker.Close()
m.tcpTracker = conntrack.NewTCPTracker(conntrack.DefaultTCPTimeout)
m.tcpTracker = conntrack.NewTCPTracker(conntrack.DefaultTCPTimeout, m.logger)
}
if m.forwarder != nil {
m.forwarder.Stop()
}
if m.logger != nil {
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
defer cancel()
if err := m.logger.Stop(ctx); err != nil {
log.Errorf("failed to shutdown logger: %v", err)
}
}
if m.nativeFirewall != nil {
return m.nativeFirewall.Reset(stateManager)
}

View File

@ -1,9 +1,11 @@
package uspfilter
import (
"context"
"fmt"
"os/exec"
"syscall"
"time"
log "github.com/sirupsen/logrus"
@ -29,23 +31,31 @@ func (m *Manager) Reset(*statemanager.Manager) error {
if m.udpTracker != nil {
m.udpTracker.Close()
m.udpTracker = conntrack.NewUDPTracker(conntrack.DefaultUDPTimeout)
m.udpTracker = conntrack.NewUDPTracker(conntrack.DefaultUDPTimeout, m.logger)
}
if m.icmpTracker != nil {
m.icmpTracker.Close()
m.icmpTracker = conntrack.NewICMPTracker(conntrack.DefaultICMPTimeout)
m.icmpTracker = conntrack.NewICMPTracker(conntrack.DefaultICMPTimeout, m.logger)
}
if m.tcpTracker != nil {
m.tcpTracker.Close()
m.tcpTracker = conntrack.NewTCPTracker(conntrack.DefaultTCPTimeout)
m.tcpTracker = conntrack.NewTCPTracker(conntrack.DefaultTCPTimeout, m.logger)
}
if m.forwarder != nil {
m.forwarder.Stop()
}
if m.logger != nil {
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
defer cancel()
if err := m.logger.Stop(ctx); err != nil {
log.Errorf("failed to shutdown logger: %v", err)
}
}
if !isWindowsFirewallReachable() {
return nil
}

View File

@ -64,7 +64,7 @@ func BenchmarkAtomicOperations(b *testing.B) {
// Memory pressure tests
func BenchmarkMemoryPressure(b *testing.B) {
b.Run("TCPHighLoad", func(b *testing.B) {
tracker := NewTCPTracker(DefaultTCPTimeout)
tracker := NewTCPTracker(DefaultTCPTimeout, nil)
defer tracker.Close()
// Generate different IPs
@ -89,7 +89,7 @@ func BenchmarkMemoryPressure(b *testing.B) {
})
b.Run("UDPHighLoad", func(b *testing.B) {
tracker := NewUDPTracker(DefaultUDPTimeout)
tracker := NewUDPTracker(DefaultUDPTimeout, nil)
defer tracker.Close()
// Generate different IPs

View File

@ -6,6 +6,8 @@ import (
"time"
"github.com/google/gopacket/layers"
nblog "github.com/netbirdio/netbird/client/firewall/uspfilter/log"
)
const (
@ -33,6 +35,7 @@ type ICMPConnTrack struct {
// ICMPTracker manages ICMP connection states
type ICMPTracker struct {
logger *nblog.Logger
connections map[ICMPConnKey]*ICMPConnTrack
timeout time.Duration
cleanupTicker *time.Ticker
@ -42,12 +45,13 @@ type ICMPTracker struct {
}
// NewICMPTracker creates a new ICMP connection tracker
func NewICMPTracker(timeout time.Duration) *ICMPTracker {
func NewICMPTracker(timeout time.Duration, logger *nblog.Logger) *ICMPTracker {
if timeout == 0 {
timeout = DefaultICMPTimeout
}
tracker := &ICMPTracker{
logger: logger,
connections: make(map[ICMPConnKey]*ICMPConnTrack),
timeout: timeout,
cleanupTicker: time.NewTicker(ICMPCleanupInterval),
@ -83,6 +87,8 @@ func (t *ICMPTracker) TrackOutbound(srcIP net.IP, dstIP net.IP, id uint16, seq u
conn.lastSeen.Store(now)
conn.established.Store(true)
t.connections[key] = conn
t.logger.Trace("New ICMP connection %v", key)
}
t.mutex.Unlock()
@ -141,6 +147,8 @@ func (t *ICMPTracker) cleanup() {
t.ipPool.Put(conn.SourceIP)
t.ipPool.Put(conn.DestIP)
delete(t.connections, key)
t.logger.Debug("ICMPTracker: removed connection %v", key)
}
}
}

View File

@ -7,7 +7,7 @@ import (
func BenchmarkICMPTracker(b *testing.B) {
b.Run("TrackOutbound", func(b *testing.B) {
tracker := NewICMPTracker(DefaultICMPTimeout)
tracker := NewICMPTracker(DefaultICMPTimeout, nil)
defer tracker.Close()
srcIP := net.ParseIP("192.168.1.1")
@ -20,7 +20,7 @@ func BenchmarkICMPTracker(b *testing.B) {
})
b.Run("IsValidInbound", func(b *testing.B) {
tracker := NewICMPTracker(DefaultICMPTimeout)
tracker := NewICMPTracker(DefaultICMPTimeout, nil)
defer tracker.Close()
srcIP := net.ParseIP("192.168.1.1")

View File

@ -6,6 +6,8 @@ import (
"net"
"sync"
"time"
nblog "github.com/netbirdio/netbird/client/firewall/uspfilter/log"
)
const (
@ -67,6 +69,7 @@ type TCPConnTrack struct {
// TCPTracker manages TCP connection states
type TCPTracker struct {
logger *nblog.Logger
connections map[ConnKey]*TCPConnTrack
mutex sync.RWMutex
cleanupTicker *time.Ticker
@ -76,8 +79,9 @@ type TCPTracker struct {
}
// NewTCPTracker creates a new TCP connection tracker
func NewTCPTracker(timeout time.Duration) *TCPTracker {
func NewTCPTracker(timeout time.Duration, logger *nblog.Logger) *TCPTracker {
tracker := &TCPTracker{
logger: logger,
connections: make(map[ConnKey]*TCPConnTrack),
cleanupTicker: time.NewTicker(TCPCleanupInterval),
done: make(chan struct{}),
@ -116,6 +120,8 @@ func (t *TCPTracker) TrackOutbound(srcIP net.IP, dstIP net.IP, srcPort uint16, d
conn.lastSeen.Store(now)
conn.established.Store(false)
t.connections[key] = conn
t.logger.Trace("New TCP connection: %s:%d -> %s:%d", srcIP, srcPort, dstIP, dstPort)
}
t.mutex.Unlock()
@ -318,6 +324,8 @@ func (t *TCPTracker) cleanup() {
t.ipPool.Put(conn.SourceIP)
t.ipPool.Put(conn.DestIP)
delete(t.connections, key)
t.logger.Trace("Closed TCP connection: %s:%d -> %s:%d", conn.SourceIP, conn.SourcePort, conn.DestIP, conn.DestPort)
}
}
}

View File

@ -9,7 +9,7 @@ import (
)
func TestTCPStateMachine(t *testing.T) {
tracker := NewTCPTracker(DefaultTCPTimeout)
tracker := NewTCPTracker(DefaultTCPTimeout, nil)
defer tracker.Close()
srcIP := net.ParseIP("100.64.0.1")
@ -154,7 +154,7 @@ func TestTCPStateMachine(t *testing.T) {
t.Run(tt.name, func(t *testing.T) {
t.Helper()
tracker = NewTCPTracker(DefaultTCPTimeout)
tracker = NewTCPTracker(DefaultTCPTimeout, nil)
tt.test(t)
})
}
@ -162,7 +162,7 @@ func TestTCPStateMachine(t *testing.T) {
}
func TestRSTHandling(t *testing.T) {
tracker := NewTCPTracker(DefaultTCPTimeout)
tracker := NewTCPTracker(DefaultTCPTimeout, nil)
defer tracker.Close()
srcIP := net.ParseIP("100.64.0.1")
@ -233,7 +233,7 @@ func establishConnection(t *testing.T, tracker *TCPTracker, srcIP, dstIP net.IP,
func BenchmarkTCPTracker(b *testing.B) {
b.Run("TrackOutbound", func(b *testing.B) {
tracker := NewTCPTracker(DefaultTCPTimeout)
tracker := NewTCPTracker(DefaultTCPTimeout, nil)
defer tracker.Close()
srcIP := net.ParseIP("192.168.1.1")
@ -246,7 +246,7 @@ func BenchmarkTCPTracker(b *testing.B) {
})
b.Run("IsValidInbound", func(b *testing.B) {
tracker := NewTCPTracker(DefaultTCPTimeout)
tracker := NewTCPTracker(DefaultTCPTimeout, nil)
defer tracker.Close()
srcIP := net.ParseIP("192.168.1.1")
@ -264,7 +264,7 @@ func BenchmarkTCPTracker(b *testing.B) {
})
b.Run("ConcurrentAccess", func(b *testing.B) {
tracker := NewTCPTracker(DefaultTCPTimeout)
tracker := NewTCPTracker(DefaultTCPTimeout, nil)
defer tracker.Close()
srcIP := net.ParseIP("192.168.1.1")
@ -287,7 +287,7 @@ func BenchmarkTCPTracker(b *testing.B) {
// Benchmark connection cleanup
func BenchmarkCleanup(b *testing.B) {
b.Run("TCPCleanup", func(b *testing.B) {
tracker := NewTCPTracker(100 * time.Millisecond) // Short timeout for testing
tracker := NewTCPTracker(100*time.Millisecond, nil) // Short timeout for testing
defer tracker.Close()
// Pre-populate with expired connections

View File

@ -4,6 +4,8 @@ import (
"net"
"sync"
"time"
nblog "github.com/netbirdio/netbird/client/firewall/uspfilter/log"
)
const (
@ -20,6 +22,7 @@ type UDPConnTrack struct {
// UDPTracker manages UDP connection states
type UDPTracker struct {
logger *nblog.Logger
connections map[ConnKey]*UDPConnTrack
timeout time.Duration
cleanupTicker *time.Ticker
@ -29,12 +32,13 @@ type UDPTracker struct {
}
// NewUDPTracker creates a new UDP connection tracker
func NewUDPTracker(timeout time.Duration) *UDPTracker {
func NewUDPTracker(timeout time.Duration, logger *nblog.Logger) *UDPTracker {
if timeout == 0 {
timeout = DefaultUDPTimeout
}
tracker := &UDPTracker{
logger: logger,
connections: make(map[ConnKey]*UDPConnTrack),
timeout: timeout,
cleanupTicker: time.NewTicker(UDPCleanupInterval),
@ -70,6 +74,8 @@ func (t *UDPTracker) TrackOutbound(srcIP net.IP, dstIP net.IP, srcPort uint16, d
conn.lastSeen.Store(now)
conn.established.Store(true)
t.connections[key] = conn
t.logger.Trace("New UDP connection: %s", conn)
}
t.mutex.Unlock()
@ -120,6 +126,8 @@ func (t *UDPTracker) cleanup() {
t.ipPool.Put(conn.SourceIP)
t.ipPool.Put(conn.DestIP)
delete(t.connections, key)
t.logger.Trace("UDP connection timed out: %s", conn)
}
}
}

View File

@ -29,7 +29,7 @@ func TestNewUDPTracker(t *testing.T) {
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
tracker := NewUDPTracker(tt.timeout)
tracker := NewUDPTracker(tt.timeout, nil)
assert.NotNil(t, tracker)
assert.Equal(t, tt.wantTimeout, tracker.timeout)
assert.NotNil(t, tracker.connections)
@ -40,7 +40,7 @@ func TestNewUDPTracker(t *testing.T) {
}
func TestUDPTracker_TrackOutbound(t *testing.T) {
tracker := NewUDPTracker(DefaultUDPTimeout)
tracker := NewUDPTracker(DefaultUDPTimeout, nil)
defer tracker.Close()
srcIP := net.ParseIP("192.168.1.2")
@ -63,7 +63,7 @@ func TestUDPTracker_TrackOutbound(t *testing.T) {
}
func TestUDPTracker_IsValidInbound(t *testing.T) {
tracker := NewUDPTracker(1 * time.Second)
tracker := NewUDPTracker(1*time.Second, nil)
defer tracker.Close()
srcIP := net.ParseIP("192.168.1.2")
@ -211,7 +211,7 @@ func TestUDPTracker_Cleanup(t *testing.T) {
func BenchmarkUDPTracker(b *testing.B) {
b.Run("TrackOutbound", func(b *testing.B) {
tracker := NewUDPTracker(DefaultUDPTimeout)
tracker := NewUDPTracker(DefaultUDPTimeout, nil)
defer tracker.Close()
srcIP := net.ParseIP("192.168.1.1")
@ -224,7 +224,7 @@ func BenchmarkUDPTracker(b *testing.B) {
})
b.Run("IsValidInbound", func(b *testing.B) {
tracker := NewUDPTracker(DefaultUDPTimeout)
tracker := NewUDPTracker(DefaultUDPTimeout, nil)
defer tracker.Close()
srcIP := net.ParseIP("192.168.1.1")

View File

@ -1,15 +1,17 @@
package forwarder
import (
log "github.com/sirupsen/logrus"
wgdevice "golang.zx2c4.com/wireguard/device"
"gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/tcpip/header"
"gvisor.dev/gvisor/pkg/tcpip/stack"
nblog "github.com/netbirdio/netbird/client/firewall/uspfilter/log"
)
// endpoint implements stack.LinkEndpoint and handles integration with the wireguard device
type endpoint struct {
logger *nblog.Logger
dispatcher stack.NetworkDispatcher
device *wgdevice.Device
mtu uint32
@ -55,7 +57,7 @@ func (e *endpoint) WritePackets(pkts stack.PacketBufferList) (int, tcpip.Error)
// TODO: handle dest ip addresses outside our network
err := e.device.CreateOutboundPacket(data.AsSlice(), address.AsSlice())
if err != nil {
log.Errorf("CreateOutboundPacket: %v", err)
e.logger.Error("CreateOutboundPacket: %v", err)
continue
}
written++

View File

@ -14,6 +14,7 @@ import (
"gvisor.dev/gvisor/pkg/tcpip/transport/udp"
"github.com/netbirdio/netbird/client/firewall/uspfilter/common"
nblog "github.com/netbirdio/netbird/client/firewall/uspfilter/log"
)
const (
@ -22,6 +23,7 @@ const (
)
type Forwarder struct {
logger *nblog.Logger
stack *stack.Stack
endpoint *endpoint
udpForwarder *udpForwarder
@ -29,8 +31,7 @@ type Forwarder struct {
cancel context.CancelFunc
}
func New(iface common.IFaceMapper) (*Forwarder, error) {
func New(iface common.IFaceMapper, logger *nblog.Logger) (*Forwarder, error) {
s := stack.New(stack.Options{
NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol},
TransportProtocols: []stack.TransportProtocolFactory{
@ -46,6 +47,7 @@ func New(iface common.IFaceMapper) (*Forwarder, error) {
}
nicID := tcpip.NICID(1)
endpoint := &endpoint{
logger: logger,
device: iface.GetWGDevice(),
mtu: uint32(mtu),
}
@ -91,9 +93,10 @@ func New(iface common.IFaceMapper) (*Forwarder, error) {
ctx, cancel := context.WithCancel(context.Background())
f := &Forwarder{
logger: logger,
stack: s,
endpoint: endpoint,
udpForwarder: newUDPForwarder(),
udpForwarder: newUDPForwarder(logger),
ctx: ctx,
cancel: cancel,
}

View File

@ -6,7 +6,6 @@ import (
"io"
"net"
log "github.com/sirupsen/logrus"
"gvisor.dev/gvisor/pkg/tcpip/adapters/gonet"
"gvisor.dev/gvisor/pkg/tcpip/transport/tcp"
"gvisor.dev/gvisor/pkg/waiter"
@ -23,16 +22,19 @@ func (f *Forwarder) handleTCP(r *tcp.ForwarderRequest) {
outConn, err := (&net.Dialer{}).DialContext(f.ctx, "tcp", dialAddr)
if err != nil {
r.Complete(true)
f.logger.Trace("forwarder: dial error for %v: %v", id, err)
return
}
f.logger.Trace("forwarder: established TCP connection to %v", id)
// Create wait queue for blocking syscalls
wq := waiter.Queue{}
ep, err2 := r.CreateEndpoint(&wq)
if err2 != nil {
if err := outConn.Close(); err != nil {
log.Errorf("forwarder: outConn close error: %v", err)
f.logger.Error("forwarder: outConn close error: %v", err)
}
r.Complete(true)
return
@ -49,10 +51,10 @@ func (f *Forwarder) handleTCP(r *tcp.ForwarderRequest) {
func (f *Forwarder) proxyTCP(inConn *gonet.TCPConn, outConn net.Conn) {
defer func() {
if err := inConn.Close(); err != nil {
log.Errorf("forwarder: inConn close error: %v", err)
f.logger.Error("forwarder: inConn close error: %v", err)
}
if err := outConn.Close(); err != nil {
log.Errorf("forwarder: outConn close error: %v", err)
f.logger.Error("forwarder: outConn close error: %v", err)
}
}()
@ -65,7 +67,7 @@ func (f *Forwarder) proxyTCP(inConn *gonet.TCPConn, outConn net.Conn) {
go func() {
n, err := io.Copy(outConn, inConn)
if err != nil && !isClosedError(err) {
log.Errorf("proxyTCP: inbound->outbound copy error after %d bytes: %v", n, err)
f.logger.Error("inbound->outbound copy error after %d bytes: %v", n, err)
}
errChan <- err
}()
@ -73,7 +75,7 @@ func (f *Forwarder) proxyTCP(inConn *gonet.TCPConn, outConn net.Conn) {
go func() {
n, err := io.Copy(inConn, outConn)
if err != nil && !isClosedError(err) {
log.Errorf("proxyTCP: outbound->inbound copy error after %d bytes: %v", n, err)
f.logger.Error("outbound->inbound copy error after %d bytes: %v", n, err)
}
errChan <- err
}()
@ -83,7 +85,7 @@ func (f *Forwarder) proxyTCP(inConn *gonet.TCPConn, outConn net.Conn) {
return
case err := <-errChan:
if err != nil && !isClosedError(err) {
log.Errorf("proxyTCP: copy error: %v", err)
f.logger.Error("proxyTCP: copy error: %v", err)
}
return
}

View File

@ -8,11 +8,12 @@ import (
"sync"
"time"
log "github.com/sirupsen/logrus"
"gvisor.dev/gvisor/pkg/tcpip/adapters/gonet"
"gvisor.dev/gvisor/pkg/tcpip/stack"
"gvisor.dev/gvisor/pkg/tcpip/transport/udp"
"gvisor.dev/gvisor/pkg/waiter"
nblog "github.com/netbirdio/netbird/client/firewall/uspfilter/log"
)
const (
@ -29,15 +30,17 @@ type udpPacketConn struct {
type udpForwarder struct {
sync.RWMutex
logger *nblog.Logger
conns map[stack.TransportEndpointID]*udpPacketConn
bufPool sync.Pool
ctx context.Context
cancel context.CancelFunc
}
func newUDPForwarder() *udpForwarder {
func newUDPForwarder(logger *nblog.Logger) *udpForwarder {
ctx, cancel := context.WithCancel(context.Background())
f := &udpForwarder{
logger: logger,
conns: make(map[stack.TransportEndpointID]*udpPacketConn),
ctx: ctx,
cancel: cancel,
@ -62,10 +65,10 @@ func (f *udpForwarder) Stop() {
for id, conn := range f.conns {
conn.cancel()
if err := conn.conn.Close(); err != nil {
log.Errorf("forwarder: UDP conn close error for %v: %v", id, err)
f.logger.Error("forwarder: UDP conn close error for %v: %v", id, err)
}
if err := conn.outConn.Close(); err != nil {
log.Errorf("forwarder: UDP outConn close error for %v: %v", id, err)
f.logger.Error("forwarder: UDP outConn close error for %v: %v", id, err)
}
delete(f.conns, id)
}
@ -87,13 +90,13 @@ func (f *udpForwarder) cleanup() {
if now.Sub(conn.lastTime) > udpTimeout {
conn.cancel()
if err := conn.conn.Close(); err != nil {
log.Errorf("forwarder: UDP conn close error for %v: %v", id, err)
f.logger.Error("forwarder: UDP conn close error for %v: %v", id, err)
}
if err := conn.outConn.Close(); err != nil {
log.Errorf("forwarder: UDP outConn close error for %v: %v", id, err)
f.logger.Error("forwarder: UDP outConn close error for %v: %v", id, err)
}
delete(f.conns, id)
log.Debugf("forwarder: cleaned up idle UDP connection %v", id)
f.logger.Trace("forwarder: cleaned up idle UDP connection %v", id)
}
}
f.Unlock()
@ -107,7 +110,7 @@ func (f *Forwarder) handleUDP(r *udp.ForwarderRequest) {
dstAddr := fmt.Sprintf("%s:%d", id.LocalAddress.String(), id.LocalPort)
if f.ctx.Err() != nil {
log.Debug("forwarder: context done, dropping UDP packet")
f.logger.Trace("forwarder: context done, dropping UDP packet")
return
}
@ -116,7 +119,7 @@ func (f *Forwarder) handleUDP(r *udp.ForwarderRequest) {
ep, err := r.CreateEndpoint(&wq)
if err != nil {
log.Errorf("forwarder: failed to create UDP endpoint: %v", err)
f.logger.Error("forwarder: failed to create UDP endpoint: %v", err)
return
}
@ -131,12 +134,16 @@ func (f *Forwarder) handleUDP(r *udp.ForwarderRequest) {
outConn, err := (&net.Dialer{}).DialContext(f.ctx, "udp", dstAddr)
if err != nil {
if err := inConn.Close(); err != nil {
log.Errorf("forwarder: UDP inConn close error for %v: %v", id, err)
f.logger.Error("forwarder: UDP inConn close error for %v: %v", id, err)
}
log.Errorf("forwarder: UDP dial error for %v: %v", id, err)
f.logger.Debug("forwarder: UDP dial error for %v: %v", id, err)
// TODO: Send ICMP error message
return
}
f.logger.Trace("forwarder: established UDP connection to %v", id)
connCtx, connCancel := context.WithCancel(f.ctx)
pConn = &udpPacketConn{
conn: inConn,
@ -154,10 +161,10 @@ func (f *Forwarder) proxyUDP(ctx context.Context, pConn *udpPacketConn, id stack
defer func() {
pConn.cancel()
if err := pConn.conn.Close(); err != nil {
log.Errorf("forwarder: UDP inConn close error for %v: %v", id, err)
f.logger.Error("forwarder: UDP inConn close error for %v: %v", id, err)
}
if err := pConn.outConn.Close(); err != nil {
log.Errorf("forwarder: UDP outConn close error for %v: %v", id, err)
f.logger.Error("forwarder: UDP outConn close error for %v: %v", id, err)
}
f.udpForwarder.Lock()
@ -180,7 +187,7 @@ func (f *Forwarder) proxyUDP(ctx context.Context, pConn *udpPacketConn, id stack
return
case err := <-errChan:
if err != nil && !isClosedError(err) {
log.Errorf("forwader: UDP proxy error for %v: %v", id, err)
f.logger.Error("proxyUDP: copy error: %v", err)
}
return
}

View File

@ -0,0 +1,208 @@
// Package logger provides a high-performance, non-blocking logger for userspace networking
package log
import (
"context"
"fmt"
"io"
"sync"
"sync/atomic"
"time"
log "github.com/sirupsen/logrus"
)
const (
maxBatchSize = 1024 * 16 // 16KB max batch size
maxMessageSize = 1024 * 2 // 2KB per message
bufferSize = 1024 * 256 // 256KB ring buffer
defaultFlushInterval = 2 * time.Second
)
// Level represents log severity
type Level uint32
const (
LevelPanic Level = iota
LevelFatal
LevelError
LevelWarn
LevelInfo
LevelDebug
LevelTrace
)
var levelStrings = map[Level]string{
LevelPanic: "PANC",
LevelFatal: "FATL",
LevelError: "ERRO",
LevelWarn: "WARN",
LevelInfo: "INFO",
LevelDebug: "DEBG",
LevelTrace: "TRAC",
}
func FromLogrusLevel(level log.Level) Level {
switch level {
case log.TraceLevel:
return LevelTrace
case log.DebugLevel:
return LevelDebug
case log.InfoLevel:
return LevelInfo
case log.WarnLevel:
return LevelWarn
case log.ErrorLevel:
return LevelError
case log.FatalLevel:
return LevelFatal
case log.PanicLevel:
return LevelPanic
default:
return LevelInfo
}
}
// Logger is a high-performance, non-blocking logger
type Logger struct {
output io.Writer
level atomic.Uint32
buffer *ringBuffer
shutdown chan struct{}
wg sync.WaitGroup
// Reusable buffer pool for formatting messages
bufPool sync.Pool
}
func NewFromLogrus(logrusLogger *log.Logger) *Logger {
l := &Logger{
output: logrusLogger.Out,
buffer: newRingBuffer(bufferSize),
shutdown: make(chan struct{}),
bufPool: sync.Pool{
New: func() interface{} {
// Pre-allocate buffer for message formatting
b := make([]byte, 0, maxMessageSize)
return &b
},
},
}
l.level.Store(uint32(LevelInfo))
l.wg.Add(1)
go l.worker()
return l
}
func (l *Logger) SetLevel(level Level) {
l.level.Store(uint32(level))
}
func (l *Logger) formatMessage(buf *[]byte, level Level, format string, args ...interface{}) {
*buf = (*buf)[:0]
// Timestamp
*buf = time.Now().AppendFormat(*buf, "2006-01-02T15:04:05.000000-07:00")
*buf = append(*buf, ' ')
// Level
*buf = append(*buf, levelStrings[level]...)
*buf = append(*buf, ' ')
// Message
if len(args) > 0 {
*buf = append(*buf, fmt.Sprintf(format, args...)...)
} else {
*buf = append(*buf, format...)
}
*buf = append(*buf, '\n')
}
func (l *Logger) log(level Level, format string, args ...interface{}) {
bufp := l.bufPool.Get().(*[]byte)
l.formatMessage(bufp, level, format, args...)
if len(*bufp) > maxMessageSize {
*bufp = (*bufp)[:maxMessageSize]
}
l.buffer.Write(*bufp)
l.bufPool.Put(bufp)
}
func (l *Logger) Trace(format string, args ...interface{}) {
if l.level.Load() <= uint32(LevelTrace) {
l.log(LevelTrace, format, args...)
}
}
func (l *Logger) Debug(format string, args ...interface{}) {
if l.level.Load() <= uint32(LevelDebug) {
l.log(LevelDebug, format, args...)
}
}
func (l *Logger) Info(format string, args ...interface{}) {
if l.level.Load() <= uint32(LevelInfo) {
l.log(LevelInfo, format, args...)
}
}
func (l *Logger) Warn(format string, args ...interface{}) {
if l.level.Load() <= uint32(LevelWarn) {
l.log(LevelWarn, format, args...)
}
}
func (l *Logger) Error(format string, args ...interface{}) {
if l.level.Load() <= uint32(LevelError) {
l.log(LevelError, format, args...)
}
}
// worker periodically flushes the buffer
func (l *Logger) worker() {
defer l.wg.Done()
ticker := time.NewTicker(defaultFlushInterval)
defer ticker.Stop()
buf := make([]byte, 0, maxBatchSize)
for {
select {
case <-l.shutdown:
return
case <-ticker.C:
// Read accumulated messages
n, _ := l.buffer.Read(buf[:cap(buf)])
if n == 0 {
continue
}
// Write batch
l.output.Write(buf[:n])
}
}
}
// Stop gracefully shuts down the logger
func (l *Logger) Stop(ctx context.Context) error {
close(l.shutdown)
done := make(chan struct{})
go func() {
l.wg.Wait()
close(done)
}()
select {
case <-ctx.Done():
return ctx.Err()
case <-done:
return nil
}
}

View File

@ -0,0 +1,93 @@
package log
import "sync"
// ringBuffer is a simple ring buffer implementation
type ringBuffer struct {
buf []byte
size int
r, w int64 // Read and write positions
mu sync.Mutex
}
func newRingBuffer(size int) *ringBuffer {
return &ringBuffer{
buf: make([]byte, size),
size: size,
}
}
func (r *ringBuffer) Write(p []byte) (n int, err error) {
if len(p) == 0 {
return 0, nil
}
r.mu.Lock()
defer r.mu.Unlock()
if len(p) > r.size {
p = p[:r.size]
}
n = len(p)
// Write data, handling wrap-around
pos := int(r.w % int64(r.size))
writeLen := min(len(p), r.size-pos)
copy(r.buf[pos:], p[:writeLen])
// If we have more data and need to wrap around
if writeLen < len(p) {
copy(r.buf, p[writeLen:])
}
// Update write position
r.w += int64(n)
return n, nil
}
func (r *ringBuffer) Read(p []byte) (n int, err error) {
r.mu.Lock()
defer r.mu.Unlock()
if r.w == r.r {
return 0, nil
}
// Calculate available data accounting for wraparound
available := int(r.w - r.r)
if available < 0 {
available += r.size
}
available = min(available, r.size)
// Limit read to buffer size
toRead := min(available, len(p))
if toRead == 0 {
return 0, nil
}
// Read data, handling wrap-around
pos := int(r.r % int64(r.size))
readLen := min(toRead, r.size-pos)
n = copy(p, r.buf[pos:pos+readLen])
// If we need more data and need to wrap around
if readLen < toRead {
n += copy(p[readLen:toRead], r.buf[:toRead-readLen])
}
// Update read position
r.r += int64(n)
return n, nil
}
// min returns the smaller of two integers
func min(a, b int) int {
if a < b {
return a
}
return b
}

View File

@ -17,6 +17,7 @@ import (
"github.com/netbirdio/netbird/client/firewall/uspfilter/common"
"github.com/netbirdio/netbird/client/firewall/uspfilter/conntrack"
"github.com/netbirdio/netbird/client/firewall/uspfilter/forwarder"
nblog "github.com/netbirdio/netbird/client/firewall/uspfilter/log"
"github.com/netbirdio/netbird/client/internal/statemanager"
)
@ -52,6 +53,7 @@ type Manager struct {
icmpTracker *conntrack.ICMPTracker
tcpTracker *conntrack.TCPTracker
forwarder *forwarder.Forwarder
logger *nblog.Logger
}
// decoder for packages
@ -106,15 +108,17 @@ func create(iface common.IFaceMapper) (*Manager, error) {
stateful: !disableConntrack,
// TODO: fix
routingEnabled: true,
// TODO: support chaning log level from logrus
logger: nblog.NewFromLogrus(log.StandardLogger()),
}
// Only initialize trackers if stateful mode is enabled
if disableConntrack {
log.Info("conntrack is disabled")
} else {
m.udpTracker = conntrack.NewUDPTracker(conntrack.DefaultUDPTimeout)
m.icmpTracker = conntrack.NewICMPTracker(conntrack.DefaultICMPTimeout)
m.tcpTracker = conntrack.NewTCPTracker(conntrack.DefaultTCPTimeout)
m.udpTracker = conntrack.NewUDPTracker(conntrack.DefaultUDPTimeout, m.logger)
m.icmpTracker = conntrack.NewICMPTracker(conntrack.DefaultICMPTimeout, m.logger)
m.tcpTracker = conntrack.NewTCPTracker(conntrack.DefaultTCPTimeout, m.logger)
}
intf := iface.GetWGDevice()
@ -125,7 +129,7 @@ func create(iface common.IFaceMapper) (*Manager, error) {
m.routingEnabled = false
} else {
var err error
m.forwarder, err = forwarder.New(iface)
m.forwarder, err = forwarder.New(iface, m.logger)
if err != nil {
log.Errorf("failed to create forwarder: %v", err)
m.routingEnabled = false
@ -455,17 +459,16 @@ func (m *Manager) dropFilter(packetData []byte, rules map[string]RuleSet) bool {
defer m.decoders.Put(d)
if !m.isValidPacket(d, packetData) {
log.Debugf("invalid packet: %v", d.decoded)
m.logger.Trace("Invalid packet structure")
return true
}
srcIP, dstIP := m.extractIPs(d)
if srcIP == nil {
log.Errorf("unknown layer: %v", d.decoded[0])
m.logger.Error("Unknown network layer: %v", d.decoded[0])
return true
}
// Check if this is local or routed traffic
isLocal := m.isLocalIP(dstIP)
// For all inbound traffic, first check if it matches a tracked connection.
@ -476,7 +479,12 @@ func (m *Manager) dropFilter(packetData []byte, rules map[string]RuleSet) bool {
// Handle local traffic - apply peer ACLs
if isLocal {
return m.applyRules(srcIP, packetData, rules, d)
drop := m.applyRules(srcIP, packetData, rules, d)
if drop {
m.logger.Trace("Dropping local packet: src=%s dst=%s rules=denied",
srcIP, dstIP)
}
return drop
}
// Handle routed traffic
@ -484,6 +492,8 @@ func (m *Manager) dropFilter(packetData []byte, rules map[string]RuleSet) bool {
// We might need to apply NAT
// Don't handle routing if not enabled
if !m.routingEnabled {
m.logger.Trace("Dropping routed packet (routing disabled): src=%s dst=%s",
srcIP, dstIP)
return true
}
@ -493,13 +503,15 @@ func (m *Manager) dropFilter(packetData []byte, rules map[string]RuleSet) bool {
// Check route ACLs
if !m.checkRouteACLs(srcIP, dstIP, proto, srcPort, dstPort) {
m.logger.Trace("Dropping routed packet (ACL denied): src=%s:%d dst=%s:%d proto=%v",
srcIP, srcPort, dstIP, dstPort, proto)
return true
}
// Let forwarder handle the packet if it passed route ACLs
err := m.forwarder.InjectIncomingPacket(packetData)
if err != nil {
log.Errorf("Failed to inject incoming packet: %v", err)
m.logger.Error("Failed to inject incoming packet: %v", err)
}
// Default: drop

View File

@ -400,7 +400,7 @@ func TestProcessOutgoingHooks(t *testing.T) {
Mask: net.CIDRMask(16, 32),
}
manager.udpTracker.Close()
manager.udpTracker = conntrack.NewUDPTracker(100 * time.Millisecond)
manager.udpTracker = conntrack.NewUDPTracker(100*time.Millisecond, nil)
defer func() {
require.NoError(t, manager.Reset(nil))
}()
@ -518,7 +518,7 @@ func TestStatefulFirewall_UDPTracking(t *testing.T) {
}
manager.udpTracker.Close() // Close the existing tracker
manager.udpTracker = conntrack.NewUDPTracker(200 * time.Millisecond)
manager.udpTracker = conntrack.NewUDPTracker(200*time.Millisecond, nil)
manager.decoders = sync.Pool{
New: func() any {
d := &decoder{