Add close message type

This commit is contained in:
Zoltan Papp 2024-06-05 19:49:30 +02:00
parent a40d4d2f32
commit fed9e587af
14 changed files with 371 additions and 170 deletions

View File

@ -87,7 +87,7 @@ func (c *Client) Connect() error {
var ctx context.Context var ctx context.Context
ctx, c.ctxCancel = context.WithCancel(c.parentCtx) ctx, c.ctxCancel = context.WithCancel(c.parentCtx)
context.AfterFunc(ctx, func() { context.AfterFunc(ctx, func() {
cErr := c.Close() cErr := c.close(false)
if cErr != nil { if cErr != nil {
log.Errorf("failed to close relay connection: %s", cErr) log.Errorf("failed to close relay connection: %s", cErr)
} }
@ -144,22 +144,30 @@ func (c *Client) HasConns() bool {
// Close closes the connection to the relay server and all connections to other peers. // Close closes the connection to the relay server and all connections to other peers.
func (c *Client) Close() error { func (c *Client) Close() error {
return c.close(false)
}
func (c *Client) close(byServer bool) error {
c.readLoopMutex.Lock() c.readLoopMutex.Lock()
defer c.readLoopMutex.Unlock() defer c.readLoopMutex.Unlock()
c.mu.Lock() c.mu.Lock()
var err error var err error
if !c.serviceIsRunning { if !c.serviceIsRunning {
c.mu.Unlock()
return nil return nil
} }
c.serviceIsRunning = false c.serviceIsRunning = false
err = c.relayConn.Close()
c.closeAllConns() c.closeAllConns()
if !byServer {
c.writeCloseMsg()
err = c.relayConn.Close()
}
c.mu.Unlock() c.mu.Unlock()
c.wgReadLoop.Wait() c.wgReadLoop.Wait()
c.log.Infof("relay client ha been closed: %s", c.serverAddress) c.log.Infof("relay connection closed with: %s", c.serverAddress)
c.ctxCancel() c.ctxCancel()
return err return err
} }
@ -232,8 +240,11 @@ func (c *Client) handShake() error {
} }
func (c *Client) readLoop(relayConn net.Conn) { func (c *Client) readLoop(relayConn net.Conn) {
var errExit error var (
var n int errExit error
n int
closedByServer bool
)
for { for {
buf := make([]byte, bufferSize) buf := make([]byte, bufferSize)
n, errExit = relayConn.Read(buf) n, errExit = relayConn.Read(buf)
@ -243,7 +254,7 @@ func (c *Client) readLoop(relayConn net.Conn) {
c.log.Debugf("failed to read message from relay server: %s", errExit) c.log.Debugf("failed to read message from relay server: %s", errExit)
} }
c.mu.Unlock() c.mu.Unlock()
break goto Exit
} }
msgType, err := messages.DetermineServerMsgType(buf[:n]) msgType, err := messages.DetermineServerMsgType(buf[:n])
@ -264,7 +275,7 @@ func (c *Client) readLoop(relayConn net.Conn) {
c.mu.Lock() c.mu.Lock()
if !c.serviceIsRunning { if !c.serviceIsRunning {
c.mu.Unlock() c.mu.Unlock()
break goto Exit
} }
container, ok := c.conns[stringID] container, ok := c.conns[stringID]
c.mu.Unlock() c.mu.Unlock()
@ -273,16 +284,19 @@ func (c *Client) readLoop(relayConn net.Conn) {
continue continue
} }
// todo review is this can cause panic
container.messages <- Msg{buf[:n]} container.messages <- Msg{buf[:n]}
case messages.MsgClose:
closedByServer = true
log.Debugf("relay connection close by server")
goto Exit
} }
} }
Exit:
c.notifyDisconnected() c.notifyDisconnected()
c.log.Tracef("exit from read loop")
c.wgReadLoop.Done() c.wgReadLoop.Done()
_ = c.close(closedByServer)
c.Close()
} }
// todo check by reference too, the id is not enought because the id come from the outer conn // todo check by reference too, the id is not enought because the id come from the outer conn
@ -365,3 +379,11 @@ func (c *Client) notifyDisconnected() {
} }
go c.onDisconnectListener() go c.onDisconnectListener()
} }
func (c *Client) writeCloseMsg() {
msg := messages.MarshalCloseMsg()
_, err := c.relayConn.Write(msg)
if err != nil {
c.log.Errorf("failed to send close message: %s", err)
}
}

