From 88d1c5a0fd57fe604c3a6eb7012e7f891bdd51c0 Mon Sep 17 00:00:00 2001 From: Maycon Santos Date: Tue, 16 Jul 2024 10:14:30 +0200 Subject: [PATCH 1/8] fix forwarded metrics (#2273) --- signal/server/signal.go | 15 +++++++++------ 1 file changed, 9 insertions(+), 6 deletions(-) diff --git a/signal/server/signal.go b/signal/server/signal.go index fc9c19efd..02c49c31d 100644 --- a/signal/server/signal.go +++ b/signal/server/signal.go @@ -23,6 +23,8 @@ const ( labelTypeError = "error" labelTypeNotConnected = "not_connected" labelTypeNotRegistered = "not_registered" + labelTypeStream = "stream" + labelTypeMessage = "message" labelError = "error" labelErrorMissingId = "missing_id" @@ -62,6 +64,7 @@ func (s *Server) Send(ctx context.Context, msg *proto.EncryptedMessage) (*proto. } if dstPeer, found := s.registry.Get(msg.RemoteKey); found { + start := time.Now() //forward the message to the target peer if err := dstPeer.Stream.Send(msg); err != nil { log.Errorf("error while forwarding message from peer [%s] to peer [%s] %v", msg.Key, msg.RemoteKey, err) @@ -69,6 +72,7 @@ func (s *Server) Send(ctx context.Context, msg *proto.EncryptedMessage) (*proto. s.metrics.MessageForwardFailures.Add(ctx, 1, metric.WithAttributes(attribute.String(labelType, labelTypeError))) } else { + s.metrics.MessageForwardLatency.Record(ctx, float64(time.Since(start).Nanoseconds())/1e6, metric.WithAttributes(attribute.String(labelType, labelTypeMessage))) s.metrics.MessagesForwarded.Add(context.Background(), 1) } } else { @@ -118,22 +122,21 @@ func (s *Server) ConnectStream(stream proto.SignalExchange_ConnectStreamServer) } else if err != nil { return err } - start := time.Now() log.Debugf("received a new message from peer [%s] to peer [%s]", p.Id, msg.RemoteKey) // lookup the target peer where the message is going to if dstPeer, found := s.registry.Get(msg.RemoteKey); found { + start := time.Now() //forward the message to the target peer if err := dstPeer.Stream.Send(msg); err != nil { log.Errorf("error while forwarding message from peer [%s] to peer [%s] %v", p.Id, msg.RemoteKey, err) //todo respond to the sender? - - // in milliseconds - s.metrics.MessageForwardLatency.Record(stream.Context(), float64(time.Since(start).Nanoseconds())/1e6) - s.metrics.MessagesForwarded.Add(stream.Context(), 1) - } else { s.metrics.MessageForwardFailures.Add(stream.Context(), 1, metric.WithAttributes(attribute.String(labelType, labelTypeError))) + } else { + // in milliseconds + s.metrics.MessageForwardLatency.Record(stream.Context(), float64(time.Since(start).Nanoseconds())/1e6, metric.WithAttributes(attribute.String(labelType, labelTypeStream))) + s.metrics.MessagesForwarded.Add(stream.Context(), 1) } } else { log.Debugf("message from peer [%s] can't be forwarded to peer [%s] because destination peer is not connected", p.Id, msg.RemoteKey) From 12ff93ba7248c75b10b2f1c651ae06160fb717ea Mon Sep 17 00:00:00 2001 From: Carlos Hernandez Date: Tue, 16 Jul 2024 02:19:01 -0600 Subject: [PATCH 2/8] Ignore no unique route updates (#2266) --- client/internal/routemanager/client.go | 28 +++++++++++++++++++++----- 1 file changed, 23 insertions(+), 5 deletions(-) diff --git a/client/internal/routemanager/client.go b/client/internal/routemanager/client.go index 3c230df21..92c71b1e0 100644 --- a/client/internal/routemanager/client.go +++ b/client/internal/routemanager/client.go @@ -3,6 +3,7 @@ package routemanager import ( "context" "fmt" + "reflect" "time" "github.com/hashicorp/go-multierror" @@ -309,22 +310,33 @@ func (c *clientNetwork) sendUpdateToClientNetworkWatcher(update routesUpdate) { }() } -func (c *clientNetwork) handleUpdate(update routesUpdate) { +func (c *clientNetwork) handleUpdate(update routesUpdate) bool { + isUpdateMapDifferent := false updateMap := make(map[route.ID]*route.Route) for _, r := range update.routes { updateMap[r.ID] = r } + if len(c.routes) != len(updateMap) { + isUpdateMapDifferent = true + } + for id, r := range c.routes { _, found := updateMap[id] if !found { close(c.routePeersNotifiers[r.Peer]) delete(c.routePeersNotifiers, r.Peer) + isUpdateMapDifferent = true + continue + } + if !reflect.DeepEqual(c.routes[id], updateMap[id]) { + isUpdateMapDifferent = true } } c.routes = updateMap + return isUpdateMapDifferent } // peersStateAndUpdateWatcher is the main point of reacting on client network routing events. @@ -351,13 +363,19 @@ func (c *clientNetwork) peersStateAndUpdateWatcher() { log.Debugf("Received a new client network route update for [%v]", c.handler) - c.handleUpdate(update) + // hash update somehow + isTrueRouteUpdate := c.handleUpdate(update) c.updateSerial = update.updateSerial - err := c.recalculateRouteAndUpdatePeerAndSystem() - if err != nil { - log.Errorf("Failed to recalculate routes for network [%v]: %v", c.handler, err) + if isTrueRouteUpdate { + log.Debug("Client network update contains different routes, recalculating routes") + err := c.recalculateRouteAndUpdatePeerAndSystem() + if err != nil { + log.Errorf("Failed to recalculate routes for network [%v]: %v", c.handler, err) + } + } else { + log.Debug("Route update is not different, skipping route recalculation") } c.startPeersStatusChangeWatcher() From 1d6f5482ddae3ae6805f279e180034a3b127c013 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jakub=20Ko=C5=82odziejczak?= <31549762+mrl5@users.noreply.github.com> Date: Tue, 16 Jul 2024 10:19:58 +0200 Subject: [PATCH 3/8] feat(client): send logs to syslog (#2259) --- client/cmd/root.go | 2 +- util/log.go | 6 +++++- util/syslog_nonwindows.go | 20 ++++++++++++++++++++ util/syslog_windows.go | 3 +++ 4 files changed, 29 insertions(+), 2 deletions(-) create mode 100644 util/syslog_nonwindows.go create mode 100644 util/syslog_windows.go diff --git a/client/cmd/root.go b/client/cmd/root.go index f0b5d2bdf..1e5c56366 100644 --- a/client/cmd/root.go +++ b/client/cmd/root.go @@ -121,7 +121,7 @@ func init() { rootCmd.PersistentFlags().StringVarP(&serviceName, "service", "s", defaultServiceName, "Netbird system service name") rootCmd.PersistentFlags().StringVarP(&configPath, "config", "c", defaultConfigPath, "Netbird config file location") rootCmd.PersistentFlags().StringVarP(&logLevel, "log-level", "l", "info", "sets Netbird log level") - rootCmd.PersistentFlags().StringVar(&logFile, "log-file", defaultLogFile, "sets Netbird log path. If console is specified the log will be output to stdout") + rootCmd.PersistentFlags().StringVar(&logFile, "log-file", defaultLogFile, "sets Netbird log path. If console is specified the log will be output to stdout. If syslog is specified the log will be sent to syslog daemon.") rootCmd.PersistentFlags().StringVarP(&setupKey, "setup-key", "k", "", "Setup key obtained from the Management Service Dashboard (used to register peer)") rootCmd.PersistentFlags().StringVar(&preSharedKey, preSharedKeyFlag, "", "Sets Wireguard PreSharedKey property. If set, then only peers that have the same key can communicate.") rootCmd.PersistentFlags().StringVarP(&hostName, "hostname", "n", "", "Sets a custom hostname for the device") diff --git a/util/log.go b/util/log.go index 90ccea48f..11bb7efa7 100644 --- a/util/log.go +++ b/util/log.go @@ -4,6 +4,7 @@ import ( "io" "os" "path/filepath" + "slices" log "github.com/sirupsen/logrus" "gopkg.in/natefinch/lumberjack.v2" @@ -18,8 +19,9 @@ func InitLog(logLevel string, logPath string) error { log.Errorf("Failed parsing log-level %s: %s", logLevel, err) return err } + custom_outputs := []string{"console", "syslog"}; - if logPath != "" && logPath != "console" { + if logPath != "" && !slices.Contains(custom_outputs, logPath) { lumberjackLogger := &lumberjack.Logger{ // Log file absolute path, os agnostic Filename: filepath.ToSlash(logPath), @@ -29,6 +31,8 @@ func InitLog(logLevel string, logPath string) error { Compress: true, } log.SetOutput(io.Writer(lumberjackLogger)) + } else if logPath == "syslog" { + AddSyslogHook() } if os.Getenv("NB_LOG_FORMAT") == "json" { diff --git a/util/syslog_nonwindows.go b/util/syslog_nonwindows.go new file mode 100644 index 000000000..6ffbcb8be --- /dev/null +++ b/util/syslog_nonwindows.go @@ -0,0 +1,20 @@ +//go:build !windows +// +build !windows + +package util + +import ( + "log/syslog" + + log "github.com/sirupsen/logrus" + lSyslog "github.com/sirupsen/logrus/hooks/syslog" +) + +func AddSyslogHook() { + hook, err := lSyslog.NewSyslogHook("", "", syslog.LOG_INFO, "") + + if err != nil { + log.Errorf("Failed creating syslog hook: %s", err) + } + log.AddHook(hook) +} diff --git a/util/syslog_windows.go b/util/syslog_windows.go new file mode 100644 index 000000000..a38d90054 --- /dev/null +++ b/util/syslog_windows.go @@ -0,0 +1,3 @@ +package util + +func AddSyslogHook() {} From f9c59a71316aa0f244a90e2564cdcb975e8c74b3 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jakub=20Ko=C5=82odziejczak?= <31549762+mrl5@users.noreply.github.com> Date: Tue, 16 Jul 2024 11:50:35 +0200 Subject: [PATCH 4/8] Refactor log util (#2276) --- util/log.go | 4 ++-- util/syslog_windows.go | 5 ++++- 2 files changed, 6 insertions(+), 3 deletions(-) diff --git a/util/log.go b/util/log.go index 11bb7efa7..74b99311e 100644 --- a/util/log.go +++ b/util/log.go @@ -19,9 +19,9 @@ func InitLog(logLevel string, logPath string) error { log.Errorf("Failed parsing log-level %s: %s", logLevel, err) return err } - custom_outputs := []string{"console", "syslog"}; + customOutputs := []string{"console", "syslog"}; - if logPath != "" && !slices.Contains(custom_outputs, logPath) { + if logPath != "" && !slices.Contains(customOutputs, logPath) { lumberjackLogger := &lumberjack.Logger{ // Log file absolute path, os agnostic Filename: filepath.ToSlash(logPath), diff --git a/util/syslog_windows.go b/util/syslog_windows.go index a38d90054..171c1a459 100644 --- a/util/syslog_windows.go +++ b/util/syslog_windows.go @@ -1,3 +1,6 @@ package util -func AddSyslogHook() {} +func AddSyslogHook() { + // The syslog package is not available for Windows. This adapter is needed + // to handle windows build. +} From 7c595e84934ae846d53dc0264b41614467c176f4 Mon Sep 17 00:00:00 2001 From: Maycon Santos Date: Tue, 16 Jul 2024 15:36:51 +0200 Subject: [PATCH 5/8] Add get_registration_delay_milliseconds metric (#2275) --- signal/README.md | 3 +++ signal/metrics/app.go | 8 ++++++++ signal/server/signal.go | 15 +++++++++++++-- 3 files changed, 24 insertions(+), 2 deletions(-) diff --git a/signal/README.md b/signal/README.md index dd2d761ad..2da47283e 100644 --- a/signal/README.md +++ b/signal/README.md @@ -90,6 +90,9 @@ The Signal Server exposes the following metrics in Prometheus format: - **registration_delay_milliseconds**: A Histogram metric that measures the time it took to register a peer in milliseconds. +- **get_registration_delay_milliseconds**: A Histogram metric that measures the time + it took to get a peer registration in + milliseconds. - **messages_forwarded_total**: A Counter metric that counts the total number of messages forwarded between peers. - **message_forward_failures_total**: A Counter metric that counts the total diff --git a/signal/metrics/app.go b/signal/metrics/app.go index fb882a5d4..f8be88be7 100644 --- a/signal/metrics/app.go +++ b/signal/metrics/app.go @@ -15,6 +15,7 @@ type AppMetrics struct { Deregistrations metric.Int64Counter RegistrationFailures metric.Int64Counter RegistrationDelay metric.Float64Histogram + GetRegistrationDelay metric.Float64Histogram MessagesForwarded metric.Int64Counter MessageForwardFailures metric.Int64Counter @@ -54,6 +55,12 @@ func NewAppMetrics(meter metric.Meter) (*AppMetrics, error) { return nil, err } + getRegistrationDelay, err := meter.Float64Histogram("get_registration_delay_milliseconds", + metric.WithExplicitBucketBoundaries(getStandardBucketBoundaries()...)) + if err != nil { + return nil, err + } + messagesForwarded, err := meter.Int64Counter("messages_forwarded_total") if err != nil { return nil, err @@ -80,6 +87,7 @@ func NewAppMetrics(meter metric.Meter) (*AppMetrics, error) { Deregistrations: deregistrations, RegistrationFailures: registrationFailures, RegistrationDelay: registrationDelay, + GetRegistrationDelay: getRegistrationDelay, MessagesForwarded: messagesForwarded, MessageForwardFailures: messageForwardFailures, diff --git a/signal/server/signal.go b/signal/server/signal.go index 02c49c31d..4ececafff 100644 --- a/signal/server/signal.go +++ b/signal/server/signal.go @@ -30,6 +30,10 @@ const ( labelErrorMissingId = "missing_id" labelErrorMissingMeta = "missing_meta" labelErrorFailedHeader = "failed_header" + + labelRegistrionStatus = "status" + labelRegistrationFound = "found" + labelRegistrationNotFound = "not_found" ) // Server an instance of a Signal server @@ -63,7 +67,10 @@ func (s *Server) Send(ctx context.Context, msg *proto.EncryptedMessage) (*proto. return nil, fmt.Errorf("peer %s is not registered", msg.Key) } + getRegistrationStart := time.Now() + if dstPeer, found := s.registry.Get(msg.RemoteKey); found { + s.metrics.GetRegistrationDelay.Record(ctx, float64(time.Since(getRegistrationStart).Nanoseconds())/1e6, metric.WithAttributes(attribute.String(labelType, labelTypeMessage), attribute.String(labelRegistrionStatus, labelRegistrationFound))) start := time.Now() //forward the message to the target peer if err := dstPeer.Stream.Send(msg); err != nil { @@ -76,6 +83,7 @@ func (s *Server) Send(ctx context.Context, msg *proto.EncryptedMessage) (*proto. s.metrics.MessagesForwarded.Add(context.Background(), 1) } } else { + s.metrics.GetRegistrationDelay.Record(ctx, float64(time.Since(getRegistrationStart).Nanoseconds())/1e6, metric.WithAttributes(attribute.String(labelType, labelTypeMessage), attribute.String(labelRegistrionStatus, labelRegistrationNotFound))) log.Debugf("message from peer [%s] can't be forwarded to peer [%s] because destination peer is not connected", msg.Key, msg.RemoteKey) //todo respond to the sender? @@ -125,8 +133,11 @@ func (s *Server) ConnectStream(stream proto.SignalExchange_ConnectStreamServer) log.Debugf("received a new message from peer [%s] to peer [%s]", p.Id, msg.RemoteKey) + getRegistrationStart := time.Now() + // lookup the target peer where the message is going to if dstPeer, found := s.registry.Get(msg.RemoteKey); found { + s.metrics.GetRegistrationDelay.Record(stream.Context(), float64(time.Since(getRegistrationStart).Nanoseconds())/1e6, metric.WithAttributes(attribute.String(labelType, labelTypeStream), attribute.String(labelRegistrionStatus, labelRegistrationFound))) start := time.Now() //forward the message to the target peer if err := dstPeer.Stream.Send(msg); err != nil { @@ -139,10 +150,10 @@ func (s *Server) ConnectStream(stream proto.SignalExchange_ConnectStreamServer) s.metrics.MessagesForwarded.Add(stream.Context(), 1) } } else { + s.metrics.GetRegistrationDelay.Record(stream.Context(), float64(time.Since(getRegistrationStart).Nanoseconds())/1e6, metric.WithAttributes(attribute.String(labelType, labelTypeStream), attribute.String(labelRegistrionStatus, labelRegistrationNotFound))) + s.metrics.MessageForwardFailures.Add(stream.Context(), 1, metric.WithAttributes(attribute.String(labelType, labelTypeNotConnected))) log.Debugf("message from peer [%s] can't be forwarded to peer [%s] because destination peer is not connected", p.Id, msg.RemoteKey) //todo respond to the sender? - - s.metrics.MessageForwardFailures.Add(stream.Context(), 1, metric.WithAttributes(attribute.String(labelType, labelTypeNotConnected))) } } <-stream.Context().Done() From 668d229b67383d9af9e11430b07c366e046f50af Mon Sep 17 00:00:00 2001 From: Maycon Santos Date: Tue, 16 Jul 2024 16:55:57 +0200 Subject: [PATCH 6/8] Fix metric label typo (#2278) --- signal/server/signal.go | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/signal/server/signal.go b/signal/server/signal.go index 4ececafff..219bdcc41 100644 --- a/signal/server/signal.go +++ b/signal/server/signal.go @@ -31,7 +31,7 @@ const ( labelErrorMissingMeta = "missing_meta" labelErrorFailedHeader = "failed_header" - labelRegistrionStatus = "status" + labelRegistrationStatus = "status" labelRegistrationFound = "found" labelRegistrationNotFound = "not_found" ) @@ -70,7 +70,7 @@ func (s *Server) Send(ctx context.Context, msg *proto.EncryptedMessage) (*proto. getRegistrationStart := time.Now() if dstPeer, found := s.registry.Get(msg.RemoteKey); found { - s.metrics.GetRegistrationDelay.Record(ctx, float64(time.Since(getRegistrationStart).Nanoseconds())/1e6, metric.WithAttributes(attribute.String(labelType, labelTypeMessage), attribute.String(labelRegistrionStatus, labelRegistrationFound))) + s.metrics.GetRegistrationDelay.Record(ctx, float64(time.Since(getRegistrationStart).Nanoseconds())/1e6, metric.WithAttributes(attribute.String(labelType, labelTypeMessage), attribute.String(labelRegistrationStatus, labelRegistrationFound))) start := time.Now() //forward the message to the target peer if err := dstPeer.Stream.Send(msg); err != nil { @@ -83,7 +83,7 @@ func (s *Server) Send(ctx context.Context, msg *proto.EncryptedMessage) (*proto. s.metrics.MessagesForwarded.Add(context.Background(), 1) } } else { - s.metrics.GetRegistrationDelay.Record(ctx, float64(time.Since(getRegistrationStart).Nanoseconds())/1e6, metric.WithAttributes(attribute.String(labelType, labelTypeMessage), attribute.String(labelRegistrionStatus, labelRegistrationNotFound))) + s.metrics.GetRegistrationDelay.Record(ctx, float64(time.Since(getRegistrationStart).Nanoseconds())/1e6, metric.WithAttributes(attribute.String(labelType, labelTypeMessage), attribute.String(labelRegistrationStatus, labelRegistrationNotFound))) log.Debugf("message from peer [%s] can't be forwarded to peer [%s] because destination peer is not connected", msg.Key, msg.RemoteKey) //todo respond to the sender? @@ -137,7 +137,7 @@ func (s *Server) ConnectStream(stream proto.SignalExchange_ConnectStreamServer) // lookup the target peer where the message is going to if dstPeer, found := s.registry.Get(msg.RemoteKey); found { - s.metrics.GetRegistrationDelay.Record(stream.Context(), float64(time.Since(getRegistrationStart).Nanoseconds())/1e6, metric.WithAttributes(attribute.String(labelType, labelTypeStream), attribute.String(labelRegistrionStatus, labelRegistrationFound))) + s.metrics.GetRegistrationDelay.Record(stream.Context(), float64(time.Since(getRegistrationStart).Nanoseconds())/1e6, metric.WithAttributes(attribute.String(labelType, labelTypeStream), attribute.String(labelRegistrationStatus, labelRegistrationFound))) start := time.Now() //forward the message to the target peer if err := dstPeer.Stream.Send(msg); err != nil { @@ -150,7 +150,7 @@ func (s *Server) ConnectStream(stream proto.SignalExchange_ConnectStreamServer) s.metrics.MessagesForwarded.Add(stream.Context(), 1) } } else { - s.metrics.GetRegistrationDelay.Record(stream.Context(), float64(time.Since(getRegistrationStart).Nanoseconds())/1e6, metric.WithAttributes(attribute.String(labelType, labelTypeStream), attribute.String(labelRegistrionStatus, labelRegistrationNotFound))) + s.metrics.GetRegistrationDelay.Record(stream.Context(), float64(time.Since(getRegistrationStart).Nanoseconds())/1e6, metric.WithAttributes(attribute.String(labelType, labelTypeStream), attribute.String(labelRegistrationStatus, labelRegistrationNotFound))) s.metrics.MessageForwardFailures.Add(stream.Context(), 1, metric.WithAttributes(attribute.String(labelType, labelTypeNotConnected))) log.Debugf("message from peer [%s] can't be forwarded to peer [%s] because destination peer is not connected", p.Id, msg.RemoteKey) //todo respond to the sender? From a711e116a3f59edcae6f3197d20723511001da7d Mon Sep 17 00:00:00 2001 From: ctrl-zzz <78654296+ctrl-zzz@users.noreply.github.com> Date: Tue, 16 Jul 2024 17:38:12 +0200 Subject: [PATCH 7/8] fix: save peer status correctly in sqlstore (#2262) * fix: save peer status correctly in sqlstore https://github.com/netbirdio/netbird/issues/2110#issuecomment-2162768273 * feat: update test function * refactor: simplify status update --- management/server/sql_store.go | 11 ++++++++--- management/server/sql_store_test.go | 4 ++-- 2 files changed, 10 insertions(+), 5 deletions(-) diff --git a/management/server/sql_store.go b/management/server/sql_store.go index 41e9fde8b..37cc10d8b 100644 --- a/management/server/sql_store.go +++ b/management/server/sql_store.go @@ -274,10 +274,15 @@ func (s *SqlStore) GetInstallationID() string { func (s *SqlStore) SavePeerStatus(accountID, peerID string, peerStatus nbpeer.PeerStatus) error { var peerCopy nbpeer.Peer peerCopy.Status = &peerStatus - result := s.db.Model(&nbpeer.Peer{}). - Where("account_id = ? AND id = ?", accountID, peerID). - Updates(peerCopy) + fieldsToUpdate := []string{ + "peer_status_last_seen", "peer_status_connected", + "peer_status_login_expired", "peer_status_required_approval", + } + result := s.db.Model(&nbpeer.Peer{}). + Select(fieldsToUpdate). + Where("account_id = ? AND id = ?", accountID, peerID). + Updates(&peerCopy) if result.Error != nil { return result.Error } diff --git a/management/server/sql_store_test.go b/management/server/sql_store_test.go index e3ba00b56..f46ca7e5d 100644 --- a/management/server/sql_store_test.go +++ b/management/server/sql_store_test.go @@ -373,7 +373,7 @@ func TestSqlite_SavePeerStatus(t *testing.T) { require.NoError(t, err) // save status of non-existing peer - newStatus := nbpeer.PeerStatus{Connected: true, LastSeen: time.Now().UTC()} + newStatus := nbpeer.PeerStatus{Connected: false, LastSeen: time.Now().UTC()} err = store.SavePeerStatus(account.Id, "non-existing-peer", newStatus) assert.Error(t, err) parsedErr, ok := status.FromError(err) @@ -388,7 +388,7 @@ func TestSqlite_SavePeerStatus(t *testing.T) { IP: net.IP{127, 0, 0, 1}, Meta: nbpeer.PeerSystemMeta{}, Name: "peer name", - Status: &nbpeer.PeerStatus{Connected: false, LastSeen: time.Now().UTC()}, + Status: &nbpeer.PeerStatus{Connected: true, LastSeen: time.Now().UTC()}, } err = store.SaveAccount(context.Background(), account) From f17016b5e5deaee4a904801e724527dcca638a9b Mon Sep 17 00:00:00 2001 From: Bethuel Mmbaga Date: Thu, 18 Jul 2024 13:50:44 +0300 Subject: [PATCH 8/8] Skip peer update on unchanged network map (#2236) * Enhance network updates by skipping unchanged messages Optimizes the network update process by skipping updates where no changes in the peer update message received. * Add unit tests * add locks * Improve concurrency and update peer message handling * Refactor account manager network update tests * fix test * Fix inverted network map update condition * Add default group and policy to test data * Run peer updates in a separate goroutine * Refactor * Refactor lock --- go.mod | 1 + go.sum | 2 + management/server/account_test.go | 332 ++++++++++++++---------- management/server/network.go | 4 +- management/server/peer.go | 4 +- management/server/peer/peer.go | 18 +- management/server/testdata/store.json | 38 ++- management/server/updatechannel.go | 86 +++++- management/server/updatechannel_test.go | 103 ++++++++ 9 files changed, 425 insertions(+), 163 deletions(-) diff --git a/go.mod b/go.mod index 1da44da3b..af6aa327f 100644 --- a/go.mod +++ b/go.mod @@ -66,6 +66,7 @@ require ( github.com/pion/transport/v3 v3.0.1 github.com/pion/turn/v3 v3.0.1 github.com/prometheus/client_golang v1.19.1 + github.com/r3labs/diff v1.1.0 github.com/rs/xid v1.3.0 github.com/shirou/gopsutil/v3 v3.24.4 github.com/skratchdot/open-golang v0.0.0-20200116055534-eef842397966 diff --git a/go.sum b/go.sum index 842311344..f22e26be6 100644 --- a/go.sum +++ b/go.sum @@ -413,6 +413,8 @@ github.com/prometheus/common v0.53.0 h1:U2pL9w9nmJwJDa4qqLQ3ZaePJ6ZTwt7cMD3AG3+a github.com/prometheus/common v0.53.0/go.mod h1:BrxBKv3FWBIGXw89Mg1AeBq7FSyRzXWI3l3e7W3RN5U= github.com/prometheus/procfs v0.15.0 h1:A82kmvXJq2jTu5YUhSGNlYoxh85zLnKgPz4bMZgI5Ek= github.com/prometheus/procfs v0.15.0/go.mod h1:Y0RJ/Y5g5wJpkTisOtqwDSo4HwhGmLB4VQSw2sQJLHk= +github.com/r3labs/diff v1.1.0 h1:V53xhrbTHrWFWq3gI4b94AjgEJOerO1+1l0xyHOBi8M= +github.com/r3labs/diff v1.1.0/go.mod h1:7WjXasNzi0vJetRcB/RqNl5dlIsmXcTTLmF5IoH6Xig= github.com/rogpeppe/go-internal v1.11.0 h1:cWPaGQEPrBb5/AsnsZesgZZ9yb1OQ+GOISoDNXVBh4M= github.com/rogpeppe/go-internal v1.11.0/go.mod h1:ddIwULY96R17DhadqLgMfk9H9tvdUzkipdSkR5nkCZA= github.com/rs/cors v1.8.0 h1:P2KMzcFwrPoSjkF1WLRPsp3UMLyql8L4v9hQpVeK5so= diff --git a/management/server/account_test.go b/management/server/account_test.go index 71b43bd65..e6c9b60da 100644 --- a/management/server/account_test.go +++ b/management/server/account_test.go @@ -1108,61 +1108,132 @@ func TestAccountManager_AddPeerWithUserID(t *testing.T) { assert.Equal(t, peer.IP.String(), fmt.Sprint(ev.Meta["ip"])) } -func TestAccountManager_NetworkUpdates(t *testing.T) { - manager, err := createManager(t) - if err != nil { - t.Fatal(err) - return +func TestAccountManager_NetworkUpdates_SaveGroup(t *testing.T) { + manager, account, peer1, peer2, peer3 := setupNetworkMapTest(t) + + updMsg := manager.peersUpdateManager.CreateChannel(context.Background(), peer1.ID) + defer manager.peersUpdateManager.CloseChannel(context.Background(), peer1.ID) + + group := group.Group{ + ID: "group-id", + Name: "GroupA", + Peers: []string{peer1.ID, peer2.ID, peer3.ID}, } - userID := "account_creator" + wg := sync.WaitGroup{} + wg.Add(1) + go func() { + defer wg.Done() - account, err := createAccount(manager, "test_account", userID, "") - if err != nil { - t.Fatal(err) - } - - setupKey, err := manager.CreateSetupKey(context.Background(), account.Id, "test-key", SetupKeyReusable, time.Hour, nil, 999, userID, false) - if err != nil { - t.Fatal("error creating setup key") - return - } - - if account.Network.Serial != 0 { - t.Errorf("expecting account network to have an initial Serial=0") - return - } - - getPeer := func() *nbpeer.Peer { - key, err := wgtypes.GeneratePrivateKey() - if err != nil { - t.Fatal(err) - return nil + message := <-updMsg + networkMap := message.Update.GetNetworkMap() + if len(networkMap.RemotePeers) != 2 { + t.Errorf("mismatch peers count: 2 expected, got %v", len(networkMap.RemotePeers)) } - expectedPeerKey := key.PublicKey().String() + }() - peer, _, _, err := manager.AddPeer(context.Background(), setupKey.Key, "", &nbpeer.Peer{ - Key: expectedPeerKey, - Meta: nbpeer.PeerSystemMeta{Hostname: expectedPeerKey}, - }) - if err != nil { - t.Fatalf("expecting peer1 to be added, got failure %v", err) - return nil - } - - return peer - } - - peer1 := getPeer() - peer2 := getPeer() - peer3 := getPeer() - - account, err = manager.Store.GetAccount(context.Background(), account.Id) - if err != nil { - t.Fatal(err) + if err := manager.SaveGroup(context.Background(), account.Id, userID, &group); err != nil { + t.Errorf("save group: %v", err) return } + wg.Wait() +} + +func TestAccountManager_NetworkUpdates_DeletePolicy(t *testing.T) { + manager, account, peer1, _, _ := setupNetworkMapTest(t) + + updMsg := manager.peersUpdateManager.CreateChannel(context.Background(), peer1.ID) + defer manager.peersUpdateManager.CloseChannel(context.Background(), peer1.ID) + + wg := sync.WaitGroup{} + wg.Add(1) + go func() { + defer wg.Done() + + message := <-updMsg + networkMap := message.Update.GetNetworkMap() + if len(networkMap.RemotePeers) != 0 { + t.Errorf("mismatch peers count: 0 expected, got %v", len(networkMap.RemotePeers)) + } + }() + + if err := manager.DeletePolicy(context.Background(), account.Id, account.Policies[0].ID, userID); err != nil { + t.Errorf("delete default rule: %v", err) + return + } + + wg.Wait() +} + +func TestAccountManager_NetworkUpdates_SavePolicy(t *testing.T) { + manager, account, peer1, _, _ := setupNetworkMapTest(t) + + updMsg := manager.peersUpdateManager.CreateChannel(context.Background(), peer1.ID) + defer manager.peersUpdateManager.CloseChannel(context.Background(), peer1.ID) + + policy := Policy{ + Enabled: true, + Rules: []*PolicyRule{ + { + Enabled: true, + Sources: []string{"group-id"}, + Destinations: []string{"group-id"}, + Bidirectional: true, + Action: PolicyTrafficActionAccept, + }, + }, + } + + wg := sync.WaitGroup{} + wg.Add(1) + go func() { + defer wg.Done() + + message := <-updMsg + networkMap := message.Update.GetNetworkMap() + if len(networkMap.RemotePeers) != 2 { + t.Errorf("mismatch peers count: 2 expected, got %v", len(networkMap.RemotePeers)) + } + }() + + if err := manager.SavePolicy(context.Background(), account.Id, userID, &policy); err != nil { + t.Errorf("save policy: %v", err) + return + } + + wg.Wait() +} + +func TestAccountManager_NetworkUpdates_DeletePeer(t *testing.T) { + manager, account, peer1, _, peer3 := setupNetworkMapTest(t) + + updMsg := manager.peersUpdateManager.CreateChannel(context.Background(), peer1.ID) + defer manager.peersUpdateManager.CloseChannel(context.Background(), peer1.ID) + + wg := sync.WaitGroup{} + wg.Add(1) + go func() { + defer wg.Done() + + message := <-updMsg + networkMap := message.Update.GetNetworkMap() + if len(networkMap.RemotePeers) != 1 { + t.Errorf("mismatch peers count: 1 expected, got %v", len(networkMap.RemotePeers)) + } + }() + + if err := manager.DeletePeer(context.Background(), account.Id, peer3.ID, userID); err != nil { + t.Errorf("delete peer: %v", err) + return + } + + wg.Wait() +} + +func TestAccountManager_NetworkUpdates_DeleteGroup(t *testing.T) { + manager, account, peer1, peer2, peer3 := setupNetworkMapTest(t) + updMsg := manager.peersUpdateManager.CreateChannel(context.Background(), peer1.ID) defer manager.peersUpdateManager.CloseChannel(context.Background(), peer1.ID) @@ -1185,108 +1256,40 @@ func TestAccountManager_NetworkUpdates(t *testing.T) { }, } + if err := manager.DeletePolicy(context.Background(), account.Id, account.Policies[0].ID, userID); err != nil { + t.Errorf("delete default rule: %v", err) + return + } + + if err := manager.SavePolicy(context.Background(), account.Id, userID, &policy); err != nil { + t.Errorf("save policy: %v", err) + return + } + wg := sync.WaitGroup{} - t.Run("save group update", func(t *testing.T) { - wg.Add(1) - go func() { - defer wg.Done() + wg.Add(1) + go func() { + defer wg.Done() - message := <-updMsg - networkMap := message.Update.GetNetworkMap() - if len(networkMap.RemotePeers) != 2 { - t.Errorf("mismatch peers count: 2 expected, got %v", len(networkMap.RemotePeers)) - } - }() - - if err := manager.SaveGroup(context.Background(), account.Id, userID, &group); err != nil { - t.Errorf("save group: %v", err) - return + message := <-updMsg + networkMap := message.Update.GetNetworkMap() + if len(networkMap.RemotePeers) != 0 { + t.Errorf("mismatch peers count: 0 expected, got %v", len(networkMap.RemotePeers)) } + }() - wg.Wait() - }) + // clean policy is pre requirement for delete group + if err := manager.DeletePolicy(context.Background(), account.Id, policy.ID, userID); err != nil { + t.Errorf("delete default rule: %v", err) + return + } - t.Run("delete policy update", func(t *testing.T) { - wg.Add(1) - go func() { - defer wg.Done() + if err := manager.DeleteGroup(context.Background(), account.Id, "", group.ID); err != nil { + t.Errorf("delete group: %v", err) + return + } - message := <-updMsg - networkMap := message.Update.GetNetworkMap() - if len(networkMap.RemotePeers) != 0 { - t.Errorf("mismatch peers count: 0 expected, got %v", len(networkMap.RemotePeers)) - } - }() - - if err := manager.DeletePolicy(context.Background(), account.Id, account.Policies[0].ID, userID); err != nil { - t.Errorf("delete default rule: %v", err) - return - } - - wg.Wait() - }) - - t.Run("save policy update", func(t *testing.T) { - wg.Add(1) - go func() { - defer wg.Done() - - message := <-updMsg - networkMap := message.Update.GetNetworkMap() - if len(networkMap.RemotePeers) != 2 { - t.Errorf("mismatch peers count: 2 expected, got %v", len(networkMap.RemotePeers)) - } - }() - - if err := manager.SavePolicy(context.Background(), account.Id, userID, &policy); err != nil { - t.Errorf("delete default rule: %v", err) - return - } - - wg.Wait() - }) - t.Run("delete peer update", func(t *testing.T) { - wg.Add(1) - go func() { - defer wg.Done() - - message := <-updMsg - networkMap := message.Update.GetNetworkMap() - if len(networkMap.RemotePeers) != 1 { - t.Errorf("mismatch peers count: 1 expected, got %v", len(networkMap.RemotePeers)) - } - }() - - if err := manager.DeletePeer(context.Background(), account.Id, peer3.ID, userID); err != nil { - t.Errorf("delete peer: %v", err) - return - } - - wg.Wait() - }) - - t.Run("delete group update", func(t *testing.T) { - wg.Add(1) - go func() { - defer wg.Done() - - message := <-updMsg - networkMap := message.Update.GetNetworkMap() - if len(networkMap.RemotePeers) != 0 { - t.Errorf("mismatch peers count: 0 expected, got %v", len(networkMap.RemotePeers)) - } - }() - - // clean policy is pre requirement for delete group - _ = manager.DeletePolicy(context.Background(), account.Id, policy.ID, userID) - - if err := manager.DeleteGroup(context.Background(), account.Id, "", group.ID); err != nil { - t.Errorf("delete group: %v", err) - return - } - - wg.Wait() - }) + wg.Wait() } func TestAccountManager_DeletePeer(t *testing.T) { @@ -2328,3 +2331,46 @@ func waitTimeout(wg *sync.WaitGroup, timeout time.Duration) bool { return true } } + +func setupNetworkMapTest(t *testing.T) (*DefaultAccountManager, *Account, *nbpeer.Peer, *nbpeer.Peer, *nbpeer.Peer) { + t.Helper() + + manager, err := createManager(t) + if err != nil { + t.Fatal(err) + } + + account, err := createAccount(manager, "test_account", userID, "") + if err != nil { + t.Fatal(err) + } + + setupKey, err := manager.CreateSetupKey(context.Background(), account.Id, "test-key", SetupKeyReusable, time.Hour, nil, 999, userID, false) + if err != nil { + t.Fatal("error creating setup key") + } + + getPeer := func(manager *DefaultAccountManager, setupKey *SetupKey) *nbpeer.Peer { + key, err := wgtypes.GeneratePrivateKey() + if err != nil { + t.Fatal(err) + } + expectedPeerKey := key.PublicKey().String() + + peer, _, _, err := manager.AddPeer(context.Background(), setupKey.Key, "", &nbpeer.Peer{ + Key: expectedPeerKey, + Meta: nbpeer.PeerSystemMeta{Hostname: expectedPeerKey}, + }) + if err != nil { + t.Fatalf("expecting peer to be added, got failure %v", err) + } + + return peer + } + + peer1 := getPeer(manager, setupKey) + peer2 := getPeer(manager, setupKey) + peer3 := getPeer(manager, setupKey) + + return manager, account, peer1, peer2, peer3 +} diff --git a/management/server/network.go b/management/server/network.go index 0e7d753a7..91d844c3e 100644 --- a/management/server/network.go +++ b/management/server/network.go @@ -40,9 +40,9 @@ type Network struct { Dns string // Serial is an ID that increments by 1 when any change to the network happened (e.g. new peer has been added). // Used to synchronize state to the client apps. - Serial uint64 + Serial uint64 `diff:"-"` - mu sync.Mutex `json:"-" gorm:"-"` + mu sync.Mutex `json:"-" gorm:"-" diff:"-"` } // NewNetwork creates a new Network initializing it with a Serial=0 diff --git a/management/server/peer.go b/management/server/peer.go index b8605fbb7..ff30fb1ff 100644 --- a/management/server/peer.go +++ b/management/server/peer.go @@ -261,6 +261,8 @@ func (am *DefaultAccountManager) deletePeers(ctx context.Context, account *Accou FirewallRulesIsEmpty: true, }, }, + NetworkMap: &NetworkMap{}, + Checks: []*posture.Checks{}, }) am.peersUpdateManager.CloseChannel(ctx, peer.ID) am.StoreEvent(ctx, userID, peer.ID, account.Id, activity.PeerRemovedByUser, peer.EventMeta(am.GetDNSDomain())) @@ -932,6 +934,6 @@ func (am *DefaultAccountManager) updateAccountPeers(ctx context.Context, account postureChecks := am.getPeerPostureChecks(account, peer) remotePeerNetworkMap := account.GetPeerNetworkMap(ctx, peer.ID, am.dnsDomain, approvedPeersMap) update := toSyncResponse(ctx, nil, peer, nil, remotePeerNetworkMap, am.GetDNSDomain(), postureChecks) - am.peersUpdateManager.SendUpdate(ctx, peer.ID, &UpdateMessage{Update: update}) + go am.peersUpdateManager.SendUpdate(ctx, peer.ID, &UpdateMessage{Update: update, NetworkMap: remotePeerNetworkMap, Checks: postureChecks}) } } diff --git a/management/server/peer/peer.go b/management/server/peer/peer.go index 4f808a79e..a193ac6df 100644 --- a/management/server/peer/peer.go +++ b/management/server/peer/peer.go @@ -18,35 +18,35 @@ type Peer struct { // WireGuard public key Key string `gorm:"index"` // A setup key this peer was registered with - SetupKey string + SetupKey string `diff:"-"` // IP address of the Peer IP net.IP `gorm:"serializer:json"` // Meta is a Peer system meta data - Meta PeerSystemMeta `gorm:"embedded;embeddedPrefix:meta_"` + Meta PeerSystemMeta `gorm:"embedded;embeddedPrefix:meta_" diff:"-"` // Name is peer's name (machine name) Name string // DNSLabel is the parsed peer name for domain resolution. It is used to form an FQDN by appending the account's // domain to the peer label. e.g. peer-dns-label.netbird.cloud DNSLabel string // Status peer's management connection status - Status *PeerStatus `gorm:"embedded;embeddedPrefix:peer_status_"` + Status *PeerStatus `gorm:"embedded;embeddedPrefix:peer_status_" diff:"-"` // The user ID that registered the peer - UserID string + UserID string `diff:"-"` // SSHKey is a public SSH key of the peer SSHKey string // SSHEnabled indicates whether SSH server is enabled on the peer SSHEnabled bool // LoginExpirationEnabled indicates whether peer's login expiration is enabled and once expired the peer has to re-login. // Works with LastLogin - LoginExpirationEnabled bool + LoginExpirationEnabled bool `diff:"-"` // LastLogin the time when peer performed last login operation - LastLogin time.Time + LastLogin time.Time `diff:"-"` // CreatedAt records the time the peer was created - CreatedAt time.Time + CreatedAt time.Time `diff:"-"` // Indicate ephemeral peer attribute - Ephemeral bool + Ephemeral bool `diff:"-"` // Geo location based on connection IP - Location Location `gorm:"embedded;embeddedPrefix:location_"` + Location Location `gorm:"embedded;embeddedPrefix:location_" diff:"-"` } type PeerStatus struct { //nolint:revive diff --git a/management/server/testdata/store.json b/management/server/testdata/store.json index 1fa4e3a9a..6a8fc0712 100644 --- a/management/server/testdata/store.json +++ b/management/server/testdata/store.json @@ -19,7 +19,7 @@ "Revoked": false, "UsedTimes": 0, "LastUsed": "0001-01-01T00:00:00Z", - "AutoGroups": null, + "AutoGroups": ["cq9bbkjjuspi5gd38epg"], "UsageLimit": 0, "Ephemeral": false } @@ -69,9 +69,41 @@ "LastLogin": "0001-01-01T00:00:00Z" } }, - "Groups": null, + "Groups": { + "cq9bbkjjuspi5gd38epg": { + "ID": "cq9bbkjjuspi5gd38epg", + "Name": "All", + "Peers": [] + } + }, "Rules": null, - "Policies": [], + "Policies": [ + { + "ID": "cq9bbkjjuspi5gd38eq0", + "Name": "Default", + "Description": "This is a default rule that allows connections between all the resources", + "Enabled": true, + "Rules": [ + { + "ID": "cq9bbkjjuspi5gd38eq0", + "Name": "Default", + "Description": "This is a default rule that allows connections between all the resources", + "Enabled": true, + "Action": "accept", + "Destinations": [ + "cq9bbkjjuspi5gd38epg" + ], + "Sources": [ + "cq9bbkjjuspi5gd38epg" + ], + "Bidirectional": true, + "Protocol": "all", + "Ports": null + } + ], + "SourcePostureChecks": null + } + ], "Routes": null, "NameServerGroups": null, "DNSSettings": null, diff --git a/management/server/updatechannel.go b/management/server/updatechannel.go index c11225dbc..0db5b323b 100644 --- a/management/server/updatechannel.go +++ b/management/server/updatechannel.go @@ -2,9 +2,12 @@ package server import ( "context" + "fmt" "sync" "time" + "github.com/netbirdio/netbird/management/server/posture" + "github.com/r3labs/diff" log "github.com/sirupsen/logrus" "github.com/netbirdio/netbird/management/proto" @@ -14,14 +17,18 @@ import ( const channelBufferSize = 100 type UpdateMessage struct { - Update *proto.SyncResponse + Update *proto.SyncResponse + NetworkMap *NetworkMap + Checks []*posture.Checks } type PeersUpdateManager struct { // peerChannels is an update channel indexed by Peer.ID peerChannels map[string]chan *UpdateMessage + // peerNetworkMaps is the UpdateMessage indexed by Peer.ID. + peerUpdateMessage map[string]*UpdateMessage // channelsMux keeps the mutex to access peerChannels - channelsMux *sync.Mutex + channelsMux *sync.RWMutex // metrics provides method to collect application metrics metrics telemetry.AppMetrics } @@ -29,9 +36,10 @@ type PeersUpdateManager struct { // NewPeersUpdateManager returns a new instance of PeersUpdateManager func NewPeersUpdateManager(metrics telemetry.AppMetrics) *PeersUpdateManager { return &PeersUpdateManager{ - peerChannels: make(map[string]chan *UpdateMessage), - channelsMux: &sync.Mutex{}, - metrics: metrics, + peerChannels: make(map[string]chan *UpdateMessage), + peerUpdateMessage: make(map[string]*UpdateMessage), + channelsMux: &sync.RWMutex{}, + metrics: metrics, } } @@ -40,7 +48,17 @@ func (p *PeersUpdateManager) SendUpdate(ctx context.Context, peerID string, upda start := time.Now() var found, dropped bool + // skip sending sync update to the peer if there is no change in update message, + // it will not check on turn credential refresh as we do not send network map or client posture checks + if update.NetworkMap != nil { + updated := p.handlePeerMessageUpdate(ctx, peerID, update) + if !updated { + return + } + } + p.channelsMux.Lock() + defer func() { p.channelsMux.Unlock() if p.metrics != nil { @@ -48,6 +66,16 @@ func (p *PeersUpdateManager) SendUpdate(ctx context.Context, peerID string, upda } }() + if update.NetworkMap != nil { + lastSentUpdate := p.peerUpdateMessage[peerID] + if lastSentUpdate != nil && lastSentUpdate.Update.NetworkMap.GetSerial() >= update.Update.NetworkMap.GetSerial() { + log.WithContext(ctx).Debugf("peer %s new network map serial: %d not greater than last sent: %d, skip sending update", + peerID, update.Update.NetworkMap.GetSerial(), lastSentUpdate.Update.NetworkMap.GetSerial()) + return + } + p.peerUpdateMessage[peerID] = update + } + if channel, ok := p.peerChannels[peerID]; ok { found = true select { @@ -80,6 +108,7 @@ func (p *PeersUpdateManager) CreateChannel(ctx context.Context, peerID string) c closed = true delete(p.peerChannels, peerID) close(channel) + delete(p.peerUpdateMessage, peerID) } // mbragin: todo shouldn't it be more? or configurable? channel := make(chan *UpdateMessage, channelBufferSize) @@ -94,6 +123,7 @@ func (p *PeersUpdateManager) closeChannel(ctx context.Context, peerID string) { if channel, ok := p.peerChannels[peerID]; ok { delete(p.peerChannels, peerID) close(channel) + delete(p.peerUpdateMessage, peerID) } log.WithContext(ctx).Debugf("closed updates channel of a peer %s", peerID) @@ -170,3 +200,49 @@ func (p *PeersUpdateManager) HasChannel(peerID string) bool { return ok } + +// handlePeerMessageUpdate checks if the update message for a peer is new and should be sent. +func (p *PeersUpdateManager) handlePeerMessageUpdate(ctx context.Context, peerID string, update *UpdateMessage) bool { + p.channelsMux.RLock() + lastSentUpdate := p.peerUpdateMessage[peerID] + p.channelsMux.RUnlock() + + if lastSentUpdate != nil { + updated, err := isNewPeerUpdateMessage(lastSentUpdate, update) + if err != nil { + log.WithContext(ctx).Errorf("error checking for SyncResponse updates: %v", err) + return false + } + if !updated { + log.WithContext(ctx).Debugf("peer %s network map is not updated, skip sending update", peerID) + return false + } + } + + return true +} + +// isNewPeerUpdateMessage checks if the given current update message is a new update that should be sent. +func isNewPeerUpdateMessage(lastSentUpdate, currUpdateToSend *UpdateMessage) (bool, error) { + if lastSentUpdate.Update.NetworkMap.GetSerial() >= currUpdateToSend.Update.NetworkMap.GetSerial() { + return false, nil + } + + changelog, err := diff.Diff(lastSentUpdate.Checks, currUpdateToSend.Checks) + if err != nil { + return false, fmt.Errorf("failed to diff checks: %v", err) + } + if len(changelog) > 0 { + return true, nil + } + + changelog, err = diff.Diff(lastSentUpdate.NetworkMap, currUpdateToSend.NetworkMap) + if err != nil { + return false, fmt.Errorf("failed to diff network map: %v", err) + } + if len(changelog) > 0 { + return true, nil + } + + return false, nil +} diff --git a/management/server/updatechannel_test.go b/management/server/updatechannel_test.go index 69f5b895c..6d8caab26 100644 --- a/management/server/updatechannel_test.go +++ b/management/server/updatechannel_test.go @@ -6,6 +6,8 @@ import ( "time" "github.com/netbirdio/netbird/management/proto" + "github.com/netbirdio/netbird/management/server/posture" + "github.com/stretchr/testify/assert" ) // var peersUpdater *PeersUpdateManager @@ -77,3 +79,104 @@ func TestCloseChannel(t *testing.T) { t.Error("Error closing the channel") } } + +func TestHandlePeerMessageUpdate(t *testing.T) { + tests := []struct { + name string + peerID string + existingUpdate *UpdateMessage + newUpdate *UpdateMessage + expectedResult bool + }{ + { + name: "update message with turn credentials update", + peerID: "peer", + newUpdate: &UpdateMessage{ + Update: &proto.SyncResponse{ + WiretrusteeConfig: &proto.WiretrusteeConfig{}, + }, + }, + expectedResult: true, + }, + { + name: "update message for peer without existing update", + peerID: "peer1", + newUpdate: &UpdateMessage{ + Update: &proto.SyncResponse{ + NetworkMap: &proto.NetworkMap{Serial: 1}, + }, + NetworkMap: &NetworkMap{Network: &Network{Serial: 2}}, + }, + expectedResult: true, + }, + { + name: "update message with no changes in update", + peerID: "peer2", + existingUpdate: &UpdateMessage{ + Update: &proto.SyncResponse{ + NetworkMap: &proto.NetworkMap{Serial: 1}, + }, + NetworkMap: &NetworkMap{Network: &Network{Serial: 1}}, + Checks: []*posture.Checks{}, + }, + newUpdate: &UpdateMessage{ + Update: &proto.SyncResponse{ + NetworkMap: &proto.NetworkMap{Serial: 1}, + }, + NetworkMap: &NetworkMap{Network: &Network{Serial: 1}}, + Checks: []*posture.Checks{}, + }, + expectedResult: false, + }, + { + name: "update message with changes in checks", + peerID: "peer3", + existingUpdate: &UpdateMessage{ + Update: &proto.SyncResponse{ + NetworkMap: &proto.NetworkMap{Serial: 1}, + }, + NetworkMap: &NetworkMap{Network: &Network{Serial: 1}}, + Checks: []*posture.Checks{}, + }, + newUpdate: &UpdateMessage{ + Update: &proto.SyncResponse{ + NetworkMap: &proto.NetworkMap{Serial: 2}, + }, + NetworkMap: &NetworkMap{Network: &Network{Serial: 2}}, + Checks: []*posture.Checks{{ID: "check1"}}, + }, + expectedResult: true, + }, + { + name: "update message with lower serial number", + peerID: "peer4", + existingUpdate: &UpdateMessage{ + Update: &proto.SyncResponse{ + NetworkMap: &proto.NetworkMap{Serial: 2}, + }, + NetworkMap: &NetworkMap{Network: &Network{Serial: 2}}, + }, + newUpdate: &UpdateMessage{ + Update: &proto.SyncResponse{ + NetworkMap: &proto.NetworkMap{Serial: 1}, + }, + NetworkMap: &NetworkMap{Network: &Network{Serial: 1}}, + }, + expectedResult: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + p := NewPeersUpdateManager(nil) + ctx := context.Background() + + if tt.existingUpdate != nil { + p.peerUpdateMessage[tt.peerID] = tt.existingUpdate + } + + result := p.handlePeerMessageUpdate(ctx, tt.peerID, tt.newUpdate) + assert.Equal(t, tt.expectedResult, result) + }) + } +}