[relay-server] Move the handshake logic to separated struct (#2648)

* Move the handshake logic to separated struct

- The server will response to the client after it ready to process the peer
- Preload the response messages

* Fix deprecated lint issue

* Fix error handling

* [relay-server] Relay measure auth time (#2675)

Measure the Relay client's authentication time
This commit is contained in:
Zoltan Papp 2024-10-12 18:21:34 +02:00 committed by GitHub
parent 3a88ac78ff
commit d93dd4fc7f
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 223 additions and 109 deletions

View File

@ -18,6 +18,8 @@ type Metrics struct {
TransferBytesSent metric.Int64Counter TransferBytesSent metric.Int64Counter
TransferBytesRecv metric.Int64Counter TransferBytesRecv metric.Int64Counter
AuthenticationTime metric.Float64Histogram
PeerStoreTime metric.Float64Histogram
peers metric.Int64UpDownCounter peers metric.Int64UpDownCounter
peerActivityChan chan string peerActivityChan chan string
@ -52,10 +54,22 @@ func NewMetrics(ctx context.Context, meter metric.Meter) (*Metrics, error) {
return nil, err return nil, err
} }
authTime, err := meter.Float64Histogram("relay_peer_authentication_time_milliseconds", metric.WithExplicitBucketBoundaries(getStandardBucketBoundaries()...))
if err != nil {
return nil, err
}
peerStoreTime, err := meter.Float64Histogram("relay_peer_store_time_milliseconds", metric.WithExplicitBucketBoundaries(getStandardBucketBoundaries()...))
if err != nil {
return nil, err
}
m := &Metrics{ m := &Metrics{
Meter: meter, Meter: meter,
TransferBytesSent: bytesSent, TransferBytesSent: bytesSent,
TransferBytesRecv: bytesRecv, TransferBytesRecv: bytesRecv,
AuthenticationTime: authTime,
PeerStoreTime: peerStoreTime,
peers: peers, peers: peers,
ctx: ctx, ctx: ctx,
@ -89,6 +103,16 @@ func (m *Metrics) PeerConnected(id string) {
m.peerLastActive[id] = time.Time{} m.peerLastActive[id] = time.Time{}
} }
// RecordAuthenticationTime measures the time taken for peer authentication
func (m *Metrics) RecordAuthenticationTime(duration time.Duration) {
m.AuthenticationTime.Record(m.ctx, float64(duration.Nanoseconds())/1e6)
}
// RecordPeerStoreTime measures the time to store the peer in map
func (m *Metrics) RecordPeerStoreTime(duration time.Duration) {
m.PeerStoreTime.Record(m.ctx, float64(duration.Nanoseconds())/1e6)
}
// PeerDisconnected decrements the number of connected peers and decrements number of idle or active connections // PeerDisconnected decrements the number of connected peers and decrements number of idle or active connections
func (m *Metrics) PeerDisconnected(id string) { func (m *Metrics) PeerDisconnected(id string) {
m.peers.Add(m.ctx, -1) m.peers.Add(m.ctx, -1)
@ -134,3 +158,19 @@ func (m *Metrics) readPeerActivity() {
} }
} }
} }
func getStandardBucketBoundaries() []float64 {
return []float64{
0.1,
0.5,
1,
5,
10,
50,
100,
500,
1000,
5000,
10000,
}
}

153
relay/server/handshake.go Normal file
View File

