mirror of
https://github.com/netbirdio/netbird.git
synced 2024-11-21 15:43:12 +01:00
[relay-server] Move the handshake logic to separated struct (#2648)
* Move the handshake logic to separated struct - The server will response to the client after it ready to process the peer - Preload the response messages * Fix deprecated lint issue * Fix error handling * [relay-server] Relay measure auth time (#2675) Measure the Relay client's authentication time
This commit is contained in:
parent
3a88ac78ff
commit
d93dd4fc7f
@ -16,8 +16,10 @@ const (
|
|||||||
type Metrics struct {
|
type Metrics struct {
|
||||||
metric.Meter
|
metric.Meter
|
||||||
|
|
||||||
TransferBytesSent metric.Int64Counter
|
TransferBytesSent metric.Int64Counter
|
||||||
TransferBytesRecv metric.Int64Counter
|
TransferBytesRecv metric.Int64Counter
|
||||||
|
AuthenticationTime metric.Float64Histogram
|
||||||
|
PeerStoreTime metric.Float64Histogram
|
||||||
|
|
||||||
peers metric.Int64UpDownCounter
|
peers metric.Int64UpDownCounter
|
||||||
peerActivityChan chan string
|
peerActivityChan chan string
|
||||||
@ -52,11 +54,23 @@ func NewMetrics(ctx context.Context, meter metric.Meter) (*Metrics, error) {
|
|||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
authTime, err := meter.Float64Histogram("relay_peer_authentication_time_milliseconds", metric.WithExplicitBucketBoundaries(getStandardBucketBoundaries()...))
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
peerStoreTime, err := meter.Float64Histogram("relay_peer_store_time_milliseconds", metric.WithExplicitBucketBoundaries(getStandardBucketBoundaries()...))
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
m := &Metrics{
|
m := &Metrics{
|
||||||
Meter: meter,
|
Meter: meter,
|
||||||
TransferBytesSent: bytesSent,
|
TransferBytesSent: bytesSent,
|
||||||
TransferBytesRecv: bytesRecv,
|
TransferBytesRecv: bytesRecv,
|
||||||
peers: peers,
|
AuthenticationTime: authTime,
|
||||||
|
PeerStoreTime: peerStoreTime,
|
||||||
|
peers: peers,
|
||||||
|
|
||||||
ctx: ctx,
|
ctx: ctx,
|
||||||
peerActivityChan: make(chan string, 10),
|
peerActivityChan: make(chan string, 10),
|
||||||
@ -89,6 +103,16 @@ func (m *Metrics) PeerConnected(id string) {
|
|||||||
m.peerLastActive[id] = time.Time{}
|
m.peerLastActive[id] = time.Time{}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// RecordAuthenticationTime measures the time taken for peer authentication
|
||||||
|
func (m *Metrics) RecordAuthenticationTime(duration time.Duration) {
|
||||||
|
m.AuthenticationTime.Record(m.ctx, float64(duration.Nanoseconds())/1e6)
|
||||||
|
}
|
||||||
|
|
||||||
|
// RecordPeerStoreTime measures the time to store the peer in map
|
||||||
|
func (m *Metrics) RecordPeerStoreTime(duration time.Duration) {
|
||||||
|
m.PeerStoreTime.Record(m.ctx, float64(duration.Nanoseconds())/1e6)
|
||||||
|
}
|
||||||
|
|
||||||
// PeerDisconnected decrements the number of connected peers and decrements number of idle or active connections
|
// PeerDisconnected decrements the number of connected peers and decrements number of idle or active connections
|
||||||
func (m *Metrics) PeerDisconnected(id string) {
|
func (m *Metrics) PeerDisconnected(id string) {
|
||||||
m.peers.Add(m.ctx, -1)
|
m.peers.Add(m.ctx, -1)
|
||||||
@ -134,3 +158,19 @@ func (m *Metrics) readPeerActivity() {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func getStandardBucketBoundaries() []float64 {
|
||||||
|
return []float64{
|
||||||
|
0.1,
|
||||||
|
0.5,
|
||||||
|
1,
|
||||||
|
5,
|
||||||
|
10,
|
||||||
|
50,
|
||||||
|
100,
|
||||||
|
500,
|
||||||
|
1000,
|
||||||
|
5000,
|
||||||
|
10000,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
153
relay/server/handshake.go
Normal file
153
relay/server/handshake.go
Normal file
@ -0,0 +1,153 @@
|
|||||||
|
package server
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"net"
|
||||||
|
|
||||||
|
log "github.com/sirupsen/logrus"
|
||||||
|
|
||||||
|
"github.com/netbirdio/netbird/relay/auth"
|
||||||
|
"github.com/netbirdio/netbird/relay/messages"
|
||||||
|
//nolint:staticcheck
|
||||||
|
"github.com/netbirdio/netbird/relay/messages/address"
|
||||||
|
//nolint:staticcheck
|
||||||
|
authmsg "github.com/netbirdio/netbird/relay/messages/auth"
|
||||||
|
)
|
||||||
|
|
||||||
|
// preparedMsg contains the marshalled success response messages
|
||||||
|
type preparedMsg struct {
|
||||||
|
responseHelloMsg []byte
|
||||||
|
responseAuthMsg []byte
|
||||||
|
}
|
||||||
|
|
||||||
|
func newPreparedMsg(instanceURL string) (*preparedMsg, error) {
|
||||||
|
rhm, err := marshalResponseHelloMsg(instanceURL)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
ram, err := messages.MarshalAuthResponse(instanceURL)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to marshal auth response msg: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return &preparedMsg{
|
||||||
|
responseHelloMsg: rhm,
|
||||||
|
responseAuthMsg: ram,
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func marshalResponseHelloMsg(instanceURL string) ([]byte, error) {
|
||||||
|
addr := &address.Address{URL: instanceURL}
|
||||||
|
addrData, err := addr.Marshal()
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to marshal response address: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
//nolint:staticcheck
|
||||||
|
responseMsg, err := messages.MarshalHelloResponse(addrData)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to marshal hello response: %w", err)
|
||||||
|
}
|
||||||
|
return responseMsg, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
type handshake struct {
|
||||||
|
conn net.Conn
|
||||||
|
validator auth.Validator
|
||||||
|
preparedMsg *preparedMsg
|
||||||
|
|
||||||
|
handshakeMethodAuth bool
|
||||||
|
peerID string
|
||||||
|
}
|
||||||
|
|
||||||
|
func (h *handshake) handshakeReceive() ([]byte, error) {
|
||||||
|
buf := make([]byte, messages.MaxHandshakeSize)
|
||||||
|
n, err := h.conn.Read(buf)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("read from %s: %w", h.conn.RemoteAddr(), err)
|
||||||
|
}
|
||||||
|
|
||||||
|
_, err = messages.ValidateVersion(buf[:n])
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("validate version from %s: %w", h.conn.RemoteAddr(), err)
|
||||||
|
}
|
||||||
|
|
||||||
|
msgType, err := messages.DetermineClientMessageType(buf[messages.SizeOfVersionByte:n])
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("determine message type from %s: %w", h.conn.RemoteAddr(), err)
|
||||||
|
}
|
||||||
|
|
||||||
|
var (
|
||||||
|
bytePeerID []byte
|
||||||
|
peerID string
|
||||||
|
)
|
||||||
|
switch msgType {
|
||||||
|
//nolint:staticcheck
|
||||||
|
case messages.MsgTypeHello:
|
||||||
|
bytePeerID, peerID, err = h.handleHelloMsg(buf[messages.SizeOfProtoHeader:n])
|
||||||
|
case messages.MsgTypeAuth:
|
||||||
|
h.handshakeMethodAuth = true
|
||||||
|
bytePeerID, peerID, err = h.handleAuthMsg(buf[messages.SizeOfProtoHeader:n])
|
||||||
|
default:
|
||||||
|
return nil, fmt.Errorf("invalid message type %d from %s", msgType, h.conn.RemoteAddr())
|
||||||
|
}
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
h.peerID = peerID
|
||||||
|
return bytePeerID, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (h *handshake) handshakeResponse() error {
|
||||||
|
var responseMsg []byte
|
||||||
|
if h.handshakeMethodAuth {
|
||||||
|
responseMsg = h.preparedMsg.responseAuthMsg
|
||||||
|
} else {
|
||||||
|
responseMsg = h.preparedMsg.responseHelloMsg
|
||||||
|
}
|
||||||
|
|
||||||
|
if _, err := h.conn.Write(responseMsg); err != nil {
|
||||||
|
return fmt.Errorf("handshake response write to %s (%s): %w", h.peerID, h.conn.RemoteAddr(), err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (h *handshake) handleHelloMsg(buf []byte) ([]byte, string, error) {
|
||||||
|
//nolint:staticcheck
|
||||||
|
rawPeerID, authData, err := messages.UnmarshalHelloMsg(buf)
|
||||||
|
if err != nil {
|
||||||
|
return nil, "", fmt.Errorf("unmarshal hello message: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
peerID := messages.HashIDToString(rawPeerID)
|
||||||
|
log.Warnf("peer %s (%s) is using deprecated initial message type", peerID, h.conn.RemoteAddr())
|
||||||
|
|
||||||
|
authMsg, err := authmsg.UnmarshalMsg(authData)
|
||||||
|
if err != nil {
|
||||||
|
return nil, "", fmt.Errorf("unmarshal auth message: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
//nolint:staticcheck
|
||||||
|
if err := h.validator.ValidateHelloMsgType(authMsg.AdditionalData); err != nil {
|
||||||
|
return nil, "", fmt.Errorf("validate %s (%s): %w", peerID, h.conn.RemoteAddr(), err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return rawPeerID, peerID, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (h *handshake) handleAuthMsg(buf []byte) ([]byte, string, error) {
|
||||||
|
rawPeerID, authPayload, err := messages.UnmarshalAuthMsg(buf)
|
||||||
|
if err != nil {
|
||||||
|
return nil, "", fmt.Errorf("unmarshal hello message: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
peerID := messages.HashIDToString(rawPeerID)
|
||||||
|
|
||||||
|
if err := h.validator.Validate(authPayload); err != nil {
|
||||||
|
return nil, "", fmt.Errorf("validate %s (%s): %w", peerID, h.conn.RemoteAddr(), err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return rawPeerID, peerID, nil
|
||||||
|
}
|
@ -7,16 +7,13 @@ import (
|
|||||||
"net/url"
|
"net/url"
|
||||||
"strings"
|
"strings"
|
||||||
"sync"
|
"sync"
|
||||||
|
"time"
|
||||||
|
|
||||||
log "github.com/sirupsen/logrus"
|
log "github.com/sirupsen/logrus"
|
||||||
"go.opentelemetry.io/otel/metric"
|
"go.opentelemetry.io/otel/metric"
|
||||||
|
|
||||||
"github.com/netbirdio/netbird/relay/auth"
|
"github.com/netbirdio/netbird/relay/auth"
|
||||||
"github.com/netbirdio/netbird/relay/messages"
|
|
||||||
//nolint:staticcheck
|
//nolint:staticcheck
|
||||||
"github.com/netbirdio/netbird/relay/messages/address"
|
|
||||||
//nolint:staticcheck
|
|
||||||
authmsg "github.com/netbirdio/netbird/relay/messages/auth"
|
|
||||||
"github.com/netbirdio/netbird/relay/metrics"
|
"github.com/netbirdio/netbird/relay/metrics"
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -28,6 +25,7 @@ type Relay struct {
|
|||||||
|
|
||||||
store *Store
|
store *Store
|
||||||
instanceURL string
|
instanceURL string
|
||||||
|
preparedMsg *preparedMsg
|
||||||
|
|
||||||
closed bool
|
closed bool
|
||||||
closeMu sync.RWMutex
|
closeMu sync.RWMutex
|
||||||
@ -69,6 +67,12 @@ func NewRelay(meter metric.Meter, exposedAddress string, tlsSupport bool, valida
|
|||||||
return nil, fmt.Errorf("get instance URL: %v", err)
|
return nil, fmt.Errorf("get instance URL: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
r.preparedMsg, err = newPreparedMsg(r.instanceURL)
|
||||||
|
if err != nil {
|
||||||
|
metricsCancel()
|
||||||
|
return nil, fmt.Errorf("prepare message: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
return r, nil
|
return r, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -100,17 +104,22 @@ func getInstanceURL(exposedAddress string, tlsSupported bool) (string, error) {
|
|||||||
|
|
||||||
// Accept start to handle a new peer connection
|
// Accept start to handle a new peer connection
|
||||||
func (r *Relay) Accept(conn net.Conn) {
|
func (r *Relay) Accept(conn net.Conn) {
|
||||||
|
acceptTime := time.Now()
|
||||||
r.closeMu.RLock()
|
r.closeMu.RLock()
|
||||||
defer r.closeMu.RUnlock()
|
defer r.closeMu.RUnlock()
|
||||||
if r.closed {
|
if r.closed {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
peerID, err := r.handshake(conn)
|
h := handshake{
|
||||||
|
conn: conn,
|
||||||
|
validator: r.validator,
|
||||||
|
preparedMsg: r.preparedMsg,
|
||||||
|
}
|
||||||
|
peerID, err := h.handshakeReceive()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Errorf("failed to handshake: %s", err)
|
log.Errorf("failed to handshake: %s", err)
|
||||||
cErr := conn.Close()
|
if cErr := conn.Close(); cErr != nil {
|
||||||
if cErr != nil {
|
|
||||||
log.Errorf("failed to close connection, %s: %s", conn.RemoteAddr(), cErr)
|
log.Errorf("failed to close connection, %s: %s", conn.RemoteAddr(), cErr)
|
||||||
}
|
}
|
||||||
return
|
return
|
||||||
@ -118,7 +127,9 @@ func (r *Relay) Accept(conn net.Conn) {
|
|||||||
|
|
||||||
peer := NewPeer(r.metrics, peerID, conn, r.store)
|
peer := NewPeer(r.metrics, peerID, conn, r.store)
|
||||||
peer.log.Infof("peer connected from: %s", conn.RemoteAddr())
|
peer.log.Infof("peer connected from: %s", conn.RemoteAddr())
|
||||||
|
storeTime := time.Now()
|
||||||
r.store.AddPeer(peer)
|
r.store.AddPeer(peer)
|
||||||
|
r.metrics.RecordPeerStoreTime(time.Since(storeTime))
|
||||||
r.metrics.PeerConnected(peer.String())
|
r.metrics.PeerConnected(peer.String())
|
||||||
go func() {
|
go func() {
|
||||||
peer.Work()
|
peer.Work()
|
||||||
@ -126,6 +137,12 @@ func (r *Relay) Accept(conn net.Conn) {
|
|||||||
peer.log.Debugf("relay connection closed")
|
peer.log.Debugf("relay connection closed")
|
||||||
r.metrics.PeerDisconnected(peer.String())
|
r.metrics.PeerDisconnected(peer.String())
|
||||||
}()
|
}()
|
||||||
|
|
||||||
|
if err := h.handshakeResponse(); err != nil {
|
||||||
|
log.Errorf("failed to send handshake response, close peer: %s", err)
|
||||||
|
peer.Close()
|
||||||
|
}
|
||||||
|
r.metrics.RecordAuthenticationTime(time.Since(acceptTime))
|
||||||
}
|
}
|
||||||
|
|
||||||
// Shutdown closes the relay server
|
// Shutdown closes the relay server
|
||||||
@ -151,99 +168,3 @@ func (r *Relay) Shutdown(ctx context.Context) {
|
|||||||
func (r *Relay) InstanceURL() string {
|
func (r *Relay) InstanceURL() string {
|
||||||
return r.instanceURL
|
return r.instanceURL
|
||||||
}
|
}
|
||||||
|
|
||||||
func (r *Relay) handshake(conn net.Conn) ([]byte, error) {
|
|
||||||
buf := make([]byte, messages.MaxHandshakeSize)
|
|
||||||
n, err := conn.Read(buf)
|
|
||||||
if err != nil {
|
|
||||||
return nil, fmt.Errorf("read from %s: %w", conn.RemoteAddr(), err)
|
|
||||||
}
|
|
||||||
|
|
||||||
_, err = messages.ValidateVersion(buf[:n])
|
|
||||||
if err != nil {
|
|
||||||
return nil, fmt.Errorf("validate version from %s: %w", conn.RemoteAddr(), err)
|
|
||||||
}
|
|
||||||
|
|
||||||
msgType, err := messages.DetermineClientMessageType(buf[messages.SizeOfVersionByte:n])
|
|
||||||
if err != nil {
|
|
||||||
return nil, fmt.Errorf("determine message type from %s: %w", conn.RemoteAddr(), err)
|
|
||||||
}
|
|
||||||
|
|
||||||
var (
|
|
||||||
responseMsg []byte
|
|
||||||
peerID []byte
|
|
||||||
)
|
|
||||||
switch msgType {
|
|
||||||
//nolint:staticcheck
|
|
||||||
case messages.MsgTypeHello:
|
|
||||||
peerID, responseMsg, err = r.handleHelloMsg(buf[messages.SizeOfProtoHeader:n], conn.RemoteAddr())
|
|
||||||
case messages.MsgTypeAuth:
|
|
||||||
peerID, responseMsg, err = r.handleAuthMsg(buf[messages.SizeOfProtoHeader:n], conn.RemoteAddr())
|
|
||||||
default:
|
|
||||||
return nil, fmt.Errorf("invalid message type %d from %s", msgType, conn.RemoteAddr())
|
|
||||||
}
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
_, err = conn.Write(responseMsg)
|
|
||||||
if err != nil {
|
|
||||||
return nil, fmt.Errorf("write to %s (%s): %w", peerID, conn.RemoteAddr(), err)
|
|
||||||
}
|
|
||||||
|
|
||||||
return peerID, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (r *Relay) handleHelloMsg(buf []byte, remoteAddr net.Addr) ([]byte, []byte, error) {
|
|
||||||
//nolint:staticcheck
|
|
||||||
rawPeerID, authData, err := messages.UnmarshalHelloMsg(buf)
|
|
||||||
if err != nil {
|
|
||||||
return nil, nil, fmt.Errorf("unmarshal hello message: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
peerID := messages.HashIDToString(rawPeerID)
|
|
||||||
log.Warnf("peer %s (%s) is using deprecated initial message type", peerID, remoteAddr)
|
|
||||||
|
|
||||||
authMsg, err := authmsg.UnmarshalMsg(authData)
|
|
||||||
if err != nil {
|
|
||||||
return nil, nil, fmt.Errorf("unmarshal auth message: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
//nolint:staticcheck
|
|
||||||
if err := r.validator.ValidateHelloMsgType(authMsg.AdditionalData); err != nil {
|
|
||||||
return nil, nil, fmt.Errorf("validate %s (%s): %w", peerID, remoteAddr, err)
|
|
||||||
}
|
|
||||||
|
|
||||||
addr := &address.Address{URL: r.instanceURL}
|
|
||||||
addrData, err := addr.Marshal()
|
|
||||||
if err != nil {
|
|
||||||
return nil, nil, fmt.Errorf("marshal addressc to %s (%s): %w", peerID, remoteAddr, err)
|
|
||||||
}
|
|
||||||
|
|
||||||
//nolint:staticcheck
|
|
||||||
responseMsg, err := messages.MarshalHelloResponse(addrData)
|
|
||||||
if err != nil {
|
|
||||||
return nil, nil, fmt.Errorf("marshal hello response to %s (%s): %w", peerID, remoteAddr, err)
|
|
||||||
}
|
|
||||||
return rawPeerID, responseMsg, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (r *Relay) handleAuthMsg(buf []byte, addr net.Addr) ([]byte, []byte, error) {
|
|
||||||
rawPeerID, authPayload, err := messages.UnmarshalAuthMsg(buf)
|
|
||||||
if err != nil {
|
|
||||||
return nil, nil, fmt.Errorf("unmarshal hello message: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
peerID := messages.HashIDToString(rawPeerID)
|
|
||||||
|
|
||||||
if err := r.validator.Validate(authPayload); err != nil {
|
|
||||||
return nil, nil, fmt.Errorf("validate %s (%s): %w", peerID, addr, err)
|
|
||||||
}
|
|
||||||
|
|
||||||
responseMsg, err := messages.MarshalAuthResponse(r.instanceURL)
|
|
||||||
if err != nil {
|
|
||||||
return nil, nil, fmt.Errorf("marshal hello response to %s (%s): %w", peerID, addr, err)
|
|
||||||
}
|
|
||||||
|
|
||||||
return rawPeerID, responseMsg, nil
|
|
||||||
}
|
|
||||||
|
Loading…
Reference in New Issue
Block a user