Eliminate gob usage from Relay protocol

This commit is contained in:
Zoltán Papp 2024-09-10 16:08:51 +02:00
parent f43a0a0177
commit a701148658
5 changed files with 136 additions and 66 deletions

View File

@ -14,8 +14,6 @@ import (
"github.com/netbirdio/netbird/relay/client/dialer/ws" "github.com/netbirdio/netbird/relay/client/dialer/ws"
"github.com/netbirdio/netbird/relay/healthcheck" "github.com/netbirdio/netbird/relay/healthcheck"
"github.com/netbirdio/netbird/relay/messages" "github.com/netbirdio/netbird/relay/messages"
"github.com/netbirdio/netbird/relay/messages/address"
auth2 "github.com/netbirdio/netbird/relay/messages/auth"
) )
const ( const (
@ -240,31 +238,21 @@ func (c *Client) connect() error {
} }
func (c *Client) handShake() error { func (c *Client) handShake() error {
authMsg := &auth2.Msg{ msg, err := messages.MarshalAuthMsg(c.hashedID, c.authTokenStore.TokenBinary())
AuthAlgorithm: auth2.AlgoHMACSHA256,
AdditionalData: c.authTokenStore.TokenBinary(),
}
authData, err := authMsg.Marshal()
if err != nil { if err != nil {
return fmt.Errorf("marshal auth message: %w", err) log.Errorf("failed to marshal auth message: %s", err)
}
msg, err := messages.MarshalHelloMsg(c.hashedID, authData)
if err != nil {
log.Errorf("failed to marshal hello message: %s", err)
return err return err
} }
_, err = c.relayConn.Write(msg) _, err = c.relayConn.Write(msg)
if err != nil { if err != nil {
log.Errorf("failed to send hello message: %s", err) log.Errorf("failed to send auth message: %s", err)
return err return err
} }
buf := make([]byte, messages.MaxHandshakeSize) buf := make([]byte, messages.MaxHandshakeSize)
n, err := c.readWithTimeout(buf) n, err := c.readWithTimeout(buf)
if err != nil { if err != nil {
log.Errorf("failed to read hello response: %s", err) log.Errorf("failed to read auth response: %s", err)
return err return err
} }
@ -279,23 +267,18 @@ func (c *Client) handShake() error {
return err return err
} }
if msgType != messages.MsgTypeHelloResponse { if msgType != messages.MsgTypeAuthResponse {
log.Errorf("unexpected message type: %s", msgType) log.Errorf("unexpected message type: %s", msgType)
return fmt.Errorf("unexpected message type") return fmt.Errorf("unexpected message type")
} }
additionalData, err := messages.UnmarshalHelloResponse(buf[messages.SizeOfProtoHeader:n]) addr, err := messages.UnmarshalAuthResponse(buf[messages.SizeOfProtoHeader:n])
if err != nil { if err != nil {
return err return err
} }
addr, err := address.Unmarshal(additionalData)
if err != nil {
return fmt.Errorf("unmarshal address: %w", err)
}
c.muInstanceURL.Lock() c.muInstanceURL.Lock()
c.instanceURL = &RelayAddr{addr: addr.URL} c.instanceURL = &RelayAddr{addr: addr}
c.muInstanceURL.Unlock() c.muInstanceURL.Unlock()
return nil return nil
} }

View File

@ -1,3 +1,4 @@
// Deprecated: This package is deprecated and will be removed in a future release.
package address package address
import ( import (
@ -18,13 +19,3 @@ func (addr *Address) Marshal() ([]byte, error) {
} }
return buf.Bytes(), nil return buf.Bytes(), nil
} }
func Unmarshal(data []byte) (*Address, error) {
var addr Address
buf := bytes.NewBuffer(data)
dec := gob.NewDecoder(buf)
if err := dec.Decode(&addr); err != nil {
return nil, fmt.Errorf("decode Address: %w", err)
}
return &addr, nil
}

View File

