diff --git a/client/internal/peer/conn.go b/client/internal/peer/conn.go index 268071593..389b9eb64 100644 --- a/client/internal/peer/conn.go +++ b/client/internal/peer/conn.go @@ -75,7 +75,6 @@ type Conn struct { signaler *Signaler allowedIPsIP string handshaker *Handshaker - closeCh chan struct{} onConnected func(remoteWireGuardKey string, remoteRosenpassPubKey []byte, wireGuardIP string, remoteRosenpassAddr string) onDisconnected func(remotePeer string, wgIP string) @@ -116,7 +115,6 @@ func NewConn(engineCtx context.Context, config ConnConfig, statusRecorder *Statu handshaker: NewHandshaker(ctx, config, signaler), statusRelay: StatusDisconnected, statusICE: StatusDisconnected, - closeCh: make(chan struct{}), } conn.workerICE = NewWorkerICE(ctx, conn.log, config, config.ICEConfig, signaler, iFaceDiscover, statusRecorder, conn.iCEConnectionIsReady, conn.onWorkerICEStateChanged, conn.doHandshake) conn.workerRelay = NewWorkerRelay(ctx, conn.log, relayManager, config, conn.relayConnectionIsReady, conn.onWorkerRelayStateChanged, conn.doHandshake) diff --git a/client/internal/peer/conn_test.go b/client/internal/peer/conn_test.go index 10f1ac5ef..7dbb06b5a 100644 --- a/client/internal/peer/conn_test.go +++ b/client/internal/peer/conn_test.go @@ -7,7 +7,6 @@ import ( "time" "github.com/magiconair/properties/assert" - "github.com/pion/stun/v2" "github.com/netbirdio/netbird/client/internal/stdnet" "github.com/netbirdio/netbird/client/internal/wgproxy" @@ -15,12 +14,13 @@ import ( ) var connConf = ConnConfig{ - Key: "LLHf3Ma6z6mdLbriAJbqhX7+nM/B71lgw2+91q3LfhU=", - LocalKey: "RRHf3Ma6z6mdLbriAJbqhX7+nM/B71lgw2+91q3LfhU=", - StunTurn: []*stun.URI{}, - InterfaceBlackList: nil, - Timeout: time.Second, - LocalWgPort: 51820, + Key: "LLHf3Ma6z6mdLbriAJbqhX7+nM/B71lgw2+91q3LfhU=", + LocalKey: "RRHf3Ma6z6mdLbriAJbqhX7+nM/B71lgw2+91q3LfhU=", + Timeout: time.Second, + LocalWgPort: 51820, + ICEConfig: ICEConfig{ + InterfaceBlackList: nil, + }, } func TestNewConn_interfaceFilter(t *testing.T) { @@ -40,7 +40,7 @@ func TestConn_GetKey(t *testing.T) { defer func() { _ = wgProxyFactory.Free() }() - conn, err := NewConn(connConf, nil, wgProxyFactory, nil, nil, nil) + conn, err := NewConn(context.Background(), connConf, nil, wgProxyFactory, nil, nil, nil) if err != nil { return } @@ -55,7 +55,7 @@ func TestConn_OnRemoteOffer(t *testing.T) { defer func() { _ = wgProxyFactory.Free() }() - conn, err := NewConn(connConf, NewRecorder("https://mgm"), wgProxyFactory, nil, nil, nil) + conn, err := NewConn(context.Background(), connConf, NewRecorder("https://mgm"), wgProxyFactory, nil, nil, nil) if err != nil { return } @@ -63,7 +63,7 @@ func TestConn_OnRemoteOffer(t *testing.T) { wg := sync.WaitGroup{} wg.Add(2) go func() { - <-conn.remoteOffersCh + <-conn.handshaker.remoteOffersCh wg.Done() }() @@ -92,7 +92,7 @@ func TestConn_OnRemoteAnswer(t *testing.T) { defer func() { _ = wgProxyFactory.Free() }() - conn, err := NewConn(connConf, NewRecorder("https://mgm"), wgProxyFactory, nil, nil, nil) + conn, err := NewConn(context.Background(), connConf, NewRecorder("https://mgm"), wgProxyFactory, nil, nil, nil) if err != nil { return } @@ -100,7 +100,7 @@ func TestConn_OnRemoteAnswer(t *testing.T) { wg := sync.WaitGroup{} wg.Add(2) go func() { - <-conn.remoteAnswerCh + <-conn.handshaker.remoteAnswerCh wg.Done() }() @@ -128,58 +128,33 @@ func TestConn_Status(t *testing.T) { defer func() { _ = wgProxyFactory.Free() }() - conn, err := NewConn(connConf, NewRecorder("https://mgm"), wgProxyFactory, nil, nil, nil) + conn, err := NewConn(context.Background(), connConf, NewRecorder("https://mgm"), wgProxyFactory, nil, nil, nil) if err != nil { return } tables := []struct { - name string - status ConnStatus - want ConnStatus + name string + statusIce ConnStatus + statusRelay ConnStatus + want ConnStatus }{ - {"StatusConnected", StatusConnected, StatusConnected}, - {"StatusDisconnected", StatusDisconnected, StatusDisconnected}, - {"StatusConnecting", StatusConnecting, StatusConnecting}, + {"StatusConnected", StatusConnected, StatusConnected, StatusConnected}, + {"StatusDisconnected", StatusDisconnected, StatusDisconnected, StatusDisconnected}, + {"StatusConnecting", StatusConnecting, StatusConnecting, StatusConnecting}, + {"StatusConnectingIce", StatusConnecting, StatusDisconnected, StatusConnecting}, + {"StatusConnectingIceAlternative", StatusConnecting, StatusConnected, StatusConnected}, + {"StatusConnectingRelay", StatusDisconnected, StatusConnecting, StatusConnecting}, + {"StatusConnectingRelayAlternative", StatusConnected, StatusConnecting, StatusConnected}, } for _, table := range tables { t.Run(table.name, func(t *testing.T) { - conn.status = table.status + conn.statusICE = table.statusIce + conn.statusRelay = table.statusRelay got := conn.Status() assert.Equal(t, got, table.want, "they should be equal") }) } } - -func TestConn_Close(t *testing.T) { - wgProxyFactory := wgproxy.NewFactory(context.Background(), connConf.LocalWgPort) - defer func() { - _ = wgProxyFactory.Free() - }() - conn, err := NewConn(connConf, NewRecorder("https://mgm"), wgProxyFactory, nil, nil, nil) - if err != nil { - return - } - - wg := sync.WaitGroup{} - wg.Add(1) - go func() { - <-conn.closeCh - wg.Done() - }() - - go func() { - for { - err := conn.Close() - if err != nil { - continue - } else { - return - } - } - }() - - wg.Wait() -}