Fix and extend test

This commit is contained in:
Zoltán Papp 2024-06-19 09:40:43 +02:00
parent e26e2c3a75
commit 24f71bc68a
2 changed files with 26 additions and 53 deletions

View File

@ -75,7 +75,6 @@ type Conn struct {
signaler *Signaler signaler *Signaler
allowedIPsIP string allowedIPsIP string
handshaker *Handshaker handshaker *Handshaker
closeCh chan struct{}
onConnected func(remoteWireGuardKey string, remoteRosenpassPubKey []byte, wireGuardIP string, remoteRosenpassAddr string) onConnected func(remoteWireGuardKey string, remoteRosenpassPubKey []byte, wireGuardIP string, remoteRosenpassAddr string)
onDisconnected func(remotePeer string, wgIP string) onDisconnected func(remotePeer string, wgIP string)
@ -116,7 +115,6 @@ func NewConn(engineCtx context.Context, config ConnConfig, statusRecorder *Statu
handshaker: NewHandshaker(ctx, config, signaler), handshaker: NewHandshaker(ctx, config, signaler),
statusRelay: StatusDisconnected, statusRelay: StatusDisconnected,
statusICE: StatusDisconnected, statusICE: StatusDisconnected,
closeCh: make(chan struct{}),
} }
conn.workerICE = NewWorkerICE(ctx, conn.log, config, config.ICEConfig, signaler, iFaceDiscover, statusRecorder, conn.iCEConnectionIsReady, conn.onWorkerICEStateChanged, conn.doHandshake) conn.workerICE = NewWorkerICE(ctx, conn.log, config, config.ICEConfig, signaler, iFaceDiscover, statusRecorder, conn.iCEConnectionIsReady, conn.onWorkerICEStateChanged, conn.doHandshake)
conn.workerRelay = NewWorkerRelay(ctx, conn.log, relayManager, config, conn.relayConnectionIsReady, conn.onWorkerRelayStateChanged, conn.doHandshake) conn.workerRelay = NewWorkerRelay(ctx, conn.log, relayManager, config, conn.relayConnectionIsReady, conn.onWorkerRelayStateChanged, conn.doHandshake)

View File

