From 121dfda915cbaa17e8a16af1f96a71927fc0ece1 Mon Sep 17 00:00:00 2001 From: Viktor Liu <17948409+lixmal@users.noreply.github.com> Date: Fri, 15 Nov 2024 20:05:26 +0100 Subject: [PATCH 01/21] [client] Fix state manager race conditions (#2890) --- client/internal/dns/server.go | 20 +++---- client/internal/engine.go | 2 +- .../routemanager/refcounter/refcounter.go | 58 +++++++++---------- .../internal/routemanager/systemops/state.go | 23 ++++---- .../systemops/systemops_generic.go | 20 +------ client/internal/statemanager/manager.go | 40 +++++++++---- util/file.go | 55 +++++++++++++----- 7 files changed, 118 insertions(+), 100 deletions(-) diff --git a/client/internal/dns/server.go b/client/internal/dns/server.go index 6c4dccae7..f0277319c 100644 --- a/client/internal/dns/server.go +++ b/client/internal/dns/server.go @@ -7,7 +7,6 @@ import ( "runtime" "strings" "sync" - "time" "github.com/miekg/dns" "github.com/mitchellh/hashstructure/v2" @@ -323,13 +322,9 @@ func (s *DefaultServer) applyConfiguration(update nbdns.Config) error { log.Error(err) } - // persist dns state right away - ctx, cancel := context.WithTimeout(s.ctx, 3*time.Second) - defer cancel() - - // don't block go func() { - if err := s.stateManager.PersistState(ctx); err != nil { + // persist dns state right away + if err := s.stateManager.PersistState(s.ctx); err != nil { log.Errorf("Failed to persist dns state: %v", err) } }() @@ -537,12 +532,11 @@ func (s *DefaultServer) upstreamCallbacks( l.Errorf("Failed to apply nameserver deactivation on the host: %v", err) } - // persist dns state right away - ctx, cancel := context.WithTimeout(s.ctx, 3*time.Second) - defer cancel() - if err := s.stateManager.PersistState(ctx); err != nil { - l.Errorf("Failed to persist dns state: %v", err) - } + go func() { + if err := s.stateManager.PersistState(s.ctx); err != nil { + l.Errorf("Failed to persist dns state: %v", err) + } + }() if runtime.GOOS == "android" && nsGroup.Primary && len(s.hostsDNSHolder.get()) > 0 { s.addHostRootZone() diff --git a/client/internal/engine.go b/client/internal/engine.go index d4a3a561a..1c912220c 100644 --- a/client/internal/engine.go +++ b/client/internal/engine.go @@ -297,7 +297,7 @@ func (e *Engine) Stop() error { if err := e.stateManager.Stop(ctx); err != nil { return fmt.Errorf("failed to stop state manager: %w", err) } - if err := e.stateManager.PersistState(ctx); err != nil { + if err := e.stateManager.PersistState(context.Background()); err != nil { log.Errorf("failed to persist state: %v", err) } diff --git a/client/internal/routemanager/refcounter/refcounter.go b/client/internal/routemanager/refcounter/refcounter.go index 0e230ef40..f2f0a169d 100644 --- a/client/internal/routemanager/refcounter/refcounter.go +++ b/client/internal/routemanager/refcounter/refcounter.go @@ -47,10 +47,9 @@ type RemoveFunc[Key, O any] func(key Key, out O) error type Counter[Key comparable, I, O any] struct { // refCountMap keeps track of the reference Ref for keys refCountMap map[Key]Ref[O] - refCountMu sync.Mutex + mu sync.Mutex // idMap keeps track of the keys associated with an ID for removal idMap map[string][]Key - idMu sync.Mutex add AddFunc[Key, I, O] remove RemoveFunc[Key, O] } @@ -75,10 +74,8 @@ func New[Key comparable, I, O any](add AddFunc[Key, I, O], remove RemoveFunc[Key func (rm *Counter[Key, I, O]) LoadData( existingCounter *Counter[Key, I, O], ) { - rm.refCountMu.Lock() - defer rm.refCountMu.Unlock() - rm.idMu.Lock() - defer rm.idMu.Unlock() + rm.mu.Lock() + defer rm.mu.Unlock() rm.refCountMap = existingCounter.refCountMap rm.idMap = existingCounter.idMap @@ -87,8 +84,8 @@ func (rm *Counter[Key, I, O]) LoadData( // Get retrieves the current reference count and associated data for a key. // If the key doesn't exist, it returns a zero value Ref and false. func (rm *Counter[Key, I, O]) Get(key Key) (Ref[O], bool) { - rm.refCountMu.Lock() - defer rm.refCountMu.Unlock() + rm.mu.Lock() + defer rm.mu.Unlock() ref, ok := rm.refCountMap[key] return ref, ok @@ -97,9 +94,13 @@ func (rm *Counter[Key, I, O]) Get(key Key) (Ref[O], bool) { // Increment increments the reference count for the given key. // If this is the first reference to the key, the AddFunc is called. func (rm *Counter[Key, I, O]) Increment(key Key, in I) (Ref[O], error) { - rm.refCountMu.Lock() - defer rm.refCountMu.Unlock() + rm.mu.Lock() + defer rm.mu.Unlock() + return rm.increment(key, in) +} + +func (rm *Counter[Key, I, O]) increment(key Key, in I) (Ref[O], error) { ref := rm.refCountMap[key] logCallerF("Increasing ref count [%d -> %d] for key %v with In [%v] Out [%v]", ref.Count, ref.Count+1, key, in, ref.Out) @@ -126,10 +127,10 @@ func (rm *Counter[Key, I, O]) Increment(key Key, in I) (Ref[O], error) { // IncrementWithID increments the reference count for the given key and groups it under the given ID. // If this is the first reference to the key, the AddFunc is called. func (rm *Counter[Key, I, O]) IncrementWithID(id string, key Key, in I) (Ref[O], error) { - rm.idMu.Lock() - defer rm.idMu.Unlock() + rm.mu.Lock() + defer rm.mu.Unlock() - ref, err := rm.Increment(key, in) + ref, err := rm.increment(key, in) if err != nil { return ref, fmt.Errorf("with ID: %w", err) } @@ -141,9 +142,12 @@ func (rm *Counter[Key, I, O]) IncrementWithID(id string, key Key, in I) (Ref[O], // Decrement decrements the reference count for the given key. // If the reference count reaches 0, the RemoveFunc is called. func (rm *Counter[Key, I, O]) Decrement(key Key) (Ref[O], error) { - rm.refCountMu.Lock() - defer rm.refCountMu.Unlock() + rm.mu.Lock() + defer rm.mu.Unlock() + return rm.decrement(key) +} +func (rm *Counter[Key, I, O]) decrement(key Key) (Ref[O], error) { ref, ok := rm.refCountMap[key] if !ok { logCallerF("No reference found for key %v", key) @@ -168,12 +172,12 @@ func (rm *Counter[Key, I, O]) Decrement(key Key) (Ref[O], error) { // DecrementWithID decrements the reference count for all keys associated with the given ID. // If the reference count reaches 0, the RemoveFunc is called. func (rm *Counter[Key, I, O]) DecrementWithID(id string) error { - rm.idMu.Lock() - defer rm.idMu.Unlock() + rm.mu.Lock() + defer rm.mu.Unlock() var merr *multierror.Error for _, key := range rm.idMap[id] { - if _, err := rm.Decrement(key); err != nil { + if _, err := rm.decrement(key); err != nil { merr = multierror.Append(merr, err) } } @@ -184,10 +188,8 @@ func (rm *Counter[Key, I, O]) DecrementWithID(id string) error { // Flush removes all references and calls RemoveFunc for each key. func (rm *Counter[Key, I, O]) Flush() error { - rm.refCountMu.Lock() - defer rm.refCountMu.Unlock() - rm.idMu.Lock() - defer rm.idMu.Unlock() + rm.mu.Lock() + defer rm.mu.Unlock() var merr *multierror.Error for key := range rm.refCountMap { @@ -206,10 +208,8 @@ func (rm *Counter[Key, I, O]) Flush() error { // Clear removes all references without calling RemoveFunc. func (rm *Counter[Key, I, O]) Clear() { - rm.refCountMu.Lock() - defer rm.refCountMu.Unlock() - rm.idMu.Lock() - defer rm.idMu.Unlock() + rm.mu.Lock() + defer rm.mu.Unlock() clear(rm.refCountMap) clear(rm.idMap) @@ -217,10 +217,8 @@ func (rm *Counter[Key, I, O]) Clear() { // MarshalJSON implements the json.Marshaler interface for Counter. func (rm *Counter[Key, I, O]) MarshalJSON() ([]byte, error) { - rm.refCountMu.Lock() - defer rm.refCountMu.Unlock() - rm.idMu.Lock() - defer rm.idMu.Unlock() + rm.mu.Lock() + defer rm.mu.Unlock() return json.Marshal(struct { RefCountMap map[Key]Ref[O] `json:"refCountMap"` diff --git a/client/internal/routemanager/systemops/state.go b/client/internal/routemanager/systemops/state.go index 425908922..8e158711e 100644 --- a/client/internal/routemanager/systemops/state.go +++ b/client/internal/routemanager/systemops/state.go @@ -2,31 +2,28 @@ package systemops import ( "net/netip" - "sync" "github.com/netbirdio/netbird/client/internal/routemanager/refcounter" ) -type ShutdownState struct { - Counter *ExclusionCounter `json:"counter,omitempty"` - mu sync.RWMutex -} +type ShutdownState ExclusionCounter func (s *ShutdownState) Name() string { return "route_state" } func (s *ShutdownState) Cleanup() error { - s.mu.RLock() - defer s.mu.RUnlock() - - if s.Counter == nil { - return nil - } - sysops := NewSysOps(nil, nil) sysops.refCounter = refcounter.New[netip.Prefix, struct{}, Nexthop](nil, sysops.removeFromRouteTable) - sysops.refCounter.LoadData(s.Counter) + sysops.refCounter.LoadData((*ExclusionCounter)(s)) return sysops.refCounter.Flush() } + +func (s *ShutdownState) MarshalJSON() ([]byte, error) { + return (*ExclusionCounter)(s).MarshalJSON() +} + +func (s *ShutdownState) UnmarshalJSON(data []byte) error { + return (*ExclusionCounter)(s).UnmarshalJSON(data) +} diff --git a/client/internal/routemanager/systemops/systemops_generic.go b/client/internal/routemanager/systemops/systemops_generic.go index 4ff34aa51..f8b3ebbb8 100644 --- a/client/internal/routemanager/systemops/systemops_generic.go +++ b/client/internal/routemanager/systemops/systemops_generic.go @@ -62,7 +62,7 @@ func (r *SysOps) setupRefCounter(initAddresses []net.IP, stateManager *statemana return nexthop, err }, func(prefix netip.Prefix, nexthop Nexthop) error { - // remove from state even if we have trouble removing it from the route table + // update state even if we have trouble removing it from the route table // it could be already gone r.updateState(stateManager) @@ -75,12 +75,9 @@ func (r *SysOps) setupRefCounter(initAddresses []net.IP, stateManager *statemana return r.setupHooks(initAddresses) } +// updateState updates state on every change so it will be persisted regularly func (r *SysOps) updateState(stateManager *statemanager.Manager) { - state := getState(stateManager) - - state.Counter = r.refCounter - - if err := stateManager.UpdateState(state); err != nil { + if err := stateManager.UpdateState((*ShutdownState)(r.refCounter)); err != nil { log.Errorf("failed to update state: %v", err) } } @@ -532,14 +529,3 @@ func isVpnRoute(addr netip.Addr, vpnRoutes []netip.Prefix, localRoutes []netip.P // Return true if the longest matching prefix is from vpnRoutes return isVpn, longestPrefix } - -func getState(stateManager *statemanager.Manager) *ShutdownState { - var shutdownState *ShutdownState - if state := stateManager.GetState(shutdownState); state != nil { - shutdownState = state.(*ShutdownState) - } else { - shutdownState = &ShutdownState{} - } - - return shutdownState -} diff --git a/client/internal/statemanager/manager.go b/client/internal/statemanager/manager.go index 580ccdfc7..da6dd022f 100644 --- a/client/internal/statemanager/manager.go +++ b/client/internal/statemanager/manager.go @@ -74,15 +74,15 @@ func (m *Manager) Stop(ctx context.Context) error { m.mu.Lock() defer m.mu.Unlock() - if m.cancel != nil { - m.cancel() + if m.cancel == nil { + return nil + } + m.cancel() - select { - case <-ctx.Done(): - return ctx.Err() - case <-m.done: - return nil - } + select { + case <-ctx.Done(): + return ctx.Err() + case <-m.done: } return nil @@ -179,14 +179,18 @@ func (m *Manager) PersistState(ctx context.Context) error { return nil } + bs, err := marshalWithPanicRecovery(m.states) + if err != nil { + return fmt.Errorf("marshal states: %w", err) + } + ctx, cancel := context.WithTimeout(ctx, 5*time.Second) defer cancel() done := make(chan error, 1) - start := time.Now() go func() { - done <- util.WriteJsonWithRestrictedPermission(ctx, m.filePath, m.states) + done <- util.WriteBytesWithRestrictedPermission(ctx, m.filePath, bs) }() select { @@ -286,3 +290,19 @@ func (m *Manager) PerformCleanup() error { return nberrors.FormatErrorOrNil(merr) } + +func marshalWithPanicRecovery(v any) ([]byte, error) { + var bs []byte + var err error + + func() { + defer func() { + if r := recover(); r != nil { + err = fmt.Errorf("panic during marshal: %v", r) + } + }() + bs, err = json.Marshal(v) + }() + + return bs, err +} diff --git a/util/file.go b/util/file.go index 4641cc1b8..f7de7ede2 100644 --- a/util/file.go +++ b/util/file.go @@ -4,6 +4,7 @@ import ( "bytes" "context" "encoding/json" + "errors" "fmt" "io" "os" @@ -14,6 +15,19 @@ import ( log "github.com/sirupsen/logrus" ) +func WriteBytesWithRestrictedPermission(ctx context.Context, file string, bs []byte) error { + configDir, configFileName, err := prepareConfigFileDir(file) + if err != nil { + return fmt.Errorf("prepare config file dir: %w", err) + } + + if err = EnforcePermission(file); err != nil { + return fmt.Errorf("enforce permission: %w", err) + } + + return writeBytes(ctx, file, err, configDir, configFileName, bs) +} + // WriteJsonWithRestrictedPermission writes JSON config object to a file. Enforces permission on the parent directory func WriteJsonWithRestrictedPermission(ctx context.Context, file string, obj interface{}) error { configDir, configFileName, err := prepareConfigFileDir(file) @@ -82,29 +96,44 @@ func DirectWriteJson(ctx context.Context, file string, obj interface{}) error { func writeJson(ctx context.Context, file string, obj interface{}, configDir string, configFileName string) error { // Check context before expensive operations if ctx.Err() != nil { - return ctx.Err() + return fmt.Errorf("write json start: %w", ctx.Err()) } // make it pretty bs, err := json.MarshalIndent(obj, "", " ") if err != nil { - return err + return fmt.Errorf("marshal: %w", err) } + return writeBytes(ctx, file, err, configDir, configFileName, bs) +} + +func writeBytes(ctx context.Context, file string, err error, configDir string, configFileName string, bs []byte) error { if ctx.Err() != nil { - return ctx.Err() + return fmt.Errorf("write bytes start: %w", ctx.Err()) } tempFile, err := os.CreateTemp(configDir, ".*"+configFileName) if err != nil { - return err + return fmt.Errorf("create temp: %w", err) } tempFileName := tempFile.Name() - // closing file ops as windows doesn't allow to move it - err = tempFile.Close() + + if deadline, ok := ctx.Deadline(); ok { + if err := tempFile.SetDeadline(deadline); err != nil && !errors.Is(err, os.ErrNoDeadline) { + log.Warnf("failed to set deadline: %v", err) + } + } + + _, err = tempFile.Write(bs) if err != nil { - return err + _ = tempFile.Close() + return fmt.Errorf("write: %w", err) + } + + if err = tempFile.Close(); err != nil { + return fmt.Errorf("close %s: %w", tempFileName, err) } defer func() { @@ -114,19 +143,13 @@ func writeJson(ctx context.Context, file string, obj interface{}, configDir stri } }() - err = os.WriteFile(tempFileName, bs, 0600) - if err != nil { - return err - } - // Check context again if ctx.Err() != nil { - return ctx.Err() + return fmt.Errorf("after temp file: %w", ctx.Err()) } - err = os.Rename(tempFileName, file) - if err != nil { - return err + if err = os.Rename(tempFileName, file); err != nil { + return fmt.Errorf("move %s to %s: %w", tempFileName, file, err) } return nil From 582bb587140789884c8905e7ae8c8c3e732077dc Mon Sep 17 00:00:00 2001 From: Viktor Liu <17948409+lixmal@users.noreply.github.com> Date: Fri, 15 Nov 2024 22:55:33 +0100 Subject: [PATCH 02/21] Move state updates outside the refcounter (#2897) --- .../systemops/systemops_generic.go | 18 +++++++----------- 1 file changed, 7 insertions(+), 11 deletions(-) diff --git a/client/internal/routemanager/systemops/systemops_generic.go b/client/internal/routemanager/systemops/systemops_generic.go index f8b3ebbb8..3038c3ec5 100644 --- a/client/internal/routemanager/systemops/systemops_generic.go +++ b/client/internal/routemanager/systemops/systemops_generic.go @@ -57,22 +57,14 @@ func (r *SysOps) setupRefCounter(initAddresses []net.IP, stateManager *statemana return nexthop, refcounter.ErrIgnore } - r.updateState(stateManager) - return nexthop, err }, - func(prefix netip.Prefix, nexthop Nexthop) error { - // update state even if we have trouble removing it from the route table - // it could be already gone - r.updateState(stateManager) - - return r.removeFromRouteTable(prefix, nexthop) - }, + r.removeFromRouteTable, ) r.refCounter = refCounter - return r.setupHooks(initAddresses) + return r.setupHooks(initAddresses, stateManager) } // updateState updates state on every change so it will be persisted regularly @@ -333,7 +325,7 @@ func (r *SysOps) genericRemoveVPNRoute(prefix netip.Prefix, intf *net.Interface) return r.removeFromRouteTable(prefix, nextHop) } -func (r *SysOps) setupHooks(initAddresses []net.IP) (nbnet.AddHookFunc, nbnet.RemoveHookFunc, error) { +func (r *SysOps) setupHooks(initAddresses []net.IP, stateManager *statemanager.Manager) (nbnet.AddHookFunc, nbnet.RemoveHookFunc, error) { beforeHook := func(connID nbnet.ConnectionID, ip net.IP) error { prefix, err := util.GetPrefixFromIP(ip) if err != nil { @@ -344,6 +336,8 @@ func (r *SysOps) setupHooks(initAddresses []net.IP) (nbnet.AddHookFunc, nbnet.Re return fmt.Errorf("adding route reference: %v", err) } + r.updateState(stateManager) + return nil } afterHook := func(connID nbnet.ConnectionID) error { @@ -351,6 +345,8 @@ func (r *SysOps) setupHooks(initAddresses []net.IP) (nbnet.AddHookFunc, nbnet.Re return fmt.Errorf("remove route reference: %w", err) } + r.updateState(stateManager) + return nil } From a7d5c522033beef9174898b5e43a907b282f7341 Mon Sep 17 00:00:00 2001 From: Viktor Liu <17948409+lixmal@users.noreply.github.com> Date: Fri, 15 Nov 2024 22:59:49 +0100 Subject: [PATCH 03/21] Fix error state race on mgmt connection error (#2892) --- client/internal/connect.go | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/client/internal/connect.go b/client/internal/connect.go index dff44f1d2..f76aa066b 100644 --- a/client/internal/connect.go +++ b/client/internal/connect.go @@ -157,7 +157,8 @@ func (c *ConnectClient) run(mobileDependency MobileDependency, probes *ProbeHold engineCtx, cancel := context.WithCancel(c.ctx) defer func() { - c.statusRecorder.MarkManagementDisconnected(state.err) + _, err := state.Status() + c.statusRecorder.MarkManagementDisconnected(err) c.statusRecorder.CleanLocalPeerState() cancel() }() From ec543f89fb819b4aae28850b370f0c06f05f7f96 Mon Sep 17 00:00:00 2001 From: Kursat Aktas Date: Sat, 16 Nov 2024 17:45:31 +0300 Subject: [PATCH 04/21] Introducing NetBird Guru on Gurubase.io (#2778) --- README.md | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/README.md b/README.md index 270c9ad87..a2d7f3897 100644 --- a/README.md +++ b/README.md @@ -19,6 +19,10 @@
+ +
+ +

From 65a94f695f09063d505f49a6b2496f7a2e4d48b4 Mon Sep 17 00:00:00 2001 From: Maycon Santos Date: Mon, 18 Nov 2024 12:55:02 +0100 Subject: [PATCH 05/21] use google domain for tests (#2902) --- client/internal/dns/server_test.go | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/client/internal/dns/server_test.go b/client/internal/dns/server_test.go index 21f1f1b7d..eab9f4ecb 100644 --- a/client/internal/dns/server_test.go +++ b/client/internal/dns/server_test.go @@ -782,7 +782,7 @@ func TestDNSPermanent_matchOnly(t *testing.T) { Port: 53, }, }, - Domains: []string{"customdomain.com"}, + Domains: []string{"google.com"}, Primary: false, }, }, @@ -804,7 +804,7 @@ func TestDNSPermanent_matchOnly(t *testing.T) { if ips[0] != zoneRecords[0].RData { t.Fatalf("invalid zone record: %v", err) } - _, err = resolver.LookupHost(context.Background(), "customdomain.com") + _, err = resolver.LookupHost(context.Background(), "google.com") if err != nil { t.Errorf("failed to resolve: %s", err) } From 78fab877c07ee50aae95ed037218ec41f0bf2489 Mon Sep 17 00:00:00 2001 From: Maycon Santos Date: Mon, 18 Nov 2024 15:31:53 +0100 Subject: [PATCH 06/21] [misc] Update signing pipeline version (#2900) --- .github/workflows/release.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/release.yml b/.github/workflows/release.yml index 14e383a27..183cdb02c 100644 --- a/.github/workflows/release.yml +++ b/.github/workflows/release.yml @@ -9,7 +9,7 @@ on: pull_request: env: - SIGN_PIPE_VER: "v0.0.16" + SIGN_PIPE_VER: "v0.0.17" GORELEASER_VER: "v2.3.2" PRODUCT_NAME: "NetBird" COPYRIGHT: "Wiretrustee UG (haftungsbeschreankt)" From 52ea2e84e9aa3c03fe43c5098ab50c94ff2e0818 Mon Sep 17 00:00:00 2001 From: Pascal Fischer <32096965+pascal-fischer@users.noreply.github.com> Date: Tue, 19 Nov 2024 00:04:50 +0100 Subject: [PATCH 07/21] [management] Add transaction metrics and exclude getAccount time from peers update (#2904) --- management/server/peer.go | 11 ++++++----- management/server/sql_store.go | 11 ++++++++++- management/server/telemetry/store_metrics.go | 12 ++++++++++++ 3 files changed, 28 insertions(+), 6 deletions(-) diff --git a/management/server/peer.go b/management/server/peer.go index 1405dead8..8c45e45c9 100644 --- a/management/server/peer.go +++ b/management/server/peer.go @@ -988,6 +988,12 @@ func (am *DefaultAccountManager) GetPeer(ctx context.Context, accountID, peerID, // updateAccountPeers updates all peers that belong to an account. // Should be called when changes have to be synced to peers. func (am *DefaultAccountManager) updateAccountPeers(ctx context.Context, accountID string) { + account, err := am.requestBuffer.GetAccountWithBackpressure(ctx, accountID) + if err != nil { + log.WithContext(ctx).Errorf("failed to send out updates to peers: %v", err) + return + } + start := time.Now() defer func() { if am.metrics != nil { @@ -995,11 +1001,6 @@ func (am *DefaultAccountManager) updateAccountPeers(ctx context.Context, account } }() - account, err := am.requestBuffer.GetAccountWithBackpressure(ctx, accountID) - if err != nil { - log.WithContext(ctx).Errorf("failed to send out updates to peers: %v", err) - return - } peers := account.GetPeers() approvedPeersMap, err := am.GetValidatedPeers(account) diff --git a/management/server/sql_store.go b/management/server/sql_store.go index 0ebda6440..278f5443d 100644 --- a/management/server/sql_store.go +++ b/management/server/sql_store.go @@ -1123,6 +1123,7 @@ func (s *SqlStore) IncrementNetworkSerial(ctx context.Context, lockStrength Lock } func (s *SqlStore) ExecuteInTransaction(ctx context.Context, operation func(store Store) error) error { + startTime := time.Now() tx := s.db.Begin() if tx.Error != nil { return tx.Error @@ -1133,7 +1134,15 @@ func (s *SqlStore) ExecuteInTransaction(ctx context.Context, operation func(stor tx.Rollback() return err } - return tx.Commit().Error + + err = tx.Commit().Error + + log.WithContext(ctx).Tracef("transaction took %v", time.Since(startTime)) + if s.metrics != nil { + s.metrics.StoreMetrics().CountTransactionDuration(time.Since(startTime)) + } + + return err } func (s *SqlStore) withTx(tx *gorm.DB) Store { diff --git a/management/server/telemetry/store_metrics.go b/management/server/telemetry/store_metrics.go index b038c3d36..bb3745b5a 100644 --- a/management/server/telemetry/store_metrics.go +++ b/management/server/telemetry/store_metrics.go @@ -13,6 +13,7 @@ type StoreMetrics struct { globalLockAcquisitionDurationMs metric.Int64Histogram persistenceDurationMicro metric.Int64Histogram persistenceDurationMs metric.Int64Histogram + transactionDurationMs metric.Int64Histogram ctx context.Context } @@ -40,11 +41,17 @@ func NewStoreMetrics(ctx context.Context, meter metric.Meter) (*StoreMetrics, er return nil, err } + transactionDurationMs, err := meter.Int64Histogram("management.store.transaction.duration.ms") + if err != nil { + return nil, err + } + return &StoreMetrics{ globalLockAcquisitionDurationMicro: globalLockAcquisitionDurationMicro, globalLockAcquisitionDurationMs: globalLockAcquisitionDurationMs, persistenceDurationMicro: persistenceDurationMicro, persistenceDurationMs: persistenceDurationMs, + transactionDurationMs: transactionDurationMs, ctx: ctx, }, nil } @@ -60,3 +67,8 @@ func (metrics *StoreMetrics) CountPersistenceDuration(duration time.Duration) { metrics.persistenceDurationMicro.Record(metrics.ctx, duration.Microseconds()) metrics.persistenceDurationMs.Record(metrics.ctx, duration.Milliseconds()) } + +// CountTransactionDuration counts the duration of a store persistence operation +func (metrics *StoreMetrics) CountTransactionDuration(duration time.Duration) { + metrics.transactionDurationMs.Record(metrics.ctx, duration.Milliseconds()) +} From eb5d0569ae0ce829a312e13ab3e9757b9cdf019f Mon Sep 17 00:00:00 2001 From: "Krzysztof Nazarewski (kdn)" Date: Tue, 19 Nov 2024 14:14:58 +0100 Subject: [PATCH 08/21] [client] Add NB_SKIP_SOCKET_MARK & fix crash instead of returing an error (#2899) * dialer: fix crash instead of returning error * add NB_SKIP_SOCKET_MARK --- .../routemanager/systemops/systemops_linux.go | 2 +- util/grpc/dialer.go | 9 +++++++-- util/net/dialer_nonios.go | 2 +- util/net/net_linux.go | 12 ++++++++++++ 4 files changed, 21 insertions(+), 4 deletions(-) diff --git a/client/internal/routemanager/systemops/systemops_linux.go b/client/internal/routemanager/systemops/systemops_linux.go index 0124fd95e..71a0f26ae 100644 --- a/client/internal/routemanager/systemops/systemops_linux.go +++ b/client/internal/routemanager/systemops/systemops_linux.go @@ -55,7 +55,7 @@ type ruleParams struct { // isLegacy determines whether to use the legacy routing setup func isLegacy() bool { - return os.Getenv("NB_USE_LEGACY_ROUTING") == "true" || nbnet.CustomRoutingDisabled() + return os.Getenv("NB_USE_LEGACY_ROUTING") == "true" || nbnet.CustomRoutingDisabled() || os.Getenv(nbnet.EnvSkipSocketMark) == "true" } // setIsLegacy sets the legacy routing setup diff --git a/util/grpc/dialer.go b/util/grpc/dialer.go index 57ab8fd55..4fbffe342 100644 --- a/util/grpc/dialer.go +++ b/util/grpc/dialer.go @@ -3,6 +3,9 @@ package grpc import ( "context" "crypto/tls" + "fmt" + "google.golang.org/grpc/codes" + "google.golang.org/grpc/status" "net" "os/user" "runtime" @@ -23,20 +26,22 @@ func WithCustomDialer() grpc.DialOption { if runtime.GOOS == "linux" { currentUser, err := user.Current() if err != nil { - log.Fatalf("failed to get current user: %v", err) + return nil, status.Errorf(codes.FailedPrecondition, "failed to get current user: %v", err) } // the custom dialer requires root permissions which are not required for use cases run as non-root if currentUser.Uid != "0" { + log.Debug("Not running as root, using standard dialer") dialer := &net.Dialer{} return dialer.DialContext(ctx, "tcp", addr) } } + log.Debug("Using nbnet.NewDialer()") conn, err := nbnet.NewDialer().DialContext(ctx, "tcp", addr) if err != nil { log.Errorf("Failed to dial: %s", err) - return nil, err + return nil, fmt.Errorf("nbnet.NewDialer().DialContext: %w", err) } return conn, nil }) diff --git a/util/net/dialer_nonios.go b/util/net/dialer_nonios.go index 4032a75c0..34004a368 100644 --- a/util/net/dialer_nonios.go +++ b/util/net/dialer_nonios.go @@ -69,7 +69,7 @@ func (d *Dialer) DialContext(ctx context.Context, network, address string) (net. conn, err := d.Dialer.DialContext(ctx, network, address) if err != nil { - return nil, fmt.Errorf("dial: %w", err) + return nil, fmt.Errorf("d.Dialer.DialContext: %w", err) } // Wrap the connection in Conn to handle Close with hooks diff --git a/util/net/net_linux.go b/util/net/net_linux.go index 954545eb5..98f49af8d 100644 --- a/util/net/net_linux.go +++ b/util/net/net_linux.go @@ -4,9 +4,14 @@ package net import ( "fmt" + "os" "syscall" + + log "github.com/sirupsen/logrus" ) +const EnvSkipSocketMark = "NB_SKIP_SOCKET_MARK" + // SetSocketMark sets the SO_MARK option on the given socket connection func SetSocketMark(conn syscall.Conn) error { sysconn, err := conn.SyscallConn() @@ -36,6 +41,13 @@ func SetRawSocketMark(conn syscall.RawConn) error { func SetSocketOpt(fd int) error { if CustomRoutingDisabled() { + log.Infof("Custom routing is disabled, skipping SO_MARK") + return nil + } + + // Check for the new environment variable + if skipSocketMark := os.Getenv(EnvSkipSocketMark); skipSocketMark == "true" { + log.Info("NB_SKIP_SOCKET_MARK is set to true, skipping SO_MARK") return nil } From 5dd6a08ea6926cb1cb87a3301524b9031f4db61a Mon Sep 17 00:00:00 2001 From: Pascal Fischer <32096965+pascal-fischer@users.noreply.github.com> Date: Tue, 19 Nov 2024 17:25:49 +0100 Subject: [PATCH 09/21] link peer meta update back to account object (#2911) --- management/server/peer.go | 1 + 1 file changed, 1 insertion(+) diff --git a/management/server/peer.go b/management/server/peer.go index 8c45e45c9..901e4815d 100644 --- a/management/server/peer.go +++ b/management/server/peer.go @@ -667,6 +667,7 @@ func (am *DefaultAccountManager) SyncPeer(ctx context.Context, sync PeerSync, ac updated := peer.UpdateMetaIfNew(sync.Meta) if updated { + account.Peers[peer.ID] = peer log.WithContext(ctx).Tracef("peer %s metadata updated", peer.ID) err = am.Store.SavePeer(ctx, account.Id, peer) if err != nil { From f66bbcc54c65f0856679fed1b50298e97c0ec7af Mon Sep 17 00:00:00 2001 From: Pascal Fischer <32096965+pascal-fischer@users.noreply.github.com> Date: Tue, 19 Nov 2024 18:13:26 +0100 Subject: [PATCH 10/21] [management] Add metric for peer meta update (#2913) --- management/server/peer.go | 2 ++ .../server/telemetry/accountmanager_metrics.go | 12 ++++++++++++ 2 files changed, 14 insertions(+) diff --git a/management/server/peer.go b/management/server/peer.go index 901e4815d..beb833dba 100644 --- a/management/server/peer.go +++ b/management/server/peer.go @@ -667,6 +667,7 @@ func (am *DefaultAccountManager) SyncPeer(ctx context.Context, sync PeerSync, ac updated := peer.UpdateMetaIfNew(sync.Meta) if updated { + am.metrics.AccountManagerMetrics().CountPeerMetUpdate() account.Peers[peer.ID] = peer log.WithContext(ctx).Tracef("peer %s metadata updated", peer.ID) err = am.Store.SavePeer(ctx, account.Id, peer) @@ -801,6 +802,7 @@ func (am *DefaultAccountManager) LoginPeer(ctx context.Context, login PeerLogin) updated := peer.UpdateMetaIfNew(login.Meta) if updated { + am.metrics.AccountManagerMetrics().CountPeerMetUpdate() shouldStorePeer = true } diff --git a/management/server/telemetry/accountmanager_metrics.go b/management/server/telemetry/accountmanager_metrics.go index e4bb4e3c3..4a5a31e2d 100644 --- a/management/server/telemetry/accountmanager_metrics.go +++ b/management/server/telemetry/accountmanager_metrics.go @@ -13,6 +13,7 @@ type AccountManagerMetrics struct { updateAccountPeersDurationMs metric.Float64Histogram getPeerNetworkMapDurationMs metric.Float64Histogram networkMapObjectCount metric.Int64Histogram + peerMetaUpdateCount metric.Int64Counter } // NewAccountManagerMetrics creates an instance of AccountManagerMetrics @@ -44,11 +45,17 @@ func NewAccountManagerMetrics(ctx context.Context, meter metric.Meter) (*Account return nil, err } + peerMetaUpdateCount, err := meter.Int64Counter("management.account.peer.meta.update.counter", metric.WithUnit("1")) + if err != nil { + return nil, err + } + return &AccountManagerMetrics{ ctx: ctx, getPeerNetworkMapDurationMs: getPeerNetworkMapDurationMs, updateAccountPeersDurationMs: updateAccountPeersDurationMs, networkMapObjectCount: networkMapObjectCount, + peerMetaUpdateCount: peerMetaUpdateCount, }, nil } @@ -67,3 +74,8 @@ func (metrics *AccountManagerMetrics) CountGetPeerNetworkMapDuration(duration ti func (metrics *AccountManagerMetrics) CountNetworkMapObjects(count int64) { metrics.networkMapObjectCount.Record(metrics.ctx, count) } + +// CountPeerMetUpdate counts the number of peer meta updates +func (metrics *AccountManagerMetrics) CountPeerMetUpdate() { + metrics.peerMetaUpdateCount.Add(metrics.ctx, 1) +} From aa575d6f445e74f34f8353a4c413adc209c56f4b Mon Sep 17 00:00:00 2001 From: Pascal Fischer <32096965+pascal-fischer@users.noreply.github.com> Date: Thu, 21 Nov 2024 15:10:34 +0100 Subject: [PATCH 11/21] [management] Add activity events to group propagation flow (#2916) --- management/server/account.go | 38 ++++++++++- management/server/activity/codes.go | 6 ++ management/server/user.go | 97 ++++++++++++++++++++++------- 3 files changed, 116 insertions(+), 25 deletions(-) diff --git a/management/server/account.go b/management/server/account.go index 95c93a22b..0ab123655 100644 --- a/management/server/account.go +++ b/management/server/account.go @@ -965,7 +965,9 @@ func (am *DefaultAccountManager) getJWTGroupsChanges(user *User, groups []*nbgro } // UserGroupsAddToPeers adds groups to all peers of user -func (a *Account) UserGroupsAddToPeers(userID string, groups ...string) { +func (a *Account) UserGroupsAddToPeers(userID string, groups ...string) map[string][]string { + groupUpdates := make(map[string][]string) + userPeers := make(map[string]struct{}) for pid, peer := range a.Peers { if peer.UserID == userID { @@ -979,6 +981,8 @@ func (a *Account) UserGroupsAddToPeers(userID string, groups ...string) { continue } + oldPeers := group.Peers + groupPeers := make(map[string]struct{}) for _, pid := range group.Peers { groupPeers[pid] = struct{}{} @@ -992,16 +996,25 @@ func (a *Account) UserGroupsAddToPeers(userID string, groups ...string) { for pid := range groupPeers { group.Peers = append(group.Peers, pid) } + + groupUpdates[gid] = difference(group.Peers, oldPeers) } + + return groupUpdates } // UserGroupsRemoveFromPeers removes groups from all peers of user -func (a *Account) UserGroupsRemoveFromPeers(userID string, groups ...string) { +func (a *Account) UserGroupsRemoveFromPeers(userID string, groups ...string) map[string][]string { + groupUpdates := make(map[string][]string) + for _, gid := range groups { group, ok := a.Groups[gid] if !ok || group.Name == "All" { continue } + + oldPeers := group.Peers + update := make([]string, 0, len(group.Peers)) for _, pid := range group.Peers { peer, ok := a.Peers[pid] @@ -1013,7 +1026,10 @@ func (a *Account) UserGroupsRemoveFromPeers(userID string, groups ...string) { } } group.Peers = update + groupUpdates[gid] = difference(oldPeers, group.Peers) } + + return groupUpdates } // BuildManager creates a new DefaultAccountManager with a provided Store @@ -1175,6 +1191,11 @@ func (am *DefaultAccountManager) UpdateAccountSettings(ctx context.Context, acco return nil, err } + err = am.handleGroupsPropagationSettings(ctx, oldSettings, newSettings, userID, accountID) + if err != nil { + return nil, fmt.Errorf("groups propagation failed: %w", err) + } + updatedAccount := account.UpdateSettings(newSettings) err = am.Store.SaveAccount(ctx, account) @@ -1185,6 +1206,19 @@ func (am *DefaultAccountManager) UpdateAccountSettings(ctx context.Context, acco return updatedAccount, nil } +func (am *DefaultAccountManager) handleGroupsPropagationSettings(ctx context.Context, oldSettings, newSettings *Settings, userID, accountID string) error { + if oldSettings.GroupsPropagationEnabled != newSettings.GroupsPropagationEnabled { + if newSettings.GroupsPropagationEnabled { + am.StoreEvent(ctx, userID, accountID, accountID, activity.UserGroupPropagationEnabled, nil) + // Todo: retroactively add user groups to all peers + } else { + am.StoreEvent(ctx, userID, accountID, accountID, activity.UserGroupPropagationDisabled, nil) + } + } + + return nil +} + func (am *DefaultAccountManager) handleInactivityExpirationSettings(ctx context.Context, account *Account, oldSettings, newSettings *Settings, userID, accountID string) error { if newSettings.PeerInactivityExpirationEnabled { diff --git a/management/server/activity/codes.go b/management/server/activity/codes.go index 603260dbc..4c57d65fb 100644 --- a/management/server/activity/codes.go +++ b/management/server/activity/codes.go @@ -148,6 +148,9 @@ const ( AccountPeerInactivityExpirationDurationUpdated Activity = 67 SetupKeyDeleted Activity = 68 + + UserGroupPropagationEnabled Activity = 69 + UserGroupPropagationDisabled Activity = 70 ) var activityMap = map[Activity]Code{ @@ -222,6 +225,9 @@ var activityMap = map[Activity]Code{ AccountPeerInactivityExpirationDisabled: {"Account peer inactivity expiration disabled", "account.peer.inactivity.expiration.disable"}, AccountPeerInactivityExpirationDurationUpdated: {"Account peer inactivity expiration duration updated", "account.peer.inactivity.expiration.update"}, SetupKeyDeleted: {"Setup key deleted", "setupkey.delete"}, + + UserGroupPropagationEnabled: {"User group propagation enabled", "account.setting.group.propagation.enable"}, + UserGroupPropagationDisabled: {"User group propagation disabled", "account.setting.group.propagation.disable"}, } // StringCode returns a string code of the activity diff --git a/management/server/user.go b/management/server/user.go index 74062112a..edb5e6fd3 100644 --- a/management/server/user.go +++ b/management/server/user.go @@ -805,15 +805,20 @@ func (am *DefaultAccountManager) SaveOrAddUsers(ctx context.Context, accountID, expiredPeers = append(expiredPeers, blockedPeers...) } + peerGroupsAdded := make(map[string][]string) + peerGroupsRemoved := make(map[string][]string) if update.AutoGroups != nil && account.Settings.GroupsPropagationEnabled { removedGroups := difference(oldUser.AutoGroups, update.AutoGroups) // need force update all auto groups in any case they will not be duplicated - account.UserGroupsAddToPeers(oldUser.Id, update.AutoGroups...) - account.UserGroupsRemoveFromPeers(oldUser.Id, removedGroups...) + peerGroupsAdded = account.UserGroupsAddToPeers(oldUser.Id, update.AutoGroups...) + peerGroupsRemoved = account.UserGroupsRemoveFromPeers(oldUser.Id, removedGroups...) } - events := am.prepareUserUpdateEvents(ctx, initiatorUser.Id, oldUser, newUser, account, transferredOwnerRole) - eventsToStore = append(eventsToStore, events...) + userUpdateEvents := am.prepareUserUpdateEvents(ctx, initiatorUser.Id, oldUser, newUser, account, transferredOwnerRole) + eventsToStore = append(eventsToStore, userUpdateEvents...) + + userGroupsEvents := am.prepareUserGroupsEvents(ctx, initiatorUser.Id, oldUser, newUser, account, peerGroupsAdded, peerGroupsRemoved) + eventsToStore = append(eventsToStore, userGroupsEvents...) updatedUserInfo, err := getUserInfo(ctx, am, newUser, account) if err != nil { @@ -872,32 +877,78 @@ func (am *DefaultAccountManager) prepareUserUpdateEvents(ctx context.Context, in }) } + return eventsToStore +} + +func (am *DefaultAccountManager) prepareUserGroupsEvents(ctx context.Context, initiatorUserID string, oldUser, newUser *User, account *Account, peerGroupsAdded, peerGroupsRemoved map[string][]string) []func() { + var eventsToStore []func() if newUser.AutoGroups != nil { removedGroups := difference(oldUser.AutoGroups, newUser.AutoGroups) addedGroups := difference(newUser.AutoGroups, oldUser.AutoGroups) - for _, g := range removedGroups { - group := account.GetGroup(g) - if group != nil { - eventsToStore = append(eventsToStore, func() { - am.StoreEvent(ctx, initiatorUserID, oldUser.Id, account.Id, activity.GroupRemovedFromUser, - map[string]any{"group": group.Name, "group_id": group.ID, "is_service_user": newUser.IsServiceUser, "user_name": newUser.ServiceUserName}) - }) - } else { - log.WithContext(ctx).Errorf("group %s not found while saving user activity event of account %s", g, account.Id) - } - } - for _, g := range addedGroups { - group := account.GetGroup(g) - if group != nil { - eventsToStore = append(eventsToStore, func() { - am.StoreEvent(ctx, initiatorUserID, oldUser.Id, account.Id, activity.GroupAddedToUser, - map[string]any{"group": group.Name, "group_id": group.ID, "is_service_user": newUser.IsServiceUser, "user_name": newUser.ServiceUserName}) - }) - } + removedEvents := am.handleGroupRemovedFromUser(ctx, initiatorUserID, oldUser, newUser, account, removedGroups, peerGroupsRemoved) + eventsToStore = append(eventsToStore, removedEvents...) + + addedEvents := am.handleGroupAddedToUser(ctx, initiatorUserID, oldUser, newUser, account, addedGroups, peerGroupsAdded) + eventsToStore = append(eventsToStore, addedEvents...) + } + return eventsToStore +} + +func (am *DefaultAccountManager) handleGroupAddedToUser(ctx context.Context, initiatorUserID string, oldUser, newUser *User, account *Account, addedGroups []string, peerGroupsAdded map[string][]string) []func() { + var eventsToStore []func() + for _, g := range addedGroups { + group := account.GetGroup(g) + if group != nil { + eventsToStore = append(eventsToStore, func() { + am.StoreEvent(ctx, initiatorUserID, oldUser.Id, account.Id, activity.GroupAddedToUser, + map[string]any{"group": group.Name, "group_id": group.ID, "is_service_user": newUser.IsServiceUser, "user_name": newUser.ServiceUserName}) + }) } } + for groupID, peerIDs := range peerGroupsAdded { + group := account.GetGroup(groupID) + for _, peerID := range peerIDs { + peer := account.GetPeer(peerID) + eventsToStore = append(eventsToStore, func() { + meta := map[string]any{ + "group": group.Name, "group_id": group.ID, + "peer_ip": peer.IP.String(), "peer_fqdn": peer.FQDN(am.GetDNSDomain()), + } + am.StoreEvent(ctx, activity.SystemInitiator, peer.ID, account.Id, activity.GroupAddedToPeer, meta) + }) + } + } + return eventsToStore +} +func (am *DefaultAccountManager) handleGroupRemovedFromUser(ctx context.Context, initiatorUserID string, oldUser, newUser *User, account *Account, removedGroups []string, peerGroupsRemoved map[string][]string) []func() { + var eventsToStore []func() + for _, g := range removedGroups { + group := account.GetGroup(g) + if group != nil { + eventsToStore = append(eventsToStore, func() { + am.StoreEvent(ctx, initiatorUserID, oldUser.Id, account.Id, activity.GroupRemovedFromUser, + map[string]any{"group": group.Name, "group_id": group.ID, "is_service_user": newUser.IsServiceUser, "user_name": newUser.ServiceUserName}) + }) + + } else { + log.WithContext(ctx).Errorf("group %s not found while saving user activity event of account %s", g, account.Id) + } + } + for groupID, peerIDs := range peerGroupsRemoved { + group := account.GetGroup(groupID) + for _, peerID := range peerIDs { + peer := account.GetPeer(peerID) + eventsToStore = append(eventsToStore, func() { + meta := map[string]any{ + "group": group.Name, "group_id": group.ID, + "peer_ip": peer.IP.String(), "peer_fqdn": peer.FQDN(am.GetDNSDomain()), + } + am.StoreEvent(ctx, activity.SystemInitiator, peer.ID, account.Id, activity.GroupRemovedFromPeer, meta) + }) + } + } return eventsToStore } From 1bbabf70b057c2384a077742b9e6760161e153aa Mon Sep 17 00:00:00 2001 From: Viktor Liu <17948409+lixmal@users.noreply.github.com> Date: Thu, 21 Nov 2024 16:53:37 +0100 Subject: [PATCH 12/21] [client] Fix allow netbird rule verdict (#2925) * Fix allow netbird rule verdict * Fix chain name --- client/firewall/nftables/manager_linux.go | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/client/firewall/nftables/manager_linux.go b/client/firewall/nftables/manager_linux.go index 3f8fac249..8e1aa0d80 100644 --- a/client/firewall/nftables/manager_linux.go +++ b/client/firewall/nftables/manager_linux.go @@ -199,7 +199,7 @@ func (m *Manager) AllowNetbird() error { var chain *nftables.Chain for _, c := range chains { - if c.Table.Name == tableNameFilter && c.Name == chainNameForward { + if c.Table.Name == tableNameFilter && c.Name == chainNameInput { chain = c break } @@ -276,7 +276,7 @@ func (m *Manager) resetNetbirdInputRules() error { func (m *Manager) deleteNetbirdInputRules(chains []*nftables.Chain) { for _, c := range chains { - if c.Table.Name == "filter" && c.Name == "INPUT" { + if c.Table.Name == tableNameFilter && c.Name == chainNameInput { rules, err := m.rConn.GetRules(c.Table, c) if err != nil { log.Errorf("get rules for chain %q: %v", c.Name, err) @@ -351,7 +351,9 @@ func (m *Manager) applyAllowNetbirdRules(chain *nftables.Chain) { Register: 1, Data: ifname(m.wgIface.Name()), }, - &expr.Verdict{}, + &expr.Verdict{ + Kind: expr.VerdictAccept, + }, }, UserData: []byte(allowNetbirdInputRuleID), } From 9db1932664557da94ff64bfc03f5acdcf30667ea Mon Sep 17 00:00:00 2001 From: Pascal Fischer <32096965+pascal-fischer@users.noreply.github.com> Date: Fri, 22 Nov 2024 10:15:51 +0100 Subject: [PATCH 13/21] [management] Fix getSetupKey call (#2927) --- management/server/http/api/openapi.yml | 30 +++++++-- management/server/http/api/types.gen.go | 89 ++++++++++++++++++++++++- management/server/setupkey.go | 2 +- management/server/setupkey_test.go | 43 ++++++++---- 4 files changed, 144 insertions(+), 20 deletions(-) diff --git a/management/server/http/api/openapi.yml b/management/server/http/api/openapi.yml index bfb375277..2e084f6e4 100644 --- a/management/server/http/api/openapi.yml +++ b/management/server/http/api/openapi.yml @@ -439,17 +439,13 @@ components: example: 5 required: - accessible_peers_count - SetupKey: + SetupKeyBase: type: object properties: id: description: Setup Key ID type: string example: 2531583362 - key: - description: Setup Key value - type: string - example: A616097E-FCF0-48FA-9354-CA4A61142761 name: description: Setup key name identifier type: string @@ -518,6 +514,28 @@ components: - updated_at - usage_limit - ephemeral + SetupKeyClear: + allOf: + - $ref: '#/components/schemas/SetupKeyBase' + - type: object + properties: + key: + description: Setup Key as plain text + type: string + example: A616097E-FCF0-48FA-9354-CA4A61142761 + required: + - key + SetupKey: + allOf: + - $ref: '#/components/schemas/SetupKeyBase' + - type: object + properties: + key: + description: Setup Key as secret + type: string + example: A6160**** + required: + - key SetupKeyRequest: type: object properties: @@ -1918,7 +1936,7 @@ paths: content: application/json: schema: - $ref: '#/components/schemas/SetupKey' + $ref: '#/components/schemas/SetupKeyClear' '400': "$ref": "#/components/responses/bad_request" '401': diff --git a/management/server/http/api/types.gen.go b/management/server/http/api/types.gen.go index f219c4574..321395d25 100644 --- a/management/server/http/api/types.gen.go +++ b/management/server/http/api/types.gen.go @@ -1062,7 +1062,94 @@ type SetupKey struct { // Id Setup Key ID Id string `json:"id"` - // Key Setup Key value + // Key Setup Key as secret + Key string `json:"key"` + + // LastUsed Setup key last usage date + LastUsed time.Time `json:"last_used"` + + // Name Setup key name identifier + Name string `json:"name"` + + // Revoked Setup key revocation status + Revoked bool `json:"revoked"` + + // State Setup key status, "valid", "overused","expired" or "revoked" + State string `json:"state"` + + // Type Setup key type, one-off for single time usage and reusable + Type string `json:"type"` + + // UpdatedAt Setup key last update date + UpdatedAt time.Time `json:"updated_at"` + + // UsageLimit A number of times this key can be used. The value of 0 indicates the unlimited usage. + UsageLimit int `json:"usage_limit"` + + // UsedTimes Usage count of setup key + UsedTimes int `json:"used_times"` + + // Valid Setup key validity status + Valid bool `json:"valid"` +} + +// SetupKeyBase defines model for SetupKeyBase. +type SetupKeyBase struct { + // AutoGroups List of group IDs to auto-assign to peers registered with this key + AutoGroups []string `json:"auto_groups"` + + // Ephemeral Indicate that the peer will be ephemeral or not + Ephemeral bool `json:"ephemeral"` + + // Expires Setup Key expiration date + Expires time.Time `json:"expires"` + + // Id Setup Key ID + Id string `json:"id"` + + // LastUsed Setup key last usage date + LastUsed time.Time `json:"last_used"` + + // Name Setup key name identifier + Name string `json:"name"` + + // Revoked Setup key revocation status + Revoked bool `json:"revoked"` + + // State Setup key status, "valid", "overused","expired" or "revoked" + State string `json:"state"` + + // Type Setup key type, one-off for single time usage and reusable + Type string `json:"type"` + + // UpdatedAt Setup key last update date + UpdatedAt time.Time `json:"updated_at"` + + // UsageLimit A number of times this key can be used. The value of 0 indicates the unlimited usage. + UsageLimit int `json:"usage_limit"` + + // UsedTimes Usage count of setup key + UsedTimes int `json:"used_times"` + + // Valid Setup key validity status + Valid bool `json:"valid"` +} + +// SetupKeyClear defines model for SetupKeyClear. +type SetupKeyClear struct { + // AutoGroups List of group IDs to auto-assign to peers registered with this key + AutoGroups []string `json:"auto_groups"` + + // Ephemeral Indicate that the peer will be ephemeral or not + Ephemeral bool `json:"ephemeral"` + + // Expires Setup Key expiration date + Expires time.Time `json:"expires"` + + // Id Setup Key ID + Id string `json:"id"` + + // Key Setup Key as plain text Key string `json:"key"` // LastUsed Setup key last usage date diff --git a/management/server/setupkey.go b/management/server/setupkey.go index cae0dfecb..ef431d3ad 100644 --- a/management/server/setupkey.go +++ b/management/server/setupkey.go @@ -379,7 +379,7 @@ func (am *DefaultAccountManager) GetSetupKey(ctx context.Context, accountID, use return nil, status.NewAdminPermissionError() } - setupKey, err := am.Store.GetSetupKeyByID(ctx, LockingStrengthShare, keyID, accountID) + setupKey, err := am.Store.GetSetupKeyByID(ctx, LockingStrengthShare, accountID, keyID) if err != nil { return nil, err } diff --git a/management/server/setupkey_test.go b/management/server/setupkey_test.go index 94ed022fa..7c8200706 100644 --- a/management/server/setupkey_test.go +++ b/management/server/setupkey_test.go @@ -210,22 +210,41 @@ func TestGetSetupKeys(t *testing.T) { t.Fatal(err) } - err = manager.SaveGroup(context.Background(), account.Id, userID, &nbgroup.Group{ - ID: "group_1", - Name: "group_name_1", - Peers: []string{}, - }) + plainKey, err := manager.CreateSetupKey(context.Background(), account.Id, "key1", SetupKeyReusable, time.Hour, nil, SetupKeyUnlimitedUsage, userID, false) if err != nil { t.Fatal(err) } - err = manager.SaveGroup(context.Background(), account.Id, userID, &nbgroup.Group{ - ID: "group_2", - Name: "group_name_2", - Peers: []string{}, - }) - if err != nil { - t.Fatal(err) + type testCase struct { + name string + keyId string + expectedFailure bool + } + + testCase1 := testCase{ + name: "Should get existing Setup Key", + keyId: plainKey.Id, + expectedFailure: false, + } + testCase2 := testCase{ + name: "Should fail to get non-existent Setup Key", + keyId: "some key", + expectedFailure: true, + } + + for _, tCase := range []testCase{testCase1, testCase2} { + t.Run(tCase.name, func(t *testing.T) { + key, err := manager.GetSetupKey(context.Background(), account.Id, userID, tCase.keyId) + + if tCase.expectedFailure { + if err == nil { + t.Fatal("expected to fail") + } + return + } + + assert.NotEqual(t, plainKey.Key, key.Key) + }) } } From 2a5cb1649402d42f588d374f6a775c62e92a5522 Mon Sep 17 00:00:00 2001 From: Zoltan Papp Date: Fri, 22 Nov 2024 18:12:34 +0100 Subject: [PATCH 14/21] [relay] Refactor initial Relay connection (#2800) Can support firewalls with restricted WS rules allow to run engine without Relay servers keep up to date Relay address changes --- client/internal/connect.go | 3 +- client/internal/engine.go | 10 ++- client/internal/peer/status.go | 20 ++--- client/internal/peer/worker_ice.go | 14 +-- relay/client/client.go | 6 +- relay/client/client_test.go | 2 +- relay/client/guard.go | 91 +++++++++++++++---- relay/client/manager.go | 140 +++++++++++++++++++---------- relay/client/picker.go | 16 ++-- relay/client/picker_test.go | 5 +- 10 files changed, 211 insertions(+), 96 deletions(-) diff --git a/client/internal/connect.go b/client/internal/connect.go index f76aa066b..8c2ad4aa1 100644 --- a/client/internal/connect.go +++ b/client/internal/connect.go @@ -232,6 +232,7 @@ func (c *ConnectClient) run(mobileDependency MobileDependency, probes *ProbeHold relayURLs, token := parseRelayInfo(loginResp) relayManager := relayClient.NewManager(engineCtx, relayURLs, myPrivateKey.PublicKey().String()) + c.statusRecorder.SetRelayMgr(relayManager) if len(relayURLs) > 0 { if token != nil { if err := relayManager.UpdateToken(token); err != nil { @@ -242,9 +243,7 @@ func (c *ConnectClient) run(mobileDependency MobileDependency, probes *ProbeHold log.Infof("connecting to the Relay service(s): %s", strings.Join(relayURLs, ", ")) if err = relayManager.Serve(); err != nil { log.Error(err) - return wrapErr(err) } - c.statusRecorder.SetRelayMgr(relayManager) } peerConfig := loginResp.GetPeerConfig() diff --git a/client/internal/engine.go b/client/internal/engine.go index 1c912220c..dc4499e17 100644 --- a/client/internal/engine.go +++ b/client/internal/engine.go @@ -538,6 +538,7 @@ func (e *Engine) handleSync(update *mgmProto.SyncResponse) error { relayMsg := wCfg.GetRelay() if relayMsg != nil { + // when we receive token we expect valid address list too c := &auth.Token{ Payload: relayMsg.GetTokenPayload(), Signature: relayMsg.GetTokenSignature(), @@ -546,9 +547,16 @@ func (e *Engine) handleSync(update *mgmProto.SyncResponse) error { log.Errorf("failed to update relay token: %v", err) return fmt.Errorf("update relay token: %w", err) } + + e.relayManager.UpdateServerURLs(relayMsg.Urls) + + // Just in case the agent started with an MGM server where the relay was disabled but was later enabled. + // We can ignore all errors because the guard will manage the reconnection retries. + _ = e.relayManager.Serve() + } else { + e.relayManager.UpdateServerURLs(nil) } - // todo update relay address in the relay manager // todo update signal } diff --git a/client/internal/peer/status.go b/client/internal/peer/status.go index 0444dc60b..74e2ee82c 100644 --- a/client/internal/peer/status.go +++ b/client/internal/peer/status.go @@ -676,25 +676,23 @@ func (d *Status) GetRelayStates() []relay.ProbeResult { // extend the list of stun, turn servers with relay address relayStates := slices.Clone(d.relayStates) - var relayState relay.ProbeResult - // if the server connection is not established then we will use the general address // in case of connection we will use the instance specific address instanceAddr, err := d.relayMgr.RelayInstanceAddress() if err != nil { // TODO add their status - if errors.Is(err, relayClient.ErrRelayClientNotConnected) { - for _, r := range d.relayMgr.ServerURLs() { - relayStates = append(relayStates, relay.ProbeResult{ - URI: r, - }) - } - return relayStates + for _, r := range d.relayMgr.ServerURLs() { + relayStates = append(relayStates, relay.ProbeResult{ + URI: r, + Err: err, + }) } - relayState.Err = err + return relayStates } - relayState.URI = instanceAddr + relayState := relay.ProbeResult{ + URI: instanceAddr, + } return append(relayStates, relayState) } diff --git a/client/internal/peer/worker_ice.go b/client/internal/peer/worker_ice.go index 4c67cb781..7ce4797c3 100644 --- a/client/internal/peer/worker_ice.go +++ b/client/internal/peer/worker_ice.go @@ -46,8 +46,6 @@ type WorkerICE struct { hasRelayOnLocally bool conn WorkerICECallbacks - selectedPriority ConnPriority - agent *ice.Agent muxAgent sync.Mutex @@ -95,10 +93,8 @@ func (w *WorkerICE) OnNewOffer(remoteOfferAnswer *OfferAnswer) { var preferredCandidateTypes []ice.CandidateType if w.hasRelayOnLocally && remoteOfferAnswer.RelaySrvAddress != "" { - w.selectedPriority = connPriorityICEP2P preferredCandidateTypes = icemaker.CandidateTypesP2P() } else { - w.selectedPriority = connPriorityICETurn preferredCandidateTypes = icemaker.CandidateTypes() } @@ -159,7 +155,7 @@ func (w *WorkerICE) OnNewOffer(remoteOfferAnswer *OfferAnswer) { RelayedOnLocal: isRelayCandidate(pair.Local), } w.log.Debugf("on ICE conn read to use ready") - go w.conn.OnConnReady(w.selectedPriority, ci) + go w.conn.OnConnReady(selectedPriority(pair), ci) } // OnRemoteCandidate Handles ICE connection Candidate provided by the remote peer. @@ -394,3 +390,11 @@ func isRelayed(pair *ice.CandidatePair) bool { } return false } + +func selectedPriority(pair *ice.CandidatePair) ConnPriority { + if isRelayed(pair) { + return connPriorityICETurn + } else { + return connPriorityICEP2P + } +} diff --git a/relay/client/client.go b/relay/client/client.go index 154c1787f..db5252f50 100644 --- a/relay/client/client.go +++ b/relay/client/client.go @@ -140,7 +140,7 @@ type Client struct { instanceURL *RelayAddr muInstanceURL sync.Mutex - onDisconnectListener func() + onDisconnectListener func(string) onConnectedListener func() listenerMutex sync.Mutex } @@ -233,7 +233,7 @@ func (c *Client) ServerInstanceURL() (string, error) { } // SetOnDisconnectListener sets a function that will be called when the connection to the relay server is closed. -func (c *Client) SetOnDisconnectListener(fn func()) { +func (c *Client) SetOnDisconnectListener(fn func(string)) { c.listenerMutex.Lock() defer c.listenerMutex.Unlock() c.onDisconnectListener = fn @@ -554,7 +554,7 @@ func (c *Client) notifyDisconnected() { if c.onDisconnectListener == nil { return } - go c.onDisconnectListener() + go c.onDisconnectListener(c.connectionURL) } func (c *Client) notifyConnected() { diff --git a/relay/client/client_test.go b/relay/client/client_test.go index ef28203e9..7ddfba4c6 100644 --- a/relay/client/client_test.go +++ b/relay/client/client_test.go @@ -551,7 +551,7 @@ func TestCloseByServer(t *testing.T) { } disconnected := make(chan struct{}) - relayClient.SetOnDisconnectListener(func() { + relayClient.SetOnDisconnectListener(func(_ string) { log.Infof("client disconnected") close(disconnected) }) diff --git a/relay/client/guard.go b/relay/client/guard.go index d6b6b0da5..b971363a8 100644 --- a/relay/client/guard.go +++ b/relay/client/guard.go @@ -4,65 +4,120 @@ import ( "context" "time" + "github.com/cenkalti/backoff/v4" log "github.com/sirupsen/logrus" ) var ( - reconnectingTimeout = 5 * time.Second + reconnectingTimeout = 60 * time.Second ) // Guard manage the reconnection tries to the Relay server in case of disconnection event. type Guard struct { - ctx context.Context - relayClient *Client + // OnNewRelayClient is a channel that is used to notify the relay client about a new relay client instance. + OnNewRelayClient chan *Client + serverPicker *ServerPicker } // NewGuard creates a new guard for the relay client. -func NewGuard(context context.Context, relayClient *Client) *Guard { +func NewGuard(sp *ServerPicker) *Guard { g := &Guard{ - ctx: context, - relayClient: relayClient, + OnNewRelayClient: make(chan *Client, 1), + serverPicker: sp, } return g } -// OnDisconnected is called when the relay client is disconnected from the relay server. It will trigger the reconnection +// StartReconnectTrys is called when the relay client is disconnected from the relay server. +// It attempts to reconnect to the relay server. The function first tries a quick reconnect +// to the same server that was used before, if the server URL is still valid. If the quick +// reconnect fails, it starts a ticker to periodically attempt server picking until it +// succeeds or the context is done. +// +// Parameters: +// - ctx: The context to control the lifecycle of the reconnection attempts. +// - relayClient: The relay client instance that was disconnected. // todo prevent multiple reconnection instances. In the current usage it should not happen, but it is better to prevent -func (g *Guard) OnDisconnected() { - if g.quickReconnect() { +func (g *Guard) StartReconnectTrys(ctx context.Context, relayClient *Client) { + if relayClient == nil { + goto RETRY + } + if g.isServerURLStillValid(relayClient) && g.quickReconnect(ctx, relayClient) { return } - ticker := time.NewTicker(reconnectingTimeout) +RETRY: + ticker := exponentTicker(ctx) defer ticker.Stop() for { select { case <-ticker.C: - err := g.relayClient.Connect() - if err != nil { - log.Errorf("failed to reconnect to relay server: %s", err) + if err := g.retry(ctx); err != nil { + log.Errorf("failed to pick new Relay server: %s", err) continue } return - case <-g.ctx.Done(): + case <-ctx.Done(): return } } } -func (g *Guard) quickReconnect() bool { - ctx, cancel := context.WithTimeout(g.ctx, 1500*time.Millisecond) +func (g *Guard) retry(ctx context.Context) error { + log.Infof("try to pick up a new Relay server") + relayClient, err := g.serverPicker.PickServer(ctx) + if err != nil { + return err + } + + // prevent to work with a deprecated Relay client instance + g.drainRelayClientChan() + + g.OnNewRelayClient <- relayClient + return nil +} + +func (g *Guard) quickReconnect(parentCtx context.Context, rc *Client) bool { + ctx, cancel := context.WithTimeout(parentCtx, 1500*time.Millisecond) defer cancel() <-ctx.Done() - if g.ctx.Err() != nil { + if parentCtx.Err() != nil { return false } + log.Infof("try to reconnect to Relay server: %s", rc.connectionURL) - if err := g.relayClient.Connect(); err != nil { + if err := rc.Connect(); err != nil { log.Errorf("failed to reconnect to relay server: %s", err) return false } return true } + +func (g *Guard) drainRelayClientChan() { + select { + case <-g.OnNewRelayClient: + default: + } +} + +func (g *Guard) isServerURLStillValid(rc *Client) bool { + for _, url := range g.serverPicker.ServerURLs.Load().([]string) { + if url == rc.connectionURL { + return true + } + } + return false +} + +func exponentTicker(ctx context.Context) *backoff.Ticker { + bo := backoff.WithContext(&backoff.ExponentialBackOff{ + InitialInterval: 2 * time.Second, + Multiplier: 2, + MaxInterval: reconnectingTimeout, + Clock: backoff.SystemClock, + }, ctx) + + return backoff.NewTicker(bo) +} diff --git a/relay/client/manager.go b/relay/client/manager.go index b14a7701b..d847bb879 100644 --- a/relay/client/manager.go +++ b/relay/client/manager.go @@ -57,12 +57,15 @@ type ManagerService interface { // relay servers will be closed if there is no active connection. Periodically the manager will check if there is any // unused relay connection and close it. type Manager struct { - ctx context.Context - serverURLs []string - peerID string - tokenStore *relayAuth.TokenStore + ctx context.Context + peerID string + running bool + tokenStore *relayAuth.TokenStore + serverPicker *ServerPicker - relayClient *Client + relayClient *Client + // the guard logic can overwrite the relayClient variable, this mutex protect the usage of the variable + relayClientMu sync.Mutex reconnectGuard *Guard relayClients map[string]*RelayTrack @@ -76,48 +79,54 @@ type Manager struct { // NewManager creates a new manager instance. // The serverURL address can be empty. In this case, the manager will not serve. func NewManager(ctx context.Context, serverURLs []string, peerID string) *Manager { - return &Manager{ - ctx: ctx, - serverURLs: serverURLs, - peerID: peerID, - tokenStore: &relayAuth.TokenStore{}, + tokenStore := &relayAuth.TokenStore{} + + m := &Manager{ + ctx: ctx, + peerID: peerID, + tokenStore: tokenStore, + serverPicker: &ServerPicker{ + TokenStore: tokenStore, + PeerID: peerID, + }, relayClients: make(map[string]*RelayTrack), onDisconnectedListeners: make(map[string]*list.List), } + m.serverPicker.ServerURLs.Store(serverURLs) + m.reconnectGuard = NewGuard(m.serverPicker) + return m } -// Serve starts the manager. It will establish a connection to the relay server and start the relay cleanup loop for -// the unused relay connections. The manager will automatically reconnect to the relay server in case of disconnection. +// Serve starts the manager, attempting to establish a connection with the relay server. +// If the connection fails, it will keep trying to reconnect in the background. +// Additionally, it starts a cleanup loop to remove unused relay connections. +// The manager will automatically reconnect to the relay server in case of disconnection. func (m *Manager) Serve() error { - if m.relayClient != nil { + if m.running { return fmt.Errorf("manager already serving") } - log.Debugf("starting relay client manager with %v relay servers", m.serverURLs) + m.running = true + log.Debugf("starting relay client manager with %v relay servers", m.serverPicker.ServerURLs.Load()) - sp := ServerPicker{ - TokenStore: m.tokenStore, - PeerID: m.peerID, - } - - client, err := sp.PickServer(m.ctx, m.serverURLs) + client, err := m.serverPicker.PickServer(m.ctx) if err != nil { - return err + go m.reconnectGuard.StartReconnectTrys(m.ctx, nil) + } else { + m.storeClient(client) } - m.relayClient = client - m.reconnectGuard = NewGuard(m.ctx, m.relayClient) - m.relayClient.SetOnConnectedListener(m.onServerConnected) - m.relayClient.SetOnDisconnectListener(func() { - m.onServerDisconnected(client.connectionURL) - }) - m.startCleanupLoop() - return nil + go m.listenGuardEvent(m.ctx) + go m.startCleanupLoop() + return err } // OpenConn opens a connection to the given peer key. If the peer is on the same relay server, the connection will be // established via the relay server. If the peer is on a different relay server, the manager will establish a new // connection to the relay server. It returns back with a net.Conn what represent the remote peer connection. func (m *Manager) OpenConn(serverAddress, peerKey string) (net.Conn, error) { + m.relayClientMu.Lock() + defer m.relayClientMu.Unlock() + if m.relayClient == nil { return nil, ErrRelayClientNotConnected } @@ -146,6 +155,9 @@ func (m *Manager) OpenConn(serverAddress, peerKey string) (net.Conn, error) { // Ready returns true if the home Relay client is connected to the relay server. func (m *Manager) Ready() bool { + m.relayClientMu.Lock() + defer m.relayClientMu.Unlock() + if m.relayClient == nil { return false } @@ -159,6 +171,13 @@ func (m *Manager) SetOnReconnectedListener(f func()) { // AddCloseListener adds a listener to the given server instance address. The listener will be called if the connection // closed. func (m *Manager) AddCloseListener(serverAddress string, onClosedListener OnServerCloseListener) error { + m.relayClientMu.Lock() + defer m.relayClientMu.Unlock() + + if m.relayClient == nil { + return ErrRelayClientNotConnected + } + foreign, err := m.isForeignServer(serverAddress) if err != nil { return err @@ -177,6 +196,9 @@ func (m *Manager) AddCloseListener(serverAddress string, onClosedListener OnServ // RelayInstanceAddress returns the address of the permanent relay server. It could change if the network connection is // lost. This address will be sent to the target peer to choose the common relay server for the communication. func (m *Manager) RelayInstanceAddress() (string, error) { + m.relayClientMu.Lock() + defer m.relayClientMu.Unlock() + if m.relayClient == nil { return "", ErrRelayClientNotConnected } @@ -185,13 +207,18 @@ func (m *Manager) RelayInstanceAddress() (string, error) { // ServerURLs returns the addresses of the relay servers. func (m *Manager) ServerURLs() []string { - return m.serverURLs + return m.serverPicker.ServerURLs.Load().([]string) } // HasRelayAddress returns true if the manager is serving. With this method can check if the peer can communicate with // Relay service. func (m *Manager) HasRelayAddress() bool { - return len(m.serverURLs) > 0 + return len(m.serverPicker.ServerURLs.Load().([]string)) > 0 +} + +func (m *Manager) UpdateServerURLs(serverURLs []string) { + log.Infof("update relay server URLs: %v", serverURLs) + m.serverPicker.ServerURLs.Store(serverURLs) } // UpdateToken updates the token in the token store. @@ -245,9 +272,7 @@ func (m *Manager) openConnVia(serverAddress, peerKey string) (net.Conn, error) { return nil, err } // if connection closed then delete the relay client from the list - relayClient.SetOnDisconnectListener(func() { - m.onServerDisconnected(serverAddress) - }) + relayClient.SetOnDisconnectListener(m.onServerDisconnected) rt.relayClient = relayClient rt.Unlock() @@ -265,14 +290,37 @@ func (m *Manager) onServerConnected() { go m.onReconnectedListenerFn() } +// onServerDisconnected start to reconnection for home server only func (m *Manager) onServerDisconnected(serverAddress string) { + m.relayClientMu.Lock() if serverAddress == m.relayClient.connectionURL { - go m.reconnectGuard.OnDisconnected() + go m.reconnectGuard.StartReconnectTrys(m.ctx, m.relayClient) } + m.relayClientMu.Unlock() m.notifyOnDisconnectListeners(serverAddress) } +func (m *Manager) listenGuardEvent(ctx context.Context) { + for { + select { + case rc := <-m.reconnectGuard.OnNewRelayClient: + m.storeClient(rc) + case <-ctx.Done(): + return + } + } +} + +func (m *Manager) storeClient(client *Client) { + m.relayClientMu.Lock() + defer m.relayClientMu.Unlock() + + m.relayClient = client + m.relayClient.SetOnConnectedListener(m.onServerConnected) + m.relayClient.SetOnDisconnectListener(m.onServerDisconnected) +} + func (m *Manager) isForeignServer(address string) (bool, error) { rAddr, err := m.relayClient.ServerInstanceURL() if err != nil { @@ -282,22 +330,16 @@ func (m *Manager) isForeignServer(address string) (bool, error) { } func (m *Manager) startCleanupLoop() { - if m.ctx.Err() != nil { - return - } - ticker := time.NewTicker(relayCleanupInterval) - go func() { - defer ticker.Stop() - for { - select { - case <-m.ctx.Done(): - return - case <-ticker.C: - m.cleanUpUnusedRelays() - } + defer ticker.Stop() + for { + select { + case <-m.ctx.Done(): + return + case <-ticker.C: + m.cleanUpUnusedRelays() } - }() + } } func (m *Manager) cleanUpUnusedRelays() { diff --git a/relay/client/picker.go b/relay/client/picker.go index 13b0547aa..eb5062dbb 100644 --- a/relay/client/picker.go +++ b/relay/client/picker.go @@ -4,6 +4,7 @@ import ( "context" "errors" "fmt" + "sync/atomic" "time" log "github.com/sirupsen/logrus" @@ -12,10 +13,13 @@ import ( ) const ( - connectionTimeout = 30 * time.Second maxConcurrentServers = 7 ) +var ( + connectionTimeout = 30 * time.Second +) + type connResult struct { RelayClient *Client Url string @@ -24,20 +28,22 @@ type connResult struct { type ServerPicker struct { TokenStore *auth.TokenStore + ServerURLs atomic.Value PeerID string } -func (sp *ServerPicker) PickServer(parentCtx context.Context, urls []string) (*Client, error) { +func (sp *ServerPicker) PickServer(parentCtx context.Context) (*Client, error) { ctx, cancel := context.WithTimeout(parentCtx, connectionTimeout) defer cancel() - totalServers := len(urls) + totalServers := len(sp.ServerURLs.Load().([]string)) connResultChan := make(chan connResult, totalServers) successChan := make(chan connResult, 1) concurrentLimiter := make(chan struct{}, maxConcurrentServers) - for _, url := range urls { + log.Debugf("pick server from list: %v", sp.ServerURLs.Load().([]string)) + for _, url := range sp.ServerURLs.Load().([]string) { // todo check if we have a successful connection so we do not need to connect to other servers concurrentLimiter <- struct{}{} go func(url string) { @@ -78,7 +84,7 @@ func (sp *ServerPicker) processConnResults(resultChan chan connResult, successCh for numOfResults := 0; numOfResults < cap(resultChan); numOfResults++ { cr := <-resultChan if cr.Err != nil { - log.Debugf("failed to connect to Relay server: %s: %v", cr.Url, cr.Err) + log.Tracef("failed to connect to Relay server: %s: %v", cr.Url, cr.Err) continue } log.Infof("connected to Relay server: %s", cr.Url) diff --git a/relay/client/picker_test.go b/relay/client/picker_test.go index 4800e05ba..20a03e64d 100644 --- a/relay/client/picker_test.go +++ b/relay/client/picker_test.go @@ -7,16 +7,19 @@ import ( ) func TestServerPicker_UnavailableServers(t *testing.T) { + connectionTimeout = 5 * time.Second + sp := ServerPicker{ TokenStore: nil, PeerID: "test", } + sp.ServerURLs.Store([]string{"rel://dummy1", "rel://dummy2"}) ctx, cancel := context.WithTimeout(context.Background(), connectionTimeout+1) defer cancel() go func() { - _, err := sp.PickServer(ctx, []string{"rel://dummy1", "rel://dummy2"}) + _, err := sp.PickServer(ctx) if err == nil { t.Error(err) } From 05c4aa7c2cad19f3679bf2548f8dc296dd1e043b Mon Sep 17 00:00:00 2001 From: Maycon Santos Date: Fri, 22 Nov 2024 18:50:47 +0100 Subject: [PATCH 15/21] [misc] Renew slack link (#2938) --- README.md | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/README.md b/README.md index a2d7f3897..e7925ae09 100644 --- a/README.md +++ b/README.md @@ -17,7 +17,7 @@
- +
@@ -34,7 +34,7 @@
See Documentation
- Join our Slack channel + Join our Slack channel
From 56cecf849ea4f6092b0dc9b421126da399bffa95 Mon Sep 17 00:00:00 2001 From: Maycon Santos Date: Fri, 22 Nov 2024 20:40:30 +0100 Subject: [PATCH 16/21] Import time package (#2940) --- relay/client/picker_test.go | 1 + 1 file changed, 1 insertion(+) diff --git a/relay/client/picker_test.go b/relay/client/picker_test.go index 20a03e64d..28167c5ce 100644 --- a/relay/client/picker_test.go +++ b/relay/client/picker_test.go @@ -4,6 +4,7 @@ import ( "context" "errors" "testing" + "time" ) func TestServerPicker_UnavailableServers(t *testing.T) { From 940d0c48c69198803b4cd88125214c5cb666bf2b Mon Sep 17 00:00:00 2001 From: Viktor Liu <17948409+lixmal@users.noreply.github.com> Date: Mon, 25 Nov 2024 15:11:31 +0100 Subject: [PATCH 17/21] [client] Don't return error in userspace mode without firewall (#2924) --- client/firewall/uspfilter/uspfilter.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/client/firewall/uspfilter/uspfilter.go b/client/firewall/uspfilter/uspfilter.go index af5dc6733..fb726395b 100644 --- a/client/firewall/uspfilter/uspfilter.go +++ b/client/firewall/uspfilter/uspfilter.go @@ -239,7 +239,7 @@ func (m *Manager) DeletePeerRule(rule firewall.Rule) error { // SetLegacyManagement doesn't need to be implemented for this manager func (m *Manager) SetLegacyManagement(isLegacy bool) error { if m.nativeFirewall == nil { - return errRouteNotSupported + return nil } return m.nativeFirewall.SetLegacyManagement(isLegacy) } From 0ecd5f211850d50dcf7179adbf3ae0246705f0a3 Mon Sep 17 00:00:00 2001 From: Viktor Liu <17948409+lixmal@users.noreply.github.com> Date: Mon, 25 Nov 2024 15:11:56 +0100 Subject: [PATCH 18/21] [client] Test nftables for incompatible iptables rules (#2948) --- .../firewall/nftables/manager_linux_test.go | 104 ++++++++++++++++++ 1 file changed, 104 insertions(+) diff --git a/client/firewall/nftables/manager_linux_test.go b/client/firewall/nftables/manager_linux_test.go index 77f4f0306..33fdc4b3d 100644 --- a/client/firewall/nftables/manager_linux_test.go +++ b/client/firewall/nftables/manager_linux_test.go @@ -1,9 +1,11 @@ package nftables import ( + "bytes" "fmt" "net" "net/netip" + "os/exec" "testing" "time" @@ -225,3 +227,105 @@ func TestNFtablesCreatePerformance(t *testing.T) { }) } } + +func runIptablesSave(t *testing.T) (string, string) { + t.Helper() + var stdout, stderr bytes.Buffer + cmd := exec.Command("iptables-save") + cmd.Stdout = &stdout + cmd.Stderr = &stderr + + err := cmd.Run() + require.NoError(t, err, "iptables-save failed to run") + + return stdout.String(), stderr.String() +} + +func verifyIptablesOutput(t *testing.T, stdout, stderr string) { + t.Helper() + // Check for any incompatibility warnings + require.NotContains(t, + stderr, + "incompatible", + "iptables-save produced compatibility warning. Full stderr: %s", + stderr, + ) + + // Verify standard tables are present + expectedTables := []string{ + "*filter", + "*nat", + "*mangle", + } + + for _, table := range expectedTables { + require.Contains(t, + stdout, + table, + "iptables-save output missing expected table: %s\nFull stdout: %s", + table, + stdout, + ) + } +} + +func TestNftablesManagerCompatibilityWithIptables(t *testing.T) { + if check() != NFTABLES { + t.Skip("nftables not supported on this system") + } + + if _, err := exec.LookPath("iptables-save"); err != nil { + t.Skipf("iptables-save not available on this system: %v", err) + } + + // First ensure iptables-nft tables exist by running iptables-save + stdout, stderr := runIptablesSave(t) + verifyIptablesOutput(t, stdout, stderr) + + manager, err := Create(ifaceMock) + require.NoError(t, err, "failed to create manager") + require.NoError(t, manager.Init(nil)) + + t.Cleanup(func() { + err := manager.Reset(nil) + require.NoError(t, err, "failed to reset manager state") + + // Verify iptables output after reset + stdout, stderr := runIptablesSave(t) + verifyIptablesOutput(t, stdout, stderr) + }) + + ip := net.ParseIP("100.96.0.1") + _, err = manager.AddPeerFiltering( + ip, + fw.ProtocolTCP, + nil, + &fw.Port{Values: []int{80}}, + fw.RuleDirectionIN, + fw.ActionAccept, + "", + "test rule", + ) + require.NoError(t, err, "failed to add peer filtering rule") + + _, err = manager.AddRouteFiltering( + []netip.Prefix{netip.MustParsePrefix("192.168.2.0/24")}, + netip.MustParsePrefix("10.1.0.0/24"), + fw.ProtocolTCP, + nil, + &fw.Port{Values: []int{443}}, + fw.ActionAccept, + ) + require.NoError(t, err, "failed to add route filtering rule") + + pair := fw.RouterPair{ + Source: netip.MustParsePrefix("192.168.1.0/24"), + Destination: netip.MustParsePrefix("10.0.0.0/24"), + Masquerade: true, + } + err = manager.AddNatRule(pair) + require.NoError(t, err, "failed to add NAT rule") + + stdout, stderr = runIptablesSave(t) + verifyIptablesOutput(t, stdout, stderr) +} From f1625b32bdd6c1fe0abb73756835947b99d1a97f Mon Sep 17 00:00:00 2001 From: Viktor Liu <17948409+lixmal@users.noreply.github.com> Date: Mon, 25 Nov 2024 15:12:16 +0100 Subject: [PATCH 19/21] [client] Set up sysctl and routing table name only if routing rules are available (#2933) --- .../routemanager/systemops/systemops_linux.go | 22 +++++++++---------- 1 file changed, 11 insertions(+), 11 deletions(-) diff --git a/client/internal/routemanager/systemops/systemops_linux.go b/client/internal/routemanager/systemops/systemops_linux.go index 71a0f26ae..ac4fd5c71 100644 --- a/client/internal/routemanager/systemops/systemops_linux.go +++ b/client/internal/routemanager/systemops/systemops_linux.go @@ -92,17 +92,6 @@ func (r *SysOps) SetupRouting(initAddresses []net.IP, stateManager *statemanager return r.setupRefCounter(initAddresses, stateManager) } - if err = addRoutingTableName(); err != nil { - log.Errorf("Error adding routing table name: %v", err) - } - - originalValues, err := sysctl.Setup(r.wgInterface) - if err != nil { - log.Errorf("Error setting up sysctl: %v", err) - sysctlFailed = true - } - originalSysctl = originalValues - defer func() { if err != nil { if cleanErr := r.CleanupRouting(stateManager); cleanErr != nil { @@ -123,6 +112,17 @@ func (r *SysOps) SetupRouting(initAddresses []net.IP, stateManager *statemanager } } + if err = addRoutingTableName(); err != nil { + log.Errorf("Error adding routing table name: %v", err) + } + + originalValues, err := sysctl.Setup(r.wgInterface) + if err != nil { + log.Errorf("Error setting up sysctl: %v", err) + sysctlFailed = true + } + originalSysctl = originalValues + return nil, nil, nil } From 9810386937edd1665109479f745e6faaf21db840 Mon Sep 17 00:00:00 2001 From: Viktor Liu <17948409+lixmal@users.noreply.github.com> Date: Mon, 25 Nov 2024 15:19:56 +0100 Subject: [PATCH 20/21] [client] Allow routing to fallback to exclusion routes if rules are not supported (#2909) --- client/internal/routemanager/systemops/systemops_linux.go | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/client/internal/routemanager/systemops/systemops_linux.go b/client/internal/routemanager/systemops/systemops_linux.go index ac4fd5c71..1d629d6e9 100644 --- a/client/internal/routemanager/systemops/systemops_linux.go +++ b/client/internal/routemanager/systemops/systemops_linux.go @@ -450,7 +450,7 @@ func addRule(params ruleParams) error { rule.Invert = params.invert rule.SuppressPrefixlen = params.suppressPrefix - if err := netlink.RuleAdd(rule); err != nil && !errors.Is(err, syscall.EEXIST) && !errors.Is(err, syscall.EAFNOSUPPORT) { + if err := netlink.RuleAdd(rule); err != nil && !errors.Is(err, syscall.EEXIST) { return fmt.Errorf("add routing rule: %w", err) } @@ -467,7 +467,7 @@ func removeRule(params ruleParams) error { rule.Priority = params.priority rule.SuppressPrefixlen = params.suppressPrefix - if err := netlink.RuleDel(rule); err != nil && !errors.Is(err, syscall.ENOENT) && !errors.Is(err, syscall.EAFNOSUPPORT) { + if err := netlink.RuleDel(rule); err != nil && !errors.Is(err, syscall.ENOENT) { return fmt.Errorf("remove routing rule: %w", err) } From ca12bc6953b8ba0e8645b90dc2ef8d50c8c38a01 Mon Sep 17 00:00:00 2001 From: Bethuel Mmbaga Date: Mon, 25 Nov 2024 18:26:24 +0300 Subject: [PATCH 21/21] [management] Refactor posture check to use store methods (#2874) --- management/server/account.go | 2 +- management/server/dns.go | 2 +- management/server/group.go | 19 +- .../server/http/posture_checks_handler.go | 3 +- .../http/posture_checks_handler_test.go | 6 +- management/server/mock_server/account_mock.go | 6 +- management/server/nameserver.go | 10 +- management/server/peer.go | 25 +- management/server/policy.go | 6 +- management/server/posture/checks.go | 6 - management/server/posture_checks.go | 337 +++++++++++------- management/server/posture_checks_test.go | 221 +++++++----- management/server/route.go | 10 +- management/server/sql_store.go | 51 ++- management/server/sql_store_test.go | 135 +++++++ management/server/status/error.go | 5 + management/server/store.go | 4 +- management/server/testdata/extended-store.sql | 1 + 18 files changed, 589 insertions(+), 260 deletions(-) diff --git a/management/server/account.go b/management/server/account.go index 0ab123655..9fb56c855 100644 --- a/management/server/account.go +++ b/management/server/account.go @@ -139,7 +139,7 @@ type AccountManager interface { HasConnectedChannel(peerID string) bool GetExternalCacheManager() ExternalCacheManager GetPostureChecks(ctx context.Context, accountID, postureChecksID, userID string) (*posture.Checks, error) - SavePostureChecks(ctx context.Context, accountID, userID string, postureChecks *posture.Checks) error + SavePostureChecks(ctx context.Context, accountID, userID string, postureChecks *posture.Checks) (*posture.Checks, error) DeletePostureChecks(ctx context.Context, accountID, postureChecksID, userID string) error ListPostureChecks(ctx context.Context, accountID, userID string) ([]*posture.Checks, error) GetIdpManager() idp.Manager diff --git a/management/server/dns.go b/management/server/dns.go index 4551be5ab..e52be6016 100644 --- a/management/server/dns.go +++ b/management/server/dns.go @@ -145,7 +145,7 @@ func (am *DefaultAccountManager) SaveDNSSettings(ctx context.Context, accountID am.StoreEvent(ctx, userID, accountID, accountID, activity.GroupRemovedFromDisabledManagementGroups, meta) } - if anyGroupHasPeers(account, addedGroups) || anyGroupHasPeers(account, removedGroups) { + if am.anyGroupHasPeers(account, addedGroups) || am.anyGroupHasPeers(account, removedGroups) { am.updateAccountPeers(ctx, accountID) } diff --git a/management/server/group.go b/management/server/group.go index a36213f04..7b307cf1a 100644 --- a/management/server/group.go +++ b/management/server/group.go @@ -566,8 +566,7 @@ func areGroupChangesAffectPeers(ctx context.Context, transaction Store, accountI return false, nil } -// anyGroupHasPeers checks if any of the given groups in the account have peers. -func anyGroupHasPeers(account *Account, groupIDs []string) bool { +func (am *DefaultAccountManager) anyGroupHasPeers(account *Account, groupIDs []string) bool { for _, groupID := range groupIDs { if group, exists := account.Groups[groupID]; exists && group.HasPeers() { return true @@ -575,3 +574,19 @@ func anyGroupHasPeers(account *Account, groupIDs []string) bool { } return false } + +// anyGroupHasPeers checks if any of the given groups in the account have peers. +func anyGroupHasPeers(ctx context.Context, transaction Store, accountID string, groupIDs []string) (bool, error) { + groups, err := transaction.GetGroupsByIDs(ctx, LockingStrengthShare, accountID, groupIDs) + if err != nil { + return false, err + } + + for _, group := range groups { + if group.HasPeers() { + return true, nil + } + } + + return false, nil +} diff --git a/management/server/http/posture_checks_handler.go b/management/server/http/posture_checks_handler.go index 1d020e9bc..2c8204292 100644 --- a/management/server/http/posture_checks_handler.go +++ b/management/server/http/posture_checks_handler.go @@ -169,7 +169,8 @@ func (p *PostureChecksHandler) savePostureChecks(w http.ResponseWriter, r *http. return } - if err := p.accountManager.SavePostureChecks(r.Context(), accountID, userID, postureChecks); err != nil { + postureChecks, err = p.accountManager.SavePostureChecks(r.Context(), accountID, userID, postureChecks) + if err != nil { util.WriteError(r.Context(), err, w) return } diff --git a/management/server/http/posture_checks_handler_test.go b/management/server/http/posture_checks_handler_test.go index 02f0f0d83..f400cec81 100644 --- a/management/server/http/posture_checks_handler_test.go +++ b/management/server/http/posture_checks_handler_test.go @@ -40,15 +40,15 @@ func initPostureChecksTestData(postureChecks ...*posture.Checks) *PostureChecksH } return p, nil }, - SavePostureChecksFunc: func(_ context.Context, accountID, userID string, postureChecks *posture.Checks) error { + SavePostureChecksFunc: func(_ context.Context, accountID, userID string, postureChecks *posture.Checks) (*posture.Checks, error) { postureChecks.ID = "postureCheck" testPostureChecks[postureChecks.ID] = postureChecks if err := postureChecks.Validate(); err != nil { - return status.Errorf(status.InvalidArgument, err.Error()) //nolint + return nil, status.Errorf(status.InvalidArgument, err.Error()) //nolint } - return nil + return postureChecks, nil }, DeletePostureChecksFunc: func(_ context.Context, accountID, postureChecksID, userID string) error { _, ok := testPostureChecks[postureChecksID] diff --git a/management/server/mock_server/account_mock.go b/management/server/mock_server/account_mock.go index aa6a47b15..673ed33bb 100644 --- a/management/server/mock_server/account_mock.go +++ b/management/server/mock_server/account_mock.go @@ -96,7 +96,7 @@ type MockAccountManager struct { HasConnectedChannelFunc func(peerID string) bool GetExternalCacheManagerFunc func() server.ExternalCacheManager GetPostureChecksFunc func(ctx context.Context, accountID, postureChecksID, userID string) (*posture.Checks, error) - SavePostureChecksFunc func(ctx context.Context, accountID, userID string, postureChecks *posture.Checks) error + SavePostureChecksFunc func(ctx context.Context, accountID, userID string, postureChecks *posture.Checks) (*posture.Checks, error) DeletePostureChecksFunc func(ctx context.Context, accountID, postureChecksID, userID string) error ListPostureChecksFunc func(ctx context.Context, accountID, userID string) ([]*posture.Checks, error) GetIdpManagerFunc func() idp.Manager @@ -730,11 +730,11 @@ func (am *MockAccountManager) GetPostureChecks(ctx context.Context, accountID, p } // SavePostureChecks mocks SavePostureChecks of the AccountManager interface -func (am *MockAccountManager) SavePostureChecks(ctx context.Context, accountID, userID string, postureChecks *posture.Checks) error { +func (am *MockAccountManager) SavePostureChecks(ctx context.Context, accountID, userID string, postureChecks *posture.Checks) (*posture.Checks, error) { if am.SavePostureChecksFunc != nil { return am.SavePostureChecksFunc(ctx, accountID, userID, postureChecks) } - return status.Errorf(codes.Unimplemented, "method SavePostureChecks is not implemented") + return nil, status.Errorf(codes.Unimplemented, "method SavePostureChecks is not implemented") } // DeletePostureChecks mocks DeletePostureChecks of the AccountManager interface diff --git a/management/server/nameserver.go b/management/server/nameserver.go index 957008714..9119a3dec 100644 --- a/management/server/nameserver.go +++ b/management/server/nameserver.go @@ -70,7 +70,7 @@ func (am *DefaultAccountManager) CreateNameServerGroup(ctx context.Context, acco return nil, err } - if anyGroupHasPeers(account, newNSGroup.Groups) { + if am.anyGroupHasPeers(account, newNSGroup.Groups) { am.updateAccountPeers(ctx, accountID) } am.StoreEvent(ctx, userID, newNSGroup.ID, accountID, activity.NameserverGroupCreated, newNSGroup.EventMeta()) @@ -105,7 +105,7 @@ func (am *DefaultAccountManager) SaveNameServerGroup(ctx context.Context, accoun return err } - if areNameServerGroupChangesAffectPeers(account, nsGroupToSave, oldNSGroup) { + if am.areNameServerGroupChangesAffectPeers(account, nsGroupToSave, oldNSGroup) { am.updateAccountPeers(ctx, accountID) } am.StoreEvent(ctx, userID, nsGroupToSave.ID, accountID, activity.NameserverGroupUpdated, nsGroupToSave.EventMeta()) @@ -135,7 +135,7 @@ func (am *DefaultAccountManager) DeleteNameServerGroup(ctx context.Context, acco return err } - if anyGroupHasPeers(account, nsGroup.Groups) { + if am.anyGroupHasPeers(account, nsGroup.Groups) { am.updateAccountPeers(ctx, accountID) } am.StoreEvent(ctx, userID, nsGroup.ID, accountID, activity.NameserverGroupDeleted, nsGroup.EventMeta()) @@ -279,9 +279,9 @@ func validateDomain(domain string) error { } // areNameServerGroupChangesAffectPeers checks if the changes in the nameserver group affect the peers. -func areNameServerGroupChangesAffectPeers(account *Account, newNSGroup, oldNSGroup *nbdns.NameServerGroup) bool { +func (am *DefaultAccountManager) areNameServerGroupChangesAffectPeers(account *Account, newNSGroup, oldNSGroup *nbdns.NameServerGroup) bool { if !newNSGroup.Enabled && !oldNSGroup.Enabled { return false } - return anyGroupHasPeers(account, newNSGroup.Groups) || anyGroupHasPeers(account, oldNSGroup.Groups) + return am.anyGroupHasPeers(account, newNSGroup.Groups) || am.anyGroupHasPeers(account, oldNSGroup.Groups) } diff --git a/management/server/peer.go b/management/server/peer.go index beb833dba..dcb47af3b 100644 --- a/management/server/peer.go +++ b/management/server/peer.go @@ -617,7 +617,11 @@ func (am *DefaultAccountManager) AddPeer(ctx context.Context, setupKey, userID s return nil, nil, nil, err } - postureChecks := am.getPeerPostureChecks(account, newPeer) + postureChecks, err := am.getPeerPostureChecks(ctx, account.Id, newPeer.ID) + if err != nil { + return nil, nil, nil, err + } + customZone := account.GetPeersCustomZone(ctx, am.dnsDomain) networkMap := account.GetPeerNetworkMap(ctx, newPeer.ID, customZone, approvedPeersMap, am.metrics.AccountManagerMetrics()) return newPeer, networkMap, postureChecks, nil @@ -702,7 +706,11 @@ func (am *DefaultAccountManager) SyncPeer(ctx context.Context, sync PeerSync, ac if err != nil { return nil, nil, nil, fmt.Errorf("failed to get validated peers: %w", err) } - postureChecks = am.getPeerPostureChecks(account, peer) + + postureChecks, err = am.getPeerPostureChecks(ctx, account.Id, peer.ID) + if err != nil { + return nil, nil, nil, err + } customZone := account.GetPeersCustomZone(ctx, am.dnsDomain) return peer, account.GetPeerNetworkMap(ctx, peer.ID, customZone, validPeersMap, am.metrics.AccountManagerMetrics()), postureChecks, nil @@ -876,7 +884,11 @@ func (am *DefaultAccountManager) getValidatedPeerWithMap(ctx context.Context, is if err != nil { return nil, nil, nil, err } - postureChecks = am.getPeerPostureChecks(account, peer) + + postureChecks, err = am.getPeerPostureChecks(ctx, account.Id, peer.ID) + if err != nil { + return nil, nil, nil, err + } customZone := account.GetPeersCustomZone(ctx, am.dnsDomain) return peer, account.GetPeerNetworkMap(ctx, peer.ID, customZone, approvedPeersMap, am.metrics.AccountManagerMetrics()), postureChecks, nil @@ -1030,7 +1042,12 @@ func (am *DefaultAccountManager) updateAccountPeers(ctx context.Context, account defer wg.Done() defer func() { <-semaphore }() - postureChecks := am.getPeerPostureChecks(account, p) + postureChecks, err := am.getPeerPostureChecks(ctx, account.Id, p.ID) + if err != nil { + log.WithContext(ctx).Errorf("failed to send out updates to peers, failed to get peer: %s posture checks: %v", p.ID, err) + return + } + remotePeerNetworkMap := account.GetPeerNetworkMap(ctx, p.ID, customZone, approvedPeersMap, am.metrics.AccountManagerMetrics()) update := toSyncResponse(ctx, nil, p, nil, nil, remotePeerNetworkMap, am.GetDNSDomain(), postureChecks, dnsCache) am.peersUpdateManager.SendUpdate(ctx, p.ID, &UpdateMessage{Update: update, NetworkMap: remotePeerNetworkMap}) diff --git a/management/server/policy.go b/management/server/policy.go index 8a5733f01..c7872591d 100644 --- a/management/server/policy.go +++ b/management/server/policy.go @@ -405,7 +405,7 @@ func (am *DefaultAccountManager) DeletePolicy(ctx context.Context, accountID, po am.StoreEvent(ctx, userID, policy.ID, accountID, activity.PolicyRemoved, policy.EventMeta()) - if anyGroupHasPeers(account, policy.ruleGroups()) { + if am.anyGroupHasPeers(account, policy.ruleGroups()) { am.updateAccountPeers(ctx, accountID) } @@ -469,7 +469,7 @@ func (am *DefaultAccountManager) savePolicy(account *Account, policyToSave *Poli if !policyToSave.Enabled && !oldPolicy.Enabled { return false, nil } - updateAccountPeers := anyGroupHasPeers(account, oldPolicy.ruleGroups()) || anyGroupHasPeers(account, policyToSave.ruleGroups()) + updateAccountPeers := am.anyGroupHasPeers(account, oldPolicy.ruleGroups()) || am.anyGroupHasPeers(account, policyToSave.ruleGroups()) return updateAccountPeers, nil } @@ -477,7 +477,7 @@ func (am *DefaultAccountManager) savePolicy(account *Account, policyToSave *Poli // Add the new policy to the account account.Policies = append(account.Policies, policyToSave) - return anyGroupHasPeers(account, policyToSave.ruleGroups()), nil + return am.anyGroupHasPeers(account, policyToSave.ruleGroups()), nil } func toProtocolFirewallRules(rules []*FirewallRule) []*proto.FirewallRule { diff --git a/management/server/posture/checks.go b/management/server/posture/checks.go index f2739dddf..b2f308d76 100644 --- a/management/server/posture/checks.go +++ b/management/server/posture/checks.go @@ -7,8 +7,6 @@ import ( "regexp" "github.com/hashicorp/go-version" - "github.com/rs/xid" - "github.com/netbirdio/netbird/management/server/http/api" nbpeer "github.com/netbirdio/netbird/management/server/peer" "github.com/netbirdio/netbird/management/server/status" @@ -172,10 +170,6 @@ func NewChecksFromAPIPostureCheckUpdate(source api.PostureCheckUpdate, postureCh } func buildPostureCheck(postureChecksID string, name string, description string, checks api.Checks) (*Checks, error) { - if postureChecksID == "" { - postureChecksID = xid.New().String() - } - postureChecks := Checks{ ID: postureChecksID, Name: name, diff --git a/management/server/posture_checks.go b/management/server/posture_checks.go index 096cff3f5..59e726c41 100644 --- a/management/server/posture_checks.go +++ b/management/server/posture_checks.go @@ -2,16 +2,14 @@ package server import ( "context" + "fmt" "slices" "github.com/netbirdio/netbird/management/server/activity" - nbpeer "github.com/netbirdio/netbird/management/server/peer" "github.com/netbirdio/netbird/management/server/posture" "github.com/netbirdio/netbird/management/server/status" -) - -const ( - errMsgPostureAdminOnly = "only users with admin power are allowed to view posture checks" + "github.com/rs/xid" + "golang.org/x/exp/maps" ) func (am *DefaultAccountManager) GetPostureChecks(ctx context.Context, accountID, postureChecksID, userID string) (*posture.Checks, error) { @@ -20,219 +18,284 @@ func (am *DefaultAccountManager) GetPostureChecks(ctx context.Context, accountID return nil, err } - if !user.HasAdminPower() || user.AccountID != accountID { - return nil, status.Errorf(status.PermissionDenied, errMsgPostureAdminOnly) - } - - return am.Store.GetPostureChecksByID(ctx, LockingStrengthShare, postureChecksID, accountID) -} - -func (am *DefaultAccountManager) SavePostureChecks(ctx context.Context, accountID, userID string, postureChecks *posture.Checks) error { - unlock := am.Store.AcquireWriteLockByUID(ctx, accountID) - defer unlock() - - account, err := am.Store.GetAccount(ctx, accountID) - if err != nil { - return err - } - - user, err := account.FindUser(userID) - if err != nil { - return err + if user.AccountID != accountID { + return nil, status.NewUserNotPartOfAccountError() } if !user.HasAdminPower() { - return status.Errorf(status.PermissionDenied, errMsgPostureAdminOnly) + return nil, status.NewAdminPermissionError() } - if err := postureChecks.Validate(); err != nil { - return status.Errorf(status.InvalidArgument, err.Error()) //nolint + return am.Store.GetPostureChecksByID(ctx, LockingStrengthShare, accountID, postureChecksID) +} + +// SavePostureChecks saves a posture check. +func (am *DefaultAccountManager) SavePostureChecks(ctx context.Context, accountID, userID string, postureChecks *posture.Checks) (*posture.Checks, error) { + unlock := am.Store.AcquireWriteLockByUID(ctx, accountID) + defer unlock() + + user, err := am.Store.GetUserByUserID(ctx, LockingStrengthShare, userID) + if err != nil { + return nil, err } - exists, uniqName := am.savePostureChecks(account, postureChecks) - - // we do not allow create new posture checks with non uniq name - if !exists && !uniqName { - return status.Errorf(status.PreconditionFailed, "Posture check name should be unique") + if user.AccountID != accountID { + return nil, status.NewUserNotPartOfAccountError() } - action := activity.PostureCheckCreated - if exists { - action = activity.PostureCheckUpdated - account.Network.IncSerial() + if !user.HasAdminPower() { + return nil, status.NewAdminPermissionError() } - if err = am.Store.SaveAccount(ctx, account); err != nil { - return err + var updateAccountPeers bool + var isUpdate = postureChecks.ID != "" + var action = activity.PostureCheckCreated + + err = am.Store.ExecuteInTransaction(ctx, func(transaction Store) error { + if err = validatePostureChecks(ctx, transaction, accountID, postureChecks); err != nil { + return err + } + + if isUpdate { + updateAccountPeers, err = arePostureCheckChangesAffectPeers(ctx, transaction, accountID, postureChecks.ID) + if err != nil { + return err + } + + if err = transaction.IncrementNetworkSerial(ctx, LockingStrengthUpdate, accountID); err != nil { + return err + } + + action = activity.PostureCheckUpdated + } + + postureChecks.AccountID = accountID + return transaction.SavePostureChecks(ctx, LockingStrengthUpdate, postureChecks) + }) + if err != nil { + return nil, err } am.StoreEvent(ctx, userID, postureChecks.ID, accountID, action, postureChecks.EventMeta()) - if arePostureCheckChangesAffectingPeers(account, postureChecks.ID, exists) { + if updateAccountPeers { am.updateAccountPeers(ctx, accountID) } - return nil + return postureChecks, nil } +// DeletePostureChecks deletes a posture check by ID. func (am *DefaultAccountManager) DeletePostureChecks(ctx context.Context, accountID, postureChecksID, userID string) error { unlock := am.Store.AcquireWriteLockByUID(ctx, accountID) defer unlock() - account, err := am.Store.GetAccount(ctx, accountID) + user, err := am.Store.GetUserByUserID(ctx, LockingStrengthShare, userID) if err != nil { return err } - user, err := account.FindUser(userID) - if err != nil { - return err + if user.AccountID != accountID { + return status.NewUserNotPartOfAccountError() } if !user.HasAdminPower() { - return status.Errorf(status.PermissionDenied, errMsgPostureAdminOnly) + return status.NewAdminPermissionError() } - postureChecks, err := am.deletePostureChecks(account, postureChecksID) + var postureChecks *posture.Checks + + err = am.Store.ExecuteInTransaction(ctx, func(transaction Store) error { + postureChecks, err = transaction.GetPostureChecksByID(ctx, LockingStrengthShare, accountID, postureChecksID) + if err != nil { + return err + } + + if err = isPostureCheckLinkedToPolicy(ctx, transaction, postureChecksID, accountID); err != nil { + return err + } + + if err = transaction.IncrementNetworkSerial(ctx, LockingStrengthUpdate, accountID); err != nil { + return err + } + + return transaction.DeletePostureChecks(ctx, LockingStrengthUpdate, accountID, postureChecksID) + }) if err != nil { return err } - if err = am.Store.SaveAccount(ctx, account); err != nil { - return err - } - am.StoreEvent(ctx, userID, postureChecks.ID, accountID, activity.PostureCheckDeleted, postureChecks.EventMeta()) return nil } +// ListPostureChecks returns a list of posture checks. func (am *DefaultAccountManager) ListPostureChecks(ctx context.Context, accountID, userID string) ([]*posture.Checks, error) { user, err := am.Store.GetUserByUserID(ctx, LockingStrengthShare, userID) if err != nil { return nil, err } - if !user.HasAdminPower() || user.AccountID != accountID { - return nil, status.Errorf(status.PermissionDenied, errMsgPostureAdminOnly) + if user.AccountID != accountID { + return nil, status.NewUserNotPartOfAccountError() + } + + if !user.HasAdminPower() { + return nil, status.NewAdminPermissionError() } return am.Store.GetAccountPostureChecks(ctx, LockingStrengthShare, accountID) } -func (am *DefaultAccountManager) savePostureChecks(account *Account, postureChecks *posture.Checks) (exists, uniqName bool) { - uniqName = true - for i, p := range account.PostureChecks { - if !exists && p.ID == postureChecks.ID { - account.PostureChecks[i] = postureChecks - exists = true - } - if p.Name == postureChecks.Name { - uniqName = false - } - } - if !exists { - account.PostureChecks = append(account.PostureChecks, postureChecks) - } - return -} - -func (am *DefaultAccountManager) deletePostureChecks(account *Account, postureChecksID string) (*posture.Checks, error) { - postureChecksIdx := -1 - for i, postureChecks := range account.PostureChecks { - if postureChecks.ID == postureChecksID { - postureChecksIdx = i - break - } - } - if postureChecksIdx < 0 { - return nil, status.Errorf(status.NotFound, "posture checks with ID %s doesn't exist", postureChecksID) - } - - // Check if posture check is linked to any policy - if isLinked, linkedPolicy := isPostureCheckLinkedToPolicy(account, postureChecksID); isLinked { - return nil, status.Errorf(status.PreconditionFailed, "posture checks have been linked to policy: %s", linkedPolicy.Name) - } - - postureChecks := account.PostureChecks[postureChecksIdx] - account.PostureChecks = append(account.PostureChecks[:postureChecksIdx], account.PostureChecks[postureChecksIdx+1:]...) - - return postureChecks, nil -} - // getPeerPostureChecks returns the posture checks applied for a given peer. -func (am *DefaultAccountManager) getPeerPostureChecks(account *Account, peer *nbpeer.Peer) []*posture.Checks { - peerPostureChecks := make(map[string]posture.Checks) +func (am *DefaultAccountManager) getPeerPostureChecks(ctx context.Context, accountID string, peerID string) ([]*posture.Checks, error) { + peerPostureChecks := make(map[string]*posture.Checks) - if len(account.PostureChecks) == 0 { + err := am.Store.ExecuteInTransaction(ctx, func(transaction Store) error { + postureChecks, err := transaction.GetAccountPostureChecks(ctx, LockingStrengthShare, accountID) + if err != nil { + return err + } + + if len(postureChecks) == 0 { + return nil + } + + policies, err := transaction.GetAccountPolicies(ctx, LockingStrengthShare, accountID) + if err != nil { + return err + } + + for _, policy := range policies { + if !policy.Enabled { + continue + } + + if err = addPolicyPostureChecks(ctx, transaction, accountID, peerID, policy, peerPostureChecks); err != nil { + return err + } + } + + return nil + }) + if err != nil { + return nil, err + } + + return maps.Values(peerPostureChecks), nil +} + +// arePostureCheckChangesAffectPeers checks if the changes in posture checks are affecting peers. +func arePostureCheckChangesAffectPeers(ctx context.Context, transaction Store, accountID, postureCheckID string) (bool, error) { + policies, err := transaction.GetAccountPolicies(ctx, LockingStrengthShare, accountID) + if err != nil { + return false, err + } + + for _, policy := range policies { + if slices.Contains(policy.SourcePostureChecks, postureCheckID) { + hasPeers, err := anyGroupHasPeers(ctx, transaction, accountID, policy.ruleGroups()) + if err != nil { + return false, err + } + + if hasPeers { + return true, nil + } + } + } + + return false, nil +} + +// validatePostureChecks validates the posture checks. +func validatePostureChecks(ctx context.Context, transaction Store, accountID string, postureChecks *posture.Checks) error { + if err := postureChecks.Validate(); err != nil { + return status.Errorf(status.InvalidArgument, err.Error()) //nolint + } + + // If the posture check already has an ID, verify its existence in the store. + if postureChecks.ID != "" { + if _, err := transaction.GetPostureChecksByID(ctx, LockingStrengthShare, accountID, postureChecks.ID); err != nil { + return err + } return nil } - for _, policy := range account.Policies { - if !policy.Enabled { - continue - } + // For new posture checks, ensure no duplicates by name. + checks, err := transaction.GetAccountPostureChecks(ctx, LockingStrengthShare, accountID) + if err != nil { + return err + } - if isPeerInPolicySourceGroups(peer.ID, account, policy) { - addPolicyPostureChecks(account, policy, peerPostureChecks) + for _, check := range checks { + if check.Name == postureChecks.Name && check.ID != postureChecks.ID { + return status.Errorf(status.InvalidArgument, "posture checks with name %s already exists", postureChecks.Name) } } - postureChecksList := make([]*posture.Checks, 0, len(peerPostureChecks)) - for _, check := range peerPostureChecks { - checkCopy := check - postureChecksList = append(postureChecksList, &checkCopy) + postureChecks.ID = xid.New().String() + + return nil +} + +// addPolicyPostureChecks adds posture checks from a policy to the peer posture checks map if the peer is in the policy's source groups. +func addPolicyPostureChecks(ctx context.Context, transaction Store, accountID, peerID string, policy *Policy, peerPostureChecks map[string]*posture.Checks) error { + isInGroup, err := isPeerInPolicySourceGroups(ctx, transaction, accountID, peerID, policy) + if err != nil { + return err } - return postureChecksList + if !isInGroup { + return nil + } + + for _, sourcePostureCheckID := range policy.SourcePostureChecks { + postureCheck, err := transaction.GetPostureChecksByID(ctx, LockingStrengthShare, accountID, sourcePostureCheckID) + if err != nil { + return err + } + peerPostureChecks[sourcePostureCheckID] = postureCheck + } + + return nil } // isPeerInPolicySourceGroups checks if a peer is present in any of the policy rule source groups. -func isPeerInPolicySourceGroups(peerID string, account *Account, policy *Policy) bool { +func isPeerInPolicySourceGroups(ctx context.Context, transaction Store, accountID, peerID string, policy *Policy) (bool, error) { for _, rule := range policy.Rules { if !rule.Enabled { continue } for _, sourceGroup := range rule.Sources { - group, ok := account.Groups[sourceGroup] - if ok && slices.Contains(group.Peers, peerID) { - return true + group, err := transaction.GetGroupByID(ctx, LockingStrengthShare, accountID, sourceGroup) + if err != nil { + return false, fmt.Errorf("failed to check peer in policy source group: %w", err) + } + + if slices.Contains(group.Peers, peerID) { + return true, nil } } } - return false -} - -func addPolicyPostureChecks(account *Account, policy *Policy, peerPostureChecks map[string]posture.Checks) { - for _, sourcePostureCheckID := range policy.SourcePostureChecks { - for _, postureCheck := range account.PostureChecks { - if postureCheck.ID == sourcePostureCheckID { - peerPostureChecks[sourcePostureCheckID] = *postureCheck - } - } - } -} - -func isPostureCheckLinkedToPolicy(account *Account, postureChecksID string) (bool, *Policy) { - for _, policy := range account.Policies { - if slices.Contains(policy.SourcePostureChecks, postureChecksID) { - return true, policy - } - } return false, nil } -// arePostureCheckChangesAffectingPeers checks if the changes in posture checks are affecting peers. -func arePostureCheckChangesAffectingPeers(account *Account, postureCheckID string, exists bool) bool { - if !exists { - return false +// isPostureCheckLinkedToPolicy checks whether the posture check is linked to any account policy. +func isPostureCheckLinkedToPolicy(ctx context.Context, transaction Store, postureChecksID, accountID string) error { + policies, err := transaction.GetAccountPolicies(ctx, LockingStrengthShare, accountID) + if err != nil { + return err } - isLinked, linkedPolicy := isPostureCheckLinkedToPolicy(account, postureCheckID) - if !isLinked { - return false + for _, policy := range policies { + if slices.Contains(policy.SourcePostureChecks, postureChecksID) { + return status.Errorf(status.PreconditionFailed, "posture checks have been linked to policy: %s", policy.Name) + } } - return anyGroupHasPeers(account, linkedPolicy.ruleGroups()) + + return nil } diff --git a/management/server/posture_checks_test.go b/management/server/posture_checks_test.go index c63538b9d..3c5c5fc79 100644 --- a/management/server/posture_checks_test.go +++ b/management/server/posture_checks_test.go @@ -7,6 +7,7 @@ import ( "github.com/rs/xid" "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" "github.com/netbirdio/netbird/management/server/group" @@ -16,7 +17,6 @@ import ( const ( adminUserID = "adminUserID" regularUserID = "regularUserID" - postureCheckID = "existing-id" postureCheckName = "Existing check" ) @@ -33,7 +33,7 @@ func TestDefaultAccountManager_PostureCheck(t *testing.T) { t.Run("Generic posture check flow", func(t *testing.T) { // regular users can not create checks - err := am.SavePostureChecks(context.Background(), account.Id, regularUserID, &posture.Checks{}) + _, err = am.SavePostureChecks(context.Background(), account.Id, regularUserID, &posture.Checks{}) assert.Error(t, err) // regular users cannot list check @@ -41,8 +41,7 @@ func TestDefaultAccountManager_PostureCheck(t *testing.T) { assert.Error(t, err) // should be possible to create posture check with uniq name - err = am.SavePostureChecks(context.Background(), account.Id, adminUserID, &posture.Checks{ - ID: postureCheckID, + postureCheck, err := am.SavePostureChecks(context.Background(), account.Id, adminUserID, &posture.Checks{ Name: postureCheckName, Checks: posture.ChecksDefinition{ NBVersionCheck: &posture.NBVersionCheck{ @@ -58,8 +57,7 @@ func TestDefaultAccountManager_PostureCheck(t *testing.T) { assert.Len(t, checks, 1) // should not be possible to create posture check with non uniq name - err = am.SavePostureChecks(context.Background(), account.Id, adminUserID, &posture.Checks{ - ID: "new-id", + _, err = am.SavePostureChecks(context.Background(), account.Id, adminUserID, &posture.Checks{ Name: postureCheckName, Checks: posture.ChecksDefinition{ GeoLocationCheck: &posture.GeoLocationCheck{ @@ -74,23 +72,20 @@ func TestDefaultAccountManager_PostureCheck(t *testing.T) { assert.Error(t, err) // admins can update posture checks - err = am.SavePostureChecks(context.Background(), account.Id, adminUserID, &posture.Checks{ - ID: postureCheckID, - Name: postureCheckName, - Checks: posture.ChecksDefinition{ - NBVersionCheck: &posture.NBVersionCheck{ - MinVersion: "0.27.0", - }, + postureCheck.Checks = posture.ChecksDefinition{ + NBVersionCheck: &posture.NBVersionCheck{ + MinVersion: "0.27.0", }, - }) + } + _, err = am.SavePostureChecks(context.Background(), account.Id, adminUserID, postureCheck) assert.NoError(t, err) // users should not be able to delete posture checks - err = am.DeletePostureChecks(context.Background(), account.Id, postureCheckID, regularUserID) + err = am.DeletePostureChecks(context.Background(), account.Id, postureCheck.ID, regularUserID) assert.Error(t, err) // admin should be able to delete posture checks - err = am.DeletePostureChecks(context.Background(), account.Id, postureCheckID, adminUserID) + err = am.DeletePostureChecks(context.Background(), account.Id, postureCheck.ID, adminUserID) assert.NoError(t, err) checks, err = am.ListPostureChecks(context.Background(), account.Id, adminUserID) assert.NoError(t, err) @@ -150,9 +145,22 @@ func TestPostureCheckAccountPeersUpdate(t *testing.T) { manager.peersUpdateManager.CloseChannel(context.Background(), peer1.ID) }) - postureCheck := posture.Checks{ - ID: "postureCheck", - Name: "postureCheck", + postureCheckA := &posture.Checks{ + Name: "postureCheckA", + AccountID: account.Id, + Checks: posture.ChecksDefinition{ + ProcessCheck: &posture.ProcessCheck{ + Processes: []posture.Process{ + {LinuxPath: "/usr/bin/netbird", MacPath: "/usr/local/bin/netbird"}, + }, + }, + }, + } + postureCheckA, err = manager.SavePostureChecks(context.Background(), account.Id, userID, postureCheckA) + require.NoError(t, err) + + postureCheckB := &posture.Checks{ + Name: "postureCheckB", AccountID: account.Id, Checks: posture.ChecksDefinition{ NBVersionCheck: &posture.NBVersionCheck{ @@ -169,7 +177,7 @@ func TestPostureCheckAccountPeersUpdate(t *testing.T) { close(done) }() - err := manager.SavePostureChecks(context.Background(), account.Id, userID, &postureCheck) + postureCheckB, err = manager.SavePostureChecks(context.Background(), account.Id, userID, postureCheckB) assert.NoError(t, err) select { @@ -187,12 +195,12 @@ func TestPostureCheckAccountPeersUpdate(t *testing.T) { close(done) }() - postureCheck.Checks = posture.ChecksDefinition{ + postureCheckB.Checks = posture.ChecksDefinition{ NBVersionCheck: &posture.NBVersionCheck{ MinVersion: "0.29.0", }, } - err := manager.SavePostureChecks(context.Background(), account.Id, userID, &postureCheck) + _, err = manager.SavePostureChecks(context.Background(), account.Id, userID, postureCheckB) assert.NoError(t, err) select { @@ -215,7 +223,7 @@ func TestPostureCheckAccountPeersUpdate(t *testing.T) { Action: PolicyTrafficActionAccept, }, }, - SourcePostureChecks: []string{postureCheck.ID}, + SourcePostureChecks: []string{postureCheckB.ID}, } // Linking posture check to policy should trigger update account peers and send peer update @@ -238,7 +246,7 @@ func TestPostureCheckAccountPeersUpdate(t *testing.T) { // Updating linked posture checks should update account peers and send peer update t.Run("updating linked to posture check with peers", func(t *testing.T) { - postureCheck.Checks = posture.ChecksDefinition{ + postureCheckB.Checks = posture.ChecksDefinition{ NBVersionCheck: &posture.NBVersionCheck{ MinVersion: "0.29.0", }, @@ -255,7 +263,7 @@ func TestPostureCheckAccountPeersUpdate(t *testing.T) { close(done) }() - err := manager.SavePostureChecks(context.Background(), account.Id, userID, &postureCheck) + _, err = manager.SavePostureChecks(context.Background(), account.Id, userID, postureCheckB) assert.NoError(t, err) select { @@ -293,7 +301,7 @@ func TestPostureCheckAccountPeersUpdate(t *testing.T) { close(done) }() - err := manager.DeletePostureChecks(context.Background(), account.Id, "postureCheck", userID) + err := manager.DeletePostureChecks(context.Background(), account.Id, postureCheckA.ID, userID) assert.NoError(t, err) select { @@ -303,7 +311,7 @@ func TestPostureCheckAccountPeersUpdate(t *testing.T) { } }) - err = manager.SavePostureChecks(context.Background(), account.Id, userID, &postureCheck) + _, err = manager.SavePostureChecks(context.Background(), account.Id, userID, postureCheckB) assert.NoError(t, err) // Updating linked posture check to policy with no peers should not trigger account peers update and not send peer update @@ -321,7 +329,7 @@ func TestPostureCheckAccountPeersUpdate(t *testing.T) { Action: PolicyTrafficActionAccept, }, }, - SourcePostureChecks: []string{postureCheck.ID}, + SourcePostureChecks: []string{postureCheckB.ID}, } err = manager.SavePolicy(context.Background(), account.Id, userID, &policy, false) assert.NoError(t, err) @@ -332,12 +340,12 @@ func TestPostureCheckAccountPeersUpdate(t *testing.T) { close(done) }() - postureCheck.Checks = posture.ChecksDefinition{ + postureCheckB.Checks = posture.ChecksDefinition{ NBVersionCheck: &posture.NBVersionCheck{ MinVersion: "0.29.0", }, } - err := manager.SavePostureChecks(context.Background(), account.Id, userID, &postureCheck) + _, err = manager.SavePostureChecks(context.Background(), account.Id, userID, postureCheckB) assert.NoError(t, err) select { @@ -367,7 +375,7 @@ func TestPostureCheckAccountPeersUpdate(t *testing.T) { Action: PolicyTrafficActionAccept, }, }, - SourcePostureChecks: []string{postureCheck.ID}, + SourcePostureChecks: []string{postureCheckB.ID}, } err = manager.SavePolicy(context.Background(), account.Id, userID, &policy, true) @@ -379,12 +387,12 @@ func TestPostureCheckAccountPeersUpdate(t *testing.T) { close(done) }() - postureCheck.Checks = posture.ChecksDefinition{ + postureCheckB.Checks = posture.ChecksDefinition{ NBVersionCheck: &posture.NBVersionCheck{ MinVersion: "0.29.0", }, } - err := manager.SavePostureChecks(context.Background(), account.Id, userID, &postureCheck) + _, err = manager.SavePostureChecks(context.Background(), account.Id, userID, postureCheckB) assert.NoError(t, err) select { @@ -409,7 +417,7 @@ func TestPostureCheckAccountPeersUpdate(t *testing.T) { Action: PolicyTrafficActionAccept, }, }, - SourcePostureChecks: []string{postureCheck.ID}, + SourcePostureChecks: []string{postureCheckB.ID}, } err = manager.SavePolicy(context.Background(), account.Id, userID, &policy, true) assert.NoError(t, err) @@ -420,7 +428,7 @@ func TestPostureCheckAccountPeersUpdate(t *testing.T) { close(done) }() - postureCheck.Checks = posture.ChecksDefinition{ + postureCheckB.Checks = posture.ChecksDefinition{ ProcessCheck: &posture.ProcessCheck{ Processes: []posture.Process{ { @@ -429,7 +437,7 @@ func TestPostureCheckAccountPeersUpdate(t *testing.T) { }, }, } - err := manager.SavePostureChecks(context.Background(), account.Id, userID, &postureCheck) + _, err = manager.SavePostureChecks(context.Background(), account.Id, userID, postureCheckB) assert.NoError(t, err) select { @@ -440,80 +448,123 @@ func TestPostureCheckAccountPeersUpdate(t *testing.T) { }) } -func TestArePostureCheckChangesAffectingPeers(t *testing.T) { - account := &Account{ - Policies: []*Policy{ - { - ID: "policyA", - Rules: []*PolicyRule{ - { - Enabled: true, - Sources: []string{"groupA"}, - Destinations: []string{"groupA"}, - }, - }, - SourcePostureChecks: []string{"checkA"}, - }, - }, - Groups: map[string]*group.Group{ - "groupA": { - ID: "groupA", - Peers: []string{"peer1"}, - }, - "groupB": { - ID: "groupB", - Peers: []string{}, - }, - }, - PostureChecks: []*posture.Checks{ - { - ID: "checkA", - }, - { - ID: "checkB", - }, - }, +func TestArePostureCheckChangesAffectPeers(t *testing.T) { + manager, err := createManager(t) + require.NoError(t, err, "failed to create account manager") + + account, err := initTestPostureChecksAccount(manager) + require.NoError(t, err, "failed to init testing account") + + groupA := &group.Group{ + ID: "groupA", + AccountID: account.Id, + Peers: []string{"peer1"}, } + groupB := &group.Group{ + ID: "groupB", + AccountID: account.Id, + Peers: []string{}, + } + err = manager.Store.SaveGroups(context.Background(), LockingStrengthUpdate, []*group.Group{groupA, groupB}) + require.NoError(t, err, "failed to save groups") + + postureCheckA := &posture.Checks{ + Name: "checkA", + AccountID: account.Id, + Checks: posture.ChecksDefinition{ + NBVersionCheck: &posture.NBVersionCheck{MinVersion: "0.33.1"}, + }, + } + postureCheckA, err = manager.SavePostureChecks(context.Background(), account.Id, adminUserID, postureCheckA) + require.NoError(t, err, "failed to save postureCheckA") + + postureCheckB := &posture.Checks{ + Name: "checkB", + AccountID: account.Id, + Checks: posture.ChecksDefinition{ + NBVersionCheck: &posture.NBVersionCheck{MinVersion: "0.33.1"}, + }, + } + postureCheckB, err = manager.SavePostureChecks(context.Background(), account.Id, adminUserID, postureCheckB) + require.NoError(t, err, "failed to save postureCheckB") + + policy := &Policy{ + ID: "policyA", + AccountID: account.Id, + Rules: []*PolicyRule{ + { + ID: "ruleA", + PolicyID: "policyA", + Enabled: true, + Sources: []string{"groupA"}, + Destinations: []string{"groupA"}, + }, + }, + SourcePostureChecks: []string{postureCheckA.ID}, + } + + err = manager.SavePolicy(context.Background(), account.Id, userID, policy, false) + require.NoError(t, err, "failed to save policy") + t.Run("posture check exists and is linked to policy with peers", func(t *testing.T) { - result := arePostureCheckChangesAffectingPeers(account, "checkA", true) + result, err := arePostureCheckChangesAffectPeers(context.Background(), manager.Store, account.Id, postureCheckA.ID) + require.NoError(t, err) assert.True(t, result) }) t.Run("posture check exists but is not linked to any policy", func(t *testing.T) { - result := arePostureCheckChangesAffectingPeers(account, "checkB", true) + result, err := arePostureCheckChangesAffectPeers(context.Background(), manager.Store, account.Id, postureCheckB.ID) + require.NoError(t, err) assert.False(t, result) }) t.Run("posture check does not exist", func(t *testing.T) { - result := arePostureCheckChangesAffectingPeers(account, "unknown", false) + result, err := arePostureCheckChangesAffectPeers(context.Background(), manager.Store, account.Id, "unknown") + require.NoError(t, err) assert.False(t, result) }) t.Run("posture check is linked to policy with no peers in source groups", func(t *testing.T) { - account.Policies[0].Rules[0].Sources = []string{"groupB"} - account.Policies[0].Rules[0].Destinations = []string{"groupA"} - result := arePostureCheckChangesAffectingPeers(account, "checkA", true) + policy.Rules[0].Sources = []string{"groupB"} + policy.Rules[0].Destinations = []string{"groupA"} + err = manager.SavePolicy(context.Background(), account.Id, userID, policy, true) + require.NoError(t, err, "failed to update policy") + + result, err := arePostureCheckChangesAffectPeers(context.Background(), manager.Store, account.Id, postureCheckA.ID) + require.NoError(t, err) assert.True(t, result) }) t.Run("posture check is linked to policy with no peers in destination groups", func(t *testing.T) { - account.Policies[0].Rules[0].Sources = []string{"groupA"} - account.Policies[0].Rules[0].Destinations = []string{"groupB"} - result := arePostureCheckChangesAffectingPeers(account, "checkA", true) + policy.Rules[0].Sources = []string{"groupA"} + policy.Rules[0].Destinations = []string{"groupB"} + err = manager.SavePolicy(context.Background(), account.Id, userID, policy, true) + require.NoError(t, err, "failed to update policy") + + result, err := arePostureCheckChangesAffectPeers(context.Background(), manager.Store, account.Id, postureCheckA.ID) + require.NoError(t, err) assert.True(t, result) }) - t.Run("posture check is linked to policy with non-existent group", func(t *testing.T) { - account.Policies[0].Rules[0].Sources = []string{"nonExistentGroup"} - account.Policies[0].Rules[0].Destinations = []string{"nonExistentGroup"} - result := arePostureCheckChangesAffectingPeers(account, "checkA", true) + t.Run("posture check is linked to policy but no peers in groups", func(t *testing.T) { + groupA.Peers = []string{} + err = manager.Store.SaveGroup(context.Background(), LockingStrengthUpdate, groupA) + require.NoError(t, err, "failed to save groups") + + result, err := arePostureCheckChangesAffectPeers(context.Background(), manager.Store, account.Id, postureCheckA.ID) + require.NoError(t, err) assert.False(t, result) }) - t.Run("posture check is linked to policy but no peers in groups", func(t *testing.T) { - account.Groups["groupA"].Peers = []string{} - result := arePostureCheckChangesAffectingPeers(account, "checkA", true) + t.Run("posture check is linked to policy with non-existent group", func(t *testing.T) { + policy.Rules[0].Sources = []string{"nonExistentGroup"} + policy.Rules[0].Destinations = []string{"nonExistentGroup"} + err = manager.SavePolicy(context.Background(), account.Id, userID, policy, true) + require.NoError(t, err, "failed to update policy") + + result, err := arePostureCheckChangesAffectPeers(context.Background(), manager.Store, account.Id, postureCheckA.ID) + require.NoError(t, err) assert.False(t, result) }) } diff --git a/management/server/route.go b/management/server/route.go index dcf2cb0d3..ecb562645 100644 --- a/management/server/route.go +++ b/management/server/route.go @@ -237,7 +237,7 @@ func (am *DefaultAccountManager) CreateRoute(ctx context.Context, accountID stri return nil, err } - if isRouteChangeAffectPeers(account, &newRoute) { + if am.isRouteChangeAffectPeers(account, &newRoute) { am.updateAccountPeers(ctx, accountID) } @@ -323,7 +323,7 @@ func (am *DefaultAccountManager) SaveRoute(ctx context.Context, accountID, userI return err } - if isRouteChangeAffectPeers(account, oldRoute) || isRouteChangeAffectPeers(account, routeToSave) { + if am.isRouteChangeAffectPeers(account, oldRoute) || am.isRouteChangeAffectPeers(account, routeToSave) { am.updateAccountPeers(ctx, accountID) } @@ -355,7 +355,7 @@ func (am *DefaultAccountManager) DeleteRoute(ctx context.Context, accountID stri am.StoreEvent(ctx, userID, string(routy.ID), accountID, activity.RouteRemoved, routy.EventMeta()) - if isRouteChangeAffectPeers(account, routy) { + if am.isRouteChangeAffectPeers(account, routy) { am.updateAccountPeers(ctx, accountID) } @@ -651,6 +651,6 @@ func getProtoPortInfo(rule *RouteFirewallRule) *proto.PortInfo { // isRouteChangeAffectPeers checks if a given route affects peers by determining // if it has a routing peer, distribution, or peer groups that include peers -func isRouteChangeAffectPeers(account *Account, route *route.Route) bool { - return anyGroupHasPeers(account, route.Groups) || anyGroupHasPeers(account, route.PeerGroups) || route.Peer != "" +func (am *DefaultAccountManager) isRouteChangeAffectPeers(account *Account, route *route.Route) bool { + return am.anyGroupHasPeers(account, route.Groups) || am.anyGroupHasPeers(account, route.PeerGroups) || route.Peer != "" } diff --git a/management/server/sql_store.go b/management/server/sql_store.go index 278f5443d..47c17bb92 100644 --- a/management/server/sql_store.go +++ b/management/server/sql_store.go @@ -1305,12 +1305,57 @@ func (s *SqlStore) GetPolicyByID(ctx context.Context, lockStrength LockingStreng // GetAccountPostureChecks retrieves posture checks for an account. func (s *SqlStore) GetAccountPostureChecks(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*posture.Checks, error) { - return getRecords[*posture.Checks](s.db, lockStrength, accountID) + var postureChecks []*posture.Checks + result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}).Find(&postureChecks, accountIDCondition, accountID) + if result.Error != nil { + log.WithContext(ctx).Errorf("failed to get posture checks from store: %s", result.Error) + return nil, status.Errorf(status.Internal, "failed to get posture checks from store") + } + + return postureChecks, nil } // GetPostureChecksByID retrieves posture checks by their ID and account ID. -func (s *SqlStore) GetPostureChecksByID(ctx context.Context, lockStrength LockingStrength, postureCheckID string, accountID string) (*posture.Checks, error) { - return getRecordByID[posture.Checks](s.db, lockStrength, postureCheckID, accountID) +func (s *SqlStore) GetPostureChecksByID(ctx context.Context, lockStrength LockingStrength, accountID, postureChecksID string) (*posture.Checks, error) { + var postureCheck *posture.Checks + result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}). + First(&postureCheck, accountAndIDQueryCondition, accountID, postureChecksID) + if result.Error != nil { + if errors.Is(result.Error, gorm.ErrRecordNotFound) { + return nil, status.NewPostureChecksNotFoundError(postureChecksID) + } + log.WithContext(ctx).Errorf("failed to get posture check from store: %s", result.Error) + return nil, status.Errorf(status.Internal, "failed to get posture check from store") + } + + return postureCheck, nil +} + +// SavePostureChecks saves a posture checks to the database. +func (s *SqlStore) SavePostureChecks(ctx context.Context, lockStrength LockingStrength, postureCheck *posture.Checks) error { + result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}).Save(postureCheck) + if result.Error != nil { + log.WithContext(ctx).Errorf("failed to save posture checks to store: %s", result.Error) + return status.Errorf(status.Internal, "failed to save posture checks to store") + } + + return nil +} + +// DeletePostureChecks deletes a posture checks from the database. +func (s *SqlStore) DeletePostureChecks(ctx context.Context, lockStrength LockingStrength, accountID, postureChecksID string) error { + result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}). + Delete(&posture.Checks{}, accountAndIDQueryCondition, accountID, postureChecksID) + if result.Error != nil { + log.WithContext(ctx).Errorf("failed to delete posture checks from store: %s", result.Error) + return status.Errorf(status.Internal, "failed to delete posture checks from store") + } + + if result.RowsAffected == 0 { + return status.NewPostureChecksNotFoundError(postureChecksID) + } + + return nil } // GetAccountRoutes retrieves network routes for an account. diff --git a/management/server/sql_store_test.go b/management/server/sql_store_test.go index 114da1ee6..de939e8d0 100644 --- a/management/server/sql_store_test.go +++ b/management/server/sql_store_test.go @@ -16,6 +16,7 @@ import ( "github.com/google/uuid" nbdns "github.com/netbirdio/netbird/dns" nbgroup "github.com/netbirdio/netbird/management/server/group" + "github.com/netbirdio/netbird/management/server/posture" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" @@ -1564,3 +1565,137 @@ func TestSqlStore_GetPeersByIDs(t *testing.T) { }) } } + +func TestSqlStore_GetPostureChecksByID(t *testing.T) { + store, cleanup, err := NewTestStoreFromSQL(context.Background(), "testdata/extended-store.sql", t.TempDir()) + t.Cleanup(cleanup) + require.NoError(t, err) + + accountID := "bf1c8084-ba50-4ce7-9439-34653001fc3b" + tests := []struct { + name string + postureChecksID string + expectError bool + }{ + { + name: "retrieve existing posture checks", + postureChecksID: "csplshq7qv948l48f7t0", + expectError: false, + }, + { + name: "retrieve non-existing posture checks", + postureChecksID: "non-existing", + expectError: true, + }, + { + name: "retrieve with empty posture checks ID", + postureChecksID: "", + expectError: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + postureChecks, err := store.GetPostureChecksByID(context.Background(), LockingStrengthShare, accountID, tt.postureChecksID) + if tt.expectError { + require.Error(t, err) + sErr, ok := status.FromError(err) + require.True(t, ok) + require.Equal(t, sErr.Type(), status.NotFound) + require.Nil(t, postureChecks) + } else { + require.NoError(t, err) + require.NotNil(t, postureChecks) + require.Equal(t, tt.postureChecksID, postureChecks.ID) + } + }) + } +} + +func TestSqlStore_SavePostureChecks(t *testing.T) { + store, cleanup, err := NewTestStoreFromSQL(context.Background(), "testdata/extended-store.sql", t.TempDir()) + t.Cleanup(cleanup) + require.NoError(t, err) + + accountID := "bf1c8084-ba50-4ce7-9439-34653001fc3b" + + postureChecks := &posture.Checks{ + ID: "posture-checks-id", + AccountID: accountID, + Checks: posture.ChecksDefinition{ + NBVersionCheck: &posture.NBVersionCheck{ + MinVersion: "0.31.0", + }, + OSVersionCheck: &posture.OSVersionCheck{ + Ios: &posture.MinVersionCheck{ + MinVersion: "13.0.1", + }, + Linux: &posture.MinKernelVersionCheck{ + MinKernelVersion: "5.3.3-dev", + }, + }, + GeoLocationCheck: &posture.GeoLocationCheck{ + Locations: []posture.Location{ + { + CountryCode: "DE", + CityName: "Berlin", + }, + }, + Action: posture.CheckActionAllow, + }, + }, + } + err = store.SavePostureChecks(context.Background(), LockingStrengthUpdate, postureChecks) + require.NoError(t, err) + + savePostureChecks, err := store.GetPostureChecksByID(context.Background(), LockingStrengthShare, accountID, "posture-checks-id") + require.NoError(t, err) + require.Equal(t, savePostureChecks, postureChecks) +} + +func TestSqlStore_DeletePostureChecks(t *testing.T) { + store, cleanup, err := NewTestStoreFromSQL(context.Background(), "testdata/extended-store.sql", t.TempDir()) + t.Cleanup(cleanup) + require.NoError(t, err) + + accountID := "bf1c8084-ba50-4ce7-9439-34653001fc3b" + + tests := []struct { + name string + postureChecksID string + expectError bool + }{ + { + name: "delete existing posture checks", + postureChecksID: "csplshq7qv948l48f7t0", + expectError: false, + }, + { + name: "delete non-existing posture checks", + postureChecksID: "non-existing-posture-checks-id", + expectError: true, + }, + { + name: "delete with empty posture checks ID", + postureChecksID: "", + expectError: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + err = store.DeletePostureChecks(context.Background(), LockingStrengthUpdate, accountID, tt.postureChecksID) + if tt.expectError { + require.Error(t, err) + sErr, ok := status.FromError(err) + require.True(t, ok) + require.Equal(t, sErr.Type(), status.NotFound) + } else { + require.NoError(t, err) + group, err := store.GetPostureChecksByID(context.Background(), LockingStrengthShare, accountID, tt.postureChecksID) + require.Error(t, err) + require.Nil(t, group) + } + }) + } +} diff --git a/management/server/status/error.go b/management/server/status/error.go index 8b6d0077b..44391e1f1 100644 --- a/management/server/status/error.go +++ b/management/server/status/error.go @@ -139,3 +139,8 @@ func NewGetAccountError(err error) error { func NewGroupNotFoundError(groupID string) error { return Errorf(NotFound, "group: %s not found", groupID) } + +// NewPostureChecksNotFoundError creates a new Error with NotFound type for a missing posture checks +func NewPostureChecksNotFoundError(postureChecksID string) error { + return Errorf(NotFound, "posture checks: %s not found", postureChecksID) +} diff --git a/management/server/store.go b/management/server/store.go index 71b0d457b..03b5821e7 100644 --- a/management/server/store.go +++ b/management/server/store.go @@ -84,7 +84,9 @@ type Store interface { GetPostureCheckByChecksDefinition(accountID string, checks *posture.ChecksDefinition) (*posture.Checks, error) GetAccountPostureChecks(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*posture.Checks, error) - GetPostureChecksByID(ctx context.Context, lockStrength LockingStrength, postureCheckID string, accountID string) (*posture.Checks, error) + GetPostureChecksByID(ctx context.Context, lockStrength LockingStrength, accountID, postureCheckID string) (*posture.Checks, error) + SavePostureChecks(ctx context.Context, lockStrength LockingStrength, postureCheck *posture.Checks) error + DeletePostureChecks(ctx context.Context, lockStrength LockingStrength, accountID, postureChecksID string) error GetPeerLabelsInAccount(ctx context.Context, lockStrength LockingStrength, accountId string) ([]string, error) AddPeerToAllGroup(ctx context.Context, accountID string, peerID string) error diff --git a/management/server/testdata/extended-store.sql b/management/server/testdata/extended-store.sql index b522741e7..1646ff4da 100644 --- a/management/server/testdata/extended-store.sql +++ b/management/server/testdata/extended-store.sql @@ -34,4 +34,5 @@ INSERT INTO personal_access_tokens VALUES('9dj38s35-63fb-11ec-90d6-0242ac120003' INSERT INTO "groups" VALUES('cfefqs706sqkneg59g4g','bf1c8084-ba50-4ce7-9439-34653001fc3b','All','api','[]',0,''); INSERT INTO "groups" VALUES('cfefqs706sqkneg59g3g','bf1c8084-ba50-4ce7-9439-34653001fc3b','AwesomeGroup1','api','[]',0,''); INSERT INTO "groups" VALUES('cfefqs706sqkneg59g2g','bf1c8084-ba50-4ce7-9439-34653001fc3b','AwesomeGroup2','api','[]',0,''); +INSERT INTO posture_checks VALUES('csplshq7qv948l48f7t0','NetBird Version > 0.32.0','','bf1c8084-ba50-4ce7-9439-34653001fc3b','{"NBVersionCheck":{"MinVersion":"0.31.0"}}'); INSERT INTO installations VALUES(1,'');