mirror of
https://github.com/netbirdio/netbird.git
synced 2025-05-30 06:40:15 +02:00
Unblock menu when login (#340)
* GetClientID method and increase interval on slow_down err * Reuse existing authentication flow if is not expired Created a new struct to hold additional info about the flow If there is a waiting sso running, we cancel its context * Run the up command on a goroutine * Use time.Until * Use proper ctx and consistently use goroutine for up/down
This commit is contained in:
parent
59a964eed8
commit
c86bacb5c3
@ -16,6 +16,7 @@ type OAuthClient interface {
|
|||||||
RequestDeviceCode(ctx context.Context) (DeviceAuthInfo, error)
|
RequestDeviceCode(ctx context.Context) (DeviceAuthInfo, error)
|
||||||
RotateAccessToken(ctx context.Context, refreshToken string) (TokenInfo, error)
|
RotateAccessToken(ctx context.Context, refreshToken string) (TokenInfo, error)
|
||||||
WaitToken(ctx context.Context, info DeviceAuthInfo) (TokenInfo, error)
|
WaitToken(ctx context.Context, info DeviceAuthInfo) (TokenInfo, error)
|
||||||
|
GetClientID(ctx context.Context) string
|
||||||
}
|
}
|
||||||
|
|
||||||
// HTTPClient http client interface for API calls
|
// HTTPClient http client interface for API calls
|
||||||
@ -104,6 +105,11 @@ func NewHostedDeviceFlow(audience string, clientID string, domain string) *Hoste
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// GetClientID returns the provider client id
|
||||||
|
func (h *Hosted) GetClientID(ctx context.Context) string {
|
||||||
|
return h.ClientID
|
||||||
|
}
|
||||||
|
|
||||||
// RequestDeviceCode requests a device code login flow information from Hosted
|
// RequestDeviceCode requests a device code login flow information from Hosted
|
||||||
func (h *Hosted) RequestDeviceCode(ctx context.Context) (DeviceAuthInfo, error) {
|
func (h *Hosted) RequestDeviceCode(ctx context.Context) (DeviceAuthInfo, error) {
|
||||||
url := "https://" + h.Domain + "/oauth/device/code"
|
url := "https://" + h.Domain + "/oauth/device/code"
|
||||||
@ -150,7 +156,8 @@ func (h *Hosted) RequestDeviceCode(ctx context.Context) (DeviceAuthInfo, error)
|
|||||||
// WaitToken waits user's login and authorize the app. Once the user's authorize
|
// WaitToken waits user's login and authorize the app. Once the user's authorize
|
||||||
// it retrieves the access token from Hosted's endpoint and validates it before returning
|
// it retrieves the access token from Hosted's endpoint and validates it before returning
|
||||||
func (h *Hosted) WaitToken(ctx context.Context, info DeviceAuthInfo) (TokenInfo, error) {
|
func (h *Hosted) WaitToken(ctx context.Context, info DeviceAuthInfo) (TokenInfo, error) {
|
||||||
ticker := time.NewTicker(time.Duration(info.Interval) * time.Second)
|
interval := time.Duration(info.Interval) * time.Second
|
||||||
|
ticker := time.NewTicker(interval)
|
||||||
for {
|
for {
|
||||||
select {
|
select {
|
||||||
case <-ctx.Done():
|
case <-ctx.Done():
|
||||||
@ -181,7 +188,12 @@ func (h *Hosted) WaitToken(ctx context.Context, info DeviceAuthInfo) (TokenInfo,
|
|||||||
if tokenResponse.Error != "" {
|
if tokenResponse.Error != "" {
|
||||||
if tokenResponse.Error == "authorization_pending" {
|
if tokenResponse.Error == "authorization_pending" {
|
||||||
continue
|
continue
|
||||||
|
} else if tokenResponse.Error == "slow_down" {
|
||||||
|
interval = interval + (3 * time.Second)
|
||||||
|
ticker.Reset(interval)
|
||||||
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
return TokenInfo{}, fmt.Errorf(tokenResponse.ErrorDescription)
|
return TokenInfo{}, fmt.Errorf(tokenResponse.ErrorDescription)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -26,14 +26,20 @@ type Server struct {
|
|||||||
configPath string
|
configPath string
|
||||||
logFile string
|
logFile string
|
||||||
|
|
||||||
oauthClient internal.OAuthClient
|
oauthAuthFlow oauthAuthFlow
|
||||||
deviceAuthInfo internal.DeviceAuthInfo
|
|
||||||
|
|
||||||
mutex sync.Mutex
|
mutex sync.Mutex
|
||||||
config *internal.Config
|
config *internal.Config
|
||||||
proto.UnimplementedDaemonServiceServer
|
proto.UnimplementedDaemonServiceServer
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type oauthAuthFlow struct {
|
||||||
|
expiresAt time.Time
|
||||||
|
client internal.OAuthClient
|
||||||
|
info internal.DeviceAuthInfo
|
||||||
|
waitCancel context.CancelFunc
|
||||||
|
}
|
||||||
|
|
||||||
// New server instance constructor.
|
// New server instance constructor.
|
||||||
func New(ctx context.Context, managementURL, adminURL, configPath, logFile string) *Server {
|
func New(ctx context.Context, managementURL, adminURL, configPath, logFile string) *Server {
|
||||||
return &Server{
|
return &Server{
|
||||||
@ -187,6 +193,21 @@ func (s *Server) Login(callerCtx context.Context, msg *proto.LoginRequest) (*pro
|
|||||||
providerConfig.ProviderConfig.Domain,
|
providerConfig.ProviderConfig.Domain,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if s.oauthAuthFlow.client != nil && s.oauthAuthFlow.client.GetClientID(ctx) == hostedClient.GetClientID(context.TODO()) {
|
||||||
|
if s.oauthAuthFlow.expiresAt.After(time.Now().Add(90 * time.Second)) {
|
||||||
|
log.Debugf("using previous device flow info")
|
||||||
|
return &proto.LoginResponse{
|
||||||
|
NeedsSSOLogin: true,
|
||||||
|
VerificationURI: s.oauthAuthFlow.info.VerificationURI,
|
||||||
|
VerificationURIComplete: s.oauthAuthFlow.info.VerificationURIComplete,
|
||||||
|
UserCode: s.oauthAuthFlow.info.UserCode,
|
||||||
|
}, nil
|
||||||
|
} else {
|
||||||
|
log.Warnf("canceling previous waiting execution")
|
||||||
|
s.oauthAuthFlow.waitCancel()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
deviceAuthInfo, err := hostedClient.RequestDeviceCode(context.TODO())
|
deviceAuthInfo, err := hostedClient.RequestDeviceCode(context.TODO())
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Errorf("getting a request device code failed: %v", err)
|
log.Errorf("getting a request device code failed: %v", err)
|
||||||
@ -194,8 +215,9 @@ func (s *Server) Login(callerCtx context.Context, msg *proto.LoginRequest) (*pro
|
|||||||
}
|
}
|
||||||
|
|
||||||
s.mutex.Lock()
|
s.mutex.Lock()
|
||||||
s.oauthClient = hostedClient
|
s.oauthAuthFlow.client = hostedClient
|
||||||
s.deviceAuthInfo = deviceAuthInfo
|
s.oauthAuthFlow.info = deviceAuthInfo
|
||||||
|
s.oauthAuthFlow.expiresAt = time.Now().Add(time.Duration(deviceAuthInfo.ExpiresIn) * time.Second)
|
||||||
s.mutex.Unlock()
|
s.mutex.Unlock()
|
||||||
|
|
||||||
state.Set(internal.StatusNeedsLogin)
|
state.Set(internal.StatusNeedsLogin)
|
||||||
@ -233,7 +255,7 @@ func (s *Server) WaitSSOLogin(callerCtx context.Context, msg *proto.WaitSSOLogin
|
|||||||
s.actCancel = cancel
|
s.actCancel = cancel
|
||||||
s.mutex.Unlock()
|
s.mutex.Unlock()
|
||||||
|
|
||||||
if s.oauthClient == nil {
|
if s.oauthAuthFlow.client == nil {
|
||||||
return nil, gstatus.Errorf(codes.Internal, "oauth client is not initialized")
|
return nil, gstatus.Errorf(codes.Internal, "oauth client is not initialized")
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -248,7 +270,7 @@ func (s *Server) WaitSSOLogin(callerCtx context.Context, msg *proto.WaitSSOLogin
|
|||||||
state.Set(internal.StatusConnecting)
|
state.Set(internal.StatusConnecting)
|
||||||
|
|
||||||
s.mutex.Lock()
|
s.mutex.Lock()
|
||||||
deviceAuthInfo := s.deviceAuthInfo
|
deviceAuthInfo := s.oauthAuthFlow.info
|
||||||
s.mutex.Unlock()
|
s.mutex.Unlock()
|
||||||
|
|
||||||
if deviceAuthInfo.UserCode != msg.UserCode {
|
if deviceAuthInfo.UserCode != msg.UserCode {
|
||||||
@ -256,12 +278,26 @@ func (s *Server) WaitSSOLogin(callerCtx context.Context, msg *proto.WaitSSOLogin
|
|||||||
return nil, gstatus.Errorf(codes.InvalidArgument, "sso user code is invalid")
|
return nil, gstatus.Errorf(codes.InvalidArgument, "sso user code is invalid")
|
||||||
}
|
}
|
||||||
|
|
||||||
waitTimeout := time.Duration(deviceAuthInfo.ExpiresIn)
|
if s.oauthAuthFlow.waitCancel != nil {
|
||||||
waitCTX, cancel := context.WithTimeout(ctx, waitTimeout*time.Second)
|
s.oauthAuthFlow.waitCancel()
|
||||||
|
}
|
||||||
|
|
||||||
|
waitTimeout := time.Until(s.oauthAuthFlow.expiresAt)
|
||||||
|
waitCTX, cancel := context.WithTimeout(ctx, waitTimeout)
|
||||||
defer cancel()
|
defer cancel()
|
||||||
|
|
||||||
tokenInfo, err := s.oauthClient.WaitToken(waitCTX, deviceAuthInfo)
|
s.mutex.Lock()
|
||||||
|
s.oauthAuthFlow.waitCancel = cancel
|
||||||
|
s.mutex.Unlock()
|
||||||
|
|
||||||
|
tokenInfo, err := s.oauthAuthFlow.client.WaitToken(waitCTX, deviceAuthInfo)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
if err == context.Canceled {
|
||||||
|
return nil, nil
|
||||||
|
}
|
||||||
|
s.mutex.Lock()
|
||||||
|
s.oauthAuthFlow.expiresAt = time.Now()
|
||||||
|
s.mutex.Unlock()
|
||||||
state.Set(internal.StatusLoginFailed)
|
state.Set(internal.StatusLoginFailed)
|
||||||
log.Errorf("waiting for browser login failed: %v", err)
|
log.Errorf("waiting for browser login failed: %v", err)
|
||||||
return nil, err
|
return nil, err
|
||||||
|
@ -261,7 +261,7 @@ func (s *serviceClient) menuUpClick() error {
|
|||||||
|
|
||||||
err = s.login()
|
err = s.login()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Errorf("get service status: %v", err)
|
log.Errorf("login failed with: %v", err)
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -271,16 +271,15 @@ func (s *serviceClient) menuUpClick() error {
|
|||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
if status.Status != string(internal.StatusIdle) {
|
if status.Status == string(internal.StatusConnected) {
|
||||||
log.Warnf("already connected")
|
log.Warnf("already connected")
|
||||||
return nil
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
if _, err := s.conn.Up(s.ctx, &proto.UpRequest{}); err != nil {
|
if _, err := s.conn.Up(s.ctx, &proto.UpRequest{}); err != nil {
|
||||||
log.Errorf("up service: %v", err)
|
log.Errorf("up service: %v", err)
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -388,15 +387,19 @@ func (s *serviceClient) onTrayReady() {
|
|||||||
case <-s.mAdminPanel.ClickedCh:
|
case <-s.mAdminPanel.ClickedCh:
|
||||||
err = open.Run(s.adminURL)
|
err = open.Run(s.adminURL)
|
||||||
case <-s.mUp.ClickedCh:
|
case <-s.mUp.ClickedCh:
|
||||||
s.mUp.Disable()
|
go func() {
|
||||||
if err = s.menuUpClick(); err != nil {
|
err := s.menuUpClick()
|
||||||
s.mUp.Enable()
|
if err != nil {
|
||||||
|
return
|
||||||
}
|
}
|
||||||
|
}()
|
||||||
case <-s.mDown.ClickedCh:
|
case <-s.mDown.ClickedCh:
|
||||||
s.mDown.Disable()
|
go func() {
|
||||||
if err = s.menuDownClick(); err != nil {
|
err := s.menuDownClick()
|
||||||
s.mDown.Enable()
|
if err != nil {
|
||||||
|
return
|
||||||
}
|
}
|
||||||
|
}()
|
||||||
case <-s.mSettings.ClickedCh:
|
case <-s.mSettings.ClickedCh:
|
||||||
s.mSettings.Disable()
|
s.mSettings.Disable()
|
||||||
go func() {
|
go func() {
|
||||||
|
Loading…
x
Reference in New Issue
Block a user