@ -0,0 +1,153 @@
package server
import (
"fmt"
"net"
log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/relay/auth"
"github.com/netbirdio/netbird/relay/messages"
//nolint:staticcheck
"github.com/netbirdio/netbird/relay/messages/address"
//nolint:staticcheck
authmsg "github.com/netbirdio/netbird/relay/messages/auth"
)
// preparedMsg contains the marshalled success response messages
type preparedMsg struct {
responseHelloMsg []byte
responseAuthMsg []byte
}
func newPreparedMsg(instanceURL string) (*preparedMsg, error) {
rhm, err := marshalResponseHelloMsg(instanceURL)
if err != nil {
return nil, err
}
ram, err := messages.MarshalAuthResponse(instanceURL)
if err != nil {
return nil, fmt.Errorf("failed to marshal auth response msg: %w", err)
}
return &preparedMsg{
responseHelloMsg: rhm,
responseAuthMsg: ram,
}, nil
}
func marshalResponseHelloMsg(instanceURL string) ([]byte, error) {
addr := &address.Address{URL: instanceURL}
addrData, err := addr.Marshal()
if err != nil {
return nil, fmt.Errorf("failed to marshal response address: %w", err)
}
//nolint:staticcheck
responseMsg, err := messages.MarshalHelloResponse(addrData)
if err != nil {
return nil, fmt.Errorf("failed to marshal hello response: %w", err)
}
return responseMsg, nil
}
type handshake struct {
conn net.Conn
validator auth.Validator
preparedMsg *preparedMsg
handshakeMethodAuth bool
peerID string
}
func (h *handshake) handshakeReceive() ([]byte, error) {
buf := make([]byte, messages.MaxHandshakeSize)
n, err := h.conn.Read(buf)
if err != nil {
return nil, fmt.Errorf("read from %s: %w", h.conn.RemoteAddr(), err)
}
_, err = messages.ValidateVersion(buf[:n])
if err != nil {
return nil, fmt.Errorf("validate version from %s: %w", h.conn.RemoteAddr(), err)
}
msgType, err := messages.DetermineClientMessageType(buf[messages.SizeOfVersionByte:n])
if err != nil {
return nil, fmt.Errorf("determine message type from %s: %w", h.conn.RemoteAddr(), err)
}
var (
bytePeerID []byte
peerID string
)
switch msgType {
//nolint:staticcheck
case messages.MsgTypeHello:
bytePeerID, peerID, err = h.handleHelloMsg(buf[messages.SizeOfProtoHeader:n])
case messages.MsgTypeAuth:
h.handshakeMethodAuth = true
bytePeerID, peerID, err = h.handleAuthMsg(buf[messages.SizeOfProtoHeader:n])
default:
return nil, fmt.Errorf("invalid message type %d from %s", msgType, h.conn.RemoteAddr())
}
if err != nil {
return nil, err
}
h.peerID = peerID
return bytePeerID, nil
}
func (h *handshake) handshakeResponse() error {
var responseMsg []byte
if h.handshakeMethodAuth {
responseMsg = h.preparedMsg.responseAuthMsg
} else {
responseMsg = h.preparedMsg.responseHelloMsg
}
if _, err := h.conn.Write(responseMsg); err != nil {
return fmt.Errorf("handshake response write to %s (%s): %w", h.peerID, h.conn.RemoteAddr(), err)
}
return nil
}
func (h *handshake) handleHelloMsg(buf []byte) ([]byte, string, error) {
//nolint:staticcheck
rawPeerID, authData, err := messages.UnmarshalHelloMsg(buf)
if err != nil {
return nil, "", fmt.Errorf("unmarshal hello message: %w", err)
}
peerID := messages.HashIDToString(rawPeerID)
log.Warnf("peer %s (%s) is using deprecated initial message type", peerID, h.conn.RemoteAddr())
authMsg, err := authmsg.UnmarshalMsg(authData)
if err != nil {
return nil, "", fmt.Errorf("unmarshal auth message: %w", err)
}
//nolint:staticcheck
if err := h.validator.ValidateHelloMsgType(authMsg.AdditionalData); err != nil {
return nil, "", fmt.Errorf("validate %s (%s): %w", peerID, h.conn.RemoteAddr(), err)
}
return rawPeerID, peerID, nil
}
func (h *handshake) handleAuthMsg(buf []byte) ([]byte, string, error) {
rawPeerID, authPayload, err := messages.UnmarshalAuthMsg(buf)
if err != nil {
return nil, "", fmt.Errorf("unmarshal hello message: %w", err)
}
peerID := messages.HashIDToString(rawPeerID)
if err := h.validator.Validate(authPayload); err != nil {
return nil, "", fmt.Errorf("validate %s (%s): %w", peerID, h.conn.RemoteAddr(), err)
}
return rawPeerID, peerID, nil
}

