Fix in client the close event

This commit is contained in:
Zoltán Papp 2024-05-26 22:14:33 +02:00
parent 36b2cd16cc
commit 173ca25dac
12 changed files with 257 additions and 54 deletions

View File

@ -1,6 +1,7 @@
package client package client
import ( import (
"context"
"fmt" "fmt"
"io" "io"
"net" "net"
@ -29,20 +30,25 @@ type connContainer struct {
type Client struct { type Client struct {
log *log.Entry log *log.Entry
ctx context.Context
ctxCancel context.CancelFunc
serverAddress string serverAddress string
hashedID []byte hashedID []byte
conns map[string]*connContainer conns map[string]*connContainer // todo handle it in thread safe way
relayConn net.Conn relayConn net.Conn
relayConnState bool relayConnState bool
mu sync.Mutex mu sync.Mutex
} }
func NewClient(serverAddress, peerID string) *Client { func NewClient(ctx context.Context, serverAddress, peerID string) *Client {
ctx, ctxCancel := context.WithCancel(ctx)
hashedID, hashedStringId := messages.HashID(peerID) hashedID, hashedStringId := messages.HashID(peerID)
return &Client{ return &Client{
log: log.WithField("client_id", hashedStringId), log: log.WithField("client_id", hashedStringId),
ctx: ctx,
ctxCancel: ctxCancel,
serverAddress: serverAddress, serverAddress: serverAddress,
hashedID: hashedID, hashedID: hashedID,
conns: make(map[string]*connContainer), conns: make(map[string]*connContainer),
@ -51,7 +57,11 @@ func NewClient(serverAddress, peerID string) *Client {
func (c *Client) Connect() error { func (c *Client) Connect() error {
c.mu.Lock() c.mu.Lock()
defer c.mu.Unlock() if c.relayConnState {
c.mu.Unlock()
return nil
}
conn, err := udp.Dial(c.serverAddress) conn, err := udp.Dial(c.serverAddress)
if err != nil { if err != nil {
return err return err
@ -68,18 +78,39 @@ func (c *Client) Connect() error {
return err return err
} }
err = c.relayConn.SetReadDeadline(time.Time{})
if err != nil {
log.Errorf("failed to reset read deadline: %s", err)
return err
}
c.relayConnState = true c.relayConnState = true
go c.readLoop() c.mu.Unlock()
go func() {
<-c.ctx.Done()
cErr := c.close()
if cErr != nil {
log.Errorf("failed to close relay connection: %s", cErr)
}
}()
// blocking function
c.readLoop()
c.mu.Lock()
// close all Conn types
for _, container := range c.conns {
close(container.messages)
}
c.conns = make(map[string]*connContainer)
c.mu.Unlock()
return nil return nil
} }
func (c *Client) OpenConn(dstPeerID string) (net.Conn, error) { func (c *Client) OpenConn(dstPeerID string) (net.Conn, error) {
c.mu.Lock()
defer c.mu.Unlock()
if !c.relayConnState {
return nil, fmt.Errorf("relay connection is not established")
}
hashedID, hashedStringID := messages.HashID(dstPeerID) hashedID, hashedStringID := messages.HashID(dstPeerID)
log.Infof("open connection to peer: %s", hashedStringID) log.Infof("open connection to peer: %s", hashedStringID)
messageBuffer := make(chan Msg, 2) messageBuffer := make(chan Msg, 2)
@ -93,6 +124,11 @@ func (c *Client) OpenConn(dstPeerID string) (net.Conn, error) {
} }
func (c *Client) Close() error { func (c *Client) Close() error {
c.ctxCancel()
return c.close()
}
func (c *Client) close() error {
c.mu.Lock() c.mu.Lock()
defer c.mu.Unlock() defer c.mu.Unlock()
@ -101,11 +137,20 @@ func (c *Client) Close() error {
} }
c.relayConnState = false c.relayConnState = false
err := c.relayConn.Close() err := c.relayConn.Close()
return err return err
} }
func (c *Client) handShake() error { func (c *Client) handShake() error {
defer func() {
err := c.relayConn.SetReadDeadline(time.Time{})
if err != nil {
log.Errorf("failed to reset read deadline: %s", err)
}
}()
msg, err := messages.MarshalHelloMsg(c.hashedID) msg, err := messages.MarshalHelloMsg(c.hashedID)
if err != nil { if err != nil {
log.Errorf("failed to marshal hello message: %s", err) log.Errorf("failed to marshal hello message: %s", err)
@ -145,7 +190,7 @@ func (c *Client) handShake() error {
func (c *Client) readLoop() { func (c *Client) readLoop() {
defer func() { defer func() {
c.log.Debugf("exit from read loop") c.log.Tracef("exit from read loop")
}() }()
var errExit error var errExit error
var n int var n int

View File

@ -6,6 +6,7 @@ import (
"time" "time"
"github.com/gorilla/websocket" "github.com/gorilla/websocket"
log "github.com/sirupsen/logrus"
) )
type Conn struct { type Conn struct {
@ -51,6 +52,9 @@ func (c *Conn) SetDeadline(t time.Time) error {
} }
func (c *Conn) Close() error { func (c *Conn) Close() error {
_ = c.WriteControl(websocket.CloseMessage, websocket.FormatCloseMessage(websocket.CloseNormalClosure, ""), time.Now().Add(time.Second*5)) err := c.WriteControl(websocket.CloseMessage, websocket.FormatCloseMessage(websocket.CloseNormalClosure, ""), time.Now().Add(time.Second*5))
if err != nil {
log.Errorf("failed to close conn?: %s", err)
}
return c.Conn.Close() return c.Conn.Close()
} }

View File

@ -2,42 +2,74 @@ package client
import ( import (
"context" "context"
"net"
"sync" "sync"
"time"
log "github.com/sirupsen/logrus"
) )
type Manager struct { type Manager struct {
ctx context.Context ctx context.Context
ctxCancel context.CancelFunc
srvAddress string srvAddress string
peerID string peerID string
wg sync.WaitGroup reconnectTime time.Duration
clients map[string]*Client mu sync.Mutex
clientsMutex sync.RWMutex client *Client
} }
func NewManager(ctx context.Context, serverAddress string, peerID string) *Manager { func NewManager(ctx context.Context, serverAddress string, peerID string) *Manager {
ctx, cancel := context.WithCancel(ctx)
return &Manager{ return &Manager{
ctx: ctx, ctx: ctx,
ctxCancel: cancel, srvAddress: serverAddress,
srvAddress: serverAddress, peerID: peerID,
peerID: peerID, reconnectTime: 5 * time.Second,
clients: make(map[string]*Client),
} }
} }
func (m *Manager) Teardown() { func (m *Manager) Serve() {
m.ctxCancel() ok := m.mu.TryLock()
m.wg.Wait() if !ok {
}
func (m *Manager) newSrvConnection(address string) {
if _, ok := m.clients[address]; ok {
return return
} }
// client := NewClient(address, m.peerID) m.client = NewClient(m.ctx, m.srvAddress, m.peerID)
//err = client.Connect()
go func() {
defer m.mu.Unlock()
// todo this is not thread safe
for {
select {
case <-m.ctx.Done():
return
default:
m.connect()
}
select {
case <-m.ctx.Done():
return
case <-time.After(2 * time.Second): //timeout
}
}
}()
}
func (m *Manager) OpenConn(peerKey string) (net.Conn, error) {
// todo m.client nil check
return m.client.OpenConn(peerKey)
}
// connect is blocking
func (m *Manager) connect() {
err := m.client.Connect()
if err != nil {
if m.ctx.Err() != nil {
return
}
log.Errorf("connection error with '%s': %s", m.srvAddress, err)
}
} }

View File

@ -23,4 +23,6 @@ func main() {
log.Errorf("failed to bind server: %s", err) log.Errorf("failed to bind server: %s", err)
os.Exit(1) os.Exit(1)
} }
select {}
} }

View File

@ -61,6 +61,7 @@ func (l *Listener) Close() error {
l.lock.Lock() l.lock.Lock()
defer l.lock.Unlock() defer l.lock.Unlock()
log.Infof("closing UDP server")
if l.listener == nil { if l.listener == nil {
return nil return nil
} }
@ -95,6 +96,7 @@ func (l *Listener) readLoop() {
} }
pConn = NewConn(l.listener, addr) pConn = NewConn(l.listener, addr)
log.Infof("new connection from: %s", pConn.RemoteAddr())
l.conns[addr.String()] = pConn l.conns[addr.String()] = pConn
go l.onAcceptFn(pConn) go l.onAcceptFn(pConn)
pConn.onNewMsg(buf[:n]) pConn.onNewMsg(buf[:n])

View File

@ -1,7 +1,9 @@
package ws package ws
import ( import (
"errors"
"fmt" "fmt"
"io"
"sync" "sync"
"time" "time"
@ -24,7 +26,7 @@ func NewConn(wsConn *websocket.Conn) *Conn {
func (c *Conn) Read(b []byte) (n int, err error) { func (c *Conn) Read(b []byte) (n int, err error) {
t, r, err := c.NextReader() t, r, err := c.NextReader()
if err != nil { if err != nil {
return 0, err return 0, ioErrHandling(err)
} }
if t != websocket.BinaryMessage { if t != websocket.BinaryMessage {
@ -32,7 +34,11 @@ func (c *Conn) Read(b []byte) (n int, err error) {
return 0, fmt.Errorf("unexpected message type") return 0, fmt.Errorf("unexpected message type")
} }
return r.Read(b) n, err = r.Read(b)
if err != nil {
return 0, ioErrHandling(err)
}
return n, err
} }
func (c *Conn) Write(b []byte) (int, error) { func (c *Conn) Write(b []byte) (int, error) {
@ -55,3 +61,14 @@ func (c *Conn) SetDeadline(t time.Time) error {
} }
return nil return nil
} }
func ioErrHandling(err error) error {
var wErr *websocket.CloseError
if !errors.As(err, &wErr) {
return err
}
if wErr.Code == websocket.CloseNormalClosure {
return io.EOF
}
return err
}

View File

@ -42,7 +42,7 @@ func (l *Listener) Listen(acceptFn func(conn net.Conn)) error {
Addr: l.address, Addr: l.address,
} }
log.Debugf("WS server is listening on address: %s", l.address) log.Infof("WS server is listening on address: %s", l.address)
err := l.server.ListenAndServe() err := l.server.ListenAndServe()
if errors.Is(err, http.ErrServerClosed) { if errors.Is(err, http.ErrServerClosed) {
return nil return nil
@ -77,6 +77,7 @@ func (l *Listener) onAccept(writer http.ResponseWriter, request *http.Request) {
return return
} }
conn := NewConn(wsConn) conn := NewConn(wsConn)
log.Infof("new connection from: %s", conn.RemoteAddr())
l.acceptFn(conn) l.acceptFn(conn)
return return
} }

View File

@ -16,7 +16,6 @@ type Peer struct {
} }
func NewPeer(id []byte, conn net.Conn) *Peer { func NewPeer(id []byte, conn net.Conn) *Peer {
log.Debugf("new peer: %v", id)
stringID := messages.HashIDToString(id) stringID := messages.HashIDToString(id)
return &Peer{ return &Peer{
Log: log.WithField("peer_id", stringID), Log: log.WithField("peer_id", stringID),

View File

@ -1,15 +1,18 @@
package server package server
import ( import (
"errors"
"fmt" "fmt"
"io" "io"
"net" "net"
"sync"
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/relay/messages" "github.com/netbirdio/netbird/relay/messages"
"github.com/netbirdio/netbird/relay/server/listener" "github.com/netbirdio/netbird/relay/server/listener"
"github.com/netbirdio/netbird/relay/server/listener/udp" "github.com/netbirdio/netbird/relay/server/listener/udp"
"github.com/netbirdio/netbird/relay/server/listener/ws"
) )
// Server // Server
@ -19,7 +22,8 @@ import (
type Server struct { type Server struct {
store *Store store *Store
listener listener.Listener UDPListener listener.Listener
WSListener listener.Listener
} }
func NewServer() *Server { func NewServer() *Server {
@ -29,15 +33,45 @@ func NewServer() *Server {
} }
func (r *Server) Listen(address string) error { func (r *Server) Listen(address string) error {
r.listener = udp.NewListener(address) wg := sync.WaitGroup{}
return r.listener.Listen(r.accept) wg.Add(2)
r.WSListener = ws.NewListener(address)
var wslErr error
go func() {
defer wg.Done()
wslErr = r.WSListener.Listen(r.accept)
if wslErr != nil {
log.Errorf("failed to bind ws server: %s", wslErr)
}
}()
r.UDPListener = udp.NewListener(address)
var udpLErr error
go func() {
defer wg.Done()
udpLErr = r.UDPListener.Listen(r.accept)
if udpLErr != nil {
log.Errorf("failed to bind ws server: %s", udpLErr)
}
}()
err := errors.Join(wslErr, udpLErr)
return err
} }
func (r *Server) Close() error { func (r *Server) Close() error {
if r.listener == nil { var wErr error
return nil if r.WSListener != nil {
wErr = r.WSListener.Close()
} }
return r.listener.Close()
var uErr error
if r.UDPListener != nil {
uErr = r.UDPListener.Close()
}
err := errors.Join(wErr, uErr)
return err
} }
func (r *Server) accept(conn net.Conn) { func (r *Server) accept(conn net.Conn) {
@ -50,12 +84,12 @@ func (r *Server) accept(conn net.Conn) {
} }
return return
} }
peer.Log.Debugf("peer connected from: %s", conn.RemoteAddr()) peer.Log.Infof("peer connected from: %s", conn.RemoteAddr())
r.store.AddPeer(peer) r.store.AddPeer(peer)
defer func() { defer func() {
peer.Log.Debugf("teardown connection")
r.store.DeletePeer(peer) r.store.DeletePeer(peer)
peer.Log.Infof("peer left")
}() }()
for { for {

View File

@ -1,6 +1,7 @@
package test package test
import ( import (
"context"
"net" "net"
"os" "os"
"testing" "testing"
@ -20,6 +21,8 @@ func TestMain(m *testing.M) {
} }
func TestClient(t *testing.T) { func TestClient(t *testing.T) {
ctx := context.Background()
addr := "localhost:1234" addr := "localhost:1234"
srv := server.NewServer() srv := server.NewServer()
go func() { go func() {
@ -36,21 +39,21 @@ func TestClient(t *testing.T) {
} }
}() }()
clientAlice := client.NewClient(addr, "alice") clientAlice := client.NewClient(ctx, addr, "alice")
err := clientAlice.Connect() err := clientAlice.Connect()
if err != nil { if err != nil {
t.Fatalf("failed to connect to server: %s", err) t.Fatalf("failed to connect to server: %s", err)
} }
defer clientAlice.Close() defer clientAlice.Close()
clientPlaceHolder := client.NewClient(addr, "clientPlaceHolder") clientPlaceHolder := client.NewClient(ctx, addr, "clientPlaceHolder")
err = clientPlaceHolder.Connect() err = clientPlaceHolder.Connect()
if err != nil { if err != nil {
t.Fatalf("failed to connect to server: %s", err) t.Fatalf("failed to connect to server: %s", err)
} }
defer clientPlaceHolder.Close() defer clientPlaceHolder.Close()
clientBob := client.NewClient(addr, "bob") clientBob := client.NewClient(ctx, addr, "bob")
err = clientBob.Connect() err = clientBob.Connect()
if err != nil { if err != nil {
t.Fatalf("failed to connect to server: %s", err) t.Fatalf("failed to connect to server: %s", err)
@ -87,6 +90,7 @@ func TestClient(t *testing.T) {
} }
func TestRegistration(t *testing.T) { func TestRegistration(t *testing.T) {
ctx := context.Background()
addr := "localhost:1234" addr := "localhost:1234"
srv := server.NewServer() srv := server.NewServer()
go func() { go func() {
@ -103,7 +107,7 @@ func TestRegistration(t *testing.T) {
} }
}() }()
clientAlice := client.NewClient(addr, "alice") clientAlice := client.NewClient(ctx, addr, "alice")
err := clientAlice.Connect() err := clientAlice.Connect()
if err != nil { if err != nil {
t.Fatalf("failed to connect to server: %s", err) t.Fatalf("failed to connect to server: %s", err)
@ -117,6 +121,7 @@ func TestRegistration(t *testing.T) {
} }
func TestRegistrationTimeout(t *testing.T) { func TestRegistrationTimeout(t *testing.T) {
ctx := context.Background()
udpListener, err := net.ListenUDP("udp", &net.UDPAddr{ udpListener, err := net.ListenUDP("udp", &net.UDPAddr{
Port: 1234, Port: 1234,
IP: net.ParseIP("0.0.0.0"), IP: net.ParseIP("0.0.0.0"),
@ -135,7 +140,7 @@ func TestRegistrationTimeout(t *testing.T) {
} }
defer tcpListener.Close() defer tcpListener.Close()
clientAlice := client.NewClient("127.0.0.1:1234", "alice") clientAlice := client.NewClient(ctx, "127.0.0.1:1234", "alice")
err = clientAlice.Connect() err = clientAlice.Connect()
if err == nil { if err == nil {
t.Errorf("failed to connect to server: %s", err) t.Errorf("failed to connect to server: %s", err)
@ -149,6 +154,7 @@ func TestRegistrationTimeout(t *testing.T) {
} }
func TestEcho(t *testing.T) { func TestEcho(t *testing.T) {
ctx := context.Background()
idAlice := "alice" idAlice := "alice"
idBob := "bob" idBob := "bob"
addr := "localhost:1234" addr := "localhost:1234"
@ -167,7 +173,7 @@ func TestEcho(t *testing.T) {
} }
}() }()
clientAlice := client.NewClient(addr, idAlice) clientAlice := client.NewClient(ctx, addr, idAlice)
err := clientAlice.Connect() err := clientAlice.Connect()
if err != nil { if err != nil {
t.Fatalf("failed to connect to server: %s", err) t.Fatalf("failed to connect to server: %s", err)
@ -179,7 +185,7 @@ func TestEcho(t *testing.T) {
} }
}() }()
clientBob := client.NewClient(addr, idBob) clientBob := client.NewClient(ctx, addr, idBob)
err = clientBob.Connect() err = clientBob.Connect()
if err != nil { if err != nil {
t.Fatalf("failed to connect to server: %s", err) t.Fatalf("failed to connect to server: %s", err)
@ -229,6 +235,8 @@ func TestEcho(t *testing.T) {
} }
func TestBindToUnavailabePeer(t *testing.T) { func TestBindToUnavailabePeer(t *testing.T) {
ctx := context.Background()
addr := "localhost:1234" addr := "localhost:1234"
srv := server.NewServer() srv := server.NewServer()
go func() { go func() {
@ -246,7 +254,7 @@ func TestBindToUnavailabePeer(t *testing.T) {
} }
}() }()
clientAlice := client.NewClient(addr, "alice") clientAlice := client.NewClient(ctx, addr, "alice")
err := clientAlice.Connect() err := clientAlice.Connect()
if err != nil { if err != nil {
t.Errorf("failed to connect to server: %s", err) t.Errorf("failed to connect to server: %s", err)
@ -266,6 +274,8 @@ func TestBindToUnavailabePeer(t *testing.T) {
} }
func TestBindReconnect(t *testing.T) { func TestBindReconnect(t *testing.T) {
ctx := context.Background()
addr := "localhost:1234" addr := "localhost:1234"
srv := server.NewServer() srv := server.NewServer()
go func() { go func() {
@ -283,7 +293,7 @@ func TestBindReconnect(t *testing.T) {
} }
}() }()
clientAlice := client.NewClient(addr, "alice") clientAlice := client.NewClient(ctx, addr, "alice")
err := clientAlice.Connect() err := clientAlice.Connect()
if err != nil { if err != nil {
t.Errorf("failed to connect to server: %s", err) t.Errorf("failed to connect to server: %s", err)
@ -294,7 +304,7 @@ func TestBindReconnect(t *testing.T) {
t.Errorf("failed to bind channel: %s", err) t.Errorf("failed to bind channel: %s", err)
} }
clientBob := client.NewClient(addr, "bob") clientBob := client.NewClient(ctx, addr, "bob")
err = clientBob.Connect() err = clientBob.Connect()
if err != nil { if err != nil {
t.Errorf("failed to connect to server: %s", err) t.Errorf("failed to connect to server: %s", err)
@ -311,7 +321,7 @@ func TestBindReconnect(t *testing.T) {
t.Errorf("failed to close client: %s", err) t.Errorf("failed to close client: %s", err)
} }
clientAlice = client.NewClient(addr, "alice") clientAlice = client.NewClient(ctx, addr, "alice")
err = clientAlice.Connect() err = clientAlice.Connect()
if err != nil { if err != nil {
t.Errorf("failed to connect to server: %s", err) t.Errorf("failed to connect to server: %s", err)

View File

@ -0,0 +1,57 @@
package test
import (
"context"
"testing"
"time"
"github.com/netbirdio/netbird/relay/client"
"github.com/netbirdio/netbird/relay/server"
)
func TestManager(t *testing.T) {
addr := "localhost:1239"
srv := server.NewServer()
go func() {
err := srv.Listen(addr)
if err != nil {
t.Fatalf("failed to bind server: %s", err)
}
}()
defer func() {
err := srv.Close()
if err != nil {
t.Errorf("failed to close server: %s", err)
}
}()
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
cm := client.NewManager(ctx, addr, "me")
cm.Serve()
// wait for the relay handshake to complete
time.Sleep(1 * time.Second)
conn, err := cm.OpenConn("remotepeer")
if err != nil {
t.Errorf("failed to open connection: %s", err)
}
readCtx, readCancel := context.WithCancel(context.Background())
defer readCancel()
go func() {
_, _ = conn.Read(make([]byte, 1))
readCancel()
}()
cancel()
select {
case <-time.After(2 * time.Second):
t.Errorf("client peer conn did not close automatically")
case <-readCtx.Done():
// conn exited well
}
}