mirror of
https://github.com/netbirdio/netbird.git
synced 2024-12-14 19:00:50 +01:00
Fix memory leak
Avoid to add listeners to multiple times
This commit is contained in:
parent
b62ad97e59
commit
d70df99f7b
@ -34,18 +34,18 @@ type WorkerRelay struct {
|
||||
relayManager relayClient.ManagerService
|
||||
conn WorkerRelayCallbacks
|
||||
|
||||
ctx context.Context
|
||||
ctxCancel context.CancelFunc
|
||||
}
|
||||
|
||||
func NewWorkerRelay(ctx context.Context, log *log.Entry, config ConnConfig, relayManager relayClient.ManagerService, callbacks WorkerRelayCallbacks) *WorkerRelay {
|
||||
return &WorkerRelay{
|
||||
r := &WorkerRelay{
|
||||
parentCtx: ctx,
|
||||
log: log,
|
||||
config: config,
|
||||
relayManager: relayManager,
|
||||
conn: callbacks,
|
||||
}
|
||||
return r
|
||||
}
|
||||
|
||||
func (w *WorkerRelay) OnNewOffer(remoteOfferAnswer *OfferAnswer) {
|
||||
@ -63,7 +63,7 @@ func (w *WorkerRelay) OnNewOffer(remoteOfferAnswer *OfferAnswer) {
|
||||
|
||||
srv := w.preferredRelayServer(currentRelayAddress, remoteOfferAnswer.RelaySrvAddress)
|
||||
|
||||
relayedConn, err := w.relayManager.OpenConn(srv, w.config.Key, w.disconnected)
|
||||
relayedConn, err := w.relayManager.OpenConn(srv, w.config.Key)
|
||||
if err != nil {
|
||||
// todo handle all type errors
|
||||
if errors.Is(err, relayClient.ErrConnAlreadyExists) {
|
||||
@ -74,11 +74,20 @@ func (w *WorkerRelay) OnNewOffer(remoteOfferAnswer *OfferAnswer) {
|
||||
return
|
||||
}
|
||||
|
||||
w.ctx, w.ctxCancel = context.WithCancel(w.parentCtx)
|
||||
ctx, ctxCancel := context.WithCancel(w.parentCtx)
|
||||
w.ctxCancel = ctxCancel
|
||||
|
||||
go w.wgStateCheck(relayedConn)
|
||||
err = w.relayManager.AddCloseListener(srv, w.disconnected)
|
||||
if err != nil {
|
||||
log.Errorf("failed to add close listener: %s", err)
|
||||
_ = relayedConn.Close()
|
||||
ctxCancel()
|
||||
return
|
||||
}
|
||||
|
||||
w.log.Debugf("Relay connection established with %s", srv)
|
||||
go w.wgStateCheck(ctx, relayedConn)
|
||||
|
||||
w.log.Debugf("peer conn opened via Relay: %s", srv)
|
||||
go w.conn.OnConnReady(RelayConnInfo{
|
||||
relayedConn: relayedConn,
|
||||
rosenpassPubKey: remoteOfferAnswer.RosenpassPubKey,
|
||||
@ -99,7 +108,7 @@ func (w *WorkerRelay) RelayIsSupportedLocally() bool {
|
||||
}
|
||||
|
||||
// wgStateCheck help to check the state of the wireguard handshake and relay connection
|
||||
func (w *WorkerRelay) wgStateCheck(conn net.Conn) {
|
||||
func (w *WorkerRelay) wgStateCheck(ctx context.Context, conn net.Conn) {
|
||||
timer := time.NewTimer(wgHandshakeOvertime)
|
||||
defer timer.Stop()
|
||||
for {
|
||||
@ -120,7 +129,7 @@ func (w *WorkerRelay) wgStateCheck(conn net.Conn) {
|
||||
}
|
||||
resetTime := time.Until(lastHandshake.Add(wgHandshakeOvertime + wgHandshakePeriod))
|
||||
timer.Reset(resetTime)
|
||||
case <-w.ctx.Done():
|
||||
case <-ctx.Done():
|
||||
return
|
||||
}
|
||||
}
|
||||
@ -149,6 +158,8 @@ func (w *WorkerRelay) wgState() (time.Time, error) {
|
||||
}
|
||||
|
||||
func (w *WorkerRelay) disconnected() {
|
||||
w.ctxCancel()
|
||||
if w.ctxCancel != nil {
|
||||
w.ctxCancel()
|
||||
}
|
||||
w.conn.OnDisconnected()
|
||||
}
|
||||
|
@ -1,9 +1,11 @@
|
||||
package client
|
||||
|
||||
import (
|
||||
"container/list"
|
||||
"context"
|
||||
"fmt"
|
||||
"net"
|
||||
"reflect"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
@ -30,9 +32,12 @@ func NewRelayTrack() *RelayTrack {
|
||||
return &RelayTrack{}
|
||||
}
|
||||
|
||||
type OnServerCloseListener func()
|
||||
|
||||
type ManagerService interface {
|
||||
Serve() error
|
||||
OpenConn(serverAddress, peerKey string, onClosedListener func()) (net.Conn, error)
|
||||
OpenConn(serverAddress, peerKey string) (net.Conn, error)
|
||||
AddCloseListener(serverAddress string, onClosedListener OnServerCloseListener) error
|
||||
RelayInstanceAddress() (string, error)
|
||||
ServerURL() string
|
||||
HasRelayAddress() bool
|
||||
@ -57,7 +62,7 @@ type Manager struct {
|
||||
relayClients map[string]*RelayTrack
|
||||
relayClientsMutex sync.RWMutex
|
||||
|
||||
onDisconnectedListeners map[string]map[*func()]struct{}
|
||||
onDisconnectedListeners map[string]*list.List
|
||||
listenerLock sync.Mutex
|
||||
}
|
||||
|
||||
@ -68,7 +73,7 @@ func NewManager(ctx context.Context, serverURL string, peerID string) *Manager {
|
||||
peerID: peerID,
|
||||
tokenStore: &relayAuth.TokenStore{},
|
||||
relayClients: make(map[string]*RelayTrack),
|
||||
onDisconnectedListeners: make(map[string]map[*func()]struct{}),
|
||||
onDisconnectedListeners: make(map[string]*list.List),
|
||||
}
|
||||
}
|
||||
|
||||
@ -97,7 +102,7 @@ func (m *Manager) Serve() error {
|
||||
// OpenConn opens a connection to the given peer key. If the peer is on the same relay server, the connection will be
|
||||
// established via the relay server. If the peer is on a different relay server, the manager will establish a new
|
||||
// connection to the relay server.
|
||||
func (m *Manager) OpenConn(serverAddress, peerKey string, onClosedListener func()) (net.Conn, error) {
|
||||
func (m *Manager) OpenConn(serverAddress, peerKey string) (net.Conn, error) {
|
||||
if m.relayClient == nil {
|
||||
return nil, errRelayClientNotConnected
|
||||
}
|
||||
@ -121,19 +126,23 @@ func (m *Manager) OpenConn(serverAddress, peerKey string, onClosedListener func(
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if onClosedListener != nil {
|
||||
var listenerAddr string
|
||||
if foreign {
|
||||
m.addListener(serverAddress, onClosedListener)
|
||||
listenerAddr = serverAddress
|
||||
} else {
|
||||
listenerAddr = m.serverURL
|
||||
}
|
||||
m.addListener(listenerAddr, onClosedListener)
|
||||
return netConn, err
|
||||
}
|
||||
|
||||
func (m *Manager) AddCloseListener(serverAddress string, onClosedListener OnServerCloseListener) error {
|
||||
foreign, err := m.isForeignServer(serverAddress)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return netConn, err
|
||||
var listenerAddr string
|
||||
if foreign {
|
||||
listenerAddr = serverAddress
|
||||
} else {
|
||||
listenerAddr = m.serverURL
|
||||
}
|
||||
m.addListener(listenerAddr, onClosedListener)
|
||||
return nil
|
||||
}
|
||||
|
||||
// RelayInstanceAddress returns the address of the permanent relay server. It could change if the network connection is lost.
|
||||
@ -265,14 +274,19 @@ func (m *Manager) cleanUpUnusedRelays() {
|
||||
}
|
||||
}
|
||||
|
||||
func (m *Manager) addListener(serverAddress string, onClosedListener func()) {
|
||||
func (m *Manager) addListener(serverAddress string, onClosedListener OnServerCloseListener) {
|
||||
m.listenerLock.Lock()
|
||||
defer m.listenerLock.Unlock()
|
||||
l, ok := m.onDisconnectedListeners[serverAddress]
|
||||
if !ok {
|
||||
l = make(map[*func()]struct{})
|
||||
l = list.New()
|
||||
}
|
||||
l[&onClosedListener] = struct{}{}
|
||||
for e := l.Front(); e != nil; e = e.Next() {
|
||||
if reflect.ValueOf(e.Value).Pointer() == reflect.ValueOf(onClosedListener).Pointer() {
|
||||
return
|
||||
}
|
||||
}
|
||||
l.PushBack(onClosedListener)
|
||||
m.onDisconnectedListeners[serverAddress] = l
|
||||
}
|
||||
|
||||
@ -284,8 +298,8 @@ func (m *Manager) notifyOnDisconnectListeners(serverAddress string) {
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
for f := range l {
|
||||
go (*f)()
|
||||
for e := l.Front(); e != nil; e = e.Next() {
|
||||
go e.Value.(OnServerCloseListener)()
|
||||
}
|
||||
delete(m.onDisconnectedListeners, serverAddress)
|
||||
}
|
||||
|
@ -87,11 +87,11 @@ func TestForeignConn(t *testing.T) {
|
||||
if err != nil {
|
||||
t.Fatalf("failed to get relay address: %s", err)
|
||||
}
|
||||
connAliceToBob, err := clientAlice.OpenConn(bobsSrvAddr, idBob, nil)
|
||||
connAliceToBob, err := clientAlice.OpenConn(bobsSrvAddr, idBob)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to bind channel: %s", err)
|
||||
}
|
||||
connBobToAlice, err := clientBob.OpenConn(bobsSrvAddr, idAlice, nil)
|
||||
connBobToAlice, err := clientBob.OpenConn(bobsSrvAddr, idAlice)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to bind channel: %s", err)
|
||||
}
|
||||
@ -187,7 +187,7 @@ func TestForeginConnClose(t *testing.T) {
|
||||
if err != nil {
|
||||
t.Fatalf("failed to serve manager: %s", err)
|
||||
}
|
||||
conn, err := mgr.OpenConn(toURL(srvCfg2), "anotherpeer", nil)
|
||||
conn, err := mgr.OpenConn(toURL(srvCfg2), "anotherpeer")
|
||||
if err != nil {
|
||||
t.Fatalf("failed to bind channel: %s", err)
|
||||
}
|
||||
@ -269,7 +269,7 @@ func TestForeginAutoClose(t *testing.T) {
|
||||
}
|
||||
|
||||
t.Log("open connection to another peer")
|
||||
conn, err := mgr.OpenConn(toURL(srvCfg2), "anotherpeer", nil)
|
||||
conn, err := mgr.OpenConn(toURL(srvCfg2), "anotherpeer")
|
||||
if err != nil {
|
||||
t.Fatalf("failed to bind channel: %s", err)
|
||||
}
|
||||
@ -330,7 +330,7 @@ func TestAutoReconnect(t *testing.T) {
|
||||
if err != nil {
|
||||
t.Errorf("failed to get relay address: %s", err)
|
||||
}
|
||||
conn, err := clientAlice.OpenConn(ra, "bob", nil)
|
||||
conn, err := clientAlice.OpenConn(ra, "bob")
|
||||
if err != nil {
|
||||
t.Errorf("failed to bind channel: %s", err)
|
||||
}
|
||||
@ -348,12 +348,77 @@ func TestAutoReconnect(t *testing.T) {
|
||||
time.Sleep(reconnectingTimeout + 1*time.Second)
|
||||
|
||||
log.Infof("reopent the connection")
|
||||
_, err = clientAlice.OpenConn(ra, "bob", nil)
|
||||
_, err = clientAlice.OpenConn(ra, "bob")
|
||||
if err != nil {
|
||||
t.Errorf("failed to open channel: %s", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestNotifierDoubleAdd(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
|
||||
srvCfg1 := server.ListenerConfig{
|
||||
Address: "localhost:1234",
|
||||
}
|
||||
srv1, err := server.NewServer(otel.Meter(""), srvCfg1.Address, false, av)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to create server: %s", err)
|
||||
}
|
||||
errChan := make(chan error, 1)
|
||||
go func() {
|
||||
err := srv1.Listen(srvCfg1)
|
||||
if err != nil {
|
||||
errChan <- err
|
||||
}
|
||||
}()
|
||||
|
||||
defer func() {
|
||||
err := srv1.Close()
|
||||
if err != nil {
|
||||
t.Errorf("failed to close server: %s", err)
|
||||
}
|
||||
}()
|
||||
|
||||
if err := waitForServerToStart(errChan); err != nil {
|
||||
t.Fatalf("failed to start server: %s", err)
|
||||
}
|
||||
|
||||
idAlice := "alice"
|
||||
log.Debugf("connect by alice")
|
||||
mCtx, cancel := context.WithCancel(ctx)
|
||||
defer cancel()
|
||||
clientAlice := NewManager(mCtx, toURL(srvCfg1), idAlice)
|
||||
err = clientAlice.Serve()
|
||||
if err != nil {
|
||||
t.Fatalf("failed to serve manager: %s", err)
|
||||
}
|
||||
|
||||
conn1, err := clientAlice.OpenConn(clientAlice.ServerURL(), "idBob")
|
||||
if err != nil {
|
||||
t.Fatalf("failed to bind channel: %s", err)
|
||||
}
|
||||
|
||||
fnCloseListener := OnServerCloseListener(func() {
|
||||
log.Infof("close listener")
|
||||
})
|
||||
|
||||
err = clientAlice.AddCloseListener(clientAlice.ServerURL(), fnCloseListener)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to add close listener: %s", err)
|
||||
}
|
||||
|
||||
err = clientAlice.AddCloseListener(clientAlice.ServerURL(), fnCloseListener)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to add close listener: %s", err)
|
||||
}
|
||||
|
||||
err = conn1.Close()
|
||||
if err != nil {
|
||||
t.Errorf("failed to close connection: %s", err)
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
func toURL(address server.ListenerConfig) string {
|
||||
return "rel://" + address.Address
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user