View File

@ -100,57 +100,56 @@ func TestRegistration(t *testing.T) {
} }
}() }()
defer func() {
err := srv.Close()
if err != nil {
t.Errorf("failed to close server: %s", err)
}
}()
clientAlice := NewClient(ctx, addr, "alice") clientAlice := NewClient(ctx, addr, "alice")
err := clientAlice.Connect() err := clientAlice.Connect()
if err != nil { if err != nil {
_ = srv.Close()
t.Fatalf("failed to connect to server: %s", err) t.Fatalf("failed to connect to server: %s", err)
} }
defer func() {
err = clientAlice.Close() err = clientAlice.Close()
if err != nil { if err != nil {
t.Errorf("failed to close conn: %s", err) t.Errorf("failed to close conn: %s", err)
} }
}() err = srv.Close()
if err != nil {
t.Errorf("failed to close server: %s", err)
}
} }
func TestRegistrationTimeout(t *testing.T) { func TestRegistrationTimeout(t *testing.T) {
ctx := context.Background() ctx := context.Background()
udpListener, err := net.ListenUDP("udp", &net.UDPAddr{ fakeUDPListener, err := net.ListenUDP("udp", &net.UDPAddr{
Port: 1234, Port: 1234,
IP: net.ParseIP("0.0.0.0"), IP: net.ParseIP("0.0.0.0"),
}) })
if err != nil { if err != nil {
t.Fatalf("failed to bind UDP server: %s", err) t.Fatalf("failed to bind UDP server: %s", err)
} }
defer udpListener.Close() defer func(fakeUDPListener *net.UDPConn) {
_ = fakeUDPListener.Close()
}(fakeUDPListener)
tcpListener, err := net.ListenTCP("tcp", &net.TCPAddr{ fakeTCPListener, err := net.ListenTCP("tcp", &net.TCPAddr{
Port: 1234, Port: 1234,
IP: net.ParseIP("0.0.0.0"), IP: net.ParseIP("0.0.0.0"),
}) })
if err != nil { if err != nil {
t.Fatalf("failed to bind TCP server: %s", err) t.Fatalf("failed to bind TCP server: %s", err)
} }
defer tcpListener.Close() defer func(fakeTCPListener *net.TCPListener) {
_ = fakeTCPListener.Close()
}(fakeTCPListener)
clientAlice := NewClient(ctx, "127.0.0.1:1234", "alice") clientAlice := NewClient(ctx, "127.0.0.1:1234", "alice")
err = clientAlice.Connect() err = clientAlice.Connect()
if err == nil { if err == nil {
t.Errorf("failed to connect to server: %s", err) t.Errorf("failed to connect to server: %s", err)
} }
defer func() { log.Debugf("%s", err)
err = clientAlice.Close() err = clientAlice.Close()
if err != nil { if err != nil {
t.Errorf("failed to close conn: %s", err) t.Errorf("failed to close conn: %s", err)
} }
}()
} }
func TestEcho(t *testing.T) { func TestEcho(t *testing.T) {
@ -259,18 +258,16 @@ func TestBindToUnavailabePeer(t *testing.T) {
if err != nil { if err != nil {
t.Errorf("failed to connect to server: %s", err) t.Errorf("failed to connect to server: %s", err)
} }
defer func() {
log.Infof("closing client")
err := clientAlice.Close()
if err != nil {
t.Errorf("failed to close client: %s", err)
}
}()
_, err = clientAlice.OpenConn("bob") _, err = clientAlice.OpenConn("bob")
if err != nil { if err != nil {
t.Errorf("failed to bind channel: %s", err) t.Errorf("failed to bind channel: %s", err)
} }
log.Infof("closing client")
err = clientAlice.Close()
if err != nil {
t.Errorf("failed to close client: %s", err)
}
} }
func TestBindReconnect(t *testing.T) { func TestBindReconnect(t *testing.T) {
@ -315,7 +312,7 @@ func TestBindReconnect(t *testing.T) {
t.Errorf("failed to bind channel: %s", err) t.Errorf("failed to bind channel: %s", err)
} }
log.Infof("closing client") log.Infof("closing client Alice")
err = clientAlice.Close() err = clientAlice.Close()
if err != nil { if err != nil {
t.Errorf("failed to close client: %s", err) t.Errorf("failed to close client: %s", err)
@ -403,52 +400,6 @@ func TestCloseConn(t *testing.T) {
} }
} }
func TestAutoReconnect(t *testing.T) {
ctx := context.Background()
addr := "localhost:1234"
srv := server.NewServer()
go func() {
err := srv.Listen(addr)
if err != nil {
t.Errorf("failed to bind server: %s", err)
}
}()
defer func() {
err := srv.Close()
if err != nil {
log.Errorf("failed to close server: %s", err)
}
}()
clientAlice := NewClient(ctx, addr, "alice")
err := clientAlice.Connect()
if err != nil {
t.Errorf("failed to connect to server: %s", err)
}
conn, err := clientAlice.OpenConn("bob")
if err != nil {
t.Errorf("failed to bind channel: %s", err)
}
_ = clientAlice.relayConn.Close()
_, err = conn.Read(make([]byte, 1))
if err == nil {
t.Errorf("unexpected reading from closed connection")
}
log.Infof("waiting for reconnection")
time.Sleep(reconnectingTimeout)
_, err = clientAlice.OpenConn("bob")
if err != nil {
t.Errorf("failed to open channel: %s", err)
}
}
func TestCloseRelayConn(t *testing.T) { func TestCloseRelayConn(t *testing.T) {
ctx := context.Background() ctx := context.Background()
@ -491,3 +442,82 @@ func TestCloseRelayConn(t *testing.T) {
t.Errorf("unexpected opening connection to closed server") t.Errorf("unexpected opening connection to closed server")
} }
} }
func TestCloseByServer(t *testing.T) {
ctx := context.Background()
addr1 := "localhost:1234"
srv1 := server.NewServer()
go func() {
err := srv1.Listen(addr1)
if err != nil {
t.Fatalf("failed to bind server: %s", err)
}
}()
idAlice := "alice"
log.Debugf("connect by alice")
relayClient := NewClient(ctx, addr1, idAlice)
err := relayClient.Connect()
if err != nil {
log.Fatalf("failed to connect to server: %s", err)
}
disconnected := make(chan struct{})
relayClient.SetOnDisconnectListener(func() {
log.Infof("client disconnected")
close(disconnected)
})
err = srv1.Close()
if err != nil {
t.Fatalf("failed to close server: %s", err)
}
select {
case <-disconnected:
case <-time.After(3 * time.Second):
log.Fatalf("timeout waiting for client to disconnect")
}
_, err = relayClient.OpenConn("bob")
if err == nil {
t.Errorf("unexpected opening connection to closed server")
}
}
func TestCloseByClient(t *testing.T) {
ctx := context.Background()
addr1 := "localhost:1234"
srv := server.NewServer()
go func() {
err := srv.Listen(addr1)
if err != nil {
t.Fatalf("failed to bind server: %s", err)
}
}()
idAlice := "alice"
log.Debugf("connect by alice")
relayClient := NewClient(ctx, addr1, idAlice)
err := relayClient.Connect()
if err != nil {
log.Fatalf("failed to connect to server: %s", err)
}
err = relayClient.Close()
if err != nil {
t.Errorf("failed to close client: %s", err)
}
_, err = relayClient.OpenConn("bob")
if err == nil {
t.Errorf("unexpected opening connection to closed server")
}
err = srv.Close()
if err != nil {
t.Fatalf("failed to close server: %s", err)
}
}

