Fix parameters of tests

This commit is contained in:
Zoltán Papp 2024-07-08 17:01:11 +02:00
parent 75f5b75bc4
commit 1f949f8cee
7 changed files with 76 additions and 62 deletions

View File

@ -4,20 +4,20 @@ import (
"sync" "sync"
) )
// Store is a simple in-memory store for token // TokenStore is a simple in-memory store for token
// With this can update the token in thread safe way // With this can update the token in thread safe way
type Store struct { type TokenStore struct {
mu sync.Mutex mu sync.Mutex
token Token token Token
} }
func (a *Store) UpdateToken(token Token) { func (a *TokenStore) UpdateToken(token Token) {
a.mu.Lock() a.mu.Lock()
defer a.mu.Unlock() defer a.mu.Unlock()
a.token = token a.token = token
} }
func (a *Store) Token() ([]byte, error) { func (a *TokenStore) Token() ([]byte, error) {
a.mu.Lock() a.mu.Lock()
defer a.mu.Unlock() defer a.mu.Unlock()
return marshalToken(a.token) return marshalToken(a.token)

View File

@ -96,11 +96,11 @@ func (cc *connContainer) close() {
// the client can be reused by calling Connect again. When the client is closed, all connections are closed too. // the client can be reused by calling Connect again. When the client is closed, all connections are closed too.
// While the Connect is in progress, the OpenConn function will block until the connection is established. // While the Connect is in progress, the OpenConn function will block until the connection is established.
type Client struct { type Client struct {
log *log.Entry log *log.Entry
parentCtx context.Context parentCtx context.Context
connectionURL string connectionURL string
authStore *auth.Store authTokenStore *auth.TokenStore
hashedID []byte hashedID []byte
bufPool *sync.Pool bufPool *sync.Pool
@ -117,14 +117,14 @@ type Client struct {
} }
// NewClient creates a new client for the relay server. The client is not connected to the server until the Connect // NewClient creates a new client for the relay server. The client is not connected to the server until the Connect
func NewClient(ctx context.Context, serverURL string, authStore *auth.Store, peerID string) *Client { func NewClient(ctx context.Context, serverURL string, authTokenStore *auth.TokenStore, peerID string) *Client {
hashedID, hashedStringId := messages.HashID(peerID) hashedID, hashedStringId := messages.HashID(peerID)
return &Client{ return &Client{
log: log.WithField("client_id", hashedStringId), log: log.WithField("client_id", hashedStringId),
parentCtx: ctx, parentCtx: ctx,
connectionURL: serverURL, connectionURL: serverURL,
authStore: authStore, authTokenStore: authTokenStore,
hashedID: hashedID, hashedID: hashedID,
bufPool: &sync.Pool{ bufPool: &sync.Pool{
New: func() any { New: func() any {
buf := make([]byte, bufferSize) buf := make([]byte, bufferSize)
@ -237,7 +237,7 @@ func (c *Client) connect() error {
} }
func (c *Client) handShake() error { func (c *Client) handShake() error {
t, err := c.authStore.Token() t, err := c.authTokenStore.Token()
if err != nil { if err != nil {
return err return err
} }

View File

@ -9,11 +9,20 @@ import (
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/relay/auth"
"github.com/netbirdio/netbird/relay/auth/hmac"
"github.com/netbirdio/netbird/util" "github.com/netbirdio/netbird/util"
"github.com/netbirdio/netbird/relay/server" "github.com/netbirdio/netbird/relay/server"
) )
var (
av = &auth.AllowAllAuth{}
hmacTokenStore = &hmac.TokenStore{}
serverListenAddr = "localhost:1234"
serverURL = "rel://localhost:1234"
)
func TestMain(m *testing.M) { func TestMain(m *testing.M) {
_ = util.InitLog("trace", "console") _ = util.InitLog("trace", "console")
code := m.Run() code := m.Run()
@ -23,8 +32,8 @@ func TestMain(m *testing.M) {
func TestClient(t *testing.T) { func TestClient(t *testing.T) {
ctx := context.Background() ctx := context.Background()
srvCfg := server.ListenerConfig{Address: "localhost:1234"} srvCfg := server.ListenerConfig{Address: serverListenAddr}
srv := server.NewServer(srvCfg.Address, false) srv := server.NewServer(serverURL, false, av)
go func() { go func() {
err := srv.Listen(srvCfg) err := srv.Listen(srvCfg)
if err != nil { if err != nil {
@ -39,21 +48,21 @@ func TestClient(t *testing.T) {
} }
}() }()
clientAlice := NewClient(ctx, srvCfg.Address, "alice") clientAlice := NewClient(ctx, serverURL, hmacTokenStore, "alice")
err := clientAlice.Connect() err := clientAlice.Connect()
if err != nil { if err != nil {
t.Fatalf("failed to connect to server: %s", err) t.Fatalf("failed to connect to server: %s", err)
} }
defer clientAlice.Close() defer clientAlice.Close()
clientPlaceHolder := NewClient(ctx, srvCfg.Address, "clientPlaceHolder") clientPlaceHolder := NewClient(ctx, serverURL, hmacTokenStore, "clientPlaceHolder")
err = clientPlaceHolder.Connect() err = clientPlaceHolder.Connect()
if err != nil { if err != nil {
t.Fatalf("failed to connect to server: %s", err) t.Fatalf("failed to connect to server: %s", err)
} }
defer clientPlaceHolder.Close() defer clientPlaceHolder.Close()
clientBob := NewClient(ctx, srvCfg.Address, "bob") clientBob := NewClient(ctx, serverURL, hmacTokenStore, "bob")
err = clientBob.Connect() err = clientBob.Connect()
if err != nil { if err != nil {
t.Fatalf("failed to connect to server: %s", err) t.Fatalf("failed to connect to server: %s", err)
@ -91,8 +100,8 @@ func TestClient(t *testing.T) {
func TestRegistration(t *testing.T) { func TestRegistration(t *testing.T) {
ctx := context.Background() ctx := context.Background()
srvCfg := server.ListenerConfig{Address: "localhost:1234"} srvCfg := server.ListenerConfig{Address: serverListenAddr}
srv := server.NewServer(srvCfg.Address, false) srv := server.NewServer(serverURL, false, av)
go func() { go func() {
err := srv.Listen(srvCfg) err := srv.Listen(srvCfg)
if err != nil { if err != nil {
@ -100,7 +109,7 @@ func TestRegistration(t *testing.T) {
} }
}() }()
clientAlice := NewClient(ctx, srvCfg.Address, "alice") clientAlice := NewClient(ctx, serverURL, hmacTokenStore, "alice")
err := clientAlice.Connect() err := clientAlice.Connect()
if err != nil { if err != nil {
_ = srv.Close() _ = srv.Close()
@ -140,7 +149,7 @@ func TestRegistrationTimeout(t *testing.T) {
_ = fakeTCPListener.Close() _ = fakeTCPListener.Close()
}(fakeTCPListener) }(fakeTCPListener)
clientAlice := NewClient(ctx, "127.0.0.1:1234", "alice") clientAlice := NewClient(ctx, "127.0.0.1:1234", hmacTokenStore, "alice")
err = clientAlice.Connect() err = clientAlice.Connect()
if err == nil { if err == nil {
t.Errorf("failed to connect to server: %s", err) t.Errorf("failed to connect to server: %s", err)
@ -156,8 +165,8 @@ func TestEcho(t *testing.T) {
ctx := context.Background() ctx := context.Background()
idAlice := "alice" idAlice := "alice"
idBob := "bob" idBob := "bob"
srvCfg := server.ListenerConfig{Address: "localhost:1234"} srvCfg := server.ListenerConfig{Address: serverListenAddr}
srv := server.NewServer(srvCfg.Address, false) srv := server.NewServer(serverURL, false, av)
go func() { go func() {
err := srv.Listen(srvCfg) err := srv.Listen(srvCfg)
if err != nil { if err != nil {
@ -172,7 +181,7 @@ func TestEcho(t *testing.T) {
} }
}() }()
clientAlice := NewClient(ctx, srvCfg.Address, idAlice) clientAlice := NewClient(ctx, serverURL, hmacTokenStore, idAlice)
err := clientAlice.Connect() err := clientAlice.Connect()
if err != nil { if err != nil {
t.Fatalf("failed to connect to server: %s", err) t.Fatalf("failed to connect to server: %s", err)
@ -184,7 +193,7 @@ func TestEcho(t *testing.T) {
} }
}() }()
clientBob := NewClient(ctx, srvCfg.Address, idBob) clientBob := NewClient(ctx, serverURL, hmacTokenStore, idBob)
err = clientBob.Connect() err = clientBob.Connect()
if err != nil { if err != nil {
t.Fatalf("failed to connect to server: %s", err) t.Fatalf("failed to connect to server: %s", err)
@ -236,8 +245,8 @@ func TestEcho(t *testing.T) {
func TestBindToUnavailabePeer(t *testing.T) { func TestBindToUnavailabePeer(t *testing.T) {
ctx := context.Background() ctx := context.Background()
srvCfg := server.ListenerConfig{Address: "localhost:1234"} srvCfg := server.ListenerConfig{Address: serverListenAddr}
srv := server.NewServer(srvCfg.Address, false) srv := server.NewServer(serverURL, false, av)
go func() { go func() {
err := srv.Listen(srvCfg) err := srv.Listen(srvCfg)
if err != nil { if err != nil {
@ -253,7 +262,7 @@ func TestBindToUnavailabePeer(t *testing.T) {
} }
}() }()
clientAlice := NewClient(ctx, srvCfg.Address, "alice") clientAlice := NewClient(ctx, serverURL, hmacTokenStore, "alice")
err := clientAlice.Connect() err := clientAlice.Connect()
if err != nil { if err != nil {
t.Errorf("failed to connect to server: %s", err) t.Errorf("failed to connect to server: %s", err)
@ -273,8 +282,8 @@ func TestBindToUnavailabePeer(t *testing.T) {
func TestBindReconnect(t *testing.T) { func TestBindReconnect(t *testing.T) {
ctx := context.Background() ctx := context.Background()
srvCfg := server.ListenerConfig{Address: "localhost:1234"} srvCfg := server.ListenerConfig{Address: serverListenAddr}
srv := server.NewServer(srvCfg.Address, false) srv := server.NewServer(serverURL, false, av)
go func() { go func() {
err := srv.Listen(srvCfg) err := srv.Listen(srvCfg)
if err != nil { if err != nil {
@ -290,7 +299,7 @@ func TestBindReconnect(t *testing.T) {
} }
}() }()
clientAlice := NewClient(ctx, srvCfg.Address, "alice") clientAlice := NewClient(ctx, serverURL, hmacTokenStore, "alice")
err := clientAlice.Connect() err := clientAlice.Connect()
if err != nil { if err != nil {
t.Errorf("failed to connect to server: %s", err) t.Errorf("failed to connect to server: %s", err)
@ -301,7 +310,7 @@ func TestBindReconnect(t *testing.T) {
t.Errorf("failed to bind channel: %s", err) t.Errorf("failed to bind channel: %s", err)
} }
clientBob := NewClient(ctx, srvCfg.Address, "bob") clientBob := NewClient(ctx, serverURL, hmacTokenStore, "bob")
err = clientBob.Connect() err = clientBob.Connect()
if err != nil { if err != nil {
t.Errorf("failed to connect to server: %s", err) t.Errorf("failed to connect to server: %s", err)
@ -318,7 +327,7 @@ func TestBindReconnect(t *testing.T) {
t.Errorf("failed to close client: %s", err) t.Errorf("failed to close client: %s", err)
} }
clientAlice = NewClient(ctx, srvCfg.Address, "alice") clientAlice = NewClient(ctx, serverURL, hmacTokenStore, "alice")
err = clientAlice.Connect() err = clientAlice.Connect()
if err != nil { if err != nil {
t.Errorf("failed to connect to server: %s", err) t.Errorf("failed to connect to server: %s", err)
@ -355,8 +364,8 @@ func TestBindReconnect(t *testing.T) {
func TestCloseConn(t *testing.T) { func TestCloseConn(t *testing.T) {
ctx := context.Background() ctx := context.Background()
srvCfg := server.ListenerConfig{Address: "localhost:1234"} srvCfg := server.ListenerConfig{Address: serverListenAddr}
srv := server.NewServer(srvCfg.Address, false) srv := server.NewServer(serverURL, false, av)
go func() { go func() {
err := srv.Listen(srvCfg) err := srv.Listen(srvCfg)
if err != nil { if err != nil {
@ -372,7 +381,7 @@ func TestCloseConn(t *testing.T) {
} }
}() }()
clientAlice := NewClient(ctx, srvCfg.Address, "alice") clientAlice := NewClient(ctx, serverURL, hmacTokenStore, "alice")
err := clientAlice.Connect() err := clientAlice.Connect()
if err != nil { if err != nil {
t.Errorf("failed to connect to server: %s", err) t.Errorf("failed to connect to server: %s", err)
@ -403,8 +412,8 @@ func TestCloseConn(t *testing.T) {
func TestCloseRelayConn(t *testing.T) { func TestCloseRelayConn(t *testing.T) {
ctx := context.Background() ctx := context.Background()
srvCfg := server.ListenerConfig{Address: "localhost:1234"} srvCfg := server.ListenerConfig{Address: serverListenAddr}
srv := server.NewServer(srvCfg.Address, false) srv := server.NewServer(serverURL, false, av)
go func() { go func() {
err := srv.Listen(srvCfg) err := srv.Listen(srvCfg)
if err != nil { if err != nil {
@ -419,7 +428,7 @@ func TestCloseRelayConn(t *testing.T) {
} }
}() }()
clientAlice := NewClient(ctx, srvCfg.Address, "alice") clientAlice := NewClient(ctx, serverURL, hmacTokenStore, "alice")
err := clientAlice.Connect() err := clientAlice.Connect()
if err != nil { if err != nil {
t.Fatalf("failed to connect to server: %s", err) t.Fatalf("failed to connect to server: %s", err)
@ -446,8 +455,8 @@ func TestCloseRelayConn(t *testing.T) {
func TestCloseByServer(t *testing.T) { func TestCloseByServer(t *testing.T) {
ctx := context.Background() ctx := context.Background()
srvCfg := server.ListenerConfig{Address: "localhost:1234"} srvCfg := server.ListenerConfig{Address: serverListenAddr}
srv1 := server.NewServer(srvCfg.Address, false) srv1 := server.NewServer(serverURL, false, av)
go func() { go func() {
err := srv1.Listen(srvCfg) err := srv1.Listen(srvCfg)
if err != nil { if err != nil {
@ -457,7 +466,7 @@ func TestCloseByServer(t *testing.T) {
idAlice := "alice" idAlice := "alice"
log.Debugf("connect by alice") log.Debugf("connect by alice")
relayClient := NewClient(ctx, srvCfg.Address, idAlice) relayClient := NewClient(ctx, serverURL, hmacTokenStore, idAlice)
err := relayClient.Connect() err := relayClient.Connect()
if err != nil { if err != nil {
log.Fatalf("failed to connect to server: %s", err) log.Fatalf("failed to connect to server: %s", err)
@ -489,8 +498,8 @@ func TestCloseByServer(t *testing.T) {
func TestCloseByClient(t *testing.T) { func TestCloseByClient(t *testing.T) {
ctx := context.Background() ctx := context.Background()
srvCfg := server.ListenerConfig{Address: "localhost:1234"} srvCfg := server.ListenerConfig{Address: serverListenAddr}
srv := server.NewServer(srvCfg.Address, false) srv := server.NewServer(serverURL, false, av)
go func() { go func() {
err := srv.Listen(srvCfg) err := srv.Listen(srvCfg)
if err != nil { if err != nil {
@ -500,7 +509,7 @@ func TestCloseByClient(t *testing.T) {
idAlice := "alice" idAlice := "alice"
log.Debugf("connect by alice") log.Debugf("connect by alice")
relayClient := NewClient(ctx, srvCfg.Address, idAlice) relayClient := NewClient(ctx, serverURL, hmacTokenStore, idAlice)
err := relayClient.Connect() err := relayClient.Connect()
if err != nil { if err != nil {
log.Fatalf("failed to connect to server: %s", err) log.Fatalf("failed to connect to server: %s", err)

View File

@ -40,7 +40,7 @@ type Manager struct {
ctx context.Context ctx context.Context
serverURL string serverURL string
peerID string peerID string
tokenStore *relayAuth.Store tokenStore *relayAuth.TokenStore
relayClient *Client relayClient *Client
reconnectGuard *Guard reconnectGuard *Guard
@ -57,7 +57,7 @@ func NewManager(ctx context.Context, serverURL string, peerID string) *Manager {
ctx: ctx, ctx: ctx,
serverURL: serverURL, serverURL: serverURL,
peerID: peerID, peerID: peerID,
tokenStore: &relayAuth.Store{}, tokenStore: &relayAuth.TokenStore{},
relayClients: make(map[string]*RelayTrack), relayClients: make(map[string]*RelayTrack),
onDisconnectedListeners: make(map[string]map[*func()]struct{}), onDisconnectedListeners: make(map[string]map[*func()]struct{}),
} }

View File

@ -16,7 +16,7 @@ func TestForeignConn(t *testing.T) {
srvCfg1 := server.ListenerConfig{ srvCfg1 := server.ListenerConfig{
Address: "localhost:1234", Address: "localhost:1234",
} }
srv1 := server.NewServer(srvCfg1.Address, false) srv1 := server.NewServer(srvCfg1.Address, false, av)
go func() { go func() {
err := srv1.Listen(srvCfg1) err := srv1.Listen(srvCfg1)
if err != nil { if err != nil {
@ -34,7 +34,7 @@ func TestForeignConn(t *testing.T) {
srvCfg2 := server.ListenerConfig{ srvCfg2 := server.ListenerConfig{
Address: "localhost:2234", Address: "localhost:2234",
} }
srv2 := server.NewServer(srvCfg2.Address, false) srv2 := server.NewServer(srvCfg2.Address, false, av)
go func() { go func() {
err := srv2.Listen(srvCfg2) err := srv2.Listen(srvCfg2)
if err != nil { if err != nil {
@ -107,7 +107,7 @@ func TestForeginConnClose(t *testing.T) {
srvCfg1 := server.ListenerConfig{ srvCfg1 := server.ListenerConfig{
Address: "localhost:1234", Address: "localhost:1234",
} }
srv1 := server.NewServer(srvCfg1.Address, false) srv1 := server.NewServer(srvCfg1.Address, false, av)
go func() { go func() {
err := srv1.Listen(srvCfg1) err := srv1.Listen(srvCfg1)
if err != nil { if err != nil {
@ -125,7 +125,7 @@ func TestForeginConnClose(t *testing.T) {
srvCfg2 := server.ListenerConfig{ srvCfg2 := server.ListenerConfig{
Address: "localhost:2234", Address: "localhost:2234",
} }
srv2 := server.NewServer(srvCfg2.Address, false) srv2 := server.NewServer(srvCfg2.Address, false, av)
go func() { go func() {
err := srv2.Listen(srvCfg2) err := srv2.Listen(srvCfg2)
if err != nil { if err != nil {
@ -164,7 +164,7 @@ func TestForeginAutoClose(t *testing.T) {
srvCfg1 := server.ListenerConfig{ srvCfg1 := server.ListenerConfig{
Address: "localhost:1234", Address: "localhost:1234",
} }
srv1 := server.NewServer(srvCfg1.Address, false) srv1 := server.NewServer(srvCfg1.Address, false, av)
go func() { go func() {
t.Log("binding server 1.") t.Log("binding server 1.")
err := srv1.Listen(srvCfg1) err := srv1.Listen(srvCfg1)
@ -185,7 +185,7 @@ func TestForeginAutoClose(t *testing.T) {
srvCfg2 := server.ListenerConfig{ srvCfg2 := server.ListenerConfig{
Address: "localhost:2234", Address: "localhost:2234",
} }
srv2 := server.NewServer(srvCfg2.Address, false) srv2 := server.NewServer(srvCfg2.Address, false, av)
go func() { go func() {
t.Log("binding server 2.") t.Log("binding server 2.")
err := srv2.Listen(srvCfg2) err := srv2.Listen(srvCfg2)
@ -237,7 +237,7 @@ func TestAutoReconnect(t *testing.T) {
srvCfg := server.ListenerConfig{ srvCfg := server.ListenerConfig{
Address: "localhost:1234", Address: "localhost:1234",
} }
srv := server.NewServer(srvCfg.Address, false) srv := server.NewServer(srvCfg.Address, false, av)
go func() { go func() {
err := srv.Listen(srvCfg) err := srv.Listen(srvCfg)
if err != nil { if err != nil {

View File

@ -6,11 +6,13 @@ import (
"os" "os"
"os/signal" "os/signal"
"syscall" "syscall"
"time"
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
"github.com/spf13/cobra" "github.com/spf13/cobra"
"github.com/netbirdio/netbird/encryption" "github.com/netbirdio/netbird/encryption"
auth "github.com/netbirdio/netbird/relay/auth/hmac"
"github.com/netbirdio/netbird/relay/server" "github.com/netbirdio/netbird/relay/server"
"github.com/netbirdio/netbird/util" "github.com/netbirdio/netbird/util"
) )
@ -82,7 +84,9 @@ func execute(cmd *cobra.Command, args []string) {
} }
tlsSupport := srvListenerCfg.TLSConfig != nil tlsSupport := srvListenerCfg.TLSConfig != nil
srv := server.NewServer(exposedAddress, tlsSupport, authSecret)
authenticator := auth.NewTimedHMACValidator(authSecret, 24*time.Hour)
srv := server.NewServer(exposedAddress, tlsSupport, authenticator)
log.Infof("server will be available on: %s", srv.InstanceURL()) log.Infof("server will be available on: %s", srv.InstanceURL())
err := srv.Listen(srvListenerCfg) err := srv.Listen(srvListenerCfg)
if err != nil { if err != nil {

View File

@ -9,7 +9,7 @@ import (
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
auth "github.com/netbirdio/netbird/relay/auth/hmac" "github.com/netbirdio/netbird/relay/auth"
"github.com/netbirdio/netbird/relay/server/listener" "github.com/netbirdio/netbird/relay/server/listener"
"github.com/netbirdio/netbird/relay/server/listener/udp" "github.com/netbirdio/netbird/relay/server/listener/udp"
"github.com/netbirdio/netbird/relay/server/listener/ws" "github.com/netbirdio/netbird/relay/server/listener/ws"
@ -26,12 +26,13 @@ type Server struct {
wSListener listener.Listener wSListener listener.Listener
} }
func NewServer(exposedAddress string, tlsSupport bool, authSecret string) *Server { func NewServer(exposedAddress string, tlsSupport bool, authValidator auth.Validator) *Server {
return &Server{ return &Server{
relay: NewRelay( relay: NewRelay(
exposedAddress, exposedAddress,
tlsSupport, tlsSupport,
auth.NewTimedHMACValidator(authSecret, 24*time.Hour)), authValidator,
),
} }
} }