diff --git a/client/cmd/service.go b/client/cmd/service.go index 855eb30fa..3560088a7 100644 --- a/client/cmd/service.go +++ b/client/cmd/service.go @@ -2,6 +2,7 @@ package cmd import ( "context" + "sync" "github.com/kardianos/service" log "github.com/sirupsen/logrus" @@ -13,10 +14,11 @@ import ( ) type program struct { - ctx context.Context - cancel context.CancelFunc - serv *grpc.Server - serverInstance *server.Server + ctx context.Context + cancel context.CancelFunc + serv *grpc.Server + serverInstance *server.Server + serverInstanceMu sync.Mutex } func newProgram(ctx context.Context, cancel context.CancelFunc) *program { diff --git a/client/cmd/service_controller.go b/client/cmd/service_controller.go index 86546e31c..761c86628 100644 --- a/client/cmd/service_controller.go +++ b/client/cmd/service_controller.go @@ -61,7 +61,9 @@ func (p *program) Start(svc service.Service) error { } proto.RegisterDaemonServiceServer(p.serv, serverInstance) + p.serverInstanceMu.Lock() p.serverInstance = serverInstance + p.serverInstanceMu.Unlock() log.Printf("started daemon server: %v", split[1]) if err := p.serv.Serve(listen); err != nil { @@ -72,6 +74,7 @@ func (p *program) Start(svc service.Service) error { } func (p *program) Stop(srv service.Service) error { + p.serverInstanceMu.Lock() if p.serverInstance != nil { in := new(proto.DownRequest) _, err := p.serverInstance.Down(p.ctx, in) @@ -79,6 +82,7 @@ func (p *program) Stop(srv service.Service) error { log.Errorf("failed to stop daemon: %v", err) } } + p.serverInstanceMu.Unlock() p.cancel() diff --git a/client/iface/wgproxy/bind/proxy.go b/client/iface/wgproxy/bind/proxy.go index e0883715a..8a2e65382 100644 --- a/client/iface/wgproxy/bind/proxy.go +++ b/client/iface/wgproxy/bind/proxy.go @@ -2,6 +2,7 @@ package bind import ( "context" + "errors" "fmt" "net" "net/netip" @@ -94,7 +95,10 @@ func (p *ProxyBind) close() error { p.Bind.RemoveEndpoint(p.wgAddr) - return p.remoteConn.Close() + if rErr := p.remoteConn.Close(); rErr != nil && !errors.Is(rErr, net.ErrClosed) { + return rErr + } + return nil } func (p *ProxyBind) proxyToLocal(ctx context.Context) { diff --git a/client/iface/wgproxy/ebpf/wrapper.go b/client/iface/wgproxy/ebpf/wrapper.go index efd5fd946..54cab4e1b 100644 --- a/client/iface/wgproxy/ebpf/wrapper.go +++ b/client/iface/wgproxy/ebpf/wrapper.go @@ -77,7 +77,7 @@ func (e *ProxyWrapper) CloseConn() error { e.cancel() - if err := e.remoteConn.Close(); err != nil { + if err := e.remoteConn.Close(); err != nil && !errors.Is(err, net.ErrClosed) { return fmt.Errorf("failed to close remote conn: %w", err) } return nil diff --git a/client/iface/wgproxy/udp/proxy.go b/client/iface/wgproxy/udp/proxy.go index 200d961f3..ba0004b8a 100644 --- a/client/iface/wgproxy/udp/proxy.go +++ b/client/iface/wgproxy/udp/proxy.go @@ -116,7 +116,7 @@ func (p *WGUDPProxy) close() error { p.cancel() var result *multierror.Error - if err := p.remoteConn.Close(); err != nil { + if err := p.remoteConn.Close(); err != nil && !errors.Is(err, net.ErrClosed) { result = multierror.Append(result, fmt.Errorf("remote conn: %s", err)) } diff --git a/client/internal/connect.go b/client/internal/connect.go index bcc9d17a3..dff44f1d2 100644 --- a/client/internal/connect.go +++ b/client/internal/connect.go @@ -207,7 +207,8 @@ func (c *ConnectClient) run(mobileDependency MobileDependency, probes *ProbeHold c.statusRecorder.MarkSignalDisconnected(nil) defer func() { - c.statusRecorder.MarkSignalDisconnected(state.err) + _, err := state.Status() + c.statusRecorder.MarkSignalDisconnected(err) }() // with the global Wiretrustee config in hand connect (just a connection, no stream yet) Signal diff --git a/client/internal/peer/conn.go b/client/internal/peer/conn.go index 84a8c221f..81c456db7 100644 --- a/client/internal/peer/conn.go +++ b/client/internal/peer/conn.go @@ -442,7 +442,7 @@ func (conn *Conn) relayConnectionIsReady(rci RelayConnInfo) { if conn.iceP2PIsActive() { conn.log.Debugf("do not switch to relay because current priority is: %v", conn.currentConnPriority) - conn.wgProxyRelay = wgProxy + conn.setRelayedProxy(wgProxy) conn.statusRelay.Set(StatusConnected) conn.updateRelayStatus(rci.relayedConn.RemoteAddr().String(), rci.rosenpassPubKey) return @@ -465,7 +465,7 @@ func (conn *Conn) relayConnectionIsReady(rci RelayConnInfo) { wgConfigWorkaround() conn.currentConnPriority = connPriorityRelay conn.statusRelay.Set(StatusConnected) - conn.wgProxyRelay = wgProxy + conn.setRelayedProxy(wgProxy) conn.updateRelayStatus(rci.relayedConn.RemoteAddr().String(), rci.rosenpassPubKey) conn.log.Infof("start to communicate with peer via relay") conn.doOnConnected(rci.rosenpassPubKey, rci.rosenpassAddr) @@ -736,6 +736,15 @@ func (conn *Conn) logTraceConnState() { } } +func (conn *Conn) setRelayedProxy(proxy wgproxy.Proxy) { + if conn.wgProxyRelay != nil { + if err := conn.wgProxyRelay.CloseConn(); err != nil { + conn.log.Warnf("failed to close deprecated wg proxy conn: %v", err) + } + } + conn.wgProxyRelay = proxy +} + func isController(config ConnConfig) bool { return config.LocalKey > config.Key } diff --git a/client/internal/peer/status.go b/client/internal/peer/status.go index a28992fac..0444dc60b 100644 --- a/client/internal/peer/status.go +++ b/client/internal/peer/status.go @@ -67,7 +67,7 @@ func (s *State) DeleteRoute(network string) { func (s *State) GetRoutes() map[string]struct{} { s.Mux.RLock() defer s.Mux.RUnlock() - return s.routes + return maps.Clone(s.routes) } // LocalPeerState contains the latest state of the local peer @@ -237,10 +237,6 @@ func (d *Status) UpdatePeerState(receivedState State) error { peerState.IP = receivedState.IP } - if receivedState.GetRoutes() != nil { - peerState.SetRoutes(receivedState.GetRoutes()) - } - skipNotification := shouldSkipNotify(receivedState.ConnStatus, peerState) if receivedState.ConnStatus != peerState.ConnStatus { @@ -261,12 +257,40 @@ func (d *Status) UpdatePeerState(receivedState State) error { return nil } - ch, found := d.changeNotify[receivedState.PubKey] - if found && ch != nil { - close(ch) - d.changeNotify[receivedState.PubKey] = nil + d.notifyPeerListChanged() + return nil +} + +func (d *Status) AddPeerStateRoute(peer string, route string) error { + d.mux.Lock() + defer d.mux.Unlock() + + peerState, ok := d.peers[peer] + if !ok { + return errors.New("peer doesn't exist") } + peerState.AddRoute(route) + d.peers[peer] = peerState + + // todo: consider to make sense of this notification or not + d.notifyPeerListChanged() + return nil +} + +func (d *Status) RemovePeerStateRoute(peer string, route string) error { + d.mux.Lock() + defer d.mux.Unlock() + + peerState, ok := d.peers[peer] + if !ok { + return errors.New("peer doesn't exist") + } + + peerState.DeleteRoute(route) + d.peers[peer] = peerState + + // todo: consider to make sense of this notification or not d.notifyPeerListChanged() return nil } @@ -301,12 +325,7 @@ func (d *Status) UpdatePeerICEState(receivedState State) error { return nil } - ch, found := d.changeNotify[receivedState.PubKey] - if found && ch != nil { - close(ch) - d.changeNotify[receivedState.PubKey] = nil - } - + d.notifyPeerStateChangeListeners(receivedState.PubKey) d.notifyPeerListChanged() return nil } @@ -334,12 +353,7 @@ func (d *Status) UpdatePeerRelayedState(receivedState State) error { return nil } - ch, found := d.changeNotify[receivedState.PubKey] - if found && ch != nil { - close(ch) - d.changeNotify[receivedState.PubKey] = nil - } - + d.notifyPeerStateChangeListeners(receivedState.PubKey) d.notifyPeerListChanged() return nil } @@ -366,12 +380,7 @@ func (d *Status) UpdatePeerRelayedStateToDisconnected(receivedState State) error return nil } - ch, found := d.changeNotify[receivedState.PubKey] - if found && ch != nil { - close(ch) - d.changeNotify[receivedState.PubKey] = nil - } - + d.notifyPeerStateChangeListeners(receivedState.PubKey) d.notifyPeerListChanged() return nil } @@ -401,12 +410,7 @@ func (d *Status) UpdatePeerICEStateToDisconnected(receivedState State) error { return nil } - ch, found := d.changeNotify[receivedState.PubKey] - if found && ch != nil { - close(ch) - d.changeNotify[receivedState.PubKey] = nil - } - + d.notifyPeerStateChangeListeners(receivedState.PubKey) d.notifyPeerListChanged() return nil } @@ -477,11 +481,14 @@ func (d *Status) FinishPeerListModifications() { func (d *Status) GetPeerStateChangeNotifier(peer string) <-chan struct{} { d.mux.Lock() defer d.mux.Unlock() + ch, found := d.changeNotify[peer] - if !found || ch == nil { - ch = make(chan struct{}) - d.changeNotify[peer] = ch + if found { + return ch } + + ch = make(chan struct{}) + d.changeNotify[peer] = ch return ch } @@ -755,6 +762,17 @@ func (d *Status) onConnectionChanged() { d.notifier.updateServerStates(d.managementState, d.signalState) } +// notifyPeerStateChangeListeners notifies route manager about the change in peer state +func (d *Status) notifyPeerStateChangeListeners(peerID string) { + ch, found := d.changeNotify[peerID] + if !found { + return + } + + close(ch) + delete(d.changeNotify, peerID) +} + func (d *Status) notifyPeerListChanged() { d.notifier.peerListChanged(d.numOfPeers()) } diff --git a/client/internal/peer/status_test.go b/client/internal/peer/status_test.go index 1d283433b..931ec9005 100644 --- a/client/internal/peer/status_test.go +++ b/client/internal/peer/status_test.go @@ -93,7 +93,7 @@ func TestGetPeerStateChangeNotifierLogic(t *testing.T) { peerState.IP = ip - err := status.UpdatePeerState(peerState) + err := status.UpdatePeerRelayedStateToDisconnected(peerState) assert.NoError(t, err, "shouldn't return error") select { diff --git a/client/internal/peer/worker_ice.go b/client/internal/peer/worker_ice.go index 55894218d..4c67cb781 100644 --- a/client/internal/peer/worker_ice.go +++ b/client/internal/peer/worker_ice.go @@ -57,6 +57,9 @@ type WorkerICE struct { localUfrag string localPwd string + + // we record the last known state of the ICE agent to avoid duplicate on disconnected events + lastKnownState ice.ConnectionState } func NewWorkerICE(ctx context.Context, log *log.Entry, config ConnConfig, signaler *Signaler, ifaceDiscover stdnet.ExternalIFaceDiscover, statusRecorder *Status, hasRelayOnLocally bool, callBacks WorkerICECallbacks) (*WorkerICE, error) { @@ -194,8 +197,7 @@ func (w *WorkerICE) Close() { return } - err := w.agent.Close() - if err != nil { + if err := w.agent.Close(); err != nil { w.log.Warnf("failed to close ICE agent: %s", err) } } @@ -215,15 +217,18 @@ func (w *WorkerICE) reCreateAgent(agentCancel context.CancelFunc, candidates []i err = agent.OnConnectionStateChange(func(state ice.ConnectionState) { w.log.Debugf("ICE ConnectionState has changed to %s", state.String()) - if state == ice.ConnectionStateFailed || state == ice.ConnectionStateDisconnected { - w.conn.OnStatusChanged(StatusDisconnected) - - w.muxAgent.Lock() - agentCancel() - _ = agent.Close() - w.agent = nil - - w.muxAgent.Unlock() + switch state { + case ice.ConnectionStateConnected: + w.lastKnownState = ice.ConnectionStateConnected + return + case ice.ConnectionStateFailed, ice.ConnectionStateDisconnected: + if w.lastKnownState != ice.ConnectionStateDisconnected { + w.lastKnownState = ice.ConnectionStateDisconnected + w.conn.OnStatusChanged(StatusDisconnected) + } + w.closeAgent(agentCancel) + default: + return } }) if err != nil { @@ -249,6 +254,17 @@ func (w *WorkerICE) reCreateAgent(agentCancel context.CancelFunc, candidates []i return agent, nil } +func (w *WorkerICE) closeAgent(cancel context.CancelFunc) { + w.muxAgent.Lock() + defer w.muxAgent.Unlock() + + cancel() + if err := w.agent.Close(); err != nil { + w.log.Warnf("failed to close ICE agent: %s", err) + } + w.agent = nil +} + func (w *WorkerICE) punchRemoteWGPort(pair *ice.CandidatePair, remoteWgPort int) { // wait local endpoint configuration time.Sleep(time.Second) diff --git a/client/internal/routemanager/client.go b/client/internal/routemanager/client.go index eaa232151..13e45b3a3 100644 --- a/client/internal/routemanager/client.go +++ b/client/internal/routemanager/client.go @@ -122,13 +122,20 @@ func (c *clientNetwork) getBestRouteFromStatuses(routePeerStatuses map[route.ID] tempScore = float64(metricDiff) * 10 } - // in some temporal cases, latency can be 0, so we set it to 1s to not block but try to avoid this route - latency := time.Second + // in some temporal cases, latency can be 0, so we set it to 999ms to not block but try to avoid this route + latency := 999 * time.Millisecond if peerStatus.latency != 0 { latency = peerStatus.latency } else { - log.Warnf("peer %s has 0 latency", r.Peer) + log.Tracef("peer %s has 0 latency, range %s", r.Peer, c.handler) } + + // avoid negative tempScore on the higher latency calculation + if latency > 1*time.Second { + latency = 999 * time.Millisecond + } + + // higher latency is worse score tempScore += 1 - latency.Seconds() if !peerStatus.relayed { @@ -150,6 +157,8 @@ func (c *clientNetwork) getBestRouteFromStatuses(routePeerStatuses map[route.ID] } } + log.Debugf("chosen route: %s, chosen score: %f, current route: %s, current score: %f", chosen, chosenScore, currID, currScore) + switch { case chosen == "": var peers []string @@ -195,15 +204,20 @@ func (c *clientNetwork) watchPeerStatusChanges(ctx context.Context, peerKey stri func (c *clientNetwork) startPeersStatusChangeWatcher() { for _, r := range c.routes { _, found := c.routePeersNotifiers[r.Peer] - if !found { - c.routePeersNotifiers[r.Peer] = make(chan struct{}) - go c.watchPeerStatusChanges(c.ctx, r.Peer, c.peerStateUpdate, c.routePeersNotifiers[r.Peer]) + if found { + continue } + + closerChan := make(chan struct{}) + c.routePeersNotifiers[r.Peer] = closerChan + go c.watchPeerStatusChanges(c.ctx, r.Peer, c.peerStateUpdate, closerChan) } } -func (c *clientNetwork) removeRouteFromWireguardPeer() error { - c.removeStateRoute() +func (c *clientNetwork) removeRouteFromWireGuardPeer() error { + if err := c.statusRecorder.RemovePeerStateRoute(c.currentChosen.Peer, c.handler.String()); err != nil { + log.Warnf("Failed to update peer state: %v", err) + } if err := c.handler.RemoveAllowedIPs(); err != nil { return fmt.Errorf("remove allowed IPs: %w", err) @@ -218,7 +232,7 @@ func (c *clientNetwork) removeRouteFromPeerAndSystem() error { var merr *multierror.Error - if err := c.removeRouteFromWireguardPeer(); err != nil { + if err := c.removeRouteFromWireGuardPeer(); err != nil { merr = multierror.Append(merr, fmt.Errorf("remove allowed IPs for peer %s: %w", c.currentChosen.Peer, err)) } if err := c.handler.RemoveRoute(); err != nil { @@ -257,7 +271,7 @@ func (c *clientNetwork) recalculateRouteAndUpdatePeerAndSystem() error { } } else { // Otherwise, remove the allowed IPs from the previous peer first - if err := c.removeRouteFromWireguardPeer(); err != nil { + if err := c.removeRouteFromWireGuardPeer(); err != nil { return fmt.Errorf("remove allowed IPs for peer %s: %w", c.currentChosen.Peer, err) } } @@ -268,37 +282,13 @@ func (c *clientNetwork) recalculateRouteAndUpdatePeerAndSystem() error { return fmt.Errorf("add allowed IPs for peer %s: %w", c.currentChosen.Peer, err) } - c.addStateRoute() - + err := c.statusRecorder.AddPeerStateRoute(c.currentChosen.Peer, c.handler.String()) + if err != nil { + return fmt.Errorf("add peer state route: %w", err) + } return nil } -func (c *clientNetwork) addStateRoute() { - state, err := c.statusRecorder.GetPeer(c.currentChosen.Peer) - if err != nil { - log.Errorf("Failed to get peer state: %v", err) - return - } - - state.AddRoute(c.handler.String()) - if err := c.statusRecorder.UpdatePeerState(state); err != nil { - log.Warnf("Failed to update peer state: %v", err) - } -} - -func (c *clientNetwork) removeStateRoute() { - state, err := c.statusRecorder.GetPeer(c.currentChosen.Peer) - if err != nil { - log.Errorf("Failed to get peer state: %v", err) - return - } - - state.DeleteRoute(c.handler.String()) - if err := c.statusRecorder.UpdatePeerState(state); err != nil { - log.Warnf("Failed to update peer state: %v", err) - } -} - func (c *clientNetwork) sendUpdateToClientNetworkWatcher(update routesUpdate) { go func() { c.routeUpdate <- update diff --git a/client/internal/routemanager/refcounter/refcounter.go b/client/internal/routemanager/refcounter/refcounter.go index c121b7d77..0e230ef40 100644 --- a/client/internal/routemanager/refcounter/refcounter.go +++ b/client/internal/routemanager/refcounter/refcounter.go @@ -217,6 +217,11 @@ func (rm *Counter[Key, I, O]) Clear() { // MarshalJSON implements the json.Marshaler interface for Counter. func (rm *Counter[Key, I, O]) MarshalJSON() ([]byte, error) { + rm.refCountMu.Lock() + defer rm.refCountMu.Unlock() + rm.idMu.Lock() + defer rm.idMu.Unlock() + return json.Marshal(struct { RefCountMap map[Key]Ref[O] `json:"refCountMap"` IDMap map[string][]Key `json:"idMap"` diff --git a/management/server/grpcserver.go b/management/server/grpcserver.go index 4c4ef6c3c..efe088b27 100644 --- a/management/server/grpcserver.go +++ b/management/server/grpcserver.go @@ -6,6 +6,7 @@ import ( "net" "net/netip" "strings" + "sync" "time" pb "github.com/golang/protobuf/proto" // nolint @@ -38,6 +39,7 @@ type GRPCServer struct { jwtClaimsExtractor *jwtclaims.ClaimsExtractor appMetrics telemetry.AppMetrics ephemeralManager *EphemeralManager + peerLocks sync.Map } // NewServer creates a new Management server @@ -148,6 +150,13 @@ func (s *GRPCServer) Sync(req *proto.EncryptedMessage, srv proto.ManagementServi // nolint:staticcheck ctx = context.WithValue(ctx, nbContext.PeerIDKey, peerKey.String()) + unlock := s.acquirePeerLockByUID(ctx, peerKey.String()) + defer func() { + if unlock != nil { + unlock() + } + }() + accountID, err := s.accountManager.GetAccountIDForPeerKey(ctx, peerKey.String()) if err != nil { // nolint:staticcheck @@ -190,6 +199,9 @@ func (s *GRPCServer) Sync(req *proto.EncryptedMessage, srv proto.ManagementServi s.appMetrics.GRPCMetrics().CountSyncRequestDuration(time.Since(reqStart)) } + unlock() + unlock = nil + return s.handleUpdates(ctx, accountID, peerKey, peer, updates, srv) } @@ -245,9 +257,12 @@ func (s *GRPCServer) sendUpdate(ctx context.Context, accountID string, peerKey w } func (s *GRPCServer) cancelPeerRoutines(ctx context.Context, accountID string, peer *nbpeer.Peer) { + unlock := s.acquirePeerLockByUID(ctx, peer.Key) + defer unlock() + + _ = s.accountManager.OnPeerDisconnected(ctx, accountID, peer.Key) s.peersUpdateManager.CloseChannel(ctx, peer.ID) s.secretsManager.CancelRefresh(peer.ID) - _ = s.accountManager.OnPeerDisconnected(ctx, accountID, peer.Key) s.ephemeralManager.OnPeerDisconnected(ctx, peer) } @@ -274,6 +289,24 @@ func (s *GRPCServer) validateToken(ctx context.Context, jwtToken string) (string return claims.UserId, nil } +func (s *GRPCServer) acquirePeerLockByUID(ctx context.Context, uniqueID string) (unlock func()) { + log.WithContext(ctx).Tracef("acquiring peer lock for ID %s", uniqueID) + + start := time.Now() + value, _ := s.peerLocks.LoadOrStore(uniqueID, &sync.RWMutex{}) + mtx := value.(*sync.RWMutex) + mtx.Lock() + log.WithContext(ctx).Tracef("acquired peer lock for ID %s in %v", uniqueID, time.Since(start)) + start = time.Now() + + unlock = func() { + mtx.Unlock() + log.WithContext(ctx).Tracef("released peer lock for ID %s in %v", uniqueID, time.Since(start)) + } + + return unlock +} + // maps internal internalStatus.Error to gRPC status.Error func mapError(ctx context.Context, err error) error { if e, ok := internalStatus.FromError(err); ok { diff --git a/management/server/http/routes_handler.go b/management/server/http/routes_handler.go index ce4edee4f..f44a164e2 100644 --- a/management/server/http/routes_handler.go +++ b/management/server/http/routes_handler.go @@ -149,7 +149,7 @@ func (h *RoutesHandler) validateRoute(req api.PostApiRoutesJSONRequestBody) erro } if req.Peer == nil && req.PeerGroups == nil { - return status.Errorf(status.InvalidArgument, "either 'peer' or 'peers_group' should be provided") + return status.Errorf(status.InvalidArgument, "either 'peer' or 'peer_groups' should be provided") } if req.Peer != nil && req.PeerGroups != nil { diff --git a/management/server/setupkey.go b/management/server/setupkey.go index d6e92fe3a..554c66ba4 100644 --- a/management/server/setupkey.go +++ b/management/server/setupkey.go @@ -5,6 +5,7 @@ import ( "crypto/sha256" b64 "encoding/base64" "hash/fnv" + "slices" "strconv" "strings" "time" @@ -236,6 +237,10 @@ func (am *DefaultAccountManager) CreateSetupKey(ctx context.Context, accountID s return nil, status.NewUserNotPartOfAccountError() } + if user.IsRegularUser() { + return nil, status.NewAdminPermissionError() + } + var setupKey *SetupKey var plainKey string var eventsToStore []func() @@ -289,6 +294,10 @@ func (am *DefaultAccountManager) SaveSetupKey(ctx context.Context, accountID str return nil, status.NewUserNotPartOfAccountError() } + if user.IsRegularUser() { + return nil, status.NewAdminPermissionError() + } + var oldKey *SetupKey var newKey *SetupKey var eventsToStore []func() @@ -414,10 +423,15 @@ func (am *DefaultAccountManager) DeleteSetupKey(ctx context.Context, accountID, } func validateSetupKeyAutoGroups(ctx context.Context, transaction Store, accountID string, autoGroupIDs []string) error { + groups, err := transaction.GetGroupsByIDs(ctx, LockingStrengthShare, accountID, autoGroupIDs) + if err != nil { + return err + } + for _, groupID := range autoGroupIDs { - group, err := transaction.GetGroupByID(ctx, LockingStrengthShare, accountID, groupID) - if err != nil { - return err + group, ok := groups[groupID] + if !ok { + return status.Errorf(status.NotFound, "group not found: %s", groupID) } if group.IsGroupAll() { @@ -432,26 +446,37 @@ func validateSetupKeyAutoGroups(ctx context.Context, transaction Store, accountI func (am *DefaultAccountManager) prepareSetupKeyEvents(ctx context.Context, transaction Store, accountID, userID string, addedGroups, removedGroups []string, key *SetupKey) []func() { var eventsToStore []func() + modifiedGroups := slices.Concat(addedGroups, removedGroups) + groups, err := transaction.GetGroupsByIDs(ctx, LockingStrengthShare, accountID, modifiedGroups) + if err != nil { + log.WithContext(ctx).Errorf("issue getting groups for setup key events: %v", err) + return nil + } + for _, g := range removedGroups { - group, err := transaction.GetGroupByID(ctx, LockingStrengthShare, accountID, g) - if err != nil { + group, ok := groups[g] + if !ok { log.WithContext(ctx).Debugf("skipped adding group: %s GroupRemovedFromSetupKey activity: %v", g, err) continue } - meta := map[string]any{"group": group.Name, "group_id": group.ID, "setupkey": key.Name} - am.StoreEvent(ctx, userID, key.Id, accountID, activity.GroupRemovedFromSetupKey, meta) + eventsToStore = append(eventsToStore, func() { + meta := map[string]any{"group": group.Name, "group_id": group.ID, "setupkey": key.Name} + am.StoreEvent(ctx, userID, key.Id, accountID, activity.GroupRemovedFromSetupKey, meta) + }) } for _, g := range addedGroups { - group, err := transaction.GetGroupByID(ctx, LockingStrengthShare, accountID, g) - if err != nil { + group, ok := groups[g] + if !ok { log.WithContext(ctx).Debugf("skipped adding group: %s GroupAddedToSetupKey activity: %v", g, err) continue } - meta := map[string]any{"group": group.Name, "group_id": group.ID, "setupkey": key.Name} - am.StoreEvent(ctx, userID, key.Id, accountID, activity.GroupAddedToSetupKey, meta) + eventsToStore = append(eventsToStore, func() { + meta := map[string]any{"group": group.Name, "group_id": group.ID, "setupkey": key.Name} + am.StoreEvent(ctx, userID, key.Id, accountID, activity.GroupAddedToSetupKey, meta) + }) } return eventsToStore diff --git a/management/server/sql_store.go b/management/server/sql_store.go index 8a0f432e6..730fb9900 100644 --- a/management/server/sql_store.go +++ b/management/server/sql_store.go @@ -485,9 +485,10 @@ func (s *SqlStore) GetAccountBySetupKey(ctx context.Context, setupKey string) (* result := s.db.Select("account_id").First(&key, keyQueryCondition, setupKey) if result.Error != nil { if errors.Is(result.Error, gorm.ErrRecordNotFound) { - return nil, status.Errorf(status.NotFound, "account not found: index lookup failed") + return nil, status.NewSetupKeyNotFoundError(setupKey) } - return nil, status.NewSetupKeyNotFoundError(result.Error) + log.WithContext(ctx).Errorf("failed to get account by setup key from store: %v", result.Error) + return nil, status.Errorf(status.Internal, "failed to get account by setup key from store") } if key.AccountID == "" { @@ -570,7 +571,7 @@ func (s *SqlStore) GetAccountUsers(ctx context.Context, lockStrength LockingStre func (s *SqlStore) GetAccountGroups(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*nbgroup.Group, error) { var groups []*nbgroup.Group - result := s.db.WithContext(ctx).Clauses(clause.Locking{Strength: string(lockStrength)}).Find(&groups, accountIDCondition, accountID) + result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}).Find(&groups, accountIDCondition, accountID) if result.Error != nil { if errors.Is(result.Error, gorm.ErrRecordNotFound) { return nil, status.Errorf(status.NotFound, "accountID not found: index lookup failed") @@ -756,9 +757,10 @@ func (s *SqlStore) GetAccountIDBySetupKey(ctx context.Context, setupKey string) result := s.db.Model(&SetupKey{}).Select("account_id").Where(keyQueryCondition, setupKey).First(&accountID) if result.Error != nil { if errors.Is(result.Error, gorm.ErrRecordNotFound) { - return "", status.Errorf(status.NotFound, "account not found: index lookup failed") + return "", status.NewSetupKeyNotFoundError(setupKey) } - return "", status.NewSetupKeyNotFoundError(result.Error) + log.WithContext(ctx).Errorf("failed to get account ID by setup key from store: %v", result.Error) + return "", status.Errorf(status.Internal, "failed to get account ID by setup key from store") } if accountID == "" { @@ -985,9 +987,10 @@ func (s *SqlStore) GetSetupKeyBySecret(ctx context.Context, lockStrength Locking First(&setupKey, keyQueryCondition, key) if result.Error != nil { if errors.Is(result.Error, gorm.ErrRecordNotFound) { - return nil, status.Errorf(status.NotFound, "setup key not found") + return nil, status.NewSetupKeyNotFoundError(key) } - return nil, status.NewSetupKeyNotFoundError(result.Error) + log.WithContext(ctx).Errorf("failed to get setup key by secret from store: %v", result.Error) + return nil, status.Errorf(status.Internal, "failed to get setup key by secret from store") } return &setupKey, nil } @@ -1005,7 +1008,7 @@ func (s *SqlStore) IncrementSetupKeyUsage(ctx context.Context, setupKeyID string } if result.RowsAffected == 0 { - return status.Errorf(status.NotFound, "setup key not found") + return status.NewSetupKeyNotFoundError(setupKeyID) } return nil @@ -1207,6 +1210,23 @@ func (s *SqlStore) GetGroupByName(ctx context.Context, lockStrength LockingStren return &group, nil } +// GetGroupsByIDs retrieves groups by their IDs and account ID. +func (s *SqlStore) GetGroupsByIDs(ctx context.Context, lockStrength LockingStrength, accountID string, groupIDs []string) (map[string]*nbgroup.Group, error) { + var groups []*nbgroup.Group + result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}).Find(&groups, "account_id = ? AND id in ?", accountID, groupIDs) + if result.Error != nil { + log.WithContext(ctx).Errorf("failed to get groups by ID's from the store: %s", result.Error) + return nil, status.Errorf(status.Internal, "failed to get groups by ID's from the store") + } + + groupsMap := make(map[string]*nbgroup.Group) + for _, group := range groups { + groupsMap[group.ID] = group + } + + return groupsMap, nil +} + // SaveGroup saves a group to the store. func (s *SqlStore) SaveGroup(ctx context.Context, lockStrength LockingStrength, group *nbgroup.Group) error { result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}).Save(group) @@ -1278,7 +1298,7 @@ func (s *SqlStore) GetRouteByID(ctx context.Context, lockStrength LockingStrengt // GetAccountSetupKeys retrieves setup keys for an account. func (s *SqlStore) GetAccountSetupKeys(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*SetupKey, error) { var setupKeys []*SetupKey - result := s.db.WithContext(ctx).Clauses(clause.Locking{Strength: string(lockStrength)}). + result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}). Find(&setupKeys, accountIDCondition, accountID) if err := result.Error; err != nil { log.WithContext(ctx).Errorf("failed to get setup keys from the store: %s", err) @@ -1291,11 +1311,11 @@ func (s *SqlStore) GetAccountSetupKeys(ctx context.Context, lockStrength Locking // GetSetupKeyByID retrieves a setup key by its ID and account ID. func (s *SqlStore) GetSetupKeyByID(ctx context.Context, lockStrength LockingStrength, accountID, setupKeyID string) (*SetupKey, error) { var setupKey *SetupKey - result := s.db.WithContext(ctx).Clauses(clause.Locking{Strength: string(lockStrength)}). + result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}). First(&setupKey, accountAndIDQueryCondition, accountID, setupKeyID) if err := result.Error; err != nil { if errors.Is(err, gorm.ErrRecordNotFound) { - return nil, status.Errorf(status.NotFound, "setup key not found") + return nil, status.NewSetupKeyNotFoundError(setupKeyID) } log.WithContext(ctx).Errorf("failed to get setup key from the store: %s", err) return nil, status.Errorf(status.Internal, "failed to get setup key from store") @@ -1306,8 +1326,7 @@ func (s *SqlStore) GetSetupKeyByID(ctx context.Context, lockStrength LockingStre // SaveSetupKey saves a setup key to the database. func (s *SqlStore) SaveSetupKey(ctx context.Context, lockStrength LockingStrength, setupKey *SetupKey) error { - result := s.db.WithContext(ctx).Session(&gorm.Session{FullSaveAssociations: true}). - Clauses(clause.Locking{Strength: string(lockStrength)}).Save(setupKey) + result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}).Save(setupKey) if result.Error != nil { log.WithContext(ctx).Errorf("failed to save setup key to store: %s", result.Error) return status.Errorf(status.Internal, "failed to save setup key to store") @@ -1318,15 +1337,14 @@ func (s *SqlStore) SaveSetupKey(ctx context.Context, lockStrength LockingStrengt // DeleteSetupKey deletes a setup key from the database. func (s *SqlStore) DeleteSetupKey(ctx context.Context, lockStrength LockingStrength, accountID, keyID string) error { - result := s.db.WithContext(ctx).Clauses(clause.Locking{Strength: string(lockStrength)}). - Delete(&SetupKey{}, accountAndIDQueryCondition, accountID, keyID) + result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}).Delete(&SetupKey{}, accountAndIDQueryCondition, accountID, keyID) if result.Error != nil { log.WithContext(ctx).Errorf("failed to delete setup key from store: %s", result.Error) return status.Errorf(status.Internal, "failed to delete setup key from store") } if result.RowsAffected == 0 { - return status.Errorf(status.NotFound, "setup key not found") + return status.NewSetupKeyNotFoundError(keyID) } return nil diff --git a/management/server/status/error.go b/management/server/status/error.go index 00be347ad..6957a7e05 100644 --- a/management/server/status/error.go +++ b/management/server/status/error.go @@ -103,8 +103,8 @@ func NewPeerLoginExpiredError() error { } // NewSetupKeyNotFoundError creates a new Error with NotFound type for a missing setup key -func NewSetupKeyNotFoundError(err error) error { - return Errorf(NotFound, "setup key not found: %s", err) +func NewSetupKeyNotFoundError(setupKeyID string) error { + return Errorf(NotFound, "setup key: %s not found", setupKeyID) } func NewGetAccountFromStoreError(err error) error { @@ -126,11 +126,6 @@ func NewAdminPermissionError() error { return Errorf(PermissionDenied, "admin role required to perform this action") } -// NewStoreContextCanceledError creates a new Error with Internal type for a canceled store context -func NewStoreContextCanceledError(duration time.Duration) error { - return Errorf(Internal, "store access: context canceled after %v", duration) -} - // NewInvalidKeyIDError creates a new Error with InvalidArgument type for an issue getting a setup key func NewInvalidKeyIDError() error { return Errorf(InvalidArgument, "invalid key ID") diff --git a/management/server/store.go b/management/server/store.go index 68b57204b..2a0c44c67 100644 --- a/management/server/store.go +++ b/management/server/store.go @@ -71,8 +71,9 @@ type Store interface { DeleteTokenID2UserIDIndex(tokenID string) error GetAccountGroups(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*nbgroup.Group, error) - GetGroupByID(ctx context.Context, lockStrength LockingStrength, accountID, groupID string) (*nbgroup.Group, error) + GetGroupByID(ctx context.Context, lockStrength LockingStrength, groupID, accountID string) (*nbgroup.Group, error) GetGroupByName(ctx context.Context, lockStrength LockingStrength, groupName, accountID string) (*nbgroup.Group, error) + GetGroupsByIDs(ctx context.Context, lockStrength LockingStrength, accountID string, groupIDs []string) (map[string]*nbgroup.Group, error) SaveGroups(ctx context.Context, lockStrength LockingStrength, groups []*nbgroup.Group) error SaveGroup(ctx context.Context, lockStrength LockingStrength, group *nbgroup.Group) error DeleteGroup(ctx context.Context, lockStrength LockingStrength, accountID, groupID string) error diff --git a/relay/client/client.go b/relay/client/client.go index a82a75453..154c1787f 100644 --- a/relay/client/client.go +++ b/relay/client/client.go @@ -3,7 +3,6 @@ package client import ( "context" "fmt" - "io" "net" "sync" "time" @@ -449,11 +448,11 @@ func (c *Client) writeTo(connReference *Conn, id string, dstID []byte, payload [ conn, ok := c.conns[id] c.mu.Unlock() if !ok { - return 0, io.EOF + return 0, net.ErrClosed } if conn.conn != connReference { - return 0, io.EOF + return 0, net.ErrClosed } // todo: use buffer pool instead of create new transport msg. @@ -508,7 +507,7 @@ func (c *Client) closeConn(connReference *Conn, id string) error { container, ok := c.conns[id] if !ok { - return fmt.Errorf("connection already closed") + return net.ErrClosed } if container.conn != connReference { diff --git a/relay/client/conn.go b/relay/client/conn.go index b4ff903e8..fe1b6fb52 100644 --- a/relay/client/conn.go +++ b/relay/client/conn.go @@ -1,7 +1,6 @@ package client import ( - "io" "net" "time" ) @@ -40,7 +39,7 @@ func (c *Conn) Write(p []byte) (n int, err error) { func (c *Conn) Read(b []byte) (n int, err error) { msg, ok := <-c.messageChan if !ok { - return 0, io.EOF + return 0, net.ErrClosed } n = copy(b, msg.Payload) diff --git a/relay/server/listener/ws/conn.go b/relay/server/listener/ws/conn.go index c248963b9..12e721fdb 100644 --- a/relay/server/listener/ws/conn.go +++ b/relay/server/listener/ws/conn.go @@ -4,7 +4,6 @@ import ( "context" "errors" "fmt" - "io" "net" "sync" "time" @@ -100,7 +99,7 @@ func (c *Conn) isClosed() bool { func (c *Conn) ioErrHandling(err error) error { if c.isClosed() { - return io.EOF + return net.ErrClosed } var wErr *websocket.CloseError @@ -108,7 +107,7 @@ func (c *Conn) ioErrHandling(err error) error { return err } if wErr.Code == websocket.StatusNormalClosure { - return io.EOF + return net.ErrClosed } return err } diff --git a/relay/server/peer.go b/relay/server/peer.go index a9c542f84..c909c35d5 100644 --- a/relay/server/peer.go +++ b/relay/server/peer.go @@ -2,7 +2,7 @@ package server import ( "context" - "io" + "errors" "net" "sync" "time" @@ -57,7 +57,7 @@ func (p *Peer) Work() { for { n, err := p.conn.Read(buf) if err != nil { - if err != io.EOF { + if !errors.Is(err, net.ErrClosed) { p.log.Errorf("failed to read message: %s", err) } return