mirror of
https://github.com/netbirdio/netbird.git
synced 2025-06-21 18:22:37 +02:00
Add close message type
This commit is contained in:
parent
a40d4d2f32
commit
fed9e587af
@ -87,7 +87,7 @@ func (c *Client) Connect() error {
|
||||
var ctx context.Context
|
||||
ctx, c.ctxCancel = context.WithCancel(c.parentCtx)
|
||||
context.AfterFunc(ctx, func() {
|
||||
cErr := c.Close()
|
||||
cErr := c.close(false)
|
||||
if cErr != nil {
|
||||
log.Errorf("failed to close relay connection: %s", cErr)
|
||||
}
|
||||
@ -144,22 +144,30 @@ func (c *Client) HasConns() bool {
|
||||
|
||||
// Close closes the connection to the relay server and all connections to other peers.
|
||||
func (c *Client) Close() error {
|
||||
return c.close(false)
|
||||
}
|
||||
|
||||
func (c *Client) close(byServer bool) error {
|
||||
c.readLoopMutex.Lock()
|
||||
defer c.readLoopMutex.Unlock()
|
||||
|
||||
c.mu.Lock()
|
||||
var err error
|
||||
if !c.serviceIsRunning {
|
||||
c.mu.Unlock()
|
||||
return nil
|
||||
}
|
||||
|
||||
c.serviceIsRunning = false
|
||||
err = c.relayConn.Close()
|
||||
c.closeAllConns()
|
||||
if !byServer {
|
||||
c.writeCloseMsg()
|
||||
err = c.relayConn.Close()
|
||||
}
|
||||
c.mu.Unlock()
|
||||
|
||||
c.wgReadLoop.Wait()
|
||||
c.log.Infof("relay client ha been closed: %s", c.serverAddress)
|
||||
c.log.Infof("relay connection closed with: %s", c.serverAddress)
|
||||
c.ctxCancel()
|
||||
return err
|
||||
}
|
||||
@ -232,8 +240,11 @@ func (c *Client) handShake() error {
|
||||
}
|
||||
|
||||
func (c *Client) readLoop(relayConn net.Conn) {
|
||||
var errExit error
|
||||
var n int
|
||||
var (
|
||||
errExit error
|
||||
n int
|
||||
closedByServer bool
|
||||
)
|
||||
for {
|
||||
buf := make([]byte, bufferSize)
|
||||
n, errExit = relayConn.Read(buf)
|
||||
@ -243,7 +254,7 @@ func (c *Client) readLoop(relayConn net.Conn) {
|
||||
c.log.Debugf("failed to read message from relay server: %s", errExit)
|
||||
}
|
||||
c.mu.Unlock()
|
||||
break
|
||||
goto Exit
|
||||
}
|
||||
|
||||
msgType, err := messages.DetermineServerMsgType(buf[:n])
|
||||
@ -264,7 +275,7 @@ func (c *Client) readLoop(relayConn net.Conn) {
|
||||
c.mu.Lock()
|
||||
if !c.serviceIsRunning {
|
||||
c.mu.Unlock()
|
||||
break
|
||||
goto Exit
|
||||
}
|
||||
container, ok := c.conns[stringID]
|
||||
c.mu.Unlock()
|
||||
@ -273,16 +284,19 @@ func (c *Client) readLoop(relayConn net.Conn) {
|
||||
continue
|
||||
}
|
||||
|
||||
// todo review is this can cause panic
|
||||
container.messages <- Msg{buf[:n]}
|
||||
case messages.MsgClose:
|
||||
closedByServer = true
|
||||
log.Debugf("relay connection close by server")
|
||||
goto Exit
|
||||
}
|
||||
}
|
||||
|
||||
Exit:
|
||||
c.notifyDisconnected()
|
||||
|
||||
c.log.Tracef("exit from read loop")
|
||||
c.wgReadLoop.Done()
|
||||
|
||||
c.Close()
|
||||
_ = c.close(closedByServer)
|
||||
}
|
||||
|
||||
// todo check by reference too, the id is not enought because the id come from the outer conn
|
||||
@ -365,3 +379,11 @@ func (c *Client) notifyDisconnected() {
|
||||
}
|
||||
go c.onDisconnectListener()
|
||||
}
|
||||
|
||||
func (c *Client) writeCloseMsg() {
|
||||
msg := messages.MarshalCloseMsg()
|
||||
_, err := c.relayConn.Write(msg)
|
||||
if err != nil {
|
||||
c.log.Errorf("failed to send close message: %s", err)
|
||||
}
|
||||
}
|
||||
|
@ -100,57 +100,56 @@ func TestRegistration(t *testing.T) {
|
||||
}
|
||||
}()
|
||||
|
||||
defer func() {
|
||||
err := srv.Close()
|
||||
if err != nil {
|
||||
t.Errorf("failed to close server: %s", err)
|
||||
}
|
||||
}()
|
||||
|
||||
clientAlice := NewClient(ctx, addr, "alice")
|
||||
err := clientAlice.Connect()
|
||||
if err != nil {
|
||||
_ = srv.Close()
|
||||
t.Fatalf("failed to connect to server: %s", err)
|
||||
}
|
||||
defer func() {
|
||||
err = clientAlice.Close()
|
||||
if err != nil {
|
||||
t.Errorf("failed to close conn: %s", err)
|
||||
}
|
||||
}()
|
||||
err = clientAlice.Close()
|
||||
if err != nil {
|
||||
t.Errorf("failed to close conn: %s", err)
|
||||
}
|
||||
err = srv.Close()
|
||||
if err != nil {
|
||||
t.Errorf("failed to close server: %s", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRegistrationTimeout(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
udpListener, err := net.ListenUDP("udp", &net.UDPAddr{
|
||||
fakeUDPListener, err := net.ListenUDP("udp", &net.UDPAddr{
|
||||
Port: 1234,
|
||||
IP: net.ParseIP("0.0.0.0"),
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("failed to bind UDP server: %s", err)
|
||||
}
|
||||
defer udpListener.Close()
|
||||
defer func(fakeUDPListener *net.UDPConn) {
|
||||
_ = fakeUDPListener.Close()
|
||||
}(fakeUDPListener)
|
||||
|
||||
tcpListener, err := net.ListenTCP("tcp", &net.TCPAddr{
|
||||
fakeTCPListener, err := net.ListenTCP("tcp", &net.TCPAddr{
|
||||
Port: 1234,
|
||||
IP: net.ParseIP("0.0.0.0"),
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("failed to bind TCP server: %s", err)
|
||||
}
|
||||
defer tcpListener.Close()
|
||||
defer func(fakeTCPListener *net.TCPListener) {
|
||||
_ = fakeTCPListener.Close()
|
||||
}(fakeTCPListener)
|
||||
|
||||
clientAlice := NewClient(ctx, "127.0.0.1:1234", "alice")
|
||||
err = clientAlice.Connect()
|
||||
if err == nil {
|
||||
t.Errorf("failed to connect to server: %s", err)
|
||||
}
|
||||
defer func() {
|
||||
err = clientAlice.Close()
|
||||
if err != nil {
|
||||
t.Errorf("failed to close conn: %s", err)
|
||||
}
|
||||
}()
|
||||
log.Debugf("%s", err)
|
||||
err = clientAlice.Close()
|
||||
if err != nil {
|
||||
t.Errorf("failed to close conn: %s", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestEcho(t *testing.T) {
|
||||
@ -259,18 +258,16 @@ func TestBindToUnavailabePeer(t *testing.T) {
|
||||
if err != nil {
|
||||
t.Errorf("failed to connect to server: %s", err)
|
||||
}
|
||||
defer func() {
|
||||
log.Infof("closing client")
|
||||
err := clientAlice.Close()
|
||||
if err != nil {
|
||||
t.Errorf("failed to close client: %s", err)
|
||||
}
|
||||
}()
|
||||
|
||||
_, err = clientAlice.OpenConn("bob")
|
||||
if err != nil {
|
||||
t.Errorf("failed to bind channel: %s", err)
|
||||
}
|
||||
|
||||
log.Infof("closing client")
|
||||
err = clientAlice.Close()
|
||||
if err != nil {
|
||||
t.Errorf("failed to close client: %s", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestBindReconnect(t *testing.T) {
|
||||
@ -315,7 +312,7 @@ func TestBindReconnect(t *testing.T) {
|
||||
t.Errorf("failed to bind channel: %s", err)
|
||||
}
|
||||
|
||||
log.Infof("closing client")
|
||||
log.Infof("closing client Alice")
|
||||
err = clientAlice.Close()
|
||||
if err != nil {
|
||||
t.Errorf("failed to close client: %s", err)
|
||||
@ -403,52 +400,6 @@ func TestCloseConn(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestAutoReconnect(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
|
||||
addr := "localhost:1234"
|
||||
srv := server.NewServer()
|
||||
go func() {
|
||||
err := srv.Listen(addr)
|
||||
if err != nil {
|
||||
t.Errorf("failed to bind server: %s", err)
|
||||
}
|
||||
}()
|
||||
|
||||
defer func() {
|
||||
err := srv.Close()
|
||||
if err != nil {
|
||||
log.Errorf("failed to close server: %s", err)
|
||||
}
|
||||
}()
|
||||
|
||||
clientAlice := NewClient(ctx, addr, "alice")
|
||||
err := clientAlice.Connect()
|
||||
if err != nil {
|
||||
t.Errorf("failed to connect to server: %s", err)
|
||||
}
|
||||
|
||||
conn, err := clientAlice.OpenConn("bob")
|
||||
if err != nil {
|
||||
t.Errorf("failed to bind channel: %s", err)
|
||||
}
|
||||
|
||||
_ = clientAlice.relayConn.Close()
|
||||
|
||||
_, err = conn.Read(make([]byte, 1))
|
||||
if err == nil {
|
||||
t.Errorf("unexpected reading from closed connection")
|
||||
}
|
||||
|
||||
log.Infof("waiting for reconnection")
|
||||
time.Sleep(reconnectingTimeout)
|
||||
|
||||
_, err = clientAlice.OpenConn("bob")
|
||||
if err != nil {
|
||||
t.Errorf("failed to open channel: %s", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestCloseRelayConn(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
|
||||
@ -491,3 +442,82 @@ func TestCloseRelayConn(t *testing.T) {
|
||||
t.Errorf("unexpected opening connection to closed server")
|
||||
}
|
||||
}
|
||||
|
||||
func TestCloseByServer(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
|
||||
addr1 := "localhost:1234"
|
||||
srv1 := server.NewServer()
|
||||
go func() {
|
||||
err := srv1.Listen(addr1)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to bind server: %s", err)
|
||||
}
|
||||
}()
|
||||
|
||||
idAlice := "alice"
|
||||
log.Debugf("connect by alice")
|
||||
relayClient := NewClient(ctx, addr1, idAlice)
|
||||
err := relayClient.Connect()
|
||||
if err != nil {
|
||||
log.Fatalf("failed to connect to server: %s", err)
|
||||
}
|
||||
|
||||
disconnected := make(chan struct{})
|
||||
relayClient.SetOnDisconnectListener(func() {
|
||||
log.Infof("client disconnected")
|
||||
close(disconnected)
|
||||
})
|
||||
|
||||
err = srv1.Close()
|
||||
if err != nil {
|
||||
t.Fatalf("failed to close server: %s", err)
|
||||
}
|
||||
|
||||
select {
|
||||
case <-disconnected:
|
||||
case <-time.After(3 * time.Second):
|
||||
log.Fatalf("timeout waiting for client to disconnect")
|
||||
}
|
||||
|
||||
_, err = relayClient.OpenConn("bob")
|
||||
if err == nil {
|
||||
t.Errorf("unexpected opening connection to closed server")
|
||||
}
|
||||
}
|
||||
|
||||
func TestCloseByClient(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
|
||||
addr1 := "localhost:1234"
|
||||
srv := server.NewServer()
|
||||
go func() {
|
||||
err := srv.Listen(addr1)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to bind server: %s", err)
|
||||
}
|
||||
}()
|
||||
|
||||
idAlice := "alice"
|
||||
log.Debugf("connect by alice")
|
||||
relayClient := NewClient(ctx, addr1, idAlice)
|
||||
err := relayClient.Connect()
|
||||
if err != nil {
|
||||
log.Fatalf("failed to connect to server: %s", err)
|
||||
}
|
||||
|
||||
err = relayClient.Close()
|
||||
if err != nil {
|
||||
t.Errorf("failed to close client: %s", err)
|
||||
}
|
||||
|
||||
_, err = relayClient.OpenConn("bob")
|
||||
if err == nil {
|
||||
t.Errorf("unexpected opening connection to closed server")
|
||||
}
|
||||
|
||||
err = srv.Close()
|
||||
if err != nil {
|
||||
t.Fatalf("failed to close server: %s", err)
|
||||
}
|
||||
}
|
||||
|
@ -6,7 +6,6 @@ import (
|
||||
"time"
|
||||
|
||||
"github.com/gorilla/websocket"
|
||||
log "github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
type Conn struct {
|
||||
@ -52,9 +51,5 @@ func (c *Conn) SetDeadline(t time.Time) error {
|
||||
}
|
||||
|
||||
func (c *Conn) Close() error {
|
||||
err := c.WriteControl(websocket.CloseMessage, websocket.FormatCloseMessage(websocket.CloseNormalClosure, ""), time.Now().Add(time.Second*5))
|
||||
if err != nil {
|
||||
log.Errorf("failed to close conn?: %s", err)
|
||||
}
|
||||
return c.Conn.Close()
|
||||
}
|
||||
|
@ -3,13 +3,17 @@ package ws
|
||||
import (
|
||||
"fmt"
|
||||
"net"
|
||||
"time"
|
||||
|
||||
"github.com/gorilla/websocket"
|
||||
)
|
||||
|
||||
func Dial(address string) (net.Conn, error) {
|
||||
addr := fmt.Sprintf("ws://" + address)
|
||||
wsConn, _, err := websocket.DefaultDialer.Dial(addr, nil)
|
||||
wsDialer := websocket.Dialer{
|
||||
HandshakeTimeout: 3 * time.Second,
|
||||
}
|
||||
wsConn, _, err := wsDialer.Dial(addr, nil)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
@ -84,8 +84,10 @@ func (m *Manager) OpenConn(serverAddress, peerKey string) (net.Conn, error) {
|
||||
}
|
||||
|
||||
if !foreign {
|
||||
log.Debugf("open connection to permanent server: %s", peerKey)
|
||||
return m.relayClient.OpenConn(peerKey)
|
||||
} else {
|
||||
log.Debugf("open connection to foreign server: %s", serverAddress)
|
||||
return m.openConnVia(serverAddress, peerKey)
|
||||
}
|
||||
}
|
||||
|
@ -47,12 +47,14 @@ func TestForeignConn(t *testing.T) {
|
||||
|
||||
idAlice := "alice"
|
||||
log.Debugf("connect by alice")
|
||||
clientAlice := NewManager(ctx, addr1, idAlice)
|
||||
mCtx, cancel := context.WithCancel(ctx)
|
||||
defer cancel()
|
||||
clientAlice := NewManager(mCtx, addr1, idAlice)
|
||||
clientAlice.Serve()
|
||||
|
||||
idBob := "bob"
|
||||
log.Debugf("connect by bob")
|
||||
clientBob := NewManager(ctx, addr2, idBob)
|
||||
clientBob := NewManager(mCtx, addr2, idBob)
|
||||
clientBob.Serve()
|
||||
|
||||
bobsSrvAddr, err := clientBob.RelayAddress()
|
||||
@ -132,61 +134,9 @@ func TestForeginConnClose(t *testing.T) {
|
||||
|
||||
idAlice := "alice"
|
||||
log.Debugf("connect by alice")
|
||||
clientAlice := NewManager(ctx, addr1, idAlice)
|
||||
clientAlice.Serve()
|
||||
|
||||
conn, err := clientAlice.OpenConn(addr2, "anotherpeer")
|
||||
if err != nil {
|
||||
t.Fatalf("failed to bind channel: %s", err)
|
||||
}
|
||||
|
||||
err = conn.Close()
|
||||
if err != nil {
|
||||
t.Fatalf("failed to close connection: %s", err)
|
||||
}
|
||||
|
||||
select {}
|
||||
}
|
||||
|
||||
func TestForeginAutoClose(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
|
||||
addr1 := "localhost:1234"
|
||||
srv1 := server.NewServer()
|
||||
go func() {
|
||||
err := srv1.Listen(addr1)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to bind server: %s", err)
|
||||
}
|
||||
}()
|
||||
|
||||
defer func() {
|
||||
err := srv1.Close()
|
||||
if err != nil {
|
||||
t.Errorf("failed to close server: %s", err)
|
||||
}
|
||||
}()
|
||||
|
||||
addr2 := "localhost:2234"
|
||||
srv2 := server.NewServer()
|
||||
go func() {
|
||||
err := srv2.Listen(addr2)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to bind server: %s", err)
|
||||
}
|
||||
}()
|
||||
|
||||
defer func() {
|
||||
err := srv2.Close()
|
||||
if err != nil {
|
||||
t.Errorf("failed to close server: %s", err)
|
||||
}
|
||||
}()
|
||||
|
||||
idAlice := "alice"
|
||||
log.Debugf("connect by alice")
|
||||
mgr := NewManager(ctx, addr1, idAlice)
|
||||
relayCleanupInterval = 2 * time.Second
|
||||
mCtx, cancel := context.WithCancel(ctx)
|
||||
defer cancel()
|
||||
mgr := NewManager(mCtx, addr1, idAlice)
|
||||
mgr.Serve()
|
||||
|
||||
conn, err := mgr.OpenConn(addr2, "anotherpeer")
|
||||
@ -198,9 +148,124 @@ func TestForeginAutoClose(t *testing.T) {
|
||||
if err != nil {
|
||||
t.Fatalf("failed to close connection: %s", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestForeginAutoClose(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
relayCleanupInterval = 1 * time.Second
|
||||
addr1 := "localhost:1234"
|
||||
srv1 := server.NewServer()
|
||||
go func() {
|
||||
t.Log("binding server 1.")
|
||||
err := srv1.Listen(addr1)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to bind server: %s", err)
|
||||
}
|
||||
}()
|
||||
|
||||
defer func() {
|
||||
t.Logf("closing server 1.")
|
||||
err := srv1.Close()
|
||||
if err != nil {
|
||||
t.Errorf("failed to close server: %s", err)
|
||||
}
|
||||
t.Logf("server 1. closed")
|
||||
}()
|
||||
|
||||
addr2 := "localhost:2234"
|
||||
srv2 := server.NewServer()
|
||||
go func() {
|
||||
t.Log("binding server 2.")
|
||||
err := srv2.Listen(addr2)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to bind server: %s", err)
|
||||
}
|
||||
}()
|
||||
defer func() {
|
||||
t.Logf("closing server 2.")
|
||||
err := srv2.Close()
|
||||
if err != nil {
|
||||
t.Errorf("failed to close server: %s", err)
|
||||
}
|
||||
t.Logf("server 2 closed.")
|
||||
}()
|
||||
|
||||
idAlice := "alice"
|
||||
t.Log("connect to server 1.")
|
||||
mCtx, cancel := context.WithCancel(ctx)
|
||||
defer cancel()
|
||||
mgr := NewManager(mCtx, addr1, idAlice)
|
||||
mgr.Serve()
|
||||
|
||||
t.Log("open connection to another peer")
|
||||
conn, err := mgr.OpenConn(addr2, "anotherpeer")
|
||||
if err != nil {
|
||||
t.Fatalf("failed to bind channel: %s", err)
|
||||
}
|
||||
|
||||
t.Log("close conn")
|
||||
err = conn.Close()
|
||||
if err != nil {
|
||||
t.Fatalf("failed to close connection: %s", err)
|
||||
}
|
||||
|
||||
t.Logf("waiting for relay cleanup: %s", relayCleanupInterval+1*time.Second)
|
||||
time.Sleep(relayCleanupInterval + 1*time.Second)
|
||||
if len(mgr.relayClients) != 0 {
|
||||
t.Errorf("expected 0, got %d", len(mgr.relayClients))
|
||||
}
|
||||
|
||||
t.Logf("closing manager")
|
||||
}
|
||||
|
||||
func TestAutoReconnect(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
reconnectingTimeout = 2 * time.Second
|
||||
|
||||
addr := "localhost:1234"
|
||||
srv := server.NewServer()
|
||||
go func() {
|
||||
err := srv.Listen(addr)
|
||||
if err != nil {
|
||||
t.Errorf("failed to bind server: %s", err)
|
||||
}
|
||||
}()
|
||||
|
||||
defer func() {
|
||||
err := srv.Close()
|
||||
if err != nil {
|
||||
log.Errorf("failed to close server: %s", err)
|
||||
}
|
||||
}()
|
||||
|
||||
mCtx, cancel := context.WithCancel(ctx)
|
||||
defer cancel()
|
||||
clientAlice := NewManager(mCtx, addr, "alice")
|
||||
clientAlice.Serve()
|
||||
ra, err := clientAlice.RelayAddress()
|
||||
if err != nil {
|
||||
t.Errorf("failed to get relay address: %s", err)
|
||||
}
|
||||
conn, err := clientAlice.OpenConn(ra.String(), "bob")
|
||||
if err != nil {
|
||||
t.Errorf("failed to bind channel: %s", err)
|
||||
}
|
||||
|
||||
t.Log("closing client relay connection")
|
||||
// todo figure out moc server
|
||||
_ = clientAlice.relayClient.relayConn.Close()
|
||||
t.Log("start test reading")
|
||||
_, err = conn.Read(make([]byte, 1))
|
||||
if err == nil {
|
||||
t.Errorf("unexpected reading from closed connection")
|
||||
}
|
||||
|
||||
log.Infof("waiting for reconnection")
|
||||
time.Sleep(reconnectingTimeout + 1*time.Second)
|
||||
|
||||
log.Infof("reopent the connection")
|
||||
_, err = clientAlice.OpenConn(ra.String(), "bob")
|
||||
if err != nil {
|
||||
t.Errorf("failed to open channel: %s", err)
|
||||
}
|
||||
}
|
||||
|
@ -10,6 +10,7 @@ const (
|
||||
MsgTypeHello MsgType = 0
|
||||
MsgTypeHelloResponse MsgType = 1
|
||||
MsgTypeTransport MsgType = 2
|
||||
MsgClose MsgType = 3
|
||||
)
|
||||
|
||||
var (
|
||||
@ -26,6 +27,8 @@ func (m MsgType) String() string {
|
||||
return "hello response"
|
||||
case MsgTypeTransport:
|
||||
return "transport"
|
||||
case MsgClose:
|
||||
return "close"
|
||||
default:
|
||||
return "unknown"
|
||||
}
|
||||
@ -39,6 +42,8 @@ func DetermineClientMsgType(msg []byte) (MsgType, error) {
|
||||
return msgType, nil
|
||||
case MsgTypeTransport:
|
||||
return msgType, nil
|
||||
case MsgClose:
|
||||
return msgType, nil
|
||||
default:
|
||||
return 0, fmt.Errorf("invalid msg type, len: %d", len(msg))
|
||||
}
|
||||
@ -52,6 +57,8 @@ func DetermineServerMsgType(msg []byte) (MsgType, error) {
|
||||
return msgType, nil
|
||||
case MsgTypeTransport:
|
||||
return msgType, nil
|
||||
case MsgClose:
|
||||
return msgType, nil
|
||||
default:
|
||||
return 0, fmt.Errorf("invalid msg type (len: %d)", len(msg))
|
||||
}
|
||||
@ -81,6 +88,14 @@ func MarshalHelloResponse() []byte {
|
||||
return msg
|
||||
}
|
||||
|
||||
// Close message
|
||||
|
||||
func MarshalCloseMsg() []byte {
|
||||
msg := make([]byte, 1)
|
||||
msg[0] = byte(MsgClose)
|
||||
return msg
|
||||
}
|
||||
|
||||
// Transport message
|
||||
|
||||
func MarshalTransportMsg(peerID []byte, payload []byte) []byte {
|
||||
|
@ -5,4 +5,5 @@ import "net"
|
||||
type Listener interface {
|
||||
Listen(func(conn net.Conn)) error
|
||||
Close() error
|
||||
WaitForExitAcceptedConns()
|
||||
}
|
||||
|
@ -21,6 +21,11 @@ type Listener struct {
|
||||
lock sync.Mutex
|
||||
}
|
||||
|
||||
func (l *Listener) WaitForExitAcceptedConns() {
|
||||
l.wg.Wait()
|
||||
return
|
||||
}
|
||||
|
||||
func NewListener(address string) listener.Listener {
|
||||
return &Listener{
|
||||
address: address,
|
||||
@ -61,11 +66,11 @@ func (l *Listener) Close() error {
|
||||
l.lock.Lock()
|
||||
defer l.lock.Unlock()
|
||||
|
||||
log.Infof("closing UDP server")
|
||||
if l.listener == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
log.Infof("closing UDP listener")
|
||||
close(l.quit)
|
||||
err := l.listener.Close()
|
||||
l.wg.Wait()
|
||||
|
@ -33,8 +33,11 @@ func NewListener(address string) listener.Listener {
|
||||
}
|
||||
}
|
||||
|
||||
// Listen todo: prevent multiple call
|
||||
func (l *Listener) Listen(acceptFn func(conn net.Conn)) error {
|
||||
if l.server != nil {
|
||||
return errors.New("server is already running")
|
||||
}
|
||||
|
||||
l.acceptFn = acceptFn
|
||||
mux := http.NewServeMux()
|
||||
mux.HandleFunc("/", l.onAccept)
|
||||
@ -69,6 +72,10 @@ func (l *Listener) Close() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (l *Listener) WaitForExitAcceptedConns() {
|
||||
l.wg.Wait()
|
||||
}
|
||||
|
||||
func (l *Listener) onAccept(writer http.ResponseWriter, request *http.Request) {
|
||||
l.wg.Add(1)
|
||||
defer l.wg.Done()
|
||||
|
@ -6,6 +6,7 @@ import (
|
||||
"fmt"
|
||||
"io"
|
||||
"net"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
@ -17,7 +18,9 @@ type Conn struct {
|
||||
lAddr *net.TCPAddr
|
||||
rAddr *net.TCPAddr
|
||||
|
||||
ctx context.Context
|
||||
closed bool
|
||||
closedMu sync.Mutex
|
||||
ctx context.Context
|
||||
}
|
||||
|
||||
func NewConn(wsConn *websocket.Conn, lAddr, rAddr *net.TCPAddr) *Conn {
|
||||
@ -32,7 +35,7 @@ func NewConn(wsConn *websocket.Conn, lAddr, rAddr *net.TCPAddr) *Conn {
|
||||
func (c *Conn) Read(b []byte) (n int, err error) {
|
||||
t, r, err := c.Reader(c.ctx)
|
||||
if err != nil {
|
||||
return 0, ioErrHandling(err)
|
||||
return 0, c.ioErrHandling(err)
|
||||
}
|
||||
|
||||
if t != websocket.MessageBinary {
|
||||
@ -42,7 +45,7 @@ func (c *Conn) Read(b []byte) (n int, err error) {
|
||||
|
||||
n, err = r.Read(b)
|
||||
if err != nil {
|
||||
return 0, ioErrHandling(err)
|
||||
return 0, c.ioErrHandling(err)
|
||||
}
|
||||
return n, err
|
||||
}
|
||||
@ -76,11 +79,23 @@ func (c *Conn) SetDeadline(t time.Time) error {
|
||||
}
|
||||
|
||||
func (c *Conn) Close() error {
|
||||
return c.Conn.Close(websocket.StatusNormalClosure, "")
|
||||
c.closedMu.Lock()
|
||||
c.closed = true
|
||||
c.closedMu.Unlock()
|
||||
return c.Conn.CloseNow()
|
||||
}
|
||||
|
||||
// todo: fix io.EOF handling
|
||||
func ioErrHandling(err error) error {
|
||||
func (c *Conn) isClosed() bool {
|
||||
c.closedMu.Lock()
|
||||
defer c.closedMu.Unlock()
|
||||
return c.closed
|
||||
}
|
||||
|
||||
func (c *Conn) ioErrHandling(err error) error {
|
||||
if c.isClosed() {
|
||||
return io.EOF
|
||||
}
|
||||
|
||||
var wErr *websocket.CloseError
|
||||
if !errors.As(err, &wErr) {
|
||||
return err
|
||||
|
@ -56,15 +56,18 @@ func (l *Listener) Close() error {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||
defer cancel()
|
||||
|
||||
log.Debugf("closing WS server")
|
||||
log.Infof("stop WS listener")
|
||||
if err := l.server.Shutdown(ctx); err != nil {
|
||||
return fmt.Errorf("server shutdown failed: %v", err)
|
||||
}
|
||||
|
||||
l.wg.Wait()
|
||||
log.Infof("WS listener stopped")
|
||||
return nil
|
||||
}
|
||||
|
||||
func (l *Listener) WaitForExitAcceptedConns() {
|
||||
l.wg.Wait()
|
||||
}
|
||||
|
||||
func (l *Listener) onAccept(w http.ResponseWriter, r *http.Request) {
|
||||
l.wg.Add(1)
|
||||
defer l.wg.Done()
|
||||
|
@ -15,11 +15,9 @@ import (
|
||||
ws "github.com/netbirdio/netbird/relay/server/listener/wsnhooyr"
|
||||
)
|
||||
|
||||
// Server
|
||||
// todo:
|
||||
// authentication: provide JWT token via RPC call. The MGM server can forward the token to the agents.
|
||||
type Server struct {
|
||||
store *Store
|
||||
store *Store
|
||||
storeMu sync.RWMutex
|
||||
|
||||
UDPListener listener.Listener
|
||||
WSListener listener.Listener
|
||||
@ -27,7 +25,8 @@ type Server struct {
|
||||
|
||||
func NewServer() *Server {
|
||||
return &Server{
|
||||
store: NewStore(),
|
||||
store: NewStore(),
|
||||
storeMu: sync.RWMutex{},
|
||||
}
|
||||
}
|
||||
|
||||
@ -69,6 +68,11 @@ func (r *Server) Close() error {
|
||||
if r.UDPListener != nil {
|
||||
uErr = r.UDPListener.Close()
|
||||
}
|
||||
|
||||
r.sendCloseMsgs()
|
||||
|
||||
r.WSListener.WaitForExitAcceptedConns()
|
||||
|
||||
err := errors.Join(wErr, uErr)
|
||||
return err
|
||||
}
|
||||
@ -88,7 +92,7 @@ func (r *Server) accept(conn net.Conn) {
|
||||
r.store.AddPeer(peer)
|
||||
defer func() {
|
||||
r.store.DeletePeer(peer)
|
||||
peer.Log.Infof("peer left")
|
||||
peer.Log.Infof("relay connection closed")
|
||||
}()
|
||||
|
||||
for {
|
||||
@ -132,10 +136,33 @@ func (r *Server) accept(conn net.Conn) {
|
||||
}
|
||||
return
|
||||
}()
|
||||
case messages.MsgClose:
|
||||
peer.Log.Infof("peer disconnected gracefully")
|
||||
_ = conn.Close()
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (r *Server) sendCloseMsgs() {
|
||||
msg := messages.MarshalCloseMsg()
|
||||
|
||||
r.storeMu.Lock()
|
||||
log.Debugf("sending close messages to %d peers", len(r.store.peers))
|
||||
for _, p := range r.store.peers {
|
||||
_, err := p.conn.Write(msg)
|
||||
if err != nil {
|
||||
log.Errorf("failed to send close message to peer: %s", p.String())
|
||||
}
|
||||
|
||||
err = p.conn.Close()
|
||||
if err != nil {
|
||||
log.Errorf("failed to close connection to peer: %s", err)
|
||||
}
|
||||
}
|
||||
r.storeMu.Unlock()
|
||||
}
|
||||
|
||||
func handShake(conn net.Conn) (*Peer, error) {
|
||||
buf := make([]byte, 1500)
|
||||
n, err := conn.Read(buf)
|
||||
|
@ -24,7 +24,6 @@ func (s *Store) AddPeer(peer *Peer) {
|
||||
func (s *Store) DeletePeer(peer *Peer) {
|
||||
s.peersLock.Lock()
|
||||
defer s.peersLock.Unlock()
|
||||
|
||||
delete(s.peers, peer.String())
|
||||
}
|
||||
|
||||
@ -35,3 +34,14 @@ func (s *Store) Peer(id string) (*Peer, bool) {
|
||||
p, ok := s.peers[id]
|
||||
return p, ok
|
||||
}
|
||||
|
||||
func (s *Store) Peers() []*Peer {
|
||||
s.peersLock.RLock()
|
||||
defer s.peersLock.RUnlock()
|
||||
|
||||
peers := make([]*Peer, 0, len(s.peers))
|
||||
for _, p := range s.peers {
|
||||
peers = append(peers, p)
|
||||
}
|
||||
return peers
|
||||
}
|
||||
|
Loading…
x
Reference in New Issue
Block a user