[relay] Improve relay messages (#2574)

Co-authored-by: Zoltán Papp <zoltan.pmail@gmail.com>
This commit is contained in:
Viktor Liu
2024-09-11 16:20:30 +02:00
committed by GitHub
parent 50ebbe482e
commit 2d1bf3982d
19 changed files with 552 additions and 116 deletions

View File

@ -1,12 +1,14 @@
package allow
import "hash"
// Auth is a Validator that allows all connections.
// Used this for testing purposes only.
type Auth struct {
}
func (a *Auth) Validate(func() hash.Hash, any) error {
func (a *Auth) Validate(any) error {
return nil
}
func (a *Auth) ValidateHelloMsgType(any) error {
return nil
}

View File

@ -1,9 +1,11 @@
package hmac
import (
"encoding/base64"
"fmt"
"sync"
log "github.com/sirupsen/logrus"
v2 "github.com/netbirdio/netbird/relay/auth/hmac/v2"
)
// TokenStore is a simple in-memory store for token
@ -20,12 +22,18 @@ func (a *TokenStore) UpdateToken(token *Token) error {
return nil
}
t, err := marshalToken(*token)
sig, err := base64.StdEncoding.DecodeString(token.Signature)
if err != nil {
log.Debugf("failed to marshal token: %s", err)
return err
return fmt.Errorf("decode signature: %w", err)
}
a.token = t
tok := v2.Token{
AuthAlgo: v2.AuthAlgoHMACSHA256,
Signature: sig,
Payload: []byte(token.Payload),
}
a.token = tok.Marshal()
return nil
}

View File

@ -18,17 +18,6 @@ type Token struct {
Signature string
}
func marshalToken(token Token) ([]byte, error) {
var buffer bytes.Buffer
encoder := gob.NewEncoder(&buffer)
err := encoder.Encode(token)
if err != nil {
log.Debugf("failed to marshal token: %s", err)
return nil, fmt.Errorf("failed to marshal token: %w", err)
}
return buffer.Bytes(), nil
}
func unmarshalToken(payload []byte) (Token, error) {
var creds Token
buffer := bytes.NewBuffer(payload)

View File

@ -0,0 +1,40 @@
package v2
import (
"crypto/sha256"
"hash"
)
const (
AuthAlgoUnknown AuthAlgo = iota
AuthAlgoHMACSHA256
)
type AuthAlgo uint8
func (a AuthAlgo) String() string {
switch a {
case AuthAlgoHMACSHA256:
return "HMAC-SHA256"
default:
return "Unknown"
}
}
func (a AuthAlgo) New() func() hash.Hash {
switch a {
case AuthAlgoHMACSHA256:
return sha256.New
default:
return nil
}
}
func (a AuthAlgo) Size() int {
switch a {
case AuthAlgoHMACSHA256:
return sha256.Size
default:
return 0
}
}

View File

@ -0,0 +1,45 @@
package v2
import (
"crypto/hmac"
"fmt"
"hash"
"strconv"
"time"
)
type Generator struct {
algo func() hash.Hash
algoType AuthAlgo
secret []byte
timeToLive time.Duration
}
func NewGenerator(algo AuthAlgo, secret []byte, timeToLive time.Duration) (*Generator, error) {
algoFunc := algo.New()
if algoFunc == nil {
return nil, fmt.Errorf("unsupported auth algorithm: %s", algo)
}
return &Generator{
algo: algoFunc,
algoType: algo,
secret: secret,
timeToLive: timeToLive,
}, nil
}
func (g *Generator) GenerateToken() (*Token, error) {
expirationTime := time.Now().Add(g.timeToLive).Unix()
payload := []byte(strconv.FormatInt(expirationTime, 10))
h := hmac.New(g.algo, g.secret)
h.Write(payload)
signature := h.Sum(nil)
return &Token{
AuthAlgo: g.algoType,
Signature: signature,
Payload: payload,
}, nil
}

View File

@ -0,0 +1,110 @@
package v2
import (
"strconv"
"testing"
"time"
)
func TestGenerateCredentials(t *testing.T) {
secret := "supersecret"
timeToLive := 1 * time.Hour
g, err := NewGenerator(AuthAlgoHMACSHA256, []byte(secret), timeToLive)
if err != nil {
t.Fatalf("failed to create generator: %v", err)
}
token, err := g.GenerateToken()
if err != nil {
t.Fatalf("expected no error, got %v", err)
}
if len(token.Payload) == 0 {
t.Fatalf("expected non-empty payload")
}
_, err = strconv.ParseInt(string(token.Payload), 10, 64)
if err != nil {
t.Fatalf("expected payload to be a valid unix timestamp, got %v", err)
}
}
func TestValidateCredentials(t *testing.T) {
secret := "supersecret"
timeToLive := 1 * time.Hour
g, err := NewGenerator(AuthAlgoHMACSHA256, []byte(secret), timeToLive)
if err != nil {
t.Fatalf("failed to create generator: %v", err)
}
token, err := g.GenerateToken()
if err != nil {
t.Fatalf("expected no error, got %v", err)
}
v := NewValidator([]byte(secret))
if err := v.Validate(token.Marshal()); err != nil {
t.Fatalf("expected valid token: %s", err)
}
}
func TestInvalidSignature(t *testing.T) {
secret := "supersecret"
timeToLive := 1 * time.Hour
g, err := NewGenerator(AuthAlgoHMACSHA256, []byte(secret), timeToLive)
if err != nil {
t.Fatalf("failed to create generator: %v", err)
}
token, err := g.GenerateToken()
if err != nil {
t.Fatalf("expected no error, got %v", err)
}
token.Signature = []byte{0, 1, 2, 3, 4, 5, 6, 7, 8, 9}
v := NewValidator([]byte(secret))
if err := v.Validate(token.Marshal()); err == nil {
t.Fatalf("expected valid token: %s", err)
}
}
func TestExpired(t *testing.T) {
secret := "supersecret"
timeToLive := -1 * time.Hour
g, err := NewGenerator(AuthAlgoHMACSHA256, []byte(secret), timeToLive)
if err != nil {
t.Fatalf("failed to create generator: %v", err)
}
token, err := g.GenerateToken()
if err != nil {
t.Fatalf("expected no error, got %v", err)
}
v := NewValidator([]byte(secret))
if err := v.Validate(token.Marshal()); err == nil {
t.Fatalf("expected valid token: %s", err)
}
}
func TestInvalidPayload(t *testing.T) {
secret := "supersecret"
timeToLive := 1 * time.Hour
g, err := NewGenerator(AuthAlgoHMACSHA256, []byte(secret), timeToLive)
if err != nil {
t.Fatalf("failed to create generator: %v", err)
}
token, err := g.GenerateToken()
if err != nil {
t.Fatalf("expected no error, got %v", err)
}
token.Payload = []byte{0, 1, 2, 3, 4, 5, 6, 7, 8, 9}
v := NewValidator([]byte(secret))
if err := v.Validate(token.Marshal()); err == nil {
t.Fatalf("expected invalid token due to invalid payload")
}
}

View File

@ -0,0 +1,39 @@
package v2
import "errors"
type Token struct {
AuthAlgo AuthAlgo
Signature []byte
Payload []byte
}
func (t *Token) Marshal() []byte {
size := 1 + len(t.Signature) + len(t.Payload)
buf := make([]byte, size)
buf[0] = byte(t.AuthAlgo)
copy(buf[1:], t.Signature)
copy(buf[1+len(t.Signature):], t.Payload)
return buf
}
func UnmarshalToken(data []byte) (*Token, error) {
if len(data) == 0 {
return nil, errors.New("invalid token data")
}
algo := AuthAlgo(data[0])
sigSize := algo.Size()
if len(data) < 1+sigSize {
return nil, errors.New("invalid token data: insufficient length")
}
return &Token{
AuthAlgo: algo,
Signature: data[1 : 1+sigSize],
Payload: data[1+sigSize:],
}, nil
}

View File

@ -0,0 +1,59 @@
package v2
import (
"crypto/hmac"
"errors"
"fmt"
"strconv"
"time"
)
const minLengthUnixTimestamp = 10
type Validator struct {
secret []byte
}
func NewValidator(secret []byte) *Validator {
return &Validator{secret: secret}
}
func (v *Validator) Validate(data any) error {
d, ok := data.([]byte)
if !ok {
return fmt.Errorf("invalid data type")
}
token, err := UnmarshalToken(d)
if err != nil {
return fmt.Errorf("unmarshal token: %w", err)
}
if len(token.Payload) < minLengthUnixTimestamp {
return errors.New("invalid payload: insufficient length")
}
hashFunc := token.AuthAlgo.New()
if hashFunc == nil {
return fmt.Errorf("unsupported auth algorithm: %s", token.AuthAlgo)
}
h := hmac.New(hashFunc, v.secret)
h.Write(token.Payload)
expectedMAC := h.Sum(nil)
if !hmac.Equal(token.Signature, expectedMAC) {
return errors.New("invalid signature")
}
timestamp, err := strconv.ParseInt(string(token.Payload), 10, 64)
if err != nil {
return fmt.Errorf("invalid payload: %w", err)
}
if time.Now().Unix() > timestamp {
return fmt.Errorf("expired token")
}
return nil
}

View File

@ -1,8 +1,8 @@
package hmac
import (
"crypto/sha256"
"fmt"
"hash"
"time"
log "github.com/sirupsen/logrus"
@ -19,7 +19,7 @@ func NewTimedHMACValidator(secret string, duration time.Duration) *TimedHMACVali
}
}
func (a *TimedHMACValidator) Validate(algo func() hash.Hash, credentials any) error {
func (a *TimedHMACValidator) Validate(credentials any) error {
b, ok := credentials.([]byte)
if !ok {
return fmt.Errorf("invalid credentials type")
@ -29,5 +29,5 @@ func (a *TimedHMACValidator) Validate(algo func() hash.Hash, credentials any) er
log.Debugf("failed to unmarshal token: %s", err)
return err
}
return a.TimedHMAC.Validate(algo, c)
return a.TimedHMAC.Validate(sha256.New, c)
}

View File

@ -1,8 +1,35 @@
package auth
import "hash"
import (
"time"
auth "github.com/netbirdio/netbird/relay/auth/hmac"
authv2 "github.com/netbirdio/netbird/relay/auth/hmac/v2"
)
// Validator is an interface that defines the Validate method.
type Validator interface {
Validate(func() hash.Hash, any) error
Validate(any) error
// Deprecated: Use Validate instead.
ValidateHelloMsgType(any) error
}
type TimedHMACValidator struct {
authenticatorV2 *authv2.Validator
authenticator *auth.TimedHMACValidator
}
func NewTimedHMACValidator(secret []byte, duration time.Duration) *TimedHMACValidator {
return &TimedHMACValidator{
authenticatorV2: authv2.NewValidator(secret),
authenticator: auth.NewTimedHMACValidator(string(secret), duration),
}
}
func (a *TimedHMACValidator) Validate(credentials any) error {
return a.authenticatorV2.Validate(credentials)
}
func (a *TimedHMACValidator) ValidateHelloMsgType(credentials any) error {
return a.authenticator.Validate(credentials)
}