mirror of
https://github.com/netbirdio/netbird.git
synced 2025-06-23 19:21:23 +02:00
Add logger
This commit is contained in:
parent
fad82ee65c
commit
d2616544fe
@ -3,6 +3,11 @@
|
||||
package uspfilter
|
||||
|
||||
import (
|
||||
"context"
|
||||
"time"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
"github.com/netbirdio/netbird/client/firewall/uspfilter/conntrack"
|
||||
"github.com/netbirdio/netbird/client/internal/statemanager"
|
||||
)
|
||||
@ -17,23 +22,31 @@ func (m *Manager) Reset(stateManager *statemanager.Manager) error {
|
||||
|
||||
if m.udpTracker != nil {
|
||||
m.udpTracker.Close()
|
||||
m.udpTracker = conntrack.NewUDPTracker(conntrack.DefaultUDPTimeout)
|
||||
m.udpTracker = conntrack.NewUDPTracker(conntrack.DefaultUDPTimeout, m.logger)
|
||||
}
|
||||
|
||||
if m.icmpTracker != nil {
|
||||
m.icmpTracker.Close()
|
||||
m.icmpTracker = conntrack.NewICMPTracker(conntrack.DefaultICMPTimeout)
|
||||
m.icmpTracker = conntrack.NewICMPTracker(conntrack.DefaultICMPTimeout, m.logger)
|
||||
}
|
||||
|
||||
if m.tcpTracker != nil {
|
||||
m.tcpTracker.Close()
|
||||
m.tcpTracker = conntrack.NewTCPTracker(conntrack.DefaultTCPTimeout)
|
||||
m.tcpTracker = conntrack.NewTCPTracker(conntrack.DefaultTCPTimeout, m.logger)
|
||||
}
|
||||
|
||||
if m.forwarder != nil {
|
||||
m.forwarder.Stop()
|
||||
}
|
||||
|
||||
if m.logger != nil {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
|
||||
defer cancel()
|
||||
if err := m.logger.Stop(ctx); err != nil {
|
||||
log.Errorf("failed to shutdown logger: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
if m.nativeFirewall != nil {
|
||||
return m.nativeFirewall.Reset(stateManager)
|
||||
}
|
||||
|
@ -1,9 +1,11 @@
|
||||
package uspfilter
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"os/exec"
|
||||
"syscall"
|
||||
"time"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
@ -29,23 +31,31 @@ func (m *Manager) Reset(*statemanager.Manager) error {
|
||||
|
||||
if m.udpTracker != nil {
|
||||
m.udpTracker.Close()
|
||||
m.udpTracker = conntrack.NewUDPTracker(conntrack.DefaultUDPTimeout)
|
||||
m.udpTracker = conntrack.NewUDPTracker(conntrack.DefaultUDPTimeout, m.logger)
|
||||
}
|
||||
|
||||
if m.icmpTracker != nil {
|
||||
m.icmpTracker.Close()
|
||||
m.icmpTracker = conntrack.NewICMPTracker(conntrack.DefaultICMPTimeout)
|
||||
m.icmpTracker = conntrack.NewICMPTracker(conntrack.DefaultICMPTimeout, m.logger)
|
||||
}
|
||||
|
||||
if m.tcpTracker != nil {
|
||||
m.tcpTracker.Close()
|
||||
m.tcpTracker = conntrack.NewTCPTracker(conntrack.DefaultTCPTimeout)
|
||||
m.tcpTracker = conntrack.NewTCPTracker(conntrack.DefaultTCPTimeout, m.logger)
|
||||
}
|
||||
|
||||
if m.forwarder != nil {
|
||||
m.forwarder.Stop()
|
||||
}
|
||||
|
||||
if m.logger != nil {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
|
||||
defer cancel()
|
||||
if err := m.logger.Stop(ctx); err != nil {
|
||||
log.Errorf("failed to shutdown logger: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
if !isWindowsFirewallReachable() {
|
||||
return nil
|
||||
}
|
||||
|
@ -64,7 +64,7 @@ func BenchmarkAtomicOperations(b *testing.B) {
|
||||
// Memory pressure tests
|
||||
func BenchmarkMemoryPressure(b *testing.B) {
|
||||
b.Run("TCPHighLoad", func(b *testing.B) {
|
||||
tracker := NewTCPTracker(DefaultTCPTimeout)
|
||||
tracker := NewTCPTracker(DefaultTCPTimeout, nil)
|
||||
defer tracker.Close()
|
||||
|
||||
// Generate different IPs
|
||||
@ -89,7 +89,7 @@ func BenchmarkMemoryPressure(b *testing.B) {
|
||||
})
|
||||
|
||||
b.Run("UDPHighLoad", func(b *testing.B) {
|
||||
tracker := NewUDPTracker(DefaultUDPTimeout)
|
||||
tracker := NewUDPTracker(DefaultUDPTimeout, nil)
|
||||
defer tracker.Close()
|
||||
|
||||
// Generate different IPs
|
||||
|
@ -6,6 +6,8 @@ import (
|
||||
"time"
|
||||
|
||||
"github.com/google/gopacket/layers"
|
||||
|
||||
nblog "github.com/netbirdio/netbird/client/firewall/uspfilter/log"
|
||||
)
|
||||
|
||||
const (
|
||||
@ -33,6 +35,7 @@ type ICMPConnTrack struct {
|
||||
|
||||
// ICMPTracker manages ICMP connection states
|
||||
type ICMPTracker struct {
|
||||
logger *nblog.Logger
|
||||
connections map[ICMPConnKey]*ICMPConnTrack
|
||||
timeout time.Duration
|
||||
cleanupTicker *time.Ticker
|
||||
@ -42,12 +45,13 @@ type ICMPTracker struct {
|
||||
}
|
||||
|
||||
// NewICMPTracker creates a new ICMP connection tracker
|
||||
func NewICMPTracker(timeout time.Duration) *ICMPTracker {
|
||||
func NewICMPTracker(timeout time.Duration, logger *nblog.Logger) *ICMPTracker {
|
||||
if timeout == 0 {
|
||||
timeout = DefaultICMPTimeout
|
||||
}
|
||||
|
||||
tracker := &ICMPTracker{
|
||||
logger: logger,
|
||||
connections: make(map[ICMPConnKey]*ICMPConnTrack),
|
||||
timeout: timeout,
|
||||
cleanupTicker: time.NewTicker(ICMPCleanupInterval),
|
||||
@ -83,6 +87,8 @@ func (t *ICMPTracker) TrackOutbound(srcIP net.IP, dstIP net.IP, id uint16, seq u
|
||||
conn.lastSeen.Store(now)
|
||||
conn.established.Store(true)
|
||||
t.connections[key] = conn
|
||||
|
||||
t.logger.Trace("New ICMP connection %v", key)
|
||||
}
|
||||
t.mutex.Unlock()
|
||||
|
||||
@ -141,6 +147,8 @@ func (t *ICMPTracker) cleanup() {
|
||||
t.ipPool.Put(conn.SourceIP)
|
||||
t.ipPool.Put(conn.DestIP)
|
||||
delete(t.connections, key)
|
||||
|
||||
t.logger.Debug("ICMPTracker: removed connection %v", key)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -7,7 +7,7 @@ import (
|
||||
|
||||
func BenchmarkICMPTracker(b *testing.B) {
|
||||
b.Run("TrackOutbound", func(b *testing.B) {
|
||||
tracker := NewICMPTracker(DefaultICMPTimeout)
|
||||
tracker := NewICMPTracker(DefaultICMPTimeout, nil)
|
||||
defer tracker.Close()
|
||||
|
||||
srcIP := net.ParseIP("192.168.1.1")
|
||||
@ -20,7 +20,7 @@ func BenchmarkICMPTracker(b *testing.B) {
|
||||
})
|
||||
|
||||
b.Run("IsValidInbound", func(b *testing.B) {
|
||||
tracker := NewICMPTracker(DefaultICMPTimeout)
|
||||
tracker := NewICMPTracker(DefaultICMPTimeout, nil)
|
||||
defer tracker.Close()
|
||||
|
||||
srcIP := net.ParseIP("192.168.1.1")
|
||||
|
@ -6,6 +6,8 @@ import (
|
||||
"net"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
nblog "github.com/netbirdio/netbird/client/firewall/uspfilter/log"
|
||||
)
|
||||
|
||||
const (
|
||||
@ -67,6 +69,7 @@ type TCPConnTrack struct {
|
||||
|
||||
// TCPTracker manages TCP connection states
|
||||
type TCPTracker struct {
|
||||
logger *nblog.Logger
|
||||
connections map[ConnKey]*TCPConnTrack
|
||||
mutex sync.RWMutex
|
||||
cleanupTicker *time.Ticker
|
||||
@ -76,8 +79,9 @@ type TCPTracker struct {
|
||||
}
|
||||
|
||||
// NewTCPTracker creates a new TCP connection tracker
|
||||
func NewTCPTracker(timeout time.Duration) *TCPTracker {
|
||||
func NewTCPTracker(timeout time.Duration, logger *nblog.Logger) *TCPTracker {
|
||||
tracker := &TCPTracker{
|
||||
logger: logger,
|
||||
connections: make(map[ConnKey]*TCPConnTrack),
|
||||
cleanupTicker: time.NewTicker(TCPCleanupInterval),
|
||||
done: make(chan struct{}),
|
||||
@ -116,6 +120,8 @@ func (t *TCPTracker) TrackOutbound(srcIP net.IP, dstIP net.IP, srcPort uint16, d
|
||||
conn.lastSeen.Store(now)
|
||||
conn.established.Store(false)
|
||||
t.connections[key] = conn
|
||||
|
||||
t.logger.Trace("New TCP connection: %s:%d -> %s:%d", srcIP, srcPort, dstIP, dstPort)
|
||||
}
|
||||
t.mutex.Unlock()
|
||||
|
||||
@ -318,6 +324,8 @@ func (t *TCPTracker) cleanup() {
|
||||
t.ipPool.Put(conn.SourceIP)
|
||||
t.ipPool.Put(conn.DestIP)
|
||||
delete(t.connections, key)
|
||||
|
||||
t.logger.Trace("Closed TCP connection: %s:%d -> %s:%d", conn.SourceIP, conn.SourcePort, conn.DestIP, conn.DestPort)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -9,7 +9,7 @@ import (
|
||||
)
|
||||
|
||||
func TestTCPStateMachine(t *testing.T) {
|
||||
tracker := NewTCPTracker(DefaultTCPTimeout)
|
||||
tracker := NewTCPTracker(DefaultTCPTimeout, nil)
|
||||
defer tracker.Close()
|
||||
|
||||
srcIP := net.ParseIP("100.64.0.1")
|
||||
@ -154,7 +154,7 @@ func TestTCPStateMachine(t *testing.T) {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
t.Helper()
|
||||
|
||||
tracker = NewTCPTracker(DefaultTCPTimeout)
|
||||
tracker = NewTCPTracker(DefaultTCPTimeout, nil)
|
||||
tt.test(t)
|
||||
})
|
||||
}
|
||||
@ -162,7 +162,7 @@ func TestTCPStateMachine(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestRSTHandling(t *testing.T) {
|
||||
tracker := NewTCPTracker(DefaultTCPTimeout)
|
||||
tracker := NewTCPTracker(DefaultTCPTimeout, nil)
|
||||
defer tracker.Close()
|
||||
|
||||
srcIP := net.ParseIP("100.64.0.1")
|
||||
@ -233,7 +233,7 @@ func establishConnection(t *testing.T, tracker *TCPTracker, srcIP, dstIP net.IP,
|
||||
|
||||
func BenchmarkTCPTracker(b *testing.B) {
|
||||
b.Run("TrackOutbound", func(b *testing.B) {
|
||||
tracker := NewTCPTracker(DefaultTCPTimeout)
|
||||
tracker := NewTCPTracker(DefaultTCPTimeout, nil)
|
||||
defer tracker.Close()
|
||||
|
||||
srcIP := net.ParseIP("192.168.1.1")
|
||||
@ -246,7 +246,7 @@ func BenchmarkTCPTracker(b *testing.B) {
|
||||
})
|
||||
|
||||
b.Run("IsValidInbound", func(b *testing.B) {
|
||||
tracker := NewTCPTracker(DefaultTCPTimeout)
|
||||
tracker := NewTCPTracker(DefaultTCPTimeout, nil)
|
||||
defer tracker.Close()
|
||||
|
||||
srcIP := net.ParseIP("192.168.1.1")
|
||||
@ -264,7 +264,7 @@ func BenchmarkTCPTracker(b *testing.B) {
|
||||
})
|
||||
|
||||
b.Run("ConcurrentAccess", func(b *testing.B) {
|
||||
tracker := NewTCPTracker(DefaultTCPTimeout)
|
||||
tracker := NewTCPTracker(DefaultTCPTimeout, nil)
|
||||
defer tracker.Close()
|
||||
|
||||
srcIP := net.ParseIP("192.168.1.1")
|
||||
@ -287,7 +287,7 @@ func BenchmarkTCPTracker(b *testing.B) {
|
||||
// Benchmark connection cleanup
|
||||
func BenchmarkCleanup(b *testing.B) {
|
||||
b.Run("TCPCleanup", func(b *testing.B) {
|
||||
tracker := NewTCPTracker(100 * time.Millisecond) // Short timeout for testing
|
||||
tracker := NewTCPTracker(100*time.Millisecond, nil) // Short timeout for testing
|
||||
defer tracker.Close()
|
||||
|
||||
// Pre-populate with expired connections
|
||||
|
@ -4,6 +4,8 @@ import (
|
||||
"net"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
nblog "github.com/netbirdio/netbird/client/firewall/uspfilter/log"
|
||||
)
|
||||
|
||||
const (
|
||||
@ -20,6 +22,7 @@ type UDPConnTrack struct {
|
||||
|
||||
// UDPTracker manages UDP connection states
|
||||
type UDPTracker struct {
|
||||
logger *nblog.Logger
|
||||
connections map[ConnKey]*UDPConnTrack
|
||||
timeout time.Duration
|
||||
cleanupTicker *time.Ticker
|
||||
@ -29,12 +32,13 @@ type UDPTracker struct {
|
||||
}
|
||||
|
||||
// NewUDPTracker creates a new UDP connection tracker
|
||||
func NewUDPTracker(timeout time.Duration) *UDPTracker {
|
||||
func NewUDPTracker(timeout time.Duration, logger *nblog.Logger) *UDPTracker {
|
||||
if timeout == 0 {
|
||||
timeout = DefaultUDPTimeout
|
||||
}
|
||||
|
||||
tracker := &UDPTracker{
|
||||
logger: logger,
|
||||
connections: make(map[ConnKey]*UDPConnTrack),
|
||||
timeout: timeout,
|
||||
cleanupTicker: time.NewTicker(UDPCleanupInterval),
|
||||
@ -70,6 +74,8 @@ func (t *UDPTracker) TrackOutbound(srcIP net.IP, dstIP net.IP, srcPort uint16, d
|
||||
conn.lastSeen.Store(now)
|
||||
conn.established.Store(true)
|
||||
t.connections[key] = conn
|
||||
|
||||
t.logger.Trace("New UDP connection: %s", conn)
|
||||
}
|
||||
t.mutex.Unlock()
|
||||
|
||||
@ -120,6 +126,8 @@ func (t *UDPTracker) cleanup() {
|
||||
t.ipPool.Put(conn.SourceIP)
|
||||
t.ipPool.Put(conn.DestIP)
|
||||
delete(t.connections, key)
|
||||
|
||||
t.logger.Trace("UDP connection timed out: %s", conn)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -29,7 +29,7 @@ func TestNewUDPTracker(t *testing.T) {
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
tracker := NewUDPTracker(tt.timeout)
|
||||
tracker := NewUDPTracker(tt.timeout, nil)
|
||||
assert.NotNil(t, tracker)
|
||||
assert.Equal(t, tt.wantTimeout, tracker.timeout)
|
||||
assert.NotNil(t, tracker.connections)
|
||||
@ -40,7 +40,7 @@ func TestNewUDPTracker(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestUDPTracker_TrackOutbound(t *testing.T) {
|
||||
tracker := NewUDPTracker(DefaultUDPTimeout)
|
||||
tracker := NewUDPTracker(DefaultUDPTimeout, nil)
|
||||
defer tracker.Close()
|
||||
|
||||
srcIP := net.ParseIP("192.168.1.2")
|
||||
@ -63,7 +63,7 @@ func TestUDPTracker_TrackOutbound(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestUDPTracker_IsValidInbound(t *testing.T) {
|
||||
tracker := NewUDPTracker(1 * time.Second)
|
||||
tracker := NewUDPTracker(1*time.Second, nil)
|
||||
defer tracker.Close()
|
||||
|
||||
srcIP := net.ParseIP("192.168.1.2")
|
||||
@ -211,7 +211,7 @@ func TestUDPTracker_Cleanup(t *testing.T) {
|
||||
|
||||
func BenchmarkUDPTracker(b *testing.B) {
|
||||
b.Run("TrackOutbound", func(b *testing.B) {
|
||||
tracker := NewUDPTracker(DefaultUDPTimeout)
|
||||
tracker := NewUDPTracker(DefaultUDPTimeout, nil)
|
||||
defer tracker.Close()
|
||||
|
||||
srcIP := net.ParseIP("192.168.1.1")
|
||||
@ -224,7 +224,7 @@ func BenchmarkUDPTracker(b *testing.B) {
|
||||
})
|
||||
|
||||
b.Run("IsValidInbound", func(b *testing.B) {
|
||||
tracker := NewUDPTracker(DefaultUDPTimeout)
|
||||
tracker := NewUDPTracker(DefaultUDPTimeout, nil)
|
||||
defer tracker.Close()
|
||||
|
||||
srcIP := net.ParseIP("192.168.1.1")
|
||||
|
@ -1,15 +1,17 @@
|
||||
package forwarder
|
||||
|
||||
import (
|
||||
log "github.com/sirupsen/logrus"
|
||||
wgdevice "golang.zx2c4.com/wireguard/device"
|
||||
"gvisor.dev/gvisor/pkg/tcpip"
|
||||
"gvisor.dev/gvisor/pkg/tcpip/header"
|
||||
"gvisor.dev/gvisor/pkg/tcpip/stack"
|
||||
|
||||
nblog "github.com/netbirdio/netbird/client/firewall/uspfilter/log"
|
||||
)
|
||||
|
||||
// endpoint implements stack.LinkEndpoint and handles integration with the wireguard device
|
||||
type endpoint struct {
|
||||
logger *nblog.Logger
|
||||
dispatcher stack.NetworkDispatcher
|
||||
device *wgdevice.Device
|
||||
mtu uint32
|
||||
@ -55,7 +57,7 @@ func (e *endpoint) WritePackets(pkts stack.PacketBufferList) (int, tcpip.Error)
|
||||
// TODO: handle dest ip addresses outside our network
|
||||
err := e.device.CreateOutboundPacket(data.AsSlice(), address.AsSlice())
|
||||
if err != nil {
|
||||
log.Errorf("CreateOutboundPacket: %v", err)
|
||||
e.logger.Error("CreateOutboundPacket: %v", err)
|
||||
continue
|
||||
}
|
||||
written++
|
||||
|
@ -14,6 +14,7 @@ import (
|
||||
"gvisor.dev/gvisor/pkg/tcpip/transport/udp"
|
||||
|
||||
"github.com/netbirdio/netbird/client/firewall/uspfilter/common"
|
||||
nblog "github.com/netbirdio/netbird/client/firewall/uspfilter/log"
|
||||
)
|
||||
|
||||
const (
|
||||
@ -22,6 +23,7 @@ const (
|
||||
)
|
||||
|
||||
type Forwarder struct {
|
||||
logger *nblog.Logger
|
||||
stack *stack.Stack
|
||||
endpoint *endpoint
|
||||
udpForwarder *udpForwarder
|
||||
@ -29,8 +31,7 @@ type Forwarder struct {
|
||||
cancel context.CancelFunc
|
||||
}
|
||||
|
||||
func New(iface common.IFaceMapper) (*Forwarder, error) {
|
||||
|
||||
func New(iface common.IFaceMapper, logger *nblog.Logger) (*Forwarder, error) {
|
||||
s := stack.New(stack.Options{
|
||||
NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol},
|
||||
TransportProtocols: []stack.TransportProtocolFactory{
|
||||
@ -46,6 +47,7 @@ func New(iface common.IFaceMapper) (*Forwarder, error) {
|
||||
}
|
||||
nicID := tcpip.NICID(1)
|
||||
endpoint := &endpoint{
|
||||
logger: logger,
|
||||
device: iface.GetWGDevice(),
|
||||
mtu: uint32(mtu),
|
||||
}
|
||||
@ -91,9 +93,10 @@ func New(iface common.IFaceMapper) (*Forwarder, error) {
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
f := &Forwarder{
|
||||
logger: logger,
|
||||
stack: s,
|
||||
endpoint: endpoint,
|
||||
udpForwarder: newUDPForwarder(),
|
||||
udpForwarder: newUDPForwarder(logger),
|
||||
ctx: ctx,
|
||||
cancel: cancel,
|
||||
}
|
||||
|
@ -6,7 +6,6 @@ import (
|
||||
"io"
|
||||
"net"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
"gvisor.dev/gvisor/pkg/tcpip/adapters/gonet"
|
||||
"gvisor.dev/gvisor/pkg/tcpip/transport/tcp"
|
||||
"gvisor.dev/gvisor/pkg/waiter"
|
||||
@ -23,16 +22,19 @@ func (f *Forwarder) handleTCP(r *tcp.ForwarderRequest) {
|
||||
outConn, err := (&net.Dialer{}).DialContext(f.ctx, "tcp", dialAddr)
|
||||
if err != nil {
|
||||
r.Complete(true)
|
||||
f.logger.Trace("forwarder: dial error for %v: %v", id, err)
|
||||
return
|
||||
}
|
||||
|
||||
f.logger.Trace("forwarder: established TCP connection to %v", id)
|
||||
|
||||
// Create wait queue for blocking syscalls
|
||||
wq := waiter.Queue{}
|
||||
|
||||
ep, err2 := r.CreateEndpoint(&wq)
|
||||
if err2 != nil {
|
||||
if err := outConn.Close(); err != nil {
|
||||
log.Errorf("forwarder: outConn close error: %v", err)
|
||||
f.logger.Error("forwarder: outConn close error: %v", err)
|
||||
}
|
||||
r.Complete(true)
|
||||
return
|
||||
@ -49,10 +51,10 @@ func (f *Forwarder) handleTCP(r *tcp.ForwarderRequest) {
|
||||
func (f *Forwarder) proxyTCP(inConn *gonet.TCPConn, outConn net.Conn) {
|
||||
defer func() {
|
||||
if err := inConn.Close(); err != nil {
|
||||
log.Errorf("forwarder: inConn close error: %v", err)
|
||||
f.logger.Error("forwarder: inConn close error: %v", err)
|
||||
}
|
||||
if err := outConn.Close(); err != nil {
|
||||
log.Errorf("forwarder: outConn close error: %v", err)
|
||||
f.logger.Error("forwarder: outConn close error: %v", err)
|
||||
}
|
||||
}()
|
||||
|
||||
@ -65,7 +67,7 @@ func (f *Forwarder) proxyTCP(inConn *gonet.TCPConn, outConn net.Conn) {
|
||||
go func() {
|
||||
n, err := io.Copy(outConn, inConn)
|
||||
if err != nil && !isClosedError(err) {
|
||||
log.Errorf("proxyTCP: inbound->outbound copy error after %d bytes: %v", n, err)
|
||||
f.logger.Error("inbound->outbound copy error after %d bytes: %v", n, err)
|
||||
}
|
||||
errChan <- err
|
||||
}()
|
||||
@ -73,7 +75,7 @@ func (f *Forwarder) proxyTCP(inConn *gonet.TCPConn, outConn net.Conn) {
|
||||
go func() {
|
||||
n, err := io.Copy(inConn, outConn)
|
||||
if err != nil && !isClosedError(err) {
|
||||
log.Errorf("proxyTCP: outbound->inbound copy error after %d bytes: %v", n, err)
|
||||
f.logger.Error("outbound->inbound copy error after %d bytes: %v", n, err)
|
||||
}
|
||||
errChan <- err
|
||||
}()
|
||||
@ -83,7 +85,7 @@ func (f *Forwarder) proxyTCP(inConn *gonet.TCPConn, outConn net.Conn) {
|
||||
return
|
||||
case err := <-errChan:
|
||||
if err != nil && !isClosedError(err) {
|
||||
log.Errorf("proxyTCP: copy error: %v", err)
|
||||
f.logger.Error("proxyTCP: copy error: %v", err)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
@ -8,11 +8,12 @@ import (
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
"gvisor.dev/gvisor/pkg/tcpip/adapters/gonet"
|
||||
"gvisor.dev/gvisor/pkg/tcpip/stack"
|
||||
"gvisor.dev/gvisor/pkg/tcpip/transport/udp"
|
||||
"gvisor.dev/gvisor/pkg/waiter"
|
||||
|
||||
nblog "github.com/netbirdio/netbird/client/firewall/uspfilter/log"
|
||||
)
|
||||
|
||||
const (
|
||||
@ -29,15 +30,17 @@ type udpPacketConn struct {
|
||||
|
||||
type udpForwarder struct {
|
||||
sync.RWMutex
|
||||
logger *nblog.Logger
|
||||
conns map[stack.TransportEndpointID]*udpPacketConn
|
||||
bufPool sync.Pool
|
||||
ctx context.Context
|
||||
cancel context.CancelFunc
|
||||
}
|
||||
|
||||
func newUDPForwarder() *udpForwarder {
|
||||
func newUDPForwarder(logger *nblog.Logger) *udpForwarder {
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
f := &udpForwarder{
|
||||
logger: logger,
|
||||
conns: make(map[stack.TransportEndpointID]*udpPacketConn),
|
||||
ctx: ctx,
|
||||
cancel: cancel,
|
||||
@ -62,10 +65,10 @@ func (f *udpForwarder) Stop() {
|
||||
for id, conn := range f.conns {
|
||||
conn.cancel()
|
||||
if err := conn.conn.Close(); err != nil {
|
||||
log.Errorf("forwarder: UDP conn close error for %v: %v", id, err)
|
||||
f.logger.Error("forwarder: UDP conn close error for %v: %v", id, err)
|
||||
}
|
||||
if err := conn.outConn.Close(); err != nil {
|
||||
log.Errorf("forwarder: UDP outConn close error for %v: %v", id, err)
|
||||
f.logger.Error("forwarder: UDP outConn close error for %v: %v", id, err)
|
||||
}
|
||||
delete(f.conns, id)
|
||||
}
|
||||
@ -87,13 +90,13 @@ func (f *udpForwarder) cleanup() {
|
||||
if now.Sub(conn.lastTime) > udpTimeout {
|
||||
conn.cancel()
|
||||
if err := conn.conn.Close(); err != nil {
|
||||
log.Errorf("forwarder: UDP conn close error for %v: %v", id, err)
|
||||
f.logger.Error("forwarder: UDP conn close error for %v: %v", id, err)
|
||||
}
|
||||
if err := conn.outConn.Close(); err != nil {
|
||||
log.Errorf("forwarder: UDP outConn close error for %v: %v", id, err)
|
||||
f.logger.Error("forwarder: UDP outConn close error for %v: %v", id, err)
|
||||
}
|
||||
delete(f.conns, id)
|
||||
log.Debugf("forwarder: cleaned up idle UDP connection %v", id)
|
||||
f.logger.Trace("forwarder: cleaned up idle UDP connection %v", id)
|
||||
}
|
||||
}
|
||||
f.Unlock()
|
||||
@ -107,7 +110,7 @@ func (f *Forwarder) handleUDP(r *udp.ForwarderRequest) {
|
||||
dstAddr := fmt.Sprintf("%s:%d", id.LocalAddress.String(), id.LocalPort)
|
||||
|
||||
if f.ctx.Err() != nil {
|
||||
log.Debug("forwarder: context done, dropping UDP packet")
|
||||
f.logger.Trace("forwarder: context done, dropping UDP packet")
|
||||
return
|
||||
}
|
||||
|
||||
@ -116,7 +119,7 @@ func (f *Forwarder) handleUDP(r *udp.ForwarderRequest) {
|
||||
|
||||
ep, err := r.CreateEndpoint(&wq)
|
||||
if err != nil {
|
||||
log.Errorf("forwarder: failed to create UDP endpoint: %v", err)
|
||||
f.logger.Error("forwarder: failed to create UDP endpoint: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
@ -131,12 +134,16 @@ func (f *Forwarder) handleUDP(r *udp.ForwarderRequest) {
|
||||
outConn, err := (&net.Dialer{}).DialContext(f.ctx, "udp", dstAddr)
|
||||
if err != nil {
|
||||
if err := inConn.Close(); err != nil {
|
||||
log.Errorf("forwarder: UDP inConn close error for %v: %v", id, err)
|
||||
f.logger.Error("forwarder: UDP inConn close error for %v: %v", id, err)
|
||||
}
|
||||
log.Errorf("forwarder: UDP dial error for %v: %v", id, err)
|
||||
f.logger.Debug("forwarder: UDP dial error for %v: %v", id, err)
|
||||
|
||||
// TODO: Send ICMP error message
|
||||
return
|
||||
}
|
||||
|
||||
f.logger.Trace("forwarder: established UDP connection to %v", id)
|
||||
|
||||
connCtx, connCancel := context.WithCancel(f.ctx)
|
||||
pConn = &udpPacketConn{
|
||||
conn: inConn,
|
||||
@ -154,10 +161,10 @@ func (f *Forwarder) proxyUDP(ctx context.Context, pConn *udpPacketConn, id stack
|
||||
defer func() {
|
||||
pConn.cancel()
|
||||
if err := pConn.conn.Close(); err != nil {
|
||||
log.Errorf("forwarder: UDP inConn close error for %v: %v", id, err)
|
||||
f.logger.Error("forwarder: UDP inConn close error for %v: %v", id, err)
|
||||
}
|
||||
if err := pConn.outConn.Close(); err != nil {
|
||||
log.Errorf("forwarder: UDP outConn close error for %v: %v", id, err)
|
||||
f.logger.Error("forwarder: UDP outConn close error for %v: %v", id, err)
|
||||
}
|
||||
|
||||
f.udpForwarder.Lock()
|
||||
@ -180,7 +187,7 @@ func (f *Forwarder) proxyUDP(ctx context.Context, pConn *udpPacketConn, id stack
|
||||
return
|
||||
case err := <-errChan:
|
||||
if err != nil && !isClosedError(err) {
|
||||
log.Errorf("forwader: UDP proxy error for %v: %v", id, err)
|
||||
f.logger.Error("proxyUDP: copy error: %v", err)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
208
client/firewall/uspfilter/log/log.go
Normal file
208
client/firewall/uspfilter/log/log.go
Normal file
@ -0,0 +1,208 @@
|
||||
// Package logger provides a high-performance, non-blocking logger for userspace networking
|
||||
package log
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"io"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
const (
|
||||
maxBatchSize = 1024 * 16 // 16KB max batch size
|
||||
maxMessageSize = 1024 * 2 // 2KB per message
|
||||
bufferSize = 1024 * 256 // 256KB ring buffer
|
||||
defaultFlushInterval = 2 * time.Second
|
||||
)
|
||||
|
||||
// Level represents log severity
|
||||
type Level uint32
|
||||
|
||||
const (
|
||||
LevelPanic Level = iota
|
||||
LevelFatal
|
||||
LevelError
|
||||
LevelWarn
|
||||
LevelInfo
|
||||
LevelDebug
|
||||
LevelTrace
|
||||
)
|
||||
|
||||
var levelStrings = map[Level]string{
|
||||
LevelPanic: "PANC",
|
||||
LevelFatal: "FATL",
|
||||
LevelError: "ERRO",
|
||||
LevelWarn: "WARN",
|
||||
LevelInfo: "INFO",
|
||||
LevelDebug: "DEBG",
|
||||
LevelTrace: "TRAC",
|
||||
}
|
||||
|
||||
func FromLogrusLevel(level log.Level) Level {
|
||||
switch level {
|
||||
case log.TraceLevel:
|
||||
return LevelTrace
|
||||
case log.DebugLevel:
|
||||
return LevelDebug
|
||||
case log.InfoLevel:
|
||||
return LevelInfo
|
||||
case log.WarnLevel:
|
||||
return LevelWarn
|
||||
case log.ErrorLevel:
|
||||
return LevelError
|
||||
case log.FatalLevel:
|
||||
return LevelFatal
|
||||
case log.PanicLevel:
|
||||
return LevelPanic
|
||||
default:
|
||||
return LevelInfo
|
||||
}
|
||||
}
|
||||
|
||||
// Logger is a high-performance, non-blocking logger
|
||||
type Logger struct {
|
||||
output io.Writer
|
||||
level atomic.Uint32
|
||||
buffer *ringBuffer
|
||||
shutdown chan struct{}
|
||||
wg sync.WaitGroup
|
||||
|
||||
// Reusable buffer pool for formatting messages
|
||||
bufPool sync.Pool
|
||||
}
|
||||
|
||||
func NewFromLogrus(logrusLogger *log.Logger) *Logger {
|
||||
l := &Logger{
|
||||
output: logrusLogger.Out,
|
||||
buffer: newRingBuffer(bufferSize),
|
||||
shutdown: make(chan struct{}),
|
||||
bufPool: sync.Pool{
|
||||
New: func() interface{} {
|
||||
// Pre-allocate buffer for message formatting
|
||||
b := make([]byte, 0, maxMessageSize)
|
||||
return &b
|
||||
},
|
||||
},
|
||||
}
|
||||
l.level.Store(uint32(LevelInfo))
|
||||
|
||||
l.wg.Add(1)
|
||||
go l.worker()
|
||||
|
||||
return l
|
||||
}
|
||||
|
||||
func (l *Logger) SetLevel(level Level) {
|
||||
l.level.Store(uint32(level))
|
||||
}
|
||||
|
||||
func (l *Logger) formatMessage(buf *[]byte, level Level, format string, args ...interface{}) {
|
||||
*buf = (*buf)[:0]
|
||||
|
||||
// Timestamp
|
||||
*buf = time.Now().AppendFormat(*buf, "2006-01-02T15:04:05.000000-07:00")
|
||||
*buf = append(*buf, ' ')
|
||||
|
||||
// Level
|
||||
*buf = append(*buf, levelStrings[level]...)
|
||||
*buf = append(*buf, ' ')
|
||||
|
||||
// Message
|
||||
if len(args) > 0 {
|
||||
*buf = append(*buf, fmt.Sprintf(format, args...)...)
|
||||
} else {
|
||||
*buf = append(*buf, format...)
|
||||
}
|
||||
|
||||
*buf = append(*buf, '\n')
|
||||
}
|
||||
|
||||
func (l *Logger) log(level Level, format string, args ...interface{}) {
|
||||
bufp := l.bufPool.Get().(*[]byte)
|
||||
l.formatMessage(bufp, level, format, args...)
|
||||
|
||||
if len(*bufp) > maxMessageSize {
|
||||
*bufp = (*bufp)[:maxMessageSize]
|
||||
}
|
||||
l.buffer.Write(*bufp)
|
||||
|
||||
l.bufPool.Put(bufp)
|
||||
}
|
||||
|
||||
func (l *Logger) Trace(format string, args ...interface{}) {
|
||||
if l.level.Load() <= uint32(LevelTrace) {
|
||||
l.log(LevelTrace, format, args...)
|
||||
}
|
||||
}
|
||||
|
||||
func (l *Logger) Debug(format string, args ...interface{}) {
|
||||
if l.level.Load() <= uint32(LevelDebug) {
|
||||
l.log(LevelDebug, format, args...)
|
||||
}
|
||||
}
|
||||
|
||||
func (l *Logger) Info(format string, args ...interface{}) {
|
||||
if l.level.Load() <= uint32(LevelInfo) {
|
||||
l.log(LevelInfo, format, args...)
|
||||
}
|
||||
}
|
||||
|
||||
func (l *Logger) Warn(format string, args ...interface{}) {
|
||||
if l.level.Load() <= uint32(LevelWarn) {
|
||||
l.log(LevelWarn, format, args...)
|
||||
}
|
||||
}
|
||||
|
||||
func (l *Logger) Error(format string, args ...interface{}) {
|
||||
if l.level.Load() <= uint32(LevelError) {
|
||||
l.log(LevelError, format, args...)
|
||||
}
|
||||
}
|
||||
|
||||
// worker periodically flushes the buffer
|
||||
func (l *Logger) worker() {
|
||||
defer l.wg.Done()
|
||||
|
||||
ticker := time.NewTicker(defaultFlushInterval)
|
||||
defer ticker.Stop()
|
||||
|
||||
buf := make([]byte, 0, maxBatchSize)
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-l.shutdown:
|
||||
return
|
||||
case <-ticker.C:
|
||||
// Read accumulated messages
|
||||
n, _ := l.buffer.Read(buf[:cap(buf)])
|
||||
if n == 0 {
|
||||
continue
|
||||
}
|
||||
|
||||
// Write batch
|
||||
l.output.Write(buf[:n])
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Stop gracefully shuts down the logger
|
||||
func (l *Logger) Stop(ctx context.Context) error {
|
||||
close(l.shutdown)
|
||||
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
l.wg.Wait()
|
||||
close(done)
|
||||
}()
|
||||
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return ctx.Err()
|
||||
case <-done:
|
||||
return nil
|
||||
}
|
||||
}
|
93
client/firewall/uspfilter/log/ringbuffer.go
Normal file
93
client/firewall/uspfilter/log/ringbuffer.go
Normal file
@ -0,0 +1,93 @@
|
||||
package log
|
||||
|
||||
import "sync"
|
||||
|
||||
// ringBuffer is a simple ring buffer implementation
|
||||
type ringBuffer struct {
|
||||
buf []byte
|
||||
size int
|
||||
r, w int64 // Read and write positions
|
||||
mu sync.Mutex
|
||||
}
|
||||
|
||||
func newRingBuffer(size int) *ringBuffer {
|
||||
return &ringBuffer{
|
||||
buf: make([]byte, size),
|
||||
size: size,
|
||||
}
|
||||
}
|
||||
|
||||
func (r *ringBuffer) Write(p []byte) (n int, err error) {
|
||||
if len(p) == 0 {
|
||||
return 0, nil
|
||||
}
|
||||
|
||||
r.mu.Lock()
|
||||
defer r.mu.Unlock()
|
||||
|
||||
if len(p) > r.size {
|
||||
p = p[:r.size]
|
||||
}
|
||||
|
||||
n = len(p)
|
||||
|
||||
// Write data, handling wrap-around
|
||||
pos := int(r.w % int64(r.size))
|
||||
writeLen := min(len(p), r.size-pos)
|
||||
copy(r.buf[pos:], p[:writeLen])
|
||||
|
||||
// If we have more data and need to wrap around
|
||||
if writeLen < len(p) {
|
||||
copy(r.buf, p[writeLen:])
|
||||
}
|
||||
|
||||
// Update write position
|
||||
r.w += int64(n)
|
||||
|
||||
return n, nil
|
||||
}
|
||||
|
||||
func (r *ringBuffer) Read(p []byte) (n int, err error) {
|
||||
r.mu.Lock()
|
||||
defer r.mu.Unlock()
|
||||
|
||||
if r.w == r.r {
|
||||
return 0, nil
|
||||
}
|
||||
|
||||
// Calculate available data accounting for wraparound
|
||||
available := int(r.w - r.r)
|
||||
if available < 0 {
|
||||
available += r.size
|
||||
}
|
||||
available = min(available, r.size)
|
||||
|
||||
// Limit read to buffer size
|
||||
toRead := min(available, len(p))
|
||||
if toRead == 0 {
|
||||
return 0, nil
|
||||
}
|
||||
|
||||
// Read data, handling wrap-around
|
||||
pos := int(r.r % int64(r.size))
|
||||
readLen := min(toRead, r.size-pos)
|
||||
n = copy(p, r.buf[pos:pos+readLen])
|
||||
|
||||
// If we need more data and need to wrap around
|
||||
if readLen < toRead {
|
||||
n += copy(p[readLen:toRead], r.buf[:toRead-readLen])
|
||||
}
|
||||
|
||||
// Update read position
|
||||
r.r += int64(n)
|
||||
|
||||
return n, nil
|
||||
}
|
||||
|
||||
// min returns the smaller of two integers
|
||||
func min(a, b int) int {
|
||||
if a < b {
|
||||
return a
|
||||
}
|
||||
return b
|
||||
}
|
@ -17,6 +17,7 @@ import (
|
||||
"github.com/netbirdio/netbird/client/firewall/uspfilter/common"
|
||||
"github.com/netbirdio/netbird/client/firewall/uspfilter/conntrack"
|
||||
"github.com/netbirdio/netbird/client/firewall/uspfilter/forwarder"
|
||||
nblog "github.com/netbirdio/netbird/client/firewall/uspfilter/log"
|
||||
"github.com/netbirdio/netbird/client/internal/statemanager"
|
||||
)
|
||||
|
||||
@ -52,6 +53,7 @@ type Manager struct {
|
||||
icmpTracker *conntrack.ICMPTracker
|
||||
tcpTracker *conntrack.TCPTracker
|
||||
forwarder *forwarder.Forwarder
|
||||
logger *nblog.Logger
|
||||
}
|
||||
|
||||
// decoder for packages
|
||||
@ -106,15 +108,17 @@ func create(iface common.IFaceMapper) (*Manager, error) {
|
||||
stateful: !disableConntrack,
|
||||
// TODO: fix
|
||||
routingEnabled: true,
|
||||
// TODO: support chaning log level from logrus
|
||||
logger: nblog.NewFromLogrus(log.StandardLogger()),
|
||||
}
|
||||
|
||||
// Only initialize trackers if stateful mode is enabled
|
||||
if disableConntrack {
|
||||
log.Info("conntrack is disabled")
|
||||
} else {
|
||||
m.udpTracker = conntrack.NewUDPTracker(conntrack.DefaultUDPTimeout)
|
||||
m.icmpTracker = conntrack.NewICMPTracker(conntrack.DefaultICMPTimeout)
|
||||
m.tcpTracker = conntrack.NewTCPTracker(conntrack.DefaultTCPTimeout)
|
||||
m.udpTracker = conntrack.NewUDPTracker(conntrack.DefaultUDPTimeout, m.logger)
|
||||
m.icmpTracker = conntrack.NewICMPTracker(conntrack.DefaultICMPTimeout, m.logger)
|
||||
m.tcpTracker = conntrack.NewTCPTracker(conntrack.DefaultTCPTimeout, m.logger)
|
||||
}
|
||||
|
||||
intf := iface.GetWGDevice()
|
||||
@ -125,7 +129,7 @@ func create(iface common.IFaceMapper) (*Manager, error) {
|
||||
m.routingEnabled = false
|
||||
} else {
|
||||
var err error
|
||||
m.forwarder, err = forwarder.New(iface)
|
||||
m.forwarder, err = forwarder.New(iface, m.logger)
|
||||
if err != nil {
|
||||
log.Errorf("failed to create forwarder: %v", err)
|
||||
m.routingEnabled = false
|
||||
@ -455,17 +459,16 @@ func (m *Manager) dropFilter(packetData []byte, rules map[string]RuleSet) bool {
|
||||
defer m.decoders.Put(d)
|
||||
|
||||
if !m.isValidPacket(d, packetData) {
|
||||
log.Debugf("invalid packet: %v", d.decoded)
|
||||
m.logger.Trace("Invalid packet structure")
|
||||
return true
|
||||
}
|
||||
|
||||
srcIP, dstIP := m.extractIPs(d)
|
||||
if srcIP == nil {
|
||||
log.Errorf("unknown layer: %v", d.decoded[0])
|
||||
m.logger.Error("Unknown network layer: %v", d.decoded[0])
|
||||
return true
|
||||
}
|
||||
|
||||
// Check if this is local or routed traffic
|
||||
isLocal := m.isLocalIP(dstIP)
|
||||
|
||||
// For all inbound traffic, first check if it matches a tracked connection.
|
||||
@ -476,7 +479,12 @@ func (m *Manager) dropFilter(packetData []byte, rules map[string]RuleSet) bool {
|
||||
|
||||
// Handle local traffic - apply peer ACLs
|
||||
if isLocal {
|
||||
return m.applyRules(srcIP, packetData, rules, d)
|
||||
drop := m.applyRules(srcIP, packetData, rules, d)
|
||||
if drop {
|
||||
m.logger.Trace("Dropping local packet: src=%s dst=%s rules=denied",
|
||||
srcIP, dstIP)
|
||||
}
|
||||
return drop
|
||||
}
|
||||
|
||||
// Handle routed traffic
|
||||
@ -484,6 +492,8 @@ func (m *Manager) dropFilter(packetData []byte, rules map[string]RuleSet) bool {
|
||||
// We might need to apply NAT
|
||||
// Don't handle routing if not enabled
|
||||
if !m.routingEnabled {
|
||||
m.logger.Trace("Dropping routed packet (routing disabled): src=%s dst=%s",
|
||||
srcIP, dstIP)
|
||||
return true
|
||||
}
|
||||
|
||||
@ -493,13 +503,15 @@ func (m *Manager) dropFilter(packetData []byte, rules map[string]RuleSet) bool {
|
||||
|
||||
// Check route ACLs
|
||||
if !m.checkRouteACLs(srcIP, dstIP, proto, srcPort, dstPort) {
|
||||
m.logger.Trace("Dropping routed packet (ACL denied): src=%s:%d dst=%s:%d proto=%v",
|
||||
srcIP, srcPort, dstIP, dstPort, proto)
|
||||
return true
|
||||
}
|
||||
|
||||
// Let forwarder handle the packet if it passed route ACLs
|
||||
err := m.forwarder.InjectIncomingPacket(packetData)
|
||||
if err != nil {
|
||||
log.Errorf("Failed to inject incoming packet: %v", err)
|
||||
m.logger.Error("Failed to inject incoming packet: %v", err)
|
||||
}
|
||||
|
||||
// Default: drop
|
||||
|
@ -400,7 +400,7 @@ func TestProcessOutgoingHooks(t *testing.T) {
|
||||
Mask: net.CIDRMask(16, 32),
|
||||
}
|
||||
manager.udpTracker.Close()
|
||||
manager.udpTracker = conntrack.NewUDPTracker(100 * time.Millisecond)
|
||||
manager.udpTracker = conntrack.NewUDPTracker(100*time.Millisecond, nil)
|
||||
defer func() {
|
||||
require.NoError(t, manager.Reset(nil))
|
||||
}()
|
||||
@ -518,7 +518,7 @@ func TestStatefulFirewall_UDPTracking(t *testing.T) {
|
||||
}
|
||||
|
||||
manager.udpTracker.Close() // Close the existing tracker
|
||||
manager.udpTracker = conntrack.NewUDPTracker(200 * time.Millisecond)
|
||||
manager.udpTracker = conntrack.NewUDPTracker(200*time.Millisecond, nil)
|
||||
manager.decoders = sync.Pool{
|
||||
New: func() any {
|
||||
d := &decoder{
|
||||
|
Loading…
x
Reference in New Issue
Block a user