mirror of
https://github.com/netbirdio/netbird.git
synced 2025-08-09 15:25:20 +02:00
- Remove WaitForExitAcceptedConns logic from server
- Implement thread safe gracefully close logic - organise the server code
This commit is contained in:
99
relay/server/relay.go
Normal file
99
relay/server/relay.go
Normal file
@ -0,0 +1,99 @@
|
||||
package server
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net"
|
||||
"sync"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
"github.com/netbirdio/netbird/relay/messages"
|
||||
)
|
||||
|
||||
type Relay struct {
|
||||
store *Store
|
||||
|
||||
closed bool
|
||||
closeMu sync.RWMutex
|
||||
}
|
||||
|
||||
func NewRelay() *Relay {
|
||||
return &Relay{
|
||||
store: NewStore(),
|
||||
}
|
||||
}
|
||||
|
||||
func (r *Relay) Accept(conn net.Conn) {
|
||||
r.closeMu.RLock()
|
||||
defer r.closeMu.RUnlock()
|
||||
if r.closed {
|
||||
return
|
||||
}
|
||||
|
||||
peerID, err := handShake(conn)
|
||||
if err != nil {
|
||||
log.Errorf("failed to handshake with %s: %s", conn.RemoteAddr(), err)
|
||||
cErr := conn.Close()
|
||||
if cErr != nil {
|
||||
log.Errorf("failed to close connection, %s: %s", conn.RemoteAddr(), cErr)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
peer := NewPeer(peerID, conn, r.store)
|
||||
peer.log.Infof("peer connected from: %s", conn.RemoteAddr())
|
||||
r.store.AddPeer(peer)
|
||||
|
||||
go func() {
|
||||
peer.Work()
|
||||
r.store.DeletePeer(peer)
|
||||
peer.log.Debugf("relay connection closed")
|
||||
}()
|
||||
}
|
||||
|
||||
func (r *Relay) Close(ctx context.Context) {
|
||||
log.Infof("closeing connection with all peers")
|
||||
r.closeMu.Lock()
|
||||
wg := sync.WaitGroup{}
|
||||
peers := r.store.Peers()
|
||||
for _, peer := range peers {
|
||||
wg.Add(1)
|
||||
go func(p *Peer) {
|
||||
p.CloseGracefully(ctx)
|
||||
wg.Done()
|
||||
}(peer)
|
||||
}
|
||||
wg.Wait()
|
||||
r.closeMu.Unlock()
|
||||
}
|
||||
|
||||
func handShake(conn net.Conn) ([]byte, error) {
|
||||
buf := make([]byte, messages.MaxHandshakeSize)
|
||||
n, err := conn.Read(buf)
|
||||
if err != nil {
|
||||
log.Errorf("failed to read message: %s", err)
|
||||
return nil, err
|
||||
}
|
||||
msgType, err := messages.DetermineClientMsgType(buf[:n])
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if msgType != messages.MsgTypeHello {
|
||||
tErr := fmt.Errorf("invalid message type")
|
||||
log.Errorf("failed to handshake: %s", tErr)
|
||||
return nil, tErr
|
||||
}
|
||||
peerID, err := messages.UnmarshalHelloMsg(buf[:n])
|
||||
if err != nil {
|
||||
log.Errorf("failed to handshake: %s", err)
|
||||
return nil, err
|
||||
}
|
||||
|
||||
msg := messages.MarshalHelloResponse()
|
||||
_, err = conn.Write(msg)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return peerID, nil
|
||||
}
|
Reference in New Issue
Block a user