- Add sha prefix for peer id in protocol

- Add magic cookie in hello msg
- Add tests
This commit is contained in:
Zoltán Papp 2024-06-25 17:36:04 +02:00
parent 0a67f5be1a
commit 085d072b17
6 changed files with 107 additions and 22 deletions

View File

@ -371,7 +371,11 @@ func (c *Client) writeTo(id string, dstID []byte, payload []byte) (int, error) {
return 0, io.EOF return 0, io.EOF
} }
*/ */
msg := messages.MarshalTransportMsg(dstID, payload) msg, err := messages.MarshalTransportMsg(dstID, payload)
if err != nil {
log.Errorf("failed to marshal transport message: %s", err)
return 0, err
}
n, err := c.relayConn.Write(msg) n, err := c.relayConn.Write(msg)
if err != nil { if err != nil {
log.Errorf("failed to write transport message: %s", err) log.Errorf("failed to write transport message: %s", err)

View File

@ -194,6 +194,7 @@ func (m *Manager) isForeignServer(address string) (bool, error) {
if err != nil { if err != nil {
return false, fmt.Errorf("relay client not connected") return false, fmt.Errorf("relay client not connected")
} }
log.Debugf("check if foreign server: %s != %s", rAddr.String(), address)
return rAddr.String() != address, nil return rAddr.String() != address, nil
} }

View File

@ -3,18 +3,25 @@ package messages
import ( import (
"crypto/sha256" "crypto/sha256"
"encoding/base64" "encoding/base64"
"fmt"
) )
const ( const (
IDSize = sha256.Size prefixLength = 4
IDSize = sha256.Size + 4 // 4 is equal with len(prefix)
)
var (
prefix = []byte("sha-") // 4 bytes
) )
func HashID(peerID string) ([]byte, string) { func HashID(peerID string) ([]byte, string) {
idHash := sha256.Sum256([]byte(peerID)) idHash := sha256.Sum256([]byte(peerID))
idHashString := base64.StdEncoding.EncodeToString(idHash[:]) idHashString := string(prefix) + base64.StdEncoding.EncodeToString(idHash[:])
return idHash[:], idHashString prefixedHash := append(prefix, idHash[:]...)
return prefixedHash, idHashString
} }
func HashIDToString(idHash []byte) string { func HashIDToString(idHash []byte) string {
return base64.StdEncoding.EncodeToString(idHash) return fmt.Sprintf("%s%s", idHash[:prefixLength], base64.StdEncoding.EncodeToString(idHash[prefixLength:]))
} }

26
relay/messages/id_test.go Normal file
View File

@ -0,0 +1,26 @@
package messages
import (
"encoding/binary"
"testing"
log "github.com/sirupsen/logrus"
)
func TestHashID(t *testing.T) {
hashedID, hashedStringId := HashID("abc")
enc := HashIDToString(hashedID)
if enc != hashedStringId {
t.Errorf("expected %s, got %s", hashedStringId, enc)
}
var magicHeader uint32 = 0x2112A442 // size 4 byte
msg := make([]byte, 4)
binary.BigEndian.PutUint32(msg, magicHeader)
magicHeader2 := []byte{0x21, 0x12, 0xA4, 0x42}
log.Infof("msg: %v, %v", msg, magicHeader2)
}

View File

@ -1,6 +1,7 @@
package messages package messages
import ( import (
"bytes"
"fmt" "fmt"
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
@ -11,10 +12,15 @@ const (
MsgTypeHelloResponse MsgType = 1 MsgTypeHelloResponse MsgType = 1
MsgTypeTransport MsgType = 2 MsgTypeTransport MsgType = 2
MsgClose MsgType = 3 MsgClose MsgType = 3
headerSizeTransport = 1 + IDSize // 1 byte for msg type, IDSize for peerID
headerSizeHello = 1 + 4 + IDSize // 1 byte for msg type, 4 byte for magic header, IDSize for peerID
) )
var ( var (
ErrInvalidMessageLength = fmt.Errorf("invalid message length") ErrInvalidMessageLength = fmt.Errorf("invalid message length")
magicHeader = []byte{0x21, 0x12, 0xA4, 0x42}
) )
type MsgType byte type MsgType byte
@ -35,7 +41,6 @@ func (m MsgType) String() string {
} }
func DetermineClientMsgType(msg []byte) (MsgType, error) { func DetermineClientMsgType(msg []byte) (MsgType, error) {
// todo: validate magic byte
msgType := MsgType(msg[0]) msgType := MsgType(msg[0])
switch msgType { switch msgType {
case MsgTypeHello: case MsgTypeHello:
@ -50,7 +55,6 @@ func DetermineClientMsgType(msg []byte) (MsgType, error) {
} }
func DetermineServerMsgType(msg []byte) (MsgType, error) { func DetermineServerMsgType(msg []byte) (MsgType, error) {
// todo: validate magic byte
msgType := MsgType(msg[0]) msgType := MsgType(msg[0])
switch msgType { switch msgType {
case MsgTypeHelloResponse: case MsgTypeHelloResponse:
@ -67,19 +71,21 @@ func DetermineServerMsgType(msg []byte) (MsgType, error) {
// MarshalHelloMsg initial hello message // MarshalHelloMsg initial hello message
func MarshalHelloMsg(peerID []byte) ([]byte, error) { func MarshalHelloMsg(peerID []byte) ([]byte, error) {
if len(peerID) != IDSize { if len(peerID) != IDSize {
return nil, fmt.Errorf("invalid peerID length") return nil, fmt.Errorf("invalid peerID length: %d", len(peerID))
} }
msg := make([]byte, 1, 1+len(peerID)) msg := make([]byte, 5, headerSizeHello)
msg[0] = byte(MsgTypeHello) msg[0] = byte(MsgTypeHello)
copy(msg[1:5], magicHeader)
msg = append(msg, peerID...) msg = append(msg, peerID...)
return msg, nil return msg, nil
} }
func UnmarshalHelloMsg(msg []byte) ([]byte, error) { func UnmarshalHelloMsg(msg []byte) ([]byte, error) {
if len(msg) < 2 { if len(msg) < headerSizeHello {
return nil, fmt.Errorf("invalid 'hello' messge") return nil, fmt.Errorf("invalid 'hello' messge")
} }
return msg[1:], nil bytes.Equal(msg[1:5], magicHeader)
return msg[5:], nil
} }
func MarshalHelloResponse() []byte { func MarshalHelloResponse() []byte {
@ -98,34 +104,32 @@ func MarshalCloseMsg() []byte {
// Transport message // Transport message
func MarshalTransportMsg(peerID []byte, payload []byte) []byte { func MarshalTransportMsg(peerID []byte, payload []byte) ([]byte, error) {
if len(peerID) != IDSize { if len(peerID) != IDSize {
return nil return nil, fmt.Errorf("invalid peerID length: %d", len(peerID))
} }
msg := make([]byte, 1+IDSize, 1+IDSize+len(payload)) msg := make([]byte, headerSizeTransport, headerSizeTransport+len(payload))
msg[0] = byte(MsgTypeTransport) msg[0] = byte(MsgTypeTransport)
copy(msg[1:], peerID) copy(msg[1:], peerID)
msg = append(msg, payload...) msg = append(msg, payload...)
return msg return msg, nil
} }
func UnmarshalTransportMsg(buf []byte) ([]byte, []byte, error) { func UnmarshalTransportMsg(buf []byte) ([]byte, []byte, error) {
headerSize := 1 + IDSize if len(buf) < headerSizeTransport {
if len(buf) < headerSize {
return nil, nil, ErrInvalidMessageLength return nil, nil, ErrInvalidMessageLength
} }
return buf[1:headerSize], buf[headerSize:], nil return buf[1:headerSizeTransport], buf[headerSizeTransport:], nil
} }
func UnmarshalTransportID(buf []byte) ([]byte, error) { func UnmarshalTransportID(buf []byte) ([]byte, error) {
headerSize := 1 + IDSize if len(buf) < headerSizeTransport {
if len(buf) < headerSize { log.Debugf("invalid message length: %d, expected: %d, %x", len(buf), headerSizeTransport, buf)
log.Debugf("invalid message length: %d, expected: %d, %x", len(buf), headerSize, buf)
return nil, ErrInvalidMessageLength return nil, ErrInvalidMessageLength
} }
return buf[1:headerSize], nil return buf[1:headerSizeTransport], nil
} }
func UpdateTransportMsg(msg []byte, peerID []byte) error { func UpdateTransportMsg(msg []byte, peerID []byte) error {

View File

@ -0,0 +1,43 @@
package messages
import (
"testing"
)
func TestMarshalHelloMsg(t *testing.T) {
peerID := []byte("abdFAaBcawquEiCMzAabYosuUaGLtSNhKxz+")
bHello, err := MarshalHelloMsg(peerID)
if err != nil {
t.Fatalf("error: %v", err)
}
receivedPeerID, err := UnmarshalHelloMsg(bHello)
if err != nil {
t.Fatalf("error: %v", err)
}
if string(receivedPeerID) != string(peerID) {
t.Errorf("expected %s, got %s", peerID, receivedPeerID)
}
}
func TestMarshalTransportMsg(t *testing.T) {
peerID := []byte("abdFAaBcawquEiCMzAabYosuUaGLtSNhKxz+")
payload := []byte("payload")
msg, err := MarshalTransportMsg(peerID, payload)
if err != nil {
t.Fatalf("error: %v", err)
}
id, respPayload, err := UnmarshalTransportMsg(msg)
if err != nil {
t.Fatalf("error: %v", err)
}
if string(id) != string(peerID) {
t.Errorf("expected %s, got %s", peerID, id)
}
if string(respPayload) != string(payload) {
t.Errorf("expected %s, got %s", payload, respPayload)
}
}