mirror of
https://github.com/netbirdio/netbird.git
synced 2025-03-29 17:17:20 +01:00
Merge branch 'main' into feature/add_pat_middleware
# Conflicts: # management/server/grpcserver.go # management/server/http/middleware/jwt.go
This commit is contained in:
commit
4d7029d80c
4
.github/workflows/golang-test-linux.yml
vendored
4
.github/workflows/golang-test-linux.yml
vendored
@ -72,7 +72,7 @@ jobs:
|
|||||||
run: go test -c -o routemanager-testing.bin ./client/internal/routemanager/...
|
run: go test -c -o routemanager-testing.bin ./client/internal/routemanager/...
|
||||||
|
|
||||||
- name: Generate Engine Test bin
|
- name: Generate Engine Test bin
|
||||||
run: go test -c -o engine-testing.bin ./client/internal/*.go
|
run: go test -c -o engine-testing.bin ./client/internal
|
||||||
|
|
||||||
- name: Generate Peer Test bin
|
- name: Generate Peer Test bin
|
||||||
run: go test -c -o peer-testing.bin ./client/internal/peer/...
|
run: go test -c -o peer-testing.bin ./client/internal/peer/...
|
||||||
@ -89,4 +89,4 @@ jobs:
|
|||||||
run: docker run -t --cap-add=NET_ADMIN --privileged --rm -v $PWD:/ci -w /ci/client/internal --entrypoint /busybox/sh gcr.io/distroless/base:debug -c /ci/engine-testing.bin -test.timeout 5m -test.parallel 1
|
run: docker run -t --cap-add=NET_ADMIN --privileged --rm -v $PWD:/ci -w /ci/client/internal --entrypoint /busybox/sh gcr.io/distroless/base:debug -c /ci/engine-testing.bin -test.timeout 5m -test.parallel 1
|
||||||
|
|
||||||
- name: Run Peer tests in docker
|
- name: Run Peer tests in docker
|
||||||
run: docker run -t --cap-add=NET_ADMIN --privileged --rm -v $PWD:/ci -w /ci/client/internal/peer --entrypoint /busybox/sh gcr.io/distroless/base:debug -c /ci/peer-testing.bin -test.timeout 5m -test.parallel 1
|
run: docker run -t --cap-add=NET_ADMIN --privileged --rm -v $PWD:/ci -w /ci/client/internal/peer --entrypoint /busybox/sh gcr.io/distroless/base:debug -c /ci/peer-testing.bin -test.timeout 5m -test.parallel 1
|
||||||
|
@ -8,6 +8,7 @@ import (
|
|||||||
|
|
||||||
"github.com/netbirdio/netbird/client/internal"
|
"github.com/netbirdio/netbird/client/internal"
|
||||||
"github.com/netbirdio/netbird/client/internal/peer"
|
"github.com/netbirdio/netbird/client/internal/peer"
|
||||||
|
"github.com/netbirdio/netbird/client/internal/stdnet"
|
||||||
"github.com/netbirdio/netbird/client/system"
|
"github.com/netbirdio/netbird/client/system"
|
||||||
"github.com/netbirdio/netbird/formatter"
|
"github.com/netbirdio/netbird/formatter"
|
||||||
"github.com/netbirdio/netbird/iface"
|
"github.com/netbirdio/netbird/iface"
|
||||||
@ -23,6 +24,11 @@ type TunAdapter interface {
|
|||||||
iface.TunAdapter
|
iface.TunAdapter
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// IFaceDiscover export internal IFaceDiscover for mobile
|
||||||
|
type IFaceDiscover interface {
|
||||||
|
stdnet.IFaceDiscover
|
||||||
|
}
|
||||||
|
|
||||||
func init() {
|
func init() {
|
||||||
formatter.SetLogcatFormatter(log.StandardLogger())
|
formatter.SetLogcatFormatter(log.StandardLogger())
|
||||||
}
|
}
|
||||||
@ -31,6 +37,7 @@ func init() {
|
|||||||
type Client struct {
|
type Client struct {
|
||||||
cfgFile string
|
cfgFile string
|
||||||
tunAdapter iface.TunAdapter
|
tunAdapter iface.TunAdapter
|
||||||
|
iFaceDiscover IFaceDiscover
|
||||||
recorder *peer.Status
|
recorder *peer.Status
|
||||||
ctxCancel context.CancelFunc
|
ctxCancel context.CancelFunc
|
||||||
ctxCancelLock *sync.Mutex
|
ctxCancelLock *sync.Mutex
|
||||||
@ -38,7 +45,7 @@ type Client struct {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// NewClient instantiate a new Client
|
// NewClient instantiate a new Client
|
||||||
func NewClient(cfgFile, deviceName string, tunAdapter TunAdapter) *Client {
|
func NewClient(cfgFile, deviceName string, tunAdapter TunAdapter, iFaceDiscover IFaceDiscover) *Client {
|
||||||
lvl, _ := log.ParseLevel("trace")
|
lvl, _ := log.ParseLevel("trace")
|
||||||
log.SetLevel(lvl)
|
log.SetLevel(lvl)
|
||||||
|
|
||||||
@ -46,6 +53,7 @@ func NewClient(cfgFile, deviceName string, tunAdapter TunAdapter) *Client {
|
|||||||
cfgFile: cfgFile,
|
cfgFile: cfgFile,
|
||||||
deviceName: deviceName,
|
deviceName: deviceName,
|
||||||
tunAdapter: tunAdapter,
|
tunAdapter: tunAdapter,
|
||||||
|
iFaceDiscover: iFaceDiscover,
|
||||||
recorder: peer.NewRecorder(""),
|
recorder: peer.NewRecorder(""),
|
||||||
ctxCancelLock: &sync.Mutex{},
|
ctxCancelLock: &sync.Mutex{},
|
||||||
}
|
}
|
||||||
@ -70,14 +78,14 @@ func (c *Client) Run(urlOpener URLOpener) error {
|
|||||||
c.ctxCancelLock.Unlock()
|
c.ctxCancelLock.Unlock()
|
||||||
|
|
||||||
auth := NewAuthWithConfig(ctx, cfg)
|
auth := NewAuthWithConfig(ctx, cfg)
|
||||||
err = auth.Login(urlOpener)
|
err = auth.login(urlOpener)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
// todo do not throw error in case of cancelled context
|
// todo do not throw error in case of cancelled context
|
||||||
ctx = internal.CtxInitState(ctx)
|
ctx = internal.CtxInitState(ctx)
|
||||||
return internal.RunClient(ctx, cfg, c.recorder, c.tunAdapter)
|
return internal.RunClient(ctx, cfg, c.recorder, c.tunAdapter, c.iFaceDiscover)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Stop the internal client and free the resources
|
// Stop the internal client and free the resources
|
||||||
|
@ -3,16 +3,32 @@ package android
|
|||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"fmt"
|
"fmt"
|
||||||
"github.com/cenkalti/backoff/v4"
|
|
||||||
"github.com/netbirdio/netbird/client/cmd"
|
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/netbirdio/netbird/client/internal"
|
"github.com/cenkalti/backoff/v4"
|
||||||
|
|
||||||
log "github.com/sirupsen/logrus"
|
log "github.com/sirupsen/logrus"
|
||||||
"google.golang.org/grpc/codes"
|
"google.golang.org/grpc/codes"
|
||||||
gstatus "google.golang.org/grpc/status"
|
gstatus "google.golang.org/grpc/status"
|
||||||
|
|
||||||
|
"github.com/netbirdio/netbird/client/cmd"
|
||||||
|
"github.com/netbirdio/netbird/client/system"
|
||||||
|
|
||||||
|
"github.com/netbirdio/netbird/client/internal"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
// SSOListener is async listener for mobile framework
|
||||||
|
type SSOListener interface {
|
||||||
|
OnSuccess(bool)
|
||||||
|
OnError(error)
|
||||||
|
}
|
||||||
|
|
||||||
|
// ErrListener is async listener for mobile framework
|
||||||
|
type ErrListener interface {
|
||||||
|
OnSuccess()
|
||||||
|
OnError(error)
|
||||||
|
}
|
||||||
|
|
||||||
// URLOpener it is a callback interface. The Open function will be triggered if
|
// URLOpener it is a callback interface. The Open function will be triggered if
|
||||||
// the backend want to show an url for the user
|
// the backend want to show an url for the user
|
||||||
type URLOpener interface {
|
type URLOpener interface {
|
||||||
@ -52,32 +68,66 @@ func NewAuthWithConfig(ctx context.Context, config *internal.Config) *Auth {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// LoginAndSaveConfigIfSSOSupported test the connectivity with the management server.
|
// SaveConfigIfSSOSupported test the connectivity with the management server by retrieving the server device flow info.
|
||||||
// If the SSO is supported than save the configuration. Return with the SSO login is supported or not.
|
// If it returns a flow info than save the configuration and return true. If it gets a codes.NotFound, it means that SSO
|
||||||
func (a *Auth) LoginAndSaveConfigIfSSOSupported() (bool, error) {
|
// is not supported and returns false without saving the configuration. For other errors return false.
|
||||||
var needsLogin bool
|
func (a *Auth) SaveConfigIfSSOSupported(listener SSOListener) {
|
||||||
|
go func() {
|
||||||
|
sso, err := a.saveConfigIfSSOSupported()
|
||||||
|
if err != nil {
|
||||||
|
listener.OnError(err)
|
||||||
|
} else {
|
||||||
|
listener.OnSuccess(sso)
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (a *Auth) saveConfigIfSSOSupported() (bool, error) {
|
||||||
|
supportsSSO := true
|
||||||
err := a.withBackOff(a.ctx, func() (err error) {
|
err := a.withBackOff(a.ctx, func() (err error) {
|
||||||
needsLogin, err = internal.IsLoginRequired(a.ctx, a.config.PrivateKey, a.config.ManagementURL, a.config.SSHKey)
|
_, err = internal.GetDeviceAuthorizationFlowInfo(a.ctx, a.config.PrivateKey, a.config.ManagementURL)
|
||||||
return
|
if s, ok := gstatus.FromError(err); ok && s.Code() == codes.NotFound {
|
||||||
|
supportsSSO = false
|
||||||
|
err = nil
|
||||||
|
}
|
||||||
|
return err
|
||||||
})
|
})
|
||||||
|
|
||||||
|
if !supportsSSO {
|
||||||
|
return false, nil
|
||||||
|
}
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return false, fmt.Errorf("backoff cycle failed: %v", err)
|
return false, fmt.Errorf("backoff cycle failed: %v", err)
|
||||||
}
|
}
|
||||||
if !needsLogin {
|
|
||||||
return false, nil
|
|
||||||
}
|
|
||||||
err = internal.WriteOutConfig(a.cfgPath, a.config)
|
err = internal.WriteOutConfig(a.cfgPath, a.config)
|
||||||
return needsLogin, err
|
return true, err
|
||||||
}
|
}
|
||||||
|
|
||||||
// LoginWithSetupKeyAndSaveConfig test the connectivity with the management server with the setup key.
|
// LoginWithSetupKeyAndSaveConfig test the connectivity with the management server with the setup key.
|
||||||
func (a *Auth) LoginWithSetupKeyAndSaveConfig(setupKey string) error {
|
func (a *Auth) LoginWithSetupKeyAndSaveConfig(resultListener ErrListener, setupKey string, deviceName string) {
|
||||||
err := a.withBackOff(a.ctx, func() error {
|
go func() {
|
||||||
err := internal.Login(a.ctx, a.config, setupKey, "")
|
err := a.loginWithSetupKeyAndSaveConfig(setupKey, deviceName)
|
||||||
if s, ok := gstatus.FromError(err); ok && (s.Code() == codes.InvalidArgument || s.Code() == codes.PermissionDenied) {
|
if err != nil {
|
||||||
return nil
|
resultListener.OnError(err)
|
||||||
|
} else {
|
||||||
|
resultListener.OnSuccess()
|
||||||
}
|
}
|
||||||
return err
|
}()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (a *Auth) loginWithSetupKeyAndSaveConfig(setupKey string, deviceName string) error {
|
||||||
|
//nolint
|
||||||
|
ctxWithValues := context.WithValue(a.ctx, system.DeviceNameCtxKey, deviceName)
|
||||||
|
|
||||||
|
err := a.withBackOff(a.ctx, func() error {
|
||||||
|
backoffErr := internal.Login(ctxWithValues, a.config, setupKey, "")
|
||||||
|
if s, ok := gstatus.FromError(backoffErr); ok && (s.Code() == codes.PermissionDenied) {
|
||||||
|
// we got an answer from management, exit backoff earlier
|
||||||
|
return backoff.Permanent(backoffErr)
|
||||||
|
}
|
||||||
|
return backoffErr
|
||||||
})
|
})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("backoff cycle failed: %v", err)
|
return fmt.Errorf("backoff cycle failed: %v", err)
|
||||||
@ -87,7 +137,18 @@ func (a *Auth) LoginWithSetupKeyAndSaveConfig(setupKey string) error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Login try register the client on the server
|
// Login try register the client on the server
|
||||||
func (a *Auth) Login(urlOpener URLOpener) error {
|
func (a *Auth) Login(resultListener ErrListener, urlOpener URLOpener) {
|
||||||
|
go func() {
|
||||||
|
err := a.login(urlOpener)
|
||||||
|
if err != nil {
|
||||||
|
resultListener.OnError(err)
|
||||||
|
} else {
|
||||||
|
resultListener.OnSuccess()
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (a *Auth) login(urlOpener URLOpener) error {
|
||||||
var needsLogin bool
|
var needsLogin bool
|
||||||
|
|
||||||
// check if we need to generate JWT token
|
// check if we need to generate JWT token
|
||||||
|
@ -94,7 +94,7 @@ func runInForegroundMode(ctx context.Context, cmd *cobra.Command) error {
|
|||||||
var cancel context.CancelFunc
|
var cancel context.CancelFunc
|
||||||
ctx, cancel = context.WithCancel(ctx)
|
ctx, cancel = context.WithCancel(ctx)
|
||||||
SetupCloseHandler(ctx, cancel)
|
SetupCloseHandler(ctx, cancel)
|
||||||
return internal.RunClient(ctx, config, peer.NewRecorder(config.ManagementURL.String()), nil)
|
return internal.RunClient(ctx, config, peer.NewRecorder(config.ManagementURL.String()), nil, nil)
|
||||||
}
|
}
|
||||||
|
|
||||||
func runInDaemonMode(ctx context.Context, cmd *cobra.Command) error {
|
func runInDaemonMode(ctx context.Context, cmd *cobra.Command) error {
|
||||||
|
@ -13,6 +13,7 @@ import (
|
|||||||
gstatus "google.golang.org/grpc/status"
|
gstatus "google.golang.org/grpc/status"
|
||||||
|
|
||||||
"github.com/netbirdio/netbird/client/internal/peer"
|
"github.com/netbirdio/netbird/client/internal/peer"
|
||||||
|
"github.com/netbirdio/netbird/client/internal/stdnet"
|
||||||
"github.com/netbirdio/netbird/client/ssh"
|
"github.com/netbirdio/netbird/client/ssh"
|
||||||
"github.com/netbirdio/netbird/client/system"
|
"github.com/netbirdio/netbird/client/system"
|
||||||
"github.com/netbirdio/netbird/iface"
|
"github.com/netbirdio/netbird/iface"
|
||||||
@ -22,7 +23,7 @@ import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
// RunClient with main logic.
|
// RunClient with main logic.
|
||||||
func RunClient(ctx context.Context, config *Config, statusRecorder *peer.Status, tunAdapter iface.TunAdapter) error {
|
func RunClient(ctx context.Context, config *Config, statusRecorder *peer.Status, tunAdapter iface.TunAdapter, iFaceDiscover stdnet.IFaceDiscover) error {
|
||||||
backOff := &backoff.ExponentialBackOff{
|
backOff := &backoff.ExponentialBackOff{
|
||||||
InitialInterval: time.Second,
|
InitialInterval: time.Second,
|
||||||
RandomizationFactor: 1,
|
RandomizationFactor: 1,
|
||||||
@ -146,7 +147,7 @@ func RunClient(ctx context.Context, config *Config, statusRecorder *peer.Status,
|
|||||||
|
|
||||||
peerConfig := loginResp.GetPeerConfig()
|
peerConfig := loginResp.GetPeerConfig()
|
||||||
|
|
||||||
engineConfig, err := createEngineConfig(myPrivateKey, config, peerConfig, tunAdapter)
|
engineConfig, err := createEngineConfig(myPrivateKey, config, peerConfig, tunAdapter, iFaceDiscover)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Error(err)
|
log.Error(err)
|
||||||
return wrapErr(err)
|
return wrapErr(err)
|
||||||
@ -163,6 +164,7 @@ func RunClient(ctx context.Context, config *Config, statusRecorder *peer.Status,
|
|||||||
state.Set(StatusConnected)
|
state.Set(StatusConnected)
|
||||||
|
|
||||||
<-engineCtx.Done()
|
<-engineCtx.Done()
|
||||||
|
statusRecorder.ClientTeardown()
|
||||||
|
|
||||||
backOff.Reset()
|
backOff.Reset()
|
||||||
|
|
||||||
@ -193,12 +195,13 @@ func RunClient(ctx context.Context, config *Config, statusRecorder *peer.Status,
|
|||||||
}
|
}
|
||||||
|
|
||||||
// createEngineConfig converts configuration received from Management Service to EngineConfig
|
// createEngineConfig converts configuration received from Management Service to EngineConfig
|
||||||
func createEngineConfig(key wgtypes.Key, config *Config, peerConfig *mgmProto.PeerConfig, tunAdapter iface.TunAdapter) (*EngineConfig, error) {
|
func createEngineConfig(key wgtypes.Key, config *Config, peerConfig *mgmProto.PeerConfig, tunAdapter iface.TunAdapter, iFaceDiscover stdnet.IFaceDiscover) (*EngineConfig, error) {
|
||||||
|
|
||||||
engineConf := &EngineConfig{
|
engineConf := &EngineConfig{
|
||||||
WgIfaceName: config.WgIface,
|
WgIfaceName: config.WgIface,
|
||||||
WgAddr: peerConfig.Address,
|
WgAddr: peerConfig.Address,
|
||||||
TunAdapter: tunAdapter,
|
TunAdapter: tunAdapter,
|
||||||
|
IFaceDiscover: iFaceDiscover,
|
||||||
IFaceBlackList: config.IFaceBlackList,
|
IFaceBlackList: config.IFaceBlackList,
|
||||||
DisableIPv6Discovery: config.DisableIPv6Discovery,
|
DisableIPv6Discovery: config.DisableIPv6Discovery,
|
||||||
WgPrivateKey: key,
|
WgPrivateKey: key,
|
||||||
|
@ -20,6 +20,7 @@ import (
|
|||||||
"github.com/netbirdio/netbird/client/internal/peer"
|
"github.com/netbirdio/netbird/client/internal/peer"
|
||||||
"github.com/netbirdio/netbird/client/internal/proxy"
|
"github.com/netbirdio/netbird/client/internal/proxy"
|
||||||
"github.com/netbirdio/netbird/client/internal/routemanager"
|
"github.com/netbirdio/netbird/client/internal/routemanager"
|
||||||
|
"github.com/netbirdio/netbird/client/internal/stdnet"
|
||||||
nbssh "github.com/netbirdio/netbird/client/ssh"
|
nbssh "github.com/netbirdio/netbird/client/ssh"
|
||||||
nbdns "github.com/netbirdio/netbird/dns"
|
nbdns "github.com/netbirdio/netbird/dns"
|
||||||
"github.com/netbirdio/netbird/iface"
|
"github.com/netbirdio/netbird/iface"
|
||||||
@ -49,6 +50,8 @@ type EngineConfig struct {
|
|||||||
// TunAdapter is option. It is necessary for mobile version.
|
// TunAdapter is option. It is necessary for mobile version.
|
||||||
TunAdapter iface.TunAdapter
|
TunAdapter iface.TunAdapter
|
||||||
|
|
||||||
|
IFaceDiscover stdnet.IFaceDiscover
|
||||||
|
|
||||||
// WgAddr is a Wireguard local address (Netbird Network IP)
|
// WgAddr is a Wireguard local address (Netbird Network IP)
|
||||||
WgAddr string
|
WgAddr string
|
||||||
|
|
||||||
@ -186,12 +189,22 @@ func (e *Engine) Start() error {
|
|||||||
networkName = "udp4"
|
networkName = "udp4"
|
||||||
}
|
}
|
||||||
|
|
||||||
|
transportNet, err := e.newStdNet()
|
||||||
|
if err != nil {
|
||||||
|
log.Warnf("failed to create pion's stdnet: %s", err)
|
||||||
|
}
|
||||||
|
|
||||||
e.udpMuxConn, err = net.ListenUDP(networkName, &net.UDPAddr{Port: e.config.UDPMuxPort})
|
e.udpMuxConn, err = net.ListenUDP(networkName, &net.UDPAddr{Port: e.config.UDPMuxPort})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Errorf("failed listening on UDP port %d: [%s]", e.config.UDPMuxPort, err.Error())
|
log.Errorf("failed listening on UDP port %d: [%s]", e.config.UDPMuxPort, err.Error())
|
||||||
e.close()
|
e.close()
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
udpMuxParams := ice.UDPMuxParams{
|
||||||
|
UDPConn: e.udpMuxConn,
|
||||||
|
Net: transportNet,
|
||||||
|
}
|
||||||
|
e.udpMux = ice.NewUDPMuxDefault(udpMuxParams)
|
||||||
|
|
||||||
e.udpMuxConnSrflx, err = net.ListenUDP(networkName, &net.UDPAddr{Port: e.config.UDPMuxSrflxPort})
|
e.udpMuxConnSrflx, err = net.ListenUDP(networkName, &net.UDPAddr{Port: e.config.UDPMuxSrflxPort})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@ -199,9 +212,7 @@ func (e *Engine) Start() error {
|
|||||||
e.close()
|
e.close()
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
e.udpMuxSrflx = ice.NewUniversalUDPMuxDefault(ice.UniversalUDPMuxParams{UDPConn: e.udpMuxConnSrflx, Net: transportNet})
|
||||||
e.udpMux = ice.NewUDPMuxDefault(ice.UDPMuxParams{UDPConn: e.udpMuxConn})
|
|
||||||
e.udpMuxSrflx = ice.NewUniversalUDPMuxDefault(ice.UniversalUDPMuxParams{UDPConn: e.udpMuxConnSrflx})
|
|
||||||
|
|
||||||
err = e.wgInterface.Create()
|
err = e.wgInterface.Create()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@ -813,7 +824,7 @@ func (e Engine) createPeerConn(pubKey string, allowedIPs string) (*peer.Conn, er
|
|||||||
NATExternalIPs: e.parseNATExternalIPMappings(),
|
NATExternalIPs: e.parseNATExternalIPMappings(),
|
||||||
}
|
}
|
||||||
|
|
||||||
peerConn, err := peer.NewConn(config, e.statusRecorder)
|
peerConn, err := peer.NewConn(config, e.statusRecorder, e.config.TunAdapter, e.config.IFaceDiscover)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
11
client/internal/engine_stdnet.go
Normal file
11
client/internal/engine_stdnet.go
Normal file
@ -0,0 +1,11 @@
|
|||||||
|
//go:build !android
|
||||||
|
|
||||||
|
package internal
|
||||||
|
|
||||||
|
import (
|
||||||
|
"github.com/pion/transport/v2/stdnet"
|
||||||
|
)
|
||||||
|
|
||||||
|
func (e *Engine) newStdNet() (*stdnet.Net, error) {
|
||||||
|
return stdnet.NewNet()
|
||||||
|
}
|
7
client/internal/engine_stdnet_android.go
Normal file
7
client/internal/engine_stdnet_android.go
Normal file
@ -0,0 +1,7 @@
|
|||||||
|
package internal
|
||||||
|
|
||||||
|
import "github.com/netbirdio/netbird/client/internal/stdnet"
|
||||||
|
|
||||||
|
func (e *Engine) newStdNet() (*stdnet.Net, error) {
|
||||||
|
return stdnet.NewNet(e.config.IFaceDiscover)
|
||||||
|
}
|
@ -9,15 +9,15 @@ import (
|
|||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/pion/ice/v2"
|
"github.com/pion/ice/v2"
|
||||||
"github.com/pion/transport/v2/stdnet"
|
|
||||||
log "github.com/sirupsen/logrus"
|
log "github.com/sirupsen/logrus"
|
||||||
"golang.zx2c4.com/wireguard/wgctrl"
|
"golang.zx2c4.com/wireguard/wgctrl"
|
||||||
|
|
||||||
"github.com/netbirdio/netbird/client/internal/proxy"
|
"github.com/netbirdio/netbird/client/internal/proxy"
|
||||||
|
"github.com/netbirdio/netbird/client/internal/stdnet"
|
||||||
"github.com/netbirdio/netbird/iface"
|
"github.com/netbirdio/netbird/iface"
|
||||||
"github.com/netbirdio/netbird/version"
|
|
||||||
signal "github.com/netbirdio/netbird/signal/client"
|
signal "github.com/netbirdio/netbird/signal/client"
|
||||||
sProto "github.com/netbirdio/netbird/signal/proto"
|
sProto "github.com/netbirdio/netbird/signal/proto"
|
||||||
|
"github.com/netbirdio/netbird/version"
|
||||||
)
|
)
|
||||||
|
|
||||||
// ConnConfig is a peer Connection configuration
|
// ConnConfig is a peer Connection configuration
|
||||||
@ -93,6 +93,9 @@ type Conn struct {
|
|||||||
proxy proxy.Proxy
|
proxy proxy.Proxy
|
||||||
remoteModeCh chan ModeMessage
|
remoteModeCh chan ModeMessage
|
||||||
meta meta
|
meta meta
|
||||||
|
|
||||||
|
adapter iface.TunAdapter
|
||||||
|
iFaceDiscover stdnet.IFaceDiscover
|
||||||
}
|
}
|
||||||
|
|
||||||
// meta holds meta information about a connection
|
// meta holds meta information about a connection
|
||||||
@ -118,7 +121,7 @@ func (conn *Conn) UpdateConf(conf ConnConfig) {
|
|||||||
|
|
||||||
// NewConn creates a new not opened Conn to the remote peer.
|
// NewConn creates a new not opened Conn to the remote peer.
|
||||||
// To establish a connection run Conn.Open
|
// To establish a connection run Conn.Open
|
||||||
func NewConn(config ConnConfig, statusRecorder *Status) (*Conn, error) {
|
func NewConn(config ConnConfig, statusRecorder *Status, adapter iface.TunAdapter, iFaceDiscover stdnet.IFaceDiscover) (*Conn, error) {
|
||||||
return &Conn{
|
return &Conn{
|
||||||
config: config,
|
config: config,
|
||||||
mu: sync.Mutex{},
|
mu: sync.Mutex{},
|
||||||
@ -128,6 +131,8 @@ func NewConn(config ConnConfig, statusRecorder *Status) (*Conn, error) {
|
|||||||
remoteAnswerCh: make(chan OfferAnswer),
|
remoteAnswerCh: make(chan OfferAnswer),
|
||||||
statusRecorder: statusRecorder,
|
statusRecorder: statusRecorder,
|
||||||
remoteModeCh: make(chan ModeMessage, 1),
|
remoteModeCh: make(chan ModeMessage, 1),
|
||||||
|
adapter: adapter,
|
||||||
|
iFaceDiscover: iFaceDiscover,
|
||||||
}, nil
|
}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -162,7 +167,9 @@ func (conn *Conn) reCreateAgent() error {
|
|||||||
defer conn.mu.Unlock()
|
defer conn.mu.Unlock()
|
||||||
|
|
||||||
failedTimeout := 6 * time.Second
|
failedTimeout := 6 * time.Second
|
||||||
transportNet, err := stdnet.NewNet()
|
|
||||||
|
var err error
|
||||||
|
transportNet, err := conn.newStdNet()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Warnf("failed to create pion's stdnet: %s", err)
|
log.Warnf("failed to create pion's stdnet: %s", err)
|
||||||
}
|
}
|
||||||
|
@ -37,7 +37,7 @@ func TestNewConn_interfaceFilter(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func TestConn_GetKey(t *testing.T) {
|
func TestConn_GetKey(t *testing.T) {
|
||||||
conn, err := NewConn(connConf, nil)
|
conn, err := NewConn(connConf, nil, nil, nil)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
@ -49,7 +49,7 @@ func TestConn_GetKey(t *testing.T) {
|
|||||||
|
|
||||||
func TestConn_OnRemoteOffer(t *testing.T) {
|
func TestConn_OnRemoteOffer(t *testing.T) {
|
||||||
|
|
||||||
conn, err := NewConn(connConf, NewRecorder("https://mgm"))
|
conn, err := NewConn(connConf, NewRecorder("https://mgm"), nil, nil)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
@ -83,7 +83,7 @@ func TestConn_OnRemoteOffer(t *testing.T) {
|
|||||||
|
|
||||||
func TestConn_OnRemoteAnswer(t *testing.T) {
|
func TestConn_OnRemoteAnswer(t *testing.T) {
|
||||||
|
|
||||||
conn, err := NewConn(connConf, NewRecorder("https://mgm"))
|
conn, err := NewConn(connConf, NewRecorder("https://mgm"), nil, nil)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
@ -116,7 +116,7 @@ func TestConn_OnRemoteAnswer(t *testing.T) {
|
|||||||
}
|
}
|
||||||
func TestConn_Status(t *testing.T) {
|
func TestConn_Status(t *testing.T) {
|
||||||
|
|
||||||
conn, err := NewConn(connConf, NewRecorder("https://mgm"))
|
conn, err := NewConn(connConf, NewRecorder("https://mgm"), nil, nil)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
@ -143,7 +143,7 @@ func TestConn_Status(t *testing.T) {
|
|||||||
|
|
||||||
func TestConn_Close(t *testing.T) {
|
func TestConn_Close(t *testing.T) {
|
||||||
|
|
||||||
conn, err := NewConn(connConf, NewRecorder("https://mgm"))
|
conn, err := NewConn(connConf, NewRecorder("https://mgm"), nil, nil)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
@ -411,7 +411,7 @@ func TestGetProxyWithMessageExchange(t *testing.T) {
|
|||||||
for _, testCase := range testCases {
|
for _, testCase := range testCases {
|
||||||
t.Run(testCase.name, func(t *testing.T) {
|
t.Run(testCase.name, func(t *testing.T) {
|
||||||
g := errgroup.Group{}
|
g := errgroup.Group{}
|
||||||
conn, err := NewConn(connConf, nil)
|
conn, err := NewConn(connConf, nil, nil, nil)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
|
@ -5,5 +5,7 @@ type Listener interface {
|
|||||||
OnConnected()
|
OnConnected()
|
||||||
OnDisconnected()
|
OnDisconnected()
|
||||||
OnConnecting()
|
OnConnecting()
|
||||||
|
OnDisconnecting()
|
||||||
|
OnAddressChanged(string, string)
|
||||||
OnPeersListChanged(int)
|
OnPeersListChanged(int)
|
||||||
}
|
}
|
||||||
|
@ -8,6 +8,7 @@ const (
|
|||||||
stateDisconnected = iota
|
stateDisconnected = iota
|
||||||
stateConnected
|
stateConnected
|
||||||
stateConnecting
|
stateConnecting
|
||||||
|
stateDisconnecting
|
||||||
)
|
)
|
||||||
|
|
||||||
type notifier struct {
|
type notifier struct {
|
||||||
@ -57,8 +58,12 @@ func (n *notifier) updateServerStates(mgmState bool, signalState bool) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
n.currentServerState = newState
|
n.currentServerState = newState
|
||||||
n.lastNotification = n.calculateState(newState, n.currentClientState)
|
|
||||||
|
|
||||||
|
if n.lastNotification == stateDisconnecting {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
n.lastNotification = n.calculateState(newState, n.currentClientState)
|
||||||
go n.notifyAll(n.lastNotification)
|
go n.notifyAll(n.lastNotification)
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -78,6 +83,14 @@ func (n *notifier) clientStop() {
|
|||||||
go n.notifyAll(n.lastNotification)
|
go n.notifyAll(n.lastNotification)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (n *notifier) clientTearDown() {
|
||||||
|
n.serverStateLock.Lock()
|
||||||
|
defer n.serverStateLock.Unlock()
|
||||||
|
n.currentClientState = false
|
||||||
|
n.lastNotification = stateDisconnecting
|
||||||
|
go n.notifyAll(n.lastNotification)
|
||||||
|
}
|
||||||
|
|
||||||
func (n *notifier) isServerStateChanged(newState bool) bool {
|
func (n *notifier) isServerStateChanged(newState bool) bool {
|
||||||
return n.currentServerState != newState
|
return n.currentServerState != newState
|
||||||
}
|
}
|
||||||
@ -99,6 +112,8 @@ func (n *notifier) notifyListener(l Listener, state int) {
|
|||||||
l.OnConnected()
|
l.OnConnected()
|
||||||
case stateConnecting:
|
case stateConnecting:
|
||||||
l.OnConnecting()
|
l.OnConnecting()
|
||||||
|
case stateDisconnecting:
|
||||||
|
l.OnDisconnecting()
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -122,3 +137,12 @@ func (n *notifier) peerListChanged(numOfPeers int) {
|
|||||||
l.OnPeersListChanged(numOfPeers)
|
l.OnPeersListChanged(numOfPeers)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (n *notifier) localAddressChanged(fqdn, address string) {
|
||||||
|
n.listenersLock.Lock()
|
||||||
|
defer n.listenersLock.Unlock()
|
||||||
|
|
||||||
|
for l := range n.listeners {
|
||||||
|
l.OnAddressChanged(fqdn, address)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
@ -190,6 +190,7 @@ func (d *Status) UpdateLocalPeerState(localPeerState LocalPeerState) {
|
|||||||
defer d.mux.Unlock()
|
defer d.mux.Unlock()
|
||||||
|
|
||||||
d.localPeer = localPeerState
|
d.localPeer = localPeerState
|
||||||
|
d.notifyAddressChanged()
|
||||||
}
|
}
|
||||||
|
|
||||||
// CleanLocalPeerState cleans local peer status
|
// CleanLocalPeerState cleans local peer status
|
||||||
@ -198,6 +199,7 @@ func (d *Status) CleanLocalPeerState() {
|
|||||||
defer d.mux.Unlock()
|
defer d.mux.Unlock()
|
||||||
|
|
||||||
d.localPeer = LocalPeerState{}
|
d.localPeer = LocalPeerState{}
|
||||||
|
d.notifyAddressChanged()
|
||||||
}
|
}
|
||||||
|
|
||||||
// MarkManagementDisconnected sets ManagementState to disconnected
|
// MarkManagementDisconnected sets ManagementState to disconnected
|
||||||
@ -215,7 +217,7 @@ func (d *Status) MarkManagementConnected() {
|
|||||||
defer d.mux.Unlock()
|
defer d.mux.Unlock()
|
||||||
defer d.onConnectionChanged()
|
defer d.onConnectionChanged()
|
||||||
|
|
||||||
d.managementState = true
|
d.managementState = true
|
||||||
}
|
}
|
||||||
|
|
||||||
// UpdateSignalAddress update the address of the signal server
|
// UpdateSignalAddress update the address of the signal server
|
||||||
@ -238,7 +240,7 @@ func (d *Status) MarkSignalDisconnected() {
|
|||||||
defer d.mux.Unlock()
|
defer d.mux.Unlock()
|
||||||
defer d.onConnectionChanged()
|
defer d.onConnectionChanged()
|
||||||
|
|
||||||
d.signalState = false
|
d.signalState = false
|
||||||
}
|
}
|
||||||
|
|
||||||
// MarkSignalConnected sets SignalState to connected
|
// MarkSignalConnected sets SignalState to connected
|
||||||
@ -286,6 +288,11 @@ func (d *Status) ClientStop() {
|
|||||||
d.notifier.clientStop()
|
d.notifier.clientStop()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// ClientTeardown will notify all listeners about the service is under teardown
|
||||||
|
func (d *Status) ClientTeardown() {
|
||||||
|
d.notifier.clientTearDown()
|
||||||
|
}
|
||||||
|
|
||||||
// AddConnectionListener add a listener to the notifier
|
// AddConnectionListener add a listener to the notifier
|
||||||
func (d *Status) AddConnectionListener(listener Listener) {
|
func (d *Status) AddConnectionListener(listener Listener) {
|
||||||
d.notifier.addListener(listener)
|
d.notifier.addListener(listener)
|
||||||
@ -303,3 +310,7 @@ func (d *Status) onConnectionChanged() {
|
|||||||
func (d *Status) notifyPeerListChanged() {
|
func (d *Status) notifyPeerListChanged() {
|
||||||
d.notifier.peerListChanged(len(d.peers))
|
d.notifier.peerListChanged(len(d.peers))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (d *Status) notifyAddressChanged() {
|
||||||
|
d.notifier.localAddressChanged(d.localPeer.FQDN, d.localPeer.IP)
|
||||||
|
}
|
||||||
|
11
client/internal/peer/stdnet.go
Normal file
11
client/internal/peer/stdnet.go
Normal file
@ -0,0 +1,11 @@
|
|||||||
|
//go:build !android
|
||||||
|
|
||||||
|
package peer
|
||||||
|
|
||||||
|
import (
|
||||||
|
"github.com/pion/transport/v2/stdnet"
|
||||||
|
)
|
||||||
|
|
||||||
|
func (conn *Conn) newStdNet() (*stdnet.Net, error) {
|
||||||
|
return stdnet.NewNet()
|
||||||
|
}
|
7
client/internal/peer/stdnet_android.go
Normal file
7
client/internal/peer/stdnet_android.go
Normal file
@ -0,0 +1,7 @@
|
|||||||
|
package peer
|
||||||
|
|
||||||
|
import "github.com/netbirdio/netbird/client/internal/stdnet"
|
||||||
|
|
||||||
|
func (conn *Conn) newStdNet() (*stdnet.Net, error) {
|
||||||
|
return stdnet.NewNet(conn.iFaceDiscover)
|
||||||
|
}
|
8
client/internal/stdnet/iface_discover.go
Normal file
8
client/internal/stdnet/iface_discover.go
Normal file
@ -0,0 +1,8 @@
|
|||||||
|
package stdnet
|
||||||
|
|
||||||
|
// IFaceDiscover provide an option for external services (mobile)
|
||||||
|
// to collect network interface information
|
||||||
|
type IFaceDiscover interface {
|
||||||
|
// IFaces return with the description of the interfaces
|
||||||
|
IFaces() (string, error)
|
||||||
|
}
|
137
client/internal/stdnet/stdnet.go
Normal file
137
client/internal/stdnet/stdnet.go
Normal file
@ -0,0 +1,137 @@
|
|||||||
|
// Package stdnet is an extension of the pion's stdnet.
|
||||||
|
// With it the list of the interface can come from external source.
|
||||||
|
// More info: https://github.com/golang/go/issues/40569
|
||||||
|
package stdnet
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"net"
|
||||||
|
"strings"
|
||||||
|
|
||||||
|
"github.com/pion/transport/v2"
|
||||||
|
"github.com/pion/transport/v2/stdnet"
|
||||||
|
log "github.com/sirupsen/logrus"
|
||||||
|
)
|
||||||
|
|
||||||
|
// Net is an implementation of the net.Net interface
|
||||||
|
// based on functions of the standard net package.
|
||||||
|
type Net struct {
|
||||||
|
stdnet.Net
|
||||||
|
interfaces []*transport.Interface
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewNet creates a new StdNet instance.
|
||||||
|
func NewNet(iFaceDiscover IFaceDiscover) (*Net, error) {
|
||||||
|
n := &Net{}
|
||||||
|
|
||||||
|
return n, n.UpdateInterfaces(iFaceDiscover)
|
||||||
|
}
|
||||||
|
|
||||||
|
// UpdateInterfaces updates the internal list of network interfaces
|
||||||
|
// and associated addresses.
|
||||||
|
func (n *Net) UpdateInterfaces(iFaceDiscover IFaceDiscover) error {
|
||||||
|
ifacesString, err := iFaceDiscover.IFaces()
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
n.interfaces = parseInterfacesString(ifacesString)
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
// Interfaces returns a slice of interfaces which are available on the
|
||||||
|
// system
|
||||||
|
func (n *Net) Interfaces() ([]*transport.Interface, error) {
|
||||||
|
return n.interfaces, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// InterfaceByIndex returns the interface specified by index.
|
||||||
|
//
|
||||||
|
// On Solaris, it returns one of the logical network interfaces
|
||||||
|
// sharing the logical data link; for more precision use
|
||||||
|
// InterfaceByName.
|
||||||
|
func (n *Net) InterfaceByIndex(index int) (*transport.Interface, error) {
|
||||||
|
for _, ifc := range n.interfaces {
|
||||||
|
if ifc.Index == index {
|
||||||
|
return ifc, nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil, fmt.Errorf("%w: index=%d", transport.ErrInterfaceNotFound, index)
|
||||||
|
}
|
||||||
|
|
||||||
|
// InterfaceByName returns the interface specified by name.
|
||||||
|
func (n *Net) InterfaceByName(name string) (*transport.Interface, error) {
|
||||||
|
for _, ifc := range n.interfaces {
|
||||||
|
if ifc.Name == name {
|
||||||
|
return ifc, nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil, fmt.Errorf("%w: %s", transport.ErrInterfaceNotFound, name)
|
||||||
|
}
|
||||||
|
|
||||||
|
func parseInterfacesString(interfaces string) []*transport.Interface {
|
||||||
|
ifs := []*transport.Interface{}
|
||||||
|
|
||||||
|
for _, iface := range strings.Split(interfaces, "\n") {
|
||||||
|
if strings.TrimSpace(iface) == "" {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
fields := strings.Split(iface, "|")
|
||||||
|
if len(fields) != 2 {
|
||||||
|
log.Warnf("parseInterfacesString: unable to split %q", iface)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
var name string
|
||||||
|
var index, mtu int
|
||||||
|
var up, broadcast, loopback, pointToPoint, multicast bool
|
||||||
|
_, err := fmt.Sscanf(fields[0], "%s %d %d %t %t %t %t %t",
|
||||||
|
&name, &index, &mtu, &up, &broadcast, &loopback, &pointToPoint, &multicast)
|
||||||
|
if err != nil {
|
||||||
|
log.Warnf("parseInterfacesString: unable to parse %q: %v", iface, err)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
newIf := net.Interface{
|
||||||
|
Name: name,
|
||||||
|
Index: index,
|
||||||
|
MTU: mtu,
|
||||||
|
}
|
||||||
|
if up {
|
||||||
|
newIf.Flags |= net.FlagUp
|
||||||
|
}
|
||||||
|
if broadcast {
|
||||||
|
newIf.Flags |= net.FlagBroadcast
|
||||||
|
}
|
||||||
|
if loopback {
|
||||||
|
newIf.Flags |= net.FlagLoopback
|
||||||
|
}
|
||||||
|
if pointToPoint {
|
||||||
|
newIf.Flags |= net.FlagPointToPoint
|
||||||
|
}
|
||||||
|
if multicast {
|
||||||
|
newIf.Flags |= net.FlagMulticast
|
||||||
|
}
|
||||||
|
|
||||||
|
ifc := transport.NewInterface(newIf)
|
||||||
|
|
||||||
|
addrs := strings.Trim(fields[1], " \n")
|
||||||
|
foundAddress := false
|
||||||
|
for _, addr := range strings.Split(addrs, " ") {
|
||||||
|
ip, ipNet, err := net.ParseCIDR(addr)
|
||||||
|
if err != nil {
|
||||||
|
log.Warnf("%s", err)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
ipNet.IP = ip
|
||||||
|
ifc.AddAddress(ipNet)
|
||||||
|
foundAddress = true
|
||||||
|
}
|
||||||
|
if foundAddress {
|
||||||
|
ifs = append(ifs, ifc)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return ifs
|
||||||
|
}
|
66
client/internal/stdnet/stdnet_test.go
Normal file
66
client/internal/stdnet/stdnet_test.go
Normal file
@ -0,0 +1,66 @@
|
|||||||
|
package stdnet
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"testing"
|
||||||
|
)
|
||||||
|
|
||||||
|
func Test_parseInterfacesString(t *testing.T) {
|
||||||
|
testData := []struct {
|
||||||
|
name string
|
||||||
|
index int
|
||||||
|
mtu int
|
||||||
|
up bool
|
||||||
|
broadcast bool
|
||||||
|
loopBack bool
|
||||||
|
pointToPoint bool
|
||||||
|
multicast bool
|
||||||
|
addr string
|
||||||
|
}{
|
||||||
|
{"wlan0", 30, 1500, true, true, false, false, true, "10.1.10.131/24"},
|
||||||
|
{"rmnet0", 30, 1500, true, true, false, false, true, "192.168.0.56/24"},
|
||||||
|
{"rmnet_data1", 30, 1500, true, true, false, false, true, "fec0::118c:faf7:8d97:3cb2/64"},
|
||||||
|
}
|
||||||
|
|
||||||
|
var exampleString string
|
||||||
|
for _, d := range testData {
|
||||||
|
exampleString = fmt.Sprintf("%s\n%s %d %d %t %t %t %t %t | %s", exampleString,
|
||||||
|
d.name,
|
||||||
|
d.index,
|
||||||
|
d.mtu,
|
||||||
|
d.up,
|
||||||
|
d.broadcast,
|
||||||
|
d.loopBack,
|
||||||
|
d.pointToPoint,
|
||||||
|
d.multicast,
|
||||||
|
d.addr)
|
||||||
|
}
|
||||||
|
nets := parseInterfacesString(exampleString)
|
||||||
|
if len(nets) == 0 {
|
||||||
|
t.Fatalf("failed to parse interfaces")
|
||||||
|
}
|
||||||
|
|
||||||
|
for i, net := range nets {
|
||||||
|
if net.MTU != testData[i].mtu {
|
||||||
|
t.Errorf("invalid mtu: %d, expected: %d", net.MTU, testData[0].mtu)
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
if net.Interface.Name != testData[i].name {
|
||||||
|
t.Errorf("invalid interface name: %s, expected: %s", net.Interface.Name, testData[i].name)
|
||||||
|
}
|
||||||
|
|
||||||
|
addr, err := net.Addrs()
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(addr) == 0 {
|
||||||
|
t.Errorf("invalid address parsing")
|
||||||
|
}
|
||||||
|
|
||||||
|
if addr[0].String() != testData[i].addr {
|
||||||
|
t.Errorf("invalid address: %s, expected: %s", addr[0].String(), testData[i].addr)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
@ -78,7 +78,7 @@ func (s *Server) Start() error {
|
|||||||
// on failure we return error to retry
|
// on failure we return error to retry
|
||||||
config, err := internal.UpdateConfig(s.latestConfigInput)
|
config, err := internal.UpdateConfig(s.latestConfigInput)
|
||||||
if errorStatus, ok := gstatus.FromError(err); ok && errorStatus.Code() == codes.NotFound {
|
if errorStatus, ok := gstatus.FromError(err); ok && errorStatus.Code() == codes.NotFound {
|
||||||
config, err = internal.UpdateOrCreateConfig(s.latestConfigInput)
|
s.config, err = internal.UpdateOrCreateConfig(s.latestConfigInput)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Warnf("unable to create configuration file: %v", err)
|
log.Warnf("unable to create configuration file: %v", err)
|
||||||
return err
|
return err
|
||||||
@ -102,7 +102,7 @@ func (s *Server) Start() error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
go func() {
|
go func() {
|
||||||
if err := internal.RunClient(ctx, config, s.statusRecorder, nil); err != nil {
|
if err := internal.RunClient(ctx, config, s.statusRecorder, nil, nil); err != nil {
|
||||||
log.Errorf("init connections: %v", err)
|
log.Errorf("init connections: %v", err)
|
||||||
}
|
}
|
||||||
}()
|
}()
|
||||||
@ -394,7 +394,7 @@ func (s *Server) Up(callerCtx context.Context, _ *proto.UpRequest) (*proto.UpRes
|
|||||||
}
|
}
|
||||||
|
|
||||||
go func() {
|
go func() {
|
||||||
if err := internal.RunClient(ctx, s.config, s.statusRecorder, nil); err != nil {
|
if err := internal.RunClient(ctx, s.config, s.statusRecorder, nil, nil); err != nil {
|
||||||
log.Errorf("run client connection: %v", err)
|
log.Errorf("run client connection: %v", err)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
@ -34,7 +34,7 @@ func GetInfo(ctx context.Context) *Info {
|
|||||||
func extractDeviceName(ctx context.Context) string {
|
func extractDeviceName(ctx context.Context) string {
|
||||||
v, ok := ctx.Value(DeviceNameCtxKey).(string)
|
v, ok := ctx.Value(DeviceNameCtxKey).(string)
|
||||||
if !ok {
|
if !ok {
|
||||||
return ""
|
return "android"
|
||||||
}
|
}
|
||||||
return v
|
return v
|
||||||
}
|
}
|
||||||
|
@ -3,10 +3,13 @@ package encryption
|
|||||||
import (
|
import (
|
||||||
"crypto/rand"
|
"crypto/rand"
|
||||||
"fmt"
|
"fmt"
|
||||||
|
|
||||||
"golang.org/x/crypto/nacl/box"
|
"golang.org/x/crypto/nacl/box"
|
||||||
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
|
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
const nonceSize = 24
|
||||||
|
|
||||||
// A set of tools to encrypt/decrypt messages being sent through the Signal Exchange Service or Management Service
|
// A set of tools to encrypt/decrypt messages being sent through the Signal Exchange Service or Management Service
|
||||||
// These tools use Golang crypto package (Curve25519, XSalsa20 and Poly1305 to encrypt and authenticate)
|
// These tools use Golang crypto package (Curve25519, XSalsa20 and Poly1305 to encrypt and authenticate)
|
||||||
// Wireguard keys are used for encryption
|
// Wireguard keys are used for encryption
|
||||||
@ -26,8 +29,11 @@ func Decrypt(encryptedMsg []byte, peerPublicKey wgtypes.Key, privateKey wgtypes.
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
copy(nonce[:], encryptedMsg[:24])
|
if len(encryptedMsg) < nonceSize {
|
||||||
opened, ok := box.Open(nil, encryptedMsg[24:], nonce, toByte32(peerPublicKey), toByte32(privateKey))
|
return nil, fmt.Errorf("invalid encrypted message lenght")
|
||||||
|
}
|
||||||
|
copy(nonce[:], encryptedMsg[:nonceSize])
|
||||||
|
opened, ok := box.Open(nil, encryptedMsg[nonceSize:], nonce, toByte32(peerPublicKey), toByte32(privateKey))
|
||||||
if !ok {
|
if !ok {
|
||||||
return nil, fmt.Errorf("failed to decrypt message from peer %s", peerPublicKey.String())
|
return nil, fmt.Errorf("failed to decrypt message from peer %s", peerPublicKey.String())
|
||||||
}
|
}
|
||||||
@ -36,8 +42,8 @@ func Decrypt(encryptedMsg []byte, peerPublicKey wgtypes.Key, privateKey wgtypes.
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Generates nonce of size 24
|
// Generates nonce of size 24
|
||||||
func genNonce() (*[24]byte, error) {
|
func genNonce() (*[nonceSize]byte, error) {
|
||||||
var nonce [24]byte
|
var nonce [nonceSize]byte
|
||||||
if _, err := rand.Read(nonce[:]); err != nil {
|
if _, err := rand.Read(nonce[:]); err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
@ -33,7 +33,7 @@ func toWgUserspaceString(wgCfg wgtypes.Config) string {
|
|||||||
|
|
||||||
if p.PresharedKey != nil {
|
if p.PresharedKey != nil {
|
||||||
preSharedHexKey := hex.EncodeToString(p.PresharedKey[:])
|
preSharedHexKey := hex.EncodeToString(p.PresharedKey[:])
|
||||||
sb.WriteString(fmt.Sprintf("public_key=%s\n", preSharedHexKey))
|
sb.WriteString(fmt.Sprintf("preshared_key=%s\n", preSharedHexKey))
|
||||||
}
|
}
|
||||||
|
|
||||||
if p.Remove {
|
if p.Remove {
|
||||||
|
@ -68,10 +68,10 @@ type AccountManager interface {
|
|||||||
GetNetworkMap(peerID string) (*NetworkMap, error)
|
GetNetworkMap(peerID string) (*NetworkMap, error)
|
||||||
GetPeerNetwork(peerID string) (*Network, error)
|
GetPeerNetwork(peerID string) (*Network, error)
|
||||||
AddPeer(setupKey, userID string, peer *Peer) (*Peer, *NetworkMap, error)
|
AddPeer(setupKey, userID string, peer *Peer) (*Peer, *NetworkMap, error)
|
||||||
CreatePAT(accountID string, executingUserID string, targetUserId string, tokenName string, expiresIn int) (*PersonalAccessTokenGenerated, error)
|
CreatePAT(accountID string, executingUserID string, targetUserID string, tokenName string, expiresIn int) (*PersonalAccessTokenGenerated, error)
|
||||||
DeletePAT(accountID string, executingUserID string, targetUserId string, tokenID string) error
|
DeletePAT(accountID string, executingUserID string, targetUserID string, tokenID string) error
|
||||||
GetPAT(accountID string, executingUserID string, targetUserId string, tokenID string) (*PersonalAccessToken, error)
|
GetPAT(accountID string, executingUserID string, targetUserID string, tokenID string) (*PersonalAccessToken, error)
|
||||||
GetAllPATs(accountID string, executingUserID string, targetUserId string) ([]*PersonalAccessToken, error)
|
GetAllPATs(accountID string, executingUserID string, targetUserID string) ([]*PersonalAccessToken, error)
|
||||||
UpdatePeerSSHKey(peerID string, sshKey string) error
|
UpdatePeerSSHKey(peerID string, sshKey string) error
|
||||||
GetUsersFromAccount(accountID, userID string) ([]*UserInfo, error)
|
GetUsersFromAccount(accountID, userID string) ([]*UserInfo, error)
|
||||||
GetGroup(accountId, groupID string) (*Group, error)
|
GetGroup(accountId, groupID string) (*Group, error)
|
||||||
@ -362,11 +362,11 @@ func (a *Account) GetNextPeerExpiration() (time.Duration, bool) {
|
|||||||
return *nextExpiry, true
|
return *nextExpiry, true
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetPeersWithExpiration returns a list of peers that have Peer.LoginExpirationEnabled set to true
|
// GetPeersWithExpiration returns a list of peers that have Peer.LoginExpirationEnabled set to true and that were added by a user
|
||||||
func (a *Account) GetPeersWithExpiration() []*Peer {
|
func (a *Account) GetPeersWithExpiration() []*Peer {
|
||||||
peers := make([]*Peer, 0)
|
peers := make([]*Peer, 0)
|
||||||
for _, peer := range a.Peers {
|
for _, peer := range a.Peers {
|
||||||
if peer.LoginExpirationEnabled {
|
if peer.LoginExpirationEnabled && peer.AddedWithSSOLogin() {
|
||||||
peers = append(peers, peer)
|
peers = append(peers, peer)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -1608,9 +1608,11 @@ func TestAccount_GetPeersWithExpiration(t *testing.T) {
|
|||||||
peers: map[string]*Peer{
|
peers: map[string]*Peer{
|
||||||
"peer-1": {
|
"peer-1": {
|
||||||
LoginExpirationEnabled: false,
|
LoginExpirationEnabled: false,
|
||||||
|
UserID: userID,
|
||||||
},
|
},
|
||||||
"peer-2": {
|
"peer-2": {
|
||||||
LoginExpirationEnabled: false,
|
LoginExpirationEnabled: false,
|
||||||
|
UserID: userID,
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
expectedPeers: map[string]struct{}{},
|
expectedPeers: map[string]struct{}{},
|
||||||
@ -1621,9 +1623,11 @@ func TestAccount_GetPeersWithExpiration(t *testing.T) {
|
|||||||
"peer-1": {
|
"peer-1": {
|
||||||
ID: "peer-1",
|
ID: "peer-1",
|
||||||
LoginExpirationEnabled: true,
|
LoginExpirationEnabled: true,
|
||||||
|
UserID: userID,
|
||||||
},
|
},
|
||||||
"peer-2": {
|
"peer-2": {
|
||||||
LoginExpirationEnabled: false,
|
LoginExpirationEnabled: false,
|
||||||
|
UserID: userID,
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
expectedPeers: map[string]struct{}{
|
expectedPeers: map[string]struct{}{
|
||||||
@ -1683,12 +1687,14 @@ func TestAccount_GetNextPeerExpiration(t *testing.T) {
|
|||||||
Connected: false,
|
Connected: false,
|
||||||
},
|
},
|
||||||
LoginExpirationEnabled: true,
|
LoginExpirationEnabled: true,
|
||||||
|
UserID: userID,
|
||||||
},
|
},
|
||||||
"peer-2": {
|
"peer-2": {
|
||||||
Status: &PeerStatus{
|
Status: &PeerStatus{
|
||||||
Connected: true,
|
Connected: true,
|
||||||
},
|
},
|
||||||
LoginExpirationEnabled: false,
|
LoginExpirationEnabled: false,
|
||||||
|
UserID: userID,
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
expiration: time.Second,
|
expiration: time.Second,
|
||||||
@ -1704,12 +1710,14 @@ func TestAccount_GetNextPeerExpiration(t *testing.T) {
|
|||||||
Connected: true,
|
Connected: true,
|
||||||
},
|
},
|
||||||
LoginExpirationEnabled: false,
|
LoginExpirationEnabled: false,
|
||||||
|
UserID: userID,
|
||||||
},
|
},
|
||||||
"peer-2": {
|
"peer-2": {
|
||||||
Status: &PeerStatus{
|
Status: &PeerStatus{
|
||||||
Connected: true,
|
Connected: true,
|
||||||
},
|
},
|
||||||
LoginExpirationEnabled: false,
|
LoginExpirationEnabled: false,
|
||||||
|
UserID: userID,
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
expiration: time.Second,
|
expiration: time.Second,
|
||||||
@ -1726,6 +1734,7 @@ func TestAccount_GetNextPeerExpiration(t *testing.T) {
|
|||||||
LoginExpired: true,
|
LoginExpired: true,
|
||||||
},
|
},
|
||||||
LoginExpirationEnabled: true,
|
LoginExpirationEnabled: true,
|
||||||
|
UserID: userID,
|
||||||
},
|
},
|
||||||
"peer-2": {
|
"peer-2": {
|
||||||
Status: &PeerStatus{
|
Status: &PeerStatus{
|
||||||
@ -1733,6 +1742,7 @@ func TestAccount_GetNextPeerExpiration(t *testing.T) {
|
|||||||
LoginExpired: true,
|
LoginExpired: true,
|
||||||
},
|
},
|
||||||
LoginExpirationEnabled: true,
|
LoginExpirationEnabled: true,
|
||||||
|
UserID: userID,
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
expiration: time.Second,
|
expiration: time.Second,
|
||||||
@ -1750,6 +1760,7 @@ func TestAccount_GetNextPeerExpiration(t *testing.T) {
|
|||||||
},
|
},
|
||||||
LoginExpirationEnabled: true,
|
LoginExpirationEnabled: true,
|
||||||
LastLogin: time.Now(),
|
LastLogin: time.Now(),
|
||||||
|
UserID: userID,
|
||||||
},
|
},
|
||||||
"peer-2": {
|
"peer-2": {
|
||||||
Status: &PeerStatus{
|
Status: &PeerStatus{
|
||||||
@ -1757,6 +1768,7 @@ func TestAccount_GetNextPeerExpiration(t *testing.T) {
|
|||||||
LoginExpired: true,
|
LoginExpired: true,
|
||||||
},
|
},
|
||||||
LoginExpirationEnabled: true,
|
LoginExpirationEnabled: true,
|
||||||
|
UserID: userID,
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
expiration: time.Minute,
|
expiration: time.Minute,
|
||||||
@ -1764,6 +1776,31 @@ func TestAccount_GetNextPeerExpiration(t *testing.T) {
|
|||||||
expectedNextRun: true,
|
expectedNextRun: true,
|
||||||
expectedNextExpiration: expectedNextExpiration,
|
expectedNextExpiration: expectedNextExpiration,
|
||||||
},
|
},
|
||||||
|
{
|
||||||
|
name: "Peers added with setup keys, no expiration",
|
||||||
|
peers: map[string]*Peer{
|
||||||
|
"peer-1": {
|
||||||
|
Status: &PeerStatus{
|
||||||
|
Connected: true,
|
||||||
|
LoginExpired: false,
|
||||||
|
},
|
||||||
|
LoginExpirationEnabled: true,
|
||||||
|
SetupKey: "key",
|
||||||
|
},
|
||||||
|
"peer-2": {
|
||||||
|
Status: &PeerStatus{
|
||||||
|
Connected: true,
|
||||||
|
LoginExpired: false,
|
||||||
|
},
|
||||||
|
LoginExpirationEnabled: true,
|
||||||
|
SetupKey: "key",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
expiration: time.Second,
|
||||||
|
expirationEnabled: false,
|
||||||
|
expectedNextRun: false,
|
||||||
|
expectedNextExpiration: time.Duration(0),
|
||||||
|
},
|
||||||
}
|
}
|
||||||
for _, testCase := range testCases {
|
for _, testCase := range testCases {
|
||||||
t.Run(testCase.name, func(t *testing.T) {
|
t.Run(testCase.name, func(t *testing.T) {
|
||||||
|
@ -46,10 +46,10 @@ func NewServer(config *Config, accountManager AccountManager, peersUpdateManager
|
|||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
var jwtValidator *jwtclaims.JWTValidator
|
var jwtMiddleware *middleware.JWTMiddleware
|
||||||
|
|
||||||
if config.HttpConfig != nil && config.HttpConfig.AuthIssuer != "" && config.HttpConfig.AuthAudience != "" && validateURL(config.HttpConfig.AuthKeysLocation) {
|
if config.HttpConfig != nil && config.HttpConfig.AuthIssuer != "" && config.HttpConfig.AuthAudience != "" && validateURL(config.HttpConfig.AuthKeysLocation) {
|
||||||
jwtValidator, err = jwtclaims.NewJWTValidator(
|
jwtMiddleware, err = middleware.NewJwtMiddleware(
|
||||||
config.HttpConfig.AuthIssuer,
|
config.HttpConfig.AuthIssuer,
|
||||||
config.HttpConfig.AuthAudience,
|
config.HttpConfig.AuthAudience,
|
||||||
config.HttpConfig.AuthKeysLocation)
|
config.HttpConfig.AuthKeysLocation)
|
||||||
@ -87,7 +87,7 @@ func NewServer(config *Config, accountManager AccountManager, peersUpdateManager
|
|||||||
accountManager: accountManager,
|
accountManager: accountManager,
|
||||||
config: config,
|
config: config,
|
||||||
turnCredentialsManager: turnCredentialsManager,
|
turnCredentialsManager: turnCredentialsManager,
|
||||||
jwtValidator: jwtValidator,
|
jwtMiddleware: jwtMiddleware,
|
||||||
jwtClaimsExtractor: jwtClaimsExtractor,
|
jwtClaimsExtractor: jwtClaimsExtractor,
|
||||||
appMetrics: appMetrics,
|
appMetrics: appMetrics,
|
||||||
}, nil
|
}, nil
|
||||||
@ -188,11 +188,11 @@ func (s *GRPCServer) cancelPeerRoutines(peer *Peer) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (s *GRPCServer) validateToken(jwtToken string) (string, error) {
|
func (s *GRPCServer) validateToken(jwtToken string) (string, error) {
|
||||||
if s.jwtValidator == nil {
|
if s.jwtMiddleware == nil {
|
||||||
return "", status.Error(codes.Internal, "no jwt validator set")
|
return "", status.Error(codes.Internal, "no jwt middleware set")
|
||||||
}
|
}
|
||||||
|
|
||||||
token, err := s.jwtValidator.ValidateAndParse(jwtToken)
|
token, err := s.jwtMiddleware.ValidateAndParse(jwtToken)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return "", status.Errorf(codes.InvalidArgument, "invalid jwt token, err: %v", err)
|
return "", status.Errorf(codes.InvalidArgument, "invalid jwt token, err: %v", err)
|
||||||
}
|
}
|
||||||
@ -223,6 +223,7 @@ func mapError(err error) error {
|
|||||||
default:
|
default:
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
log.Errorf("got an unhandled error: %s", err)
|
||||||
return status.Errorf(codes.Internal, "failed handling request")
|
return status.Errorf(codes.Internal, "failed handling request")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
0
management/server/http/middleware/jwt.go
Normal file
0
management/server/http/middleware/jwt.go
Normal file
@ -2,10 +2,11 @@ package idp
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"fmt"
|
"fmt"
|
||||||
"github.com/netbirdio/netbird/management/server/telemetry"
|
|
||||||
"net/http"
|
"net/http"
|
||||||
"strings"
|
"strings"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
"github.com/netbirdio/netbird/management/server/telemetry"
|
||||||
)
|
)
|
||||||
|
|
||||||
// Manager idp manager interface
|
// Manager idp manager interface
|
||||||
@ -20,8 +21,9 @@ type Manager interface {
|
|||||||
|
|
||||||
// Config an idp configuration struct to be loaded from management server's config file
|
// Config an idp configuration struct to be loaded from management server's config file
|
||||||
type Config struct {
|
type Config struct {
|
||||||
ManagerType string
|
ManagerType string
|
||||||
Auth0ClientCredentials Auth0ClientConfig
|
Auth0ClientCredentials Auth0ClientConfig
|
||||||
|
KeycloakClientCredentials KeycloakClientConfig
|
||||||
}
|
}
|
||||||
|
|
||||||
// ManagerCredentials interface that authenticates using the credential of each type of idp
|
// ManagerCredentials interface that authenticates using the credential of each type of idp
|
||||||
@ -71,6 +73,8 @@ func NewManager(config Config, appMetrics telemetry.AppMetrics) (Manager, error)
|
|||||||
return nil, nil
|
return nil, nil
|
||||||
case "auth0":
|
case "auth0":
|
||||||
return NewAuth0Manager(config.Auth0ClientCredentials, appMetrics)
|
return NewAuth0Manager(config.Auth0ClientCredentials, appMetrics)
|
||||||
|
case "keycloak":
|
||||||
|
return NewKeycloakManager(config.KeycloakClientCredentials, appMetrics)
|
||||||
default:
|
default:
|
||||||
return nil, fmt.Errorf("invalid manager type: %s", config.ManagerType)
|
return nil, fmt.Errorf("invalid manager type: %s", config.ManagerType)
|
||||||
}
|
}
|
||||||
|
581
management/server/idp/keycloak.go
Normal file
581
management/server/idp/keycloak.go
Normal file
@ -0,0 +1,581 @@
|
|||||||
|
package idp
|
||||||
|
|
||||||
|
import (
|
||||||
|
"encoding/json"
|
||||||
|
"fmt"
|
||||||
|
"io"
|
||||||
|
"net/http"
|
||||||
|
"net/url"
|
||||||
|
"path"
|
||||||
|
"strconv"
|
||||||
|
"strings"
|
||||||
|
"sync"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/golang-jwt/jwt"
|
||||||
|
"github.com/netbirdio/netbird/management/server/telemetry"
|
||||||
|
log "github.com/sirupsen/logrus"
|
||||||
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
wtAccountID = "wt_account_id"
|
||||||
|
wtPendingInvite = "wt_pending_invite"
|
||||||
|
)
|
||||||
|
|
||||||
|
// KeycloakManager keycloak manager client instance.
|
||||||
|
type KeycloakManager struct {
|
||||||
|
adminEndpoint string
|
||||||
|
httpClient ManagerHTTPClient
|
||||||
|
credentials ManagerCredentials
|
||||||
|
helper ManagerHelper
|
||||||
|
appMetrics telemetry.AppMetrics
|
||||||
|
}
|
||||||
|
|
||||||
|
// KeycloakClientConfig keycloak manager client configurations.
|
||||||
|
type KeycloakClientConfig struct {
|
||||||
|
ClientID string
|
||||||
|
ClientSecret string
|
||||||
|
AdminEndpoint string
|
||||||
|
TokenEndpoint string
|
||||||
|
GrantType string
|
||||||
|
}
|
||||||
|
|
||||||
|
// KeycloakCredentials keycloak authentication information.
|
||||||
|
type KeycloakCredentials struct {
|
||||||
|
clientConfig KeycloakClientConfig
|
||||||
|
helper ManagerHelper
|
||||||
|
httpClient ManagerHTTPClient
|
||||||
|
jwtToken JWTToken
|
||||||
|
mux sync.Mutex
|
||||||
|
appMetrics telemetry.AppMetrics
|
||||||
|
}
|
||||||
|
|
||||||
|
// keycloakUserCredential describe the authentication method for,
|
||||||
|
// newly created user profile.
|
||||||
|
type keycloakUserCredential struct {
|
||||||
|
Type string `json:"type"`
|
||||||
|
Value string `json:"value"`
|
||||||
|
Temporary bool `json:"temporary"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// keycloakUserAttributes holds additional user data fields.
|
||||||
|
type keycloakUserAttributes map[string][]string
|
||||||
|
|
||||||
|
// createUserRequest is a user create request.
|
||||||
|
type keycloakCreateUserRequest struct {
|
||||||
|
Email string `json:"email"`
|
||||||
|
Username string `json:"username"`
|
||||||
|
Enabled bool `json:"enabled"`
|
||||||
|
EmailVerified bool `json:"emailVerified"`
|
||||||
|
Credentials []keycloakUserCredential `json:"credentials"`
|
||||||
|
Attributes keycloakUserAttributes `json:"attributes"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// keycloakProfile represents an keycloak user profile response.
|
||||||
|
type keycloakProfile struct {
|
||||||
|
ID string `json:"id"`
|
||||||
|
CreatedTimestamp int64 `json:"createdTimestamp"`
|
||||||
|
Username string `json:"username"`
|
||||||
|
Email string `json:"email"`
|
||||||
|
Attributes keycloakUserAttributes `json:"attributes"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewKeycloakManager creates a new instance of the KeycloakManager.
|
||||||
|
func NewKeycloakManager(config KeycloakClientConfig, appMetrics telemetry.AppMetrics) (*KeycloakManager, error) {
|
||||||
|
httpTransport := http.DefaultTransport.(*http.Transport).Clone()
|
||||||
|
httpTransport.MaxIdleConns = 5
|
||||||
|
|
||||||
|
httpClient := &http.Client{
|
||||||
|
Timeout: 10 * time.Second,
|
||||||
|
Transport: httpTransport,
|
||||||
|
}
|
||||||
|
|
||||||
|
helper := JsonParser{}
|
||||||
|
|
||||||
|
if config.ClientID == "" || config.ClientSecret == "" || config.GrantType == "" || config.AdminEndpoint == "" || config.TokenEndpoint == "" {
|
||||||
|
return nil, fmt.Errorf("keycloak idp configuration is not complete")
|
||||||
|
}
|
||||||
|
|
||||||
|
if config.GrantType != "client_credentials" {
|
||||||
|
return nil, fmt.Errorf("keycloak idp configuration failed. Grant Type should be client_credentials")
|
||||||
|
}
|
||||||
|
|
||||||
|
credentials := &KeycloakCredentials{
|
||||||
|
clientConfig: config,
|
||||||
|
httpClient: httpClient,
|
||||||
|
helper: helper,
|
||||||
|
appMetrics: appMetrics,
|
||||||
|
}
|
||||||
|
|
||||||
|
return &KeycloakManager{
|
||||||
|
adminEndpoint: config.AdminEndpoint,
|
||||||
|
httpClient: httpClient,
|
||||||
|
credentials: credentials,
|
||||||
|
helper: helper,
|
||||||
|
appMetrics: appMetrics,
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// jwtStillValid returns true if the token still valid and have enough time to be used and get a response from keycloak.
|
||||||
|
func (kc *KeycloakCredentials) jwtStillValid() bool {
|
||||||
|
return !kc.jwtToken.expiresInTime.IsZero() && time.Now().Add(5*time.Second).Before(kc.jwtToken.expiresInTime)
|
||||||
|
}
|
||||||
|
|
||||||
|
// requestJWTToken performs request to get jwt token.
|
||||||
|
func (kc *KeycloakCredentials) requestJWTToken() (*http.Response, error) {
|
||||||
|
data := url.Values{}
|
||||||
|
data.Set("client_id", kc.clientConfig.ClientID)
|
||||||
|
data.Set("client_secret", kc.clientConfig.ClientSecret)
|
||||||
|
data.Set("grant_type", kc.clientConfig.GrantType)
|
||||||
|
|
||||||
|
payload := strings.NewReader(data.Encode())
|
||||||
|
req, err := http.NewRequest(http.MethodPost, kc.clientConfig.TokenEndpoint, payload)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
req.Header.Add("content-type", "application/x-www-form-urlencoded")
|
||||||
|
|
||||||
|
log.Debug("requesting new jwt token for keycloak idp manager")
|
||||||
|
|
||||||
|
resp, err := kc.httpClient.Do(req)
|
||||||
|
if err != nil {
|
||||||
|
if kc.appMetrics != nil {
|
||||||
|
kc.appMetrics.IDPMetrics().CountRequestError()
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
if resp.StatusCode != http.StatusOK {
|
||||||
|
return nil, fmt.Errorf("unable to get keycloak token, statusCode %d", resp.StatusCode)
|
||||||
|
}
|
||||||
|
|
||||||
|
return resp, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// parseRequestJWTResponse parses jwt raw response body and extracts token and expires in seconds
|
||||||
|
func (kc *KeycloakCredentials) parseRequestJWTResponse(rawBody io.ReadCloser) (JWTToken, error) {
|
||||||
|
jwtToken := JWTToken{}
|
||||||
|
body, err := io.ReadAll(rawBody)
|
||||||
|
if err != nil {
|
||||||
|
return jwtToken, err
|
||||||
|
}
|
||||||
|
|
||||||
|
err = kc.helper.Unmarshal(body, &jwtToken)
|
||||||
|
if err != nil {
|
||||||
|
return jwtToken, err
|
||||||
|
}
|
||||||
|
|
||||||
|
if jwtToken.ExpiresIn == 0 && jwtToken.AccessToken == "" {
|
||||||
|
return jwtToken, fmt.Errorf("error while reading response body, expires_in: %d and access_token: %s", jwtToken.ExpiresIn, jwtToken.AccessToken)
|
||||||
|
}
|
||||||
|
|
||||||
|
data, err := jwt.DecodeSegment(strings.Split(jwtToken.AccessToken, ".")[1])
|
||||||
|
if err != nil {
|
||||||
|
return jwtToken, err
|
||||||
|
}
|
||||||
|
|
||||||
|
// Exp maps into exp from jwt token
|
||||||
|
var IssuedAt struct{ Exp int64 }
|
||||||
|
err = kc.helper.Unmarshal(data, &IssuedAt)
|
||||||
|
if err != nil {
|
||||||
|
return jwtToken, err
|
||||||
|
}
|
||||||
|
jwtToken.expiresInTime = time.Unix(IssuedAt.Exp, 0)
|
||||||
|
|
||||||
|
return jwtToken, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Authenticate retrieves access token to use the keycloak Management API.
|
||||||
|
func (kc *KeycloakCredentials) Authenticate() (JWTToken, error) {
|
||||||
|
kc.mux.Lock()
|
||||||
|
defer kc.mux.Unlock()
|
||||||
|
|
||||||
|
if kc.appMetrics != nil {
|
||||||
|
kc.appMetrics.IDPMetrics().CountAuthenticate()
|
||||||
|
}
|
||||||
|
|
||||||
|
// reuse the token without requesting a new one if it is not expired,
|
||||||
|
// and if expiry time is sufficient time available to make a request.
|
||||||
|
if kc.jwtStillValid() {
|
||||||
|
return kc.jwtToken, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
resp, err := kc.requestJWTToken()
|
||||||
|
if err != nil {
|
||||||
|
return kc.jwtToken, err
|
||||||
|
}
|
||||||
|
defer resp.Body.Close()
|
||||||
|
|
||||||
|
jwtToken, err := kc.parseRequestJWTResponse(resp.Body)
|
||||||
|
if err != nil {
|
||||||
|
return kc.jwtToken, err
|
||||||
|
}
|
||||||
|
|
||||||
|
kc.jwtToken = jwtToken
|
||||||
|
|
||||||
|
return kc.jwtToken, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// CreateUser creates a new user in keycloak Idp and sends an invite.
|
||||||
|
func (km *KeycloakManager) CreateUser(email string, name string, accountID string) (*UserData, error) {
|
||||||
|
jwtToken, err := km.credentials.Authenticate()
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
invite := true
|
||||||
|
appMetadata := AppMetadata{
|
||||||
|
WTAccountID: accountID,
|
||||||
|
WTPendingInvite: &invite,
|
||||||
|
}
|
||||||
|
|
||||||
|
payloadString, err := buildKeycloakCreateUserRequestPayload(email, name, appMetadata)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
reqURL := fmt.Sprintf("%s/users", km.adminEndpoint)
|
||||||
|
payload := strings.NewReader(payloadString)
|
||||||
|
|
||||||
|
req, err := http.NewRequest(http.MethodPost, reqURL, payload)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
req.Header.Add("authorization", "Bearer "+jwtToken.AccessToken)
|
||||||
|
req.Header.Add("content-type", "application/json")
|
||||||
|
|
||||||
|
if km.appMetrics != nil {
|
||||||
|
km.appMetrics.IDPMetrics().CountCreateUser()
|
||||||
|
}
|
||||||
|
|
||||||
|
resp, err := km.httpClient.Do(req)
|
||||||
|
if err != nil {
|
||||||
|
if km.appMetrics != nil {
|
||||||
|
km.appMetrics.IDPMetrics().CountRequestError()
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
defer resp.Body.Close()
|
||||||
|
|
||||||
|
if resp.StatusCode != http.StatusCreated {
|
||||||
|
if km.appMetrics != nil {
|
||||||
|
km.appMetrics.IDPMetrics().CountRequestStatusError()
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil, fmt.Errorf("unable to create user, statusCode %d", resp.StatusCode)
|
||||||
|
}
|
||||||
|
|
||||||
|
locationHeader := resp.Header.Get("location")
|
||||||
|
userID, err := extractUserIDFromLocationHeader(locationHeader)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
return km.GetUserDataByID(userID, appMetadata)
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetUserByEmail searches users with a given email.
|
||||||
|
// If no users have been found, this function returns an empty list.
|
||||||
|
func (km *KeycloakManager) GetUserByEmail(email string) ([]*UserData, error) {
|
||||||
|
q := url.Values{}
|
||||||
|
q.Add("email", email)
|
||||||
|
q.Add("exact", "true")
|
||||||
|
|
||||||
|
body, err := km.get("users", q)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
if km.appMetrics != nil {
|
||||||
|
km.appMetrics.IDPMetrics().CountGetUserByEmail()
|
||||||
|
}
|
||||||
|
|
||||||
|
profiles := make([]keycloakProfile, 0)
|
||||||
|
err = km.helper.Unmarshal(body, &profiles)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
users := make([]*UserData, 0)
|
||||||
|
for _, profile := range profiles {
|
||||||
|
users = append(users, profile.userData())
|
||||||
|
}
|
||||||
|
|
||||||
|
return users, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetUserDataByID requests user data from keycloak via ID.
|
||||||
|
func (km *KeycloakManager) GetUserDataByID(userID string, appMetadata AppMetadata) (*UserData, error) {
|
||||||
|
body, err := km.get("users/"+userID, nil)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
if km.appMetrics != nil {
|
||||||
|
km.appMetrics.IDPMetrics().CountGetUserDataByID()
|
||||||
|
}
|
||||||
|
|
||||||
|
var profile keycloakProfile
|
||||||
|
err = km.helper.Unmarshal(body, &profile)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
return profile.userData(), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetAccount returns all the users for a given profile.
|
||||||
|
func (km *KeycloakManager) GetAccount(accountID string) ([]*UserData, error) {
|
||||||
|
q := url.Values{}
|
||||||
|
q.Add("q", wtAccountID+":"+accountID)
|
||||||
|
|
||||||
|
body, err := km.get("users", q)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
if km.appMetrics != nil {
|
||||||
|
km.appMetrics.IDPMetrics().CountGetAccount()
|
||||||
|
}
|
||||||
|
|
||||||
|
profiles := make([]keycloakProfile, 0)
|
||||||
|
err = km.helper.Unmarshal(body, &profiles)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
users := make([]*UserData, 0)
|
||||||
|
for _, profile := range profiles {
|
||||||
|
users = append(users, profile.userData())
|
||||||
|
}
|
||||||
|
|
||||||
|
return users, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetAllAccounts gets all registered accounts with corresponding user data.
|
||||||
|
// It returns a list of users indexed by accountID.
|
||||||
|
func (km *KeycloakManager) GetAllAccounts() (map[string][]*UserData, error) {
|
||||||
|
totalUsers, err := km.totalUsersCount()
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
q := url.Values{}
|
||||||
|
q.Add("max", fmt.Sprint(*totalUsers))
|
||||||
|
|
||||||
|
body, err := km.get("users", q)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
if km.appMetrics != nil {
|
||||||
|
km.appMetrics.IDPMetrics().CountGetAllAccounts()
|
||||||
|
}
|
||||||
|
|
||||||
|
profiles := make([]keycloakProfile, 0)
|
||||||
|
err = km.helper.Unmarshal(body, &profiles)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
indexedUsers := make(map[string][]*UserData)
|
||||||
|
for _, profile := range profiles {
|
||||||
|
userData := profile.userData()
|
||||||
|
|
||||||
|
accountID := userData.AppMetadata.WTAccountID
|
||||||
|
if accountID != "" {
|
||||||
|
if _, ok := indexedUsers[accountID]; !ok {
|
||||||
|
indexedUsers[accountID] = make([]*UserData, 0)
|
||||||
|
}
|
||||||
|
indexedUsers[accountID] = append(indexedUsers[accountID], userData)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return indexedUsers, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// UpdateUserAppMetadata updates user app metadata based on userID and metadata map.
|
||||||
|
func (km *KeycloakManager) UpdateUserAppMetadata(userID string, appMetadata AppMetadata) error {
|
||||||
|
jwtToken, err := km.credentials.Authenticate()
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
attrs := keycloakUserAttributes{}
|
||||||
|
attrs.Set(wtAccountID, appMetadata.WTAccountID)
|
||||||
|
if appMetadata.WTPendingInvite != nil {
|
||||||
|
attrs.Set(wtPendingInvite, strconv.FormatBool(*appMetadata.WTPendingInvite))
|
||||||
|
} else {
|
||||||
|
attrs.Set(wtPendingInvite, "false")
|
||||||
|
}
|
||||||
|
|
||||||
|
reqURL := fmt.Sprintf("%s/users/%s", km.adminEndpoint, userID)
|
||||||
|
data, err := km.helper.Marshal(map[string]any{
|
||||||
|
"attributes": attrs,
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
payload := strings.NewReader(string(data))
|
||||||
|
|
||||||
|
req, err := http.NewRequest(http.MethodPut, reqURL, payload)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
req.Header.Add("authorization", "Bearer "+jwtToken.AccessToken)
|
||||||
|
req.Header.Add("content-type", "application/json")
|
||||||
|
|
||||||
|
log.Debugf("updating IdP metadata for user %s", userID)
|
||||||
|
|
||||||
|
resp, err := km.httpClient.Do(req)
|
||||||
|
if err != nil {
|
||||||
|
if km.appMetrics != nil {
|
||||||
|
km.appMetrics.IDPMetrics().CountRequestError()
|
||||||
|
}
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
defer resp.Body.Close()
|
||||||
|
|
||||||
|
if km.appMetrics != nil {
|
||||||
|
km.appMetrics.IDPMetrics().CountUpdateUserAppMetadata()
|
||||||
|
}
|
||||||
|
|
||||||
|
if resp.StatusCode != http.StatusNoContent {
|
||||||
|
return fmt.Errorf("unable to update the appMetadata, statusCode %d", resp.StatusCode)
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func buildKeycloakCreateUserRequestPayload(email string, name string, appMetadata AppMetadata) (string, error) {
|
||||||
|
attrs := keycloakUserAttributes{}
|
||||||
|
attrs.Set(wtAccountID, appMetadata.WTAccountID)
|
||||||
|
attrs.Set(wtPendingInvite, strconv.FormatBool(*appMetadata.WTPendingInvite))
|
||||||
|
|
||||||
|
req := &keycloakCreateUserRequest{
|
||||||
|
Email: email,
|
||||||
|
Username: name,
|
||||||
|
Enabled: true,
|
||||||
|
EmailVerified: true,
|
||||||
|
Credentials: []keycloakUserCredential{
|
||||||
|
{
|
||||||
|
Type: "password",
|
||||||
|
Value: GeneratePassword(8, 1, 1, 1),
|
||||||
|
Temporary: false,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
Attributes: attrs,
|
||||||
|
}
|
||||||
|
|
||||||
|
str, err := json.Marshal(req)
|
||||||
|
if err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
|
||||||
|
return string(str), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// get perform Get requests.
|
||||||
|
func (km *KeycloakManager) get(resource string, q url.Values) ([]byte, error) {
|
||||||
|
jwtToken, err := km.credentials.Authenticate()
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
reqURL := fmt.Sprintf("%s/%s?%s", km.adminEndpoint, resource, q.Encode())
|
||||||
|
req, err := http.NewRequest(http.MethodGet, reqURL, nil)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
req.Header.Add("authorization", "Bearer "+jwtToken.AccessToken)
|
||||||
|
req.Header.Add("content-type", "application/json")
|
||||||
|
|
||||||
|
resp, err := km.httpClient.Do(req)
|
||||||
|
if err != nil {
|
||||||
|
if km.appMetrics != nil {
|
||||||
|
km.appMetrics.IDPMetrics().CountRequestError()
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
defer resp.Body.Close()
|
||||||
|
|
||||||
|
if resp.StatusCode != http.StatusOK {
|
||||||
|
if km.appMetrics != nil {
|
||||||
|
km.appMetrics.IDPMetrics().CountRequestStatusError()
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil, fmt.Errorf("unable to get %s, statusCode %d", reqURL, resp.StatusCode)
|
||||||
|
}
|
||||||
|
|
||||||
|
return io.ReadAll(resp.Body)
|
||||||
|
}
|
||||||
|
|
||||||
|
// totalUsersCount returns the total count of all user created.
|
||||||
|
// Used when fetching all registered accounts with pagination.
|
||||||
|
func (km *KeycloakManager) totalUsersCount() (*int, error) {
|
||||||
|
body, err := km.get("users/count", nil)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
count, err := strconv.Atoi(string(body))
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
return &count, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// extractUserIDFromLocationHeader extracts the user ID from the location,
|
||||||
|
// header once the user is created successfully
|
||||||
|
func extractUserIDFromLocationHeader(locationHeader string) (string, error) {
|
||||||
|
userURL, err := url.Parse(locationHeader)
|
||||||
|
if err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
|
||||||
|
return path.Base(userURL.Path), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// userData construct user data from keycloak profile.
|
||||||
|
func (kp keycloakProfile) userData() *UserData {
|
||||||
|
accountID := kp.Attributes.Get(wtAccountID)
|
||||||
|
pendingInvite, err := strconv.ParseBool(kp.Attributes.Get(wtPendingInvite))
|
||||||
|
if err != nil {
|
||||||
|
pendingInvite = false
|
||||||
|
}
|
||||||
|
|
||||||
|
return &UserData{
|
||||||
|
Email: kp.Email,
|
||||||
|
Name: kp.Username,
|
||||||
|
ID: kp.ID,
|
||||||
|
AppMetadata: AppMetadata{
|
||||||
|
WTAccountID: accountID,
|
||||||
|
WTPendingInvite: &pendingInvite,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Set sets the key to value. It replaces any existing
|
||||||
|
// values.
|
||||||
|
func (ka keycloakUserAttributes) Set(key, value string) {
|
||||||
|
ka[key] = []string{value}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Get returns the first value associated with the given key.
|
||||||
|
// If there are no values associated with the key, Get returns
|
||||||
|
// the empty string.
|
||||||
|
func (ka keycloakUserAttributes) Get(key string) string {
|
||||||
|
if ka == nil {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
|
||||||
|
values := ka[key]
|
||||||
|
if len(values) == 0 {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
return values[0]
|
||||||
|
}
|
401
management/server/idp/keycloak_test.go
Normal file
401
management/server/idp/keycloak_test.go
Normal file
@ -0,0 +1,401 @@
|
|||||||
|
package idp
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"io"
|
||||||
|
"strings"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/netbirdio/netbird/management/server/telemetry"
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestNewKeycloakManager(t *testing.T) {
|
||||||
|
type test struct {
|
||||||
|
name string
|
||||||
|
inputConfig KeycloakClientConfig
|
||||||
|
assertErrFunc require.ErrorAssertionFunc
|
||||||
|
assertErrFuncMessage string
|
||||||
|
}
|
||||||
|
|
||||||
|
defaultTestConfig := KeycloakClientConfig{
|
||||||
|
ClientID: "client_id",
|
||||||
|
ClientSecret: "client_secret",
|
||||||
|
AdminEndpoint: "https://localhost:8080/auth/admin/realms/test123",
|
||||||
|
TokenEndpoint: "https://localhost:8080/auth/realms/test123/protocol/openid-connect/token",
|
||||||
|
GrantType: "client_credentials",
|
||||||
|
}
|
||||||
|
|
||||||
|
testCase1 := test{
|
||||||
|
name: "Good Configuration",
|
||||||
|
inputConfig: defaultTestConfig,
|
||||||
|
assertErrFunc: require.NoError,
|
||||||
|
assertErrFuncMessage: "shouldn't return error",
|
||||||
|
}
|
||||||
|
|
||||||
|
testCase2Config := defaultTestConfig
|
||||||
|
testCase2Config.ClientID = ""
|
||||||
|
|
||||||
|
testCase2 := test{
|
||||||
|
name: "Missing ClientID Configuration",
|
||||||
|
inputConfig: testCase2Config,
|
||||||
|
assertErrFunc: require.Error,
|
||||||
|
assertErrFuncMessage: "should return error when field empty",
|
||||||
|
}
|
||||||
|
|
||||||
|
testCase5Config := defaultTestConfig
|
||||||
|
testCase5Config.GrantType = "authorization_code"
|
||||||
|
|
||||||
|
testCase5 := test{
|
||||||
|
name: "Wrong GrantType",
|
||||||
|
inputConfig: testCase5Config,
|
||||||
|
assertErrFunc: require.Error,
|
||||||
|
assertErrFuncMessage: "should return error when wrong grant type",
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, testCase := range []test{testCase1, testCase2, testCase5} {
|
||||||
|
t.Run(testCase.name, func(t *testing.T) {
|
||||||
|
_, err := NewKeycloakManager(testCase.inputConfig, &telemetry.MockAppMetrics{})
|
||||||
|
testCase.assertErrFunc(t, err, testCase.assertErrFuncMessage)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
type mockKeycloakCredentials struct {
|
||||||
|
jwtToken JWTToken
|
||||||
|
err error
|
||||||
|
}
|
||||||
|
|
||||||
|
func (mc *mockKeycloakCredentials) Authenticate() (JWTToken, error) {
|
||||||
|
return mc.jwtToken, mc.err
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestKeycloakRequestJWTToken(t *testing.T) {
|
||||||
|
|
||||||
|
type requestJWTTokenTest struct {
|
||||||
|
name string
|
||||||
|
inputCode int
|
||||||
|
inputRespBody string
|
||||||
|
helper ManagerHelper
|
||||||
|
expectedFuncExitErrDiff error
|
||||||
|
expectedToken string
|
||||||
|
}
|
||||||
|
exp := 5
|
||||||
|
token := newTestJWT(t, exp)
|
||||||
|
|
||||||
|
requestJWTTokenTesttCase1 := requestJWTTokenTest{
|
||||||
|
name: "Good JWT Response",
|
||||||
|
inputCode: 200,
|
||||||
|
inputRespBody: fmt.Sprintf("{\"access_token\":\"%s\",\"scope\":\"read:users\",\"expires_in\":%d,\"token_type\":\"Bearer\"}", token, exp),
|
||||||
|
helper: JsonParser{},
|
||||||
|
expectedToken: token,
|
||||||
|
}
|
||||||
|
requestJWTTokenTestCase2 := requestJWTTokenTest{
|
||||||
|
name: "Request Bad Status Code",
|
||||||
|
inputCode: 400,
|
||||||
|
inputRespBody: "{}",
|
||||||
|
helper: JsonParser{},
|
||||||
|
expectedFuncExitErrDiff: fmt.Errorf("unable to get keycloak token, statusCode 400"),
|
||||||
|
expectedToken: "",
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, testCase := range []requestJWTTokenTest{requestJWTTokenTesttCase1, requestJWTTokenTestCase2} {
|
||||||
|
t.Run(testCase.name, func(t *testing.T) {
|
||||||
|
|
||||||
|
jwtReqClient := mockHTTPClient{
|
||||||
|
resBody: testCase.inputRespBody,
|
||||||
|
code: testCase.inputCode,
|
||||||
|
}
|
||||||
|
config := KeycloakClientConfig{}
|
||||||
|
|
||||||
|
creds := KeycloakCredentials{
|
||||||
|
clientConfig: config,
|
||||||
|
httpClient: &jwtReqClient,
|
||||||
|
helper: testCase.helper,
|
||||||
|
}
|
||||||
|
|
||||||
|
resp, err := creds.requestJWTToken()
|
||||||
|
if err != nil {
|
||||||
|
if testCase.expectedFuncExitErrDiff != nil {
|
||||||
|
assert.EqualError(t, err, testCase.expectedFuncExitErrDiff.Error(), "errors should be the same")
|
||||||
|
} else {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
body, err := io.ReadAll(resp.Body)
|
||||||
|
assert.NoError(t, err, "unable to read the response body")
|
||||||
|
|
||||||
|
jwtToken := JWTToken{}
|
||||||
|
err = testCase.helper.Unmarshal(body, &jwtToken)
|
||||||
|
assert.NoError(t, err, "unable to parse the json input")
|
||||||
|
|
||||||
|
assert.Equalf(t, testCase.expectedToken, jwtToken.AccessToken, "two tokens should be the same")
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestKeycloakParseRequestJWTResponse(t *testing.T) {
|
||||||
|
type parseRequestJWTResponseTest struct {
|
||||||
|
name string
|
||||||
|
inputRespBody string
|
||||||
|
helper ManagerHelper
|
||||||
|
expectedToken string
|
||||||
|
expectedExpiresIn int
|
||||||
|
assertErrFunc assert.ErrorAssertionFunc
|
||||||
|
assertErrFuncMessage string
|
||||||
|
}
|
||||||
|
|
||||||
|
exp := 100
|
||||||
|
token := newTestJWT(t, exp)
|
||||||
|
|
||||||
|
parseRequestJWTResponseTestCase1 := parseRequestJWTResponseTest{
|
||||||
|
name: "Parse Good JWT Body",
|
||||||
|
inputRespBody: fmt.Sprintf("{\"access_token\":\"%s\",\"scope\":\"read:users\",\"expires_in\":%d,\"token_type\":\"Bearer\"}", token, exp),
|
||||||
|
helper: JsonParser{},
|
||||||
|
expectedToken: token,
|
||||||
|
expectedExpiresIn: exp,
|
||||||
|
assertErrFunc: assert.NoError,
|
||||||
|
assertErrFuncMessage: "no error was expected",
|
||||||
|
}
|
||||||
|
parseRequestJWTResponseTestCase2 := parseRequestJWTResponseTest{
|
||||||
|
name: "Parse Bad json JWT Body",
|
||||||
|
inputRespBody: "",
|
||||||
|
helper: JsonParser{},
|
||||||
|
expectedToken: "",
|
||||||
|
expectedExpiresIn: 0,
|
||||||
|
assertErrFunc: assert.Error,
|
||||||
|
assertErrFuncMessage: "json error was expected",
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, testCase := range []parseRequestJWTResponseTest{parseRequestJWTResponseTestCase1, parseRequestJWTResponseTestCase2} {
|
||||||
|
t.Run(testCase.name, func(t *testing.T) {
|
||||||
|
rawBody := io.NopCloser(strings.NewReader(testCase.inputRespBody))
|
||||||
|
config := KeycloakClientConfig{}
|
||||||
|
|
||||||
|
creds := KeycloakCredentials{
|
||||||
|
clientConfig: config,
|
||||||
|
helper: testCase.helper,
|
||||||
|
}
|
||||||
|
jwtToken, err := creds.parseRequestJWTResponse(rawBody)
|
||||||
|
testCase.assertErrFunc(t, err, testCase.assertErrFuncMessage)
|
||||||
|
|
||||||
|
assert.Equalf(t, testCase.expectedToken, jwtToken.AccessToken, "two tokens should be the same")
|
||||||
|
assert.Equalf(t, testCase.expectedExpiresIn, jwtToken.ExpiresIn, "the two expire times should be the same")
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestKeycloakJwtStillValid(t *testing.T) {
|
||||||
|
type jwtStillValidTest struct {
|
||||||
|
name string
|
||||||
|
inputTime time.Time
|
||||||
|
expectedResult bool
|
||||||
|
message string
|
||||||
|
}
|
||||||
|
|
||||||
|
jwtStillValidTestCase1 := jwtStillValidTest{
|
||||||
|
name: "JWT still valid",
|
||||||
|
inputTime: time.Now().Add(10 * time.Second),
|
||||||
|
expectedResult: true,
|
||||||
|
message: "should be true",
|
||||||
|
}
|
||||||
|
jwtStillValidTestCase2 := jwtStillValidTest{
|
||||||
|
name: "JWT is invalid",
|
||||||
|
inputTime: time.Now(),
|
||||||
|
expectedResult: false,
|
||||||
|
message: "should be false",
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, testCase := range []jwtStillValidTest{jwtStillValidTestCase1, jwtStillValidTestCase2} {
|
||||||
|
t.Run(testCase.name, func(t *testing.T) {
|
||||||
|
config := KeycloakClientConfig{}
|
||||||
|
|
||||||
|
creds := KeycloakCredentials{
|
||||||
|
clientConfig: config,
|
||||||
|
}
|
||||||
|
creds.jwtToken.expiresInTime = testCase.inputTime
|
||||||
|
|
||||||
|
assert.Equalf(t, testCase.expectedResult, creds.jwtStillValid(), testCase.message)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestKeycloakAuthenticate(t *testing.T) {
|
||||||
|
type authenticateTest struct {
|
||||||
|
name string
|
||||||
|
inputCode int
|
||||||
|
inputResBody string
|
||||||
|
inputExpireToken time.Time
|
||||||
|
helper ManagerHelper
|
||||||
|
expectedFuncExitErrDiff error
|
||||||
|
expectedCode int
|
||||||
|
expectedToken string
|
||||||
|
}
|
||||||
|
exp := 5
|
||||||
|
token := newTestJWT(t, exp)
|
||||||
|
|
||||||
|
authenticateTestCase1 := authenticateTest{
|
||||||
|
name: "Get Cached token",
|
||||||
|
inputExpireToken: time.Now().Add(30 * time.Second),
|
||||||
|
helper: JsonParser{},
|
||||||
|
expectedFuncExitErrDiff: nil,
|
||||||
|
expectedCode: 200,
|
||||||
|
expectedToken: "",
|
||||||
|
}
|
||||||
|
|
||||||
|
authenticateTestCase2 := authenticateTest{
|
||||||
|
name: "Get Good JWT Response",
|
||||||
|
inputCode: 200,
|
||||||
|
inputResBody: fmt.Sprintf("{\"access_token\":\"%s\",\"scope\":\"read:users\",\"expires_in\":%d,\"token_type\":\"Bearer\"}", token, exp),
|
||||||
|
helper: JsonParser{},
|
||||||
|
expectedCode: 200,
|
||||||
|
expectedToken: token,
|
||||||
|
}
|
||||||
|
|
||||||
|
authenticateTestCase3 := authenticateTest{
|
||||||
|
name: "Get Bad Status Code",
|
||||||
|
inputCode: 400,
|
||||||
|
inputResBody: "{}",
|
||||||
|
helper: JsonParser{},
|
||||||
|
expectedFuncExitErrDiff: fmt.Errorf("unable to get keycloak token, statusCode 400"),
|
||||||
|
expectedCode: 200,
|
||||||
|
expectedToken: "",
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, testCase := range []authenticateTest{authenticateTestCase1, authenticateTestCase2, authenticateTestCase3} {
|
||||||
|
t.Run(testCase.name, func(t *testing.T) {
|
||||||
|
|
||||||
|
jwtReqClient := mockHTTPClient{
|
||||||
|
resBody: testCase.inputResBody,
|
||||||
|
code: testCase.inputCode,
|
||||||
|
}
|
||||||
|
config := KeycloakClientConfig{}
|
||||||
|
|
||||||
|
creds := KeycloakCredentials{
|
||||||
|
clientConfig: config,
|
||||||
|
httpClient: &jwtReqClient,
|
||||||
|
helper: testCase.helper,
|
||||||
|
}
|
||||||
|
creds.jwtToken.expiresInTime = testCase.inputExpireToken
|
||||||
|
|
||||||
|
_, err := creds.Authenticate()
|
||||||
|
if err != nil {
|
||||||
|
if testCase.expectedFuncExitErrDiff != nil {
|
||||||
|
assert.EqualError(t, err, testCase.expectedFuncExitErrDiff.Error(), "errors should be the same")
|
||||||
|
} else {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
assert.Equalf(t, testCase.expectedToken, creds.jwtToken.AccessToken, "two tokens should be the same")
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestKeycloakUpdateUserAppMetadata(t *testing.T) {
|
||||||
|
type updateUserAppMetadataTest struct {
|
||||||
|
name string
|
||||||
|
inputReqBody string
|
||||||
|
expectedReqBody string
|
||||||
|
appMetadata AppMetadata
|
||||||
|
statusCode int
|
||||||
|
helper ManagerHelper
|
||||||
|
managerCreds ManagerCredentials
|
||||||
|
assertErrFunc assert.ErrorAssertionFunc
|
||||||
|
assertErrFuncMessage string
|
||||||
|
}
|
||||||
|
|
||||||
|
appMetadata := AppMetadata{WTAccountID: "ok"}
|
||||||
|
|
||||||
|
updateUserAppMetadataTestCase1 := updateUserAppMetadataTest{
|
||||||
|
name: "Bad Authentication",
|
||||||
|
expectedReqBody: "",
|
||||||
|
appMetadata: appMetadata,
|
||||||
|
statusCode: 400,
|
||||||
|
helper: JsonParser{},
|
||||||
|
managerCreds: &mockKeycloakCredentials{
|
||||||
|
jwtToken: JWTToken{},
|
||||||
|
err: fmt.Errorf("error"),
|
||||||
|
},
|
||||||
|
assertErrFunc: assert.Error,
|
||||||
|
assertErrFuncMessage: "should return error",
|
||||||
|
}
|
||||||
|
|
||||||
|
updateUserAppMetadataTestCase2 := updateUserAppMetadataTest{
|
||||||
|
name: "Bad Status Code",
|
||||||
|
expectedReqBody: fmt.Sprintf("{\"attributes\":{\"wt_account_id\":[\"%s\"],\"wt_pending_invite\":[\"false\"]}}", appMetadata.WTAccountID),
|
||||||
|
appMetadata: appMetadata,
|
||||||
|
statusCode: 400,
|
||||||
|
helper: JsonParser{},
|
||||||
|
managerCreds: &mockKeycloakCredentials{
|
||||||
|
jwtToken: JWTToken{},
|
||||||
|
},
|
||||||
|
assertErrFunc: assert.Error,
|
||||||
|
assertErrFuncMessage: "should return error",
|
||||||
|
}
|
||||||
|
|
||||||
|
updateUserAppMetadataTestCase3 := updateUserAppMetadataTest{
|
||||||
|
name: "Bad Response Parsing",
|
||||||
|
statusCode: 400,
|
||||||
|
helper: &mockJsonParser{marshalErrorString: "error"},
|
||||||
|
managerCreds: &mockKeycloakCredentials{
|
||||||
|
jwtToken: JWTToken{},
|
||||||
|
},
|
||||||
|
assertErrFunc: assert.Error,
|
||||||
|
assertErrFuncMessage: "should return error",
|
||||||
|
}
|
||||||
|
|
||||||
|
updateUserAppMetadataTestCase4 := updateUserAppMetadataTest{
|
||||||
|
name: "Good request",
|
||||||
|
expectedReqBody: fmt.Sprintf("{\"attributes\":{\"wt_account_id\":[\"%s\"],\"wt_pending_invite\":[\"false\"]}}", appMetadata.WTAccountID),
|
||||||
|
appMetadata: appMetadata,
|
||||||
|
statusCode: 204,
|
||||||
|
helper: JsonParser{},
|
||||||
|
managerCreds: &mockKeycloakCredentials{
|
||||||
|
jwtToken: JWTToken{},
|
||||||
|
},
|
||||||
|
assertErrFunc: assert.NoError,
|
||||||
|
assertErrFuncMessage: "shouldn't return error",
|
||||||
|
}
|
||||||
|
|
||||||
|
invite := true
|
||||||
|
updateUserAppMetadataTestCase5 := updateUserAppMetadataTest{
|
||||||
|
name: "Update Pending Invite",
|
||||||
|
expectedReqBody: fmt.Sprintf("{\"attributes\":{\"wt_account_id\":[\"%s\"],\"wt_pending_invite\":[\"true\"]}}", appMetadata.WTAccountID),
|
||||||
|
appMetadata: AppMetadata{
|
||||||
|
WTAccountID: "ok",
|
||||||
|
WTPendingInvite: &invite,
|
||||||
|
},
|
||||||
|
statusCode: 204,
|
||||||
|
helper: JsonParser{},
|
||||||
|
managerCreds: &mockKeycloakCredentials{
|
||||||
|
jwtToken: JWTToken{},
|
||||||
|
},
|
||||||
|
assertErrFunc: assert.NoError,
|
||||||
|
assertErrFuncMessage: "shouldn't return error",
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, testCase := range []updateUserAppMetadataTest{updateUserAppMetadataTestCase1, updateUserAppMetadataTestCase2,
|
||||||
|
updateUserAppMetadataTestCase3, updateUserAppMetadataTestCase4, updateUserAppMetadataTestCase5} {
|
||||||
|
t.Run(testCase.name, func(t *testing.T) {
|
||||||
|
reqClient := mockHTTPClient{
|
||||||
|
resBody: testCase.inputReqBody,
|
||||||
|
code: testCase.statusCode,
|
||||||
|
}
|
||||||
|
|
||||||
|
manager := &KeycloakManager{
|
||||||
|
httpClient: &reqClient,
|
||||||
|
credentials: testCase.managerCreds,
|
||||||
|
helper: testCase.helper,
|
||||||
|
}
|
||||||
|
|
||||||
|
err := manager.UpdateUserAppMetadata("1", testCase.appMetadata)
|
||||||
|
testCase.assertErrFunc(t, err, testCase.assertErrFuncMessage)
|
||||||
|
|
||||||
|
assert.Equal(t, testCase.expectedReqBody, reqClient.reqBody, "request body should match")
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
@ -528,7 +528,7 @@ func (am *DefaultAccountManager) AddPeer(setupKey, userID string, peer *Peer) (*
|
|||||||
SSHEnabled: false,
|
SSHEnabled: false,
|
||||||
SSHKey: peer.SSHKey,
|
SSHKey: peer.SSHKey,
|
||||||
LastLogin: time.Now(),
|
LastLogin: time.Now(),
|
||||||
LoginExpirationEnabled: true,
|
LoginExpirationEnabled: addedByUser,
|
||||||
}
|
}
|
||||||
|
|
||||||
// add peer to 'All' group
|
// add peer to 'All' group
|
||||||
|
@ -197,8 +197,4 @@ func TestUser_GetAllPATs(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
assert.Equal(t, 2, len(pats))
|
assert.Equal(t, 2, len(pats))
|
||||||
assert.Equal(t, mockTokenID1, pats[0].ID)
|
|
||||||
assert.Equal(t, mockToken1, pats[0].HashedToken)
|
|
||||||
assert.Equal(t, mockTokenID2, pats[1].ID)
|
|
||||||
assert.Equal(t, mockToken2, pats[1].HashedToken)
|
|
||||||
}
|
}
|
||||||
|
Loading…
Reference in New Issue
Block a user