From 67ce14eaea61f63b57151724b1079ebaae5b885d Mon Sep 17 00:00:00 2001 From: Pascal Fischer <32096965+pascal-fischer@users.noreply.github.com> Date: Fri, 8 Nov 2024 18:47:22 +0100 Subject: [PATCH 01/10] [management] Add peer lock to grpc server (#2859) * add peer lock to grpc server * remove sleep and put db update first * don't export lock method --- management/server/grpcserver.go | 35 ++++++++++++++++++++++++++++++++- 1 file changed, 34 insertions(+), 1 deletion(-) diff --git a/management/server/grpcserver.go b/management/server/grpcserver.go index 4c4ef6c3c..efe088b27 100644 --- a/management/server/grpcserver.go +++ b/management/server/grpcserver.go @@ -6,6 +6,7 @@ import ( "net" "net/netip" "strings" + "sync" "time" pb "github.com/golang/protobuf/proto" // nolint @@ -38,6 +39,7 @@ type GRPCServer struct { jwtClaimsExtractor *jwtclaims.ClaimsExtractor appMetrics telemetry.AppMetrics ephemeralManager *EphemeralManager + peerLocks sync.Map } // NewServer creates a new Management server @@ -148,6 +150,13 @@ func (s *GRPCServer) Sync(req *proto.EncryptedMessage, srv proto.ManagementServi // nolint:staticcheck ctx = context.WithValue(ctx, nbContext.PeerIDKey, peerKey.String()) + unlock := s.acquirePeerLockByUID(ctx, peerKey.String()) + defer func() { + if unlock != nil { + unlock() + } + }() + accountID, err := s.accountManager.GetAccountIDForPeerKey(ctx, peerKey.String()) if err != nil { // nolint:staticcheck @@ -190,6 +199,9 @@ func (s *GRPCServer) Sync(req *proto.EncryptedMessage, srv proto.ManagementServi s.appMetrics.GRPCMetrics().CountSyncRequestDuration(time.Since(reqStart)) } + unlock() + unlock = nil + return s.handleUpdates(ctx, accountID, peerKey, peer, updates, srv) } @@ -245,9 +257,12 @@ func (s *GRPCServer) sendUpdate(ctx context.Context, accountID string, peerKey w } func (s *GRPCServer) cancelPeerRoutines(ctx context.Context, accountID string, peer *nbpeer.Peer) { + unlock := s.acquirePeerLockByUID(ctx, peer.Key) + defer unlock() + + _ = s.accountManager.OnPeerDisconnected(ctx, accountID, peer.Key) s.peersUpdateManager.CloseChannel(ctx, peer.ID) s.secretsManager.CancelRefresh(peer.ID) - _ = s.accountManager.OnPeerDisconnected(ctx, accountID, peer.Key) s.ephemeralManager.OnPeerDisconnected(ctx, peer) } @@ -274,6 +289,24 @@ func (s *GRPCServer) validateToken(ctx context.Context, jwtToken string) (string return claims.UserId, nil } +func (s *GRPCServer) acquirePeerLockByUID(ctx context.Context, uniqueID string) (unlock func()) { + log.WithContext(ctx).Tracef("acquiring peer lock for ID %s", uniqueID) + + start := time.Now() + value, _ := s.peerLocks.LoadOrStore(uniqueID, &sync.RWMutex{}) + mtx := value.(*sync.RWMutex) + mtx.Lock() + log.WithContext(ctx).Tracef("acquired peer lock for ID %s in %v", uniqueID, time.Since(start)) + start = time.Now() + + unlock = func() { + mtx.Unlock() + log.WithContext(ctx).Tracef("released peer lock for ID %s in %v", uniqueID, time.Since(start)) + } + + return unlock +} + // maps internal internalStatus.Error to gRPC status.Error func mapError(ctx context.Context, err error) error { if e, ok := internalStatus.FromError(err); ok { From 08b6e9d647e3eb1bcf6d6be91ef88fda33c9f5ec Mon Sep 17 00:00:00 2001 From: Viktor Liu <17948409+lixmal@users.noreply.github.com> Date: Fri, 8 Nov 2024 23:28:02 +0100 Subject: [PATCH 02/10] [management] Fix api error message typo peers_group (#2862) --- management/server/http/routes_handler.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/management/server/http/routes_handler.go b/management/server/http/routes_handler.go index ce4edee4f..f44a164e2 100644 --- a/management/server/http/routes_handler.go +++ b/management/server/http/routes_handler.go @@ -149,7 +149,7 @@ func (h *RoutesHandler) validateRoute(req api.PostApiRoutesJSONRequestBody) erro } if req.Peer == nil && req.PeerGroups == nil { - return status.Errorf(status.InvalidArgument, "either 'peer' or 'peers_group' should be provided") + return status.Errorf(status.InvalidArgument, "either 'peer' or 'peer_groups' should be provided") } if req.Peer != nil && req.PeerGroups != nil { From b4d7605147e6ddbd22c214b42ef43267bc78ce80 Mon Sep 17 00:00:00 2001 From: Zoltan Papp Date: Mon, 11 Nov 2024 10:53:57 +0100 Subject: [PATCH 03/10] [client] Remove loop after route calculation (#2856) - ICE do not trigger disconnect callbacks if the stated did not change - Fix route calculation callback loop - Move route state updates into protected scope by mutex - Do not calculate routes in case of peer.Open() and peer.Close() --- client/internal/peer/status.go | 88 +++++++++++-------- client/internal/peer/status_test.go | 2 +- client/internal/peer/worker_ice.go | 38 +++++--- client/internal/routemanager/client.go | 66 ++++++-------- .../routemanager/refcounter/refcounter.go | 5 ++ 5 files changed, 114 insertions(+), 85 deletions(-) diff --git a/client/internal/peer/status.go b/client/internal/peer/status.go index a28992fac..241dfabbb 100644 --- a/client/internal/peer/status.go +++ b/client/internal/peer/status.go @@ -237,10 +237,6 @@ func (d *Status) UpdatePeerState(receivedState State) error { peerState.IP = receivedState.IP } - if receivedState.GetRoutes() != nil { - peerState.SetRoutes(receivedState.GetRoutes()) - } - skipNotification := shouldSkipNotify(receivedState.ConnStatus, peerState) if receivedState.ConnStatus != peerState.ConnStatus { @@ -261,12 +257,40 @@ func (d *Status) UpdatePeerState(receivedState State) error { return nil } - ch, found := d.changeNotify[receivedState.PubKey] - if found && ch != nil { - close(ch) - d.changeNotify[receivedState.PubKey] = nil + d.notifyPeerListChanged() + return nil +} + +func (d *Status) AddPeerStateRoute(peer string, route string) error { + d.mux.Lock() + defer d.mux.Unlock() + + peerState, ok := d.peers[peer] + if !ok { + return errors.New("peer doesn't exist") } + peerState.AddRoute(route) + d.peers[peer] = peerState + + // todo: consider to make sense of this notification or not + d.notifyPeerListChanged() + return nil +} + +func (d *Status) RemovePeerStateRoute(peer string, route string) error { + d.mux.Lock() + defer d.mux.Unlock() + + peerState, ok := d.peers[peer] + if !ok { + return errors.New("peer doesn't exist") + } + + peerState.DeleteRoute(route) + d.peers[peer] = peerState + + // todo: consider to make sense of this notification or not d.notifyPeerListChanged() return nil } @@ -301,12 +325,7 @@ func (d *Status) UpdatePeerICEState(receivedState State) error { return nil } - ch, found := d.changeNotify[receivedState.PubKey] - if found && ch != nil { - close(ch) - d.changeNotify[receivedState.PubKey] = nil - } - + d.notifyPeerStateChangeListeners(receivedState.PubKey) d.notifyPeerListChanged() return nil } @@ -334,12 +353,7 @@ func (d *Status) UpdatePeerRelayedState(receivedState State) error { return nil } - ch, found := d.changeNotify[receivedState.PubKey] - if found && ch != nil { - close(ch) - d.changeNotify[receivedState.PubKey] = nil - } - + d.notifyPeerStateChangeListeners(receivedState.PubKey) d.notifyPeerListChanged() return nil } @@ -366,12 +380,7 @@ func (d *Status) UpdatePeerRelayedStateToDisconnected(receivedState State) error return nil } - ch, found := d.changeNotify[receivedState.PubKey] - if found && ch != nil { - close(ch) - d.changeNotify[receivedState.PubKey] = nil - } - + d.notifyPeerStateChangeListeners(receivedState.PubKey) d.notifyPeerListChanged() return nil } @@ -401,12 +410,7 @@ func (d *Status) UpdatePeerICEStateToDisconnected(receivedState State) error { return nil } - ch, found := d.changeNotify[receivedState.PubKey] - if found && ch != nil { - close(ch) - d.changeNotify[receivedState.PubKey] = nil - } - + d.notifyPeerStateChangeListeners(receivedState.PubKey) d.notifyPeerListChanged() return nil } @@ -477,11 +481,14 @@ func (d *Status) FinishPeerListModifications() { func (d *Status) GetPeerStateChangeNotifier(peer string) <-chan struct{} { d.mux.Lock() defer d.mux.Unlock() + ch, found := d.changeNotify[peer] - if !found || ch == nil { - ch = make(chan struct{}) - d.changeNotify[peer] = ch + if found { + return ch } + + ch = make(chan struct{}) + d.changeNotify[peer] = ch return ch } @@ -755,6 +762,17 @@ func (d *Status) onConnectionChanged() { d.notifier.updateServerStates(d.managementState, d.signalState) } +// notifyPeerStateChangeListeners notifies route manager about the change in peer state +func (d *Status) notifyPeerStateChangeListeners(peerID string) { + ch, found := d.changeNotify[peerID] + if !found { + return + } + + close(ch) + delete(d.changeNotify, peerID) +} + func (d *Status) notifyPeerListChanged() { d.notifier.peerListChanged(d.numOfPeers()) } diff --git a/client/internal/peer/status_test.go b/client/internal/peer/status_test.go index 1d283433b..931ec9005 100644 --- a/client/internal/peer/status_test.go +++ b/client/internal/peer/status_test.go @@ -93,7 +93,7 @@ func TestGetPeerStateChangeNotifierLogic(t *testing.T) { peerState.IP = ip - err := status.UpdatePeerState(peerState) + err := status.UpdatePeerRelayedStateToDisconnected(peerState) assert.NoError(t, err, "shouldn't return error") select { diff --git a/client/internal/peer/worker_ice.go b/client/internal/peer/worker_ice.go index 55894218d..4c67cb781 100644 --- a/client/internal/peer/worker_ice.go +++ b/client/internal/peer/worker_ice.go @@ -57,6 +57,9 @@ type WorkerICE struct { localUfrag string localPwd string + + // we record the last known state of the ICE agent to avoid duplicate on disconnected events + lastKnownState ice.ConnectionState } func NewWorkerICE(ctx context.Context, log *log.Entry, config ConnConfig, signaler *Signaler, ifaceDiscover stdnet.ExternalIFaceDiscover, statusRecorder *Status, hasRelayOnLocally bool, callBacks WorkerICECallbacks) (*WorkerICE, error) { @@ -194,8 +197,7 @@ func (w *WorkerICE) Close() { return } - err := w.agent.Close() - if err != nil { + if err := w.agent.Close(); err != nil { w.log.Warnf("failed to close ICE agent: %s", err) } } @@ -215,15 +217,18 @@ func (w *WorkerICE) reCreateAgent(agentCancel context.CancelFunc, candidates []i err = agent.OnConnectionStateChange(func(state ice.ConnectionState) { w.log.Debugf("ICE ConnectionState has changed to %s", state.String()) - if state == ice.ConnectionStateFailed || state == ice.ConnectionStateDisconnected { - w.conn.OnStatusChanged(StatusDisconnected) - - w.muxAgent.Lock() - agentCancel() - _ = agent.Close() - w.agent = nil - - w.muxAgent.Unlock() + switch state { + case ice.ConnectionStateConnected: + w.lastKnownState = ice.ConnectionStateConnected + return + case ice.ConnectionStateFailed, ice.ConnectionStateDisconnected: + if w.lastKnownState != ice.ConnectionStateDisconnected { + w.lastKnownState = ice.ConnectionStateDisconnected + w.conn.OnStatusChanged(StatusDisconnected) + } + w.closeAgent(agentCancel) + default: + return } }) if err != nil { @@ -249,6 +254,17 @@ func (w *WorkerICE) reCreateAgent(agentCancel context.CancelFunc, candidates []i return agent, nil } +func (w *WorkerICE) closeAgent(cancel context.CancelFunc) { + w.muxAgent.Lock() + defer w.muxAgent.Unlock() + + cancel() + if err := w.agent.Close(); err != nil { + w.log.Warnf("failed to close ICE agent: %s", err) + } + w.agent = nil +} + func (w *WorkerICE) punchRemoteWGPort(pair *ice.CandidatePair, remoteWgPort int) { // wait local endpoint configuration time.Sleep(time.Second) diff --git a/client/internal/routemanager/client.go b/client/internal/routemanager/client.go index eaa232151..13e45b3a3 100644 --- a/client/internal/routemanager/client.go +++ b/client/internal/routemanager/client.go @@ -122,13 +122,20 @@ func (c *clientNetwork) getBestRouteFromStatuses(routePeerStatuses map[route.ID] tempScore = float64(metricDiff) * 10 } - // in some temporal cases, latency can be 0, so we set it to 1s to not block but try to avoid this route - latency := time.Second + // in some temporal cases, latency can be 0, so we set it to 999ms to not block but try to avoid this route + latency := 999 * time.Millisecond if peerStatus.latency != 0 { latency = peerStatus.latency } else { - log.Warnf("peer %s has 0 latency", r.Peer) + log.Tracef("peer %s has 0 latency, range %s", r.Peer, c.handler) } + + // avoid negative tempScore on the higher latency calculation + if latency > 1*time.Second { + latency = 999 * time.Millisecond + } + + // higher latency is worse score tempScore += 1 - latency.Seconds() if !peerStatus.relayed { @@ -150,6 +157,8 @@ func (c *clientNetwork) getBestRouteFromStatuses(routePeerStatuses map[route.ID] } } + log.Debugf("chosen route: %s, chosen score: %f, current route: %s, current score: %f", chosen, chosenScore, currID, currScore) + switch { case chosen == "": var peers []string @@ -195,15 +204,20 @@ func (c *clientNetwork) watchPeerStatusChanges(ctx context.Context, peerKey stri func (c *clientNetwork) startPeersStatusChangeWatcher() { for _, r := range c.routes { _, found := c.routePeersNotifiers[r.Peer] - if !found { - c.routePeersNotifiers[r.Peer] = make(chan struct{}) - go c.watchPeerStatusChanges(c.ctx, r.Peer, c.peerStateUpdate, c.routePeersNotifiers[r.Peer]) + if found { + continue } + + closerChan := make(chan struct{}) + c.routePeersNotifiers[r.Peer] = closerChan + go c.watchPeerStatusChanges(c.ctx, r.Peer, c.peerStateUpdate, closerChan) } } -func (c *clientNetwork) removeRouteFromWireguardPeer() error { - c.removeStateRoute() +func (c *clientNetwork) removeRouteFromWireGuardPeer() error { + if err := c.statusRecorder.RemovePeerStateRoute(c.currentChosen.Peer, c.handler.String()); err != nil { + log.Warnf("Failed to update peer state: %v", err) + } if err := c.handler.RemoveAllowedIPs(); err != nil { return fmt.Errorf("remove allowed IPs: %w", err) @@ -218,7 +232,7 @@ func (c *clientNetwork) removeRouteFromPeerAndSystem() error { var merr *multierror.Error - if err := c.removeRouteFromWireguardPeer(); err != nil { + if err := c.removeRouteFromWireGuardPeer(); err != nil { merr = multierror.Append(merr, fmt.Errorf("remove allowed IPs for peer %s: %w", c.currentChosen.Peer, err)) } if err := c.handler.RemoveRoute(); err != nil { @@ -257,7 +271,7 @@ func (c *clientNetwork) recalculateRouteAndUpdatePeerAndSystem() error { } } else { // Otherwise, remove the allowed IPs from the previous peer first - if err := c.removeRouteFromWireguardPeer(); err != nil { + if err := c.removeRouteFromWireGuardPeer(); err != nil { return fmt.Errorf("remove allowed IPs for peer %s: %w", c.currentChosen.Peer, err) } } @@ -268,37 +282,13 @@ func (c *clientNetwork) recalculateRouteAndUpdatePeerAndSystem() error { return fmt.Errorf("add allowed IPs for peer %s: %w", c.currentChosen.Peer, err) } - c.addStateRoute() - + err := c.statusRecorder.AddPeerStateRoute(c.currentChosen.Peer, c.handler.String()) + if err != nil { + return fmt.Errorf("add peer state route: %w", err) + } return nil } -func (c *clientNetwork) addStateRoute() { - state, err := c.statusRecorder.GetPeer(c.currentChosen.Peer) - if err != nil { - log.Errorf("Failed to get peer state: %v", err) - return - } - - state.AddRoute(c.handler.String()) - if err := c.statusRecorder.UpdatePeerState(state); err != nil { - log.Warnf("Failed to update peer state: %v", err) - } -} - -func (c *clientNetwork) removeStateRoute() { - state, err := c.statusRecorder.GetPeer(c.currentChosen.Peer) - if err != nil { - log.Errorf("Failed to get peer state: %v", err) - return - } - - state.DeleteRoute(c.handler.String()) - if err := c.statusRecorder.UpdatePeerState(state); err != nil { - log.Warnf("Failed to update peer state: %v", err) - } -} - func (c *clientNetwork) sendUpdateToClientNetworkWatcher(update routesUpdate) { go func() { c.routeUpdate <- update diff --git a/client/internal/routemanager/refcounter/refcounter.go b/client/internal/routemanager/refcounter/refcounter.go index c121b7d77..0e230ef40 100644 --- a/client/internal/routemanager/refcounter/refcounter.go +++ b/client/internal/routemanager/refcounter/refcounter.go @@ -217,6 +217,11 @@ func (rm *Counter[Key, I, O]) Clear() { // MarshalJSON implements the json.Marshaler interface for Counter. func (rm *Counter[Key, I, O]) MarshalJSON() ([]byte, error) { + rm.refCountMu.Lock() + defer rm.refCountMu.Unlock() + rm.idMu.Lock() + defer rm.idMu.Unlock() + return json.Marshal(struct { RefCountMap map[Key]Ref[O] `json:"refCountMap"` IDMap map[string][]Key `json:"idMap"` From 30f025e7dd56a4b8dc7a50dc47b0896ca00e08f3 Mon Sep 17 00:00:00 2001 From: Zoltan Papp Date: Mon, 11 Nov 2024 14:18:38 +0100 Subject: [PATCH 04/10] [client] fix/proxy close (#2873) When the remote peer switches the Relay instance then must to close the proxy connection to the old instance. It can cause issues when the remote peer switch connects to the Relay instance multiple times and then reconnects to an instance it had previously connected to. --- client/iface/wgproxy/bind/proxy.go | 6 +++++- client/iface/wgproxy/ebpf/wrapper.go | 2 +- client/iface/wgproxy/udp/proxy.go | 2 +- client/internal/peer/conn.go | 13 +++++++++++-- relay/client/client.go | 7 +++---- relay/client/conn.go | 3 +-- relay/server/listener/ws/conn.go | 5 ++--- relay/server/peer.go | 4 ++-- 8 files changed, 26 insertions(+), 16 deletions(-) diff --git a/client/iface/wgproxy/bind/proxy.go b/client/iface/wgproxy/bind/proxy.go index e0883715a..8a2e65382 100644 --- a/client/iface/wgproxy/bind/proxy.go +++ b/client/iface/wgproxy/bind/proxy.go @@ -2,6 +2,7 @@ package bind import ( "context" + "errors" "fmt" "net" "net/netip" @@ -94,7 +95,10 @@ func (p *ProxyBind) close() error { p.Bind.RemoveEndpoint(p.wgAddr) - return p.remoteConn.Close() + if rErr := p.remoteConn.Close(); rErr != nil && !errors.Is(rErr, net.ErrClosed) { + return rErr + } + return nil } func (p *ProxyBind) proxyToLocal(ctx context.Context) { diff --git a/client/iface/wgproxy/ebpf/wrapper.go b/client/iface/wgproxy/ebpf/wrapper.go index efd5fd946..54cab4e1b 100644 --- a/client/iface/wgproxy/ebpf/wrapper.go +++ b/client/iface/wgproxy/ebpf/wrapper.go @@ -77,7 +77,7 @@ func (e *ProxyWrapper) CloseConn() error { e.cancel() - if err := e.remoteConn.Close(); err != nil { + if err := e.remoteConn.Close(); err != nil && !errors.Is(err, net.ErrClosed) { return fmt.Errorf("failed to close remote conn: %w", err) } return nil diff --git a/client/iface/wgproxy/udp/proxy.go b/client/iface/wgproxy/udp/proxy.go index 200d961f3..ba0004b8a 100644 --- a/client/iface/wgproxy/udp/proxy.go +++ b/client/iface/wgproxy/udp/proxy.go @@ -116,7 +116,7 @@ func (p *WGUDPProxy) close() error { p.cancel() var result *multierror.Error - if err := p.remoteConn.Close(); err != nil { + if err := p.remoteConn.Close(); err != nil && !errors.Is(err, net.ErrClosed) { result = multierror.Append(result, fmt.Errorf("remote conn: %s", err)) } diff --git a/client/internal/peer/conn.go b/client/internal/peer/conn.go index 84a8c221f..81c456db7 100644 --- a/client/internal/peer/conn.go +++ b/client/internal/peer/conn.go @@ -442,7 +442,7 @@ func (conn *Conn) relayConnectionIsReady(rci RelayConnInfo) { if conn.iceP2PIsActive() { conn.log.Debugf("do not switch to relay because current priority is: %v", conn.currentConnPriority) - conn.wgProxyRelay = wgProxy + conn.setRelayedProxy(wgProxy) conn.statusRelay.Set(StatusConnected) conn.updateRelayStatus(rci.relayedConn.RemoteAddr().String(), rci.rosenpassPubKey) return @@ -465,7 +465,7 @@ func (conn *Conn) relayConnectionIsReady(rci RelayConnInfo) { wgConfigWorkaround() conn.currentConnPriority = connPriorityRelay conn.statusRelay.Set(StatusConnected) - conn.wgProxyRelay = wgProxy + conn.setRelayedProxy(wgProxy) conn.updateRelayStatus(rci.relayedConn.RemoteAddr().String(), rci.rosenpassPubKey) conn.log.Infof("start to communicate with peer via relay") conn.doOnConnected(rci.rosenpassPubKey, rci.rosenpassAddr) @@ -736,6 +736,15 @@ func (conn *Conn) logTraceConnState() { } } +func (conn *Conn) setRelayedProxy(proxy wgproxy.Proxy) { + if conn.wgProxyRelay != nil { + if err := conn.wgProxyRelay.CloseConn(); err != nil { + conn.log.Warnf("failed to close deprecated wg proxy conn: %v", err) + } + } + conn.wgProxyRelay = proxy +} + func isController(config ConnConfig) bool { return config.LocalKey > config.Key } diff --git a/relay/client/client.go b/relay/client/client.go index a82a75453..154c1787f 100644 --- a/relay/client/client.go +++ b/relay/client/client.go @@ -3,7 +3,6 @@ package client import ( "context" "fmt" - "io" "net" "sync" "time" @@ -449,11 +448,11 @@ func (c *Client) writeTo(connReference *Conn, id string, dstID []byte, payload [ conn, ok := c.conns[id] c.mu.Unlock() if !ok { - return 0, io.EOF + return 0, net.ErrClosed } if conn.conn != connReference { - return 0, io.EOF + return 0, net.ErrClosed } // todo: use buffer pool instead of create new transport msg. @@ -508,7 +507,7 @@ func (c *Client) closeConn(connReference *Conn, id string) error { container, ok := c.conns[id] if !ok { - return fmt.Errorf("connection already closed") + return net.ErrClosed } if container.conn != connReference { diff --git a/relay/client/conn.go b/relay/client/conn.go index b4ff903e8..fe1b6fb52 100644 --- a/relay/client/conn.go +++ b/relay/client/conn.go @@ -1,7 +1,6 @@ package client import ( - "io" "net" "time" ) @@ -40,7 +39,7 @@ func (c *Conn) Write(p []byte) (n int, err error) { func (c *Conn) Read(b []byte) (n int, err error) { msg, ok := <-c.messageChan if !ok { - return 0, io.EOF + return 0, net.ErrClosed } n = copy(b, msg.Payload) diff --git a/relay/server/listener/ws/conn.go b/relay/server/listener/ws/conn.go index c248963b9..12e721fdb 100644 --- a/relay/server/listener/ws/conn.go +++ b/relay/server/listener/ws/conn.go @@ -4,7 +4,6 @@ import ( "context" "errors" "fmt" - "io" "net" "sync" "time" @@ -100,7 +99,7 @@ func (c *Conn) isClosed() bool { func (c *Conn) ioErrHandling(err error) error { if c.isClosed() { - return io.EOF + return net.ErrClosed } var wErr *websocket.CloseError @@ -108,7 +107,7 @@ func (c *Conn) ioErrHandling(err error) error { return err } if wErr.Code == websocket.StatusNormalClosure { - return io.EOF + return net.ErrClosed } return err } diff --git a/relay/server/peer.go b/relay/server/peer.go index a9c542f84..c909c35d5 100644 --- a/relay/server/peer.go +++ b/relay/server/peer.go @@ -2,7 +2,7 @@ package server import ( "context" - "io" + "errors" "net" "sync" "time" @@ -57,7 +57,7 @@ func (p *Peer) Work() { for { n, err := p.conn.Read(buf) if err != nil { - if err != io.EOF { + if !errors.Is(err, net.ErrClosed) { p.log.Errorf("failed to read message: %s", err) } return From e0bed2b0fb29c4abbc6d46191a7f54de90e29122 Mon Sep 17 00:00:00 2001 From: Viktor Liu <17948409+lixmal@users.noreply.github.com> Date: Mon, 11 Nov 2024 14:55:10 +0100 Subject: [PATCH 05/10] [client] Fix race conditions (#2869) * Fix concurrent map access in status * Fix race when retrieving ctx state error * Fix race when accessing service controller server instance --- client/cmd/service.go | 10 ++++++---- client/cmd/service_controller.go | 4 ++++ client/internal/connect.go | 3 ++- client/internal/peer/status.go | 2 +- 4 files changed, 13 insertions(+), 6 deletions(-) diff --git a/client/cmd/service.go b/client/cmd/service.go index 855eb30fa..3560088a7 100644 --- a/client/cmd/service.go +++ b/client/cmd/service.go @@ -2,6 +2,7 @@ package cmd import ( "context" + "sync" "github.com/kardianos/service" log "github.com/sirupsen/logrus" @@ -13,10 +14,11 @@ import ( ) type program struct { - ctx context.Context - cancel context.CancelFunc - serv *grpc.Server - serverInstance *server.Server + ctx context.Context + cancel context.CancelFunc + serv *grpc.Server + serverInstance *server.Server + serverInstanceMu sync.Mutex } func newProgram(ctx context.Context, cancel context.CancelFunc) *program { diff --git a/client/cmd/service_controller.go b/client/cmd/service_controller.go index 86546e31c..761c86628 100644 --- a/client/cmd/service_controller.go +++ b/client/cmd/service_controller.go @@ -61,7 +61,9 @@ func (p *program) Start(svc service.Service) error { } proto.RegisterDaemonServiceServer(p.serv, serverInstance) + p.serverInstanceMu.Lock() p.serverInstance = serverInstance + p.serverInstanceMu.Unlock() log.Printf("started daemon server: %v", split[1]) if err := p.serv.Serve(listen); err != nil { @@ -72,6 +74,7 @@ func (p *program) Start(svc service.Service) error { } func (p *program) Stop(srv service.Service) error { + p.serverInstanceMu.Lock() if p.serverInstance != nil { in := new(proto.DownRequest) _, err := p.serverInstance.Down(p.ctx, in) @@ -79,6 +82,7 @@ func (p *program) Stop(srv service.Service) error { log.Errorf("failed to stop daemon: %v", err) } } + p.serverInstanceMu.Unlock() p.cancel() diff --git a/client/internal/connect.go b/client/internal/connect.go index bcc9d17a3..dff44f1d2 100644 --- a/client/internal/connect.go +++ b/client/internal/connect.go @@ -207,7 +207,8 @@ func (c *ConnectClient) run(mobileDependency MobileDependency, probes *ProbeHold c.statusRecorder.MarkSignalDisconnected(nil) defer func() { - c.statusRecorder.MarkSignalDisconnected(state.err) + _, err := state.Status() + c.statusRecorder.MarkSignalDisconnected(err) }() // with the global Wiretrustee config in hand connect (just a connection, no stream yet) Signal diff --git a/client/internal/peer/status.go b/client/internal/peer/status.go index 241dfabbb..0444dc60b 100644 --- a/client/internal/peer/status.go +++ b/client/internal/peer/status.go @@ -67,7 +67,7 @@ func (s *State) DeleteRoute(network string) { func (s *State) GetRoutes() map[string]struct{} { s.Mux.RLock() defer s.Mux.RUnlock() - return s.routes + return maps.Clone(s.routes) } // LocalPeerState contains the latest state of the local peer From 6cb697eed69f522402c48672c38a0f7c5acdc0b9 Mon Sep 17 00:00:00 2001 From: Bethuel Mmbaga Date: Mon, 11 Nov 2024 19:46:10 +0300 Subject: [PATCH 06/10] [management] Refactor setup key to use store methods (#2861) * Refactor setup key handling to use store methods Signed-off-by: bcmmbaga * add lock to get account groups Signed-off-by: bcmmbaga * add check for regular user Signed-off-by: bcmmbaga * get only required groups for auto-group validation Signed-off-by: bcmmbaga * add account lock and return auto groups map on validation Signed-off-by: bcmmbaga * fix missing group removed from setup key activity Signed-off-by: bcmmbaga * Remove context from DB queries Signed-off-by: bcmmbaga * Add user permission check and add setup events into events to store slice Signed-off-by: bcmmbaga * Retrieve all groups once during setup key auto-group validation Signed-off-by: bcmmbaga * Fix lint Signed-off-by: bcmmbaga * Fix sonar Signed-off-by: bcmmbaga --------- Signed-off-by: bcmmbaga --- management/server/account.go | 4 +- management/server/account_test.go | 2 +- management/server/group.go | 2 +- management/server/group/group.go | 5 + management/server/peer.go | 2 +- management/server/setupkey.go | 240 +++++++++++++++++----------- management/server/sql_store.go | 115 ++++++++----- management/server/sql_store_test.go | 4 +- management/server/status/error.go | 19 ++- management/server/store.go | 8 +- management/server/user.go | 5 + 11 files changed, 263 insertions(+), 143 deletions(-) diff --git a/management/server/account.go b/management/server/account.go index aa7609388..583853f25 100644 --- a/management/server/account.go +++ b/management/server/account.go @@ -2029,7 +2029,7 @@ func (am *DefaultAccountManager) syncJWTGroups(ctx context.Context, accountID st return fmt.Errorf("error getting user: %w", err) } - groups, err := transaction.GetAccountGroups(ctx, accountID) + groups, err := transaction.GetAccountGroups(ctx, LockingStrengthShare, accountID) if err != nil { return fmt.Errorf("error getting account groups: %w", err) } @@ -2059,7 +2059,7 @@ func (am *DefaultAccountManager) syncJWTGroups(ctx context.Context, accountID st // Propagate changes to peers if group propagation is enabled if settings.GroupsPropagationEnabled { - groups, err = transaction.GetAccountGroups(ctx, accountID) + groups, err = transaction.GetAccountGroups(ctx, LockingStrengthShare, accountID) if err != nil { return fmt.Errorf("error getting account groups: %w", err) } diff --git a/management/server/account_test.go b/management/server/account_test.go index 6a2d85fe8..fdf004a3b 100644 --- a/management/server/account_test.go +++ b/management/server/account_test.go @@ -2773,7 +2773,7 @@ func TestAccount_SetJWTGroups(t *testing.T) { err = manager.syncJWTGroups(context.Background(), "accountID", claims) assert.NoError(t, err, "unable to sync jwt groups") - groups, err := manager.Store.GetAccountGroups(context.Background(), "accountID") + groups, err := manager.Store.GetAccountGroups(context.Background(), LockingStrengthShare, "accountID") assert.NoError(t, err) assert.Len(t, groups, 3, "new group3 should be added") diff --git a/management/server/group.go b/management/server/group.go index bdb569e37..b2ec88cc0 100644 --- a/management/server/group.go +++ b/management/server/group.go @@ -59,7 +59,7 @@ func (am *DefaultAccountManager) GetAllGroups(ctx context.Context, accountID, us return nil, err } - return am.Store.GetAccountGroups(ctx, accountID) + return am.Store.GetAccountGroups(ctx, LockingStrengthShare, accountID) } // GetGroupByName filters all groups in an account by name and returns the one with the most peers diff --git a/management/server/group/group.go b/management/server/group/group.go index d293e1afc..e98e5ecc4 100644 --- a/management/server/group/group.go +++ b/management/server/group/group.go @@ -49,3 +49,8 @@ func (g *Group) Copy() *Group { func (g *Group) HasPeers() bool { return len(g.Peers) > 0 } + +// IsGroupAll checks if the group is a default "All" group. +func (g *Group) IsGroupAll() bool { + return g.Name == "All" +} diff --git a/management/server/peer.go b/management/server/peer.go index 9c5ab571b..8ced2a1de 100644 --- a/management/server/peer.go +++ b/management/server/peer.go @@ -765,7 +765,7 @@ func (am *DefaultAccountManager) LoginPeer(ctx context.Context, login PeerLogin) } } - groups, err := am.Store.GetAccountGroups(ctx, accountID) + groups, err := am.Store.GetAccountGroups(ctx, LockingStrengthShare, accountID) if err != nil { return nil, nil, nil, err } diff --git a/management/server/setupkey.go b/management/server/setupkey.go index 43b6e02c9..554c66ba4 100644 --- a/management/server/setupkey.go +++ b/management/server/setupkey.go @@ -4,18 +4,17 @@ import ( "context" "crypto/sha256" b64 "encoding/base64" - "fmt" "hash/fnv" + "slices" "strconv" "strings" "time" "unicode/utf8" "github.com/google/uuid" - log "github.com/sirupsen/logrus" - "github.com/netbirdio/netbird/management/server/activity" "github.com/netbirdio/netbird/management/server/status" + log "github.com/sirupsen/logrus" ) const ( @@ -229,32 +228,43 @@ func (am *DefaultAccountManager) CreateSetupKey(ctx context.Context, accountID s unlock := am.Store.AcquireWriteLockByUID(ctx, accountID) defer unlock() - account, err := am.Store.GetAccount(ctx, accountID) + user, err := am.Store.GetUserByUserID(ctx, LockingStrengthShare, userID) if err != nil { return nil, err } - if err := validateSetupKeyAutoGroups(account, autoGroups); err != nil { - return nil, err + if user.AccountID != accountID { + return nil, status.NewUserNotPartOfAccountError() } - setupKey, plainKey := GenerateSetupKey(keyName, keyType, expiresIn, autoGroups, usageLimit, ephemeral) - account.SetupKeys[setupKey.Key] = setupKey - err = am.Store.SaveAccount(ctx, account) + if user.IsRegularUser() { + return nil, status.NewAdminPermissionError() + } + + var setupKey *SetupKey + var plainKey string + var eventsToStore []func() + + err = am.Store.ExecuteInTransaction(ctx, func(transaction Store) error { + if err = validateSetupKeyAutoGroups(ctx, transaction, accountID, autoGroups); err != nil { + return err + } + + setupKey, plainKey = GenerateSetupKey(keyName, keyType, expiresIn, autoGroups, usageLimit, ephemeral) + setupKey.AccountID = accountID + + events := am.prepareSetupKeyEvents(ctx, transaction, accountID, userID, autoGroups, nil, setupKey) + eventsToStore = append(eventsToStore, events...) + + return transaction.SaveSetupKey(ctx, LockingStrengthUpdate, setupKey) + }) if err != nil { - return nil, status.Errorf(status.Internal, "failed adding account key") + return nil, err } am.StoreEvent(ctx, userID, setupKey.Id, accountID, activity.SetupKeyCreated, setupKey.EventMeta()) - - for _, g := range setupKey.AutoGroups { - group := account.GetGroup(g) - if group != nil { - am.StoreEvent(ctx, userID, setupKey.Id, accountID, activity.GroupAddedToSetupKey, - map[string]any{"group": group.Name, "group_id": group.ID, "setupkey": setupKey.Name}) - } else { - log.WithContext(ctx).Errorf("group %s not found while saving setup key activity event of account %s", g, account.Id) - } + for _, storeEvent := range eventsToStore { + storeEvent() } // for the creation return the plain key to the caller @@ -268,43 +278,56 @@ func (am *DefaultAccountManager) CreateSetupKey(ctx context.Context, accountID s // (e.g. the key itself, creation date, ID, etc). // These properties are overwritten: Name, AutoGroups, Revoked. The rest is copied from the existing key. func (am *DefaultAccountManager) SaveSetupKey(ctx context.Context, accountID string, keyToSave *SetupKey, userID string) (*SetupKey, error) { - unlock := am.Store.AcquireWriteLockByUID(ctx, accountID) - defer unlock() - if keyToSave == nil { return nil, status.Errorf(status.InvalidArgument, "provided setup key to update is nil") } - account, err := am.Store.GetAccount(ctx, accountID) + unlock := am.Store.AcquireWriteLockByUID(ctx, accountID) + defer unlock() + + user, err := am.Store.GetUserByUserID(ctx, LockingStrengthShare, userID) if err != nil { return nil, err } + if user.AccountID != accountID { + return nil, status.NewUserNotPartOfAccountError() + } + + if user.IsRegularUser() { + return nil, status.NewAdminPermissionError() + } + var oldKey *SetupKey - for _, key := range account.SetupKeys { - if key.Id == keyToSave.Id { - oldKey = key.Copy() - break + var newKey *SetupKey + var eventsToStore []func() + + err = am.Store.ExecuteInTransaction(ctx, func(transaction Store) error { + if err = validateSetupKeyAutoGroups(ctx, transaction, accountID, keyToSave.AutoGroups); err != nil { + return err } - } - if oldKey == nil { - return nil, status.Errorf(status.NotFound, "setup key not found") - } - if err := validateSetupKeyAutoGroups(account, keyToSave.AutoGroups); err != nil { - return nil, err - } + oldKey, err = transaction.GetSetupKeyByID(ctx, LockingStrengthShare, accountID, keyToSave.Id) + if err != nil { + return err + } - // only auto groups, revoked status, and name can be updated for now - newKey := oldKey.Copy() - newKey.Name = keyToSave.Name - newKey.AutoGroups = keyToSave.AutoGroups - newKey.Revoked = keyToSave.Revoked - newKey.UpdatedAt = time.Now().UTC() + // only auto groups, revoked status, and name can be updated for now + newKey = oldKey.Copy() + newKey.Name = keyToSave.Name + newKey.AutoGroups = keyToSave.AutoGroups + newKey.Revoked = keyToSave.Revoked + newKey.UpdatedAt = time.Now().UTC() - account.SetupKeys[newKey.Key] = newKey + addedGroups := difference(newKey.AutoGroups, oldKey.AutoGroups) + removedGroups := difference(oldKey.AutoGroups, newKey.AutoGroups) - if err = am.Store.SaveAccount(ctx, account); err != nil { + events := am.prepareSetupKeyEvents(ctx, transaction, accountID, userID, addedGroups, removedGroups, oldKey) + eventsToStore = append(eventsToStore, events...) + + return transaction.SaveSetupKey(ctx, LockingStrengthUpdate, newKey) + }) + if err != nil { return nil, err } @@ -312,30 +335,9 @@ func (am *DefaultAccountManager) SaveSetupKey(ctx context.Context, accountID str am.StoreEvent(ctx, userID, newKey.Id, accountID, activity.SetupKeyRevoked, newKey.EventMeta()) } - defer func() { - addedGroups := difference(newKey.AutoGroups, oldKey.AutoGroups) - removedGroups := difference(oldKey.AutoGroups, newKey.AutoGroups) - for _, g := range removedGroups { - group := account.GetGroup(g) - if group != nil { - am.StoreEvent(ctx, userID, oldKey.Id, accountID, activity.GroupRemovedFromSetupKey, - map[string]any{"group": group.Name, "group_id": group.ID, "setupkey": newKey.Name}) - } else { - log.WithContext(ctx).Errorf("group %s not found while saving setup key activity event of account %s", g, account.Id) - } - - } - - for _, g := range addedGroups { - group := account.GetGroup(g) - if group != nil { - am.StoreEvent(ctx, userID, oldKey.Id, accountID, activity.GroupAddedToSetupKey, - map[string]any{"group": group.Name, "group_id": group.ID, "setupkey": newKey.Name}) - } else { - log.WithContext(ctx).Errorf("group %s not found while saving setup key activity event of account %s", g, account.Id) - } - } - }() + for _, storeEvent := range eventsToStore { + storeEvent() + } return newKey, nil } @@ -347,16 +349,15 @@ func (am *DefaultAccountManager) ListSetupKeys(ctx context.Context, accountID, u return nil, err } - if !user.IsAdminOrServiceUser() || user.AccountID != accountID { - return nil, status.NewUnauthorizedToViewSetupKeysError() + if user.AccountID != accountID { + return nil, status.NewUserNotPartOfAccountError() } - setupKeys, err := am.Store.GetAccountSetupKeys(ctx, LockingStrengthShare, accountID) - if err != nil { - return nil, err + if user.IsRegularUser() { + return nil, status.NewAdminPermissionError() } - return setupKeys, nil + return am.Store.GetAccountSetupKeys(ctx, LockingStrengthShare, accountID) } // GetSetupKey looks up a SetupKey by KeyID, returns NotFound error if not found. @@ -366,8 +367,12 @@ func (am *DefaultAccountManager) GetSetupKey(ctx context.Context, accountID, use return nil, err } - if !user.IsAdminOrServiceUser() || user.AccountID != accountID { - return nil, status.NewUnauthorizedToViewSetupKeysError() + if user.AccountID != accountID { + return nil, status.NewUserNotPartOfAccountError() + } + + if user.IsRegularUser() { + return nil, status.NewAdminPermissionError() } setupKey, err := am.Store.GetSetupKeyByID(ctx, LockingStrengthShare, keyID, accountID) @@ -387,21 +392,29 @@ func (am *DefaultAccountManager) GetSetupKey(ctx context.Context, accountID, use func (am *DefaultAccountManager) DeleteSetupKey(ctx context.Context, accountID, userID, keyID string) error { user, err := am.Store.GetUserByUserID(ctx, LockingStrengthShare, userID) if err != nil { - return fmt.Errorf("failed to get user: %w", err) + return err } - if !user.IsAdminOrServiceUser() || user.AccountID != accountID { - return status.NewUnauthorizedToViewSetupKeysError() + if user.AccountID != accountID { + return status.NewUserNotPartOfAccountError() } - deletedSetupKey, err := am.Store.GetSetupKeyByID(ctx, LockingStrengthShare, keyID, accountID) + if user.IsRegularUser() { + return status.NewAdminPermissionError() + } + + var deletedSetupKey *SetupKey + + err = am.Store.ExecuteInTransaction(ctx, func(transaction Store) error { + deletedSetupKey, err = transaction.GetSetupKeyByID(ctx, LockingStrengthShare, accountID, keyID) + if err != nil { + return err + } + + return transaction.DeleteSetupKey(ctx, LockingStrengthUpdate, accountID, keyID) + }) if err != nil { - return fmt.Errorf("failed to get setup key: %w", err) - } - - err = am.Store.DeleteSetupKey(ctx, accountID, keyID) - if err != nil { - return fmt.Errorf("failed to delete setup key: %w", err) + return err } am.StoreEvent(ctx, userID, keyID, accountID, activity.SetupKeyDeleted, deletedSetupKey.EventMeta()) @@ -409,15 +422,62 @@ func (am *DefaultAccountManager) DeleteSetupKey(ctx context.Context, accountID, return nil } -func validateSetupKeyAutoGroups(account *Account, autoGroups []string) error { - for _, group := range autoGroups { - g, ok := account.Groups[group] +func validateSetupKeyAutoGroups(ctx context.Context, transaction Store, accountID string, autoGroupIDs []string) error { + groups, err := transaction.GetGroupsByIDs(ctx, LockingStrengthShare, accountID, autoGroupIDs) + if err != nil { + return err + } + + for _, groupID := range autoGroupIDs { + group, ok := groups[groupID] if !ok { - return status.Errorf(status.NotFound, "group %s doesn't exist", group) + return status.Errorf(status.NotFound, "group not found: %s", groupID) } - if g.Name == "All" { - return status.Errorf(status.InvalidArgument, "can't add All group to the setup key") + + if group.IsGroupAll() { + return status.Errorf(status.InvalidArgument, "can't add 'All' group to the setup key") } } + return nil } + +// prepareSetupKeyEvents prepares a list of event functions to be stored. +func (am *DefaultAccountManager) prepareSetupKeyEvents(ctx context.Context, transaction Store, accountID, userID string, addedGroups, removedGroups []string, key *SetupKey) []func() { + var eventsToStore []func() + + modifiedGroups := slices.Concat(addedGroups, removedGroups) + groups, err := transaction.GetGroupsByIDs(ctx, LockingStrengthShare, accountID, modifiedGroups) + if err != nil { + log.WithContext(ctx).Errorf("issue getting groups for setup key events: %v", err) + return nil + } + + for _, g := range removedGroups { + group, ok := groups[g] + if !ok { + log.WithContext(ctx).Debugf("skipped adding group: %s GroupRemovedFromSetupKey activity: %v", g, err) + continue + } + + eventsToStore = append(eventsToStore, func() { + meta := map[string]any{"group": group.Name, "group_id": group.ID, "setupkey": key.Name} + am.StoreEvent(ctx, userID, key.Id, accountID, activity.GroupRemovedFromSetupKey, meta) + }) + } + + for _, g := range addedGroups { + group, ok := groups[g] + if !ok { + log.WithContext(ctx).Debugf("skipped adding group: %s GroupAddedToSetupKey activity: %v", g, err) + continue + } + + eventsToStore = append(eventsToStore, func() { + meta := map[string]any{"group": group.Name, "group_id": group.ID, "setupkey": key.Name} + am.StoreEvent(ctx, userID, key.Id, accountID, activity.GroupAddedToSetupKey, meta) + }) + } + + return eventsToStore +} diff --git a/management/server/sql_store.go b/management/server/sql_store.go index 5bc521437..9dd3e778d 100644 --- a/management/server/sql_store.go +++ b/management/server/sql_store.go @@ -485,9 +485,10 @@ func (s *SqlStore) GetAccountBySetupKey(ctx context.Context, setupKey string) (* result := s.db.Select("account_id").First(&key, keyQueryCondition, setupKey) if result.Error != nil { if errors.Is(result.Error, gorm.ErrRecordNotFound) { - return nil, status.Errorf(status.NotFound, "account not found: index lookup failed") + return nil, status.NewSetupKeyNotFoundError(setupKey) } - return nil, status.NewSetupKeyNotFoundError(result.Error) + log.WithContext(ctx).Errorf("failed to get account by setup key from store: %v", result.Error) + return nil, status.Errorf(status.Internal, "failed to get account by setup key from store") } if key.AccountID == "" { @@ -568,15 +569,15 @@ func (s *SqlStore) GetAccountUsers(ctx context.Context, accountID string) ([]*Us return users, nil } -func (s *SqlStore) GetAccountGroups(ctx context.Context, accountID string) ([]*nbgroup.Group, error) { +func (s *SqlStore) GetAccountGroups(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*nbgroup.Group, error) { var groups []*nbgroup.Group - result := s.db.Find(&groups, accountIDCondition, accountID) + result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}).Find(&groups, accountIDCondition, accountID) if result.Error != nil { if errors.Is(result.Error, gorm.ErrRecordNotFound) { return nil, status.Errorf(status.NotFound, "accountID not found: index lookup failed") } - log.WithContext(ctx).Errorf("error when getting groups from the store: %s", result.Error) - return nil, status.Errorf(status.Internal, "issue getting groups from store") + log.WithContext(ctx).Errorf("failed to get account groups from the store: %s", result.Error) + return nil, status.Errorf(status.Internal, "failed to get account groups from the store") } return groups, nil @@ -756,9 +757,10 @@ func (s *SqlStore) GetAccountIDBySetupKey(ctx context.Context, setupKey string) result := s.db.Model(&SetupKey{}).Select("account_id").Where(keyQueryCondition, setupKey).First(&accountID) if result.Error != nil { if errors.Is(result.Error, gorm.ErrRecordNotFound) { - return "", status.Errorf(status.NotFound, "account not found: index lookup failed") + return "", status.NewSetupKeyNotFoundError(setupKey) } - return "", status.NewSetupKeyNotFoundError(result.Error) + log.WithContext(ctx).Errorf("failed to get account ID by setup key from store: %v", result.Error) + return "", status.Errorf(status.Internal, "failed to get account ID by setup key from store") } if accountID == "" { @@ -986,9 +988,10 @@ func (s *SqlStore) GetSetupKeyBySecret(ctx context.Context, lockStrength Locking First(&setupKey, keyQueryCondition, key) if result.Error != nil { if errors.Is(result.Error, gorm.ErrRecordNotFound) { - return nil, status.Errorf(status.NotFound, "setup key not found") + return nil, status.NewSetupKeyNotFoundError(key) } - return nil, status.NewSetupKeyNotFoundError(result.Error) + log.WithContext(ctx).Errorf("failed to get setup key by secret from store: %v", result.Error) + return nil, status.Errorf(status.Internal, "failed to get setup key by secret from store") } return &setupKey, nil } @@ -1006,7 +1009,7 @@ func (s *SqlStore) IncrementSetupKeyUsage(ctx context.Context, setupKeyID string } if result.RowsAffected == 0 { - return status.Errorf(status.NotFound, "setup key not found") + return status.NewSetupKeyNotFoundError(setupKeyID) } return nil @@ -1179,6 +1182,23 @@ func (s *SqlStore) GetGroupByName(ctx context.Context, lockStrength LockingStren return &group, nil } +// GetGroupsByIDs retrieves groups by their IDs and account ID. +func (s *SqlStore) GetGroupsByIDs(ctx context.Context, lockStrength LockingStrength, accountID string, groupIDs []string) (map[string]*nbgroup.Group, error) { + var groups []*nbgroup.Group + result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}).Find(&groups, "account_id = ? AND id in ?", accountID, groupIDs) + if result.Error != nil { + log.WithContext(ctx).Errorf("failed to get groups by ID's from the store: %s", result.Error) + return nil, status.Errorf(status.Internal, "failed to get groups by ID's from the store") + } + + groupsMap := make(map[string]*nbgroup.Group) + for _, group := range groups { + groupsMap[group.ID] = group + } + + return groupsMap, nil +} + // SaveGroup saves a group to the store. func (s *SqlStore) SaveGroup(ctx context.Context, lockStrength LockingStrength, group *nbgroup.Group) error { result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}).Save(group) @@ -1220,12 +1240,57 @@ func (s *SqlStore) GetRouteByID(ctx context.Context, lockStrength LockingStrengt // GetAccountSetupKeys retrieves setup keys for an account. func (s *SqlStore) GetAccountSetupKeys(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*SetupKey, error) { - return getRecords[*SetupKey](s.db, lockStrength, accountID) + var setupKeys []*SetupKey + result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}). + Find(&setupKeys, accountIDCondition, accountID) + if err := result.Error; err != nil { + log.WithContext(ctx).Errorf("failed to get setup keys from the store: %s", err) + return nil, status.Errorf(status.Internal, "failed to get setup keys from store") + } + + return setupKeys, nil } // GetSetupKeyByID retrieves a setup key by its ID and account ID. -func (s *SqlStore) GetSetupKeyByID(ctx context.Context, lockStrength LockingStrength, setupKeyID string, accountID string) (*SetupKey, error) { - return getRecordByID[SetupKey](s.db, lockStrength, setupKeyID, accountID) +func (s *SqlStore) GetSetupKeyByID(ctx context.Context, lockStrength LockingStrength, accountID, setupKeyID string) (*SetupKey, error) { + var setupKey *SetupKey + result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}). + First(&setupKey, accountAndIDQueryCondition, accountID, setupKeyID) + if err := result.Error; err != nil { + if errors.Is(err, gorm.ErrRecordNotFound) { + return nil, status.NewSetupKeyNotFoundError(setupKeyID) + } + log.WithContext(ctx).Errorf("failed to get setup key from the store: %s", err) + return nil, status.Errorf(status.Internal, "failed to get setup key from store") + } + + return setupKey, nil +} + +// SaveSetupKey saves a setup key to the database. +func (s *SqlStore) SaveSetupKey(ctx context.Context, lockStrength LockingStrength, setupKey *SetupKey) error { + result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}).Save(setupKey) + if result.Error != nil { + log.WithContext(ctx).Errorf("failed to save setup key to store: %s", result.Error) + return status.Errorf(status.Internal, "failed to save setup key to store") + } + + return nil +} + +// DeleteSetupKey deletes a setup key from the database. +func (s *SqlStore) DeleteSetupKey(ctx context.Context, lockStrength LockingStrength, accountID, keyID string) error { + result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}).Delete(&SetupKey{}, accountAndIDQueryCondition, accountID, keyID) + if result.Error != nil { + log.WithContext(ctx).Errorf("failed to delete setup key from store: %s", result.Error) + return status.Errorf(status.Internal, "failed to delete setup key from store") + } + + if result.RowsAffected == 0 { + return status.NewSetupKeyNotFoundError(keyID) + } + + return nil } // GetAccountNameServerGroups retrieves name server groups for an account. @@ -1238,10 +1303,6 @@ func (s *SqlStore) GetNameServerGroupByID(ctx context.Context, lockStrength Lock return getRecordByID[nbdns.NameServerGroup](s.db, lockStrength, nsGroupID, accountID) } -func (s *SqlStore) DeleteSetupKey(ctx context.Context, accountID, keyID string) error { - return deleteRecordByID[SetupKey](s.db, LockingStrengthUpdate, keyID, accountID) -} - // getRecords retrieves records from the database based on the account ID. func getRecords[T any](db *gorm.DB, lockStrength LockingStrength, accountID string) ([]T, error) { var record []T @@ -1274,21 +1335,3 @@ func getRecordByID[T any](db *gorm.DB, lockStrength LockingStrength, recordID, a } return &record, nil } - -// deleteRecordByID deletes a record by its ID and account ID from the database. -func deleteRecordByID[T any](db *gorm.DB, lockStrength LockingStrength, recordID, accountID string) error { - var record T - result := db.Clauses(clause.Locking{Strength: string(lockStrength)}).Delete(record, accountAndIDQueryCondition, accountID, recordID) - if err := result.Error; err != nil { - parts := strings.Split(fmt.Sprintf("%T", record), ".") - recordType := parts[len(parts)-1] - - return status.Errorf(status.Internal, "failed to delete %s from store: %v", recordType, err) - } - - if result.RowsAffected == 0 { - return status.Errorf(status.NotFound, "record not found") - } - - return nil -} diff --git a/management/server/sql_store_test.go b/management/server/sql_store_test.go index b371e2313..3f3b2a453 100644 --- a/management/server/sql_store_test.go +++ b/management/server/sql_store_test.go @@ -1274,7 +1274,7 @@ func Test_DeleteSetupKeySuccessfully(t *testing.T) { accountID := "bf1c8084-ba50-4ce7-9439-34653001fc3b" setupKeyID := "A2C8E62B-38F5-4553-B31E-DD66C696CEBB" - err = store.DeleteSetupKey(context.Background(), accountID, setupKeyID) + err = store.DeleteSetupKey(context.Background(), LockingStrengthUpdate, accountID, setupKeyID) require.NoError(t, err) _, err = store.GetSetupKeyByID(context.Background(), LockingStrengthShare, setupKeyID, accountID) @@ -1290,6 +1290,6 @@ func Test_DeleteSetupKeyFailsForNonExistingKey(t *testing.T) { accountID := "bf1c8084-ba50-4ce7-9439-34653001fc3b" nonExistingKeyID := "non-existing-key-id" - err = store.DeleteSetupKey(context.Background(), accountID, nonExistingKeyID) + err = store.DeleteSetupKey(context.Background(), LockingStrengthUpdate, accountID, nonExistingKeyID) require.Error(t, err) } diff --git a/management/server/status/error.go b/management/server/status/error.go index a145edf80..a415d5b6e 100644 --- a/management/server/status/error.go +++ b/management/server/status/error.go @@ -103,19 +103,29 @@ func NewPeerLoginExpiredError() error { } // NewSetupKeyNotFoundError creates a new Error with NotFound type for a missing setup key -func NewSetupKeyNotFoundError(err error) error { - return Errorf(NotFound, "setup key not found: %s", err) +func NewSetupKeyNotFoundError(setupKeyID string) error { + return Errorf(NotFound, "setup key: %s not found", setupKeyID) } func NewGetAccountFromStoreError(err error) error { return Errorf(Internal, "issue getting account from store: %s", err) } +// NewUserNotPartOfAccountError creates a new Error with PermissionDenied type for a user not being part of an account +func NewUserNotPartOfAccountError() error { + return Errorf(PermissionDenied, "user is not part of this account") +} + // NewGetUserFromStoreError creates a new Error with Internal type for an issue getting user from store func NewGetUserFromStoreError() error { return Errorf(Internal, "issue getting user from store") } +// NewAdminPermissionError creates a new Error with PermissionDenied type for actions requiring admin role. +func NewAdminPermissionError() error { + return Errorf(PermissionDenied, "admin role required to perform this action") +} + // NewStoreContextCanceledError creates a new Error with Internal type for a canceled store context func NewStoreContextCanceledError(duration time.Duration) error { return Errorf(Internal, "store access: context canceled after %v", duration) @@ -125,8 +135,3 @@ func NewStoreContextCanceledError(duration time.Duration) error { func NewInvalidKeyIDError() error { return Errorf(InvalidArgument, "invalid key ID") } - -// NewUnauthorizedToViewSetupKeysError creates a new Error with Unauthorized type for an issue getting a setup key -func NewUnauthorizedToViewSetupKeysError() error { - return Errorf(Unauthorized, "only users with admin power can view setup keys") -} diff --git a/management/server/store.go b/management/server/store.go index 087c98847..c7c066629 100644 --- a/management/server/store.go +++ b/management/server/store.go @@ -70,9 +70,10 @@ type Store interface { DeleteHashedPAT2TokenIDIndex(hashedToken string) error DeleteTokenID2UserIDIndex(tokenID string) error - GetAccountGroups(ctx context.Context, accountID string) ([]*nbgroup.Group, error) + GetAccountGroups(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*nbgroup.Group, error) GetGroupByID(ctx context.Context, lockStrength LockingStrength, groupID, accountID string) (*nbgroup.Group, error) GetGroupByName(ctx context.Context, lockStrength LockingStrength, groupName, accountID string) (*nbgroup.Group, error) + GetGroupsByIDs(ctx context.Context, lockStrength LockingStrength, accountID string, groupIDs []string) (map[string]*nbgroup.Group, error) SaveGroups(ctx context.Context, lockStrength LockingStrength, groups []*nbgroup.Group) error SaveGroup(ctx context.Context, lockStrength LockingStrength, group *nbgroup.Group) error @@ -96,7 +97,9 @@ type Store interface { GetSetupKeyBySecret(ctx context.Context, lockStrength LockingStrength, key string) (*SetupKey, error) IncrementSetupKeyUsage(ctx context.Context, setupKeyID string) error GetAccountSetupKeys(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*SetupKey, error) - GetSetupKeyByID(ctx context.Context, lockStrength LockingStrength, setupKeyID string, accountID string) (*SetupKey, error) + GetSetupKeyByID(ctx context.Context, lockStrength LockingStrength, accountID, setupKeyID string) (*SetupKey, error) + SaveSetupKey(ctx context.Context, lockStrength LockingStrength, setupKey *SetupKey) error + DeleteSetupKey(ctx context.Context, lockStrength LockingStrength, accountID, keyID string) error GetAccountRoutes(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*route.Route, error) GetRouteByID(ctx context.Context, lockStrength LockingStrength, routeID string, accountID string) (*route.Route, error) @@ -124,7 +127,6 @@ type Store interface { // This is also a method of metrics.DataSource interface. GetStoreEngine() StoreEngine ExecuteInTransaction(ctx context.Context, f func(store Store) error) error - DeleteSetupKey(ctx context.Context, accountID, keyID string) error } type StoreEngine string diff --git a/management/server/user.go b/management/server/user.go index 9fdd3a6ee..1368b76b1 100644 --- a/management/server/user.go +++ b/management/server/user.go @@ -103,6 +103,11 @@ func (u *User) IsAdminOrServiceUser() bool { return u.HasAdminPower() || u.IsServiceUser } +// IsRegularUser checks if the user is a regular user. +func (u *User) IsRegularUser() bool { + return !u.HasAdminPower() && !u.IsServiceUser +} + // ToUserInfo converts a User object to a UserInfo object. func (u *User) ToUserInfo(userData *idp.UserData, settings *Settings) (*UserInfo, error) { autoGroups := u.AutoGroups From 664d1388aab5f283684b09ad0e47560dfab48df7 Mon Sep 17 00:00:00 2001 From: bcmmbaga Date: Mon, 11 Nov 2024 22:29:59 +0300 Subject: [PATCH 07/10] fix merge Signed-off-by: bcmmbaga --- management/server/sql_store.go | 22 ++++++++++++---------- management/server/status/error.go | 1 - 2 files changed, 12 insertions(+), 11 deletions(-) diff --git a/management/server/sql_store.go b/management/server/sql_store.go index 730fb9900..502a83f2e 100644 --- a/management/server/sql_store.go +++ b/management/server/sql_store.go @@ -33,12 +33,13 @@ import ( ) const ( - storeSqliteFileName = "store.db" - idQueryCondition = "id = ?" - keyQueryCondition = "key = ?" - accountAndIDQueryCondition = "account_id = ? and id = ?" - accountIDCondition = "account_id = ?" - peerNotFoundFMT = "peer %s not found" + storeSqliteFileName = "store.db" + idQueryCondition = "id = ?" + keyQueryCondition = "key = ?" + accountAndIDQueryCondition = "account_id = ? and id = ?" + accountAndIDsQueryCondition = "account_id = ? AND id IN ?" + accountIDCondition = "account_id = ?" + peerNotFoundFMT = "peer %s not found" ) // SqlStore represents an account storage backed by a Sql DB persisted to disk @@ -1095,10 +1096,11 @@ func (s *SqlStore) GetPeerByID(ctx context.Context, lockStrength LockingStrength } func (s *SqlStore) IncrementNetworkSerial(ctx context.Context, lockStrength LockingStrength, accountId string) error { - result := s.db.WithContext(ctx).Clauses(clause.Locking{Strength: string(lockStrength)}). + result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}). Model(&Account{}).Where(idQueryCondition, accountId).Update("network_serial", gorm.Expr("network_serial + 1")) if result.Error != nil { - return status.Errorf(status.Internal, "issue incrementing network serial count: %s", result.Error) + log.WithContext(ctx).Errorf("failed to increment network serial count in store: %v", result.Error) + return status.Errorf(status.Internal, "failed to increment network serial count in store") } return nil } @@ -1213,7 +1215,7 @@ func (s *SqlStore) GetGroupByName(ctx context.Context, lockStrength LockingStren // GetGroupsByIDs retrieves groups by their IDs and account ID. func (s *SqlStore) GetGroupsByIDs(ctx context.Context, lockStrength LockingStrength, accountID string, groupIDs []string) (map[string]*nbgroup.Group, error) { var groups []*nbgroup.Group - result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}).Find(&groups, "account_id = ? AND id in ?", accountID, groupIDs) + result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}).Find(&groups, accountAndIDsQueryCondition, accountID, groupIDs) if result.Error != nil { log.WithContext(ctx).Errorf("failed to get groups by ID's from the store: %s", result.Error) return nil, status.Errorf(status.Internal, "failed to get groups by ID's from the store") @@ -1256,7 +1258,7 @@ func (s *SqlStore) DeleteGroup(ctx context.Context, lockStrength LockingStrength // DeleteGroups deletes groups from the database. func (s *SqlStore) DeleteGroups(ctx context.Context, strength LockingStrength, accountID string, groupIDs []string) error { result := s.db.Clauses(clause.Locking{Strength: string(strength)}). - Delete(&nbgroup.Group{}, " account_id = ? AND id IN ?", accountID, groupIDs) + Delete(&nbgroup.Group{}, accountAndIDsQueryCondition, accountID, groupIDs) if result.Error != nil { log.WithContext(ctx).Errorf("failed to delete groups from store: %v", result.Error) return status.Errorf(status.Internal, "failed to delete groups from store: %v", result.Error) diff --git a/management/server/status/error.go b/management/server/status/error.go index 6957a7e05..db6e4c2fb 100644 --- a/management/server/status/error.go +++ b/management/server/status/error.go @@ -3,7 +3,6 @@ package status import ( "errors" "fmt" - "time" ) const ( From ab00c41dada6f97d13a3f53f3937071b21621c90 Mon Sep 17 00:00:00 2001 From: bcmmbaga Date: Mon, 11 Nov 2024 22:38:24 +0300 Subject: [PATCH 08/10] fix sonar Signed-off-by: bcmmbaga --- management/server/group.go | 23 +++++++++++++++++++---- management/server/group/group.go | 6 ++---- 2 files changed, 21 insertions(+), 8 deletions(-) diff --git a/management/server/group.go b/management/server/group.go index c49bb2471..1afb8f3c5 100644 --- a/management/server/group.go +++ b/management/server/group.go @@ -89,6 +89,10 @@ func (am *DefaultAccountManager) SaveGroups(ctx context.Context, accountID, user return status.NewUserNotPartOfAccountError() } + if user.IsRegularUser() { + return status.NewAdminPermissionError() + } + var eventsToStore []func() var groupsToSave []*nbgroup.Group var updateAccountPeers bool @@ -213,6 +217,10 @@ func (am *DefaultAccountManager) DeleteGroup(ctx context.Context, accountID, use return status.NewUserNotPartOfAccountError() } + if user.IsRegularUser() { + return status.NewAdminPermissionError() + } + var group *nbgroup.Group err = am.Store.ExecuteInTransaction(ctx, func(transaction Store) error { @@ -260,6 +268,10 @@ func (am *DefaultAccountManager) DeleteGroups(ctx context.Context, accountID, us return status.NewUserNotPartOfAccountError() } + if user.IsRegularUser() { + return status.NewAdminPermissionError() + } + var allErrors error var groupIDsToDelete []string var deletedGroups []*nbgroup.Group @@ -438,6 +450,11 @@ func validateDeleteGroup(ctx context.Context, transaction Store, group *nbgroup. return &GroupLinkError{"user", linkedUser.Id} } + return checkGroupLinkedToSettings(ctx, transaction, group) +} + +// checkGroupLinkedToSettings verifies if a group is linked to any settings in the account. +func checkGroupLinkedToSettings(ctx context.Context, transaction Store, group *nbgroup.Group) error { dnsSettings, err := transaction.GetAccountDNSSettings(ctx, LockingStrengthShare, group.AccountID) if err != nil { return err @@ -452,10 +469,8 @@ func validateDeleteGroup(ctx context.Context, transaction Store, group *nbgroup. return err } - if settings.Extra != nil { - if slices.Contains(settings.Extra.IntegratedValidatorGroups, group.ID) { - return &GroupLinkError{"integrated validator", group.Name} - } + if settings.Extra != nil && slices.Contains(settings.Extra.IntegratedValidatorGroups, group.ID) { + return &GroupLinkError{"integrated validator", group.Name} } return nil diff --git a/management/server/group/group.go b/management/server/group/group.go index bb0f5b7b6..24c60d3ce 100644 --- a/management/server/group/group.go +++ b/management/server/group/group.go @@ -55,8 +55,7 @@ func (g *Group) IsGroupAll() bool { return g.Name == "All" } -// AddPeer adds peerID to Peers if not already present, -// returning true if added. +// AddPeer adds peerID to Peers if not present, returning true if added. func (g *Group) AddPeer(peerID string) bool { if peerID == "" { return false @@ -72,8 +71,7 @@ func (g *Group) AddPeer(peerID string) bool { return true } -// RemovePeer removes peerID from Peers if present, -// returning true if removed. +// RemovePeer removes peerID from Peers if present, returning true if removed. func (g *Group) RemovePeer(peerID string) bool { for i, itemID := range g.Peers { if itemID == peerID { From 113c21b0e1b56f2c2cc9dd9dc2fd8b3a74365632 Mon Sep 17 00:00:00 2001 From: bcmmbaga Date: Mon, 11 Nov 2024 22:57:24 +0300 Subject: [PATCH 09/10] Change setup key log level to debug for missing group Signed-off-by: bcmmbaga --- management/server/setupkey.go | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/management/server/setupkey.go b/management/server/setupkey.go index 554c66ba4..f055d877f 100644 --- a/management/server/setupkey.go +++ b/management/server/setupkey.go @@ -449,14 +449,14 @@ func (am *DefaultAccountManager) prepareSetupKeyEvents(ctx context.Context, tran modifiedGroups := slices.Concat(addedGroups, removedGroups) groups, err := transaction.GetGroupsByIDs(ctx, LockingStrengthShare, accountID, modifiedGroups) if err != nil { - log.WithContext(ctx).Errorf("issue getting groups for setup key events: %v", err) + log.WithContext(ctx).Debugf("failed to get groups for setup key events: %v", err) return nil } for _, g := range removedGroups { group, ok := groups[g] if !ok { - log.WithContext(ctx).Debugf("skipped adding group: %s GroupRemovedFromSetupKey activity: %v", g, err) + log.WithContext(ctx).Debugf("skipped adding group: %s GroupRemovedFromSetupKey activity: group not found", g) continue } @@ -469,7 +469,7 @@ func (am *DefaultAccountManager) prepareSetupKeyEvents(ctx context.Context, tran for _, g := range addedGroups { group, ok := groups[g] if !ok { - log.WithContext(ctx).Debugf("skipped adding group: %s GroupAddedToSetupKey activity: %v", g, err) + log.WithContext(ctx).Debugf("skipped adding group: %s GroupAddedToSetupKey activity: group not found", g) continue } From d23b5c892b923ad5b3a8f45368da13de26ef64ea Mon Sep 17 00:00:00 2001 From: bcmmbaga Date: Mon, 11 Nov 2024 22:58:22 +0300 Subject: [PATCH 10/10] Retrieve modified peers once for group events Signed-off-by: bcmmbaga --- management/server/group.go | 35 ++++++++++++++++++++-------------- management/server/sql_store.go | 17 +++++++++++++++++ management/server/store.go | 1 + 3 files changed, 39 insertions(+), 14 deletions(-) diff --git a/management/server/group.go b/management/server/group.go index 1afb8f3c5..57960e7f9 100644 --- a/management/server/group.go +++ b/management/server/group.go @@ -156,34 +156,41 @@ func (am *DefaultAccountManager) prepareGroupEvents(ctx context.Context, transac }) } + modifiedPeers := slices.Concat(addedPeers, removedPeers) + peers, err := transaction.GetPeersByIDs(ctx, LockingStrengthShare, accountID, modifiedPeers) + if err != nil { + log.WithContext(ctx).Debugf("failed to get peers for group events: %v", err) + return nil + } + for _, peerID := range addedPeers { - peer, err := transaction.GetPeerByID(context.Background(), LockingStrengthShare, accountID, peerID) - if err != nil { - log.WithContext(ctx).Debugf("skipped adding peer: %s GroupAddedToPeer activity: %v", peerID, err) + peer, ok := peers[peerID] + if !ok { + log.WithContext(ctx).Debugf("skipped adding peer: %s GroupAddedToPeer activity: peer not found in store", peerID) continue } - meta := map[string]any{ - "group": newGroup.Name, "group_id": newGroup.ID, - "peer_ip": peer.IP.String(), "peer_fqdn": peer.FQDN(am.GetDNSDomain()), - } eventsToStore = append(eventsToStore, func() { + meta := map[string]any{ + "group": newGroup.Name, "group_id": newGroup.ID, + "peer_ip": peer.IP.String(), "peer_fqdn": peer.FQDN(am.GetDNSDomain()), + } am.StoreEvent(ctx, userID, peer.ID, accountID, activity.GroupAddedToPeer, meta) }) } for _, peerID := range removedPeers { - peer, err := transaction.GetPeerByID(context.Background(), LockingStrengthShare, accountID, peerID) - if err != nil { - log.WithContext(ctx).Debugf("skipped adding peer: %s GroupRemovedFromPeer activity: %v", peerID, err) + peer, ok := peers[peerID] + if !ok { + log.WithContext(ctx).Debugf("skipped adding peer: %s GroupRemovedFromPeer activity: peer not found in store", peerID) continue } - meta := map[string]any{ - "group": newGroup.Name, "group_id": newGroup.ID, - "peer_ip": peer.IP.String(), "peer_fqdn": peer.FQDN(am.GetDNSDomain()), - } eventsToStore = append(eventsToStore, func() { + meta := map[string]any{ + "group": newGroup.Name, "group_id": newGroup.ID, + "peer_ip": peer.IP.String(), "peer_fqdn": peer.FQDN(am.GetDNSDomain()), + } am.StoreEvent(ctx, userID, peer.ID, accountID, activity.GroupRemovedFromPeer, meta) }) } diff --git a/management/server/sql_store.go b/management/server/sql_store.go index 502a83f2e..7c741d35c 100644 --- a/management/server/sql_store.go +++ b/management/server/sql_store.go @@ -1095,6 +1095,23 @@ func (s *SqlStore) GetPeerByID(ctx context.Context, lockStrength LockingStrength return peer, nil } +// GetPeersByIDs retrieves peers by their IDs and account ID. +func (s *SqlStore) GetPeersByIDs(ctx context.Context, lockStrength LockingStrength, accountID string, peerIDs []string) (map[string]*nbpeer.Peer, error) { + var peers []*nbpeer.Peer + result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}).Find(&peers, accountAndIDsQueryCondition, accountID, peerIDs) + if result.Error != nil { + log.WithContext(ctx).Errorf("failed to get peers by ID's from the store: %s", result.Error) + return nil, status.Errorf(status.Internal, "failed to get peers by ID's from the store") + } + + peersMap := make(map[string]*nbpeer.Peer) + for _, peer := range peers { + peersMap[peer.ID] = peer + } + + return peersMap, nil +} + func (s *SqlStore) IncrementNetworkSerial(ctx context.Context, lockStrength LockingStrength, accountId string) error { result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}). Model(&Account{}).Where(idQueryCondition, accountId).Update("network_serial", gorm.Expr("network_serial + 1")) diff --git a/management/server/store.go b/management/server/store.go index 2a0c44c67..71b0d457b 100644 --- a/management/server/store.go +++ b/management/server/store.go @@ -93,6 +93,7 @@ type Store interface { GetPeerByPeerPubKey(ctx context.Context, lockStrength LockingStrength, peerKey string) (*nbpeer.Peer, error) GetUserPeers(ctx context.Context, lockStrength LockingStrength, accountID, userID string) ([]*nbpeer.Peer, error) GetPeerByID(ctx context.Context, lockStrength LockingStrength, accountID string, peerID string) (*nbpeer.Peer, error) + GetPeersByIDs(ctx context.Context, lockStrength LockingStrength, accountID string, peerIDs []string) (map[string]*nbpeer.Peer, error) SavePeer(ctx context.Context, accountID string, peer *nbpeer.Peer) error SavePeerStatus(accountID, peerID string, status nbpeer.PeerStatus) error SavePeerLocation(accountID string, peer *nbpeer.Peer) error