View File

@ -6,7 +6,6 @@ import (
"time" "time"
"github.com/gorilla/websocket" "github.com/gorilla/websocket"
log "github.com/sirupsen/logrus"
) )
type Conn struct { type Conn struct {
@ -52,9 +51,5 @@ func (c *Conn) SetDeadline(t time.Time) error {
} }
func (c *Conn) Close() error { func (c *Conn) Close() error {
err := c.WriteControl(websocket.CloseMessage, websocket.FormatCloseMessage(websocket.CloseNormalClosure, ""), time.Now().Add(time.Second*5))
if err != nil {
log.Errorf("failed to close conn?: %s", err)
}
return c.Conn.Close() return c.Conn.Close()
} }

View File

@ -3,13 +3,17 @@ package ws
import ( import (
"fmt" "fmt"
"net" "net"
"time"
"github.com/gorilla/websocket" "github.com/gorilla/websocket"
) )
func Dial(address string) (net.Conn, error) { func Dial(address string) (net.Conn, error) {
addr := fmt.Sprintf("ws://" + address) addr := fmt.Sprintf("ws://" + address)
wsConn, _, err := websocket.DefaultDialer.Dial(addr, nil) wsDialer := websocket.Dialer{
HandshakeTimeout: 3 * time.Second,
}
wsConn, _, err := wsDialer.Dial(addr, nil)
if err != nil { if err != nil {
return nil, err return nil, err
} }

View File

@ -84,8 +84,10 @@ func (m *Manager) OpenConn(serverAddress, peerKey string) (net.Conn, error) {
} }
if !foreign { if !foreign {
log.Debugf("open connection to permanent server: %s", peerKey)
return m.relayClient.OpenConn(peerKey) return m.relayClient.OpenConn(peerKey)
} else { } else {
log.Debugf("open connection to foreign server: %s", serverAddress)
return m.openConnVia(serverAddress, peerKey) return m.openConnVia(serverAddress, peerKey)
} }
} }

