Integrate the relay authentication

This commit is contained in:
Zoltan Papp 2024-07-05 16:12:30 +02:00
parent 8845e8fbc7
commit 836072098b
30 changed files with 3055 additions and 1594 deletions

View File

@ -86,7 +86,7 @@ func startManagement(t *testing.T, config *mgmt.Config) (*grpc.Server, net.Liste
if err != nil {
t.Fatal(err)
}
turnManager := mgmt.NewTimeBasedAuthSecretsManager(peersUpdateManager, config.TURNConfig)
turnManager := mgmt.NewTimeBasedAuthSecretsManager(peersUpdateManager, config.TURNConfig, "")
mgmtServer, err := mgmt.NewServer(config, accountManager, peersUpdateManager, turnManager, nil, nil)
if err != nil {
t.Fatal(err)

View File

@ -26,6 +26,7 @@ import (
"github.com/netbirdio/netbird/iface"
mgm "github.com/netbirdio/netbird/management/client"
mgmProto "github.com/netbirdio/netbird/management/proto"
"github.com/netbirdio/netbird/relay/auth/hmac"
relayClient "github.com/netbirdio/netbird/relay/client"
signal "github.com/netbirdio/netbird/signal/client"
"github.com/netbirdio/netbird/util"
@ -245,9 +246,10 @@ func (c *ConnectClient) run(
c.statusRecorder.MarkSignalConnected()
relayAddress := relayAddress(loginResp)
relayManager := relayClient.NewManager(engineCtx, relayAddress, myPrivateKey.PublicKey().String())
if relayAddress != "" {
relayURL, token := parseRelayInfo(loginResp)
relayManager := relayClient.NewManager(engineCtx, relayURL, myPrivateKey.PublicKey().String())
if relayURL != "" {
relayManager.UpdateToken(token)
if err = relayManager.Serve(); err != nil {
log.Error(err)
return wrapErr(err)
@ -307,15 +309,27 @@ func (c *ConnectClient) run(
return nil
}
func relayAddress(resp *mgmProto.LoginResponse) string {
func parseRelayInfo(resp *mgmProto.LoginResponse) (string, hmac.Token) {
// todo remove this
if ra := peer.ForcedRelayAddress(); ra != "" {
return ra
return ra, hmac.Token{}
}
if resp.GetWiretrusteeConfig().GetRelayAddress() != "" {
return resp.GetWiretrusteeConfig().GetRelayAddress()
msg := resp.GetWiretrusteeConfig().GetRelay()
if msg == nil {
return "", hmac.Token{}
}
return ""
var url string
if msg.GetUrls() != nil && len(msg.GetUrls()) > 0 {
url = msg.GetUrls()[0]
}
token := hmac.Token{
Payload: msg.GetTokenPayload(),
Signature: msg.GetTokenSignature(),
}
return url, token
}
func (c *ConnectClient) Engine() *Engine {

View File

@ -24,6 +24,7 @@ import (
"github.com/netbirdio/netbird/client/firewall/manager"
"github.com/netbirdio/netbird/client/internal/acl"
"github.com/netbirdio/netbird/client/internal/dns"
"github.com/netbirdio/netbird/client/internal/networkmonitor"
"github.com/netbirdio/netbird/client/internal/peer"
"github.com/netbirdio/netbird/client/internal/relay"
@ -36,6 +37,7 @@ import (
"github.com/netbirdio/netbird/iface/bind"
mgm "github.com/netbirdio/netbird/management/client"
mgmProto "github.com/netbirdio/netbird/management/proto"
auth "github.com/netbirdio/netbird/relay/auth/hmac"
relayClient "github.com/netbirdio/netbird/relay/client"
"github.com/netbirdio/netbird/route"
signal "github.com/netbirdio/netbird/signal/client"
@ -467,12 +469,13 @@ func (e *Engine) handleSync(update *mgmProto.SyncResponse) error {
defer e.syncMsgMux.Unlock()
if update.GetWiretrusteeConfig() != nil {
err := e.updateTURNs(update.GetWiretrusteeConfig().GetTurns())
wCfg := update.GetWiretrusteeConfig()
err := e.updateTURNs(wCfg.GetTurns())
if err != nil {
return err
}
err = e.updateSTUNs(update.GetWiretrusteeConfig().GetStuns())
err = e.updateSTUNs(wCfg.GetStuns())
if err != nil {
return err
}
@ -482,8 +485,16 @@ func (e *Engine) handleSync(update *mgmProto.SyncResponse) error {
stunTurn = append(stunTurn, e.TURNs...)
e.StunTurn.Store(stunTurn)
// todo update relay address in the relay manager
relayMsg := wCfg.GetRelay()
if relayMsg != nil {
c := auth.Token{
Payload: relayMsg.GetTokenPayload(),
Signature: relayMsg.GetTokenSignature(),
}
e.relayManager.UpdateToken(c)
}
// todo update relay address in the relay manager
// todo update signal
}

View File

@ -1071,7 +1071,7 @@ func startManagement(dataDir string) (*grpc.Server, string, error) {
if err != nil {
return nil, "", err
}
turnManager := server.NewTimeBasedAuthSecretsManager(peersUpdateManager, config.TURNConfig)
turnManager := server.NewTimeBasedAuthSecretsManager(peersUpdateManager, config.TURNConfig, "")
mgmtServer, err := server.NewServer(config, accountManager, peersUpdateManager, turnManager, nil, nil)
if err != nil {
return nil, "", err

View File

@ -122,7 +122,7 @@ func startManagement(t *testing.T, signalAddr string, counter *int) (*grpc.Serve
if err != nil {
return nil, "", err
}
turnManager := server.NewTimeBasedAuthSecretsManager(peersUpdateManager, config.TURNConfig)
turnManager := server.NewTimeBasedAuthSecretsManager(peersUpdateManager, config.TURNConfig, "")
mgmtServer, err := server.NewServer(config, accountManager, peersUpdateManager, turnManager, nil, nil)
if err != nil {
return nil, "", err

View File

@ -75,7 +75,7 @@ func startManagement(t *testing.T) (*grpc.Server, net.Listener) {
if err != nil {
t.Fatal(err)
}
turnManager := mgmt.NewTimeBasedAuthSecretsManager(peersUpdateManager, config.TURNConfig)
turnManager := mgmt.NewTimeBasedAuthSecretsManager(peersUpdateManager, config.TURNConfig, "")
mgmtServer, err := mgmt.NewServer(config, accountManager, peersUpdateManager, turnManager, nil, nil)
if err != nil {
t.Fatal(err)

View File

@ -183,7 +183,7 @@ var (
return fmt.Errorf("failed to build default manager: %v", err)
}
turnManager := server.NewTimeBasedAuthSecretsManager(peersUpdateManager, config.TURNConfig)
turnRelayTokenManager := server.NewTimeBasedAuthSecretsManager(peersUpdateManager, config.TURNConfig, config.RelayAddress)
trustedPeers := config.ReverseProxy.TrustedPeers
defaultTrustedPeers := []netip.Prefix{netip.MustParsePrefix("0.0.0.0/0"), netip.MustParsePrefix("::/0")}
@ -260,7 +260,7 @@ var (
ephemeralManager.LoadInitialPeers()
gRPCAPIHandler := grpc.NewServer(gRPCOpts...)
srv, err := server.NewServer(config, accountManager, peersUpdateManager, turnManager, appMetrics, ephemeralManager)
srv, err := server.NewServer(config, accountManager, peersUpdateManager, turnRelayTokenManager, appMetrics, ephemeralManager)
if err != nil {
return fmt.Errorf("failed creating gRPC API handler: %v", err)
}

File diff suppressed because it is too large Load Diff

View File

@ -147,7 +147,7 @@ message WiretrusteeConfig {
// a Signal server config
HostConfig signal = 3;
string RelayAddress = 4;
RelayConfig relay = 4;
}
// HostConfig describes connection properties of some server (e.g. STUN, Signal, Management)
@ -164,6 +164,13 @@ message HostConfig {
DTLS = 4;
}
}
message RelayConfig {
repeated string urls = 1;
string tokenPayload = 2;
string tokenSignature = 3;
}
// ProtectedHostConfig is similar to HostConfig but has additional user and password
// Mostly used for TURN servers
message ProtectedHostConfig {

View File

@ -166,9 +166,9 @@ func TestGetNetworkMap_DNSConfigSync(t *testing.T) {
newAccountDNSConfig, err := am.GetNetworkMap(peer1.ID)
require.NoError(t, err)
require.Len(t, newAccountDNSConfig.DNSConfig.CustomZones, 1, "default DNS config should have one custom zone for peers")
require.True(t, newAccountDNSConfig.DNSConfig.ServiceEnable, "default DNS config should have local DNS service enabled")
require.Len(t, newAccountDNSConfig.DNSConfig.NameServerGroups, 0, "updated DNS config should have no nameserver groups since peer 1 is NS for the only existing NS group")
require.Len(t, newAccountDNSConfig.DNSConfig.CustomZones, 1, "default DNS turnCfg should have one custom zone for peers")
require.True(t, newAccountDNSConfig.DNSConfig.ServiceEnable, "default DNS turnCfg should have local DNS service enabled")
require.Len(t, newAccountDNSConfig.DNSConfig.NameServerGroups, 0, "updated DNS turnCfg should have no nameserver groups since peer 1 is NS for the only existing NS group")
dnsSettings := account.DNSSettings.Copy()
dnsSettings.DisabledManagementGroups = append(dnsSettings.DisabledManagementGroups, dnsGroup1ID)
@ -178,13 +178,13 @@ func TestGetNetworkMap_DNSConfigSync(t *testing.T) {
updatedAccountDNSConfig, err := am.GetNetworkMap(peer1.ID)
require.NoError(t, err)
require.Len(t, updatedAccountDNSConfig.DNSConfig.CustomZones, 0, "updated DNS config should have no custom zone when peer belongs to a disabled group")
require.False(t, updatedAccountDNSConfig.DNSConfig.ServiceEnable, "updated DNS config should have local DNS service disabled when peer belongs to a disabled group")
require.Len(t, updatedAccountDNSConfig.DNSConfig.CustomZones, 0, "updated DNS turnCfg should have no custom zone when peer belongs to a disabled group")
require.False(t, updatedAccountDNSConfig.DNSConfig.ServiceEnable, "updated DNS turnCfg should have local DNS service disabled when peer belongs to a disabled group")
peer2AccountDNSConfig, err := am.GetNetworkMap(peer2.ID)
require.NoError(t, err)
require.Len(t, peer2AccountDNSConfig.DNSConfig.CustomZones, 1, "DNS config should have one custom zone for peers not in the disabled group")
require.True(t, peer2AccountDNSConfig.DNSConfig.ServiceEnable, "DNS config should have DNS service enabled for peers not in the disabled group")
require.Len(t, peer2AccountDNSConfig.DNSConfig.NameServerGroups, 1, "updated DNS config should have 1 nameserver groups since peer 2 is part of the group All")
require.Len(t, peer2AccountDNSConfig.DNSConfig.CustomZones, 1, "DNS turnCfg should have one custom zone for peers not in the disabled group")
require.True(t, peer2AccountDNSConfig.DNSConfig.ServiceEnable, "DNS turnCfg should have DNS service enabled for peers not in the disabled group")
require.Len(t, peer2AccountDNSConfig.DNSConfig.NameServerGroups, 1, "updated DNS turnCfg should have 1 nameserver groups since peer 2 is part of the group All")
}
func createDNSManager(t *testing.T) (*DefaultAccountManager, error) {

View File

@ -31,7 +31,7 @@ type GRPCServer struct {
proto.UnimplementedManagementServiceServer
peersUpdateManager *PeersUpdateManager
config *Config
turnCredentialsManager TURNCredentialsManager
turnRelayTokenManager TURNRelayTokenManager
jwtValidator *jwtclaims.JWTValidator
jwtClaimsExtractor *jwtclaims.ClaimsExtractor
appMetrics telemetry.AppMetrics
@ -39,7 +39,7 @@ type GRPCServer struct {
}
// NewServer creates a new Management server
func NewServer(config *Config, accountManager AccountManager, peersUpdateManager *PeersUpdateManager, turnCredentialsManager TURNCredentialsManager, appMetrics telemetry.AppMetrics, ephemeralManager *EphemeralManager) (*GRPCServer, error) {
func NewServer(config *Config, accountManager AccountManager, peersUpdateManager *PeersUpdateManager, turnRelayTokenManager TURNRelayTokenManager, appMetrics telemetry.AppMetrics, ephemeralManager *EphemeralManager) (*GRPCServer, error) {
key, err := wgtypes.GeneratePrivateKey()
if err != nil {
return nil, err
@ -58,7 +58,7 @@ func NewServer(config *Config, accountManager AccountManager, peersUpdateManager
return nil, status.Errorf(codes.Internal, "unable to create new jwt middleware, err: %v", err)
}
} else {
log.Debug("unable to use http config to create new jwt middleware")
log.Debug("unable to use http turnCfg to create new jwt middleware")
}
if appMetrics != nil {
@ -87,7 +87,7 @@ func NewServer(config *Config, accountManager AccountManager, peersUpdateManager
peersUpdateManager: peersUpdateManager,
accountManager: accountManager,
config: config,
turnCredentialsManager: turnCredentialsManager,
turnRelayTokenManager: turnRelayTokenManager,
jwtValidator: jwtValidator,
jwtClaimsExtractor: jwtClaimsExtractor,
appMetrics: appMetrics,
@ -150,7 +150,7 @@ func (s *GRPCServer) Sync(req *proto.EncryptedMessage, srv proto.ManagementServi
s.ephemeralManager.OnPeerConnected(peer)
if s.config.TURNConfig.TimeBasedCredentials {
s.turnCredentialsManager.SetupRefresh(peer.ID)
s.turnRelayTokenManager.SetupRefresh(peer.ID)
}
if s.appMetrics != nil {
@ -201,7 +201,7 @@ func (s *GRPCServer) Sync(req *proto.EncryptedMessage, srv proto.ManagementServi
func (s *GRPCServer) cancelPeerRoutines(peer *nbpeer.Peer) {
s.peersUpdateManager.CloseChannel(peer.ID)
s.turnCredentialsManager.CancelRefresh(peer.ID)
s.turnRelayTokenManager.CancelRefresh(peer.ID)
_ = s.accountManager.CancelPeerRoutines(peer)
s.ephemeralManager.OnPeerDisconnected(peer)
}
@ -377,9 +377,14 @@ func (s *GRPCServer) Login(ctx context.Context, req *proto.EncryptedMessage) (*p
s.ephemeralManager.OnPeerDisconnected(peer)
}
trt, err := s.turnRelayTokenManager.Generate()
if err != nil {
log.Errorf("failed generating TURN and Relay token: %v", err)
}
// if peer has reached this point then it has logged in
loginResp := &proto.LoginResponse{
WiretrusteeConfig: toWiretrusteeConfig(s.config, nil),
WiretrusteeConfig: toWiretrusteeConfig(s.config, nil, trt),
PeerConfig: toPeerConfig(peer, netMap.Network, s.accountManager.GetDNSDomain()),
}
encryptedResp, err := encryption.EncryptMessage(peerKey, s.wgKey, loginResp)
@ -407,11 +412,11 @@ func ToResponseProto(configProto Protocol) proto.HostConfig_Protocol {
case TCP:
return proto.HostConfig_TCP
default:
panic(fmt.Errorf("unexpected config protocol type %v", configProto))
panic(fmt.Errorf("unexpected turnCfg protocol type %v", configProto))
}
}
func toWiretrusteeConfig(config *Config, turnCredentials *TURNCredentials) *proto.WiretrusteeConfig {
func toWiretrusteeConfig(config *Config, turnCredentials *TURNRelayToken, relayToken *TURNRelayToken) *proto.WiretrusteeConfig {
if config == nil {
return nil
}
@ -427,8 +432,8 @@ func toWiretrusteeConfig(config *Config, turnCredentials *TURNCredentials) *prot
var username string
var password string
if turnCredentials != nil {
username = turnCredentials.Username
password = turnCredentials.Password
username = turnCredentials.Payload
password = turnCredentials.Signature
} else {
username = turn.Username
password = turn.Password
@ -443,6 +448,18 @@ func toWiretrusteeConfig(config *Config, turnCredentials *TURNCredentials) *prot
})
}
var relayCfg *proto.RelayConfig
if config.RelayAddress != "" {
relayCfg = &proto.RelayConfig{
Urls: []string{config.RelayAddress},
}
if relayToken != nil {
relayCfg.TokenPayload = relayToken.Payload
relayCfg.TokenSignature = relayToken.Signature
}
}
return &proto.WiretrusteeConfig{
Stuns: stuns,
Turns: turns,
@ -450,7 +467,7 @@ func toWiretrusteeConfig(config *Config, turnCredentials *TURNCredentials) *prot
Uri: config.Signal.URI,
Protocol: ToResponseProto(config.Signal.Proto),
},
RelayAddress: config.RelayAddress,
Relay: relayCfg,
}
}
@ -478,8 +495,8 @@ func toRemotePeerConfig(peers []*nbpeer.Peer, dnsName string) []*proto.RemotePee
return remotePeers
}
func toSyncResponse(config *Config, peer *nbpeer.Peer, turnCredentials *TURNCredentials, networkMap *NetworkMap, dnsName string) *proto.SyncResponse {
wtConfig := toWiretrusteeConfig(config, turnCredentials)
func toSyncResponse(config *Config, peer *nbpeer.Peer, turnCredentials *TURNRelayToken, relayCredentials *TURNRelayToken, networkMap *NetworkMap, dnsName string) *proto.SyncResponse {
wtConfig := toWiretrusteeConfig(config, turnCredentials, relayCredentials)
pConfig := toPeerConfig(peer, networkMap.Network, dnsName)
@ -520,14 +537,16 @@ func (s *GRPCServer) IsHealthy(ctx context.Context, req *proto.Empty) (*proto.Em
// sendInitialSync sends initial proto.SyncResponse to the peer requesting synchronization
func (s *GRPCServer) sendInitialSync(peerKey wgtypes.Key, peer *nbpeer.Peer, networkMap *NetworkMap, srv proto.ManagementService_SyncServer) error {
// make secret time based TURN credentials optional
var turnCredentials *TURNCredentials
if s.config.TURNConfig.TimeBasedCredentials {
creds := s.turnCredentialsManager.GenerateCredentials()
turnCredentials = &creds
} else {
turnCredentials = nil
var turnCredentials *TURNRelayToken
trt, err := s.turnRelayTokenManager.Generate()
if err != nil {
log.Errorf("failed generating TURN and Relay token: %v", err)
}
plainResp := toSyncResponse(s.config, peer, turnCredentials, networkMap, s.accountManager.GetDNSDomain())
if s.config.TURNConfig.TimeBasedCredentials {
turnCredentials = trt
}
plainResp := toSyncResponse(s.config, peer, turnCredentials, trt, networkMap, s.accountManager.GetDNSDomain())
encryptedResp, err := encryption.EncryptMessage(peerKey, s.wgKey, plainResp)
if err != nil {

View File

@ -169,7 +169,7 @@ func Test_SyncProtocol(t *testing.T) {
}
if wiretrusteeConfig.GetSignal() == nil {
t.Fatal("expecting SyncResponse to have WiretrusteeConfig with non-nil Signal config")
t.Fatal("expecting SyncResponse to have WiretrusteeConfig with non-nil Signal turnCfg")
}
expectedSignalConfig := &mgmtProto.HostConfig{
@ -418,7 +418,7 @@ func startManagement(t *testing.T, config *Config) (*grpc.Server, string, error)
if err != nil {
return nil, "", err
}
turnManager := NewTimeBasedAuthSecretsManager(peersUpdateManager, config.TURNConfig)
turnManager := NewTimeBasedAuthSecretsManager(peersUpdateManager, config.TURNConfig, "")
ephemeralMgr := NewEphemeralManager(store, accountManager)
mgmtServer, err := NewServer(config, accountManager, peersUpdateManager, turnManager, nil, ephemeralMgr)

View File

@ -544,7 +544,7 @@ func startServer(config *server.Config) (*grpc.Server, net.Listener) {
if err != nil {
log.Fatalf("failed creating a manager: %v", err)
}
turnManager := server.NewTimeBasedAuthSecretsManager(peersUpdateManager, config.TURNConfig)
turnManager := server.NewTimeBasedAuthSecretsManager(peersUpdateManager, config.TURNConfig, "")
mgmtServer, err := server.NewServer(config, accountManager, peersUpdateManager, turnManager, nil, nil)
Expect(err).NotTo(HaveOccurred())
mgmtProto.RegisterManagementServiceServer(s, mgmtServer)

View File

@ -900,7 +900,7 @@ func (am *DefaultAccountManager) updateAccountPeers(account *Account) {
continue
}
remotePeerNetworkMap := account.GetPeerNetworkMap(peer.ID, am.dnsDomain, approvedPeersMap)
update := toSyncResponse(nil, peer, nil, remotePeerNetworkMap, am.GetDNSDomain())
update := toSyncResponse(nil, peer, nil, nil, remotePeerNetworkMap, am.GetDNSDomain())
am.peersUpdateManager.SendUpdate(peer.ID, &UpdateMessage{Update: update})
}
}

View File

@ -0,0 +1,126 @@
package server
import (
"fmt"
"sync"
"time"
log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/management/proto"
auth "github.com/netbirdio/netbird/relay/auth/hmac"
)
// TURNRelayTokenManager used to manage TURN credentials
type TURNRelayTokenManager interface {
Generate() (*TURNRelayToken, error)
SetupRefresh(peerKey string)
CancelRefresh(peerKey string)
}
// TimeBasedAuthSecretsManager generates credentials with TTL and using pre-shared secret known to TURN server
type TimeBasedAuthSecretsManager struct {
mux sync.Mutex
turnCfg *TURNConfig
relayAddr string
hmacToken *auth.TimedHMAC
updateManager *PeersUpdateManager
cancelMap map[string]chan struct{}
}
type TURNRelayToken auth.Token
func NewTimeBasedAuthSecretsManager(updateManager *PeersUpdateManager, turnCfg *TURNConfig, relayAddress string) *TimeBasedAuthSecretsManager {
return &TimeBasedAuthSecretsManager{
mux: sync.Mutex{},
updateManager: updateManager,
turnCfg: turnCfg,
relayAddr: relayAddress,
hmacToken: auth.NewTimedHMAC(turnCfg.Secret, turnCfg.CredentialsTTL.Duration),
cancelMap: make(map[string]chan struct{}),
}
}
// Generate generates new time-based secret credentials - basically username is a unix timestamp and password is a HMAC hash of a timestamp with a preshared TURN secret
func (m *TimeBasedAuthSecretsManager) Generate() (*TURNRelayToken, error) {
token, err := m.hmacToken.GenerateToken()
if err != nil {
return nil, fmt.Errorf("failed to generate token: %s", err)
}
return (*TURNRelayToken)(token), nil
}
func (m *TimeBasedAuthSecretsManager) cancel(peerID string) {
if channel, ok := m.cancelMap[peerID]; ok {
close(channel)
delete(m.cancelMap, peerID)
}
}
// CancelRefresh cancels scheduled peer credentials refresh
func (m *TimeBasedAuthSecretsManager) CancelRefresh(peerID string) {
m.mux.Lock()
defer m.mux.Unlock()
m.cancel(peerID)
}
// SetupRefresh starts peer credentials refresh. Since credentials are expiring (TTL) it is necessary to always generate them and send to the peer.
// A goroutine is created and put into TimeBasedAuthSecretsManager.cancelMap. This routine should be cancelled if peer is gone.
func (m *TimeBasedAuthSecretsManager) SetupRefresh(peerID string) {
m.mux.Lock()
defer m.mux.Unlock()
m.cancel(peerID)
cancel := make(chan struct{}, 1)
m.cancelMap[peerID] = cancel
log.Debugf("starting turn refresh for %s", peerID)
go func() {
// we don't want to regenerate credentials right on expiration, so we do it slightly before (at 3/4 of TTL)
ticker := time.NewTicker(m.turnCfg.CredentialsTTL.Duration / 4 * 3)
defer ticker.Stop()
for {
select {
case <-cancel:
log.Debugf("stopping turn refresh for %s", peerID)
return
case <-ticker.C:
m.pushNewTokens(peerID)
}
}
}()
}
func (m *TimeBasedAuthSecretsManager) pushNewTokens(peerID string) {
token, err := m.hmacToken.GenerateToken()
if err != nil {
log.Errorf("failed to generate token for peer '%s': %s", peerID, err)
return
}
var turns []*proto.ProtectedHostConfig
for _, host := range m.turnCfg.Turns {
turns = append(turns, &proto.ProtectedHostConfig{
HostConfig: &proto.HostConfig{
Uri: host.URI,
Protocol: ToResponseProto(host.Proto),
},
User: token.Payload,
Password: token.Signature,
})
}
update := &proto.SyncResponse{
WiretrusteeConfig: &proto.WiretrusteeConfig{
Turns: turns,
Relay: &proto.RelayConfig{
Urls: []string{m.relayAddr},
TokenPayload: token.Payload,
TokenSignature: token.Signature,
},
},
}
log.Debugf("sending new TURN credentials to peer %s", peerID)
m.updateManager.SendUpdate(peerID, &UpdateMessage{Update: update})
}

View File

@ -26,18 +26,18 @@ func TestTimeBasedAuthSecretsManager_GenerateCredentials(t *testing.T) {
CredentialsTTL: ttl,
Secret: secret,
Turns: []*Host{TurnTestHost},
})
}, "")
credentials := tested.GenerateCredentials()
credentials, _ := tested.Generate()
if credentials.Username == "" {
if credentials.Payload == "" {
t.Errorf("expected generated TURN username not to be empty, got empty")
}
if credentials.Password == "" {
if credentials.Signature == "" {
t.Errorf("expected generated TURN password not to be empty, got empty")
}
validateMAC(t, credentials.Username, credentials.Password, []byte(secret))
validateMAC(t, credentials.Payload, credentials.Signature, []byte(secret))
}
@ -52,7 +52,7 @@ func TestTimeBasedAuthSecretsManager_SetupRefresh(t *testing.T) {
CredentialsTTL: ttl,
Secret: secret,
Turns: []*Host{TurnTestHost},
})
}, "")
tested.SetupRefresh(peer)
@ -100,7 +100,7 @@ func TestTimeBasedAuthSecretsManager_CancelRefresh(t *testing.T) {
CredentialsTTL: ttl,
Secret: secret,
Turns: []*Host{TurnTestHost},
})
}, "")
tested.SetupRefresh(peer)
if _, ok := tested.cancelMap[peer]; !ok {

View File

@ -1,125 +0,0 @@
package server
import (
"crypto/hmac"
"crypto/sha1"
"encoding/base64"
"fmt"
"sync"
"time"
log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/management/proto"
)
// TURNCredentialsManager used to manage TURN credentials
type TURNCredentialsManager interface {
GenerateCredentials() TURNCredentials
SetupRefresh(peerKey string)
CancelRefresh(peerKey string)
}
// TimeBasedAuthSecretsManager generates credentials with TTL and using pre-shared secret known to TURN server
type TimeBasedAuthSecretsManager struct {
mux sync.Mutex
config *TURNConfig
updateManager *PeersUpdateManager
cancelMap map[string]chan struct{}
}
type TURNCredentials struct {
Username string
Password string
}
func NewTimeBasedAuthSecretsManager(updateManager *PeersUpdateManager, config *TURNConfig) *TimeBasedAuthSecretsManager {
return &TimeBasedAuthSecretsManager{
mux: sync.Mutex{},
config: config,
updateManager: updateManager,
cancelMap: make(map[string]chan struct{}),
}
}
// GenerateCredentials generates new time-based secret credentials - basically username is a unix timestamp and password is a HMAC hash of a timestamp with a preshared TURN secret
func (m *TimeBasedAuthSecretsManager) GenerateCredentials() TURNCredentials {
mac := hmac.New(sha1.New, []byte(m.config.Secret))
timeAuth := time.Now().Add(m.config.CredentialsTTL.Duration).Unix()
username := fmt.Sprint(timeAuth)
_, err := mac.Write([]byte(username))
if err != nil {
log.Errorln("Generating turn password failed with error: ", err)
}
bytePassword := mac.Sum(nil)
password := base64.StdEncoding.EncodeToString(bytePassword)
return TURNCredentials{
Username: username,
Password: password,
}
}
func (m *TimeBasedAuthSecretsManager) cancel(peerID string) {
if channel, ok := m.cancelMap[peerID]; ok {
close(channel)
delete(m.cancelMap, peerID)
}
}
// CancelRefresh cancels scheduled peer credentials refresh
func (m *TimeBasedAuthSecretsManager) CancelRefresh(peerID string) {
m.mux.Lock()
defer m.mux.Unlock()
m.cancel(peerID)
}
// SetupRefresh starts peer credentials refresh. Since credentials are expiring (TTL) it is necessary to always generate them and send to the peer.
// A goroutine is created and put into TimeBasedAuthSecretsManager.cancelMap. This routine should be cancelled if peer is gone.
func (m *TimeBasedAuthSecretsManager) SetupRefresh(peerID string) {
m.mux.Lock()
defer m.mux.Unlock()
m.cancel(peerID)
cancel := make(chan struct{}, 1)
m.cancelMap[peerID] = cancel
log.Debugf("starting turn refresh for %s", peerID)
go func() {
// we don't want to regenerate credentials right on expiration, so we do it slightly before (at 3/4 of TTL)
ticker := time.NewTicker(m.config.CredentialsTTL.Duration / 4 * 3)
for {
select {
case <-cancel:
log.Debugf("stopping turn refresh for %s", peerID)
return
case <-ticker.C:
c := m.GenerateCredentials()
var turns []*proto.ProtectedHostConfig
for _, host := range m.config.Turns {
turns = append(turns, &proto.ProtectedHostConfig{
HostConfig: &proto.HostConfig{
Uri: host.URI,
Protocol: ToResponseProto(host.Proto),
},
User: c.Username,
Password: c.Password,
})
}
update := &proto.SyncResponse{
WiretrusteeConfig: &proto.WiretrusteeConfig{
Turns: turns,
},
}
log.Debugf("sending new TURN credentials to peer %s", peerID)
m.updateManager.SendUpdate(peerID, &UpdateMessage{Update: update})
}
}
}()
}

9
relay/auth/allow_all.go Normal file
View File

@ -0,0 +1,9 @@
package auth
// AllowAllAuth is a Validator that allows all connections.
type AllowAllAuth struct {
}
func (a *AllowAllAuth) Validate(any) error {
return nil
}

24
relay/auth/hmac/store.go Normal file
View File

@ -0,0 +1,24 @@
package hmac
import (
"sync"
)
// Store is a simple in-memory store for token
// With this can update the token in thread safe way
type Store struct {
mu sync.Mutex
token Token
}
func (a *Store) UpdateToken(token Token) {
a.mu.Lock()
defer a.mu.Unlock()
a.token = token
}
func (a *Store) Token() ([]byte, error) {
a.mu.Lock()
defer a.mu.Unlock()
return marshalToken(a.token)
}

104
relay/auth/hmac/token.go Normal file
View File

@ -0,0 +1,104 @@
package hmac
import (
"bytes"
"crypto/hmac"
"crypto/sha1"
"encoding/base64"
"encoding/gob"
"fmt"
"strconv"
"sync"
"time"
log "github.com/sirupsen/logrus"
)
type Token struct {
Payload string
Signature string
}
func marshalToken(token Token) ([]byte, error) {
buffer := bytes.NewBuffer([]byte{})
encoder := gob.NewEncoder(buffer)
err := encoder.Encode(token)
if err != nil {
log.Errorf("failed to marshal token: %s", err)
return nil, err
}
return buffer.Bytes(), nil
}
func unmarshalToken(payload []byte) (Token, error) {
var creds Token
buffer := bytes.NewBuffer(payload)
decoder := gob.NewDecoder(buffer)
err := decoder.Decode(&creds)
return creds, err
}
// TimedHMAC generates token with TTL and using pre-shared secret known to TURN server
type TimedHMAC struct {
mux sync.Mutex
secret string
timeToLive time.Duration
}
func NewTimedHMAC(secret string, timeToLive time.Duration) *TimedHMAC {
return &TimedHMAC{
secret: secret,
timeToLive: timeToLive,
}
}
// GenerateToken generates new time-based secret token - basically Payload is a unix timestamp and Signature is a HMAC hash of a timestamp with a preshared TURN secret
func (m *TimedHMAC) GenerateToken() (*Token, error) {
timeAuth := time.Now().Add(m.timeToLive).Unix()
timeStamp := fmt.Sprint(timeAuth)
checksum, err := m.generate(timeStamp)
if err != nil {
return nil, err
}
return &Token{
Payload: timeStamp,
Signature: base64.StdEncoding.EncodeToString(checksum),
}, nil
}
func (m *TimedHMAC) Validate(token Token) error {
expectedMAC, err := m.generate(token.Payload)
if err != nil {
return err
}
expectedSignature := base64.StdEncoding.EncodeToString(expectedMAC)
if !hmac.Equal([]byte(expectedSignature), []byte(token.Signature)) {
return fmt.Errorf("signature mismatch")
}
timeAuthInt, err := strconv.ParseInt(token.Payload, 10, 64)
if err != nil {
return fmt.Errorf("invalid payload: %s", err)
}
if time.Now().Unix() > timeAuthInt {
return fmt.Errorf("expired token")
}
return nil
}
func (m *TimedHMAC) generate(payload string) ([]byte, error) {
mac := hmac.New(sha1.New, []byte(m.secret))
_, err := mac.Write([]byte(payload))
if err != nil {
log.Errorf("failed to generate token: %s", err)
return nil, err
}
return mac.Sum(nil), nil
}

View File

@ -0,0 +1,103 @@
package hmac
import (
"encoding/base64"
"strconv"
"testing"
"time"
)
func TestGenerateCredentials(t *testing.T) {
secret := "secret"
timeToLive := 1 * time.Hour
v := NewTimedHMAC(secret, timeToLive)
creds, err := v.GenerateToken()
if err != nil {
t.Fatalf("expected no error, got %v", err)
}
if creds.Payload == "" {
t.Fatalf("expected non-empty payload")
}
_, err = strconv.ParseInt(creds.Payload, 10, 64)
if err != nil {
t.Fatalf("expected payload to be a valid unix timestamp, got %v", err)
}
_, err = base64.StdEncoding.DecodeString(creds.Signature)
if err != nil {
t.Fatalf("expected signature to be base64 encoded, got %v", err)
}
}
func TestValidateCredentials(t *testing.T) {
secret := "supersecret"
timeToLive := 1 * time.Hour
manager := NewTimedHMAC(secret, timeToLive)
// Test valid token
creds, err := manager.GenerateToken()
if err != nil {
t.Fatalf("expected no error, got %v", err)
}
if err := manager.Validate(*creds); err != nil {
t.Fatalf("expected valid token: %s", err)
}
}
func TestInvalidSignature(t *testing.T) {
secret := "supersecret"
timeToLive := 1 * time.Hour
manager := NewTimedHMAC(secret, timeToLive)
creds, err := manager.GenerateToken()
if err != nil {
t.Fatalf("expected no error, got %v", err)
}
invalidCreds := &Token{
Payload: creds.Payload,
Signature: "invalidsignature",
}
if err = manager.Validate(*invalidCreds); err == nil {
t.Fatalf("expected invalid token due to signature mismatch")
}
}
func TestExpired(t *testing.T) {
secret := "supersecret"
v := NewTimedHMAC(secret, -1*time.Hour)
expiredCreds, err := v.GenerateToken()
if err != nil {
t.Fatalf("expected no error, got %v", err)
}
if err = v.Validate(*expiredCreds); err == nil {
t.Fatalf("expected invalid token due to expiration")
}
}
func TestInvalidPayload(t *testing.T) {
secret := "supersecret"
timeToLive := 1 * time.Hour
v := NewTimedHMAC(secret, timeToLive)
creds, err := v.GenerateToken()
if err != nil {
t.Fatalf("expected no error, got %v", err)
}
// Test invalid payload
invalidPayloadCreds := &Token{
Payload: "invalidtimestamp",
Signature: creds.Signature,
}
if err = v.Validate(*invalidPayloadCreds); err == nil {
t.Fatalf("expected invalid token due to invalid payload")
}
}

View File

@ -0,0 +1,27 @@
package hmac
import (
log "github.com/sirupsen/logrus"
"time"
)
type TimedHMACValidator struct {
*TimedHMAC
}
func NewTimedHMACValidator(secret string, duration time.Duration) *TimedHMACValidator {
ta := NewTimedHMAC(secret, duration)
return &TimedHMACValidator{
ta,
}
}
func (a *TimedHMACValidator) Validate(credentials any) error {
b := credentials.([]byte)
c, err := unmarshalToken(b)
if err != nil {
log.Errorf("failed to unmarshal token: %s", err)
return err
}
return a.TimedHMAC.Validate(c)
}

5
relay/auth/validator.go Normal file
View File

@ -0,0 +1,5 @@
package auth
type Validator interface {
Validate(any) error
}

View File

@ -10,6 +10,7 @@ import (
log "github.com/sirupsen/logrus"
auth "github.com/netbirdio/netbird/relay/auth/hmac"
"github.com/netbirdio/netbird/relay/client/dialer/ws"
"github.com/netbirdio/netbird/relay/healthcheck"
"github.com/netbirdio/netbird/relay/messages"
@ -98,6 +99,7 @@ type Client struct {
log *log.Entry
parentCtx context.Context
connectionURL string
authStore *auth.Store
hashedID []byte
bufPool *sync.Pool
@ -115,12 +117,13 @@ type Client struct {
}
// NewClient creates a new client for the relay server. The client is not connected to the server until the Connect
func NewClient(ctx context.Context, serverURL, peerID string) *Client {
func NewClient(ctx context.Context, serverURL string, authStore *auth.Store, peerID string) *Client {
hashedID, hashedStringId := messages.HashID(peerID)
return &Client{
log: log.WithField("client_id", hashedStringId),
parentCtx: ctx,
connectionURL: serverURL,
authStore: authStore,
hashedID: hashedID,
bufPool: &sync.Pool{
New: func() any {
@ -234,7 +237,12 @@ func (c *Client) connect() error {
}
func (c *Client) handShake() error {
msg, err := messages.MarshalHelloMsg(c.hashedID)
t, err := c.authStore.Token()
if err != nil {
return err
}
msg, err := messages.MarshalHelloMsg(c.hashedID, t)
if err != nil {
log.Errorf("failed to marshal hello message: %s", err)
return err
@ -262,11 +270,11 @@ func (c *Client) handShake() error {
return fmt.Errorf("unexpected message type")
}
domain, err := messages.UnmarshalHelloResponse(buf[:n])
ia, err := messages.UnmarshalHelloResponse(buf[:n])
if err != nil {
return err
}
c.instanceURL = domain
c.instanceURL = ia
return nil
}

View File

@ -8,6 +8,8 @@ import (
"time"
log "github.com/sirupsen/logrus"
relayAuth "github.com/netbirdio/netbird/relay/auth/hmac"
)
var (
@ -38,6 +40,7 @@ type Manager struct {
ctx context.Context
serverURL string
peerID string
tokenStore *relayAuth.Store
relayClient *Client
reconnectGuard *Guard
@ -54,6 +57,7 @@ func NewManager(ctx context.Context, serverURL string, peerID string) *Manager {
ctx: ctx,
serverURL: serverURL,
peerID: peerID,
tokenStore: &relayAuth.Store{},
relayClients: make(map[string]*RelayTrack),
onDisconnectedListeners: make(map[string]map[*func()]struct{}),
}
@ -65,7 +69,7 @@ func (m *Manager) Serve() error {
return fmt.Errorf("manager already serving")
}
m.relayClient = NewClient(m.ctx, m.serverURL, m.peerID)
m.relayClient = NewClient(m.ctx, m.serverURL, m.tokenStore, m.peerID)
err := m.relayClient.Connect()
if err != nil {
log.Errorf("failed to connect to relay server: %s", err)
@ -158,7 +162,7 @@ func (m *Manager) openConnVia(serverAddress, peerKey string) (net.Conn, error) {
m.relayClients[serverAddress] = rt
m.relayClientsMutex.Unlock()
relayClient := NewClient(m.ctx, serverAddress, m.peerID)
relayClient := NewClient(m.ctx, serverAddress, m.tokenStore, m.peerID)
err := relayClient.Connect()
if err != nil {
rt.Unlock()
@ -260,3 +264,7 @@ func (m *Manager) notifyOnDisconnectListeners(serverAddress string) {
m.listenerLock.Unlock()
}
func (m *Manager) UpdateToken(token relayAuth.Token) {
m.tokenStore.UpdateToken(token)
}

View File

@ -24,6 +24,7 @@ var (
letsencryptDomains []string
tlsCertFile string
tlsKeyFile string
authSecret string
rootCmd = &cobra.Command{
Use: "relay",
@ -41,7 +42,7 @@ func init() {
rootCmd.PersistentFlags().StringArrayVarP(&letsencryptDomains, "letsencrypt-domains", "a", nil, "list of domains to issue Let's Encrypt certificate for. Enables TLS using Let's Encrypt. Will fetch and renew certificate, and run the server with TLS")
rootCmd.PersistentFlags().StringVarP(&tlsCertFile, "tls-cert-file", "c", "", "")
rootCmd.PersistentFlags().StringVarP(&tlsKeyFile, "tls-key-file", "k", "", "")
rootCmd.PersistentFlags().StringVarP(&authSecret, "auth-secret", "s", "", "log level")
}
func waitForExitSignal() {
@ -56,6 +57,11 @@ func execute(cmd *cobra.Command, args []string) {
os.Exit(1)
}
if authSecret == "" {
log.Errorf("auth secret is required")
os.Exit(1)
}
srvListenerCfg := server.ListenerConfig{
Address: listenAddress,
}
@ -76,7 +82,7 @@ func execute(cmd *cobra.Command, args []string) {
}
tlsSupport := srvListenerCfg.TLSConfig != nil
srv := server.NewServer(exposedAddress, tlsSupport)
srv := server.NewServer(exposedAddress, tlsSupport, authSecret)
log.Infof("server will be available on: %s", srv.InstanceURL())
err := srv.Listen(srvListenerCfg)
if err != nil {

View File

@ -15,8 +15,10 @@ const (
MsgTypeClose MsgType = 3
MsgTypeHealthCheck MsgType = 4
headerSizeTransport = 1 + IDSize // 1 byte for msg type, IDSize for peerID
headerSizeHello = 1 + 4 + IDSize // 1 byte for msg type, 4 byte for magic header, IDSize for peerID
sizeOfMsgType = 1
sizeOfMagicBye = 4
headerSizeTransport = sizeOfMsgType + IDSize // 1 byte for msg type, IDSize for peerID
headerSizeHello = sizeOfMsgType + sizeOfMagicBye + IDSize // 1 byte for msg type, 4 byte for magic header, IDSize for peerID
MaxHandshakeSize = 90
)
@ -47,7 +49,7 @@ func (m MsgType) String() string {
}
type HelloResponse struct {
DomainAddress string
InstanceAddress string
}
func DetermineClientMsgType(msg []byte) (MsgType, error) {
@ -83,28 +85,29 @@ func DetermineServerMsgType(msg []byte) (MsgType, error) {
}
// MarshalHelloMsg initial hello message
func MarshalHelloMsg(peerID []byte) ([]byte, error) {
func MarshalHelloMsg(peerID []byte, additions []byte) ([]byte, error) {
if len(peerID) != IDSize {
return nil, fmt.Errorf("invalid peerID length: %d", len(peerID))
}
msg := make([]byte, 5, headerSizeHello)
msg := make([]byte, 5, headerSizeHello+len(additions))
msg[0] = byte(MsgTypeHello)
copy(msg[1:5], magicHeader)
msg = append(msg, peerID...)
msg = append(msg, additions...)
return msg, nil
}
func UnmarshalHelloMsg(msg []byte) ([]byte, error) {
func UnmarshalHelloMsg(msg []byte) ([]byte, []byte, error) {
if len(msg) < headerSizeHello {
return nil, fmt.Errorf("invalid 'hello' messge")
return nil, nil, fmt.Errorf("invalid 'hello' messge")
}
bytes.Equal(msg[1:5], magicHeader)
return msg[5:], nil
return msg[5:], msg[headerSizeHello:], nil
}
func MarshalHelloResponse(DomainAddress string) ([]byte, error) {
payload := HelloResponse{
DomainAddress: DomainAddress,
InstanceAddress: DomainAddress,
}
buf := new(bytes.Buffer)
@ -135,7 +138,7 @@ func UnmarshalHelloResponse(msg []byte) (string, error) {
log.Errorf("failed to gob decode hello response: %s", err)
return "", err
}
return payload.DomainAddress, nil
return payload.InstanceAddress, nil
}
// Close message

View File

@ -6,12 +6,12 @@ import (
func TestMarshalHelloMsg(t *testing.T) {
peerID := []byte("abdFAaBcawquEiCMzAabYosuUaGLtSNhKxz+")
bHello, err := MarshalHelloMsg(peerID)
bHello, err := MarshalHelloMsg(peerID, nil)
if err != nil {
t.Fatalf("error: %v", err)
}
receivedPeerID, err := UnmarshalHelloMsg(bHello)
receivedPeerID, _, err := UnmarshalHelloMsg(bHello)
if err != nil {
t.Fatalf("error: %v", err)
}

View File

@ -8,19 +8,23 @@ import (
log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/relay/auth"
"github.com/netbirdio/netbird/relay/messages"
)
type Relay struct {
validator auth.Validator
store *Store
instaceURL string // domain:port
instaceURL string
closed bool
closeMu sync.RWMutex
}
func NewRelay(exposedAddress string, tlsSupport bool) *Relay {
func NewRelay(exposedAddress string, tlsSupport bool, validator auth.Validator) *Relay {
r := &Relay{
validator: validator,
store: NewStore(),
}
@ -29,6 +33,7 @@ func NewRelay(exposedAddress string, tlsSupport bool) *Relay {
} else {
r.instaceURL = fmt.Sprintf("rel://%s", exposedAddress)
}
return r
}
@ -94,12 +99,17 @@ func (r *Relay) handShake(conn net.Conn) ([]byte, error) {
return nil, tErr
}
peerID, err := messages.UnmarshalHelloMsg(buf[:n])
peerID, authPayload, err := messages.UnmarshalHelloMsg(buf[:n])
if err != nil {
log.Errorf("failed to handshake: %s", err)
return nil, err
}
if err := r.validator.Validate(authPayload); err != nil {
log.Errorf("failed to authenticate peer with id: %s, %s", peerID, err)
return nil, fmt.Errorf("failed to authenticate peer")
}
msg, _ := messages.MarshalHelloResponse(r.instaceURL)
_, err = conn.Write(msg)
if err != nil {

View File

@ -9,6 +9,7 @@ import (
log "github.com/sirupsen/logrus"
auth "github.com/netbirdio/netbird/relay/auth/hmac"
"github.com/netbirdio/netbird/relay/server/listener"
"github.com/netbirdio/netbird/relay/server/listener/udp"
"github.com/netbirdio/netbird/relay/server/listener/ws"
@ -25,9 +26,12 @@ type Server struct {
wSListener listener.Listener
}
func NewServer(exposedAddress string, tlsSupport bool) *Server {
func NewServer(exposedAddress string, tlsSupport bool, authSecret string) *Server {
return &Server{
relay: NewRelay(exposedAddress, tlsSupport),
relay: NewRelay(
exposedAddress,
tlsSupport,
auth.NewTimedHMACValidator(authSecret, 24*time.Hour)),
}
}