mirror of
https://github.com/netbirdio/netbird.git
synced 2025-02-16 10:20:09 +01:00
[client] Add mTLS support for SSO login (#2188)
* Add mTLS support for SSO login * Refactor variable to follow Go naming conventions --------- Co-authored-by: bcmmbaga <bethuelmbaga12@gmail.com>
This commit is contained in:
parent
9716be854d
commit
4bbedb5193
@ -84,7 +84,7 @@ func (a *Auth) SaveConfigIfSSOSupported(listener SSOListener) {
|
|||||||
func (a *Auth) saveConfigIfSSOSupported() (bool, error) {
|
func (a *Auth) saveConfigIfSSOSupported() (bool, error) {
|
||||||
supportsSSO := true
|
supportsSSO := true
|
||||||
err := a.withBackOff(a.ctx, func() (err error) {
|
err := a.withBackOff(a.ctx, func() (err error) {
|
||||||
_, err = internal.GetPKCEAuthorizationFlowInfo(a.ctx, a.config.PrivateKey, a.config.ManagementURL)
|
_, err = internal.GetPKCEAuthorizationFlowInfo(a.ctx, a.config.PrivateKey, a.config.ManagementURL, nil)
|
||||||
if s, ok := gstatus.FromError(err); ok && (s.Code() == codes.NotFound || s.Code() == codes.Unimplemented) {
|
if s, ok := gstatus.FromError(err); ok && (s.Code() == codes.NotFound || s.Code() == codes.Unimplemented) {
|
||||||
_, err = internal.GetDeviceAuthorizationFlowInfo(a.ctx, a.config.PrivateKey, a.config.ManagementURL)
|
_, err = internal.GetDeviceAuthorizationFlowInfo(a.ctx, a.config.PrivateKey, a.config.ManagementURL)
|
||||||
s, ok := gstatus.FromError(err)
|
s, ok := gstatus.FromError(err)
|
||||||
|
@ -86,7 +86,7 @@ func NewOAuthFlow(ctx context.Context, config *internal.Config, isLinuxDesktopCl
|
|||||||
|
|
||||||
// authenticateWithPKCEFlow initializes the Proof Key for Code Exchange flow auth flow
|
// authenticateWithPKCEFlow initializes the Proof Key for Code Exchange flow auth flow
|
||||||
func authenticateWithPKCEFlow(ctx context.Context, config *internal.Config) (OAuthFlow, error) {
|
func authenticateWithPKCEFlow(ctx context.Context, config *internal.Config) (OAuthFlow, error) {
|
||||||
pkceFlowInfo, err := internal.GetPKCEAuthorizationFlowInfo(ctx, config.PrivateKey, config.ManagementURL)
|
pkceFlowInfo, err := internal.GetPKCEAuthorizationFlowInfo(ctx, config.PrivateKey, config.ManagementURL, config.ClientCertKeyPair)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("getting pkce authorization flow info failed with error: %v", err)
|
return nil, fmt.Errorf("getting pkce authorization flow info failed with error: %v", err)
|
||||||
}
|
}
|
||||||
|
@ -4,6 +4,7 @@ import (
|
|||||||
"context"
|
"context"
|
||||||
"crypto/sha256"
|
"crypto/sha256"
|
||||||
"crypto/subtle"
|
"crypto/subtle"
|
||||||
|
"crypto/tls"
|
||||||
"encoding/base64"
|
"encoding/base64"
|
||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
@ -143,6 +144,18 @@ func (p *PKCEAuthorizationFlow) WaitToken(ctx context.Context, _ AuthFlowInfo) (
|
|||||||
func (p *PKCEAuthorizationFlow) startServer(server *http.Server, tokenChan chan<- *oauth2.Token, errChan chan<- error) {
|
func (p *PKCEAuthorizationFlow) startServer(server *http.Server, tokenChan chan<- *oauth2.Token, errChan chan<- error) {
|
||||||
mux := http.NewServeMux()
|
mux := http.NewServeMux()
|
||||||
mux.HandleFunc("/", func(w http.ResponseWriter, req *http.Request) {
|
mux.HandleFunc("/", func(w http.ResponseWriter, req *http.Request) {
|
||||||
|
cert := p.providerConfig.ClientCertPair
|
||||||
|
if cert != nil {
|
||||||
|
tr := &http.Transport{
|
||||||
|
TLSClientConfig: &tls.Config{
|
||||||
|
Certificates: []tls.Certificate{*cert},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
sslClient := &http.Client{Transport: tr}
|
||||||
|
ctx := context.WithValue(req.Context(), oauth2.HTTPClient, sslClient)
|
||||||
|
req = req.WithContext(ctx)
|
||||||
|
}
|
||||||
|
|
||||||
token, err := p.handleRequest(req)
|
token, err := p.handleRequest(req)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
renderPKCEFlowTmpl(w, err)
|
renderPKCEFlowTmpl(w, err)
|
||||||
|
@ -2,6 +2,7 @@ package internal
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
|
"crypto/tls"
|
||||||
"fmt"
|
"fmt"
|
||||||
"net/url"
|
"net/url"
|
||||||
"os"
|
"os"
|
||||||
@ -57,6 +58,8 @@ type ConfigInput struct {
|
|||||||
DisableAutoConnect *bool
|
DisableAutoConnect *bool
|
||||||
ExtraIFaceBlackList []string
|
ExtraIFaceBlackList []string
|
||||||
DNSRouteInterval *time.Duration
|
DNSRouteInterval *time.Duration
|
||||||
|
ClientCertPath string
|
||||||
|
ClientCertKeyPath string
|
||||||
}
|
}
|
||||||
|
|
||||||
// Config Configuration type
|
// Config Configuration type
|
||||||
@ -102,6 +105,13 @@ type Config struct {
|
|||||||
|
|
||||||
// DNSRouteInterval is the interval in which the DNS routes are updated
|
// DNSRouteInterval is the interval in which the DNS routes are updated
|
||||||
DNSRouteInterval time.Duration
|
DNSRouteInterval time.Duration
|
||||||
|
//Path to a certificate used for mTLS authentication
|
||||||
|
ClientCertPath string
|
||||||
|
|
||||||
|
//Path to corresponding private key of ClientCertPath
|
||||||
|
ClientCertKeyPath string
|
||||||
|
|
||||||
|
ClientCertKeyPair *tls.Certificate `json:"-"`
|
||||||
}
|
}
|
||||||
|
|
||||||
// ReadConfig read config file and return with Config. If it is not exists create a new with default values
|
// ReadConfig read config file and return with Config. If it is not exists create a new with default values
|
||||||
@ -385,6 +395,26 @@ func (config *Config) apply(input ConfigInput) (updated bool, err error) {
|
|||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if input.ClientCertKeyPath != "" {
|
||||||
|
config.ClientCertKeyPath = input.ClientCertKeyPath
|
||||||
|
updated = true
|
||||||
|
}
|
||||||
|
|
||||||
|
if input.ClientCertPath != "" {
|
||||||
|
config.ClientCertPath = input.ClientCertPath
|
||||||
|
updated = true
|
||||||
|
}
|
||||||
|
|
||||||
|
if config.ClientCertPath != "" && config.ClientCertKeyPath != "" {
|
||||||
|
cert, err := tls.LoadX509KeyPair(config.ClientCertPath, config.ClientCertKeyPath)
|
||||||
|
if err != nil {
|
||||||
|
log.Error("Failed to load mTLS cert/key pair: ", err)
|
||||||
|
} else {
|
||||||
|
config.ClientCertKeyPair = &cert
|
||||||
|
log.Info("Loaded client mTLS cert/key pair")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
return updated, nil
|
return updated, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -2,6 +2,7 @@ package internal
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
|
"crypto/tls"
|
||||||
"fmt"
|
"fmt"
|
||||||
"net/url"
|
"net/url"
|
||||||
|
|
||||||
@ -36,10 +37,12 @@ type PKCEAuthProviderConfig struct {
|
|||||||
RedirectURLs []string
|
RedirectURLs []string
|
||||||
// UseIDToken indicates if the id token should be used for authentication
|
// UseIDToken indicates if the id token should be used for authentication
|
||||||
UseIDToken bool
|
UseIDToken bool
|
||||||
|
//ClientCertPair is used for mTLS authentication to the IDP
|
||||||
|
ClientCertPair *tls.Certificate
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetPKCEAuthorizationFlowInfo initialize a PKCEAuthorizationFlow instance and return with it
|
// GetPKCEAuthorizationFlowInfo initialize a PKCEAuthorizationFlow instance and return with it
|
||||||
func GetPKCEAuthorizationFlowInfo(ctx context.Context, privateKey string, mgmURL *url.URL) (PKCEAuthorizationFlow, error) {
|
func GetPKCEAuthorizationFlowInfo(ctx context.Context, privateKey string, mgmURL *url.URL, clientCert *tls.Certificate) (PKCEAuthorizationFlow, error) {
|
||||||
// validate our peer's Wireguard PRIVATE key
|
// validate our peer's Wireguard PRIVATE key
|
||||||
myPrivateKey, err := wgtypes.ParseKey(privateKey)
|
myPrivateKey, err := wgtypes.ParseKey(privateKey)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@ -93,6 +96,7 @@ func GetPKCEAuthorizationFlowInfo(ctx context.Context, privateKey string, mgmURL
|
|||||||
Scope: protoPKCEAuthorizationFlow.GetProviderConfig().GetScope(),
|
Scope: protoPKCEAuthorizationFlow.GetProviderConfig().GetScope(),
|
||||||
RedirectURLs: protoPKCEAuthorizationFlow.GetProviderConfig().GetRedirectURLs(),
|
RedirectURLs: protoPKCEAuthorizationFlow.GetProviderConfig().GetRedirectURLs(),
|
||||||
UseIDToken: protoPKCEAuthorizationFlow.GetProviderConfig().GetUseIDToken(),
|
UseIDToken: protoPKCEAuthorizationFlow.GetProviderConfig().GetUseIDToken(),
|
||||||
|
ClientCertPair: clientCert,
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -74,7 +74,7 @@ func (a *Auth) SaveConfigIfSSOSupported() (bool, error) {
|
|||||||
err := a.withBackOff(a.ctx, func() (err error) {
|
err := a.withBackOff(a.ctx, func() (err error) {
|
||||||
_, err = internal.GetDeviceAuthorizationFlowInfo(a.ctx, a.config.PrivateKey, a.config.ManagementURL)
|
_, err = internal.GetDeviceAuthorizationFlowInfo(a.ctx, a.config.PrivateKey, a.config.ManagementURL)
|
||||||
if s, ok := gstatus.FromError(err); ok && (s.Code() == codes.NotFound || s.Code() == codes.Unimplemented) {
|
if s, ok := gstatus.FromError(err); ok && (s.Code() == codes.NotFound || s.Code() == codes.Unimplemented) {
|
||||||
_, err = internal.GetPKCEAuthorizationFlowInfo(a.ctx, a.config.PrivateKey, a.config.ManagementURL)
|
_, err = internal.GetPKCEAuthorizationFlowInfo(a.ctx, a.config.PrivateKey, a.config.ManagementURL, nil)
|
||||||
if s, ok := gstatus.FromError(err); ok && (s.Code() == codes.NotFound || s.Code() == codes.Unimplemented) {
|
if s, ok := gstatus.FromError(err); ok && (s.Code() == codes.NotFound || s.Code() == codes.Unimplemented) {
|
||||||
supportsSSO = false
|
supportsSSO = false
|
||||||
err = nil
|
err = nil
|
||||||
|
Loading…
Reference in New Issue
Block a user