mirror of
https://github.com/netbirdio/netbird.git
synced 2025-08-08 23:05:28 +02:00
[relay] Improve relay messages (#2574)
Co-authored-by: Zoltán Papp <zoltan.pmail@gmail.com>
This commit is contained in:
@ -2,7 +2,6 @@ package server
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/sha256"
|
||||
"fmt"
|
||||
"net"
|
||||
"net/url"
|
||||
@ -14,7 +13,9 @@ import (
|
||||
|
||||
"github.com/netbirdio/netbird/relay/auth"
|
||||
"github.com/netbirdio/netbird/relay/messages"
|
||||
//nolint:staticcheck
|
||||
"github.com/netbirdio/netbird/relay/messages/address"
|
||||
//nolint:staticcheck
|
||||
authmsg "github.com/netbirdio/netbird/relay/messages/auth"
|
||||
"github.com/netbirdio/netbird/relay/metrics"
|
||||
)
|
||||
@ -168,39 +169,81 @@ func (r *Relay) handshake(conn net.Conn) ([]byte, error) {
|
||||
return nil, fmt.Errorf("determine message type from %s: %w", conn.RemoteAddr(), err)
|
||||
}
|
||||
|
||||
if msgType != messages.MsgTypeHello {
|
||||
return nil, fmt.Errorf("invalid message type from %s", conn.RemoteAddr())
|
||||
var (
|
||||
responseMsg []byte
|
||||
peerID []byte
|
||||
)
|
||||
switch msgType {
|
||||
//nolint:staticcheck
|
||||
case messages.MsgTypeHello:
|
||||
peerID, responseMsg, err = r.handleHelloMsg(buf[messages.SizeOfProtoHeader:n], conn.RemoteAddr())
|
||||
case messages.MsgTypeAuth:
|
||||
peerID, responseMsg, err = r.handleAuthMsg(buf[messages.SizeOfProtoHeader:n], conn.RemoteAddr())
|
||||
default:
|
||||
return nil, fmt.Errorf("invalid message type %d from %s", msgType, conn.RemoteAddr())
|
||||
}
|
||||
|
||||
peerID, authData, err := messages.UnmarshalHelloMsg(buf[messages.SizeOfProtoHeader:n])
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("unmarshal hello message: %w", err)
|
||||
return nil, err
|
||||
}
|
||||
|
||||
authMsg, err := authmsg.UnmarshalMsg(authData)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("unmarshal auth message: %w", err)
|
||||
}
|
||||
|
||||
if err := r.validator.Validate(sha256.New, authMsg.AdditionalData); err != nil {
|
||||
return nil, fmt.Errorf("validate %s (%s): %w", peerID, conn.RemoteAddr(), err)
|
||||
}
|
||||
|
||||
addr := &address.Address{URL: r.instanceURL}
|
||||
addrData, err := addr.Marshal()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("marshal addressc to %s (%s): %w", peerID, conn.RemoteAddr(), err)
|
||||
}
|
||||
|
||||
msg, err := messages.MarshalHelloResponse(addrData)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("marshal hello response to %s (%s): %w", peerID, conn.RemoteAddr(), err)
|
||||
}
|
||||
|
||||
_, err = conn.Write(msg)
|
||||
_, err = conn.Write(responseMsg)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("write to %s (%s): %w", peerID, conn.RemoteAddr(), err)
|
||||
}
|
||||
|
||||
return peerID, nil
|
||||
}
|
||||
|
||||
func (r *Relay) handleHelloMsg(buf []byte, remoteAddr net.Addr) ([]byte, []byte, error) {
|
||||
//nolint:staticcheck
|
||||
rawPeerID, authData, err := messages.UnmarshalHelloMsg(buf)
|
||||
if err != nil {
|
||||
return nil, nil, fmt.Errorf("unmarshal hello message: %w", err)
|
||||
}
|
||||
|
||||
peerID := messages.HashIDToString(rawPeerID)
|
||||
log.Warnf("peer %s (%s) is using deprecated initial message type", peerID, remoteAddr)
|
||||
|
||||
authMsg, err := authmsg.UnmarshalMsg(authData)
|
||||
if err != nil {
|
||||
return nil, nil, fmt.Errorf("unmarshal auth message: %w", err)
|
||||
}
|
||||
|
||||
//nolint:staticcheck
|
||||
if err := r.validator.ValidateHelloMsgType(authMsg.AdditionalData); err != nil {
|
||||
return nil, nil, fmt.Errorf("validate %s (%s): %w", peerID, remoteAddr, err)
|
||||
}
|
||||
|
||||
addr := &address.Address{URL: r.instanceURL}
|
||||
addrData, err := addr.Marshal()
|
||||
if err != nil {
|
||||
return nil, nil, fmt.Errorf("marshal addressc to %s (%s): %w", peerID, remoteAddr, err)
|
||||
}
|
||||
|
||||
//nolint:staticcheck
|
||||
responseMsg, err := messages.MarshalHelloResponse(addrData)
|
||||
if err != nil {
|
||||
return nil, nil, fmt.Errorf("marshal hello response to %s (%s): %w", peerID, remoteAddr, err)
|
||||
}
|
||||
return rawPeerID, responseMsg, nil
|
||||
}
|
||||
|
||||
func (r *Relay) handleAuthMsg(buf []byte, addr net.Addr) ([]byte, []byte, error) {
|
||||
rawPeerID, authPayload, err := messages.UnmarshalAuthMsg(buf)
|
||||
if err != nil {
|
||||
return nil, nil, fmt.Errorf("unmarshal hello message: %w", err)
|
||||
}
|
||||
|
||||
peerID := messages.HashIDToString(rawPeerID)
|
||||
|
||||
if err := r.validator.Validate(authPayload); err != nil {
|
||||
return nil, nil, fmt.Errorf("validate %s (%s): %w", peerID, addr, err)
|
||||
}
|
||||
|
||||
responseMsg, err := messages.MarshalAuthResponse(r.instanceURL)
|
||||
if err != nil {
|
||||
return nil, nil, fmt.Errorf("marshal hello response to %s (%s): %w", peerID, addr, err)
|
||||
}
|
||||
|
||||
return rawPeerID, responseMsg, nil
|
||||
}
|
||||
|
Reference in New Issue
Block a user