mirror of
https://github.com/netbirdio/netbird.git
synced 2024-12-15 11:21:04 +01:00
81425872e1
- remove unused variable - rename Struct
253 lines
5.5 KiB
Go
253 lines
5.5 KiB
Go
package dialer
|
|
|
|
import (
|
|
"context"
|
|
"errors"
|
|
"net"
|
|
"testing"
|
|
"time"
|
|
|
|
"github.com/sirupsen/logrus"
|
|
)
|
|
|
|
type MockAddr struct {
|
|
network string
|
|
}
|
|
|
|
func (m *MockAddr) Network() string {
|
|
return m.network
|
|
}
|
|
|
|
func (m *MockAddr) String() string {
|
|
return "1.2.3.4"
|
|
}
|
|
|
|
// MockDialer is a mock implementation of DialeFn
|
|
type MockDialer struct {
|
|
dialFunc func(ctx context.Context, address string) (net.Conn, error)
|
|
protocolStr string
|
|
}
|
|
|
|
func (m *MockDialer) Dial(ctx context.Context, address string) (net.Conn, error) {
|
|
return m.dialFunc(ctx, address)
|
|
}
|
|
|
|
func (m *MockDialer) Protocol() string {
|
|
return m.protocolStr
|
|
}
|
|
|
|
// MockConn implements net.Conn for testing
|
|
type MockConn struct {
|
|
remoteAddr net.Addr
|
|
}
|
|
|
|
func (m *MockConn) Read(b []byte) (n int, err error) {
|
|
return 0, nil
|
|
}
|
|
|
|
func (m *MockConn) Write(b []byte) (n int, err error) {
|
|
return 0, nil
|
|
}
|
|
|
|
func (m *MockConn) Close() error {
|
|
return nil
|
|
}
|
|
|
|
func (m *MockConn) LocalAddr() net.Addr {
|
|
return nil
|
|
}
|
|
|
|
func (m *MockConn) RemoteAddr() net.Addr {
|
|
return m.remoteAddr
|
|
}
|
|
|
|
func (m *MockConn) SetDeadline(t time.Time) error {
|
|
return nil
|
|
}
|
|
|
|
func (m *MockConn) SetReadDeadline(t time.Time) error {
|
|
return nil
|
|
}
|
|
|
|
func (m *MockConn) SetWriteDeadline(t time.Time) error {
|
|
return nil
|
|
}
|
|
|
|
func TestRaceDialEmptyDialers(t *testing.T) {
|
|
logger := logrus.NewEntry(logrus.New())
|
|
serverURL := "test.server.com"
|
|
|
|
rd := NewRaceDial(logger, serverURL)
|
|
conn, err := rd.Dial()
|
|
if err == nil {
|
|
t.Errorf("Expected an error with empty dialers, got nil")
|
|
}
|
|
if conn != nil {
|
|
t.Errorf("Expected nil connection with empty dialers, got %v", conn)
|
|
}
|
|
}
|
|
|
|
func TestRaceDialSingleSuccessfulDialer(t *testing.T) {
|
|
logger := logrus.NewEntry(logrus.New())
|
|
serverURL := "test.server.com"
|
|
proto := "test-protocol"
|
|
|
|
mockConn := &MockConn{
|
|
remoteAddr: &MockAddr{network: proto},
|
|
}
|
|
|
|
mockDialer := &MockDialer{
|
|
dialFunc: func(ctx context.Context, address string) (net.Conn, error) {
|
|
return mockConn, nil
|
|
},
|
|
protocolStr: proto,
|
|
}
|
|
|
|
rd := NewRaceDial(logger, serverURL, mockDialer)
|
|
conn, err := rd.Dial()
|
|
if err != nil {
|
|
t.Errorf("Expected no error, got %v", err)
|
|
}
|
|
if conn == nil {
|
|
t.Errorf("Expected non-nil connection")
|
|
}
|
|
}
|
|
|
|
func TestRaceDialMultipleDialersWithOneSuccess(t *testing.T) {
|
|
logger := logrus.NewEntry(logrus.New())
|
|
serverURL := "test.server.com"
|
|
proto2 := "protocol2"
|
|
|
|
mockConn2 := &MockConn{
|
|
remoteAddr: &MockAddr{network: proto2},
|
|
}
|
|
|
|
mockDialer1 := &MockDialer{
|
|
dialFunc: func(ctx context.Context, address string) (net.Conn, error) {
|
|
return nil, errors.New("first dialer failed")
|
|
},
|
|
protocolStr: "proto1",
|
|
}
|
|
|
|
mockDialer2 := &MockDialer{
|
|
dialFunc: func(ctx context.Context, address string) (net.Conn, error) {
|
|
return mockConn2, nil
|
|
},
|
|
protocolStr: "proto2",
|
|
}
|
|
|
|
rd := NewRaceDial(logger, serverURL, mockDialer1, mockDialer2)
|
|
conn, err := rd.Dial()
|
|
if err != nil {
|
|
t.Errorf("Expected no error, got %v", err)
|
|
}
|
|
if conn.RemoteAddr().Network() != proto2 {
|
|
t.Errorf("Expected connection with protocol %s, got %s", proto2, conn.RemoteAddr().Network())
|
|
}
|
|
}
|
|
|
|
func TestRaceDialTimeout(t *testing.T) {
|
|
logger := logrus.NewEntry(logrus.New())
|
|
serverURL := "test.server.com"
|
|
|
|
connectionTimeout = 3 * time.Second
|
|
mockDialer := &MockDialer{
|
|
dialFunc: func(ctx context.Context, address string) (net.Conn, error) {
|
|
<-ctx.Done()
|
|
return nil, ctx.Err()
|
|
},
|
|
protocolStr: "proto1",
|
|
}
|
|
|
|
rd := NewRaceDial(logger, serverURL, mockDialer)
|
|
conn, err := rd.Dial()
|
|
if err == nil {
|
|
t.Errorf("Expected an error, got nil")
|
|
}
|
|
if conn != nil {
|
|
t.Errorf("Expected nil connection, got %v", conn)
|
|
}
|
|
}
|
|
|
|
func TestRaceDialAllDialersFail(t *testing.T) {
|
|
logger := logrus.NewEntry(logrus.New())
|
|
serverURL := "test.server.com"
|
|
|
|
mockDialer1 := &MockDialer{
|
|
dialFunc: func(ctx context.Context, address string) (net.Conn, error) {
|
|
return nil, errors.New("first dialer failed")
|
|
},
|
|
protocolStr: "protocol1",
|
|
}
|
|
|
|
mockDialer2 := &MockDialer{
|
|
dialFunc: func(ctx context.Context, address string) (net.Conn, error) {
|
|
return nil, errors.New("second dialer failed")
|
|
},
|
|
protocolStr: "protocol2",
|
|
}
|
|
|
|
rd := NewRaceDial(logger, serverURL, mockDialer1, mockDialer2)
|
|
conn, err := rd.Dial()
|
|
if err == nil {
|
|
t.Errorf("Expected an error, got nil")
|
|
}
|
|
if conn != nil {
|
|
t.Errorf("Expected nil connection, got %v", conn)
|
|
}
|
|
}
|
|
|
|
func TestRaceDialFirstSuccessfulDialerWins(t *testing.T) {
|
|
logger := logrus.NewEntry(logrus.New())
|
|
serverURL := "test.server.com"
|
|
proto1 := "protocol1"
|
|
proto2 := "protocol2"
|
|
|
|
mockConn1 := &MockConn{
|
|
remoteAddr: &MockAddr{network: proto1},
|
|
}
|
|
|
|
mockConn2 := &MockConn{
|
|
remoteAddr: &MockAddr{network: proto2},
|
|
}
|
|
|
|
mockDialer1 := &MockDialer{
|
|
dialFunc: func(ctx context.Context, address string) (net.Conn, error) {
|
|
time.Sleep(1 * time.Second)
|
|
return mockConn1, nil
|
|
},
|
|
protocolStr: proto1,
|
|
}
|
|
|
|
mock2err := make(chan error)
|
|
mockDialer2 := &MockDialer{
|
|
dialFunc: func(ctx context.Context, address string) (net.Conn, error) {
|
|
<-ctx.Done()
|
|
mock2err <- ctx.Err()
|
|
return mockConn2, ctx.Err()
|
|
},
|
|
protocolStr: proto2,
|
|
}
|
|
|
|
rd := NewRaceDial(logger, serverURL, mockDialer1, mockDialer2)
|
|
conn, err := rd.Dial()
|
|
if err != nil {
|
|
t.Errorf("Expected no error, got %v", err)
|
|
}
|
|
if conn == nil {
|
|
t.Errorf("Expected non-nil connection")
|
|
}
|
|
if conn != mockConn1 {
|
|
t.Errorf("Expected first connection, got %v", conn)
|
|
}
|
|
|
|
select {
|
|
case <-time.After(3 * time.Second):
|
|
t.Errorf("Timed out waiting for second dialer to finish")
|
|
case err := <-mock2err:
|
|
if !errors.Is(err, context.Canceled) {
|
|
t.Errorf("Expected context.Canceled error, got %v", err)
|
|
}
|
|
}
|
|
}
|