@ -7,7 +7,6 @@ import (
"time" "time"
"github.com/magiconair/properties/assert" "github.com/magiconair/properties/assert"
"github.com/pion/stun/v2"
"github.com/netbirdio/netbird/client/internal/stdnet" "github.com/netbirdio/netbird/client/internal/stdnet"
"github.com/netbirdio/netbird/client/internal/wgproxy" "github.com/netbirdio/netbird/client/internal/wgproxy"
@ -15,12 +14,13 @@ import (
) )
var connConf = ConnConfig{ var connConf = ConnConfig{
Key: "LLHf3Ma6z6mdLbriAJbqhX7+nM/B71lgw2+91q3LfhU=", Key: "LLHf3Ma6z6mdLbriAJbqhX7+nM/B71lgw2+91q3LfhU=",
LocalKey: "RRHf3Ma6z6mdLbriAJbqhX7+nM/B71lgw2+91q3LfhU=", LocalKey: "RRHf3Ma6z6mdLbriAJbqhX7+nM/B71lgw2+91q3LfhU=",
StunTurn: []*stun.URI{}, Timeout: time.Second,
InterfaceBlackList: nil, LocalWgPort: 51820,
Timeout: time.Second, ICEConfig: ICEConfig{
LocalWgPort: 51820, InterfaceBlackList: nil,
},
} }
func TestNewConn_interfaceFilter(t *testing.T) { func TestNewConn_interfaceFilter(t *testing.T) {
@ -40,7 +40,7 @@ func TestConn_GetKey(t *testing.T) {
defer func() { defer func() {
_ = wgProxyFactory.Free() _ = wgProxyFactory.Free()
}() }()
conn, err := NewConn(connConf, nil, wgProxyFactory, nil, nil, nil) conn, err := NewConn(context.Background(), connConf, nil, wgProxyFactory, nil, nil, nil)
if err != nil { if err != nil {
return return
} }
@ -55,7 +55,7 @@ func TestConn_OnRemoteOffer(t *testing.T) {
defer func() { defer func() {
_ = wgProxyFactory.Free() _ = wgProxyFactory.Free()
}() }()
conn, err := NewConn(connConf, NewRecorder("https://mgm"), wgProxyFactory, nil, nil, nil) conn, err := NewConn(context.Background(), connConf, NewRecorder("https://mgm"), wgProxyFactory, nil, nil, nil)
if err != nil { if err != nil {
return return
} }
@ -63,7 +63,7 @@ func TestConn_OnRemoteOffer(t *testing.T) {
wg := sync.WaitGroup{} wg := sync.WaitGroup{}
wg.Add(2) wg.Add(2)
go func() { go func() {
<-conn.remoteOffersCh <-conn.handshaker.remoteOffersCh
wg.Done() wg.Done()
}() }()
@ -92,7 +92,7 @@ func TestConn_OnRemoteAnswer(t *testing.T) {
defer func() { defer func() {
_ = wgProxyFactory.Free() _ = wgProxyFactory.Free()
}() }()
conn, err := NewConn(connConf, NewRecorder("https://mgm"), wgProxyFactory, nil, nil, nil) conn, err := NewConn(context.Background(), connConf, NewRecorder("https://mgm"), wgProxyFactory, nil, nil, nil)
if err != nil { if err != nil {
return return
} }
@ -100,7 +100,7 @@ func TestConn_OnRemoteAnswer(t *testing.T) {
wg := sync.WaitGroup{} wg := sync.WaitGroup{}
wg.Add(2) wg.Add(2)
go func() { go func() {
<-conn.remoteAnswerCh <-conn.handshaker.remoteAnswerCh
wg.Done() wg.Done()
}() }()
@ -128,58 +128,33 @@ func TestConn_Status(t *testing.T) {
defer func() { defer func() {
_ = wgProxyFactory.Free() _ = wgProxyFactory.Free()
}() }()
conn, err := NewConn(connConf, NewRecorder("https://mgm"), wgProxyFactory, nil, nil, nil) conn, err := NewConn(context.Background(), connConf, NewRecorder("https://mgm"), wgProxyFactory, nil, nil, nil)
if err != nil { if err != nil {
return return
} }
tables := []struct { tables := []struct {
name string name string
status ConnStatus statusIce ConnStatus
want ConnStatus statusRelay ConnStatus
want ConnStatus
}{ }{
{"StatusConnected", StatusConnected, StatusConnected}, {"StatusConnected", StatusConnected, StatusConnected, StatusConnected},
{"StatusDisconnected", StatusDisconnected, StatusDisconnected}, {"StatusDisconnected", StatusDisconnected, StatusDisconnected, StatusDisconnected},
{"StatusConnecting", StatusConnecting, StatusConnecting}, {"StatusConnecting", StatusConnecting, StatusConnecting, StatusConnecting},
{"StatusConnectingIce", StatusConnecting, StatusDisconnected, StatusConnecting},
{"StatusConnectingIceAlternative", StatusConnecting, StatusConnected, StatusConnected},
{"StatusConnectingRelay", StatusDisconnected, StatusConnecting, StatusConnecting},
{"StatusConnectingRelayAlternative", StatusConnected, StatusConnecting, StatusConnected},
} }
for _, table := range tables { for _, table := range tables {
t.Run(table.name, func(t *testing.T) { t.Run(table.name, func(t *testing.T) {
conn.status = table.status conn.statusICE = table.statusIce
conn.statusRelay = table.statusRelay
got := conn.Status() got := conn.Status()
assert.Equal(t, got, table.want, "they should be equal") assert.Equal(t, got, table.want, "they should be equal")
}) })
} }
} }
func TestConn_Close(t *testing.T) {
wgProxyFactory := wgproxy.NewFactory(context.Background(), connConf.LocalWgPort)
defer func() {
_ = wgProxyFactory.Free()
}()
conn, err := NewConn(connConf, NewRecorder("https://mgm"), wgProxyFactory, nil, nil, nil)
if err != nil {
return
}
wg := sync.WaitGroup{}
wg.Add(1)
go func() {
<-conn.closeCh
wg.Done()
}()
go func() {
for {
err := conn.Close()
if err != nil {
continue
} else {
return
}
}
}()
wg.Wait()
}