Unblock menu when login (#340)

* GetClientID method and increase interval on slow_down err

* Reuse existing authentication flow if is not expired

Created a new struct to hold additional info
 about the flow

 If there is a waiting sso running, we cancel its context

* Run the up command on a goroutine

* Use time.Until

* Use proper ctx and consistently use goroutine for up/down
This commit is contained in:
Maycon Santos 2022-05-28 18:37:08 +02:00 committed by GitHub
parent 59a964eed8
commit c86bacb5c3
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 73 additions and 22 deletions

View File

@ -16,6 +16,7 @@ type OAuthClient interface {
RequestDeviceCode(ctx context.Context) (DeviceAuthInfo, error) RequestDeviceCode(ctx context.Context) (DeviceAuthInfo, error)
RotateAccessToken(ctx context.Context, refreshToken string) (TokenInfo, error) RotateAccessToken(ctx context.Context, refreshToken string) (TokenInfo, error)
WaitToken(ctx context.Context, info DeviceAuthInfo) (TokenInfo, error) WaitToken(ctx context.Context, info DeviceAuthInfo) (TokenInfo, error)
GetClientID(ctx context.Context) string
} }
// HTTPClient http client interface for API calls // HTTPClient http client interface for API calls
@ -104,6 +105,11 @@ func NewHostedDeviceFlow(audience string, clientID string, domain string) *Hoste
} }
} }
// GetClientID returns the provider client id
func (h *Hosted) GetClientID(ctx context.Context) string {
return h.ClientID
}
// RequestDeviceCode requests a device code login flow information from Hosted // RequestDeviceCode requests a device code login flow information from Hosted
func (h *Hosted) RequestDeviceCode(ctx context.Context) (DeviceAuthInfo, error) { func (h *Hosted) RequestDeviceCode(ctx context.Context) (DeviceAuthInfo, error) {
url := "https://" + h.Domain + "/oauth/device/code" url := "https://" + h.Domain + "/oauth/device/code"
@ -150,7 +156,8 @@ func (h *Hosted) RequestDeviceCode(ctx context.Context) (DeviceAuthInfo, error)
// WaitToken waits user's login and authorize the app. Once the user's authorize // WaitToken waits user's login and authorize the app. Once the user's authorize
// it retrieves the access token from Hosted's endpoint and validates it before returning // it retrieves the access token from Hosted's endpoint and validates it before returning
func (h *Hosted) WaitToken(ctx context.Context, info DeviceAuthInfo) (TokenInfo, error) { func (h *Hosted) WaitToken(ctx context.Context, info DeviceAuthInfo) (TokenInfo, error) {
ticker := time.NewTicker(time.Duration(info.Interval) * time.Second) interval := time.Duration(info.Interval) * time.Second
ticker := time.NewTicker(interval)
for { for {
select { select {
case <-ctx.Done(): case <-ctx.Done():
@ -181,7 +188,12 @@ func (h *Hosted) WaitToken(ctx context.Context, info DeviceAuthInfo) (TokenInfo,
if tokenResponse.Error != "" { if tokenResponse.Error != "" {
if tokenResponse.Error == "authorization_pending" { if tokenResponse.Error == "authorization_pending" {
continue continue
} else if tokenResponse.Error == "slow_down" {
interval = interval + (3 * time.Second)
ticker.Reset(interval)
continue
} }
return TokenInfo{}, fmt.Errorf(tokenResponse.ErrorDescription) return TokenInfo{}, fmt.Errorf(tokenResponse.ErrorDescription)
} }

View File

@ -26,14 +26,20 @@ type Server struct {
configPath string configPath string
logFile string logFile string
oauthClient internal.OAuthClient oauthAuthFlow oauthAuthFlow
deviceAuthInfo internal.DeviceAuthInfo
mutex sync.Mutex mutex sync.Mutex
config *internal.Config config *internal.Config
proto.UnimplementedDaemonServiceServer proto.UnimplementedDaemonServiceServer
} }
type oauthAuthFlow struct {
expiresAt time.Time
client internal.OAuthClient
info internal.DeviceAuthInfo
waitCancel context.CancelFunc
}
// New server instance constructor. // New server instance constructor.
func New(ctx context.Context, managementURL, adminURL, configPath, logFile string) *Server { func New(ctx context.Context, managementURL, adminURL, configPath, logFile string) *Server {
return &Server{ return &Server{
@ -187,6 +193,21 @@ func (s *Server) Login(callerCtx context.Context, msg *proto.LoginRequest) (*pro
providerConfig.ProviderConfig.Domain, providerConfig.ProviderConfig.Domain,
) )
if s.oauthAuthFlow.client != nil && s.oauthAuthFlow.client.GetClientID(ctx) == hostedClient.GetClientID(context.TODO()) {
if s.oauthAuthFlow.expiresAt.After(time.Now().Add(90 * time.Second)) {
log.Debugf("using previous device flow info")
return &proto.LoginResponse{
NeedsSSOLogin: true,
VerificationURI: s.oauthAuthFlow.info.VerificationURI,
VerificationURIComplete: s.oauthAuthFlow.info.VerificationURIComplete,
UserCode: s.oauthAuthFlow.info.UserCode,
}, nil
} else {
log.Warnf("canceling previous waiting execution")
s.oauthAuthFlow.waitCancel()
}
}
deviceAuthInfo, err := hostedClient.RequestDeviceCode(context.TODO()) deviceAuthInfo, err := hostedClient.RequestDeviceCode(context.TODO())
if err != nil { if err != nil {
log.Errorf("getting a request device code failed: %v", err) log.Errorf("getting a request device code failed: %v", err)
@ -194,8 +215,9 @@ func (s *Server) Login(callerCtx context.Context, msg *proto.LoginRequest) (*pro
} }
s.mutex.Lock() s.mutex.Lock()
s.oauthClient = hostedClient s.oauthAuthFlow.client = hostedClient
s.deviceAuthInfo = deviceAuthInfo s.oauthAuthFlow.info = deviceAuthInfo
s.oauthAuthFlow.expiresAt = time.Now().Add(time.Duration(deviceAuthInfo.ExpiresIn) * time.Second)
s.mutex.Unlock() s.mutex.Unlock()
state.Set(internal.StatusNeedsLogin) state.Set(internal.StatusNeedsLogin)
@ -233,7 +255,7 @@ func (s *Server) WaitSSOLogin(callerCtx context.Context, msg *proto.WaitSSOLogin
s.actCancel = cancel s.actCancel = cancel
s.mutex.Unlock() s.mutex.Unlock()
if s.oauthClient == nil { if s.oauthAuthFlow.client == nil {
return nil, gstatus.Errorf(codes.Internal, "oauth client is not initialized") return nil, gstatus.Errorf(codes.Internal, "oauth client is not initialized")
} }
@ -248,7 +270,7 @@ func (s *Server) WaitSSOLogin(callerCtx context.Context, msg *proto.WaitSSOLogin
state.Set(internal.StatusConnecting) state.Set(internal.StatusConnecting)
s.mutex.Lock() s.mutex.Lock()
deviceAuthInfo := s.deviceAuthInfo deviceAuthInfo := s.oauthAuthFlow.info
s.mutex.Unlock() s.mutex.Unlock()
if deviceAuthInfo.UserCode != msg.UserCode { if deviceAuthInfo.UserCode != msg.UserCode {
@ -256,12 +278,26 @@ func (s *Server) WaitSSOLogin(callerCtx context.Context, msg *proto.WaitSSOLogin
return nil, gstatus.Errorf(codes.InvalidArgument, "sso user code is invalid") return nil, gstatus.Errorf(codes.InvalidArgument, "sso user code is invalid")
} }
waitTimeout := time.Duration(deviceAuthInfo.ExpiresIn) if s.oauthAuthFlow.waitCancel != nil {
waitCTX, cancel := context.WithTimeout(ctx, waitTimeout*time.Second) s.oauthAuthFlow.waitCancel()
}
waitTimeout := time.Until(s.oauthAuthFlow.expiresAt)
waitCTX, cancel := context.WithTimeout(ctx, waitTimeout)
defer cancel() defer cancel()
tokenInfo, err := s.oauthClient.WaitToken(waitCTX, deviceAuthInfo) s.mutex.Lock()
s.oauthAuthFlow.waitCancel = cancel
s.mutex.Unlock()
tokenInfo, err := s.oauthAuthFlow.client.WaitToken(waitCTX, deviceAuthInfo)
if err != nil { if err != nil {
if err == context.Canceled {
return nil, nil
}
s.mutex.Lock()
s.oauthAuthFlow.expiresAt = time.Now()
s.mutex.Unlock()
state.Set(internal.StatusLoginFailed) state.Set(internal.StatusLoginFailed)
log.Errorf("waiting for browser login failed: %v", err) log.Errorf("waiting for browser login failed: %v", err)
return nil, err return nil, err

View File

@ -261,7 +261,7 @@ func (s *serviceClient) menuUpClick() error {
err = s.login() err = s.login()
if err != nil { if err != nil {
log.Errorf("get service status: %v", err) log.Errorf("login failed with: %v", err)
return err return err
} }
@ -271,16 +271,15 @@ func (s *serviceClient) menuUpClick() error {
return err return err
} }
if status.Status != string(internal.StatusIdle) { if status.Status == string(internal.StatusConnected) {
log.Warnf("already connected") log.Warnf("already connected")
return nil return err
} }
if _, err := s.conn.Up(s.ctx, &proto.UpRequest{}); err != nil { if _, err := s.conn.Up(s.ctx, &proto.UpRequest{}); err != nil {
log.Errorf("up service: %v", err) log.Errorf("up service: %v", err)
return err return err
} }
return nil return nil
} }
@ -388,15 +387,19 @@ func (s *serviceClient) onTrayReady() {
case <-s.mAdminPanel.ClickedCh: case <-s.mAdminPanel.ClickedCh:
err = open.Run(s.adminURL) err = open.Run(s.adminURL)
case <-s.mUp.ClickedCh: case <-s.mUp.ClickedCh:
s.mUp.Disable() go func() {
if err = s.menuUpClick(); err != nil { err := s.menuUpClick()
s.mUp.Enable() if err != nil {
return
} }
}()
case <-s.mDown.ClickedCh: case <-s.mDown.ClickedCh:
s.mDown.Disable() go func() {
if err = s.menuDownClick(); err != nil { err := s.menuDownClick()
s.mDown.Enable() if err != nil {
return
} }
}()
case <-s.mSettings.ClickedCh: case <-s.mSettings.ClickedCh:
s.mSettings.Disable() s.mSettings.Disable()
go func() { go func() {