mirror of
https://github.com/netbirdio/netbird.git
synced 2025-06-21 18:22:37 +02:00
- Add sha prefix for peer id in protocol
- Add magic cookie in hello msg - Add tests
This commit is contained in:
parent
0a67f5be1a
commit
085d072b17
@ -371,7 +371,11 @@ func (c *Client) writeTo(id string, dstID []byte, payload []byte) (int, error) {
|
|||||||
return 0, io.EOF
|
return 0, io.EOF
|
||||||
}
|
}
|
||||||
*/
|
*/
|
||||||
msg := messages.MarshalTransportMsg(dstID, payload)
|
msg, err := messages.MarshalTransportMsg(dstID, payload)
|
||||||
|
if err != nil {
|
||||||
|
log.Errorf("failed to marshal transport message: %s", err)
|
||||||
|
return 0, err
|
||||||
|
}
|
||||||
n, err := c.relayConn.Write(msg)
|
n, err := c.relayConn.Write(msg)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Errorf("failed to write transport message: %s", err)
|
log.Errorf("failed to write transport message: %s", err)
|
||||||
|
@ -194,6 +194,7 @@ func (m *Manager) isForeignServer(address string) (bool, error) {
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
return false, fmt.Errorf("relay client not connected")
|
return false, fmt.Errorf("relay client not connected")
|
||||||
}
|
}
|
||||||
|
log.Debugf("check if foreign server: %s != %s", rAddr.String(), address)
|
||||||
return rAddr.String() != address, nil
|
return rAddr.String() != address, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -3,18 +3,25 @@ package messages
|
|||||||
import (
|
import (
|
||||||
"crypto/sha256"
|
"crypto/sha256"
|
||||||
"encoding/base64"
|
"encoding/base64"
|
||||||
|
"fmt"
|
||||||
)
|
)
|
||||||
|
|
||||||
const (
|
const (
|
||||||
IDSize = sha256.Size
|
prefixLength = 4
|
||||||
|
IDSize = sha256.Size + 4 // 4 is equal with len(prefix)
|
||||||
|
)
|
||||||
|
|
||||||
|
var (
|
||||||
|
prefix = []byte("sha-") // 4 bytes
|
||||||
)
|
)
|
||||||
|
|
||||||
func HashID(peerID string) ([]byte, string) {
|
func HashID(peerID string) ([]byte, string) {
|
||||||
idHash := sha256.Sum256([]byte(peerID))
|
idHash := sha256.Sum256([]byte(peerID))
|
||||||
idHashString := base64.StdEncoding.EncodeToString(idHash[:])
|
idHashString := string(prefix) + base64.StdEncoding.EncodeToString(idHash[:])
|
||||||
return idHash[:], idHashString
|
prefixedHash := append(prefix, idHash[:]...)
|
||||||
|
return prefixedHash, idHashString
|
||||||
}
|
}
|
||||||
|
|
||||||
func HashIDToString(idHash []byte) string {
|
func HashIDToString(idHash []byte) string {
|
||||||
return base64.StdEncoding.EncodeToString(idHash)
|
return fmt.Sprintf("%s%s", idHash[:prefixLength], base64.StdEncoding.EncodeToString(idHash[prefixLength:]))
|
||||||
}
|
}
|
||||||
|
26
relay/messages/id_test.go
Normal file
26
relay/messages/id_test.go
Normal file
@ -0,0 +1,26 @@
|
|||||||
|
package messages
|
||||||
|
|
||||||
|
import (
|
||||||
|
"encoding/binary"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
log "github.com/sirupsen/logrus"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestHashID(t *testing.T) {
|
||||||
|
hashedID, hashedStringId := HashID("abc")
|
||||||
|
enc := HashIDToString(hashedID)
|
||||||
|
if enc != hashedStringId {
|
||||||
|
t.Errorf("expected %s, got %s", hashedStringId, enc)
|
||||||
|
}
|
||||||
|
|
||||||
|
var magicHeader uint32 = 0x2112A442 // size 4 byte
|
||||||
|
|
||||||
|
msg := make([]byte, 4)
|
||||||
|
binary.BigEndian.PutUint32(msg, magicHeader)
|
||||||
|
|
||||||
|
magicHeader2 := []byte{0x21, 0x12, 0xA4, 0x42}
|
||||||
|
|
||||||
|
log.Infof("msg: %v, %v", msg, magicHeader2)
|
||||||
|
|
||||||
|
}
|
@ -1,6 +1,7 @@
|
|||||||
package messages
|
package messages
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"bytes"
|
||||||
"fmt"
|
"fmt"
|
||||||
|
|
||||||
log "github.com/sirupsen/logrus"
|
log "github.com/sirupsen/logrus"
|
||||||
@ -11,10 +12,15 @@ const (
|
|||||||
MsgTypeHelloResponse MsgType = 1
|
MsgTypeHelloResponse MsgType = 1
|
||||||
MsgTypeTransport MsgType = 2
|
MsgTypeTransport MsgType = 2
|
||||||
MsgClose MsgType = 3
|
MsgClose MsgType = 3
|
||||||
|
|
||||||
|
headerSizeTransport = 1 + IDSize // 1 byte for msg type, IDSize for peerID
|
||||||
|
headerSizeHello = 1 + 4 + IDSize // 1 byte for msg type, 4 byte for magic header, IDSize for peerID
|
||||||
)
|
)
|
||||||
|
|
||||||
var (
|
var (
|
||||||
ErrInvalidMessageLength = fmt.Errorf("invalid message length")
|
ErrInvalidMessageLength = fmt.Errorf("invalid message length")
|
||||||
|
|
||||||
|
magicHeader = []byte{0x21, 0x12, 0xA4, 0x42}
|
||||||
)
|
)
|
||||||
|
|
||||||
type MsgType byte
|
type MsgType byte
|
||||||
@ -35,7 +41,6 @@ func (m MsgType) String() string {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func DetermineClientMsgType(msg []byte) (MsgType, error) {
|
func DetermineClientMsgType(msg []byte) (MsgType, error) {
|
||||||
// todo: validate magic byte
|
|
||||||
msgType := MsgType(msg[0])
|
msgType := MsgType(msg[0])
|
||||||
switch msgType {
|
switch msgType {
|
||||||
case MsgTypeHello:
|
case MsgTypeHello:
|
||||||
@ -50,7 +55,6 @@ func DetermineClientMsgType(msg []byte) (MsgType, error) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func DetermineServerMsgType(msg []byte) (MsgType, error) {
|
func DetermineServerMsgType(msg []byte) (MsgType, error) {
|
||||||
// todo: validate magic byte
|
|
||||||
msgType := MsgType(msg[0])
|
msgType := MsgType(msg[0])
|
||||||
switch msgType {
|
switch msgType {
|
||||||
case MsgTypeHelloResponse:
|
case MsgTypeHelloResponse:
|
||||||
@ -67,19 +71,21 @@ func DetermineServerMsgType(msg []byte) (MsgType, error) {
|
|||||||
// MarshalHelloMsg initial hello message
|
// MarshalHelloMsg initial hello message
|
||||||
func MarshalHelloMsg(peerID []byte) ([]byte, error) {
|
func MarshalHelloMsg(peerID []byte) ([]byte, error) {
|
||||||
if len(peerID) != IDSize {
|
if len(peerID) != IDSize {
|
||||||
return nil, fmt.Errorf("invalid peerID length")
|
return nil, fmt.Errorf("invalid peerID length: %d", len(peerID))
|
||||||
}
|
}
|
||||||
msg := make([]byte, 1, 1+len(peerID))
|
msg := make([]byte, 5, headerSizeHello)
|
||||||
msg[0] = byte(MsgTypeHello)
|
msg[0] = byte(MsgTypeHello)
|
||||||
|
copy(msg[1:5], magicHeader)
|
||||||
msg = append(msg, peerID...)
|
msg = append(msg, peerID...)
|
||||||
return msg, nil
|
return msg, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func UnmarshalHelloMsg(msg []byte) ([]byte, error) {
|
func UnmarshalHelloMsg(msg []byte) ([]byte, error) {
|
||||||
if len(msg) < 2 {
|
if len(msg) < headerSizeHello {
|
||||||
return nil, fmt.Errorf("invalid 'hello' messge")
|
return nil, fmt.Errorf("invalid 'hello' messge")
|
||||||
}
|
}
|
||||||
return msg[1:], nil
|
bytes.Equal(msg[1:5], magicHeader)
|
||||||
|
return msg[5:], nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func MarshalHelloResponse() []byte {
|
func MarshalHelloResponse() []byte {
|
||||||
@ -98,34 +104,32 @@ func MarshalCloseMsg() []byte {
|
|||||||
|
|
||||||
// Transport message
|
// Transport message
|
||||||
|
|
||||||
func MarshalTransportMsg(peerID []byte, payload []byte) []byte {
|
func MarshalTransportMsg(peerID []byte, payload []byte) ([]byte, error) {
|
||||||
if len(peerID) != IDSize {
|
if len(peerID) != IDSize {
|
||||||
return nil
|
return nil, fmt.Errorf("invalid peerID length: %d", len(peerID))
|
||||||
}
|
}
|
||||||
|
|
||||||
msg := make([]byte, 1+IDSize, 1+IDSize+len(payload))
|
msg := make([]byte, headerSizeTransport, headerSizeTransport+len(payload))
|
||||||
msg[0] = byte(MsgTypeTransport)
|
msg[0] = byte(MsgTypeTransport)
|
||||||
copy(msg[1:], peerID)
|
copy(msg[1:], peerID)
|
||||||
msg = append(msg, payload...)
|
msg = append(msg, payload...)
|
||||||
return msg
|
return msg, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func UnmarshalTransportMsg(buf []byte) ([]byte, []byte, error) {
|
func UnmarshalTransportMsg(buf []byte) ([]byte, []byte, error) {
|
||||||
headerSize := 1 + IDSize
|
if len(buf) < headerSizeTransport {
|
||||||
if len(buf) < headerSize {
|
|
||||||
return nil, nil, ErrInvalidMessageLength
|
return nil, nil, ErrInvalidMessageLength
|
||||||
}
|
}
|
||||||
|
|
||||||
return buf[1:headerSize], buf[headerSize:], nil
|
return buf[1:headerSizeTransport], buf[headerSizeTransport:], nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func UnmarshalTransportID(buf []byte) ([]byte, error) {
|
func UnmarshalTransportID(buf []byte) ([]byte, error) {
|
||||||
headerSize := 1 + IDSize
|
if len(buf) < headerSizeTransport {
|
||||||
if len(buf) < headerSize {
|
log.Debugf("invalid message length: %d, expected: %d, %x", len(buf), headerSizeTransport, buf)
|
||||||
log.Debugf("invalid message length: %d, expected: %d, %x", len(buf), headerSize, buf)
|
|
||||||
return nil, ErrInvalidMessageLength
|
return nil, ErrInvalidMessageLength
|
||||||
}
|
}
|
||||||
return buf[1:headerSize], nil
|
return buf[1:headerSizeTransport], nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func UpdateTransportMsg(msg []byte, peerID []byte) error {
|
func UpdateTransportMsg(msg []byte, peerID []byte) error {
|
||||||
|
43
relay/messages/message_test.go
Normal file
43
relay/messages/message_test.go
Normal file
@ -0,0 +1,43 @@
|
|||||||
|
package messages
|
||||||
|
|
||||||
|
import (
|
||||||
|
"testing"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestMarshalHelloMsg(t *testing.T) {
|
||||||
|
peerID := []byte("abdFAaBcawquEiCMzAabYosuUaGLtSNhKxz+")
|
||||||
|
bHello, err := MarshalHelloMsg(peerID)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("error: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
receivedPeerID, err := UnmarshalHelloMsg(bHello)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("error: %v", err)
|
||||||
|
}
|
||||||
|
if string(receivedPeerID) != string(peerID) {
|
||||||
|
t.Errorf("expected %s, got %s", peerID, receivedPeerID)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestMarshalTransportMsg(t *testing.T) {
|
||||||
|
peerID := []byte("abdFAaBcawquEiCMzAabYosuUaGLtSNhKxz+")
|
||||||
|
payload := []byte("payload")
|
||||||
|
msg, err := MarshalTransportMsg(peerID, payload)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("error: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
id, respPayload, err := UnmarshalTransportMsg(msg)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("error: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if string(id) != string(peerID) {
|
||||||
|
t.Errorf("expected %s, got %s", peerID, id)
|
||||||
|
}
|
||||||
|
|
||||||
|
if string(respPayload) != string(payload) {
|
||||||
|
t.Errorf("expected %s, got %s", payload, respPayload)
|
||||||
|
}
|
||||||
|
}
|
Loading…
x
Reference in New Issue
Block a user