mirror of
https://github.com/netbirdio/netbird.git
synced 2025-06-21 18:22:37 +02:00
Fix in client the close event
This commit is contained in:
parent
36b2cd16cc
commit
173ca25dac
@ -1,6 +1,7 @@
|
|||||||
package client
|
package client
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"context"
|
||||||
"fmt"
|
"fmt"
|
||||||
"io"
|
"io"
|
||||||
"net"
|
"net"
|
||||||
@ -29,20 +30,25 @@ type connContainer struct {
|
|||||||
|
|
||||||
type Client struct {
|
type Client struct {
|
||||||
log *log.Entry
|
log *log.Entry
|
||||||
|
ctx context.Context
|
||||||
|
ctxCancel context.CancelFunc
|
||||||
serverAddress string
|
serverAddress string
|
||||||
hashedID []byte
|
hashedID []byte
|
||||||
|
|
||||||
conns map[string]*connContainer
|
conns map[string]*connContainer // todo handle it in thread safe way
|
||||||
|
|
||||||
relayConn net.Conn
|
relayConn net.Conn
|
||||||
relayConnState bool
|
relayConnState bool
|
||||||
mu sync.Mutex
|
mu sync.Mutex
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewClient(serverAddress, peerID string) *Client {
|
func NewClient(ctx context.Context, serverAddress, peerID string) *Client {
|
||||||
|
ctx, ctxCancel := context.WithCancel(ctx)
|
||||||
hashedID, hashedStringId := messages.HashID(peerID)
|
hashedID, hashedStringId := messages.HashID(peerID)
|
||||||
return &Client{
|
return &Client{
|
||||||
log: log.WithField("client_id", hashedStringId),
|
log: log.WithField("client_id", hashedStringId),
|
||||||
|
ctx: ctx,
|
||||||
|
ctxCancel: ctxCancel,
|
||||||
serverAddress: serverAddress,
|
serverAddress: serverAddress,
|
||||||
hashedID: hashedID,
|
hashedID: hashedID,
|
||||||
conns: make(map[string]*connContainer),
|
conns: make(map[string]*connContainer),
|
||||||
@ -51,7 +57,11 @@ func NewClient(serverAddress, peerID string) *Client {
|
|||||||
|
|
||||||
func (c *Client) Connect() error {
|
func (c *Client) Connect() error {
|
||||||
c.mu.Lock()
|
c.mu.Lock()
|
||||||
defer c.mu.Unlock()
|
if c.relayConnState {
|
||||||
|
c.mu.Unlock()
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
conn, err := udp.Dial(c.serverAddress)
|
conn, err := udp.Dial(c.serverAddress)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
@ -68,18 +78,39 @@ func (c *Client) Connect() error {
|
|||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
err = c.relayConn.SetReadDeadline(time.Time{})
|
|
||||||
if err != nil {
|
|
||||||
log.Errorf("failed to reset read deadline: %s", err)
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
c.relayConnState = true
|
c.relayConnState = true
|
||||||
go c.readLoop()
|
c.mu.Unlock()
|
||||||
|
|
||||||
|
go func() {
|
||||||
|
<-c.ctx.Done()
|
||||||
|
cErr := c.close()
|
||||||
|
if cErr != nil {
|
||||||
|
log.Errorf("failed to close relay connection: %s", cErr)
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
// blocking function
|
||||||
|
c.readLoop()
|
||||||
|
|
||||||
|
c.mu.Lock()
|
||||||
|
|
||||||
|
// close all Conn types
|
||||||
|
for _, container := range c.conns {
|
||||||
|
close(container.messages)
|
||||||
|
}
|
||||||
|
c.conns = make(map[string]*connContainer)
|
||||||
|
|
||||||
|
c.mu.Unlock()
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *Client) OpenConn(dstPeerID string) (net.Conn, error) {
|
func (c *Client) OpenConn(dstPeerID string) (net.Conn, error) {
|
||||||
|
c.mu.Lock()
|
||||||
|
defer c.mu.Unlock()
|
||||||
|
if !c.relayConnState {
|
||||||
|
return nil, fmt.Errorf("relay connection is not established")
|
||||||
|
}
|
||||||
|
|
||||||
hashedID, hashedStringID := messages.HashID(dstPeerID)
|
hashedID, hashedStringID := messages.HashID(dstPeerID)
|
||||||
log.Infof("open connection to peer: %s", hashedStringID)
|
log.Infof("open connection to peer: %s", hashedStringID)
|
||||||
messageBuffer := make(chan Msg, 2)
|
messageBuffer := make(chan Msg, 2)
|
||||||
@ -93,6 +124,11 @@ func (c *Client) OpenConn(dstPeerID string) (net.Conn, error) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (c *Client) Close() error {
|
func (c *Client) Close() error {
|
||||||
|
c.ctxCancel()
|
||||||
|
return c.close()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *Client) close() error {
|
||||||
c.mu.Lock()
|
c.mu.Lock()
|
||||||
defer c.mu.Unlock()
|
defer c.mu.Unlock()
|
||||||
|
|
||||||
@ -101,11 +137,20 @@ func (c *Client) Close() error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
c.relayConnState = false
|
c.relayConnState = false
|
||||||
|
|
||||||
err := c.relayConn.Close()
|
err := c.relayConn.Close()
|
||||||
|
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *Client) handShake() error {
|
func (c *Client) handShake() error {
|
||||||
|
defer func() {
|
||||||
|
err := c.relayConn.SetReadDeadline(time.Time{})
|
||||||
|
if err != nil {
|
||||||
|
log.Errorf("failed to reset read deadline: %s", err)
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
msg, err := messages.MarshalHelloMsg(c.hashedID)
|
msg, err := messages.MarshalHelloMsg(c.hashedID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Errorf("failed to marshal hello message: %s", err)
|
log.Errorf("failed to marshal hello message: %s", err)
|
||||||
@ -145,7 +190,7 @@ func (c *Client) handShake() error {
|
|||||||
|
|
||||||
func (c *Client) readLoop() {
|
func (c *Client) readLoop() {
|
||||||
defer func() {
|
defer func() {
|
||||||
c.log.Debugf("exit from read loop")
|
c.log.Tracef("exit from read loop")
|
||||||
}()
|
}()
|
||||||
var errExit error
|
var errExit error
|
||||||
var n int
|
var n int
|
||||||
|
@ -6,6 +6,7 @@ import (
|
|||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/gorilla/websocket"
|
"github.com/gorilla/websocket"
|
||||||
|
log "github.com/sirupsen/logrus"
|
||||||
)
|
)
|
||||||
|
|
||||||
type Conn struct {
|
type Conn struct {
|
||||||
@ -51,6 +52,9 @@ func (c *Conn) SetDeadline(t time.Time) error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (c *Conn) Close() error {
|
func (c *Conn) Close() error {
|
||||||
_ = c.WriteControl(websocket.CloseMessage, websocket.FormatCloseMessage(websocket.CloseNormalClosure, ""), time.Now().Add(time.Second*5))
|
err := c.WriteControl(websocket.CloseMessage, websocket.FormatCloseMessage(websocket.CloseNormalClosure, ""), time.Now().Add(time.Second*5))
|
||||||
|
if err != nil {
|
||||||
|
log.Errorf("failed to close conn?: %s", err)
|
||||||
|
}
|
||||||
return c.Conn.Close()
|
return c.Conn.Close()
|
||||||
}
|
}
|
||||||
|
@ -2,42 +2,74 @@ package client
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
|
"net"
|
||||||
"sync"
|
"sync"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
log "github.com/sirupsen/logrus"
|
||||||
)
|
)
|
||||||
|
|
||||||
type Manager struct {
|
type Manager struct {
|
||||||
ctx context.Context
|
ctx context.Context
|
||||||
ctxCancel context.CancelFunc
|
|
||||||
srvAddress string
|
srvAddress string
|
||||||
peerID string
|
peerID string
|
||||||
|
|
||||||
wg sync.WaitGroup
|
reconnectTime time.Duration
|
||||||
|
|
||||||
clients map[string]*Client
|
mu sync.Mutex
|
||||||
clientsMutex sync.RWMutex
|
client *Client
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewManager(ctx context.Context, serverAddress string, peerID string) *Manager {
|
func NewManager(ctx context.Context, serverAddress string, peerID string) *Manager {
|
||||||
ctx, cancel := context.WithCancel(ctx)
|
|
||||||
return &Manager{
|
return &Manager{
|
||||||
ctx: ctx,
|
ctx: ctx,
|
||||||
ctxCancel: cancel,
|
srvAddress: serverAddress,
|
||||||
srvAddress: serverAddress,
|
peerID: peerID,
|
||||||
peerID: peerID,
|
reconnectTime: 5 * time.Second,
|
||||||
clients: make(map[string]*Client),
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *Manager) Teardown() {
|
func (m *Manager) Serve() {
|
||||||
m.ctxCancel()
|
ok := m.mu.TryLock()
|
||||||
m.wg.Wait()
|
if !ok {
|
||||||
}
|
|
||||||
|
|
||||||
func (m *Manager) newSrvConnection(address string) {
|
|
||||||
if _, ok := m.clients[address]; ok {
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
// client := NewClient(address, m.peerID)
|
m.client = NewClient(m.ctx, m.srvAddress, m.peerID)
|
||||||
//err = client.Connect()
|
|
||||||
|
go func() {
|
||||||
|
defer m.mu.Unlock()
|
||||||
|
|
||||||
|
// todo this is not thread safe
|
||||||
|
for {
|
||||||
|
select {
|
||||||
|
case <-m.ctx.Done():
|
||||||
|
return
|
||||||
|
default:
|
||||||
|
m.connect()
|
||||||
|
}
|
||||||
|
|
||||||
|
select {
|
||||||
|
case <-m.ctx.Done():
|
||||||
|
return
|
||||||
|
case <-time.After(2 * time.Second): //timeout
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *Manager) OpenConn(peerKey string) (net.Conn, error) {
|
||||||
|
// todo m.client nil check
|
||||||
|
return m.client.OpenConn(peerKey)
|
||||||
|
}
|
||||||
|
|
||||||
|
// connect is blocking
|
||||||
|
func (m *Manager) connect() {
|
||||||
|
err := m.client.Connect()
|
||||||
|
if err != nil {
|
||||||
|
if m.ctx.Err() != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
log.Errorf("connection error with '%s': %s", m.srvAddress, err)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
@ -23,4 +23,6 @@ func main() {
|
|||||||
log.Errorf("failed to bind server: %s", err)
|
log.Errorf("failed to bind server: %s", err)
|
||||||
os.Exit(1)
|
os.Exit(1)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
select {}
|
||||||
}
|
}
|
||||||
|
@ -61,6 +61,7 @@ func (l *Listener) Close() error {
|
|||||||
l.lock.Lock()
|
l.lock.Lock()
|
||||||
defer l.lock.Unlock()
|
defer l.lock.Unlock()
|
||||||
|
|
||||||
|
log.Infof("closing UDP server")
|
||||||
if l.listener == nil {
|
if l.listener == nil {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
@ -95,6 +96,7 @@ func (l *Listener) readLoop() {
|
|||||||
}
|
}
|
||||||
|
|
||||||
pConn = NewConn(l.listener, addr)
|
pConn = NewConn(l.listener, addr)
|
||||||
|
log.Infof("new connection from: %s", pConn.RemoteAddr())
|
||||||
l.conns[addr.String()] = pConn
|
l.conns[addr.String()] = pConn
|
||||||
go l.onAcceptFn(pConn)
|
go l.onAcceptFn(pConn)
|
||||||
pConn.onNewMsg(buf[:n])
|
pConn.onNewMsg(buf[:n])
|
||||||
|
@ -1,7 +1,9 @@
|
|||||||
package ws
|
package ws
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
|
"io"
|
||||||
"sync"
|
"sync"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
@ -24,7 +26,7 @@ func NewConn(wsConn *websocket.Conn) *Conn {
|
|||||||
func (c *Conn) Read(b []byte) (n int, err error) {
|
func (c *Conn) Read(b []byte) (n int, err error) {
|
||||||
t, r, err := c.NextReader()
|
t, r, err := c.NextReader()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return 0, err
|
return 0, ioErrHandling(err)
|
||||||
}
|
}
|
||||||
|
|
||||||
if t != websocket.BinaryMessage {
|
if t != websocket.BinaryMessage {
|
||||||
@ -32,7 +34,11 @@ func (c *Conn) Read(b []byte) (n int, err error) {
|
|||||||
return 0, fmt.Errorf("unexpected message type")
|
return 0, fmt.Errorf("unexpected message type")
|
||||||
}
|
}
|
||||||
|
|
||||||
return r.Read(b)
|
n, err = r.Read(b)
|
||||||
|
if err != nil {
|
||||||
|
return 0, ioErrHandling(err)
|
||||||
|
}
|
||||||
|
return n, err
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *Conn) Write(b []byte) (int, error) {
|
func (c *Conn) Write(b []byte) (int, error) {
|
||||||
@ -55,3 +61,14 @@ func (c *Conn) SetDeadline(t time.Time) error {
|
|||||||
}
|
}
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func ioErrHandling(err error) error {
|
||||||
|
var wErr *websocket.CloseError
|
||||||
|
if !errors.As(err, &wErr) {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
if wErr.Code == websocket.CloseNormalClosure {
|
||||||
|
return io.EOF
|
||||||
|
}
|
||||||
|
return err
|
||||||
|
}
|
@ -42,7 +42,7 @@ func (l *Listener) Listen(acceptFn func(conn net.Conn)) error {
|
|||||||
Addr: l.address,
|
Addr: l.address,
|
||||||
}
|
}
|
||||||
|
|
||||||
log.Debugf("WS server is listening on address: %s", l.address)
|
log.Infof("WS server is listening on address: %s", l.address)
|
||||||
err := l.server.ListenAndServe()
|
err := l.server.ListenAndServe()
|
||||||
if errors.Is(err, http.ErrServerClosed) {
|
if errors.Is(err, http.ErrServerClosed) {
|
||||||
return nil
|
return nil
|
||||||
@ -77,6 +77,7 @@ func (l *Listener) onAccept(writer http.ResponseWriter, request *http.Request) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
conn := NewConn(wsConn)
|
conn := NewConn(wsConn)
|
||||||
|
log.Infof("new connection from: %s", conn.RemoteAddr())
|
||||||
l.acceptFn(conn)
|
l.acceptFn(conn)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
@ -16,7 +16,6 @@ type Peer struct {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func NewPeer(id []byte, conn net.Conn) *Peer {
|
func NewPeer(id []byte, conn net.Conn) *Peer {
|
||||||
log.Debugf("new peer: %v", id)
|
|
||||||
stringID := messages.HashIDToString(id)
|
stringID := messages.HashIDToString(id)
|
||||||
return &Peer{
|
return &Peer{
|
||||||
Log: log.WithField("peer_id", stringID),
|
Log: log.WithField("peer_id", stringID),
|
||||||
|
@ -1,15 +1,18 @@
|
|||||||
package server
|
package server
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"io"
|
"io"
|
||||||
"net"
|
"net"
|
||||||
|
"sync"
|
||||||
|
|
||||||
log "github.com/sirupsen/logrus"
|
log "github.com/sirupsen/logrus"
|
||||||
|
|
||||||
"github.com/netbirdio/netbird/relay/messages"
|
"github.com/netbirdio/netbird/relay/messages"
|
||||||
"github.com/netbirdio/netbird/relay/server/listener"
|
"github.com/netbirdio/netbird/relay/server/listener"
|
||||||
"github.com/netbirdio/netbird/relay/server/listener/udp"
|
"github.com/netbirdio/netbird/relay/server/listener/udp"
|
||||||
|
"github.com/netbirdio/netbird/relay/server/listener/ws"
|
||||||
)
|
)
|
||||||
|
|
||||||
// Server
|
// Server
|
||||||
@ -19,7 +22,8 @@ import (
|
|||||||
type Server struct {
|
type Server struct {
|
||||||
store *Store
|
store *Store
|
||||||
|
|
||||||
listener listener.Listener
|
UDPListener listener.Listener
|
||||||
|
WSListener listener.Listener
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewServer() *Server {
|
func NewServer() *Server {
|
||||||
@ -29,15 +33,45 @@ func NewServer() *Server {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (r *Server) Listen(address string) error {
|
func (r *Server) Listen(address string) error {
|
||||||
r.listener = udp.NewListener(address)
|
wg := sync.WaitGroup{}
|
||||||
return r.listener.Listen(r.accept)
|
wg.Add(2)
|
||||||
|
|
||||||
|
r.WSListener = ws.NewListener(address)
|
||||||
|
var wslErr error
|
||||||
|
go func() {
|
||||||
|
defer wg.Done()
|
||||||
|
wslErr = r.WSListener.Listen(r.accept)
|
||||||
|
if wslErr != nil {
|
||||||
|
log.Errorf("failed to bind ws server: %s", wslErr)
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
r.UDPListener = udp.NewListener(address)
|
||||||
|
var udpLErr error
|
||||||
|
go func() {
|
||||||
|
defer wg.Done()
|
||||||
|
udpLErr = r.UDPListener.Listen(r.accept)
|
||||||
|
if udpLErr != nil {
|
||||||
|
log.Errorf("failed to bind ws server: %s", udpLErr)
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
err := errors.Join(wslErr, udpLErr)
|
||||||
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
func (r *Server) Close() error {
|
func (r *Server) Close() error {
|
||||||
if r.listener == nil {
|
var wErr error
|
||||||
return nil
|
if r.WSListener != nil {
|
||||||
|
wErr = r.WSListener.Close()
|
||||||
}
|
}
|
||||||
return r.listener.Close()
|
|
||||||
|
var uErr error
|
||||||
|
if r.UDPListener != nil {
|
||||||
|
uErr = r.UDPListener.Close()
|
||||||
|
}
|
||||||
|
err := errors.Join(wErr, uErr)
|
||||||
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
func (r *Server) accept(conn net.Conn) {
|
func (r *Server) accept(conn net.Conn) {
|
||||||
@ -50,12 +84,12 @@ func (r *Server) accept(conn net.Conn) {
|
|||||||
}
|
}
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
peer.Log.Debugf("peer connected from: %s", conn.RemoteAddr())
|
peer.Log.Infof("peer connected from: %s", conn.RemoteAddr())
|
||||||
|
|
||||||
r.store.AddPeer(peer)
|
r.store.AddPeer(peer)
|
||||||
defer func() {
|
defer func() {
|
||||||
peer.Log.Debugf("teardown connection")
|
|
||||||
r.store.DeletePeer(peer)
|
r.store.DeletePeer(peer)
|
||||||
|
peer.Log.Infof("peer left")
|
||||||
}()
|
}()
|
||||||
|
|
||||||
for {
|
for {
|
||||||
|
@ -1,6 +1,7 @@
|
|||||||
package test
|
package test
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"context"
|
||||||
"net"
|
"net"
|
||||||
"os"
|
"os"
|
||||||
"testing"
|
"testing"
|
||||||
@ -20,6 +21,8 @@ func TestMain(m *testing.M) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func TestClient(t *testing.T) {
|
func TestClient(t *testing.T) {
|
||||||
|
ctx := context.Background()
|
||||||
|
|
||||||
addr := "localhost:1234"
|
addr := "localhost:1234"
|
||||||
srv := server.NewServer()
|
srv := server.NewServer()
|
||||||
go func() {
|
go func() {
|
||||||
@ -36,21 +39,21 @@ func TestClient(t *testing.T) {
|
|||||||
}
|
}
|
||||||
}()
|
}()
|
||||||
|
|
||||||
clientAlice := client.NewClient(addr, "alice")
|
clientAlice := client.NewClient(ctx, addr, "alice")
|
||||||
err := clientAlice.Connect()
|
err := clientAlice.Connect()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("failed to connect to server: %s", err)
|
t.Fatalf("failed to connect to server: %s", err)
|
||||||
}
|
}
|
||||||
defer clientAlice.Close()
|
defer clientAlice.Close()
|
||||||
|
|
||||||
clientPlaceHolder := client.NewClient(addr, "clientPlaceHolder")
|
clientPlaceHolder := client.NewClient(ctx, addr, "clientPlaceHolder")
|
||||||
err = clientPlaceHolder.Connect()
|
err = clientPlaceHolder.Connect()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("failed to connect to server: %s", err)
|
t.Fatalf("failed to connect to server: %s", err)
|
||||||
}
|
}
|
||||||
defer clientPlaceHolder.Close()
|
defer clientPlaceHolder.Close()
|
||||||
|
|
||||||
clientBob := client.NewClient(addr, "bob")
|
clientBob := client.NewClient(ctx, addr, "bob")
|
||||||
err = clientBob.Connect()
|
err = clientBob.Connect()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("failed to connect to server: %s", err)
|
t.Fatalf("failed to connect to server: %s", err)
|
||||||
@ -87,6 +90,7 @@ func TestClient(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func TestRegistration(t *testing.T) {
|
func TestRegistration(t *testing.T) {
|
||||||
|
ctx := context.Background()
|
||||||
addr := "localhost:1234"
|
addr := "localhost:1234"
|
||||||
srv := server.NewServer()
|
srv := server.NewServer()
|
||||||
go func() {
|
go func() {
|
||||||
@ -103,7 +107,7 @@ func TestRegistration(t *testing.T) {
|
|||||||
}
|
}
|
||||||
}()
|
}()
|
||||||
|
|
||||||
clientAlice := client.NewClient(addr, "alice")
|
clientAlice := client.NewClient(ctx, addr, "alice")
|
||||||
err := clientAlice.Connect()
|
err := clientAlice.Connect()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("failed to connect to server: %s", err)
|
t.Fatalf("failed to connect to server: %s", err)
|
||||||
@ -117,6 +121,7 @@ func TestRegistration(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func TestRegistrationTimeout(t *testing.T) {
|
func TestRegistrationTimeout(t *testing.T) {
|
||||||
|
ctx := context.Background()
|
||||||
udpListener, err := net.ListenUDP("udp", &net.UDPAddr{
|
udpListener, err := net.ListenUDP("udp", &net.UDPAddr{
|
||||||
Port: 1234,
|
Port: 1234,
|
||||||
IP: net.ParseIP("0.0.0.0"),
|
IP: net.ParseIP("0.0.0.0"),
|
||||||
@ -135,7 +140,7 @@ func TestRegistrationTimeout(t *testing.T) {
|
|||||||
}
|
}
|
||||||
defer tcpListener.Close()
|
defer tcpListener.Close()
|
||||||
|
|
||||||
clientAlice := client.NewClient("127.0.0.1:1234", "alice")
|
clientAlice := client.NewClient(ctx, "127.0.0.1:1234", "alice")
|
||||||
err = clientAlice.Connect()
|
err = clientAlice.Connect()
|
||||||
if err == nil {
|
if err == nil {
|
||||||
t.Errorf("failed to connect to server: %s", err)
|
t.Errorf("failed to connect to server: %s", err)
|
||||||
@ -149,6 +154,7 @@ func TestRegistrationTimeout(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func TestEcho(t *testing.T) {
|
func TestEcho(t *testing.T) {
|
||||||
|
ctx := context.Background()
|
||||||
idAlice := "alice"
|
idAlice := "alice"
|
||||||
idBob := "bob"
|
idBob := "bob"
|
||||||
addr := "localhost:1234"
|
addr := "localhost:1234"
|
||||||
@ -167,7 +173,7 @@ func TestEcho(t *testing.T) {
|
|||||||
}
|
}
|
||||||
}()
|
}()
|
||||||
|
|
||||||
clientAlice := client.NewClient(addr, idAlice)
|
clientAlice := client.NewClient(ctx, addr, idAlice)
|
||||||
err := clientAlice.Connect()
|
err := clientAlice.Connect()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("failed to connect to server: %s", err)
|
t.Fatalf("failed to connect to server: %s", err)
|
||||||
@ -179,7 +185,7 @@ func TestEcho(t *testing.T) {
|
|||||||
}
|
}
|
||||||
}()
|
}()
|
||||||
|
|
||||||
clientBob := client.NewClient(addr, idBob)
|
clientBob := client.NewClient(ctx, addr, idBob)
|
||||||
err = clientBob.Connect()
|
err = clientBob.Connect()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("failed to connect to server: %s", err)
|
t.Fatalf("failed to connect to server: %s", err)
|
||||||
@ -229,6 +235,8 @@ func TestEcho(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func TestBindToUnavailabePeer(t *testing.T) {
|
func TestBindToUnavailabePeer(t *testing.T) {
|
||||||
|
ctx := context.Background()
|
||||||
|
|
||||||
addr := "localhost:1234"
|
addr := "localhost:1234"
|
||||||
srv := server.NewServer()
|
srv := server.NewServer()
|
||||||
go func() {
|
go func() {
|
||||||
@ -246,7 +254,7 @@ func TestBindToUnavailabePeer(t *testing.T) {
|
|||||||
}
|
}
|
||||||
}()
|
}()
|
||||||
|
|
||||||
clientAlice := client.NewClient(addr, "alice")
|
clientAlice := client.NewClient(ctx, addr, "alice")
|
||||||
err := clientAlice.Connect()
|
err := clientAlice.Connect()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Errorf("failed to connect to server: %s", err)
|
t.Errorf("failed to connect to server: %s", err)
|
||||||
@ -266,6 +274,8 @@ func TestBindToUnavailabePeer(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func TestBindReconnect(t *testing.T) {
|
func TestBindReconnect(t *testing.T) {
|
||||||
|
ctx := context.Background()
|
||||||
|
|
||||||
addr := "localhost:1234"
|
addr := "localhost:1234"
|
||||||
srv := server.NewServer()
|
srv := server.NewServer()
|
||||||
go func() {
|
go func() {
|
||||||
@ -283,7 +293,7 @@ func TestBindReconnect(t *testing.T) {
|
|||||||
}
|
}
|
||||||
}()
|
}()
|
||||||
|
|
||||||
clientAlice := client.NewClient(addr, "alice")
|
clientAlice := client.NewClient(ctx, addr, "alice")
|
||||||
err := clientAlice.Connect()
|
err := clientAlice.Connect()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Errorf("failed to connect to server: %s", err)
|
t.Errorf("failed to connect to server: %s", err)
|
||||||
@ -294,7 +304,7 @@ func TestBindReconnect(t *testing.T) {
|
|||||||
t.Errorf("failed to bind channel: %s", err)
|
t.Errorf("failed to bind channel: %s", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
clientBob := client.NewClient(addr, "bob")
|
clientBob := client.NewClient(ctx, addr, "bob")
|
||||||
err = clientBob.Connect()
|
err = clientBob.Connect()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Errorf("failed to connect to server: %s", err)
|
t.Errorf("failed to connect to server: %s", err)
|
||||||
@ -311,7 +321,7 @@ func TestBindReconnect(t *testing.T) {
|
|||||||
t.Errorf("failed to close client: %s", err)
|
t.Errorf("failed to close client: %s", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
clientAlice = client.NewClient(addr, "alice")
|
clientAlice = client.NewClient(ctx, addr, "alice")
|
||||||
err = clientAlice.Connect()
|
err = clientAlice.Connect()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Errorf("failed to connect to server: %s", err)
|
t.Errorf("failed to connect to server: %s", err)
|
||||||
|
57
relay/test/manager_test.go
Normal file
57
relay/test/manager_test.go
Normal file
@ -0,0 +1,57 @@
|
|||||||
|
package test
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/netbirdio/netbird/relay/client"
|
||||||
|
"github.com/netbirdio/netbird/relay/server"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestManager(t *testing.T) {
|
||||||
|
addr := "localhost:1239"
|
||||||
|
|
||||||
|
srv := server.NewServer()
|
||||||
|
go func() {
|
||||||
|
err := srv.Listen(addr)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("failed to bind server: %s", err)
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
defer func() {
|
||||||
|
err := srv.Close()
|
||||||
|
if err != nil {
|
||||||
|
t.Errorf("failed to close server: %s", err)
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
ctx, cancel := context.WithCancel(context.Background())
|
||||||
|
defer cancel()
|
||||||
|
|
||||||
|
cm := client.NewManager(ctx, addr, "me")
|
||||||
|
cm.Serve()
|
||||||
|
|
||||||
|
// wait for the relay handshake to complete
|
||||||
|
time.Sleep(1 * time.Second)
|
||||||
|
conn, err := cm.OpenConn("remotepeer")
|
||||||
|
if err != nil {
|
||||||
|
t.Errorf("failed to open connection: %s", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
readCtx, readCancel := context.WithCancel(context.Background())
|
||||||
|
defer readCancel()
|
||||||
|
go func() {
|
||||||
|
_, _ = conn.Read(make([]byte, 1))
|
||||||
|
readCancel()
|
||||||
|
}()
|
||||||
|
|
||||||
|
cancel()
|
||||||
|
|
||||||
|
select {
|
||||||
|
case <-time.After(2 * time.Second):
|
||||||
|
t.Errorf("client peer conn did not close automatically")
|
||||||
|
case <-readCtx.Done():
|
||||||
|
// conn exited well
|
||||||
|
}
|
||||||
|
}
|
Loading…
x
Reference in New Issue
Block a user