diff --git a/client/cmd/service.go b/client/cmd/service.go index 855eb30fa..3560088a7 100644 --- a/client/cmd/service.go +++ b/client/cmd/service.go @@ -2,6 +2,7 @@ package cmd import ( "context" + "sync" "github.com/kardianos/service" log "github.com/sirupsen/logrus" @@ -13,10 +14,11 @@ import ( ) type program struct { - ctx context.Context - cancel context.CancelFunc - serv *grpc.Server - serverInstance *server.Server + ctx context.Context + cancel context.CancelFunc + serv *grpc.Server + serverInstance *server.Server + serverInstanceMu sync.Mutex } func newProgram(ctx context.Context, cancel context.CancelFunc) *program { diff --git a/client/cmd/service_controller.go b/client/cmd/service_controller.go index 86546e31c..761c86628 100644 --- a/client/cmd/service_controller.go +++ b/client/cmd/service_controller.go @@ -61,7 +61,9 @@ func (p *program) Start(svc service.Service) error { } proto.RegisterDaemonServiceServer(p.serv, serverInstance) + p.serverInstanceMu.Lock() p.serverInstance = serverInstance + p.serverInstanceMu.Unlock() log.Printf("started daemon server: %v", split[1]) if err := p.serv.Serve(listen); err != nil { @@ -72,6 +74,7 @@ func (p *program) Start(svc service.Service) error { } func (p *program) Stop(srv service.Service) error { + p.serverInstanceMu.Lock() if p.serverInstance != nil { in := new(proto.DownRequest) _, err := p.serverInstance.Down(p.ctx, in) @@ -79,6 +82,7 @@ func (p *program) Stop(srv service.Service) error { log.Errorf("failed to stop daemon: %v", err) } } + p.serverInstanceMu.Unlock() p.cancel() diff --git a/client/iface/wgproxy/bind/proxy.go b/client/iface/wgproxy/bind/proxy.go index e0883715a..8a2e65382 100644 --- a/client/iface/wgproxy/bind/proxy.go +++ b/client/iface/wgproxy/bind/proxy.go @@ -2,6 +2,7 @@ package bind import ( "context" + "errors" "fmt" "net" "net/netip" @@ -94,7 +95,10 @@ func (p *ProxyBind) close() error { p.Bind.RemoveEndpoint(p.wgAddr) - return p.remoteConn.Close() + if rErr := p.remoteConn.Close(); rErr != nil && !errors.Is(rErr, net.ErrClosed) { + return rErr + } + return nil } func (p *ProxyBind) proxyToLocal(ctx context.Context) { diff --git a/client/iface/wgproxy/ebpf/wrapper.go b/client/iface/wgproxy/ebpf/wrapper.go index efd5fd946..54cab4e1b 100644 --- a/client/iface/wgproxy/ebpf/wrapper.go +++ b/client/iface/wgproxy/ebpf/wrapper.go @@ -77,7 +77,7 @@ func (e *ProxyWrapper) CloseConn() error { e.cancel() - if err := e.remoteConn.Close(); err != nil { + if err := e.remoteConn.Close(); err != nil && !errors.Is(err, net.ErrClosed) { return fmt.Errorf("failed to close remote conn: %w", err) } return nil diff --git a/client/iface/wgproxy/udp/proxy.go b/client/iface/wgproxy/udp/proxy.go index 200d961f3..ba0004b8a 100644 --- a/client/iface/wgproxy/udp/proxy.go +++ b/client/iface/wgproxy/udp/proxy.go @@ -116,7 +116,7 @@ func (p *WGUDPProxy) close() error { p.cancel() var result *multierror.Error - if err := p.remoteConn.Close(); err != nil { + if err := p.remoteConn.Close(); err != nil && !errors.Is(err, net.ErrClosed) { result = multierror.Append(result, fmt.Errorf("remote conn: %s", err)) } diff --git a/client/internal/connect.go b/client/internal/connect.go index bcc9d17a3..dff44f1d2 100644 --- a/client/internal/connect.go +++ b/client/internal/connect.go @@ -207,7 +207,8 @@ func (c *ConnectClient) run(mobileDependency MobileDependency, probes *ProbeHold c.statusRecorder.MarkSignalDisconnected(nil) defer func() { - c.statusRecorder.MarkSignalDisconnected(state.err) + _, err := state.Status() + c.statusRecorder.MarkSignalDisconnected(err) }() // with the global Wiretrustee config in hand connect (just a connection, no stream yet) Signal diff --git a/client/internal/peer/conn.go b/client/internal/peer/conn.go index 84a8c221f..81c456db7 100644 --- a/client/internal/peer/conn.go +++ b/client/internal/peer/conn.go @@ -442,7 +442,7 @@ func (conn *Conn) relayConnectionIsReady(rci RelayConnInfo) { if conn.iceP2PIsActive() { conn.log.Debugf("do not switch to relay because current priority is: %v", conn.currentConnPriority) - conn.wgProxyRelay = wgProxy + conn.setRelayedProxy(wgProxy) conn.statusRelay.Set(StatusConnected) conn.updateRelayStatus(rci.relayedConn.RemoteAddr().String(), rci.rosenpassPubKey) return @@ -465,7 +465,7 @@ func (conn *Conn) relayConnectionIsReady(rci RelayConnInfo) { wgConfigWorkaround() conn.currentConnPriority = connPriorityRelay conn.statusRelay.Set(StatusConnected) - conn.wgProxyRelay = wgProxy + conn.setRelayedProxy(wgProxy) conn.updateRelayStatus(rci.relayedConn.RemoteAddr().String(), rci.rosenpassPubKey) conn.log.Infof("start to communicate with peer via relay") conn.doOnConnected(rci.rosenpassPubKey, rci.rosenpassAddr) @@ -736,6 +736,15 @@ func (conn *Conn) logTraceConnState() { } } +func (conn *Conn) setRelayedProxy(proxy wgproxy.Proxy) { + if conn.wgProxyRelay != nil { + if err := conn.wgProxyRelay.CloseConn(); err != nil { + conn.log.Warnf("failed to close deprecated wg proxy conn: %v", err) + } + } + conn.wgProxyRelay = proxy +} + func isController(config ConnConfig) bool { return config.LocalKey > config.Key } diff --git a/client/internal/peer/status.go b/client/internal/peer/status.go index a28992fac..0444dc60b 100644 --- a/client/internal/peer/status.go +++ b/client/internal/peer/status.go @@ -67,7 +67,7 @@ func (s *State) DeleteRoute(network string) { func (s *State) GetRoutes() map[string]struct{} { s.Mux.RLock() defer s.Mux.RUnlock() - return s.routes + return maps.Clone(s.routes) } // LocalPeerState contains the latest state of the local peer @@ -237,10 +237,6 @@ func (d *Status) UpdatePeerState(receivedState State) error { peerState.IP = receivedState.IP } - if receivedState.GetRoutes() != nil { - peerState.SetRoutes(receivedState.GetRoutes()) - } - skipNotification := shouldSkipNotify(receivedState.ConnStatus, peerState) if receivedState.ConnStatus != peerState.ConnStatus { @@ -261,12 +257,40 @@ func (d *Status) UpdatePeerState(receivedState State) error { return nil } - ch, found := d.changeNotify[receivedState.PubKey] - if found && ch != nil { - close(ch) - d.changeNotify[receivedState.PubKey] = nil + d.notifyPeerListChanged() + return nil +} + +func (d *Status) AddPeerStateRoute(peer string, route string) error { + d.mux.Lock() + defer d.mux.Unlock() + + peerState, ok := d.peers[peer] + if !ok { + return errors.New("peer doesn't exist") } + peerState.AddRoute(route) + d.peers[peer] = peerState + + // todo: consider to make sense of this notification or not + d.notifyPeerListChanged() + return nil +} + +func (d *Status) RemovePeerStateRoute(peer string, route string) error { + d.mux.Lock() + defer d.mux.Unlock() + + peerState, ok := d.peers[peer] + if !ok { + return errors.New("peer doesn't exist") + } + + peerState.DeleteRoute(route) + d.peers[peer] = peerState + + // todo: consider to make sense of this notification or not d.notifyPeerListChanged() return nil } @@ -301,12 +325,7 @@ func (d *Status) UpdatePeerICEState(receivedState State) error { return nil } - ch, found := d.changeNotify[receivedState.PubKey] - if found && ch != nil { - close(ch) - d.changeNotify[receivedState.PubKey] = nil - } - + d.notifyPeerStateChangeListeners(receivedState.PubKey) d.notifyPeerListChanged() return nil } @@ -334,12 +353,7 @@ func (d *Status) UpdatePeerRelayedState(receivedState State) error { return nil } - ch, found := d.changeNotify[receivedState.PubKey] - if found && ch != nil { - close(ch) - d.changeNotify[receivedState.PubKey] = nil - } - + d.notifyPeerStateChangeListeners(receivedState.PubKey) d.notifyPeerListChanged() return nil } @@ -366,12 +380,7 @@ func (d *Status) UpdatePeerRelayedStateToDisconnected(receivedState State) error return nil } - ch, found := d.changeNotify[receivedState.PubKey] - if found && ch != nil { - close(ch) - d.changeNotify[receivedState.PubKey] = nil - } - + d.notifyPeerStateChangeListeners(receivedState.PubKey) d.notifyPeerListChanged() return nil } @@ -401,12 +410,7 @@ func (d *Status) UpdatePeerICEStateToDisconnected(receivedState State) error { return nil } - ch, found := d.changeNotify[receivedState.PubKey] - if found && ch != nil { - close(ch) - d.changeNotify[receivedState.PubKey] = nil - } - + d.notifyPeerStateChangeListeners(receivedState.PubKey) d.notifyPeerListChanged() return nil } @@ -477,11 +481,14 @@ func (d *Status) FinishPeerListModifications() { func (d *Status) GetPeerStateChangeNotifier(peer string) <-chan struct{} { d.mux.Lock() defer d.mux.Unlock() + ch, found := d.changeNotify[peer] - if !found || ch == nil { - ch = make(chan struct{}) - d.changeNotify[peer] = ch + if found { + return ch } + + ch = make(chan struct{}) + d.changeNotify[peer] = ch return ch } @@ -755,6 +762,17 @@ func (d *Status) onConnectionChanged() { d.notifier.updateServerStates(d.managementState, d.signalState) } +// notifyPeerStateChangeListeners notifies route manager about the change in peer state +func (d *Status) notifyPeerStateChangeListeners(peerID string) { + ch, found := d.changeNotify[peerID] + if !found { + return + } + + close(ch) + delete(d.changeNotify, peerID) +} + func (d *Status) notifyPeerListChanged() { d.notifier.peerListChanged(d.numOfPeers()) } diff --git a/client/internal/peer/status_test.go b/client/internal/peer/status_test.go index 1d283433b..931ec9005 100644 --- a/client/internal/peer/status_test.go +++ b/client/internal/peer/status_test.go @@ -93,7 +93,7 @@ func TestGetPeerStateChangeNotifierLogic(t *testing.T) { peerState.IP = ip - err := status.UpdatePeerState(peerState) + err := status.UpdatePeerRelayedStateToDisconnected(peerState) assert.NoError(t, err, "shouldn't return error") select { diff --git a/client/internal/peer/worker_ice.go b/client/internal/peer/worker_ice.go index 55894218d..4c67cb781 100644 --- a/client/internal/peer/worker_ice.go +++ b/client/internal/peer/worker_ice.go @@ -57,6 +57,9 @@ type WorkerICE struct { localUfrag string localPwd string + + // we record the last known state of the ICE agent to avoid duplicate on disconnected events + lastKnownState ice.ConnectionState } func NewWorkerICE(ctx context.Context, log *log.Entry, config ConnConfig, signaler *Signaler, ifaceDiscover stdnet.ExternalIFaceDiscover, statusRecorder *Status, hasRelayOnLocally bool, callBacks WorkerICECallbacks) (*WorkerICE, error) { @@ -194,8 +197,7 @@ func (w *WorkerICE) Close() { return } - err := w.agent.Close() - if err != nil { + if err := w.agent.Close(); err != nil { w.log.Warnf("failed to close ICE agent: %s", err) } } @@ -215,15 +217,18 @@ func (w *WorkerICE) reCreateAgent(agentCancel context.CancelFunc, candidates []i err = agent.OnConnectionStateChange(func(state ice.ConnectionState) { w.log.Debugf("ICE ConnectionState has changed to %s", state.String()) - if state == ice.ConnectionStateFailed || state == ice.ConnectionStateDisconnected { - w.conn.OnStatusChanged(StatusDisconnected) - - w.muxAgent.Lock() - agentCancel() - _ = agent.Close() - w.agent = nil - - w.muxAgent.Unlock() + switch state { + case ice.ConnectionStateConnected: + w.lastKnownState = ice.ConnectionStateConnected + return + case ice.ConnectionStateFailed, ice.ConnectionStateDisconnected: + if w.lastKnownState != ice.ConnectionStateDisconnected { + w.lastKnownState = ice.ConnectionStateDisconnected + w.conn.OnStatusChanged(StatusDisconnected) + } + w.closeAgent(agentCancel) + default: + return } }) if err != nil { @@ -249,6 +254,17 @@ func (w *WorkerICE) reCreateAgent(agentCancel context.CancelFunc, candidates []i return agent, nil } +func (w *WorkerICE) closeAgent(cancel context.CancelFunc) { + w.muxAgent.Lock() + defer w.muxAgent.Unlock() + + cancel() + if err := w.agent.Close(); err != nil { + w.log.Warnf("failed to close ICE agent: %s", err) + } + w.agent = nil +} + func (w *WorkerICE) punchRemoteWGPort(pair *ice.CandidatePair, remoteWgPort int) { // wait local endpoint configuration time.Sleep(time.Second) diff --git a/client/internal/routemanager/client.go b/client/internal/routemanager/client.go index eaa232151..13e45b3a3 100644 --- a/client/internal/routemanager/client.go +++ b/client/internal/routemanager/client.go @@ -122,13 +122,20 @@ func (c *clientNetwork) getBestRouteFromStatuses(routePeerStatuses map[route.ID] tempScore = float64(metricDiff) * 10 } - // in some temporal cases, latency can be 0, so we set it to 1s to not block but try to avoid this route - latency := time.Second + // in some temporal cases, latency can be 0, so we set it to 999ms to not block but try to avoid this route + latency := 999 * time.Millisecond if peerStatus.latency != 0 { latency = peerStatus.latency } else { - log.Warnf("peer %s has 0 latency", r.Peer) + log.Tracef("peer %s has 0 latency, range %s", r.Peer, c.handler) } + + // avoid negative tempScore on the higher latency calculation + if latency > 1*time.Second { + latency = 999 * time.Millisecond + } + + // higher latency is worse score tempScore += 1 - latency.Seconds() if !peerStatus.relayed { @@ -150,6 +157,8 @@ func (c *clientNetwork) getBestRouteFromStatuses(routePeerStatuses map[route.ID] } } + log.Debugf("chosen route: %s, chosen score: %f, current route: %s, current score: %f", chosen, chosenScore, currID, currScore) + switch { case chosen == "": var peers []string @@ -195,15 +204,20 @@ func (c *clientNetwork) watchPeerStatusChanges(ctx context.Context, peerKey stri func (c *clientNetwork) startPeersStatusChangeWatcher() { for _, r := range c.routes { _, found := c.routePeersNotifiers[r.Peer] - if !found { - c.routePeersNotifiers[r.Peer] = make(chan struct{}) - go c.watchPeerStatusChanges(c.ctx, r.Peer, c.peerStateUpdate, c.routePeersNotifiers[r.Peer]) + if found { + continue } + + closerChan := make(chan struct{}) + c.routePeersNotifiers[r.Peer] = closerChan + go c.watchPeerStatusChanges(c.ctx, r.Peer, c.peerStateUpdate, closerChan) } } -func (c *clientNetwork) removeRouteFromWireguardPeer() error { - c.removeStateRoute() +func (c *clientNetwork) removeRouteFromWireGuardPeer() error { + if err := c.statusRecorder.RemovePeerStateRoute(c.currentChosen.Peer, c.handler.String()); err != nil { + log.Warnf("Failed to update peer state: %v", err) + } if err := c.handler.RemoveAllowedIPs(); err != nil { return fmt.Errorf("remove allowed IPs: %w", err) @@ -218,7 +232,7 @@ func (c *clientNetwork) removeRouteFromPeerAndSystem() error { var merr *multierror.Error - if err := c.removeRouteFromWireguardPeer(); err != nil { + if err := c.removeRouteFromWireGuardPeer(); err != nil { merr = multierror.Append(merr, fmt.Errorf("remove allowed IPs for peer %s: %w", c.currentChosen.Peer, err)) } if err := c.handler.RemoveRoute(); err != nil { @@ -257,7 +271,7 @@ func (c *clientNetwork) recalculateRouteAndUpdatePeerAndSystem() error { } } else { // Otherwise, remove the allowed IPs from the previous peer first - if err := c.removeRouteFromWireguardPeer(); err != nil { + if err := c.removeRouteFromWireGuardPeer(); err != nil { return fmt.Errorf("remove allowed IPs for peer %s: %w", c.currentChosen.Peer, err) } } @@ -268,37 +282,13 @@ func (c *clientNetwork) recalculateRouteAndUpdatePeerAndSystem() error { return fmt.Errorf("add allowed IPs for peer %s: %w", c.currentChosen.Peer, err) } - c.addStateRoute() - + err := c.statusRecorder.AddPeerStateRoute(c.currentChosen.Peer, c.handler.String()) + if err != nil { + return fmt.Errorf("add peer state route: %w", err) + } return nil } -func (c *clientNetwork) addStateRoute() { - state, err := c.statusRecorder.GetPeer(c.currentChosen.Peer) - if err != nil { - log.Errorf("Failed to get peer state: %v", err) - return - } - - state.AddRoute(c.handler.String()) - if err := c.statusRecorder.UpdatePeerState(state); err != nil { - log.Warnf("Failed to update peer state: %v", err) - } -} - -func (c *clientNetwork) removeStateRoute() { - state, err := c.statusRecorder.GetPeer(c.currentChosen.Peer) - if err != nil { - log.Errorf("Failed to get peer state: %v", err) - return - } - - state.DeleteRoute(c.handler.String()) - if err := c.statusRecorder.UpdatePeerState(state); err != nil { - log.Warnf("Failed to update peer state: %v", err) - } -} - func (c *clientNetwork) sendUpdateToClientNetworkWatcher(update routesUpdate) { go func() { c.routeUpdate <- update diff --git a/client/internal/routemanager/refcounter/refcounter.go b/client/internal/routemanager/refcounter/refcounter.go index c121b7d77..0e230ef40 100644 --- a/client/internal/routemanager/refcounter/refcounter.go +++ b/client/internal/routemanager/refcounter/refcounter.go @@ -217,6 +217,11 @@ func (rm *Counter[Key, I, O]) Clear() { // MarshalJSON implements the json.Marshaler interface for Counter. func (rm *Counter[Key, I, O]) MarshalJSON() ([]byte, error) { + rm.refCountMu.Lock() + defer rm.refCountMu.Unlock() + rm.idMu.Lock() + defer rm.idMu.Unlock() + return json.Marshal(struct { RefCountMap map[Key]Ref[O] `json:"refCountMap"` IDMap map[string][]Key `json:"idMap"` diff --git a/management/server/group.go b/management/server/group.go index ee42b0064..5d3014169 100644 --- a/management/server/group.go +++ b/management/server/group.go @@ -89,6 +89,10 @@ func (am *DefaultAccountManager) SaveGroups(ctx context.Context, accountID, user return status.NewUserNotPartOfAccountError() } + if user.IsRegularUser() { + return status.NewAdminPermissionError() + } + var eventsToStore []func() var groupsToSave []*nbgroup.Group var updateAccountPeers bool @@ -152,34 +156,41 @@ func (am *DefaultAccountManager) prepareGroupEvents(ctx context.Context, transac }) } + modifiedPeers := slices.Concat(addedPeers, removedPeers) + peers, err := transaction.GetPeersByIDs(ctx, LockingStrengthShare, accountID, modifiedPeers) + if err != nil { + log.WithContext(ctx).Debugf("failed to get peers for group events: %v", err) + return nil + } + for _, peerID := range addedPeers { - peer, err := transaction.GetPeerByID(context.Background(), LockingStrengthShare, accountID, peerID) - if err != nil { - log.WithContext(ctx).Debugf("skipped adding peer: %s GroupAddedToPeer activity: %v", peerID, err) + peer, ok := peers[peerID] + if !ok { + log.WithContext(ctx).Debugf("skipped adding peer: %s GroupAddedToPeer activity: peer not found in store", peerID) continue } - meta := map[string]any{ - "group": newGroup.Name, "group_id": newGroup.ID, - "peer_ip": peer.IP.String(), "peer_fqdn": peer.FQDN(am.GetDNSDomain()), - } eventsToStore = append(eventsToStore, func() { + meta := map[string]any{ + "group": newGroup.Name, "group_id": newGroup.ID, + "peer_ip": peer.IP.String(), "peer_fqdn": peer.FQDN(am.GetDNSDomain()), + } am.StoreEvent(ctx, userID, peer.ID, accountID, activity.GroupAddedToPeer, meta) }) } for _, peerID := range removedPeers { - peer, err := transaction.GetPeerByID(context.Background(), LockingStrengthShare, accountID, peerID) - if err != nil { - log.WithContext(ctx).Debugf("skipped adding peer: %s GroupRemovedFromPeer activity: %v", peerID, err) + peer, ok := peers[peerID] + if !ok { + log.WithContext(ctx).Debugf("skipped adding peer: %s GroupRemovedFromPeer activity: peer not found in store", peerID) continue } - meta := map[string]any{ - "group": newGroup.Name, "group_id": newGroup.ID, - "peer_ip": peer.IP.String(), "peer_fqdn": peer.FQDN(am.GetDNSDomain()), - } eventsToStore = append(eventsToStore, func() { + meta := map[string]any{ + "group": newGroup.Name, "group_id": newGroup.ID, + "peer_ip": peer.IP.String(), "peer_fqdn": peer.FQDN(am.GetDNSDomain()), + } am.StoreEvent(ctx, userID, peer.ID, accountID, activity.GroupRemovedFromPeer, meta) }) } @@ -213,6 +224,10 @@ func (am *DefaultAccountManager) DeleteGroup(ctx context.Context, accountID, use return status.NewUserNotPartOfAccountError() } + if user.IsRegularUser() { + return status.NewAdminPermissionError() + } + var group *nbgroup.Group err = am.Store.ExecuteInTransaction(ctx, func(transaction Store) error { @@ -260,6 +275,10 @@ func (am *DefaultAccountManager) DeleteGroups(ctx context.Context, accountID, us return status.NewUserNotPartOfAccountError() } + if user.IsRegularUser() { + return status.NewAdminPermissionError() + } + var allErrors error var groupIDsToDelete []string var deletedGroups []*nbgroup.Group @@ -438,6 +457,11 @@ func validateDeleteGroup(ctx context.Context, transaction Store, group *nbgroup. return &GroupLinkError{"user", linkedUser.Id} } + return checkGroupLinkedToSettings(ctx, transaction, group) +} + +// checkGroupLinkedToSettings verifies if a group is linked to any settings in the account. +func checkGroupLinkedToSettings(ctx context.Context, transaction Store, group *nbgroup.Group) error { dnsSettings, err := transaction.GetAccountDNSSettings(ctx, LockingStrengthShare, group.AccountID) if err != nil { return err @@ -452,10 +476,8 @@ func validateDeleteGroup(ctx context.Context, transaction Store, group *nbgroup. return err } - if settings.Extra != nil { - if slices.Contains(settings.Extra.IntegratedValidatorGroups, group.ID) { - return &GroupLinkError{"integrated validator", group.Name} - } + if settings.Extra != nil && slices.Contains(settings.Extra.IntegratedValidatorGroups, group.ID) { + return &GroupLinkError{"integrated validator", group.Name} } return nil diff --git a/management/server/group/group.go b/management/server/group/group.go index bb0f5b7b6..24c60d3ce 100644 --- a/management/server/group/group.go +++ b/management/server/group/group.go @@ -55,8 +55,7 @@ func (g *Group) IsGroupAll() bool { return g.Name == "All" } -// AddPeer adds peerID to Peers if not already present, -// returning true if added. +// AddPeer adds peerID to Peers if not present, returning true if added. func (g *Group) AddPeer(peerID string) bool { if peerID == "" { return false @@ -72,8 +71,7 @@ func (g *Group) AddPeer(peerID string) bool { return true } -// RemovePeer removes peerID from Peers if present, -// returning true if removed. +// RemovePeer removes peerID from Peers if present, returning true if removed. func (g *Group) RemovePeer(peerID string) bool { for i, itemID := range g.Peers { if itemID == peerID { diff --git a/management/server/grpcserver.go b/management/server/grpcserver.go index 4c4ef6c3c..efe088b27 100644 --- a/management/server/grpcserver.go +++ b/management/server/grpcserver.go @@ -6,6 +6,7 @@ import ( "net" "net/netip" "strings" + "sync" "time" pb "github.com/golang/protobuf/proto" // nolint @@ -38,6 +39,7 @@ type GRPCServer struct { jwtClaimsExtractor *jwtclaims.ClaimsExtractor appMetrics telemetry.AppMetrics ephemeralManager *EphemeralManager + peerLocks sync.Map } // NewServer creates a new Management server @@ -148,6 +150,13 @@ func (s *GRPCServer) Sync(req *proto.EncryptedMessage, srv proto.ManagementServi // nolint:staticcheck ctx = context.WithValue(ctx, nbContext.PeerIDKey, peerKey.String()) + unlock := s.acquirePeerLockByUID(ctx, peerKey.String()) + defer func() { + if unlock != nil { + unlock() + } + }() + accountID, err := s.accountManager.GetAccountIDForPeerKey(ctx, peerKey.String()) if err != nil { // nolint:staticcheck @@ -190,6 +199,9 @@ func (s *GRPCServer) Sync(req *proto.EncryptedMessage, srv proto.ManagementServi s.appMetrics.GRPCMetrics().CountSyncRequestDuration(time.Since(reqStart)) } + unlock() + unlock = nil + return s.handleUpdates(ctx, accountID, peerKey, peer, updates, srv) } @@ -245,9 +257,12 @@ func (s *GRPCServer) sendUpdate(ctx context.Context, accountID string, peerKey w } func (s *GRPCServer) cancelPeerRoutines(ctx context.Context, accountID string, peer *nbpeer.Peer) { + unlock := s.acquirePeerLockByUID(ctx, peer.Key) + defer unlock() + + _ = s.accountManager.OnPeerDisconnected(ctx, accountID, peer.Key) s.peersUpdateManager.CloseChannel(ctx, peer.ID) s.secretsManager.CancelRefresh(peer.ID) - _ = s.accountManager.OnPeerDisconnected(ctx, accountID, peer.Key) s.ephemeralManager.OnPeerDisconnected(ctx, peer) } @@ -274,6 +289,24 @@ func (s *GRPCServer) validateToken(ctx context.Context, jwtToken string) (string return claims.UserId, nil } +func (s *GRPCServer) acquirePeerLockByUID(ctx context.Context, uniqueID string) (unlock func()) { + log.WithContext(ctx).Tracef("acquiring peer lock for ID %s", uniqueID) + + start := time.Now() + value, _ := s.peerLocks.LoadOrStore(uniqueID, &sync.RWMutex{}) + mtx := value.(*sync.RWMutex) + mtx.Lock() + log.WithContext(ctx).Tracef("acquired peer lock for ID %s in %v", uniqueID, time.Since(start)) + start = time.Now() + + unlock = func() { + mtx.Unlock() + log.WithContext(ctx).Tracef("released peer lock for ID %s in %v", uniqueID, time.Since(start)) + } + + return unlock +} + // maps internal internalStatus.Error to gRPC status.Error func mapError(ctx context.Context, err error) error { if e, ok := internalStatus.FromError(err); ok { diff --git a/management/server/http/routes_handler.go b/management/server/http/routes_handler.go index ce4edee4f..f44a164e2 100644 --- a/management/server/http/routes_handler.go +++ b/management/server/http/routes_handler.go @@ -149,7 +149,7 @@ func (h *RoutesHandler) validateRoute(req api.PostApiRoutesJSONRequestBody) erro } if req.Peer == nil && req.PeerGroups == nil { - return status.Errorf(status.InvalidArgument, "either 'peer' or 'peers_group' should be provided") + return status.Errorf(status.InvalidArgument, "either 'peer' or 'peer_groups' should be provided") } if req.Peer != nil && req.PeerGroups != nil { diff --git a/management/server/setupkey.go b/management/server/setupkey.go index d6e92fe3a..f055d877f 100644 --- a/management/server/setupkey.go +++ b/management/server/setupkey.go @@ -5,6 +5,7 @@ import ( "crypto/sha256" b64 "encoding/base64" "hash/fnv" + "slices" "strconv" "strings" "time" @@ -236,6 +237,10 @@ func (am *DefaultAccountManager) CreateSetupKey(ctx context.Context, accountID s return nil, status.NewUserNotPartOfAccountError() } + if user.IsRegularUser() { + return nil, status.NewAdminPermissionError() + } + var setupKey *SetupKey var plainKey string var eventsToStore []func() @@ -289,6 +294,10 @@ func (am *DefaultAccountManager) SaveSetupKey(ctx context.Context, accountID str return nil, status.NewUserNotPartOfAccountError() } + if user.IsRegularUser() { + return nil, status.NewAdminPermissionError() + } + var oldKey *SetupKey var newKey *SetupKey var eventsToStore []func() @@ -414,10 +423,15 @@ func (am *DefaultAccountManager) DeleteSetupKey(ctx context.Context, accountID, } func validateSetupKeyAutoGroups(ctx context.Context, transaction Store, accountID string, autoGroupIDs []string) error { + groups, err := transaction.GetGroupsByIDs(ctx, LockingStrengthShare, accountID, autoGroupIDs) + if err != nil { + return err + } + for _, groupID := range autoGroupIDs { - group, err := transaction.GetGroupByID(ctx, LockingStrengthShare, accountID, groupID) - if err != nil { - return err + group, ok := groups[groupID] + if !ok { + return status.Errorf(status.NotFound, "group not found: %s", groupID) } if group.IsGroupAll() { @@ -432,26 +446,37 @@ func validateSetupKeyAutoGroups(ctx context.Context, transaction Store, accountI func (am *DefaultAccountManager) prepareSetupKeyEvents(ctx context.Context, transaction Store, accountID, userID string, addedGroups, removedGroups []string, key *SetupKey) []func() { var eventsToStore []func() + modifiedGroups := slices.Concat(addedGroups, removedGroups) + groups, err := transaction.GetGroupsByIDs(ctx, LockingStrengthShare, accountID, modifiedGroups) + if err != nil { + log.WithContext(ctx).Debugf("failed to get groups for setup key events: %v", err) + return nil + } + for _, g := range removedGroups { - group, err := transaction.GetGroupByID(ctx, LockingStrengthShare, accountID, g) - if err != nil { - log.WithContext(ctx).Debugf("skipped adding group: %s GroupRemovedFromSetupKey activity: %v", g, err) + group, ok := groups[g] + if !ok { + log.WithContext(ctx).Debugf("skipped adding group: %s GroupRemovedFromSetupKey activity: group not found", g) continue } - meta := map[string]any{"group": group.Name, "group_id": group.ID, "setupkey": key.Name} - am.StoreEvent(ctx, userID, key.Id, accountID, activity.GroupRemovedFromSetupKey, meta) + eventsToStore = append(eventsToStore, func() { + meta := map[string]any{"group": group.Name, "group_id": group.ID, "setupkey": key.Name} + am.StoreEvent(ctx, userID, key.Id, accountID, activity.GroupRemovedFromSetupKey, meta) + }) } for _, g := range addedGroups { - group, err := transaction.GetGroupByID(ctx, LockingStrengthShare, accountID, g) - if err != nil { - log.WithContext(ctx).Debugf("skipped adding group: %s GroupAddedToSetupKey activity: %v", g, err) + group, ok := groups[g] + if !ok { + log.WithContext(ctx).Debugf("skipped adding group: %s GroupAddedToSetupKey activity: group not found", g) continue } - meta := map[string]any{"group": group.Name, "group_id": group.ID, "setupkey": key.Name} - am.StoreEvent(ctx, userID, key.Id, accountID, activity.GroupAddedToSetupKey, meta) + eventsToStore = append(eventsToStore, func() { + meta := map[string]any{"group": group.Name, "group_id": group.ID, "setupkey": key.Name} + am.StoreEvent(ctx, userID, key.Id, accountID, activity.GroupAddedToSetupKey, meta) + }) } return eventsToStore diff --git a/management/server/sql_store.go b/management/server/sql_store.go index 466d36aff..81dc704c2 100644 --- a/management/server/sql_store.go +++ b/management/server/sql_store.go @@ -33,12 +33,13 @@ import ( ) const ( - storeSqliteFileName = "store.db" - idQueryCondition = "id = ?" - keyQueryCondition = "key = ?" - accountAndIDQueryCondition = "account_id = ? and id = ?" - accountIDCondition = "account_id = ?" - peerNotFoundFMT = "peer %s not found" + storeSqliteFileName = "store.db" + idQueryCondition = "id = ?" + keyQueryCondition = "key = ?" + accountAndIDQueryCondition = "account_id = ? and id = ?" + accountAndIDsQueryCondition = "account_id = ? AND id IN ?" + accountIDCondition = "account_id = ?" + peerNotFoundFMT = "peer %s not found" ) // SqlStore represents an account storage backed by a Sql DB persisted to disk @@ -485,9 +486,10 @@ func (s *SqlStore) GetAccountBySetupKey(ctx context.Context, setupKey string) (* result := s.db.Select("account_id").First(&key, keyQueryCondition, setupKey) if result.Error != nil { if errors.Is(result.Error, gorm.ErrRecordNotFound) { - return nil, status.Errorf(status.NotFound, "account not found: index lookup failed") + return nil, status.NewSetupKeyNotFoundError(setupKey) } - return nil, status.NewSetupKeyNotFoundError(result.Error) + log.WithContext(ctx).Errorf("failed to get account by setup key from store: %v", result.Error) + return nil, status.Errorf(status.Internal, "failed to get account by setup key from store") } if key.AccountID == "" { @@ -570,7 +572,7 @@ func (s *SqlStore) GetAccountUsers(ctx context.Context, lockStrength LockingStre func (s *SqlStore) GetAccountGroups(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*nbgroup.Group, error) { var groups []*nbgroup.Group - result := s.db.WithContext(ctx).Clauses(clause.Locking{Strength: string(lockStrength)}).Find(&groups, accountIDCondition, accountID) + result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}).Find(&groups, accountIDCondition, accountID) if result.Error != nil { if errors.Is(result.Error, gorm.ErrRecordNotFound) { return nil, status.Errorf(status.NotFound, "accountID not found: index lookup failed") @@ -756,9 +758,10 @@ func (s *SqlStore) GetAccountIDBySetupKey(ctx context.Context, setupKey string) result := s.db.Model(&SetupKey{}).Select("account_id").Where(keyQueryCondition, setupKey).First(&accountID) if result.Error != nil { if errors.Is(result.Error, gorm.ErrRecordNotFound) { - return "", status.Errorf(status.NotFound, "account not found: index lookup failed") + return "", status.NewSetupKeyNotFoundError(setupKey) } - return "", status.NewSetupKeyNotFoundError(result.Error) + log.WithContext(ctx).Errorf("failed to get account ID by setup key from store: %v", result.Error) + return "", status.Errorf(status.Internal, "failed to get account ID by setup key from store") } if accountID == "" { @@ -985,9 +988,10 @@ func (s *SqlStore) GetSetupKeyBySecret(ctx context.Context, lockStrength Locking First(&setupKey, keyQueryCondition, key) if result.Error != nil { if errors.Is(result.Error, gorm.ErrRecordNotFound) { - return nil, status.Errorf(status.NotFound, "setup key not found") + return nil, status.NewSetupKeyNotFoundError(key) } - return nil, status.NewSetupKeyNotFoundError(result.Error) + log.WithContext(ctx).Errorf("failed to get setup key by secret from store: %v", result.Error) + return nil, status.Errorf(status.Internal, "failed to get setup key by secret from store") } return &setupKey, nil } @@ -1005,7 +1009,7 @@ func (s *SqlStore) IncrementSetupKeyUsage(ctx context.Context, setupKeyID string } if result.RowsAffected == 0 { - return status.Errorf(status.NotFound, "setup key not found") + return status.NewSetupKeyNotFoundError(setupKeyID) } return nil @@ -1091,11 +1095,29 @@ func (s *SqlStore) GetPeerByID(ctx context.Context, lockStrength LockingStrength return peer, nil } +// GetPeersByIDs retrieves peers by their IDs and account ID. +func (s *SqlStore) GetPeersByIDs(ctx context.Context, lockStrength LockingStrength, accountID string, peerIDs []string) (map[string]*nbpeer.Peer, error) { + var peers []*nbpeer.Peer + result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}).Find(&peers, accountAndIDsQueryCondition, accountID, peerIDs) + if result.Error != nil { + log.WithContext(ctx).Errorf("failed to get peers by ID's from the store: %s", result.Error) + return nil, status.Errorf(status.Internal, "failed to get peers by ID's from the store") + } + + peersMap := make(map[string]*nbpeer.Peer) + for _, peer := range peers { + peersMap[peer.ID] = peer + } + + return peersMap, nil +} + func (s *SqlStore) IncrementNetworkSerial(ctx context.Context, lockStrength LockingStrength, accountId string) error { - result := s.db.WithContext(ctx).Clauses(clause.Locking{Strength: string(lockStrength)}). + result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}). Model(&Account{}).Where(idQueryCondition, accountId).Update("network_serial", gorm.Expr("network_serial + 1")) if result.Error != nil { - return status.Errorf(status.Internal, "issue incrementing network serial count: %s", result.Error) + log.WithContext(ctx).Errorf("failed to increment network serial count in store: %v", result.Error) + return status.Errorf(status.Internal, "failed to increment network serial count in store") } return nil } @@ -1207,6 +1229,23 @@ func (s *SqlStore) GetGroupByName(ctx context.Context, lockStrength LockingStren return &group, nil } +// GetGroupsByIDs retrieves groups by their IDs and account ID. +func (s *SqlStore) GetGroupsByIDs(ctx context.Context, lockStrength LockingStrength, accountID string, groupIDs []string) (map[string]*nbgroup.Group, error) { + var groups []*nbgroup.Group + result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}).Find(&groups, accountAndIDsQueryCondition, accountID, groupIDs) + if result.Error != nil { + log.WithContext(ctx).Errorf("failed to get groups by ID's from the store: %s", result.Error) + return nil, status.Errorf(status.Internal, "failed to get groups by ID's from the store") + } + + groupsMap := make(map[string]*nbgroup.Group) + for _, group := range groups { + groupsMap[group.ID] = group + } + + return groupsMap, nil +} + // SaveGroup saves a group to the store. func (s *SqlStore) SaveGroup(ctx context.Context, lockStrength LockingStrength, group *nbgroup.Group) error { result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}).Save(group) @@ -1236,7 +1275,7 @@ func (s *SqlStore) DeleteGroup(ctx context.Context, lockStrength LockingStrength // DeleteGroups deletes groups from the database. func (s *SqlStore) DeleteGroups(ctx context.Context, strength LockingStrength, accountID string, groupIDs []string) error { result := s.db.Clauses(clause.Locking{Strength: string(strength)}). - Delete(&nbgroup.Group{}, " account_id = ? AND id IN ?", accountID, groupIDs) + Delete(&nbgroup.Group{}, accountAndIDsQueryCondition, accountID, groupIDs) if result.Error != nil { log.WithContext(ctx).Errorf("failed to delete groups from store: %v", result.Error) return status.Errorf(status.Internal, "failed to delete groups from store: %v", result.Error) @@ -1326,7 +1365,7 @@ func (s *SqlStore) GetRouteByID(ctx context.Context, lockStrength LockingStrengt // GetAccountSetupKeys retrieves setup keys for an account. func (s *SqlStore) GetAccountSetupKeys(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*SetupKey, error) { var setupKeys []*SetupKey - result := s.db.WithContext(ctx).Clauses(clause.Locking{Strength: string(lockStrength)}). + result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}). Find(&setupKeys, accountIDCondition, accountID) if err := result.Error; err != nil { log.WithContext(ctx).Errorf("failed to get setup keys from the store: %s", err) @@ -1339,11 +1378,11 @@ func (s *SqlStore) GetAccountSetupKeys(ctx context.Context, lockStrength Locking // GetSetupKeyByID retrieves a setup key by its ID and account ID. func (s *SqlStore) GetSetupKeyByID(ctx context.Context, lockStrength LockingStrength, accountID, setupKeyID string) (*SetupKey, error) { var setupKey *SetupKey - result := s.db.WithContext(ctx).Clauses(clause.Locking{Strength: string(lockStrength)}). + result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}). First(&setupKey, accountAndIDQueryCondition, accountID, setupKeyID) if err := result.Error; err != nil { if errors.Is(err, gorm.ErrRecordNotFound) { - return nil, status.Errorf(status.NotFound, "setup key not found") + return nil, status.NewSetupKeyNotFoundError(setupKeyID) } log.WithContext(ctx).Errorf("failed to get setup key from the store: %s", err) return nil, status.Errorf(status.Internal, "failed to get setup key from store") @@ -1354,8 +1393,7 @@ func (s *SqlStore) GetSetupKeyByID(ctx context.Context, lockStrength LockingStre // SaveSetupKey saves a setup key to the database. func (s *SqlStore) SaveSetupKey(ctx context.Context, lockStrength LockingStrength, setupKey *SetupKey) error { - result := s.db.WithContext(ctx).Session(&gorm.Session{FullSaveAssociations: true}). - Clauses(clause.Locking{Strength: string(lockStrength)}).Save(setupKey) + result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}).Save(setupKey) if result.Error != nil { log.WithContext(ctx).Errorf("failed to save setup key to store: %s", result.Error) return status.Errorf(status.Internal, "failed to save setup key to store") @@ -1366,15 +1404,14 @@ func (s *SqlStore) SaveSetupKey(ctx context.Context, lockStrength LockingStrengt // DeleteSetupKey deletes a setup key from the database. func (s *SqlStore) DeleteSetupKey(ctx context.Context, lockStrength LockingStrength, accountID, keyID string) error { - result := s.db.WithContext(ctx).Clauses(clause.Locking{Strength: string(lockStrength)}). - Delete(&SetupKey{}, accountAndIDQueryCondition, accountID, keyID) + result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}).Delete(&SetupKey{}, accountAndIDQueryCondition, accountID, keyID) if result.Error != nil { log.WithContext(ctx).Errorf("failed to delete setup key from store: %s", result.Error) return status.Errorf(status.Internal, "failed to delete setup key from store") } if result.RowsAffected == 0 { - return status.Errorf(status.NotFound, "setup key not found") + return status.NewSetupKeyNotFoundError(keyID) } return nil diff --git a/management/server/sql_store_test.go b/management/server/sql_store_test.go index 20409798b..114da1ee6 100644 --- a/management/server/sql_store_test.go +++ b/management/server/sql_store_test.go @@ -14,11 +14,10 @@ import ( "time" "github.com/google/uuid" - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" - nbdns "github.com/netbirdio/netbird/dns" nbgroup "github.com/netbirdio/netbird/management/server/group" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" route2 "github.com/netbirdio/netbird/route" @@ -1293,3 +1292,275 @@ func Test_DeleteSetupKeyFailsForNonExistingKey(t *testing.T) { err = store.DeleteSetupKey(context.Background(), LockingStrengthUpdate, accountID, nonExistingKeyID) require.Error(t, err) } + +func TestSqlStore_GetGroupsByIDs(t *testing.T) { + store, cleanup, err := NewTestStoreFromSQL(context.Background(), "testdata/extended-store.sql", t.TempDir()) + t.Cleanup(cleanup) + require.NoError(t, err) + + accountID := "bf1c8084-ba50-4ce7-9439-34653001fc3b" + + tests := []struct { + name string + groupIDs []string + expectedCount int + }{ + { + name: "retrieve existing groups by existing IDs", + groupIDs: []string{"cfefqs706sqkneg59g4g", "cfefqs706sqkneg59g3g"}, + expectedCount: 2, + }, + { + name: "empty group IDs list", + groupIDs: []string{}, + expectedCount: 0, + }, + { + name: "non-existing group IDs", + groupIDs: []string{"nonexistent1", "nonexistent2"}, + expectedCount: 0, + }, + { + name: "mixed existing and non-existing group IDs", + groupIDs: []string{"cfefqs706sqkneg59g4g", "nonexistent"}, + expectedCount: 1, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + groups, err := store.GetGroupsByIDs(context.Background(), LockingStrengthShare, accountID, tt.groupIDs) + require.NoError(t, err) + require.Len(t, groups, tt.expectedCount) + }) + } +} + +func TestSqlStore_SaveGroup(t *testing.T) { + store, cleanup, err := NewTestStoreFromSQL(context.Background(), "testdata/extended-store.sql", t.TempDir()) + t.Cleanup(cleanup) + require.NoError(t, err) + + accountID := "bf1c8084-ba50-4ce7-9439-34653001fc3b" + + group := &nbgroup.Group{ + ID: "group-id", + AccountID: accountID, + Issued: "api", + Peers: []string{"peer1", "peer2"}, + } + err = store.SaveGroup(context.Background(), LockingStrengthUpdate, group) + require.NoError(t, err) + + savedGroup, err := store.GetGroupByID(context.Background(), LockingStrengthShare, accountID, "group-id") + require.NoError(t, err) + require.Equal(t, savedGroup, group) +} + +func TestSqlStore_SaveGroups(t *testing.T) { + store, cleanup, err := NewTestStoreFromSQL(context.Background(), "testdata/extended-store.sql", t.TempDir()) + t.Cleanup(cleanup) + require.NoError(t, err) + + accountID := "bf1c8084-ba50-4ce7-9439-34653001fc3b" + + groups := []*nbgroup.Group{ + { + ID: "group-1", + AccountID: accountID, + Issued: "api", + Peers: []string{"peer1", "peer2"}, + }, + { + ID: "group-2", + AccountID: accountID, + Issued: "integration", + Peers: []string{"peer3", "peer4"}, + }, + } + err = store.SaveGroups(context.Background(), LockingStrengthUpdate, groups) + require.NoError(t, err) +} + +func TestSqlStore_DeleteGroup(t *testing.T) { + store, cleanup, err := NewTestStoreFromSQL(context.Background(), "testdata/extended-store.sql", t.TempDir()) + t.Cleanup(cleanup) + require.NoError(t, err) + + accountID := "bf1c8084-ba50-4ce7-9439-34653001fc3b" + + tests := []struct { + name string + groupID string + expectError bool + }{ + { + name: "delete existing group", + groupID: "cfefqs706sqkneg59g4g", + expectError: false, + }, + { + name: "delete non-existing group", + groupID: "non-existing-group-id", + expectError: true, + }, + { + name: "delete with empty group ID", + groupID: "", + expectError: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + err := store.DeleteGroup(context.Background(), LockingStrengthUpdate, accountID, tt.groupID) + if tt.expectError { + require.Error(t, err) + sErr, ok := status.FromError(err) + require.True(t, ok) + require.Equal(t, sErr.Type(), status.NotFound) + } else { + require.NoError(t, err) + + group, err := store.GetGroupByID(context.Background(), LockingStrengthShare, accountID, tt.groupID) + require.Error(t, err) + require.Nil(t, group) + } + }) + } +} + +func TestSqlStore_DeleteGroups(t *testing.T) { + store, cleanup, err := NewTestStoreFromSQL(context.Background(), "testdata/extended-store.sql", t.TempDir()) + t.Cleanup(cleanup) + require.NoError(t, err) + + accountID := "bf1c8084-ba50-4ce7-9439-34653001fc3b" + + tests := []struct { + name string + groupIDs []string + expectError bool + }{ + { + name: "delete multiple existing groups", + groupIDs: []string{"cfefqs706sqkneg59g4g", "cfefqs706sqkneg59g3g"}, + expectError: false, + }, + { + name: "delete non-existing groups", + groupIDs: []string{"non-existing-id-1", "non-existing-id-2"}, + expectError: false, + }, + { + name: "delete with empty group IDs list", + groupIDs: []string{}, + expectError: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + err := store.DeleteGroups(context.Background(), LockingStrengthUpdate, accountID, tt.groupIDs) + if tt.expectError { + require.Error(t, err) + } else { + require.NoError(t, err) + + for _, groupID := range tt.groupIDs { + group, err := store.GetGroupByID(context.Background(), LockingStrengthShare, accountID, groupID) + require.Error(t, err) + require.Nil(t, group) + } + } + }) + } +} + +func TestSqlStore_GetPeerByID(t *testing.T) { + store, cleanup, err := NewTestStoreFromSQL(context.Background(), "testdata/store_policy_migrate.sql", t.TempDir()) + t.Cleanup(cleanup) + require.NoError(t, err) + + accountID := "bf1c8084-ba50-4ce7-9439-34653001fc3b" + tests := []struct { + name string + peerID string + expectError bool + }{ + { + name: "retrieve existing peer", + peerID: "cfefqs706sqkneg59g4g", + expectError: false, + }, + { + name: "retrieve non-existing peer", + peerID: "non-existing", + expectError: true, + }, + { + name: "retrieve with empty peer ID", + peerID: "", + expectError: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + peer, err := store.GetPeerByID(context.Background(), LockingStrengthShare, accountID, tt.peerID) + if tt.expectError { + require.Error(t, err) + sErr, ok := status.FromError(err) + require.True(t, ok) + require.Equal(t, sErr.Type(), status.NotFound) + require.Nil(t, peer) + } else { + require.NoError(t, err) + require.NotNil(t, peer) + require.Equal(t, tt.peerID, peer.ID) + } + }) + } +} + +func TestSqlStore_GetPeersByIDs(t *testing.T) { + store, cleanup, err := NewTestStoreFromSQL(context.Background(), "testdata/store_policy_migrate.sql", t.TempDir()) + t.Cleanup(cleanup) + require.NoError(t, err) + + accountID := "bf1c8084-ba50-4ce7-9439-34653001fc3b" + tests := []struct { + name string + peerIDs []string + expectedCount int + }{ + { + name: "retrieve existing peers by existing IDs", + peerIDs: []string{"cfefqs706sqkneg59g4g", "cfeg6sf06sqkneg59g50"}, + expectedCount: 2, + }, + { + name: "empty peer IDs list", + peerIDs: []string{}, + expectedCount: 0, + }, + { + name: "non-existing peer IDs", + peerIDs: []string{"nonexistent1", "nonexistent2"}, + expectedCount: 0, + }, + { + name: "mixed existing and non-existing peer IDs", + peerIDs: []string{"cfeg6sf06sqkneg59g50", "nonexistent"}, + expectedCount: 1, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + peers, err := store.GetPeersByIDs(context.Background(), LockingStrengthShare, accountID, tt.peerIDs) + require.NoError(t, err) + require.Len(t, peers, tt.expectedCount) + }) + } +} diff --git a/management/server/status/error.go b/management/server/status/error.go index bdf5c7549..ba9e01c4f 100644 --- a/management/server/status/error.go +++ b/management/server/status/error.go @@ -3,7 +3,6 @@ package status import ( "errors" "fmt" - "time" ) const ( @@ -103,8 +102,8 @@ func NewPeerLoginExpiredError() error { } // NewSetupKeyNotFoundError creates a new Error with NotFound type for a missing setup key -func NewSetupKeyNotFoundError(err error) error { - return Errorf(NotFound, "setup key not found: %s", err) +func NewSetupKeyNotFoundError(setupKeyID string) error { + return Errorf(NotFound, "setup key: %s not found", setupKeyID) } func NewGetAccountFromStoreError(err error) error { @@ -126,11 +125,6 @@ func NewAdminPermissionError() error { return Errorf(PermissionDenied, "admin role required to perform this action") } -// NewStoreContextCanceledError creates a new Error with Internal type for a canceled store context -func NewStoreContextCanceledError(duration time.Duration) error { - return Errorf(Internal, "store access: context canceled after %v", duration) -} - // NewInvalidKeyIDError creates a new Error with InvalidArgument type for an issue getting a setup key func NewInvalidKeyIDError() error { return Errorf(InvalidArgument, "invalid key ID") diff --git a/management/server/store.go b/management/server/store.go index 7e2581045..03b5821e7 100644 --- a/management/server/store.go +++ b/management/server/store.go @@ -71,8 +71,9 @@ type Store interface { DeleteTokenID2UserIDIndex(tokenID string) error GetAccountGroups(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*nbgroup.Group, error) - GetGroupByID(ctx context.Context, lockStrength LockingStrength, accountID, groupID string) (*nbgroup.Group, error) + GetGroupByID(ctx context.Context, lockStrength LockingStrength, groupID, accountID string) (*nbgroup.Group, error) GetGroupByName(ctx context.Context, lockStrength LockingStrength, groupName, accountID string) (*nbgroup.Group, error) + GetGroupsByIDs(ctx context.Context, lockStrength LockingStrength, accountID string, groupIDs []string) (map[string]*nbgroup.Group, error) SaveGroups(ctx context.Context, lockStrength LockingStrength, groups []*nbgroup.Group) error SaveGroup(ctx context.Context, lockStrength LockingStrength, group *nbgroup.Group) error DeleteGroup(ctx context.Context, lockStrength LockingStrength, accountID, groupID string) error @@ -94,6 +95,7 @@ type Store interface { GetPeerByPeerPubKey(ctx context.Context, lockStrength LockingStrength, peerKey string) (*nbpeer.Peer, error) GetUserPeers(ctx context.Context, lockStrength LockingStrength, accountID, userID string) ([]*nbpeer.Peer, error) GetPeerByID(ctx context.Context, lockStrength LockingStrength, accountID string, peerID string) (*nbpeer.Peer, error) + GetPeersByIDs(ctx context.Context, lockStrength LockingStrength, accountID string, peerIDs []string) (map[string]*nbpeer.Peer, error) SavePeer(ctx context.Context, accountID string, peer *nbpeer.Peer) error SavePeerStatus(accountID, peerID string, status nbpeer.PeerStatus) error SavePeerLocation(accountID string, peer *nbpeer.Peer) error diff --git a/relay/client/client.go b/relay/client/client.go index a82a75453..154c1787f 100644 --- a/relay/client/client.go +++ b/relay/client/client.go @@ -3,7 +3,6 @@ package client import ( "context" "fmt" - "io" "net" "sync" "time" @@ -449,11 +448,11 @@ func (c *Client) writeTo(connReference *Conn, id string, dstID []byte, payload [ conn, ok := c.conns[id] c.mu.Unlock() if !ok { - return 0, io.EOF + return 0, net.ErrClosed } if conn.conn != connReference { - return 0, io.EOF + return 0, net.ErrClosed } // todo: use buffer pool instead of create new transport msg. @@ -508,7 +507,7 @@ func (c *Client) closeConn(connReference *Conn, id string) error { container, ok := c.conns[id] if !ok { - return fmt.Errorf("connection already closed") + return net.ErrClosed } if container.conn != connReference { diff --git a/relay/client/conn.go b/relay/client/conn.go index b4ff903e8..fe1b6fb52 100644 --- a/relay/client/conn.go +++ b/relay/client/conn.go @@ -1,7 +1,6 @@ package client import ( - "io" "net" "time" ) @@ -40,7 +39,7 @@ func (c *Conn) Write(p []byte) (n int, err error) { func (c *Conn) Read(b []byte) (n int, err error) { msg, ok := <-c.messageChan if !ok { - return 0, io.EOF + return 0, net.ErrClosed } n = copy(b, msg.Payload) diff --git a/relay/server/listener/ws/conn.go b/relay/server/listener/ws/conn.go index c248963b9..12e721fdb 100644 --- a/relay/server/listener/ws/conn.go +++ b/relay/server/listener/ws/conn.go @@ -4,7 +4,6 @@ import ( "context" "errors" "fmt" - "io" "net" "sync" "time" @@ -100,7 +99,7 @@ func (c *Conn) isClosed() bool { func (c *Conn) ioErrHandling(err error) error { if c.isClosed() { - return io.EOF + return net.ErrClosed } var wErr *websocket.CloseError @@ -108,7 +107,7 @@ func (c *Conn) ioErrHandling(err error) error { return err } if wErr.Code == websocket.StatusNormalClosure { - return io.EOF + return net.ErrClosed } return err } diff --git a/relay/server/peer.go b/relay/server/peer.go index a9c542f84..c909c35d5 100644 --- a/relay/server/peer.go +++ b/relay/server/peer.go @@ -2,7 +2,7 @@ package server import ( "context" - "io" + "errors" "net" "sync" "time" @@ -57,7 +57,7 @@ func (p *Peer) Work() { for { n, err := p.conn.Read(buf) if err != nil { - if err != io.EOF { + if !errors.Is(err, net.ErrClosed) { p.log.Errorf("failed to read message: %s", err) } return