mirror of
https://github.com/netbirdio/netbird.git
synced 2025-06-21 10:18:50 +02:00
Fix memory leak
Avoid to add listeners to multiple times
This commit is contained in:
parent
b62ad97e59
commit
d70df99f7b
@ -34,18 +34,18 @@ type WorkerRelay struct {
|
|||||||
relayManager relayClient.ManagerService
|
relayManager relayClient.ManagerService
|
||||||
conn WorkerRelayCallbacks
|
conn WorkerRelayCallbacks
|
||||||
|
|
||||||
ctx context.Context
|
|
||||||
ctxCancel context.CancelFunc
|
ctxCancel context.CancelFunc
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewWorkerRelay(ctx context.Context, log *log.Entry, config ConnConfig, relayManager relayClient.ManagerService, callbacks WorkerRelayCallbacks) *WorkerRelay {
|
func NewWorkerRelay(ctx context.Context, log *log.Entry, config ConnConfig, relayManager relayClient.ManagerService, callbacks WorkerRelayCallbacks) *WorkerRelay {
|
||||||
return &WorkerRelay{
|
r := &WorkerRelay{
|
||||||
parentCtx: ctx,
|
parentCtx: ctx,
|
||||||
log: log,
|
log: log,
|
||||||
config: config,
|
config: config,
|
||||||
relayManager: relayManager,
|
relayManager: relayManager,
|
||||||
conn: callbacks,
|
conn: callbacks,
|
||||||
}
|
}
|
||||||
|
return r
|
||||||
}
|
}
|
||||||
|
|
||||||
func (w *WorkerRelay) OnNewOffer(remoteOfferAnswer *OfferAnswer) {
|
func (w *WorkerRelay) OnNewOffer(remoteOfferAnswer *OfferAnswer) {
|
||||||
@ -63,7 +63,7 @@ func (w *WorkerRelay) OnNewOffer(remoteOfferAnswer *OfferAnswer) {
|
|||||||
|
|
||||||
srv := w.preferredRelayServer(currentRelayAddress, remoteOfferAnswer.RelaySrvAddress)
|
srv := w.preferredRelayServer(currentRelayAddress, remoteOfferAnswer.RelaySrvAddress)
|
||||||
|
|
||||||
relayedConn, err := w.relayManager.OpenConn(srv, w.config.Key, w.disconnected)
|
relayedConn, err := w.relayManager.OpenConn(srv, w.config.Key)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
// todo handle all type errors
|
// todo handle all type errors
|
||||||
if errors.Is(err, relayClient.ErrConnAlreadyExists) {
|
if errors.Is(err, relayClient.ErrConnAlreadyExists) {
|
||||||
@ -74,11 +74,20 @@ func (w *WorkerRelay) OnNewOffer(remoteOfferAnswer *OfferAnswer) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
w.ctx, w.ctxCancel = context.WithCancel(w.parentCtx)
|
ctx, ctxCancel := context.WithCancel(w.parentCtx)
|
||||||
|
w.ctxCancel = ctxCancel
|
||||||
|
|
||||||
go w.wgStateCheck(relayedConn)
|
err = w.relayManager.AddCloseListener(srv, w.disconnected)
|
||||||
|
if err != nil {
|
||||||
|
log.Errorf("failed to add close listener: %s", err)
|
||||||
|
_ = relayedConn.Close()
|
||||||
|
ctxCancel()
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
w.log.Debugf("Relay connection established with %s", srv)
|
go w.wgStateCheck(ctx, relayedConn)
|
||||||
|
|
||||||
|
w.log.Debugf("peer conn opened via Relay: %s", srv)
|
||||||
go w.conn.OnConnReady(RelayConnInfo{
|
go w.conn.OnConnReady(RelayConnInfo{
|
||||||
relayedConn: relayedConn,
|
relayedConn: relayedConn,
|
||||||
rosenpassPubKey: remoteOfferAnswer.RosenpassPubKey,
|
rosenpassPubKey: remoteOfferAnswer.RosenpassPubKey,
|
||||||
@ -99,7 +108,7 @@ func (w *WorkerRelay) RelayIsSupportedLocally() bool {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// wgStateCheck help to check the state of the wireguard handshake and relay connection
|
// wgStateCheck help to check the state of the wireguard handshake and relay connection
|
||||||
func (w *WorkerRelay) wgStateCheck(conn net.Conn) {
|
func (w *WorkerRelay) wgStateCheck(ctx context.Context, conn net.Conn) {
|
||||||
timer := time.NewTimer(wgHandshakeOvertime)
|
timer := time.NewTimer(wgHandshakeOvertime)
|
||||||
defer timer.Stop()
|
defer timer.Stop()
|
||||||
for {
|
for {
|
||||||
@ -120,7 +129,7 @@ func (w *WorkerRelay) wgStateCheck(conn net.Conn) {
|
|||||||
}
|
}
|
||||||
resetTime := time.Until(lastHandshake.Add(wgHandshakeOvertime + wgHandshakePeriod))
|
resetTime := time.Until(lastHandshake.Add(wgHandshakeOvertime + wgHandshakePeriod))
|
||||||
timer.Reset(resetTime)
|
timer.Reset(resetTime)
|
||||||
case <-w.ctx.Done():
|
case <-ctx.Done():
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -149,6 +158,8 @@ func (w *WorkerRelay) wgState() (time.Time, error) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (w *WorkerRelay) disconnected() {
|
func (w *WorkerRelay) disconnected() {
|
||||||
|
if w.ctxCancel != nil {
|
||||||
w.ctxCancel()
|
w.ctxCancel()
|
||||||
|
}
|
||||||
w.conn.OnDisconnected()
|
w.conn.OnDisconnected()
|
||||||
}
|
}
|
||||||
|
@ -1,9 +1,11 @@
|
|||||||
package client
|
package client
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"container/list"
|
||||||
"context"
|
"context"
|
||||||
"fmt"
|
"fmt"
|
||||||
"net"
|
"net"
|
||||||
|
"reflect"
|
||||||
"sync"
|
"sync"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
@ -30,9 +32,12 @@ func NewRelayTrack() *RelayTrack {
|
|||||||
return &RelayTrack{}
|
return &RelayTrack{}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type OnServerCloseListener func()
|
||||||
|
|
||||||
type ManagerService interface {
|
type ManagerService interface {
|
||||||
Serve() error
|
Serve() error
|
||||||
OpenConn(serverAddress, peerKey string, onClosedListener func()) (net.Conn, error)
|
OpenConn(serverAddress, peerKey string) (net.Conn, error)
|
||||||
|
AddCloseListener(serverAddress string, onClosedListener OnServerCloseListener) error
|
||||||
RelayInstanceAddress() (string, error)
|
RelayInstanceAddress() (string, error)
|
||||||
ServerURL() string
|
ServerURL() string
|
||||||
HasRelayAddress() bool
|
HasRelayAddress() bool
|
||||||
@ -57,7 +62,7 @@ type Manager struct {
|
|||||||
relayClients map[string]*RelayTrack
|
relayClients map[string]*RelayTrack
|
||||||
relayClientsMutex sync.RWMutex
|
relayClientsMutex sync.RWMutex
|
||||||
|
|
||||||
onDisconnectedListeners map[string]map[*func()]struct{}
|
onDisconnectedListeners map[string]*list.List
|
||||||
listenerLock sync.Mutex
|
listenerLock sync.Mutex
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -68,7 +73,7 @@ func NewManager(ctx context.Context, serverURL string, peerID string) *Manager {
|
|||||||
peerID: peerID,
|
peerID: peerID,
|
||||||
tokenStore: &relayAuth.TokenStore{},
|
tokenStore: &relayAuth.TokenStore{},
|
||||||
relayClients: make(map[string]*RelayTrack),
|
relayClients: make(map[string]*RelayTrack),
|
||||||
onDisconnectedListeners: make(map[string]map[*func()]struct{}),
|
onDisconnectedListeners: make(map[string]*list.List),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -97,7 +102,7 @@ func (m *Manager) Serve() error {
|
|||||||
// OpenConn opens a connection to the given peer key. If the peer is on the same relay server, the connection will be
|
// OpenConn opens a connection to the given peer key. If the peer is on the same relay server, the connection will be
|
||||||
// established via the relay server. If the peer is on a different relay server, the manager will establish a new
|
// established via the relay server. If the peer is on a different relay server, the manager will establish a new
|
||||||
// connection to the relay server.
|
// connection to the relay server.
|
||||||
func (m *Manager) OpenConn(serverAddress, peerKey string, onClosedListener func()) (net.Conn, error) {
|
func (m *Manager) OpenConn(serverAddress, peerKey string) (net.Conn, error) {
|
||||||
if m.relayClient == nil {
|
if m.relayClient == nil {
|
||||||
return nil, errRelayClientNotConnected
|
return nil, errRelayClientNotConnected
|
||||||
}
|
}
|
||||||
@ -121,19 +126,23 @@ func (m *Manager) OpenConn(serverAddress, peerKey string, onClosedListener func(
|
|||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
if onClosedListener != nil {
|
return netConn, err
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *Manager) AddCloseListener(serverAddress string, onClosedListener OnServerCloseListener) error {
|
||||||
|
foreign, err := m.isForeignServer(serverAddress)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
var listenerAddr string
|
var listenerAddr string
|
||||||
if foreign {
|
if foreign {
|
||||||
m.addListener(serverAddress, onClosedListener)
|
|
||||||
listenerAddr = serverAddress
|
listenerAddr = serverAddress
|
||||||
} else {
|
} else {
|
||||||
listenerAddr = m.serverURL
|
listenerAddr = m.serverURL
|
||||||
}
|
}
|
||||||
m.addListener(listenerAddr, onClosedListener)
|
m.addListener(listenerAddr, onClosedListener)
|
||||||
|
return nil
|
||||||
}
|
|
||||||
|
|
||||||
return netConn, err
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// RelayInstanceAddress returns the address of the permanent relay server. It could change if the network connection is lost.
|
// RelayInstanceAddress returns the address of the permanent relay server. It could change if the network connection is lost.
|
||||||
@ -265,14 +274,19 @@ func (m *Manager) cleanUpUnusedRelays() {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *Manager) addListener(serverAddress string, onClosedListener func()) {
|
func (m *Manager) addListener(serverAddress string, onClosedListener OnServerCloseListener) {
|
||||||
m.listenerLock.Lock()
|
m.listenerLock.Lock()
|
||||||
defer m.listenerLock.Unlock()
|
defer m.listenerLock.Unlock()
|
||||||
l, ok := m.onDisconnectedListeners[serverAddress]
|
l, ok := m.onDisconnectedListeners[serverAddress]
|
||||||
if !ok {
|
if !ok {
|
||||||
l = make(map[*func()]struct{})
|
l = list.New()
|
||||||
}
|
}
|
||||||
l[&onClosedListener] = struct{}{}
|
for e := l.Front(); e != nil; e = e.Next() {
|
||||||
|
if reflect.ValueOf(e.Value).Pointer() == reflect.ValueOf(onClosedListener).Pointer() {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
l.PushBack(onClosedListener)
|
||||||
m.onDisconnectedListeners[serverAddress] = l
|
m.onDisconnectedListeners[serverAddress] = l
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -284,8 +298,8 @@ func (m *Manager) notifyOnDisconnectListeners(serverAddress string) {
|
|||||||
if !ok {
|
if !ok {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
for f := range l {
|
for e := l.Front(); e != nil; e = e.Next() {
|
||||||
go (*f)()
|
go e.Value.(OnServerCloseListener)()
|
||||||
}
|
}
|
||||||
delete(m.onDisconnectedListeners, serverAddress)
|
delete(m.onDisconnectedListeners, serverAddress)
|
||||||
}
|
}
|
||||||
|
@ -87,11 +87,11 @@ func TestForeignConn(t *testing.T) {
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("failed to get relay address: %s", err)
|
t.Fatalf("failed to get relay address: %s", err)
|
||||||
}
|
}
|
||||||
connAliceToBob, err := clientAlice.OpenConn(bobsSrvAddr, idBob, nil)
|
connAliceToBob, err := clientAlice.OpenConn(bobsSrvAddr, idBob)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("failed to bind channel: %s", err)
|
t.Fatalf("failed to bind channel: %s", err)
|
||||||
}
|
}
|
||||||
connBobToAlice, err := clientBob.OpenConn(bobsSrvAddr, idAlice, nil)
|
connBobToAlice, err := clientBob.OpenConn(bobsSrvAddr, idAlice)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("failed to bind channel: %s", err)
|
t.Fatalf("failed to bind channel: %s", err)
|
||||||
}
|
}
|
||||||
@ -187,7 +187,7 @@ func TestForeginConnClose(t *testing.T) {
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("failed to serve manager: %s", err)
|
t.Fatalf("failed to serve manager: %s", err)
|
||||||
}
|
}
|
||||||
conn, err := mgr.OpenConn(toURL(srvCfg2), "anotherpeer", nil)
|
conn, err := mgr.OpenConn(toURL(srvCfg2), "anotherpeer")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("failed to bind channel: %s", err)
|
t.Fatalf("failed to bind channel: %s", err)
|
||||||
}
|
}
|
||||||
@ -269,7 +269,7 @@ func TestForeginAutoClose(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
t.Log("open connection to another peer")
|
t.Log("open connection to another peer")
|
||||||
conn, err := mgr.OpenConn(toURL(srvCfg2), "anotherpeer", nil)
|
conn, err := mgr.OpenConn(toURL(srvCfg2), "anotherpeer")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("failed to bind channel: %s", err)
|
t.Fatalf("failed to bind channel: %s", err)
|
||||||
}
|
}
|
||||||
@ -330,7 +330,7 @@ func TestAutoReconnect(t *testing.T) {
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
t.Errorf("failed to get relay address: %s", err)
|
t.Errorf("failed to get relay address: %s", err)
|
||||||
}
|
}
|
||||||
conn, err := clientAlice.OpenConn(ra, "bob", nil)
|
conn, err := clientAlice.OpenConn(ra, "bob")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Errorf("failed to bind channel: %s", err)
|
t.Errorf("failed to bind channel: %s", err)
|
||||||
}
|
}
|
||||||
@ -348,12 +348,77 @@ func TestAutoReconnect(t *testing.T) {
|
|||||||
time.Sleep(reconnectingTimeout + 1*time.Second)
|
time.Sleep(reconnectingTimeout + 1*time.Second)
|
||||||
|
|
||||||
log.Infof("reopent the connection")
|
log.Infof("reopent the connection")
|
||||||
_, err = clientAlice.OpenConn(ra, "bob", nil)
|
_, err = clientAlice.OpenConn(ra, "bob")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Errorf("failed to open channel: %s", err)
|
t.Errorf("failed to open channel: %s", err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestNotifierDoubleAdd(t *testing.T) {
|
||||||
|
ctx := context.Background()
|
||||||
|
|
||||||
|
srvCfg1 := server.ListenerConfig{
|
||||||
|
Address: "localhost:1234",
|
||||||
|
}
|
||||||
|
srv1, err := server.NewServer(otel.Meter(""), srvCfg1.Address, false, av)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("failed to create server: %s", err)
|
||||||
|
}
|
||||||
|
errChan := make(chan error, 1)
|
||||||
|
go func() {
|
||||||
|
err := srv1.Listen(srvCfg1)
|
||||||
|
if err != nil {
|
||||||
|
errChan <- err
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
defer func() {
|
||||||
|
err := srv1.Close()
|
||||||
|
if err != nil {
|
||||||
|
t.Errorf("failed to close server: %s", err)
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
if err := waitForServerToStart(errChan); err != nil {
|
||||||
|
t.Fatalf("failed to start server: %s", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
idAlice := "alice"
|
||||||
|
log.Debugf("connect by alice")
|
||||||
|
mCtx, cancel := context.WithCancel(ctx)
|
||||||
|
defer cancel()
|
||||||
|
clientAlice := NewManager(mCtx, toURL(srvCfg1), idAlice)
|
||||||
|
err = clientAlice.Serve()
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("failed to serve manager: %s", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
conn1, err := clientAlice.OpenConn(clientAlice.ServerURL(), "idBob")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("failed to bind channel: %s", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
fnCloseListener := OnServerCloseListener(func() {
|
||||||
|
log.Infof("close listener")
|
||||||
|
})
|
||||||
|
|
||||||
|
err = clientAlice.AddCloseListener(clientAlice.ServerURL(), fnCloseListener)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("failed to add close listener: %s", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
err = clientAlice.AddCloseListener(clientAlice.ServerURL(), fnCloseListener)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("failed to add close listener: %s", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
err = conn1.Close()
|
||||||
|
if err != nil {
|
||||||
|
t.Errorf("failed to close connection: %s", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
func toURL(address server.ListenerConfig) string {
|
func toURL(address server.ListenerConfig) string {
|
||||||
return "rel://" + address.Address
|
return "rel://" + address.Address
|
||||||
}
|
}
|
||||||
|
Loading…
x
Reference in New Issue
Block a user