Handle flow updates (#3455)

This commit is contained in:
Viktor Liu 2025-03-07 13:56:00 +01:00 committed by GitHub
parent 3c3a454e61
commit 54be772ffd
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 154 additions and 94 deletions

View File

@ -58,13 +58,14 @@ func (l *Logger) startReceiver() {
if l.enabled.Load() { if l.enabled.Load() {
return return
} }
l.mux.Lock() l.mux.Lock()
ctx, cancel := context.WithCancel(l.ctx) ctx, cancel := context.WithCancel(l.ctx)
l.cancelReceiver = cancel l.cancelReceiver = cancel
l.mux.Unlock() l.mux.Unlock()
c := make(rcvChan, 100) c := make(rcvChan, 100)
l.rcvChan.Swap(&c) l.rcvChan.Store(&c)
l.enabled.Store(true) l.enabled.Store(true)
for { for {
@ -100,6 +101,7 @@ func (l *Logger) stop() {
l.cancelReceiver() l.cancelReceiver()
l.cancelReceiver = nil l.cancelReceiver = nil
} }
l.rcvChan.Store(nil)
l.mux.Unlock() l.mux.Unlock()
} }

View File

@ -2,6 +2,7 @@ package netflow
import ( import (
"context" "context"
"errors"
"fmt" "fmt"
"runtime" "runtime"
"sync" "sync"
@ -45,46 +46,80 @@ func NewManager(ctx context.Context, iface nftypes.IFaceMapper, publicKey []byte
} }
} }
// Update applies new flow configuration settings
// needsNewClient checks if a new client needs to be created
func (m *Manager) needsNewClient(previous *nftypes.FlowConfig) bool {
current := m.flowConfig
return previous == nil ||
!previous.Enabled ||
previous.TokenPayload != current.TokenPayload ||
previous.TokenSignature != current.TokenSignature ||
previous.URL != current.URL
}
// enableFlow starts components for flow tracking
func (m *Manager) enableFlow(previous *nftypes.FlowConfig) error {
// first make sender ready so events don't pile up
if m.needsNewClient(previous) {
if m.receiverClient != nil {
if err := m.receiverClient.Close(); err != nil {
log.Warnf("error closing previous flow client: %s", err)
}
}
flowClient, err := client.NewClient(m.flowConfig.URL, m.flowConfig.TokenPayload, m.flowConfig.TokenSignature, m.flowConfig.Interval)
if err != nil {
return fmt.Errorf("create client: %w", err)
}
log.Infof("flow client configured to connect to %s", m.flowConfig.URL)
m.receiverClient = flowClient
go m.receiveACKs(flowClient)
go m.startSender()
}
m.logger.Enable()
if m.conntrack != nil {
if err := m.conntrack.Start(m.flowConfig.Counters); err != nil {
return fmt.Errorf("start conntrack: %w", err)
}
}
return nil
}
// disableFlow stops components for flow tracking
func (m *Manager) disableFlow() error {
if m.conntrack != nil {
m.conntrack.Stop()
}
m.logger.Disable()
if m.receiverClient != nil {
return m.receiverClient.Close()
}
return nil
}
// Update applies new flow configuration settings // Update applies new flow configuration settings
func (m *Manager) Update(update *nftypes.FlowConfig) error { func (m *Manager) Update(update *nftypes.FlowConfig) error {
if update == nil { if update == nil {
return nil return nil
} }
m.mux.Lock() m.mux.Lock()
defer m.mux.Unlock() defer m.mux.Unlock()
previous := m.flowConfig previous := m.flowConfig
m.flowConfig = update m.flowConfig = update
if update.Enabled { if update.Enabled {
if m.conntrack != nil { return m.enableFlow(previous)
if err := m.conntrack.Start(update.Counters); err != nil {
return fmt.Errorf("start conntrack: %w", err)
}
}
m.logger.Enable()
if previous == nil || !previous.Enabled {
flowClient, err := client.NewClient(m.ctx, m.flowConfig.URL, m.flowConfig.TokenPayload, m.flowConfig.TokenSignature)
if err != nil {
return err
}
log.Infof("flow client connected to %s", m.flowConfig.URL)
m.receiverClient = flowClient
go m.receiveACKs()
go m.startSender()
}
return nil
} }
if m.conntrack != nil { return m.disableFlow()
m.conntrack.Stop()
}
m.logger.Disable()
if previous != nil && previous.Enabled {
return m.receiverClient.Close()
}
return nil
} }
// Close cleans up all resources // Close cleans up all resources
@ -95,6 +130,13 @@ func (m *Manager) Close() {
if m.conntrack != nil { if m.conntrack != nil {
m.conntrack.Close() m.conntrack.Close()
} }
if m.receiverClient != nil {
if err := m.receiverClient.Close(); err != nil {
log.Warnf("failed to close receiver client: %s", err)
}
}
m.logger.Close() m.logger.Close()
} }
@ -106,6 +148,7 @@ func (m *Manager) GetLogger() nftypes.FlowLogger {
func (m *Manager) startSender() { func (m *Manager) startSender() {
ticker := time.NewTicker(m.flowConfig.Interval) ticker := time.NewTicker(m.flowConfig.Interval)
defer ticker.Stop() defer ticker.Stop()
for { for {
select { select {
case <-m.ctx.Done(): case <-m.ctx.Done():
@ -113,35 +156,38 @@ func (m *Manager) startSender() {
case <-ticker.C: case <-ticker.C:
events := m.logger.GetEvents() events := m.logger.GetEvents()
for _, event := range events { for _, event := range events {
log.Infof("send flow event to server: %s", event.ID) if err := m.send(event); err != nil {
err := m.send(event) log.Errorf("failed to send flow event to server: %s", err)
if err != nil { continue
log.Errorf("send flow event to server: %s", err)
} }
log.Tracef("sent flow event: %s", event.ID)
} }
} }
} }
} }
func (m *Manager) receiveACKs() { func (m *Manager) receiveACKs(client *client.GRPCClient) {
if m.receiverClient == nil { err := client.Receive(m.ctx, m.flowConfig.Interval, func(ack *proto.FlowEventAck) error {
return log.Tracef("received flow event ack: %s", ack.EventId)
}
err := m.receiverClient.Receive(m.ctx, func(ack *proto.FlowEventAck) error {
log.Infof("receive flow event ack: %s", ack.EventId)
m.logger.DeleteEvents([]string{ack.EventId}) m.logger.DeleteEvents([]string{ack.EventId})
return nil return nil
}) })
if err != nil {
log.Errorf("receive flow event ack: %s", err) if err != nil && !errors.Is(err, context.Canceled) {
log.Errorf("failed to receive flow event ack: %s", err)
} }
} }
func (m *Manager) send(event *nftypes.Event) error { func (m *Manager) send(event *nftypes.Event) error {
if m.receiverClient == nil { m.mux.Lock()
client := m.receiverClient
m.mux.Unlock()
if client == nil {
return nil return nil
} }
return m.receiverClient.Send(m.ctx, toProtoEvent(m.publicKey, event))
return client.Send(toProtoEvent(m.publicKey, event))
} }
func toProtoEvent(publicKey []byte, event *nftypes.Event) *proto.FlowEvent { func toProtoEvent(publicKey []byte, event *nftypes.Event) *proto.FlowEvent {
@ -163,6 +209,7 @@ func toProtoEvent(publicKey []byte, event *nftypes.Event) *proto.FlowEvent {
TxBytes: event.TxBytes, TxBytes: event.TxBytes,
}, },
} }
if event.Protocol == nftypes.ICMP { if event.Protocol == nftypes.ICMP {
protoEvent.FlowFields.ConnectionInfo = &proto.FlowFields_IcmpInfo{ protoEvent.FlowFields.ConnectionInfo = &proto.FlowFields_IcmpInfo{
IcmpInfo: &proto.ICMPInfo{ IcmpInfo: &proto.ICMPInfo{

View File

@ -3,6 +3,8 @@ package store
import ( import (
"sync" "sync"
"golang.org/x/exp/maps"
"github.com/netbirdio/netbird/client/internal/netflow/types" "github.com/netbirdio/netbird/client/internal/netflow/types"
) )
@ -26,7 +28,7 @@ func (m *Memory) StoreEvent(event *types.Event) {
func (m *Memory) Close() { func (m *Memory) Close() {
m.mux.Lock() m.mux.Lock()
defer m.mux.Unlock() defer m.mux.Unlock()
m.events = make(map[string]*types.Event) maps.Clear(m.events)
} }
func (m *Memory) GetEvents() []*types.Event { func (m *Memory) GetEvents() []*types.Event {

View File

@ -4,8 +4,10 @@ import (
"context" "context"
"crypto/tls" "crypto/tls"
"crypto/x509" "crypto/x509"
"errors"
"fmt" "fmt"
"strings" "strings"
"sync"
"time" "time"
"github.com/cenkalti/backoff/v4" "github.com/cenkalti/backoff/v4"
@ -25,95 +27,99 @@ type GRPCClient struct {
realClient proto.FlowServiceClient realClient proto.FlowServiceClient
clientConn *grpc.ClientConn clientConn *grpc.ClientConn
stream proto.FlowService_EventsClient stream proto.FlowService_EventsClient
streamMu sync.Mutex
} }
func NewClient(ctx context.Context, addr, payload, signature string) (*GRPCClient, error) { func NewClient(addr, payload, signature string, interval time.Duration) (*GRPCClient, error) {
var opts []grpc.DialOption
transportOption := grpc.WithTransportCredentials(insecure.NewCredentials())
if strings.Contains(addr, "443") { if strings.Contains(addr, "443") {
certPool, err := x509.SystemCertPool() certPool, err := x509.SystemCertPool()
if err != nil || certPool == nil { if err != nil || certPool == nil {
log.Debugf("System cert pool not available; falling back to embedded cert, error: %v", err) log.Debugf("System cert pool not available; falling back to embedded cert, error: %v", err)
certPool = embeddedroots.Get() certPool = embeddedroots.Get()
} }
transportOption = grpc.WithTransportCredentials(credentials.NewTLS(&tls.Config{ opts = append(opts, grpc.WithTransportCredentials(credentials.NewTLS(&tls.Config{
RootCAs: certPool, RootCAs: certPool,
})) })))
} else {
opts = append(opts, grpc.WithTransportCredentials(insecure.NewCredentials()))
} }
connCtx, cancel := context.WithTimeout(context.Background(), 30*time.Second) opts = append(opts,
defer cancel()
conn, err := grpc.DialContext(
connCtx,
addr,
transportOption,
nbgrpc.WithCustomDialer(), nbgrpc.WithCustomDialer(),
grpc.WithBlock(), grpc.WithIdleTimeout(interval*2),
grpc.WithKeepaliveParams(keepalive.ClientParameters{ grpc.WithKeepaliveParams(keepalive.ClientParameters{
Time: 30 * time.Second, Time: 30 * time.Second,
Timeout: 10 * time.Second, Timeout: 10 * time.Second,
}), }),
withAuthToken(payload, signature), withAuthToken(payload, signature),
grpc.WithDefaultServiceConfig(`{"healthCheckConfig": {"serviceName": ""}}`),
) )
conn, err := grpc.NewClient(addr, opts...)
if err != nil { if err != nil {
return nil, fmt.Errorf("dialing with context: %s", err) return nil, fmt.Errorf("creating new grpc client: %w", err)
} }
client := &GRPCClient{ return &GRPCClient{
realClient: proto.NewFlowServiceClient(conn), realClient: proto.NewFlowServiceClient(conn),
clientConn: conn, clientConn: conn,
} }, nil
return client, nil
} }
func (c *GRPCClient) Close() error { func (c *GRPCClient) Close() error {
c.streamMu.Lock()
defer c.streamMu.Unlock()
c.stream = nil
return c.clientConn.Close() return c.clientConn.Close()
} }
func (c *GRPCClient) Receive(ctx context.Context, msgHandler func(msg *proto.FlowEventAck) error) error { func (c *GRPCClient) Receive(ctx context.Context, interval time.Duration, msgHandler func(msg *proto.FlowEventAck) error) error {
backOff := defaultBackoff(ctx) backOff := defaultBackoff(ctx, interval)
operation := func() error { operation := func() error {
connState := c.clientConn.GetState() return c.establishStreamAndReceive(ctx, msgHandler)
if connState == connectivity.Shutdown {
return backoff.Permanent(fmt.Errorf("connection to signal has been shut down"))
}
stream, err := c.realClient.Events(ctx, grpc.WaitForReady(true))
if err != nil {
return err
}
c.stream = stream
err = checkHeader(stream)
if err != nil {
return err
}
return c.receive(stream, msgHandler)
} }
err := backoff.Retry(operation, backOff) if err := backoff.Retry(operation, backOff); err != nil {
if err != nil { return fmt.Errorf("receive failed permanently: %w", err)
log.Errorf("exiting the flow receiver service connection retry loop due to the unrecoverable error: %v", err)
return err
} }
return nil return nil
} }
func (c *GRPCClient) establishStreamAndReceive(ctx context.Context, msgHandler func(msg *proto.FlowEventAck) error) error {
if c.clientConn.GetState() == connectivity.Shutdown {
return backoff.Permanent(errors.New("connection to flow receiver has been shut down"))
}
stream, err := c.realClient.Events(ctx, grpc.WaitForReady(true))
if err != nil {
return fmt.Errorf("create event stream: %w", err)
}
if err = checkHeader(stream); err != nil {
return fmt.Errorf("check header: %w", err)
}
c.streamMu.Lock()
c.stream = stream
c.streamMu.Unlock()
return c.receive(stream, msgHandler)
}
func (c *GRPCClient) receive(stream proto.FlowService_EventsClient, msgHandler func(msg *proto.FlowEventAck) error) error { func (c *GRPCClient) receive(stream proto.FlowService_EventsClient, msgHandler func(msg *proto.FlowEventAck) error) error {
for { for {
msg, err := stream.Recv() msg, err := stream.Recv()
if err != nil { if err != nil {
return err return fmt.Errorf("receive from stream: %w", err)
} }
if err := msgHandler(msg); err != nil { if err := msgHandler(msg); err != nil {
return err return fmt.Errorf("handle message: %w", err)
} }
} }
} }
@ -122,7 +128,7 @@ func checkHeader(stream proto.FlowService_EventsClient) error {
header, err := stream.Header() header, err := stream.Header()
if err != nil { if err != nil {
log.Errorf("waiting for flow receiver header: %s", err) log.Errorf("waiting for flow receiver header: %s", err)
return err return fmt.Errorf("wait for header: %w", err)
} }
if len(header) == 0 { if len(header) == 0 {
@ -132,26 +138,29 @@ func checkHeader(stream proto.FlowService_EventsClient) error {
return nil return nil
} }
func defaultBackoff(ctx context.Context) backoff.BackOff { func defaultBackoff(ctx context.Context, interval time.Duration) backoff.BackOff {
return backoff.WithContext(&backoff.ExponentialBackOff{ return backoff.WithContext(&backoff.ExponentialBackOff{
InitialInterval: 800 * time.Millisecond, InitialInterval: 800 * time.Millisecond,
RandomizationFactor: 1, RandomizationFactor: 1,
Multiplier: 1.7, Multiplier: 1.7,
MaxInterval: 10 * time.Second, MaxInterval: interval / 2,
MaxElapsedTime: 3 * 30 * 24 * time.Hour, // 3 months MaxElapsedTime: 3 * 30 * 24 * time.Hour, // 3 months
Stop: backoff.Stop, Stop: backoff.Stop,
Clock: backoff.SystemClock, Clock: backoff.SystemClock,
}, ctx) }, ctx)
} }
func (c *GRPCClient) Send(ctx context.Context, event *proto.FlowEvent) error { func (c *GRPCClient) Send(event *proto.FlowEvent) error {
if c.stream == nil { c.streamMu.Lock()
return fmt.Errorf("stream not initialized") stream := c.stream
c.streamMu.Unlock()
if stream == nil {
return errors.New("stream not initialized")
} }
err := c.stream.Send(event) if err := stream.Send(event); err != nil {
if err != nil { return fmt.Errorf("send flow event: %w", err)
return fmt.Errorf("sending flow event: %s", err)
} }
return nil return nil