mirror of
https://github.com/netbirdio/netbird.git
synced 2024-12-15 11:21:04 +01:00
4d0e16f2d0
- Implement thread safe gracefully close logic - organise the server code
100 lines
1.9 KiB
Go
100 lines
1.9 KiB
Go
package server
|
|
|
|
import (
|
|
"context"
|
|
"fmt"
|
|
"net"
|
|
"sync"
|
|
|
|
log "github.com/sirupsen/logrus"
|
|
|
|
"github.com/netbirdio/netbird/relay/messages"
|
|
)
|
|
|
|
type Relay struct {
|
|
store *Store
|
|
|
|
closed bool
|
|
closeMu sync.RWMutex
|
|
}
|
|
|
|
func NewRelay() *Relay {
|
|
return &Relay{
|
|
store: NewStore(),
|
|
}
|
|
}
|
|
|
|
func (r *Relay) Accept(conn net.Conn) {
|
|
r.closeMu.RLock()
|
|
defer r.closeMu.RUnlock()
|
|
if r.closed {
|
|
return
|
|
}
|
|
|
|
peerID, err := handShake(conn)
|
|
if err != nil {
|
|
log.Errorf("failed to handshake with %s: %s", conn.RemoteAddr(), err)
|
|
cErr := conn.Close()
|
|
if cErr != nil {
|
|
log.Errorf("failed to close connection, %s: %s", conn.RemoteAddr(), cErr)
|
|
}
|
|
return
|
|
}
|
|
|
|
peer := NewPeer(peerID, conn, r.store)
|
|
peer.log.Infof("peer connected from: %s", conn.RemoteAddr())
|
|
r.store.AddPeer(peer)
|
|
|
|
go func() {
|
|
peer.Work()
|
|
r.store.DeletePeer(peer)
|
|
peer.log.Debugf("relay connection closed")
|
|
}()
|
|
}
|
|
|
|
func (r *Relay) Close(ctx context.Context) {
|
|
log.Infof("closeing connection with all peers")
|
|
r.closeMu.Lock()
|
|
wg := sync.WaitGroup{}
|
|
peers := r.store.Peers()
|
|
for _, peer := range peers {
|
|
wg.Add(1)
|
|
go func(p *Peer) {
|
|
p.CloseGracefully(ctx)
|
|
wg.Done()
|
|
}(peer)
|
|
}
|
|
wg.Wait()
|
|
r.closeMu.Unlock()
|
|
}
|
|
|
|
func handShake(conn net.Conn) ([]byte, error) {
|
|
buf := make([]byte, messages.MaxHandshakeSize)
|
|
n, err := conn.Read(buf)
|
|
if err != nil {
|
|
log.Errorf("failed to read message: %s", err)
|
|
return nil, err
|
|
}
|
|
msgType, err := messages.DetermineClientMsgType(buf[:n])
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
if msgType != messages.MsgTypeHello {
|
|
tErr := fmt.Errorf("invalid message type")
|
|
log.Errorf("failed to handshake: %s", tErr)
|
|
return nil, tErr
|
|
}
|
|
peerID, err := messages.UnmarshalHelloMsg(buf[:n])
|
|
if err != nil {
|
|
log.Errorf("failed to handshake: %s", err)
|
|
return nil, err
|
|
}
|
|
|
|
msg := messages.MarshalHelloResponse()
|
|
_, err = conn.Write(msg)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
return peerID, nil
|
|
}
|