diff --git a/client/internal/peer/conn.go b/client/internal/peer/conn.go index 092356b3d..bc965de13 100644 --- a/client/internal/peer/conn.go +++ b/client/internal/peer/conn.go @@ -97,7 +97,8 @@ type Conn struct { workerICE *WorkerICE workerRelay *WorkerRelay - connID nbnet.ConnectionID + connIDRelay nbnet.ConnectionID + connIDICE nbnet.ConnectionID beforeAddPeerHooks []nbnet.AddHookFunc afterRemovePeerHooks []nbnet.RemoveHookFunc @@ -234,14 +235,7 @@ func (conn *Conn) Close() { conn.log.Errorf("failed to remove wg endpoint: %v", err) } - if conn.connID != "" { - for _, hook := range conn.afterRemovePeerHooks { - if err := hook(conn.connID); err != nil { - conn.log.Errorf("After remove peer hook failed: %v", err) - } - } - conn.connID = "" - } + conn.freeUpConnID() if conn.evalStatus() == StatusConnected && conn.onDisconnected != nil { conn.onDisconnected(conn.config.WgConfig.RemoteKey, conn.config.WgConfig.AllowedIps) @@ -443,9 +437,9 @@ func (conn *Conn) iCEConnectionIsReady(priority ConnPriority, iceConnInfo ICECon endpointUdpAddr, _ := net.ResolveUDPAddr(endpoint.Network(), endpoint.String()) conn.log.Debugf("Conn resolved IP is %s for endopint %s", endpoint, endpointUdpAddr.IP) - conn.connID = nbnet.GenerateConnID() + conn.connIDICE = nbnet.GenerateConnID() for _, hook := range conn.beforeAddPeerHooks { - if err := hook(conn.connID, endpointUdpAddr.IP); err != nil { + if err := hook(conn.connIDICE, endpointUdpAddr.IP); err != nil { conn.log.Errorf("Before add peer hook failed: %v", err) } } @@ -541,9 +535,9 @@ func (conn *Conn) relayConnectionIsReady(rci RelayConnInfo) { } } - conn.connID = nbnet.GenerateConnID() + conn.connIDRelay = nbnet.GenerateConnID() for _, hook := range conn.beforeAddPeerHooks { - if err := hook(conn.connID, endpointUdpAddr.IP); err != nil { + if err := hook(conn.connIDRelay, endpointUdpAddr.IP); err != nil { conn.log.Errorf("Before add peer hook failed: %v", err) } } @@ -728,6 +722,26 @@ func (conn *Conn) isConnected() bool { return true } +func (conn *Conn) freeUpConnID() { + if conn.connIDRelay != "" { + for _, hook := range conn.afterRemovePeerHooks { + if err := hook(conn.connIDRelay); err != nil { + conn.log.Errorf("After remove peer hook failed: %v", err) + } + } + conn.connIDRelay = "" + } + + if conn.connIDICE != "" { + for _, hook := range conn.afterRemovePeerHooks { + if err := hook(conn.connIDICE); err != nil { + conn.log.Errorf("After remove peer hook failed: %v", err) + } + } + conn.connIDICE = "" + } +} + func isRosenpassEnabled(remoteRosenpassPubKey []byte) bool { return remoteRosenpassPubKey != nil }