mirror of
https://github.com/netbirdio/netbird.git
synced 2025-06-21 18:22:37 +02:00
Add close message type
This commit is contained in:
parent
a40d4d2f32
commit
fed9e587af
@ -87,7 +87,7 @@ func (c *Client) Connect() error {
|
|||||||
var ctx context.Context
|
var ctx context.Context
|
||||||
ctx, c.ctxCancel = context.WithCancel(c.parentCtx)
|
ctx, c.ctxCancel = context.WithCancel(c.parentCtx)
|
||||||
context.AfterFunc(ctx, func() {
|
context.AfterFunc(ctx, func() {
|
||||||
cErr := c.Close()
|
cErr := c.close(false)
|
||||||
if cErr != nil {
|
if cErr != nil {
|
||||||
log.Errorf("failed to close relay connection: %s", cErr)
|
log.Errorf("failed to close relay connection: %s", cErr)
|
||||||
}
|
}
|
||||||
@ -144,22 +144,30 @@ func (c *Client) HasConns() bool {
|
|||||||
|
|
||||||
// Close closes the connection to the relay server and all connections to other peers.
|
// Close closes the connection to the relay server and all connections to other peers.
|
||||||
func (c *Client) Close() error {
|
func (c *Client) Close() error {
|
||||||
|
return c.close(false)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *Client) close(byServer bool) error {
|
||||||
c.readLoopMutex.Lock()
|
c.readLoopMutex.Lock()
|
||||||
defer c.readLoopMutex.Unlock()
|
defer c.readLoopMutex.Unlock()
|
||||||
|
|
||||||
c.mu.Lock()
|
c.mu.Lock()
|
||||||
var err error
|
var err error
|
||||||
if !c.serviceIsRunning {
|
if !c.serviceIsRunning {
|
||||||
|
c.mu.Unlock()
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
c.serviceIsRunning = false
|
c.serviceIsRunning = false
|
||||||
err = c.relayConn.Close()
|
|
||||||
c.closeAllConns()
|
c.closeAllConns()
|
||||||
|
if !byServer {
|
||||||
|
c.writeCloseMsg()
|
||||||
|
err = c.relayConn.Close()
|
||||||
|
}
|
||||||
c.mu.Unlock()
|
c.mu.Unlock()
|
||||||
|
|
||||||
c.wgReadLoop.Wait()
|
c.wgReadLoop.Wait()
|
||||||
c.log.Infof("relay client ha been closed: %s", c.serverAddress)
|
c.log.Infof("relay connection closed with: %s", c.serverAddress)
|
||||||
c.ctxCancel()
|
c.ctxCancel()
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
@ -232,8 +240,11 @@ func (c *Client) handShake() error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (c *Client) readLoop(relayConn net.Conn) {
|
func (c *Client) readLoop(relayConn net.Conn) {
|
||||||
var errExit error
|
var (
|
||||||
var n int
|
errExit error
|
||||||
|
n int
|
||||||
|
closedByServer bool
|
||||||
|
)
|
||||||
for {
|
for {
|
||||||
buf := make([]byte, bufferSize)
|
buf := make([]byte, bufferSize)
|
||||||
n, errExit = relayConn.Read(buf)
|
n, errExit = relayConn.Read(buf)
|
||||||
@ -243,7 +254,7 @@ func (c *Client) readLoop(relayConn net.Conn) {
|
|||||||
c.log.Debugf("failed to read message from relay server: %s", errExit)
|
c.log.Debugf("failed to read message from relay server: %s", errExit)
|
||||||
}
|
}
|
||||||
c.mu.Unlock()
|
c.mu.Unlock()
|
||||||
break
|
goto Exit
|
||||||
}
|
}
|
||||||
|
|
||||||
msgType, err := messages.DetermineServerMsgType(buf[:n])
|
msgType, err := messages.DetermineServerMsgType(buf[:n])
|
||||||
@ -264,7 +275,7 @@ func (c *Client) readLoop(relayConn net.Conn) {
|
|||||||
c.mu.Lock()
|
c.mu.Lock()
|
||||||
if !c.serviceIsRunning {
|
if !c.serviceIsRunning {
|
||||||
c.mu.Unlock()
|
c.mu.Unlock()
|
||||||
break
|
goto Exit
|
||||||
}
|
}
|
||||||
container, ok := c.conns[stringID]
|
container, ok := c.conns[stringID]
|
||||||
c.mu.Unlock()
|
c.mu.Unlock()
|
||||||
@ -273,16 +284,19 @@ func (c *Client) readLoop(relayConn net.Conn) {
|
|||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// todo review is this can cause panic
|
||||||
container.messages <- Msg{buf[:n]}
|
container.messages <- Msg{buf[:n]}
|
||||||
|
case messages.MsgClose:
|
||||||
|
closedByServer = true
|
||||||
|
log.Debugf("relay connection close by server")
|
||||||
|
goto Exit
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
Exit:
|
||||||
c.notifyDisconnected()
|
c.notifyDisconnected()
|
||||||
|
|
||||||
c.log.Tracef("exit from read loop")
|
|
||||||
c.wgReadLoop.Done()
|
c.wgReadLoop.Done()
|
||||||
|
_ = c.close(closedByServer)
|
||||||
c.Close()
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// todo check by reference too, the id is not enought because the id come from the outer conn
|
// todo check by reference too, the id is not enought because the id come from the outer conn
|
||||||
@ -365,3 +379,11 @@ func (c *Client) notifyDisconnected() {
|
|||||||
}
|
}
|
||||||
go c.onDisconnectListener()
|
go c.onDisconnectListener()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (c *Client) writeCloseMsg() {
|
||||||
|
msg := messages.MarshalCloseMsg()
|
||||||
|
_, err := c.relayConn.Write(msg)
|
||||||
|
if err != nil {
|
||||||
|
c.log.Errorf("failed to send close message: %s", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
@ -100,57 +100,56 @@ func TestRegistration(t *testing.T) {
|
|||||||
}
|
}
|
||||||
}()
|
}()
|
||||||
|
|
||||||
defer func() {
|
|
||||||
err := srv.Close()
|
|
||||||
if err != nil {
|
|
||||||
t.Errorf("failed to close server: %s", err)
|
|
||||||
}
|
|
||||||
}()
|
|
||||||
|
|
||||||
clientAlice := NewClient(ctx, addr, "alice")
|
clientAlice := NewClient(ctx, addr, "alice")
|
||||||
err := clientAlice.Connect()
|
err := clientAlice.Connect()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
_ = srv.Close()
|
||||||
t.Fatalf("failed to connect to server: %s", err)
|
t.Fatalf("failed to connect to server: %s", err)
|
||||||
}
|
}
|
||||||
defer func() {
|
|
||||||
err = clientAlice.Close()
|
err = clientAlice.Close()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Errorf("failed to close conn: %s", err)
|
t.Errorf("failed to close conn: %s", err)
|
||||||
}
|
}
|
||||||
}()
|
err = srv.Close()
|
||||||
|
if err != nil {
|
||||||
|
t.Errorf("failed to close server: %s", err)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestRegistrationTimeout(t *testing.T) {
|
func TestRegistrationTimeout(t *testing.T) {
|
||||||
ctx := context.Background()
|
ctx := context.Background()
|
||||||
udpListener, err := net.ListenUDP("udp", &net.UDPAddr{
|
fakeUDPListener, err := net.ListenUDP("udp", &net.UDPAddr{
|
||||||
Port: 1234,
|
Port: 1234,
|
||||||
IP: net.ParseIP("0.0.0.0"),
|
IP: net.ParseIP("0.0.0.0"),
|
||||||
})
|
})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("failed to bind UDP server: %s", err)
|
t.Fatalf("failed to bind UDP server: %s", err)
|
||||||
}
|
}
|
||||||
defer udpListener.Close()
|
defer func(fakeUDPListener *net.UDPConn) {
|
||||||
|
_ = fakeUDPListener.Close()
|
||||||
|
}(fakeUDPListener)
|
||||||
|
|
||||||
tcpListener, err := net.ListenTCP("tcp", &net.TCPAddr{
|
fakeTCPListener, err := net.ListenTCP("tcp", &net.TCPAddr{
|
||||||
Port: 1234,
|
Port: 1234,
|
||||||
IP: net.ParseIP("0.0.0.0"),
|
IP: net.ParseIP("0.0.0.0"),
|
||||||
})
|
})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("failed to bind TCP server: %s", err)
|
t.Fatalf("failed to bind TCP server: %s", err)
|
||||||
}
|
}
|
||||||
defer tcpListener.Close()
|
defer func(fakeTCPListener *net.TCPListener) {
|
||||||
|
_ = fakeTCPListener.Close()
|
||||||
|
}(fakeTCPListener)
|
||||||
|
|
||||||
clientAlice := NewClient(ctx, "127.0.0.1:1234", "alice")
|
clientAlice := NewClient(ctx, "127.0.0.1:1234", "alice")
|
||||||
err = clientAlice.Connect()
|
err = clientAlice.Connect()
|
||||||
if err == nil {
|
if err == nil {
|
||||||
t.Errorf("failed to connect to server: %s", err)
|
t.Errorf("failed to connect to server: %s", err)
|
||||||
}
|
}
|
||||||
defer func() {
|
log.Debugf("%s", err)
|
||||||
err = clientAlice.Close()
|
err = clientAlice.Close()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Errorf("failed to close conn: %s", err)
|
t.Errorf("failed to close conn: %s", err)
|
||||||
}
|
}
|
||||||
}()
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestEcho(t *testing.T) {
|
func TestEcho(t *testing.T) {
|
||||||
@ -259,18 +258,16 @@ func TestBindToUnavailabePeer(t *testing.T) {
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
t.Errorf("failed to connect to server: %s", err)
|
t.Errorf("failed to connect to server: %s", err)
|
||||||
}
|
}
|
||||||
defer func() {
|
|
||||||
log.Infof("closing client")
|
|
||||||
err := clientAlice.Close()
|
|
||||||
if err != nil {
|
|
||||||
t.Errorf("failed to close client: %s", err)
|
|
||||||
}
|
|
||||||
}()
|
|
||||||
|
|
||||||
_, err = clientAlice.OpenConn("bob")
|
_, err = clientAlice.OpenConn("bob")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Errorf("failed to bind channel: %s", err)
|
t.Errorf("failed to bind channel: %s", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
log.Infof("closing client")
|
||||||
|
err = clientAlice.Close()
|
||||||
|
if err != nil {
|
||||||
|
t.Errorf("failed to close client: %s", err)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestBindReconnect(t *testing.T) {
|
func TestBindReconnect(t *testing.T) {
|
||||||
@ -315,7 +312,7 @@ func TestBindReconnect(t *testing.T) {
|
|||||||
t.Errorf("failed to bind channel: %s", err)
|
t.Errorf("failed to bind channel: %s", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
log.Infof("closing client")
|
log.Infof("closing client Alice")
|
||||||
err = clientAlice.Close()
|
err = clientAlice.Close()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Errorf("failed to close client: %s", err)
|
t.Errorf("failed to close client: %s", err)
|
||||||
@ -403,52 +400,6 @@ func TestCloseConn(t *testing.T) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestAutoReconnect(t *testing.T) {
|
|
||||||
ctx := context.Background()
|
|
||||||
|
|
||||||
addr := "localhost:1234"
|
|
||||||
srv := server.NewServer()
|
|
||||||
go func() {
|
|
||||||
err := srv.Listen(addr)
|
|
||||||
if err != nil {
|
|
||||||
t.Errorf("failed to bind server: %s", err)
|
|
||||||
}
|
|
||||||
}()
|
|
||||||
|
|
||||||
defer func() {
|
|
||||||
err := srv.Close()
|
|
||||||
if err != nil {
|
|
||||||
log.Errorf("failed to close server: %s", err)
|
|
||||||
}
|
|
||||||
}()
|
|
||||||
|
|
||||||
clientAlice := NewClient(ctx, addr, "alice")
|
|
||||||
err := clientAlice.Connect()
|
|
||||||
if err != nil {
|
|
||||||
t.Errorf("failed to connect to server: %s", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
conn, err := clientAlice.OpenConn("bob")
|
|
||||||
if err != nil {
|
|
||||||
t.Errorf("failed to bind channel: %s", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
_ = clientAlice.relayConn.Close()
|
|
||||||
|
|
||||||
_, err = conn.Read(make([]byte, 1))
|
|
||||||
if err == nil {
|
|
||||||
t.Errorf("unexpected reading from closed connection")
|
|
||||||
}
|
|
||||||
|
|
||||||
log.Infof("waiting for reconnection")
|
|
||||||
time.Sleep(reconnectingTimeout)
|
|
||||||
|
|
||||||
_, err = clientAlice.OpenConn("bob")
|
|
||||||
if err != nil {
|
|
||||||
t.Errorf("failed to open channel: %s", err)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestCloseRelayConn(t *testing.T) {
|
func TestCloseRelayConn(t *testing.T) {
|
||||||
ctx := context.Background()
|
ctx := context.Background()
|
||||||
|
|
||||||
@ -491,3 +442,82 @@ func TestCloseRelayConn(t *testing.T) {
|
|||||||
t.Errorf("unexpected opening connection to closed server")
|
t.Errorf("unexpected opening connection to closed server")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestCloseByServer(t *testing.T) {
|
||||||
|
ctx := context.Background()
|
||||||
|
|
||||||
|
addr1 := "localhost:1234"
|
||||||
|
srv1 := server.NewServer()
|
||||||
|
go func() {
|
||||||
|
err := srv1.Listen(addr1)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("failed to bind server: %s", err)
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
idAlice := "alice"
|
||||||
|
log.Debugf("connect by alice")
|
||||||
|
relayClient := NewClient(ctx, addr1, idAlice)
|
||||||
|
err := relayClient.Connect()
|
||||||
|
if err != nil {
|
||||||
|
log.Fatalf("failed to connect to server: %s", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
disconnected := make(chan struct{})
|
||||||
|
relayClient.SetOnDisconnectListener(func() {
|
||||||
|
log.Infof("client disconnected")
|
||||||
|
close(disconnected)
|
||||||
|
})
|
||||||
|
|
||||||
|
err = srv1.Close()
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("failed to close server: %s", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
select {
|
||||||
|
case <-disconnected:
|
||||||
|
case <-time.After(3 * time.Second):
|
||||||
|
log.Fatalf("timeout waiting for client to disconnect")
|
||||||
|
}
|
||||||
|
|
||||||
|
_, err = relayClient.OpenConn("bob")
|
||||||
|
if err == nil {
|
||||||
|
t.Errorf("unexpected opening connection to closed server")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestCloseByClient(t *testing.T) {
|
||||||
|
ctx := context.Background()
|
||||||
|
|
||||||
|
addr1 := "localhost:1234"
|
||||||
|
srv := server.NewServer()
|
||||||
|
go func() {
|
||||||
|
err := srv.Listen(addr1)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("failed to bind server: %s", err)
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
idAlice := "alice"
|
||||||
|
log.Debugf("connect by alice")
|
||||||
|
relayClient := NewClient(ctx, addr1, idAlice)
|
||||||
|
err := relayClient.Connect()
|
||||||
|
if err != nil {
|
||||||
|
log.Fatalf("failed to connect to server: %s", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
err = relayClient.Close()
|
||||||
|
if err != nil {
|
||||||
|
t.Errorf("failed to close client: %s", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
_, err = relayClient.OpenConn("bob")
|
||||||
|
if err == nil {
|
||||||
|
t.Errorf("unexpected opening connection to closed server")
|
||||||
|
}
|
||||||
|
|
||||||
|
err = srv.Close()
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("failed to close server: %s", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
@ -6,7 +6,6 @@ import (
|
|||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/gorilla/websocket"
|
"github.com/gorilla/websocket"
|
||||||
log "github.com/sirupsen/logrus"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
type Conn struct {
|
type Conn struct {
|
||||||
@ -52,9 +51,5 @@ func (c *Conn) SetDeadline(t time.Time) error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (c *Conn) Close() error {
|
func (c *Conn) Close() error {
|
||||||
err := c.WriteControl(websocket.CloseMessage, websocket.FormatCloseMessage(websocket.CloseNormalClosure, ""), time.Now().Add(time.Second*5))
|
|
||||||
if err != nil {
|
|
||||||
log.Errorf("failed to close conn?: %s", err)
|
|
||||||
}
|
|
||||||
return c.Conn.Close()
|
return c.Conn.Close()
|
||||||
}
|
}
|
||||||
|
@ -3,13 +3,17 @@ package ws
|
|||||||
import (
|
import (
|
||||||
"fmt"
|
"fmt"
|
||||||
"net"
|
"net"
|
||||||
|
"time"
|
||||||
|
|
||||||
"github.com/gorilla/websocket"
|
"github.com/gorilla/websocket"
|
||||||
)
|
)
|
||||||
|
|
||||||
func Dial(address string) (net.Conn, error) {
|
func Dial(address string) (net.Conn, error) {
|
||||||
addr := fmt.Sprintf("ws://" + address)
|
addr := fmt.Sprintf("ws://" + address)
|
||||||
wsConn, _, err := websocket.DefaultDialer.Dial(addr, nil)
|
wsDialer := websocket.Dialer{
|
||||||
|
HandshakeTimeout: 3 * time.Second,
|
||||||
|
}
|
||||||
|
wsConn, _, err := wsDialer.Dial(addr, nil)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
@ -84,8 +84,10 @@ func (m *Manager) OpenConn(serverAddress, peerKey string) (net.Conn, error) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
if !foreign {
|
if !foreign {
|
||||||
|
log.Debugf("open connection to permanent server: %s", peerKey)
|
||||||
return m.relayClient.OpenConn(peerKey)
|
return m.relayClient.OpenConn(peerKey)
|
||||||
} else {
|
} else {
|
||||||
|
log.Debugf("open connection to foreign server: %s", serverAddress)
|
||||||
return m.openConnVia(serverAddress, peerKey)
|
return m.openConnVia(serverAddress, peerKey)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -47,12 +47,14 @@ func TestForeignConn(t *testing.T) {
|
|||||||
|
|
||||||
idAlice := "alice"
|
idAlice := "alice"
|
||||||
log.Debugf("connect by alice")
|
log.Debugf("connect by alice")
|
||||||
clientAlice := NewManager(ctx, addr1, idAlice)
|
mCtx, cancel := context.WithCancel(ctx)
|
||||||
|
defer cancel()
|
||||||
|
clientAlice := NewManager(mCtx, addr1, idAlice)
|
||||||
clientAlice.Serve()
|
clientAlice.Serve()
|
||||||
|
|
||||||
idBob := "bob"
|
idBob := "bob"
|
||||||
log.Debugf("connect by bob")
|
log.Debugf("connect by bob")
|
||||||
clientBob := NewManager(ctx, addr2, idBob)
|
clientBob := NewManager(mCtx, addr2, idBob)
|
||||||
clientBob.Serve()
|
clientBob.Serve()
|
||||||
|
|
||||||
bobsSrvAddr, err := clientBob.RelayAddress()
|
bobsSrvAddr, err := clientBob.RelayAddress()
|
||||||
@ -132,61 +134,9 @@ func TestForeginConnClose(t *testing.T) {
|
|||||||
|
|
||||||
idAlice := "alice"
|
idAlice := "alice"
|
||||||
log.Debugf("connect by alice")
|
log.Debugf("connect by alice")
|
||||||
clientAlice := NewManager(ctx, addr1, idAlice)
|
mCtx, cancel := context.WithCancel(ctx)
|
||||||
clientAlice.Serve()
|
defer cancel()
|
||||||
|
mgr := NewManager(mCtx, addr1, idAlice)
|
||||||
conn, err := clientAlice.OpenConn(addr2, "anotherpeer")
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("failed to bind channel: %s", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
err = conn.Close()
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("failed to close connection: %s", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
select {}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestForeginAutoClose(t *testing.T) {
|
|
||||||
ctx := context.Background()
|
|
||||||
|
|
||||||
addr1 := "localhost:1234"
|
|
||||||
srv1 := server.NewServer()
|
|
||||||
go func() {
|
|
||||||
err := srv1.Listen(addr1)
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("failed to bind server: %s", err)
|
|
||||||
}
|
|
||||||
}()
|
|
||||||
|
|
||||||
defer func() {
|
|
||||||
err := srv1.Close()
|
|
||||||
if err != nil {
|
|
||||||
t.Errorf("failed to close server: %s", err)
|
|
||||||
}
|
|
||||||
}()
|
|
||||||
|
|
||||||
addr2 := "localhost:2234"
|
|
||||||
srv2 := server.NewServer()
|
|
||||||
go func() {
|
|
||||||
err := srv2.Listen(addr2)
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("failed to bind server: %s", err)
|
|
||||||
}
|
|
||||||
}()
|
|
||||||
|
|
||||||
defer func() {
|
|
||||||
err := srv2.Close()
|
|
||||||
if err != nil {
|
|
||||||
t.Errorf("failed to close server: %s", err)
|
|
||||||
}
|
|
||||||
}()
|
|
||||||
|
|
||||||
idAlice := "alice"
|
|
||||||
log.Debugf("connect by alice")
|
|
||||||
mgr := NewManager(ctx, addr1, idAlice)
|
|
||||||
relayCleanupInterval = 2 * time.Second
|
|
||||||
mgr.Serve()
|
mgr.Serve()
|
||||||
|
|
||||||
conn, err := mgr.OpenConn(addr2, "anotherpeer")
|
conn, err := mgr.OpenConn(addr2, "anotherpeer")
|
||||||
@ -198,9 +148,124 @@ func TestForeginAutoClose(t *testing.T) {
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("failed to close connection: %s", err)
|
t.Fatalf("failed to close connection: %s", err)
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestForeginAutoClose(t *testing.T) {
|
||||||
|
ctx := context.Background()
|
||||||
|
relayCleanupInterval = 1 * time.Second
|
||||||
|
addr1 := "localhost:1234"
|
||||||
|
srv1 := server.NewServer()
|
||||||
|
go func() {
|
||||||
|
t.Log("binding server 1.")
|
||||||
|
err := srv1.Listen(addr1)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("failed to bind server: %s", err)
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
defer func() {
|
||||||
|
t.Logf("closing server 1.")
|
||||||
|
err := srv1.Close()
|
||||||
|
if err != nil {
|
||||||
|
t.Errorf("failed to close server: %s", err)
|
||||||
|
}
|
||||||
|
t.Logf("server 1. closed")
|
||||||
|
}()
|
||||||
|
|
||||||
|
addr2 := "localhost:2234"
|
||||||
|
srv2 := server.NewServer()
|
||||||
|
go func() {
|
||||||
|
t.Log("binding server 2.")
|
||||||
|
err := srv2.Listen(addr2)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("failed to bind server: %s", err)
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
defer func() {
|
||||||
|
t.Logf("closing server 2.")
|
||||||
|
err := srv2.Close()
|
||||||
|
if err != nil {
|
||||||
|
t.Errorf("failed to close server: %s", err)
|
||||||
|
}
|
||||||
|
t.Logf("server 2 closed.")
|
||||||
|
}()
|
||||||
|
|
||||||
|
idAlice := "alice"
|
||||||
|
t.Log("connect to server 1.")
|
||||||
|
mCtx, cancel := context.WithCancel(ctx)
|
||||||
|
defer cancel()
|
||||||
|
mgr := NewManager(mCtx, addr1, idAlice)
|
||||||
|
mgr.Serve()
|
||||||
|
|
||||||
|
t.Log("open connection to another peer")
|
||||||
|
conn, err := mgr.OpenConn(addr2, "anotherpeer")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("failed to bind channel: %s", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
t.Log("close conn")
|
||||||
|
err = conn.Close()
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("failed to close connection: %s", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
t.Logf("waiting for relay cleanup: %s", relayCleanupInterval+1*time.Second)
|
||||||
time.Sleep(relayCleanupInterval + 1*time.Second)
|
time.Sleep(relayCleanupInterval + 1*time.Second)
|
||||||
if len(mgr.relayClients) != 0 {
|
if len(mgr.relayClients) != 0 {
|
||||||
t.Errorf("expected 0, got %d", len(mgr.relayClients))
|
t.Errorf("expected 0, got %d", len(mgr.relayClients))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
t.Logf("closing manager")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestAutoReconnect(t *testing.T) {
|
||||||
|
ctx := context.Background()
|
||||||
|
reconnectingTimeout = 2 * time.Second
|
||||||
|
|
||||||
|
addr := "localhost:1234"
|
||||||
|
srv := server.NewServer()
|
||||||
|
go func() {
|
||||||
|
err := srv.Listen(addr)
|
||||||
|
if err != nil {
|
||||||
|
t.Errorf("failed to bind server: %s", err)
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
defer func() {
|
||||||
|
err := srv.Close()
|
||||||
|
if err != nil {
|
||||||
|
log.Errorf("failed to close server: %s", err)
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
mCtx, cancel := context.WithCancel(ctx)
|
||||||
|
defer cancel()
|
||||||
|
clientAlice := NewManager(mCtx, addr, "alice")
|
||||||
|
clientAlice.Serve()
|
||||||
|
ra, err := clientAlice.RelayAddress()
|
||||||
|
if err != nil {
|
||||||
|
t.Errorf("failed to get relay address: %s", err)
|
||||||
|
}
|
||||||
|
conn, err := clientAlice.OpenConn(ra.String(), "bob")
|
||||||
|
if err != nil {
|
||||||
|
t.Errorf("failed to bind channel: %s", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
t.Log("closing client relay connection")
|
||||||
|
// todo figure out moc server
|
||||||
|
_ = clientAlice.relayClient.relayConn.Close()
|
||||||
|
t.Log("start test reading")
|
||||||
|
_, err = conn.Read(make([]byte, 1))
|
||||||
|
if err == nil {
|
||||||
|
t.Errorf("unexpected reading from closed connection")
|
||||||
|
}
|
||||||
|
|
||||||
|
log.Infof("waiting for reconnection")
|
||||||
|
time.Sleep(reconnectingTimeout + 1*time.Second)
|
||||||
|
|
||||||
|
log.Infof("reopent the connection")
|
||||||
|
_, err = clientAlice.OpenConn(ra.String(), "bob")
|
||||||
|
if err != nil {
|
||||||
|
t.Errorf("failed to open channel: %s", err)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
@ -10,6 +10,7 @@ const (
|
|||||||
MsgTypeHello MsgType = 0
|
MsgTypeHello MsgType = 0
|
||||||
MsgTypeHelloResponse MsgType = 1
|
MsgTypeHelloResponse MsgType = 1
|
||||||
MsgTypeTransport MsgType = 2
|
MsgTypeTransport MsgType = 2
|
||||||
|
MsgClose MsgType = 3
|
||||||
)
|
)
|
||||||
|
|
||||||
var (
|
var (
|
||||||
@ -26,6 +27,8 @@ func (m MsgType) String() string {
|
|||||||
return "hello response"
|
return "hello response"
|
||||||
case MsgTypeTransport:
|
case MsgTypeTransport:
|
||||||
return "transport"
|
return "transport"
|
||||||
|
case MsgClose:
|
||||||
|
return "close"
|
||||||
default:
|
default:
|
||||||
return "unknown"
|
return "unknown"
|
||||||
}
|
}
|
||||||
@ -39,6 +42,8 @@ func DetermineClientMsgType(msg []byte) (MsgType, error) {
|
|||||||
return msgType, nil
|
return msgType, nil
|
||||||
case MsgTypeTransport:
|
case MsgTypeTransport:
|
||||||
return msgType, nil
|
return msgType, nil
|
||||||
|
case MsgClose:
|
||||||
|
return msgType, nil
|
||||||
default:
|
default:
|
||||||
return 0, fmt.Errorf("invalid msg type, len: %d", len(msg))
|
return 0, fmt.Errorf("invalid msg type, len: %d", len(msg))
|
||||||
}
|
}
|
||||||
@ -52,6 +57,8 @@ func DetermineServerMsgType(msg []byte) (MsgType, error) {
|
|||||||
return msgType, nil
|
return msgType, nil
|
||||||
case MsgTypeTransport:
|
case MsgTypeTransport:
|
||||||
return msgType, nil
|
return msgType, nil
|
||||||
|
case MsgClose:
|
||||||
|
return msgType, nil
|
||||||
default:
|
default:
|
||||||
return 0, fmt.Errorf("invalid msg type (len: %d)", len(msg))
|
return 0, fmt.Errorf("invalid msg type (len: %d)", len(msg))
|
||||||
}
|
}
|
||||||
@ -81,6 +88,14 @@ func MarshalHelloResponse() []byte {
|
|||||||
return msg
|
return msg
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Close message
|
||||||
|
|
||||||
|
func MarshalCloseMsg() []byte {
|
||||||
|
msg := make([]byte, 1)
|
||||||
|
msg[0] = byte(MsgClose)
|
||||||
|
return msg
|
||||||
|
}
|
||||||
|
|
||||||
// Transport message
|
// Transport message
|
||||||
|
|
||||||
func MarshalTransportMsg(peerID []byte, payload []byte) []byte {
|
func MarshalTransportMsg(peerID []byte, payload []byte) []byte {
|
||||||
|
@ -5,4 +5,5 @@ import "net"
|
|||||||
type Listener interface {
|
type Listener interface {
|
||||||
Listen(func(conn net.Conn)) error
|
Listen(func(conn net.Conn)) error
|
||||||
Close() error
|
Close() error
|
||||||
|
WaitForExitAcceptedConns()
|
||||||
}
|
}
|
||||||
|
@ -21,6 +21,11 @@ type Listener struct {
|
|||||||
lock sync.Mutex
|
lock sync.Mutex
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (l *Listener) WaitForExitAcceptedConns() {
|
||||||
|
l.wg.Wait()
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
func NewListener(address string) listener.Listener {
|
func NewListener(address string) listener.Listener {
|
||||||
return &Listener{
|
return &Listener{
|
||||||
address: address,
|
address: address,
|
||||||
@ -61,11 +66,11 @@ func (l *Listener) Close() error {
|
|||||||
l.lock.Lock()
|
l.lock.Lock()
|
||||||
defer l.lock.Unlock()
|
defer l.lock.Unlock()
|
||||||
|
|
||||||
log.Infof("closing UDP server")
|
|
||||||
if l.listener == nil {
|
if l.listener == nil {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
log.Infof("closing UDP listener")
|
||||||
close(l.quit)
|
close(l.quit)
|
||||||
err := l.listener.Close()
|
err := l.listener.Close()
|
||||||
l.wg.Wait()
|
l.wg.Wait()
|
||||||
|
@ -33,8 +33,11 @@ func NewListener(address string) listener.Listener {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Listen todo: prevent multiple call
|
|
||||||
func (l *Listener) Listen(acceptFn func(conn net.Conn)) error {
|
func (l *Listener) Listen(acceptFn func(conn net.Conn)) error {
|
||||||
|
if l.server != nil {
|
||||||
|
return errors.New("server is already running")
|
||||||
|
}
|
||||||
|
|
||||||
l.acceptFn = acceptFn
|
l.acceptFn = acceptFn
|
||||||
mux := http.NewServeMux()
|
mux := http.NewServeMux()
|
||||||
mux.HandleFunc("/", l.onAccept)
|
mux.HandleFunc("/", l.onAccept)
|
||||||
@ -69,6 +72,10 @@ func (l *Listener) Close() error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (l *Listener) WaitForExitAcceptedConns() {
|
||||||
|
l.wg.Wait()
|
||||||
|
}
|
||||||
|
|
||||||
func (l *Listener) onAccept(writer http.ResponseWriter, request *http.Request) {
|
func (l *Listener) onAccept(writer http.ResponseWriter, request *http.Request) {
|
||||||
l.wg.Add(1)
|
l.wg.Add(1)
|
||||||
defer l.wg.Done()
|
defer l.wg.Done()
|
||||||
|
@ -6,6 +6,7 @@ import (
|
|||||||
"fmt"
|
"fmt"
|
||||||
"io"
|
"io"
|
||||||
"net"
|
"net"
|
||||||
|
"sync"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
log "github.com/sirupsen/logrus"
|
log "github.com/sirupsen/logrus"
|
||||||
@ -17,6 +18,8 @@ type Conn struct {
|
|||||||
lAddr *net.TCPAddr
|
lAddr *net.TCPAddr
|
||||||
rAddr *net.TCPAddr
|
rAddr *net.TCPAddr
|
||||||
|
|
||||||
|
closed bool
|
||||||
|
closedMu sync.Mutex
|
||||||
ctx context.Context
|
ctx context.Context
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -32,7 +35,7 @@ func NewConn(wsConn *websocket.Conn, lAddr, rAddr *net.TCPAddr) *Conn {
|
|||||||
func (c *Conn) Read(b []byte) (n int, err error) {
|
func (c *Conn) Read(b []byte) (n int, err error) {
|
||||||
t, r, err := c.Reader(c.ctx)
|
t, r, err := c.Reader(c.ctx)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return 0, ioErrHandling(err)
|
return 0, c.ioErrHandling(err)
|
||||||
}
|
}
|
||||||
|
|
||||||
if t != websocket.MessageBinary {
|
if t != websocket.MessageBinary {
|
||||||
@ -42,7 +45,7 @@ func (c *Conn) Read(b []byte) (n int, err error) {
|
|||||||
|
|
||||||
n, err = r.Read(b)
|
n, err = r.Read(b)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return 0, ioErrHandling(err)
|
return 0, c.ioErrHandling(err)
|
||||||
}
|
}
|
||||||
return n, err
|
return n, err
|
||||||
}
|
}
|
||||||
@ -76,11 +79,23 @@ func (c *Conn) SetDeadline(t time.Time) error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (c *Conn) Close() error {
|
func (c *Conn) Close() error {
|
||||||
return c.Conn.Close(websocket.StatusNormalClosure, "")
|
c.closedMu.Lock()
|
||||||
|
c.closed = true
|
||||||
|
c.closedMu.Unlock()
|
||||||
|
return c.Conn.CloseNow()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *Conn) isClosed() bool {
|
||||||
|
c.closedMu.Lock()
|
||||||
|
defer c.closedMu.Unlock()
|
||||||
|
return c.closed
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *Conn) ioErrHandling(err error) error {
|
||||||
|
if c.isClosed() {
|
||||||
|
return io.EOF
|
||||||
}
|
}
|
||||||
|
|
||||||
// todo: fix io.EOF handling
|
|
||||||
func ioErrHandling(err error) error {
|
|
||||||
var wErr *websocket.CloseError
|
var wErr *websocket.CloseError
|
||||||
if !errors.As(err, &wErr) {
|
if !errors.As(err, &wErr) {
|
||||||
return err
|
return err
|
||||||
|
@ -56,15 +56,18 @@ func (l *Listener) Close() error {
|
|||||||
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||||
defer cancel()
|
defer cancel()
|
||||||
|
|
||||||
log.Debugf("closing WS server")
|
log.Infof("stop WS listener")
|
||||||
if err := l.server.Shutdown(ctx); err != nil {
|
if err := l.server.Shutdown(ctx); err != nil {
|
||||||
return fmt.Errorf("server shutdown failed: %v", err)
|
return fmt.Errorf("server shutdown failed: %v", err)
|
||||||
}
|
}
|
||||||
|
log.Infof("WS listener stopped")
|
||||||
l.wg.Wait()
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (l *Listener) WaitForExitAcceptedConns() {
|
||||||
|
l.wg.Wait()
|
||||||
|
}
|
||||||
|
|
||||||
func (l *Listener) onAccept(w http.ResponseWriter, r *http.Request) {
|
func (l *Listener) onAccept(w http.ResponseWriter, r *http.Request) {
|
||||||
l.wg.Add(1)
|
l.wg.Add(1)
|
||||||
defer l.wg.Done()
|
defer l.wg.Done()
|
||||||
|
@ -15,11 +15,9 @@ import (
|
|||||||
ws "github.com/netbirdio/netbird/relay/server/listener/wsnhooyr"
|
ws "github.com/netbirdio/netbird/relay/server/listener/wsnhooyr"
|
||||||
)
|
)
|
||||||
|
|
||||||
// Server
|
|
||||||
// todo:
|
|
||||||
// authentication: provide JWT token via RPC call. The MGM server can forward the token to the agents.
|
|
||||||
type Server struct {
|
type Server struct {
|
||||||
store *Store
|
store *Store
|
||||||
|
storeMu sync.RWMutex
|
||||||
|
|
||||||
UDPListener listener.Listener
|
UDPListener listener.Listener
|
||||||
WSListener listener.Listener
|
WSListener listener.Listener
|
||||||
@ -28,6 +26,7 @@ type Server struct {
|
|||||||
func NewServer() *Server {
|
func NewServer() *Server {
|
||||||
return &Server{
|
return &Server{
|
||||||
store: NewStore(),
|
store: NewStore(),
|
||||||
|
storeMu: sync.RWMutex{},
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -69,6 +68,11 @@ func (r *Server) Close() error {
|
|||||||
if r.UDPListener != nil {
|
if r.UDPListener != nil {
|
||||||
uErr = r.UDPListener.Close()
|
uErr = r.UDPListener.Close()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
r.sendCloseMsgs()
|
||||||
|
|
||||||
|
r.WSListener.WaitForExitAcceptedConns()
|
||||||
|
|
||||||
err := errors.Join(wErr, uErr)
|
err := errors.Join(wErr, uErr)
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
@ -88,7 +92,7 @@ func (r *Server) accept(conn net.Conn) {
|
|||||||
r.store.AddPeer(peer)
|
r.store.AddPeer(peer)
|
||||||
defer func() {
|
defer func() {
|
||||||
r.store.DeletePeer(peer)
|
r.store.DeletePeer(peer)
|
||||||
peer.Log.Infof("peer left")
|
peer.Log.Infof("relay connection closed")
|
||||||
}()
|
}()
|
||||||
|
|
||||||
for {
|
for {
|
||||||
@ -132,10 +136,33 @@ func (r *Server) accept(conn net.Conn) {
|
|||||||
}
|
}
|
||||||
return
|
return
|
||||||
}()
|
}()
|
||||||
|
case messages.MsgClose:
|
||||||
|
peer.Log.Infof("peer disconnected gracefully")
|
||||||
|
_ = conn.Close()
|
||||||
|
return
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (r *Server) sendCloseMsgs() {
|
||||||
|
msg := messages.MarshalCloseMsg()
|
||||||
|
|
||||||
|
r.storeMu.Lock()
|
||||||
|
log.Debugf("sending close messages to %d peers", len(r.store.peers))
|
||||||
|
for _, p := range r.store.peers {
|
||||||
|
_, err := p.conn.Write(msg)
|
||||||
|
if err != nil {
|
||||||
|
log.Errorf("failed to send close message to peer: %s", p.String())
|
||||||
|
}
|
||||||
|
|
||||||
|
err = p.conn.Close()
|
||||||
|
if err != nil {
|
||||||
|
log.Errorf("failed to close connection to peer: %s", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
r.storeMu.Unlock()
|
||||||
|
}
|
||||||
|
|
||||||
func handShake(conn net.Conn) (*Peer, error) {
|
func handShake(conn net.Conn) (*Peer, error) {
|
||||||
buf := make([]byte, 1500)
|
buf := make([]byte, 1500)
|
||||||
n, err := conn.Read(buf)
|
n, err := conn.Read(buf)
|
||||||
|
@ -24,7 +24,6 @@ func (s *Store) AddPeer(peer *Peer) {
|
|||||||
func (s *Store) DeletePeer(peer *Peer) {
|
func (s *Store) DeletePeer(peer *Peer) {
|
||||||
s.peersLock.Lock()
|
s.peersLock.Lock()
|
||||||
defer s.peersLock.Unlock()
|
defer s.peersLock.Unlock()
|
||||||
|
|
||||||
delete(s.peers, peer.String())
|
delete(s.peers, peer.String())
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -35,3 +34,14 @@ func (s *Store) Peer(id string) (*Peer, bool) {
|
|||||||
p, ok := s.peers[id]
|
p, ok := s.peers[id]
|
||||||
return p, ok
|
return p, ok
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (s *Store) Peers() []*Peer {
|
||||||
|
s.peersLock.RLock()
|
||||||
|
defer s.peersLock.RUnlock()
|
||||||
|
|
||||||
|
peers := make([]*Peer, 0, len(s.peers))
|
||||||
|
for _, p := range s.peers {
|
||||||
|
peers = append(peers, p)
|
||||||
|
}
|
||||||
|
return peers
|
||||||
|
}
|
||||||
|
Loading…
x
Reference in New Issue
Block a user