netbird/relay/server/relay.go

164 lines
4.2 KiB
Go
Raw Normal View History

package server
import (
"context"
"fmt"
"net"
2024-07-29 21:53:07 +02:00
"net/url"
"sync"
log "github.com/sirupsen/logrus"
2024-07-24 16:26:26 +02:00
"go.opentelemetry.io/otel/metric"
2024-07-05 16:12:30 +02:00
"github.com/netbirdio/netbird/relay/auth"
"github.com/netbirdio/netbird/relay/messages"
2024-07-24 16:26:26 +02:00
"github.com/netbirdio/netbird/relay/metrics"
)
2024-07-29 21:53:07 +02:00
// Relay represents the relay server
type Relay struct {
metrics *metrics.Metrics
metricsCancel context.CancelFunc
validator auth.Validator
2024-07-05 16:12:30 +02:00
2024-07-24 16:34:47 +02:00
store *Store
instanceURL string
closed bool
closeMu sync.RWMutex
}
2024-07-29 21:53:07 +02:00
// NewRelay creates a new Relay instance
//
// Parameters:
// meter: An instance of metric.Meter from the go.opentelemetry.io/otel/metric package. It is used to create and manage
// metrics for the relay server.
// exposedAddress: A string representing the address that the relay server is exposed on. The client will use this
// address as the relay server's instance URL.
// tlsSupport: A boolean indicating whether the relay server supports TLS (Transport Layer Security) or not. The
// instance URL depends on this value.
// validator: An instance of auth.Validator from the auth package. It is used to validate the authentication of the
// peers.
//
// Returns:
// A pointer to a Relay instance and an error. If the Relay instance is successfully created, the error is nil.
// Otherwise, the error contains the details of what went wrong.
2024-07-24 16:26:26 +02:00
func NewRelay(meter metric.Meter, exposedAddress string, tlsSupport bool, validator auth.Validator) (*Relay, error) {
ctx, metricsCancel := context.WithCancel(context.Background())
m, err := metrics.NewMetrics(ctx, meter)
2024-07-24 16:26:26 +02:00
if err != nil {
metricsCancel()
2024-07-24 16:26:26 +02:00
return nil, fmt.Errorf("creating app metrics: %v", err)
}
2024-07-02 11:57:17 +02:00
r := &Relay{
metrics: m,
metricsCancel: metricsCancel,
validator: validator,
store: NewStore(),
}
2024-07-02 11:57:17 +02:00
if tlsSupport {
2024-07-24 16:34:47 +02:00
r.instanceURL = fmt.Sprintf("rels://%s", exposedAddress)
2024-07-02 11:57:17 +02:00
} else {
2024-07-24 16:34:47 +02:00
r.instanceURL = fmt.Sprintf("rel://%s", exposedAddress)
2024-07-02 11:57:17 +02:00
}
2024-07-29 21:53:07 +02:00
_, err = url.ParseRequestURI(r.instanceURL)
if err != nil {
return nil, fmt.Errorf("invalid exposed address: %v", err)
}
2024-07-05 16:12:30 +02:00
2024-07-24 16:26:26 +02:00
return r, nil
}
2024-07-29 21:53:07 +02:00
// Accept start to handle a new peer connection
func (r *Relay) Accept(conn net.Conn) {
r.closeMu.RLock()
defer r.closeMu.RUnlock()
if r.closed {
return
}
2024-07-29 21:53:07 +02:00
peerID, err := r.handshake(conn)
if err != nil {
log.Errorf("failed to handshake with %s: %s", conn.RemoteAddr(), err)
cErr := conn.Close()
if cErr != nil {
log.Errorf("failed to close connection, %s: %s", conn.RemoteAddr(), cErr)
}
return
}
peer := NewPeer(r.metrics, peerID, conn, r.store)
peer.log.Infof("peer connected from: %s", conn.RemoteAddr())
r.store.AddPeer(peer)
r.metrics.PeerConnected(peer.String())
go func() {
peer.Work()
r.store.DeletePeer(peer)
peer.log.Debugf("relay connection closed")
r.metrics.PeerDisconnected(peer.String())
}()
}
2024-07-29 21:53:07 +02:00
// Close closes the relay server
// It closes the connection with all peers in gracefully and stops accepting new connections.
func (r *Relay) Close(ctx context.Context) {
2024-07-08 21:53:20 +02:00
log.Infof("close connection with all peers")
r.closeMu.Lock()
wg := sync.WaitGroup{}
peers := r.store.Peers()
for _, peer := range peers {
wg.Add(1)
go func(p *Peer) {
p.CloseGracefully(ctx)
wg.Done()
}(peer)
}
wg.Wait()
r.metricsCancel()
r.closeMu.Unlock()
}
2024-07-29 21:53:07 +02:00
// InstanceURL returns the instance URL of the relay server
func (r *Relay) InstanceURL() string {
return r.instanceURL
}
func (r *Relay) handshake(conn net.Conn) ([]byte, error) {
buf := make([]byte, messages.MaxHandshakeSize)
n, err := conn.Read(buf)
if err != nil {
log.Errorf("failed to read message: %s", err)
return nil, err
}
msgType, err := messages.DetermineClientMsgType(buf[:n])
if err != nil {
return nil, err
}
2024-07-02 11:57:17 +02:00
if msgType != messages.MsgTypeHello {
tErr := fmt.Errorf("invalid message type")
log.Errorf("failed to handshake: %s", tErr)
return nil, tErr
}
2024-07-02 11:57:17 +02:00
2024-07-05 16:12:30 +02:00
peerID, authPayload, err := messages.UnmarshalHelloMsg(buf[:n])
if err != nil {
log.Errorf("failed to handshake: %s", err)
return nil, err
}
2024-07-05 16:12:30 +02:00
if err := r.validator.Validate(authPayload); err != nil {
log.Errorf("failed to authenticate connection: %s", err)
return nil, err
2024-07-05 16:12:30 +02:00
}
2024-07-24 16:34:47 +02:00
msg, _ := messages.MarshalHelloResponse(r.instanceURL)
_, err = conn.Write(msg)
if err != nil {
return nil, err
}
return peerID, nil
}