Unblock menu when login (#340)

* GetClientID method and increase interval on slow_down err

* Reuse existing authentication flow if is not expired

Created a new struct to hold additional info
 about the flow

 If there is a waiting sso running, we cancel its context

* Run the up command on a goroutine

* Use time.Until

* Use proper ctx and consistently use goroutine for up/down
This commit is contained in:
Maycon Santos 2022-05-28 18:37:08 +02:00 committed by GitHub
parent 59a964eed8
commit c86bacb5c3
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 73 additions and 22 deletions

View File

@ -16,6 +16,7 @@ type OAuthClient interface {
RequestDeviceCode(ctx context.Context) (DeviceAuthInfo, error)
RotateAccessToken(ctx context.Context, refreshToken string) (TokenInfo, error)
WaitToken(ctx context.Context, info DeviceAuthInfo) (TokenInfo, error)
GetClientID(ctx context.Context) string
}
// HTTPClient http client interface for API calls
@ -104,6 +105,11 @@ func NewHostedDeviceFlow(audience string, clientID string, domain string) *Hoste
}
}
// GetClientID returns the provider client id
func (h *Hosted) GetClientID(ctx context.Context) string {
return h.ClientID
}
// RequestDeviceCode requests a device code login flow information from Hosted
func (h *Hosted) RequestDeviceCode(ctx context.Context) (DeviceAuthInfo, error) {
url := "https://" + h.Domain + "/oauth/device/code"
@ -150,7 +156,8 @@ func (h *Hosted) RequestDeviceCode(ctx context.Context) (DeviceAuthInfo, error)
// WaitToken waits user's login and authorize the app. Once the user's authorize
// it retrieves the access token from Hosted's endpoint and validates it before returning
func (h *Hosted) WaitToken(ctx context.Context, info DeviceAuthInfo) (TokenInfo, error) {
ticker := time.NewTicker(time.Duration(info.Interval) * time.Second)
interval := time.Duration(info.Interval) * time.Second
ticker := time.NewTicker(interval)
for {
select {
case <-ctx.Done():
@ -181,7 +188,12 @@ func (h *Hosted) WaitToken(ctx context.Context, info DeviceAuthInfo) (TokenInfo,
if tokenResponse.Error != "" {
if tokenResponse.Error == "authorization_pending" {
continue
} else if tokenResponse.Error == "slow_down" {
interval = interval + (3 * time.Second)
ticker.Reset(interval)
continue
}
return TokenInfo{}, fmt.Errorf(tokenResponse.ErrorDescription)
}

View File

@ -26,14 +26,20 @@ type Server struct {
configPath string
logFile string
oauthClient internal.OAuthClient
deviceAuthInfo internal.DeviceAuthInfo
oauthAuthFlow oauthAuthFlow
mutex sync.Mutex
config *internal.Config
proto.UnimplementedDaemonServiceServer
}
type oauthAuthFlow struct {
expiresAt time.Time
client internal.OAuthClient
info internal.DeviceAuthInfo
waitCancel context.CancelFunc
}
// New server instance constructor.
func New(ctx context.Context, managementURL, adminURL, configPath, logFile string) *Server {
return &Server{
@ -187,6 +193,21 @@ func (s *Server) Login(callerCtx context.Context, msg *proto.LoginRequest) (*pro
providerConfig.ProviderConfig.Domain,
)
if s.oauthAuthFlow.client != nil && s.oauthAuthFlow.client.GetClientID(ctx) == hostedClient.GetClientID(context.TODO()) {
if s.oauthAuthFlow.expiresAt.After(time.Now().Add(90 * time.Second)) {
log.Debugf("using previous device flow info")
return &proto.LoginResponse{
NeedsSSOLogin: true,
VerificationURI: s.oauthAuthFlow.info.VerificationURI,
VerificationURIComplete: s.oauthAuthFlow.info.VerificationURIComplete,
UserCode: s.oauthAuthFlow.info.UserCode,
}, nil
} else {
log.Warnf("canceling previous waiting execution")
s.oauthAuthFlow.waitCancel()
}
}
deviceAuthInfo, err := hostedClient.RequestDeviceCode(context.TODO())
if err != nil {
log.Errorf("getting a request device code failed: %v", err)
@ -194,8 +215,9 @@ func (s *Server) Login(callerCtx context.Context, msg *proto.LoginRequest) (*pro
}
s.mutex.Lock()
s.oauthClient = hostedClient
s.deviceAuthInfo = deviceAuthInfo
s.oauthAuthFlow.client = hostedClient
s.oauthAuthFlow.info = deviceAuthInfo
s.oauthAuthFlow.expiresAt = time.Now().Add(time.Duration(deviceAuthInfo.ExpiresIn) * time.Second)
s.mutex.Unlock()
state.Set(internal.StatusNeedsLogin)
@ -233,7 +255,7 @@ func (s *Server) WaitSSOLogin(callerCtx context.Context, msg *proto.WaitSSOLogin
s.actCancel = cancel
s.mutex.Unlock()
if s.oauthClient == nil {
if s.oauthAuthFlow.client == nil {
return nil, gstatus.Errorf(codes.Internal, "oauth client is not initialized")
}
@ -248,7 +270,7 @@ func (s *Server) WaitSSOLogin(callerCtx context.Context, msg *proto.WaitSSOLogin
state.Set(internal.StatusConnecting)
s.mutex.Lock()
deviceAuthInfo := s.deviceAuthInfo
deviceAuthInfo := s.oauthAuthFlow.info
s.mutex.Unlock()
if deviceAuthInfo.UserCode != msg.UserCode {
@ -256,12 +278,26 @@ func (s *Server) WaitSSOLogin(callerCtx context.Context, msg *proto.WaitSSOLogin
return nil, gstatus.Errorf(codes.InvalidArgument, "sso user code is invalid")
}
waitTimeout := time.Duration(deviceAuthInfo.ExpiresIn)
waitCTX, cancel := context.WithTimeout(ctx, waitTimeout*time.Second)
if s.oauthAuthFlow.waitCancel != nil {
s.oauthAuthFlow.waitCancel()
}
waitTimeout := time.Until(s.oauthAuthFlow.expiresAt)
waitCTX, cancel := context.WithTimeout(ctx, waitTimeout)
defer cancel()
tokenInfo, err := s.oauthClient.WaitToken(waitCTX, deviceAuthInfo)
s.mutex.Lock()
s.oauthAuthFlow.waitCancel = cancel
s.mutex.Unlock()
tokenInfo, err := s.oauthAuthFlow.client.WaitToken(waitCTX, deviceAuthInfo)
if err != nil {
if err == context.Canceled {
return nil, nil
}
s.mutex.Lock()
s.oauthAuthFlow.expiresAt = time.Now()
s.mutex.Unlock()
state.Set(internal.StatusLoginFailed)
log.Errorf("waiting for browser login failed: %v", err)
return nil, err

View File

@ -261,7 +261,7 @@ func (s *serviceClient) menuUpClick() error {
err = s.login()
if err != nil {
log.Errorf("get service status: %v", err)
log.Errorf("login failed with: %v", err)
return err
}
@ -271,16 +271,15 @@ func (s *serviceClient) menuUpClick() error {
return err
}
if status.Status != string(internal.StatusIdle) {
if status.Status == string(internal.StatusConnected) {
log.Warnf("already connected")
return nil
return err
}
if _, err := s.conn.Up(s.ctx, &proto.UpRequest{}); err != nil {
log.Errorf("up service: %v", err)
return err
}
return nil
}
@ -388,15 +387,19 @@ func (s *serviceClient) onTrayReady() {
case <-s.mAdminPanel.ClickedCh:
err = open.Run(s.adminURL)
case <-s.mUp.ClickedCh:
s.mUp.Disable()
if err = s.menuUpClick(); err != nil {
s.mUp.Enable()
}
go func() {
err := s.menuUpClick()
if err != nil {
return
}
}()
case <-s.mDown.ClickedCh:
s.mDown.Disable()
if err = s.menuDownClick(); err != nil {
s.mDown.Enable()
}
go func() {
err := s.menuDownClick()
if err != nil {
return
}
}()
case <-s.mSettings.ClickedCh:
s.mSettings.Disable()
go func() {