mirror of
https://github.com/netbirdio/netbird.git
synced 2024-12-15 03:11:02 +01:00
106 lines
2.5 KiB
Go
106 lines
2.5 KiB
Go
|
package hmac
|
||
|
|
||
|
import (
|
||
|
"bytes"
|
||
|
"crypto/hmac"
|
||
|
"encoding/base64"
|
||
|
"encoding/gob"
|
||
|
"fmt"
|
||
|
"hash"
|
||
|
"strconv"
|
||
|
"time"
|
||
|
|
||
|
log "github.com/sirupsen/logrus"
|
||
|
)
|
||
|
|
||
|
type Token struct {
|
||
|
Payload string
|
||
|
Signature string
|
||
|
}
|
||
|
|
||
|
func marshalToken(token Token) ([]byte, error) {
|
||
|
var buffer bytes.Buffer
|
||
|
encoder := gob.NewEncoder(&buffer)
|
||
|
err := encoder.Encode(token)
|
||
|
if err != nil {
|
||
|
log.Debugf("failed to marshal token: %s", err)
|
||
|
return nil, fmt.Errorf("failed to marshal token: %w", err)
|
||
|
}
|
||
|
return buffer.Bytes(), nil
|
||
|
}
|
||
|
|
||
|
func unmarshalToken(payload []byte) (Token, error) {
|
||
|
var creds Token
|
||
|
buffer := bytes.NewBuffer(payload)
|
||
|
decoder := gob.NewDecoder(buffer)
|
||
|
err := decoder.Decode(&creds)
|
||
|
return creds, err
|
||
|
}
|
||
|
|
||
|
// TimedHMAC generates a token with TTL and uses a pre-shared secret known to the relay server
|
||
|
type TimedHMAC struct {
|
||
|
secret string
|
||
|
timeToLive time.Duration
|
||
|
}
|
||
|
|
||
|
// NewTimedHMAC creates a new TimedHMAC instance
|
||
|
func NewTimedHMAC(secret string, timeToLive time.Duration) *TimedHMAC {
|
||
|
return &TimedHMAC{
|
||
|
secret: secret,
|
||
|
timeToLive: timeToLive,
|
||
|
}
|
||
|
}
|
||
|
|
||
|
// GenerateToken generates new time-based secret token - basically Payload is a unix timestamp and Signature is a HMAC
|
||
|
// hash of a timestamp with a preshared TURN secret
|
||
|
func (m *TimedHMAC) GenerateToken(algo func() hash.Hash) (*Token, error) {
|
||
|
timeAuth := time.Now().Add(m.timeToLive).Unix()
|
||
|
timeStamp := strconv.FormatInt(timeAuth, 10)
|
||
|
|
||
|
checksum, err := m.generate(algo, timeStamp)
|
||
|
if err != nil {
|
||
|
return nil, err
|
||
|
}
|
||
|
|
||
|
return &Token{
|
||
|
Payload: timeStamp,
|
||
|
Signature: base64.StdEncoding.EncodeToString(checksum),
|
||
|
}, nil
|
||
|
}
|
||
|
|
||
|
// Validate checks if the token is valid
|
||
|
func (m *TimedHMAC) Validate(algo func() hash.Hash, token Token) error {
|
||
|
expectedMAC, err := m.generate(algo, token.Payload)
|
||
|
if err != nil {
|
||
|
return err
|
||
|
}
|
||
|
|
||
|
expectedSignature := base64.StdEncoding.EncodeToString(expectedMAC)
|
||
|
|
||
|
if !hmac.Equal([]byte(expectedSignature), []byte(token.Signature)) {
|
||
|
return fmt.Errorf("signature mismatch")
|
||
|
}
|
||
|
|
||
|
timeAuthInt, err := strconv.ParseInt(token.Payload, 10, 64)
|
||
|
if err != nil {
|
||
|
return fmt.Errorf("invalid payload: %w", err)
|
||
|
}
|
||
|
|
||
|
if time.Now().Unix() > timeAuthInt {
|
||
|
return fmt.Errorf("expired token")
|
||
|
}
|
||
|
|
||
|
return nil
|
||
|
}
|
||
|
|
||
|
func (m *TimedHMAC) generate(algo func() hash.Hash, payload string) ([]byte, error) {
|
||
|
mac := hmac.New(algo, []byte(m.secret))
|
||
|
_, err := mac.Write([]byte(payload))
|
||
|
if err != nil {
|
||
|
log.Debugf("failed to generate token: %s", err)
|
||
|
return nil, fmt.Errorf("failed to generate token: %w", err)
|
||
|
}
|
||
|
|
||
|
return mac.Sum(nil), nil
|
||
|
}
|