mirror of
https://github.com/netbirdio/netbird.git
synced 2025-04-12 05:28:44 +02:00
Unblock menu when login (#340)
* GetClientID method and increase interval on slow_down err * Reuse existing authentication flow if is not expired Created a new struct to hold additional info about the flow If there is a waiting sso running, we cancel its context * Run the up command on a goroutine * Use time.Until * Use proper ctx and consistently use goroutine for up/down
This commit is contained in:
parent
59a964eed8
commit
c86bacb5c3
@ -16,6 +16,7 @@ type OAuthClient interface {
|
||||
RequestDeviceCode(ctx context.Context) (DeviceAuthInfo, error)
|
||||
RotateAccessToken(ctx context.Context, refreshToken string) (TokenInfo, error)
|
||||
WaitToken(ctx context.Context, info DeviceAuthInfo) (TokenInfo, error)
|
||||
GetClientID(ctx context.Context) string
|
||||
}
|
||||
|
||||
// HTTPClient http client interface for API calls
|
||||
@ -104,6 +105,11 @@ func NewHostedDeviceFlow(audience string, clientID string, domain string) *Hoste
|
||||
}
|
||||
}
|
||||
|
||||
// GetClientID returns the provider client id
|
||||
func (h *Hosted) GetClientID(ctx context.Context) string {
|
||||
return h.ClientID
|
||||
}
|
||||
|
||||
// RequestDeviceCode requests a device code login flow information from Hosted
|
||||
func (h *Hosted) RequestDeviceCode(ctx context.Context) (DeviceAuthInfo, error) {
|
||||
url := "https://" + h.Domain + "/oauth/device/code"
|
||||
@ -150,7 +156,8 @@ func (h *Hosted) RequestDeviceCode(ctx context.Context) (DeviceAuthInfo, error)
|
||||
// WaitToken waits user's login and authorize the app. Once the user's authorize
|
||||
// it retrieves the access token from Hosted's endpoint and validates it before returning
|
||||
func (h *Hosted) WaitToken(ctx context.Context, info DeviceAuthInfo) (TokenInfo, error) {
|
||||
ticker := time.NewTicker(time.Duration(info.Interval) * time.Second)
|
||||
interval := time.Duration(info.Interval) * time.Second
|
||||
ticker := time.NewTicker(interval)
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
@ -181,7 +188,12 @@ func (h *Hosted) WaitToken(ctx context.Context, info DeviceAuthInfo) (TokenInfo,
|
||||
if tokenResponse.Error != "" {
|
||||
if tokenResponse.Error == "authorization_pending" {
|
||||
continue
|
||||
} else if tokenResponse.Error == "slow_down" {
|
||||
interval = interval + (3 * time.Second)
|
||||
ticker.Reset(interval)
|
||||
continue
|
||||
}
|
||||
|
||||
return TokenInfo{}, fmt.Errorf(tokenResponse.ErrorDescription)
|
||||
}
|
||||
|
||||
|
@ -26,14 +26,20 @@ type Server struct {
|
||||
configPath string
|
||||
logFile string
|
||||
|
||||
oauthClient internal.OAuthClient
|
||||
deviceAuthInfo internal.DeviceAuthInfo
|
||||
oauthAuthFlow oauthAuthFlow
|
||||
|
||||
mutex sync.Mutex
|
||||
config *internal.Config
|
||||
proto.UnimplementedDaemonServiceServer
|
||||
}
|
||||
|
||||
type oauthAuthFlow struct {
|
||||
expiresAt time.Time
|
||||
client internal.OAuthClient
|
||||
info internal.DeviceAuthInfo
|
||||
waitCancel context.CancelFunc
|
||||
}
|
||||
|
||||
// New server instance constructor.
|
||||
func New(ctx context.Context, managementURL, adminURL, configPath, logFile string) *Server {
|
||||
return &Server{
|
||||
@ -187,6 +193,21 @@ func (s *Server) Login(callerCtx context.Context, msg *proto.LoginRequest) (*pro
|
||||
providerConfig.ProviderConfig.Domain,
|
||||
)
|
||||
|
||||
if s.oauthAuthFlow.client != nil && s.oauthAuthFlow.client.GetClientID(ctx) == hostedClient.GetClientID(context.TODO()) {
|
||||
if s.oauthAuthFlow.expiresAt.After(time.Now().Add(90 * time.Second)) {
|
||||
log.Debugf("using previous device flow info")
|
||||
return &proto.LoginResponse{
|
||||
NeedsSSOLogin: true,
|
||||
VerificationURI: s.oauthAuthFlow.info.VerificationURI,
|
||||
VerificationURIComplete: s.oauthAuthFlow.info.VerificationURIComplete,
|
||||
UserCode: s.oauthAuthFlow.info.UserCode,
|
||||
}, nil
|
||||
} else {
|
||||
log.Warnf("canceling previous waiting execution")
|
||||
s.oauthAuthFlow.waitCancel()
|
||||
}
|
||||
}
|
||||
|
||||
deviceAuthInfo, err := hostedClient.RequestDeviceCode(context.TODO())
|
||||
if err != nil {
|
||||
log.Errorf("getting a request device code failed: %v", err)
|
||||
@ -194,8 +215,9 @@ func (s *Server) Login(callerCtx context.Context, msg *proto.LoginRequest) (*pro
|
||||
}
|
||||
|
||||
s.mutex.Lock()
|
||||
s.oauthClient = hostedClient
|
||||
s.deviceAuthInfo = deviceAuthInfo
|
||||
s.oauthAuthFlow.client = hostedClient
|
||||
s.oauthAuthFlow.info = deviceAuthInfo
|
||||
s.oauthAuthFlow.expiresAt = time.Now().Add(time.Duration(deviceAuthInfo.ExpiresIn) * time.Second)
|
||||
s.mutex.Unlock()
|
||||
|
||||
state.Set(internal.StatusNeedsLogin)
|
||||
@ -233,7 +255,7 @@ func (s *Server) WaitSSOLogin(callerCtx context.Context, msg *proto.WaitSSOLogin
|
||||
s.actCancel = cancel
|
||||
s.mutex.Unlock()
|
||||
|
||||
if s.oauthClient == nil {
|
||||
if s.oauthAuthFlow.client == nil {
|
||||
return nil, gstatus.Errorf(codes.Internal, "oauth client is not initialized")
|
||||
}
|
||||
|
||||
@ -248,7 +270,7 @@ func (s *Server) WaitSSOLogin(callerCtx context.Context, msg *proto.WaitSSOLogin
|
||||
state.Set(internal.StatusConnecting)
|
||||
|
||||
s.mutex.Lock()
|
||||
deviceAuthInfo := s.deviceAuthInfo
|
||||
deviceAuthInfo := s.oauthAuthFlow.info
|
||||
s.mutex.Unlock()
|
||||
|
||||
if deviceAuthInfo.UserCode != msg.UserCode {
|
||||
@ -256,12 +278,26 @@ func (s *Server) WaitSSOLogin(callerCtx context.Context, msg *proto.WaitSSOLogin
|
||||
return nil, gstatus.Errorf(codes.InvalidArgument, "sso user code is invalid")
|
||||
}
|
||||
|
||||
waitTimeout := time.Duration(deviceAuthInfo.ExpiresIn)
|
||||
waitCTX, cancel := context.WithTimeout(ctx, waitTimeout*time.Second)
|
||||
if s.oauthAuthFlow.waitCancel != nil {
|
||||
s.oauthAuthFlow.waitCancel()
|
||||
}
|
||||
|
||||
waitTimeout := time.Until(s.oauthAuthFlow.expiresAt)
|
||||
waitCTX, cancel := context.WithTimeout(ctx, waitTimeout)
|
||||
defer cancel()
|
||||
|
||||
tokenInfo, err := s.oauthClient.WaitToken(waitCTX, deviceAuthInfo)
|
||||
s.mutex.Lock()
|
||||
s.oauthAuthFlow.waitCancel = cancel
|
||||
s.mutex.Unlock()
|
||||
|
||||
tokenInfo, err := s.oauthAuthFlow.client.WaitToken(waitCTX, deviceAuthInfo)
|
||||
if err != nil {
|
||||
if err == context.Canceled {
|
||||
return nil, nil
|
||||
}
|
||||
s.mutex.Lock()
|
||||
s.oauthAuthFlow.expiresAt = time.Now()
|
||||
s.mutex.Unlock()
|
||||
state.Set(internal.StatusLoginFailed)
|
||||
log.Errorf("waiting for browser login failed: %v", err)
|
||||
return nil, err
|
||||
|
@ -261,7 +261,7 @@ func (s *serviceClient) menuUpClick() error {
|
||||
|
||||
err = s.login()
|
||||
if err != nil {
|
||||
log.Errorf("get service status: %v", err)
|
||||
log.Errorf("login failed with: %v", err)
|
||||
return err
|
||||
}
|
||||
|
||||
@ -271,16 +271,15 @@ func (s *serviceClient) menuUpClick() error {
|
||||
return err
|
||||
}
|
||||
|
||||
if status.Status != string(internal.StatusIdle) {
|
||||
if status.Status == string(internal.StatusConnected) {
|
||||
log.Warnf("already connected")
|
||||
return nil
|
||||
return err
|
||||
}
|
||||
|
||||
if _, err := s.conn.Up(s.ctx, &proto.UpRequest{}); err != nil {
|
||||
log.Errorf("up service: %v", err)
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
@ -388,15 +387,19 @@ func (s *serviceClient) onTrayReady() {
|
||||
case <-s.mAdminPanel.ClickedCh:
|
||||
err = open.Run(s.adminURL)
|
||||
case <-s.mUp.ClickedCh:
|
||||
s.mUp.Disable()
|
||||
if err = s.menuUpClick(); err != nil {
|
||||
s.mUp.Enable()
|
||||
}
|
||||
go func() {
|
||||
err := s.menuUpClick()
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
}()
|
||||
case <-s.mDown.ClickedCh:
|
||||
s.mDown.Disable()
|
||||
if err = s.menuDownClick(); err != nil {
|
||||
s.mDown.Enable()
|
||||
}
|
||||
go func() {
|
||||
err := s.menuDownClick()
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
}()
|
||||
case <-s.mSettings.ClickedCh:
|
||||
s.mSettings.Disable()
|
||||
go func() {
|
||||
|
Loading…
Reference in New Issue
Block a user