mirror of
https://github.com/netbirdio/netbird.git
synced 2025-01-25 23:38:42 +01:00
Improve Client Authentication (#1135)
* shutdown the pkce server on user cancellation * Refactor openURL to exclusively manage authentication flow instructions and browser launching * Refactor authentication flow initialization based on client OS The NewOAuthFlow method now first checks the operating system and if it is a non-desktop Linux, it opts for Device Code Flow. PKCEFlow is tried first and if it fails, then it falls back on Device Code Flow. If both unsuccessful, the authentication process halts and error messages have been updated to provide more helpful feedback for troubleshooting authentication errors * Replace log-based Linux desktop check with process check To verify if a Linux OS is running a desktop environment in the Authentication utility, the log-based method that checks the XDG_CURRENT_DESKTOP env has been replaced with a method that checks directly if either X or Wayland display server processes are running. This method is more reliable as it directly checks for the display server process rather than relying on an environment variable that may not be set in all desktop environments. * Refactor PKCE Authorization Flow to improve server handling * refactor check for linux running desktop environment * Improve server shutdown handling and encapsulate handlers with new server multiplexer The changes enhance the way the server shuts down by specifying a context with timeout of 5 seconds, adding a safeguard to ensure the server halts even on potential hanging requests. Also, the server's root handler is now encapsulated within a new ServeMux instance, to support multiple registrations of a path
This commit is contained in:
parent
34e2c6b943
commit
8febab4076
@ -3,8 +3,6 @@ package cmd
|
|||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"fmt"
|
"fmt"
|
||||||
"os"
|
|
||||||
"runtime"
|
|
||||||
"strings"
|
"strings"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
@ -195,51 +193,12 @@ func openURL(cmd *cobra.Command, verificationURIComplete, userCode string) {
|
|||||||
codeMsg = fmt.Sprintf("and enter the code %s to authenticate.", userCode)
|
codeMsg = fmt.Sprintf("and enter the code %s to authenticate.", userCode)
|
||||||
}
|
}
|
||||||
|
|
||||||
browserAuthMsg := "Please do the SSO login in your browser. \n" +
|
cmd.Println("Please do the SSO login in your browser. \n" +
|
||||||
"If your browser didn't open automatically, use this URL to log in:\n\n" +
|
"If your browser didn't open automatically, use this URL to log in:\n\n" +
|
||||||
verificationURIComplete + " " + codeMsg
|
verificationURIComplete + " " + codeMsg)
|
||||||
|
cmd.Println("")
|
||||||
setupKeyAuthMsg := "\nAlternatively, you may want to use a setup key, see:\n\n" +
|
if err := open.Run(verificationURIComplete); err != nil {
|
||||||
"https://docs.netbird.io/how-to/register-machines-using-setup-keys"
|
cmd.Println("\nAlternatively, you may want to use a setup key, see:\n\n" +
|
||||||
|
"https://docs.netbird.io/how-to/register-machines-using-setup-keys")
|
||||||
authenticateUsingBrowser := func() {
|
|
||||||
cmd.Println(browserAuthMsg)
|
|
||||||
cmd.Println("")
|
|
||||||
if err := open.Run(verificationURIComplete); err != nil {
|
|
||||||
cmd.Println(setupKeyAuthMsg)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
switch runtime.GOOS {
|
|
||||||
case "windows", "darwin":
|
|
||||||
authenticateUsingBrowser()
|
|
||||||
case "linux":
|
|
||||||
if isLinuxRunningDesktop() {
|
|
||||||
authenticateUsingBrowser()
|
|
||||||
} else {
|
|
||||||
// If current flow is PKCE, it implies the server is anticipating the redirect to localhost.
|
|
||||||
// Devices lacking browser support are incompatible with this flow.Therefore,
|
|
||||||
// these devices will need to resort to setup keys instead.
|
|
||||||
if isPKCEFlow(verificationURIComplete) {
|
|
||||||
cmd.Println("Please proceed with setting up this device using setup keys, see:\n\n" +
|
|
||||||
"https://docs.netbird.io/how-to/register-machines-using-setup-keys")
|
|
||||||
} else {
|
|
||||||
cmd.Println(browserAuthMsg)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// isLinuxRunningDesktop checks if a Linux OS is running desktop environment.
|
|
||||||
func isLinuxRunningDesktop() bool {
|
|
||||||
return os.Getenv("DESKTOP_SESSION") != "" || os.Getenv("XDG_CURRENT_DESKTOP") != ""
|
|
||||||
}
|
|
||||||
|
|
||||||
// isPKCEFlow determines if the PKCE flow is active or not,
|
|
||||||
// by checking the existence of redirect_uri inside the verification URL.
|
|
||||||
func isPKCEFlow(verificationURL string) bool {
|
|
||||||
if verificationURL == "" {
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
return strings.Contains(verificationURL, "redirect_uri")
|
|
||||||
}
|
|
||||||
|
@ -4,8 +4,8 @@ import (
|
|||||||
"context"
|
"context"
|
||||||
"fmt"
|
"fmt"
|
||||||
"net/http"
|
"net/http"
|
||||||
|
"runtime"
|
||||||
|
|
||||||
log "github.com/sirupsen/logrus"
|
|
||||||
"google.golang.org/grpc/codes"
|
"google.golang.org/grpc/codes"
|
||||||
gstatus "google.golang.org/grpc/status"
|
gstatus "google.golang.org/grpc/status"
|
||||||
|
|
||||||
@ -57,25 +57,43 @@ func (t TokenInfo) GetTokenToUse() string {
|
|||||||
return t.AccessToken
|
return t.AccessToken
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewOAuthFlow initializes and returns the appropriate OAuth flow based on the management configuration.
|
// NewOAuthFlow initializes and returns the appropriate OAuth flow based on the management configuration
|
||||||
|
//
|
||||||
|
// It starts by initializing the PKCE.If this process fails, it resorts to the Device Code Flow,
|
||||||
|
// and if that also fails, the authentication process is deemed unsuccessful
|
||||||
|
//
|
||||||
|
// On Linux distros without desktop environment support, it only tries to initialize the Device Code Flow
|
||||||
func NewOAuthFlow(ctx context.Context, config *internal.Config) (OAuthFlow, error) {
|
func NewOAuthFlow(ctx context.Context, config *internal.Config) (OAuthFlow, error) {
|
||||||
log.Debug("loading pkce authorization flow info")
|
if runtime.GOOS == "linux" && !isLinuxRunningDesktop() {
|
||||||
|
return authenticateWithDeviceCodeFlow(ctx, config)
|
||||||
pkceFlowInfo, err := internal.GetPKCEAuthorizationFlowInfo(ctx, config.PrivateKey, config.ManagementURL)
|
|
||||||
if err == nil {
|
|
||||||
return NewPKCEAuthorizationFlow(pkceFlowInfo.ProviderConfig)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
log.Debugf("loading pkce authorization flow info failed with error: %v", err)
|
pkceFlow, err := authenticateWithPKCEFlow(ctx, config)
|
||||||
log.Debugf("falling back to device authorization flow info")
|
if err != nil {
|
||||||
|
// fallback to device code flow
|
||||||
|
return authenticateWithDeviceCodeFlow(ctx, config)
|
||||||
|
}
|
||||||
|
return pkceFlow, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// authenticateWithPKCEFlow initializes the Proof Key for Code Exchange flow auth flow
|
||||||
|
func authenticateWithPKCEFlow(ctx context.Context, config *internal.Config) (OAuthFlow, error) {
|
||||||
|
pkceFlowInfo, err := internal.GetPKCEAuthorizationFlowInfo(ctx, config.PrivateKey, config.ManagementURL)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("getting pkce authorization flow info failed with error: %v", err)
|
||||||
|
}
|
||||||
|
return NewPKCEAuthorizationFlow(pkceFlowInfo.ProviderConfig)
|
||||||
|
}
|
||||||
|
|
||||||
|
// authenticateWithDeviceCodeFlow initializes the Device Code auth Flow
|
||||||
|
func authenticateWithDeviceCodeFlow(ctx context.Context, config *internal.Config) (OAuthFlow, error) {
|
||||||
deviceFlowInfo, err := internal.GetDeviceAuthorizationFlowInfo(ctx, config.PrivateKey, config.ManagementURL)
|
deviceFlowInfo, err := internal.GetDeviceAuthorizationFlowInfo(ctx, config.PrivateKey, config.ManagementURL)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
s, ok := gstatus.FromError(err)
|
s, ok := gstatus.FromError(err)
|
||||||
if ok && s.Code() == codes.NotFound {
|
if ok && s.Code() == codes.NotFound {
|
||||||
return nil, fmt.Errorf("no SSO provider returned from management. " +
|
return nil, fmt.Errorf("no SSO provider returned from management. " +
|
||||||
"If you are using hosting Netbird see documentation at " +
|
"Please proceed with setting up this device using setup keys " +
|
||||||
"https://github.com/netbirdio/netbird/tree/main/management for details")
|
"https://docs.netbird.io/how-to/register-machines-using-setup-keys")
|
||||||
} else if ok && s.Code() == codes.Unimplemented {
|
} else if ok && s.Code() == codes.Unimplemented {
|
||||||
return nil, fmt.Errorf("the management server, %s, does not support SSO providers, "+
|
return nil, fmt.Errorf("the management server, %s, does not support SSO providers, "+
|
||||||
"please update your server or use Setup Keys to login", config.ManagementURL)
|
"please update your server or use Setup Keys to login", config.ManagementURL)
|
||||||
|
@ -12,7 +12,6 @@ import (
|
|||||||
"net/http"
|
"net/http"
|
||||||
"net/url"
|
"net/url"
|
||||||
"strings"
|
"strings"
|
||||||
"sync"
|
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
log "github.com/sirupsen/logrus"
|
log "github.com/sirupsen/logrus"
|
||||||
@ -80,7 +79,7 @@ func (p *PKCEAuthorizationFlow) GetClientID(_ context.Context) string {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// RequestAuthInfo requests a authorization code login flow information.
|
// RequestAuthInfo requests a authorization code login flow information.
|
||||||
func (p *PKCEAuthorizationFlow) RequestAuthInfo(_ context.Context) (AuthFlowInfo, error) {
|
func (p *PKCEAuthorizationFlow) RequestAuthInfo(ctx context.Context) (AuthFlowInfo, error) {
|
||||||
state, err := randomBytesInHex(24)
|
state, err := randomBytesInHex(24)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return AuthFlowInfo{}, fmt.Errorf("could not generate random state: %v", err)
|
return AuthFlowInfo{}, fmt.Errorf("could not generate random state: %v", err)
|
||||||
@ -114,64 +113,37 @@ func (p *PKCEAuthorizationFlow) WaitToken(ctx context.Context, _ AuthFlowInfo) (
|
|||||||
tokenChan := make(chan *oauth2.Token, 1)
|
tokenChan := make(chan *oauth2.Token, 1)
|
||||||
errChan := make(chan error, 1)
|
errChan := make(chan error, 1)
|
||||||
|
|
||||||
go p.startServer(tokenChan, errChan)
|
parsedURL, err := url.Parse(p.oAuthConfig.RedirectURL)
|
||||||
|
if err != nil {
|
||||||
|
return TokenInfo{}, fmt.Errorf("failed to parse redirect URL: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
server := &http.Server{Addr: fmt.Sprintf(":%s", parsedURL.Port())}
|
||||||
|
defer func() {
|
||||||
|
shutdownCtx, cancel := context.WithTimeout(ctx, 5*time.Second)
|
||||||
|
defer cancel()
|
||||||
|
|
||||||
|
if err := server.Shutdown(shutdownCtx); err != nil {
|
||||||
|
log.Errorf("failed to close the server: %v", err)
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
go p.startServer(server, tokenChan, errChan)
|
||||||
|
|
||||||
select {
|
select {
|
||||||
case <-ctx.Done():
|
case <-ctx.Done():
|
||||||
return TokenInfo{}, ctx.Err()
|
return TokenInfo{}, ctx.Err()
|
||||||
case token := <-tokenChan:
|
case token := <-tokenChan:
|
||||||
return p.handleOAuthToken(token)
|
return p.parseOAuthToken(token)
|
||||||
case err := <-errChan:
|
case err := <-errChan:
|
||||||
return TokenInfo{}, err
|
return TokenInfo{}, err
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (p *PKCEAuthorizationFlow) startServer(tokenChan chan<- *oauth2.Token, errChan chan<- error) {
|
func (p *PKCEAuthorizationFlow) startServer(server *http.Server, tokenChan chan<- *oauth2.Token, errChan chan<- error) {
|
||||||
var wg sync.WaitGroup
|
mux := http.NewServeMux()
|
||||||
|
mux.HandleFunc("/", func(w http.ResponseWriter, req *http.Request) {
|
||||||
parsedURL, err := url.Parse(p.oAuthConfig.RedirectURL)
|
token, err := p.handleRequest(req)
|
||||||
if err != nil {
|
|
||||||
errChan <- fmt.Errorf("failed to parse redirect URL: %v", err)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
server := http.Server{Addr: fmt.Sprintf(":%s", parsedURL.Port())}
|
|
||||||
go func() {
|
|
||||||
if err := server.ListenAndServe(); err != nil && !errors.Is(err, http.ErrServerClosed) {
|
|
||||||
errChan <- err
|
|
||||||
}
|
|
||||||
}()
|
|
||||||
|
|
||||||
wg.Add(1)
|
|
||||||
http.HandleFunc("/", func(w http.ResponseWriter, req *http.Request) {
|
|
||||||
defer wg.Done()
|
|
||||||
|
|
||||||
tokenValidatorFunc := func() (*oauth2.Token, error) {
|
|
||||||
query := req.URL.Query()
|
|
||||||
|
|
||||||
if authError := query.Get(queryError); authError != "" {
|
|
||||||
authErrorDesc := query.Get(queryErrorDesc)
|
|
||||||
return nil, fmt.Errorf("%s.%s", authError, authErrorDesc)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Prevent timing attacks on state
|
|
||||||
if state := query.Get(queryState); subtle.ConstantTimeCompare([]byte(p.state), []byte(state)) == 0 {
|
|
||||||
return nil, fmt.Errorf("invalid state")
|
|
||||||
}
|
|
||||||
|
|
||||||
code := query.Get(queryCode)
|
|
||||||
if code == "" {
|
|
||||||
return nil, fmt.Errorf("missing code")
|
|
||||||
}
|
|
||||||
|
|
||||||
return p.oAuthConfig.Exchange(
|
|
||||||
req.Context(),
|
|
||||||
code,
|
|
||||||
oauth2.SetAuthURLParam("code_verifier", p.codeVerifier),
|
|
||||||
)
|
|
||||||
}
|
|
||||||
|
|
||||||
token, err := tokenValidatorFunc()
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
renderPKCEFlowTmpl(w, err)
|
renderPKCEFlowTmpl(w, err)
|
||||||
errChan <- fmt.Errorf("PKCE authorization flow failed: %v", err)
|
errChan <- fmt.Errorf("PKCE authorization flow failed: %v", err)
|
||||||
@ -182,13 +154,38 @@ func (p *PKCEAuthorizationFlow) startServer(tokenChan chan<- *oauth2.Token, errC
|
|||||||
tokenChan <- token
|
tokenChan <- token
|
||||||
})
|
})
|
||||||
|
|
||||||
wg.Wait()
|
server.Handler = mux
|
||||||
if err := server.Shutdown(context.Background()); err != nil {
|
if err := server.ListenAndServe(); err != nil && !errors.Is(err, http.ErrServerClosed) {
|
||||||
log.Errorf("error while shutting down pkce flow server: %v", err)
|
errChan <- err
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (p *PKCEAuthorizationFlow) handleOAuthToken(token *oauth2.Token) (TokenInfo, error) {
|
func (p *PKCEAuthorizationFlow) handleRequest(req *http.Request) (*oauth2.Token, error) {
|
||||||
|
query := req.URL.Query()
|
||||||
|
|
||||||
|
if authError := query.Get(queryError); authError != "" {
|
||||||
|
authErrorDesc := query.Get(queryErrorDesc)
|
||||||
|
return nil, fmt.Errorf("%s.%s", authError, authErrorDesc)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Prevent timing attacks on the state
|
||||||
|
if state := query.Get(queryState); subtle.ConstantTimeCompare([]byte(p.state), []byte(state)) == 0 {
|
||||||
|
return nil, fmt.Errorf("invalid state")
|
||||||
|
}
|
||||||
|
|
||||||
|
code := query.Get(queryCode)
|
||||||
|
if code == "" {
|
||||||
|
return nil, fmt.Errorf("missing code")
|
||||||
|
}
|
||||||
|
|
||||||
|
return p.oAuthConfig.Exchange(
|
||||||
|
req.Context(),
|
||||||
|
code,
|
||||||
|
oauth2.SetAuthURLParam("code_verifier", p.codeVerifier),
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *PKCEAuthorizationFlow) parseOAuthToken(token *oauth2.Token) (TokenInfo, error) {
|
||||||
tokenInfo := TokenInfo{
|
tokenInfo := TokenInfo{
|
||||||
AccessToken: token.AccessToken,
|
AccessToken: token.AccessToken,
|
||||||
RefreshToken: token.RefreshToken,
|
RefreshToken: token.RefreshToken,
|
||||||
|
@ -7,6 +7,7 @@ import (
|
|||||||
"encoding/json"
|
"encoding/json"
|
||||||
"fmt"
|
"fmt"
|
||||||
"io"
|
"io"
|
||||||
|
"os"
|
||||||
"reflect"
|
"reflect"
|
||||||
"strings"
|
"strings"
|
||||||
)
|
)
|
||||||
@ -60,3 +61,8 @@ func isValidAccessToken(token string, audience string) error {
|
|||||||
|
|
||||||
return fmt.Errorf("invalid JWT token audience field")
|
return fmt.Errorf("invalid JWT token audience field")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// isLinuxRunningDesktop checks if a Linux OS is running desktop environment
|
||||||
|
func isLinuxRunningDesktop() bool {
|
||||||
|
return os.Getenv("DESKTOP_SESSION") != "" || os.Getenv("XDG_CURRENT_DESKTOP") != ""
|
||||||
|
}
|
||||||
|
Loading…
Reference in New Issue
Block a user