mirror of
https://github.com/netbirdio/netbird.git
synced 2024-12-04 14:03:35 +01:00
Move header offset calculation to private values
This commit is contained in:
parent
56badd7535
commit
b87173f47d
@ -261,7 +261,7 @@ func (c *Client) handShake() error {
|
||||
return fmt.Errorf("validate version: %w", err)
|
||||
}
|
||||
|
||||
msgType, err := messages.DetermineServerMessageType(buf[messages.SizeOfVersionByte:n])
|
||||
msgType, err := messages.DetermineServerMessageType(buf[:n])
|
||||
if err != nil {
|
||||
log.Errorf("failed to determine message type: %s", err)
|
||||
return err
|
||||
@ -272,7 +272,7 @@ func (c *Client) handShake() error {
|
||||
return fmt.Errorf("unexpected message type")
|
||||
}
|
||||
|
||||
addr, err := messages.UnmarshalAuthResponse(buf[messages.SizeOfProtoHeader:n])
|
||||
addr, err := messages.UnmarshalAuthResponse(buf[:n])
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@ -312,14 +312,14 @@ func (c *Client) readLoop(relayConn net.Conn) {
|
||||
continue
|
||||
}
|
||||
|
||||
msgType, err := messages.DetermineServerMessageType(buf[messages.SizeOfVersionByte:n])
|
||||
msgType, err := messages.DetermineServerMessageType(buf[:n])
|
||||
if err != nil {
|
||||
c.log.Errorf("failed to determine message type: %s", err)
|
||||
c.bufPool.Put(bufPtr)
|
||||
continue
|
||||
}
|
||||
|
||||
if !c.handleMsg(msgType, buf[messages.SizeOfProtoHeader:n], bufPtr, hc, internallyStoppedFlag) {
|
||||
if !c.handleMsg(msgType, buf[:n], bufPtr, hc, internallyStoppedFlag) {
|
||||
break
|
||||
}
|
||||
}
|
||||
|
@ -23,20 +23,20 @@ const (
|
||||
MsgTypeAuth = 6
|
||||
MsgTypeAuthResponse = 7
|
||||
|
||||
SizeOfVersionByte = 1
|
||||
SizeOfMsgType = 1
|
||||
sizeOfVersionByte = 1
|
||||
sizeOfMsgType = 1
|
||||
|
||||
SizeOfProtoHeader = SizeOfVersionByte + SizeOfMsgType
|
||||
sizeOfCommonHeader = sizeOfVersionByte + sizeOfMsgType
|
||||
|
||||
sizeOfMagicByte = 4
|
||||
|
||||
headerSizeTransport = IDSize
|
||||
headerSizeTransport = sizeOfCommonHeader + IDSize
|
||||
|
||||
headerSizeHello = sizeOfMagicByte + IDSize
|
||||
headerSizeHelloResp = 0
|
||||
headerSizeHello = sizeOfCommonHeader + sizeOfMagicByte + IDSize
|
||||
headerSizeHelloResp = sizeOfCommonHeader + sizeOfCommonHeader
|
||||
|
||||
headerSizeAuth = sizeOfMagicByte + IDSize
|
||||
headerSizeAuthResp = 0
|
||||
headerSizeAuth = sizeOfCommonHeader + sizeOfMagicByte + IDSize
|
||||
headerSizeAuthResp = sizeOfCommonHeader
|
||||
)
|
||||
|
||||
var (
|
||||
@ -73,7 +73,7 @@ func (m MsgType) String() string {
|
||||
|
||||
// ValidateVersion checks if the given version is supported by the protocol
|
||||
func ValidateVersion(msg []byte) (int, error) {
|
||||
if len(msg) < SizeOfVersionByte {
|
||||
if len(msg) < sizeOfCommonHeader {
|
||||
return 0, ErrInvalidMessageLength
|
||||
}
|
||||
version := int(msg[0])
|
||||
@ -85,11 +85,11 @@ func ValidateVersion(msg []byte) (int, error) {
|
||||
|
||||
// DetermineClientMessageType determines the message type from the first the message
|
||||
func DetermineClientMessageType(msg []byte) (MsgType, error) {
|
||||
if len(msg) < SizeOfMsgType {
|
||||
if len(msg) < sizeOfCommonHeader {
|
||||
return 0, ErrInvalidMessageLength
|
||||
}
|
||||
|
||||
msgType := MsgType(msg[0])
|
||||
msgType := MsgType(msg[1])
|
||||
switch msgType {
|
||||
case
|
||||
MsgTypeHello,
|
||||
@ -105,11 +105,11 @@ func DetermineClientMessageType(msg []byte) (MsgType, error) {
|
||||
|
||||
// DetermineServerMessageType determines the message type from the first the message
|
||||
func DetermineServerMessageType(msg []byte) (MsgType, error) {
|
||||
if len(msg) < SizeOfMsgType {
|
||||
if len(msg) < sizeOfCommonHeader {
|
||||
return 0, ErrInvalidMessageLength
|
||||
}
|
||||
|
||||
msgType := MsgType(msg[0])
|
||||
msgType := MsgType(msg[1])
|
||||
switch msgType {
|
||||
case
|
||||
MsgTypeHelloResponse,
|
||||
@ -134,12 +134,12 @@ func MarshalHelloMsg(peerID []byte, additions []byte) ([]byte, error) {
|
||||
return nil, fmt.Errorf("invalid peerID length: %d", len(peerID))
|
||||
}
|
||||
|
||||
msg := make([]byte, SizeOfProtoHeader+sizeOfMagicByte, SizeOfProtoHeader+headerSizeHello+len(additions))
|
||||
msg := make([]byte, sizeOfCommonHeader+sizeOfMagicByte, sizeOfCommonHeader+headerSizeHello+len(additions))
|
||||
|
||||
msg[0] = byte(CurrentProtocolVersion)
|
||||
msg[1] = byte(MsgTypeHello)
|
||||
|
||||
copy(msg[SizeOfProtoHeader:SizeOfProtoHeader+sizeOfMagicByte], magicHeader)
|
||||
copy(msg[sizeOfCommonHeader:sizeOfCommonHeader+sizeOfMagicByte], magicHeader)
|
||||
|
||||
msg = append(msg, peerID...)
|
||||
msg = append(msg, additions...)
|
||||
@ -154,11 +154,11 @@ func UnmarshalHelloMsg(msg []byte) ([]byte, []byte, error) {
|
||||
if len(msg) < headerSizeHello {
|
||||
return nil, nil, ErrInvalidMessageLength
|
||||
}
|
||||
if !bytes.Equal(msg[:sizeOfMagicByte], magicHeader) {
|
||||
if !bytes.Equal(msg[sizeOfCommonHeader:sizeOfCommonHeader+sizeOfMagicByte], magicHeader) {
|
||||
return nil, nil, errors.New("invalid magic header")
|
||||
}
|
||||
|
||||
return msg[sizeOfMagicByte:headerSizeHello], msg[headerSizeHello:], nil
|
||||
return msg[sizeOfCommonHeader+sizeOfMagicByte : headerSizeHello], msg[headerSizeHello:], nil
|
||||
}
|
||||
|
||||
// Deprecated: Use MarshalAuthResponse instead.
|
||||
@ -167,7 +167,7 @@ func UnmarshalHelloMsg(msg []byte) ([]byte, []byte, error) {
|
||||
// instance URL. This URL will be used by choose the common Relay server in case if the peers are in different Relay
|
||||
// servers.
|
||||
func MarshalHelloResponse(additionalData []byte) ([]byte, error) {
|
||||
msg := make([]byte, SizeOfProtoHeader, SizeOfProtoHeader+headerSizeHelloResp+len(additionalData))
|
||||
msg := make([]byte, headerSizeHelloResp, headerSizeHelloResp+len(additionalData))
|
||||
|
||||
msg[0] = byte(CurrentProtocolVersion)
|
||||
msg[1] = byte(MsgTypeHelloResponse)
|
||||
@ -196,12 +196,12 @@ func MarshalAuthMsg(peerID []byte, authPayload []byte) ([]byte, error) {
|
||||
return nil, fmt.Errorf("invalid peerID length: %d", len(peerID))
|
||||
}
|
||||
|
||||
msg := make([]byte, SizeOfProtoHeader+sizeOfMagicByte, SizeOfProtoHeader+headerSizeAuth+len(authPayload))
|
||||
msg := make([]byte, sizeOfCommonHeader+sizeOfMagicByte, sizeOfCommonHeader+headerSizeAuth+len(authPayload))
|
||||
|
||||
msg[0] = byte(CurrentProtocolVersion)
|
||||
msg[1] = byte(MsgTypeAuth)
|
||||
|
||||
copy(msg[SizeOfProtoHeader:SizeOfProtoHeader+sizeOfMagicByte], magicHeader)
|
||||
copy(msg[sizeOfCommonHeader:sizeOfCommonHeader+sizeOfMagicByte], magicHeader)
|
||||
|
||||
msg = append(msg, peerID...)
|
||||
msg = append(msg, authPayload...)
|
||||
@ -227,7 +227,7 @@ func UnmarshalAuthMsg(msg []byte) ([]byte, []byte, error) {
|
||||
// servers.
|
||||
func MarshalAuthResponse(address string) ([]byte, error) {
|
||||
ab := []byte(address)
|
||||
msg := make([]byte, SizeOfProtoHeader, SizeOfProtoHeader+headerSizeAuthResp+len(ab))
|
||||
msg := make([]byte, sizeOfCommonHeader, headerSizeAuthResp+len(ab))
|
||||
|
||||
msg[0] = byte(CurrentProtocolVersion)
|
||||
msg[1] = byte(MsgTypeAuthResponse)
|
||||
@ -239,7 +239,7 @@ func MarshalAuthResponse(address string) ([]byte, error) {
|
||||
|
||||
// UnmarshalAuthResponse it is a confirmation message to auth success
|
||||
func UnmarshalAuthResponse(msg []byte) (string, error) {
|
||||
if len(msg) < headerSizeAuthResp+1 {
|
||||
if len(msg) < headerSizeAuthResp+1 { // +1 is the minimum expected size of the address
|
||||
return "", ErrInvalidMessageLength
|
||||
}
|
||||
return string(msg), nil
|
||||
@ -249,7 +249,7 @@ func UnmarshalAuthResponse(msg []byte) (string, error) {
|
||||
// The close message is used to close the connection gracefully between the client and the server. The server and the
|
||||
// client can send this message. After receiving this message, the server or client will close the connection.
|
||||
func MarshalCloseMsg() []byte {
|
||||
msg := make([]byte, SizeOfProtoHeader)
|
||||
msg := make([]byte, sizeOfCommonHeader)
|
||||
|
||||
msg[0] = byte(CurrentProtocolVersion)
|
||||
msg[1] = byte(MsgTypeClose)
|
||||
@ -265,12 +265,12 @@ func MarshalTransportMsg(peerID []byte, payload []byte) ([]byte, error) {
|
||||
return nil, fmt.Errorf("invalid peerID length: %d", len(peerID))
|
||||
}
|
||||
|
||||
msg := make([]byte, SizeOfProtoHeader+headerSizeTransport, SizeOfProtoHeader+headerSizeTransport+len(payload))
|
||||
msg := make([]byte, headerSizeTransport, headerSizeTransport+len(payload))
|
||||
|
||||
msg[0] = byte(CurrentProtocolVersion)
|
||||
msg[1] = byte(MsgTypeTransport)
|
||||
|
||||
copy(msg[SizeOfProtoHeader:], peerID)
|
||||
copy(msg[sizeOfCommonHeader:], peerID)
|
||||
|
||||
msg = append(msg, payload...)
|
||||
|
||||
@ -283,7 +283,7 @@ func UnmarshalTransportMsg(buf []byte) ([]byte, []byte, error) {
|
||||
return nil, nil, ErrInvalidMessageLength
|
||||
}
|
||||
|
||||
return buf[:headerSizeTransport], buf[headerSizeTransport:], nil
|
||||
return buf[sizeOfCommonHeader:headerSizeTransport], buf[headerSizeTransport:], nil
|
||||
}
|
||||
|
||||
// UnmarshalTransportID extracts the peerID from the transport message.
|
||||
@ -291,7 +291,7 @@ func UnmarshalTransportID(buf []byte) ([]byte, error) {
|
||||
if len(buf) < headerSizeTransport {
|
||||
return nil, ErrInvalidMessageLength
|
||||
}
|
||||
return buf[:headerSizeTransport], nil
|
||||
return buf[sizeOfCommonHeader:headerSizeTransport], nil
|
||||
}
|
||||
|
||||
// UpdateTransportMsg updates the peerID in the transport message.
|
||||
|
@ -11,13 +11,37 @@ func TestMarshalHelloMsg(t *testing.T) {
|
||||
t.Fatalf("error: %v", err)
|
||||
}
|
||||
|
||||
receivedPeerID, _, err := UnmarshalHelloMsg(bHello[SizeOfProtoHeader:])
|
||||
receivedPeerID, addition, err := UnmarshalHelloMsg(bHello)
|
||||
if err != nil {
|
||||
t.Fatalf("error: %v", err)
|
||||
}
|
||||
if string(receivedPeerID) != string(peerID) {
|
||||
t.Errorf("expected %s, got %s", peerID, receivedPeerID)
|
||||
}
|
||||
|
||||
if len(addition) != 0 {
|
||||
t.Errorf("expected empty addition, got %v", addition)
|
||||
}
|
||||
}
|
||||
|
||||
func TestMarshalAuthMsg(t *testing.T) {
|
||||
peerID := []byte("abdFAaBcawquEiCMzAabYosuUaGLtSNhKxz+")
|
||||
bHello, err := MarshalAuthMsg(peerID, nil)
|
||||
if err != nil {
|
||||
t.Fatalf("error: %v", err)
|
||||
}
|
||||
|
||||
receivedPeerID, addition, err := UnmarshalAuthMsg(bHello)
|
||||
if err != nil {
|
||||
t.Fatalf("error: %v", err)
|
||||
}
|
||||
if string(receivedPeerID) != string(peerID) {
|
||||
t.Errorf("expected %s, got %s", peerID, receivedPeerID)
|
||||
}
|
||||
|
||||
if len(addition) != 0 {
|
||||
t.Errorf("expected empty addition, got %v", addition)
|
||||
}
|
||||
}
|
||||
|
||||
func TestMarshalTransportMsg(t *testing.T) {
|
||||
@ -28,7 +52,15 @@ func TestMarshalTransportMsg(t *testing.T) {
|
||||
t.Fatalf("error: %v", err)
|
||||
}
|
||||
|
||||
id, respPayload, err := UnmarshalTransportMsg(msg[SizeOfProtoHeader:])
|
||||
tid, err := UnmarshalTransportID(msg)
|
||||
if err != nil {
|
||||
t.Fatalf("error: %v", err)
|
||||
}
|
||||
if string(tid) != string(peerID) {
|
||||
t.Errorf("expected %s, got %s", peerID, tid)
|
||||
}
|
||||
|
||||
id, respPayload, err := UnmarshalTransportMsg(msg)
|
||||
if err != nil {
|
||||
t.Fatalf("error: %v", err)
|
||||
}
|
||||
|
@ -68,21 +68,19 @@ func (p *Peer) Work() {
|
||||
return
|
||||
}
|
||||
|
||||
msg := buf[:n]
|
||||
|
||||
_, err = messages.ValidateVersion(msg)
|
||||
_, err = messages.ValidateVersion(buf[:n])
|
||||
if err != nil {
|
||||
p.log.Warnf("failed to validate protocol version: %s", err)
|
||||
return
|
||||
}
|
||||
|
||||
msgType, err := messages.DetermineClientMessageType(msg[messages.SizeOfVersionByte:])
|
||||
msgType, err := messages.DetermineClientMessageType(buf[:n])
|
||||
if err != nil {
|
||||
p.log.Errorf("failed to determine message type: %s", err)
|
||||
return
|
||||
}
|
||||
|
||||
p.handleMsgType(ctx, msgType, hc, n, msg)
|
||||
p.handleMsgType(ctx, msgType, hc, n, buf[:n])
|
||||
}
|
||||
}
|
||||
|
||||
@ -175,7 +173,7 @@ func (p *Peer) handleHealthcheckEvents(ctx context.Context, hc *healthcheck.Send
|
||||
}
|
||||
|
||||
func (p *Peer) handleTransportMsg(msg []byte) {
|
||||
peerID, err := messages.UnmarshalTransportID(msg[messages.SizeOfProtoHeader:])
|
||||
peerID, err := messages.UnmarshalTransportID(msg)
|
||||
if err != nil {
|
||||
p.log.Errorf("failed to unmarshal transport message: %s", err)
|
||||
return
|
||||
@ -188,7 +186,7 @@ func (p *Peer) handleTransportMsg(msg []byte) {
|
||||
return
|
||||
}
|
||||
|
||||
err = messages.UpdateTransportMsg(msg[messages.SizeOfProtoHeader:], p.idB)
|
||||
err = messages.UpdateTransportMsg(msg, p.idB)
|
||||
if err != nil {
|
||||
p.log.Errorf("failed to update transport message: %s", err)
|
||||
return
|
||||
|
@ -164,7 +164,7 @@ func (r *Relay) handshake(conn net.Conn) ([]byte, error) {
|
||||
return nil, fmt.Errorf("validate version from %s: %w", conn.RemoteAddr(), err)
|
||||
}
|
||||
|
||||
msgType, err := messages.DetermineClientMessageType(buf[messages.SizeOfVersionByte:n])
|
||||
msgType, err := messages.DetermineClientMessageType(buf[:n])
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("determine message type from %s: %w", conn.RemoteAddr(), err)
|
||||
}
|
||||
@ -175,9 +175,9 @@ func (r *Relay) handshake(conn net.Conn) ([]byte, error) {
|
||||
)
|
||||
switch msgType {
|
||||
case messages.MsgTypeHello:
|
||||
responseMsg, err = r.handleHelloMsg(buf[messages.SizeOfProtoHeader:n], conn.RemoteAddr())
|
||||
responseMsg, err = r.handleHelloMsg(buf[:n], conn.RemoteAddr())
|
||||
case messages.MsgTypeAuth:
|
||||
responseMsg, err = r.handleAuthMsg(buf[messages.SizeOfProtoHeader:n], conn.RemoteAddr())
|
||||
responseMsg, err = r.handleAuthMsg(buf[:n], conn.RemoteAddr())
|
||||
default:
|
||||
return nil, fmt.Errorf("invalid message type %d from %s", msgType, conn.RemoteAddr())
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user