Eliminate gob usage from Relay protocol

This commit is contained in:
Zoltán Papp 2024-09-10 16:08:51 +02:00
parent f43a0a0177
commit a701148658
5 changed files with 136 additions and 66 deletions

View File

@ -14,8 +14,6 @@ import (
"github.com/netbirdio/netbird/relay/client/dialer/ws"
"github.com/netbirdio/netbird/relay/healthcheck"
"github.com/netbirdio/netbird/relay/messages"
"github.com/netbirdio/netbird/relay/messages/address"
auth2 "github.com/netbirdio/netbird/relay/messages/auth"
)
const (
@ -240,31 +238,21 @@ func (c *Client) connect() error {
}
func (c *Client) handShake() error {
authMsg := &auth2.Msg{
AuthAlgorithm: auth2.AlgoHMACSHA256,
AdditionalData: c.authTokenStore.TokenBinary(),
}
authData, err := authMsg.Marshal()
msg, err := messages.MarshalAuthMsg(c.hashedID, c.authTokenStore.TokenBinary())
if err != nil {
return fmt.Errorf("marshal auth message: %w", err)
}
msg, err := messages.MarshalHelloMsg(c.hashedID, authData)
if err != nil {
log.Errorf("failed to marshal hello message: %s", err)
log.Errorf("failed to marshal auth message: %s", err)
return err
}
_, err = c.relayConn.Write(msg)
if err != nil {
log.Errorf("failed to send hello message: %s", err)
log.Errorf("failed to send auth message: %s", err)
return err
}
buf := make([]byte, messages.MaxHandshakeSize)
n, err := c.readWithTimeout(buf)
if err != nil {
log.Errorf("failed to read hello response: %s", err)
log.Errorf("failed to read auth response: %s", err)
return err
}
@ -279,23 +267,18 @@ func (c *Client) handShake() error {
return err
}
if msgType != messages.MsgTypeHelloResponse {
if msgType != messages.MsgTypeAuthResponse {
log.Errorf("unexpected message type: %s", msgType)
return fmt.Errorf("unexpected message type")
}
additionalData, err := messages.UnmarshalHelloResponse(buf[messages.SizeOfProtoHeader:n])
addr, err := messages.UnmarshalAuthResponse(buf[messages.SizeOfProtoHeader:n])
if err != nil {
return err
}
addr, err := address.Unmarshal(additionalData)
if err != nil {
return fmt.Errorf("unmarshal address: %w", err)
}
c.muInstanceURL.Lock()
c.instanceURL = &RelayAddr{addr: addr.URL}
c.instanceURL = &RelayAddr{addr: addr}
c.muInstanceURL.Unlock()
return nil
}

View File

@ -1,3 +1,4 @@
// Deprecated: This package is deprecated and will be removed in a future release.
package address
import (
@ -18,13 +19,3 @@ func (addr *Address) Marshal() ([]byte, error) {
}
return buf.Bytes(), nil
}
func Unmarshal(data []byte) (*Address, error) {
var addr Address
buf := bytes.NewBuffer(data)
dec := gob.NewDecoder(buf)
if err := dec.Decode(&addr); err != nil {
return nil, fmt.Errorf("decode Address: %w", err)
}
return &addr, nil
}

View File

@ -1,3 +1,4 @@
// Deprecated: This package is deprecated and will be removed in a future release.
package auth
import (
@ -30,15 +31,6 @@ type Msg struct {
AdditionalData []byte
}
func (msg *Msg) Marshal() ([]byte, error) {
var buf bytes.Buffer
enc := gob.NewEncoder(&buf)
if err := enc.Encode(msg); err != nil {
return nil, fmt.Errorf("encode Msg: %w", err)
}
return buf.Bytes(), nil
}
func UnmarshalMsg(data []byte) (*Msg, error) {
var msg *Msg

View File

@ -7,12 +7,16 @@ import (
)
const (
MsgTypeUnknown MsgType = 0
MsgTypeHello MsgType = 1
MsgTypeUnknown MsgType = 0
// Deprecated: Use MsgTypeAuth instead.
MsgTypeHello MsgType = 1
// Deprecated: Use MsgTypeAuthResponse instead.
MsgTypeHelloResponse MsgType = 2
MsgTypeTransport MsgType = 3
MsgTypeClose MsgType = 4
MsgTypeHealthCheck MsgType = 5
MsgTypeAuth = 6
MsgTypeAuthResponse = 7
SizeOfVersionByte = 1
SizeOfMsgType = 1
@ -47,6 +51,10 @@ func (m MsgType) String() string {
return "hello"
case MsgTypeHelloResponse:
return "hello response"
case MsgTypeAuth:
return "auth"
case MsgTypeAuthResponse:
return "auth response"
case MsgTypeTransport:
return "transport"
case MsgTypeClose:
@ -58,10 +66,6 @@ func (m MsgType) String() string {
}
}
type HelloResponse struct {
InstanceAddress string
}
// ValidateVersion checks if the given version is supported by the protocol
func ValidateVersion(msg []byte) (int, error) {
if len(msg) < SizeOfVersionByte {
@ -84,6 +88,7 @@ func DetermineClientMessageType(msg []byte) (MsgType, error) {
switch msgType {
case
MsgTypeHello,
MsgTypeAuth,
MsgTypeTransport,
MsgTypeClose,
MsgTypeHealthCheck:
@ -103,6 +108,7 @@ func DetermineServerMessageType(msg []byte) (MsgType, error) {
switch msgType {
case
MsgTypeHelloResponse,
MsgTypeAuthResponse,
MsgTypeTransport,
MsgTypeClose,
MsgTypeHealthCheck:
@ -112,6 +118,7 @@ func DetermineServerMessageType(msg []byte) (MsgType, error) {
}
}
// Deprecated: Use MarshalAuthMsg instead.
// MarshalHelloMsg initial hello message
// The Hello message is the first message sent by a client after establishing a connection with the Relay server. This
// message is used to authenticate the client with the server. The authentication is done using an HMAC method.
@ -135,6 +142,7 @@ func MarshalHelloMsg(peerID []byte, additions []byte) ([]byte, error) {
return msg, nil
}
// Deprecated: Use UnmarshalAuthMsg instead.
// UnmarshalHelloMsg extracts peerID and the additional data from the hello message. The Additional data is used to
// authenticate the client with the server.
func UnmarshalHelloMsg(msg []byte) ([]byte, []byte, error) {
@ -148,6 +156,7 @@ func UnmarshalHelloMsg(msg []byte) ([]byte, []byte, error) {
return msg[sizeOfMagicByte:headerSizeHello], msg[headerSizeHello:], nil
}
// Deprecated: Use MarshalAuthResponse instead.
// MarshalHelloResponse creates a response message to the hello message.
// In case of success connection the server response with a Hello Response message. This message contains the server's
// instance URL. This URL will be used by choose the common Relay server in case if the peers are in different Relay
@ -163,6 +172,7 @@ func MarshalHelloResponse(additionalData []byte) ([]byte, error) {
return msg, nil
}
// Deprecated: Use UnmarshalAuthResponse instead.
// UnmarshalHelloResponse extracts the additional data from the hello response message.
func UnmarshalHelloResponse(msg []byte) ([]byte, error) {
if len(msg) < headerSizeHelloResp {
@ -171,6 +181,65 @@ func UnmarshalHelloResponse(msg []byte) ([]byte, error) {
return msg, nil
}
// MarshalAuthMsg initial authentication message
// The Auth message is the first message sent by a client after establishing a connection with the Relay server. This
// message is used to authenticate the client with the server. The authentication is done using an HMAC method.
// The protocol does not limit to use HMAC, it can be any other method. If the authentication failed the server will
// close the network connection without any response.
func MarshalAuthMsg(peerID []byte, authPayload []byte) ([]byte, error) {
if len(peerID) != IDSize {
return nil, fmt.Errorf("invalid peerID length: %d", len(peerID))
}
msg := make([]byte, SizeOfProtoHeader+sizeOfMagicByte, SizeOfProtoHeader+headerSizeHello+len(authPayload))
msg[0] = byte(CurrentProtocolVersion)
msg[1] = byte(MsgTypeHello)
copy(msg[SizeOfProtoHeader:SizeOfProtoHeader+sizeOfMagicByte], magicHeader)
msg = append(msg, peerID...)
msg = append(msg, authPayload...)
return msg, nil
}
// UnmarshalAuthMsg extracts peerID and the auth payload from the message
func UnmarshalAuthMsg(msg []byte) ([]byte, []byte, error) {
if len(msg) < headerSizeHello {
return nil, nil, ErrInvalidMessageLength
}
if !bytes.Equal(msg[:sizeOfMagicByte], magicHeader) {
return nil, nil, errors.New("invalid magic header")
}
return msg[sizeOfMagicByte:headerSizeHello], msg[headerSizeHello:], nil
}
// MarshalAuthResponse creates a response message to the auth.
// In case of success connection the server response with a AuthResponse message. This message contains the server's
// instance URL. This URL will be used by choose the common Relay server in case if the peers are in different Relay
// servers.
func MarshalAuthResponse(address string) ([]byte, error) {
ab := []byte(address)
msg := make([]byte, SizeOfProtoHeader, SizeOfProtoHeader+headerSizeHelloResp+len(ab))
msg[0] = byte(CurrentProtocolVersion)
msg[1] = byte(MsgTypeHelloResponse)
msg = append(msg, ab...)
return msg, nil
}
// UnmarshalAuthResponse it is a confirmation message to auth success
func UnmarshalAuthResponse(msg []byte) (string, error) {
if len(msg) < headerSizeHelloResp+1 {
return "", ErrInvalidMessageLength
}
return string(msg), nil
}
// MarshalCloseMsg creates a close message.
// The close message is used to close the connection gracefully between the client and the server. The server and the
// client can send this message. After receiving this message, the server or client will close the connection.

View File

@ -21,9 +21,10 @@ import (
// Relay represents the relay server
type Relay struct {
metrics *metrics.Metrics
metricsCancel context.CancelFunc
validator auth.Validator
metrics *metrics.Metrics
metricsCancel context.CancelFunc
validator auth.Validator
validatorDummy auth.Validator // todo: this is just a dummy variable. Replace it with the proper validator
store *Store
instanceURL string
@ -168,14 +169,36 @@ func (r *Relay) handshake(conn net.Conn) ([]byte, error) {
return nil, fmt.Errorf("determine message type from %s: %w", conn.RemoteAddr(), err)
}
if msgType != messages.MsgTypeHello {
return nil, fmt.Errorf("invalid message type from %s", conn.RemoteAddr())
var (
responseMsg []byte
peerID []byte
)
switch msgType {
case messages.MsgTypeHello:
responseMsg, err = r.handleHelloMsg(buf[messages.SizeOfProtoHeader:n], conn.RemoteAddr())
case messages.MsgTypeAuth:
responseMsg, err = r.handleAuthMsg(buf[messages.SizeOfProtoHeader:n], conn.RemoteAddr())
default:
return nil, fmt.Errorf("invalid message type %d from %s", msgType, conn.RemoteAddr())
}
if err != nil {
return nil, err
}
peerID, authData, err := messages.UnmarshalHelloMsg(buf[messages.SizeOfProtoHeader:n])
_, err = conn.Write(responseMsg)
if err != nil {
return nil, fmt.Errorf("write to %s (%s): %w", peerID, conn.RemoteAddr(), err)
}
return peerID, nil
}
func (r *Relay) handleHelloMsg(buf []byte, remoteAddr net.Addr) ([]byte, error) {
peerID, authData, err := messages.UnmarshalHelloMsg(buf)
if err != nil {
return nil, fmt.Errorf("unmarshal hello message: %w", err)
}
log.Warnf("peer is using depracated initial message type: %s (%s)", peerID, remoteAddr)
authMsg, err := authmsg.UnmarshalMsg(authData)
if err != nil {
@ -183,24 +206,36 @@ func (r *Relay) handshake(conn net.Conn) ([]byte, error) {
}
if err := r.validator.Validate(sha256.New, authMsg.AdditionalData); err != nil {
return nil, fmt.Errorf("validate %s (%s): %w", peerID, conn.RemoteAddr(), err)
return nil, fmt.Errorf("validate %s (%s): %w", peerID, remoteAddr, err)
}
addr := &address.Address{URL: r.instanceURL}
addrData, err := addr.Marshal()
if err != nil {
return nil, fmt.Errorf("marshal addressc to %s (%s): %w", peerID, conn.RemoteAddr(), err)
return nil, fmt.Errorf("marshal addressc to %s (%s): %w", peerID, remoteAddr, err)
}
msg, err := messages.MarshalHelloResponse(addrData)
responseMsg, err := messages.MarshalHelloResponse(addrData)
if err != nil {
return nil, fmt.Errorf("marshal hello response to %s (%s): %w", peerID, conn.RemoteAddr(), err)
return nil, fmt.Errorf("marshal hello response to %s (%s): %w", peerID, remoteAddr, err)
}
_, err = conn.Write(msg)
if err != nil {
return nil, fmt.Errorf("write to %s (%s): %w", peerID, conn.RemoteAddr(), err)
}
return peerID, nil
return responseMsg, nil
}
func (r *Relay) handleAuthMsg(buf []byte, addr net.Addr) ([]byte, error) {
peerID, authPayload, err := messages.UnmarshalAuthMsg(buf)
if err != nil {
return nil, fmt.Errorf("unmarshal hello message: %w", err)
}
// todo use the proper validator
if err := r.validatorDummy.Validate(sha256.New, authPayload); err != nil {
return nil, fmt.Errorf("validate %s (%s): %w", peerID, addr, err)
}
responseMsg, err := messages.MarshalAuthResponse(r.instanceURL)
if err != nil {
return nil, fmt.Errorf("marshal hello response to %s (%s): %w", peerID, addr, err)
}
return responseMsg, nil
}