netbird/relay/server/server.go

195 lines
4.1 KiB
Go
Raw Normal View History

2024-05-17 17:43:28 +02:00
package server
import (
2024-05-26 22:14:33 +02:00
"errors"
2024-05-17 17:43:28 +02:00
"fmt"
"io"
"net"
2024-05-26 22:14:33 +02:00
"sync"
2024-05-17 17:43:28 +02:00
log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/relay/messages"
"github.com/netbirdio/netbird/relay/server/listener"
"github.com/netbirdio/netbird/relay/server/listener/udp"
ws "github.com/netbirdio/netbird/relay/server/listener/wsnhooyr"
2024-05-17 17:43:28 +02:00
)
2024-06-26 15:26:19 +02:00
const (
2024-06-26 16:22:26 +02:00
bufferSize = 8820
2024-06-26 15:26:19 +02:00
)
2024-05-17 17:43:28 +02:00
type Server struct {
2024-06-05 19:49:30 +02:00
store *Store
storeMu sync.RWMutex
2024-05-17 17:43:28 +02:00
2024-05-26 22:14:33 +02:00
UDPListener listener.Listener
WSListener listener.Listener
2024-05-17 17:43:28 +02:00
}
func NewServer() *Server {
return &Server{
2024-06-05 19:49:30 +02:00
store: NewStore(),
storeMu: sync.RWMutex{},
2024-05-17 17:43:28 +02:00
}
}
func (r *Server) Listen(address string) error {
2024-05-26 22:14:33 +02:00
wg := sync.WaitGroup{}
wg.Add(2)
r.WSListener = ws.NewListener(address)
var wslErr error
go func() {
defer wg.Done()
wslErr = r.WSListener.Listen(r.accept)
if wslErr != nil {
log.Errorf("failed to bind ws server: %s", wslErr)
}
}()
r.UDPListener = udp.NewListener(address)
var udpLErr error
go func() {
defer wg.Done()
udpLErr = r.UDPListener.Listen(r.accept)
if udpLErr != nil {
log.Errorf("failed to bind ws server: %s", udpLErr)
}
}()
err := errors.Join(wslErr, udpLErr)
return err
2024-05-17 17:43:28 +02:00
}
func (r *Server) Close() error {
2024-05-26 22:14:33 +02:00
var wErr error
if r.WSListener != nil {
wErr = r.WSListener.Close()
}
var uErr error
if r.UDPListener != nil {
uErr = r.UDPListener.Close()
2024-05-17 17:43:28 +02:00
}
2024-06-05 19:49:30 +02:00
r.sendCloseMsgs()
r.WSListener.WaitForExitAcceptedConns()
2024-05-26 22:14:33 +02:00
err := errors.Join(wErr, uErr)
return err
2024-05-17 17:43:28 +02:00
}
func (r *Server) accept(conn net.Conn) {
peer, err := handShake(conn)
if err != nil {
2024-06-25 15:13:08 +02:00
log.Errorf("failed to handshake with %s: %s", conn.RemoteAddr(), err)
2024-05-17 17:43:28 +02:00
cErr := conn.Close()
if cErr != nil {
log.Errorf("failed to close connection, %s: %s", conn.RemoteAddr(), cErr)
}
return
}
2024-05-26 22:14:33 +02:00
peer.Log.Infof("peer connected from: %s", conn.RemoteAddr())
2024-05-17 17:43:28 +02:00
r.store.AddPeer(peer)
defer func() {
r.store.DeletePeer(peer)
2024-06-05 19:49:30 +02:00
peer.Log.Infof("relay connection closed")
2024-05-17 17:43:28 +02:00
}()
2024-06-26 15:26:19 +02:00
buf := make([]byte, bufferSize)
2024-05-17 17:43:28 +02:00
for {
n, err := conn.Read(buf)
if err != nil {
if err != io.EOF {
peer.Log.Errorf("failed to read message: %s", err)
}
return
}
2024-06-26 15:26:19 +02:00
msg := buf[:n]
msgType, err := messages.DetermineClientMsgType(msg)
2024-05-17 17:43:28 +02:00
if err != nil {
peer.Log.Errorf("failed to determine message type: %s", err)
2024-05-17 17:43:28 +02:00
return
}
switch msgType {
case messages.MsgTypeTransport:
2024-05-23 13:24:02 +02:00
peerID, err := messages.UnmarshalTransportID(msg)
2024-05-17 17:43:28 +02:00
if err != nil {
peer.Log.Errorf("failed to unmarshal transport message: %s", err)
continue
}
2024-06-26 15:26:19 +02:00
stringPeerID := messages.HashIDToString(peerID)
dp, ok := r.store.Peer(stringPeerID)
if !ok {
peer.Log.Errorf("peer not found: %s", stringPeerID)
continue
}
err = messages.UpdateTransportMsg(msg, peer.ID())
if err != nil {
peer.Log.Errorf("failed to update transport message: %s", err)
continue
}
_, err = dp.conn.Write(msg)
if err != nil {
peer.Log.Errorf("failed to write transport message to: %s", dp.String())
}
2024-06-05 19:49:30 +02:00
case messages.MsgClose:
peer.Log.Infof("peer disconnected gracefully")
_ = conn.Close()
return
}
}
}
func (r *Server) sendCloseMsgs() {
msg := messages.MarshalCloseMsg()
r.storeMu.Lock()
log.Debugf("sending close messages to %d peers", len(r.store.peers))
for _, p := range r.store.peers {
_, err := p.conn.Write(msg)
if err != nil {
log.Errorf("failed to send close message to peer: %s", p.String())
}
err = p.conn.Close()
if err != nil {
log.Errorf("failed to close connection to peer: %s", err)
2024-05-17 17:43:28 +02:00
}
}
2024-06-05 19:49:30 +02:00
r.storeMu.Unlock()
2024-05-17 17:43:28 +02:00
}
func handShake(conn net.Conn) (*Peer, error) {
2024-06-26 16:22:26 +02:00
buf := make([]byte, messages.MaxHandshakeSize)
2024-05-17 17:43:28 +02:00
n, err := conn.Read(buf)
if err != nil {
log.Errorf("failed to read message: %s", err)
return nil, err
}
msgType, err := messages.DetermineClientMsgType(buf[:n])
if err != nil {
return nil, err
}
if msgType != messages.MsgTypeHello {
tErr := fmt.Errorf("invalid message type")
log.Errorf("failed to handshake: %s", tErr)
return nil, tErr
}
peerId, err := messages.UnmarshalHelloMsg(buf[:n])
if err != nil {
log.Errorf("failed to handshake: %s", err)
return nil, err
}
p := NewPeer(peerId, conn)
msg := messages.MarshalHelloResponse()
_, err = conn.Write(msg)
return p, err
2024-05-17 17:43:28 +02:00
}