diff --git a/relay/client/client.go b/relay/client/client.go index db5252f50..bccd85c93 100644 --- a/relay/client/client.go +++ b/relay/client/client.go @@ -306,7 +306,7 @@ func (c *Client) handShake() error { return fmt.Errorf("validate version: %w", err) } - msgType, err := messages.DetermineServerMessageType(buf[messages.SizeOfVersionByte:n]) + msgType, err := messages.DetermineServerMessageType(buf[:n]) if err != nil { c.log.Errorf("failed to determine message type: %s", err) return err @@ -317,7 +317,7 @@ func (c *Client) handShake() error { return fmt.Errorf("unexpected message type") } - addr, err := messages.UnmarshalAuthResponse(buf[messages.SizeOfProtoHeader:n]) + addr, err := messages.UnmarshalAuthResponse(buf[:n]) if err != nil { return err } @@ -348,24 +348,27 @@ func (c *Client) readLoop(relayConn net.Conn) { c.log.Debugf("failed to read message from relay server: %s", errExit) } c.mu.Unlock() + c.bufPool.Put(bufPtr) break } - _, err := messages.ValidateVersion(buf[:n]) + buf = buf[:n] + + _, err := messages.ValidateVersion(buf) if err != nil { c.log.Errorf("failed to validate protocol version: %s", err) c.bufPool.Put(bufPtr) continue } - msgType, err := messages.DetermineServerMessageType(buf[messages.SizeOfVersionByte:n]) + msgType, err := messages.DetermineServerMessageType(buf) if err != nil { c.log.Errorf("failed to determine message type: %s", err) c.bufPool.Put(bufPtr) continue } - if !c.handleMsg(msgType, buf[messages.SizeOfProtoHeader:n], bufPtr, hc, internallyStoppedFlag) { + if !c.handleMsg(msgType, buf, bufPtr, hc, internallyStoppedFlag) { break } } diff --git a/relay/messages/message.go b/relay/messages/message.go index 39ca0aa90..7794c57bc 100644 --- a/relay/messages/message.go +++ b/relay/messages/message.go @@ -23,20 +23,26 @@ const ( MsgTypeAuth = 6 MsgTypeAuthResponse = 7 - SizeOfVersionByte = 1 - SizeOfMsgType = 1 + // base size of the message + sizeOfVersionByte = 1 + sizeOfMsgType = 1 + sizeOfProtoHeader = sizeOfVersionByte + sizeOfMsgType - SizeOfProtoHeader = SizeOfVersionByte + SizeOfMsgType - - sizeOfMagicByte = 4 - - headerSizeTransport = IDSize + // auth message + sizeOfMagicByte = 4 + headerSizeAuth = sizeOfMagicByte + IDSize + offsetMagicByte = sizeOfProtoHeader + offsetAuthPeerID = sizeOfProtoHeader + sizeOfMagicByte + headerTotalSizeAuth = sizeOfProtoHeader + headerSizeAuth + // hello message headerSizeHello = sizeOfMagicByte + IDSize headerSizeHelloResp = 0 - headerSizeAuth = sizeOfMagicByte + IDSize - headerSizeAuthResp = 0 + // transport + headerSizeTransport = IDSize + offsetTransportID = sizeOfProtoHeader + headerTotalSizeTransport = sizeOfProtoHeader + headerSizeTransport ) var ( @@ -73,7 +79,7 @@ func (m MsgType) String() string { // ValidateVersion checks if the given version is supported by the protocol func ValidateVersion(msg []byte) (int, error) { - if len(msg) < SizeOfVersionByte { + if len(msg) < sizeOfProtoHeader { return 0, ErrInvalidMessageLength } version := int(msg[0]) @@ -85,11 +91,11 @@ func ValidateVersion(msg []byte) (int, error) { // DetermineClientMessageType determines the message type from the first the message func DetermineClientMessageType(msg []byte) (MsgType, error) { - if len(msg) < SizeOfMsgType { + if len(msg) < sizeOfProtoHeader { return 0, ErrInvalidMessageLength } - msgType := MsgType(msg[0]) + msgType := MsgType(msg[1]) switch msgType { case MsgTypeHello, @@ -105,11 +111,11 @@ func DetermineClientMessageType(msg []byte) (MsgType, error) { // DetermineServerMessageType determines the message type from the first the message func DetermineServerMessageType(msg []byte) (MsgType, error) { - if len(msg) < SizeOfMsgType { + if len(msg) < sizeOfProtoHeader { return 0, ErrInvalidMessageLength } - msgType := MsgType(msg[0]) + msgType := MsgType(msg[1]) switch msgType { case MsgTypeHelloResponse, @@ -134,12 +140,12 @@ func MarshalHelloMsg(peerID []byte, additions []byte) ([]byte, error) { return nil, fmt.Errorf("invalid peerID length: %d", len(peerID)) } - msg := make([]byte, SizeOfProtoHeader+sizeOfMagicByte, SizeOfProtoHeader+headerSizeHello+len(additions)) + msg := make([]byte, sizeOfProtoHeader+sizeOfMagicByte, sizeOfProtoHeader+headerSizeHello+len(additions)) msg[0] = byte(CurrentProtocolVersion) msg[1] = byte(MsgTypeHello) - copy(msg[SizeOfProtoHeader:SizeOfProtoHeader+sizeOfMagicByte], magicHeader) + copy(msg[sizeOfProtoHeader:sizeOfProtoHeader+sizeOfMagicByte], magicHeader) msg = append(msg, peerID...) msg = append(msg, additions...) @@ -151,14 +157,14 @@ func MarshalHelloMsg(peerID []byte, additions []byte) ([]byte, error) { // UnmarshalHelloMsg extracts peerID and the additional data from the hello message. The Additional data is used to // authenticate the client with the server. func UnmarshalHelloMsg(msg []byte) ([]byte, []byte, error) { - if len(msg) < headerSizeHello { + if len(msg) < sizeOfProtoHeader+headerSizeHello { return nil, nil, ErrInvalidMessageLength } - if !bytes.Equal(msg[:sizeOfMagicByte], magicHeader) { + if !bytes.Equal(msg[sizeOfProtoHeader:sizeOfProtoHeader+sizeOfMagicByte], magicHeader) { return nil, nil, errors.New("invalid magic header") } - return msg[sizeOfMagicByte:headerSizeHello], msg[headerSizeHello:], nil + return msg[sizeOfProtoHeader+sizeOfMagicByte : sizeOfProtoHeader+headerSizeHello], msg[headerSizeHello:], nil } // Deprecated: Use MarshalAuthResponse instead. @@ -167,7 +173,7 @@ func UnmarshalHelloMsg(msg []byte) ([]byte, []byte, error) { // instance URL. This URL will be used by choose the common Relay server in case if the peers are in different Relay // servers. func MarshalHelloResponse(additionalData []byte) ([]byte, error) { - msg := make([]byte, SizeOfProtoHeader, SizeOfProtoHeader+headerSizeHelloResp+len(additionalData)) + msg := make([]byte, sizeOfProtoHeader, sizeOfProtoHeader+headerSizeHelloResp+len(additionalData)) msg[0] = byte(CurrentProtocolVersion) msg[1] = byte(MsgTypeHelloResponse) @@ -180,7 +186,7 @@ func MarshalHelloResponse(additionalData []byte) ([]byte, error) { // Deprecated: Use UnmarshalAuthResponse instead. // UnmarshalHelloResponse extracts the additional data from the hello response message. func UnmarshalHelloResponse(msg []byte) ([]byte, error) { - if len(msg) < headerSizeHelloResp { + if len(msg) < sizeOfProtoHeader+headerSizeHelloResp { return nil, ErrInvalidMessageLength } return msg, nil @@ -196,12 +202,12 @@ func MarshalAuthMsg(peerID []byte, authPayload []byte) ([]byte, error) { return nil, fmt.Errorf("invalid peerID length: %d", len(peerID)) } - msg := make([]byte, SizeOfProtoHeader+sizeOfMagicByte, SizeOfProtoHeader+headerSizeAuth+len(authPayload)) + msg := make([]byte, sizeOfProtoHeader+sizeOfMagicByte, headerTotalSizeAuth+len(authPayload)) msg[0] = byte(CurrentProtocolVersion) msg[1] = byte(MsgTypeAuth) - copy(msg[SizeOfProtoHeader:SizeOfProtoHeader+sizeOfMagicByte], magicHeader) + copy(msg[sizeOfProtoHeader:], magicHeader) msg = append(msg, peerID...) msg = append(msg, authPayload...) @@ -211,14 +217,14 @@ func MarshalAuthMsg(peerID []byte, authPayload []byte) ([]byte, error) { // UnmarshalAuthMsg extracts peerID and the auth payload from the message func UnmarshalAuthMsg(msg []byte) ([]byte, []byte, error) { - if len(msg) < headerSizeAuth { + if len(msg) < headerTotalSizeAuth { return nil, nil, ErrInvalidMessageLength } - if !bytes.Equal(msg[:sizeOfMagicByte], magicHeader) { + if !bytes.Equal(msg[offsetMagicByte:offsetMagicByte+sizeOfMagicByte], magicHeader) { return nil, nil, errors.New("invalid magic header") } - return msg[sizeOfMagicByte:headerSizeAuth], msg[headerSizeAuth:], nil + return msg[offsetAuthPeerID:headerTotalSizeAuth], msg[headerTotalSizeAuth:], nil } // MarshalAuthResponse creates a response message to the auth. @@ -227,7 +233,7 @@ func UnmarshalAuthMsg(msg []byte) ([]byte, []byte, error) { // servers. func MarshalAuthResponse(address string) ([]byte, error) { ab := []byte(address) - msg := make([]byte, SizeOfProtoHeader, SizeOfProtoHeader+headerSizeAuthResp+len(ab)) + msg := make([]byte, sizeOfProtoHeader, sizeOfProtoHeader+len(ab)) msg[0] = byte(CurrentProtocolVersion) msg[1] = byte(MsgTypeAuthResponse) @@ -243,39 +249,34 @@ func MarshalAuthResponse(address string) ([]byte, error) { // UnmarshalAuthResponse it is a confirmation message to auth success func UnmarshalAuthResponse(msg []byte) (string, error) { - if len(msg) < headerSizeAuthResp+1 { + if len(msg) < sizeOfProtoHeader+1 { return "", ErrInvalidMessageLength } - return string(msg), nil + return string(msg[sizeOfProtoHeader:]), nil } // MarshalCloseMsg creates a close message. // The close message is used to close the connection gracefully between the client and the server. The server and the // client can send this message. After receiving this message, the server or client will close the connection. func MarshalCloseMsg() []byte { - msg := make([]byte, SizeOfProtoHeader) - - msg[0] = byte(CurrentProtocolVersion) - msg[1] = byte(MsgTypeClose) - - return msg + return []byte{ + byte(CurrentProtocolVersion), + byte(MsgTypeClose), + } } // MarshalTransportMsg creates a transport message. // The transport message is used to exchange data between peers. The message contains the data to be exchanged and the // destination peer hashed ID. -func MarshalTransportMsg(peerID []byte, payload []byte) ([]byte, error) { +func MarshalTransportMsg(peerID, payload []byte) ([]byte, error) { if len(peerID) != IDSize { return nil, fmt.Errorf("invalid peerID length: %d", len(peerID)) } - msg := make([]byte, SizeOfProtoHeader+headerSizeTransport, SizeOfProtoHeader+headerSizeTransport+len(payload)) - + msg := make([]byte, headerTotalSizeTransport, headerTotalSizeTransport+len(payload)) msg[0] = byte(CurrentProtocolVersion) msg[1] = byte(MsgTypeTransport) - - copy(msg[SizeOfProtoHeader:], peerID) - + copy(msg[sizeOfProtoHeader:], peerID) msg = append(msg, payload...) return msg, nil @@ -283,29 +284,29 @@ func MarshalTransportMsg(peerID []byte, payload []byte) ([]byte, error) { // UnmarshalTransportMsg extracts the peerID and the payload from the transport message. func UnmarshalTransportMsg(buf []byte) ([]byte, []byte, error) { - if len(buf) < headerSizeTransport { + if len(buf) < headerTotalSizeTransport { return nil, nil, ErrInvalidMessageLength } - return buf[:headerSizeTransport], buf[headerSizeTransport:], nil + return buf[offsetTransportID:headerTotalSizeTransport], buf[headerTotalSizeTransport:], nil } // UnmarshalTransportID extracts the peerID from the transport message. func UnmarshalTransportID(buf []byte) ([]byte, error) { - if len(buf) < headerSizeTransport { + if len(buf) < headerTotalSizeTransport { return nil, ErrInvalidMessageLength } - return buf[:headerSizeTransport], nil + return buf[offsetTransportID:headerTotalSizeTransport], nil } // UpdateTransportMsg updates the peerID in the transport message. // With this function the server can reuse the given byte slice to update the peerID in the transport message. So do // need to allocate a new byte slice. func UpdateTransportMsg(msg []byte, peerID []byte) error { - if len(msg) < len(peerID) { + if len(msg) < offsetTransportID+len(peerID) { return ErrInvalidMessageLength } - copy(msg, peerID) + copy(msg[offsetTransportID:], peerID) return nil } diff --git a/relay/messages/message_test.go b/relay/messages/message_test.go index 6e917da71..19bede07b 100644 --- a/relay/messages/message_test.go +++ b/relay/messages/message_test.go @@ -6,12 +6,21 @@ import ( func TestMarshalHelloMsg(t *testing.T) { peerID := []byte("abdFAaBcawquEiCMzAabYosuUaGLtSNhKxz+") - bHello, err := MarshalHelloMsg(peerID, nil) + msg, err := MarshalHelloMsg(peerID, nil) if err != nil { t.Fatalf("error: %v", err) } - receivedPeerID, _, err := UnmarshalHelloMsg(bHello[SizeOfProtoHeader:]) + msgType, err := DetermineClientMessageType(msg) + if err != nil { + t.Fatalf("error: %v", err) + } + + if msgType != MsgTypeHello { + t.Errorf("expected %d, got %d", MsgTypeHello, msgType) + } + + receivedPeerID, _, err := UnmarshalHelloMsg(msg) if err != nil { t.Fatalf("error: %v", err) } @@ -22,12 +31,21 @@ func TestMarshalHelloMsg(t *testing.T) { func TestMarshalAuthMsg(t *testing.T) { peerID := []byte("abdFAaBcawquEiCMzAabYosuUaGLtSNhKxz+") - bHello, err := MarshalAuthMsg(peerID, []byte{}) + msg, err := MarshalAuthMsg(peerID, []byte{}) if err != nil { t.Fatalf("error: %v", err) } - receivedPeerID, _, err := UnmarshalAuthMsg(bHello[SizeOfProtoHeader:]) + msgType, err := DetermineClientMessageType(msg) + if err != nil { + t.Fatalf("error: %v", err) + } + + if msgType != MsgTypeAuth { + t.Errorf("expected %d, got %d", MsgTypeAuth, msgType) + } + + receivedPeerID, _, err := UnmarshalAuthMsg(msg) if err != nil { t.Fatalf("error: %v", err) } @@ -36,6 +54,31 @@ func TestMarshalAuthMsg(t *testing.T) { } } +func TestMarshalAuthResponse(t *testing.T) { + address := "myaddress" + msg, err := MarshalAuthResponse(address) + if err != nil { + t.Fatalf("error: %v", err) + } + + msgType, err := DetermineServerMessageType(msg) + if err != nil { + t.Fatalf("error: %v", err) + } + + if msgType != MsgTypeAuthResponse { + t.Errorf("expected %d, got %d", MsgTypeAuthResponse, msgType) + } + + respAddr, err := UnmarshalAuthResponse(msg) + if err != nil { + t.Fatalf("error: %v", err) + } + if respAddr != address { + t.Errorf("expected %s, got %s", address, respAddr) + } +} + func TestMarshalTransportMsg(t *testing.T) { peerID := []byte("abdFAaBcawquEiCMzAabYosuUaGLtSNhKxz+") payload := []byte("payload") @@ -44,7 +87,25 @@ func TestMarshalTransportMsg(t *testing.T) { t.Fatalf("error: %v", err) } - id, respPayload, err := UnmarshalTransportMsg(msg[SizeOfProtoHeader:]) + msgType, err := DetermineClientMessageType(msg) + if err != nil { + t.Fatalf("error: %v", err) + } + + if msgType != MsgTypeTransport { + t.Errorf("expected %d, got %d", MsgTypeTransport, msgType) + } + + uPeerID, err := UnmarshalTransportID(msg) + if err != nil { + t.Fatalf("failed to unmarshal transport id: %v", err) + } + + if string(uPeerID) != string(peerID) { + t.Errorf("expected %s, got %s", peerID, uPeerID) + } + + id, respPayload, err := UnmarshalTransportMsg(msg) if err != nil { t.Fatalf("error: %v", err) } @@ -57,3 +118,21 @@ func TestMarshalTransportMsg(t *testing.T) { t.Errorf("expected %s, got %s", payload, respPayload) } } + +func TestMarshalHealthcheck(t *testing.T) { + msg := MarshalHealthcheck() + + _, err := ValidateVersion(msg) + if err != nil { + t.Fatalf("error: %v", err) + } + + msgType, err := DetermineServerMessageType(msg) + if err != nil { + t.Fatalf("error: %v", err) + } + + if msgType != MsgTypeHealthCheck { + t.Errorf("expected %d, got %d", MsgTypeHealthCheck, msgType) + } +} diff --git a/relay/server/handshake.go b/relay/server/handshake.go index 0257300f8..babd6f955 100644 --- a/relay/server/handshake.go +++ b/relay/server/handshake.go @@ -68,12 +68,14 @@ func (h *handshake) handshakeReceive() ([]byte, error) { return nil, fmt.Errorf("read from %s: %w", h.conn.RemoteAddr(), err) } - _, err = messages.ValidateVersion(buf[:n]) + buf = buf[:n] + + _, err = messages.ValidateVersion(buf) if err != nil { return nil, fmt.Errorf("validate version from %s: %w", h.conn.RemoteAddr(), err) } - msgType, err := messages.DetermineClientMessageType(buf[messages.SizeOfVersionByte:n]) + msgType, err := messages.DetermineClientMessageType(buf) if err != nil { return nil, fmt.Errorf("determine message type from %s: %w", h.conn.RemoteAddr(), err) } @@ -85,10 +87,10 @@ func (h *handshake) handshakeReceive() ([]byte, error) { switch msgType { //nolint:staticcheck case messages.MsgTypeHello: - bytePeerID, peerID, err = h.handleHelloMsg(buf[messages.SizeOfProtoHeader:n]) + bytePeerID, peerID, err = h.handleHelloMsg(buf) case messages.MsgTypeAuth: h.handshakeMethodAuth = true - bytePeerID, peerID, err = h.handleAuthMsg(buf[messages.SizeOfProtoHeader:n]) + bytePeerID, peerID, err = h.handleAuthMsg(buf) default: return nil, fmt.Errorf("invalid message type %d from %s", msgType, h.conn.RemoteAddr()) } diff --git a/relay/server/peer.go b/relay/server/peer.go index f65fb786a..aa9790f63 100644 --- a/relay/server/peer.go +++ b/relay/server/peer.go @@ -84,7 +84,7 @@ func (p *Peer) Work() { return } - msgType, err := messages.DetermineClientMessageType(msg[messages.SizeOfVersionByte:]) + msgType, err := messages.DetermineClientMessageType(msg) if err != nil { p.log.Errorf("failed to determine message type: %s", err) return @@ -191,7 +191,7 @@ func (p *Peer) handleHealthcheckEvents(ctx context.Context, hc *healthcheck.Send } func (p *Peer) handleTransportMsg(msg []byte) { - peerID, err := messages.UnmarshalTransportID(msg[messages.SizeOfProtoHeader:]) + peerID, err := messages.UnmarshalTransportID(msg) if err != nil { p.log.Errorf("failed to unmarshal transport message: %s", err) return @@ -204,7 +204,7 @@ func (p *Peer) handleTransportMsg(msg []byte) { return } - err = messages.UpdateTransportMsg(msg[messages.SizeOfProtoHeader:], p.idB) + err = messages.UpdateTransportMsg(msg, p.idB) if err != nil { p.log.Errorf("failed to update transport message: %s", err) return