Integrate the relay authentication

This commit is contained in:
Zoltan Papp 2024-07-05 16:12:30 +02:00
parent 8845e8fbc7
commit 836072098b
30 changed files with 3055 additions and 1594 deletions

View File

@ -86,7 +86,7 @@ func startManagement(t *testing.T, config *mgmt.Config) (*grpc.Server, net.Liste
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
turnManager := mgmt.NewTimeBasedAuthSecretsManager(peersUpdateManager, config.TURNConfig) turnManager := mgmt.NewTimeBasedAuthSecretsManager(peersUpdateManager, config.TURNConfig, "")
mgmtServer, err := mgmt.NewServer(config, accountManager, peersUpdateManager, turnManager, nil, nil) mgmtServer, err := mgmt.NewServer(config, accountManager, peersUpdateManager, turnManager, nil, nil)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)

View File

@ -26,6 +26,7 @@ import (
"github.com/netbirdio/netbird/iface" "github.com/netbirdio/netbird/iface"
mgm "github.com/netbirdio/netbird/management/client" mgm "github.com/netbirdio/netbird/management/client"
mgmProto "github.com/netbirdio/netbird/management/proto" mgmProto "github.com/netbirdio/netbird/management/proto"
"github.com/netbirdio/netbird/relay/auth/hmac"
relayClient "github.com/netbirdio/netbird/relay/client" relayClient "github.com/netbirdio/netbird/relay/client"
signal "github.com/netbirdio/netbird/signal/client" signal "github.com/netbirdio/netbird/signal/client"
"github.com/netbirdio/netbird/util" "github.com/netbirdio/netbird/util"
@ -245,9 +246,10 @@ func (c *ConnectClient) run(
c.statusRecorder.MarkSignalConnected() c.statusRecorder.MarkSignalConnected()
relayAddress := relayAddress(loginResp) relayURL, token := parseRelayInfo(loginResp)
relayManager := relayClient.NewManager(engineCtx, relayAddress, myPrivateKey.PublicKey().String()) relayManager := relayClient.NewManager(engineCtx, relayURL, myPrivateKey.PublicKey().String())
if relayAddress != "" { if relayURL != "" {
relayManager.UpdateToken(token)
if err = relayManager.Serve(); err != nil { if err = relayManager.Serve(); err != nil {
log.Error(err) log.Error(err)
return wrapErr(err) return wrapErr(err)
@ -307,15 +309,27 @@ func (c *ConnectClient) run(
return nil return nil
} }
func relayAddress(resp *mgmProto.LoginResponse) string { func parseRelayInfo(resp *mgmProto.LoginResponse) (string, hmac.Token) {
// todo remove this
if ra := peer.ForcedRelayAddress(); ra != "" { if ra := peer.ForcedRelayAddress(); ra != "" {
return ra return ra, hmac.Token{}
} }
if resp.GetWiretrusteeConfig().GetRelayAddress() != "" { msg := resp.GetWiretrusteeConfig().GetRelay()
return resp.GetWiretrusteeConfig().GetRelayAddress() if msg == nil {
return "", hmac.Token{}
} }
return ""
var url string
if msg.GetUrls() != nil && len(msg.GetUrls()) > 0 {
url = msg.GetUrls()[0]
}
token := hmac.Token{
Payload: msg.GetTokenPayload(),
Signature: msg.GetTokenSignature(),
}
return url, token
} }
func (c *ConnectClient) Engine() *Engine { func (c *ConnectClient) Engine() *Engine {

View File

@ -24,6 +24,7 @@ import (
"github.com/netbirdio/netbird/client/firewall/manager" "github.com/netbirdio/netbird/client/firewall/manager"
"github.com/netbirdio/netbird/client/internal/acl" "github.com/netbirdio/netbird/client/internal/acl"
"github.com/netbirdio/netbird/client/internal/dns" "github.com/netbirdio/netbird/client/internal/dns"
"github.com/netbirdio/netbird/client/internal/networkmonitor" "github.com/netbirdio/netbird/client/internal/networkmonitor"
"github.com/netbirdio/netbird/client/internal/peer" "github.com/netbirdio/netbird/client/internal/peer"
"github.com/netbirdio/netbird/client/internal/relay" "github.com/netbirdio/netbird/client/internal/relay"
@ -36,6 +37,7 @@ import (
"github.com/netbirdio/netbird/iface/bind" "github.com/netbirdio/netbird/iface/bind"
mgm "github.com/netbirdio/netbird/management/client" mgm "github.com/netbirdio/netbird/management/client"
mgmProto "github.com/netbirdio/netbird/management/proto" mgmProto "github.com/netbirdio/netbird/management/proto"
auth "github.com/netbirdio/netbird/relay/auth/hmac"
relayClient "github.com/netbirdio/netbird/relay/client" relayClient "github.com/netbirdio/netbird/relay/client"
"github.com/netbirdio/netbird/route" "github.com/netbirdio/netbird/route"
signal "github.com/netbirdio/netbird/signal/client" signal "github.com/netbirdio/netbird/signal/client"
@ -467,12 +469,13 @@ func (e *Engine) handleSync(update *mgmProto.SyncResponse) error {
defer e.syncMsgMux.Unlock() defer e.syncMsgMux.Unlock()
if update.GetWiretrusteeConfig() != nil { if update.GetWiretrusteeConfig() != nil {
err := e.updateTURNs(update.GetWiretrusteeConfig().GetTurns()) wCfg := update.GetWiretrusteeConfig()
err := e.updateTURNs(wCfg.GetTurns())
if err != nil { if err != nil {
return err return err
} }
err = e.updateSTUNs(update.GetWiretrusteeConfig().GetStuns()) err = e.updateSTUNs(wCfg.GetStuns())
if err != nil { if err != nil {
return err return err
} }
@ -482,8 +485,16 @@ func (e *Engine) handleSync(update *mgmProto.SyncResponse) error {
stunTurn = append(stunTurn, e.TURNs...) stunTurn = append(stunTurn, e.TURNs...)
e.StunTurn.Store(stunTurn) e.StunTurn.Store(stunTurn)
// todo update relay address in the relay manager relayMsg := wCfg.GetRelay()
if relayMsg != nil {
c := auth.Token{
Payload: relayMsg.GetTokenPayload(),
Signature: relayMsg.GetTokenSignature(),
}
e.relayManager.UpdateToken(c)
}
// todo update relay address in the relay manager
// todo update signal // todo update signal
} }

View File

@ -1071,7 +1071,7 @@ func startManagement(dataDir string) (*grpc.Server, string, error) {
if err != nil { if err != nil {
return nil, "", err return nil, "", err
} }
turnManager := server.NewTimeBasedAuthSecretsManager(peersUpdateManager, config.TURNConfig) turnManager := server.NewTimeBasedAuthSecretsManager(peersUpdateManager, config.TURNConfig, "")
mgmtServer, err := server.NewServer(config, accountManager, peersUpdateManager, turnManager, nil, nil) mgmtServer, err := server.NewServer(config, accountManager, peersUpdateManager, turnManager, nil, nil)
if err != nil { if err != nil {
return nil, "", err return nil, "", err

View File

@ -122,7 +122,7 @@ func startManagement(t *testing.T, signalAddr string, counter *int) (*grpc.Serve
if err != nil { if err != nil {
return nil, "", err return nil, "", err
} }
turnManager := server.NewTimeBasedAuthSecretsManager(peersUpdateManager, config.TURNConfig) turnManager := server.NewTimeBasedAuthSecretsManager(peersUpdateManager, config.TURNConfig, "")
mgmtServer, err := server.NewServer(config, accountManager, peersUpdateManager, turnManager, nil, nil) mgmtServer, err := server.NewServer(config, accountManager, peersUpdateManager, turnManager, nil, nil)
if err != nil { if err != nil {
return nil, "", err return nil, "", err

View File

@ -75,7 +75,7 @@ func startManagement(t *testing.T) (*grpc.Server, net.Listener) {
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
turnManager := mgmt.NewTimeBasedAuthSecretsManager(peersUpdateManager, config.TURNConfig) turnManager := mgmt.NewTimeBasedAuthSecretsManager(peersUpdateManager, config.TURNConfig, "")
mgmtServer, err := mgmt.NewServer(config, accountManager, peersUpdateManager, turnManager, nil, nil) mgmtServer, err := mgmt.NewServer(config, accountManager, peersUpdateManager, turnManager, nil, nil)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)

View File

@ -183,7 +183,7 @@ var (
return fmt.Errorf("failed to build default manager: %v", err) return fmt.Errorf("failed to build default manager: %v", err)
} }
turnManager := server.NewTimeBasedAuthSecretsManager(peersUpdateManager, config.TURNConfig) turnRelayTokenManager := server.NewTimeBasedAuthSecretsManager(peersUpdateManager, config.TURNConfig, config.RelayAddress)
trustedPeers := config.ReverseProxy.TrustedPeers trustedPeers := config.ReverseProxy.TrustedPeers
defaultTrustedPeers := []netip.Prefix{netip.MustParsePrefix("0.0.0.0/0"), netip.MustParsePrefix("::/0")} defaultTrustedPeers := []netip.Prefix{netip.MustParsePrefix("0.0.0.0/0"), netip.MustParsePrefix("::/0")}
@ -260,7 +260,7 @@ var (
ephemeralManager.LoadInitialPeers() ephemeralManager.LoadInitialPeers()
gRPCAPIHandler := grpc.NewServer(gRPCOpts...) gRPCAPIHandler := grpc.NewServer(gRPCOpts...)
srv, err := server.NewServer(config, accountManager, peersUpdateManager, turnManager, appMetrics, ephemeralManager) srv, err := server.NewServer(config, accountManager, peersUpdateManager, turnRelayTokenManager, appMetrics, ephemeralManager)
if err != nil { if err != nil {
return fmt.Errorf("failed creating gRPC API handler: %v", err) return fmt.Errorf("failed creating gRPC API handler: %v", err)
} }

File diff suppressed because it is too large Load Diff

View File

@ -147,7 +147,7 @@ message WiretrusteeConfig {
// a Signal server config // a Signal server config
HostConfig signal = 3; HostConfig signal = 3;
string RelayAddress = 4; RelayConfig relay = 4;
} }
// HostConfig describes connection properties of some server (e.g. STUN, Signal, Management) // HostConfig describes connection properties of some server (e.g. STUN, Signal, Management)
@ -164,6 +164,13 @@ message HostConfig {
DTLS = 4; DTLS = 4;
} }
} }
message RelayConfig {
repeated string urls = 1;
string tokenPayload = 2;
string tokenSignature = 3;
}
// ProtectedHostConfig is similar to HostConfig but has additional user and password // ProtectedHostConfig is similar to HostConfig but has additional user and password
// Mostly used for TURN servers // Mostly used for TURN servers
message ProtectedHostConfig { message ProtectedHostConfig {

View File

@ -166,9 +166,9 @@ func TestGetNetworkMap_DNSConfigSync(t *testing.T) {
newAccountDNSConfig, err := am.GetNetworkMap(peer1.ID) newAccountDNSConfig, err := am.GetNetworkMap(peer1.ID)
require.NoError(t, err) require.NoError(t, err)
require.Len(t, newAccountDNSConfig.DNSConfig.CustomZones, 1, "default DNS config should have one custom zone for peers") require.Len(t, newAccountDNSConfig.DNSConfig.CustomZones, 1, "default DNS turnCfg should have one custom zone for peers")
require.True(t, newAccountDNSConfig.DNSConfig.ServiceEnable, "default DNS config should have local DNS service enabled") require.True(t, newAccountDNSConfig.DNSConfig.ServiceEnable, "default DNS turnCfg should have local DNS service enabled")
require.Len(t, newAccountDNSConfig.DNSConfig.NameServerGroups, 0, "updated DNS config should have no nameserver groups since peer 1 is NS for the only existing NS group") require.Len(t, newAccountDNSConfig.DNSConfig.NameServerGroups, 0, "updated DNS turnCfg should have no nameserver groups since peer 1 is NS for the only existing NS group")
dnsSettings := account.DNSSettings.Copy() dnsSettings := account.DNSSettings.Copy()
dnsSettings.DisabledManagementGroups = append(dnsSettings.DisabledManagementGroups, dnsGroup1ID) dnsSettings.DisabledManagementGroups = append(dnsSettings.DisabledManagementGroups, dnsGroup1ID)
@ -178,13 +178,13 @@ func TestGetNetworkMap_DNSConfigSync(t *testing.T) {
updatedAccountDNSConfig, err := am.GetNetworkMap(peer1.ID) updatedAccountDNSConfig, err := am.GetNetworkMap(peer1.ID)
require.NoError(t, err) require.NoError(t, err)
require.Len(t, updatedAccountDNSConfig.DNSConfig.CustomZones, 0, "updated DNS config should have no custom zone when peer belongs to a disabled group") require.Len(t, updatedAccountDNSConfig.DNSConfig.CustomZones, 0, "updated DNS turnCfg should have no custom zone when peer belongs to a disabled group")
require.False(t, updatedAccountDNSConfig.DNSConfig.ServiceEnable, "updated DNS config should have local DNS service disabled when peer belongs to a disabled group") require.False(t, updatedAccountDNSConfig.DNSConfig.ServiceEnable, "updated DNS turnCfg should have local DNS service disabled when peer belongs to a disabled group")
peer2AccountDNSConfig, err := am.GetNetworkMap(peer2.ID) peer2AccountDNSConfig, err := am.GetNetworkMap(peer2.ID)
require.NoError(t, err) require.NoError(t, err)
require.Len(t, peer2AccountDNSConfig.DNSConfig.CustomZones, 1, "DNS config should have one custom zone for peers not in the disabled group") require.Len(t, peer2AccountDNSConfig.DNSConfig.CustomZones, 1, "DNS turnCfg should have one custom zone for peers not in the disabled group")
require.True(t, peer2AccountDNSConfig.DNSConfig.ServiceEnable, "DNS config should have DNS service enabled for peers not in the disabled group") require.True(t, peer2AccountDNSConfig.DNSConfig.ServiceEnable, "DNS turnCfg should have DNS service enabled for peers not in the disabled group")
require.Len(t, peer2AccountDNSConfig.DNSConfig.NameServerGroups, 1, "updated DNS config should have 1 nameserver groups since peer 2 is part of the group All") require.Len(t, peer2AccountDNSConfig.DNSConfig.NameServerGroups, 1, "updated DNS turnCfg should have 1 nameserver groups since peer 2 is part of the group All")
} }
func createDNSManager(t *testing.T) (*DefaultAccountManager, error) { func createDNSManager(t *testing.T) (*DefaultAccountManager, error) {

View File

@ -29,17 +29,17 @@ type GRPCServer struct {
accountManager AccountManager accountManager AccountManager
wgKey wgtypes.Key wgKey wgtypes.Key
proto.UnimplementedManagementServiceServer proto.UnimplementedManagementServiceServer
peersUpdateManager *PeersUpdateManager peersUpdateManager *PeersUpdateManager
config *Config config *Config
turnCredentialsManager TURNCredentialsManager turnRelayTokenManager TURNRelayTokenManager
jwtValidator *jwtclaims.JWTValidator jwtValidator *jwtclaims.JWTValidator
jwtClaimsExtractor *jwtclaims.ClaimsExtractor jwtClaimsExtractor *jwtclaims.ClaimsExtractor
appMetrics telemetry.AppMetrics appMetrics telemetry.AppMetrics
ephemeralManager *EphemeralManager ephemeralManager *EphemeralManager
} }
// NewServer creates a new Management server // NewServer creates a new Management server
func NewServer(config *Config, accountManager AccountManager, peersUpdateManager *PeersUpdateManager, turnCredentialsManager TURNCredentialsManager, appMetrics telemetry.AppMetrics, ephemeralManager *EphemeralManager) (*GRPCServer, error) { func NewServer(config *Config, accountManager AccountManager, peersUpdateManager *PeersUpdateManager, turnRelayTokenManager TURNRelayTokenManager, appMetrics telemetry.AppMetrics, ephemeralManager *EphemeralManager) (*GRPCServer, error) {
key, err := wgtypes.GeneratePrivateKey() key, err := wgtypes.GeneratePrivateKey()
if err != nil { if err != nil {
return nil, err return nil, err
@ -58,7 +58,7 @@ func NewServer(config *Config, accountManager AccountManager, peersUpdateManager
return nil, status.Errorf(codes.Internal, "unable to create new jwt middleware, err: %v", err) return nil, status.Errorf(codes.Internal, "unable to create new jwt middleware, err: %v", err)
} }
} else { } else {
log.Debug("unable to use http config to create new jwt middleware") log.Debug("unable to use http turnCfg to create new jwt middleware")
} }
if appMetrics != nil { if appMetrics != nil {
@ -84,14 +84,14 @@ func NewServer(config *Config, accountManager AccountManager, peersUpdateManager
return &GRPCServer{ return &GRPCServer{
wgKey: key, wgKey: key,
// peerKey -> event channel // peerKey -> event channel
peersUpdateManager: peersUpdateManager, peersUpdateManager: peersUpdateManager,
accountManager: accountManager, accountManager: accountManager,
config: config, config: config,
turnCredentialsManager: turnCredentialsManager, turnRelayTokenManager: turnRelayTokenManager,
jwtValidator: jwtValidator, jwtValidator: jwtValidator,
jwtClaimsExtractor: jwtClaimsExtractor, jwtClaimsExtractor: jwtClaimsExtractor,
appMetrics: appMetrics, appMetrics: appMetrics,
ephemeralManager: ephemeralManager, ephemeralManager: ephemeralManager,
}, nil }, nil
} }
@ -150,7 +150,7 @@ func (s *GRPCServer) Sync(req *proto.EncryptedMessage, srv proto.ManagementServi
s.ephemeralManager.OnPeerConnected(peer) s.ephemeralManager.OnPeerConnected(peer)
if s.config.TURNConfig.TimeBasedCredentials { if s.config.TURNConfig.TimeBasedCredentials {
s.turnCredentialsManager.SetupRefresh(peer.ID) s.turnRelayTokenManager.SetupRefresh(peer.ID)
} }
if s.appMetrics != nil { if s.appMetrics != nil {
@ -201,7 +201,7 @@ func (s *GRPCServer) Sync(req *proto.EncryptedMessage, srv proto.ManagementServi
func (s *GRPCServer) cancelPeerRoutines(peer *nbpeer.Peer) { func (s *GRPCServer) cancelPeerRoutines(peer *nbpeer.Peer) {
s.peersUpdateManager.CloseChannel(peer.ID) s.peersUpdateManager.CloseChannel(peer.ID)
s.turnCredentialsManager.CancelRefresh(peer.ID) s.turnRelayTokenManager.CancelRefresh(peer.ID)
_ = s.accountManager.CancelPeerRoutines(peer) _ = s.accountManager.CancelPeerRoutines(peer)
s.ephemeralManager.OnPeerDisconnected(peer) s.ephemeralManager.OnPeerDisconnected(peer)
} }
@ -377,9 +377,14 @@ func (s *GRPCServer) Login(ctx context.Context, req *proto.EncryptedMessage) (*p
s.ephemeralManager.OnPeerDisconnected(peer) s.ephemeralManager.OnPeerDisconnected(peer)
} }
trt, err := s.turnRelayTokenManager.Generate()
if err != nil {
log.Errorf("failed generating TURN and Relay token: %v", err)
}
// if peer has reached this point then it has logged in // if peer has reached this point then it has logged in
loginResp := &proto.LoginResponse{ loginResp := &proto.LoginResponse{
WiretrusteeConfig: toWiretrusteeConfig(s.config, nil), WiretrusteeConfig: toWiretrusteeConfig(s.config, nil, trt),
PeerConfig: toPeerConfig(peer, netMap.Network, s.accountManager.GetDNSDomain()), PeerConfig: toPeerConfig(peer, netMap.Network, s.accountManager.GetDNSDomain()),
} }
encryptedResp, err := encryption.EncryptMessage(peerKey, s.wgKey, loginResp) encryptedResp, err := encryption.EncryptMessage(peerKey, s.wgKey, loginResp)
@ -407,11 +412,11 @@ func ToResponseProto(configProto Protocol) proto.HostConfig_Protocol {
case TCP: case TCP:
return proto.HostConfig_TCP return proto.HostConfig_TCP
default: default:
panic(fmt.Errorf("unexpected config protocol type %v", configProto)) panic(fmt.Errorf("unexpected turnCfg protocol type %v", configProto))
} }
} }
func toWiretrusteeConfig(config *Config, turnCredentials *TURNCredentials) *proto.WiretrusteeConfig { func toWiretrusteeConfig(config *Config, turnCredentials *TURNRelayToken, relayToken *TURNRelayToken) *proto.WiretrusteeConfig {
if config == nil { if config == nil {
return nil return nil
} }
@ -427,8 +432,8 @@ func toWiretrusteeConfig(config *Config, turnCredentials *TURNCredentials) *prot
var username string var username string
var password string var password string
if turnCredentials != nil { if turnCredentials != nil {
username = turnCredentials.Username username = turnCredentials.Payload
password = turnCredentials.Password password = turnCredentials.Signature
} else { } else {
username = turn.Username username = turn.Username
password = turn.Password password = turn.Password
@ -443,6 +448,18 @@ func toWiretrusteeConfig(config *Config, turnCredentials *TURNCredentials) *prot
}) })
} }
var relayCfg *proto.RelayConfig
if config.RelayAddress != "" {
relayCfg = &proto.RelayConfig{
Urls: []string{config.RelayAddress},
}
if relayToken != nil {
relayCfg.TokenPayload = relayToken.Payload
relayCfg.TokenSignature = relayToken.Signature
}
}
return &proto.WiretrusteeConfig{ return &proto.WiretrusteeConfig{
Stuns: stuns, Stuns: stuns,
Turns: turns, Turns: turns,
@ -450,7 +467,7 @@ func toWiretrusteeConfig(config *Config, turnCredentials *TURNCredentials) *prot
Uri: config.Signal.URI, Uri: config.Signal.URI,
Protocol: ToResponseProto(config.Signal.Proto), Protocol: ToResponseProto(config.Signal.Proto),
}, },
RelayAddress: config.RelayAddress, Relay: relayCfg,
} }
} }
@ -478,8 +495,8 @@ func toRemotePeerConfig(peers []*nbpeer.Peer, dnsName string) []*proto.RemotePee
return remotePeers return remotePeers
} }
func toSyncResponse(config *Config, peer *nbpeer.Peer, turnCredentials *TURNCredentials, networkMap *NetworkMap, dnsName string) *proto.SyncResponse { func toSyncResponse(config *Config, peer *nbpeer.Peer, turnCredentials *TURNRelayToken, relayCredentials *TURNRelayToken, networkMap *NetworkMap, dnsName string) *proto.SyncResponse {
wtConfig := toWiretrusteeConfig(config, turnCredentials) wtConfig := toWiretrusteeConfig(config, turnCredentials, relayCredentials)
pConfig := toPeerConfig(peer, networkMap.Network, dnsName) pConfig := toPeerConfig(peer, networkMap.Network, dnsName)
@ -520,14 +537,16 @@ func (s *GRPCServer) IsHealthy(ctx context.Context, req *proto.Empty) (*proto.Em
// sendInitialSync sends initial proto.SyncResponse to the peer requesting synchronization // sendInitialSync sends initial proto.SyncResponse to the peer requesting synchronization
func (s *GRPCServer) sendInitialSync(peerKey wgtypes.Key, peer *nbpeer.Peer, networkMap *NetworkMap, srv proto.ManagementService_SyncServer) error { func (s *GRPCServer) sendInitialSync(peerKey wgtypes.Key, peer *nbpeer.Peer, networkMap *NetworkMap, srv proto.ManagementService_SyncServer) error {
// make secret time based TURN credentials optional // make secret time based TURN credentials optional
var turnCredentials *TURNCredentials var turnCredentials *TURNRelayToken
if s.config.TURNConfig.TimeBasedCredentials { trt, err := s.turnRelayTokenManager.Generate()
creds := s.turnCredentialsManager.GenerateCredentials() if err != nil {
turnCredentials = &creds log.Errorf("failed generating TURN and Relay token: %v", err)
} else {
turnCredentials = nil
} }
plainResp := toSyncResponse(s.config, peer, turnCredentials, networkMap, s.accountManager.GetDNSDomain()) if s.config.TURNConfig.TimeBasedCredentials {
turnCredentials = trt
}
plainResp := toSyncResponse(s.config, peer, turnCredentials, trt, networkMap, s.accountManager.GetDNSDomain())
encryptedResp, err := encryption.EncryptMessage(peerKey, s.wgKey, plainResp) encryptedResp, err := encryption.EncryptMessage(peerKey, s.wgKey, plainResp)
if err != nil { if err != nil {

View File

@ -169,7 +169,7 @@ func Test_SyncProtocol(t *testing.T) {
} }
if wiretrusteeConfig.GetSignal() == nil { if wiretrusteeConfig.GetSignal() == nil {
t.Fatal("expecting SyncResponse to have WiretrusteeConfig with non-nil Signal config") t.Fatal("expecting SyncResponse to have WiretrusteeConfig with non-nil Signal turnCfg")
} }
expectedSignalConfig := &mgmtProto.HostConfig{ expectedSignalConfig := &mgmtProto.HostConfig{
@ -418,7 +418,7 @@ func startManagement(t *testing.T, config *Config) (*grpc.Server, string, error)
if err != nil { if err != nil {
return nil, "", err return nil, "", err
} }
turnManager := NewTimeBasedAuthSecretsManager(peersUpdateManager, config.TURNConfig) turnManager := NewTimeBasedAuthSecretsManager(peersUpdateManager, config.TURNConfig, "")
ephemeralMgr := NewEphemeralManager(store, accountManager) ephemeralMgr := NewEphemeralManager(store, accountManager)
mgmtServer, err := NewServer(config, accountManager, peersUpdateManager, turnManager, nil, ephemeralMgr) mgmtServer, err := NewServer(config, accountManager, peersUpdateManager, turnManager, nil, ephemeralMgr)

View File

@ -544,7 +544,7 @@ func startServer(config *server.Config) (*grpc.Server, net.Listener) {
if err != nil { if err != nil {
log.Fatalf("failed creating a manager: %v", err) log.Fatalf("failed creating a manager: %v", err)
} }
turnManager := server.NewTimeBasedAuthSecretsManager(peersUpdateManager, config.TURNConfig) turnManager := server.NewTimeBasedAuthSecretsManager(peersUpdateManager, config.TURNConfig, "")
mgmtServer, err := server.NewServer(config, accountManager, peersUpdateManager, turnManager, nil, nil) mgmtServer, err := server.NewServer(config, accountManager, peersUpdateManager, turnManager, nil, nil)
Expect(err).NotTo(HaveOccurred()) Expect(err).NotTo(HaveOccurred())
mgmtProto.RegisterManagementServiceServer(s, mgmtServer) mgmtProto.RegisterManagementServiceServer(s, mgmtServer)

View File

@ -900,7 +900,7 @@ func (am *DefaultAccountManager) updateAccountPeers(account *Account) {
continue continue
} }
remotePeerNetworkMap := account.GetPeerNetworkMap(peer.ID, am.dnsDomain, approvedPeersMap) remotePeerNetworkMap := account.GetPeerNetworkMap(peer.ID, am.dnsDomain, approvedPeersMap)
update := toSyncResponse(nil, peer, nil, remotePeerNetworkMap, am.GetDNSDomain()) update := toSyncResponse(nil, peer, nil, nil, remotePeerNetworkMap, am.GetDNSDomain())
am.peersUpdateManager.SendUpdate(peer.ID, &UpdateMessage{Update: update}) am.peersUpdateManager.SendUpdate(peer.ID, &UpdateMessage{Update: update})
} }
} }

View File

@ -0,0 +1,126 @@
package server
import (
"fmt"
"sync"
"time"
log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/management/proto"
auth "github.com/netbirdio/netbird/relay/auth/hmac"
)
// TURNRelayTokenManager used to manage TURN credentials
type TURNRelayTokenManager interface {
Generate() (*TURNRelayToken, error)
SetupRefresh(peerKey string)
CancelRefresh(peerKey string)
}
// TimeBasedAuthSecretsManager generates credentials with TTL and using pre-shared secret known to TURN server
type TimeBasedAuthSecretsManager struct {
mux sync.Mutex
turnCfg *TURNConfig
relayAddr string
hmacToken *auth.TimedHMAC
updateManager *PeersUpdateManager
cancelMap map[string]chan struct{}
}
type TURNRelayToken auth.Token
func NewTimeBasedAuthSecretsManager(updateManager *PeersUpdateManager, turnCfg *TURNConfig, relayAddress string) *TimeBasedAuthSecretsManager {
return &TimeBasedAuthSecretsManager{
mux: sync.Mutex{},
updateManager: updateManager,
turnCfg: turnCfg,
relayAddr: relayAddress,
hmacToken: auth.NewTimedHMAC(turnCfg.Secret, turnCfg.CredentialsTTL.Duration),
cancelMap: make(map[string]chan struct{}),
}
}
// Generate generates new time-based secret credentials - basically username is a unix timestamp and password is a HMAC hash of a timestamp with a preshared TURN secret
func (m *TimeBasedAuthSecretsManager) Generate() (*TURNRelayToken, error) {
token, err := m.hmacToken.GenerateToken()
if err != nil {
return nil, fmt.Errorf("failed to generate token: %s", err)
}
return (*TURNRelayToken)(token), nil
}
func (m *TimeBasedAuthSecretsManager) cancel(peerID string) {
if channel, ok := m.cancelMap[peerID]; ok {
close(channel)
delete(m.cancelMap, peerID)
}
}
// CancelRefresh cancels scheduled peer credentials refresh
func (m *TimeBasedAuthSecretsManager) CancelRefresh(peerID string) {
m.mux.Lock()
defer m.mux.Unlock()
m.cancel(peerID)
}
// SetupRefresh starts peer credentials refresh. Since credentials are expiring (TTL) it is necessary to always generate them and send to the peer.
// A goroutine is created and put into TimeBasedAuthSecretsManager.cancelMap. This routine should be cancelled if peer is gone.
func (m *TimeBasedAuthSecretsManager) SetupRefresh(peerID string) {
m.mux.Lock()
defer m.mux.Unlock()
m.cancel(peerID)
cancel := make(chan struct{}, 1)
m.cancelMap[peerID] = cancel
log.Debugf("starting turn refresh for %s", peerID)
go func() {
// we don't want to regenerate credentials right on expiration, so we do it slightly before (at 3/4 of TTL)
ticker := time.NewTicker(m.turnCfg.CredentialsTTL.Duration / 4 * 3)
defer ticker.Stop()
for {
select {
case <-cancel:
log.Debugf("stopping turn refresh for %s", peerID)
return
case <-ticker.C:
m.pushNewTokens(peerID)
}
}
}()
}
func (m *TimeBasedAuthSecretsManager) pushNewTokens(peerID string) {
token, err := m.hmacToken.GenerateToken()
if err != nil {
log.Errorf("failed to generate token for peer '%s': %s", peerID, err)
return
}
var turns []*proto.ProtectedHostConfig
for _, host := range m.turnCfg.Turns {
turns = append(turns, &proto.ProtectedHostConfig{
HostConfig: &proto.HostConfig{
Uri: host.URI,
Protocol: ToResponseProto(host.Proto),
},
User: token.Payload,
Password: token.Signature,
})
}
update := &proto.SyncResponse{
WiretrusteeConfig: &proto.WiretrusteeConfig{
Turns: turns,
Relay: &proto.RelayConfig{
Urls: []string{m.relayAddr},
TokenPayload: token.Payload,
TokenSignature: token.Signature,
},
},
}
log.Debugf("sending new TURN credentials to peer %s", peerID)
m.updateManager.SendUpdate(peerID, &UpdateMessage{Update: update})
}

View File

@ -26,18 +26,18 @@ func TestTimeBasedAuthSecretsManager_GenerateCredentials(t *testing.T) {
CredentialsTTL: ttl, CredentialsTTL: ttl,
Secret: secret, Secret: secret,
Turns: []*Host{TurnTestHost}, Turns: []*Host{TurnTestHost},
}) }, "")
credentials := tested.GenerateCredentials() credentials, _ := tested.Generate()
if credentials.Username == "" { if credentials.Payload == "" {
t.Errorf("expected generated TURN username not to be empty, got empty") t.Errorf("expected generated TURN username not to be empty, got empty")
} }
if credentials.Password == "" { if credentials.Signature == "" {
t.Errorf("expected generated TURN password not to be empty, got empty") t.Errorf("expected generated TURN password not to be empty, got empty")
} }
validateMAC(t, credentials.Username, credentials.Password, []byte(secret)) validateMAC(t, credentials.Payload, credentials.Signature, []byte(secret))
} }
@ -52,7 +52,7 @@ func TestTimeBasedAuthSecretsManager_SetupRefresh(t *testing.T) {
CredentialsTTL: ttl, CredentialsTTL: ttl,
Secret: secret, Secret: secret,
Turns: []*Host{TurnTestHost}, Turns: []*Host{TurnTestHost},
}) }, "")
tested.SetupRefresh(peer) tested.SetupRefresh(peer)
@ -100,7 +100,7 @@ func TestTimeBasedAuthSecretsManager_CancelRefresh(t *testing.T) {
CredentialsTTL: ttl, CredentialsTTL: ttl,
Secret: secret, Secret: secret,
Turns: []*Host{TurnTestHost}, Turns: []*Host{TurnTestHost},
}) }, "")
tested.SetupRefresh(peer) tested.SetupRefresh(peer)
if _, ok := tested.cancelMap[peer]; !ok { if _, ok := tested.cancelMap[peer]; !ok {

View File

@ -1,125 +0,0 @@
package server
import (
"crypto/hmac"
"crypto/sha1"
"encoding/base64"
"fmt"
"sync"
"time"
log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/management/proto"
)
// TURNCredentialsManager used to manage TURN credentials
type TURNCredentialsManager interface {
GenerateCredentials() TURNCredentials
SetupRefresh(peerKey string)
CancelRefresh(peerKey string)
}
// TimeBasedAuthSecretsManager generates credentials with TTL and using pre-shared secret known to TURN server
type TimeBasedAuthSecretsManager struct {
mux sync.Mutex
config *TURNConfig
updateManager *PeersUpdateManager
cancelMap map[string]chan struct{}
}
type TURNCredentials struct {
Username string
Password string
}
func NewTimeBasedAuthSecretsManager(updateManager *PeersUpdateManager, config *TURNConfig) *TimeBasedAuthSecretsManager {
return &TimeBasedAuthSecretsManager{
mux: sync.Mutex{},
config: config,
updateManager: updateManager,
cancelMap: make(map[string]chan struct{}),
}
}
// GenerateCredentials generates new time-based secret credentials - basically username is a unix timestamp and password is a HMAC hash of a timestamp with a preshared TURN secret
func (m *TimeBasedAuthSecretsManager) GenerateCredentials() TURNCredentials {
mac := hmac.New(sha1.New, []byte(m.config.Secret))
timeAuth := time.Now().Add(m.config.CredentialsTTL.Duration).Unix()
username := fmt.Sprint(timeAuth)
_, err := mac.Write([]byte(username))
if err != nil {
log.Errorln("Generating turn password failed with error: ", err)
}
bytePassword := mac.Sum(nil)
password := base64.StdEncoding.EncodeToString(bytePassword)
return TURNCredentials{
Username: username,
Password: password,
}
}
func (m *TimeBasedAuthSecretsManager) cancel(peerID string) {
if channel, ok := m.cancelMap[peerID]; ok {
close(channel)
delete(m.cancelMap, peerID)
}
}
// CancelRefresh cancels scheduled peer credentials refresh
func (m *TimeBasedAuthSecretsManager) CancelRefresh(peerID string) {
m.mux.Lock()
defer m.mux.Unlock()
m.cancel(peerID)
}
// SetupRefresh starts peer credentials refresh. Since credentials are expiring (TTL) it is necessary to always generate them and send to the peer.
// A goroutine is created and put into TimeBasedAuthSecretsManager.cancelMap. This routine should be cancelled if peer is gone.
func (m *TimeBasedAuthSecretsManager) SetupRefresh(peerID string) {
m.mux.Lock()
defer m.mux.Unlock()
m.cancel(peerID)
cancel := make(chan struct{}, 1)
m.cancelMap[peerID] = cancel
log.Debugf("starting turn refresh for %s", peerID)
go func() {
// we don't want to regenerate credentials right on expiration, so we do it slightly before (at 3/4 of TTL)
ticker := time.NewTicker(m.config.CredentialsTTL.Duration / 4 * 3)
for {
select {
case <-cancel:
log.Debugf("stopping turn refresh for %s", peerID)
return
case <-ticker.C:
c := m.GenerateCredentials()
var turns []*proto.ProtectedHostConfig
for _, host := range m.config.Turns {
turns = append(turns, &proto.ProtectedHostConfig{
HostConfig: &proto.HostConfig{
Uri: host.URI,
Protocol: ToResponseProto(host.Proto),
},
User: c.Username,
Password: c.Password,
})
}
update := &proto.SyncResponse{
WiretrusteeConfig: &proto.WiretrusteeConfig{
Turns: turns,
},
}
log.Debugf("sending new TURN credentials to peer %s", peerID)
m.updateManager.SendUpdate(peerID, &UpdateMessage{Update: update})
}
}
}()
}

9
relay/auth/allow_all.go Normal file
View File

@ -0,0 +1,9 @@
package auth
// AllowAllAuth is a Validator that allows all connections.
type AllowAllAuth struct {
}
func (a *AllowAllAuth) Validate(any) error {
return nil
}

24
relay/auth/hmac/store.go Normal file
View File

@ -0,0 +1,24 @@
package hmac
import (
"sync"
)
// Store is a simple in-memory store for token
// With this can update the token in thread safe way
type Store struct {
mu sync.Mutex
token Token
}
func (a *Store) UpdateToken(token Token) {
a.mu.Lock()
defer a.mu.Unlock()
a.token = token
}
func (a *Store) Token() ([]byte, error) {
a.mu.Lock()
defer a.mu.Unlock()
return marshalToken(a.token)
}

104
relay/auth/hmac/token.go Normal file
View File

@ -0,0 +1,104 @@
package hmac
import (
"bytes"
"crypto/hmac"
"crypto/sha1"
"encoding/base64"
"encoding/gob"
"fmt"
"strconv"
"sync"
"time"
log "github.com/sirupsen/logrus"
)
type Token struct {
Payload string
Signature string
}
func marshalToken(token Token) ([]byte, error) {
buffer := bytes.NewBuffer([]byte{})
encoder := gob.NewEncoder(buffer)
err := encoder.Encode(token)
if err != nil {
log.Errorf("failed to marshal token: %s", err)
return nil, err
}
return buffer.Bytes(), nil
}
func unmarshalToken(payload []byte) (Token, error) {
var creds Token
buffer := bytes.NewBuffer(payload)
decoder := gob.NewDecoder(buffer)
err := decoder.Decode(&creds)
return creds, err
}
// TimedHMAC generates token with TTL and using pre-shared secret known to TURN server
type TimedHMAC struct {
mux sync.Mutex
secret string
timeToLive time.Duration
}
func NewTimedHMAC(secret string, timeToLive time.Duration) *TimedHMAC {
return &TimedHMAC{
secret: secret,
timeToLive: timeToLive,
}
}
// GenerateToken generates new time-based secret token - basically Payload is a unix timestamp and Signature is a HMAC hash of a timestamp with a preshared TURN secret
func (m *TimedHMAC) GenerateToken() (*Token, error) {
timeAuth := time.Now().Add(m.timeToLive).Unix()
timeStamp := fmt.Sprint(timeAuth)
checksum, err := m.generate(timeStamp)
if err != nil {
return nil, err
}
return &Token{
Payload: timeStamp,
Signature: base64.StdEncoding.EncodeToString(checksum),
}, nil
}
func (m *TimedHMAC) Validate(token Token) error {
expectedMAC, err := m.generate(token.Payload)
if err != nil {
return err
}
expectedSignature := base64.StdEncoding.EncodeToString(expectedMAC)
if !hmac.Equal([]byte(expectedSignature), []byte(token.Signature)) {
return fmt.Errorf("signature mismatch")
}
timeAuthInt, err := strconv.ParseInt(token.Payload, 10, 64)
if err != nil {
return fmt.Errorf("invalid payload: %s", err)
}
if time.Now().Unix() > timeAuthInt {
return fmt.Errorf("expired token")
}
return nil
}
func (m *TimedHMAC) generate(payload string) ([]byte, error) {
mac := hmac.New(sha1.New, []byte(m.secret))
_, err := mac.Write([]byte(payload))
if err != nil {
log.Errorf("failed to generate token: %s", err)
return nil, err
}
return mac.Sum(nil), nil
}

View File

@ -0,0 +1,103 @@
package hmac
import (
"encoding/base64"
"strconv"
"testing"
"time"
)
func TestGenerateCredentials(t *testing.T) {
secret := "secret"
timeToLive := 1 * time.Hour
v := NewTimedHMAC(secret, timeToLive)
creds, err := v.GenerateToken()
if err != nil {
t.Fatalf("expected no error, got %v", err)
}
if creds.Payload == "" {
t.Fatalf("expected non-empty payload")
}
_, err = strconv.ParseInt(creds.Payload, 10, 64)
if err != nil {
t.Fatalf("expected payload to be a valid unix timestamp, got %v", err)
}
_, err = base64.StdEncoding.DecodeString(creds.Signature)
if err != nil {
t.Fatalf("expected signature to be base64 encoded, got %v", err)
}
}
func TestValidateCredentials(t *testing.T) {
secret := "supersecret"
timeToLive := 1 * time.Hour
manager := NewTimedHMAC(secret, timeToLive)
// Test valid token
creds, err := manager.GenerateToken()
if err != nil {
t.Fatalf("expected no error, got %v", err)
}
if err := manager.Validate(*creds); err != nil {
t.Fatalf("expected valid token: %s", err)
}
}
func TestInvalidSignature(t *testing.T) {
secret := "supersecret"
timeToLive := 1 * time.Hour
manager := NewTimedHMAC(secret, timeToLive)
creds, err := manager.GenerateToken()
if err != nil {
t.Fatalf("expected no error, got %v", err)
}
invalidCreds := &Token{
Payload: creds.Payload,
Signature: "invalidsignature",
}
if err = manager.Validate(*invalidCreds); err == nil {
t.Fatalf("expected invalid token due to signature mismatch")
}
}
func TestExpired(t *testing.T) {
secret := "supersecret"
v := NewTimedHMAC(secret, -1*time.Hour)
expiredCreds, err := v.GenerateToken()
if err != nil {
t.Fatalf("expected no error, got %v", err)
}
if err = v.Validate(*expiredCreds); err == nil {
t.Fatalf("expected invalid token due to expiration")
}
}
func TestInvalidPayload(t *testing.T) {
secret := "supersecret"
timeToLive := 1 * time.Hour
v := NewTimedHMAC(secret, timeToLive)
creds, err := v.GenerateToken()
if err != nil {
t.Fatalf("expected no error, got %v", err)
}
// Test invalid payload
invalidPayloadCreds := &Token{
Payload: "invalidtimestamp",
Signature: creds.Signature,
}
if err = v.Validate(*invalidPayloadCreds); err == nil {
t.Fatalf("expected invalid token due to invalid payload")
}
}

View File

@ -0,0 +1,27 @@
package hmac
import (
log "github.com/sirupsen/logrus"
"time"
)
type TimedHMACValidator struct {
*TimedHMAC
}
func NewTimedHMACValidator(secret string, duration time.Duration) *TimedHMACValidator {
ta := NewTimedHMAC(secret, duration)
return &TimedHMACValidator{
ta,
}
}
func (a *TimedHMACValidator) Validate(credentials any) error {
b := credentials.([]byte)
c, err := unmarshalToken(b)
if err != nil {
log.Errorf("failed to unmarshal token: %s", err)
return err
}
return a.TimedHMAC.Validate(c)
}

5
relay/auth/validator.go Normal file
View File

@ -0,0 +1,5 @@
package auth
type Validator interface {
Validate(any) error
}

View File

@ -10,6 +10,7 @@ import (
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
auth "github.com/netbirdio/netbird/relay/auth/hmac"
"github.com/netbirdio/netbird/relay/client/dialer/ws" "github.com/netbirdio/netbird/relay/client/dialer/ws"
"github.com/netbirdio/netbird/relay/healthcheck" "github.com/netbirdio/netbird/relay/healthcheck"
"github.com/netbirdio/netbird/relay/messages" "github.com/netbirdio/netbird/relay/messages"
@ -98,6 +99,7 @@ type Client struct {
log *log.Entry log *log.Entry
parentCtx context.Context parentCtx context.Context
connectionURL string connectionURL string
authStore *auth.Store
hashedID []byte hashedID []byte
bufPool *sync.Pool bufPool *sync.Pool
@ -115,12 +117,13 @@ type Client struct {
} }
// NewClient creates a new client for the relay server. The client is not connected to the server until the Connect // NewClient creates a new client for the relay server. The client is not connected to the server until the Connect
func NewClient(ctx context.Context, serverURL, peerID string) *Client { func NewClient(ctx context.Context, serverURL string, authStore *auth.Store, peerID string) *Client {
hashedID, hashedStringId := messages.HashID(peerID) hashedID, hashedStringId := messages.HashID(peerID)
return &Client{ return &Client{
log: log.WithField("client_id", hashedStringId), log: log.WithField("client_id", hashedStringId),
parentCtx: ctx, parentCtx: ctx,
connectionURL: serverURL, connectionURL: serverURL,
authStore: authStore,
hashedID: hashedID, hashedID: hashedID,
bufPool: &sync.Pool{ bufPool: &sync.Pool{
New: func() any { New: func() any {
@ -234,7 +237,12 @@ func (c *Client) connect() error {
} }
func (c *Client) handShake() error { func (c *Client) handShake() error {
msg, err := messages.MarshalHelloMsg(c.hashedID) t, err := c.authStore.Token()
if err != nil {
return err
}
msg, err := messages.MarshalHelloMsg(c.hashedID, t)
if err != nil { if err != nil {
log.Errorf("failed to marshal hello message: %s", err) log.Errorf("failed to marshal hello message: %s", err)
return err return err
@ -262,11 +270,11 @@ func (c *Client) handShake() error {
return fmt.Errorf("unexpected message type") return fmt.Errorf("unexpected message type")
} }
domain, err := messages.UnmarshalHelloResponse(buf[:n]) ia, err := messages.UnmarshalHelloResponse(buf[:n])
if err != nil { if err != nil {
return err return err
} }
c.instanceURL = domain c.instanceURL = ia
return nil return nil
} }

View File

@ -8,6 +8,8 @@ import (
"time" "time"
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
relayAuth "github.com/netbirdio/netbird/relay/auth/hmac"
) )
var ( var (
@ -35,9 +37,10 @@ func NewRelayTrack() *RelayTrack {
// relay servers will be closed if there is no active connection. Periodically the manager will check if there is any // relay servers will be closed if there is no active connection. Periodically the manager will check if there is any
// unused relay connection and close it. // unused relay connection and close it.
type Manager struct { type Manager struct {
ctx context.Context ctx context.Context
serverURL string serverURL string
peerID string peerID string
tokenStore *relayAuth.Store
relayClient *Client relayClient *Client
reconnectGuard *Guard reconnectGuard *Guard
@ -54,6 +57,7 @@ func NewManager(ctx context.Context, serverURL string, peerID string) *Manager {
ctx: ctx, ctx: ctx,
serverURL: serverURL, serverURL: serverURL,
peerID: peerID, peerID: peerID,
tokenStore: &relayAuth.Store{},
relayClients: make(map[string]*RelayTrack), relayClients: make(map[string]*RelayTrack),
onDisconnectedListeners: make(map[string]map[*func()]struct{}), onDisconnectedListeners: make(map[string]map[*func()]struct{}),
} }
@ -65,7 +69,7 @@ func (m *Manager) Serve() error {
return fmt.Errorf("manager already serving") return fmt.Errorf("manager already serving")
} }
m.relayClient = NewClient(m.ctx, m.serverURL, m.peerID) m.relayClient = NewClient(m.ctx, m.serverURL, m.tokenStore, m.peerID)
err := m.relayClient.Connect() err := m.relayClient.Connect()
if err != nil { if err != nil {
log.Errorf("failed to connect to relay server: %s", err) log.Errorf("failed to connect to relay server: %s", err)
@ -158,7 +162,7 @@ func (m *Manager) openConnVia(serverAddress, peerKey string) (net.Conn, error) {
m.relayClients[serverAddress] = rt m.relayClients[serverAddress] = rt
m.relayClientsMutex.Unlock() m.relayClientsMutex.Unlock()
relayClient := NewClient(m.ctx, serverAddress, m.peerID) relayClient := NewClient(m.ctx, serverAddress, m.tokenStore, m.peerID)
err := relayClient.Connect() err := relayClient.Connect()
if err != nil { if err != nil {
rt.Unlock() rt.Unlock()
@ -260,3 +264,7 @@ func (m *Manager) notifyOnDisconnectListeners(serverAddress string) {
m.listenerLock.Unlock() m.listenerLock.Unlock()
} }
func (m *Manager) UpdateToken(token relayAuth.Token) {
m.tokenStore.UpdateToken(token)
}

View File

@ -24,6 +24,7 @@ var (
letsencryptDomains []string letsencryptDomains []string
tlsCertFile string tlsCertFile string
tlsKeyFile string tlsKeyFile string
authSecret string
rootCmd = &cobra.Command{ rootCmd = &cobra.Command{
Use: "relay", Use: "relay",
@ -41,7 +42,7 @@ func init() {
rootCmd.PersistentFlags().StringArrayVarP(&letsencryptDomains, "letsencrypt-domains", "a", nil, "list of domains to issue Let's Encrypt certificate for. Enables TLS using Let's Encrypt. Will fetch and renew certificate, and run the server with TLS") rootCmd.PersistentFlags().StringArrayVarP(&letsencryptDomains, "letsencrypt-domains", "a", nil, "list of domains to issue Let's Encrypt certificate for. Enables TLS using Let's Encrypt. Will fetch and renew certificate, and run the server with TLS")
rootCmd.PersistentFlags().StringVarP(&tlsCertFile, "tls-cert-file", "c", "", "") rootCmd.PersistentFlags().StringVarP(&tlsCertFile, "tls-cert-file", "c", "", "")
rootCmd.PersistentFlags().StringVarP(&tlsKeyFile, "tls-key-file", "k", "", "") rootCmd.PersistentFlags().StringVarP(&tlsKeyFile, "tls-key-file", "k", "", "")
rootCmd.PersistentFlags().StringVarP(&authSecret, "auth-secret", "s", "", "log level")
} }
func waitForExitSignal() { func waitForExitSignal() {
@ -56,6 +57,11 @@ func execute(cmd *cobra.Command, args []string) {
os.Exit(1) os.Exit(1)
} }
if authSecret == "" {
log.Errorf("auth secret is required")
os.Exit(1)
}
srvListenerCfg := server.ListenerConfig{ srvListenerCfg := server.ListenerConfig{
Address: listenAddress, Address: listenAddress,
} }
@ -76,7 +82,7 @@ func execute(cmd *cobra.Command, args []string) {
} }
tlsSupport := srvListenerCfg.TLSConfig != nil tlsSupport := srvListenerCfg.TLSConfig != nil
srv := server.NewServer(exposedAddress, tlsSupport) srv := server.NewServer(exposedAddress, tlsSupport, authSecret)
log.Infof("server will be available on: %s", srv.InstanceURL()) log.Infof("server will be available on: %s", srv.InstanceURL())
err := srv.Listen(srvListenerCfg) err := srv.Listen(srvListenerCfg)
if err != nil { if err != nil {

View File

@ -15,8 +15,10 @@ const (
MsgTypeClose MsgType = 3 MsgTypeClose MsgType = 3
MsgTypeHealthCheck MsgType = 4 MsgTypeHealthCheck MsgType = 4
headerSizeTransport = 1 + IDSize // 1 byte for msg type, IDSize for peerID sizeOfMsgType = 1
headerSizeHello = 1 + 4 + IDSize // 1 byte for msg type, 4 byte for magic header, IDSize for peerID sizeOfMagicBye = 4
headerSizeTransport = sizeOfMsgType + IDSize // 1 byte for msg type, IDSize for peerID
headerSizeHello = sizeOfMsgType + sizeOfMagicBye + IDSize // 1 byte for msg type, 4 byte for magic header, IDSize for peerID
MaxHandshakeSize = 90 MaxHandshakeSize = 90
) )
@ -47,7 +49,7 @@ func (m MsgType) String() string {
} }
type HelloResponse struct { type HelloResponse struct {
DomainAddress string InstanceAddress string
} }
func DetermineClientMsgType(msg []byte) (MsgType, error) { func DetermineClientMsgType(msg []byte) (MsgType, error) {
@ -83,28 +85,29 @@ func DetermineServerMsgType(msg []byte) (MsgType, error) {
} }
// MarshalHelloMsg initial hello message // MarshalHelloMsg initial hello message
func MarshalHelloMsg(peerID []byte) ([]byte, error) { func MarshalHelloMsg(peerID []byte, additions []byte) ([]byte, error) {
if len(peerID) != IDSize { if len(peerID) != IDSize {
return nil, fmt.Errorf("invalid peerID length: %d", len(peerID)) return nil, fmt.Errorf("invalid peerID length: %d", len(peerID))
} }
msg := make([]byte, 5, headerSizeHello) msg := make([]byte, 5, headerSizeHello+len(additions))
msg[0] = byte(MsgTypeHello) msg[0] = byte(MsgTypeHello)
copy(msg[1:5], magicHeader) copy(msg[1:5], magicHeader)
msg = append(msg, peerID...) msg = append(msg, peerID...)
msg = append(msg, additions...)
return msg, nil return msg, nil
} }
func UnmarshalHelloMsg(msg []byte) ([]byte, error) { func UnmarshalHelloMsg(msg []byte) ([]byte, []byte, error) {
if len(msg) < headerSizeHello { if len(msg) < headerSizeHello {
return nil, fmt.Errorf("invalid 'hello' messge") return nil, nil, fmt.Errorf("invalid 'hello' messge")
} }
bytes.Equal(msg[1:5], magicHeader) bytes.Equal(msg[1:5], magicHeader)
return msg[5:], nil return msg[5:], msg[headerSizeHello:], nil
} }
func MarshalHelloResponse(DomainAddress string) ([]byte, error) { func MarshalHelloResponse(DomainAddress string) ([]byte, error) {
payload := HelloResponse{ payload := HelloResponse{
DomainAddress: DomainAddress, InstanceAddress: DomainAddress,
} }
buf := new(bytes.Buffer) buf := new(bytes.Buffer)
@ -135,7 +138,7 @@ func UnmarshalHelloResponse(msg []byte) (string, error) {
log.Errorf("failed to gob decode hello response: %s", err) log.Errorf("failed to gob decode hello response: %s", err)
return "", err return "", err
} }
return payload.DomainAddress, nil return payload.InstanceAddress, nil
} }
// Close message // Close message

View File

@ -6,12 +6,12 @@ import (
func TestMarshalHelloMsg(t *testing.T) { func TestMarshalHelloMsg(t *testing.T) {
peerID := []byte("abdFAaBcawquEiCMzAabYosuUaGLtSNhKxz+") peerID := []byte("abdFAaBcawquEiCMzAabYosuUaGLtSNhKxz+")
bHello, err := MarshalHelloMsg(peerID) bHello, err := MarshalHelloMsg(peerID, nil)
if err != nil { if err != nil {
t.Fatalf("error: %v", err) t.Fatalf("error: %v", err)
} }
receivedPeerID, err := UnmarshalHelloMsg(bHello) receivedPeerID, _, err := UnmarshalHelloMsg(bHello)
if err != nil { if err != nil {
t.Fatalf("error: %v", err) t.Fatalf("error: %v", err)
} }

View File

@ -8,20 +8,24 @@ import (
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/relay/auth"
"github.com/netbirdio/netbird/relay/messages" "github.com/netbirdio/netbird/relay/messages"
) )
type Relay struct { type Relay struct {
validator auth.Validator
store *Store store *Store
instaceURL string // domain:port instaceURL string
closed bool closed bool
closeMu sync.RWMutex closeMu sync.RWMutex
} }
func NewRelay(exposedAddress string, tlsSupport bool) *Relay { func NewRelay(exposedAddress string, tlsSupport bool, validator auth.Validator) *Relay {
r := &Relay{ r := &Relay{
store: NewStore(), validator: validator,
store: NewStore(),
} }
if tlsSupport { if tlsSupport {
@ -29,6 +33,7 @@ func NewRelay(exposedAddress string, tlsSupport bool) *Relay {
} else { } else {
r.instaceURL = fmt.Sprintf("rel://%s", exposedAddress) r.instaceURL = fmt.Sprintf("rel://%s", exposedAddress)
} }
return r return r
} }
@ -94,12 +99,17 @@ func (r *Relay) handShake(conn net.Conn) ([]byte, error) {
return nil, tErr return nil, tErr
} }
peerID, err := messages.UnmarshalHelloMsg(buf[:n]) peerID, authPayload, err := messages.UnmarshalHelloMsg(buf[:n])
if err != nil { if err != nil {
log.Errorf("failed to handshake: %s", err) log.Errorf("failed to handshake: %s", err)
return nil, err return nil, err
} }
if err := r.validator.Validate(authPayload); err != nil {
log.Errorf("failed to authenticate peer with id: %s, %s", peerID, err)
return nil, fmt.Errorf("failed to authenticate peer")
}
msg, _ := messages.MarshalHelloResponse(r.instaceURL) msg, _ := messages.MarshalHelloResponse(r.instaceURL)
_, err = conn.Write(msg) _, err = conn.Write(msg)
if err != nil { if err != nil {

View File

@ -9,6 +9,7 @@ import (
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
auth "github.com/netbirdio/netbird/relay/auth/hmac"
"github.com/netbirdio/netbird/relay/server/listener" "github.com/netbirdio/netbird/relay/server/listener"
"github.com/netbirdio/netbird/relay/server/listener/udp" "github.com/netbirdio/netbird/relay/server/listener/udp"
"github.com/netbirdio/netbird/relay/server/listener/ws" "github.com/netbirdio/netbird/relay/server/listener/ws"
@ -25,9 +26,12 @@ type Server struct {
wSListener listener.Listener wSListener listener.Listener
} }
func NewServer(exposedAddress string, tlsSupport bool) *Server { func NewServer(exposedAddress string, tlsSupport bool, authSecret string) *Server {
return &Server{ return &Server{
relay: NewRelay(exposedAddress, tlsSupport), relay: NewRelay(
exposedAddress,
tlsSupport,
auth.NewTimedHMACValidator(authSecret, 24*time.Hour)),
} }
} }