Fix memory leak

Avoid to add listeners to multiple times
This commit is contained in:
Zoltán Papp 2024-07-25 17:21:27 +02:00
parent b62ad97e59
commit d70df99f7b
3 changed files with 124 additions and 34 deletions

View File

@ -34,18 +34,18 @@ type WorkerRelay struct {
relayManager relayClient.ManagerService
conn WorkerRelayCallbacks
ctx context.Context
ctxCancel context.CancelFunc
}
func NewWorkerRelay(ctx context.Context, log *log.Entry, config ConnConfig, relayManager relayClient.ManagerService, callbacks WorkerRelayCallbacks) *WorkerRelay {
return &WorkerRelay{
r := &WorkerRelay{
parentCtx: ctx,
log: log,
config: config,
relayManager: relayManager,
conn: callbacks,
}
return r
}
func (w *WorkerRelay) OnNewOffer(remoteOfferAnswer *OfferAnswer) {
@ -63,7 +63,7 @@ func (w *WorkerRelay) OnNewOffer(remoteOfferAnswer *OfferAnswer) {
srv := w.preferredRelayServer(currentRelayAddress, remoteOfferAnswer.RelaySrvAddress)
relayedConn, err := w.relayManager.OpenConn(srv, w.config.Key, w.disconnected)
relayedConn, err := w.relayManager.OpenConn(srv, w.config.Key)
if err != nil {
// todo handle all type errors
if errors.Is(err, relayClient.ErrConnAlreadyExists) {
@ -74,11 +74,20 @@ func (w *WorkerRelay) OnNewOffer(remoteOfferAnswer *OfferAnswer) {
return
}
w.ctx, w.ctxCancel = context.WithCancel(w.parentCtx)
ctx, ctxCancel := context.WithCancel(w.parentCtx)
w.ctxCancel = ctxCancel
go w.wgStateCheck(relayedConn)
err = w.relayManager.AddCloseListener(srv, w.disconnected)
if err != nil {
log.Errorf("failed to add close listener: %s", err)
_ = relayedConn.Close()
ctxCancel()
return
}
w.log.Debugf("Relay connection established with %s", srv)
go w.wgStateCheck(ctx, relayedConn)
w.log.Debugf("peer conn opened via Relay: %s", srv)
go w.conn.OnConnReady(RelayConnInfo{
relayedConn: relayedConn,
rosenpassPubKey: remoteOfferAnswer.RosenpassPubKey,
@ -99,7 +108,7 @@ func (w *WorkerRelay) RelayIsSupportedLocally() bool {
}
// wgStateCheck help to check the state of the wireguard handshake and relay connection
func (w *WorkerRelay) wgStateCheck(conn net.Conn) {
func (w *WorkerRelay) wgStateCheck(ctx context.Context, conn net.Conn) {
timer := time.NewTimer(wgHandshakeOvertime)
defer timer.Stop()
for {
@ -120,7 +129,7 @@ func (w *WorkerRelay) wgStateCheck(conn net.Conn) {
}
resetTime := time.Until(lastHandshake.Add(wgHandshakeOvertime + wgHandshakePeriod))
timer.Reset(resetTime)
case <-w.ctx.Done():
case <-ctx.Done():
return
}
}
@ -149,6 +158,8 @@ func (w *WorkerRelay) wgState() (time.Time, error) {
}
func (w *WorkerRelay) disconnected() {
w.ctxCancel()
if w.ctxCancel != nil {
w.ctxCancel()
}
w.conn.OnDisconnected()
}

View File

@ -1,9 +1,11 @@
package client
import (
"container/list"
"context"
"fmt"
"net"
"reflect"
"sync"
"time"
@ -30,9 +32,12 @@ func NewRelayTrack() *RelayTrack {
return &RelayTrack{}
}
type OnServerCloseListener func()
type ManagerService interface {
Serve() error
OpenConn(serverAddress, peerKey string, onClosedListener func()) (net.Conn, error)
OpenConn(serverAddress, peerKey string) (net.Conn, error)
AddCloseListener(serverAddress string, onClosedListener OnServerCloseListener) error
RelayInstanceAddress() (string, error)
ServerURL() string
HasRelayAddress() bool
@ -57,7 +62,7 @@ type Manager struct {
relayClients map[string]*RelayTrack
relayClientsMutex sync.RWMutex
onDisconnectedListeners map[string]map[*func()]struct{}
onDisconnectedListeners map[string]*list.List
listenerLock sync.Mutex
}
@ -68,7 +73,7 @@ func NewManager(ctx context.Context, serverURL string, peerID string) *Manager {
peerID: peerID,
tokenStore: &relayAuth.TokenStore{},
relayClients: make(map[string]*RelayTrack),
onDisconnectedListeners: make(map[string]map[*func()]struct{}),
onDisconnectedListeners: make(map[string]*list.List),
}
}
@ -97,7 +102,7 @@ func (m *Manager) Serve() error {
// OpenConn opens a connection to the given peer key. If the peer is on the same relay server, the connection will be
// established via the relay server. If the peer is on a different relay server, the manager will establish a new
// connection to the relay server.
func (m *Manager) OpenConn(serverAddress, peerKey string, onClosedListener func()) (net.Conn, error) {
func (m *Manager) OpenConn(serverAddress, peerKey string) (net.Conn, error) {
if m.relayClient == nil {
return nil, errRelayClientNotConnected
}
@ -121,19 +126,23 @@ func (m *Manager) OpenConn(serverAddress, peerKey string, onClosedListener func(
return nil, err
}
if onClosedListener != nil {
var listenerAddr string
if foreign {
m.addListener(serverAddress, onClosedListener)
listenerAddr = serverAddress
} else {
listenerAddr = m.serverURL
}
m.addListener(listenerAddr, onClosedListener)
return netConn, err
}
func (m *Manager) AddCloseListener(serverAddress string, onClosedListener OnServerCloseListener) error {
foreign, err := m.isForeignServer(serverAddress)
if err != nil {
return err
}
return netConn, err
var listenerAddr string
if foreign {
listenerAddr = serverAddress
} else {
listenerAddr = m.serverURL
}
m.addListener(listenerAddr, onClosedListener)
return nil
}
// RelayInstanceAddress returns the address of the permanent relay server. It could change if the network connection is lost.
@ -265,14 +274,19 @@ func (m *Manager) cleanUpUnusedRelays() {
}
}
func (m *Manager) addListener(serverAddress string, onClosedListener func()) {
func (m *Manager) addListener(serverAddress string, onClosedListener OnServerCloseListener) {
m.listenerLock.Lock()
defer m.listenerLock.Unlock()
l, ok := m.onDisconnectedListeners[serverAddress]
if !ok {
l = make(map[*func()]struct{})
l = list.New()
}
l[&onClosedListener] = struct{}{}
for e := l.Front(); e != nil; e = e.Next() {
if reflect.ValueOf(e.Value).Pointer() == reflect.ValueOf(onClosedListener).Pointer() {
return
}
}
l.PushBack(onClosedListener)
m.onDisconnectedListeners[serverAddress] = l
}
@ -284,8 +298,8 @@ func (m *Manager) notifyOnDisconnectListeners(serverAddress string) {
if !ok {
return
}
for f := range l {
go (*f)()
for e := l.Front(); e != nil; e = e.Next() {
go e.Value.(OnServerCloseListener)()
}
delete(m.onDisconnectedListeners, serverAddress)
}

View File

@ -87,11 +87,11 @@ func TestForeignConn(t *testing.T) {
if err != nil {
t.Fatalf("failed to get relay address: %s", err)
}
connAliceToBob, err := clientAlice.OpenConn(bobsSrvAddr, idBob, nil)
connAliceToBob, err := clientAlice.OpenConn(bobsSrvAddr, idBob)
if err != nil {
t.Fatalf("failed to bind channel: %s", err)
}
connBobToAlice, err := clientBob.OpenConn(bobsSrvAddr, idAlice, nil)
connBobToAlice, err := clientBob.OpenConn(bobsSrvAddr, idAlice)
if err != nil {
t.Fatalf("failed to bind channel: %s", err)
}
@ -187,7 +187,7 @@ func TestForeginConnClose(t *testing.T) {
if err != nil {
t.Fatalf("failed to serve manager: %s", err)
}
conn, err := mgr.OpenConn(toURL(srvCfg2), "anotherpeer", nil)
conn, err := mgr.OpenConn(toURL(srvCfg2), "anotherpeer")
if err != nil {
t.Fatalf("failed to bind channel: %s", err)
}
@ -269,7 +269,7 @@ func TestForeginAutoClose(t *testing.T) {
}
t.Log("open connection to another peer")
conn, err := mgr.OpenConn(toURL(srvCfg2), "anotherpeer", nil)
conn, err := mgr.OpenConn(toURL(srvCfg2), "anotherpeer")
if err != nil {
t.Fatalf("failed to bind channel: %s", err)
}
@ -330,7 +330,7 @@ func TestAutoReconnect(t *testing.T) {
if err != nil {
t.Errorf("failed to get relay address: %s", err)
}
conn, err := clientAlice.OpenConn(ra, "bob", nil)
conn, err := clientAlice.OpenConn(ra, "bob")
if err != nil {
t.Errorf("failed to bind channel: %s", err)
}
@ -348,12 +348,77 @@ func TestAutoReconnect(t *testing.T) {
time.Sleep(reconnectingTimeout + 1*time.Second)
log.Infof("reopent the connection")
_, err = clientAlice.OpenConn(ra, "bob", nil)
_, err = clientAlice.OpenConn(ra, "bob")
if err != nil {
t.Errorf("failed to open channel: %s", err)
}
}
func TestNotifierDoubleAdd(t *testing.T) {
ctx := context.Background()
srvCfg1 := server.ListenerConfig{
Address: "localhost:1234",
}
srv1, err := server.NewServer(otel.Meter(""), srvCfg1.Address, false, av)
if err != nil {
t.Fatalf("failed to create server: %s", err)
}
errChan := make(chan error, 1)
go func() {
err := srv1.Listen(srvCfg1)
if err != nil {
errChan <- err
}
}()
defer func() {
err := srv1.Close()
if err != nil {
t.Errorf("failed to close server: %s", err)
}
}()
if err := waitForServerToStart(errChan); err != nil {
t.Fatalf("failed to start server: %s", err)
}
idAlice := "alice"
log.Debugf("connect by alice")
mCtx, cancel := context.WithCancel(ctx)
defer cancel()
clientAlice := NewManager(mCtx, toURL(srvCfg1), idAlice)
err = clientAlice.Serve()
if err != nil {
t.Fatalf("failed to serve manager: %s", err)
}
conn1, err := clientAlice.OpenConn(clientAlice.ServerURL(), "idBob")
if err != nil {
t.Fatalf("failed to bind channel: %s", err)
}
fnCloseListener := OnServerCloseListener(func() {
log.Infof("close listener")
})
err = clientAlice.AddCloseListener(clientAlice.ServerURL(), fnCloseListener)
if err != nil {
t.Fatalf("failed to add close listener: %s", err)
}
err = clientAlice.AddCloseListener(clientAlice.ServerURL(), fnCloseListener)
if err != nil {
t.Fatalf("failed to add close listener: %s", err)
}
err = conn1.Close()
if err != nil {
t.Errorf("failed to close connection: %s", err)
}
}
func toURL(address server.ListenerConfig) string {
return "rel://" + address.Address
}