mirror of
https://github.com/netbirdio/netbird.git
synced 2024-12-14 19:00:50 +01:00
201 lines
4.6 KiB
Go
201 lines
4.6 KiB
Go
package messages
|
|
|
|
import (
|
|
"bytes"
|
|
"encoding/gob"
|
|
"fmt"
|
|
|
|
log "github.com/sirupsen/logrus"
|
|
)
|
|
|
|
const (
|
|
MsgTypeHello MsgType = 0
|
|
MsgTypeHelloResponse MsgType = 1
|
|
MsgTypeTransport MsgType = 2
|
|
MsgTypeClose MsgType = 3
|
|
MsgTypeHealthCheck MsgType = 4
|
|
|
|
sizeOfMsgType = 1
|
|
sizeOfMagicBye = 4
|
|
headerSizeTransport = sizeOfMsgType + IDSize // 1 byte for msg type, IDSize for peerID
|
|
headerSizeHello = sizeOfMsgType + sizeOfMagicBye + IDSize // 1 byte for msg type, 4 byte for magic header, IDSize for peerID
|
|
|
|
MaxHandshakeSize = 8192
|
|
)
|
|
|
|
var (
|
|
ErrInvalidMessageLength = fmt.Errorf("invalid message length")
|
|
|
|
magicHeader = []byte{0x21, 0x12, 0xA4, 0x42}
|
|
|
|
healthCheckMsg = []byte{byte(MsgTypeHealthCheck)}
|
|
)
|
|
|
|
type MsgType byte
|
|
|
|
func (m MsgType) String() string {
|
|
switch m {
|
|
case MsgTypeHello:
|
|
return "hello"
|
|
case MsgTypeHelloResponse:
|
|
return "hello response"
|
|
case MsgTypeTransport:
|
|
return "transport"
|
|
case MsgTypeClose:
|
|
return "close"
|
|
case MsgTypeHealthCheck:
|
|
return "health check"
|
|
default:
|
|
return "unknown"
|
|
}
|
|
}
|
|
|
|
type HelloResponse struct {
|
|
InstanceAddress string
|
|
}
|
|
|
|
func DetermineClientMsgType(msg []byte) (MsgType, error) {
|
|
msgType := MsgType(msg[0])
|
|
switch msgType {
|
|
case MsgTypeHello:
|
|
return msgType, nil
|
|
case MsgTypeTransport:
|
|
return msgType, nil
|
|
case MsgTypeClose:
|
|
return msgType, nil
|
|
case MsgTypeHealthCheck:
|
|
return msgType, nil
|
|
default:
|
|
return 0, fmt.Errorf("invalid msg type, len: %d", len(msg))
|
|
}
|
|
}
|
|
|
|
func DetermineServerMsgType(msg []byte) (MsgType, error) {
|
|
msgType := MsgType(msg[0])
|
|
switch msgType {
|
|
case MsgTypeHelloResponse:
|
|
return msgType, nil
|
|
case MsgTypeTransport:
|
|
return msgType, nil
|
|
case MsgTypeClose:
|
|
return msgType, nil
|
|
case MsgTypeHealthCheck:
|
|
return msgType, nil
|
|
default:
|
|
return 0, fmt.Errorf("invalid msg type (len: %d)", len(msg))
|
|
}
|
|
}
|
|
|
|
// MarshalHelloMsg initial hello message
|
|
func MarshalHelloMsg(peerID []byte, additions []byte) ([]byte, error) {
|
|
if len(peerID) != IDSize {
|
|
return nil, fmt.Errorf("invalid peerID length: %d", len(peerID))
|
|
}
|
|
|
|
// 5 = 1 byte for msg type, 4 byte for magic header
|
|
msg := make([]byte, 5, headerSizeHello+len(additions))
|
|
msg[0] = byte(MsgTypeHello)
|
|
copy(msg[1:5], magicHeader)
|
|
msg = append(msg, peerID...)
|
|
msg = append(msg, additions...)
|
|
return msg, nil
|
|
}
|
|
|
|
func UnmarshalHelloMsg(msg []byte) ([]byte, []byte, error) {
|
|
if len(msg) < headerSizeHello {
|
|
return nil, nil, fmt.Errorf("invalid 'hello' messge")
|
|
}
|
|
if !bytes.Equal(msg[1:5], magicHeader) {
|
|
return nil, nil, fmt.Errorf("invalid magic header")
|
|
}
|
|
return msg[5 : 5+IDSize], msg[headerSizeHello:], nil
|
|
}
|
|
|
|
func MarshalHelloResponse(DomainAddress string) ([]byte, error) {
|
|
payload := HelloResponse{
|
|
InstanceAddress: DomainAddress,
|
|
}
|
|
|
|
buf := new(bytes.Buffer)
|
|
enc := gob.NewEncoder(buf)
|
|
|
|
err := enc.Encode(payload)
|
|
if err != nil {
|
|
log.Errorf("failed to gob encode hello response: %s", err)
|
|
return nil, err
|
|
}
|
|
|
|
msg := make([]byte, 1, 1+buf.Len())
|
|
msg[0] = byte(MsgTypeHelloResponse)
|
|
msg = append(msg, buf.Bytes()...)
|
|
return msg, nil
|
|
}
|
|
|
|
func UnmarshalHelloResponse(msg []byte) (string, error) {
|
|
if len(msg) < 2 {
|
|
return "", fmt.Errorf("invalid 'hello response' message")
|
|
}
|
|
payload := HelloResponse{}
|
|
buf := bytes.NewBuffer(msg[1:])
|
|
dec := gob.NewDecoder(buf)
|
|
|
|
err := dec.Decode(&payload)
|
|
if err != nil {
|
|
log.Errorf("failed to gob decode hello response: %s", err)
|
|
return "", err
|
|
}
|
|
return payload.InstanceAddress, nil
|
|
}
|
|
|
|
// Close message
|
|
|
|
func MarshalCloseMsg() []byte {
|
|
msg := make([]byte, 1)
|
|
msg[0] = byte(MsgTypeClose)
|
|
return msg
|
|
}
|
|
|
|
// Transport message
|
|
|
|
func MarshalTransportMsg(peerID []byte, payload []byte) ([]byte, error) {
|
|
if len(peerID) != IDSize {
|
|
return nil, fmt.Errorf("invalid peerID length: %d", len(peerID))
|
|
}
|
|
|
|
msg := make([]byte, headerSizeTransport, headerSizeTransport+len(payload))
|
|
msg[0] = byte(MsgTypeTransport)
|
|
copy(msg[1:], peerID)
|
|
msg = append(msg, payload...)
|
|
return msg, nil
|
|
}
|
|
|
|
func UnmarshalTransportMsg(buf []byte) ([]byte, []byte, error) {
|
|
if len(buf) < headerSizeTransport {
|
|
return nil, nil, ErrInvalidMessageLength
|
|
}
|
|
|
|
return buf[1:headerSizeTransport], buf[headerSizeTransport:], nil
|
|
}
|
|
|
|
func UnmarshalTransportID(buf []byte) ([]byte, error) {
|
|
if len(buf) < headerSizeTransport {
|
|
log.Debugf("invalid message length: %d, expected: %d, %x", len(buf), headerSizeTransport, buf)
|
|
return nil, ErrInvalidMessageLength
|
|
}
|
|
return buf[1:headerSizeTransport], nil
|
|
}
|
|
|
|
func UpdateTransportMsg(msg []byte, peerID []byte) error {
|
|
if len(msg) < 1+len(peerID) {
|
|
return ErrInvalidMessageLength
|
|
}
|
|
copy(msg[1:], peerID)
|
|
return nil
|
|
}
|
|
|
|
// health check message
|
|
|
|
func MarshalHealthcheck() []byte {
|
|
return healthCheckMsg
|
|
}
|