package client import ( "context" "net" "os" "testing" "time" log "github.com/sirupsen/logrus" "github.com/netbirdio/netbird/relay/auth" "github.com/netbirdio/netbird/relay/auth/hmac" "github.com/netbirdio/netbird/util" "github.com/netbirdio/netbird/relay/server" ) var ( av = &auth.AllowAllAuth{} hmacTokenStore = &hmac.TokenStore{} serverListenAddr = "localhost:1234" serverURL = "rel://localhost:1234" ) func TestMain(m *testing.M) { _ = util.InitLog("trace", "console") code := m.Run() os.Exit(code) } func TestClient(t *testing.T) { ctx := context.Background() srv := server.NewServer(serverURL, false, av) go func() { listenCfg := server.ListenerConfig{Address: serverListenAddr} err := srv.Listen(listenCfg) if err != nil { t.Fatalf("failed to bind server: %s", err) } }() defer func() { err := srv.Close() if err != nil { t.Errorf("failed to close server: %s", err) } }() t.Log("alice connecting to server") clientAlice := NewClient(ctx, serverURL, hmacTokenStore, "alice") err := clientAlice.Connect() if err != nil { t.Fatalf("failed to connect to server: %s", err) } defer clientAlice.Close() t.Log("placeholder connecting to server") clientPlaceHolder := NewClient(ctx, serverURL, hmacTokenStore, "clientPlaceHolder") err = clientPlaceHolder.Connect() if err != nil { t.Fatalf("failed to connect to server: %s", err) } defer clientPlaceHolder.Close() t.Log("Bob connecting to server") clientBob := NewClient(ctx, serverURL, hmacTokenStore, "bob") err = clientBob.Connect() if err != nil { t.Fatalf("failed to connect to server: %s", err) } defer clientBob.Close() t.Log("Alice open connection to Bob") connAliceToBob, err := clientAlice.OpenConn("bob") if err != nil { t.Fatalf("failed to bind channel: %s", err) } t.Log("Bob open connection to Alice") connBobToAlice, err := clientBob.OpenConn("alice") if err != nil { t.Fatalf("failed to bind channel: %s", err) } payload := "hello bob, I am alice" _, err = connAliceToBob.Write([]byte(payload)) if err != nil { t.Fatalf("failed to write to channel: %s", err) } log.Debugf("alice sent message to bob") buf := make([]byte, 65535) n, err := connBobToAlice.Read(buf) if err != nil { t.Fatalf("failed to read from channel: %s", err) } log.Debugf("on new message from alice to bob") if payload != string(buf[:n]) { t.Fatalf("expected %s, got %s", payload, string(buf[:n])) } } func TestRegistration(t *testing.T) { ctx := context.Background() srvCfg := server.ListenerConfig{Address: serverListenAddr} srv := server.NewServer(serverURL, false, av) go func() { err := srv.Listen(srvCfg) if err != nil { t.Fatalf("failed to bind server: %s", err) } }() clientAlice := NewClient(ctx, serverURL, hmacTokenStore, "alice") err := clientAlice.Connect() if err != nil { _ = srv.Close() t.Fatalf("failed to connect to server: %s", err) } err = clientAlice.Close() if err != nil { t.Errorf("failed to close conn: %s", err) } err = srv.Close() if err != nil { t.Errorf("failed to close server: %s", err) } } func TestRegistrationTimeout(t *testing.T) { ctx := context.Background() fakeUDPListener, err := net.ListenUDP("udp", &net.UDPAddr{ Port: 1234, IP: net.ParseIP("0.0.0.0"), }) if err != nil { t.Fatalf("failed to bind UDP server: %s", err) } defer func(fakeUDPListener *net.UDPConn) { _ = fakeUDPListener.Close() }(fakeUDPListener) fakeTCPListener, err := net.ListenTCP("tcp", &net.TCPAddr{ Port: 1234, IP: net.ParseIP("0.0.0.0"), }) if err != nil { t.Fatalf("failed to bind TCP server: %s", err) } defer func(fakeTCPListener *net.TCPListener) { _ = fakeTCPListener.Close() }(fakeTCPListener) clientAlice := NewClient(ctx, "127.0.0.1:1234", hmacTokenStore, "alice") err = clientAlice.Connect() if err == nil { t.Errorf("failed to connect to server: %s", err) } log.Debugf("%s", err) err = clientAlice.Close() if err != nil { t.Errorf("failed to close conn: %s", err) } } func TestEcho(t *testing.T) { ctx := context.Background() idAlice := "alice" idBob := "bob" srvCfg := server.ListenerConfig{Address: serverListenAddr} srv := server.NewServer(serverURL, false, av) go func() { err := srv.Listen(srvCfg) if err != nil { t.Fatalf("failed to bind server: %s", err) } }() defer func() { err := srv.Close() if err != nil { t.Errorf("failed to close server: %s", err) } }() clientAlice := NewClient(ctx, serverURL, hmacTokenStore, idAlice) err := clientAlice.Connect() if err != nil { t.Fatalf("failed to connect to server: %s", err) } defer func() { err := clientAlice.Close() if err != nil { t.Errorf("failed to close Alice client: %s", err) } }() clientBob := NewClient(ctx, serverURL, hmacTokenStore, idBob) err = clientBob.Connect() if err != nil { t.Fatalf("failed to connect to server: %s", err) } defer func() { err := clientBob.Close() if err != nil { t.Errorf("failed to close Bob client: %s", err) } }() connAliceToBob, err := clientAlice.OpenConn(idBob) if err != nil { t.Fatalf("failed to bind channel: %s", err) } connBobToAlice, err := clientBob.OpenConn(idAlice) if err != nil { t.Fatalf("failed to bind channel: %s", err) } payload := "hello bob, I am alice" _, err = connAliceToBob.Write([]byte(payload)) if err != nil { t.Fatalf("failed to write to channel: %s", err) } buf := make([]byte, 65535) n, err := connBobToAlice.Read(buf) if err != nil { t.Fatalf("failed to read from channel: %s", err) } _, err = connBobToAlice.Write(buf[:n]) if err != nil { t.Fatalf("failed to write to channel: %s", err) } n, err = connAliceToBob.Read(buf) if err != nil { t.Fatalf("failed to read from channel: %s", err) } if payload != string(buf[:n]) { t.Fatalf("expected %s, got %s", payload, string(buf[:n])) } } func TestBindToUnavailabePeer(t *testing.T) { ctx := context.Background() srvCfg := server.ListenerConfig{Address: serverListenAddr} srv := server.NewServer(serverURL, false, av) go func() { err := srv.Listen(srvCfg) if err != nil { t.Fatalf("failed to bind server: %s", err) } }() defer func() { log.Infof("closing server") err := srv.Close() if err != nil { t.Errorf("failed to close server: %s", err) } }() clientAlice := NewClient(ctx, serverURL, hmacTokenStore, "alice") err := clientAlice.Connect() if err != nil { t.Errorf("failed to connect to server: %s", err) } _, err = clientAlice.OpenConn("bob") if err != nil { t.Errorf("failed to bind channel: %s", err) } log.Infof("closing client") err = clientAlice.Close() if err != nil { t.Errorf("failed to close client: %s", err) } } func TestBindReconnect(t *testing.T) { ctx := context.Background() srvCfg := server.ListenerConfig{Address: serverListenAddr} srv := server.NewServer(serverURL, false, av) go func() { err := srv.Listen(srvCfg) if err != nil { t.Errorf("failed to bind server: %s", err) } }() defer func() { log.Infof("closing server") err := srv.Close() if err != nil { t.Errorf("failed to close server: %s", err) } }() clientAlice := NewClient(ctx, serverURL, hmacTokenStore, "alice") err := clientAlice.Connect() if err != nil { t.Errorf("failed to connect to server: %s", err) } _, err = clientAlice.OpenConn("bob") if err != nil { t.Errorf("failed to bind channel: %s", err) } clientBob := NewClient(ctx, serverURL, hmacTokenStore, "bob") err = clientBob.Connect() if err != nil { t.Errorf("failed to connect to server: %s", err) } chBob, err := clientBob.OpenConn("alice") if err != nil { t.Errorf("failed to bind channel: %s", err) } log.Infof("closing client Alice") err = clientAlice.Close() if err != nil { t.Errorf("failed to close client: %s", err) } clientAlice = NewClient(ctx, serverURL, hmacTokenStore, "alice") err = clientAlice.Connect() if err != nil { t.Errorf("failed to connect to server: %s", err) } chAlice, err := clientAlice.OpenConn("bob") if err != nil { t.Errorf("failed to bind channel: %s", err) } testString := "hello alice, I am bob" _, err = chBob.Write([]byte(testString)) if err != nil { t.Errorf("failed to write to channel: %s", err) } buf := make([]byte, 65535) n, err := chAlice.Read(buf) if err != nil { t.Errorf("failed to read from channel: %s", err) } if testString != string(buf[:n]) { t.Errorf("expected %s, got %s", testString, string(buf[:n])) } log.Infof("closing client") err = clientAlice.Close() if err != nil { t.Errorf("failed to close client: %s", err) } } func TestCloseConn(t *testing.T) { ctx := context.Background() srvCfg := server.ListenerConfig{Address: serverListenAddr} srv := server.NewServer(serverURL, false, av) go func() { err := srv.Listen(srvCfg) if err != nil { t.Errorf("failed to bind server: %s", err) } }() defer func() { log.Infof("closing server") err := srv.Close() if err != nil { t.Errorf("failed to close server: %s", err) } }() clientAlice := NewClient(ctx, serverURL, hmacTokenStore, "alice") err := clientAlice.Connect() if err != nil { t.Errorf("failed to connect to server: %s", err) } conn, err := clientAlice.OpenConn("bob") if err != nil { t.Errorf("failed to bind channel: %s", err) } log.Infof("closing connection") err = conn.Close() if err != nil { t.Errorf("failed to close connection: %s", err) } _, err = conn.Read(make([]byte, 1)) if err == nil { t.Errorf("unexpected reading from closed connection") } _, err = conn.Write([]byte("hello")) if err == nil { t.Errorf("unexpected writing from closed connection") } } func TestCloseRelayConn(t *testing.T) { ctx := context.Background() srvCfg := server.ListenerConfig{Address: serverListenAddr} srv := server.NewServer(serverURL, false, av) go func() { err := srv.Listen(srvCfg) if err != nil { t.Errorf("failed to bind server: %s", err) } }() defer func() { err := srv.Close() if err != nil { log.Errorf("failed to close server: %s", err) } }() clientAlice := NewClient(ctx, serverURL, hmacTokenStore, "alice") err := clientAlice.Connect() if err != nil { t.Fatalf("failed to connect to server: %s", err) } conn, err := clientAlice.OpenConn("bob") if err != nil { t.Errorf("failed to bind channel: %s", err) } _ = clientAlice.relayConn.Close() _, err = conn.Read(make([]byte, 1)) if err == nil { t.Errorf("unexpected reading from closed connection") } _, err = clientAlice.OpenConn("bob") if err == nil { t.Errorf("unexpected opening connection to closed server") } } func TestCloseByServer(t *testing.T) { ctx := context.Background() srvCfg := server.ListenerConfig{Address: serverListenAddr} srv1 := server.NewServer(serverURL, false, av) go func() { err := srv1.Listen(srvCfg) if err != nil { t.Fatalf("failed to bind server: %s", err) } }() idAlice := "alice" log.Debugf("connect by alice") relayClient := NewClient(ctx, serverURL, hmacTokenStore, idAlice) err := relayClient.Connect() if err != nil { log.Fatalf("failed to connect to server: %s", err) } disconnected := make(chan struct{}) relayClient.SetOnDisconnectListener(func() { log.Infof("client disconnected") close(disconnected) }) err = srv1.Close() if err != nil { t.Fatalf("failed to close server: %s", err) } select { case <-disconnected: case <-time.After(3 * time.Second): log.Fatalf("timeout waiting for client to disconnect") } _, err = relayClient.OpenConn("bob") if err == nil { t.Errorf("unexpected opening connection to closed server") } } func TestCloseByClient(t *testing.T) { ctx := context.Background() srvCfg := server.ListenerConfig{Address: serverListenAddr} srv := server.NewServer(serverURL, false, av) go func() { err := srv.Listen(srvCfg) if err != nil { t.Fatalf("failed to bind server: %s", err) } }() idAlice := "alice" log.Debugf("connect by alice") relayClient := NewClient(ctx, serverURL, hmacTokenStore, idAlice) err := relayClient.Connect() if err != nil { log.Fatalf("failed to connect to server: %s", err) } err = relayClient.Close() if err != nil { t.Errorf("failed to close client: %s", err) } _, err = relayClient.OpenConn("bob") if err == nil { t.Errorf("unexpected opening connection to closed server") } err = srv.Close() if err != nil { t.Fatalf("failed to close server: %s", err) } }