diff --git a/client/internal/netflow/logger/logger.go b/client/internal/netflow/logger/logger.go index 1e23c1dce..6569b75b7 100644 --- a/client/internal/netflow/logger/logger.go +++ b/client/internal/netflow/logger/logger.go @@ -58,13 +58,14 @@ func (l *Logger) startReceiver() { if l.enabled.Load() { return } + l.mux.Lock() ctx, cancel := context.WithCancel(l.ctx) l.cancelReceiver = cancel l.mux.Unlock() c := make(rcvChan, 100) - l.rcvChan.Swap(&c) + l.rcvChan.Store(&c) l.enabled.Store(true) for { @@ -100,6 +101,7 @@ func (l *Logger) stop() { l.cancelReceiver() l.cancelReceiver = nil } + l.rcvChan.Store(nil) l.mux.Unlock() } diff --git a/client/internal/netflow/manager.go b/client/internal/netflow/manager.go index 8ab81f4ff..ed5655f8a 100644 --- a/client/internal/netflow/manager.go +++ b/client/internal/netflow/manager.go @@ -2,6 +2,7 @@ package netflow import ( "context" + "errors" "fmt" "runtime" "sync" @@ -45,46 +46,80 @@ func NewManager(ctx context.Context, iface nftypes.IFaceMapper, publicKey []byte } } +// Update applies new flow configuration settings +// needsNewClient checks if a new client needs to be created +func (m *Manager) needsNewClient(previous *nftypes.FlowConfig) bool { + current := m.flowConfig + return previous == nil || + !previous.Enabled || + previous.TokenPayload != current.TokenPayload || + previous.TokenSignature != current.TokenSignature || + previous.URL != current.URL +} + +// enableFlow starts components for flow tracking +func (m *Manager) enableFlow(previous *nftypes.FlowConfig) error { + // first make sender ready so events don't pile up + if m.needsNewClient(previous) { + if m.receiverClient != nil { + if err := m.receiverClient.Close(); err != nil { + log.Warnf("error closing previous flow client: %s", err) + } + } + + flowClient, err := client.NewClient(m.flowConfig.URL, m.flowConfig.TokenPayload, m.flowConfig.TokenSignature, m.flowConfig.Interval) + if err != nil { + return fmt.Errorf("create client: %w", err) + } + log.Infof("flow client configured to connect to %s", m.flowConfig.URL) + + m.receiverClient = flowClient + go m.receiveACKs(flowClient) + go m.startSender() + } + + m.logger.Enable() + + if m.conntrack != nil { + if err := m.conntrack.Start(m.flowConfig.Counters); err != nil { + return fmt.Errorf("start conntrack: %w", err) + } + } + + return nil +} + +// disableFlow stops components for flow tracking +func (m *Manager) disableFlow() error { + if m.conntrack != nil { + m.conntrack.Stop() + } + + m.logger.Disable() + + if m.receiverClient != nil { + return m.receiverClient.Close() + } + return nil +} + // Update applies new flow configuration settings func (m *Manager) Update(update *nftypes.FlowConfig) error { if update == nil { return nil } + m.mux.Lock() defer m.mux.Unlock() + previous := m.flowConfig m.flowConfig = update if update.Enabled { - if m.conntrack != nil { - if err := m.conntrack.Start(update.Counters); err != nil { - return fmt.Errorf("start conntrack: %w", err) - } - } - - m.logger.Enable() - if previous == nil || !previous.Enabled { - flowClient, err := client.NewClient(m.ctx, m.flowConfig.URL, m.flowConfig.TokenPayload, m.flowConfig.TokenSignature) - if err != nil { - return err - } - log.Infof("flow client connected to %s", m.flowConfig.URL) - m.receiverClient = flowClient - go m.receiveACKs() - go m.startSender() - } - return nil + return m.enableFlow(previous) } - if m.conntrack != nil { - m.conntrack.Stop() - } - m.logger.Disable() - if previous != nil && previous.Enabled { - return m.receiverClient.Close() - } - - return nil + return m.disableFlow() } // Close cleans up all resources @@ -95,6 +130,13 @@ func (m *Manager) Close() { if m.conntrack != nil { m.conntrack.Close() } + + if m.receiverClient != nil { + if err := m.receiverClient.Close(); err != nil { + log.Warnf("failed to close receiver client: %s", err) + } + } + m.logger.Close() } @@ -106,6 +148,7 @@ func (m *Manager) GetLogger() nftypes.FlowLogger { func (m *Manager) startSender() { ticker := time.NewTicker(m.flowConfig.Interval) defer ticker.Stop() + for { select { case <-m.ctx.Done(): @@ -113,35 +156,38 @@ func (m *Manager) startSender() { case <-ticker.C: events := m.logger.GetEvents() for _, event := range events { - log.Infof("send flow event to server: %s", event.ID) - err := m.send(event) - if err != nil { - log.Errorf("send flow event to server: %s", err) + if err := m.send(event); err != nil { + log.Errorf("failed to send flow event to server: %s", err) + continue } + log.Tracef("sent flow event: %s", event.ID) } } } } -func (m *Manager) receiveACKs() { - if m.receiverClient == nil { - return - } - err := m.receiverClient.Receive(m.ctx, func(ack *proto.FlowEventAck) error { - log.Infof("receive flow event ack: %s", ack.EventId) +func (m *Manager) receiveACKs(client *client.GRPCClient) { + err := client.Receive(m.ctx, m.flowConfig.Interval, func(ack *proto.FlowEventAck) error { + log.Tracef("received flow event ack: %s", ack.EventId) m.logger.DeleteEvents([]string{ack.EventId}) return nil }) - if err != nil { - log.Errorf("receive flow event ack: %s", err) + + if err != nil && !errors.Is(err, context.Canceled) { + log.Errorf("failed to receive flow event ack: %s", err) } } func (m *Manager) send(event *nftypes.Event) error { - if m.receiverClient == nil { + m.mux.Lock() + client := m.receiverClient + m.mux.Unlock() + + if client == nil { return nil } - return m.receiverClient.Send(m.ctx, toProtoEvent(m.publicKey, event)) + + return client.Send(toProtoEvent(m.publicKey, event)) } func toProtoEvent(publicKey []byte, event *nftypes.Event) *proto.FlowEvent { @@ -163,6 +209,7 @@ func toProtoEvent(publicKey []byte, event *nftypes.Event) *proto.FlowEvent { TxBytes: event.TxBytes, }, } + if event.Protocol == nftypes.ICMP { protoEvent.FlowFields.ConnectionInfo = &proto.FlowFields_IcmpInfo{ IcmpInfo: &proto.ICMPInfo{ diff --git a/client/internal/netflow/store/memory.go b/client/internal/netflow/store/memory.go index b0dcbd6f8..7fa08b510 100644 --- a/client/internal/netflow/store/memory.go +++ b/client/internal/netflow/store/memory.go @@ -3,6 +3,8 @@ package store import ( "sync" + "golang.org/x/exp/maps" + "github.com/netbirdio/netbird/client/internal/netflow/types" ) @@ -26,7 +28,7 @@ func (m *Memory) StoreEvent(event *types.Event) { func (m *Memory) Close() { m.mux.Lock() defer m.mux.Unlock() - m.events = make(map[string]*types.Event) + maps.Clear(m.events) } func (m *Memory) GetEvents() []*types.Event { diff --git a/flow/client/client.go b/flow/client/client.go index 47c80ef0d..b16b28c64 100644 --- a/flow/client/client.go +++ b/flow/client/client.go @@ -4,8 +4,10 @@ import ( "context" "crypto/tls" "crypto/x509" + "errors" "fmt" "strings" + "sync" "time" "github.com/cenkalti/backoff/v4" @@ -25,95 +27,99 @@ type GRPCClient struct { realClient proto.FlowServiceClient clientConn *grpc.ClientConn stream proto.FlowService_EventsClient + streamMu sync.Mutex } -func NewClient(ctx context.Context, addr, payload, signature string) (*GRPCClient, error) { +func NewClient(addr, payload, signature string, interval time.Duration) (*GRPCClient, error) { + var opts []grpc.DialOption - transportOption := grpc.WithTransportCredentials(insecure.NewCredentials()) if strings.Contains(addr, "443") { - certPool, err := x509.SystemCertPool() if err != nil || certPool == nil { log.Debugf("System cert pool not available; falling back to embedded cert, error: %v", err) certPool = embeddedroots.Get() } - transportOption = grpc.WithTransportCredentials(credentials.NewTLS(&tls.Config{ + opts = append(opts, grpc.WithTransportCredentials(credentials.NewTLS(&tls.Config{ RootCAs: certPool, - })) + }))) + } else { + opts = append(opts, grpc.WithTransportCredentials(insecure.NewCredentials())) } - connCtx, cancel := context.WithTimeout(context.Background(), 30*time.Second) - defer cancel() - - conn, err := grpc.DialContext( - connCtx, - addr, - transportOption, + opts = append(opts, nbgrpc.WithCustomDialer(), - grpc.WithBlock(), + grpc.WithIdleTimeout(interval*2), grpc.WithKeepaliveParams(keepalive.ClientParameters{ Time: 30 * time.Second, Timeout: 10 * time.Second, }), withAuthToken(payload, signature), + grpc.WithDefaultServiceConfig(`{"healthCheckConfig": {"serviceName": ""}}`), ) + conn, err := grpc.NewClient(addr, opts...) if err != nil { - return nil, fmt.Errorf("dialing with context: %s", err) + return nil, fmt.Errorf("creating new grpc client: %w", err) } - client := &GRPCClient{ + return &GRPCClient{ realClient: proto.NewFlowServiceClient(conn), clientConn: conn, - } - return client, nil + }, nil } func (c *GRPCClient) Close() error { + c.streamMu.Lock() + defer c.streamMu.Unlock() + + c.stream = nil return c.clientConn.Close() } -func (c *GRPCClient) Receive(ctx context.Context, msgHandler func(msg *proto.FlowEventAck) error) error { - backOff := defaultBackoff(ctx) +func (c *GRPCClient) Receive(ctx context.Context, interval time.Duration, msgHandler func(msg *proto.FlowEventAck) error) error { + backOff := defaultBackoff(ctx, interval) operation := func() error { - connState := c.clientConn.GetState() - if connState == connectivity.Shutdown { - return backoff.Permanent(fmt.Errorf("connection to signal has been shut down")) - } - - stream, err := c.realClient.Events(ctx, grpc.WaitForReady(true)) - if err != nil { - return err - } - c.stream = stream - - err = checkHeader(stream) - if err != nil { - return err - } - - return c.receive(stream, msgHandler) + return c.establishStreamAndReceive(ctx, msgHandler) } - err := backoff.Retry(operation, backOff) - if err != nil { - log.Errorf("exiting the flow receiver service connection retry loop due to the unrecoverable error: %v", err) - return err + if err := backoff.Retry(operation, backOff); err != nil { + return fmt.Errorf("receive failed permanently: %w", err) } return nil } +func (c *GRPCClient) establishStreamAndReceive(ctx context.Context, msgHandler func(msg *proto.FlowEventAck) error) error { + if c.clientConn.GetState() == connectivity.Shutdown { + return backoff.Permanent(errors.New("connection to flow receiver has been shut down")) + } + + stream, err := c.realClient.Events(ctx, grpc.WaitForReady(true)) + if err != nil { + return fmt.Errorf("create event stream: %w", err) + } + + if err = checkHeader(stream); err != nil { + return fmt.Errorf("check header: %w", err) + } + + c.streamMu.Lock() + c.stream = stream + c.streamMu.Unlock() + + return c.receive(stream, msgHandler) +} + func (c *GRPCClient) receive(stream proto.FlowService_EventsClient, msgHandler func(msg *proto.FlowEventAck) error) error { for { msg, err := stream.Recv() if err != nil { - return err + return fmt.Errorf("receive from stream: %w", err) } if err := msgHandler(msg); err != nil { - return err + return fmt.Errorf("handle message: %w", err) } } } @@ -122,7 +128,7 @@ func checkHeader(stream proto.FlowService_EventsClient) error { header, err := stream.Header() if err != nil { log.Errorf("waiting for flow receiver header: %s", err) - return err + return fmt.Errorf("wait for header: %w", err) } if len(header) == 0 { @@ -132,26 +138,29 @@ func checkHeader(stream proto.FlowService_EventsClient) error { return nil } -func defaultBackoff(ctx context.Context) backoff.BackOff { +func defaultBackoff(ctx context.Context, interval time.Duration) backoff.BackOff { return backoff.WithContext(&backoff.ExponentialBackOff{ InitialInterval: 800 * time.Millisecond, RandomizationFactor: 1, Multiplier: 1.7, - MaxInterval: 10 * time.Second, + MaxInterval: interval / 2, MaxElapsedTime: 3 * 30 * 24 * time.Hour, // 3 months Stop: backoff.Stop, Clock: backoff.SystemClock, }, ctx) } -func (c *GRPCClient) Send(ctx context.Context, event *proto.FlowEvent) error { - if c.stream == nil { - return fmt.Errorf("stream not initialized") +func (c *GRPCClient) Send(event *proto.FlowEvent) error { + c.streamMu.Lock() + stream := c.stream + c.streamMu.Unlock() + + if stream == nil { + return errors.New("stream not initialized") } - err := c.stream.Send(event) - if err != nil { - return fmt.Errorf("sending flow event: %s", err) + if err := stream.Send(event); err != nil { + return fmt.Errorf("send flow event: %w", err) } return nil