@ -1,3 +1,4 @@
// Deprecated: This package is deprecated and will be removed in a future release.
package auth package auth
import ( import (
@ -30,15 +31,6 @@ type Msg struct {
AdditionalData []byte AdditionalData []byte
} }
func (msg *Msg) Marshal() ([]byte, error) {
var buf bytes.Buffer
enc := gob.NewEncoder(&buf)
if err := enc.Encode(msg); err != nil {
return nil, fmt.Errorf("encode Msg: %w", err)
}
return buf.Bytes(), nil
}
func UnmarshalMsg(data []byte) (*Msg, error) { func UnmarshalMsg(data []byte) (*Msg, error) {
var msg *Msg var msg *Msg

View File

@ -7,12 +7,16 @@ import (
) )
const ( const (
MsgTypeUnknown MsgType = 0 MsgTypeUnknown MsgType = 0
MsgTypeHello MsgType = 1 // Deprecated: Use MsgTypeAuth instead.
MsgTypeHello MsgType = 1
// Deprecated: Use MsgTypeAuthResponse instead.
MsgTypeHelloResponse MsgType = 2 MsgTypeHelloResponse MsgType = 2
MsgTypeTransport MsgType = 3 MsgTypeTransport MsgType = 3
MsgTypeClose MsgType = 4 MsgTypeClose MsgType = 4
MsgTypeHealthCheck MsgType = 5 MsgTypeHealthCheck MsgType = 5
MsgTypeAuth = 6
MsgTypeAuthResponse = 7
SizeOfVersionByte = 1 SizeOfVersionByte = 1
SizeOfMsgType = 1 SizeOfMsgType = 1
@ -47,6 +51,10 @@ func (m MsgType) String() string {
return "hello" return "hello"
case MsgTypeHelloResponse: case MsgTypeHelloResponse:
return "hello response" return "hello response"
case MsgTypeAuth:
return "auth"
case MsgTypeAuthResponse:
return "auth response"
case MsgTypeTransport: case MsgTypeTransport:
return "transport" return "transport"
case MsgTypeClose: case MsgTypeClose:
@ -58,10 +66,6 @@ func (m MsgType) String() string {
} }
} }
type HelloResponse struct {
InstanceAddress string
}
// ValidateVersion checks if the given version is supported by the protocol // ValidateVersion checks if the given version is supported by the protocol
func ValidateVersion(msg []byte) (int, error) { func ValidateVersion(msg []byte) (int, error) {
if len(msg) < SizeOfVersionByte { if len(msg) < SizeOfVersionByte {
@ -84,6 +88,7 @@ func DetermineClientMessageType(msg []byte) (MsgType, error) {
switch msgType { switch msgType {
case case
MsgTypeHello, MsgTypeHello,
MsgTypeAuth,
MsgTypeTransport, MsgTypeTransport,
MsgTypeClose, MsgTypeClose,
MsgTypeHealthCheck: MsgTypeHealthCheck:
@ -103,6 +108,7 @@ func DetermineServerMessageType(msg []byte) (MsgType, error) {
switch msgType { switch msgType {
case case
MsgTypeHelloResponse, MsgTypeHelloResponse,
MsgTypeAuthResponse,
MsgTypeTransport, MsgTypeTransport,
MsgTypeClose, MsgTypeClose,
MsgTypeHealthCheck: MsgTypeHealthCheck:
@ -112,6 +118,7 @@ func DetermineServerMessageType(msg []byte) (MsgType, error) {
} }
} }
// Deprecated: Use MarshalAuthMsg instead.
// MarshalHelloMsg initial hello message // MarshalHelloMsg initial hello message
// The Hello message is the first message sent by a client after establishing a connection with the Relay server. This // The Hello message is the first message sent by a client after establishing a connection with the Relay server. This
// message is used to authenticate the client with the server. The authentication is done using an HMAC method. // message is used to authenticate the client with the server. The authentication is done using an HMAC method.
@ -135,6 +142,7 @@ func MarshalHelloMsg(peerID []byte, additions []byte) ([]byte, error) {
return msg, nil return msg, nil
} }
// Deprecated: Use UnmarshalAuthMsg instead.
// UnmarshalHelloMsg extracts peerID and the additional data from the hello message. The Additional data is used to // UnmarshalHelloMsg extracts peerID and the additional data from the hello message. The Additional data is used to
// authenticate the client with the server. // authenticate the client with the server.
func UnmarshalHelloMsg(msg []byte) ([]byte, []byte, error) { func UnmarshalHelloMsg(msg []byte) ([]byte, []byte, error) {
@ -148,6 +156,7 @@ func UnmarshalHelloMsg(msg []byte) ([]byte, []byte, error) {
return msg[sizeOfMagicByte:headerSizeHello], msg[headerSizeHello:], nil return msg[sizeOfMagicByte:headerSizeHello], msg[headerSizeHello:], nil
} }
// Deprecated: Use MarshalAuthResponse instead.
// MarshalHelloResponse creates a response message to the hello message. // MarshalHelloResponse creates a response message to the hello message.
// In case of success connection the server response with a Hello Response message. This message contains the server's // In case of success connection the server response with a Hello Response message. This message contains the server's
// instance URL. This URL will be used by choose the common Relay server in case if the peers are in different Relay // instance URL. This URL will be used by choose the common Relay server in case if the peers are in different Relay
@ -163,6 +172,7 @@ func MarshalHelloResponse(additionalData []byte) ([]byte, error) {
return msg, nil return msg, nil
} }
// Deprecated: Use UnmarshalAuthResponse instead.
// UnmarshalHelloResponse extracts the additional data from the hello response message. // UnmarshalHelloResponse extracts the additional data from the hello response message.
func UnmarshalHelloResponse(msg []byte) ([]byte, error) { func UnmarshalHelloResponse(msg []byte) ([]byte, error) {
if len(msg) < headerSizeHelloResp { if len(msg) < headerSizeHelloResp {
@ -171,6 +181,65 @@ func UnmarshalHelloResponse(msg []byte) ([]byte, error) {
return msg, nil return msg, nil
} }
// MarshalAuthMsg initial authentication message
// The Auth message is the first message sent by a client after establishing a connection with the Relay server. This
// message is used to authenticate the client with the server. The authentication is done using an HMAC method.
// The protocol does not limit to use HMAC, it can be any other method. If the authentication failed the server will
// close the network connection without any response.
func MarshalAuthMsg(peerID []byte, authPayload []byte) ([]byte, error) {
if len(peerID) != IDSize {
return nil, fmt.Errorf("invalid peerID length: %d", len(peerID))
}
msg := make([]byte, SizeOfProtoHeader+sizeOfMagicByte, SizeOfProtoHeader+headerSizeHello+len(authPayload))
msg[0] = byte(CurrentProtocolVersion)
msg[1] = byte(MsgTypeHello)
copy(msg[SizeOfProtoHeader:SizeOfProtoHeader+sizeOfMagicByte], magicHeader)
msg = append(msg, peerID...)
msg = append(msg, authPayload...)
return msg, nil
}
// UnmarshalAuthMsg extracts peerID and the auth payload from the message
func UnmarshalAuthMsg(msg []byte) ([]byte, []byte, error) {
if len(msg) < headerSizeHello {
return nil, nil, ErrInvalidMessageLength
}
if !bytes.Equal(msg[:sizeOfMagicByte], magicHeader) {
return nil, nil, errors.New("invalid magic header")
}
return msg[sizeOfMagicByte:headerSizeHello], msg[headerSizeHello:], nil
}
// MarshalAuthResponse creates a response message to the auth.
// In case of success connection the server response with a AuthResponse message. This message contains the server's
// instance URL. This URL will be used by choose the common Relay server in case if the peers are in different Relay
// servers.
func MarshalAuthResponse(address string) ([]byte, error) {
ab := []byte(address)
msg := make([]byte, SizeOfProtoHeader, SizeOfProtoHeader+headerSizeHelloResp+len(ab))
msg[0] = byte(CurrentProtocolVersion)
msg[1] = byte(MsgTypeHelloResponse)
msg = append(msg, ab...)
return msg, nil
}
// UnmarshalAuthResponse it is a confirmation message to auth success
func UnmarshalAuthResponse(msg []byte) (string, error) {
if len(msg) < headerSizeHelloResp+1 {
return "", ErrInvalidMessageLength
}
return string(msg), nil
}
// MarshalCloseMsg creates a close message. // MarshalCloseMsg creates a close message.
// The close message is used to close the connection gracefully between the client and the server. The server and the // The close message is used to close the connection gracefully between the client and the server. The server and the
// client can send this message. After receiving this message, the server or client will close the connection. // client can send this message. After receiving this message, the server or client will close the connection.

View File

@ -21,9 +21,10 @@ import (
// Relay represents the relay server // Relay represents the relay server
type Relay struct { type Relay struct {
metrics *metrics.Metrics metrics *metrics.Metrics
metricsCancel context.CancelFunc metricsCancel context.CancelFunc
validator auth.Validator validator auth.Validator
validatorDummy auth.Validator // todo: this is just a dummy variable. Replace it with the proper validator
store *Store store *Store
instanceURL string instanceURL string
@ -168,14 +169,36 @@ func (r *Relay) handshake(conn net.Conn) ([]byte, error) {
return nil, fmt.Errorf("determine message type from %s: %w", conn.RemoteAddr(), err) return nil, fmt.Errorf("determine message type from %s: %w", conn.RemoteAddr(), err)
} }
if msgType != messages.MsgTypeHello { var (
return nil, fmt.Errorf("invalid message type from %s", conn.RemoteAddr()) responseMsg []byte
peerID []byte
)
switch msgType {
case messages.MsgTypeHello:
responseMsg, err = r.handleHelloMsg(buf[messages.SizeOfProtoHeader:n], conn.RemoteAddr())
case messages.MsgTypeAuth:
responseMsg, err = r.handleAuthMsg(buf[messages.SizeOfProtoHeader:n], conn.RemoteAddr())
default:
return nil, fmt.Errorf("invalid message type %d from %s", msgType, conn.RemoteAddr())
}
if err != nil {
return nil, err
} }
peerID, authData, err := messages.UnmarshalHelloMsg(buf[messages.SizeOfProtoHeader:n]) _, err = conn.Write(responseMsg)
if err != nil {
return nil, fmt.Errorf("write to %s (%s): %w", peerID, conn.RemoteAddr(), err)
}
return peerID, nil
}
func (r *Relay) handleHelloMsg(buf []byte, remoteAddr net.Addr) ([]byte, error) {
peerID, authData, err := messages.UnmarshalHelloMsg(buf)
if err != nil { if err != nil {
return nil, fmt.Errorf("unmarshal hello message: %w", err) return nil, fmt.Errorf("unmarshal hello message: %w", err)
} }
log.Warnf("peer is using depracated initial message type: %s (%s)", peerID, remoteAddr)
authMsg, err := authmsg.UnmarshalMsg(authData) authMsg, err := authmsg.UnmarshalMsg(authData)
if err != nil { if err != nil {
@ -183,24 +206,36 @@ func (r *Relay) handshake(conn net.Conn) ([]byte, error) {
} }
if err := r.validator.Validate(sha256.New, authMsg.AdditionalData); err != nil { if err := r.validator.Validate(sha256.New, authMsg.AdditionalData); err != nil {
return nil, fmt.Errorf("validate %s (%s): %w", peerID, conn.RemoteAddr(), err) return nil, fmt.Errorf("validate %s (%s): %w", peerID, remoteAddr, err)
} }
addr := &address.Address{URL: r.instanceURL} addr := &address.Address{URL: r.instanceURL}
addrData, err := addr.Marshal() addrData, err := addr.Marshal()
if err != nil { if err != nil {
return nil, fmt.Errorf("marshal addressc to %s (%s): %w", peerID, conn.RemoteAddr(), err) return nil, fmt.Errorf("marshal addressc to %s (%s): %w", peerID, remoteAddr, err)
} }
msg, err := messages.MarshalHelloResponse(addrData) responseMsg, err := messages.MarshalHelloResponse(addrData)
if err != nil { if err != nil {
return nil, fmt.Errorf("marshal hello response to %s (%s): %w", peerID, conn.RemoteAddr(), err) return nil, fmt.Errorf("marshal hello response to %s (%s): %w", peerID, remoteAddr, err)
} }
return responseMsg, nil
_, err = conn.Write(msg) }
if err != nil {
return nil, fmt.Errorf("write to %s (%s): %w", peerID, conn.RemoteAddr(), err) func (r *Relay) handleAuthMsg(buf []byte, addr net.Addr) ([]byte, error) {
} peerID, authPayload, err := messages.UnmarshalAuthMsg(buf)
if err != nil {
return peerID, nil return nil, fmt.Errorf("unmarshal hello message: %w", err)
}
// todo use the proper validator
if err := r.validatorDummy.Validate(sha256.New, authPayload); err != nil {
return nil, fmt.Errorf("validate %s (%s): %w", peerID, addr, err)
}
responseMsg, err := messages.MarshalAuthResponse(r.instanceURL)
if err != nil {
return nil, fmt.Errorf("marshal hello response to %s (%s): %w", peerID, addr, err)
}
return responseMsg, nil
} }