diff --git a/client/internal/peer/worker_relay.go b/client/internal/peer/worker_relay.go index 0e341ed14..cf72e30e6 100644 --- a/client/internal/peer/worker_relay.go +++ b/client/internal/peer/worker_relay.go @@ -34,18 +34,18 @@ type WorkerRelay struct { relayManager relayClient.ManagerService conn WorkerRelayCallbacks - ctx context.Context ctxCancel context.CancelFunc } func NewWorkerRelay(ctx context.Context, log *log.Entry, config ConnConfig, relayManager relayClient.ManagerService, callbacks WorkerRelayCallbacks) *WorkerRelay { - return &WorkerRelay{ + r := &WorkerRelay{ parentCtx: ctx, log: log, config: config, relayManager: relayManager, conn: callbacks, } + return r } func (w *WorkerRelay) OnNewOffer(remoteOfferAnswer *OfferAnswer) { @@ -63,7 +63,7 @@ func (w *WorkerRelay) OnNewOffer(remoteOfferAnswer *OfferAnswer) { srv := w.preferredRelayServer(currentRelayAddress, remoteOfferAnswer.RelaySrvAddress) - relayedConn, err := w.relayManager.OpenConn(srv, w.config.Key, w.disconnected) + relayedConn, err := w.relayManager.OpenConn(srv, w.config.Key) if err != nil { // todo handle all type errors if errors.Is(err, relayClient.ErrConnAlreadyExists) { @@ -74,11 +74,20 @@ func (w *WorkerRelay) OnNewOffer(remoteOfferAnswer *OfferAnswer) { return } - w.ctx, w.ctxCancel = context.WithCancel(w.parentCtx) + ctx, ctxCancel := context.WithCancel(w.parentCtx) + w.ctxCancel = ctxCancel - go w.wgStateCheck(relayedConn) + err = w.relayManager.AddCloseListener(srv, w.disconnected) + if err != nil { + log.Errorf("failed to add close listener: %s", err) + _ = relayedConn.Close() + ctxCancel() + return + } - w.log.Debugf("Relay connection established with %s", srv) + go w.wgStateCheck(ctx, relayedConn) + + w.log.Debugf("peer conn opened via Relay: %s", srv) go w.conn.OnConnReady(RelayConnInfo{ relayedConn: relayedConn, rosenpassPubKey: remoteOfferAnswer.RosenpassPubKey, @@ -99,7 +108,7 @@ func (w *WorkerRelay) RelayIsSupportedLocally() bool { } // wgStateCheck help to check the state of the wireguard handshake and relay connection -func (w *WorkerRelay) wgStateCheck(conn net.Conn) { +func (w *WorkerRelay) wgStateCheck(ctx context.Context, conn net.Conn) { timer := time.NewTimer(wgHandshakeOvertime) defer timer.Stop() for { @@ -120,7 +129,7 @@ func (w *WorkerRelay) wgStateCheck(conn net.Conn) { } resetTime := time.Until(lastHandshake.Add(wgHandshakeOvertime + wgHandshakePeriod)) timer.Reset(resetTime) - case <-w.ctx.Done(): + case <-ctx.Done(): return } } @@ -149,6 +158,8 @@ func (w *WorkerRelay) wgState() (time.Time, error) { } func (w *WorkerRelay) disconnected() { - w.ctxCancel() + if w.ctxCancel != nil { + w.ctxCancel() + } w.conn.OnDisconnected() } diff --git a/relay/client/manager.go b/relay/client/manager.go index e75455e47..bb523822a 100644 --- a/relay/client/manager.go +++ b/relay/client/manager.go @@ -1,9 +1,11 @@ package client import ( + "container/list" "context" "fmt" "net" + "reflect" "sync" "time" @@ -30,9 +32,12 @@ func NewRelayTrack() *RelayTrack { return &RelayTrack{} } +type OnServerCloseListener func() + type ManagerService interface { Serve() error - OpenConn(serverAddress, peerKey string, onClosedListener func()) (net.Conn, error) + OpenConn(serverAddress, peerKey string) (net.Conn, error) + AddCloseListener(serverAddress string, onClosedListener OnServerCloseListener) error RelayInstanceAddress() (string, error) ServerURL() string HasRelayAddress() bool @@ -57,7 +62,7 @@ type Manager struct { relayClients map[string]*RelayTrack relayClientsMutex sync.RWMutex - onDisconnectedListeners map[string]map[*func()]struct{} + onDisconnectedListeners map[string]*list.List listenerLock sync.Mutex } @@ -68,7 +73,7 @@ func NewManager(ctx context.Context, serverURL string, peerID string) *Manager { peerID: peerID, tokenStore: &relayAuth.TokenStore{}, relayClients: make(map[string]*RelayTrack), - onDisconnectedListeners: make(map[string]map[*func()]struct{}), + onDisconnectedListeners: make(map[string]*list.List), } } @@ -97,7 +102,7 @@ func (m *Manager) Serve() error { // OpenConn opens a connection to the given peer key. If the peer is on the same relay server, the connection will be // established via the relay server. If the peer is on a different relay server, the manager will establish a new // connection to the relay server. -func (m *Manager) OpenConn(serverAddress, peerKey string, onClosedListener func()) (net.Conn, error) { +func (m *Manager) OpenConn(serverAddress, peerKey string) (net.Conn, error) { if m.relayClient == nil { return nil, errRelayClientNotConnected } @@ -121,19 +126,23 @@ func (m *Manager) OpenConn(serverAddress, peerKey string, onClosedListener func( return nil, err } - if onClosedListener != nil { - var listenerAddr string - if foreign { - m.addListener(serverAddress, onClosedListener) - listenerAddr = serverAddress - } else { - listenerAddr = m.serverURL - } - m.addListener(listenerAddr, onClosedListener) + return netConn, err +} +func (m *Manager) AddCloseListener(serverAddress string, onClosedListener OnServerCloseListener) error { + foreign, err := m.isForeignServer(serverAddress) + if err != nil { + return err } - return netConn, err + var listenerAddr string + if foreign { + listenerAddr = serverAddress + } else { + listenerAddr = m.serverURL + } + m.addListener(listenerAddr, onClosedListener) + return nil } // RelayInstanceAddress returns the address of the permanent relay server. It could change if the network connection is lost. @@ -265,14 +274,19 @@ func (m *Manager) cleanUpUnusedRelays() { } } -func (m *Manager) addListener(serverAddress string, onClosedListener func()) { +func (m *Manager) addListener(serverAddress string, onClosedListener OnServerCloseListener) { m.listenerLock.Lock() defer m.listenerLock.Unlock() l, ok := m.onDisconnectedListeners[serverAddress] if !ok { - l = make(map[*func()]struct{}) + l = list.New() } - l[&onClosedListener] = struct{}{} + for e := l.Front(); e != nil; e = e.Next() { + if reflect.ValueOf(e.Value).Pointer() == reflect.ValueOf(onClosedListener).Pointer() { + return + } + } + l.PushBack(onClosedListener) m.onDisconnectedListeners[serverAddress] = l } @@ -284,8 +298,8 @@ func (m *Manager) notifyOnDisconnectListeners(serverAddress string) { if !ok { return } - for f := range l { - go (*f)() + for e := l.Front(); e != nil; e = e.Next() { + go e.Value.(OnServerCloseListener)() } delete(m.onDisconnectedListeners, serverAddress) } diff --git a/relay/client/manager_test.go b/relay/client/manager_test.go index 928171175..cfec3f54e 100644 --- a/relay/client/manager_test.go +++ b/relay/client/manager_test.go @@ -87,11 +87,11 @@ func TestForeignConn(t *testing.T) { if err != nil { t.Fatalf("failed to get relay address: %s", err) } - connAliceToBob, err := clientAlice.OpenConn(bobsSrvAddr, idBob, nil) + connAliceToBob, err := clientAlice.OpenConn(bobsSrvAddr, idBob) if err != nil { t.Fatalf("failed to bind channel: %s", err) } - connBobToAlice, err := clientBob.OpenConn(bobsSrvAddr, idAlice, nil) + connBobToAlice, err := clientBob.OpenConn(bobsSrvAddr, idAlice) if err != nil { t.Fatalf("failed to bind channel: %s", err) } @@ -187,7 +187,7 @@ func TestForeginConnClose(t *testing.T) { if err != nil { t.Fatalf("failed to serve manager: %s", err) } - conn, err := mgr.OpenConn(toURL(srvCfg2), "anotherpeer", nil) + conn, err := mgr.OpenConn(toURL(srvCfg2), "anotherpeer") if err != nil { t.Fatalf("failed to bind channel: %s", err) } @@ -269,7 +269,7 @@ func TestForeginAutoClose(t *testing.T) { } t.Log("open connection to another peer") - conn, err := mgr.OpenConn(toURL(srvCfg2), "anotherpeer", nil) + conn, err := mgr.OpenConn(toURL(srvCfg2), "anotherpeer") if err != nil { t.Fatalf("failed to bind channel: %s", err) } @@ -330,7 +330,7 @@ func TestAutoReconnect(t *testing.T) { if err != nil { t.Errorf("failed to get relay address: %s", err) } - conn, err := clientAlice.OpenConn(ra, "bob", nil) + conn, err := clientAlice.OpenConn(ra, "bob") if err != nil { t.Errorf("failed to bind channel: %s", err) } @@ -348,12 +348,77 @@ func TestAutoReconnect(t *testing.T) { time.Sleep(reconnectingTimeout + 1*time.Second) log.Infof("reopent the connection") - _, err = clientAlice.OpenConn(ra, "bob", nil) + _, err = clientAlice.OpenConn(ra, "bob") if err != nil { t.Errorf("failed to open channel: %s", err) } } +func TestNotifierDoubleAdd(t *testing.T) { + ctx := context.Background() + + srvCfg1 := server.ListenerConfig{ + Address: "localhost:1234", + } + srv1, err := server.NewServer(otel.Meter(""), srvCfg1.Address, false, av) + if err != nil { + t.Fatalf("failed to create server: %s", err) + } + errChan := make(chan error, 1) + go func() { + err := srv1.Listen(srvCfg1) + if err != nil { + errChan <- err + } + }() + + defer func() { + err := srv1.Close() + if err != nil { + t.Errorf("failed to close server: %s", err) + } + }() + + if err := waitForServerToStart(errChan); err != nil { + t.Fatalf("failed to start server: %s", err) + } + + idAlice := "alice" + log.Debugf("connect by alice") + mCtx, cancel := context.WithCancel(ctx) + defer cancel() + clientAlice := NewManager(mCtx, toURL(srvCfg1), idAlice) + err = clientAlice.Serve() + if err != nil { + t.Fatalf("failed to serve manager: %s", err) + } + + conn1, err := clientAlice.OpenConn(clientAlice.ServerURL(), "idBob") + if err != nil { + t.Fatalf("failed to bind channel: %s", err) + } + + fnCloseListener := OnServerCloseListener(func() { + log.Infof("close listener") + }) + + err = clientAlice.AddCloseListener(clientAlice.ServerURL(), fnCloseListener) + if err != nil { + t.Fatalf("failed to add close listener: %s", err) + } + + err = clientAlice.AddCloseListener(clientAlice.ServerURL(), fnCloseListener) + if err != nil { + t.Fatalf("failed to add close listener: %s", err) + } + + err = conn1.Close() + if err != nil { + t.Errorf("failed to close connection: %s", err) + } + +} + func toURL(address server.ListenerConfig) string { return "rel://" + address.Address }