diff --git a/client/firewall/uspfilter/conntrack/common_test.go b/client/firewall/uspfilter/conntrack/common_test.go index f5f864ead..6d1ed5890 100644 --- a/client/firewall/uspfilter/conntrack/common_test.go +++ b/client/firewall/uspfilter/conntrack/common_test.go @@ -12,7 +12,7 @@ import ( ) var logger = log.NewFromLogrus(logrus.StandardLogger()) -var flowLogger = netflow.NewManager(context.Background(), []byte{}).GetLogger() +var flowLogger = netflow.NewManager(context.Background(), nil, []byte{}).GetLogger() // Memory pressure tests func BenchmarkMemoryPressure(b *testing.B) { diff --git a/client/firewall/uspfilter/uspfilter_test.go b/client/firewall/uspfilter/uspfilter_test.go index ad6d430b8..ed0fabe69 100644 --- a/client/firewall/uspfilter/uspfilter_test.go +++ b/client/firewall/uspfilter/uspfilter_test.go @@ -24,7 +24,7 @@ import ( ) var logger = log.NewFromLogrus(logrus.StandardLogger()) -var flowLogger = netflow.NewManager(context.Background(), []byte{}).GetLogger() +var flowLogger = netflow.NewManager(context.Background(), nil, []byte{}).GetLogger() type IFaceMock struct { SetFilterFunc func(device.PacketFilter) error diff --git a/client/internal/acl/manager_test.go b/client/internal/acl/manager_test.go index bce850347..82a136e9c 100644 --- a/client/internal/acl/manager_test.go +++ b/client/internal/acl/manager_test.go @@ -15,7 +15,7 @@ import ( mgmProto "github.com/netbirdio/netbird/management/proto" ) -var flowLogger = netflow.NewManager(context.Background(), []byte{}).GetLogger() +var flowLogger = netflow.NewManager(context.Background(), nil, []byte{}).GetLogger() func TestDefaultManager(t *testing.T) { networkMap := &mgmProto.NetworkMap{ diff --git a/client/internal/dns/server_test.go b/client/internal/dns/server_test.go index 853bc9b9c..7c75f6bed 100644 --- a/client/internal/dns/server_test.go +++ b/client/internal/dns/server_test.go @@ -30,7 +30,7 @@ import ( "github.com/netbirdio/netbird/formatter" ) -var flowLogger = netflow.NewManager(context.Background(), []byte{}).GetLogger() +var flowLogger = netflow.NewManager(context.Background(), nil, []byte{}).GetLogger() type mocWGIface struct { filter device.PacketFilter diff --git a/client/internal/engine.go b/client/internal/engine.go index d1f878820..2210384de 100644 --- a/client/internal/engine.go +++ b/client/internal/engine.go @@ -216,7 +216,6 @@ func NewEngine( statusRecorder *peer.Status, checks []*mgmProto.Checks, ) *Engine { - publicKey := config.WgPrivateKey.PublicKey() engine := &Engine{ clientCtx: clientCtx, clientCancel: clientCancel, @@ -235,7 +234,6 @@ func NewEngine( statusRecorder: statusRecorder, checks: checks, connSemaphore: semaphoregroup.NewSemaphoreGroup(connInitLimit), - flowManager: netflow.NewManager(clientCtx, publicKey[:]), } if runtime.GOOS == "ios" { if !fileExists(mobileDep.StateFilePath) { @@ -304,8 +302,6 @@ func (e *Engine) Stop() error { return fmt.Errorf("failed to remove all peers: %s", err) } - e.flowManager.Close() - if e.cancel != nil { e.cancel() } @@ -315,6 +311,12 @@ func (e *Engine) Stop() error { time.Sleep(500 * time.Millisecond) e.close() + + // stop flow manager after wg interface is gone + if e.flowManager != nil { + e.flowManager.Close() + } + log.Infof("stopped Netbird Engine") ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second) @@ -349,6 +351,10 @@ func (e *Engine) Start() error { } e.wgInterface = wgIface + // start flow manager right after interface creation + publicKey := e.config.WgPrivateKey.PublicKey() + e.flowManager = netflow.NewManager(e.ctx, e.wgInterface, publicKey[:]) + if e.config.RosenpassEnabled { log.Infof("rosenpass is enabled") if e.config.RosenpassPermissive { diff --git a/client/internal/netflow/conntrack/conntrack.go b/client/internal/netflow/conntrack/conntrack.go new file mode 100644 index 000000000..33d69ef38 --- /dev/null +++ b/client/internal/netflow/conntrack/conntrack.go @@ -0,0 +1,268 @@ +//go:build linux && !android + +package conntrack + +import ( + "encoding/binary" + "fmt" + "net/netip" + "sync" + + "github.com/google/uuid" + log "github.com/sirupsen/logrus" + nfct "github.com/ti-mo/conntrack" + "github.com/ti-mo/netfilter" + + nftypes "github.com/netbirdio/netbird/client/internal/netflow/types" +) + +const defaultChannelSize = 100 + +// ConnTrack manages kernel-based conntrack events +type ConnTrack struct { + flowLogger nftypes.FlowLogger + iface nftypes.IFaceMapper + + conn *nfct.Conn + mux sync.Mutex + + instanceID uuid.UUID + started bool + done chan struct{} +} + +// New creates a new connection tracker that interfaces with the kernel's conntrack system +func New(flowLogger nftypes.FlowLogger, iface nftypes.IFaceMapper) *ConnTrack { + return &ConnTrack{ + flowLogger: flowLogger, + iface: iface, + instanceID: uuid.New(), + started: false, + done: make(chan struct{}, 1), + } +} + +// Start begins tracking connections by listening for conntrack events. This method is idempotent. +func (c *ConnTrack) Start() error { + c.mux.Lock() + defer c.mux.Unlock() + + if c.started { + return nil + } + + log.Info("Starting conntrack event listening") + + conn, err := nfct.Dial(nil) + if err != nil { + return fmt.Errorf("dial conntrack: %w", err) + } + c.conn = conn + + events := make(chan nfct.Event, defaultChannelSize) + errChan, err := conn.Listen(events, 1, []netfilter.NetlinkGroup{ + netfilter.GroupCTNew, + netfilter.GroupCTDestroy, + }) + + if err != nil { + if err := c.conn.Close(); err != nil { + log.Errorf("Error closing conntrack connection: %v", err) + } + c.conn = nil + return fmt.Errorf("start conntrack listener: %w", err) + } + + c.started = true + + go c.receiverRoutine(events, errChan) + + return nil +} + +func (c *ConnTrack) receiverRoutine(events chan nfct.Event, errChan chan error) { + for { + select { + case event := <-events: + c.handleEvent(event) + case err := <-errChan: + log.Errorf("Error from conntrack event listener: %v", err) + if err := c.conn.Close(); err != nil { + log.Errorf("Error closing conntrack connection: %v", err) + } + return + case <-c.done: + return + } + } +} + +// Stop stops the connection tracking. This method is idempotent. +func (c *ConnTrack) Stop() { + c.mux.Lock() + defer c.mux.Unlock() + + if !c.started { + return + } + + log.Info("Stopping conntrack event listening") + + select { + case c.done <- struct{}{}: + default: + } + + if c.conn != nil { + if err := c.conn.Close(); err != nil { + log.Errorf("Error closing conntrack connection: %v", err) + } + c.conn = nil + } + + c.started = false +} + +// Close stops listening for events and cleans up resources +func (c *ConnTrack) Close() error { + c.mux.Lock() + defer c.mux.Unlock() + + if c.started { + select { + case c.done <- struct{}{}: + default: + } + } + + if c.conn != nil { + err := c.conn.Close() + c.conn = nil + c.started = false + if err != nil { + return fmt.Errorf("close conntrack: %w", err) + } + } + + return nil +} + +// handleEvent processes incoming conntrack events +func (c *ConnTrack) handleEvent(event nfct.Event) { + if event.Flow == nil { + return + } + + flow := *event.Flow + + proto := nftypes.Protocol(flow.TupleOrig.Proto.Protocol) + if proto == nftypes.ProtocolUnknown { + return + } + srcIP := flow.TupleOrig.IP.SourceAddress + dstIP := flow.TupleOrig.IP.DestinationAddress + + if !c.relevantFlow(srcIP, dstIP) { + return + } + + var srcPort, dstPort uint16 + var icmpType, icmpCode uint8 + + switch proto { + case nftypes.TCP, nftypes.UDP, nftypes.SCTP: + srcPort = flow.TupleOrig.Proto.SourcePort + dstPort = flow.TupleOrig.Proto.DestinationPort + case nftypes.ICMP: + icmpType = flow.TupleOrig.Proto.ICMPType + icmpCode = flow.TupleOrig.Proto.ICMPCode + } + + switch event.Type { + case nfct.EventNew: + c.handleNewFlow(flow.ID, proto, srcIP, dstIP, srcPort, dstPort, icmpType, icmpCode) + + case nfct.EventDestroy: + c.handleDestroyFlow(flow.ID, proto, srcIP, dstIP, srcPort, dstPort, icmpType, icmpCode) + } +} + +// relevantFlow checks if the flow is related to the specified interface +func (c *ConnTrack) relevantFlow(srcIP, dstIP netip.Addr) bool { + // TODO: filter traffic by interface + + wgnet := c.iface.Address().Network + if !wgnet.Contains(srcIP.AsSlice()) && !wgnet.Contains(dstIP.AsSlice()) { + return false + } + + return true +} + +func (c *ConnTrack) handleNewFlow(id uint32, proto nftypes.Protocol, srcIP, dstIP netip.Addr, srcPort, dstPort uint16, icmpType, icmpCode uint8) { + flowID := c.getFlowID(id) + direction := c.inferDirection(srcIP, dstIP) + + log.Tracef("New %s %s connection: %s:%d -> %s:%d", direction, proto, srcIP, srcPort, dstIP, dstPort) + c.flowLogger.StoreEvent(nftypes.EventFields{ + FlowID: flowID, + Type: nftypes.TypeStart, + Direction: direction, + Protocol: proto, + SourceIP: srcIP, + DestIP: dstIP, + SourcePort: srcPort, + DestPort: dstPort, + ICMPType: icmpType, + ICMPCode: icmpCode, + }) +} + +func (c *ConnTrack) handleDestroyFlow(id uint32, proto nftypes.Protocol, srcIP, dstIP netip.Addr, srcPort, dstPort uint16, icmpType, icmpCode uint8) { + flowID := c.getFlowID(id) + direction := c.inferDirection(srcIP, dstIP) + + log.Tracef("Ended %s %s connection: %s:%d -> %s:%d", direction, proto, srcIP, srcPort, dstIP, dstPort) + c.flowLogger.StoreEvent(nftypes.EventFields{ + FlowID: flowID, + Type: nftypes.TypeEnd, + Direction: direction, + Protocol: proto, + SourceIP: srcIP, + DestIP: dstIP, + SourcePort: srcPort, + DestPort: dstPort, + ICMPType: icmpType, + ICMPCode: icmpCode, + }) +} + +// getFlowID creates a unique UUID based on the conntrack ID and instance ID +func (c *ConnTrack) getFlowID(conntrackID uint32) uuid.UUID { + var buf [4]byte + binary.BigEndian.PutUint32(buf[:], conntrackID) + return uuid.NewSHA1(c.instanceID, buf[:]) +} + +func (c *ConnTrack) inferDirection(srcIP, dstIP netip.Addr) nftypes.Direction { + wgaddr := c.iface.Address().IP + wgnetwork := c.iface.Address().Network + src, dst := srcIP.AsSlice(), dstIP.AsSlice() + + switch { + case wgaddr.Equal(src): + return nftypes.Egress + case wgaddr.Equal(dst): + return nftypes.Ingress + case wgnetwork.Contains(src): + // netbird network -> resource network + return nftypes.Ingress + case wgnetwork.Contains(dst): + // resource network -> netbird network + return nftypes.Egress + + // TODO: handle site2site traffic + } + + return nftypes.DirectionUnknown +} diff --git a/client/internal/netflow/conntrack/conntrack_nonlinux.go b/client/internal/netflow/conntrack/conntrack_nonlinux.go new file mode 100644 index 000000000..9044fd76c --- /dev/null +++ b/client/internal/netflow/conntrack/conntrack_nonlinux.go @@ -0,0 +1,9 @@ +//go:build !linux || android + +package conntrack + +import nftypes "github.com/netbirdio/netbird/client/internal/netflow/types" + +func New(flowLogger nftypes.FlowLogger, iface nftypes.IFaceMapper) nftypes.ConnTracker { + return nil +} diff --git a/client/internal/netflow/manager.go b/client/internal/netflow/manager.go index 527dfd256..c5bf3d7ed 100644 --- a/client/internal/netflow/manager.go +++ b/client/internal/netflow/manager.go @@ -2,35 +2,50 @@ package netflow import ( "context" + "fmt" + "runtime" "sync" "time" log "github.com/sirupsen/logrus" "google.golang.org/protobuf/types/known/timestamppb" + "github.com/netbirdio/netbird/client/internal/netflow/conntrack" "github.com/netbirdio/netbird/client/internal/netflow/logger" "github.com/netbirdio/netbird/client/internal/netflow/types" "github.com/netbirdio/netbird/flow/client" "github.com/netbirdio/netbird/flow/proto" ) +// Manager handles netflow tracking and logging type Manager struct { mux sync.Mutex logger types.FlowLogger flowConfig *types.FlowConfig + conntrack types.ConnTracker ctx context.Context receiverClient *client.GRPCClient publicKey []byte } -func NewManager(ctx context.Context, publicKey []byte) *Manager { +// NewManager creates a new netflow manager +func NewManager(ctx context.Context, iface types.IFaceMapper, publicKey []byte) *Manager { + flowLogger := logger.New(ctx) + + var ct types.ConnTracker + if runtime.GOOS == "linux" && iface != nil && !iface.IsUserspaceBind() { + ct = conntrack.New(flowLogger, iface) + } + return &Manager{ - logger: logger.New(ctx), + logger: flowLogger, + conntrack: ct, ctx: ctx, publicKey: publicKey, } } +// Update applies new flow configuration settings func (m *Manager) Update(update *types.FlowConfig) error { if update == nil { return nil @@ -41,6 +56,12 @@ func (m *Manager) Update(update *types.FlowConfig) error { m.flowConfig = update if update.Enabled { + if m.conntrack != nil { + if err := m.conntrack.Start(); err != nil { + return fmt.Errorf("start conntrack: %w", err) + } + } + m.logger.Enable() if previous == nil || !previous.Enabled { flowClient, err := client.NewClient(m.ctx, m.flowConfig.URL, m.flowConfig.TokenPayload, m.flowConfig.TokenSignature) @@ -55,6 +76,9 @@ func (m *Manager) Update(update *types.FlowConfig) error { return nil } + if m.conntrack != nil { + m.conntrack.Stop() + } m.logger.Disable() if previous != nil && previous.Enabled { return m.receiverClient.Close() @@ -63,10 +87,18 @@ func (m *Manager) Update(update *types.FlowConfig) error { return nil } +// Close cleans up all resources func (m *Manager) Close() { + m.mux.Lock() + defer m.mux.Unlock() + + if m.conntrack != nil { + m.conntrack.Close() + } m.logger.Close() } +// GetLogger returns the flow logger func (m *Manager) GetLogger() types.FlowLogger { return m.logger } diff --git a/client/internal/netflow/types/types.go b/client/internal/netflow/types/types.go index 02c8135fe..2729a8ea0 100644 --- a/client/internal/netflow/types/types.go +++ b/client/internal/netflow/types/types.go @@ -5,15 +5,18 @@ import ( "time" "github.com/google/uuid" + + "github.com/netbirdio/netbird/client/iface/device" ) type Protocol uint8 const ( - ProtocolUnknown = 0 - ICMP = 1 - TCP = 6 - UDP = 17 + ProtocolUnknown = Protocol(0) + ICMP = Protocol(1) + TCP = Protocol(6) + UDP = Protocol(17) + SCTP = Protocol(132) ) func (p Protocol) String() string { @@ -51,7 +54,7 @@ func (d Direction) String() string { } const ( - DirectionUnknown = iota + DirectionUnknown = Direction(iota) Ingress Egress ) @@ -66,7 +69,7 @@ type EventFields struct { FlowID uuid.UUID Type Type Direction Direction - Protocol uint8 + Protocol Protocol SourceIP netip.Addr DestIP netip.Addr SourcePort uint16 @@ -117,3 +120,20 @@ type Store interface { // Close closes the store Close() } + +// ConnTracker defines the interface for connection tracking functionality +type ConnTracker interface { + // Start begins tracking connections by listening for conntrack events. + Start() error + // Stop stops the connection tracking. + Stop() + // Close stops listening for events and cleans up resources + Close() error +} + +// IFaceMapper provides interface to check if we're using userspace WireGuard +type IFaceMapper interface { + IsUserspaceBind() bool + Name() string + Address() device.WGAddress +} diff --git a/go.mod b/go.mod index 4c948c007..e59b7daf4 100644 --- a/go.mod +++ b/go.mod @@ -81,6 +81,8 @@ require ( github.com/testcontainers/testcontainers-go/modules/mysql v0.31.0 github.com/testcontainers/testcontainers-go/modules/postgres v0.31.0 github.com/things-go/go-socks5 v0.0.4 + github.com/ti-mo/conntrack v0.5.1 + github.com/ti-mo/netfilter v0.5.2 github.com/yusufpapurcu/wmi v1.2.4 github.com/zcalusic/sysinfo v1.1.3 go.opentelemetry.io/contrib/instrumentation/google.golang.org/grpc/otelgrpc v0.49.0 diff --git a/go.sum b/go.sum index 23a7b8aac..5fe03e708 100644 --- a/go.sum +++ b/go.sum @@ -697,6 +697,10 @@ github.com/testcontainers/testcontainers-go/modules/postgres v0.31.0 h1:isAwFS3K github.com/testcontainers/testcontainers-go/modules/postgres v0.31.0/go.mod h1:ZNYY8vumNCEG9YI59A9d6/YaMY49uwRhmeU563EzFGw= github.com/things-go/go-socks5 v0.0.4 h1:jMQjIc+qhD4z9cITOMnBiwo9dDmpGuXmBlkRFrl/qD0= github.com/things-go/go-socks5 v0.0.4/go.mod h1:sh4K6WHrmHZpjxLTCHyYtXYH8OUuD+yZun41NomR1IQ= +github.com/ti-mo/conntrack v0.5.1 h1:opEwkFICnDbQc0BUXl73PHBK0h23jEIFVjXsqvF4GY0= +github.com/ti-mo/conntrack v0.5.1/go.mod h1:T6NCbkMdVU4qEIgwL0njA6lw/iCAbzchlnwm1Sa314o= +github.com/ti-mo/netfilter v0.5.2 h1:CTjOwFuNNeZ9QPdRXt1MZFLFUf84cKtiQutNauHWd40= +github.com/ti-mo/netfilter v0.5.2/go.mod h1:Btx3AtFiOVdHReTDmP9AE+hlkOcvIy403u7BXXbWZKo= github.com/tklauser/go-sysconf v0.3.12/go.mod h1:Ho14jnntGE1fpdOqQEEaiKRpvIavV0hSfmBq8nJbHYI= github.com/tklauser/go-sysconf v0.3.14 h1:g5vzr9iPFFz24v2KZXs/pvpvh8/V9Fw6vQK5ZZb78yU= github.com/tklauser/go-sysconf v0.3.14/go.mod h1:1ym4lWMLUOhuBOPGtRcJm7tEGX4SCYNEEEtghGG/8uY=