mirror of
https://github.com/netbirdio/netbird.git
synced 2025-06-24 11:41:35 +02:00
[relay] Code cleaning (#3074)
- Keep message byte processing in message.go file - Add new unit tests
This commit is contained in:
parent
b34887a920
commit
6a6b527f24
@ -306,7 +306,7 @@ func (c *Client) handShake() error {
|
|||||||
return fmt.Errorf("validate version: %w", err)
|
return fmt.Errorf("validate version: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
msgType, err := messages.DetermineServerMessageType(buf[messages.SizeOfVersionByte:n])
|
msgType, err := messages.DetermineServerMessageType(buf[:n])
|
||||||
if err != nil {
|
if err != nil {
|
||||||
c.log.Errorf("failed to determine message type: %s", err)
|
c.log.Errorf("failed to determine message type: %s", err)
|
||||||
return err
|
return err
|
||||||
@ -317,7 +317,7 @@ func (c *Client) handShake() error {
|
|||||||
return fmt.Errorf("unexpected message type")
|
return fmt.Errorf("unexpected message type")
|
||||||
}
|
}
|
||||||
|
|
||||||
addr, err := messages.UnmarshalAuthResponse(buf[messages.SizeOfProtoHeader:n])
|
addr, err := messages.UnmarshalAuthResponse(buf[:n])
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
@ -348,24 +348,27 @@ func (c *Client) readLoop(relayConn net.Conn) {
|
|||||||
c.log.Debugf("failed to read message from relay server: %s", errExit)
|
c.log.Debugf("failed to read message from relay server: %s", errExit)
|
||||||
}
|
}
|
||||||
c.mu.Unlock()
|
c.mu.Unlock()
|
||||||
|
c.bufPool.Put(bufPtr)
|
||||||
break
|
break
|
||||||
}
|
}
|
||||||
|
|
||||||
_, err := messages.ValidateVersion(buf[:n])
|
buf = buf[:n]
|
||||||
|
|
||||||
|
_, err := messages.ValidateVersion(buf)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
c.log.Errorf("failed to validate protocol version: %s", err)
|
c.log.Errorf("failed to validate protocol version: %s", err)
|
||||||
c.bufPool.Put(bufPtr)
|
c.bufPool.Put(bufPtr)
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
msgType, err := messages.DetermineServerMessageType(buf[messages.SizeOfVersionByte:n])
|
msgType, err := messages.DetermineServerMessageType(buf)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
c.log.Errorf("failed to determine message type: %s", err)
|
c.log.Errorf("failed to determine message type: %s", err)
|
||||||
c.bufPool.Put(bufPtr)
|
c.bufPool.Put(bufPtr)
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
if !c.handleMsg(msgType, buf[messages.SizeOfProtoHeader:n], bufPtr, hc, internallyStoppedFlag) {
|
if !c.handleMsg(msgType, buf, bufPtr, hc, internallyStoppedFlag) {
|
||||||
break
|
break
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -23,20 +23,26 @@ const (
|
|||||||
MsgTypeAuth = 6
|
MsgTypeAuth = 6
|
||||||
MsgTypeAuthResponse = 7
|
MsgTypeAuthResponse = 7
|
||||||
|
|
||||||
SizeOfVersionByte = 1
|
// base size of the message
|
||||||
SizeOfMsgType = 1
|
sizeOfVersionByte = 1
|
||||||
|
sizeOfMsgType = 1
|
||||||
|
sizeOfProtoHeader = sizeOfVersionByte + sizeOfMsgType
|
||||||
|
|
||||||
SizeOfProtoHeader = SizeOfVersionByte + SizeOfMsgType
|
// auth message
|
||||||
|
sizeOfMagicByte = 4
|
||||||
sizeOfMagicByte = 4
|
headerSizeAuth = sizeOfMagicByte + IDSize
|
||||||
|
offsetMagicByte = sizeOfProtoHeader
|
||||||
headerSizeTransport = IDSize
|
offsetAuthPeerID = sizeOfProtoHeader + sizeOfMagicByte
|
||||||
|
headerTotalSizeAuth = sizeOfProtoHeader + headerSizeAuth
|
||||||
|
|
||||||
|
// hello message
|
||||||
headerSizeHello = sizeOfMagicByte + IDSize
|
headerSizeHello = sizeOfMagicByte + IDSize
|
||||||
headerSizeHelloResp = 0
|
headerSizeHelloResp = 0
|
||||||
|
|
||||||
headerSizeAuth = sizeOfMagicByte + IDSize
|
// transport
|
||||||
headerSizeAuthResp = 0
|
headerSizeTransport = IDSize
|
||||||
|
offsetTransportID = sizeOfProtoHeader
|
||||||
|
headerTotalSizeTransport = sizeOfProtoHeader + headerSizeTransport
|
||||||
)
|
)
|
||||||
|
|
||||||
var (
|
var (
|
||||||
@ -73,7 +79,7 @@ func (m MsgType) String() string {
|
|||||||
|
|
||||||
// ValidateVersion checks if the given version is supported by the protocol
|
// ValidateVersion checks if the given version is supported by the protocol
|
||||||
func ValidateVersion(msg []byte) (int, error) {
|
func ValidateVersion(msg []byte) (int, error) {
|
||||||
if len(msg) < SizeOfVersionByte {
|
if len(msg) < sizeOfProtoHeader {
|
||||||
return 0, ErrInvalidMessageLength
|
return 0, ErrInvalidMessageLength
|
||||||
}
|
}
|
||||||
version := int(msg[0])
|
version := int(msg[0])
|
||||||
@ -85,11 +91,11 @@ func ValidateVersion(msg []byte) (int, error) {
|
|||||||
|
|
||||||
// DetermineClientMessageType determines the message type from the first the message
|
// DetermineClientMessageType determines the message type from the first the message
|
||||||
func DetermineClientMessageType(msg []byte) (MsgType, error) {
|
func DetermineClientMessageType(msg []byte) (MsgType, error) {
|
||||||
if len(msg) < SizeOfMsgType {
|
if len(msg) < sizeOfProtoHeader {
|
||||||
return 0, ErrInvalidMessageLength
|
return 0, ErrInvalidMessageLength
|
||||||
}
|
}
|
||||||
|
|
||||||
msgType := MsgType(msg[0])
|
msgType := MsgType(msg[1])
|
||||||
switch msgType {
|
switch msgType {
|
||||||
case
|
case
|
||||||
MsgTypeHello,
|
MsgTypeHello,
|
||||||
@ -105,11 +111,11 @@ func DetermineClientMessageType(msg []byte) (MsgType, error) {
|
|||||||
|
|
||||||
// DetermineServerMessageType determines the message type from the first the message
|
// DetermineServerMessageType determines the message type from the first the message
|
||||||
func DetermineServerMessageType(msg []byte) (MsgType, error) {
|
func DetermineServerMessageType(msg []byte) (MsgType, error) {
|
||||||
if len(msg) < SizeOfMsgType {
|
if len(msg) < sizeOfProtoHeader {
|
||||||
return 0, ErrInvalidMessageLength
|
return 0, ErrInvalidMessageLength
|
||||||
}
|
}
|
||||||
|
|
||||||
msgType := MsgType(msg[0])
|
msgType := MsgType(msg[1])
|
||||||
switch msgType {
|
switch msgType {
|
||||||
case
|
case
|
||||||
MsgTypeHelloResponse,
|
MsgTypeHelloResponse,
|
||||||
@ -134,12 +140,12 @@ func MarshalHelloMsg(peerID []byte, additions []byte) ([]byte, error) {
|
|||||||
return nil, fmt.Errorf("invalid peerID length: %d", len(peerID))
|
return nil, fmt.Errorf("invalid peerID length: %d", len(peerID))
|
||||||
}
|
}
|
||||||
|
|
||||||
msg := make([]byte, SizeOfProtoHeader+sizeOfMagicByte, SizeOfProtoHeader+headerSizeHello+len(additions))
|
msg := make([]byte, sizeOfProtoHeader+sizeOfMagicByte, sizeOfProtoHeader+headerSizeHello+len(additions))
|
||||||
|
|
||||||
msg[0] = byte(CurrentProtocolVersion)
|
msg[0] = byte(CurrentProtocolVersion)
|
||||||
msg[1] = byte(MsgTypeHello)
|
msg[1] = byte(MsgTypeHello)
|
||||||
|
|
||||||
copy(msg[SizeOfProtoHeader:SizeOfProtoHeader+sizeOfMagicByte], magicHeader)
|
copy(msg[sizeOfProtoHeader:sizeOfProtoHeader+sizeOfMagicByte], magicHeader)
|
||||||
|
|
||||||
msg = append(msg, peerID...)
|
msg = append(msg, peerID...)
|
||||||
msg = append(msg, additions...)
|
msg = append(msg, additions...)
|
||||||
@ -151,14 +157,14 @@ func MarshalHelloMsg(peerID []byte, additions []byte) ([]byte, error) {
|
|||||||
// UnmarshalHelloMsg extracts peerID and the additional data from the hello message. The Additional data is used to
|
// UnmarshalHelloMsg extracts peerID and the additional data from the hello message. The Additional data is used to
|
||||||
// authenticate the client with the server.
|
// authenticate the client with the server.
|
||||||
func UnmarshalHelloMsg(msg []byte) ([]byte, []byte, error) {
|
func UnmarshalHelloMsg(msg []byte) ([]byte, []byte, error) {
|
||||||
if len(msg) < headerSizeHello {
|
if len(msg) < sizeOfProtoHeader+headerSizeHello {
|
||||||
return nil, nil, ErrInvalidMessageLength
|
return nil, nil, ErrInvalidMessageLength
|
||||||
}
|
}
|
||||||
if !bytes.Equal(msg[:sizeOfMagicByte], magicHeader) {
|
if !bytes.Equal(msg[sizeOfProtoHeader:sizeOfProtoHeader+sizeOfMagicByte], magicHeader) {
|
||||||
return nil, nil, errors.New("invalid magic header")
|
return nil, nil, errors.New("invalid magic header")
|
||||||
}
|
}
|
||||||
|
|
||||||
return msg[sizeOfMagicByte:headerSizeHello], msg[headerSizeHello:], nil
|
return msg[sizeOfProtoHeader+sizeOfMagicByte : sizeOfProtoHeader+headerSizeHello], msg[headerSizeHello:], nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// Deprecated: Use MarshalAuthResponse instead.
|
// Deprecated: Use MarshalAuthResponse instead.
|
||||||
@ -167,7 +173,7 @@ func UnmarshalHelloMsg(msg []byte) ([]byte, []byte, error) {
|
|||||||
// instance URL. This URL will be used by choose the common Relay server in case if the peers are in different Relay
|
// instance URL. This URL will be used by choose the common Relay server in case if the peers are in different Relay
|
||||||
// servers.
|
// servers.
|
||||||
func MarshalHelloResponse(additionalData []byte) ([]byte, error) {
|
func MarshalHelloResponse(additionalData []byte) ([]byte, error) {
|
||||||
msg := make([]byte, SizeOfProtoHeader, SizeOfProtoHeader+headerSizeHelloResp+len(additionalData))
|
msg := make([]byte, sizeOfProtoHeader, sizeOfProtoHeader+headerSizeHelloResp+len(additionalData))
|
||||||
|
|
||||||
msg[0] = byte(CurrentProtocolVersion)
|
msg[0] = byte(CurrentProtocolVersion)
|
||||||
msg[1] = byte(MsgTypeHelloResponse)
|
msg[1] = byte(MsgTypeHelloResponse)
|
||||||
@ -180,7 +186,7 @@ func MarshalHelloResponse(additionalData []byte) ([]byte, error) {
|
|||||||
// Deprecated: Use UnmarshalAuthResponse instead.
|
// Deprecated: Use UnmarshalAuthResponse instead.
|
||||||
// UnmarshalHelloResponse extracts the additional data from the hello response message.
|
// UnmarshalHelloResponse extracts the additional data from the hello response message.
|
||||||
func UnmarshalHelloResponse(msg []byte) ([]byte, error) {
|
func UnmarshalHelloResponse(msg []byte) ([]byte, error) {
|
||||||
if len(msg) < headerSizeHelloResp {
|
if len(msg) < sizeOfProtoHeader+headerSizeHelloResp {
|
||||||
return nil, ErrInvalidMessageLength
|
return nil, ErrInvalidMessageLength
|
||||||
}
|
}
|
||||||
return msg, nil
|
return msg, nil
|
||||||
@ -196,12 +202,12 @@ func MarshalAuthMsg(peerID []byte, authPayload []byte) ([]byte, error) {
|
|||||||
return nil, fmt.Errorf("invalid peerID length: %d", len(peerID))
|
return nil, fmt.Errorf("invalid peerID length: %d", len(peerID))
|
||||||
}
|
}
|
||||||
|
|
||||||
msg := make([]byte, SizeOfProtoHeader+sizeOfMagicByte, SizeOfProtoHeader+headerSizeAuth+len(authPayload))
|
msg := make([]byte, sizeOfProtoHeader+sizeOfMagicByte, headerTotalSizeAuth+len(authPayload))
|
||||||
|
|
||||||
msg[0] = byte(CurrentProtocolVersion)
|
msg[0] = byte(CurrentProtocolVersion)
|
||||||
msg[1] = byte(MsgTypeAuth)
|
msg[1] = byte(MsgTypeAuth)
|
||||||
|
|
||||||
copy(msg[SizeOfProtoHeader:SizeOfProtoHeader+sizeOfMagicByte], magicHeader)
|
copy(msg[sizeOfProtoHeader:], magicHeader)
|
||||||
|
|
||||||
msg = append(msg, peerID...)
|
msg = append(msg, peerID...)
|
||||||
msg = append(msg, authPayload...)
|
msg = append(msg, authPayload...)
|
||||||
@ -211,14 +217,14 @@ func MarshalAuthMsg(peerID []byte, authPayload []byte) ([]byte, error) {
|
|||||||
|
|
||||||
// UnmarshalAuthMsg extracts peerID and the auth payload from the message
|
// UnmarshalAuthMsg extracts peerID and the auth payload from the message
|
||||||
func UnmarshalAuthMsg(msg []byte) ([]byte, []byte, error) {
|
func UnmarshalAuthMsg(msg []byte) ([]byte, []byte, error) {
|
||||||
if len(msg) < headerSizeAuth {
|
if len(msg) < headerTotalSizeAuth {
|
||||||
return nil, nil, ErrInvalidMessageLength
|
return nil, nil, ErrInvalidMessageLength
|
||||||
}
|
}
|
||||||
if !bytes.Equal(msg[:sizeOfMagicByte], magicHeader) {
|
if !bytes.Equal(msg[offsetMagicByte:offsetMagicByte+sizeOfMagicByte], magicHeader) {
|
||||||
return nil, nil, errors.New("invalid magic header")
|
return nil, nil, errors.New("invalid magic header")
|
||||||
}
|
}
|
||||||
|
|
||||||
return msg[sizeOfMagicByte:headerSizeAuth], msg[headerSizeAuth:], nil
|
return msg[offsetAuthPeerID:headerTotalSizeAuth], msg[headerTotalSizeAuth:], nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// MarshalAuthResponse creates a response message to the auth.
|
// MarshalAuthResponse creates a response message to the auth.
|
||||||
@ -227,7 +233,7 @@ func UnmarshalAuthMsg(msg []byte) ([]byte, []byte, error) {
|
|||||||
// servers.
|
// servers.
|
||||||
func MarshalAuthResponse(address string) ([]byte, error) {
|
func MarshalAuthResponse(address string) ([]byte, error) {
|
||||||
ab := []byte(address)
|
ab := []byte(address)
|
||||||
msg := make([]byte, SizeOfProtoHeader, SizeOfProtoHeader+headerSizeAuthResp+len(ab))
|
msg := make([]byte, sizeOfProtoHeader, sizeOfProtoHeader+len(ab))
|
||||||
|
|
||||||
msg[0] = byte(CurrentProtocolVersion)
|
msg[0] = byte(CurrentProtocolVersion)
|
||||||
msg[1] = byte(MsgTypeAuthResponse)
|
msg[1] = byte(MsgTypeAuthResponse)
|
||||||
@ -243,39 +249,34 @@ func MarshalAuthResponse(address string) ([]byte, error) {
|
|||||||
|
|
||||||
// UnmarshalAuthResponse it is a confirmation message to auth success
|
// UnmarshalAuthResponse it is a confirmation message to auth success
|
||||||
func UnmarshalAuthResponse(msg []byte) (string, error) {
|
func UnmarshalAuthResponse(msg []byte) (string, error) {
|
||||||
if len(msg) < headerSizeAuthResp+1 {
|
if len(msg) < sizeOfProtoHeader+1 {
|
||||||
return "", ErrInvalidMessageLength
|
return "", ErrInvalidMessageLength
|
||||||
}
|
}
|
||||||
return string(msg), nil
|
return string(msg[sizeOfProtoHeader:]), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// MarshalCloseMsg creates a close message.
|
// MarshalCloseMsg creates a close message.
|
||||||
// The close message is used to close the connection gracefully between the client and the server. The server and the
|
// The close message is used to close the connection gracefully between the client and the server. The server and the
|
||||||
// client can send this message. After receiving this message, the server or client will close the connection.
|
// client can send this message. After receiving this message, the server or client will close the connection.
|
||||||
func MarshalCloseMsg() []byte {
|
func MarshalCloseMsg() []byte {
|
||||||
msg := make([]byte, SizeOfProtoHeader)
|
return []byte{
|
||||||
|
byte(CurrentProtocolVersion),
|
||||||
msg[0] = byte(CurrentProtocolVersion)
|
byte(MsgTypeClose),
|
||||||
msg[1] = byte(MsgTypeClose)
|
}
|
||||||
|
|
||||||
return msg
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// MarshalTransportMsg creates a transport message.
|
// MarshalTransportMsg creates a transport message.
|
||||||
// The transport message is used to exchange data between peers. The message contains the data to be exchanged and the
|
// The transport message is used to exchange data between peers. The message contains the data to be exchanged and the
|
||||||
// destination peer hashed ID.
|
// destination peer hashed ID.
|
||||||
func MarshalTransportMsg(peerID []byte, payload []byte) ([]byte, error) {
|
func MarshalTransportMsg(peerID, payload []byte) ([]byte, error) {
|
||||||
if len(peerID) != IDSize {
|
if len(peerID) != IDSize {
|
||||||
return nil, fmt.Errorf("invalid peerID length: %d", len(peerID))
|
return nil, fmt.Errorf("invalid peerID length: %d", len(peerID))
|
||||||
}
|
}
|
||||||
|
|
||||||
msg := make([]byte, SizeOfProtoHeader+headerSizeTransport, SizeOfProtoHeader+headerSizeTransport+len(payload))
|
msg := make([]byte, headerTotalSizeTransport, headerTotalSizeTransport+len(payload))
|
||||||
|
|
||||||
msg[0] = byte(CurrentProtocolVersion)
|
msg[0] = byte(CurrentProtocolVersion)
|
||||||
msg[1] = byte(MsgTypeTransport)
|
msg[1] = byte(MsgTypeTransport)
|
||||||
|
copy(msg[sizeOfProtoHeader:], peerID)
|
||||||
copy(msg[SizeOfProtoHeader:], peerID)
|
|
||||||
|
|
||||||
msg = append(msg, payload...)
|
msg = append(msg, payload...)
|
||||||
|
|
||||||
return msg, nil
|
return msg, nil
|
||||||
@ -283,29 +284,29 @@ func MarshalTransportMsg(peerID []byte, payload []byte) ([]byte, error) {
|
|||||||
|
|
||||||
// UnmarshalTransportMsg extracts the peerID and the payload from the transport message.
|
// UnmarshalTransportMsg extracts the peerID and the payload from the transport message.
|
||||||
func UnmarshalTransportMsg(buf []byte) ([]byte, []byte, error) {
|
func UnmarshalTransportMsg(buf []byte) ([]byte, []byte, error) {
|
||||||
if len(buf) < headerSizeTransport {
|
if len(buf) < headerTotalSizeTransport {
|
||||||
return nil, nil, ErrInvalidMessageLength
|
return nil, nil, ErrInvalidMessageLength
|
||||||
}
|
}
|
||||||
|
|
||||||
return buf[:headerSizeTransport], buf[headerSizeTransport:], nil
|
return buf[offsetTransportID:headerTotalSizeTransport], buf[headerTotalSizeTransport:], nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// UnmarshalTransportID extracts the peerID from the transport message.
|
// UnmarshalTransportID extracts the peerID from the transport message.
|
||||||
func UnmarshalTransportID(buf []byte) ([]byte, error) {
|
func UnmarshalTransportID(buf []byte) ([]byte, error) {
|
||||||
if len(buf) < headerSizeTransport {
|
if len(buf) < headerTotalSizeTransport {
|
||||||
return nil, ErrInvalidMessageLength
|
return nil, ErrInvalidMessageLength
|
||||||
}
|
}
|
||||||
return buf[:headerSizeTransport], nil
|
return buf[offsetTransportID:headerTotalSizeTransport], nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// UpdateTransportMsg updates the peerID in the transport message.
|
// UpdateTransportMsg updates the peerID in the transport message.
|
||||||
// With this function the server can reuse the given byte slice to update the peerID in the transport message. So do
|
// With this function the server can reuse the given byte slice to update the peerID in the transport message. So do
|
||||||
// need to allocate a new byte slice.
|
// need to allocate a new byte slice.
|
||||||
func UpdateTransportMsg(msg []byte, peerID []byte) error {
|
func UpdateTransportMsg(msg []byte, peerID []byte) error {
|
||||||
if len(msg) < len(peerID) {
|
if len(msg) < offsetTransportID+len(peerID) {
|
||||||
return ErrInvalidMessageLength
|
return ErrInvalidMessageLength
|
||||||
}
|
}
|
||||||
copy(msg, peerID)
|
copy(msg[offsetTransportID:], peerID)
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -6,12 +6,21 @@ import (
|
|||||||
|
|
||||||
func TestMarshalHelloMsg(t *testing.T) {
|
func TestMarshalHelloMsg(t *testing.T) {
|
||||||
peerID := []byte("abdFAaBcawquEiCMzAabYosuUaGLtSNhKxz+")
|
peerID := []byte("abdFAaBcawquEiCMzAabYosuUaGLtSNhKxz+")
|
||||||
bHello, err := MarshalHelloMsg(peerID, nil)
|
msg, err := MarshalHelloMsg(peerID, nil)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("error: %v", err)
|
t.Fatalf("error: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
receivedPeerID, _, err := UnmarshalHelloMsg(bHello[SizeOfProtoHeader:])
|
msgType, err := DetermineClientMessageType(msg)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("error: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if msgType != MsgTypeHello {
|
||||||
|
t.Errorf("expected %d, got %d", MsgTypeHello, msgType)
|
||||||
|
}
|
||||||
|
|
||||||
|
receivedPeerID, _, err := UnmarshalHelloMsg(msg)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("error: %v", err)
|
t.Fatalf("error: %v", err)
|
||||||
}
|
}
|
||||||
@ -22,12 +31,21 @@ func TestMarshalHelloMsg(t *testing.T) {
|
|||||||
|
|
||||||
func TestMarshalAuthMsg(t *testing.T) {
|
func TestMarshalAuthMsg(t *testing.T) {
|
||||||
peerID := []byte("abdFAaBcawquEiCMzAabYosuUaGLtSNhKxz+")
|
peerID := []byte("abdFAaBcawquEiCMzAabYosuUaGLtSNhKxz+")
|
||||||
bHello, err := MarshalAuthMsg(peerID, []byte{})
|
msg, err := MarshalAuthMsg(peerID, []byte{})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("error: %v", err)
|
t.Fatalf("error: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
receivedPeerID, _, err := UnmarshalAuthMsg(bHello[SizeOfProtoHeader:])
|
msgType, err := DetermineClientMessageType(msg)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("error: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if msgType != MsgTypeAuth {
|
||||||
|
t.Errorf("expected %d, got %d", MsgTypeAuth, msgType)
|
||||||
|
}
|
||||||
|
|
||||||
|
receivedPeerID, _, err := UnmarshalAuthMsg(msg)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("error: %v", err)
|
t.Fatalf("error: %v", err)
|
||||||
}
|
}
|
||||||
@ -36,6 +54,31 @@ func TestMarshalAuthMsg(t *testing.T) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestMarshalAuthResponse(t *testing.T) {
|
||||||
|
address := "myaddress"
|
||||||
|
msg, err := MarshalAuthResponse(address)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("error: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
msgType, err := DetermineServerMessageType(msg)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("error: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if msgType != MsgTypeAuthResponse {
|
||||||
|
t.Errorf("expected %d, got %d", MsgTypeAuthResponse, msgType)
|
||||||
|
}
|
||||||
|
|
||||||
|
respAddr, err := UnmarshalAuthResponse(msg)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("error: %v", err)
|
||||||
|
}
|
||||||
|
if respAddr != address {
|
||||||
|
t.Errorf("expected %s, got %s", address, respAddr)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func TestMarshalTransportMsg(t *testing.T) {
|
func TestMarshalTransportMsg(t *testing.T) {
|
||||||
peerID := []byte("abdFAaBcawquEiCMzAabYosuUaGLtSNhKxz+")
|
peerID := []byte("abdFAaBcawquEiCMzAabYosuUaGLtSNhKxz+")
|
||||||
payload := []byte("payload")
|
payload := []byte("payload")
|
||||||
@ -44,7 +87,25 @@ func TestMarshalTransportMsg(t *testing.T) {
|
|||||||
t.Fatalf("error: %v", err)
|
t.Fatalf("error: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
id, respPayload, err := UnmarshalTransportMsg(msg[SizeOfProtoHeader:])
|
msgType, err := DetermineClientMessageType(msg)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("error: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if msgType != MsgTypeTransport {
|
||||||
|
t.Errorf("expected %d, got %d", MsgTypeTransport, msgType)
|
||||||
|
}
|
||||||
|
|
||||||
|
uPeerID, err := UnmarshalTransportID(msg)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("failed to unmarshal transport id: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if string(uPeerID) != string(peerID) {
|
||||||
|
t.Errorf("expected %s, got %s", peerID, uPeerID)
|
||||||
|
}
|
||||||
|
|
||||||
|
id, respPayload, err := UnmarshalTransportMsg(msg)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("error: %v", err)
|
t.Fatalf("error: %v", err)
|
||||||
}
|
}
|
||||||
@ -57,3 +118,21 @@ func TestMarshalTransportMsg(t *testing.T) {
|
|||||||
t.Errorf("expected %s, got %s", payload, respPayload)
|
t.Errorf("expected %s, got %s", payload, respPayload)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestMarshalHealthcheck(t *testing.T) {
|
||||||
|
msg := MarshalHealthcheck()
|
||||||
|
|
||||||
|
_, err := ValidateVersion(msg)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("error: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
msgType, err := DetermineServerMessageType(msg)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("error: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if msgType != MsgTypeHealthCheck {
|
||||||
|
t.Errorf("expected %d, got %d", MsgTypeHealthCheck, msgType)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
@ -68,12 +68,14 @@ func (h *handshake) handshakeReceive() ([]byte, error) {
|
|||||||
return nil, fmt.Errorf("read from %s: %w", h.conn.RemoteAddr(), err)
|
return nil, fmt.Errorf("read from %s: %w", h.conn.RemoteAddr(), err)
|
||||||
}
|
}
|
||||||
|
|
||||||
_, err = messages.ValidateVersion(buf[:n])
|
buf = buf[:n]
|
||||||
|
|
||||||
|
_, err = messages.ValidateVersion(buf)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("validate version from %s: %w", h.conn.RemoteAddr(), err)
|
return nil, fmt.Errorf("validate version from %s: %w", h.conn.RemoteAddr(), err)
|
||||||
}
|
}
|
||||||
|
|
||||||
msgType, err := messages.DetermineClientMessageType(buf[messages.SizeOfVersionByte:n])
|
msgType, err := messages.DetermineClientMessageType(buf)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("determine message type from %s: %w", h.conn.RemoteAddr(), err)
|
return nil, fmt.Errorf("determine message type from %s: %w", h.conn.RemoteAddr(), err)
|
||||||
}
|
}
|
||||||
@ -85,10 +87,10 @@ func (h *handshake) handshakeReceive() ([]byte, error) {
|
|||||||
switch msgType {
|
switch msgType {
|
||||||
//nolint:staticcheck
|
//nolint:staticcheck
|
||||||
case messages.MsgTypeHello:
|
case messages.MsgTypeHello:
|
||||||
bytePeerID, peerID, err = h.handleHelloMsg(buf[messages.SizeOfProtoHeader:n])
|
bytePeerID, peerID, err = h.handleHelloMsg(buf)
|
||||||
case messages.MsgTypeAuth:
|
case messages.MsgTypeAuth:
|
||||||
h.handshakeMethodAuth = true
|
h.handshakeMethodAuth = true
|
||||||
bytePeerID, peerID, err = h.handleAuthMsg(buf[messages.SizeOfProtoHeader:n])
|
bytePeerID, peerID, err = h.handleAuthMsg(buf)
|
||||||
default:
|
default:
|
||||||
return nil, fmt.Errorf("invalid message type %d from %s", msgType, h.conn.RemoteAddr())
|
return nil, fmt.Errorf("invalid message type %d from %s", msgType, h.conn.RemoteAddr())
|
||||||
}
|
}
|
||||||
|
@ -84,7 +84,7 @@ func (p *Peer) Work() {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
msgType, err := messages.DetermineClientMessageType(msg[messages.SizeOfVersionByte:])
|
msgType, err := messages.DetermineClientMessageType(msg)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
p.log.Errorf("failed to determine message type: %s", err)
|
p.log.Errorf("failed to determine message type: %s", err)
|
||||||
return
|
return
|
||||||
@ -191,7 +191,7 @@ func (p *Peer) handleHealthcheckEvents(ctx context.Context, hc *healthcheck.Send
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (p *Peer) handleTransportMsg(msg []byte) {
|
func (p *Peer) handleTransportMsg(msg []byte) {
|
||||||
peerID, err := messages.UnmarshalTransportID(msg[messages.SizeOfProtoHeader:])
|
peerID, err := messages.UnmarshalTransportID(msg)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
p.log.Errorf("failed to unmarshal transport message: %s", err)
|
p.log.Errorf("failed to unmarshal transport message: %s", err)
|
||||||
return
|
return
|
||||||
@ -204,7 +204,7 @@ func (p *Peer) handleTransportMsg(msg []byte) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
err = messages.UpdateTransportMsg(msg[messages.SizeOfProtoHeader:], p.idB)
|
err = messages.UpdateTransportMsg(msg, p.idB)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
p.log.Errorf("failed to update transport message: %s", err)
|
p.log.Errorf("failed to update transport message: %s", err)
|
||||||
return
|
return
|
||||||
|
Loading…
x
Reference in New Issue
Block a user