diff --git a/client/internal/peer/status.go b/client/internal/peer/status.go index a28992fac..241dfabbb 100644 --- a/client/internal/peer/status.go +++ b/client/internal/peer/status.go @@ -237,10 +237,6 @@ func (d *Status) UpdatePeerState(receivedState State) error { peerState.IP = receivedState.IP } - if receivedState.GetRoutes() != nil { - peerState.SetRoutes(receivedState.GetRoutes()) - } - skipNotification := shouldSkipNotify(receivedState.ConnStatus, peerState) if receivedState.ConnStatus != peerState.ConnStatus { @@ -261,12 +257,40 @@ func (d *Status) UpdatePeerState(receivedState State) error { return nil } - ch, found := d.changeNotify[receivedState.PubKey] - if found && ch != nil { - close(ch) - d.changeNotify[receivedState.PubKey] = nil + d.notifyPeerListChanged() + return nil +} + +func (d *Status) AddPeerStateRoute(peer string, route string) error { + d.mux.Lock() + defer d.mux.Unlock() + + peerState, ok := d.peers[peer] + if !ok { + return errors.New("peer doesn't exist") } + peerState.AddRoute(route) + d.peers[peer] = peerState + + // todo: consider to make sense of this notification or not + d.notifyPeerListChanged() + return nil +} + +func (d *Status) RemovePeerStateRoute(peer string, route string) error { + d.mux.Lock() + defer d.mux.Unlock() + + peerState, ok := d.peers[peer] + if !ok { + return errors.New("peer doesn't exist") + } + + peerState.DeleteRoute(route) + d.peers[peer] = peerState + + // todo: consider to make sense of this notification or not d.notifyPeerListChanged() return nil } @@ -301,12 +325,7 @@ func (d *Status) UpdatePeerICEState(receivedState State) error { return nil } - ch, found := d.changeNotify[receivedState.PubKey] - if found && ch != nil { - close(ch) - d.changeNotify[receivedState.PubKey] = nil - } - + d.notifyPeerStateChangeListeners(receivedState.PubKey) d.notifyPeerListChanged() return nil } @@ -334,12 +353,7 @@ func (d *Status) UpdatePeerRelayedState(receivedState State) error { return nil } - ch, found := d.changeNotify[receivedState.PubKey] - if found && ch != nil { - close(ch) - d.changeNotify[receivedState.PubKey] = nil - } - + d.notifyPeerStateChangeListeners(receivedState.PubKey) d.notifyPeerListChanged() return nil } @@ -366,12 +380,7 @@ func (d *Status) UpdatePeerRelayedStateToDisconnected(receivedState State) error return nil } - ch, found := d.changeNotify[receivedState.PubKey] - if found && ch != nil { - close(ch) - d.changeNotify[receivedState.PubKey] = nil - } - + d.notifyPeerStateChangeListeners(receivedState.PubKey) d.notifyPeerListChanged() return nil } @@ -401,12 +410,7 @@ func (d *Status) UpdatePeerICEStateToDisconnected(receivedState State) error { return nil } - ch, found := d.changeNotify[receivedState.PubKey] - if found && ch != nil { - close(ch) - d.changeNotify[receivedState.PubKey] = nil - } - + d.notifyPeerStateChangeListeners(receivedState.PubKey) d.notifyPeerListChanged() return nil } @@ -477,11 +481,14 @@ func (d *Status) FinishPeerListModifications() { func (d *Status) GetPeerStateChangeNotifier(peer string) <-chan struct{} { d.mux.Lock() defer d.mux.Unlock() + ch, found := d.changeNotify[peer] - if !found || ch == nil { - ch = make(chan struct{}) - d.changeNotify[peer] = ch + if found { + return ch } + + ch = make(chan struct{}) + d.changeNotify[peer] = ch return ch } @@ -755,6 +762,17 @@ func (d *Status) onConnectionChanged() { d.notifier.updateServerStates(d.managementState, d.signalState) } +// notifyPeerStateChangeListeners notifies route manager about the change in peer state +func (d *Status) notifyPeerStateChangeListeners(peerID string) { + ch, found := d.changeNotify[peerID] + if !found { + return + } + + close(ch) + delete(d.changeNotify, peerID) +} + func (d *Status) notifyPeerListChanged() { d.notifier.peerListChanged(d.numOfPeers()) } diff --git a/client/internal/peer/status_test.go b/client/internal/peer/status_test.go index 1d283433b..931ec9005 100644 --- a/client/internal/peer/status_test.go +++ b/client/internal/peer/status_test.go @@ -93,7 +93,7 @@ func TestGetPeerStateChangeNotifierLogic(t *testing.T) { peerState.IP = ip - err := status.UpdatePeerState(peerState) + err := status.UpdatePeerRelayedStateToDisconnected(peerState) assert.NoError(t, err, "shouldn't return error") select { diff --git a/client/internal/peer/worker_ice.go b/client/internal/peer/worker_ice.go index 55894218d..4c67cb781 100644 --- a/client/internal/peer/worker_ice.go +++ b/client/internal/peer/worker_ice.go @@ -57,6 +57,9 @@ type WorkerICE struct { localUfrag string localPwd string + + // we record the last known state of the ICE agent to avoid duplicate on disconnected events + lastKnownState ice.ConnectionState } func NewWorkerICE(ctx context.Context, log *log.Entry, config ConnConfig, signaler *Signaler, ifaceDiscover stdnet.ExternalIFaceDiscover, statusRecorder *Status, hasRelayOnLocally bool, callBacks WorkerICECallbacks) (*WorkerICE, error) { @@ -194,8 +197,7 @@ func (w *WorkerICE) Close() { return } - err := w.agent.Close() - if err != nil { + if err := w.agent.Close(); err != nil { w.log.Warnf("failed to close ICE agent: %s", err) } } @@ -215,15 +217,18 @@ func (w *WorkerICE) reCreateAgent(agentCancel context.CancelFunc, candidates []i err = agent.OnConnectionStateChange(func(state ice.ConnectionState) { w.log.Debugf("ICE ConnectionState has changed to %s", state.String()) - if state == ice.ConnectionStateFailed || state == ice.ConnectionStateDisconnected { - w.conn.OnStatusChanged(StatusDisconnected) - - w.muxAgent.Lock() - agentCancel() - _ = agent.Close() - w.agent = nil - - w.muxAgent.Unlock() + switch state { + case ice.ConnectionStateConnected: + w.lastKnownState = ice.ConnectionStateConnected + return + case ice.ConnectionStateFailed, ice.ConnectionStateDisconnected: + if w.lastKnownState != ice.ConnectionStateDisconnected { + w.lastKnownState = ice.ConnectionStateDisconnected + w.conn.OnStatusChanged(StatusDisconnected) + } + w.closeAgent(agentCancel) + default: + return } }) if err != nil { @@ -249,6 +254,17 @@ func (w *WorkerICE) reCreateAgent(agentCancel context.CancelFunc, candidates []i return agent, nil } +func (w *WorkerICE) closeAgent(cancel context.CancelFunc) { + w.muxAgent.Lock() + defer w.muxAgent.Unlock() + + cancel() + if err := w.agent.Close(); err != nil { + w.log.Warnf("failed to close ICE agent: %s", err) + } + w.agent = nil +} + func (w *WorkerICE) punchRemoteWGPort(pair *ice.CandidatePair, remoteWgPort int) { // wait local endpoint configuration time.Sleep(time.Second) diff --git a/client/internal/routemanager/client.go b/client/internal/routemanager/client.go index eaa232151..13e45b3a3 100644 --- a/client/internal/routemanager/client.go +++ b/client/internal/routemanager/client.go @@ -122,13 +122,20 @@ func (c *clientNetwork) getBestRouteFromStatuses(routePeerStatuses map[route.ID] tempScore = float64(metricDiff) * 10 } - // in some temporal cases, latency can be 0, so we set it to 1s to not block but try to avoid this route - latency := time.Second + // in some temporal cases, latency can be 0, so we set it to 999ms to not block but try to avoid this route + latency := 999 * time.Millisecond if peerStatus.latency != 0 { latency = peerStatus.latency } else { - log.Warnf("peer %s has 0 latency", r.Peer) + log.Tracef("peer %s has 0 latency, range %s", r.Peer, c.handler) } + + // avoid negative tempScore on the higher latency calculation + if latency > 1*time.Second { + latency = 999 * time.Millisecond + } + + // higher latency is worse score tempScore += 1 - latency.Seconds() if !peerStatus.relayed { @@ -150,6 +157,8 @@ func (c *clientNetwork) getBestRouteFromStatuses(routePeerStatuses map[route.ID] } } + log.Debugf("chosen route: %s, chosen score: %f, current route: %s, current score: %f", chosen, chosenScore, currID, currScore) + switch { case chosen == "": var peers []string @@ -195,15 +204,20 @@ func (c *clientNetwork) watchPeerStatusChanges(ctx context.Context, peerKey stri func (c *clientNetwork) startPeersStatusChangeWatcher() { for _, r := range c.routes { _, found := c.routePeersNotifiers[r.Peer] - if !found { - c.routePeersNotifiers[r.Peer] = make(chan struct{}) - go c.watchPeerStatusChanges(c.ctx, r.Peer, c.peerStateUpdate, c.routePeersNotifiers[r.Peer]) + if found { + continue } + + closerChan := make(chan struct{}) + c.routePeersNotifiers[r.Peer] = closerChan + go c.watchPeerStatusChanges(c.ctx, r.Peer, c.peerStateUpdate, closerChan) } } -func (c *clientNetwork) removeRouteFromWireguardPeer() error { - c.removeStateRoute() +func (c *clientNetwork) removeRouteFromWireGuardPeer() error { + if err := c.statusRecorder.RemovePeerStateRoute(c.currentChosen.Peer, c.handler.String()); err != nil { + log.Warnf("Failed to update peer state: %v", err) + } if err := c.handler.RemoveAllowedIPs(); err != nil { return fmt.Errorf("remove allowed IPs: %w", err) @@ -218,7 +232,7 @@ func (c *clientNetwork) removeRouteFromPeerAndSystem() error { var merr *multierror.Error - if err := c.removeRouteFromWireguardPeer(); err != nil { + if err := c.removeRouteFromWireGuardPeer(); err != nil { merr = multierror.Append(merr, fmt.Errorf("remove allowed IPs for peer %s: %w", c.currentChosen.Peer, err)) } if err := c.handler.RemoveRoute(); err != nil { @@ -257,7 +271,7 @@ func (c *clientNetwork) recalculateRouteAndUpdatePeerAndSystem() error { } } else { // Otherwise, remove the allowed IPs from the previous peer first - if err := c.removeRouteFromWireguardPeer(); err != nil { + if err := c.removeRouteFromWireGuardPeer(); err != nil { return fmt.Errorf("remove allowed IPs for peer %s: %w", c.currentChosen.Peer, err) } } @@ -268,37 +282,13 @@ func (c *clientNetwork) recalculateRouteAndUpdatePeerAndSystem() error { return fmt.Errorf("add allowed IPs for peer %s: %w", c.currentChosen.Peer, err) } - c.addStateRoute() - + err := c.statusRecorder.AddPeerStateRoute(c.currentChosen.Peer, c.handler.String()) + if err != nil { + return fmt.Errorf("add peer state route: %w", err) + } return nil } -func (c *clientNetwork) addStateRoute() { - state, err := c.statusRecorder.GetPeer(c.currentChosen.Peer) - if err != nil { - log.Errorf("Failed to get peer state: %v", err) - return - } - - state.AddRoute(c.handler.String()) - if err := c.statusRecorder.UpdatePeerState(state); err != nil { - log.Warnf("Failed to update peer state: %v", err) - } -} - -func (c *clientNetwork) removeStateRoute() { - state, err := c.statusRecorder.GetPeer(c.currentChosen.Peer) - if err != nil { - log.Errorf("Failed to get peer state: %v", err) - return - } - - state.DeleteRoute(c.handler.String()) - if err := c.statusRecorder.UpdatePeerState(state); err != nil { - log.Warnf("Failed to update peer state: %v", err) - } -} - func (c *clientNetwork) sendUpdateToClientNetworkWatcher(update routesUpdate) { go func() { c.routeUpdate <- update diff --git a/client/internal/routemanager/refcounter/refcounter.go b/client/internal/routemanager/refcounter/refcounter.go index c121b7d77..0e230ef40 100644 --- a/client/internal/routemanager/refcounter/refcounter.go +++ b/client/internal/routemanager/refcounter/refcounter.go @@ -217,6 +217,11 @@ func (rm *Counter[Key, I, O]) Clear() { // MarshalJSON implements the json.Marshaler interface for Counter. func (rm *Counter[Key, I, O]) MarshalJSON() ([]byte, error) { + rm.refCountMu.Lock() + defer rm.refCountMu.Unlock() + rm.idMu.Lock() + defer rm.idMu.Unlock() + return json.Marshal(struct { RefCountMap map[Key]Ref[O] `json:"refCountMap"` IDMap map[string][]Key `json:"idMap"`