From c86bacb5c3ce3f18a3ae09c6db4410139d73acd2 Mon Sep 17 00:00:00 2001 From: Maycon Santos Date: Sat, 28 May 2022 18:37:08 +0200 Subject: [PATCH] Unblock menu when login (#340) * GetClientID method and increase interval on slow_down err * Reuse existing authentication flow if is not expired Created a new struct to hold additional info about the flow If there is a waiting sso running, we cancel its context * Run the up command on a goroutine * Use time.Until * Use proper ctx and consistently use goroutine for up/down --- client/internal/oauth.go | 14 ++++++++++- client/server/server.go | 54 +++++++++++++++++++++++++++++++++------- client/ui/client_ui.go | 27 +++++++++++--------- 3 files changed, 73 insertions(+), 22 deletions(-) diff --git a/client/internal/oauth.go b/client/internal/oauth.go index b29721482..0249b239d 100644 --- a/client/internal/oauth.go +++ b/client/internal/oauth.go @@ -16,6 +16,7 @@ type OAuthClient interface { RequestDeviceCode(ctx context.Context) (DeviceAuthInfo, error) RotateAccessToken(ctx context.Context, refreshToken string) (TokenInfo, error) WaitToken(ctx context.Context, info DeviceAuthInfo) (TokenInfo, error) + GetClientID(ctx context.Context) string } // HTTPClient http client interface for API calls @@ -104,6 +105,11 @@ func NewHostedDeviceFlow(audience string, clientID string, domain string) *Hoste } } +// GetClientID returns the provider client id +func (h *Hosted) GetClientID(ctx context.Context) string { + return h.ClientID +} + // RequestDeviceCode requests a device code login flow information from Hosted func (h *Hosted) RequestDeviceCode(ctx context.Context) (DeviceAuthInfo, error) { url := "https://" + h.Domain + "/oauth/device/code" @@ -150,7 +156,8 @@ func (h *Hosted) RequestDeviceCode(ctx context.Context) (DeviceAuthInfo, error) // WaitToken waits user's login and authorize the app. Once the user's authorize // it retrieves the access token from Hosted's endpoint and validates it before returning func (h *Hosted) WaitToken(ctx context.Context, info DeviceAuthInfo) (TokenInfo, error) { - ticker := time.NewTicker(time.Duration(info.Interval) * time.Second) + interval := time.Duration(info.Interval) * time.Second + ticker := time.NewTicker(interval) for { select { case <-ctx.Done(): @@ -181,7 +188,12 @@ func (h *Hosted) WaitToken(ctx context.Context, info DeviceAuthInfo) (TokenInfo, if tokenResponse.Error != "" { if tokenResponse.Error == "authorization_pending" { continue + } else if tokenResponse.Error == "slow_down" { + interval = interval + (3 * time.Second) + ticker.Reset(interval) + continue } + return TokenInfo{}, fmt.Errorf(tokenResponse.ErrorDescription) } diff --git a/client/server/server.go b/client/server/server.go index 52acd8465..18794cb60 100644 --- a/client/server/server.go +++ b/client/server/server.go @@ -26,14 +26,20 @@ type Server struct { configPath string logFile string - oauthClient internal.OAuthClient - deviceAuthInfo internal.DeviceAuthInfo + oauthAuthFlow oauthAuthFlow mutex sync.Mutex config *internal.Config proto.UnimplementedDaemonServiceServer } +type oauthAuthFlow struct { + expiresAt time.Time + client internal.OAuthClient + info internal.DeviceAuthInfo + waitCancel context.CancelFunc +} + // New server instance constructor. func New(ctx context.Context, managementURL, adminURL, configPath, logFile string) *Server { return &Server{ @@ -187,6 +193,21 @@ func (s *Server) Login(callerCtx context.Context, msg *proto.LoginRequest) (*pro providerConfig.ProviderConfig.Domain, ) + if s.oauthAuthFlow.client != nil && s.oauthAuthFlow.client.GetClientID(ctx) == hostedClient.GetClientID(context.TODO()) { + if s.oauthAuthFlow.expiresAt.After(time.Now().Add(90 * time.Second)) { + log.Debugf("using previous device flow info") + return &proto.LoginResponse{ + NeedsSSOLogin: true, + VerificationURI: s.oauthAuthFlow.info.VerificationURI, + VerificationURIComplete: s.oauthAuthFlow.info.VerificationURIComplete, + UserCode: s.oauthAuthFlow.info.UserCode, + }, nil + } else { + log.Warnf("canceling previous waiting execution") + s.oauthAuthFlow.waitCancel() + } + } + deviceAuthInfo, err := hostedClient.RequestDeviceCode(context.TODO()) if err != nil { log.Errorf("getting a request device code failed: %v", err) @@ -194,8 +215,9 @@ func (s *Server) Login(callerCtx context.Context, msg *proto.LoginRequest) (*pro } s.mutex.Lock() - s.oauthClient = hostedClient - s.deviceAuthInfo = deviceAuthInfo + s.oauthAuthFlow.client = hostedClient + s.oauthAuthFlow.info = deviceAuthInfo + s.oauthAuthFlow.expiresAt = time.Now().Add(time.Duration(deviceAuthInfo.ExpiresIn) * time.Second) s.mutex.Unlock() state.Set(internal.StatusNeedsLogin) @@ -233,7 +255,7 @@ func (s *Server) WaitSSOLogin(callerCtx context.Context, msg *proto.WaitSSOLogin s.actCancel = cancel s.mutex.Unlock() - if s.oauthClient == nil { + if s.oauthAuthFlow.client == nil { return nil, gstatus.Errorf(codes.Internal, "oauth client is not initialized") } @@ -248,7 +270,7 @@ func (s *Server) WaitSSOLogin(callerCtx context.Context, msg *proto.WaitSSOLogin state.Set(internal.StatusConnecting) s.mutex.Lock() - deviceAuthInfo := s.deviceAuthInfo + deviceAuthInfo := s.oauthAuthFlow.info s.mutex.Unlock() if deviceAuthInfo.UserCode != msg.UserCode { @@ -256,12 +278,26 @@ func (s *Server) WaitSSOLogin(callerCtx context.Context, msg *proto.WaitSSOLogin return nil, gstatus.Errorf(codes.InvalidArgument, "sso user code is invalid") } - waitTimeout := time.Duration(deviceAuthInfo.ExpiresIn) - waitCTX, cancel := context.WithTimeout(ctx, waitTimeout*time.Second) + if s.oauthAuthFlow.waitCancel != nil { + s.oauthAuthFlow.waitCancel() + } + + waitTimeout := time.Until(s.oauthAuthFlow.expiresAt) + waitCTX, cancel := context.WithTimeout(ctx, waitTimeout) defer cancel() - tokenInfo, err := s.oauthClient.WaitToken(waitCTX, deviceAuthInfo) + s.mutex.Lock() + s.oauthAuthFlow.waitCancel = cancel + s.mutex.Unlock() + + tokenInfo, err := s.oauthAuthFlow.client.WaitToken(waitCTX, deviceAuthInfo) if err != nil { + if err == context.Canceled { + return nil, nil + } + s.mutex.Lock() + s.oauthAuthFlow.expiresAt = time.Now() + s.mutex.Unlock() state.Set(internal.StatusLoginFailed) log.Errorf("waiting for browser login failed: %v", err) return nil, err diff --git a/client/ui/client_ui.go b/client/ui/client_ui.go index 7ddc46ae2..41857daed 100644 --- a/client/ui/client_ui.go +++ b/client/ui/client_ui.go @@ -261,7 +261,7 @@ func (s *serviceClient) menuUpClick() error { err = s.login() if err != nil { - log.Errorf("get service status: %v", err) + log.Errorf("login failed with: %v", err) return err } @@ -271,16 +271,15 @@ func (s *serviceClient) menuUpClick() error { return err } - if status.Status != string(internal.StatusIdle) { + if status.Status == string(internal.StatusConnected) { log.Warnf("already connected") - return nil + return err } if _, err := s.conn.Up(s.ctx, &proto.UpRequest{}); err != nil { log.Errorf("up service: %v", err) return err } - return nil } @@ -388,15 +387,19 @@ func (s *serviceClient) onTrayReady() { case <-s.mAdminPanel.ClickedCh: err = open.Run(s.adminURL) case <-s.mUp.ClickedCh: - s.mUp.Disable() - if err = s.menuUpClick(); err != nil { - s.mUp.Enable() - } + go func() { + err := s.menuUpClick() + if err != nil { + return + } + }() case <-s.mDown.ClickedCh: - s.mDown.Disable() - if err = s.menuDownClick(); err != nil { - s.mDown.Enable() - } + go func() { + err := s.menuDownClick() + if err != nil { + return + } + }() case <-s.mSettings.ClickedCh: s.mSettings.Disable() go func() {