mirror of
https://github.com/netbirdio/netbird.git
synced 2024-11-07 08:44:07 +01:00
[relay] Improve relay messages (#2574)
Co-authored-by: Zoltán Papp <zoltan.pmail@gmail.com>
This commit is contained in:
parent
50ebbe482e
commit
2d1bf3982d
@ -4,6 +4,7 @@ import (
|
|||||||
"context"
|
"context"
|
||||||
"crypto/sha1"
|
"crypto/sha1"
|
||||||
"crypto/sha256"
|
"crypto/sha256"
|
||||||
|
"encoding/base64"
|
||||||
"fmt"
|
"fmt"
|
||||||
"sync"
|
"sync"
|
||||||
"time"
|
"time"
|
||||||
@ -12,6 +13,7 @@ import (
|
|||||||
|
|
||||||
"github.com/netbirdio/netbird/management/proto"
|
"github.com/netbirdio/netbird/management/proto"
|
||||||
auth "github.com/netbirdio/netbird/relay/auth/hmac"
|
auth "github.com/netbirdio/netbird/relay/auth/hmac"
|
||||||
|
authv2 "github.com/netbirdio/netbird/relay/auth/hmac/v2"
|
||||||
)
|
)
|
||||||
|
|
||||||
const defaultDuration = 12 * time.Hour
|
const defaultDuration = 12 * time.Hour
|
||||||
@ -30,7 +32,7 @@ type TimeBasedAuthSecretsManager struct {
|
|||||||
turnCfg *TURNConfig
|
turnCfg *TURNConfig
|
||||||
relayCfg *Relay
|
relayCfg *Relay
|
||||||
turnHmacToken *auth.TimedHMAC
|
turnHmacToken *auth.TimedHMAC
|
||||||
relayHmacToken *auth.TimedHMAC
|
relayHmacToken *authv2.Generator
|
||||||
updateManager *PeersUpdateManager
|
updateManager *PeersUpdateManager
|
||||||
turnCancelMap map[string]chan struct{}
|
turnCancelMap map[string]chan struct{}
|
||||||
relayCancelMap map[string]chan struct{}
|
relayCancelMap map[string]chan struct{}
|
||||||
@ -63,7 +65,11 @@ func NewTimeBasedAuthSecretsManager(updateManager *PeersUpdateManager, turnCfg *
|
|||||||
duration = defaultDuration
|
duration = defaultDuration
|
||||||
}
|
}
|
||||||
|
|
||||||
mgr.relayHmacToken = auth.NewTimedHMAC(relayCfg.Secret, duration)
|
hashedSecret := sha256.Sum256([]byte(relayCfg.Secret))
|
||||||
|
var err error
|
||||||
|
if mgr.relayHmacToken, err = authv2.NewGenerator(authv2.AuthAlgoHMACSHA256, hashedSecret[:], duration); err != nil {
|
||||||
|
log.Errorf("failed to create relay token generator: %s", err)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
return mgr
|
return mgr
|
||||||
@ -76,7 +82,7 @@ func (m *TimeBasedAuthSecretsManager) GenerateTurnToken() (*Token, error) {
|
|||||||
}
|
}
|
||||||
turnToken, err := m.turnHmacToken.GenerateToken(sha1.New)
|
turnToken, err := m.turnHmacToken.GenerateToken(sha1.New)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("failed to generate TURN token: %s", err)
|
return nil, fmt.Errorf("generate TURN token: %s", err)
|
||||||
}
|
}
|
||||||
return (*Token)(turnToken), nil
|
return (*Token)(turnToken), nil
|
||||||
}
|
}
|
||||||
@ -86,11 +92,15 @@ func (m *TimeBasedAuthSecretsManager) GenerateRelayToken() (*Token, error) {
|
|||||||
if m.relayHmacToken == nil {
|
if m.relayHmacToken == nil {
|
||||||
return nil, fmt.Errorf("relay configuration is not set")
|
return nil, fmt.Errorf("relay configuration is not set")
|
||||||
}
|
}
|
||||||
relayToken, err := m.relayHmacToken.GenerateToken(sha256.New)
|
relayToken, err := m.relayHmacToken.GenerateToken()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("failed to generate relay token: %s", err)
|
return nil, fmt.Errorf("generate relay token: %s", err)
|
||||||
}
|
}
|
||||||
return (*Token)(relayToken), nil
|
|
||||||
|
return &Token{
|
||||||
|
Payload: string(relayToken.Payload),
|
||||||
|
Signature: base64.StdEncoding.EncodeToString(relayToken.Signature),
|
||||||
|
}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *TimeBasedAuthSecretsManager) cancelTURN(peerID string) {
|
func (m *TimeBasedAuthSecretsManager) cancelTURN(peerID string) {
|
||||||
@ -200,7 +210,7 @@ func (m *TimeBasedAuthSecretsManager) pushNewTURNTokens(ctx context.Context, pee
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (m *TimeBasedAuthSecretsManager) pushNewRelayTokens(ctx context.Context, peerID string) {
|
func (m *TimeBasedAuthSecretsManager) pushNewRelayTokens(ctx context.Context, peerID string) {
|
||||||
relayToken, err := m.relayHmacToken.GenerateToken(sha256.New)
|
relayToken, err := m.relayHmacToken.GenerateToken()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Errorf("failed to generate relay token for peer '%s': %s", peerID, err)
|
log.Errorf("failed to generate relay token for peer '%s': %s", peerID, err)
|
||||||
return
|
return
|
||||||
@ -210,8 +220,8 @@ func (m *TimeBasedAuthSecretsManager) pushNewRelayTokens(ctx context.Context, pe
|
|||||||
WiretrusteeConfig: &proto.WiretrusteeConfig{
|
WiretrusteeConfig: &proto.WiretrusteeConfig{
|
||||||
Relay: &proto.RelayConfig{
|
Relay: &proto.RelayConfig{
|
||||||
Urls: m.relayCfg.Addresses,
|
Urls: m.relayCfg.Addresses,
|
||||||
TokenPayload: relayToken.Payload,
|
TokenPayload: string(relayToken.Payload),
|
||||||
TokenSignature: relayToken.Signature,
|
TokenSignature: base64.StdEncoding.EncodeToString(relayToken.Signature),
|
||||||
},
|
},
|
||||||
// omit Turns to avoid updates there
|
// omit Turns to avoid updates there
|
||||||
},
|
},
|
||||||
|
@ -63,7 +63,8 @@ func TestTimeBasedAuthSecretsManager_GenerateCredentials(t *testing.T) {
|
|||||||
t.Errorf("expected generated relay signature not to be empty, got empty")
|
t.Errorf("expected generated relay signature not to be empty, got empty")
|
||||||
}
|
}
|
||||||
|
|
||||||
validateMAC(t, sha256.New, relayCredentials.Payload, relayCredentials.Signature, []byte(secret))
|
hashedSecret := sha256.Sum256([]byte(secret))
|
||||||
|
validateMAC(t, sha256.New, relayCredentials.Payload, relayCredentials.Signature, hashedSecret[:])
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestTimeBasedAuthSecretsManager_SetupRefresh(t *testing.T) {
|
func TestTimeBasedAuthSecretsManager_SetupRefresh(t *testing.T) {
|
||||||
|
@ -1,12 +1,14 @@
|
|||||||
package allow
|
package allow
|
||||||
|
|
||||||
import "hash"
|
|
||||||
|
|
||||||
// Auth is a Validator that allows all connections.
|
// Auth is a Validator that allows all connections.
|
||||||
// Used this for testing purposes only.
|
// Used this for testing purposes only.
|
||||||
type Auth struct {
|
type Auth struct {
|
||||||
}
|
}
|
||||||
|
|
||||||
func (a *Auth) Validate(func() hash.Hash, any) error {
|
func (a *Auth) Validate(any) error {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (a *Auth) ValidateHelloMsgType(any) error {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
@ -1,9 +1,11 @@
|
|||||||
package hmac
|
package hmac
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"encoding/base64"
|
||||||
|
"fmt"
|
||||||
"sync"
|
"sync"
|
||||||
|
|
||||||
log "github.com/sirupsen/logrus"
|
v2 "github.com/netbirdio/netbird/relay/auth/hmac/v2"
|
||||||
)
|
)
|
||||||
|
|
||||||
// TokenStore is a simple in-memory store for token
|
// TokenStore is a simple in-memory store for token
|
||||||
@ -20,12 +22,18 @@ func (a *TokenStore) UpdateToken(token *Token) error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
t, err := marshalToken(*token)
|
sig, err := base64.StdEncoding.DecodeString(token.Signature)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Debugf("failed to marshal token: %s", err)
|
return fmt.Errorf("decode signature: %w", err)
|
||||||
return err
|
|
||||||
}
|
}
|
||||||
a.token = t
|
|
||||||
|
tok := v2.Token{
|
||||||
|
AuthAlgo: v2.AuthAlgoHMACSHA256,
|
||||||
|
Signature: sig,
|
||||||
|
Payload: []byte(token.Payload),
|
||||||
|
}
|
||||||
|
|
||||||
|
a.token = tok.Marshal()
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -18,17 +18,6 @@ type Token struct {
|
|||||||
Signature string
|
Signature string
|
||||||
}
|
}
|
||||||
|
|
||||||
func marshalToken(token Token) ([]byte, error) {
|
|
||||||
var buffer bytes.Buffer
|
|
||||||
encoder := gob.NewEncoder(&buffer)
|
|
||||||
err := encoder.Encode(token)
|
|
||||||
if err != nil {
|
|
||||||
log.Debugf("failed to marshal token: %s", err)
|
|
||||||
return nil, fmt.Errorf("failed to marshal token: %w", err)
|
|
||||||
}
|
|
||||||
return buffer.Bytes(), nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func unmarshalToken(payload []byte) (Token, error) {
|
func unmarshalToken(payload []byte) (Token, error) {
|
||||||
var creds Token
|
var creds Token
|
||||||
buffer := bytes.NewBuffer(payload)
|
buffer := bytes.NewBuffer(payload)
|
||||||
|
40
relay/auth/hmac/v2/algo.go
Normal file
40
relay/auth/hmac/v2/algo.go
Normal file
@ -0,0 +1,40 @@
|
|||||||
|
package v2
|
||||||
|
|
||||||
|
import (
|
||||||
|
"crypto/sha256"
|
||||||
|
"hash"
|
||||||
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
AuthAlgoUnknown AuthAlgo = iota
|
||||||
|
AuthAlgoHMACSHA256
|
||||||
|
)
|
||||||
|
|
||||||
|
type AuthAlgo uint8
|
||||||
|
|
||||||
|
func (a AuthAlgo) String() string {
|
||||||
|
switch a {
|
||||||
|
case AuthAlgoHMACSHA256:
|
||||||
|
return "HMAC-SHA256"
|
||||||
|
default:
|
||||||
|
return "Unknown"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (a AuthAlgo) New() func() hash.Hash {
|
||||||
|
switch a {
|
||||||
|
case AuthAlgoHMACSHA256:
|
||||||
|
return sha256.New
|
||||||
|
default:
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (a AuthAlgo) Size() int {
|
||||||
|
switch a {
|
||||||
|
case AuthAlgoHMACSHA256:
|
||||||
|
return sha256.Size
|
||||||
|
default:
|
||||||
|
return 0
|
||||||
|
}
|
||||||
|
}
|
45
relay/auth/hmac/v2/generator.go
Normal file
45
relay/auth/hmac/v2/generator.go
Normal file
@ -0,0 +1,45 @@
|
|||||||
|
package v2
|
||||||
|
|
||||||
|
import (
|
||||||
|
"crypto/hmac"
|
||||||
|
"fmt"
|
||||||
|
"hash"
|
||||||
|
"strconv"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
type Generator struct {
|
||||||
|
algo func() hash.Hash
|
||||||
|
algoType AuthAlgo
|
||||||
|
secret []byte
|
||||||
|
timeToLive time.Duration
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewGenerator(algo AuthAlgo, secret []byte, timeToLive time.Duration) (*Generator, error) {
|
||||||
|
algoFunc := algo.New()
|
||||||
|
if algoFunc == nil {
|
||||||
|
return nil, fmt.Errorf("unsupported auth algorithm: %s", algo)
|
||||||
|
}
|
||||||
|
return &Generator{
|
||||||
|
algo: algoFunc,
|
||||||
|
algoType: algo,
|
||||||
|
secret: secret,
|
||||||
|
timeToLive: timeToLive,
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (g *Generator) GenerateToken() (*Token, error) {
|
||||||
|
expirationTime := time.Now().Add(g.timeToLive).Unix()
|
||||||
|
|
||||||
|
payload := []byte(strconv.FormatInt(expirationTime, 10))
|
||||||
|
|
||||||
|
h := hmac.New(g.algo, g.secret)
|
||||||
|
h.Write(payload)
|
||||||
|
signature := h.Sum(nil)
|
||||||
|
|
||||||
|
return &Token{
|
||||||
|
AuthAlgo: g.algoType,
|
||||||
|
Signature: signature,
|
||||||
|
Payload: payload,
|
||||||
|
}, nil
|
||||||
|
}
|
110
relay/auth/hmac/v2/hmac_test.go
Normal file
110
relay/auth/hmac/v2/hmac_test.go
Normal file
@ -0,0 +1,110 @@
|
|||||||
|
package v2
|
||||||
|
|
||||||
|
import (
|
||||||
|
"strconv"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestGenerateCredentials(t *testing.T) {
|
||||||
|
secret := "supersecret"
|
||||||
|
timeToLive := 1 * time.Hour
|
||||||
|
g, err := NewGenerator(AuthAlgoHMACSHA256, []byte(secret), timeToLive)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("failed to create generator: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
token, err := g.GenerateToken()
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("expected no error, got %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(token.Payload) == 0 {
|
||||||
|
t.Fatalf("expected non-empty payload")
|
||||||
|
}
|
||||||
|
|
||||||
|
_, err = strconv.ParseInt(string(token.Payload), 10, 64)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("expected payload to be a valid unix timestamp, got %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestValidateCredentials(t *testing.T) {
|
||||||
|
secret := "supersecret"
|
||||||
|
timeToLive := 1 * time.Hour
|
||||||
|
g, err := NewGenerator(AuthAlgoHMACSHA256, []byte(secret), timeToLive)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("failed to create generator: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
token, err := g.GenerateToken()
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("expected no error, got %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
v := NewValidator([]byte(secret))
|
||||||
|
if err := v.Validate(token.Marshal()); err != nil {
|
||||||
|
t.Fatalf("expected valid token: %s", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestInvalidSignature(t *testing.T) {
|
||||||
|
secret := "supersecret"
|
||||||
|
timeToLive := 1 * time.Hour
|
||||||
|
g, err := NewGenerator(AuthAlgoHMACSHA256, []byte(secret), timeToLive)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("failed to create generator: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
token, err := g.GenerateToken()
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("expected no error, got %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
token.Signature = []byte{0, 1, 2, 3, 4, 5, 6, 7, 8, 9}
|
||||||
|
|
||||||
|
v := NewValidator([]byte(secret))
|
||||||
|
if err := v.Validate(token.Marshal()); err == nil {
|
||||||
|
t.Fatalf("expected valid token: %s", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestExpired(t *testing.T) {
|
||||||
|
secret := "supersecret"
|
||||||
|
timeToLive := -1 * time.Hour
|
||||||
|
g, err := NewGenerator(AuthAlgoHMACSHA256, []byte(secret), timeToLive)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("failed to create generator: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
token, err := g.GenerateToken()
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("expected no error, got %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
v := NewValidator([]byte(secret))
|
||||||
|
if err := v.Validate(token.Marshal()); err == nil {
|
||||||
|
t.Fatalf("expected valid token: %s", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestInvalidPayload(t *testing.T) {
|
||||||
|
secret := "supersecret"
|
||||||
|
timeToLive := 1 * time.Hour
|
||||||
|
g, err := NewGenerator(AuthAlgoHMACSHA256, []byte(secret), timeToLive)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("failed to create generator: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
token, err := g.GenerateToken()
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("expected no error, got %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
token.Payload = []byte{0, 1, 2, 3, 4, 5, 6, 7, 8, 9}
|
||||||
|
|
||||||
|
v := NewValidator([]byte(secret))
|
||||||
|
if err := v.Validate(token.Marshal()); err == nil {
|
||||||
|
t.Fatalf("expected invalid token due to invalid payload")
|
||||||
|
}
|
||||||
|
}
|
39
relay/auth/hmac/v2/token.go
Normal file
39
relay/auth/hmac/v2/token.go
Normal file
@ -0,0 +1,39 @@
|
|||||||
|
package v2
|
||||||
|
|
||||||
|
import "errors"
|
||||||
|
|
||||||
|
type Token struct {
|
||||||
|
AuthAlgo AuthAlgo
|
||||||
|
Signature []byte
|
||||||
|
Payload []byte
|
||||||
|
}
|
||||||
|
|
||||||
|
func (t *Token) Marshal() []byte {
|
||||||
|
size := 1 + len(t.Signature) + len(t.Payload)
|
||||||
|
|
||||||
|
buf := make([]byte, size)
|
||||||
|
|
||||||
|
buf[0] = byte(t.AuthAlgo)
|
||||||
|
copy(buf[1:], t.Signature)
|
||||||
|
copy(buf[1+len(t.Signature):], t.Payload)
|
||||||
|
|
||||||
|
return buf
|
||||||
|
}
|
||||||
|
|
||||||
|
func UnmarshalToken(data []byte) (*Token, error) {
|
||||||
|
if len(data) == 0 {
|
||||||
|
return nil, errors.New("invalid token data")
|
||||||
|
}
|
||||||
|
|
||||||
|
algo := AuthAlgo(data[0])
|
||||||
|
sigSize := algo.Size()
|
||||||
|
if len(data) < 1+sigSize {
|
||||||
|
return nil, errors.New("invalid token data: insufficient length")
|
||||||
|
}
|
||||||
|
|
||||||
|
return &Token{
|
||||||
|
AuthAlgo: algo,
|
||||||
|
Signature: data[1 : 1+sigSize],
|
||||||
|
Payload: data[1+sigSize:],
|
||||||
|
}, nil
|
||||||
|
}
|
59
relay/auth/hmac/v2/validator.go
Normal file
59
relay/auth/hmac/v2/validator.go
Normal file
@ -0,0 +1,59 @@
|
|||||||
|
package v2
|
||||||
|
|
||||||
|
import (
|
||||||
|
"crypto/hmac"
|
||||||
|
"errors"
|
||||||
|
"fmt"
|
||||||
|
"strconv"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
const minLengthUnixTimestamp = 10
|
||||||
|
|
||||||
|
type Validator struct {
|
||||||
|
secret []byte
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewValidator(secret []byte) *Validator {
|
||||||
|
return &Validator{secret: secret}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (v *Validator) Validate(data any) error {
|
||||||
|
d, ok := data.([]byte)
|
||||||
|
if !ok {
|
||||||
|
return fmt.Errorf("invalid data type")
|
||||||
|
}
|
||||||
|
|
||||||
|
token, err := UnmarshalToken(d)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("unmarshal token: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(token.Payload) < minLengthUnixTimestamp {
|
||||||
|
return errors.New("invalid payload: insufficient length")
|
||||||
|
}
|
||||||
|
|
||||||
|
hashFunc := token.AuthAlgo.New()
|
||||||
|
if hashFunc == nil {
|
||||||
|
return fmt.Errorf("unsupported auth algorithm: %s", token.AuthAlgo)
|
||||||
|
}
|
||||||
|
|
||||||
|
h := hmac.New(hashFunc, v.secret)
|
||||||
|
h.Write(token.Payload)
|
||||||
|
expectedMAC := h.Sum(nil)
|
||||||
|
|
||||||
|
if !hmac.Equal(token.Signature, expectedMAC) {
|
||||||
|
return errors.New("invalid signature")
|
||||||
|
}
|
||||||
|
|
||||||
|
timestamp, err := strconv.ParseInt(string(token.Payload), 10, 64)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("invalid payload: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if time.Now().Unix() > timestamp {
|
||||||
|
return fmt.Errorf("expired token")
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
@ -1,8 +1,8 @@
|
|||||||
package hmac
|
package hmac
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"crypto/sha256"
|
||||||
"fmt"
|
"fmt"
|
||||||
"hash"
|
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
log "github.com/sirupsen/logrus"
|
log "github.com/sirupsen/logrus"
|
||||||
@ -19,7 +19,7 @@ func NewTimedHMACValidator(secret string, duration time.Duration) *TimedHMACVali
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (a *TimedHMACValidator) Validate(algo func() hash.Hash, credentials any) error {
|
func (a *TimedHMACValidator) Validate(credentials any) error {
|
||||||
b, ok := credentials.([]byte)
|
b, ok := credentials.([]byte)
|
||||||
if !ok {
|
if !ok {
|
||||||
return fmt.Errorf("invalid credentials type")
|
return fmt.Errorf("invalid credentials type")
|
||||||
@ -29,5 +29,5 @@ func (a *TimedHMACValidator) Validate(algo func() hash.Hash, credentials any) er
|
|||||||
log.Debugf("failed to unmarshal token: %s", err)
|
log.Debugf("failed to unmarshal token: %s", err)
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
return a.TimedHMAC.Validate(algo, c)
|
return a.TimedHMAC.Validate(sha256.New, c)
|
||||||
}
|
}
|
||||||
|
@ -1,8 +1,35 @@
|
|||||||
package auth
|
package auth
|
||||||
|
|
||||||
import "hash"
|
import (
|
||||||
|
"time"
|
||||||
|
|
||||||
|
auth "github.com/netbirdio/netbird/relay/auth/hmac"
|
||||||
|
authv2 "github.com/netbirdio/netbird/relay/auth/hmac/v2"
|
||||||
|
)
|
||||||
|
|
||||||
// Validator is an interface that defines the Validate method.
|
// Validator is an interface that defines the Validate method.
|
||||||
type Validator interface {
|
type Validator interface {
|
||||||
Validate(func() hash.Hash, any) error
|
Validate(any) error
|
||||||
|
// Deprecated: Use Validate instead.
|
||||||
|
ValidateHelloMsgType(any) error
|
||||||
|
}
|
||||||
|
|
||||||
|
type TimedHMACValidator struct {
|
||||||
|
authenticatorV2 *authv2.Validator
|
||||||
|
authenticator *auth.TimedHMACValidator
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewTimedHMACValidator(secret []byte, duration time.Duration) *TimedHMACValidator {
|
||||||
|
return &TimedHMACValidator{
|
||||||
|
authenticatorV2: authv2.NewValidator(secret),
|
||||||
|
authenticator: auth.NewTimedHMACValidator(string(secret), duration),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (a *TimedHMACValidator) Validate(credentials any) error {
|
||||||
|
return a.authenticatorV2.Validate(credentials)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (a *TimedHMACValidator) ValidateHelloMsgType(credentials any) error {
|
||||||
|
return a.authenticator.Validate(credentials)
|
||||||
}
|
}
|
||||||
|
@ -14,8 +14,6 @@ import (
|
|||||||
"github.com/netbirdio/netbird/relay/client/dialer/ws"
|
"github.com/netbirdio/netbird/relay/client/dialer/ws"
|
||||||
"github.com/netbirdio/netbird/relay/healthcheck"
|
"github.com/netbirdio/netbird/relay/healthcheck"
|
||||||
"github.com/netbirdio/netbird/relay/messages"
|
"github.com/netbirdio/netbird/relay/messages"
|
||||||
"github.com/netbirdio/netbird/relay/messages/address"
|
|
||||||
auth2 "github.com/netbirdio/netbird/relay/messages/auth"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
const (
|
const (
|
||||||
@ -240,31 +238,21 @@ func (c *Client) connect() error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (c *Client) handShake() error {
|
func (c *Client) handShake() error {
|
||||||
authMsg := &auth2.Msg{
|
msg, err := messages.MarshalAuthMsg(c.hashedID, c.authTokenStore.TokenBinary())
|
||||||
AuthAlgorithm: auth2.AlgoHMACSHA256,
|
|
||||||
AdditionalData: c.authTokenStore.TokenBinary(),
|
|
||||||
}
|
|
||||||
|
|
||||||
authData, err := authMsg.Marshal()
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("marshal auth message: %w", err)
|
log.Errorf("failed to marshal auth message: %s", err)
|
||||||
}
|
|
||||||
|
|
||||||
msg, err := messages.MarshalHelloMsg(c.hashedID, authData)
|
|
||||||
if err != nil {
|
|
||||||
log.Errorf("failed to marshal hello message: %s", err)
|
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
_, err = c.relayConn.Write(msg)
|
_, err = c.relayConn.Write(msg)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Errorf("failed to send hello message: %s", err)
|
log.Errorf("failed to send auth message: %s", err)
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
buf := make([]byte, messages.MaxHandshakeSize)
|
buf := make([]byte, messages.MaxHandshakeRespSize)
|
||||||
n, err := c.readWithTimeout(buf)
|
n, err := c.readWithTimeout(buf)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Errorf("failed to read hello response: %s", err)
|
log.Errorf("failed to read auth response: %s", err)
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -279,23 +267,18 @@ func (c *Client) handShake() error {
|
|||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
if msgType != messages.MsgTypeHelloResponse {
|
if msgType != messages.MsgTypeAuthResponse {
|
||||||
log.Errorf("unexpected message type: %s", msgType)
|
log.Errorf("unexpected message type: %s", msgType)
|
||||||
return fmt.Errorf("unexpected message type")
|
return fmt.Errorf("unexpected message type")
|
||||||
}
|
}
|
||||||
|
|
||||||
additionalData, err := messages.UnmarshalHelloResponse(buf[messages.SizeOfProtoHeader:n])
|
addr, err := messages.UnmarshalAuthResponse(buf[messages.SizeOfProtoHeader:n])
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
addr, err := address.Unmarshal(additionalData)
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("unmarshal address: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
c.muInstanceURL.Lock()
|
c.muInstanceURL.Lock()
|
||||||
c.instanceURL = &RelayAddr{addr: addr.URL}
|
c.instanceURL = &RelayAddr{addr: addr}
|
||||||
c.muInstanceURL.Unlock()
|
c.muInstanceURL.Unlock()
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
@ -2,6 +2,7 @@ package cmd
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
|
"crypto/sha256"
|
||||||
"crypto/tls"
|
"crypto/tls"
|
||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
@ -16,7 +17,7 @@ import (
|
|||||||
"github.com/spf13/cobra"
|
"github.com/spf13/cobra"
|
||||||
|
|
||||||
"github.com/netbirdio/netbird/encryption"
|
"github.com/netbirdio/netbird/encryption"
|
||||||
auth "github.com/netbirdio/netbird/relay/auth/hmac"
|
"github.com/netbirdio/netbird/relay/auth"
|
||||||
"github.com/netbirdio/netbird/relay/server"
|
"github.com/netbirdio/netbird/relay/server"
|
||||||
"github.com/netbirdio/netbird/signal/metrics"
|
"github.com/netbirdio/netbird/signal/metrics"
|
||||||
"github.com/netbirdio/netbird/util"
|
"github.com/netbirdio/netbird/util"
|
||||||
@ -139,7 +140,9 @@ func execute(cmd *cobra.Command, args []string) error {
|
|||||||
}
|
}
|
||||||
srvListenerCfg.TLSConfig = tlsConfig
|
srvListenerCfg.TLSConfig = tlsConfig
|
||||||
|
|
||||||
authenticator := auth.NewTimedHMACValidator(cobraConfig.AuthSecret, 24*time.Hour)
|
hashedSecret := sha256.Sum256([]byte(cobraConfig.AuthSecret))
|
||||||
|
authenticator := auth.NewTimedHMACValidator(hashedSecret[:], 24*time.Hour)
|
||||||
|
|
||||||
srv, err := server.NewServer(metricsServer.Meter, cobraConfig.ExposedAddress, tlsSupport, authenticator)
|
srv, err := server.NewServer(metricsServer.Meter, cobraConfig.ExposedAddress, tlsSupport, authenticator)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Debugf("failed to create relay server: %v", err)
|
log.Debugf("failed to create relay server: %v", err)
|
||||||
|
@ -1,3 +1,4 @@
|
|||||||
|
// Deprecated: This package is deprecated and will be removed in a future release.
|
||||||
package address
|
package address
|
||||||
|
|
||||||
import (
|
import (
|
||||||
@ -18,13 +19,3 @@ func (addr *Address) Marshal() ([]byte, error) {
|
|||||||
}
|
}
|
||||||
return buf.Bytes(), nil
|
return buf.Bytes(), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func Unmarshal(data []byte) (*Address, error) {
|
|
||||||
var addr Address
|
|
||||||
buf := bytes.NewBuffer(data)
|
|
||||||
dec := gob.NewDecoder(buf)
|
|
||||||
if err := dec.Decode(&addr); err != nil {
|
|
||||||
return nil, fmt.Errorf("decode Address: %w", err)
|
|
||||||
}
|
|
||||||
return &addr, nil
|
|
||||||
}
|
|
||||||
|
@ -1,3 +1,4 @@
|
|||||||
|
// Deprecated: This package is deprecated and will be removed in a future release.
|
||||||
package auth
|
package auth
|
||||||
|
|
||||||
import (
|
import (
|
||||||
@ -30,15 +31,6 @@ type Msg struct {
|
|||||||
AdditionalData []byte
|
AdditionalData []byte
|
||||||
}
|
}
|
||||||
|
|
||||||
func (msg *Msg) Marshal() ([]byte, error) {
|
|
||||||
var buf bytes.Buffer
|
|
||||||
enc := gob.NewEncoder(&buf)
|
|
||||||
if err := enc.Encode(msg); err != nil {
|
|
||||||
return nil, fmt.Errorf("encode Msg: %w", err)
|
|
||||||
}
|
|
||||||
return buf.Bytes(), nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func UnmarshalMsg(data []byte) (*Msg, error) {
|
func UnmarshalMsg(data []byte) (*Msg, error) {
|
||||||
var msg *Msg
|
var msg *Msg
|
||||||
|
|
||||||
|
@ -7,12 +7,21 @@ import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
const (
|
const (
|
||||||
|
MaxHandshakeSize = 212
|
||||||
|
MaxHandshakeRespSize = 8192
|
||||||
|
|
||||||
|
CurrentProtocolVersion = 1
|
||||||
|
|
||||||
MsgTypeUnknown MsgType = 0
|
MsgTypeUnknown MsgType = 0
|
||||||
|
// Deprecated: Use MsgTypeAuth instead.
|
||||||
MsgTypeHello MsgType = 1
|
MsgTypeHello MsgType = 1
|
||||||
|
// Deprecated: Use MsgTypeAuthResponse instead.
|
||||||
MsgTypeHelloResponse MsgType = 2
|
MsgTypeHelloResponse MsgType = 2
|
||||||
MsgTypeTransport MsgType = 3
|
MsgTypeTransport MsgType = 3
|
||||||
MsgTypeClose MsgType = 4
|
MsgTypeClose MsgType = 4
|
||||||
MsgTypeHealthCheck MsgType = 5
|
MsgTypeHealthCheck MsgType = 5
|
||||||
|
MsgTypeAuth = 6
|
||||||
|
MsgTypeAuthResponse = 7
|
||||||
|
|
||||||
SizeOfVersionByte = 1
|
SizeOfVersionByte = 1
|
||||||
SizeOfMsgType = 1
|
SizeOfMsgType = 1
|
||||||
@ -22,12 +31,12 @@ const (
|
|||||||
sizeOfMagicByte = 4
|
sizeOfMagicByte = 4
|
||||||
|
|
||||||
headerSizeTransport = IDSize
|
headerSizeTransport = IDSize
|
||||||
|
|
||||||
headerSizeHello = sizeOfMagicByte + IDSize
|
headerSizeHello = sizeOfMagicByte + IDSize
|
||||||
headerSizeHelloResp = 0
|
headerSizeHelloResp = 0
|
||||||
|
|
||||||
MaxHandshakeSize = 8192
|
headerSizeAuth = sizeOfMagicByte + IDSize
|
||||||
|
headerSizeAuthResp = 0
|
||||||
CurrentProtocolVersion = 1
|
|
||||||
)
|
)
|
||||||
|
|
||||||
var (
|
var (
|
||||||
@ -47,6 +56,10 @@ func (m MsgType) String() string {
|
|||||||
return "hello"
|
return "hello"
|
||||||
case MsgTypeHelloResponse:
|
case MsgTypeHelloResponse:
|
||||||
return "hello response"
|
return "hello response"
|
||||||
|
case MsgTypeAuth:
|
||||||
|
return "auth"
|
||||||
|
case MsgTypeAuthResponse:
|
||||||
|
return "auth response"
|
||||||
case MsgTypeTransport:
|
case MsgTypeTransport:
|
||||||
return "transport"
|
return "transport"
|
||||||
case MsgTypeClose:
|
case MsgTypeClose:
|
||||||
@ -58,10 +71,6 @@ func (m MsgType) String() string {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
type HelloResponse struct {
|
|
||||||
InstanceAddress string
|
|
||||||
}
|
|
||||||
|
|
||||||
// ValidateVersion checks if the given version is supported by the protocol
|
// ValidateVersion checks if the given version is supported by the protocol
|
||||||
func ValidateVersion(msg []byte) (int, error) {
|
func ValidateVersion(msg []byte) (int, error) {
|
||||||
if len(msg) < SizeOfVersionByte {
|
if len(msg) < SizeOfVersionByte {
|
||||||
@ -84,6 +93,7 @@ func DetermineClientMessageType(msg []byte) (MsgType, error) {
|
|||||||
switch msgType {
|
switch msgType {
|
||||||
case
|
case
|
||||||
MsgTypeHello,
|
MsgTypeHello,
|
||||||
|
MsgTypeAuth,
|
||||||
MsgTypeTransport,
|
MsgTypeTransport,
|
||||||
MsgTypeClose,
|
MsgTypeClose,
|
||||||
MsgTypeHealthCheck:
|
MsgTypeHealthCheck:
|
||||||
@ -103,6 +113,7 @@ func DetermineServerMessageType(msg []byte) (MsgType, error) {
|
|||||||
switch msgType {
|
switch msgType {
|
||||||
case
|
case
|
||||||
MsgTypeHelloResponse,
|
MsgTypeHelloResponse,
|
||||||
|
MsgTypeAuthResponse,
|
||||||
MsgTypeTransport,
|
MsgTypeTransport,
|
||||||
MsgTypeClose,
|
MsgTypeClose,
|
||||||
MsgTypeHealthCheck:
|
MsgTypeHealthCheck:
|
||||||
@ -112,6 +123,7 @@ func DetermineServerMessageType(msg []byte) (MsgType, error) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Deprecated: Use MarshalAuthMsg instead.
|
||||||
// MarshalHelloMsg initial hello message
|
// MarshalHelloMsg initial hello message
|
||||||
// The Hello message is the first message sent by a client after establishing a connection with the Relay server. This
|
// The Hello message is the first message sent by a client after establishing a connection with the Relay server. This
|
||||||
// message is used to authenticate the client with the server. The authentication is done using an HMAC method.
|
// message is used to authenticate the client with the server. The authentication is done using an HMAC method.
|
||||||
@ -135,6 +147,7 @@ func MarshalHelloMsg(peerID []byte, additions []byte) ([]byte, error) {
|
|||||||
return msg, nil
|
return msg, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Deprecated: Use UnmarshalAuthMsg instead.
|
||||||
// UnmarshalHelloMsg extracts peerID and the additional data from the hello message. The Additional data is used to
|
// UnmarshalHelloMsg extracts peerID and the additional data from the hello message. The Additional data is used to
|
||||||
// authenticate the client with the server.
|
// authenticate the client with the server.
|
||||||
func UnmarshalHelloMsg(msg []byte) ([]byte, []byte, error) {
|
func UnmarshalHelloMsg(msg []byte) ([]byte, []byte, error) {
|
||||||
@ -148,6 +161,7 @@ func UnmarshalHelloMsg(msg []byte) ([]byte, []byte, error) {
|
|||||||
return msg[sizeOfMagicByte:headerSizeHello], msg[headerSizeHello:], nil
|
return msg[sizeOfMagicByte:headerSizeHello], msg[headerSizeHello:], nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Deprecated: Use MarshalAuthResponse instead.
|
||||||
// MarshalHelloResponse creates a response message to the hello message.
|
// MarshalHelloResponse creates a response message to the hello message.
|
||||||
// In case of success connection the server response with a Hello Response message. This message contains the server's
|
// In case of success connection the server response with a Hello Response message. This message contains the server's
|
||||||
// instance URL. This URL will be used by choose the common Relay server in case if the peers are in different Relay
|
// instance URL. This URL will be used by choose the common Relay server in case if the peers are in different Relay
|
||||||
@ -163,6 +177,7 @@ func MarshalHelloResponse(additionalData []byte) ([]byte, error) {
|
|||||||
return msg, nil
|
return msg, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Deprecated: Use UnmarshalAuthResponse instead.
|
||||||
// UnmarshalHelloResponse extracts the additional data from the hello response message.
|
// UnmarshalHelloResponse extracts the additional data from the hello response message.
|
||||||
func UnmarshalHelloResponse(msg []byte) ([]byte, error) {
|
func UnmarshalHelloResponse(msg []byte) ([]byte, error) {
|
||||||
if len(msg) < headerSizeHelloResp {
|
if len(msg) < headerSizeHelloResp {
|
||||||
@ -171,6 +186,69 @@ func UnmarshalHelloResponse(msg []byte) ([]byte, error) {
|
|||||||
return msg, nil
|
return msg, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// MarshalAuthMsg initial authentication message
|
||||||
|
// The Auth message is the first message sent by a client after establishing a connection with the Relay server. This
|
||||||
|
// message is used to authenticate the client with the server. The authentication is done using an HMAC method.
|
||||||
|
// The protocol does not limit to use HMAC, it can be any other method. If the authentication failed the server will
|
||||||
|
// close the network connection without any response.
|
||||||
|
func MarshalAuthMsg(peerID []byte, authPayload []byte) ([]byte, error) {
|
||||||
|
if len(peerID) != IDSize {
|
||||||
|
return nil, fmt.Errorf("invalid peerID length: %d", len(peerID))
|
||||||
|
}
|
||||||
|
|
||||||
|
msg := make([]byte, SizeOfProtoHeader+sizeOfMagicByte, SizeOfProtoHeader+headerSizeAuth+len(authPayload))
|
||||||
|
|
||||||
|
msg[0] = byte(CurrentProtocolVersion)
|
||||||
|
msg[1] = byte(MsgTypeAuth)
|
||||||
|
|
||||||
|
copy(msg[SizeOfProtoHeader:SizeOfProtoHeader+sizeOfMagicByte], magicHeader)
|
||||||
|
|
||||||
|
msg = append(msg, peerID...)
|
||||||
|
msg = append(msg, authPayload...)
|
||||||
|
|
||||||
|
return msg, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// UnmarshalAuthMsg extracts peerID and the auth payload from the message
|
||||||
|
func UnmarshalAuthMsg(msg []byte) ([]byte, []byte, error) {
|
||||||
|
if len(msg) < headerSizeAuth {
|
||||||
|
return nil, nil, ErrInvalidMessageLength
|
||||||
|
}
|
||||||
|
if !bytes.Equal(msg[:sizeOfMagicByte], magicHeader) {
|
||||||
|
return nil, nil, errors.New("invalid magic header")
|
||||||
|
}
|
||||||
|
|
||||||
|
return msg[sizeOfMagicByte:headerSizeAuth], msg[headerSizeAuth:], nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// MarshalAuthResponse creates a response message to the auth.
|
||||||
|
// In case of success connection the server response with a AuthResponse message. This message contains the server's
|
||||||
|
// instance URL. This URL will be used by choose the common Relay server in case if the peers are in different Relay
|
||||||
|
// servers.
|
||||||
|
func MarshalAuthResponse(address string) ([]byte, error) {
|
||||||
|
ab := []byte(address)
|
||||||
|
msg := make([]byte, SizeOfProtoHeader, SizeOfProtoHeader+headerSizeAuthResp+len(ab))
|
||||||
|
|
||||||
|
msg[0] = byte(CurrentProtocolVersion)
|
||||||
|
msg[1] = byte(MsgTypeAuthResponse)
|
||||||
|
|
||||||
|
msg = append(msg, ab...)
|
||||||
|
|
||||||
|
if len(msg) > MaxHandshakeRespSize {
|
||||||
|
return nil, fmt.Errorf("invalid message length: %d", len(msg))
|
||||||
|
}
|
||||||
|
|
||||||
|
return msg, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// UnmarshalAuthResponse it is a confirmation message to auth success
|
||||||
|
func UnmarshalAuthResponse(msg []byte) (string, error) {
|
||||||
|
if len(msg) < headerSizeAuthResp+1 {
|
||||||
|
return "", ErrInvalidMessageLength
|
||||||
|
}
|
||||||
|
return string(msg), nil
|
||||||
|
}
|
||||||
|
|
||||||
// MarshalCloseMsg creates a close message.
|
// MarshalCloseMsg creates a close message.
|
||||||
// The close message is used to close the connection gracefully between the client and the server. The server and the
|
// The close message is used to close the connection gracefully between the client and the server. The server and the
|
||||||
// client can send this message. After receiving this message, the server or client will close the connection.
|
// client can send this message. After receiving this message, the server or client will close the connection.
|
||||||
|
@ -20,6 +20,22 @@ func TestMarshalHelloMsg(t *testing.T) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestMarshalAuthMsg(t *testing.T) {
|
||||||
|
peerID := []byte("abdFAaBcawquEiCMzAabYosuUaGLtSNhKxz+")
|
||||||
|
bHello, err := MarshalAuthMsg(peerID, []byte{})
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("error: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
receivedPeerID, _, err := UnmarshalAuthMsg(bHello[SizeOfProtoHeader:])
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("error: %v", err)
|
||||||
|
}
|
||||||
|
if string(receivedPeerID) != string(peerID) {
|
||||||
|
t.Errorf("expected %s, got %s", peerID, receivedPeerID)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func TestMarshalTransportMsg(t *testing.T) {
|
func TestMarshalTransportMsg(t *testing.T) {
|
||||||
peerID := []byte("abdFAaBcawquEiCMzAabYosuUaGLtSNhKxz+")
|
peerID := []byte("abdFAaBcawquEiCMzAabYosuUaGLtSNhKxz+")
|
||||||
payload := []byte("payload")
|
payload := []byte("payload")
|
||||||
|
@ -2,7 +2,6 @@ package server
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"crypto/sha256"
|
|
||||||
"fmt"
|
"fmt"
|
||||||
"net"
|
"net"
|
||||||
"net/url"
|
"net/url"
|
||||||
@ -14,7 +13,9 @@ import (
|
|||||||
|
|
||||||
"github.com/netbirdio/netbird/relay/auth"
|
"github.com/netbirdio/netbird/relay/auth"
|
||||||
"github.com/netbirdio/netbird/relay/messages"
|
"github.com/netbirdio/netbird/relay/messages"
|
||||||
|
//nolint:staticcheck
|
||||||
"github.com/netbirdio/netbird/relay/messages/address"
|
"github.com/netbirdio/netbird/relay/messages/address"
|
||||||
|
//nolint:staticcheck
|
||||||
authmsg "github.com/netbirdio/netbird/relay/messages/auth"
|
authmsg "github.com/netbirdio/netbird/relay/messages/auth"
|
||||||
"github.com/netbirdio/netbird/relay/metrics"
|
"github.com/netbirdio/netbird/relay/metrics"
|
||||||
)
|
)
|
||||||
@ -168,39 +169,81 @@ func (r *Relay) handshake(conn net.Conn) ([]byte, error) {
|
|||||||
return nil, fmt.Errorf("determine message type from %s: %w", conn.RemoteAddr(), err)
|
return nil, fmt.Errorf("determine message type from %s: %w", conn.RemoteAddr(), err)
|
||||||
}
|
}
|
||||||
|
|
||||||
if msgType != messages.MsgTypeHello {
|
var (
|
||||||
return nil, fmt.Errorf("invalid message type from %s", conn.RemoteAddr())
|
responseMsg []byte
|
||||||
|
peerID []byte
|
||||||
|
)
|
||||||
|
switch msgType {
|
||||||
|
//nolint:staticcheck
|
||||||
|
case messages.MsgTypeHello:
|
||||||
|
peerID, responseMsg, err = r.handleHelloMsg(buf[messages.SizeOfProtoHeader:n], conn.RemoteAddr())
|
||||||
|
case messages.MsgTypeAuth:
|
||||||
|
peerID, responseMsg, err = r.handleAuthMsg(buf[messages.SizeOfProtoHeader:n], conn.RemoteAddr())
|
||||||
|
default:
|
||||||
|
return nil, fmt.Errorf("invalid message type %d from %s", msgType, conn.RemoteAddr())
|
||||||
}
|
}
|
||||||
|
|
||||||
peerID, authData, err := messages.UnmarshalHelloMsg(buf[messages.SizeOfProtoHeader:n])
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("unmarshal hello message: %w", err)
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
authMsg, err := authmsg.UnmarshalMsg(authData)
|
_, err = conn.Write(responseMsg)
|
||||||
if err != nil {
|
|
||||||
return nil, fmt.Errorf("unmarshal auth message: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
if err := r.validator.Validate(sha256.New, authMsg.AdditionalData); err != nil {
|
|
||||||
return nil, fmt.Errorf("validate %s (%s): %w", peerID, conn.RemoteAddr(), err)
|
|
||||||
}
|
|
||||||
|
|
||||||
addr := &address.Address{URL: r.instanceURL}
|
|
||||||
addrData, err := addr.Marshal()
|
|
||||||
if err != nil {
|
|
||||||
return nil, fmt.Errorf("marshal addressc to %s (%s): %w", peerID, conn.RemoteAddr(), err)
|
|
||||||
}
|
|
||||||
|
|
||||||
msg, err := messages.MarshalHelloResponse(addrData)
|
|
||||||
if err != nil {
|
|
||||||
return nil, fmt.Errorf("marshal hello response to %s (%s): %w", peerID, conn.RemoteAddr(), err)
|
|
||||||
}
|
|
||||||
|
|
||||||
_, err = conn.Write(msg)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("write to %s (%s): %w", peerID, conn.RemoteAddr(), err)
|
return nil, fmt.Errorf("write to %s (%s): %w", peerID, conn.RemoteAddr(), err)
|
||||||
}
|
}
|
||||||
|
|
||||||
return peerID, nil
|
return peerID, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (r *Relay) handleHelloMsg(buf []byte, remoteAddr net.Addr) ([]byte, []byte, error) {
|
||||||
|
//nolint:staticcheck
|
||||||
|
rawPeerID, authData, err := messages.UnmarshalHelloMsg(buf)
|
||||||
|
if err != nil {
|
||||||
|
return nil, nil, fmt.Errorf("unmarshal hello message: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
peerID := messages.HashIDToString(rawPeerID)
|
||||||
|
log.Warnf("peer %s (%s) is using deprecated initial message type", peerID, remoteAddr)
|
||||||
|
|
||||||
|
authMsg, err := authmsg.UnmarshalMsg(authData)
|
||||||
|
if err != nil {
|
||||||
|
return nil, nil, fmt.Errorf("unmarshal auth message: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
//nolint:staticcheck
|
||||||
|
if err := r.validator.ValidateHelloMsgType(authMsg.AdditionalData); err != nil {
|
||||||
|
return nil, nil, fmt.Errorf("validate %s (%s): %w", peerID, remoteAddr, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
addr := &address.Address{URL: r.instanceURL}
|
||||||
|
addrData, err := addr.Marshal()
|
||||||
|
if err != nil {
|
||||||
|
return nil, nil, fmt.Errorf("marshal addressc to %s (%s): %w", peerID, remoteAddr, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
//nolint:staticcheck
|
||||||
|
responseMsg, err := messages.MarshalHelloResponse(addrData)
|
||||||
|
if err != nil {
|
||||||
|
return nil, nil, fmt.Errorf("marshal hello response to %s (%s): %w", peerID, remoteAddr, err)
|
||||||
|
}
|
||||||
|
return rawPeerID, responseMsg, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *Relay) handleAuthMsg(buf []byte, addr net.Addr) ([]byte, []byte, error) {
|
||||||
|
rawPeerID, authPayload, err := messages.UnmarshalAuthMsg(buf)
|
||||||
|
if err != nil {
|
||||||
|
return nil, nil, fmt.Errorf("unmarshal hello message: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
peerID := messages.HashIDToString(rawPeerID)
|
||||||
|
|
||||||
|
if err := r.validator.Validate(authPayload); err != nil {
|
||||||
|
return nil, nil, fmt.Errorf("validate %s (%s): %w", peerID, addr, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
responseMsg, err := messages.MarshalAuthResponse(r.instanceURL)
|
||||||
|
if err != nil {
|
||||||
|
return nil, nil, fmt.Errorf("marshal hello response to %s (%s): %w", peerID, addr, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return rawPeerID, responseMsg, nil
|
||||||
|
}
|
||||||
|
Loading…
Reference in New Issue
Block a user