From 4d0e16f2d0cf8af17bfb8c71a138c10c7c5c95c2 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Zolt=C3=A1n=20Papp?= Date: Thu, 27 Jun 2024 02:36:44 +0200 Subject: [PATCH] - Remove WaitForExitAcceptedConns logic from server - Implement thread safe gracefully close logic - organise the server code --- relay/server/listener/listener.go | 1 - relay/server/listener/udp/listener.go | 5 - .../server/listener/{wsnhooyr => ws}/conn.go | 0 .../listener/{wsnhooyr => ws}/listener.go | 10 -- relay/server/peer.go | 126 ++++++++++++-- relay/server/relay.go | 99 +++++++++++ relay/server/server.go | 159 +++--------------- 7 files changed, 234 insertions(+), 166 deletions(-) rename relay/server/listener/{wsnhooyr => ws}/conn.go (100%) rename relay/server/listener/{wsnhooyr => ws}/listener.go (92%) create mode 100644 relay/server/relay.go diff --git a/relay/server/listener/listener.go b/relay/server/listener/listener.go index 3336e0ad8..66e6d357e 100644 --- a/relay/server/listener/listener.go +++ b/relay/server/listener/listener.go @@ -5,5 +5,4 @@ import "net" type Listener interface { Listen(func(conn net.Conn)) error Close() error - WaitForExitAcceptedConns() } diff --git a/relay/server/listener/udp/listener.go b/relay/server/listener/udp/listener.go index 3c2dfc070..31b6098e5 100644 --- a/relay/server/listener/udp/listener.go +++ b/relay/server/listener/udp/listener.go @@ -21,11 +21,6 @@ type Listener struct { lock sync.Mutex } -func (l *Listener) WaitForExitAcceptedConns() { - l.wg.Wait() - return -} - func NewListener(address string) listener.Listener { return &Listener{ address: address, diff --git a/relay/server/listener/wsnhooyr/conn.go b/relay/server/listener/ws/conn.go similarity index 100% rename from relay/server/listener/wsnhooyr/conn.go rename to relay/server/listener/ws/conn.go diff --git a/relay/server/listener/wsnhooyr/listener.go b/relay/server/listener/ws/listener.go similarity index 92% rename from relay/server/listener/wsnhooyr/listener.go rename to relay/server/listener/ws/listener.go index b6bcd12c0..26ba2276e 100644 --- a/relay/server/listener/wsnhooyr/listener.go +++ b/relay/server/listener/ws/listener.go @@ -6,7 +6,6 @@ import ( "fmt" "net" "net/http" - "sync" "time" log "github.com/sirupsen/logrus" @@ -18,7 +17,6 @@ import ( type Listener struct { address string - wg sync.WaitGroup server *http.Server acceptFn func(conn net.Conn) } @@ -63,14 +61,7 @@ func (l *Listener) Close() error { return nil } -func (l *Listener) WaitForExitAcceptedConns() { - l.wg.Wait() -} - func (l *Listener) onAccept(w http.ResponseWriter, r *http.Request) { - l.wg.Add(1) - defer l.wg.Done() - wsConn, err := websocket.Accept(w, r, nil) if err != nil { log.Errorf("failed to accept ws connection: %s", err) @@ -91,5 +82,4 @@ func (l *Listener) onAccept(w http.ResponseWriter, r *http.Request) { conn := NewConn(wsConn, lAddr, rAddr) l.acceptFn(conn) - return } diff --git a/relay/server/peer.go b/relay/server/peer.go index 7af113079..14a86a8ab 100644 --- a/relay/server/peer.go +++ b/relay/server/peer.go @@ -1,33 +1,137 @@ package server import ( + "context" + "fmt" + "io" "net" + "sync" + "time" log "github.com/sirupsen/logrus" "github.com/netbirdio/netbird/relay/messages" ) +const ( + bufferSize = 8820 +) + type Peer struct { - Log *log.Entry - idS string - idB []byte - conn net.Conn + log *log.Entry + idS string + idB []byte + conn net.Conn + connMu sync.RWMutex + store *Store } -func NewPeer(id []byte, conn net.Conn) *Peer { +func NewPeer(id []byte, conn net.Conn, store *Store) *Peer { stringID := messages.HashIDToString(id) return &Peer{ - Log: log.WithField("peer_id", stringID), - idB: id, - idS: stringID, - conn: conn, + log: log.WithField("peer_id", stringID), + idS: stringID, + idB: id, + conn: conn, + store: store, } } -func (p *Peer) ID() []byte { - return p.idB + +func (p *Peer) Work() { + buf := make([]byte, bufferSize) + for { + n, err := p.conn.Read(buf) + if err != nil { + if err != io.EOF { + p.log.Errorf("failed to read message: %s", err) + } + return + } + + msg := buf[:n] + + msgType, err := messages.DetermineClientMsgType(msg) + if err != nil { + p.log.Errorf("failed to determine message type: %s", err) + return + } + switch msgType { + case messages.MsgHealthCheck: + case messages.MsgTypeTransport: + peerID, err := messages.UnmarshalTransportID(msg) + if err != nil { + p.log.Errorf("failed to unmarshal transport message: %s", err) + continue + } + stringPeerID := messages.HashIDToString(peerID) + dp, ok := p.store.Peer(stringPeerID) + if !ok { + p.log.Errorf("peer not found: %s", stringPeerID) + continue + } + err = messages.UpdateTransportMsg(msg, p.idB) + if err != nil { + p.log.Errorf("failed to update transport message: %s", err) + continue + } + _, err = dp.Write(msg) + if err != nil { + p.log.Errorf("failed to write transport message to: %s", dp.String()) + } + case messages.MsgClose: + p.log.Infof("peer exited gracefully") + _ = p.conn.Close() + return + } + } +} + +// Write writes data to the connection +// it has been called by the remote peer +func (p *Peer) Write(b []byte) (int, error) { + p.connMu.RLock() + defer p.connMu.RUnlock() + return p.conn.Write(b) +} + +func (p *Peer) CloseGracefully(ctx context.Context) { + p.connMu.Lock() + _, err := p.writeWithTimeout(ctx, messages.MarshalCloseMsg()) + if err != nil { + log.Errorf("failed to send close message to peer: %s", p.String()) + } + + err = p.conn.Close() + if err != nil { + log.Errorf("failed to close connection to peer: %s", err) + } + + defer p.connMu.Unlock() } func (p *Peer) String() string { return p.idS } + +func (p *Peer) writeWithTimeout(ctx context.Context, buf []byte) (int, error) { + ctx, cancel := context.WithTimeout(ctx, 3*time.Second) + defer cancel() + + writeDone := make(chan struct{}) + var ( + n int + err error + ) + + go func() { + _, err = p.conn.Write(buf) + close(writeDone) + }() + + select { + case <-ctx.Done(): + return 0, fmt.Errorf("write operation timed out") + case <-writeDone: + return n, err + } +} diff --git a/relay/server/relay.go b/relay/server/relay.go new file mode 100644 index 000000000..39719f4a9 --- /dev/null +++ b/relay/server/relay.go @@ -0,0 +1,99 @@ +package server + +import ( + "context" + "fmt" + "net" + "sync" + + log "github.com/sirupsen/logrus" + + "github.com/netbirdio/netbird/relay/messages" +) + +type Relay struct { + store *Store + + closed bool + closeMu sync.RWMutex +} + +func NewRelay() *Relay { + return &Relay{ + store: NewStore(), + } +} + +func (r *Relay) Accept(conn net.Conn) { + r.closeMu.RLock() + defer r.closeMu.RUnlock() + if r.closed { + return + } + + peerID, err := handShake(conn) + if err != nil { + log.Errorf("failed to handshake with %s: %s", conn.RemoteAddr(), err) + cErr := conn.Close() + if cErr != nil { + log.Errorf("failed to close connection, %s: %s", conn.RemoteAddr(), cErr) + } + return + } + + peer := NewPeer(peerID, conn, r.store) + peer.log.Infof("peer connected from: %s", conn.RemoteAddr()) + r.store.AddPeer(peer) + + go func() { + peer.Work() + r.store.DeletePeer(peer) + peer.log.Debugf("relay connection closed") + }() +} + +func (r *Relay) Close(ctx context.Context) { + log.Infof("closeing connection with all peers") + r.closeMu.Lock() + wg := sync.WaitGroup{} + peers := r.store.Peers() + for _, peer := range peers { + wg.Add(1) + go func(p *Peer) { + p.CloseGracefully(ctx) + wg.Done() + }(peer) + } + wg.Wait() + r.closeMu.Unlock() +} + +func handShake(conn net.Conn) ([]byte, error) { + buf := make([]byte, messages.MaxHandshakeSize) + n, err := conn.Read(buf) + if err != nil { + log.Errorf("failed to read message: %s", err) + return nil, err + } + msgType, err := messages.DetermineClientMsgType(buf[:n]) + if err != nil { + return nil, err + } + if msgType != messages.MsgTypeHello { + tErr := fmt.Errorf("invalid message type") + log.Errorf("failed to handshake: %s", tErr) + return nil, tErr + } + peerID, err := messages.UnmarshalHelloMsg(buf[:n]) + if err != nil { + log.Errorf("failed to handshake: %s", err) + return nil, err + } + + msg := messages.MarshalHelloResponse() + _, err = conn.Write(msg) + if err != nil { + return nil, err + } + return peerID, nil +} diff --git a/relay/server/server.go b/relay/server/server.go index f52f7eab8..cf48e19e3 100644 --- a/relay/server/server.go +++ b/relay/server/server.go @@ -1,36 +1,27 @@ package server import ( + "context" "errors" - "fmt" - "io" - "net" "sync" + "time" log "github.com/sirupsen/logrus" - "github.com/netbirdio/netbird/relay/messages" "github.com/netbirdio/netbird/relay/server/listener" "github.com/netbirdio/netbird/relay/server/listener/udp" - ws "github.com/netbirdio/netbird/relay/server/listener/wsnhooyr" -) - -const ( - bufferSize = 8820 + "github.com/netbirdio/netbird/relay/server/listener/ws" ) type Server struct { - store *Store - storeMu sync.RWMutex - - UDPListener listener.Listener - WSListener listener.Listener + relay *Relay + uDPListener listener.Listener + wSListener listener.Listener } func NewServer() *Server { return &Server{ - store: NewStore(), - storeMu: sync.RWMutex{}, + relay: NewRelay(), } } @@ -38,21 +29,21 @@ func (r *Server) Listen(address string) error { wg := sync.WaitGroup{} wg.Add(2) - r.WSListener = ws.NewListener(address) + r.wSListener = ws.NewListener(address) var wslErr error go func() { defer wg.Done() - wslErr = r.WSListener.Listen(r.accept) + wslErr = r.wSListener.Listen(r.relay.Accept) if wslErr != nil { log.Errorf("failed to bind ws server: %s", wslErr) } }() - r.UDPListener = udp.NewListener(address) + r.uDPListener = udp.NewListener(address) var udpLErr error go func() { defer wg.Done() - udpLErr = r.UDPListener.Listen(r.accept) + udpLErr = r.uDPListener.Listen(r.relay.Accept) if udpLErr != nil { log.Errorf("failed to bind ws server: %s", udpLErr) } @@ -64,131 +55,21 @@ func (r *Server) Listen(address string) error { func (r *Server) Close() error { var wErr error - if r.WSListener != nil { - wErr = r.WSListener.Close() + // stop service new connections + if r.wSListener != nil { + wErr = r.wSListener.Close() } var uErr error - if r.UDPListener != nil { - uErr = r.UDPListener.Close() + if r.uDPListener != nil { + uErr = r.uDPListener.Close() } - r.sendCloseMsgs() - - r.WSListener.WaitForExitAcceptedConns() + // close accepted connections gracefully + ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second) + defer cancel() + r.relay.Close(ctx) err := errors.Join(wErr, uErr) return err } - -func (r *Server) accept(conn net.Conn) { - peer, err := handShake(conn) - if err != nil { - log.Errorf("failed to handshake with %s: %s", conn.RemoteAddr(), err) - cErr := conn.Close() - if cErr != nil { - log.Errorf("failed to close connection, %s: %s", conn.RemoteAddr(), cErr) - } - return - } - peer.Log.Infof("peer connected from: %s", conn.RemoteAddr()) - - r.store.AddPeer(peer) - defer func() { - r.store.DeletePeer(peer) - peer.Log.Infof("relay connection closed") - }() - - buf := make([]byte, bufferSize) - for { - n, err := conn.Read(buf) - if err != nil { - if err != io.EOF { - peer.Log.Errorf("failed to read message: %s", err) - } - return - } - - msg := buf[:n] - - msgType, err := messages.DetermineClientMsgType(msg) - if err != nil { - peer.Log.Errorf("failed to determine message type: %s", err) - return - } - switch msgType { - case messages.MsgTypeTransport: - peerID, err := messages.UnmarshalTransportID(msg) - if err != nil { - peer.Log.Errorf("failed to unmarshal transport message: %s", err) - continue - } - stringPeerID := messages.HashIDToString(peerID) - dp, ok := r.store.Peer(stringPeerID) - if !ok { - peer.Log.Errorf("peer not found: %s", stringPeerID) - continue - } - err = messages.UpdateTransportMsg(msg, peer.ID()) - if err != nil { - peer.Log.Errorf("failed to update transport message: %s", err) - continue - } - _, err = dp.conn.Write(msg) - if err != nil { - peer.Log.Errorf("failed to write transport message to: %s", dp.String()) - } - case messages.MsgClose: - peer.Log.Infof("peer disconnected gracefully") - _ = conn.Close() - return - } - } -} - -func (r *Server) sendCloseMsgs() { - msg := messages.MarshalCloseMsg() - - r.storeMu.Lock() - log.Debugf("sending close messages to %d peers", len(r.store.peers)) - for _, p := range r.store.peers { - _, err := p.conn.Write(msg) - if err != nil { - log.Errorf("failed to send close message to peer: %s", p.String()) - } - - err = p.conn.Close() - if err != nil { - log.Errorf("failed to close connection to peer: %s", err) - } - } - r.storeMu.Unlock() -} - -func handShake(conn net.Conn) (*Peer, error) { - buf := make([]byte, messages.MaxHandshakeSize) - n, err := conn.Read(buf) - if err != nil { - log.Errorf("failed to read message: %s", err) - return nil, err - } - msgType, err := messages.DetermineClientMsgType(buf[:n]) - if err != nil { - return nil, err - } - if msgType != messages.MsgTypeHello { - tErr := fmt.Errorf("invalid message type") - log.Errorf("failed to handshake: %s", tErr) - return nil, tErr - } - peerId, err := messages.UnmarshalHelloMsg(buf[:n]) - if err != nil { - log.Errorf("failed to handshake: %s", err) - return nil, err - } - p := NewPeer(peerId, conn) - - msg := messages.MarshalHelloResponse() - _, err = conn.Write(msg) - return p, err -}