mirror of
https://github.com/netbirdio/netbird.git
synced 2025-06-22 18:51:34 +02:00
Integrate the relay authentication
This commit is contained in:
parent
8845e8fbc7
commit
836072098b
@ -86,7 +86,7 @@ func startManagement(t *testing.T, config *mgmt.Config) (*grpc.Server, net.Liste
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
turnManager := mgmt.NewTimeBasedAuthSecretsManager(peersUpdateManager, config.TURNConfig)
|
turnManager := mgmt.NewTimeBasedAuthSecretsManager(peersUpdateManager, config.TURNConfig, "")
|
||||||
mgmtServer, err := mgmt.NewServer(config, accountManager, peersUpdateManager, turnManager, nil, nil)
|
mgmtServer, err := mgmt.NewServer(config, accountManager, peersUpdateManager, turnManager, nil, nil)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
|
@ -26,6 +26,7 @@ import (
|
|||||||
"github.com/netbirdio/netbird/iface"
|
"github.com/netbirdio/netbird/iface"
|
||||||
mgm "github.com/netbirdio/netbird/management/client"
|
mgm "github.com/netbirdio/netbird/management/client"
|
||||||
mgmProto "github.com/netbirdio/netbird/management/proto"
|
mgmProto "github.com/netbirdio/netbird/management/proto"
|
||||||
|
"github.com/netbirdio/netbird/relay/auth/hmac"
|
||||||
relayClient "github.com/netbirdio/netbird/relay/client"
|
relayClient "github.com/netbirdio/netbird/relay/client"
|
||||||
signal "github.com/netbirdio/netbird/signal/client"
|
signal "github.com/netbirdio/netbird/signal/client"
|
||||||
"github.com/netbirdio/netbird/util"
|
"github.com/netbirdio/netbird/util"
|
||||||
@ -245,9 +246,10 @@ func (c *ConnectClient) run(
|
|||||||
|
|
||||||
c.statusRecorder.MarkSignalConnected()
|
c.statusRecorder.MarkSignalConnected()
|
||||||
|
|
||||||
relayAddress := relayAddress(loginResp)
|
relayURL, token := parseRelayInfo(loginResp)
|
||||||
relayManager := relayClient.NewManager(engineCtx, relayAddress, myPrivateKey.PublicKey().String())
|
relayManager := relayClient.NewManager(engineCtx, relayURL, myPrivateKey.PublicKey().String())
|
||||||
if relayAddress != "" {
|
if relayURL != "" {
|
||||||
|
relayManager.UpdateToken(token)
|
||||||
if err = relayManager.Serve(); err != nil {
|
if err = relayManager.Serve(); err != nil {
|
||||||
log.Error(err)
|
log.Error(err)
|
||||||
return wrapErr(err)
|
return wrapErr(err)
|
||||||
@ -307,15 +309,27 @@ func (c *ConnectClient) run(
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func relayAddress(resp *mgmProto.LoginResponse) string {
|
func parseRelayInfo(resp *mgmProto.LoginResponse) (string, hmac.Token) {
|
||||||
|
// todo remove this
|
||||||
if ra := peer.ForcedRelayAddress(); ra != "" {
|
if ra := peer.ForcedRelayAddress(); ra != "" {
|
||||||
return ra
|
return ra, hmac.Token{}
|
||||||
}
|
}
|
||||||
|
|
||||||
if resp.GetWiretrusteeConfig().GetRelayAddress() != "" {
|
msg := resp.GetWiretrusteeConfig().GetRelay()
|
||||||
return resp.GetWiretrusteeConfig().GetRelayAddress()
|
if msg == nil {
|
||||||
|
return "", hmac.Token{}
|
||||||
}
|
}
|
||||||
return ""
|
|
||||||
|
var url string
|
||||||
|
if msg.GetUrls() != nil && len(msg.GetUrls()) > 0 {
|
||||||
|
url = msg.GetUrls()[0]
|
||||||
|
}
|
||||||
|
|
||||||
|
token := hmac.Token{
|
||||||
|
Payload: msg.GetTokenPayload(),
|
||||||
|
Signature: msg.GetTokenSignature(),
|
||||||
|
}
|
||||||
|
return url, token
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *ConnectClient) Engine() *Engine {
|
func (c *ConnectClient) Engine() *Engine {
|
||||||
|
@ -24,6 +24,7 @@ import (
|
|||||||
"github.com/netbirdio/netbird/client/firewall/manager"
|
"github.com/netbirdio/netbird/client/firewall/manager"
|
||||||
"github.com/netbirdio/netbird/client/internal/acl"
|
"github.com/netbirdio/netbird/client/internal/acl"
|
||||||
"github.com/netbirdio/netbird/client/internal/dns"
|
"github.com/netbirdio/netbird/client/internal/dns"
|
||||||
|
|
||||||
"github.com/netbirdio/netbird/client/internal/networkmonitor"
|
"github.com/netbirdio/netbird/client/internal/networkmonitor"
|
||||||
"github.com/netbirdio/netbird/client/internal/peer"
|
"github.com/netbirdio/netbird/client/internal/peer"
|
||||||
"github.com/netbirdio/netbird/client/internal/relay"
|
"github.com/netbirdio/netbird/client/internal/relay"
|
||||||
@ -36,6 +37,7 @@ import (
|
|||||||
"github.com/netbirdio/netbird/iface/bind"
|
"github.com/netbirdio/netbird/iface/bind"
|
||||||
mgm "github.com/netbirdio/netbird/management/client"
|
mgm "github.com/netbirdio/netbird/management/client"
|
||||||
mgmProto "github.com/netbirdio/netbird/management/proto"
|
mgmProto "github.com/netbirdio/netbird/management/proto"
|
||||||
|
auth "github.com/netbirdio/netbird/relay/auth/hmac"
|
||||||
relayClient "github.com/netbirdio/netbird/relay/client"
|
relayClient "github.com/netbirdio/netbird/relay/client"
|
||||||
"github.com/netbirdio/netbird/route"
|
"github.com/netbirdio/netbird/route"
|
||||||
signal "github.com/netbirdio/netbird/signal/client"
|
signal "github.com/netbirdio/netbird/signal/client"
|
||||||
@ -467,12 +469,13 @@ func (e *Engine) handleSync(update *mgmProto.SyncResponse) error {
|
|||||||
defer e.syncMsgMux.Unlock()
|
defer e.syncMsgMux.Unlock()
|
||||||
|
|
||||||
if update.GetWiretrusteeConfig() != nil {
|
if update.GetWiretrusteeConfig() != nil {
|
||||||
err := e.updateTURNs(update.GetWiretrusteeConfig().GetTurns())
|
wCfg := update.GetWiretrusteeConfig()
|
||||||
|
err := e.updateTURNs(wCfg.GetTurns())
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
err = e.updateSTUNs(update.GetWiretrusteeConfig().GetStuns())
|
err = e.updateSTUNs(wCfg.GetStuns())
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
@ -482,8 +485,16 @@ func (e *Engine) handleSync(update *mgmProto.SyncResponse) error {
|
|||||||
stunTurn = append(stunTurn, e.TURNs...)
|
stunTurn = append(stunTurn, e.TURNs...)
|
||||||
e.StunTurn.Store(stunTurn)
|
e.StunTurn.Store(stunTurn)
|
||||||
|
|
||||||
// todo update relay address in the relay manager
|
relayMsg := wCfg.GetRelay()
|
||||||
|
if relayMsg != nil {
|
||||||
|
c := auth.Token{
|
||||||
|
Payload: relayMsg.GetTokenPayload(),
|
||||||
|
Signature: relayMsg.GetTokenSignature(),
|
||||||
|
}
|
||||||
|
e.relayManager.UpdateToken(c)
|
||||||
|
}
|
||||||
|
|
||||||
|
// todo update relay address in the relay manager
|
||||||
// todo update signal
|
// todo update signal
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -1071,7 +1071,7 @@ func startManagement(dataDir string) (*grpc.Server, string, error) {
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, "", err
|
return nil, "", err
|
||||||
}
|
}
|
||||||
turnManager := server.NewTimeBasedAuthSecretsManager(peersUpdateManager, config.TURNConfig)
|
turnManager := server.NewTimeBasedAuthSecretsManager(peersUpdateManager, config.TURNConfig, "")
|
||||||
mgmtServer, err := server.NewServer(config, accountManager, peersUpdateManager, turnManager, nil, nil)
|
mgmtServer, err := server.NewServer(config, accountManager, peersUpdateManager, turnManager, nil, nil)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, "", err
|
return nil, "", err
|
||||||
|
@ -122,7 +122,7 @@ func startManagement(t *testing.T, signalAddr string, counter *int) (*grpc.Serve
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, "", err
|
return nil, "", err
|
||||||
}
|
}
|
||||||
turnManager := server.NewTimeBasedAuthSecretsManager(peersUpdateManager, config.TURNConfig)
|
turnManager := server.NewTimeBasedAuthSecretsManager(peersUpdateManager, config.TURNConfig, "")
|
||||||
mgmtServer, err := server.NewServer(config, accountManager, peersUpdateManager, turnManager, nil, nil)
|
mgmtServer, err := server.NewServer(config, accountManager, peersUpdateManager, turnManager, nil, nil)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, "", err
|
return nil, "", err
|
||||||
|
@ -75,7 +75,7 @@ func startManagement(t *testing.T) (*grpc.Server, net.Listener) {
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
turnManager := mgmt.NewTimeBasedAuthSecretsManager(peersUpdateManager, config.TURNConfig)
|
turnManager := mgmt.NewTimeBasedAuthSecretsManager(peersUpdateManager, config.TURNConfig, "")
|
||||||
mgmtServer, err := mgmt.NewServer(config, accountManager, peersUpdateManager, turnManager, nil, nil)
|
mgmtServer, err := mgmt.NewServer(config, accountManager, peersUpdateManager, turnManager, nil, nil)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
|
@ -183,7 +183,7 @@ var (
|
|||||||
return fmt.Errorf("failed to build default manager: %v", err)
|
return fmt.Errorf("failed to build default manager: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
turnManager := server.NewTimeBasedAuthSecretsManager(peersUpdateManager, config.TURNConfig)
|
turnRelayTokenManager := server.NewTimeBasedAuthSecretsManager(peersUpdateManager, config.TURNConfig, config.RelayAddress)
|
||||||
|
|
||||||
trustedPeers := config.ReverseProxy.TrustedPeers
|
trustedPeers := config.ReverseProxy.TrustedPeers
|
||||||
defaultTrustedPeers := []netip.Prefix{netip.MustParsePrefix("0.0.0.0/0"), netip.MustParsePrefix("::/0")}
|
defaultTrustedPeers := []netip.Prefix{netip.MustParsePrefix("0.0.0.0/0"), netip.MustParsePrefix("::/0")}
|
||||||
@ -260,7 +260,7 @@ var (
|
|||||||
ephemeralManager.LoadInitialPeers()
|
ephemeralManager.LoadInitialPeers()
|
||||||
|
|
||||||
gRPCAPIHandler := grpc.NewServer(gRPCOpts...)
|
gRPCAPIHandler := grpc.NewServer(gRPCOpts...)
|
||||||
srv, err := server.NewServer(config, accountManager, peersUpdateManager, turnManager, appMetrics, ephemeralManager)
|
srv, err := server.NewServer(config, accountManager, peersUpdateManager, turnRelayTokenManager, appMetrics, ephemeralManager)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("failed creating gRPC API handler: %v", err)
|
return fmt.Errorf("failed creating gRPC API handler: %v", err)
|
||||||
}
|
}
|
||||||
|
File diff suppressed because it is too large
Load Diff
@ -147,7 +147,7 @@ message WiretrusteeConfig {
|
|||||||
// a Signal server config
|
// a Signal server config
|
||||||
HostConfig signal = 3;
|
HostConfig signal = 3;
|
||||||
|
|
||||||
string RelayAddress = 4;
|
RelayConfig relay = 4;
|
||||||
}
|
}
|
||||||
|
|
||||||
// HostConfig describes connection properties of some server (e.g. STUN, Signal, Management)
|
// HostConfig describes connection properties of some server (e.g. STUN, Signal, Management)
|
||||||
@ -164,6 +164,13 @@ message HostConfig {
|
|||||||
DTLS = 4;
|
DTLS = 4;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
message RelayConfig {
|
||||||
|
repeated string urls = 1;
|
||||||
|
string tokenPayload = 2;
|
||||||
|
string tokenSignature = 3;
|
||||||
|
}
|
||||||
|
|
||||||
// ProtectedHostConfig is similar to HostConfig but has additional user and password
|
// ProtectedHostConfig is similar to HostConfig but has additional user and password
|
||||||
// Mostly used for TURN servers
|
// Mostly used for TURN servers
|
||||||
message ProtectedHostConfig {
|
message ProtectedHostConfig {
|
||||||
|
@ -166,9 +166,9 @@ func TestGetNetworkMap_DNSConfigSync(t *testing.T) {
|
|||||||
|
|
||||||
newAccountDNSConfig, err := am.GetNetworkMap(peer1.ID)
|
newAccountDNSConfig, err := am.GetNetworkMap(peer1.ID)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
require.Len(t, newAccountDNSConfig.DNSConfig.CustomZones, 1, "default DNS config should have one custom zone for peers")
|
require.Len(t, newAccountDNSConfig.DNSConfig.CustomZones, 1, "default DNS turnCfg should have one custom zone for peers")
|
||||||
require.True(t, newAccountDNSConfig.DNSConfig.ServiceEnable, "default DNS config should have local DNS service enabled")
|
require.True(t, newAccountDNSConfig.DNSConfig.ServiceEnable, "default DNS turnCfg should have local DNS service enabled")
|
||||||
require.Len(t, newAccountDNSConfig.DNSConfig.NameServerGroups, 0, "updated DNS config should have no nameserver groups since peer 1 is NS for the only existing NS group")
|
require.Len(t, newAccountDNSConfig.DNSConfig.NameServerGroups, 0, "updated DNS turnCfg should have no nameserver groups since peer 1 is NS for the only existing NS group")
|
||||||
|
|
||||||
dnsSettings := account.DNSSettings.Copy()
|
dnsSettings := account.DNSSettings.Copy()
|
||||||
dnsSettings.DisabledManagementGroups = append(dnsSettings.DisabledManagementGroups, dnsGroup1ID)
|
dnsSettings.DisabledManagementGroups = append(dnsSettings.DisabledManagementGroups, dnsGroup1ID)
|
||||||
@ -178,13 +178,13 @@ func TestGetNetworkMap_DNSConfigSync(t *testing.T) {
|
|||||||
|
|
||||||
updatedAccountDNSConfig, err := am.GetNetworkMap(peer1.ID)
|
updatedAccountDNSConfig, err := am.GetNetworkMap(peer1.ID)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
require.Len(t, updatedAccountDNSConfig.DNSConfig.CustomZones, 0, "updated DNS config should have no custom zone when peer belongs to a disabled group")
|
require.Len(t, updatedAccountDNSConfig.DNSConfig.CustomZones, 0, "updated DNS turnCfg should have no custom zone when peer belongs to a disabled group")
|
||||||
require.False(t, updatedAccountDNSConfig.DNSConfig.ServiceEnable, "updated DNS config should have local DNS service disabled when peer belongs to a disabled group")
|
require.False(t, updatedAccountDNSConfig.DNSConfig.ServiceEnable, "updated DNS turnCfg should have local DNS service disabled when peer belongs to a disabled group")
|
||||||
peer2AccountDNSConfig, err := am.GetNetworkMap(peer2.ID)
|
peer2AccountDNSConfig, err := am.GetNetworkMap(peer2.ID)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
require.Len(t, peer2AccountDNSConfig.DNSConfig.CustomZones, 1, "DNS config should have one custom zone for peers not in the disabled group")
|
require.Len(t, peer2AccountDNSConfig.DNSConfig.CustomZones, 1, "DNS turnCfg should have one custom zone for peers not in the disabled group")
|
||||||
require.True(t, peer2AccountDNSConfig.DNSConfig.ServiceEnable, "DNS config should have DNS service enabled for peers not in the disabled group")
|
require.True(t, peer2AccountDNSConfig.DNSConfig.ServiceEnable, "DNS turnCfg should have DNS service enabled for peers not in the disabled group")
|
||||||
require.Len(t, peer2AccountDNSConfig.DNSConfig.NameServerGroups, 1, "updated DNS config should have 1 nameserver groups since peer 2 is part of the group All")
|
require.Len(t, peer2AccountDNSConfig.DNSConfig.NameServerGroups, 1, "updated DNS turnCfg should have 1 nameserver groups since peer 2 is part of the group All")
|
||||||
}
|
}
|
||||||
|
|
||||||
func createDNSManager(t *testing.T) (*DefaultAccountManager, error) {
|
func createDNSManager(t *testing.T) (*DefaultAccountManager, error) {
|
||||||
|
@ -29,17 +29,17 @@ type GRPCServer struct {
|
|||||||
accountManager AccountManager
|
accountManager AccountManager
|
||||||
wgKey wgtypes.Key
|
wgKey wgtypes.Key
|
||||||
proto.UnimplementedManagementServiceServer
|
proto.UnimplementedManagementServiceServer
|
||||||
peersUpdateManager *PeersUpdateManager
|
peersUpdateManager *PeersUpdateManager
|
||||||
config *Config
|
config *Config
|
||||||
turnCredentialsManager TURNCredentialsManager
|
turnRelayTokenManager TURNRelayTokenManager
|
||||||
jwtValidator *jwtclaims.JWTValidator
|
jwtValidator *jwtclaims.JWTValidator
|
||||||
jwtClaimsExtractor *jwtclaims.ClaimsExtractor
|
jwtClaimsExtractor *jwtclaims.ClaimsExtractor
|
||||||
appMetrics telemetry.AppMetrics
|
appMetrics telemetry.AppMetrics
|
||||||
ephemeralManager *EphemeralManager
|
ephemeralManager *EphemeralManager
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewServer creates a new Management server
|
// NewServer creates a new Management server
|
||||||
func NewServer(config *Config, accountManager AccountManager, peersUpdateManager *PeersUpdateManager, turnCredentialsManager TURNCredentialsManager, appMetrics telemetry.AppMetrics, ephemeralManager *EphemeralManager) (*GRPCServer, error) {
|
func NewServer(config *Config, accountManager AccountManager, peersUpdateManager *PeersUpdateManager, turnRelayTokenManager TURNRelayTokenManager, appMetrics telemetry.AppMetrics, ephemeralManager *EphemeralManager) (*GRPCServer, error) {
|
||||||
key, err := wgtypes.GeneratePrivateKey()
|
key, err := wgtypes.GeneratePrivateKey()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
@ -58,7 +58,7 @@ func NewServer(config *Config, accountManager AccountManager, peersUpdateManager
|
|||||||
return nil, status.Errorf(codes.Internal, "unable to create new jwt middleware, err: %v", err)
|
return nil, status.Errorf(codes.Internal, "unable to create new jwt middleware, err: %v", err)
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
log.Debug("unable to use http config to create new jwt middleware")
|
log.Debug("unable to use http turnCfg to create new jwt middleware")
|
||||||
}
|
}
|
||||||
|
|
||||||
if appMetrics != nil {
|
if appMetrics != nil {
|
||||||
@ -84,14 +84,14 @@ func NewServer(config *Config, accountManager AccountManager, peersUpdateManager
|
|||||||
return &GRPCServer{
|
return &GRPCServer{
|
||||||
wgKey: key,
|
wgKey: key,
|
||||||
// peerKey -> event channel
|
// peerKey -> event channel
|
||||||
peersUpdateManager: peersUpdateManager,
|
peersUpdateManager: peersUpdateManager,
|
||||||
accountManager: accountManager,
|
accountManager: accountManager,
|
||||||
config: config,
|
config: config,
|
||||||
turnCredentialsManager: turnCredentialsManager,
|
turnRelayTokenManager: turnRelayTokenManager,
|
||||||
jwtValidator: jwtValidator,
|
jwtValidator: jwtValidator,
|
||||||
jwtClaimsExtractor: jwtClaimsExtractor,
|
jwtClaimsExtractor: jwtClaimsExtractor,
|
||||||
appMetrics: appMetrics,
|
appMetrics: appMetrics,
|
||||||
ephemeralManager: ephemeralManager,
|
ephemeralManager: ephemeralManager,
|
||||||
}, nil
|
}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -150,7 +150,7 @@ func (s *GRPCServer) Sync(req *proto.EncryptedMessage, srv proto.ManagementServi
|
|||||||
s.ephemeralManager.OnPeerConnected(peer)
|
s.ephemeralManager.OnPeerConnected(peer)
|
||||||
|
|
||||||
if s.config.TURNConfig.TimeBasedCredentials {
|
if s.config.TURNConfig.TimeBasedCredentials {
|
||||||
s.turnCredentialsManager.SetupRefresh(peer.ID)
|
s.turnRelayTokenManager.SetupRefresh(peer.ID)
|
||||||
}
|
}
|
||||||
|
|
||||||
if s.appMetrics != nil {
|
if s.appMetrics != nil {
|
||||||
@ -201,7 +201,7 @@ func (s *GRPCServer) Sync(req *proto.EncryptedMessage, srv proto.ManagementServi
|
|||||||
|
|
||||||
func (s *GRPCServer) cancelPeerRoutines(peer *nbpeer.Peer) {
|
func (s *GRPCServer) cancelPeerRoutines(peer *nbpeer.Peer) {
|
||||||
s.peersUpdateManager.CloseChannel(peer.ID)
|
s.peersUpdateManager.CloseChannel(peer.ID)
|
||||||
s.turnCredentialsManager.CancelRefresh(peer.ID)
|
s.turnRelayTokenManager.CancelRefresh(peer.ID)
|
||||||
_ = s.accountManager.CancelPeerRoutines(peer)
|
_ = s.accountManager.CancelPeerRoutines(peer)
|
||||||
s.ephemeralManager.OnPeerDisconnected(peer)
|
s.ephemeralManager.OnPeerDisconnected(peer)
|
||||||
}
|
}
|
||||||
@ -377,9 +377,14 @@ func (s *GRPCServer) Login(ctx context.Context, req *proto.EncryptedMessage) (*p
|
|||||||
s.ephemeralManager.OnPeerDisconnected(peer)
|
s.ephemeralManager.OnPeerDisconnected(peer)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
trt, err := s.turnRelayTokenManager.Generate()
|
||||||
|
if err != nil {
|
||||||
|
log.Errorf("failed generating TURN and Relay token: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
// if peer has reached this point then it has logged in
|
// if peer has reached this point then it has logged in
|
||||||
loginResp := &proto.LoginResponse{
|
loginResp := &proto.LoginResponse{
|
||||||
WiretrusteeConfig: toWiretrusteeConfig(s.config, nil),
|
WiretrusteeConfig: toWiretrusteeConfig(s.config, nil, trt),
|
||||||
PeerConfig: toPeerConfig(peer, netMap.Network, s.accountManager.GetDNSDomain()),
|
PeerConfig: toPeerConfig(peer, netMap.Network, s.accountManager.GetDNSDomain()),
|
||||||
}
|
}
|
||||||
encryptedResp, err := encryption.EncryptMessage(peerKey, s.wgKey, loginResp)
|
encryptedResp, err := encryption.EncryptMessage(peerKey, s.wgKey, loginResp)
|
||||||
@ -407,11 +412,11 @@ func ToResponseProto(configProto Protocol) proto.HostConfig_Protocol {
|
|||||||
case TCP:
|
case TCP:
|
||||||
return proto.HostConfig_TCP
|
return proto.HostConfig_TCP
|
||||||
default:
|
default:
|
||||||
panic(fmt.Errorf("unexpected config protocol type %v", configProto))
|
panic(fmt.Errorf("unexpected turnCfg protocol type %v", configProto))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func toWiretrusteeConfig(config *Config, turnCredentials *TURNCredentials) *proto.WiretrusteeConfig {
|
func toWiretrusteeConfig(config *Config, turnCredentials *TURNRelayToken, relayToken *TURNRelayToken) *proto.WiretrusteeConfig {
|
||||||
if config == nil {
|
if config == nil {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
@ -427,8 +432,8 @@ func toWiretrusteeConfig(config *Config, turnCredentials *TURNCredentials) *prot
|
|||||||
var username string
|
var username string
|
||||||
var password string
|
var password string
|
||||||
if turnCredentials != nil {
|
if turnCredentials != nil {
|
||||||
username = turnCredentials.Username
|
username = turnCredentials.Payload
|
||||||
password = turnCredentials.Password
|
password = turnCredentials.Signature
|
||||||
} else {
|
} else {
|
||||||
username = turn.Username
|
username = turn.Username
|
||||||
password = turn.Password
|
password = turn.Password
|
||||||
@ -443,6 +448,18 @@ func toWiretrusteeConfig(config *Config, turnCredentials *TURNCredentials) *prot
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
var relayCfg *proto.RelayConfig
|
||||||
|
if config.RelayAddress != "" {
|
||||||
|
relayCfg = &proto.RelayConfig{
|
||||||
|
Urls: []string{config.RelayAddress},
|
||||||
|
}
|
||||||
|
|
||||||
|
if relayToken != nil {
|
||||||
|
relayCfg.TokenPayload = relayToken.Payload
|
||||||
|
relayCfg.TokenSignature = relayToken.Signature
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
return &proto.WiretrusteeConfig{
|
return &proto.WiretrusteeConfig{
|
||||||
Stuns: stuns,
|
Stuns: stuns,
|
||||||
Turns: turns,
|
Turns: turns,
|
||||||
@ -450,7 +467,7 @@ func toWiretrusteeConfig(config *Config, turnCredentials *TURNCredentials) *prot
|
|||||||
Uri: config.Signal.URI,
|
Uri: config.Signal.URI,
|
||||||
Protocol: ToResponseProto(config.Signal.Proto),
|
Protocol: ToResponseProto(config.Signal.Proto),
|
||||||
},
|
},
|
||||||
RelayAddress: config.RelayAddress,
|
Relay: relayCfg,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -478,8 +495,8 @@ func toRemotePeerConfig(peers []*nbpeer.Peer, dnsName string) []*proto.RemotePee
|
|||||||
return remotePeers
|
return remotePeers
|
||||||
}
|
}
|
||||||
|
|
||||||
func toSyncResponse(config *Config, peer *nbpeer.Peer, turnCredentials *TURNCredentials, networkMap *NetworkMap, dnsName string) *proto.SyncResponse {
|
func toSyncResponse(config *Config, peer *nbpeer.Peer, turnCredentials *TURNRelayToken, relayCredentials *TURNRelayToken, networkMap *NetworkMap, dnsName string) *proto.SyncResponse {
|
||||||
wtConfig := toWiretrusteeConfig(config, turnCredentials)
|
wtConfig := toWiretrusteeConfig(config, turnCredentials, relayCredentials)
|
||||||
|
|
||||||
pConfig := toPeerConfig(peer, networkMap.Network, dnsName)
|
pConfig := toPeerConfig(peer, networkMap.Network, dnsName)
|
||||||
|
|
||||||
@ -520,14 +537,16 @@ func (s *GRPCServer) IsHealthy(ctx context.Context, req *proto.Empty) (*proto.Em
|
|||||||
// sendInitialSync sends initial proto.SyncResponse to the peer requesting synchronization
|
// sendInitialSync sends initial proto.SyncResponse to the peer requesting synchronization
|
||||||
func (s *GRPCServer) sendInitialSync(peerKey wgtypes.Key, peer *nbpeer.Peer, networkMap *NetworkMap, srv proto.ManagementService_SyncServer) error {
|
func (s *GRPCServer) sendInitialSync(peerKey wgtypes.Key, peer *nbpeer.Peer, networkMap *NetworkMap, srv proto.ManagementService_SyncServer) error {
|
||||||
// make secret time based TURN credentials optional
|
// make secret time based TURN credentials optional
|
||||||
var turnCredentials *TURNCredentials
|
var turnCredentials *TURNRelayToken
|
||||||
if s.config.TURNConfig.TimeBasedCredentials {
|
trt, err := s.turnRelayTokenManager.Generate()
|
||||||
creds := s.turnCredentialsManager.GenerateCredentials()
|
if err != nil {
|
||||||
turnCredentials = &creds
|
log.Errorf("failed generating TURN and Relay token: %v", err)
|
||||||
} else {
|
|
||||||
turnCredentials = nil
|
|
||||||
}
|
}
|
||||||
plainResp := toSyncResponse(s.config, peer, turnCredentials, networkMap, s.accountManager.GetDNSDomain())
|
if s.config.TURNConfig.TimeBasedCredentials {
|
||||||
|
turnCredentials = trt
|
||||||
|
}
|
||||||
|
|
||||||
|
plainResp := toSyncResponse(s.config, peer, turnCredentials, trt, networkMap, s.accountManager.GetDNSDomain())
|
||||||
|
|
||||||
encryptedResp, err := encryption.EncryptMessage(peerKey, s.wgKey, plainResp)
|
encryptedResp, err := encryption.EncryptMessage(peerKey, s.wgKey, plainResp)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
@ -169,7 +169,7 @@ func Test_SyncProtocol(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
if wiretrusteeConfig.GetSignal() == nil {
|
if wiretrusteeConfig.GetSignal() == nil {
|
||||||
t.Fatal("expecting SyncResponse to have WiretrusteeConfig with non-nil Signal config")
|
t.Fatal("expecting SyncResponse to have WiretrusteeConfig with non-nil Signal turnCfg")
|
||||||
}
|
}
|
||||||
|
|
||||||
expectedSignalConfig := &mgmtProto.HostConfig{
|
expectedSignalConfig := &mgmtProto.HostConfig{
|
||||||
@ -418,7 +418,7 @@ func startManagement(t *testing.T, config *Config) (*grpc.Server, string, error)
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, "", err
|
return nil, "", err
|
||||||
}
|
}
|
||||||
turnManager := NewTimeBasedAuthSecretsManager(peersUpdateManager, config.TURNConfig)
|
turnManager := NewTimeBasedAuthSecretsManager(peersUpdateManager, config.TURNConfig, "")
|
||||||
|
|
||||||
ephemeralMgr := NewEphemeralManager(store, accountManager)
|
ephemeralMgr := NewEphemeralManager(store, accountManager)
|
||||||
mgmtServer, err := NewServer(config, accountManager, peersUpdateManager, turnManager, nil, ephemeralMgr)
|
mgmtServer, err := NewServer(config, accountManager, peersUpdateManager, turnManager, nil, ephemeralMgr)
|
||||||
|
@ -544,7 +544,7 @@ func startServer(config *server.Config) (*grpc.Server, net.Listener) {
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
log.Fatalf("failed creating a manager: %v", err)
|
log.Fatalf("failed creating a manager: %v", err)
|
||||||
}
|
}
|
||||||
turnManager := server.NewTimeBasedAuthSecretsManager(peersUpdateManager, config.TURNConfig)
|
turnManager := server.NewTimeBasedAuthSecretsManager(peersUpdateManager, config.TURNConfig, "")
|
||||||
mgmtServer, err := server.NewServer(config, accountManager, peersUpdateManager, turnManager, nil, nil)
|
mgmtServer, err := server.NewServer(config, accountManager, peersUpdateManager, turnManager, nil, nil)
|
||||||
Expect(err).NotTo(HaveOccurred())
|
Expect(err).NotTo(HaveOccurred())
|
||||||
mgmtProto.RegisterManagementServiceServer(s, mgmtServer)
|
mgmtProto.RegisterManagementServiceServer(s, mgmtServer)
|
||||||
|
@ -900,7 +900,7 @@ func (am *DefaultAccountManager) updateAccountPeers(account *Account) {
|
|||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
remotePeerNetworkMap := account.GetPeerNetworkMap(peer.ID, am.dnsDomain, approvedPeersMap)
|
remotePeerNetworkMap := account.GetPeerNetworkMap(peer.ID, am.dnsDomain, approvedPeersMap)
|
||||||
update := toSyncResponse(nil, peer, nil, remotePeerNetworkMap, am.GetDNSDomain())
|
update := toSyncResponse(nil, peer, nil, nil, remotePeerNetworkMap, am.GetDNSDomain())
|
||||||
am.peersUpdateManager.SendUpdate(peer.ID, &UpdateMessage{Update: update})
|
am.peersUpdateManager.SendUpdate(peer.ID, &UpdateMessage{Update: update})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
126
management/server/token_mgr.go
Normal file
126
management/server/token_mgr.go
Normal file
@ -0,0 +1,126 @@
|
|||||||
|
package server
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"sync"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
log "github.com/sirupsen/logrus"
|
||||||
|
|
||||||
|
"github.com/netbirdio/netbird/management/proto"
|
||||||
|
auth "github.com/netbirdio/netbird/relay/auth/hmac"
|
||||||
|
)
|
||||||
|
|
||||||
|
// TURNRelayTokenManager used to manage TURN credentials
|
||||||
|
type TURNRelayTokenManager interface {
|
||||||
|
Generate() (*TURNRelayToken, error)
|
||||||
|
SetupRefresh(peerKey string)
|
||||||
|
CancelRefresh(peerKey string)
|
||||||
|
}
|
||||||
|
|
||||||
|
// TimeBasedAuthSecretsManager generates credentials with TTL and using pre-shared secret known to TURN server
|
||||||
|
type TimeBasedAuthSecretsManager struct {
|
||||||
|
mux sync.Mutex
|
||||||
|
turnCfg *TURNConfig
|
||||||
|
relayAddr string
|
||||||
|
hmacToken *auth.TimedHMAC
|
||||||
|
updateManager *PeersUpdateManager
|
||||||
|
cancelMap map[string]chan struct{}
|
||||||
|
}
|
||||||
|
|
||||||
|
type TURNRelayToken auth.Token
|
||||||
|
|
||||||
|
func NewTimeBasedAuthSecretsManager(updateManager *PeersUpdateManager, turnCfg *TURNConfig, relayAddress string) *TimeBasedAuthSecretsManager {
|
||||||
|
return &TimeBasedAuthSecretsManager{
|
||||||
|
mux: sync.Mutex{},
|
||||||
|
updateManager: updateManager,
|
||||||
|
turnCfg: turnCfg,
|
||||||
|
relayAddr: relayAddress,
|
||||||
|
hmacToken: auth.NewTimedHMAC(turnCfg.Secret, turnCfg.CredentialsTTL.Duration),
|
||||||
|
cancelMap: make(map[string]chan struct{}),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Generate generates new time-based secret credentials - basically username is a unix timestamp and password is a HMAC hash of a timestamp with a preshared TURN secret
|
||||||
|
func (m *TimeBasedAuthSecretsManager) Generate() (*TURNRelayToken, error) {
|
||||||
|
token, err := m.hmacToken.GenerateToken()
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to generate token: %s", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return (*TURNRelayToken)(token), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *TimeBasedAuthSecretsManager) cancel(peerID string) {
|
||||||
|
if channel, ok := m.cancelMap[peerID]; ok {
|
||||||
|
close(channel)
|
||||||
|
delete(m.cancelMap, peerID)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// CancelRefresh cancels scheduled peer credentials refresh
|
||||||
|
func (m *TimeBasedAuthSecretsManager) CancelRefresh(peerID string) {
|
||||||
|
m.mux.Lock()
|
||||||
|
defer m.mux.Unlock()
|
||||||
|
m.cancel(peerID)
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetupRefresh starts peer credentials refresh. Since credentials are expiring (TTL) it is necessary to always generate them and send to the peer.
|
||||||
|
// A goroutine is created and put into TimeBasedAuthSecretsManager.cancelMap. This routine should be cancelled if peer is gone.
|
||||||
|
func (m *TimeBasedAuthSecretsManager) SetupRefresh(peerID string) {
|
||||||
|
m.mux.Lock()
|
||||||
|
defer m.mux.Unlock()
|
||||||
|
m.cancel(peerID)
|
||||||
|
cancel := make(chan struct{}, 1)
|
||||||
|
m.cancelMap[peerID] = cancel
|
||||||
|
log.Debugf("starting turn refresh for %s", peerID)
|
||||||
|
|
||||||
|
go func() {
|
||||||
|
// we don't want to regenerate credentials right on expiration, so we do it slightly before (at 3/4 of TTL)
|
||||||
|
ticker := time.NewTicker(m.turnCfg.CredentialsTTL.Duration / 4 * 3)
|
||||||
|
defer ticker.Stop()
|
||||||
|
|
||||||
|
for {
|
||||||
|
select {
|
||||||
|
case <-cancel:
|
||||||
|
log.Debugf("stopping turn refresh for %s", peerID)
|
||||||
|
return
|
||||||
|
case <-ticker.C:
|
||||||
|
m.pushNewTokens(peerID)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *TimeBasedAuthSecretsManager) pushNewTokens(peerID string) {
|
||||||
|
token, err := m.hmacToken.GenerateToken()
|
||||||
|
if err != nil {
|
||||||
|
log.Errorf("failed to generate token for peer '%s': %s", peerID, err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
var turns []*proto.ProtectedHostConfig
|
||||||
|
for _, host := range m.turnCfg.Turns {
|
||||||
|
turns = append(turns, &proto.ProtectedHostConfig{
|
||||||
|
HostConfig: &proto.HostConfig{
|
||||||
|
Uri: host.URI,
|
||||||
|
Protocol: ToResponseProto(host.Proto),
|
||||||
|
},
|
||||||
|
User: token.Payload,
|
||||||
|
Password: token.Signature,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
update := &proto.SyncResponse{
|
||||||
|
WiretrusteeConfig: &proto.WiretrusteeConfig{
|
||||||
|
Turns: turns,
|
||||||
|
Relay: &proto.RelayConfig{
|
||||||
|
Urls: []string{m.relayAddr},
|
||||||
|
TokenPayload: token.Payload,
|
||||||
|
TokenSignature: token.Signature,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
log.Debugf("sending new TURN credentials to peer %s", peerID)
|
||||||
|
m.updateManager.SendUpdate(peerID, &UpdateMessage{Update: update})
|
||||||
|
}
|
@ -26,18 +26,18 @@ func TestTimeBasedAuthSecretsManager_GenerateCredentials(t *testing.T) {
|
|||||||
CredentialsTTL: ttl,
|
CredentialsTTL: ttl,
|
||||||
Secret: secret,
|
Secret: secret,
|
||||||
Turns: []*Host{TurnTestHost},
|
Turns: []*Host{TurnTestHost},
|
||||||
})
|
}, "")
|
||||||
|
|
||||||
credentials := tested.GenerateCredentials()
|
credentials, _ := tested.Generate()
|
||||||
|
|
||||||
if credentials.Username == "" {
|
if credentials.Payload == "" {
|
||||||
t.Errorf("expected generated TURN username not to be empty, got empty")
|
t.Errorf("expected generated TURN username not to be empty, got empty")
|
||||||
}
|
}
|
||||||
if credentials.Password == "" {
|
if credentials.Signature == "" {
|
||||||
t.Errorf("expected generated TURN password not to be empty, got empty")
|
t.Errorf("expected generated TURN password not to be empty, got empty")
|
||||||
}
|
}
|
||||||
|
|
||||||
validateMAC(t, credentials.Username, credentials.Password, []byte(secret))
|
validateMAC(t, credentials.Payload, credentials.Signature, []byte(secret))
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -52,7 +52,7 @@ func TestTimeBasedAuthSecretsManager_SetupRefresh(t *testing.T) {
|
|||||||
CredentialsTTL: ttl,
|
CredentialsTTL: ttl,
|
||||||
Secret: secret,
|
Secret: secret,
|
||||||
Turns: []*Host{TurnTestHost},
|
Turns: []*Host{TurnTestHost},
|
||||||
})
|
}, "")
|
||||||
|
|
||||||
tested.SetupRefresh(peer)
|
tested.SetupRefresh(peer)
|
||||||
|
|
||||||
@ -100,7 +100,7 @@ func TestTimeBasedAuthSecretsManager_CancelRefresh(t *testing.T) {
|
|||||||
CredentialsTTL: ttl,
|
CredentialsTTL: ttl,
|
||||||
Secret: secret,
|
Secret: secret,
|
||||||
Turns: []*Host{TurnTestHost},
|
Turns: []*Host{TurnTestHost},
|
||||||
})
|
}, "")
|
||||||
|
|
||||||
tested.SetupRefresh(peer)
|
tested.SetupRefresh(peer)
|
||||||
if _, ok := tested.cancelMap[peer]; !ok {
|
if _, ok := tested.cancelMap[peer]; !ok {
|
@ -1,125 +0,0 @@
|
|||||||
package server
|
|
||||||
|
|
||||||
import (
|
|
||||||
"crypto/hmac"
|
|
||||||
"crypto/sha1"
|
|
||||||
"encoding/base64"
|
|
||||||
"fmt"
|
|
||||||
"sync"
|
|
||||||
"time"
|
|
||||||
|
|
||||||
log "github.com/sirupsen/logrus"
|
|
||||||
|
|
||||||
"github.com/netbirdio/netbird/management/proto"
|
|
||||||
)
|
|
||||||
|
|
||||||
// TURNCredentialsManager used to manage TURN credentials
|
|
||||||
type TURNCredentialsManager interface {
|
|
||||||
GenerateCredentials() TURNCredentials
|
|
||||||
SetupRefresh(peerKey string)
|
|
||||||
CancelRefresh(peerKey string)
|
|
||||||
}
|
|
||||||
|
|
||||||
// TimeBasedAuthSecretsManager generates credentials with TTL and using pre-shared secret known to TURN server
|
|
||||||
type TimeBasedAuthSecretsManager struct {
|
|
||||||
mux sync.Mutex
|
|
||||||
config *TURNConfig
|
|
||||||
updateManager *PeersUpdateManager
|
|
||||||
cancelMap map[string]chan struct{}
|
|
||||||
}
|
|
||||||
|
|
||||||
type TURNCredentials struct {
|
|
||||||
Username string
|
|
||||||
Password string
|
|
||||||
}
|
|
||||||
|
|
||||||
func NewTimeBasedAuthSecretsManager(updateManager *PeersUpdateManager, config *TURNConfig) *TimeBasedAuthSecretsManager {
|
|
||||||
return &TimeBasedAuthSecretsManager{
|
|
||||||
mux: sync.Mutex{},
|
|
||||||
config: config,
|
|
||||||
updateManager: updateManager,
|
|
||||||
cancelMap: make(map[string]chan struct{}),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// GenerateCredentials generates new time-based secret credentials - basically username is a unix timestamp and password is a HMAC hash of a timestamp with a preshared TURN secret
|
|
||||||
func (m *TimeBasedAuthSecretsManager) GenerateCredentials() TURNCredentials {
|
|
||||||
mac := hmac.New(sha1.New, []byte(m.config.Secret))
|
|
||||||
|
|
||||||
timeAuth := time.Now().Add(m.config.CredentialsTTL.Duration).Unix()
|
|
||||||
|
|
||||||
username := fmt.Sprint(timeAuth)
|
|
||||||
|
|
||||||
_, err := mac.Write([]byte(username))
|
|
||||||
if err != nil {
|
|
||||||
log.Errorln("Generating turn password failed with error: ", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
bytePassword := mac.Sum(nil)
|
|
||||||
password := base64.StdEncoding.EncodeToString(bytePassword)
|
|
||||||
|
|
||||||
return TURNCredentials{
|
|
||||||
Username: username,
|
|
||||||
Password: password,
|
|
||||||
}
|
|
||||||
|
|
||||||
}
|
|
||||||
|
|
||||||
func (m *TimeBasedAuthSecretsManager) cancel(peerID string) {
|
|
||||||
if channel, ok := m.cancelMap[peerID]; ok {
|
|
||||||
close(channel)
|
|
||||||
delete(m.cancelMap, peerID)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// CancelRefresh cancels scheduled peer credentials refresh
|
|
||||||
func (m *TimeBasedAuthSecretsManager) CancelRefresh(peerID string) {
|
|
||||||
m.mux.Lock()
|
|
||||||
defer m.mux.Unlock()
|
|
||||||
m.cancel(peerID)
|
|
||||||
}
|
|
||||||
|
|
||||||
// SetupRefresh starts peer credentials refresh. Since credentials are expiring (TTL) it is necessary to always generate them and send to the peer.
|
|
||||||
// A goroutine is created and put into TimeBasedAuthSecretsManager.cancelMap. This routine should be cancelled if peer is gone.
|
|
||||||
func (m *TimeBasedAuthSecretsManager) SetupRefresh(peerID string) {
|
|
||||||
m.mux.Lock()
|
|
||||||
defer m.mux.Unlock()
|
|
||||||
m.cancel(peerID)
|
|
||||||
cancel := make(chan struct{}, 1)
|
|
||||||
m.cancelMap[peerID] = cancel
|
|
||||||
log.Debugf("starting turn refresh for %s", peerID)
|
|
||||||
|
|
||||||
go func() {
|
|
||||||
// we don't want to regenerate credentials right on expiration, so we do it slightly before (at 3/4 of TTL)
|
|
||||||
ticker := time.NewTicker(m.config.CredentialsTTL.Duration / 4 * 3)
|
|
||||||
|
|
||||||
for {
|
|
||||||
select {
|
|
||||||
case <-cancel:
|
|
||||||
log.Debugf("stopping turn refresh for %s", peerID)
|
|
||||||
return
|
|
||||||
case <-ticker.C:
|
|
||||||
c := m.GenerateCredentials()
|
|
||||||
var turns []*proto.ProtectedHostConfig
|
|
||||||
for _, host := range m.config.Turns {
|
|
||||||
turns = append(turns, &proto.ProtectedHostConfig{
|
|
||||||
HostConfig: &proto.HostConfig{
|
|
||||||
Uri: host.URI,
|
|
||||||
Protocol: ToResponseProto(host.Proto),
|
|
||||||
},
|
|
||||||
User: c.Username,
|
|
||||||
Password: c.Password,
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
update := &proto.SyncResponse{
|
|
||||||
WiretrusteeConfig: &proto.WiretrusteeConfig{
|
|
||||||
Turns: turns,
|
|
||||||
},
|
|
||||||
}
|
|
||||||
log.Debugf("sending new TURN credentials to peer %s", peerID)
|
|
||||||
m.updateManager.SendUpdate(peerID, &UpdateMessage{Update: update})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}()
|
|
||||||
}
|
|
9
relay/auth/allow_all.go
Normal file
9
relay/auth/allow_all.go
Normal file
@ -0,0 +1,9 @@
|
|||||||
|
package auth
|
||||||
|
|
||||||
|
// AllowAllAuth is a Validator that allows all connections.
|
||||||
|
type AllowAllAuth struct {
|
||||||
|
}
|
||||||
|
|
||||||
|
func (a *AllowAllAuth) Validate(any) error {
|
||||||
|
return nil
|
||||||
|
}
|
24
relay/auth/hmac/store.go
Normal file
24
relay/auth/hmac/store.go
Normal file
@ -0,0 +1,24 @@
|
|||||||
|
package hmac
|
||||||
|
|
||||||
|
import (
|
||||||
|
"sync"
|
||||||
|
)
|
||||||
|
|
||||||
|
// Store is a simple in-memory store for token
|
||||||
|
// With this can update the token in thread safe way
|
||||||
|
type Store struct {
|
||||||
|
mu sync.Mutex
|
||||||
|
token Token
|
||||||
|
}
|
||||||
|
|
||||||
|
func (a *Store) UpdateToken(token Token) {
|
||||||
|
a.mu.Lock()
|
||||||
|
defer a.mu.Unlock()
|
||||||
|
a.token = token
|
||||||
|
}
|
||||||
|
|
||||||
|
func (a *Store) Token() ([]byte, error) {
|
||||||
|
a.mu.Lock()
|
||||||
|
defer a.mu.Unlock()
|
||||||
|
return marshalToken(a.token)
|
||||||
|
}
|
104
relay/auth/hmac/token.go
Normal file
104
relay/auth/hmac/token.go
Normal file
@ -0,0 +1,104 @@
|
|||||||
|
package hmac
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"crypto/hmac"
|
||||||
|
"crypto/sha1"
|
||||||
|
"encoding/base64"
|
||||||
|
"encoding/gob"
|
||||||
|
"fmt"
|
||||||
|
"strconv"
|
||||||
|
"sync"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
log "github.com/sirupsen/logrus"
|
||||||
|
)
|
||||||
|
|
||||||
|
type Token struct {
|
||||||
|
Payload string
|
||||||
|
Signature string
|
||||||
|
}
|
||||||
|
|
||||||
|
func marshalToken(token Token) ([]byte, error) {
|
||||||
|
buffer := bytes.NewBuffer([]byte{})
|
||||||
|
encoder := gob.NewEncoder(buffer)
|
||||||
|
err := encoder.Encode(token)
|
||||||
|
if err != nil {
|
||||||
|
log.Errorf("failed to marshal token: %s", err)
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return buffer.Bytes(), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func unmarshalToken(payload []byte) (Token, error) {
|
||||||
|
var creds Token
|
||||||
|
buffer := bytes.NewBuffer(payload)
|
||||||
|
decoder := gob.NewDecoder(buffer)
|
||||||
|
err := decoder.Decode(&creds)
|
||||||
|
return creds, err
|
||||||
|
}
|
||||||
|
|
||||||
|
// TimedHMAC generates token with TTL and using pre-shared secret known to TURN server
|
||||||
|
type TimedHMAC struct {
|
||||||
|
mux sync.Mutex
|
||||||
|
secret string
|
||||||
|
timeToLive time.Duration
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewTimedHMAC(secret string, timeToLive time.Duration) *TimedHMAC {
|
||||||
|
return &TimedHMAC{
|
||||||
|
secret: secret,
|
||||||
|
timeToLive: timeToLive,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// GenerateToken generates new time-based secret token - basically Payload is a unix timestamp and Signature is a HMAC hash of a timestamp with a preshared TURN secret
|
||||||
|
func (m *TimedHMAC) GenerateToken() (*Token, error) {
|
||||||
|
timeAuth := time.Now().Add(m.timeToLive).Unix()
|
||||||
|
timeStamp := fmt.Sprint(timeAuth)
|
||||||
|
|
||||||
|
checksum, err := m.generate(timeStamp)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
return &Token{
|
||||||
|
Payload: timeStamp,
|
||||||
|
Signature: base64.StdEncoding.EncodeToString(checksum),
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *TimedHMAC) Validate(token Token) error {
|
||||||
|
expectedMAC, err := m.generate(token.Payload)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
expectedSignature := base64.StdEncoding.EncodeToString(expectedMAC)
|
||||||
|
|
||||||
|
if !hmac.Equal([]byte(expectedSignature), []byte(token.Signature)) {
|
||||||
|
return fmt.Errorf("signature mismatch")
|
||||||
|
}
|
||||||
|
|
||||||
|
timeAuthInt, err := strconv.ParseInt(token.Payload, 10, 64)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("invalid payload: %s", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if time.Now().Unix() > timeAuthInt {
|
||||||
|
return fmt.Errorf("expired token")
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *TimedHMAC) generate(payload string) ([]byte, error) {
|
||||||
|
mac := hmac.New(sha1.New, []byte(m.secret))
|
||||||
|
_, err := mac.Write([]byte(payload))
|
||||||
|
if err != nil {
|
||||||
|
log.Errorf("failed to generate token: %s", err)
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
return mac.Sum(nil), nil
|
||||||
|
}
|
103
relay/auth/hmac/token_test.go
Normal file
103
relay/auth/hmac/token_test.go
Normal file
@ -0,0 +1,103 @@
|
|||||||
|
package hmac
|
||||||
|
|
||||||
|
import (
|
||||||
|
"encoding/base64"
|
||||||
|
"strconv"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestGenerateCredentials(t *testing.T) {
|
||||||
|
secret := "secret"
|
||||||
|
timeToLive := 1 * time.Hour
|
||||||
|
v := NewTimedHMAC(secret, timeToLive)
|
||||||
|
|
||||||
|
creds, err := v.GenerateToken()
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("expected no error, got %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if creds.Payload == "" {
|
||||||
|
t.Fatalf("expected non-empty payload")
|
||||||
|
}
|
||||||
|
|
||||||
|
_, err = strconv.ParseInt(creds.Payload, 10, 64)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("expected payload to be a valid unix timestamp, got %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
_, err = base64.StdEncoding.DecodeString(creds.Signature)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("expected signature to be base64 encoded, got %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestValidateCredentials(t *testing.T) {
|
||||||
|
secret := "supersecret"
|
||||||
|
timeToLive := 1 * time.Hour
|
||||||
|
manager := NewTimedHMAC(secret, timeToLive)
|
||||||
|
|
||||||
|
// Test valid token
|
||||||
|
creds, err := manager.GenerateToken()
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("expected no error, got %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := manager.Validate(*creds); err != nil {
|
||||||
|
t.Fatalf("expected valid token: %s", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestInvalidSignature(t *testing.T) {
|
||||||
|
secret := "supersecret"
|
||||||
|
timeToLive := 1 * time.Hour
|
||||||
|
manager := NewTimedHMAC(secret, timeToLive)
|
||||||
|
|
||||||
|
creds, err := manager.GenerateToken()
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("expected no error, got %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
invalidCreds := &Token{
|
||||||
|
Payload: creds.Payload,
|
||||||
|
Signature: "invalidsignature",
|
||||||
|
}
|
||||||
|
|
||||||
|
if err = manager.Validate(*invalidCreds); err == nil {
|
||||||
|
t.Fatalf("expected invalid token due to signature mismatch")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestExpired(t *testing.T) {
|
||||||
|
secret := "supersecret"
|
||||||
|
v := NewTimedHMAC(secret, -1*time.Hour)
|
||||||
|
expiredCreds, err := v.GenerateToken()
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("expected no error, got %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if err = v.Validate(*expiredCreds); err == nil {
|
||||||
|
t.Fatalf("expected invalid token due to expiration")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestInvalidPayload(t *testing.T) {
|
||||||
|
secret := "supersecret"
|
||||||
|
timeToLive := 1 * time.Hour
|
||||||
|
v := NewTimedHMAC(secret, timeToLive)
|
||||||
|
|
||||||
|
creds, err := v.GenerateToken()
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("expected no error, got %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Test invalid payload
|
||||||
|
invalidPayloadCreds := &Token{
|
||||||
|
Payload: "invalidtimestamp",
|
||||||
|
Signature: creds.Signature,
|
||||||
|
}
|
||||||
|
|
||||||
|
if err = v.Validate(*invalidPayloadCreds); err == nil {
|
||||||
|
t.Fatalf("expected invalid token due to invalid payload")
|
||||||
|
}
|
||||||
|
}
|
27
relay/auth/hmac/validator.go
Normal file
27
relay/auth/hmac/validator.go
Normal file
@ -0,0 +1,27 @@
|
|||||||
|
package hmac
|
||||||
|
|
||||||
|
import (
|
||||||
|
log "github.com/sirupsen/logrus"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
type TimedHMACValidator struct {
|
||||||
|
*TimedHMAC
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewTimedHMACValidator(secret string, duration time.Duration) *TimedHMACValidator {
|
||||||
|
ta := NewTimedHMAC(secret, duration)
|
||||||
|
return &TimedHMACValidator{
|
||||||
|
ta,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (a *TimedHMACValidator) Validate(credentials any) error {
|
||||||
|
b := credentials.([]byte)
|
||||||
|
c, err := unmarshalToken(b)
|
||||||
|
if err != nil {
|
||||||
|
log.Errorf("failed to unmarshal token: %s", err)
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
return a.TimedHMAC.Validate(c)
|
||||||
|
}
|
5
relay/auth/validator.go
Normal file
5
relay/auth/validator.go
Normal file
@ -0,0 +1,5 @@
|
|||||||
|
package auth
|
||||||
|
|
||||||
|
type Validator interface {
|
||||||
|
Validate(any) error
|
||||||
|
}
|
@ -10,6 +10,7 @@ import (
|
|||||||
|
|
||||||
log "github.com/sirupsen/logrus"
|
log "github.com/sirupsen/logrus"
|
||||||
|
|
||||||
|
auth "github.com/netbirdio/netbird/relay/auth/hmac"
|
||||||
"github.com/netbirdio/netbird/relay/client/dialer/ws"
|
"github.com/netbirdio/netbird/relay/client/dialer/ws"
|
||||||
"github.com/netbirdio/netbird/relay/healthcheck"
|
"github.com/netbirdio/netbird/relay/healthcheck"
|
||||||
"github.com/netbirdio/netbird/relay/messages"
|
"github.com/netbirdio/netbird/relay/messages"
|
||||||
@ -98,6 +99,7 @@ type Client struct {
|
|||||||
log *log.Entry
|
log *log.Entry
|
||||||
parentCtx context.Context
|
parentCtx context.Context
|
||||||
connectionURL string
|
connectionURL string
|
||||||
|
authStore *auth.Store
|
||||||
hashedID []byte
|
hashedID []byte
|
||||||
|
|
||||||
bufPool *sync.Pool
|
bufPool *sync.Pool
|
||||||
@ -115,12 +117,13 @@ type Client struct {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// NewClient creates a new client for the relay server. The client is not connected to the server until the Connect
|
// NewClient creates a new client for the relay server. The client is not connected to the server until the Connect
|
||||||
func NewClient(ctx context.Context, serverURL, peerID string) *Client {
|
func NewClient(ctx context.Context, serverURL string, authStore *auth.Store, peerID string) *Client {
|
||||||
hashedID, hashedStringId := messages.HashID(peerID)
|
hashedID, hashedStringId := messages.HashID(peerID)
|
||||||
return &Client{
|
return &Client{
|
||||||
log: log.WithField("client_id", hashedStringId),
|
log: log.WithField("client_id", hashedStringId),
|
||||||
parentCtx: ctx,
|
parentCtx: ctx,
|
||||||
connectionURL: serverURL,
|
connectionURL: serverURL,
|
||||||
|
authStore: authStore,
|
||||||
hashedID: hashedID,
|
hashedID: hashedID,
|
||||||
bufPool: &sync.Pool{
|
bufPool: &sync.Pool{
|
||||||
New: func() any {
|
New: func() any {
|
||||||
@ -234,7 +237,12 @@ func (c *Client) connect() error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (c *Client) handShake() error {
|
func (c *Client) handShake() error {
|
||||||
msg, err := messages.MarshalHelloMsg(c.hashedID)
|
t, err := c.authStore.Token()
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
msg, err := messages.MarshalHelloMsg(c.hashedID, t)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Errorf("failed to marshal hello message: %s", err)
|
log.Errorf("failed to marshal hello message: %s", err)
|
||||||
return err
|
return err
|
||||||
@ -262,11 +270,11 @@ func (c *Client) handShake() error {
|
|||||||
return fmt.Errorf("unexpected message type")
|
return fmt.Errorf("unexpected message type")
|
||||||
}
|
}
|
||||||
|
|
||||||
domain, err := messages.UnmarshalHelloResponse(buf[:n])
|
ia, err := messages.UnmarshalHelloResponse(buf[:n])
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
c.instanceURL = domain
|
c.instanceURL = ia
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -8,6 +8,8 @@ import (
|
|||||||
"time"
|
"time"
|
||||||
|
|
||||||
log "github.com/sirupsen/logrus"
|
log "github.com/sirupsen/logrus"
|
||||||
|
|
||||||
|
relayAuth "github.com/netbirdio/netbird/relay/auth/hmac"
|
||||||
)
|
)
|
||||||
|
|
||||||
var (
|
var (
|
||||||
@ -35,9 +37,10 @@ func NewRelayTrack() *RelayTrack {
|
|||||||
// relay servers will be closed if there is no active connection. Periodically the manager will check if there is any
|
// relay servers will be closed if there is no active connection. Periodically the manager will check if there is any
|
||||||
// unused relay connection and close it.
|
// unused relay connection and close it.
|
||||||
type Manager struct {
|
type Manager struct {
|
||||||
ctx context.Context
|
ctx context.Context
|
||||||
serverURL string
|
serverURL string
|
||||||
peerID string
|
peerID string
|
||||||
|
tokenStore *relayAuth.Store
|
||||||
|
|
||||||
relayClient *Client
|
relayClient *Client
|
||||||
reconnectGuard *Guard
|
reconnectGuard *Guard
|
||||||
@ -54,6 +57,7 @@ func NewManager(ctx context.Context, serverURL string, peerID string) *Manager {
|
|||||||
ctx: ctx,
|
ctx: ctx,
|
||||||
serverURL: serverURL,
|
serverURL: serverURL,
|
||||||
peerID: peerID,
|
peerID: peerID,
|
||||||
|
tokenStore: &relayAuth.Store{},
|
||||||
relayClients: make(map[string]*RelayTrack),
|
relayClients: make(map[string]*RelayTrack),
|
||||||
onDisconnectedListeners: make(map[string]map[*func()]struct{}),
|
onDisconnectedListeners: make(map[string]map[*func()]struct{}),
|
||||||
}
|
}
|
||||||
@ -65,7 +69,7 @@ func (m *Manager) Serve() error {
|
|||||||
return fmt.Errorf("manager already serving")
|
return fmt.Errorf("manager already serving")
|
||||||
}
|
}
|
||||||
|
|
||||||
m.relayClient = NewClient(m.ctx, m.serverURL, m.peerID)
|
m.relayClient = NewClient(m.ctx, m.serverURL, m.tokenStore, m.peerID)
|
||||||
err := m.relayClient.Connect()
|
err := m.relayClient.Connect()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Errorf("failed to connect to relay server: %s", err)
|
log.Errorf("failed to connect to relay server: %s", err)
|
||||||
@ -158,7 +162,7 @@ func (m *Manager) openConnVia(serverAddress, peerKey string) (net.Conn, error) {
|
|||||||
m.relayClients[serverAddress] = rt
|
m.relayClients[serverAddress] = rt
|
||||||
m.relayClientsMutex.Unlock()
|
m.relayClientsMutex.Unlock()
|
||||||
|
|
||||||
relayClient := NewClient(m.ctx, serverAddress, m.peerID)
|
relayClient := NewClient(m.ctx, serverAddress, m.tokenStore, m.peerID)
|
||||||
err := relayClient.Connect()
|
err := relayClient.Connect()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
rt.Unlock()
|
rt.Unlock()
|
||||||
@ -260,3 +264,7 @@ func (m *Manager) notifyOnDisconnectListeners(serverAddress string) {
|
|||||||
m.listenerLock.Unlock()
|
m.listenerLock.Unlock()
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (m *Manager) UpdateToken(token relayAuth.Token) {
|
||||||
|
m.tokenStore.UpdateToken(token)
|
||||||
|
}
|
||||||
|
@ -24,6 +24,7 @@ var (
|
|||||||
letsencryptDomains []string
|
letsencryptDomains []string
|
||||||
tlsCertFile string
|
tlsCertFile string
|
||||||
tlsKeyFile string
|
tlsKeyFile string
|
||||||
|
authSecret string
|
||||||
|
|
||||||
rootCmd = &cobra.Command{
|
rootCmd = &cobra.Command{
|
||||||
Use: "relay",
|
Use: "relay",
|
||||||
@ -41,7 +42,7 @@ func init() {
|
|||||||
rootCmd.PersistentFlags().StringArrayVarP(&letsencryptDomains, "letsencrypt-domains", "a", nil, "list of domains to issue Let's Encrypt certificate for. Enables TLS using Let's Encrypt. Will fetch and renew certificate, and run the server with TLS")
|
rootCmd.PersistentFlags().StringArrayVarP(&letsencryptDomains, "letsencrypt-domains", "a", nil, "list of domains to issue Let's Encrypt certificate for. Enables TLS using Let's Encrypt. Will fetch and renew certificate, and run the server with TLS")
|
||||||
rootCmd.PersistentFlags().StringVarP(&tlsCertFile, "tls-cert-file", "c", "", "")
|
rootCmd.PersistentFlags().StringVarP(&tlsCertFile, "tls-cert-file", "c", "", "")
|
||||||
rootCmd.PersistentFlags().StringVarP(&tlsKeyFile, "tls-key-file", "k", "", "")
|
rootCmd.PersistentFlags().StringVarP(&tlsKeyFile, "tls-key-file", "k", "", "")
|
||||||
|
rootCmd.PersistentFlags().StringVarP(&authSecret, "auth-secret", "s", "", "log level")
|
||||||
}
|
}
|
||||||
|
|
||||||
func waitForExitSignal() {
|
func waitForExitSignal() {
|
||||||
@ -56,6 +57,11 @@ func execute(cmd *cobra.Command, args []string) {
|
|||||||
os.Exit(1)
|
os.Exit(1)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if authSecret == "" {
|
||||||
|
log.Errorf("auth secret is required")
|
||||||
|
os.Exit(1)
|
||||||
|
}
|
||||||
|
|
||||||
srvListenerCfg := server.ListenerConfig{
|
srvListenerCfg := server.ListenerConfig{
|
||||||
Address: listenAddress,
|
Address: listenAddress,
|
||||||
}
|
}
|
||||||
@ -76,7 +82,7 @@ func execute(cmd *cobra.Command, args []string) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
tlsSupport := srvListenerCfg.TLSConfig != nil
|
tlsSupport := srvListenerCfg.TLSConfig != nil
|
||||||
srv := server.NewServer(exposedAddress, tlsSupport)
|
srv := server.NewServer(exposedAddress, tlsSupport, authSecret)
|
||||||
log.Infof("server will be available on: %s", srv.InstanceURL())
|
log.Infof("server will be available on: %s", srv.InstanceURL())
|
||||||
err := srv.Listen(srvListenerCfg)
|
err := srv.Listen(srvListenerCfg)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
@ -15,8 +15,10 @@ const (
|
|||||||
MsgTypeClose MsgType = 3
|
MsgTypeClose MsgType = 3
|
||||||
MsgTypeHealthCheck MsgType = 4
|
MsgTypeHealthCheck MsgType = 4
|
||||||
|
|
||||||
headerSizeTransport = 1 + IDSize // 1 byte for msg type, IDSize for peerID
|
sizeOfMsgType = 1
|
||||||
headerSizeHello = 1 + 4 + IDSize // 1 byte for msg type, 4 byte for magic header, IDSize for peerID
|
sizeOfMagicBye = 4
|
||||||
|
headerSizeTransport = sizeOfMsgType + IDSize // 1 byte for msg type, IDSize for peerID
|
||||||
|
headerSizeHello = sizeOfMsgType + sizeOfMagicBye + IDSize // 1 byte for msg type, 4 byte for magic header, IDSize for peerID
|
||||||
|
|
||||||
MaxHandshakeSize = 90
|
MaxHandshakeSize = 90
|
||||||
)
|
)
|
||||||
@ -47,7 +49,7 @@ func (m MsgType) String() string {
|
|||||||
}
|
}
|
||||||
|
|
||||||
type HelloResponse struct {
|
type HelloResponse struct {
|
||||||
DomainAddress string
|
InstanceAddress string
|
||||||
}
|
}
|
||||||
|
|
||||||
func DetermineClientMsgType(msg []byte) (MsgType, error) {
|
func DetermineClientMsgType(msg []byte) (MsgType, error) {
|
||||||
@ -83,28 +85,29 @@ func DetermineServerMsgType(msg []byte) (MsgType, error) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// MarshalHelloMsg initial hello message
|
// MarshalHelloMsg initial hello message
|
||||||
func MarshalHelloMsg(peerID []byte) ([]byte, error) {
|
func MarshalHelloMsg(peerID []byte, additions []byte) ([]byte, error) {
|
||||||
if len(peerID) != IDSize {
|
if len(peerID) != IDSize {
|
||||||
return nil, fmt.Errorf("invalid peerID length: %d", len(peerID))
|
return nil, fmt.Errorf("invalid peerID length: %d", len(peerID))
|
||||||
}
|
}
|
||||||
msg := make([]byte, 5, headerSizeHello)
|
msg := make([]byte, 5, headerSizeHello+len(additions))
|
||||||
msg[0] = byte(MsgTypeHello)
|
msg[0] = byte(MsgTypeHello)
|
||||||
copy(msg[1:5], magicHeader)
|
copy(msg[1:5], magicHeader)
|
||||||
msg = append(msg, peerID...)
|
msg = append(msg, peerID...)
|
||||||
|
msg = append(msg, additions...)
|
||||||
return msg, nil
|
return msg, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func UnmarshalHelloMsg(msg []byte) ([]byte, error) {
|
func UnmarshalHelloMsg(msg []byte) ([]byte, []byte, error) {
|
||||||
if len(msg) < headerSizeHello {
|
if len(msg) < headerSizeHello {
|
||||||
return nil, fmt.Errorf("invalid 'hello' messge")
|
return nil, nil, fmt.Errorf("invalid 'hello' messge")
|
||||||
}
|
}
|
||||||
bytes.Equal(msg[1:5], magicHeader)
|
bytes.Equal(msg[1:5], magicHeader)
|
||||||
return msg[5:], nil
|
return msg[5:], msg[headerSizeHello:], nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func MarshalHelloResponse(DomainAddress string) ([]byte, error) {
|
func MarshalHelloResponse(DomainAddress string) ([]byte, error) {
|
||||||
payload := HelloResponse{
|
payload := HelloResponse{
|
||||||
DomainAddress: DomainAddress,
|
InstanceAddress: DomainAddress,
|
||||||
}
|
}
|
||||||
|
|
||||||
buf := new(bytes.Buffer)
|
buf := new(bytes.Buffer)
|
||||||
@ -135,7 +138,7 @@ func UnmarshalHelloResponse(msg []byte) (string, error) {
|
|||||||
log.Errorf("failed to gob decode hello response: %s", err)
|
log.Errorf("failed to gob decode hello response: %s", err)
|
||||||
return "", err
|
return "", err
|
||||||
}
|
}
|
||||||
return payload.DomainAddress, nil
|
return payload.InstanceAddress, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// Close message
|
// Close message
|
||||||
|
@ -6,12 +6,12 @@ import (
|
|||||||
|
|
||||||
func TestMarshalHelloMsg(t *testing.T) {
|
func TestMarshalHelloMsg(t *testing.T) {
|
||||||
peerID := []byte("abdFAaBcawquEiCMzAabYosuUaGLtSNhKxz+")
|
peerID := []byte("abdFAaBcawquEiCMzAabYosuUaGLtSNhKxz+")
|
||||||
bHello, err := MarshalHelloMsg(peerID)
|
bHello, err := MarshalHelloMsg(peerID, nil)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("error: %v", err)
|
t.Fatalf("error: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
receivedPeerID, err := UnmarshalHelloMsg(bHello)
|
receivedPeerID, _, err := UnmarshalHelloMsg(bHello)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("error: %v", err)
|
t.Fatalf("error: %v", err)
|
||||||
}
|
}
|
||||||
|
@ -8,20 +8,24 @@ import (
|
|||||||
|
|
||||||
log "github.com/sirupsen/logrus"
|
log "github.com/sirupsen/logrus"
|
||||||
|
|
||||||
|
"github.com/netbirdio/netbird/relay/auth"
|
||||||
"github.com/netbirdio/netbird/relay/messages"
|
"github.com/netbirdio/netbird/relay/messages"
|
||||||
)
|
)
|
||||||
|
|
||||||
type Relay struct {
|
type Relay struct {
|
||||||
|
validator auth.Validator
|
||||||
|
|
||||||
store *Store
|
store *Store
|
||||||
instaceURL string // domain:port
|
instaceURL string
|
||||||
|
|
||||||
closed bool
|
closed bool
|
||||||
closeMu sync.RWMutex
|
closeMu sync.RWMutex
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewRelay(exposedAddress string, tlsSupport bool) *Relay {
|
func NewRelay(exposedAddress string, tlsSupport bool, validator auth.Validator) *Relay {
|
||||||
r := &Relay{
|
r := &Relay{
|
||||||
store: NewStore(),
|
validator: validator,
|
||||||
|
store: NewStore(),
|
||||||
}
|
}
|
||||||
|
|
||||||
if tlsSupport {
|
if tlsSupport {
|
||||||
@ -29,6 +33,7 @@ func NewRelay(exposedAddress string, tlsSupport bool) *Relay {
|
|||||||
} else {
|
} else {
|
||||||
r.instaceURL = fmt.Sprintf("rel://%s", exposedAddress)
|
r.instaceURL = fmt.Sprintf("rel://%s", exposedAddress)
|
||||||
}
|
}
|
||||||
|
|
||||||
return r
|
return r
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -94,12 +99,17 @@ func (r *Relay) handShake(conn net.Conn) ([]byte, error) {
|
|||||||
return nil, tErr
|
return nil, tErr
|
||||||
}
|
}
|
||||||
|
|
||||||
peerID, err := messages.UnmarshalHelloMsg(buf[:n])
|
peerID, authPayload, err := messages.UnmarshalHelloMsg(buf[:n])
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Errorf("failed to handshake: %s", err)
|
log.Errorf("failed to handshake: %s", err)
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if err := r.validator.Validate(authPayload); err != nil {
|
||||||
|
log.Errorf("failed to authenticate peer with id: %s, %s", peerID, err)
|
||||||
|
return nil, fmt.Errorf("failed to authenticate peer")
|
||||||
|
}
|
||||||
|
|
||||||
msg, _ := messages.MarshalHelloResponse(r.instaceURL)
|
msg, _ := messages.MarshalHelloResponse(r.instaceURL)
|
||||||
_, err = conn.Write(msg)
|
_, err = conn.Write(msg)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
@ -9,6 +9,7 @@ import (
|
|||||||
|
|
||||||
log "github.com/sirupsen/logrus"
|
log "github.com/sirupsen/logrus"
|
||||||
|
|
||||||
|
auth "github.com/netbirdio/netbird/relay/auth/hmac"
|
||||||
"github.com/netbirdio/netbird/relay/server/listener"
|
"github.com/netbirdio/netbird/relay/server/listener"
|
||||||
"github.com/netbirdio/netbird/relay/server/listener/udp"
|
"github.com/netbirdio/netbird/relay/server/listener/udp"
|
||||||
"github.com/netbirdio/netbird/relay/server/listener/ws"
|
"github.com/netbirdio/netbird/relay/server/listener/ws"
|
||||||
@ -25,9 +26,12 @@ type Server struct {
|
|||||||
wSListener listener.Listener
|
wSListener listener.Listener
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewServer(exposedAddress string, tlsSupport bool) *Server {
|
func NewServer(exposedAddress string, tlsSupport bool, authSecret string) *Server {
|
||||||
return &Server{
|
return &Server{
|
||||||
relay: NewRelay(exposedAddress, tlsSupport),
|
relay: NewRelay(
|
||||||
|
exposedAddress,
|
||||||
|
tlsSupport,
|
||||||
|
auth.NewTimedHMACValidator(authSecret, 24*time.Hour)),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user