- Remove WaitForExitAcceptedConns logic from server

- Implement thread safe gracefully close logic
- organise the server code
This commit is contained in:
Zoltán Papp 2024-06-27 02:36:44 +02:00
parent 3fcdb51376
commit 4d0e16f2d0
7 changed files with 234 additions and 166 deletions

View File

@ -5,5 +5,4 @@ import "net"
type Listener interface {
Listen(func(conn net.Conn)) error
Close() error
WaitForExitAcceptedConns()
}

View File

@ -21,11 +21,6 @@ type Listener struct {
lock sync.Mutex
}
func (l *Listener) WaitForExitAcceptedConns() {
l.wg.Wait()
return
}
func NewListener(address string) listener.Listener {
return &Listener{
address: address,

View File

@ -6,7 +6,6 @@ import (
"fmt"
"net"
"net/http"
"sync"
"time"
log "github.com/sirupsen/logrus"
@ -18,7 +17,6 @@ import (
type Listener struct {
address string
wg sync.WaitGroup
server *http.Server
acceptFn func(conn net.Conn)
}
@ -63,14 +61,7 @@ func (l *Listener) Close() error {
return nil
}
func (l *Listener) WaitForExitAcceptedConns() {
l.wg.Wait()
}
func (l *Listener) onAccept(w http.ResponseWriter, r *http.Request) {
l.wg.Add(1)
defer l.wg.Done()
wsConn, err := websocket.Accept(w, r, nil)
if err != nil {
log.Errorf("failed to accept ws connection: %s", err)
@ -91,5 +82,4 @@ func (l *Listener) onAccept(w http.ResponseWriter, r *http.Request) {
conn := NewConn(wsConn, lAddr, rAddr)
l.acceptFn(conn)
return
}

View File

@ -1,33 +1,137 @@
package server
import (
"context"
"fmt"
"io"
"net"
"sync"
"time"
log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/relay/messages"
)
const (
bufferSize = 8820
)
type Peer struct {
Log *log.Entry
log *log.Entry
idS string
idB []byte
conn net.Conn
connMu sync.RWMutex
store *Store
}
func NewPeer(id []byte, conn net.Conn) *Peer {
func NewPeer(id []byte, conn net.Conn, store *Store) *Peer {
stringID := messages.HashIDToString(id)
return &Peer{
Log: log.WithField("peer_id", stringID),
idB: id,
log: log.WithField("peer_id", stringID),
idS: stringID,
idB: id,
conn: conn,
store: store,
}
}
func (p *Peer) ID() []byte {
return p.idB
func (p *Peer) Work() {
buf := make([]byte, bufferSize)
for {
n, err := p.conn.Read(buf)
if err != nil {
if err != io.EOF {
p.log.Errorf("failed to read message: %s", err)
}
return
}
msg := buf[:n]
msgType, err := messages.DetermineClientMsgType(msg)
if err != nil {
p.log.Errorf("failed to determine message type: %s", err)
return
}
switch msgType {
case messages.MsgHealthCheck:
case messages.MsgTypeTransport:
peerID, err := messages.UnmarshalTransportID(msg)
if err != nil {
p.log.Errorf("failed to unmarshal transport message: %s", err)
continue
}
stringPeerID := messages.HashIDToString(peerID)
dp, ok := p.store.Peer(stringPeerID)
if !ok {
p.log.Errorf("peer not found: %s", stringPeerID)
continue
}
err = messages.UpdateTransportMsg(msg, p.idB)
if err != nil {
p.log.Errorf("failed to update transport message: %s", err)
continue
}
_, err = dp.Write(msg)
if err != nil {
p.log.Errorf("failed to write transport message to: %s", dp.String())
}
case messages.MsgClose:
p.log.Infof("peer exited gracefully")
_ = p.conn.Close()
return
}
}
}
// Write writes data to the connection
// it has been called by the remote peer
func (p *Peer) Write(b []byte) (int, error) {
p.connMu.RLock()
defer p.connMu.RUnlock()
return p.conn.Write(b)
}
func (p *Peer) CloseGracefully(ctx context.Context) {
p.connMu.Lock()
_, err := p.writeWithTimeout(ctx, messages.MarshalCloseMsg())
if err != nil {
log.Errorf("failed to send close message to peer: %s", p.String())
}
err = p.conn.Close()
if err != nil {
log.Errorf("failed to close connection to peer: %s", err)
}
defer p.connMu.Unlock()
}
func (p *Peer) String() string {
return p.idS
}
func (p *Peer) writeWithTimeout(ctx context.Context, buf []byte) (int, error) {
ctx, cancel := context.WithTimeout(ctx, 3*time.Second)
defer cancel()
writeDone := make(chan struct{})
var (
n int
err error
)
go func() {
_, err = p.conn.Write(buf)
close(writeDone)
}()
select {
case <-ctx.Done():
return 0, fmt.Errorf("write operation timed out")
case <-writeDone:
return n, err
}
}

99
relay/server/relay.go Normal file
View File

@ -0,0 +1,99 @@
package server
import (
"context"
"fmt"
"net"
"sync"
log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/relay/messages"
)
type Relay struct {
store *Store
closed bool
closeMu sync.RWMutex
}
func NewRelay() *Relay {
return &Relay{
store: NewStore(),
}
}
func (r *Relay) Accept(conn net.Conn) {
r.closeMu.RLock()
defer r.closeMu.RUnlock()
if r.closed {
return
}
peerID, err := handShake(conn)
if err != nil {
log.Errorf("failed to handshake with %s: %s", conn.RemoteAddr(), err)
cErr := conn.Close()
if cErr != nil {
log.Errorf("failed to close connection, %s: %s", conn.RemoteAddr(), cErr)
}
return
}
peer := NewPeer(peerID, conn, r.store)
peer.log.Infof("peer connected from: %s", conn.RemoteAddr())
r.store.AddPeer(peer)
go func() {
peer.Work()
r.store.DeletePeer(peer)
peer.log.Debugf("relay connection closed")
}()
}
func (r *Relay) Close(ctx context.Context) {
log.Infof("closeing connection with all peers")
r.closeMu.Lock()
wg := sync.WaitGroup{}
peers := r.store.Peers()
for _, peer := range peers {
wg.Add(1)
go func(p *Peer) {
p.CloseGracefully(ctx)
wg.Done()
}(peer)
}
wg.Wait()
r.closeMu.Unlock()
}
func handShake(conn net.Conn) ([]byte, error) {
buf := make([]byte, messages.MaxHandshakeSize)
n, err := conn.Read(buf)
if err != nil {
log.Errorf("failed to read message: %s", err)
return nil, err
}
msgType, err := messages.DetermineClientMsgType(buf[:n])
if err != nil {
return nil, err
}
if msgType != messages.MsgTypeHello {
tErr := fmt.Errorf("invalid message type")
log.Errorf("failed to handshake: %s", tErr)
return nil, tErr
}
peerID, err := messages.UnmarshalHelloMsg(buf[:n])
if err != nil {
log.Errorf("failed to handshake: %s", err)
return nil, err
}
msg := messages.MarshalHelloResponse()
_, err = conn.Write(msg)
if err != nil {
return nil, err
}
return peerID, nil
}

View File

@ -1,36 +1,27 @@
package server
import (
"context"
"errors"
"fmt"
"io"
"net"
"sync"
"time"
log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/relay/messages"
"github.com/netbirdio/netbird/relay/server/listener"
"github.com/netbirdio/netbird/relay/server/listener/udp"
ws "github.com/netbirdio/netbird/relay/server/listener/wsnhooyr"
)
const (
bufferSize = 8820
"github.com/netbirdio/netbird/relay/server/listener/ws"
)
type Server struct {
store *Store
storeMu sync.RWMutex
UDPListener listener.Listener
WSListener listener.Listener
relay *Relay
uDPListener listener.Listener
wSListener listener.Listener
}
func NewServer() *Server {
return &Server{
store: NewStore(),
storeMu: sync.RWMutex{},
relay: NewRelay(),
}
}
@ -38,21 +29,21 @@ func (r *Server) Listen(address string) error {
wg := sync.WaitGroup{}
wg.Add(2)
r.WSListener = ws.NewListener(address)
r.wSListener = ws.NewListener(address)
var wslErr error
go func() {
defer wg.Done()
wslErr = r.WSListener.Listen(r.accept)
wslErr = r.wSListener.Listen(r.relay.Accept)
if wslErr != nil {
log.Errorf("failed to bind ws server: %s", wslErr)
}
}()
r.UDPListener = udp.NewListener(address)
r.uDPListener = udp.NewListener(address)
var udpLErr error
go func() {
defer wg.Done()
udpLErr = r.UDPListener.Listen(r.accept)
udpLErr = r.uDPListener.Listen(r.relay.Accept)
if udpLErr != nil {
log.Errorf("failed to bind ws server: %s", udpLErr)
}
@ -64,131 +55,21 @@ func (r *Server) Listen(address string) error {
func (r *Server) Close() error {
var wErr error
if r.WSListener != nil {
wErr = r.WSListener.Close()
// stop service new connections
if r.wSListener != nil {
wErr = r.wSListener.Close()
}
var uErr error
if r.UDPListener != nil {
uErr = r.UDPListener.Close()
if r.uDPListener != nil {
uErr = r.uDPListener.Close()
}
r.sendCloseMsgs()
r.WSListener.WaitForExitAcceptedConns()
// close accepted connections gracefully
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
defer cancel()
r.relay.Close(ctx)
err := errors.Join(wErr, uErr)
return err
}
func (r *Server) accept(conn net.Conn) {
peer, err := handShake(conn)
if err != nil {
log.Errorf("failed to handshake with %s: %s", conn.RemoteAddr(), err)
cErr := conn.Close()
if cErr != nil {
log.Errorf("failed to close connection, %s: %s", conn.RemoteAddr(), cErr)
}
return
}
peer.Log.Infof("peer connected from: %s", conn.RemoteAddr())
r.store.AddPeer(peer)
defer func() {
r.store.DeletePeer(peer)
peer.Log.Infof("relay connection closed")
}()
buf := make([]byte, bufferSize)
for {
n, err := conn.Read(buf)
if err != nil {
if err != io.EOF {
peer.Log.Errorf("failed to read message: %s", err)
}
return
}
msg := buf[:n]
msgType, err := messages.DetermineClientMsgType(msg)
if err != nil {
peer.Log.Errorf("failed to determine message type: %s", err)
return
}
switch msgType {
case messages.MsgTypeTransport:
peerID, err := messages.UnmarshalTransportID(msg)
if err != nil {
peer.Log.Errorf("failed to unmarshal transport message: %s", err)
continue
}
stringPeerID := messages.HashIDToString(peerID)
dp, ok := r.store.Peer(stringPeerID)
if !ok {
peer.Log.Errorf("peer not found: %s", stringPeerID)
continue
}
err = messages.UpdateTransportMsg(msg, peer.ID())
if err != nil {
peer.Log.Errorf("failed to update transport message: %s", err)
continue
}
_, err = dp.conn.Write(msg)
if err != nil {
peer.Log.Errorf("failed to write transport message to: %s", dp.String())
}
case messages.MsgClose:
peer.Log.Infof("peer disconnected gracefully")
_ = conn.Close()
return
}
}
}
func (r *Server) sendCloseMsgs() {
msg := messages.MarshalCloseMsg()
r.storeMu.Lock()
log.Debugf("sending close messages to %d peers", len(r.store.peers))
for _, p := range r.store.peers {
_, err := p.conn.Write(msg)
if err != nil {
log.Errorf("failed to send close message to peer: %s", p.String())
}
err = p.conn.Close()
if err != nil {
log.Errorf("failed to close connection to peer: %s", err)
}
}
r.storeMu.Unlock()
}
func handShake(conn net.Conn) (*Peer, error) {
buf := make([]byte, messages.MaxHandshakeSize)
n, err := conn.Read(buf)
if err != nil {
log.Errorf("failed to read message: %s", err)
return nil, err
}
msgType, err := messages.DetermineClientMsgType(buf[:n])
if err != nil {
return nil, err
}
if msgType != messages.MsgTypeHello {
tErr := fmt.Errorf("invalid message type")
log.Errorf("failed to handshake: %s", tErr)
return nil, tErr
}
peerId, err := messages.UnmarshalHelloMsg(buf[:n])
if err != nil {
log.Errorf("failed to handshake: %s", err)
return nil, err
}
p := NewPeer(peerId, conn)
msg := messages.MarshalHelloResponse()
_, err = conn.Write(msg)
return p, err
}