View File

@ -47,12 +47,14 @@ func TestForeignConn(t *testing.T) {
idAlice := "alice" idAlice := "alice"
log.Debugf("connect by alice") log.Debugf("connect by alice")
clientAlice := NewManager(ctx, addr1, idAlice) mCtx, cancel := context.WithCancel(ctx)
defer cancel()
clientAlice := NewManager(mCtx, addr1, idAlice)
clientAlice.Serve() clientAlice.Serve()
idBob := "bob" idBob := "bob"
log.Debugf("connect by bob") log.Debugf("connect by bob")
clientBob := NewManager(ctx, addr2, idBob) clientBob := NewManager(mCtx, addr2, idBob)
clientBob.Serve() clientBob.Serve()
bobsSrvAddr, err := clientBob.RelayAddress() bobsSrvAddr, err := clientBob.RelayAddress()
@ -132,61 +134,9 @@ func TestForeginConnClose(t *testing.T) {
idAlice := "alice" idAlice := "alice"
log.Debugf("connect by alice") log.Debugf("connect by alice")
clientAlice := NewManager(ctx, addr1, idAlice) mCtx, cancel := context.WithCancel(ctx)
clientAlice.Serve() defer cancel()
mgr := NewManager(mCtx, addr1, idAlice)
conn, err := clientAlice.OpenConn(addr2, "anotherpeer")
if err != nil {
t.Fatalf("failed to bind channel: %s", err)
}
err = conn.Close()
if err != nil {
t.Fatalf("failed to close connection: %s", err)
}
select {}
}
func TestForeginAutoClose(t *testing.T) {
ctx := context.Background()
addr1 := "localhost:1234"
srv1 := server.NewServer()
go func() {
err := srv1.Listen(addr1)
if err != nil {
t.Fatalf("failed to bind server: %s", err)
}
}()
defer func() {
err := srv1.Close()
if err != nil {
t.Errorf("failed to close server: %s", err)
}
}()
addr2 := "localhost:2234"
srv2 := server.NewServer()
go func() {
err := srv2.Listen(addr2)
if err != nil {
t.Fatalf("failed to bind server: %s", err)
}
}()
defer func() {
err := srv2.Close()
if err != nil {
t.Errorf("failed to close server: %s", err)
}
}()
idAlice := "alice"
log.Debugf("connect by alice")
mgr := NewManager(ctx, addr1, idAlice)
relayCleanupInterval = 2 * time.Second
mgr.Serve() mgr.Serve()
conn, err := mgr.OpenConn(addr2, "anotherpeer") conn, err := mgr.OpenConn(addr2, "anotherpeer")
@ -198,9 +148,124 @@ func TestForeginAutoClose(t *testing.T) {
if err != nil { if err != nil {
t.Fatalf("failed to close connection: %s", err) t.Fatalf("failed to close connection: %s", err)
} }
}
func TestForeginAutoClose(t *testing.T) {
ctx := context.Background()
relayCleanupInterval = 1 * time.Second
addr1 := "localhost:1234"
srv1 := server.NewServer()
go func() {
t.Log("binding server 1.")
err := srv1.Listen(addr1)
if err != nil {
t.Fatalf("failed to bind server: %s", err)
}
}()
defer func() {
t.Logf("closing server 1.")
err := srv1.Close()
if err != nil {
t.Errorf("failed to close server: %s", err)
}
t.Logf("server 1. closed")
}()
addr2 := "localhost:2234"
srv2 := server.NewServer()
go func() {
t.Log("binding server 2.")
err := srv2.Listen(addr2)
if err != nil {
t.Fatalf("failed to bind server: %s", err)
}
}()
defer func() {
t.Logf("closing server 2.")
err := srv2.Close()
if err != nil {
t.Errorf("failed to close server: %s", err)
}
t.Logf("server 2 closed.")
}()
idAlice := "alice"
t.Log("connect to server 1.")
mCtx, cancel := context.WithCancel(ctx)
defer cancel()
mgr := NewManager(mCtx, addr1, idAlice)
mgr.Serve()
t.Log("open connection to another peer")
conn, err := mgr.OpenConn(addr2, "anotherpeer")
if err != nil {
t.Fatalf("failed to bind channel: %s", err)
}
t.Log("close conn")
err = conn.Close()
if err != nil {
t.Fatalf("failed to close connection: %s", err)
}
t.Logf("waiting for relay cleanup: %s", relayCleanupInterval+1*time.Second)
time.Sleep(relayCleanupInterval + 1*time.Second) time.Sleep(relayCleanupInterval + 1*time.Second)
if len(mgr.relayClients) != 0 { if len(mgr.relayClients) != 0 {
t.Errorf("expected 0, got %d", len(mgr.relayClients)) t.Errorf("expected 0, got %d", len(mgr.relayClients))
} }
t.Logf("closing manager")
}
func TestAutoReconnect(t *testing.T) {
ctx := context.Background()
reconnectingTimeout = 2 * time.Second
addr := "localhost:1234"
srv := server.NewServer()
go func() {
err := srv.Listen(addr)
if err != nil {
t.Errorf("failed to bind server: %s", err)
}
}()
defer func() {
err := srv.Close()
if err != nil {
log.Errorf("failed to close server: %s", err)
}
}()
mCtx, cancel := context.WithCancel(ctx)
defer cancel()
clientAlice := NewManager(mCtx, addr, "alice")
clientAlice.Serve()
ra, err := clientAlice.RelayAddress()
if err != nil {
t.Errorf("failed to get relay address: %s", err)
}
conn, err := clientAlice.OpenConn(ra.String(), "bob")
if err != nil {
t.Errorf("failed to bind channel: %s", err)
}
t.Log("closing client relay connection")
// todo figure out moc server
_ = clientAlice.relayClient.relayConn.Close()
t.Log("start test reading")
_, err = conn.Read(make([]byte, 1))
if err == nil {
t.Errorf("unexpected reading from closed connection")
}
log.Infof("waiting for reconnection")
time.Sleep(reconnectingTimeout + 1*time.Second)
log.Infof("reopent the connection")
_, err = clientAlice.OpenConn(ra.String(), "bob")
if err != nil {
t.Errorf("failed to open channel: %s", err)
}
} }

View File

@ -10,6 +10,7 @@ const (
MsgTypeHello MsgType = 0 MsgTypeHello MsgType = 0
MsgTypeHelloResponse MsgType = 1 MsgTypeHelloResponse MsgType = 1
MsgTypeTransport MsgType = 2 MsgTypeTransport MsgType = 2
MsgClose MsgType = 3
) )
var ( var (
@ -26,6 +27,8 @@ func (m MsgType) String() string {
return "hello response" return "hello response"
case MsgTypeTransport: case MsgTypeTransport:
return "transport" return "transport"
case MsgClose:
return "close"
default: default:
return "unknown" return "unknown"
} }
@ -39,6 +42,8 @@ func DetermineClientMsgType(msg []byte) (MsgType, error) {
return msgType, nil return msgType, nil
case MsgTypeTransport: case MsgTypeTransport:
return msgType, nil return msgType, nil
case MsgClose:
return msgType, nil
default: default:
return 0, fmt.Errorf("invalid msg type, len: %d", len(msg)) return 0, fmt.Errorf("invalid msg type, len: %d", len(msg))
} }
@ -52,6 +57,8 @@ func DetermineServerMsgType(msg []byte) (MsgType, error) {
return msgType, nil return msgType, nil
case MsgTypeTransport: case MsgTypeTransport:
return msgType, nil return msgType, nil
case MsgClose:
return msgType, nil
default: default:
return 0, fmt.Errorf("invalid msg type (len: %d)", len(msg)) return 0, fmt.Errorf("invalid msg type (len: %d)", len(msg))
} }
@ -81,6 +88,14 @@ func MarshalHelloResponse() []byte {
return msg return msg
} }
// Close message
func MarshalCloseMsg() []byte {
msg := make([]byte, 1)
msg[0] = byte(MsgClose)
return msg
}
// Transport message // Transport message
func MarshalTransportMsg(peerID []byte, payload []byte) []byte { func MarshalTransportMsg(peerID []byte, payload []byte) []byte {

View File

@ -5,4 +5,5 @@ import "net"
type Listener interface { type Listener interface {
Listen(func(conn net.Conn)) error Listen(func(conn net.Conn)) error
Close() error Close() error
WaitForExitAcceptedConns()
} }

View File

@ -21,6 +21,11 @@ type Listener struct {
lock sync.Mutex lock sync.Mutex
} }
func (l *Listener) WaitForExitAcceptedConns() {
l.wg.Wait()
return
}
func NewListener(address string) listener.Listener { func NewListener(address string) listener.Listener {
return &Listener{ return &Listener{
address: address, address: address,
@ -61,11 +66,11 @@ func (l *Listener) Close() error {
l.lock.Lock() l.lock.Lock()
defer l.lock.Unlock() defer l.lock.Unlock()
log.Infof("closing UDP server")
if l.listener == nil { if l.listener == nil {
return nil return nil
} }
log.Infof("closing UDP listener")
close(l.quit) close(l.quit)
err := l.listener.Close() err := l.listener.Close()
l.wg.Wait() l.wg.Wait()

View File

@ -33,8 +33,11 @@ func NewListener(address string) listener.Listener {
} }
} }
// Listen todo: prevent multiple call
func (l *Listener) Listen(acceptFn func(conn net.Conn)) error { func (l *Listener) Listen(acceptFn func(conn net.Conn)) error {
if l.server != nil {
return errors.New("server is already running")
}
l.acceptFn = acceptFn l.acceptFn = acceptFn
mux := http.NewServeMux() mux := http.NewServeMux()
mux.HandleFunc("/", l.onAccept) mux.HandleFunc("/", l.onAccept)
@ -69,6 +72,10 @@ func (l *Listener) Close() error {
return nil return nil
} }
func (l *Listener) WaitForExitAcceptedConns() {
l.wg.Wait()
}
func (l *Listener) onAccept(writer http.ResponseWriter, request *http.Request) { func (l *Listener) onAccept(writer http.ResponseWriter, request *http.Request) {
l.wg.Add(1) l.wg.Add(1)
defer l.wg.Done() defer l.wg.Done()

View File

@ -6,6 +6,7 @@ import (
"fmt" "fmt"
"io" "io"
"net" "net"
"sync"
"time" "time"
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
@ -17,6 +18,8 @@ type Conn struct {
lAddr *net.TCPAddr lAddr *net.TCPAddr
rAddr *net.TCPAddr rAddr *net.TCPAddr
closed bool
closedMu sync.Mutex
ctx context.Context ctx context.Context
} }
@ -32,7 +35,7 @@ func NewConn(wsConn *websocket.Conn, lAddr, rAddr *net.TCPAddr) *Conn {
func (c *Conn) Read(b []byte) (n int, err error) { func (c *Conn) Read(b []byte) (n int, err error) {
t, r, err := c.Reader(c.ctx) t, r, err := c.Reader(c.ctx)
if err != nil { if err != nil {
return 0, ioErrHandling(err) return 0, c.ioErrHandling(err)
} }
if t != websocket.MessageBinary { if t != websocket.MessageBinary {
@ -42,7 +45,7 @@ func (c *Conn) Read(b []byte) (n int, err error) {
n, err = r.Read(b) n, err = r.Read(b)
if err != nil { if err != nil {
return 0, ioErrHandling(err) return 0, c.ioErrHandling(err)
} }
return n, err return n, err
} }
@ -76,11 +79,23 @@ func (c *Conn) SetDeadline(t time.Time) error {
} }
func (c *Conn) Close() error { func (c *Conn) Close() error {
return c.Conn.Close(websocket.StatusNormalClosure, "") c.closedMu.Lock()
c.closed = true
c.closedMu.Unlock()
return c.Conn.CloseNow()
}
func (c *Conn) isClosed() bool {
c.closedMu.Lock()
defer c.closedMu.Unlock()
return c.closed
}
func (c *Conn) ioErrHandling(err error) error {
if c.isClosed() {
return io.EOF
} }
// todo: fix io.EOF handling
func ioErrHandling(err error) error {
var wErr *websocket.CloseError var wErr *websocket.CloseError
if !errors.As(err, &wErr) { if !errors.As(err, &wErr) {
return err return err

View File

@ -56,15 +56,18 @@ func (l *Listener) Close() error {
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel() defer cancel()
log.Debugf("closing WS server") log.Infof("stop WS listener")
if err := l.server.Shutdown(ctx); err != nil { if err := l.server.Shutdown(ctx); err != nil {
return fmt.Errorf("server shutdown failed: %v", err) return fmt.Errorf("server shutdown failed: %v", err)
} }
log.Infof("WS listener stopped")
l.wg.Wait()
return nil return nil
} }
func (l *Listener) WaitForExitAcceptedConns() {
l.wg.Wait()
}
func (l *Listener) onAccept(w http.ResponseWriter, r *http.Request) { func (l *Listener) onAccept(w http.ResponseWriter, r *http.Request) {
l.wg.Add(1) l.wg.Add(1)
defer l.wg.Done() defer l.wg.Done()

View File

@ -15,11 +15,9 @@ import (
ws "github.com/netbirdio/netbird/relay/server/listener/wsnhooyr" ws "github.com/netbirdio/netbird/relay/server/listener/wsnhooyr"
) )
// Server
// todo:
// authentication: provide JWT token via RPC call. The MGM server can forward the token to the agents.
type Server struct { type Server struct {
store *Store store *Store
storeMu sync.RWMutex
UDPListener listener.Listener UDPListener listener.Listener
WSListener listener.Listener WSListener listener.Listener
@ -28,6 +26,7 @@ type Server struct {
func NewServer() *Server { func NewServer() *Server {
return &Server{ return &Server{
store: NewStore(), store: NewStore(),
storeMu: sync.RWMutex{},
} }
} }
@ -69,6 +68,11 @@ func (r *Server) Close() error {
if r.UDPListener != nil { if r.UDPListener != nil {
uErr = r.UDPListener.Close() uErr = r.UDPListener.Close()
} }
r.sendCloseMsgs()
r.WSListener.WaitForExitAcceptedConns()
err := errors.Join(wErr, uErr) err := errors.Join(wErr, uErr)
return err return err
} }
@ -88,7 +92,7 @@ func (r *Server) accept(conn net.Conn) {
r.store.AddPeer(peer) r.store.AddPeer(peer)
defer func() { defer func() {
r.store.DeletePeer(peer) r.store.DeletePeer(peer)
peer.Log.Infof("peer left") peer.Log.Infof("relay connection closed")
}() }()
for { for {
@ -132,10 +136,33 @@ func (r *Server) accept(conn net.Conn) {
} }
return return
}() }()
case messages.MsgClose:
peer.Log.Infof("peer disconnected gracefully")
_ = conn.Close()
return
} }
} }
} }
func (r *Server) sendCloseMsgs() {
msg := messages.MarshalCloseMsg()
r.storeMu.Lock()
log.Debugf("sending close messages to %d peers", len(r.store.peers))
for _, p := range r.store.peers {
_, err := p.conn.Write(msg)
if err != nil {
log.Errorf("failed to send close message to peer: %s", p.String())
}
err = p.conn.Close()
if err != nil {
log.Errorf("failed to close connection to peer: %s", err)
}
}
r.storeMu.Unlock()
}
func handShake(conn net.Conn) (*Peer, error) { func handShake(conn net.Conn) (*Peer, error) {
buf := make([]byte, 1500) buf := make([]byte, 1500)
n, err := conn.Read(buf) n, err := conn.Read(buf)

View File

@ -24,7 +24,6 @@ func (s *Store) AddPeer(peer *Peer) {
func (s *Store) DeletePeer(peer *Peer) { func (s *Store) DeletePeer(peer *Peer) {
s.peersLock.Lock() s.peersLock.Lock()
defer s.peersLock.Unlock() defer s.peersLock.Unlock()
delete(s.peers, peer.String()) delete(s.peers, peer.String())
} }
@ -35,3 +34,14 @@ func (s *Store) Peer(id string) (*Peer, bool) {
p, ok := s.peers[id] p, ok := s.peers[id]
return p, ok return p, ok
} }
func (s *Store) Peers() []*Peer {
s.peersLock.RLock()
defer s.peersLock.RUnlock()
peers := make([]*Peer, 0, len(s.peers))
for _, p := range s.peers {
peers = append(peers, p)
}
return peers
}