Fix memory leak

Avoid to add listeners to multiple times
This commit is contained in:
Zoltán Papp 2024-07-25 17:21:27 +02:00
parent b62ad97e59
commit d70df99f7b
3 changed files with 124 additions and 34 deletions

View File

@ -34,18 +34,18 @@ type WorkerRelay struct {
relayManager relayClient.ManagerService relayManager relayClient.ManagerService
conn WorkerRelayCallbacks conn WorkerRelayCallbacks
ctx context.Context
ctxCancel context.CancelFunc ctxCancel context.CancelFunc
} }
func NewWorkerRelay(ctx context.Context, log *log.Entry, config ConnConfig, relayManager relayClient.ManagerService, callbacks WorkerRelayCallbacks) *WorkerRelay { func NewWorkerRelay(ctx context.Context, log *log.Entry, config ConnConfig, relayManager relayClient.ManagerService, callbacks WorkerRelayCallbacks) *WorkerRelay {
return &WorkerRelay{ r := &WorkerRelay{
parentCtx: ctx, parentCtx: ctx,
log: log, log: log,
config: config, config: config,
relayManager: relayManager, relayManager: relayManager,
conn: callbacks, conn: callbacks,
} }
return r
} }
func (w *WorkerRelay) OnNewOffer(remoteOfferAnswer *OfferAnswer) { func (w *WorkerRelay) OnNewOffer(remoteOfferAnswer *OfferAnswer) {
@ -63,7 +63,7 @@ func (w *WorkerRelay) OnNewOffer(remoteOfferAnswer *OfferAnswer) {
srv := w.preferredRelayServer(currentRelayAddress, remoteOfferAnswer.RelaySrvAddress) srv := w.preferredRelayServer(currentRelayAddress, remoteOfferAnswer.RelaySrvAddress)
relayedConn, err := w.relayManager.OpenConn(srv, w.config.Key, w.disconnected) relayedConn, err := w.relayManager.OpenConn(srv, w.config.Key)
if err != nil { if err != nil {
// todo handle all type errors // todo handle all type errors
if errors.Is(err, relayClient.ErrConnAlreadyExists) { if errors.Is(err, relayClient.ErrConnAlreadyExists) {
@ -74,11 +74,20 @@ func (w *WorkerRelay) OnNewOffer(remoteOfferAnswer *OfferAnswer) {
return return
} }
w.ctx, w.ctxCancel = context.WithCancel(w.parentCtx) ctx, ctxCancel := context.WithCancel(w.parentCtx)
w.ctxCancel = ctxCancel
go w.wgStateCheck(relayedConn) err = w.relayManager.AddCloseListener(srv, w.disconnected)
if err != nil {
log.Errorf("failed to add close listener: %s", err)
_ = relayedConn.Close()
ctxCancel()
return
}
w.log.Debugf("Relay connection established with %s", srv) go w.wgStateCheck(ctx, relayedConn)
w.log.Debugf("peer conn opened via Relay: %s", srv)
go w.conn.OnConnReady(RelayConnInfo{ go w.conn.OnConnReady(RelayConnInfo{
relayedConn: relayedConn, relayedConn: relayedConn,
rosenpassPubKey: remoteOfferAnswer.RosenpassPubKey, rosenpassPubKey: remoteOfferAnswer.RosenpassPubKey,
@ -99,7 +108,7 @@ func (w *WorkerRelay) RelayIsSupportedLocally() bool {
} }
// wgStateCheck help to check the state of the wireguard handshake and relay connection // wgStateCheck help to check the state of the wireguard handshake and relay connection
func (w *WorkerRelay) wgStateCheck(conn net.Conn) { func (w *WorkerRelay) wgStateCheck(ctx context.Context, conn net.Conn) {
timer := time.NewTimer(wgHandshakeOvertime) timer := time.NewTimer(wgHandshakeOvertime)
defer timer.Stop() defer timer.Stop()
for { for {
@ -120,7 +129,7 @@ func (w *WorkerRelay) wgStateCheck(conn net.Conn) {
} }
resetTime := time.Until(lastHandshake.Add(wgHandshakeOvertime + wgHandshakePeriod)) resetTime := time.Until(lastHandshake.Add(wgHandshakeOvertime + wgHandshakePeriod))
timer.Reset(resetTime) timer.Reset(resetTime)
case <-w.ctx.Done(): case <-ctx.Done():
return return
} }
} }
@ -149,6 +158,8 @@ func (w *WorkerRelay) wgState() (time.Time, error) {
} }
func (w *WorkerRelay) disconnected() { func (w *WorkerRelay) disconnected() {
if w.ctxCancel != nil {
w.ctxCancel() w.ctxCancel()
}
w.conn.OnDisconnected() w.conn.OnDisconnected()
} }

View File

@ -1,9 +1,11 @@
package client package client
import ( import (
"container/list"
"context" "context"
"fmt" "fmt"
"net" "net"
"reflect"
"sync" "sync"
"time" "time"
@ -30,9 +32,12 @@ func NewRelayTrack() *RelayTrack {
return &RelayTrack{} return &RelayTrack{}
} }
type OnServerCloseListener func()
type ManagerService interface { type ManagerService interface {
Serve() error Serve() error
OpenConn(serverAddress, peerKey string, onClosedListener func()) (net.Conn, error) OpenConn(serverAddress, peerKey string) (net.Conn, error)
AddCloseListener(serverAddress string, onClosedListener OnServerCloseListener) error
RelayInstanceAddress() (string, error) RelayInstanceAddress() (string, error)
ServerURL() string ServerURL() string
HasRelayAddress() bool HasRelayAddress() bool
@ -57,7 +62,7 @@ type Manager struct {
relayClients map[string]*RelayTrack relayClients map[string]*RelayTrack
relayClientsMutex sync.RWMutex relayClientsMutex sync.RWMutex
onDisconnectedListeners map[string]map[*func()]struct{} onDisconnectedListeners map[string]*list.List
listenerLock sync.Mutex listenerLock sync.Mutex
} }
@ -68,7 +73,7 @@ func NewManager(ctx context.Context, serverURL string, peerID string) *Manager {
peerID: peerID, peerID: peerID,
tokenStore: &relayAuth.TokenStore{}, tokenStore: &relayAuth.TokenStore{},
relayClients: make(map[string]*RelayTrack), relayClients: make(map[string]*RelayTrack),
onDisconnectedListeners: make(map[string]map[*func()]struct{}), onDisconnectedListeners: make(map[string]*list.List),
} }
} }
@ -97,7 +102,7 @@ func (m *Manager) Serve() error {
// OpenConn opens a connection to the given peer key. If the peer is on the same relay server, the connection will be // OpenConn opens a connection to the given peer key. If the peer is on the same relay server, the connection will be
// established via the relay server. If the peer is on a different relay server, the manager will establish a new // established via the relay server. If the peer is on a different relay server, the manager will establish a new
// connection to the relay server. // connection to the relay server.
func (m *Manager) OpenConn(serverAddress, peerKey string, onClosedListener func()) (net.Conn, error) { func (m *Manager) OpenConn(serverAddress, peerKey string) (net.Conn, error) {
if m.relayClient == nil { if m.relayClient == nil {
return nil, errRelayClientNotConnected return nil, errRelayClientNotConnected
} }
@ -121,19 +126,23 @@ func (m *Manager) OpenConn(serverAddress, peerKey string, onClosedListener func(
return nil, err return nil, err
} }
if onClosedListener != nil { return netConn, err
}
func (m *Manager) AddCloseListener(serverAddress string, onClosedListener OnServerCloseListener) error {
foreign, err := m.isForeignServer(serverAddress)
if err != nil {
return err
}
var listenerAddr string var listenerAddr string
if foreign { if foreign {
m.addListener(serverAddress, onClosedListener)
listenerAddr = serverAddress listenerAddr = serverAddress
} else { } else {
listenerAddr = m.serverURL listenerAddr = m.serverURL
} }
m.addListener(listenerAddr, onClosedListener) m.addListener(listenerAddr, onClosedListener)
return nil
}
return netConn, err
} }
// RelayInstanceAddress returns the address of the permanent relay server. It could change if the network connection is lost. // RelayInstanceAddress returns the address of the permanent relay server. It could change if the network connection is lost.
@ -265,14 +274,19 @@ func (m *Manager) cleanUpUnusedRelays() {
} }
} }
func (m *Manager) addListener(serverAddress string, onClosedListener func()) { func (m *Manager) addListener(serverAddress string, onClosedListener OnServerCloseListener) {
m.listenerLock.Lock() m.listenerLock.Lock()
defer m.listenerLock.Unlock() defer m.listenerLock.Unlock()
l, ok := m.onDisconnectedListeners[serverAddress] l, ok := m.onDisconnectedListeners[serverAddress]
if !ok { if !ok {
l = make(map[*func()]struct{}) l = list.New()
} }
l[&onClosedListener] = struct{}{} for e := l.Front(); e != nil; e = e.Next() {
if reflect.ValueOf(e.Value).Pointer() == reflect.ValueOf(onClosedListener).Pointer() {
return
}
}
l.PushBack(onClosedListener)
m.onDisconnectedListeners[serverAddress] = l m.onDisconnectedListeners[serverAddress] = l
} }
@ -284,8 +298,8 @@ func (m *Manager) notifyOnDisconnectListeners(serverAddress string) {
if !ok { if !ok {
return return
} }
for f := range l { for e := l.Front(); e != nil; e = e.Next() {
go (*f)() go e.Value.(OnServerCloseListener)()
} }
delete(m.onDisconnectedListeners, serverAddress) delete(m.onDisconnectedListeners, serverAddress)
} }

View File

@ -87,11 +87,11 @@ func TestForeignConn(t *testing.T) {
if err != nil { if err != nil {
t.Fatalf("failed to get relay address: %s", err) t.Fatalf("failed to get relay address: %s", err)
} }
connAliceToBob, err := clientAlice.OpenConn(bobsSrvAddr, idBob, nil) connAliceToBob, err := clientAlice.OpenConn(bobsSrvAddr, idBob)
if err != nil { if err != nil {
t.Fatalf("failed to bind channel: %s", err) t.Fatalf("failed to bind channel: %s", err)
} }
connBobToAlice, err := clientBob.OpenConn(bobsSrvAddr, idAlice, nil) connBobToAlice, err := clientBob.OpenConn(bobsSrvAddr, idAlice)
if err != nil { if err != nil {
t.Fatalf("failed to bind channel: %s", err) t.Fatalf("failed to bind channel: %s", err)
} }
@ -187,7 +187,7 @@ func TestForeginConnClose(t *testing.T) {
if err != nil { if err != nil {
t.Fatalf("failed to serve manager: %s", err) t.Fatalf("failed to serve manager: %s", err)
} }
conn, err := mgr.OpenConn(toURL(srvCfg2), "anotherpeer", nil) conn, err := mgr.OpenConn(toURL(srvCfg2), "anotherpeer")
if err != nil { if err != nil {
t.Fatalf("failed to bind channel: %s", err) t.Fatalf("failed to bind channel: %s", err)
} }
@ -269,7 +269,7 @@ func TestForeginAutoClose(t *testing.T) {
} }
t.Log("open connection to another peer") t.Log("open connection to another peer")
conn, err := mgr.OpenConn(toURL(srvCfg2), "anotherpeer", nil) conn, err := mgr.OpenConn(toURL(srvCfg2), "anotherpeer")
if err != nil { if err != nil {
t.Fatalf("failed to bind channel: %s", err) t.Fatalf("failed to bind channel: %s", err)
} }
@ -330,7 +330,7 @@ func TestAutoReconnect(t *testing.T) {
if err != nil { if err != nil {
t.Errorf("failed to get relay address: %s", err) t.Errorf("failed to get relay address: %s", err)
} }
conn, err := clientAlice.OpenConn(ra, "bob", nil) conn, err := clientAlice.OpenConn(ra, "bob")
if err != nil { if err != nil {
t.Errorf("failed to bind channel: %s", err) t.Errorf("failed to bind channel: %s", err)
} }
@ -348,12 +348,77 @@ func TestAutoReconnect(t *testing.T) {
time.Sleep(reconnectingTimeout + 1*time.Second) time.Sleep(reconnectingTimeout + 1*time.Second)
log.Infof("reopent the connection") log.Infof("reopent the connection")
_, err = clientAlice.OpenConn(ra, "bob", nil) _, err = clientAlice.OpenConn(ra, "bob")
if err != nil { if err != nil {
t.Errorf("failed to open channel: %s", err) t.Errorf("failed to open channel: %s", err)
} }
} }
func TestNotifierDoubleAdd(t *testing.T) {
ctx := context.Background()
srvCfg1 := server.ListenerConfig{
Address: "localhost:1234",
}
srv1, err := server.NewServer(otel.Meter(""), srvCfg1.Address, false, av)
if err != nil {
t.Fatalf("failed to create server: %s", err)
}
errChan := make(chan error, 1)
go func() {
err := srv1.Listen(srvCfg1)
if err != nil {
errChan <- err
}
}()
defer func() {
err := srv1.Close()
if err != nil {
t.Errorf("failed to close server: %s", err)
}
}()
if err := waitForServerToStart(errChan); err != nil {
t.Fatalf("failed to start server: %s", err)
}
idAlice := "alice"
log.Debugf("connect by alice")
mCtx, cancel := context.WithCancel(ctx)
defer cancel()
clientAlice := NewManager(mCtx, toURL(srvCfg1), idAlice)
err = clientAlice.Serve()
if err != nil {
t.Fatalf("failed to serve manager: %s", err)
}
conn1, err := clientAlice.OpenConn(clientAlice.ServerURL(), "idBob")
if err != nil {
t.Fatalf("failed to bind channel: %s", err)
}
fnCloseListener := OnServerCloseListener(func() {
log.Infof("close listener")
})
err = clientAlice.AddCloseListener(clientAlice.ServerURL(), fnCloseListener)
if err != nil {
t.Fatalf("failed to add close listener: %s", err)
}
err = clientAlice.AddCloseListener(clientAlice.ServerURL(), fnCloseListener)
if err != nil {
t.Fatalf("failed to add close listener: %s", err)
}
err = conn1.Close()
if err != nil {
t.Errorf("failed to close connection: %s", err)
}
}
func toURL(address server.ListenerConfig) string { func toURL(address server.ListenerConfig) string {
return "rel://" + address.Address return "rel://" + address.Address
} }