- Add sha prefix for peer id in protocol

- Add magic cookie in hello msg
- Add tests
This commit is contained in:
Zoltán Papp 2024-06-25 17:36:04 +02:00
parent 0a67f5be1a
commit 085d072b17
6 changed files with 107 additions and 22 deletions

View File

@ -371,7 +371,11 @@ func (c *Client) writeTo(id string, dstID []byte, payload []byte) (int, error) {
return 0, io.EOF
}
*/
msg := messages.MarshalTransportMsg(dstID, payload)
msg, err := messages.MarshalTransportMsg(dstID, payload)
if err != nil {
log.Errorf("failed to marshal transport message: %s", err)
return 0, err
}
n, err := c.relayConn.Write(msg)
if err != nil {
log.Errorf("failed to write transport message: %s", err)

View File

@ -194,6 +194,7 @@ func (m *Manager) isForeignServer(address string) (bool, error) {
if err != nil {
return false, fmt.Errorf("relay client not connected")
}
log.Debugf("check if foreign server: %s != %s", rAddr.String(), address)
return rAddr.String() != address, nil
}

View File

@ -3,18 +3,25 @@ package messages
import (
"crypto/sha256"
"encoding/base64"
"fmt"
)
const (
IDSize = sha256.Size
prefixLength = 4
IDSize = sha256.Size + 4 // 4 is equal with len(prefix)
)
var (
prefix = []byte("sha-") // 4 bytes
)
func HashID(peerID string) ([]byte, string) {
idHash := sha256.Sum256([]byte(peerID))
idHashString := base64.StdEncoding.EncodeToString(idHash[:])
return idHash[:], idHashString
idHashString := string(prefix) + base64.StdEncoding.EncodeToString(idHash[:])
prefixedHash := append(prefix, idHash[:]...)
return prefixedHash, idHashString
}
func HashIDToString(idHash []byte) string {
return base64.StdEncoding.EncodeToString(idHash)
return fmt.Sprintf("%s%s", idHash[:prefixLength], base64.StdEncoding.EncodeToString(idHash[prefixLength:]))
}

26
relay/messages/id_test.go Normal file
View File

@ -0,0 +1,26 @@
package messages
import (
"encoding/binary"
"testing"
log "github.com/sirupsen/logrus"
)
func TestHashID(t *testing.T) {
hashedID, hashedStringId := HashID("abc")
enc := HashIDToString(hashedID)
if enc != hashedStringId {
t.Errorf("expected %s, got %s", hashedStringId, enc)
}
var magicHeader uint32 = 0x2112A442 // size 4 byte
msg := make([]byte, 4)
binary.BigEndian.PutUint32(msg, magicHeader)
magicHeader2 := []byte{0x21, 0x12, 0xA4, 0x42}
log.Infof("msg: %v, %v", msg, magicHeader2)
}

View File

@ -1,6 +1,7 @@
package messages
import (
"bytes"
"fmt"
log "github.com/sirupsen/logrus"
@ -11,10 +12,15 @@ const (
MsgTypeHelloResponse MsgType = 1
MsgTypeTransport MsgType = 2
MsgClose MsgType = 3
headerSizeTransport = 1 + IDSize // 1 byte for msg type, IDSize for peerID
headerSizeHello = 1 + 4 + IDSize // 1 byte for msg type, 4 byte for magic header, IDSize for peerID
)
var (
ErrInvalidMessageLength = fmt.Errorf("invalid message length")
magicHeader = []byte{0x21, 0x12, 0xA4, 0x42}
)
type MsgType byte
@ -35,7 +41,6 @@ func (m MsgType) String() string {
}
func DetermineClientMsgType(msg []byte) (MsgType, error) {
// todo: validate magic byte
msgType := MsgType(msg[0])
switch msgType {
case MsgTypeHello:
@ -50,7 +55,6 @@ func DetermineClientMsgType(msg []byte) (MsgType, error) {
}
func DetermineServerMsgType(msg []byte) (MsgType, error) {
// todo: validate magic byte
msgType := MsgType(msg[0])
switch msgType {
case MsgTypeHelloResponse:
@ -67,19 +71,21 @@ func DetermineServerMsgType(msg []byte) (MsgType, error) {
// MarshalHelloMsg initial hello message
func MarshalHelloMsg(peerID []byte) ([]byte, error) {
if len(peerID) != IDSize {
return nil, fmt.Errorf("invalid peerID length")
return nil, fmt.Errorf("invalid peerID length: %d", len(peerID))
}
msg := make([]byte, 1, 1+len(peerID))
msg := make([]byte, 5, headerSizeHello)
msg[0] = byte(MsgTypeHello)
copy(msg[1:5], magicHeader)
msg = append(msg, peerID...)
return msg, nil
}
func UnmarshalHelloMsg(msg []byte) ([]byte, error) {
if len(msg) < 2 {
if len(msg) < headerSizeHello {
return nil, fmt.Errorf("invalid 'hello' messge")
}
return msg[1:], nil
bytes.Equal(msg[1:5], magicHeader)
return msg[5:], nil
}
func MarshalHelloResponse() []byte {
@ -98,34 +104,32 @@ func MarshalCloseMsg() []byte {
// Transport message
func MarshalTransportMsg(peerID []byte, payload []byte) []byte {
func MarshalTransportMsg(peerID []byte, payload []byte) ([]byte, error) {
if len(peerID) != IDSize {
return nil
return nil, fmt.Errorf("invalid peerID length: %d", len(peerID))
}
msg := make([]byte, 1+IDSize, 1+IDSize+len(payload))
msg := make([]byte, headerSizeTransport, headerSizeTransport+len(payload))
msg[0] = byte(MsgTypeTransport)
copy(msg[1:], peerID)
msg = append(msg, payload...)
return msg
return msg, nil
}
func UnmarshalTransportMsg(buf []byte) ([]byte, []byte, error) {
headerSize := 1 + IDSize
if len(buf) < headerSize {
if len(buf) < headerSizeTransport {
return nil, nil, ErrInvalidMessageLength
}
return buf[1:headerSize], buf[headerSize:], nil
return buf[1:headerSizeTransport], buf[headerSizeTransport:], nil
}
func UnmarshalTransportID(buf []byte) ([]byte, error) {
headerSize := 1 + IDSize
if len(buf) < headerSize {
log.Debugf("invalid message length: %d, expected: %d, %x", len(buf), headerSize, buf)
if len(buf) < headerSizeTransport {
log.Debugf("invalid message length: %d, expected: %d, %x", len(buf), headerSizeTransport, buf)
return nil, ErrInvalidMessageLength
}
return buf[1:headerSize], nil
return buf[1:headerSizeTransport], nil
}
func UpdateTransportMsg(msg []byte, peerID []byte) error {

View File

@ -0,0 +1,43 @@
package messages
import (
"testing"
)
func TestMarshalHelloMsg(t *testing.T) {
peerID := []byte("abdFAaBcawquEiCMzAabYosuUaGLtSNhKxz+")
bHello, err := MarshalHelloMsg(peerID)
if err != nil {
t.Fatalf("error: %v", err)
}
receivedPeerID, err := UnmarshalHelloMsg(bHello)
if err != nil {
t.Fatalf("error: %v", err)
}
if string(receivedPeerID) != string(peerID) {
t.Errorf("expected %s, got %s", peerID, receivedPeerID)
}
}
func TestMarshalTransportMsg(t *testing.T) {
peerID := []byte("abdFAaBcawquEiCMzAabYosuUaGLtSNhKxz+")
payload := []byte("payload")
msg, err := MarshalTransportMsg(peerID, payload)
if err != nil {
t.Fatalf("error: %v", err)
}
id, respPayload, err := UnmarshalTransportMsg(msg)
if err != nil {
t.Fatalf("error: %v", err)
}
if string(id) != string(peerID) {
t.Errorf("expected %s, got %s", peerID, id)
}
if string(respPayload) != string(payload) {
t.Errorf("expected %s, got %s", payload, respPayload)
}
}