Add safe read/write to route map (#1760)

This commit is contained in:
Carlos Hernandez 2024-04-11 14:12:23 -06:00 committed by GitHub
parent 061f673a4f
commit 76702c8a09
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
6 changed files with 49 additions and 11 deletions

View File

@ -794,6 +794,7 @@ func (e *Engine) updateOfflinePeers(offlinePeers []*mgmProto.RemotePeerConfig) {
FQDN: offlinePeer.GetFqdn(), FQDN: offlinePeer.GetFqdn(),
ConnStatus: peer.StatusDisconnected, ConnStatus: peer.StatusDisconnected,
ConnStatusUpdate: time.Now(), ConnStatusUpdate: time.Now(),
Mux: new(sync.RWMutex),
} }
} }
e.statusRecorder.ReplaceOfflinePeers(replacement) e.statusRecorder.ReplaceOfflinePeers(replacement)

View File

@ -229,7 +229,6 @@ func (conn *Conn) reCreateAgent() error {
} }
conn.agent, err = ice.NewAgent(agentConfig) conn.agent, err = ice.NewAgent(agentConfig)
if err != nil { if err != nil {
return err return err
} }
@ -285,6 +284,7 @@ func (conn *Conn) Open() error {
IP: strings.Split(conn.config.WgConfig.AllowedIps, "/")[0], IP: strings.Split(conn.config.WgConfig.AllowedIps, "/")[0],
ConnStatusUpdate: time.Now(), ConnStatusUpdate: time.Now(),
ConnStatus: conn.status, ConnStatus: conn.status,
Mux: new(sync.RWMutex),
} }
err := conn.statusRecorder.UpdatePeerState(peerState) err := conn.statusRecorder.UpdatePeerState(peerState)
if err != nil { if err != nil {
@ -344,6 +344,7 @@ func (conn *Conn) Open() error {
PubKey: conn.config.Key, PubKey: conn.config.Key,
ConnStatus: conn.status, ConnStatus: conn.status,
ConnStatusUpdate: time.Now(), ConnStatusUpdate: time.Now(),
Mux: new(sync.RWMutex),
} }
err = conn.statusRecorder.UpdatePeerState(peerState) err = conn.statusRecorder.UpdatePeerState(peerState)
if err != nil { if err != nil {
@ -468,6 +469,7 @@ func (conn *Conn) configureConnection(remoteConn net.Conn, remoteWgPort int, rem
RemoteIceCandidateEndpoint: fmt.Sprintf("%s:%d", pair.Remote.Address(), pair.Local.Port()), RemoteIceCandidateEndpoint: fmt.Sprintf("%s:%d", pair.Remote.Address(), pair.Local.Port()),
Direct: !isRelayCandidate(pair.Local), Direct: !isRelayCandidate(pair.Local),
RosenpassEnabled: rosenpassEnabled, RosenpassEnabled: rosenpassEnabled,
Mux: new(sync.RWMutex),
} }
if pair.Local.Type() == ice.CandidateTypeRelay || pair.Remote.Type() == ice.CandidateTypeRelay { if pair.Local.Type() == ice.CandidateTypeRelay || pair.Remote.Type() == ice.CandidateTypeRelay {
peerState.Relayed = true peerState.Relayed = true
@ -558,6 +560,7 @@ func (conn *Conn) cleanup() error {
PubKey: conn.config.Key, PubKey: conn.config.Key,
ConnStatus: conn.status, ConnStatus: conn.status,
ConnStatusUpdate: time.Now(), ConnStatusUpdate: time.Now(),
Mux: new(sync.RWMutex),
} }
err := conn.statusRecorder.UpdatePeerState(peerState) err := conn.statusRecorder.UpdatePeerState(peerState)
if err != nil { if err != nil {

View File

@ -14,6 +14,7 @@ import (
// State contains the latest state of a peer // State contains the latest state of a peer
type State struct { type State struct {
Mux *sync.RWMutex
IP string IP string
PubKey string PubKey string
FQDN string FQDN string
@ -30,7 +31,38 @@ type State struct {
BytesRx int64 BytesRx int64
Latency time.Duration Latency time.Duration
RosenpassEnabled bool RosenpassEnabled bool
Routes map[string]struct{} routes map[string]struct{}
}
// AddRoute add a single route to routes map
func (s *State) AddRoute(network string) {
s.Mux.Lock()
if s.routes == nil {
s.routes = make(map[string]struct{})
}
s.routes[network] = struct{}{}
s.Mux.Unlock()
}
// SetRoutes set state routes
func (s *State) SetRoutes(routes map[string]struct{}) {
s.Mux.Lock()
s.routes = routes
s.Mux.Unlock()
}
// DeleteRoute removes a route from the network amp
func (s *State) DeleteRoute(network string) {
s.Mux.Lock()
delete(s.routes, network)
s.Mux.Unlock()
}
// GetRoutes return routes map
func (s *State) GetRoutes() map[string]struct{} {
s.Mux.RLock()
defer s.Mux.RUnlock()
return s.routes
} }
// LocalPeerState contains the latest state of the local peer // LocalPeerState contains the latest state of the local peer
@ -143,6 +175,7 @@ func (d *Status) AddPeer(peerPubKey string, fqdn string) error {
PubKey: peerPubKey, PubKey: peerPubKey,
ConnStatus: StatusDisconnected, ConnStatus: StatusDisconnected,
FQDN: fqdn, FQDN: fqdn,
Mux: new(sync.RWMutex),
} }
d.peerListChangedForNotification = true d.peerListChangedForNotification = true
return nil return nil
@ -189,8 +222,8 @@ func (d *Status) UpdatePeerState(receivedState State) error {
peerState.IP = receivedState.IP peerState.IP = receivedState.IP
} }
if receivedState.Routes != nil { if receivedState.GetRoutes() != nil {
peerState.Routes = receivedState.Routes peerState.SetRoutes(receivedState.GetRoutes())
} }
skipNotification := shouldSkipNotify(receivedState, peerState) skipNotification := shouldSkipNotify(receivedState, peerState)
@ -440,7 +473,6 @@ func (d *Status) IsLoginRequired() bool {
s, ok := gstatus.FromError(d.managementError) s, ok := gstatus.FromError(d.managementError)
if ok && (s.Code() == codes.InvalidArgument || s.Code() == codes.PermissionDenied) { if ok && (s.Code() == codes.InvalidArgument || s.Code() == codes.PermissionDenied) {
return true return true
} }
return false return false
} }

View File

@ -3,6 +3,7 @@ package peer
import ( import (
"errors" "errors"
"testing" "testing"
"sync"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
) )
@ -42,6 +43,7 @@ func TestUpdatePeerState(t *testing.T) {
status := NewRecorder("https://mgm") status := NewRecorder("https://mgm")
peerState := State{ peerState := State{
PubKey: key, PubKey: key,
Mux: new(sync.RWMutex),
} }
status.peers[key] = peerState status.peers[key] = peerState
@ -62,6 +64,7 @@ func TestStatus_UpdatePeerFQDN(t *testing.T) {
status := NewRecorder("https://mgm") status := NewRecorder("https://mgm")
peerState := State{ peerState := State{
PubKey: key, PubKey: key,
Mux: new(sync.RWMutex),
} }
status.peers[key] = peerState status.peers[key] = peerState
@ -80,6 +83,7 @@ func TestGetPeerStateChangeNotifierLogic(t *testing.T) {
status := NewRecorder("https://mgm") status := NewRecorder("https://mgm")
peerState := State{ peerState := State{
PubKey: key, PubKey: key,
Mux: new(sync.RWMutex),
} }
status.peers[key] = peerState status.peers[key] = peerState
@ -104,6 +108,7 @@ func TestRemovePeer(t *testing.T) {
status := NewRecorder("https://mgm") status := NewRecorder("https://mgm")
peerState := State{ peerState := State{
PubKey: key, PubKey: key,
Mux: new(sync.RWMutex),
} }
status.peers[key] = peerState status.peers[key] = peerState

View File

@ -196,7 +196,7 @@ func (c *clientNetwork) removeRouteFromWireguardPeer(peerKey string) error {
return fmt.Errorf("get peer state: %v", err) return fmt.Errorf("get peer state: %v", err)
} }
delete(state.Routes, c.network.String()) state.DeleteRoute(c.network.String())
if err := c.statusRecorder.UpdatePeerState(state); err != nil { if err := c.statusRecorder.UpdatePeerState(state); err != nil {
log.Warnf("Failed to update peer state: %v", err) log.Warnf("Failed to update peer state: %v", err)
} }
@ -268,10 +268,7 @@ func (c *clientNetwork) recalculateRouteAndUpdatePeerAndSystem() error {
if err != nil { if err != nil {
log.Errorf("Failed to get peer state: %v", err) log.Errorf("Failed to get peer state: %v", err)
} else { } else {
if state.Routes == nil { state.AddRoute(c.network.String())
state.Routes = map[string]struct{}{}
}
state.Routes[c.network.String()] = struct{}{}
if err := c.statusRecorder.UpdatePeerState(state); err != nil { if err := c.statusRecorder.UpdatePeerState(state); err != nil {
log.Warnf("Failed to update peer state: %v", err) log.Warnf("Failed to update peer state: %v", err)
} }

View File

@ -718,7 +718,7 @@ func toProtoFullStatus(fullStatus peer.FullStatus) *proto.FullStatus {
BytesRx: peerState.BytesRx, BytesRx: peerState.BytesRx,
BytesTx: peerState.BytesTx, BytesTx: peerState.BytesTx,
RosenpassEnabled: peerState.RosenpassEnabled, RosenpassEnabled: peerState.RosenpassEnabled,
Routes: maps.Keys(peerState.Routes), Routes: maps.Keys(peerState.GetRoutes()),
Latency: durationpb.New(peerState.Latency), Latency: durationpb.New(peerState.Latency),
} }
pbFullStatus.Peers = append(pbFullStatus.Peers, pbPeerState) pbFullStatus.Peers = append(pbFullStatus.Peers, pbPeerState)