mirror of
https://github.com/netbirdio/netbird.git
synced 2024-12-14 10:50:45 +01:00
2d1bf3982d
Co-authored-by: Zoltán Papp <zoltan.pmail@gmail.com>
250 lines
7.1 KiB
Go
250 lines
7.1 KiB
Go
package server
|
|
|
|
import (
|
|
"context"
|
|
"fmt"
|
|
"net"
|
|
"net/url"
|
|
"strings"
|
|
"sync"
|
|
|
|
log "github.com/sirupsen/logrus"
|
|
"go.opentelemetry.io/otel/metric"
|
|
|
|
"github.com/netbirdio/netbird/relay/auth"
|
|
"github.com/netbirdio/netbird/relay/messages"
|
|
//nolint:staticcheck
|
|
"github.com/netbirdio/netbird/relay/messages/address"
|
|
//nolint:staticcheck
|
|
authmsg "github.com/netbirdio/netbird/relay/messages/auth"
|
|
"github.com/netbirdio/netbird/relay/metrics"
|
|
)
|
|
|
|
// Relay represents the relay server
|
|
type Relay struct {
|
|
metrics *metrics.Metrics
|
|
metricsCancel context.CancelFunc
|
|
validator auth.Validator
|
|
|
|
store *Store
|
|
instanceURL string
|
|
|
|
closed bool
|
|
closeMu sync.RWMutex
|
|
}
|
|
|
|
// NewRelay creates a new Relay instance
|
|
//
|
|
// Parameters:
|
|
// meter: An instance of metric.Meter from the go.opentelemetry.io/otel/metric package. It is used to create and manage
|
|
// metrics for the relay server.
|
|
// exposedAddress: A string representing the address that the relay server is exposed on. The client will use this
|
|
// address as the relay server's instance URL.
|
|
// tlsSupport: A boolean indicating whether the relay server supports TLS (Transport Layer Security) or not. The
|
|
// instance URL depends on this value.
|
|
// validator: An instance of auth.Validator from the auth package. It is used to validate the authentication of the
|
|
// peers.
|
|
//
|
|
// Returns:
|
|
// A pointer to a Relay instance and an error. If the Relay instance is successfully created, the error is nil.
|
|
// Otherwise, the error contains the details of what went wrong.
|
|
func NewRelay(meter metric.Meter, exposedAddress string, tlsSupport bool, validator auth.Validator) (*Relay, error) {
|
|
ctx, metricsCancel := context.WithCancel(context.Background())
|
|
m, err := metrics.NewMetrics(ctx, meter)
|
|
if err != nil {
|
|
metricsCancel()
|
|
return nil, fmt.Errorf("creating app metrics: %v", err)
|
|
}
|
|
|
|
r := &Relay{
|
|
metrics: m,
|
|
metricsCancel: metricsCancel,
|
|
validator: validator,
|
|
store: NewStore(),
|
|
}
|
|
|
|
r.instanceURL, err = getInstanceURL(exposedAddress, tlsSupport)
|
|
if err != nil {
|
|
metricsCancel()
|
|
return nil, fmt.Errorf("get instance URL: %v", err)
|
|
}
|
|
|
|
return r, nil
|
|
}
|
|
|
|
// getInstanceURL checks if user supplied a URL scheme otherwise adds to the
|
|
// provided address according to TLS definition and parses the address before returning it
|
|
func getInstanceURL(exposedAddress string, tlsSupported bool) (string, error) {
|
|
addr := exposedAddress
|
|
split := strings.Split(exposedAddress, "://")
|
|
switch {
|
|
case len(split) == 1 && tlsSupported:
|
|
addr = "rels://" + exposedAddress
|
|
case len(split) == 1 && !tlsSupported:
|
|
addr = "rel://" + exposedAddress
|
|
case len(split) > 2:
|
|
return "", fmt.Errorf("invalid exposed address: %s", exposedAddress)
|
|
}
|
|
|
|
parsedURL, err := url.ParseRequestURI(addr)
|
|
if err != nil {
|
|
return "", fmt.Errorf("invalid exposed address: %v", err)
|
|
}
|
|
|
|
if parsedURL.Scheme != "rel" && parsedURL.Scheme != "rels" {
|
|
return "", fmt.Errorf("invalid scheme: %s", parsedURL.Scheme)
|
|
}
|
|
|
|
return parsedURL.String(), nil
|
|
}
|
|
|
|
// Accept start to handle a new peer connection
|
|
func (r *Relay) Accept(conn net.Conn) {
|
|
r.closeMu.RLock()
|
|
defer r.closeMu.RUnlock()
|
|
if r.closed {
|
|
return
|
|
}
|
|
|
|
peerID, err := r.handshake(conn)
|
|
if err != nil {
|
|
log.Errorf("failed to handshake: %s", err)
|
|
cErr := conn.Close()
|
|
if cErr != nil {
|
|
log.Errorf("failed to close connection, %s: %s", conn.RemoteAddr(), cErr)
|
|
}
|
|
return
|
|
}
|
|
|
|
peer := NewPeer(r.metrics, peerID, conn, r.store)
|
|
peer.log.Infof("peer connected from: %s", conn.RemoteAddr())
|
|
r.store.AddPeer(peer)
|
|
r.metrics.PeerConnected(peer.String())
|
|
go func() {
|
|
peer.Work()
|
|
r.store.DeletePeer(peer)
|
|
peer.log.Debugf("relay connection closed")
|
|
r.metrics.PeerDisconnected(peer.String())
|
|
}()
|
|
}
|
|
|
|
// Shutdown closes the relay server
|
|
// It closes the connection with all peers in gracefully and stops accepting new connections.
|
|
func (r *Relay) Shutdown(ctx context.Context) {
|
|
log.Infof("close connection with all peers")
|
|
r.closeMu.Lock()
|
|
wg := sync.WaitGroup{}
|
|
peers := r.store.Peers()
|
|
for _, peer := range peers {
|
|
wg.Add(1)
|
|
go func(p *Peer) {
|
|
p.CloseGracefully(ctx)
|
|
wg.Done()
|
|
}(peer)
|
|
}
|
|
wg.Wait()
|
|
r.metricsCancel()
|
|
r.closeMu.Unlock()
|
|
}
|
|
|
|
// InstanceURL returns the instance URL of the relay server
|
|
func (r *Relay) InstanceURL() string {
|
|
return r.instanceURL
|
|
}
|
|
|
|
func (r *Relay) handshake(conn net.Conn) ([]byte, error) {
|
|
buf := make([]byte, messages.MaxHandshakeSize)
|
|
n, err := conn.Read(buf)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("read from %s: %w", conn.RemoteAddr(), err)
|
|
}
|
|
|
|
_, err = messages.ValidateVersion(buf[:n])
|
|
if err != nil {
|
|
return nil, fmt.Errorf("validate version from %s: %w", conn.RemoteAddr(), err)
|
|
}
|
|
|
|
msgType, err := messages.DetermineClientMessageType(buf[messages.SizeOfVersionByte:n])
|
|
if err != nil {
|
|
return nil, fmt.Errorf("determine message type from %s: %w", conn.RemoteAddr(), err)
|
|
}
|
|
|
|
var (
|
|
responseMsg []byte
|
|
peerID []byte
|
|
)
|
|
switch msgType {
|
|
//nolint:staticcheck
|
|
case messages.MsgTypeHello:
|
|
peerID, responseMsg, err = r.handleHelloMsg(buf[messages.SizeOfProtoHeader:n], conn.RemoteAddr())
|
|
case messages.MsgTypeAuth:
|
|
peerID, responseMsg, err = r.handleAuthMsg(buf[messages.SizeOfProtoHeader:n], conn.RemoteAddr())
|
|
default:
|
|
return nil, fmt.Errorf("invalid message type %d from %s", msgType, conn.RemoteAddr())
|
|
}
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
_, err = conn.Write(responseMsg)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("write to %s (%s): %w", peerID, conn.RemoteAddr(), err)
|
|
}
|
|
|
|
return peerID, nil
|
|
}
|
|
|
|
func (r *Relay) handleHelloMsg(buf []byte, remoteAddr net.Addr) ([]byte, []byte, error) {
|
|
//nolint:staticcheck
|
|
rawPeerID, authData, err := messages.UnmarshalHelloMsg(buf)
|
|
if err != nil {
|
|
return nil, nil, fmt.Errorf("unmarshal hello message: %w", err)
|
|
}
|
|
|
|
peerID := messages.HashIDToString(rawPeerID)
|
|
log.Warnf("peer %s (%s) is using deprecated initial message type", peerID, remoteAddr)
|
|
|
|
authMsg, err := authmsg.UnmarshalMsg(authData)
|
|
if err != nil {
|
|
return nil, nil, fmt.Errorf("unmarshal auth message: %w", err)
|
|
}
|
|
|
|
//nolint:staticcheck
|
|
if err := r.validator.ValidateHelloMsgType(authMsg.AdditionalData); err != nil {
|
|
return nil, nil, fmt.Errorf("validate %s (%s): %w", peerID, remoteAddr, err)
|
|
}
|
|
|
|
addr := &address.Address{URL: r.instanceURL}
|
|
addrData, err := addr.Marshal()
|
|
if err != nil {
|
|
return nil, nil, fmt.Errorf("marshal addressc to %s (%s): %w", peerID, remoteAddr, err)
|
|
}
|
|
|
|
//nolint:staticcheck
|
|
responseMsg, err := messages.MarshalHelloResponse(addrData)
|
|
if err != nil {
|
|
return nil, nil, fmt.Errorf("marshal hello response to %s (%s): %w", peerID, remoteAddr, err)
|
|
}
|
|
return rawPeerID, responseMsg, nil
|
|
}
|
|
|
|
func (r *Relay) handleAuthMsg(buf []byte, addr net.Addr) ([]byte, []byte, error) {
|
|
rawPeerID, authPayload, err := messages.UnmarshalAuthMsg(buf)
|
|
if err != nil {
|
|
return nil, nil, fmt.Errorf("unmarshal hello message: %w", err)
|
|
}
|
|
|
|
peerID := messages.HashIDToString(rawPeerID)
|
|
|
|
if err := r.validator.Validate(authPayload); err != nil {
|
|
return nil, nil, fmt.Errorf("validate %s (%s): %w", peerID, addr, err)
|
|
}
|
|
|
|
responseMsg, err := messages.MarshalAuthResponse(r.instanceURL)
|
|
if err != nil {
|
|
return nil, nil, fmt.Errorf("marshal hello response to %s (%s): %w", peerID, addr, err)
|
|
}
|
|
|
|
return rawPeerID, responseMsg, nil
|
|
}
|