View File

@ -7,16 +7,13 @@ import (
"net/url" "net/url"
"strings" "strings"
"sync" "sync"
"time"
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
"go.opentelemetry.io/otel/metric" "go.opentelemetry.io/otel/metric"
"github.com/netbirdio/netbird/relay/auth" "github.com/netbirdio/netbird/relay/auth"
"github.com/netbirdio/netbird/relay/messages"
//nolint:staticcheck //nolint:staticcheck
"github.com/netbirdio/netbird/relay/messages/address"
//nolint:staticcheck
authmsg "github.com/netbirdio/netbird/relay/messages/auth"
"github.com/netbirdio/netbird/relay/metrics" "github.com/netbirdio/netbird/relay/metrics"
) )
@ -28,6 +25,7 @@ type Relay struct {
store *Store store *Store
instanceURL string instanceURL string
preparedMsg *preparedMsg
closed bool closed bool
closeMu sync.RWMutex closeMu sync.RWMutex
@ -69,6 +67,12 @@ func NewRelay(meter metric.Meter, exposedAddress string, tlsSupport bool, valida
return nil, fmt.Errorf("get instance URL: %v", err) return nil, fmt.Errorf("get instance URL: %v", err)
} }
r.preparedMsg, err = newPreparedMsg(r.instanceURL)
if err != nil {
metricsCancel()
return nil, fmt.Errorf("prepare message: %v", err)
}
return r, nil return r, nil
} }
@ -100,17 +104,22 @@ func getInstanceURL(exposedAddress string, tlsSupported bool) (string, error) {
// Accept start to handle a new peer connection // Accept start to handle a new peer connection
func (r *Relay) Accept(conn net.Conn) { func (r *Relay) Accept(conn net.Conn) {
acceptTime := time.Now()
r.closeMu.RLock() r.closeMu.RLock()
defer r.closeMu.RUnlock() defer r.closeMu.RUnlock()
if r.closed { if r.closed {
return return
} }
peerID, err := r.handshake(conn) h := handshake{
conn: conn,
validator: r.validator,
preparedMsg: r.preparedMsg,
}
peerID, err := h.handshakeReceive()
if err != nil { if err != nil {
log.Errorf("failed to handshake: %s", err) log.Errorf("failed to handshake: %s", err)
cErr := conn.Close() if cErr := conn.Close(); cErr != nil {
if cErr != nil {
log.Errorf("failed to close connection, %s: %s", conn.RemoteAddr(), cErr) log.Errorf("failed to close connection, %s: %s", conn.RemoteAddr(), cErr)
} }
return return
@ -118,7 +127,9 @@ func (r *Relay) Accept(conn net.Conn) {
peer := NewPeer(r.metrics, peerID, conn, r.store) peer := NewPeer(r.metrics, peerID, conn, r.store)
peer.log.Infof("peer connected from: %s", conn.RemoteAddr()) peer.log.Infof("peer connected from: %s", conn.RemoteAddr())
storeTime := time.Now()
r.store.AddPeer(peer) r.store.AddPeer(peer)
r.metrics.RecordPeerStoreTime(time.Since(storeTime))
r.metrics.PeerConnected(peer.String()) r.metrics.PeerConnected(peer.String())
go func() { go func() {
peer.Work() peer.Work()
@ -126,6 +137,12 @@ func (r *Relay) Accept(conn net.Conn) {
peer.log.Debugf("relay connection closed") peer.log.Debugf("relay connection closed")
r.metrics.PeerDisconnected(peer.String()) r.metrics.PeerDisconnected(peer.String())
}() }()
if err := h.handshakeResponse(); err != nil {
log.Errorf("failed to send handshake response, close peer: %s", err)
peer.Close()
}
r.metrics.RecordAuthenticationTime(time.Since(acceptTime))
} }
// Shutdown closes the relay server // Shutdown closes the relay server
@ -151,99 +168,3 @@ func (r *Relay) Shutdown(ctx context.Context) {
func (r *Relay) InstanceURL() string { func (r *Relay) InstanceURL() string {
return r.instanceURL return r.instanceURL
} }
func (r *Relay) handshake(conn net.Conn) ([]byte, error) {
buf := make([]byte, messages.MaxHandshakeSize)
n, err := conn.Read(buf)
if err != nil {
return nil, fmt.Errorf("read from %s: %w", conn.RemoteAddr(), err)
}
_, err = messages.ValidateVersion(buf[:n])
if err != nil {
return nil, fmt.Errorf("validate version from %s: %w", conn.RemoteAddr(), err)
}
msgType, err := messages.DetermineClientMessageType(buf[messages.SizeOfVersionByte:n])
if err != nil {
return nil, fmt.Errorf("determine message type from %s: %w", conn.RemoteAddr(), err)
}
var (
responseMsg []byte
peerID []byte
)
switch msgType {
//nolint:staticcheck
case messages.MsgTypeHello:
peerID, responseMsg, err = r.handleHelloMsg(buf[messages.SizeOfProtoHeader:n], conn.RemoteAddr())
case messages.MsgTypeAuth:
peerID, responseMsg, err = r.handleAuthMsg(buf[messages.SizeOfProtoHeader:n], conn.RemoteAddr())
default:
return nil, fmt.Errorf("invalid message type %d from %s", msgType, conn.RemoteAddr())
}
if err != nil {
return nil, err
}
_, err = conn.Write(responseMsg)
if err != nil {
return nil, fmt.Errorf("write to %s (%s): %w", peerID, conn.RemoteAddr(), err)
}
return peerID, nil
}
func (r *Relay) handleHelloMsg(buf []byte, remoteAddr net.Addr) ([]byte, []byte, error) {
//nolint:staticcheck
rawPeerID, authData, err := messages.UnmarshalHelloMsg(buf)
if err != nil {
return nil, nil, fmt.Errorf("unmarshal hello message: %w", err)
}
peerID := messages.HashIDToString(rawPeerID)
log.Warnf("peer %s (%s) is using deprecated initial message type", peerID, remoteAddr)
authMsg, err := authmsg.UnmarshalMsg(authData)
if err != nil {
return nil, nil, fmt.Errorf("unmarshal auth message: %w", err)
}
//nolint:staticcheck
if err := r.validator.ValidateHelloMsgType(authMsg.AdditionalData); err != nil {
return nil, nil, fmt.Errorf("validate %s (%s): %w", peerID, remoteAddr, err)
}
addr := &address.Address{URL: r.instanceURL}
addrData, err := addr.Marshal()
if err != nil {
return nil, nil, fmt.Errorf("marshal addressc to %s (%s): %w", peerID, remoteAddr, err)
}
//nolint:staticcheck
responseMsg, err := messages.MarshalHelloResponse(addrData)
if err != nil {
return nil, nil, fmt.Errorf("marshal hello response to %s (%s): %w", peerID, remoteAddr, err)
}
return rawPeerID, responseMsg, nil
}
func (r *Relay) handleAuthMsg(buf []byte, addr net.Addr) ([]byte, []byte, error) {
rawPeerID, authPayload, err := messages.UnmarshalAuthMsg(buf)
if err != nil {
return nil, nil, fmt.Errorf("unmarshal hello message: %w", err)
}
peerID := messages.HashIDToString(rawPeerID)
if err := r.validator.Validate(authPayload); err != nil {
return nil, nil, fmt.Errorf("validate %s (%s): %w", peerID, addr, err)
}
responseMsg, err := messages.MarshalAuthResponse(r.instanceURL)
if err != nil {
return nil, nil, fmt.Errorf("marshal hello response to %s (%s): %w", peerID, addr, err)
}
return rawPeerID, responseMsg, nil
}