mirror of
https://github.com/netbirdio/netbird.git
synced 2025-03-01 00:11:36 +01:00
- Add sha prefix for peer id in protocol
- Add magic cookie in hello msg - Add tests
This commit is contained in:
parent
0a67f5be1a
commit
085d072b17
@ -371,7 +371,11 @@ func (c *Client) writeTo(id string, dstID []byte, payload []byte) (int, error) {
|
||||
return 0, io.EOF
|
||||
}
|
||||
*/
|
||||
msg := messages.MarshalTransportMsg(dstID, payload)
|
||||
msg, err := messages.MarshalTransportMsg(dstID, payload)
|
||||
if err != nil {
|
||||
log.Errorf("failed to marshal transport message: %s", err)
|
||||
return 0, err
|
||||
}
|
||||
n, err := c.relayConn.Write(msg)
|
||||
if err != nil {
|
||||
log.Errorf("failed to write transport message: %s", err)
|
||||
|
@ -194,6 +194,7 @@ func (m *Manager) isForeignServer(address string) (bool, error) {
|
||||
if err != nil {
|
||||
return false, fmt.Errorf("relay client not connected")
|
||||
}
|
||||
log.Debugf("check if foreign server: %s != %s", rAddr.String(), address)
|
||||
return rAddr.String() != address, nil
|
||||
}
|
||||
|
||||
|
@ -3,18 +3,25 @@ package messages
|
||||
import (
|
||||
"crypto/sha256"
|
||||
"encoding/base64"
|
||||
"fmt"
|
||||
)
|
||||
|
||||
const (
|
||||
IDSize = sha256.Size
|
||||
prefixLength = 4
|
||||
IDSize = sha256.Size + 4 // 4 is equal with len(prefix)
|
||||
)
|
||||
|
||||
var (
|
||||
prefix = []byte("sha-") // 4 bytes
|
||||
)
|
||||
|
||||
func HashID(peerID string) ([]byte, string) {
|
||||
idHash := sha256.Sum256([]byte(peerID))
|
||||
idHashString := base64.StdEncoding.EncodeToString(idHash[:])
|
||||
return idHash[:], idHashString
|
||||
idHashString := string(prefix) + base64.StdEncoding.EncodeToString(idHash[:])
|
||||
prefixedHash := append(prefix, idHash[:]...)
|
||||
return prefixedHash, idHashString
|
||||
}
|
||||
|
||||
func HashIDToString(idHash []byte) string {
|
||||
return base64.StdEncoding.EncodeToString(idHash)
|
||||
return fmt.Sprintf("%s%s", idHash[:prefixLength], base64.StdEncoding.EncodeToString(idHash[prefixLength:]))
|
||||
}
|
||||
|
26
relay/messages/id_test.go
Normal file
26
relay/messages/id_test.go
Normal file
@ -0,0 +1,26 @@
|
||||
package messages
|
||||
|
||||
import (
|
||||
"encoding/binary"
|
||||
"testing"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
func TestHashID(t *testing.T) {
|
||||
hashedID, hashedStringId := HashID("abc")
|
||||
enc := HashIDToString(hashedID)
|
||||
if enc != hashedStringId {
|
||||
t.Errorf("expected %s, got %s", hashedStringId, enc)
|
||||
}
|
||||
|
||||
var magicHeader uint32 = 0x2112A442 // size 4 byte
|
||||
|
||||
msg := make([]byte, 4)
|
||||
binary.BigEndian.PutUint32(msg, magicHeader)
|
||||
|
||||
magicHeader2 := []byte{0x21, 0x12, 0xA4, 0x42}
|
||||
|
||||
log.Infof("msg: %v, %v", msg, magicHeader2)
|
||||
|
||||
}
|
@ -1,6 +1,7 @@
|
||||
package messages
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"fmt"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
@ -11,10 +12,15 @@ const (
|
||||
MsgTypeHelloResponse MsgType = 1
|
||||
MsgTypeTransport MsgType = 2
|
||||
MsgClose MsgType = 3
|
||||
|
||||
headerSizeTransport = 1 + IDSize // 1 byte for msg type, IDSize for peerID
|
||||
headerSizeHello = 1 + 4 + IDSize // 1 byte for msg type, 4 byte for magic header, IDSize for peerID
|
||||
)
|
||||
|
||||
var (
|
||||
ErrInvalidMessageLength = fmt.Errorf("invalid message length")
|
||||
|
||||
magicHeader = []byte{0x21, 0x12, 0xA4, 0x42}
|
||||
)
|
||||
|
||||
type MsgType byte
|
||||
@ -35,7 +41,6 @@ func (m MsgType) String() string {
|
||||
}
|
||||
|
||||
func DetermineClientMsgType(msg []byte) (MsgType, error) {
|
||||
// todo: validate magic byte
|
||||
msgType := MsgType(msg[0])
|
||||
switch msgType {
|
||||
case MsgTypeHello:
|
||||
@ -50,7 +55,6 @@ func DetermineClientMsgType(msg []byte) (MsgType, error) {
|
||||
}
|
||||
|
||||
func DetermineServerMsgType(msg []byte) (MsgType, error) {
|
||||
// todo: validate magic byte
|
||||
msgType := MsgType(msg[0])
|
||||
switch msgType {
|
||||
case MsgTypeHelloResponse:
|
||||
@ -67,19 +71,21 @@ func DetermineServerMsgType(msg []byte) (MsgType, error) {
|
||||
// MarshalHelloMsg initial hello message
|
||||
func MarshalHelloMsg(peerID []byte) ([]byte, error) {
|
||||
if len(peerID) != IDSize {
|
||||
return nil, fmt.Errorf("invalid peerID length")
|
||||
return nil, fmt.Errorf("invalid peerID length: %d", len(peerID))
|
||||
}
|
||||
msg := make([]byte, 1, 1+len(peerID))
|
||||
msg := make([]byte, 5, headerSizeHello)
|
||||
msg[0] = byte(MsgTypeHello)
|
||||
copy(msg[1:5], magicHeader)
|
||||
msg = append(msg, peerID...)
|
||||
return msg, nil
|
||||
}
|
||||
|
||||
func UnmarshalHelloMsg(msg []byte) ([]byte, error) {
|
||||
if len(msg) < 2 {
|
||||
if len(msg) < headerSizeHello {
|
||||
return nil, fmt.Errorf("invalid 'hello' messge")
|
||||
}
|
||||
return msg[1:], nil
|
||||
bytes.Equal(msg[1:5], magicHeader)
|
||||
return msg[5:], nil
|
||||
}
|
||||
|
||||
func MarshalHelloResponse() []byte {
|
||||
@ -98,34 +104,32 @@ func MarshalCloseMsg() []byte {
|
||||
|
||||
// Transport message
|
||||
|
||||
func MarshalTransportMsg(peerID []byte, payload []byte) []byte {
|
||||
func MarshalTransportMsg(peerID []byte, payload []byte) ([]byte, error) {
|
||||
if len(peerID) != IDSize {
|
||||
return nil
|
||||
return nil, fmt.Errorf("invalid peerID length: %d", len(peerID))
|
||||
}
|
||||
|
||||
msg := make([]byte, 1+IDSize, 1+IDSize+len(payload))
|
||||
msg := make([]byte, headerSizeTransport, headerSizeTransport+len(payload))
|
||||
msg[0] = byte(MsgTypeTransport)
|
||||
copy(msg[1:], peerID)
|
||||
msg = append(msg, payload...)
|
||||
return msg
|
||||
return msg, nil
|
||||
}
|
||||
|
||||
func UnmarshalTransportMsg(buf []byte) ([]byte, []byte, error) {
|
||||
headerSize := 1 + IDSize
|
||||
if len(buf) < headerSize {
|
||||
if len(buf) < headerSizeTransport {
|
||||
return nil, nil, ErrInvalidMessageLength
|
||||
}
|
||||
|
||||
return buf[1:headerSize], buf[headerSize:], nil
|
||||
return buf[1:headerSizeTransport], buf[headerSizeTransport:], nil
|
||||
}
|
||||
|
||||
func UnmarshalTransportID(buf []byte) ([]byte, error) {
|
||||
headerSize := 1 + IDSize
|
||||
if len(buf) < headerSize {
|
||||
log.Debugf("invalid message length: %d, expected: %d, %x", len(buf), headerSize, buf)
|
||||
if len(buf) < headerSizeTransport {
|
||||
log.Debugf("invalid message length: %d, expected: %d, %x", len(buf), headerSizeTransport, buf)
|
||||
return nil, ErrInvalidMessageLength
|
||||
}
|
||||
return buf[1:headerSize], nil
|
||||
return buf[1:headerSizeTransport], nil
|
||||
}
|
||||
|
||||
func UpdateTransportMsg(msg []byte, peerID []byte) error {
|
||||
|
43
relay/messages/message_test.go
Normal file
43
relay/messages/message_test.go
Normal file
@ -0,0 +1,43 @@
|
||||
package messages
|
||||
|
||||
import (
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestMarshalHelloMsg(t *testing.T) {
|
||||
peerID := []byte("abdFAaBcawquEiCMzAabYosuUaGLtSNhKxz+")
|
||||
bHello, err := MarshalHelloMsg(peerID)
|
||||
if err != nil {
|
||||
t.Fatalf("error: %v", err)
|
||||
}
|
||||
|
||||
receivedPeerID, err := UnmarshalHelloMsg(bHello)
|
||||
if err != nil {
|
||||
t.Fatalf("error: %v", err)
|
||||
}
|
||||
if string(receivedPeerID) != string(peerID) {
|
||||
t.Errorf("expected %s, got %s", peerID, receivedPeerID)
|
||||
}
|
||||
}
|
||||
|
||||
func TestMarshalTransportMsg(t *testing.T) {
|
||||
peerID := []byte("abdFAaBcawquEiCMzAabYosuUaGLtSNhKxz+")
|
||||
payload := []byte("payload")
|
||||
msg, err := MarshalTransportMsg(peerID, payload)
|
||||
if err != nil {
|
||||
t.Fatalf("error: %v", err)
|
||||
}
|
||||
|
||||
id, respPayload, err := UnmarshalTransportMsg(msg)
|
||||
if err != nil {
|
||||
t.Fatalf("error: %v", err)
|
||||
}
|
||||
|
||||
if string(id) != string(peerID) {
|
||||
t.Errorf("expected %s, got %s", peerID, id)
|
||||
}
|
||||
|
||||
if string(respPayload) != string(payload) {
|
||||
t.Errorf("expected %s, got %s", payload, respPayload)
|
||||
}
|
||||
}
|
Loading…
Reference in New Issue
Block a user