mirror of
https://github.com/netbirdio/netbird.git
synced 2025-06-21 18:22:37 +02:00
- Remove WaitForExitAcceptedConns logic from server
- Implement thread safe gracefully close logic - organise the server code
This commit is contained in:
parent
3fcdb51376
commit
4d0e16f2d0
@ -5,5 +5,4 @@ import "net"
|
||||
type Listener interface {
|
||||
Listen(func(conn net.Conn)) error
|
||||
Close() error
|
||||
WaitForExitAcceptedConns()
|
||||
}
|
||||
|
@ -21,11 +21,6 @@ type Listener struct {
|
||||
lock sync.Mutex
|
||||
}
|
||||
|
||||
func (l *Listener) WaitForExitAcceptedConns() {
|
||||
l.wg.Wait()
|
||||
return
|
||||
}
|
||||
|
||||
func NewListener(address string) listener.Listener {
|
||||
return &Listener{
|
||||
address: address,
|
||||
|
@ -6,7 +6,6 @@ import (
|
||||
"fmt"
|
||||
"net"
|
||||
"net/http"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
@ -18,7 +17,6 @@ import (
|
||||
type Listener struct {
|
||||
address string
|
||||
|
||||
wg sync.WaitGroup
|
||||
server *http.Server
|
||||
acceptFn func(conn net.Conn)
|
||||
}
|
||||
@ -63,14 +61,7 @@ func (l *Listener) Close() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (l *Listener) WaitForExitAcceptedConns() {
|
||||
l.wg.Wait()
|
||||
}
|
||||
|
||||
func (l *Listener) onAccept(w http.ResponseWriter, r *http.Request) {
|
||||
l.wg.Add(1)
|
||||
defer l.wg.Done()
|
||||
|
||||
wsConn, err := websocket.Accept(w, r, nil)
|
||||
if err != nil {
|
||||
log.Errorf("failed to accept ws connection: %s", err)
|
||||
@ -91,5 +82,4 @@ func (l *Listener) onAccept(w http.ResponseWriter, r *http.Request) {
|
||||
|
||||
conn := NewConn(wsConn, lAddr, rAddr)
|
||||
l.acceptFn(conn)
|
||||
return
|
||||
}
|
@ -1,33 +1,137 @@
|
||||
package server
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"io"
|
||||
"net"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
"github.com/netbirdio/netbird/relay/messages"
|
||||
)
|
||||
|
||||
const (
|
||||
bufferSize = 8820
|
||||
)
|
||||
|
||||
type Peer struct {
|
||||
Log *log.Entry
|
||||
idS string
|
||||
idB []byte
|
||||
conn net.Conn
|
||||
log *log.Entry
|
||||
idS string
|
||||
idB []byte
|
||||
conn net.Conn
|
||||
connMu sync.RWMutex
|
||||
store *Store
|
||||
}
|
||||
|
||||
func NewPeer(id []byte, conn net.Conn) *Peer {
|
||||
func NewPeer(id []byte, conn net.Conn, store *Store) *Peer {
|
||||
stringID := messages.HashIDToString(id)
|
||||
return &Peer{
|
||||
Log: log.WithField("peer_id", stringID),
|
||||
idB: id,
|
||||
idS: stringID,
|
||||
conn: conn,
|
||||
log: log.WithField("peer_id", stringID),
|
||||
idS: stringID,
|
||||
idB: id,
|
||||
conn: conn,
|
||||
store: store,
|
||||
}
|
||||
}
|
||||
func (p *Peer) ID() []byte {
|
||||
return p.idB
|
||||
|
||||
func (p *Peer) Work() {
|
||||
buf := make([]byte, bufferSize)
|
||||
for {
|
||||
n, err := p.conn.Read(buf)
|
||||
if err != nil {
|
||||
if err != io.EOF {
|
||||
p.log.Errorf("failed to read message: %s", err)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
msg := buf[:n]
|
||||
|
||||
msgType, err := messages.DetermineClientMsgType(msg)
|
||||
if err != nil {
|
||||
p.log.Errorf("failed to determine message type: %s", err)
|
||||
return
|
||||
}
|
||||
switch msgType {
|
||||
case messages.MsgHealthCheck:
|
||||
case messages.MsgTypeTransport:
|
||||
peerID, err := messages.UnmarshalTransportID(msg)
|
||||
if err != nil {
|
||||
p.log.Errorf("failed to unmarshal transport message: %s", err)
|
||||
continue
|
||||
}
|
||||
stringPeerID := messages.HashIDToString(peerID)
|
||||
dp, ok := p.store.Peer(stringPeerID)
|
||||
if !ok {
|
||||
p.log.Errorf("peer not found: %s", stringPeerID)
|
||||
continue
|
||||
}
|
||||
err = messages.UpdateTransportMsg(msg, p.idB)
|
||||
if err != nil {
|
||||
p.log.Errorf("failed to update transport message: %s", err)
|
||||
continue
|
||||
}
|
||||
_, err = dp.Write(msg)
|
||||
if err != nil {
|
||||
p.log.Errorf("failed to write transport message to: %s", dp.String())
|
||||
}
|
||||
case messages.MsgClose:
|
||||
p.log.Infof("peer exited gracefully")
|
||||
_ = p.conn.Close()
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Write writes data to the connection
|
||||
// it has been called by the remote peer
|
||||
func (p *Peer) Write(b []byte) (int, error) {
|
||||
p.connMu.RLock()
|
||||
defer p.connMu.RUnlock()
|
||||
return p.conn.Write(b)
|
||||
}
|
||||
|
||||
func (p *Peer) CloseGracefully(ctx context.Context) {
|
||||
p.connMu.Lock()
|
||||
_, err := p.writeWithTimeout(ctx, messages.MarshalCloseMsg())
|
||||
if err != nil {
|
||||
log.Errorf("failed to send close message to peer: %s", p.String())
|
||||
}
|
||||
|
||||
err = p.conn.Close()
|
||||
if err != nil {
|
||||
log.Errorf("failed to close connection to peer: %s", err)
|
||||
}
|
||||
|
||||
defer p.connMu.Unlock()
|
||||
}
|
||||
|
||||
func (p *Peer) String() string {
|
||||
return p.idS
|
||||
}
|
||||
|
||||
func (p *Peer) writeWithTimeout(ctx context.Context, buf []byte) (int, error) {
|
||||
ctx, cancel := context.WithTimeout(ctx, 3*time.Second)
|
||||
defer cancel()
|
||||
|
||||
writeDone := make(chan struct{})
|
||||
var (
|
||||
n int
|
||||
err error
|
||||
)
|
||||
|
||||
go func() {
|
||||
_, err = p.conn.Write(buf)
|
||||
close(writeDone)
|
||||
}()
|
||||
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return 0, fmt.Errorf("write operation timed out")
|
||||
case <-writeDone:
|
||||
return n, err
|
||||
}
|
||||
}
|
||||
|
99
relay/server/relay.go
Normal file
99
relay/server/relay.go
Normal file
@ -0,0 +1,99 @@
|
||||
package server
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net"
|
||||
"sync"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
"github.com/netbirdio/netbird/relay/messages"
|
||||
)
|
||||
|
||||
type Relay struct {
|
||||
store *Store
|
||||
|
||||
closed bool
|
||||
closeMu sync.RWMutex
|
||||
}
|
||||
|
||||
func NewRelay() *Relay {
|
||||
return &Relay{
|
||||
store: NewStore(),
|
||||
}
|
||||
}
|
||||
|
||||
func (r *Relay) Accept(conn net.Conn) {
|
||||
r.closeMu.RLock()
|
||||
defer r.closeMu.RUnlock()
|
||||
if r.closed {
|
||||
return
|
||||
}
|
||||
|
||||
peerID, err := handShake(conn)
|
||||
if err != nil {
|
||||
log.Errorf("failed to handshake with %s: %s", conn.RemoteAddr(), err)
|
||||
cErr := conn.Close()
|
||||
if cErr != nil {
|
||||
log.Errorf("failed to close connection, %s: %s", conn.RemoteAddr(), cErr)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
peer := NewPeer(peerID, conn, r.store)
|
||||
peer.log.Infof("peer connected from: %s", conn.RemoteAddr())
|
||||
r.store.AddPeer(peer)
|
||||
|
||||
go func() {
|
||||
peer.Work()
|
||||
r.store.DeletePeer(peer)
|
||||
peer.log.Debugf("relay connection closed")
|
||||
}()
|
||||
}
|
||||
|
||||
func (r *Relay) Close(ctx context.Context) {
|
||||
log.Infof("closeing connection with all peers")
|
||||
r.closeMu.Lock()
|
||||
wg := sync.WaitGroup{}
|
||||
peers := r.store.Peers()
|
||||
for _, peer := range peers {
|
||||
wg.Add(1)
|
||||
go func(p *Peer) {
|
||||
p.CloseGracefully(ctx)
|
||||
wg.Done()
|
||||
}(peer)
|
||||
}
|
||||
wg.Wait()
|
||||
r.closeMu.Unlock()
|
||||
}
|
||||
|
||||
func handShake(conn net.Conn) ([]byte, error) {
|
||||
buf := make([]byte, messages.MaxHandshakeSize)
|
||||
n, err := conn.Read(buf)
|
||||
if err != nil {
|
||||
log.Errorf("failed to read message: %s", err)
|
||||
return nil, err
|
||||
}
|
||||
msgType, err := messages.DetermineClientMsgType(buf[:n])
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if msgType != messages.MsgTypeHello {
|
||||
tErr := fmt.Errorf("invalid message type")
|
||||
log.Errorf("failed to handshake: %s", tErr)
|
||||
return nil, tErr
|
||||
}
|
||||
peerID, err := messages.UnmarshalHelloMsg(buf[:n])
|
||||
if err != nil {
|
||||
log.Errorf("failed to handshake: %s", err)
|
||||
return nil, err
|
||||
}
|
||||
|
||||
msg := messages.MarshalHelloResponse()
|
||||
_, err = conn.Write(msg)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return peerID, nil
|
||||
}
|
@ -1,36 +1,27 @@
|
||||
package server
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"net"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
"github.com/netbirdio/netbird/relay/messages"
|
||||
"github.com/netbirdio/netbird/relay/server/listener"
|
||||
"github.com/netbirdio/netbird/relay/server/listener/udp"
|
||||
ws "github.com/netbirdio/netbird/relay/server/listener/wsnhooyr"
|
||||
)
|
||||
|
||||
const (
|
||||
bufferSize = 8820
|
||||
"github.com/netbirdio/netbird/relay/server/listener/ws"
|
||||
)
|
||||
|
||||
type Server struct {
|
||||
store *Store
|
||||
storeMu sync.RWMutex
|
||||
|
||||
UDPListener listener.Listener
|
||||
WSListener listener.Listener
|
||||
relay *Relay
|
||||
uDPListener listener.Listener
|
||||
wSListener listener.Listener
|
||||
}
|
||||
|
||||
func NewServer() *Server {
|
||||
return &Server{
|
||||
store: NewStore(),
|
||||
storeMu: sync.RWMutex{},
|
||||
relay: NewRelay(),
|
||||
}
|
||||
}
|
||||
|
||||
@ -38,21 +29,21 @@ func (r *Server) Listen(address string) error {
|
||||
wg := sync.WaitGroup{}
|
||||
wg.Add(2)
|
||||
|
||||
r.WSListener = ws.NewListener(address)
|
||||
r.wSListener = ws.NewListener(address)
|
||||
var wslErr error
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
wslErr = r.WSListener.Listen(r.accept)
|
||||
wslErr = r.wSListener.Listen(r.relay.Accept)
|
||||
if wslErr != nil {
|
||||
log.Errorf("failed to bind ws server: %s", wslErr)
|
||||
}
|
||||
}()
|
||||
|
||||
r.UDPListener = udp.NewListener(address)
|
||||
r.uDPListener = udp.NewListener(address)
|
||||
var udpLErr error
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
udpLErr = r.UDPListener.Listen(r.accept)
|
||||
udpLErr = r.uDPListener.Listen(r.relay.Accept)
|
||||
if udpLErr != nil {
|
||||
log.Errorf("failed to bind ws server: %s", udpLErr)
|
||||
}
|
||||
@ -64,131 +55,21 @@ func (r *Server) Listen(address string) error {
|
||||
|
||||
func (r *Server) Close() error {
|
||||
var wErr error
|
||||
if r.WSListener != nil {
|
||||
wErr = r.WSListener.Close()
|
||||
// stop service new connections
|
||||
if r.wSListener != nil {
|
||||
wErr = r.wSListener.Close()
|
||||
}
|
||||
|
||||
var uErr error
|
||||
if r.UDPListener != nil {
|
||||
uErr = r.UDPListener.Close()
|
||||
if r.uDPListener != nil {
|
||||
uErr = r.uDPListener.Close()
|
||||
}
|
||||
|
||||
r.sendCloseMsgs()
|
||||
|
||||
r.WSListener.WaitForExitAcceptedConns()
|
||||
// close accepted connections gracefully
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
|
||||
defer cancel()
|
||||
r.relay.Close(ctx)
|
||||
|
||||
err := errors.Join(wErr, uErr)
|
||||
return err
|
||||
}
|
||||
|
||||
func (r *Server) accept(conn net.Conn) {
|
||||
peer, err := handShake(conn)
|
||||
if err != nil {
|
||||
log.Errorf("failed to handshake with %s: %s", conn.RemoteAddr(), err)
|
||||
cErr := conn.Close()
|
||||
if cErr != nil {
|
||||
log.Errorf("failed to close connection, %s: %s", conn.RemoteAddr(), cErr)
|
||||
}
|
||||
return
|
||||
}
|
||||
peer.Log.Infof("peer connected from: %s", conn.RemoteAddr())
|
||||
|
||||
r.store.AddPeer(peer)
|
||||
defer func() {
|
||||
r.store.DeletePeer(peer)
|
||||
peer.Log.Infof("relay connection closed")
|
||||
}()
|
||||
|
||||
buf := make([]byte, bufferSize)
|
||||
for {
|
||||
n, err := conn.Read(buf)
|
||||
if err != nil {
|
||||
if err != io.EOF {
|
||||
peer.Log.Errorf("failed to read message: %s", err)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
msg := buf[:n]
|
||||
|
||||
msgType, err := messages.DetermineClientMsgType(msg)
|
||||
if err != nil {
|
||||
peer.Log.Errorf("failed to determine message type: %s", err)
|
||||
return
|
||||
}
|
||||
switch msgType {
|
||||
case messages.MsgTypeTransport:
|
||||
peerID, err := messages.UnmarshalTransportID(msg)
|
||||
if err != nil {
|
||||
peer.Log.Errorf("failed to unmarshal transport message: %s", err)
|
||||
continue
|
||||
}
|
||||
stringPeerID := messages.HashIDToString(peerID)
|
||||
dp, ok := r.store.Peer(stringPeerID)
|
||||
if !ok {
|
||||
peer.Log.Errorf("peer not found: %s", stringPeerID)
|
||||
continue
|
||||
}
|
||||
err = messages.UpdateTransportMsg(msg, peer.ID())
|
||||
if err != nil {
|
||||
peer.Log.Errorf("failed to update transport message: %s", err)
|
||||
continue
|
||||
}
|
||||
_, err = dp.conn.Write(msg)
|
||||
if err != nil {
|
||||
peer.Log.Errorf("failed to write transport message to: %s", dp.String())
|
||||
}
|
||||
case messages.MsgClose:
|
||||
peer.Log.Infof("peer disconnected gracefully")
|
||||
_ = conn.Close()
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (r *Server) sendCloseMsgs() {
|
||||
msg := messages.MarshalCloseMsg()
|
||||
|
||||
r.storeMu.Lock()
|
||||
log.Debugf("sending close messages to %d peers", len(r.store.peers))
|
||||
for _, p := range r.store.peers {
|
||||
_, err := p.conn.Write(msg)
|
||||
if err != nil {
|
||||
log.Errorf("failed to send close message to peer: %s", p.String())
|
||||
}
|
||||
|
||||
err = p.conn.Close()
|
||||
if err != nil {
|
||||
log.Errorf("failed to close connection to peer: %s", err)
|
||||
}
|
||||
}
|
||||
r.storeMu.Unlock()
|
||||
}
|
||||
|
||||
func handShake(conn net.Conn) (*Peer, error) {
|
||||
buf := make([]byte, messages.MaxHandshakeSize)
|
||||
n, err := conn.Read(buf)
|
||||
if err != nil {
|
||||
log.Errorf("failed to read message: %s", err)
|
||||
return nil, err
|
||||
}
|
||||
msgType, err := messages.DetermineClientMsgType(buf[:n])
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if msgType != messages.MsgTypeHello {
|
||||
tErr := fmt.Errorf("invalid message type")
|
||||
log.Errorf("failed to handshake: %s", tErr)
|
||||
return nil, tErr
|
||||
}
|
||||
peerId, err := messages.UnmarshalHelloMsg(buf[:n])
|
||||
if err != nil {
|
||||
log.Errorf("failed to handshake: %s", err)
|
||||
return nil, err
|
||||
}
|
||||
p := NewPeer(peerId, conn)
|
||||
|
||||
msg := messages.MarshalHelloResponse()
|
||||
_, err = conn.Write(msg)
|
||||
return p, err
|
||||
}
|
||||
|
Loading…
x
Reference in New Issue
Block a user