From b87173f47d0409b686986c79a58e4edaa51717f9 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Zolt=C3=A1n=20Papp?= Date: Tue, 10 Sep 2024 17:06:22 +0200 Subject: [PATCH] Move header offset calculation to private values --- relay/client/client.go | 8 ++--- relay/messages/message.go | 54 +++++++++++++++++----------------- relay/messages/message_test.go | 36 +++++++++++++++++++++-- relay/server/peer.go | 12 ++++---- relay/server/relay.go | 6 ++-- 5 files changed, 73 insertions(+), 43 deletions(-) diff --git a/relay/client/client.go b/relay/client/client.go index 6560c81e1..c3d5561f2 100644 --- a/relay/client/client.go +++ b/relay/client/client.go @@ -261,7 +261,7 @@ func (c *Client) handShake() error { return fmt.Errorf("validate version: %w", err) } - msgType, err := messages.DetermineServerMessageType(buf[messages.SizeOfVersionByte:n]) + msgType, err := messages.DetermineServerMessageType(buf[:n]) if err != nil { log.Errorf("failed to determine message type: %s", err) return err @@ -272,7 +272,7 @@ func (c *Client) handShake() error { return fmt.Errorf("unexpected message type") } - addr, err := messages.UnmarshalAuthResponse(buf[messages.SizeOfProtoHeader:n]) + addr, err := messages.UnmarshalAuthResponse(buf[:n]) if err != nil { return err } @@ -312,14 +312,14 @@ func (c *Client) readLoop(relayConn net.Conn) { continue } - msgType, err := messages.DetermineServerMessageType(buf[messages.SizeOfVersionByte:n]) + msgType, err := messages.DetermineServerMessageType(buf[:n]) if err != nil { c.log.Errorf("failed to determine message type: %s", err) c.bufPool.Put(bufPtr) continue } - if !c.handleMsg(msgType, buf[messages.SizeOfProtoHeader:n], bufPtr, hc, internallyStoppedFlag) { + if !c.handleMsg(msgType, buf[:n], bufPtr, hc, internallyStoppedFlag) { break } } diff --git a/relay/messages/message.go b/relay/messages/message.go index 17b307b43..fe3d5c043 100644 --- a/relay/messages/message.go +++ b/relay/messages/message.go @@ -23,20 +23,20 @@ const ( MsgTypeAuth = 6 MsgTypeAuthResponse = 7 - SizeOfVersionByte = 1 - SizeOfMsgType = 1 + sizeOfVersionByte = 1 + sizeOfMsgType = 1 - SizeOfProtoHeader = SizeOfVersionByte + SizeOfMsgType + sizeOfCommonHeader = sizeOfVersionByte + sizeOfMsgType sizeOfMagicByte = 4 - headerSizeTransport = IDSize + headerSizeTransport = sizeOfCommonHeader + IDSize - headerSizeHello = sizeOfMagicByte + IDSize - headerSizeHelloResp = 0 + headerSizeHello = sizeOfCommonHeader + sizeOfMagicByte + IDSize + headerSizeHelloResp = sizeOfCommonHeader + sizeOfCommonHeader - headerSizeAuth = sizeOfMagicByte + IDSize - headerSizeAuthResp = 0 + headerSizeAuth = sizeOfCommonHeader + sizeOfMagicByte + IDSize + headerSizeAuthResp = sizeOfCommonHeader ) var ( @@ -73,7 +73,7 @@ func (m MsgType) String() string { // ValidateVersion checks if the given version is supported by the protocol func ValidateVersion(msg []byte) (int, error) { - if len(msg) < SizeOfVersionByte { + if len(msg) < sizeOfCommonHeader { return 0, ErrInvalidMessageLength } version := int(msg[0]) @@ -85,11 +85,11 @@ func ValidateVersion(msg []byte) (int, error) { // DetermineClientMessageType determines the message type from the first the message func DetermineClientMessageType(msg []byte) (MsgType, error) { - if len(msg) < SizeOfMsgType { + if len(msg) < sizeOfCommonHeader { return 0, ErrInvalidMessageLength } - msgType := MsgType(msg[0]) + msgType := MsgType(msg[1]) switch msgType { case MsgTypeHello, @@ -105,11 +105,11 @@ func DetermineClientMessageType(msg []byte) (MsgType, error) { // DetermineServerMessageType determines the message type from the first the message func DetermineServerMessageType(msg []byte) (MsgType, error) { - if len(msg) < SizeOfMsgType { + if len(msg) < sizeOfCommonHeader { return 0, ErrInvalidMessageLength } - msgType := MsgType(msg[0]) + msgType := MsgType(msg[1]) switch msgType { case MsgTypeHelloResponse, @@ -134,12 +134,12 @@ func MarshalHelloMsg(peerID []byte, additions []byte) ([]byte, error) { return nil, fmt.Errorf("invalid peerID length: %d", len(peerID)) } - msg := make([]byte, SizeOfProtoHeader+sizeOfMagicByte, SizeOfProtoHeader+headerSizeHello+len(additions)) + msg := make([]byte, sizeOfCommonHeader+sizeOfMagicByte, sizeOfCommonHeader+headerSizeHello+len(additions)) msg[0] = byte(CurrentProtocolVersion) msg[1] = byte(MsgTypeHello) - copy(msg[SizeOfProtoHeader:SizeOfProtoHeader+sizeOfMagicByte], magicHeader) + copy(msg[sizeOfCommonHeader:sizeOfCommonHeader+sizeOfMagicByte], magicHeader) msg = append(msg, peerID...) msg = append(msg, additions...) @@ -154,11 +154,11 @@ func UnmarshalHelloMsg(msg []byte) ([]byte, []byte, error) { if len(msg) < headerSizeHello { return nil, nil, ErrInvalidMessageLength } - if !bytes.Equal(msg[:sizeOfMagicByte], magicHeader) { + if !bytes.Equal(msg[sizeOfCommonHeader:sizeOfCommonHeader+sizeOfMagicByte], magicHeader) { return nil, nil, errors.New("invalid magic header") } - return msg[sizeOfMagicByte:headerSizeHello], msg[headerSizeHello:], nil + return msg[sizeOfCommonHeader+sizeOfMagicByte : headerSizeHello], msg[headerSizeHello:], nil } // Deprecated: Use MarshalAuthResponse instead. @@ -167,7 +167,7 @@ func UnmarshalHelloMsg(msg []byte) ([]byte, []byte, error) { // instance URL. This URL will be used by choose the common Relay server in case if the peers are in different Relay // servers. func MarshalHelloResponse(additionalData []byte) ([]byte, error) { - msg := make([]byte, SizeOfProtoHeader, SizeOfProtoHeader+headerSizeHelloResp+len(additionalData)) + msg := make([]byte, headerSizeHelloResp, headerSizeHelloResp+len(additionalData)) msg[0] = byte(CurrentProtocolVersion) msg[1] = byte(MsgTypeHelloResponse) @@ -196,12 +196,12 @@ func MarshalAuthMsg(peerID []byte, authPayload []byte) ([]byte, error) { return nil, fmt.Errorf("invalid peerID length: %d", len(peerID)) } - msg := make([]byte, SizeOfProtoHeader+sizeOfMagicByte, SizeOfProtoHeader+headerSizeAuth+len(authPayload)) + msg := make([]byte, sizeOfCommonHeader+sizeOfMagicByte, sizeOfCommonHeader+headerSizeAuth+len(authPayload)) msg[0] = byte(CurrentProtocolVersion) msg[1] = byte(MsgTypeAuth) - copy(msg[SizeOfProtoHeader:SizeOfProtoHeader+sizeOfMagicByte], magicHeader) + copy(msg[sizeOfCommonHeader:sizeOfCommonHeader+sizeOfMagicByte], magicHeader) msg = append(msg, peerID...) msg = append(msg, authPayload...) @@ -227,7 +227,7 @@ func UnmarshalAuthMsg(msg []byte) ([]byte, []byte, error) { // servers. func MarshalAuthResponse(address string) ([]byte, error) { ab := []byte(address) - msg := make([]byte, SizeOfProtoHeader, SizeOfProtoHeader+headerSizeAuthResp+len(ab)) + msg := make([]byte, sizeOfCommonHeader, headerSizeAuthResp+len(ab)) msg[0] = byte(CurrentProtocolVersion) msg[1] = byte(MsgTypeAuthResponse) @@ -239,7 +239,7 @@ func MarshalAuthResponse(address string) ([]byte, error) { // UnmarshalAuthResponse it is a confirmation message to auth success func UnmarshalAuthResponse(msg []byte) (string, error) { - if len(msg) < headerSizeAuthResp+1 { + if len(msg) < headerSizeAuthResp+1 { // +1 is the minimum expected size of the address return "", ErrInvalidMessageLength } return string(msg), nil @@ -249,7 +249,7 @@ func UnmarshalAuthResponse(msg []byte) (string, error) { // The close message is used to close the connection gracefully between the client and the server. The server and the // client can send this message. After receiving this message, the server or client will close the connection. func MarshalCloseMsg() []byte { - msg := make([]byte, SizeOfProtoHeader) + msg := make([]byte, sizeOfCommonHeader) msg[0] = byte(CurrentProtocolVersion) msg[1] = byte(MsgTypeClose) @@ -265,12 +265,12 @@ func MarshalTransportMsg(peerID []byte, payload []byte) ([]byte, error) { return nil, fmt.Errorf("invalid peerID length: %d", len(peerID)) } - msg := make([]byte, SizeOfProtoHeader+headerSizeTransport, SizeOfProtoHeader+headerSizeTransport+len(payload)) + msg := make([]byte, headerSizeTransport, headerSizeTransport+len(payload)) msg[0] = byte(CurrentProtocolVersion) msg[1] = byte(MsgTypeTransport) - copy(msg[SizeOfProtoHeader:], peerID) + copy(msg[sizeOfCommonHeader:], peerID) msg = append(msg, payload...) @@ -283,7 +283,7 @@ func UnmarshalTransportMsg(buf []byte) ([]byte, []byte, error) { return nil, nil, ErrInvalidMessageLength } - return buf[:headerSizeTransport], buf[headerSizeTransport:], nil + return buf[sizeOfCommonHeader:headerSizeTransport], buf[headerSizeTransport:], nil } // UnmarshalTransportID extracts the peerID from the transport message. @@ -291,7 +291,7 @@ func UnmarshalTransportID(buf []byte) ([]byte, error) { if len(buf) < headerSizeTransport { return nil, ErrInvalidMessageLength } - return buf[:headerSizeTransport], nil + return buf[sizeOfCommonHeader:headerSizeTransport], nil } // UpdateTransportMsg updates the peerID in the transport message. diff --git a/relay/messages/message_test.go b/relay/messages/message_test.go index a4e7d9fae..6e67df46a 100644 --- a/relay/messages/message_test.go +++ b/relay/messages/message_test.go @@ -11,13 +11,37 @@ func TestMarshalHelloMsg(t *testing.T) { t.Fatalf("error: %v", err) } - receivedPeerID, _, err := UnmarshalHelloMsg(bHello[SizeOfProtoHeader:]) + receivedPeerID, addition, err := UnmarshalHelloMsg(bHello) if err != nil { t.Fatalf("error: %v", err) } if string(receivedPeerID) != string(peerID) { t.Errorf("expected %s, got %s", peerID, receivedPeerID) } + + if len(addition) != 0 { + t.Errorf("expected empty addition, got %v", addition) + } +} + +func TestMarshalAuthMsg(t *testing.T) { + peerID := []byte("abdFAaBcawquEiCMzAabYosuUaGLtSNhKxz+") + bHello, err := MarshalAuthMsg(peerID, nil) + if err != nil { + t.Fatalf("error: %v", err) + } + + receivedPeerID, addition, err := UnmarshalAuthMsg(bHello) + if err != nil { + t.Fatalf("error: %v", err) + } + if string(receivedPeerID) != string(peerID) { + t.Errorf("expected %s, got %s", peerID, receivedPeerID) + } + + if len(addition) != 0 { + t.Errorf("expected empty addition, got %v", addition) + } } func TestMarshalTransportMsg(t *testing.T) { @@ -28,7 +52,15 @@ func TestMarshalTransportMsg(t *testing.T) { t.Fatalf("error: %v", err) } - id, respPayload, err := UnmarshalTransportMsg(msg[SizeOfProtoHeader:]) + tid, err := UnmarshalTransportID(msg) + if err != nil { + t.Fatalf("error: %v", err) + } + if string(tid) != string(peerID) { + t.Errorf("expected %s, got %s", peerID, tid) + } + + id, respPayload, err := UnmarshalTransportMsg(msg) if err != nil { t.Fatalf("error: %v", err) } diff --git a/relay/server/peer.go b/relay/server/peer.go index a9583700a..033756956 100644 --- a/relay/server/peer.go +++ b/relay/server/peer.go @@ -68,21 +68,19 @@ func (p *Peer) Work() { return } - msg := buf[:n] - - _, err = messages.ValidateVersion(msg) + _, err = messages.ValidateVersion(buf[:n]) if err != nil { p.log.Warnf("failed to validate protocol version: %s", err) return } - msgType, err := messages.DetermineClientMessageType(msg[messages.SizeOfVersionByte:]) + msgType, err := messages.DetermineClientMessageType(buf[:n]) if err != nil { p.log.Errorf("failed to determine message type: %s", err) return } - p.handleMsgType(ctx, msgType, hc, n, msg) + p.handleMsgType(ctx, msgType, hc, n, buf[:n]) } } @@ -175,7 +173,7 @@ func (p *Peer) handleHealthcheckEvents(ctx context.Context, hc *healthcheck.Send } func (p *Peer) handleTransportMsg(msg []byte) { - peerID, err := messages.UnmarshalTransportID(msg[messages.SizeOfProtoHeader:]) + peerID, err := messages.UnmarshalTransportID(msg) if err != nil { p.log.Errorf("failed to unmarshal transport message: %s", err) return @@ -188,7 +186,7 @@ func (p *Peer) handleTransportMsg(msg []byte) { return } - err = messages.UpdateTransportMsg(msg[messages.SizeOfProtoHeader:], p.idB) + err = messages.UpdateTransportMsg(msg, p.idB) if err != nil { p.log.Errorf("failed to update transport message: %s", err) return diff --git a/relay/server/relay.go b/relay/server/relay.go index 4dc262904..4271f53f6 100644 --- a/relay/server/relay.go +++ b/relay/server/relay.go @@ -164,7 +164,7 @@ func (r *Relay) handshake(conn net.Conn) ([]byte, error) { return nil, fmt.Errorf("validate version from %s: %w", conn.RemoteAddr(), err) } - msgType, err := messages.DetermineClientMessageType(buf[messages.SizeOfVersionByte:n]) + msgType, err := messages.DetermineClientMessageType(buf[:n]) if err != nil { return nil, fmt.Errorf("determine message type from %s: %w", conn.RemoteAddr(), err) } @@ -175,9 +175,9 @@ func (r *Relay) handshake(conn net.Conn) ([]byte, error) { ) switch msgType { case messages.MsgTypeHello: - responseMsg, err = r.handleHelloMsg(buf[messages.SizeOfProtoHeader:n], conn.RemoteAddr()) + responseMsg, err = r.handleHelloMsg(buf[:n], conn.RemoteAddr()) case messages.MsgTypeAuth: - responseMsg, err = r.handleAuthMsg(buf[messages.SizeOfProtoHeader:n], conn.RemoteAddr()) + responseMsg, err = r.handleAuthMsg(buf[:n], conn.RemoteAddr()) default: return nil, fmt.Errorf("invalid message type %d from %s", msgType, conn.RemoteAddr()) }