mirror of
https://github.com/netbirdio/netbird.git
synced 2025-06-30 14:40:15 +02:00
Fix concurrency on the client (#183)
* reworked peer connection establishment logic eliminating race conditions and deadlocks while running many peers
This commit is contained in:
@ -3,9 +3,10 @@ package internal
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"github.com/cenkalti/backoff/v4"
|
||||
"github.com/pion/ice/v2"
|
||||
log "github.com/sirupsen/logrus"
|
||||
"github.com/wiretrustee/wiretrustee/client/internal/peer"
|
||||
"github.com/wiretrustee/wiretrustee/client/internal/proxy"
|
||||
"github.com/wiretrustee/wiretrustee/iface"
|
||||
mgm "github.com/wiretrustee/wiretrustee/management/client"
|
||||
mgmProto "github.com/wiretrustee/wiretrustee/management/proto"
|
||||
@ -20,11 +21,14 @@ import (
|
||||
|
||||
// PeerConnectionTimeoutMax is a timeout of an initial connection attempt to a remote peer.
|
||||
// E.g. this peer will wait PeerConnectionTimeoutMax for the remote peer to respond, if not successful then it will retry the connection attempt.
|
||||
const PeerConnectionTimeoutMax = 45 //sec
|
||||
const PeerConnectionTimeoutMin = 30 //sec
|
||||
const PeerConnectionTimeoutMax = 45000 //ms
|
||||
const PeerConnectionTimeoutMin = 30000 //ms
|
||||
|
||||
const WgPort = 51820
|
||||
|
||||
// EngineConfig is a config for the Engine
|
||||
type EngineConfig struct {
|
||||
WgPort int
|
||||
WgIface string
|
||||
// WgAddr is a Wireguard local address (Wiretrustee Network IP)
|
||||
WgAddr string
|
||||
@ -42,21 +46,13 @@ type Engine struct {
|
||||
signal *signal.Client
|
||||
// mgmClient is a Management Service client
|
||||
mgmClient *mgm.Client
|
||||
// conns is a collection of remote peer connections indexed by local public key of the remote peers
|
||||
conns map[string]*Connection
|
||||
// peerMap is a map that holds all the peers that are known to this peer
|
||||
peerMap map[string]struct{}
|
||||
// peerConns is a map that holds all the peers that are known to this peer
|
||||
peerConns map[string]*peer.Conn
|
||||
|
||||
// peerMux is used to sync peer operations (e.g. open connection, peer removal)
|
||||
peerMux *sync.Mutex
|
||||
// syncMsgMux is used to guarantee sequential Management Service message processing
|
||||
syncMsgMux *sync.Mutex
|
||||
|
||||
config *EngineConfig
|
||||
|
||||
// wgPort is a Wireguard local listen port
|
||||
wgPort int
|
||||
|
||||
// STUNs is a list of STUN servers used by ICE
|
||||
STUNs []*ice.URL
|
||||
// TURNs is a list of STUN servers used by ICE
|
||||
@ -78,9 +74,7 @@ func NewEngine(signalClient *signal.Client, mgmClient *mgm.Client, config *Engin
|
||||
return &Engine{
|
||||
signal: signalClient,
|
||||
mgmClient: mgmClient,
|
||||
conns: map[string]*Connection{},
|
||||
peerMap: map[string]struct{}{},
|
||||
peerMux: &sync.Mutex{},
|
||||
peerConns: map[string]*peer.Conn{},
|
||||
syncMsgMux: &sync.Mutex{},
|
||||
config: config,
|
||||
STUNs: []*ice.URL{},
|
||||
@ -91,13 +85,16 @@ func NewEngine(signalClient *signal.Client, mgmClient *mgm.Client, config *Engin
|
||||
}
|
||||
|
||||
func (e *Engine) Stop() error {
|
||||
e.syncMsgMux.Lock()
|
||||
defer e.syncMsgMux.Unlock()
|
||||
|
||||
err := e.removeAllPeerConnections()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
log.Debugf("removing Wiretrustee interface %s", e.config.WgIface)
|
||||
err = iface.Close()
|
||||
err = iface.Close(e.config.WgIface)
|
||||
if err != nil {
|
||||
log.Errorf("failed closing Wiretrustee interface %s %v", e.config.WgIface, err)
|
||||
return err
|
||||
@ -112,6 +109,8 @@ func (e *Engine) Stop() error {
|
||||
// Connections to remote peers are not established here.
|
||||
// However, they will be established once an event with a list of peers to connect to will be received from Management Service
|
||||
func (e *Engine) Start() error {
|
||||
e.syncMsgMux.Lock()
|
||||
defer e.syncMsgMux.Unlock()
|
||||
|
||||
wgIface := e.config.WgIface
|
||||
wgAddr := e.config.WgAddr
|
||||
@ -123,86 +122,33 @@ func (e *Engine) Start() error {
|
||||
return err
|
||||
}
|
||||
|
||||
err = iface.Configure(wgIface, myPrivateKey.String())
|
||||
err = iface.Configure(wgIface, myPrivateKey.String(), e.config.WgPort)
|
||||
if err != nil {
|
||||
log.Errorf("failed configuring Wireguard interface [%s]: %s", wgIface, err.Error())
|
||||
return err
|
||||
}
|
||||
|
||||
port, err := iface.GetListenPort(wgIface)
|
||||
if err != nil {
|
||||
log.Errorf("failed getting Wireguard listen port [%s]: %s", wgIface, err.Error())
|
||||
return err
|
||||
}
|
||||
e.wgPort = *port
|
||||
|
||||
e.receiveSignalEvents()
|
||||
e.receiveManagementEvents()
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// initializePeer peer agent attempt to open connection
|
||||
func (e *Engine) initializePeer(peer Peer) {
|
||||
|
||||
e.peerMap[peer.WgPubKey] = struct{}{}
|
||||
|
||||
var backOff = backoff.WithContext(&backoff.ExponentialBackOff{
|
||||
InitialInterval: backoff.DefaultInitialInterval,
|
||||
RandomizationFactor: backoff.DefaultRandomizationFactor,
|
||||
Multiplier: backoff.DefaultMultiplier,
|
||||
MaxInterval: 5 * time.Second,
|
||||
MaxElapsedTime: 0, //never stop
|
||||
Stop: backoff.Stop,
|
||||
Clock: backoff.SystemClock,
|
||||
}, e.ctx)
|
||||
|
||||
operation := func() error {
|
||||
|
||||
if e.signal.GetStatus() != signal.StreamConnected {
|
||||
return fmt.Errorf("not opening connection to peer because Signal is unavailable")
|
||||
}
|
||||
|
||||
_, err := e.openPeerConnection(e.wgPort, e.config.WgPrivateKey, peer)
|
||||
e.peerMux.Lock()
|
||||
defer e.peerMux.Unlock()
|
||||
if _, ok := e.peerMap[peer.WgPubKey]; !ok {
|
||||
log.Debugf("peer was removed: %v, stop connecting", peer.WgPubKey)
|
||||
return nil
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
log.Debugf("retrying connection because of error: %s", err.Error())
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
go func() {
|
||||
err := backoff.Retry(operation, backOff)
|
||||
if err != nil {
|
||||
// should actually never happen
|
||||
panic(err)
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
func (e *Engine) removePeers(peers []string) error {
|
||||
for _, peer := range peers {
|
||||
err := e.removePeer(peer)
|
||||
for _, p := range peers {
|
||||
err := e.removePeer(p)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
log.Infof("removed peer %s", p)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (e *Engine) removeAllPeerConnections() error {
|
||||
log.Debugf("removing all peer connections")
|
||||
e.peerMux.Lock()
|
||||
defer e.peerMux.Unlock()
|
||||
for peer := range e.conns {
|
||||
err := e.removePeer(peer)
|
||||
for p := range e.peerConns {
|
||||
err := e.removePeer(p)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@ -212,69 +158,39 @@ func (e *Engine) removeAllPeerConnections() error {
|
||||
|
||||
// removePeer closes an existing peer connection and removes a peer
|
||||
func (e *Engine) removePeer(peerKey string) error {
|
||||
|
||||
delete(e.peerMap, peerKey)
|
||||
|
||||
conn, exists := e.conns[peerKey]
|
||||
if exists && conn != nil {
|
||||
delete(e.conns, peerKey)
|
||||
log.Debugf("removing peer from engine %s", peerKey)
|
||||
conn, exists := e.peerConns[peerKey]
|
||||
if exists {
|
||||
delete(e.peerConns, peerKey)
|
||||
return conn.Close()
|
||||
}
|
||||
log.Infof("removed peer %s", peerKey)
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetPeerConnectionStatus returns a connection Status or nil if peer connection wasn't found
|
||||
func (e *Engine) GetPeerConnectionStatus(peerKey string) *Status {
|
||||
e.peerMux.Lock()
|
||||
defer e.peerMux.Unlock()
|
||||
func (e *Engine) GetPeerConnectionStatus(peerKey string) peer.ConnStatus {
|
||||
|
||||
conn, exists := e.conns[peerKey]
|
||||
conn, exists := e.peerConns[peerKey]
|
||||
if exists && conn != nil {
|
||||
return &conn.Status
|
||||
return conn.Status()
|
||||
}
|
||||
|
||||
return nil
|
||||
return -1
|
||||
}
|
||||
|
||||
// openPeerConnection opens a new remote peer connection
|
||||
func (e *Engine) openPeerConnection(wgPort int, myKey wgtypes.Key, peer Peer) (*Connection, error) {
|
||||
// GetConnectedPeers returns a connection Status or nil if peer connection wasn't found
|
||||
func (e *Engine) GetConnectedPeers() []string {
|
||||
e.syncMsgMux.Lock()
|
||||
defer e.syncMsgMux.Unlock()
|
||||
|
||||
remoteKey, _ := wgtypes.ParseKey(peer.WgPubKey)
|
||||
connConfig := &ConnConfig{
|
||||
WgListenAddr: fmt.Sprintf("127.0.0.1:%d", wgPort),
|
||||
WgPeerIP: e.config.WgAddr,
|
||||
WgIface: e.config.WgIface,
|
||||
WgAllowedIPs: peer.WgAllowedIps,
|
||||
WgKey: myKey,
|
||||
RemoteWgKey: remoteKey,
|
||||
StunTurnURLS: append(e.STUNs, e.TURNs...),
|
||||
iFaceBlackList: e.config.IFaceBlackList,
|
||||
PreSharedKey: e.config.PreSharedKey,
|
||||
peers := []string{}
|
||||
for s, conn := range e.peerConns {
|
||||
if conn.Status() == peer.StatusConnected {
|
||||
peers = append(peers, s)
|
||||
}
|
||||
}
|
||||
|
||||
signalOffer := func(uFrag string, pwd string) error {
|
||||
return signalAuth(uFrag, pwd, myKey, remoteKey, e.signal, false)
|
||||
}
|
||||
|
||||
signalAnswer := func(uFrag string, pwd string) error {
|
||||
return signalAuth(uFrag, pwd, myKey, remoteKey, e.signal, true)
|
||||
}
|
||||
signalCandidate := func(candidate ice.Candidate) error {
|
||||
return signalCandidate(candidate, myKey, remoteKey, e.signal)
|
||||
}
|
||||
conn := NewConnection(*connConfig, signalCandidate, signalOffer, signalAnswer)
|
||||
e.peerMux.Lock()
|
||||
e.conns[remoteKey.String()] = conn
|
||||
e.peerMux.Unlock()
|
||||
|
||||
// blocks until the connection is open (or timeout)
|
||||
timeout := rand.Intn(PeerConnectionTimeoutMax-PeerConnectionTimeoutMin) + PeerConnectionTimeoutMin
|
||||
err := conn.Open(time.Duration(timeout) * time.Second)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return conn, nil
|
||||
return peers
|
||||
}
|
||||
|
||||
func signalCandidate(candidate ice.Candidate, myKey wgtypes.Key, remoteKey wgtypes.Key, s *signal.Client) error {
|
||||
@ -400,17 +316,15 @@ func (e *Engine) updateTURNs(turns []*mgmProto.ProtectedHostConfig) error {
|
||||
}
|
||||
|
||||
func (e *Engine) updatePeers(remotePeers []*mgmProto.RemotePeerConfig) error {
|
||||
e.peerMux.Lock()
|
||||
defer e.peerMux.Unlock()
|
||||
log.Debugf("got peers update from Management Service, total peers to connect to = %d", len(remotePeers))
|
||||
remotePeerMap := make(map[string]struct{})
|
||||
for _, peer := range remotePeers {
|
||||
remotePeerMap[peer.GetWgPubKey()] = struct{}{}
|
||||
for _, p := range remotePeers {
|
||||
remotePeerMap[p.GetWgPubKey()] = struct{}{}
|
||||
}
|
||||
|
||||
//remove peers that are no longer available for us
|
||||
toRemove := []string{}
|
||||
for p := range e.conns {
|
||||
for p := range e.peerConns {
|
||||
if _, ok := remotePeerMap[p]; !ok {
|
||||
toRemove = append(toRemove, p)
|
||||
}
|
||||
@ -421,20 +335,115 @@ func (e *Engine) updatePeers(remotePeers []*mgmProto.RemotePeerConfig) error {
|
||||
}
|
||||
|
||||
// add new peers
|
||||
for _, peer := range remotePeers {
|
||||
peerKey := peer.GetWgPubKey()
|
||||
peerIPs := peer.GetAllowedIps()
|
||||
if _, ok := e.peerMap[peerKey]; !ok {
|
||||
e.initializePeer(Peer{
|
||||
WgPubKey: peerKey,
|
||||
WgAllowedIps: strings.Join(peerIPs, ","),
|
||||
})
|
||||
for _, p := range remotePeers {
|
||||
peerKey := p.GetWgPubKey()
|
||||
peerIPs := p.GetAllowedIps()
|
||||
if _, ok := e.peerConns[peerKey]; !ok {
|
||||
conn, err := e.createPeerConn(peerKey, strings.Join(peerIPs, ","))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
e.peerConns[peerKey] = conn
|
||||
|
||||
go e.connWorker(conn, peerKey)
|
||||
}
|
||||
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (e Engine) connWorker(conn *peer.Conn, peerKey string) {
|
||||
for {
|
||||
|
||||
// randomize starting time a bit
|
||||
min := 500
|
||||
max := 2000
|
||||
time.Sleep(time.Duration(rand.Intn(max-min)+min) * time.Millisecond)
|
||||
|
||||
// if peer has been removed -> give up
|
||||
if !e.peerExists(peerKey) {
|
||||
log.Infof("peer %s doesn't exist anymore, won't retry connection", peerKey)
|
||||
return
|
||||
}
|
||||
|
||||
if !e.signal.Ready() {
|
||||
log.Infof("signal client isn't ready, skipping connection attempt %s", peerKey)
|
||||
continue
|
||||
}
|
||||
|
||||
err := conn.Open()
|
||||
if err != nil {
|
||||
log.Debugf("connection to peer %s failed: %v", peerKey, err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (e Engine) peerExists(peerKey string) bool {
|
||||
e.syncMsgMux.Lock()
|
||||
defer e.syncMsgMux.Unlock()
|
||||
_, ok := e.peerConns[peerKey]
|
||||
return ok
|
||||
}
|
||||
|
||||
func (e Engine) createPeerConn(pubKey string, allowedIPs string) (*peer.Conn, error) {
|
||||
|
||||
var stunTurn []*ice.URL
|
||||
stunTurn = append(stunTurn, e.STUNs...)
|
||||
stunTurn = append(stunTurn, e.TURNs...)
|
||||
|
||||
interfaceBlacklist := make([]string, 0, len(e.config.IFaceBlackList))
|
||||
for k := range e.config.IFaceBlackList {
|
||||
interfaceBlacklist = append(interfaceBlacklist, k)
|
||||
}
|
||||
|
||||
proxyConfig := proxy.Config{
|
||||
RemoteKey: pubKey,
|
||||
WgListenAddr: fmt.Sprintf("127.0.0.1:%d", e.config.WgPort),
|
||||
WgInterface: e.config.WgIface,
|
||||
AllowedIps: allowedIPs,
|
||||
PreSharedKey: e.config.PreSharedKey,
|
||||
}
|
||||
|
||||
// randomize connection timeout
|
||||
timeout := time.Duration(rand.Intn(PeerConnectionTimeoutMax-PeerConnectionTimeoutMin)+PeerConnectionTimeoutMin) * time.Millisecond
|
||||
config := peer.ConnConfig{
|
||||
Key: pubKey,
|
||||
LocalKey: e.config.WgPrivateKey.PublicKey().String(),
|
||||
StunTurn: stunTurn,
|
||||
InterfaceBlackList: interfaceBlacklist,
|
||||
Timeout: timeout,
|
||||
ProxyConfig: proxyConfig,
|
||||
}
|
||||
|
||||
peerConn, err := peer.NewConn(config)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
wgPubKey, err := wgtypes.ParseKey(pubKey)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
signalOffer := func(uFrag string, pwd string) error {
|
||||
return signalAuth(uFrag, pwd, e.config.WgPrivateKey, wgPubKey, e.signal, false)
|
||||
}
|
||||
|
||||
signalCandidate := func(candidate ice.Candidate) error {
|
||||
return signalCandidate(candidate, e.config.WgPrivateKey, wgPubKey, e.signal)
|
||||
}
|
||||
|
||||
signalAnswer := func(uFrag string, pwd string) error {
|
||||
return signalAuth(uFrag, pwd, e.config.WgPrivateKey, wgPubKey, e.signal, true)
|
||||
}
|
||||
|
||||
peerConn.SetSignalCandidate(signalCandidate)
|
||||
peerConn.SetSignalOffer(signalOffer)
|
||||
peerConn.SetSignalAnswer(signalAnswer)
|
||||
|
||||
return peerConn, nil
|
||||
}
|
||||
|
||||
// receiveSignalEvents connects to the Signal Service event stream to negotiate connection with remote peers
|
||||
func (e *Engine) receiveSignalEvents() {
|
||||
|
||||
@ -445,58 +454,37 @@ func (e *Engine) receiveSignalEvents() {
|
||||
e.syncMsgMux.Lock()
|
||||
defer e.syncMsgMux.Unlock()
|
||||
|
||||
conn := e.conns[msg.Key]
|
||||
conn := e.peerConns[msg.Key]
|
||||
if conn == nil {
|
||||
return fmt.Errorf("wrongly addressed message %s", msg.Key)
|
||||
}
|
||||
|
||||
if conn.Config.RemoteWgKey.String() != msg.Key {
|
||||
return fmt.Errorf("unknown peer %s", msg.Key)
|
||||
}
|
||||
|
||||
switch msg.GetBody().Type {
|
||||
case sProto.Body_OFFER:
|
||||
remoteCred, err := signal.UnMarshalCredential(msg)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
err = conn.OnOffer(IceCredentials{
|
||||
uFrag: remoteCred.UFrag,
|
||||
pwd: remoteCred.Pwd,
|
||||
conn.OnRemoteOffer(peer.IceCredentials{
|
||||
UFrag: remoteCred.UFrag,
|
||||
Pwd: remoteCred.Pwd,
|
||||
})
|
||||
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
case sProto.Body_ANSWER:
|
||||
remoteCred, err := signal.UnMarshalCredential(msg)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
err = conn.OnAnswer(IceCredentials{
|
||||
uFrag: remoteCred.UFrag,
|
||||
pwd: remoteCred.Pwd,
|
||||
conn.OnRemoteAnswer(peer.IceCredentials{
|
||||
UFrag: remoteCred.UFrag,
|
||||
Pwd: remoteCred.Pwd,
|
||||
})
|
||||
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
case sProto.Body_CANDIDATE:
|
||||
|
||||
candidate, err := ice.UnmarshalCandidate(msg.GetBody().Payload)
|
||||
if err != nil {
|
||||
log.Errorf("failed on parsing remote candidate %s -> %s", candidate, err)
|
||||
return err
|
||||
}
|
||||
|
||||
err = conn.OnRemoteCandidate(candidate)
|
||||
if err != nil {
|
||||
log.Errorf("error handling CANDIATE from %s", msg.Key)
|
||||
return err
|
||||
}
|
||||
conn.OnRemoteCandidate(candidate)
|
||||
}
|
||||
|
||||
return nil
|
||||
|
Reference in New Issue
Block a user