mirror of
https://github.com/netbirdio/netbird.git
synced 2025-01-31 10:31:58 +01:00
Integrate the relay authentication
This commit is contained in:
parent
8845e8fbc7
commit
836072098b
@ -86,7 +86,7 @@ func startManagement(t *testing.T, config *mgmt.Config) (*grpc.Server, net.Liste
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
turnManager := mgmt.NewTimeBasedAuthSecretsManager(peersUpdateManager, config.TURNConfig)
|
||||
turnManager := mgmt.NewTimeBasedAuthSecretsManager(peersUpdateManager, config.TURNConfig, "")
|
||||
mgmtServer, err := mgmt.NewServer(config, accountManager, peersUpdateManager, turnManager, nil, nil)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
|
@ -26,6 +26,7 @@ import (
|
||||
"github.com/netbirdio/netbird/iface"
|
||||
mgm "github.com/netbirdio/netbird/management/client"
|
||||
mgmProto "github.com/netbirdio/netbird/management/proto"
|
||||
"github.com/netbirdio/netbird/relay/auth/hmac"
|
||||
relayClient "github.com/netbirdio/netbird/relay/client"
|
||||
signal "github.com/netbirdio/netbird/signal/client"
|
||||
"github.com/netbirdio/netbird/util"
|
||||
@ -245,9 +246,10 @@ func (c *ConnectClient) run(
|
||||
|
||||
c.statusRecorder.MarkSignalConnected()
|
||||
|
||||
relayAddress := relayAddress(loginResp)
|
||||
relayManager := relayClient.NewManager(engineCtx, relayAddress, myPrivateKey.PublicKey().String())
|
||||
if relayAddress != "" {
|
||||
relayURL, token := parseRelayInfo(loginResp)
|
||||
relayManager := relayClient.NewManager(engineCtx, relayURL, myPrivateKey.PublicKey().String())
|
||||
if relayURL != "" {
|
||||
relayManager.UpdateToken(token)
|
||||
if err = relayManager.Serve(); err != nil {
|
||||
log.Error(err)
|
||||
return wrapErr(err)
|
||||
@ -307,15 +309,27 @@ func (c *ConnectClient) run(
|
||||
return nil
|
||||
}
|
||||
|
||||
func relayAddress(resp *mgmProto.LoginResponse) string {
|
||||
func parseRelayInfo(resp *mgmProto.LoginResponse) (string, hmac.Token) {
|
||||
// todo remove this
|
||||
if ra := peer.ForcedRelayAddress(); ra != "" {
|
||||
return ra
|
||||
return ra, hmac.Token{}
|
||||
}
|
||||
|
||||
if resp.GetWiretrusteeConfig().GetRelayAddress() != "" {
|
||||
return resp.GetWiretrusteeConfig().GetRelayAddress()
|
||||
msg := resp.GetWiretrusteeConfig().GetRelay()
|
||||
if msg == nil {
|
||||
return "", hmac.Token{}
|
||||
}
|
||||
return ""
|
||||
|
||||
var url string
|
||||
if msg.GetUrls() != nil && len(msg.GetUrls()) > 0 {
|
||||
url = msg.GetUrls()[0]
|
||||
}
|
||||
|
||||
token := hmac.Token{
|
||||
Payload: msg.GetTokenPayload(),
|
||||
Signature: msg.GetTokenSignature(),
|
||||
}
|
||||
return url, token
|
||||
}
|
||||
|
||||
func (c *ConnectClient) Engine() *Engine {
|
||||
|
@ -24,6 +24,7 @@ import (
|
||||
"github.com/netbirdio/netbird/client/firewall/manager"
|
||||
"github.com/netbirdio/netbird/client/internal/acl"
|
||||
"github.com/netbirdio/netbird/client/internal/dns"
|
||||
|
||||
"github.com/netbirdio/netbird/client/internal/networkmonitor"
|
||||
"github.com/netbirdio/netbird/client/internal/peer"
|
||||
"github.com/netbirdio/netbird/client/internal/relay"
|
||||
@ -36,6 +37,7 @@ import (
|
||||
"github.com/netbirdio/netbird/iface/bind"
|
||||
mgm "github.com/netbirdio/netbird/management/client"
|
||||
mgmProto "github.com/netbirdio/netbird/management/proto"
|
||||
auth "github.com/netbirdio/netbird/relay/auth/hmac"
|
||||
relayClient "github.com/netbirdio/netbird/relay/client"
|
||||
"github.com/netbirdio/netbird/route"
|
||||
signal "github.com/netbirdio/netbird/signal/client"
|
||||
@ -467,12 +469,13 @@ func (e *Engine) handleSync(update *mgmProto.SyncResponse) error {
|
||||
defer e.syncMsgMux.Unlock()
|
||||
|
||||
if update.GetWiretrusteeConfig() != nil {
|
||||
err := e.updateTURNs(update.GetWiretrusteeConfig().GetTurns())
|
||||
wCfg := update.GetWiretrusteeConfig()
|
||||
err := e.updateTURNs(wCfg.GetTurns())
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
err = e.updateSTUNs(update.GetWiretrusteeConfig().GetStuns())
|
||||
err = e.updateSTUNs(wCfg.GetStuns())
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@ -482,8 +485,16 @@ func (e *Engine) handleSync(update *mgmProto.SyncResponse) error {
|
||||
stunTurn = append(stunTurn, e.TURNs...)
|
||||
e.StunTurn.Store(stunTurn)
|
||||
|
||||
// todo update relay address in the relay manager
|
||||
relayMsg := wCfg.GetRelay()
|
||||
if relayMsg != nil {
|
||||
c := auth.Token{
|
||||
Payload: relayMsg.GetTokenPayload(),
|
||||
Signature: relayMsg.GetTokenSignature(),
|
||||
}
|
||||
e.relayManager.UpdateToken(c)
|
||||
}
|
||||
|
||||
// todo update relay address in the relay manager
|
||||
// todo update signal
|
||||
}
|
||||
|
||||
|
@ -1071,7 +1071,7 @@ func startManagement(dataDir string) (*grpc.Server, string, error) {
|
||||
if err != nil {
|
||||
return nil, "", err
|
||||
}
|
||||
turnManager := server.NewTimeBasedAuthSecretsManager(peersUpdateManager, config.TURNConfig)
|
||||
turnManager := server.NewTimeBasedAuthSecretsManager(peersUpdateManager, config.TURNConfig, "")
|
||||
mgmtServer, err := server.NewServer(config, accountManager, peersUpdateManager, turnManager, nil, nil)
|
||||
if err != nil {
|
||||
return nil, "", err
|
||||
|
@ -122,7 +122,7 @@ func startManagement(t *testing.T, signalAddr string, counter *int) (*grpc.Serve
|
||||
if err != nil {
|
||||
return nil, "", err
|
||||
}
|
||||
turnManager := server.NewTimeBasedAuthSecretsManager(peersUpdateManager, config.TURNConfig)
|
||||
turnManager := server.NewTimeBasedAuthSecretsManager(peersUpdateManager, config.TURNConfig, "")
|
||||
mgmtServer, err := server.NewServer(config, accountManager, peersUpdateManager, turnManager, nil, nil)
|
||||
if err != nil {
|
||||
return nil, "", err
|
||||
|
@ -75,7 +75,7 @@ func startManagement(t *testing.T) (*grpc.Server, net.Listener) {
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
turnManager := mgmt.NewTimeBasedAuthSecretsManager(peersUpdateManager, config.TURNConfig)
|
||||
turnManager := mgmt.NewTimeBasedAuthSecretsManager(peersUpdateManager, config.TURNConfig, "")
|
||||
mgmtServer, err := mgmt.NewServer(config, accountManager, peersUpdateManager, turnManager, nil, nil)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
|
@ -183,7 +183,7 @@ var (
|
||||
return fmt.Errorf("failed to build default manager: %v", err)
|
||||
}
|
||||
|
||||
turnManager := server.NewTimeBasedAuthSecretsManager(peersUpdateManager, config.TURNConfig)
|
||||
turnRelayTokenManager := server.NewTimeBasedAuthSecretsManager(peersUpdateManager, config.TURNConfig, config.RelayAddress)
|
||||
|
||||
trustedPeers := config.ReverseProxy.TrustedPeers
|
||||
defaultTrustedPeers := []netip.Prefix{netip.MustParsePrefix("0.0.0.0/0"), netip.MustParsePrefix("::/0")}
|
||||
@ -260,7 +260,7 @@ var (
|
||||
ephemeralManager.LoadInitialPeers()
|
||||
|
||||
gRPCAPIHandler := grpc.NewServer(gRPCOpts...)
|
||||
srv, err := server.NewServer(config, accountManager, peersUpdateManager, turnManager, appMetrics, ephemeralManager)
|
||||
srv, err := server.NewServer(config, accountManager, peersUpdateManager, turnRelayTokenManager, appMetrics, ephemeralManager)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed creating gRPC API handler: %v", err)
|
||||
}
|
||||
|
File diff suppressed because it is too large
Load Diff
@ -147,7 +147,7 @@ message WiretrusteeConfig {
|
||||
// a Signal server config
|
||||
HostConfig signal = 3;
|
||||
|
||||
string RelayAddress = 4;
|
||||
RelayConfig relay = 4;
|
||||
}
|
||||
|
||||
// HostConfig describes connection properties of some server (e.g. STUN, Signal, Management)
|
||||
@ -164,6 +164,13 @@ message HostConfig {
|
||||
DTLS = 4;
|
||||
}
|
||||
}
|
||||
|
||||
message RelayConfig {
|
||||
repeated string urls = 1;
|
||||
string tokenPayload = 2;
|
||||
string tokenSignature = 3;
|
||||
}
|
||||
|
||||
// ProtectedHostConfig is similar to HostConfig but has additional user and password
|
||||
// Mostly used for TURN servers
|
||||
message ProtectedHostConfig {
|
||||
|
@ -166,9 +166,9 @@ func TestGetNetworkMap_DNSConfigSync(t *testing.T) {
|
||||
|
||||
newAccountDNSConfig, err := am.GetNetworkMap(peer1.ID)
|
||||
require.NoError(t, err)
|
||||
require.Len(t, newAccountDNSConfig.DNSConfig.CustomZones, 1, "default DNS config should have one custom zone for peers")
|
||||
require.True(t, newAccountDNSConfig.DNSConfig.ServiceEnable, "default DNS config should have local DNS service enabled")
|
||||
require.Len(t, newAccountDNSConfig.DNSConfig.NameServerGroups, 0, "updated DNS config should have no nameserver groups since peer 1 is NS for the only existing NS group")
|
||||
require.Len(t, newAccountDNSConfig.DNSConfig.CustomZones, 1, "default DNS turnCfg should have one custom zone for peers")
|
||||
require.True(t, newAccountDNSConfig.DNSConfig.ServiceEnable, "default DNS turnCfg should have local DNS service enabled")
|
||||
require.Len(t, newAccountDNSConfig.DNSConfig.NameServerGroups, 0, "updated DNS turnCfg should have no nameserver groups since peer 1 is NS for the only existing NS group")
|
||||
|
||||
dnsSettings := account.DNSSettings.Copy()
|
||||
dnsSettings.DisabledManagementGroups = append(dnsSettings.DisabledManagementGroups, dnsGroup1ID)
|
||||
@ -178,13 +178,13 @@ func TestGetNetworkMap_DNSConfigSync(t *testing.T) {
|
||||
|
||||
updatedAccountDNSConfig, err := am.GetNetworkMap(peer1.ID)
|
||||
require.NoError(t, err)
|
||||
require.Len(t, updatedAccountDNSConfig.DNSConfig.CustomZones, 0, "updated DNS config should have no custom zone when peer belongs to a disabled group")
|
||||
require.False(t, updatedAccountDNSConfig.DNSConfig.ServiceEnable, "updated DNS config should have local DNS service disabled when peer belongs to a disabled group")
|
||||
require.Len(t, updatedAccountDNSConfig.DNSConfig.CustomZones, 0, "updated DNS turnCfg should have no custom zone when peer belongs to a disabled group")
|
||||
require.False(t, updatedAccountDNSConfig.DNSConfig.ServiceEnable, "updated DNS turnCfg should have local DNS service disabled when peer belongs to a disabled group")
|
||||
peer2AccountDNSConfig, err := am.GetNetworkMap(peer2.ID)
|
||||
require.NoError(t, err)
|
||||
require.Len(t, peer2AccountDNSConfig.DNSConfig.CustomZones, 1, "DNS config should have one custom zone for peers not in the disabled group")
|
||||
require.True(t, peer2AccountDNSConfig.DNSConfig.ServiceEnable, "DNS config should have DNS service enabled for peers not in the disabled group")
|
||||
require.Len(t, peer2AccountDNSConfig.DNSConfig.NameServerGroups, 1, "updated DNS config should have 1 nameserver groups since peer 2 is part of the group All")
|
||||
require.Len(t, peer2AccountDNSConfig.DNSConfig.CustomZones, 1, "DNS turnCfg should have one custom zone for peers not in the disabled group")
|
||||
require.True(t, peer2AccountDNSConfig.DNSConfig.ServiceEnable, "DNS turnCfg should have DNS service enabled for peers not in the disabled group")
|
||||
require.Len(t, peer2AccountDNSConfig.DNSConfig.NameServerGroups, 1, "updated DNS turnCfg should have 1 nameserver groups since peer 2 is part of the group All")
|
||||
}
|
||||
|
||||
func createDNSManager(t *testing.T) (*DefaultAccountManager, error) {
|
||||
|
@ -29,17 +29,17 @@ type GRPCServer struct {
|
||||
accountManager AccountManager
|
||||
wgKey wgtypes.Key
|
||||
proto.UnimplementedManagementServiceServer
|
||||
peersUpdateManager *PeersUpdateManager
|
||||
config *Config
|
||||
turnCredentialsManager TURNCredentialsManager
|
||||
jwtValidator *jwtclaims.JWTValidator
|
||||
jwtClaimsExtractor *jwtclaims.ClaimsExtractor
|
||||
appMetrics telemetry.AppMetrics
|
||||
ephemeralManager *EphemeralManager
|
||||
peersUpdateManager *PeersUpdateManager
|
||||
config *Config
|
||||
turnRelayTokenManager TURNRelayTokenManager
|
||||
jwtValidator *jwtclaims.JWTValidator
|
||||
jwtClaimsExtractor *jwtclaims.ClaimsExtractor
|
||||
appMetrics telemetry.AppMetrics
|
||||
ephemeralManager *EphemeralManager
|
||||
}
|
||||
|
||||
// NewServer creates a new Management server
|
||||
func NewServer(config *Config, accountManager AccountManager, peersUpdateManager *PeersUpdateManager, turnCredentialsManager TURNCredentialsManager, appMetrics telemetry.AppMetrics, ephemeralManager *EphemeralManager) (*GRPCServer, error) {
|
||||
func NewServer(config *Config, accountManager AccountManager, peersUpdateManager *PeersUpdateManager, turnRelayTokenManager TURNRelayTokenManager, appMetrics telemetry.AppMetrics, ephemeralManager *EphemeralManager) (*GRPCServer, error) {
|
||||
key, err := wgtypes.GeneratePrivateKey()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
@ -58,7 +58,7 @@ func NewServer(config *Config, accountManager AccountManager, peersUpdateManager
|
||||
return nil, status.Errorf(codes.Internal, "unable to create new jwt middleware, err: %v", err)
|
||||
}
|
||||
} else {
|
||||
log.Debug("unable to use http config to create new jwt middleware")
|
||||
log.Debug("unable to use http turnCfg to create new jwt middleware")
|
||||
}
|
||||
|
||||
if appMetrics != nil {
|
||||
@ -84,14 +84,14 @@ func NewServer(config *Config, accountManager AccountManager, peersUpdateManager
|
||||
return &GRPCServer{
|
||||
wgKey: key,
|
||||
// peerKey -> event channel
|
||||
peersUpdateManager: peersUpdateManager,
|
||||
accountManager: accountManager,
|
||||
config: config,
|
||||
turnCredentialsManager: turnCredentialsManager,
|
||||
jwtValidator: jwtValidator,
|
||||
jwtClaimsExtractor: jwtClaimsExtractor,
|
||||
appMetrics: appMetrics,
|
||||
ephemeralManager: ephemeralManager,
|
||||
peersUpdateManager: peersUpdateManager,
|
||||
accountManager: accountManager,
|
||||
config: config,
|
||||
turnRelayTokenManager: turnRelayTokenManager,
|
||||
jwtValidator: jwtValidator,
|
||||
jwtClaimsExtractor: jwtClaimsExtractor,
|
||||
appMetrics: appMetrics,
|
||||
ephemeralManager: ephemeralManager,
|
||||
}, nil
|
||||
}
|
||||
|
||||
@ -150,7 +150,7 @@ func (s *GRPCServer) Sync(req *proto.EncryptedMessage, srv proto.ManagementServi
|
||||
s.ephemeralManager.OnPeerConnected(peer)
|
||||
|
||||
if s.config.TURNConfig.TimeBasedCredentials {
|
||||
s.turnCredentialsManager.SetupRefresh(peer.ID)
|
||||
s.turnRelayTokenManager.SetupRefresh(peer.ID)
|
||||
}
|
||||
|
||||
if s.appMetrics != nil {
|
||||
@ -201,7 +201,7 @@ func (s *GRPCServer) Sync(req *proto.EncryptedMessage, srv proto.ManagementServi
|
||||
|
||||
func (s *GRPCServer) cancelPeerRoutines(peer *nbpeer.Peer) {
|
||||
s.peersUpdateManager.CloseChannel(peer.ID)
|
||||
s.turnCredentialsManager.CancelRefresh(peer.ID)
|
||||
s.turnRelayTokenManager.CancelRefresh(peer.ID)
|
||||
_ = s.accountManager.CancelPeerRoutines(peer)
|
||||
s.ephemeralManager.OnPeerDisconnected(peer)
|
||||
}
|
||||
@ -377,9 +377,14 @@ func (s *GRPCServer) Login(ctx context.Context, req *proto.EncryptedMessage) (*p
|
||||
s.ephemeralManager.OnPeerDisconnected(peer)
|
||||
}
|
||||
|
||||
trt, err := s.turnRelayTokenManager.Generate()
|
||||
if err != nil {
|
||||
log.Errorf("failed generating TURN and Relay token: %v", err)
|
||||
}
|
||||
|
||||
// if peer has reached this point then it has logged in
|
||||
loginResp := &proto.LoginResponse{
|
||||
WiretrusteeConfig: toWiretrusteeConfig(s.config, nil),
|
||||
WiretrusteeConfig: toWiretrusteeConfig(s.config, nil, trt),
|
||||
PeerConfig: toPeerConfig(peer, netMap.Network, s.accountManager.GetDNSDomain()),
|
||||
}
|
||||
encryptedResp, err := encryption.EncryptMessage(peerKey, s.wgKey, loginResp)
|
||||
@ -407,11 +412,11 @@ func ToResponseProto(configProto Protocol) proto.HostConfig_Protocol {
|
||||
case TCP:
|
||||
return proto.HostConfig_TCP
|
||||
default:
|
||||
panic(fmt.Errorf("unexpected config protocol type %v", configProto))
|
||||
panic(fmt.Errorf("unexpected turnCfg protocol type %v", configProto))
|
||||
}
|
||||
}
|
||||
|
||||
func toWiretrusteeConfig(config *Config, turnCredentials *TURNCredentials) *proto.WiretrusteeConfig {
|
||||
func toWiretrusteeConfig(config *Config, turnCredentials *TURNRelayToken, relayToken *TURNRelayToken) *proto.WiretrusteeConfig {
|
||||
if config == nil {
|
||||
return nil
|
||||
}
|
||||
@ -427,8 +432,8 @@ func toWiretrusteeConfig(config *Config, turnCredentials *TURNCredentials) *prot
|
||||
var username string
|
||||
var password string
|
||||
if turnCredentials != nil {
|
||||
username = turnCredentials.Username
|
||||
password = turnCredentials.Password
|
||||
username = turnCredentials.Payload
|
||||
password = turnCredentials.Signature
|
||||
} else {
|
||||
username = turn.Username
|
||||
password = turn.Password
|
||||
@ -443,6 +448,18 @@ func toWiretrusteeConfig(config *Config, turnCredentials *TURNCredentials) *prot
|
||||
})
|
||||
}
|
||||
|
||||
var relayCfg *proto.RelayConfig
|
||||
if config.RelayAddress != "" {
|
||||
relayCfg = &proto.RelayConfig{
|
||||
Urls: []string{config.RelayAddress},
|
||||
}
|
||||
|
||||
if relayToken != nil {
|
||||
relayCfg.TokenPayload = relayToken.Payload
|
||||
relayCfg.TokenSignature = relayToken.Signature
|
||||
}
|
||||
}
|
||||
|
||||
return &proto.WiretrusteeConfig{
|
||||
Stuns: stuns,
|
||||
Turns: turns,
|
||||
@ -450,7 +467,7 @@ func toWiretrusteeConfig(config *Config, turnCredentials *TURNCredentials) *prot
|
||||
Uri: config.Signal.URI,
|
||||
Protocol: ToResponseProto(config.Signal.Proto),
|
||||
},
|
||||
RelayAddress: config.RelayAddress,
|
||||
Relay: relayCfg,
|
||||
}
|
||||
}
|
||||
|
||||
@ -478,8 +495,8 @@ func toRemotePeerConfig(peers []*nbpeer.Peer, dnsName string) []*proto.RemotePee
|
||||
return remotePeers
|
||||
}
|
||||
|
||||
func toSyncResponse(config *Config, peer *nbpeer.Peer, turnCredentials *TURNCredentials, networkMap *NetworkMap, dnsName string) *proto.SyncResponse {
|
||||
wtConfig := toWiretrusteeConfig(config, turnCredentials)
|
||||
func toSyncResponse(config *Config, peer *nbpeer.Peer, turnCredentials *TURNRelayToken, relayCredentials *TURNRelayToken, networkMap *NetworkMap, dnsName string) *proto.SyncResponse {
|
||||
wtConfig := toWiretrusteeConfig(config, turnCredentials, relayCredentials)
|
||||
|
||||
pConfig := toPeerConfig(peer, networkMap.Network, dnsName)
|
||||
|
||||
@ -520,14 +537,16 @@ func (s *GRPCServer) IsHealthy(ctx context.Context, req *proto.Empty) (*proto.Em
|
||||
// sendInitialSync sends initial proto.SyncResponse to the peer requesting synchronization
|
||||
func (s *GRPCServer) sendInitialSync(peerKey wgtypes.Key, peer *nbpeer.Peer, networkMap *NetworkMap, srv proto.ManagementService_SyncServer) error {
|
||||
// make secret time based TURN credentials optional
|
||||
var turnCredentials *TURNCredentials
|
||||
if s.config.TURNConfig.TimeBasedCredentials {
|
||||
creds := s.turnCredentialsManager.GenerateCredentials()
|
||||
turnCredentials = &creds
|
||||
} else {
|
||||
turnCredentials = nil
|
||||
var turnCredentials *TURNRelayToken
|
||||
trt, err := s.turnRelayTokenManager.Generate()
|
||||
if err != nil {
|
||||
log.Errorf("failed generating TURN and Relay token: %v", err)
|
||||
}
|
||||
plainResp := toSyncResponse(s.config, peer, turnCredentials, networkMap, s.accountManager.GetDNSDomain())
|
||||
if s.config.TURNConfig.TimeBasedCredentials {
|
||||
turnCredentials = trt
|
||||
}
|
||||
|
||||
plainResp := toSyncResponse(s.config, peer, turnCredentials, trt, networkMap, s.accountManager.GetDNSDomain())
|
||||
|
||||
encryptedResp, err := encryption.EncryptMessage(peerKey, s.wgKey, plainResp)
|
||||
if err != nil {
|
||||
|
@ -169,7 +169,7 @@ func Test_SyncProtocol(t *testing.T) {
|
||||
}
|
||||
|
||||
if wiretrusteeConfig.GetSignal() == nil {
|
||||
t.Fatal("expecting SyncResponse to have WiretrusteeConfig with non-nil Signal config")
|
||||
t.Fatal("expecting SyncResponse to have WiretrusteeConfig with non-nil Signal turnCfg")
|
||||
}
|
||||
|
||||
expectedSignalConfig := &mgmtProto.HostConfig{
|
||||
@ -418,7 +418,7 @@ func startManagement(t *testing.T, config *Config) (*grpc.Server, string, error)
|
||||
if err != nil {
|
||||
return nil, "", err
|
||||
}
|
||||
turnManager := NewTimeBasedAuthSecretsManager(peersUpdateManager, config.TURNConfig)
|
||||
turnManager := NewTimeBasedAuthSecretsManager(peersUpdateManager, config.TURNConfig, "")
|
||||
|
||||
ephemeralMgr := NewEphemeralManager(store, accountManager)
|
||||
mgmtServer, err := NewServer(config, accountManager, peersUpdateManager, turnManager, nil, ephemeralMgr)
|
||||
|
@ -544,7 +544,7 @@ func startServer(config *server.Config) (*grpc.Server, net.Listener) {
|
||||
if err != nil {
|
||||
log.Fatalf("failed creating a manager: %v", err)
|
||||
}
|
||||
turnManager := server.NewTimeBasedAuthSecretsManager(peersUpdateManager, config.TURNConfig)
|
||||
turnManager := server.NewTimeBasedAuthSecretsManager(peersUpdateManager, config.TURNConfig, "")
|
||||
mgmtServer, err := server.NewServer(config, accountManager, peersUpdateManager, turnManager, nil, nil)
|
||||
Expect(err).NotTo(HaveOccurred())
|
||||
mgmtProto.RegisterManagementServiceServer(s, mgmtServer)
|
||||
|
@ -900,7 +900,7 @@ func (am *DefaultAccountManager) updateAccountPeers(account *Account) {
|
||||
continue
|
||||
}
|
||||
remotePeerNetworkMap := account.GetPeerNetworkMap(peer.ID, am.dnsDomain, approvedPeersMap)
|
||||
update := toSyncResponse(nil, peer, nil, remotePeerNetworkMap, am.GetDNSDomain())
|
||||
update := toSyncResponse(nil, peer, nil, nil, remotePeerNetworkMap, am.GetDNSDomain())
|
||||
am.peersUpdateManager.SendUpdate(peer.ID, &UpdateMessage{Update: update})
|
||||
}
|
||||
}
|
||||
|
126
management/server/token_mgr.go
Normal file
126
management/server/token_mgr.go
Normal file
@ -0,0 +1,126 @@
|
||||
package server
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
"github.com/netbirdio/netbird/management/proto"
|
||||
auth "github.com/netbirdio/netbird/relay/auth/hmac"
|
||||
)
|
||||
|
||||
// TURNRelayTokenManager used to manage TURN credentials
|
||||
type TURNRelayTokenManager interface {
|
||||
Generate() (*TURNRelayToken, error)
|
||||
SetupRefresh(peerKey string)
|
||||
CancelRefresh(peerKey string)
|
||||
}
|
||||
|
||||
// TimeBasedAuthSecretsManager generates credentials with TTL and using pre-shared secret known to TURN server
|
||||
type TimeBasedAuthSecretsManager struct {
|
||||
mux sync.Mutex
|
||||
turnCfg *TURNConfig
|
||||
relayAddr string
|
||||
hmacToken *auth.TimedHMAC
|
||||
updateManager *PeersUpdateManager
|
||||
cancelMap map[string]chan struct{}
|
||||
}
|
||||
|
||||
type TURNRelayToken auth.Token
|
||||
|
||||
func NewTimeBasedAuthSecretsManager(updateManager *PeersUpdateManager, turnCfg *TURNConfig, relayAddress string) *TimeBasedAuthSecretsManager {
|
||||
return &TimeBasedAuthSecretsManager{
|
||||
mux: sync.Mutex{},
|
||||
updateManager: updateManager,
|
||||
turnCfg: turnCfg,
|
||||
relayAddr: relayAddress,
|
||||
hmacToken: auth.NewTimedHMAC(turnCfg.Secret, turnCfg.CredentialsTTL.Duration),
|
||||
cancelMap: make(map[string]chan struct{}),
|
||||
}
|
||||
}
|
||||
|
||||
// Generate generates new time-based secret credentials - basically username is a unix timestamp and password is a HMAC hash of a timestamp with a preshared TURN secret
|
||||
func (m *TimeBasedAuthSecretsManager) Generate() (*TURNRelayToken, error) {
|
||||
token, err := m.hmacToken.GenerateToken()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to generate token: %s", err)
|
||||
}
|
||||
|
||||
return (*TURNRelayToken)(token), nil
|
||||
}
|
||||
|
||||
func (m *TimeBasedAuthSecretsManager) cancel(peerID string) {
|
||||
if channel, ok := m.cancelMap[peerID]; ok {
|
||||
close(channel)
|
||||
delete(m.cancelMap, peerID)
|
||||
}
|
||||
}
|
||||
|
||||
// CancelRefresh cancels scheduled peer credentials refresh
|
||||
func (m *TimeBasedAuthSecretsManager) CancelRefresh(peerID string) {
|
||||
m.mux.Lock()
|
||||
defer m.mux.Unlock()
|
||||
m.cancel(peerID)
|
||||
}
|
||||
|
||||
// SetupRefresh starts peer credentials refresh. Since credentials are expiring (TTL) it is necessary to always generate them and send to the peer.
|
||||
// A goroutine is created and put into TimeBasedAuthSecretsManager.cancelMap. This routine should be cancelled if peer is gone.
|
||||
func (m *TimeBasedAuthSecretsManager) SetupRefresh(peerID string) {
|
||||
m.mux.Lock()
|
||||
defer m.mux.Unlock()
|
||||
m.cancel(peerID)
|
||||
cancel := make(chan struct{}, 1)
|
||||
m.cancelMap[peerID] = cancel
|
||||
log.Debugf("starting turn refresh for %s", peerID)
|
||||
|
||||
go func() {
|
||||
// we don't want to regenerate credentials right on expiration, so we do it slightly before (at 3/4 of TTL)
|
||||
ticker := time.NewTicker(m.turnCfg.CredentialsTTL.Duration / 4 * 3)
|
||||
defer ticker.Stop()
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-cancel:
|
||||
log.Debugf("stopping turn refresh for %s", peerID)
|
||||
return
|
||||
case <-ticker.C:
|
||||
m.pushNewTokens(peerID)
|
||||
}
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
func (m *TimeBasedAuthSecretsManager) pushNewTokens(peerID string) {
|
||||
token, err := m.hmacToken.GenerateToken()
|
||||
if err != nil {
|
||||
log.Errorf("failed to generate token for peer '%s': %s", peerID, err)
|
||||
return
|
||||
}
|
||||
|
||||
var turns []*proto.ProtectedHostConfig
|
||||
for _, host := range m.turnCfg.Turns {
|
||||
turns = append(turns, &proto.ProtectedHostConfig{
|
||||
HostConfig: &proto.HostConfig{
|
||||
Uri: host.URI,
|
||||
Protocol: ToResponseProto(host.Proto),
|
||||
},
|
||||
User: token.Payload,
|
||||
Password: token.Signature,
|
||||
})
|
||||
}
|
||||
|
||||
update := &proto.SyncResponse{
|
||||
WiretrusteeConfig: &proto.WiretrusteeConfig{
|
||||
Turns: turns,
|
||||
Relay: &proto.RelayConfig{
|
||||
Urls: []string{m.relayAddr},
|
||||
TokenPayload: token.Payload,
|
||||
TokenSignature: token.Signature,
|
||||
},
|
||||
},
|
||||
}
|
||||
log.Debugf("sending new TURN credentials to peer %s", peerID)
|
||||
m.updateManager.SendUpdate(peerID, &UpdateMessage{Update: update})
|
||||
}
|
@ -26,18 +26,18 @@ func TestTimeBasedAuthSecretsManager_GenerateCredentials(t *testing.T) {
|
||||
CredentialsTTL: ttl,
|
||||
Secret: secret,
|
||||
Turns: []*Host{TurnTestHost},
|
||||
})
|
||||
}, "")
|
||||
|
||||
credentials := tested.GenerateCredentials()
|
||||
credentials, _ := tested.Generate()
|
||||
|
||||
if credentials.Username == "" {
|
||||
if credentials.Payload == "" {
|
||||
t.Errorf("expected generated TURN username not to be empty, got empty")
|
||||
}
|
||||
if credentials.Password == "" {
|
||||
if credentials.Signature == "" {
|
||||
t.Errorf("expected generated TURN password not to be empty, got empty")
|
||||
}
|
||||
|
||||
validateMAC(t, credentials.Username, credentials.Password, []byte(secret))
|
||||
validateMAC(t, credentials.Payload, credentials.Signature, []byte(secret))
|
||||
|
||||
}
|
||||
|
||||
@ -52,7 +52,7 @@ func TestTimeBasedAuthSecretsManager_SetupRefresh(t *testing.T) {
|
||||
CredentialsTTL: ttl,
|
||||
Secret: secret,
|
||||
Turns: []*Host{TurnTestHost},
|
||||
})
|
||||
}, "")
|
||||
|
||||
tested.SetupRefresh(peer)
|
||||
|
||||
@ -100,7 +100,7 @@ func TestTimeBasedAuthSecretsManager_CancelRefresh(t *testing.T) {
|
||||
CredentialsTTL: ttl,
|
||||
Secret: secret,
|
||||
Turns: []*Host{TurnTestHost},
|
||||
})
|
||||
}, "")
|
||||
|
||||
tested.SetupRefresh(peer)
|
||||
if _, ok := tested.cancelMap[peer]; !ok {
|
@ -1,125 +0,0 @@
|
||||
package server
|
||||
|
||||
import (
|
||||
"crypto/hmac"
|
||||
"crypto/sha1"
|
||||
"encoding/base64"
|
||||
"fmt"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
"github.com/netbirdio/netbird/management/proto"
|
||||
)
|
||||
|
||||
// TURNCredentialsManager used to manage TURN credentials
|
||||
type TURNCredentialsManager interface {
|
||||
GenerateCredentials() TURNCredentials
|
||||
SetupRefresh(peerKey string)
|
||||
CancelRefresh(peerKey string)
|
||||
}
|
||||
|
||||
// TimeBasedAuthSecretsManager generates credentials with TTL and using pre-shared secret known to TURN server
|
||||
type TimeBasedAuthSecretsManager struct {
|
||||
mux sync.Mutex
|
||||
config *TURNConfig
|
||||
updateManager *PeersUpdateManager
|
||||
cancelMap map[string]chan struct{}
|
||||
}
|
||||
|
||||
type TURNCredentials struct {
|
||||
Username string
|
||||
Password string
|
||||
}
|
||||
|
||||
func NewTimeBasedAuthSecretsManager(updateManager *PeersUpdateManager, config *TURNConfig) *TimeBasedAuthSecretsManager {
|
||||
return &TimeBasedAuthSecretsManager{
|
||||
mux: sync.Mutex{},
|
||||
config: config,
|
||||
updateManager: updateManager,
|
||||
cancelMap: make(map[string]chan struct{}),
|
||||
}
|
||||
}
|
||||
|
||||
// GenerateCredentials generates new time-based secret credentials - basically username is a unix timestamp and password is a HMAC hash of a timestamp with a preshared TURN secret
|
||||
func (m *TimeBasedAuthSecretsManager) GenerateCredentials() TURNCredentials {
|
||||
mac := hmac.New(sha1.New, []byte(m.config.Secret))
|
||||
|
||||
timeAuth := time.Now().Add(m.config.CredentialsTTL.Duration).Unix()
|
||||
|
||||
username := fmt.Sprint(timeAuth)
|
||||
|
||||
_, err := mac.Write([]byte(username))
|
||||
if err != nil {
|
||||
log.Errorln("Generating turn password failed with error: ", err)
|
||||
}
|
||||
|
||||
bytePassword := mac.Sum(nil)
|
||||
password := base64.StdEncoding.EncodeToString(bytePassword)
|
||||
|
||||
return TURNCredentials{
|
||||
Username: username,
|
||||
Password: password,
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
func (m *TimeBasedAuthSecretsManager) cancel(peerID string) {
|
||||
if channel, ok := m.cancelMap[peerID]; ok {
|
||||
close(channel)
|
||||
delete(m.cancelMap, peerID)
|
||||
}
|
||||
}
|
||||
|
||||
// CancelRefresh cancels scheduled peer credentials refresh
|
||||
func (m *TimeBasedAuthSecretsManager) CancelRefresh(peerID string) {
|
||||
m.mux.Lock()
|
||||
defer m.mux.Unlock()
|
||||
m.cancel(peerID)
|
||||
}
|
||||
|
||||
// SetupRefresh starts peer credentials refresh. Since credentials are expiring (TTL) it is necessary to always generate them and send to the peer.
|
||||
// A goroutine is created and put into TimeBasedAuthSecretsManager.cancelMap. This routine should be cancelled if peer is gone.
|
||||
func (m *TimeBasedAuthSecretsManager) SetupRefresh(peerID string) {
|
||||
m.mux.Lock()
|
||||
defer m.mux.Unlock()
|
||||
m.cancel(peerID)
|
||||
cancel := make(chan struct{}, 1)
|
||||
m.cancelMap[peerID] = cancel
|
||||
log.Debugf("starting turn refresh for %s", peerID)
|
||||
|
||||
go func() {
|
||||
// we don't want to regenerate credentials right on expiration, so we do it slightly before (at 3/4 of TTL)
|
||||
ticker := time.NewTicker(m.config.CredentialsTTL.Duration / 4 * 3)
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-cancel:
|
||||
log.Debugf("stopping turn refresh for %s", peerID)
|
||||
return
|
||||
case <-ticker.C:
|
||||
c := m.GenerateCredentials()
|
||||
var turns []*proto.ProtectedHostConfig
|
||||
for _, host := range m.config.Turns {
|
||||
turns = append(turns, &proto.ProtectedHostConfig{
|
||||
HostConfig: &proto.HostConfig{
|
||||
Uri: host.URI,
|
||||
Protocol: ToResponseProto(host.Proto),
|
||||
},
|
||||
User: c.Username,
|
||||
Password: c.Password,
|
||||
})
|
||||
}
|
||||
|
||||
update := &proto.SyncResponse{
|
||||
WiretrusteeConfig: &proto.WiretrusteeConfig{
|
||||
Turns: turns,
|
||||
},
|
||||
}
|
||||
log.Debugf("sending new TURN credentials to peer %s", peerID)
|
||||
m.updateManager.SendUpdate(peerID, &UpdateMessage{Update: update})
|
||||
}
|
||||
}
|
||||
}()
|
||||
}
|
9
relay/auth/allow_all.go
Normal file
9
relay/auth/allow_all.go
Normal file
@ -0,0 +1,9 @@
|
||||
package auth
|
||||
|
||||
// AllowAllAuth is a Validator that allows all connections.
|
||||
type AllowAllAuth struct {
|
||||
}
|
||||
|
||||
func (a *AllowAllAuth) Validate(any) error {
|
||||
return nil
|
||||
}
|
24
relay/auth/hmac/store.go
Normal file
24
relay/auth/hmac/store.go
Normal file
@ -0,0 +1,24 @@
|
||||
package hmac
|
||||
|
||||
import (
|
||||
"sync"
|
||||
)
|
||||
|
||||
// Store is a simple in-memory store for token
|
||||
// With this can update the token in thread safe way
|
||||
type Store struct {
|
||||
mu sync.Mutex
|
||||
token Token
|
||||
}
|
||||
|
||||
func (a *Store) UpdateToken(token Token) {
|
||||
a.mu.Lock()
|
||||
defer a.mu.Unlock()
|
||||
a.token = token
|
||||
}
|
||||
|
||||
func (a *Store) Token() ([]byte, error) {
|
||||
a.mu.Lock()
|
||||
defer a.mu.Unlock()
|
||||
return marshalToken(a.token)
|
||||
}
|
104
relay/auth/hmac/token.go
Normal file
104
relay/auth/hmac/token.go
Normal file
@ -0,0 +1,104 @@
|
||||
package hmac
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"crypto/hmac"
|
||||
"crypto/sha1"
|
||||
"encoding/base64"
|
||||
"encoding/gob"
|
||||
"fmt"
|
||||
"strconv"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
type Token struct {
|
||||
Payload string
|
||||
Signature string
|
||||
}
|
||||
|
||||
func marshalToken(token Token) ([]byte, error) {
|
||||
buffer := bytes.NewBuffer([]byte{})
|
||||
encoder := gob.NewEncoder(buffer)
|
||||
err := encoder.Encode(token)
|
||||
if err != nil {
|
||||
log.Errorf("failed to marshal token: %s", err)
|
||||
return nil, err
|
||||
}
|
||||
return buffer.Bytes(), nil
|
||||
}
|
||||
|
||||
func unmarshalToken(payload []byte) (Token, error) {
|
||||
var creds Token
|
||||
buffer := bytes.NewBuffer(payload)
|
||||
decoder := gob.NewDecoder(buffer)
|
||||
err := decoder.Decode(&creds)
|
||||
return creds, err
|
||||
}
|
||||
|
||||
// TimedHMAC generates token with TTL and using pre-shared secret known to TURN server
|
||||
type TimedHMAC struct {
|
||||
mux sync.Mutex
|
||||
secret string
|
||||
timeToLive time.Duration
|
||||
}
|
||||
|
||||
func NewTimedHMAC(secret string, timeToLive time.Duration) *TimedHMAC {
|
||||
return &TimedHMAC{
|
||||
secret: secret,
|
||||
timeToLive: timeToLive,
|
||||
}
|
||||
}
|
||||
|
||||
// GenerateToken generates new time-based secret token - basically Payload is a unix timestamp and Signature is a HMAC hash of a timestamp with a preshared TURN secret
|
||||
func (m *TimedHMAC) GenerateToken() (*Token, error) {
|
||||
timeAuth := time.Now().Add(m.timeToLive).Unix()
|
||||
timeStamp := fmt.Sprint(timeAuth)
|
||||
|
||||
checksum, err := m.generate(timeStamp)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &Token{
|
||||
Payload: timeStamp,
|
||||
Signature: base64.StdEncoding.EncodeToString(checksum),
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (m *TimedHMAC) Validate(token Token) error {
|
||||
expectedMAC, err := m.generate(token.Payload)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
expectedSignature := base64.StdEncoding.EncodeToString(expectedMAC)
|
||||
|
||||
if !hmac.Equal([]byte(expectedSignature), []byte(token.Signature)) {
|
||||
return fmt.Errorf("signature mismatch")
|
||||
}
|
||||
|
||||
timeAuthInt, err := strconv.ParseInt(token.Payload, 10, 64)
|
||||
if err != nil {
|
||||
return fmt.Errorf("invalid payload: %s", err)
|
||||
}
|
||||
|
||||
if time.Now().Unix() > timeAuthInt {
|
||||
return fmt.Errorf("expired token")
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *TimedHMAC) generate(payload string) ([]byte, error) {
|
||||
mac := hmac.New(sha1.New, []byte(m.secret))
|
||||
_, err := mac.Write([]byte(payload))
|
||||
if err != nil {
|
||||
log.Errorf("failed to generate token: %s", err)
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return mac.Sum(nil), nil
|
||||
}
|
103
relay/auth/hmac/token_test.go
Normal file
103
relay/auth/hmac/token_test.go
Normal file
@ -0,0 +1,103 @@
|
||||
package hmac
|
||||
|
||||
import (
|
||||
"encoding/base64"
|
||||
"strconv"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
func TestGenerateCredentials(t *testing.T) {
|
||||
secret := "secret"
|
||||
timeToLive := 1 * time.Hour
|
||||
v := NewTimedHMAC(secret, timeToLive)
|
||||
|
||||
creds, err := v.GenerateToken()
|
||||
if err != nil {
|
||||
t.Fatalf("expected no error, got %v", err)
|
||||
}
|
||||
|
||||
if creds.Payload == "" {
|
||||
t.Fatalf("expected non-empty payload")
|
||||
}
|
||||
|
||||
_, err = strconv.ParseInt(creds.Payload, 10, 64)
|
||||
if err != nil {
|
||||
t.Fatalf("expected payload to be a valid unix timestamp, got %v", err)
|
||||
}
|
||||
|
||||
_, err = base64.StdEncoding.DecodeString(creds.Signature)
|
||||
if err != nil {
|
||||
t.Fatalf("expected signature to be base64 encoded, got %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestValidateCredentials(t *testing.T) {
|
||||
secret := "supersecret"
|
||||
timeToLive := 1 * time.Hour
|
||||
manager := NewTimedHMAC(secret, timeToLive)
|
||||
|
||||
// Test valid token
|
||||
creds, err := manager.GenerateToken()
|
||||
if err != nil {
|
||||
t.Fatalf("expected no error, got %v", err)
|
||||
}
|
||||
|
||||
if err := manager.Validate(*creds); err != nil {
|
||||
t.Fatalf("expected valid token: %s", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestInvalidSignature(t *testing.T) {
|
||||
secret := "supersecret"
|
||||
timeToLive := 1 * time.Hour
|
||||
manager := NewTimedHMAC(secret, timeToLive)
|
||||
|
||||
creds, err := manager.GenerateToken()
|
||||
if err != nil {
|
||||
t.Fatalf("expected no error, got %v", err)
|
||||
}
|
||||
|
||||
invalidCreds := &Token{
|
||||
Payload: creds.Payload,
|
||||
Signature: "invalidsignature",
|
||||
}
|
||||
|
||||
if err = manager.Validate(*invalidCreds); err == nil {
|
||||
t.Fatalf("expected invalid token due to signature mismatch")
|
||||
}
|
||||
}
|
||||
|
||||
func TestExpired(t *testing.T) {
|
||||
secret := "supersecret"
|
||||
v := NewTimedHMAC(secret, -1*time.Hour)
|
||||
expiredCreds, err := v.GenerateToken()
|
||||
if err != nil {
|
||||
t.Fatalf("expected no error, got %v", err)
|
||||
}
|
||||
|
||||
if err = v.Validate(*expiredCreds); err == nil {
|
||||
t.Fatalf("expected invalid token due to expiration")
|
||||
}
|
||||
}
|
||||
|
||||
func TestInvalidPayload(t *testing.T) {
|
||||
secret := "supersecret"
|
||||
timeToLive := 1 * time.Hour
|
||||
v := NewTimedHMAC(secret, timeToLive)
|
||||
|
||||
creds, err := v.GenerateToken()
|
||||
if err != nil {
|
||||
t.Fatalf("expected no error, got %v", err)
|
||||
}
|
||||
|
||||
// Test invalid payload
|
||||
invalidPayloadCreds := &Token{
|
||||
Payload: "invalidtimestamp",
|
||||
Signature: creds.Signature,
|
||||
}
|
||||
|
||||
if err = v.Validate(*invalidPayloadCreds); err == nil {
|
||||
t.Fatalf("expected invalid token due to invalid payload")
|
||||
}
|
||||
}
|
27
relay/auth/hmac/validator.go
Normal file
27
relay/auth/hmac/validator.go
Normal file
@ -0,0 +1,27 @@
|
||||
package hmac
|
||||
|
||||
import (
|
||||
log "github.com/sirupsen/logrus"
|
||||
"time"
|
||||
)
|
||||
|
||||
type TimedHMACValidator struct {
|
||||
*TimedHMAC
|
||||
}
|
||||
|
||||
func NewTimedHMACValidator(secret string, duration time.Duration) *TimedHMACValidator {
|
||||
ta := NewTimedHMAC(secret, duration)
|
||||
return &TimedHMACValidator{
|
||||
ta,
|
||||
}
|
||||
}
|
||||
|
||||
func (a *TimedHMACValidator) Validate(credentials any) error {
|
||||
b := credentials.([]byte)
|
||||
c, err := unmarshalToken(b)
|
||||
if err != nil {
|
||||
log.Errorf("failed to unmarshal token: %s", err)
|
||||
return err
|
||||
}
|
||||
return a.TimedHMAC.Validate(c)
|
||||
}
|
5
relay/auth/validator.go
Normal file
5
relay/auth/validator.go
Normal file
@ -0,0 +1,5 @@
|
||||
package auth
|
||||
|
||||
type Validator interface {
|
||||
Validate(any) error
|
||||
}
|
@ -10,6 +10,7 @@ import (
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
auth "github.com/netbirdio/netbird/relay/auth/hmac"
|
||||
"github.com/netbirdio/netbird/relay/client/dialer/ws"
|
||||
"github.com/netbirdio/netbird/relay/healthcheck"
|
||||
"github.com/netbirdio/netbird/relay/messages"
|
||||
@ -98,6 +99,7 @@ type Client struct {
|
||||
log *log.Entry
|
||||
parentCtx context.Context
|
||||
connectionURL string
|
||||
authStore *auth.Store
|
||||
hashedID []byte
|
||||
|
||||
bufPool *sync.Pool
|
||||
@ -115,12 +117,13 @@ type Client struct {
|
||||
}
|
||||
|
||||
// NewClient creates a new client for the relay server. The client is not connected to the server until the Connect
|
||||
func NewClient(ctx context.Context, serverURL, peerID string) *Client {
|
||||
func NewClient(ctx context.Context, serverURL string, authStore *auth.Store, peerID string) *Client {
|
||||
hashedID, hashedStringId := messages.HashID(peerID)
|
||||
return &Client{
|
||||
log: log.WithField("client_id", hashedStringId),
|
||||
parentCtx: ctx,
|
||||
connectionURL: serverURL,
|
||||
authStore: authStore,
|
||||
hashedID: hashedID,
|
||||
bufPool: &sync.Pool{
|
||||
New: func() any {
|
||||
@ -234,7 +237,12 @@ func (c *Client) connect() error {
|
||||
}
|
||||
|
||||
func (c *Client) handShake() error {
|
||||
msg, err := messages.MarshalHelloMsg(c.hashedID)
|
||||
t, err := c.authStore.Token()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
msg, err := messages.MarshalHelloMsg(c.hashedID, t)
|
||||
if err != nil {
|
||||
log.Errorf("failed to marshal hello message: %s", err)
|
||||
return err
|
||||
@ -262,11 +270,11 @@ func (c *Client) handShake() error {
|
||||
return fmt.Errorf("unexpected message type")
|
||||
}
|
||||
|
||||
domain, err := messages.UnmarshalHelloResponse(buf[:n])
|
||||
ia, err := messages.UnmarshalHelloResponse(buf[:n])
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
c.instanceURL = domain
|
||||
c.instanceURL = ia
|
||||
return nil
|
||||
}
|
||||
|
||||
|
@ -8,6 +8,8 @@ import (
|
||||
"time"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
relayAuth "github.com/netbirdio/netbird/relay/auth/hmac"
|
||||
)
|
||||
|
||||
var (
|
||||
@ -35,9 +37,10 @@ func NewRelayTrack() *RelayTrack {
|
||||
// relay servers will be closed if there is no active connection. Periodically the manager will check if there is any
|
||||
// unused relay connection and close it.
|
||||
type Manager struct {
|
||||
ctx context.Context
|
||||
serverURL string
|
||||
peerID string
|
||||
ctx context.Context
|
||||
serverURL string
|
||||
peerID string
|
||||
tokenStore *relayAuth.Store
|
||||
|
||||
relayClient *Client
|
||||
reconnectGuard *Guard
|
||||
@ -54,6 +57,7 @@ func NewManager(ctx context.Context, serverURL string, peerID string) *Manager {
|
||||
ctx: ctx,
|
||||
serverURL: serverURL,
|
||||
peerID: peerID,
|
||||
tokenStore: &relayAuth.Store{},
|
||||
relayClients: make(map[string]*RelayTrack),
|
||||
onDisconnectedListeners: make(map[string]map[*func()]struct{}),
|
||||
}
|
||||
@ -65,7 +69,7 @@ func (m *Manager) Serve() error {
|
||||
return fmt.Errorf("manager already serving")
|
||||
}
|
||||
|
||||
m.relayClient = NewClient(m.ctx, m.serverURL, m.peerID)
|
||||
m.relayClient = NewClient(m.ctx, m.serverURL, m.tokenStore, m.peerID)
|
||||
err := m.relayClient.Connect()
|
||||
if err != nil {
|
||||
log.Errorf("failed to connect to relay server: %s", err)
|
||||
@ -158,7 +162,7 @@ func (m *Manager) openConnVia(serverAddress, peerKey string) (net.Conn, error) {
|
||||
m.relayClients[serverAddress] = rt
|
||||
m.relayClientsMutex.Unlock()
|
||||
|
||||
relayClient := NewClient(m.ctx, serverAddress, m.peerID)
|
||||
relayClient := NewClient(m.ctx, serverAddress, m.tokenStore, m.peerID)
|
||||
err := relayClient.Connect()
|
||||
if err != nil {
|
||||
rt.Unlock()
|
||||
@ -260,3 +264,7 @@ func (m *Manager) notifyOnDisconnectListeners(serverAddress string) {
|
||||
m.listenerLock.Unlock()
|
||||
|
||||
}
|
||||
|
||||
func (m *Manager) UpdateToken(token relayAuth.Token) {
|
||||
m.tokenStore.UpdateToken(token)
|
||||
}
|
||||
|
@ -24,6 +24,7 @@ var (
|
||||
letsencryptDomains []string
|
||||
tlsCertFile string
|
||||
tlsKeyFile string
|
||||
authSecret string
|
||||
|
||||
rootCmd = &cobra.Command{
|
||||
Use: "relay",
|
||||
@ -41,7 +42,7 @@ func init() {
|
||||
rootCmd.PersistentFlags().StringArrayVarP(&letsencryptDomains, "letsencrypt-domains", "a", nil, "list of domains to issue Let's Encrypt certificate for. Enables TLS using Let's Encrypt. Will fetch and renew certificate, and run the server with TLS")
|
||||
rootCmd.PersistentFlags().StringVarP(&tlsCertFile, "tls-cert-file", "c", "", "")
|
||||
rootCmd.PersistentFlags().StringVarP(&tlsKeyFile, "tls-key-file", "k", "", "")
|
||||
|
||||
rootCmd.PersistentFlags().StringVarP(&authSecret, "auth-secret", "s", "", "log level")
|
||||
}
|
||||
|
||||
func waitForExitSignal() {
|
||||
@ -56,6 +57,11 @@ func execute(cmd *cobra.Command, args []string) {
|
||||
os.Exit(1)
|
||||
}
|
||||
|
||||
if authSecret == "" {
|
||||
log.Errorf("auth secret is required")
|
||||
os.Exit(1)
|
||||
}
|
||||
|
||||
srvListenerCfg := server.ListenerConfig{
|
||||
Address: listenAddress,
|
||||
}
|
||||
@ -76,7 +82,7 @@ func execute(cmd *cobra.Command, args []string) {
|
||||
}
|
||||
|
||||
tlsSupport := srvListenerCfg.TLSConfig != nil
|
||||
srv := server.NewServer(exposedAddress, tlsSupport)
|
||||
srv := server.NewServer(exposedAddress, tlsSupport, authSecret)
|
||||
log.Infof("server will be available on: %s", srv.InstanceURL())
|
||||
err := srv.Listen(srvListenerCfg)
|
||||
if err != nil {
|
||||
|
@ -15,8 +15,10 @@ const (
|
||||
MsgTypeClose MsgType = 3
|
||||
MsgTypeHealthCheck MsgType = 4
|
||||
|
||||
headerSizeTransport = 1 + IDSize // 1 byte for msg type, IDSize for peerID
|
||||
headerSizeHello = 1 + 4 + IDSize // 1 byte for msg type, 4 byte for magic header, IDSize for peerID
|
||||
sizeOfMsgType = 1
|
||||
sizeOfMagicBye = 4
|
||||
headerSizeTransport = sizeOfMsgType + IDSize // 1 byte for msg type, IDSize for peerID
|
||||
headerSizeHello = sizeOfMsgType + sizeOfMagicBye + IDSize // 1 byte for msg type, 4 byte for magic header, IDSize for peerID
|
||||
|
||||
MaxHandshakeSize = 90
|
||||
)
|
||||
@ -47,7 +49,7 @@ func (m MsgType) String() string {
|
||||
}
|
||||
|
||||
type HelloResponse struct {
|
||||
DomainAddress string
|
||||
InstanceAddress string
|
||||
}
|
||||
|
||||
func DetermineClientMsgType(msg []byte) (MsgType, error) {
|
||||
@ -83,28 +85,29 @@ func DetermineServerMsgType(msg []byte) (MsgType, error) {
|
||||
}
|
||||
|
||||
// MarshalHelloMsg initial hello message
|
||||
func MarshalHelloMsg(peerID []byte) ([]byte, error) {
|
||||
func MarshalHelloMsg(peerID []byte, additions []byte) ([]byte, error) {
|
||||
if len(peerID) != IDSize {
|
||||
return nil, fmt.Errorf("invalid peerID length: %d", len(peerID))
|
||||
}
|
||||
msg := make([]byte, 5, headerSizeHello)
|
||||
msg := make([]byte, 5, headerSizeHello+len(additions))
|
||||
msg[0] = byte(MsgTypeHello)
|
||||
copy(msg[1:5], magicHeader)
|
||||
msg = append(msg, peerID...)
|
||||
msg = append(msg, additions...)
|
||||
return msg, nil
|
||||
}
|
||||
|
||||
func UnmarshalHelloMsg(msg []byte) ([]byte, error) {
|
||||
func UnmarshalHelloMsg(msg []byte) ([]byte, []byte, error) {
|
||||
if len(msg) < headerSizeHello {
|
||||
return nil, fmt.Errorf("invalid 'hello' messge")
|
||||
return nil, nil, fmt.Errorf("invalid 'hello' messge")
|
||||
}
|
||||
bytes.Equal(msg[1:5], magicHeader)
|
||||
return msg[5:], nil
|
||||
return msg[5:], msg[headerSizeHello:], nil
|
||||
}
|
||||
|
||||
func MarshalHelloResponse(DomainAddress string) ([]byte, error) {
|
||||
payload := HelloResponse{
|
||||
DomainAddress: DomainAddress,
|
||||
InstanceAddress: DomainAddress,
|
||||
}
|
||||
|
||||
buf := new(bytes.Buffer)
|
||||
@ -135,7 +138,7 @@ func UnmarshalHelloResponse(msg []byte) (string, error) {
|
||||
log.Errorf("failed to gob decode hello response: %s", err)
|
||||
return "", err
|
||||
}
|
||||
return payload.DomainAddress, nil
|
||||
return payload.InstanceAddress, nil
|
||||
}
|
||||
|
||||
// Close message
|
||||
|
@ -6,12 +6,12 @@ import (
|
||||
|
||||
func TestMarshalHelloMsg(t *testing.T) {
|
||||
peerID := []byte("abdFAaBcawquEiCMzAabYosuUaGLtSNhKxz+")
|
||||
bHello, err := MarshalHelloMsg(peerID)
|
||||
bHello, err := MarshalHelloMsg(peerID, nil)
|
||||
if err != nil {
|
||||
t.Fatalf("error: %v", err)
|
||||
}
|
||||
|
||||
receivedPeerID, err := UnmarshalHelloMsg(bHello)
|
||||
receivedPeerID, _, err := UnmarshalHelloMsg(bHello)
|
||||
if err != nil {
|
||||
t.Fatalf("error: %v", err)
|
||||
}
|
||||
|
@ -8,20 +8,24 @@ import (
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
"github.com/netbirdio/netbird/relay/auth"
|
||||
"github.com/netbirdio/netbird/relay/messages"
|
||||
)
|
||||
|
||||
type Relay struct {
|
||||
validator auth.Validator
|
||||
|
||||
store *Store
|
||||
instaceURL string // domain:port
|
||||
instaceURL string
|
||||
|
||||
closed bool
|
||||
closeMu sync.RWMutex
|
||||
}
|
||||
|
||||
func NewRelay(exposedAddress string, tlsSupport bool) *Relay {
|
||||
func NewRelay(exposedAddress string, tlsSupport bool, validator auth.Validator) *Relay {
|
||||
r := &Relay{
|
||||
store: NewStore(),
|
||||
validator: validator,
|
||||
store: NewStore(),
|
||||
}
|
||||
|
||||
if tlsSupport {
|
||||
@ -29,6 +33,7 @@ func NewRelay(exposedAddress string, tlsSupport bool) *Relay {
|
||||
} else {
|
||||
r.instaceURL = fmt.Sprintf("rel://%s", exposedAddress)
|
||||
}
|
||||
|
||||
return r
|
||||
}
|
||||
|
||||
@ -94,12 +99,17 @@ func (r *Relay) handShake(conn net.Conn) ([]byte, error) {
|
||||
return nil, tErr
|
||||
}
|
||||
|
||||
peerID, err := messages.UnmarshalHelloMsg(buf[:n])
|
||||
peerID, authPayload, err := messages.UnmarshalHelloMsg(buf[:n])
|
||||
if err != nil {
|
||||
log.Errorf("failed to handshake: %s", err)
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if err := r.validator.Validate(authPayload); err != nil {
|
||||
log.Errorf("failed to authenticate peer with id: %s, %s", peerID, err)
|
||||
return nil, fmt.Errorf("failed to authenticate peer")
|
||||
}
|
||||
|
||||
msg, _ := messages.MarshalHelloResponse(r.instaceURL)
|
||||
_, err = conn.Write(msg)
|
||||
if err != nil {
|
||||
|
@ -9,6 +9,7 @@ import (
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
auth "github.com/netbirdio/netbird/relay/auth/hmac"
|
||||
"github.com/netbirdio/netbird/relay/server/listener"
|
||||
"github.com/netbirdio/netbird/relay/server/listener/udp"
|
||||
"github.com/netbirdio/netbird/relay/server/listener/ws"
|
||||
@ -25,9 +26,12 @@ type Server struct {
|
||||
wSListener listener.Listener
|
||||
}
|
||||
|
||||
func NewServer(exposedAddress string, tlsSupport bool) *Server {
|
||||
func NewServer(exposedAddress string, tlsSupport bool, authSecret string) *Server {
|
||||
return &Server{
|
||||
relay: NewRelay(exposedAddress, tlsSupport),
|
||||
relay: NewRelay(
|
||||
exposedAddress,
|
||||
tlsSupport,
|
||||
auth.NewTimedHMACValidator(authSecret, 24*time.Hour)),
|
||||
}
|
||||
}
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user