mirror of
https://github.com/netbirdio/netbird.git
synced 2025-08-17 18:41:41 +02:00
[management, client] Add logout feature (#4268)
This commit is contained in:
@@ -2,6 +2,7 @@ package server
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"os"
|
||||
"os/exec"
|
||||
@@ -13,6 +14,7 @@ import (
|
||||
|
||||
"github.com/cenkalti/backoff/v4"
|
||||
"golang.org/x/exp/maps"
|
||||
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
|
||||
"google.golang.org/protobuf/types/known/durationpb"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
@@ -24,6 +26,7 @@ import (
|
||||
"github.com/netbirdio/netbird/client/internal/auth"
|
||||
"github.com/netbirdio/netbird/client/internal/profilemanager"
|
||||
"github.com/netbirdio/netbird/client/system"
|
||||
mgm "github.com/netbirdio/netbird/management/client"
|
||||
"github.com/netbirdio/netbird/management/domain"
|
||||
|
||||
"github.com/netbirdio/netbird/client/internal"
|
||||
@@ -47,6 +50,8 @@ const (
|
||||
errProfilesDisabled = "profiles are disabled, you cannot use this feature without profiles enabled"
|
||||
)
|
||||
|
||||
var ErrServiceNotUp = errors.New("service is not up")
|
||||
|
||||
// Server for service control.
|
||||
type Server struct {
|
||||
rootCtx context.Context
|
||||
@@ -131,13 +136,7 @@ func (s *Server) Start() error {
|
||||
return fmt.Errorf("failed to get active profile state: %w", err)
|
||||
}
|
||||
|
||||
cfgPath, err := activeProf.FilePath()
|
||||
if err != nil {
|
||||
log.Errorf("failed to get active profile file path: %v", err)
|
||||
return fmt.Errorf("failed to get active profile file path: %w", err)
|
||||
}
|
||||
|
||||
config, err := profilemanager.GetConfig(cfgPath)
|
||||
config, err := s.getConfig(activeProf)
|
||||
if err != nil {
|
||||
log.Errorf("failed to get active profile config: %v", err)
|
||||
|
||||
@@ -484,13 +483,7 @@ func (s *Server) Login(callerCtx context.Context, msg *proto.LoginRequest) (*pro
|
||||
}
|
||||
s.mutex.Unlock()
|
||||
|
||||
cfgPath, err := activeProf.FilePath()
|
||||
if err != nil {
|
||||
log.Errorf("failed to get active profile file path: %v", err)
|
||||
return nil, fmt.Errorf("failed to get active profile file path: %w", err)
|
||||
}
|
||||
|
||||
config, err := profilemanager.GetConfig(cfgPath)
|
||||
config, err := s.getConfig(activeProf)
|
||||
if err != nil {
|
||||
log.Errorf("failed to get active profile config: %v", err)
|
||||
return nil, fmt.Errorf("failed to get active profile config: %w", err)
|
||||
@@ -701,13 +694,7 @@ func (s *Server) Up(callerCtx context.Context, msg *proto.UpRequest) (*proto.UpR
|
||||
|
||||
log.Infof("active profile: %s for %s", activeProf.Name, activeProf.Username)
|
||||
|
||||
cfgPath, err := activeProf.FilePath()
|
||||
if err != nil {
|
||||
log.Errorf("failed to get active profile file path: %v", err)
|
||||
return nil, fmt.Errorf("failed to get active profile file path: %w", err)
|
||||
}
|
||||
|
||||
config, err := profilemanager.GetConfig(cfgPath)
|
||||
config, err := s.getConfig(activeProf)
|
||||
if err != nil {
|
||||
log.Errorf("failed to get active profile config: %v", err)
|
||||
return nil, fmt.Errorf("failed to get active profile config: %w", err)
|
||||
@@ -789,13 +776,7 @@ func (s *Server) SwitchProfile(callerCtx context.Context, msg *proto.SwitchProfi
|
||||
log.Errorf("failed to get active profile state: %v", err)
|
||||
return nil, fmt.Errorf("failed to get active profile state: %w", err)
|
||||
}
|
||||
cfgPath, err := activeProf.FilePath()
|
||||
if err != nil {
|
||||
log.Errorf("failed to get active profile file path: %v", err)
|
||||
return nil, fmt.Errorf("failed to get active profile file path: %w", err)
|
||||
}
|
||||
|
||||
config, err := profilemanager.GetConfig(cfgPath)
|
||||
config, err := s.getConfig(activeProf)
|
||||
if err != nil {
|
||||
log.Errorf("failed to get default profile config: %v", err)
|
||||
return nil, fmt.Errorf("failed to get default profile config: %w", err)
|
||||
@@ -811,26 +792,201 @@ func (s *Server) Down(ctx context.Context, _ *proto.DownRequest) (*proto.DownRes
|
||||
s.mutex.Lock()
|
||||
defer s.mutex.Unlock()
|
||||
|
||||
s.oauthAuthFlow = oauthAuthFlow{}
|
||||
|
||||
if s.actCancel == nil {
|
||||
return nil, fmt.Errorf("service is not up")
|
||||
}
|
||||
s.actCancel()
|
||||
|
||||
err := s.connectClient.Stop()
|
||||
if err != nil {
|
||||
if err := s.cleanupConnection(); err != nil {
|
||||
log.Errorf("failed to shut down properly: %v", err)
|
||||
return nil, err
|
||||
}
|
||||
s.isSessionActive.Store(false)
|
||||
|
||||
state := internal.CtxGetState(s.rootCtx)
|
||||
state.Set(internal.StatusIdle)
|
||||
|
||||
return &proto.DownResponse{}, nil
|
||||
}
|
||||
|
||||
func (s *Server) cleanupConnection() error {
|
||||
s.oauthAuthFlow = oauthAuthFlow{}
|
||||
|
||||
if s.actCancel == nil {
|
||||
return ErrServiceNotUp
|
||||
}
|
||||
s.actCancel()
|
||||
|
||||
if s.connectClient == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
if err := s.connectClient.Stop(); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
s.connectClient = nil
|
||||
s.isSessionActive.Store(false)
|
||||
|
||||
log.Infof("service is down")
|
||||
|
||||
return &proto.DownResponse{}, nil
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *Server) Logout(ctx context.Context, msg *proto.LogoutRequest) (*proto.LogoutResponse, error) {
|
||||
s.mutex.Lock()
|
||||
defer s.mutex.Unlock()
|
||||
|
||||
if msg.ProfileName != nil && *msg.ProfileName != "" {
|
||||
return s.handleProfileLogout(ctx, msg)
|
||||
}
|
||||
|
||||
return s.handleActiveProfileLogout(ctx)
|
||||
}
|
||||
|
||||
func (s *Server) handleProfileLogout(ctx context.Context, msg *proto.LogoutRequest) (*proto.LogoutResponse, error) {
|
||||
if err := s.validateProfileOperation(*msg.ProfileName, true); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if msg.Username == nil || *msg.Username == "" {
|
||||
return nil, gstatus.Errorf(codes.InvalidArgument, "username must be provided when profile name is specified")
|
||||
}
|
||||
username := *msg.Username
|
||||
|
||||
if err := s.logoutFromProfile(ctx, *msg.ProfileName, username); err != nil {
|
||||
log.Errorf("failed to logout from profile %s: %v", *msg.ProfileName, err)
|
||||
return nil, gstatus.Errorf(codes.Internal, "logout: %v", err)
|
||||
}
|
||||
|
||||
activeProf, _ := s.profileManager.GetActiveProfileState()
|
||||
if activeProf != nil && activeProf.Name == *msg.ProfileName {
|
||||
if err := s.cleanupConnection(); err != nil && !errors.Is(err, ErrServiceNotUp) {
|
||||
log.Errorf("failed to cleanup connection: %v", err)
|
||||
}
|
||||
state := internal.CtxGetState(s.rootCtx)
|
||||
state.Set(internal.StatusNeedsLogin)
|
||||
}
|
||||
|
||||
return &proto.LogoutResponse{}, nil
|
||||
}
|
||||
|
||||
func (s *Server) handleActiveProfileLogout(ctx context.Context) (*proto.LogoutResponse, error) {
|
||||
if s.config == nil {
|
||||
activeProf, err := s.profileManager.GetActiveProfileState()
|
||||
if err != nil {
|
||||
return nil, gstatus.Errorf(codes.FailedPrecondition, "failed to get active profile state: %v", err)
|
||||
}
|
||||
|
||||
config, err := s.getConfig(activeProf)
|
||||
if err != nil {
|
||||
return nil, gstatus.Errorf(codes.FailedPrecondition, "not logged in")
|
||||
}
|
||||
s.config = config
|
||||
}
|
||||
|
||||
if err := s.sendLogoutRequest(ctx); err != nil {
|
||||
log.Errorf("failed to send logout request: %v", err)
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if err := s.cleanupConnection(); err != nil && !errors.Is(err, ErrServiceNotUp) {
|
||||
log.Errorf("failed to cleanup connection: %v", err)
|
||||
return nil, err
|
||||
}
|
||||
|
||||
state := internal.CtxGetState(s.rootCtx)
|
||||
state.Set(internal.StatusNeedsLogin)
|
||||
|
||||
return &proto.LogoutResponse{}, nil
|
||||
}
|
||||
|
||||
// getConfig loads the config from the active profile
|
||||
func (s *Server) getConfig(activeProf *profilemanager.ActiveProfileState) (*profilemanager.Config, error) {
|
||||
cfgPath, err := activeProf.FilePath()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to get active profile file path: %w", err)
|
||||
}
|
||||
|
||||
config, err := profilemanager.GetConfig(cfgPath)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to get config: %w", err)
|
||||
}
|
||||
|
||||
return config, nil
|
||||
}
|
||||
|
||||
func (s *Server) canRemoveProfile(profileName string) error {
|
||||
if profileName == profilemanager.DefaultProfileName {
|
||||
return fmt.Errorf("remove profile with reserved name: %s", profilemanager.DefaultProfileName)
|
||||
}
|
||||
|
||||
activeProf, err := s.profileManager.GetActiveProfileState()
|
||||
if err == nil && activeProf.Name == profileName {
|
||||
return fmt.Errorf("remove active profile: %s", profileName)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *Server) validateProfileOperation(profileName string, allowActiveProfile bool) error {
|
||||
if s.checkProfilesDisabled() {
|
||||
return gstatus.Errorf(codes.Unavailable, errProfilesDisabled)
|
||||
}
|
||||
|
||||
if profileName == "" {
|
||||
return gstatus.Errorf(codes.InvalidArgument, "profile name must be provided")
|
||||
}
|
||||
|
||||
if !allowActiveProfile {
|
||||
if err := s.canRemoveProfile(profileName); err != nil {
|
||||
return gstatus.Errorf(codes.InvalidArgument, "%v", err)
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// logoutFromProfile logs out from a specific profile by loading its config and sending logout request
|
||||
func (s *Server) logoutFromProfile(ctx context.Context, profileName, username string) error {
|
||||
activeProf, err := s.profileManager.GetActiveProfileState()
|
||||
if err == nil && activeProf.Name == profileName && s.connectClient != nil {
|
||||
return s.sendLogoutRequest(ctx)
|
||||
}
|
||||
|
||||
profileState := &profilemanager.ActiveProfileState{
|
||||
Name: profileName,
|
||||
Username: username,
|
||||
}
|
||||
profilePath, err := profileState.FilePath()
|
||||
if err != nil {
|
||||
return fmt.Errorf("get profile path: %w", err)
|
||||
}
|
||||
|
||||
config, err := profilemanager.GetConfig(profilePath)
|
||||
if err != nil {
|
||||
return fmt.Errorf("profile '%s' not found", profileName)
|
||||
}
|
||||
|
||||
return s.sendLogoutRequestWithConfig(ctx, config)
|
||||
}
|
||||
|
||||
func (s *Server) sendLogoutRequest(ctx context.Context) error {
|
||||
return s.sendLogoutRequestWithConfig(ctx, s.config)
|
||||
}
|
||||
|
||||
func (s *Server) sendLogoutRequestWithConfig(ctx context.Context, config *profilemanager.Config) error {
|
||||
key, err := wgtypes.ParseKey(config.PrivateKey)
|
||||
if err != nil {
|
||||
return fmt.Errorf("parse private key: %w", err)
|
||||
}
|
||||
|
||||
mgmTlsEnabled := config.ManagementURL.Scheme == "https"
|
||||
mgmClient, err := mgm.NewClient(ctx, config.ManagementURL.Host, key, mgmTlsEnabled)
|
||||
if err != nil {
|
||||
return fmt.Errorf("connect to management server: %w", err)
|
||||
}
|
||||
defer func() {
|
||||
if err := mgmClient.Close(); err != nil {
|
||||
log.Errorf("close management client: %v", err)
|
||||
}
|
||||
}()
|
||||
|
||||
return mgmClient.Logout()
|
||||
}
|
||||
|
||||
// Status returns the daemon status
|
||||
@@ -1107,12 +1263,12 @@ func (s *Server) RemoveProfile(ctx context.Context, msg *proto.RemoveProfileRequ
|
||||
s.mutex.Lock()
|
||||
defer s.mutex.Unlock()
|
||||
|
||||
if s.checkProfilesDisabled() {
|
||||
return nil, gstatus.Errorf(codes.Unavailable, errProfilesDisabled)
|
||||
if err := s.validateProfileOperation(msg.ProfileName, false); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if msg.ProfileName == "" {
|
||||
return nil, gstatus.Errorf(codes.InvalidArgument, "profile name must be provided")
|
||||
if err := s.logoutFromProfile(ctx, msg.ProfileName, msg.Username); err != nil {
|
||||
log.Warnf("failed to logout from profile %s before removal: %v", msg.ProfileName, err)
|
||||
}
|
||||
|
||||
if err := s.profileManager.RemoveProfile(msg.ProfileName, msg.Username); err != nil {
|
||||
|
Reference in New Issue
Block a user