mirror of
https://github.com/netbirdio/netbird.git
synced 2025-08-17 02:21:47 +02:00
[relay] Code cleaning (#3074)
- Keep message byte processing in message.go file - Add new unit tests
This commit is contained in:
@ -23,20 +23,26 @@ const (
|
||||
MsgTypeAuth = 6
|
||||
MsgTypeAuthResponse = 7
|
||||
|
||||
SizeOfVersionByte = 1
|
||||
SizeOfMsgType = 1
|
||||
// base size of the message
|
||||
sizeOfVersionByte = 1
|
||||
sizeOfMsgType = 1
|
||||
sizeOfProtoHeader = sizeOfVersionByte + sizeOfMsgType
|
||||
|
||||
SizeOfProtoHeader = SizeOfVersionByte + SizeOfMsgType
|
||||
|
||||
sizeOfMagicByte = 4
|
||||
|
||||
headerSizeTransport = IDSize
|
||||
// auth message
|
||||
sizeOfMagicByte = 4
|
||||
headerSizeAuth = sizeOfMagicByte + IDSize
|
||||
offsetMagicByte = sizeOfProtoHeader
|
||||
offsetAuthPeerID = sizeOfProtoHeader + sizeOfMagicByte
|
||||
headerTotalSizeAuth = sizeOfProtoHeader + headerSizeAuth
|
||||
|
||||
// hello message
|
||||
headerSizeHello = sizeOfMagicByte + IDSize
|
||||
headerSizeHelloResp = 0
|
||||
|
||||
headerSizeAuth = sizeOfMagicByte + IDSize
|
||||
headerSizeAuthResp = 0
|
||||
// transport
|
||||
headerSizeTransport = IDSize
|
||||
offsetTransportID = sizeOfProtoHeader
|
||||
headerTotalSizeTransport = sizeOfProtoHeader + headerSizeTransport
|
||||
)
|
||||
|
||||
var (
|
||||
@ -73,7 +79,7 @@ func (m MsgType) String() string {
|
||||
|
||||
// ValidateVersion checks if the given version is supported by the protocol
|
||||
func ValidateVersion(msg []byte) (int, error) {
|
||||
if len(msg) < SizeOfVersionByte {
|
||||
if len(msg) < sizeOfProtoHeader {
|
||||
return 0, ErrInvalidMessageLength
|
||||
}
|
||||
version := int(msg[0])
|
||||
@ -85,11 +91,11 @@ func ValidateVersion(msg []byte) (int, error) {
|
||||
|
||||
// DetermineClientMessageType determines the message type from the first the message
|
||||
func DetermineClientMessageType(msg []byte) (MsgType, error) {
|
||||
if len(msg) < SizeOfMsgType {
|
||||
if len(msg) < sizeOfProtoHeader {
|
||||
return 0, ErrInvalidMessageLength
|
||||
}
|
||||
|
||||
msgType := MsgType(msg[0])
|
||||
msgType := MsgType(msg[1])
|
||||
switch msgType {
|
||||
case
|
||||
MsgTypeHello,
|
||||
@ -105,11 +111,11 @@ func DetermineClientMessageType(msg []byte) (MsgType, error) {
|
||||
|
||||
// DetermineServerMessageType determines the message type from the first the message
|
||||
func DetermineServerMessageType(msg []byte) (MsgType, error) {
|
||||
if len(msg) < SizeOfMsgType {
|
||||
if len(msg) < sizeOfProtoHeader {
|
||||
return 0, ErrInvalidMessageLength
|
||||
}
|
||||
|
||||
msgType := MsgType(msg[0])
|
||||
msgType := MsgType(msg[1])
|
||||
switch msgType {
|
||||
case
|
||||
MsgTypeHelloResponse,
|
||||
@ -134,12 +140,12 @@ func MarshalHelloMsg(peerID []byte, additions []byte) ([]byte, error) {
|
||||
return nil, fmt.Errorf("invalid peerID length: %d", len(peerID))
|
||||
}
|
||||
|
||||
msg := make([]byte, SizeOfProtoHeader+sizeOfMagicByte, SizeOfProtoHeader+headerSizeHello+len(additions))
|
||||
msg := make([]byte, sizeOfProtoHeader+sizeOfMagicByte, sizeOfProtoHeader+headerSizeHello+len(additions))
|
||||
|
||||
msg[0] = byte(CurrentProtocolVersion)
|
||||
msg[1] = byte(MsgTypeHello)
|
||||
|
||||
copy(msg[SizeOfProtoHeader:SizeOfProtoHeader+sizeOfMagicByte], magicHeader)
|
||||
copy(msg[sizeOfProtoHeader:sizeOfProtoHeader+sizeOfMagicByte], magicHeader)
|
||||
|
||||
msg = append(msg, peerID...)
|
||||
msg = append(msg, additions...)
|
||||
@ -151,14 +157,14 @@ func MarshalHelloMsg(peerID []byte, additions []byte) ([]byte, error) {
|
||||
// UnmarshalHelloMsg extracts peerID and the additional data from the hello message. The Additional data is used to
|
||||
// authenticate the client with the server.
|
||||
func UnmarshalHelloMsg(msg []byte) ([]byte, []byte, error) {
|
||||
if len(msg) < headerSizeHello {
|
||||
if len(msg) < sizeOfProtoHeader+headerSizeHello {
|
||||
return nil, nil, ErrInvalidMessageLength
|
||||
}
|
||||
if !bytes.Equal(msg[:sizeOfMagicByte], magicHeader) {
|
||||
if !bytes.Equal(msg[sizeOfProtoHeader:sizeOfProtoHeader+sizeOfMagicByte], magicHeader) {
|
||||
return nil, nil, errors.New("invalid magic header")
|
||||
}
|
||||
|
||||
return msg[sizeOfMagicByte:headerSizeHello], msg[headerSizeHello:], nil
|
||||
return msg[sizeOfProtoHeader+sizeOfMagicByte : sizeOfProtoHeader+headerSizeHello], msg[headerSizeHello:], nil
|
||||
}
|
||||
|
||||
// Deprecated: Use MarshalAuthResponse instead.
|
||||
@ -167,7 +173,7 @@ func UnmarshalHelloMsg(msg []byte) ([]byte, []byte, error) {
|
||||
// instance URL. This URL will be used by choose the common Relay server in case if the peers are in different Relay
|
||||
// servers.
|
||||
func MarshalHelloResponse(additionalData []byte) ([]byte, error) {
|
||||
msg := make([]byte, SizeOfProtoHeader, SizeOfProtoHeader+headerSizeHelloResp+len(additionalData))
|
||||
msg := make([]byte, sizeOfProtoHeader, sizeOfProtoHeader+headerSizeHelloResp+len(additionalData))
|
||||
|
||||
msg[0] = byte(CurrentProtocolVersion)
|
||||
msg[1] = byte(MsgTypeHelloResponse)
|
||||
@ -180,7 +186,7 @@ func MarshalHelloResponse(additionalData []byte) ([]byte, error) {
|
||||
// Deprecated: Use UnmarshalAuthResponse instead.
|
||||
// UnmarshalHelloResponse extracts the additional data from the hello response message.
|
||||
func UnmarshalHelloResponse(msg []byte) ([]byte, error) {
|
||||
if len(msg) < headerSizeHelloResp {
|
||||
if len(msg) < sizeOfProtoHeader+headerSizeHelloResp {
|
||||
return nil, ErrInvalidMessageLength
|
||||
}
|
||||
return msg, nil
|
||||
@ -196,12 +202,12 @@ func MarshalAuthMsg(peerID []byte, authPayload []byte) ([]byte, error) {
|
||||
return nil, fmt.Errorf("invalid peerID length: %d", len(peerID))
|
||||
}
|
||||
|
||||
msg := make([]byte, SizeOfProtoHeader+sizeOfMagicByte, SizeOfProtoHeader+headerSizeAuth+len(authPayload))
|
||||
msg := make([]byte, sizeOfProtoHeader+sizeOfMagicByte, headerTotalSizeAuth+len(authPayload))
|
||||
|
||||
msg[0] = byte(CurrentProtocolVersion)
|
||||
msg[1] = byte(MsgTypeAuth)
|
||||
|
||||
copy(msg[SizeOfProtoHeader:SizeOfProtoHeader+sizeOfMagicByte], magicHeader)
|
||||
copy(msg[sizeOfProtoHeader:], magicHeader)
|
||||
|
||||
msg = append(msg, peerID...)
|
||||
msg = append(msg, authPayload...)
|
||||
@ -211,14 +217,14 @@ func MarshalAuthMsg(peerID []byte, authPayload []byte) ([]byte, error) {
|
||||
|
||||
// UnmarshalAuthMsg extracts peerID and the auth payload from the message
|
||||
func UnmarshalAuthMsg(msg []byte) ([]byte, []byte, error) {
|
||||
if len(msg) < headerSizeAuth {
|
||||
if len(msg) < headerTotalSizeAuth {
|
||||
return nil, nil, ErrInvalidMessageLength
|
||||
}
|
||||
if !bytes.Equal(msg[:sizeOfMagicByte], magicHeader) {
|
||||
if !bytes.Equal(msg[offsetMagicByte:offsetMagicByte+sizeOfMagicByte], magicHeader) {
|
||||
return nil, nil, errors.New("invalid magic header")
|
||||
}
|
||||
|
||||
return msg[sizeOfMagicByte:headerSizeAuth], msg[headerSizeAuth:], nil
|
||||
return msg[offsetAuthPeerID:headerTotalSizeAuth], msg[headerTotalSizeAuth:], nil
|
||||
}
|
||||
|
||||
// MarshalAuthResponse creates a response message to the auth.
|
||||
@ -227,7 +233,7 @@ func UnmarshalAuthMsg(msg []byte) ([]byte, []byte, error) {
|
||||
// servers.
|
||||
func MarshalAuthResponse(address string) ([]byte, error) {
|
||||
ab := []byte(address)
|
||||
msg := make([]byte, SizeOfProtoHeader, SizeOfProtoHeader+headerSizeAuthResp+len(ab))
|
||||
msg := make([]byte, sizeOfProtoHeader, sizeOfProtoHeader+len(ab))
|
||||
|
||||
msg[0] = byte(CurrentProtocolVersion)
|
||||
msg[1] = byte(MsgTypeAuthResponse)
|
||||
@ -243,39 +249,34 @@ func MarshalAuthResponse(address string) ([]byte, error) {
|
||||
|
||||
// UnmarshalAuthResponse it is a confirmation message to auth success
|
||||
func UnmarshalAuthResponse(msg []byte) (string, error) {
|
||||
if len(msg) < headerSizeAuthResp+1 {
|
||||
if len(msg) < sizeOfProtoHeader+1 {
|
||||
return "", ErrInvalidMessageLength
|
||||
}
|
||||
return string(msg), nil
|
||||
return string(msg[sizeOfProtoHeader:]), nil
|
||||
}
|
||||
|
||||
// MarshalCloseMsg creates a close message.
|
||||
// The close message is used to close the connection gracefully between the client and the server. The server and the
|
||||
// client can send this message. After receiving this message, the server or client will close the connection.
|
||||
func MarshalCloseMsg() []byte {
|
||||
msg := make([]byte, SizeOfProtoHeader)
|
||||
|
||||
msg[0] = byte(CurrentProtocolVersion)
|
||||
msg[1] = byte(MsgTypeClose)
|
||||
|
||||
return msg
|
||||
return []byte{
|
||||
byte(CurrentProtocolVersion),
|
||||
byte(MsgTypeClose),
|
||||
}
|
||||
}
|
||||
|
||||
// MarshalTransportMsg creates a transport message.
|
||||
// The transport message is used to exchange data between peers. The message contains the data to be exchanged and the
|
||||
// destination peer hashed ID.
|
||||
func MarshalTransportMsg(peerID []byte, payload []byte) ([]byte, error) {
|
||||
func MarshalTransportMsg(peerID, payload []byte) ([]byte, error) {
|
||||
if len(peerID) != IDSize {
|
||||
return nil, fmt.Errorf("invalid peerID length: %d", len(peerID))
|
||||
}
|
||||
|
||||
msg := make([]byte, SizeOfProtoHeader+headerSizeTransport, SizeOfProtoHeader+headerSizeTransport+len(payload))
|
||||
|
||||
msg := make([]byte, headerTotalSizeTransport, headerTotalSizeTransport+len(payload))
|
||||
msg[0] = byte(CurrentProtocolVersion)
|
||||
msg[1] = byte(MsgTypeTransport)
|
||||
|
||||
copy(msg[SizeOfProtoHeader:], peerID)
|
||||
|
||||
copy(msg[sizeOfProtoHeader:], peerID)
|
||||
msg = append(msg, payload...)
|
||||
|
||||
return msg, nil
|
||||
@ -283,29 +284,29 @@ func MarshalTransportMsg(peerID []byte, payload []byte) ([]byte, error) {
|
||||
|
||||
// UnmarshalTransportMsg extracts the peerID and the payload from the transport message.
|
||||
func UnmarshalTransportMsg(buf []byte) ([]byte, []byte, error) {
|
||||
if len(buf) < headerSizeTransport {
|
||||
if len(buf) < headerTotalSizeTransport {
|
||||
return nil, nil, ErrInvalidMessageLength
|
||||
}
|
||||
|
||||
return buf[:headerSizeTransport], buf[headerSizeTransport:], nil
|
||||
return buf[offsetTransportID:headerTotalSizeTransport], buf[headerTotalSizeTransport:], nil
|
||||
}
|
||||
|
||||
// UnmarshalTransportID extracts the peerID from the transport message.
|
||||
func UnmarshalTransportID(buf []byte) ([]byte, error) {
|
||||
if len(buf) < headerSizeTransport {
|
||||
if len(buf) < headerTotalSizeTransport {
|
||||
return nil, ErrInvalidMessageLength
|
||||
}
|
||||
return buf[:headerSizeTransport], nil
|
||||
return buf[offsetTransportID:headerTotalSizeTransport], nil
|
||||
}
|
||||
|
||||
// UpdateTransportMsg updates the peerID in the transport message.
|
||||
// With this function the server can reuse the given byte slice to update the peerID in the transport message. So do
|
||||
// need to allocate a new byte slice.
|
||||
func UpdateTransportMsg(msg []byte, peerID []byte) error {
|
||||
if len(msg) < len(peerID) {
|
||||
if len(msg) < offsetTransportID+len(peerID) {
|
||||
return ErrInvalidMessageLength
|
||||
}
|
||||
copy(msg, peerID)
|
||||
copy(msg[offsetTransportID:], peerID)
|
||||
return nil
|
||||
}
|
||||
|
||||
|
Reference in New Issue
Block a user