diff --git a/client/cmd/logout.go b/client/cmd/logout.go new file mode 100644 index 000000000..071be5ca9 --- /dev/null +++ b/client/cmd/logout.go @@ -0,0 +1,57 @@ +package cmd + +import ( + "context" + "fmt" + "os/user" + "time" + + "github.com/spf13/cobra" + + "github.com/netbirdio/netbird/client/proto" +) + +var logoutCmd = &cobra.Command{ + Use: "logout", + Short: "logout from the Netbird Management Service and delete peer", + RunE: func(cmd *cobra.Command, args []string) error { + SetFlagsFromEnvVars(rootCmd) + + cmd.SetOut(cmd.OutOrStdout()) + + ctx, cancel := context.WithTimeout(context.Background(), time.Second*7) + defer cancel() + + conn, err := DialClientGRPCServer(ctx, daemonAddr) + if err != nil { + return fmt.Errorf("connect to daemon: %v", err) + } + defer conn.Close() + + daemonClient := proto.NewDaemonServiceClient(conn) + + req := &proto.LogoutRequest{} + + if profileName != "" { + req.ProfileName = &profileName + + currUser, err := user.Current() + if err != nil { + return fmt.Errorf("get current user: %v", err) + } + username := currUser.Username + req.Username = &username + } + + if _, err := daemonClient.Logout(ctx, req); err != nil { + return fmt.Errorf("logout: %v", err) + } + + cmd.Println("Logged out successfully") + return nil + }, +} + +func init() { + logoutCmd.PersistentFlags().StringVar(&profileName, profileNameFlag, "", profileNameDesc) +} diff --git a/client/cmd/profile.go b/client/cmd/profile.go index f32e9c844..d420dcbd9 100644 --- a/client/cmd/profile.go +++ b/client/cmd/profile.go @@ -3,9 +3,8 @@ package cmd import ( "context" "fmt" - "time" - "os/user" + "time" "github.com/spf13/cobra" @@ -22,10 +21,11 @@ var profileCmd = &cobra.Command{ } var profileListCmd = &cobra.Command{ - Use: "list", - Short: "list all profiles", - Long: `List all available profiles in the Netbird client.`, - RunE: listProfilesFunc, + Use: "list", + Short: "list all profiles", + Long: `List all available profiles in the Netbird client.`, + Aliases: []string{"ls"}, + RunE: listProfilesFunc, } var profileAddCmd = &cobra.Command{ diff --git a/client/cmd/root.go b/client/cmd/root.go index e3ce79964..6a8ae27f4 100644 --- a/client/cmd/root.go +++ b/client/cmd/root.go @@ -133,6 +133,7 @@ func init() { rootCmd.AddCommand(downCmd) rootCmd.AddCommand(statusCmd) rootCmd.AddCommand(loginCmd) + rootCmd.AddCommand(logoutCmd) rootCmd.AddCommand(versionCmd) rootCmd.AddCommand(sshCmd) rootCmd.AddCommand(networksCMD) diff --git a/client/internal/profilemanager/profilemanager.go b/client/internal/profilemanager/profilemanager.go index 4598af33e..fe0afae2b 100644 --- a/client/internal/profilemanager/profilemanager.go +++ b/client/internal/profilemanager/profilemanager.go @@ -13,7 +13,8 @@ import ( ) const ( - defaultProfileName = "default" + DefaultProfileName = "default" + defaultProfileName = DefaultProfileName // Keep for backward compatibility activeProfileStateFilename = "active_profile.txt" ) diff --git a/client/proto/daemon.pb.go b/client/proto/daemon.pb.go index f405ffd65..7d5ddc8a9 100644 --- a/client/proto/daemon.pb.go +++ b/client/proto/daemon.pb.go @@ -4342,6 +4342,94 @@ func (x *GetActiveProfileResponse) GetUsername() string { return "" } +type LogoutRequest struct { + state protoimpl.MessageState `protogen:"open.v1"` + ProfileName *string `protobuf:"bytes,1,opt,name=profileName,proto3,oneof" json:"profileName,omitempty"` + Username *string `protobuf:"bytes,2,opt,name=username,proto3,oneof" json:"username,omitempty"` + unknownFields protoimpl.UnknownFields + sizeCache protoimpl.SizeCache +} + +func (x *LogoutRequest) Reset() { + *x = LogoutRequest{} + mi := &file_daemon_proto_msgTypes[65] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) +} + +func (x *LogoutRequest) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*LogoutRequest) ProtoMessage() {} + +func (x *LogoutRequest) ProtoReflect() protoreflect.Message { + mi := &file_daemon_proto_msgTypes[65] + if x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use LogoutRequest.ProtoReflect.Descriptor instead. +func (*LogoutRequest) Descriptor() ([]byte, []int) { + return file_daemon_proto_rawDescGZIP(), []int{65} +} + +func (x *LogoutRequest) GetProfileName() string { + if x != nil && x.ProfileName != nil { + return *x.ProfileName + } + return "" +} + +func (x *LogoutRequest) GetUsername() string { + if x != nil && x.Username != nil { + return *x.Username + } + return "" +} + +type LogoutResponse struct { + state protoimpl.MessageState `protogen:"open.v1"` + unknownFields protoimpl.UnknownFields + sizeCache protoimpl.SizeCache +} + +func (x *LogoutResponse) Reset() { + *x = LogoutResponse{} + mi := &file_daemon_proto_msgTypes[66] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) +} + +func (x *LogoutResponse) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*LogoutResponse) ProtoMessage() {} + +func (x *LogoutResponse) ProtoReflect() protoreflect.Message { + mi := &file_daemon_proto_msgTypes[66] + if x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use LogoutResponse.ProtoReflect.Descriptor instead. +func (*LogoutResponse) Descriptor() ([]byte, []int) { + return file_daemon_proto_rawDescGZIP(), []int{66} +} + type PortInfo_Range struct { state protoimpl.MessageState `protogen:"open.v1"` Start uint32 `protobuf:"varint,1,opt,name=start,proto3" json:"start,omitempty"` @@ -4352,7 +4440,7 @@ type PortInfo_Range struct { func (x *PortInfo_Range) Reset() { *x = PortInfo_Range{} - mi := &file_daemon_proto_msgTypes[66] + mi := &file_daemon_proto_msgTypes[68] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -4364,7 +4452,7 @@ func (x *PortInfo_Range) String() string { func (*PortInfo_Range) ProtoMessage() {} func (x *PortInfo_Range) ProtoReflect() protoreflect.Message { - mi := &file_daemon_proto_msgTypes[66] + mi := &file_daemon_proto_msgTypes[68] if x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -4778,7 +4866,13 @@ const file_daemon_proto_rawDesc = "" + "\x17GetActiveProfileRequest\"X\n" + "\x18GetActiveProfileResponse\x12 \n" + "\vprofileName\x18\x01 \x01(\tR\vprofileName\x12\x1a\n" + - "\busername\x18\x02 \x01(\tR\busername*b\n" + + "\busername\x18\x02 \x01(\tR\busername\"t\n" + + "\rLogoutRequest\x12%\n" + + "\vprofileName\x18\x01 \x01(\tH\x00R\vprofileName\x88\x01\x01\x12\x1f\n" + + "\busername\x18\x02 \x01(\tH\x01R\busername\x88\x01\x01B\x0e\n" + + "\f_profileNameB\v\n" + + "\t_username\"\x10\n" + + "\x0eLogoutResponse*b\n" + "\bLogLevel\x12\v\n" + "\aUNKNOWN\x10\x00\x12\t\n" + "\x05PANIC\x10\x01\x12\t\n" + @@ -4787,7 +4881,7 @@ const file_daemon_proto_rawDesc = "" + "\x04WARN\x10\x04\x12\b\n" + "\x04INFO\x10\x05\x12\t\n" + "\x05DEBUG\x10\x06\x12\t\n" + - "\x05TRACE\x10\a2\x84\x0f\n" + + "\x05TRACE\x10\a2\xbf\x0f\n" + "\rDaemonService\x126\n" + "\x05Login\x12\x14.daemon.LoginRequest\x1a\x15.daemon.LoginResponse\"\x00\x12K\n" + "\fWaitSSOLogin\x12\x1b.daemon.WaitSSOLoginRequest\x1a\x1c.daemon.WaitSSOLoginResponse\"\x00\x12-\n" + @@ -4817,7 +4911,8 @@ const file_daemon_proto_rawDesc = "" + "AddProfile\x12\x19.daemon.AddProfileRequest\x1a\x1a.daemon.AddProfileResponse\"\x00\x12N\n" + "\rRemoveProfile\x12\x1c.daemon.RemoveProfileRequest\x1a\x1d.daemon.RemoveProfileResponse\"\x00\x12K\n" + "\fListProfiles\x12\x1b.daemon.ListProfilesRequest\x1a\x1c.daemon.ListProfilesResponse\"\x00\x12W\n" + - "\x10GetActiveProfile\x12\x1f.daemon.GetActiveProfileRequest\x1a .daemon.GetActiveProfileResponse\"\x00B\bZ\x06/protob\x06proto3" + "\x10GetActiveProfile\x12\x1f.daemon.GetActiveProfileRequest\x1a .daemon.GetActiveProfileResponse\"\x00\x129\n" + + "\x06Logout\x12\x15.daemon.LogoutRequest\x1a\x16.daemon.LogoutResponse\"\x00B\bZ\x06/protob\x06proto3" var ( file_daemon_proto_rawDescOnce sync.Once @@ -4832,7 +4927,7 @@ func file_daemon_proto_rawDescGZIP() []byte { } var file_daemon_proto_enumTypes = make([]protoimpl.EnumInfo, 3) -var file_daemon_proto_msgTypes = make([]protoimpl.MessageInfo, 68) +var file_daemon_proto_msgTypes = make([]protoimpl.MessageInfo, 70) var file_daemon_proto_goTypes = []any{ (LogLevel)(0), // 0: daemon.LogLevel (SystemEvent_Severity)(0), // 1: daemon.SystemEvent.Severity @@ -4902,18 +4997,20 @@ var file_daemon_proto_goTypes = []any{ (*Profile)(nil), // 65: daemon.Profile (*GetActiveProfileRequest)(nil), // 66: daemon.GetActiveProfileRequest (*GetActiveProfileResponse)(nil), // 67: daemon.GetActiveProfileResponse - nil, // 68: daemon.Network.ResolvedIPsEntry - (*PortInfo_Range)(nil), // 69: daemon.PortInfo.Range - nil, // 70: daemon.SystemEvent.MetadataEntry - (*durationpb.Duration)(nil), // 71: google.protobuf.Duration - (*timestamppb.Timestamp)(nil), // 72: google.protobuf.Timestamp + (*LogoutRequest)(nil), // 68: daemon.LogoutRequest + (*LogoutResponse)(nil), // 69: daemon.LogoutResponse + nil, // 70: daemon.Network.ResolvedIPsEntry + (*PortInfo_Range)(nil), // 71: daemon.PortInfo.Range + nil, // 72: daemon.SystemEvent.MetadataEntry + (*durationpb.Duration)(nil), // 73: google.protobuf.Duration + (*timestamppb.Timestamp)(nil), // 74: google.protobuf.Timestamp } var file_daemon_proto_depIdxs = []int32{ - 71, // 0: daemon.LoginRequest.dnsRouteInterval:type_name -> google.protobuf.Duration + 73, // 0: daemon.LoginRequest.dnsRouteInterval:type_name -> google.protobuf.Duration 22, // 1: daemon.StatusResponse.fullStatus:type_name -> daemon.FullStatus - 72, // 2: daemon.PeerState.connStatusUpdate:type_name -> google.protobuf.Timestamp - 72, // 3: daemon.PeerState.lastWireguardHandshake:type_name -> google.protobuf.Timestamp - 71, // 4: daemon.PeerState.latency:type_name -> google.protobuf.Duration + 74, // 2: daemon.PeerState.connStatusUpdate:type_name -> google.protobuf.Timestamp + 74, // 3: daemon.PeerState.lastWireguardHandshake:type_name -> google.protobuf.Timestamp + 73, // 4: daemon.PeerState.latency:type_name -> google.protobuf.Duration 19, // 5: daemon.FullStatus.managementState:type_name -> daemon.ManagementState 18, // 6: daemon.FullStatus.signalState:type_name -> daemon.SignalState 17, // 7: daemon.FullStatus.localPeerState:type_name -> daemon.LocalPeerState @@ -4922,8 +5019,8 @@ var file_daemon_proto_depIdxs = []int32{ 21, // 10: daemon.FullStatus.dns_servers:type_name -> daemon.NSGroupState 52, // 11: daemon.FullStatus.events:type_name -> daemon.SystemEvent 28, // 12: daemon.ListNetworksResponse.routes:type_name -> daemon.Network - 68, // 13: daemon.Network.resolvedIPs:type_name -> daemon.Network.ResolvedIPsEntry - 69, // 14: daemon.PortInfo.range:type_name -> daemon.PortInfo.Range + 70, // 13: daemon.Network.resolvedIPs:type_name -> daemon.Network.ResolvedIPsEntry + 71, // 14: daemon.PortInfo.range:type_name -> daemon.PortInfo.Range 29, // 15: daemon.ForwardingRule.destinationPort:type_name -> daemon.PortInfo 29, // 16: daemon.ForwardingRule.translatedPort:type_name -> daemon.PortInfo 30, // 17: daemon.ForwardingRulesResponse.rules:type_name -> daemon.ForwardingRule @@ -4934,10 +5031,10 @@ var file_daemon_proto_depIdxs = []int32{ 49, // 22: daemon.TracePacketResponse.stages:type_name -> daemon.TraceStage 1, // 23: daemon.SystemEvent.severity:type_name -> daemon.SystemEvent.Severity 2, // 24: daemon.SystemEvent.category:type_name -> daemon.SystemEvent.Category - 72, // 25: daemon.SystemEvent.timestamp:type_name -> google.protobuf.Timestamp - 70, // 26: daemon.SystemEvent.metadata:type_name -> daemon.SystemEvent.MetadataEntry + 74, // 25: daemon.SystemEvent.timestamp:type_name -> google.protobuf.Timestamp + 72, // 26: daemon.SystemEvent.metadata:type_name -> daemon.SystemEvent.MetadataEntry 52, // 27: daemon.GetEventsResponse.events:type_name -> daemon.SystemEvent - 71, // 28: daemon.SetConfigRequest.dnsRouteInterval:type_name -> google.protobuf.Duration + 73, // 28: daemon.SetConfigRequest.dnsRouteInterval:type_name -> google.protobuf.Duration 65, // 29: daemon.ListProfilesResponse.profiles:type_name -> daemon.Profile 27, // 30: daemon.Network.ResolvedIPsEntry.value:type_name -> daemon.IPList 4, // 31: daemon.DaemonService.Login:input_type -> daemon.LoginRequest @@ -4966,34 +5063,36 @@ var file_daemon_proto_depIdxs = []int32{ 61, // 54: daemon.DaemonService.RemoveProfile:input_type -> daemon.RemoveProfileRequest 63, // 55: daemon.DaemonService.ListProfiles:input_type -> daemon.ListProfilesRequest 66, // 56: daemon.DaemonService.GetActiveProfile:input_type -> daemon.GetActiveProfileRequest - 5, // 57: daemon.DaemonService.Login:output_type -> daemon.LoginResponse - 7, // 58: daemon.DaemonService.WaitSSOLogin:output_type -> daemon.WaitSSOLoginResponse - 9, // 59: daemon.DaemonService.Up:output_type -> daemon.UpResponse - 11, // 60: daemon.DaemonService.Status:output_type -> daemon.StatusResponse - 13, // 61: daemon.DaemonService.Down:output_type -> daemon.DownResponse - 15, // 62: daemon.DaemonService.GetConfig:output_type -> daemon.GetConfigResponse - 24, // 63: daemon.DaemonService.ListNetworks:output_type -> daemon.ListNetworksResponse - 26, // 64: daemon.DaemonService.SelectNetworks:output_type -> daemon.SelectNetworksResponse - 26, // 65: daemon.DaemonService.DeselectNetworks:output_type -> daemon.SelectNetworksResponse - 31, // 66: daemon.DaemonService.ForwardingRules:output_type -> daemon.ForwardingRulesResponse - 33, // 67: daemon.DaemonService.DebugBundle:output_type -> daemon.DebugBundleResponse - 35, // 68: daemon.DaemonService.GetLogLevel:output_type -> daemon.GetLogLevelResponse - 37, // 69: daemon.DaemonService.SetLogLevel:output_type -> daemon.SetLogLevelResponse - 40, // 70: daemon.DaemonService.ListStates:output_type -> daemon.ListStatesResponse - 42, // 71: daemon.DaemonService.CleanState:output_type -> daemon.CleanStateResponse - 44, // 72: daemon.DaemonService.DeleteState:output_type -> daemon.DeleteStateResponse - 46, // 73: daemon.DaemonService.SetNetworkMapPersistence:output_type -> daemon.SetNetworkMapPersistenceResponse - 50, // 74: daemon.DaemonService.TracePacket:output_type -> daemon.TracePacketResponse - 52, // 75: daemon.DaemonService.SubscribeEvents:output_type -> daemon.SystemEvent - 54, // 76: daemon.DaemonService.GetEvents:output_type -> daemon.GetEventsResponse - 56, // 77: daemon.DaemonService.SwitchProfile:output_type -> daemon.SwitchProfileResponse - 58, // 78: daemon.DaemonService.SetConfig:output_type -> daemon.SetConfigResponse - 60, // 79: daemon.DaemonService.AddProfile:output_type -> daemon.AddProfileResponse - 62, // 80: daemon.DaemonService.RemoveProfile:output_type -> daemon.RemoveProfileResponse - 64, // 81: daemon.DaemonService.ListProfiles:output_type -> daemon.ListProfilesResponse - 67, // 82: daemon.DaemonService.GetActiveProfile:output_type -> daemon.GetActiveProfileResponse - 57, // [57:83] is the sub-list for method output_type - 31, // [31:57] is the sub-list for method input_type + 68, // 57: daemon.DaemonService.Logout:input_type -> daemon.LogoutRequest + 5, // 58: daemon.DaemonService.Login:output_type -> daemon.LoginResponse + 7, // 59: daemon.DaemonService.WaitSSOLogin:output_type -> daemon.WaitSSOLoginResponse + 9, // 60: daemon.DaemonService.Up:output_type -> daemon.UpResponse + 11, // 61: daemon.DaemonService.Status:output_type -> daemon.StatusResponse + 13, // 62: daemon.DaemonService.Down:output_type -> daemon.DownResponse + 15, // 63: daemon.DaemonService.GetConfig:output_type -> daemon.GetConfigResponse + 24, // 64: daemon.DaemonService.ListNetworks:output_type -> daemon.ListNetworksResponse + 26, // 65: daemon.DaemonService.SelectNetworks:output_type -> daemon.SelectNetworksResponse + 26, // 66: daemon.DaemonService.DeselectNetworks:output_type -> daemon.SelectNetworksResponse + 31, // 67: daemon.DaemonService.ForwardingRules:output_type -> daemon.ForwardingRulesResponse + 33, // 68: daemon.DaemonService.DebugBundle:output_type -> daemon.DebugBundleResponse + 35, // 69: daemon.DaemonService.GetLogLevel:output_type -> daemon.GetLogLevelResponse + 37, // 70: daemon.DaemonService.SetLogLevel:output_type -> daemon.SetLogLevelResponse + 40, // 71: daemon.DaemonService.ListStates:output_type -> daemon.ListStatesResponse + 42, // 72: daemon.DaemonService.CleanState:output_type -> daemon.CleanStateResponse + 44, // 73: daemon.DaemonService.DeleteState:output_type -> daemon.DeleteStateResponse + 46, // 74: daemon.DaemonService.SetNetworkMapPersistence:output_type -> daemon.SetNetworkMapPersistenceResponse + 50, // 75: daemon.DaemonService.TracePacket:output_type -> daemon.TracePacketResponse + 52, // 76: daemon.DaemonService.SubscribeEvents:output_type -> daemon.SystemEvent + 54, // 77: daemon.DaemonService.GetEvents:output_type -> daemon.GetEventsResponse + 56, // 78: daemon.DaemonService.SwitchProfile:output_type -> daemon.SwitchProfileResponse + 58, // 79: daemon.DaemonService.SetConfig:output_type -> daemon.SetConfigResponse + 60, // 80: daemon.DaemonService.AddProfile:output_type -> daemon.AddProfileResponse + 62, // 81: daemon.DaemonService.RemoveProfile:output_type -> daemon.RemoveProfileResponse + 64, // 82: daemon.DaemonService.ListProfiles:output_type -> daemon.ListProfilesResponse + 67, // 83: daemon.DaemonService.GetActiveProfile:output_type -> daemon.GetActiveProfileResponse + 69, // 84: daemon.DaemonService.Logout:output_type -> daemon.LogoutResponse + 58, // [58:85] is the sub-list for method output_type + 31, // [31:58] is the sub-list for method input_type 31, // [31:31] is the sub-list for extension type_name 31, // [31:31] is the sub-list for extension extendee 0, // [0:31] is the sub-list for field type_name @@ -5014,13 +5113,14 @@ func file_daemon_proto_init() { file_daemon_proto_msgTypes[46].OneofWrappers = []any{} file_daemon_proto_msgTypes[52].OneofWrappers = []any{} file_daemon_proto_msgTypes[54].OneofWrappers = []any{} + file_daemon_proto_msgTypes[65].OneofWrappers = []any{} type x struct{} out := protoimpl.TypeBuilder{ File: protoimpl.DescBuilder{ GoPackagePath: reflect.TypeOf(x{}).PkgPath(), RawDescriptor: unsafe.Slice(unsafe.StringData(file_daemon_proto_rawDesc), len(file_daemon_proto_rawDesc)), NumEnums: 3, - NumMessages: 68, + NumMessages: 70, NumExtensions: 0, NumServices: 1, }, diff --git a/client/proto/daemon.proto b/client/proto/daemon.proto index c25503df9..2e1e0254c 100644 --- a/client/proto/daemon.proto +++ b/client/proto/daemon.proto @@ -79,6 +79,9 @@ service DaemonService { rpc ListProfiles(ListProfilesRequest) returns (ListProfilesResponse) {} rpc GetActiveProfile(GetActiveProfileRequest) returns (GetActiveProfileResponse) {} + + // Logout disconnects from the network and deletes the peer from the management server + rpc Logout(LogoutRequest) returns (LogoutResponse) {} } @@ -614,4 +617,11 @@ message GetActiveProfileRequest {} message GetActiveProfileResponse { string profileName = 1; string username = 2; -} \ No newline at end of file +} + +message LogoutRequest { + optional string profileName = 1; + optional string username = 2; +} + +message LogoutResponse {} \ No newline at end of file diff --git a/client/proto/daemon_grpc.pb.go b/client/proto/daemon_grpc.pb.go index 669083168..edb56bd8a 100644 --- a/client/proto/daemon_grpc.pb.go +++ b/client/proto/daemon_grpc.pb.go @@ -61,6 +61,8 @@ type DaemonServiceClient interface { RemoveProfile(ctx context.Context, in *RemoveProfileRequest, opts ...grpc.CallOption) (*RemoveProfileResponse, error) ListProfiles(ctx context.Context, in *ListProfilesRequest, opts ...grpc.CallOption) (*ListProfilesResponse, error) GetActiveProfile(ctx context.Context, in *GetActiveProfileRequest, opts ...grpc.CallOption) (*GetActiveProfileResponse, error) + // Logout disconnects from the network and deletes the peer from the management server + Logout(ctx context.Context, in *LogoutRequest, opts ...grpc.CallOption) (*LogoutResponse, error) } type daemonServiceClient struct { @@ -328,6 +330,15 @@ func (c *daemonServiceClient) GetActiveProfile(ctx context.Context, in *GetActiv return out, nil } +func (c *daemonServiceClient) Logout(ctx context.Context, in *LogoutRequest, opts ...grpc.CallOption) (*LogoutResponse, error) { + out := new(LogoutResponse) + err := c.cc.Invoke(ctx, "/daemon.DaemonService/Logout", in, out, opts...) + if err != nil { + return nil, err + } + return out, nil +} + // DaemonServiceServer is the server API for DaemonService service. // All implementations must embed UnimplementedDaemonServiceServer // for forward compatibility @@ -375,6 +386,8 @@ type DaemonServiceServer interface { RemoveProfile(context.Context, *RemoveProfileRequest) (*RemoveProfileResponse, error) ListProfiles(context.Context, *ListProfilesRequest) (*ListProfilesResponse, error) GetActiveProfile(context.Context, *GetActiveProfileRequest) (*GetActiveProfileResponse, error) + // Logout disconnects from the network and deletes the peer from the management server + Logout(context.Context, *LogoutRequest) (*LogoutResponse, error) mustEmbedUnimplementedDaemonServiceServer() } @@ -460,6 +473,9 @@ func (UnimplementedDaemonServiceServer) ListProfiles(context.Context, *ListProfi func (UnimplementedDaemonServiceServer) GetActiveProfile(context.Context, *GetActiveProfileRequest) (*GetActiveProfileResponse, error) { return nil, status.Errorf(codes.Unimplemented, "method GetActiveProfile not implemented") } +func (UnimplementedDaemonServiceServer) Logout(context.Context, *LogoutRequest) (*LogoutResponse, error) { + return nil, status.Errorf(codes.Unimplemented, "method Logout not implemented") +} func (UnimplementedDaemonServiceServer) mustEmbedUnimplementedDaemonServiceServer() {} // UnsafeDaemonServiceServer may be embedded to opt out of forward compatibility for this service. @@ -944,6 +960,24 @@ func _DaemonService_GetActiveProfile_Handler(srv interface{}, ctx context.Contex return interceptor(ctx, in, info, handler) } +func _DaemonService_Logout_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) { + in := new(LogoutRequest) + if err := dec(in); err != nil { + return nil, err + } + if interceptor == nil { + return srv.(DaemonServiceServer).Logout(ctx, in) + } + info := &grpc.UnaryServerInfo{ + Server: srv, + FullMethod: "/daemon.DaemonService/Logout", + } + handler := func(ctx context.Context, req interface{}) (interface{}, error) { + return srv.(DaemonServiceServer).Logout(ctx, req.(*LogoutRequest)) + } + return interceptor(ctx, in, info, handler) +} + // DaemonService_ServiceDesc is the grpc.ServiceDesc for DaemonService service. // It's only intended for direct use with grpc.RegisterService, // and not to be introspected or modified (even as a copy) @@ -1051,6 +1085,10 @@ var DaemonService_ServiceDesc = grpc.ServiceDesc{ MethodName: "GetActiveProfile", Handler: _DaemonService_GetActiveProfile_Handler, }, + { + MethodName: "Logout", + Handler: _DaemonService_Logout_Handler, + }, }, Streams: []grpc.StreamDesc{ { diff --git a/client/server/server.go b/client/server/server.go index 3cb173881..7eb59c91a 100644 --- a/client/server/server.go +++ b/client/server/server.go @@ -2,6 +2,7 @@ package server import ( "context" + "errors" "fmt" "os" "os/exec" @@ -13,6 +14,7 @@ import ( "github.com/cenkalti/backoff/v4" "golang.org/x/exp/maps" + "golang.zx2c4.com/wireguard/wgctrl/wgtypes" "google.golang.org/protobuf/types/known/durationpb" log "github.com/sirupsen/logrus" @@ -24,6 +26,7 @@ import ( "github.com/netbirdio/netbird/client/internal/auth" "github.com/netbirdio/netbird/client/internal/profilemanager" "github.com/netbirdio/netbird/client/system" + mgm "github.com/netbirdio/netbird/management/client" "github.com/netbirdio/netbird/management/domain" "github.com/netbirdio/netbird/client/internal" @@ -47,6 +50,8 @@ const ( errProfilesDisabled = "profiles are disabled, you cannot use this feature without profiles enabled" ) +var ErrServiceNotUp = errors.New("service is not up") + // Server for service control. type Server struct { rootCtx context.Context @@ -131,13 +136,7 @@ func (s *Server) Start() error { return fmt.Errorf("failed to get active profile state: %w", err) } - cfgPath, err := activeProf.FilePath() - if err != nil { - log.Errorf("failed to get active profile file path: %v", err) - return fmt.Errorf("failed to get active profile file path: %w", err) - } - - config, err := profilemanager.GetConfig(cfgPath) + config, err := s.getConfig(activeProf) if err != nil { log.Errorf("failed to get active profile config: %v", err) @@ -484,13 +483,7 @@ func (s *Server) Login(callerCtx context.Context, msg *proto.LoginRequest) (*pro } s.mutex.Unlock() - cfgPath, err := activeProf.FilePath() - if err != nil { - log.Errorf("failed to get active profile file path: %v", err) - return nil, fmt.Errorf("failed to get active profile file path: %w", err) - } - - config, err := profilemanager.GetConfig(cfgPath) + config, err := s.getConfig(activeProf) if err != nil { log.Errorf("failed to get active profile config: %v", err) return nil, fmt.Errorf("failed to get active profile config: %w", err) @@ -701,13 +694,7 @@ func (s *Server) Up(callerCtx context.Context, msg *proto.UpRequest) (*proto.UpR log.Infof("active profile: %s for %s", activeProf.Name, activeProf.Username) - cfgPath, err := activeProf.FilePath() - if err != nil { - log.Errorf("failed to get active profile file path: %v", err) - return nil, fmt.Errorf("failed to get active profile file path: %w", err) - } - - config, err := profilemanager.GetConfig(cfgPath) + config, err := s.getConfig(activeProf) if err != nil { log.Errorf("failed to get active profile config: %v", err) return nil, fmt.Errorf("failed to get active profile config: %w", err) @@ -789,13 +776,7 @@ func (s *Server) SwitchProfile(callerCtx context.Context, msg *proto.SwitchProfi log.Errorf("failed to get active profile state: %v", err) return nil, fmt.Errorf("failed to get active profile state: %w", err) } - cfgPath, err := activeProf.FilePath() - if err != nil { - log.Errorf("failed to get active profile file path: %v", err) - return nil, fmt.Errorf("failed to get active profile file path: %w", err) - } - - config, err := profilemanager.GetConfig(cfgPath) + config, err := s.getConfig(activeProf) if err != nil { log.Errorf("failed to get default profile config: %v", err) return nil, fmt.Errorf("failed to get default profile config: %w", err) @@ -811,26 +792,201 @@ func (s *Server) Down(ctx context.Context, _ *proto.DownRequest) (*proto.DownRes s.mutex.Lock() defer s.mutex.Unlock() - s.oauthAuthFlow = oauthAuthFlow{} - - if s.actCancel == nil { - return nil, fmt.Errorf("service is not up") - } - s.actCancel() - - err := s.connectClient.Stop() - if err != nil { + if err := s.cleanupConnection(); err != nil { log.Errorf("failed to shut down properly: %v", err) return nil, err } - s.isSessionActive.Store(false) state := internal.CtxGetState(s.rootCtx) state.Set(internal.StatusIdle) + return &proto.DownResponse{}, nil +} + +func (s *Server) cleanupConnection() error { + s.oauthAuthFlow = oauthAuthFlow{} + + if s.actCancel == nil { + return ErrServiceNotUp + } + s.actCancel() + + if s.connectClient == nil { + return nil + } + + if err := s.connectClient.Stop(); err != nil { + return err + } + + s.connectClient = nil + s.isSessionActive.Store(false) + log.Infof("service is down") - return &proto.DownResponse{}, nil + return nil +} + +func (s *Server) Logout(ctx context.Context, msg *proto.LogoutRequest) (*proto.LogoutResponse, error) { + s.mutex.Lock() + defer s.mutex.Unlock() + + if msg.ProfileName != nil && *msg.ProfileName != "" { + return s.handleProfileLogout(ctx, msg) + } + + return s.handleActiveProfileLogout(ctx) +} + +func (s *Server) handleProfileLogout(ctx context.Context, msg *proto.LogoutRequest) (*proto.LogoutResponse, error) { + if err := s.validateProfileOperation(*msg.ProfileName, true); err != nil { + return nil, err + } + + if msg.Username == nil || *msg.Username == "" { + return nil, gstatus.Errorf(codes.InvalidArgument, "username must be provided when profile name is specified") + } + username := *msg.Username + + if err := s.logoutFromProfile(ctx, *msg.ProfileName, username); err != nil { + log.Errorf("failed to logout from profile %s: %v", *msg.ProfileName, err) + return nil, gstatus.Errorf(codes.Internal, "logout: %v", err) + } + + activeProf, _ := s.profileManager.GetActiveProfileState() + if activeProf != nil && activeProf.Name == *msg.ProfileName { + if err := s.cleanupConnection(); err != nil && !errors.Is(err, ErrServiceNotUp) { + log.Errorf("failed to cleanup connection: %v", err) + } + state := internal.CtxGetState(s.rootCtx) + state.Set(internal.StatusNeedsLogin) + } + + return &proto.LogoutResponse{}, nil +} + +func (s *Server) handleActiveProfileLogout(ctx context.Context) (*proto.LogoutResponse, error) { + if s.config == nil { + activeProf, err := s.profileManager.GetActiveProfileState() + if err != nil { + return nil, gstatus.Errorf(codes.FailedPrecondition, "failed to get active profile state: %v", err) + } + + config, err := s.getConfig(activeProf) + if err != nil { + return nil, gstatus.Errorf(codes.FailedPrecondition, "not logged in") + } + s.config = config + } + + if err := s.sendLogoutRequest(ctx); err != nil { + log.Errorf("failed to send logout request: %v", err) + return nil, err + } + + if err := s.cleanupConnection(); err != nil && !errors.Is(err, ErrServiceNotUp) { + log.Errorf("failed to cleanup connection: %v", err) + return nil, err + } + + state := internal.CtxGetState(s.rootCtx) + state.Set(internal.StatusNeedsLogin) + + return &proto.LogoutResponse{}, nil +} + +// getConfig loads the config from the active profile +func (s *Server) getConfig(activeProf *profilemanager.ActiveProfileState) (*profilemanager.Config, error) { + cfgPath, err := activeProf.FilePath() + if err != nil { + return nil, fmt.Errorf("failed to get active profile file path: %w", err) + } + + config, err := profilemanager.GetConfig(cfgPath) + if err != nil { + return nil, fmt.Errorf("failed to get config: %w", err) + } + + return config, nil +} + +func (s *Server) canRemoveProfile(profileName string) error { + if profileName == profilemanager.DefaultProfileName { + return fmt.Errorf("remove profile with reserved name: %s", profilemanager.DefaultProfileName) + } + + activeProf, err := s.profileManager.GetActiveProfileState() + if err == nil && activeProf.Name == profileName { + return fmt.Errorf("remove active profile: %s", profileName) + } + + return nil +} + +func (s *Server) validateProfileOperation(profileName string, allowActiveProfile bool) error { + if s.checkProfilesDisabled() { + return gstatus.Errorf(codes.Unavailable, errProfilesDisabled) + } + + if profileName == "" { + return gstatus.Errorf(codes.InvalidArgument, "profile name must be provided") + } + + if !allowActiveProfile { + if err := s.canRemoveProfile(profileName); err != nil { + return gstatus.Errorf(codes.InvalidArgument, "%v", err) + } + } + + return nil +} + +// logoutFromProfile logs out from a specific profile by loading its config and sending logout request +func (s *Server) logoutFromProfile(ctx context.Context, profileName, username string) error { + activeProf, err := s.profileManager.GetActiveProfileState() + if err == nil && activeProf.Name == profileName && s.connectClient != nil { + return s.sendLogoutRequest(ctx) + } + + profileState := &profilemanager.ActiveProfileState{ + Name: profileName, + Username: username, + } + profilePath, err := profileState.FilePath() + if err != nil { + return fmt.Errorf("get profile path: %w", err) + } + + config, err := profilemanager.GetConfig(profilePath) + if err != nil { + return fmt.Errorf("profile '%s' not found", profileName) + } + + return s.sendLogoutRequestWithConfig(ctx, config) +} + +func (s *Server) sendLogoutRequest(ctx context.Context) error { + return s.sendLogoutRequestWithConfig(ctx, s.config) +} + +func (s *Server) sendLogoutRequestWithConfig(ctx context.Context, config *profilemanager.Config) error { + key, err := wgtypes.ParseKey(config.PrivateKey) + if err != nil { + return fmt.Errorf("parse private key: %w", err) + } + + mgmTlsEnabled := config.ManagementURL.Scheme == "https" + mgmClient, err := mgm.NewClient(ctx, config.ManagementURL.Host, key, mgmTlsEnabled) + if err != nil { + return fmt.Errorf("connect to management server: %w", err) + } + defer func() { + if err := mgmClient.Close(); err != nil { + log.Errorf("close management client: %v", err) + } + }() + + return mgmClient.Logout() } // Status returns the daemon status @@ -1107,12 +1263,12 @@ func (s *Server) RemoveProfile(ctx context.Context, msg *proto.RemoveProfileRequ s.mutex.Lock() defer s.mutex.Unlock() - if s.checkProfilesDisabled() { - return nil, gstatus.Errorf(codes.Unavailable, errProfilesDisabled) + if err := s.validateProfileOperation(msg.ProfileName, false); err != nil { + return nil, err } - if msg.ProfileName == "" { - return nil, gstatus.Errorf(codes.InvalidArgument, "profile name must be provided") + if err := s.logoutFromProfile(ctx, msg.ProfileName, msg.Username); err != nil { + log.Warnf("failed to logout from profile %s before removal: %v", msg.ProfileName, err) } if err := s.profileManager.RemoveProfile(msg.ProfileName, msg.Username); err != nil { diff --git a/client/ui/client_ui.go b/client/ui/client_ui.go index c74412c8b..88cb11eab 100644 --- a/client/ui/client_ui.go +++ b/client/ui/client_ui.go @@ -831,6 +831,7 @@ func (s *serviceClient) onTrayReady() { s.mLazyConnEnabled = s.mSettings.AddSubMenuItemCheckbox("Enable Lazy Connections", lazyConnMenuDescr, false) s.mBlockInbound = s.mSettings.AddSubMenuItemCheckbox("Block Inbound Connections", blockInboundMenuDescr, false) s.mNotifications = s.mSettings.AddSubMenuItemCheckbox("Notifications", notificationsMenuDescr, false) + s.mSettings.AddSeparator() s.mAdvancedSettings = s.mSettings.AddSubMenuItem("Advanced Settings", advancedSettingsMenuDescr) s.mCreateDebugBundle = s.mSettings.AddSubMenuItem("Create Debug Bundle", debugBundleMenuDescr) s.loadSettings() diff --git a/client/ui/event_handler.go b/client/ui/event_handler.go index c0bc74a2c..e9b7f4f30 100644 --- a/client/ui/event_handler.go +++ b/client/ui/event_handler.go @@ -13,6 +13,7 @@ import ( "fyne.io/systray" log "github.com/sirupsen/logrus" + "github.com/netbirdio/netbird/client/proto" "github.com/netbirdio/netbird/version" ) @@ -231,3 +232,19 @@ func (h *eventHandler) runSelfCommand(ctx context.Context, command, arg string) log.Printf("command '%s %s' completed successfully", command, arg) } + +func (h *eventHandler) logout(ctx context.Context) error { + client, err := h.client.getSrvClient(defaultFailTimeout) + if err != nil { + return fmt.Errorf("failed to get service client: %w", err) + } + + _, err = client.Logout(ctx, &proto.LogoutRequest{}) + if err != nil { + return fmt.Errorf("logout failed: %w", err) + } + + h.client.getSrvConfig() + + return nil +} diff --git a/client/ui/profile.go b/client/ui/profile.go index 779f60aa4..b0502c1fb 100644 --- a/client/ui/profile.go +++ b/client/ui/profile.go @@ -40,12 +40,13 @@ func (s *serviceClient) showProfilesUI() { list := widget.NewList( func() int { return len(profiles) }, func() fyne.CanvasObject { - // Each item: Selected indicator, Name, spacer, Select & Remove buttons + // Each item: Selected indicator, Name, spacer, Select, Logout & Remove buttons return container.NewHBox( widget.NewLabel(""), // indicator widget.NewLabel(""), // profile name layout.NewSpacer(), widget.NewButton("Select", nil), + widget.NewButton("Logout", nil), widget.NewButton("Remove", nil), ) }, @@ -55,7 +56,8 @@ func (s *serviceClient) showProfilesUI() { indicator := row.Objects[0].(*widget.Label) nameLabel := row.Objects[1].(*widget.Label) selectBtn := row.Objects[3].(*widget.Button) - removeBtn := row.Objects[4].(*widget.Button) + logoutBtn := row.Objects[4].(*widget.Button) + removeBtn := row.Objects[5].(*widget.Button) profile := profiles[i] // Show a checkmark if selected @@ -105,7 +107,7 @@ func (s *serviceClient) showProfilesUI() { return } - status, err := conn.Status(context.Background(), &proto.StatusRequest{}) + status, err := conn.Status(s.ctx, &proto.StatusRequest{}) if err != nil { log.Errorf("failed to get status after switching profile: %v", err) return @@ -125,6 +127,12 @@ func (s *serviceClient) showProfilesUI() { ) } + logoutBtn.Show() + logoutBtn.SetText("Logout") + logoutBtn.OnTapped = func() { + s.handleProfileLogout(profile.Name, refresh) + } + // Remove profile removeBtn.SetText("Remove") removeBtn.OnTapped = func() { @@ -135,7 +143,7 @@ func (s *serviceClient) showProfilesUI() { if !confirm { return } - // remove + err = s.removeProfile(profile.Name) if err != nil { log.Errorf("failed to remove profile: %v", err) @@ -230,7 +238,7 @@ func (s *serviceClient) addProfile(profileName string) error { return fmt.Errorf("get current user: %w", err) } - _, err = conn.AddProfile(context.Background(), &proto.AddProfileRequest{ + _, err = conn.AddProfile(s.ctx, &proto.AddProfileRequest{ ProfileName: profileName, Username: currUser.Username, }) @@ -253,7 +261,7 @@ func (s *serviceClient) switchProfile(profileName string) error { return fmt.Errorf("get current user: %w", err) } - if _, err := conn.SwitchProfile(context.Background(), &proto.SwitchProfileRequest{ + if _, err := conn.SwitchProfile(s.ctx, &proto.SwitchProfileRequest{ ProfileName: &profileName, Username: &currUser.Username, }); err != nil { @@ -279,7 +287,7 @@ func (s *serviceClient) removeProfile(profileName string) error { return fmt.Errorf("get current user: %w", err) } - _, err = conn.RemoveProfile(context.Background(), &proto.RemoveProfileRequest{ + _, err = conn.RemoveProfile(s.ctx, &proto.RemoveProfileRequest{ ProfileName: profileName, Username: currUser.Username, }) @@ -305,7 +313,7 @@ func (s *serviceClient) getProfiles() ([]Profile, error) { if err != nil { return nil, fmt.Errorf("get current user: %w", err) } - profilesResp, err := conn.ListProfiles(context.Background(), &proto.ListProfilesRequest{ + profilesResp, err := conn.ListProfiles(s.ctx, &proto.ListProfilesRequest{ Username: currUser.Username, }) if err != nil { @@ -324,6 +332,52 @@ func (s *serviceClient) getProfiles() ([]Profile, error) { return profiles, nil } +func (s *serviceClient) handleProfileLogout(profileName string, refreshCallback func()) { + dialog.ShowConfirm( + "Logout", + fmt.Sprintf("Are you sure you want to logout from '%s'?", profileName), + func(confirm bool) { + if !confirm { + return + } + + conn, err := s.getSrvClient(defaultFailTimeout) + if err != nil { + log.Errorf("failed to get service client: %v", err) + dialog.ShowError(fmt.Errorf("failed to connect to service"), s.wProfiles) + return + } + + currUser, err := user.Current() + if err != nil { + log.Errorf("failed to get current user: %v", err) + dialog.ShowError(fmt.Errorf("failed to get current user"), s.wProfiles) + return + } + + username := currUser.Username + _, err = conn.Logout(s.ctx, &proto.LogoutRequest{ + ProfileName: &profileName, + Username: &username, + }) + if err != nil { + log.Errorf("logout failed: %v", err) + dialog.ShowError(fmt.Errorf("logout failed"), s.wProfiles) + return + } + + dialog.ShowInformation( + "Logged Out", + fmt.Sprintf("Successfully logged out from '%s'", profileName), + s.wProfiles, + ) + + refreshCallback() + }, + s.wProfiles, + ) +} + type subItem struct { *systray.MenuItem ctx context.Context @@ -339,6 +393,7 @@ type profileMenu struct { emailMenuItem *systray.MenuItem profileSubItems []*subItem manageProfilesSubItem *subItem + logoutSubItem *subItem profilesState []Profile downClickCallback func() error upClickCallback func() error @@ -533,12 +588,11 @@ func (p *profileMenu) refresh() { for { select { case <-ctx.Done(): - return // context cancelled + return case _, ok := <-manageItem.ClickedCh: if !ok { - return // channel closed + return } - // Handle manage profiles click p.eventHandler.runSelfCommand(p.ctx, "profiles", "true") p.refresh() p.loadSettingsCallback() @@ -546,6 +600,30 @@ func (p *profileMenu) refresh() { } }() + // Add Logout menu item + ctx2, cancel2 := context.WithCancel(context.Background()) + logoutItem := p.profileMenuItem.AddSubMenuItem("Logout", "") + p.logoutSubItem = &subItem{logoutItem, ctx2, cancel2} + + go func() { + for { + select { + case <-ctx2.Done(): + return + case _, ok := <-logoutItem.ClickedCh: + if !ok { + return + } + if err := p.eventHandler.logout(p.ctx); err != nil { + log.Errorf("logout failed: %v", err) + p.app.SendNotification(fyne.NewNotification("Error", "Failed to logout")) + } else { + p.app.SendNotification(fyne.NewNotification("Success", "Logged out successfully")) + } + } + } + }() + if activeProf.ProfileName == "default" || activeProf.Username == currUser.Username { p.profileMenuItem.SetTitle(activeProf.ProfileName) } else { @@ -556,7 +634,6 @@ func (p *profileMenu) refresh() { } func (p *profileMenu) clear(profiles []Profile) { - // Clear existing profile items for _, item := range p.profileSubItems { item.Remove() item.cancel() @@ -565,11 +642,16 @@ func (p *profileMenu) clear(profiles []Profile) { p.profilesState = profiles if p.manageProfilesSubItem != nil { - // Remove the manage profiles item if it exists p.manageProfilesSubItem.Remove() p.manageProfilesSubItem.cancel() p.manageProfilesSubItem = nil } + + if p.logoutSubItem != nil { + p.logoutSubItem.Remove() + p.logoutSubItem.cancel() + p.logoutSubItem = nil + } } func (p *profileMenu) updateMenu() { diff --git a/management/client/client.go b/management/client/client.go index 950f6137e..3a50a155b 100644 --- a/management/client/client.go +++ b/management/client/client.go @@ -22,4 +22,5 @@ type Client interface { GetNetworkMap(sysInfo *system.Info) (*proto.NetworkMap, error) IsHealthy() bool SyncMeta(sysInfo *system.Info) error + Logout() error } diff --git a/management/client/grpc.go b/management/client/grpc.go index ef26574bd..f181d8b46 100644 --- a/management/client/grpc.go +++ b/management/client/grpc.go @@ -497,6 +497,32 @@ func (c *GrpcClient) notifyConnected() { c.connStateCallback.MarkManagementConnected() } +func (c *GrpcClient) Logout() error { + serverKey, err := c.GetServerPublicKey() + if err != nil { + return fmt.Errorf("get server public key: %w", err) + } + + mgmCtx, cancel := context.WithTimeout(c.ctx, time.Second*5) + defer cancel() + + message := &proto.Empty{} + encryptedMSG, err := encryption.EncryptMessage(*serverKey, c.key, message) + if err != nil { + return fmt.Errorf("encrypt logout message: %w", err) + } + + _, err = c.realClient.Logout(mgmCtx, &proto.EncryptedMessage{ + WgPubKey: c.key.PublicKey().String(), + Body: encryptedMSG, + }) + if err != nil { + return fmt.Errorf("logout: %w", err) + } + + return nil +} + func infoToMetaData(info *system.Info) *proto.PeerSystemMeta { if info == nil { return nil diff --git a/management/client/mock.go b/management/client/mock.go index 9e1786f82..8e1a13705 100644 --- a/management/client/mock.go +++ b/management/client/mock.go @@ -19,6 +19,7 @@ type MockClient struct { GetDeviceAuthorizationFlowFunc func(serverKey wgtypes.Key) (*proto.DeviceAuthorizationFlow, error) GetPKCEAuthorizationFlowFunc func(serverKey wgtypes.Key) (*proto.PKCEAuthorizationFlow, error) SyncMetaFunc func(sysInfo *system.Info) error + LogoutFunc func() error } func (m *MockClient) IsHealthy() bool { @@ -85,3 +86,10 @@ func (m *MockClient) SyncMeta(sysInfo *system.Info) error { } return m.SyncMetaFunc(sysInfo) } + +func (m *MockClient) Logout() error { + if m.LogoutFunc == nil { + return nil + } + return m.LogoutFunc() +} diff --git a/management/proto/management.pb.go b/management/proto/management.pb.go index 8503f2e94..848610c78 100644 --- a/management/proto/management.pb.go +++ b/management/proto/management.pb.go @@ -3825,7 +3825,7 @@ var file_management_proto_rawDesc = []byte{ 0x6e, 0x12, 0x06, 0x0a, 0x02, 0x49, 0x4e, 0x10, 0x00, 0x12, 0x07, 0x0a, 0x03, 0x4f, 0x55, 0x54, 0x10, 0x01, 0x2a, 0x22, 0x0a, 0x0a, 0x52, 0x75, 0x6c, 0x65, 0x41, 0x63, 0x74, 0x69, 0x6f, 0x6e, 0x12, 0x0a, 0x0a, 0x06, 0x41, 0x43, 0x43, 0x45, 0x50, 0x54, 0x10, 0x00, 0x12, 0x08, 0x0a, 0x04, - 0x44, 0x52, 0x4f, 0x50, 0x10, 0x01, 0x32, 0x90, 0x04, 0x0a, 0x11, 0x4d, 0x61, 0x6e, 0x61, 0x67, + 0x44, 0x52, 0x4f, 0x50, 0x10, 0x01, 0x32, 0xcd, 0x04, 0x0a, 0x11, 0x4d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x53, 0x65, 0x72, 0x76, 0x69, 0x63, 0x65, 0x12, 0x45, 0x0a, 0x05, 0x4c, 0x6f, 0x67, 0x69, 0x6e, 0x12, 0x1c, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x45, 0x6e, 0x63, 0x72, 0x79, 0x70, 0x74, 0x65, 0x64, 0x4d, 0x65, 0x73, 0x73, @@ -3858,8 +3858,12 @@ var file_management_proto_rawDesc = []byte{ 0x6e, 0x63, 0x4d, 0x65, 0x74, 0x61, 0x12, 0x1c, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x45, 0x6e, 0x63, 0x72, 0x79, 0x70, 0x74, 0x65, 0x64, 0x4d, 0x65, 0x73, 0x73, 0x61, 0x67, 0x65, 0x1a, 0x11, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, - 0x74, 0x2e, 0x45, 0x6d, 0x70, 0x74, 0x79, 0x22, 0x00, 0x42, 0x08, 0x5a, 0x06, 0x2f, 0x70, 0x72, - 0x6f, 0x74, 0x6f, 0x62, 0x06, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x33, + 0x74, 0x2e, 0x45, 0x6d, 0x70, 0x74, 0x79, 0x22, 0x00, 0x12, 0x3b, 0x0a, 0x06, 0x4c, 0x6f, 0x67, + 0x6f, 0x75, 0x74, 0x12, 0x1c, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, + 0x2e, 0x45, 0x6e, 0x63, 0x72, 0x79, 0x70, 0x74, 0x65, 0x64, 0x4d, 0x65, 0x73, 0x73, 0x61, 0x67, + 0x65, 0x1a, 0x11, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x45, + 0x6d, 0x70, 0x74, 0x79, 0x22, 0x00, 0x42, 0x08, 0x5a, 0x06, 0x2f, 0x70, 0x72, 0x6f, 0x74, 0x6f, + 0x62, 0x06, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x33, } var ( @@ -3986,15 +3990,17 @@ var file_management_proto_depIdxs = []int32{ 5, // 57: management.ManagementService.GetDeviceAuthorizationFlow:input_type -> management.EncryptedMessage 5, // 58: management.ManagementService.GetPKCEAuthorizationFlow:input_type -> management.EncryptedMessage 5, // 59: management.ManagementService.SyncMeta:input_type -> management.EncryptedMessage - 5, // 60: management.ManagementService.Login:output_type -> management.EncryptedMessage - 5, // 61: management.ManagementService.Sync:output_type -> management.EncryptedMessage - 16, // 62: management.ManagementService.GetServerKey:output_type -> management.ServerKeyResponse - 17, // 63: management.ManagementService.isHealthy:output_type -> management.Empty - 5, // 64: management.ManagementService.GetDeviceAuthorizationFlow:output_type -> management.EncryptedMessage - 5, // 65: management.ManagementService.GetPKCEAuthorizationFlow:output_type -> management.EncryptedMessage - 17, // 66: management.ManagementService.SyncMeta:output_type -> management.Empty - 60, // [60:67] is the sub-list for method output_type - 53, // [53:60] is the sub-list for method input_type + 5, // 60: management.ManagementService.Logout:input_type -> management.EncryptedMessage + 5, // 61: management.ManagementService.Login:output_type -> management.EncryptedMessage + 5, // 62: management.ManagementService.Sync:output_type -> management.EncryptedMessage + 16, // 63: management.ManagementService.GetServerKey:output_type -> management.ServerKeyResponse + 17, // 64: management.ManagementService.isHealthy:output_type -> management.Empty + 5, // 65: management.ManagementService.GetDeviceAuthorizationFlow:output_type -> management.EncryptedMessage + 5, // 66: management.ManagementService.GetPKCEAuthorizationFlow:output_type -> management.EncryptedMessage + 17, // 67: management.ManagementService.SyncMeta:output_type -> management.Empty + 17, // 68: management.ManagementService.Logout:output_type -> management.Empty + 61, // [61:69] is the sub-list for method output_type + 53, // [53:61] is the sub-list for method input_type 53, // [53:53] is the sub-list for extension type_name 53, // [53:53] is the sub-list for extension extendee 0, // [0:53] is the sub-list for field type_name diff --git a/management/proto/management.proto b/management/proto/management.proto index 8e137df93..d5441d352 100644 --- a/management/proto/management.proto +++ b/management/proto/management.proto @@ -45,6 +45,9 @@ service ManagementService { // sync meta will evaluate the checks and update the peer meta with the result. // EncryptedMessage of the request has a body of Empty. rpc SyncMeta(EncryptedMessage) returns (Empty) {} + + // Logout logs out the peer and removes it from the management server + rpc Logout(EncryptedMessage) returns (Empty) {} } message EncryptedMessage { diff --git a/management/proto/management_grpc.pb.go b/management/proto/management_grpc.pb.go index badf242f5..5b189334d 100644 --- a/management/proto/management_grpc.pb.go +++ b/management/proto/management_grpc.pb.go @@ -48,6 +48,8 @@ type ManagementServiceClient interface { // sync meta will evaluate the checks and update the peer meta with the result. // EncryptedMessage of the request has a body of Empty. SyncMeta(ctx context.Context, in *EncryptedMessage, opts ...grpc.CallOption) (*Empty, error) + // Logout logs out the peer and removes it from the management server + Logout(ctx context.Context, in *EncryptedMessage, opts ...grpc.CallOption) (*Empty, error) } type managementServiceClient struct { @@ -144,6 +146,15 @@ func (c *managementServiceClient) SyncMeta(ctx context.Context, in *EncryptedMes return out, nil } +func (c *managementServiceClient) Logout(ctx context.Context, in *EncryptedMessage, opts ...grpc.CallOption) (*Empty, error) { + out := new(Empty) + err := c.cc.Invoke(ctx, "/management.ManagementService/Logout", in, out, opts...) + if err != nil { + return nil, err + } + return out, nil +} + // ManagementServiceServer is the server API for ManagementService service. // All implementations must embed UnimplementedManagementServiceServer // for forward compatibility @@ -178,6 +189,8 @@ type ManagementServiceServer interface { // sync meta will evaluate the checks and update the peer meta with the result. // EncryptedMessage of the request has a body of Empty. SyncMeta(context.Context, *EncryptedMessage) (*Empty, error) + // Logout logs out the peer and removes it from the management server + Logout(context.Context, *EncryptedMessage) (*Empty, error) mustEmbedUnimplementedManagementServiceServer() } @@ -206,6 +219,9 @@ func (UnimplementedManagementServiceServer) GetPKCEAuthorizationFlow(context.Con func (UnimplementedManagementServiceServer) SyncMeta(context.Context, *EncryptedMessage) (*Empty, error) { return nil, status.Errorf(codes.Unimplemented, "method SyncMeta not implemented") } +func (UnimplementedManagementServiceServer) Logout(context.Context, *EncryptedMessage) (*Empty, error) { + return nil, status.Errorf(codes.Unimplemented, "method Logout not implemented") +} func (UnimplementedManagementServiceServer) mustEmbedUnimplementedManagementServiceServer() {} // UnsafeManagementServiceServer may be embedded to opt out of forward compatibility for this service. @@ -348,6 +364,24 @@ func _ManagementService_SyncMeta_Handler(srv interface{}, ctx context.Context, d return interceptor(ctx, in, info, handler) } +func _ManagementService_Logout_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) { + in := new(EncryptedMessage) + if err := dec(in); err != nil { + return nil, err + } + if interceptor == nil { + return srv.(ManagementServiceServer).Logout(ctx, in) + } + info := &grpc.UnaryServerInfo{ + Server: srv, + FullMethod: "/management.ManagementService/Logout", + } + handler := func(ctx context.Context, req interface{}) (interface{}, error) { + return srv.(ManagementServiceServer).Logout(ctx, req.(*EncryptedMessage)) + } + return interceptor(ctx, in, info, handler) +} + // ManagementService_ServiceDesc is the grpc.ServiceDesc for ManagementService service. // It's only intended for direct use with grpc.RegisterService, // and not to be introspected or modified (even as a copy) @@ -379,6 +413,10 @@ var ManagementService_ServiceDesc = grpc.ServiceDesc{ MethodName: "SyncMeta", Handler: _ManagementService_SyncMeta_Handler, }, + { + MethodName: "Logout", + Handler: _ManagementService_Logout_Handler, + }, }, Streams: []grpc.StreamDesc{ { diff --git a/management/server/grpcserver.go b/management/server/grpcserver.go index 2f1bc3673..b121cc993 100644 --- a/management/server/grpcserver.go +++ b/management/server/grpcserver.go @@ -19,7 +19,9 @@ import ( "google.golang.org/grpc/status" integrationsConfig "github.com/netbirdio/management-integrations/integrations/config" + "github.com/netbirdio/netbird/management/server/integrations/integrated_validator" + "github.com/netbirdio/netbird/management/server/store" "github.com/netbirdio/netbird/encryption" "github.com/netbirdio/netbird/management/proto" @@ -909,6 +911,44 @@ func (s *GRPCServer) SyncMeta(ctx context.Context, req *proto.EncryptedMessage) return &proto.Empty{}, nil } +func (s *GRPCServer) Logout(ctx context.Context, req *proto.EncryptedMessage) (*proto.Empty, error) { + log.WithContext(ctx).Debugf("Logout request from peer [%s]", req.WgPubKey) + + empty := &proto.Empty{} + peerKey, err := s.parseRequest(ctx, req, empty) + if err != nil { + return nil, err + } + + peer, err := s.accountManager.GetStore().GetPeerByPeerPubKey(ctx, store.LockingStrengthShare, peerKey.String()) + if err != nil { + log.WithContext(ctx).Debugf("peer %s is not registered for logout", peerKey.String()) + // TODO: consider idempotency + return nil, mapError(ctx, err) + } + + // nolint:staticcheck + ctx = context.WithValue(ctx, nbContext.PeerIDKey, peer.ID) + // nolint:staticcheck + ctx = context.WithValue(ctx, nbContext.AccountIDKey, peer.AccountID) + + userID := peer.UserID + if userID == "" { + userID = activity.SystemInitiator + } + + if err = s.accountManager.DeletePeer(ctx, peer.AccountID, peer.ID, userID); err != nil { + log.WithContext(ctx).Errorf("failed to logout peer %s: %v", peerKey.String(), err) + return nil, mapError(ctx, err) + } + + s.accountManager.BufferUpdateAccountPeers(ctx, peer.AccountID) + + log.WithContext(ctx).Infof("peer %s logged out successfully", peerKey.String()) + + return &proto.Empty{}, nil +} + // toProtocolChecks converts posture checks to protocol checks. func toProtocolChecks(ctx context.Context, postureChecks []*posture.Checks) []*proto.Checks { protoChecks := make([]*proto.Checks, 0, len(postureChecks))