diff --git a/client/cmd/networks.go b/client/cmd/networks.go new file mode 100644 index 000000000..6ebf13810 --- /dev/null +++ b/client/cmd/networks.go @@ -0,0 +1,175 @@ +package cmd + +import ( + "fmt" + "strings" + + "github.com/spf13/cobra" + "google.golang.org/grpc/status" + + "github.com/netbirdio/netbird/client/proto" +) + +var appendFlag bool + +var networksCMD = &cobra.Command{ + Use: "networks", + Aliases: []string{"routes"}, + Short: "Manage networks", + Long: `Commands to list, select, or deselect networks. Replaces the "routes" command.`, +} + +var routesListCmd = &cobra.Command{ + Use: "list", + Aliases: []string{"ls"}, + Short: "List networks", + Example: " netbird networks list", + Long: "List all available network routes.", + RunE: networksList, +} + +var routesSelectCmd = &cobra.Command{ + Use: "select network...|all", + Short: "Select network", + Long: "Select a list of networks by identifiers or 'all' to clear all selections and to accept all (including new) networks.\nDefault mode is replace, use -a to append to already selected networks.", + Example: " netbird networks select all\n netbird networks select route1 route2\n netbird routes select -a route3", + Args: cobra.MinimumNArgs(1), + RunE: networksSelect, +} + +var routesDeselectCmd = &cobra.Command{ + Use: "deselect network...|all", + Short: "Deselect networks", + Long: "Deselect previously selected networks by identifiers or 'all' to disable accepting any networks.", + Example: " netbird networks deselect all\n netbird networks deselect route1 route2", + Args: cobra.MinimumNArgs(1), + RunE: networksDeselect, +} + +func init() { + routesSelectCmd.PersistentFlags().BoolVarP(&appendFlag, "append", "a", false, "Append to current network selection instead of replacing") +} + +func networksList(cmd *cobra.Command, _ []string) error { + conn, err := getClient(cmd) + if err != nil { + return err + } + defer conn.Close() + + client := proto.NewDaemonServiceClient(conn) + resp, err := client.ListNetworks(cmd.Context(), &proto.ListNetworksRequest{}) + if err != nil { + return fmt.Errorf("failed to list network: %v", status.Convert(err).Message()) + } + + if len(resp.Routes) == 0 { + cmd.Println("No networks available.") + return nil + } + + printRoutes(cmd, resp) + + return nil +} + +func printRoutes(cmd *cobra.Command, resp *proto.ListNetworksResponse) { + cmd.Println("Available Networks:") + for _, route := range resp.Routes { + printRoute(cmd, route) + } +} + +func printRoute(cmd *cobra.Command, route *proto.Network) { + selectedStatus := getSelectedStatus(route) + domains := route.GetDomains() + + if len(domains) > 0 { + printDomainRoute(cmd, route, domains, selectedStatus) + } else { + printNetworkRoute(cmd, route, selectedStatus) + } +} + +func getSelectedStatus(route *proto.Network) string { + if route.GetSelected() { + return "Selected" + } + return "Not Selected" +} + +func printDomainRoute(cmd *cobra.Command, route *proto.Network, domains []string, selectedStatus string) { + cmd.Printf("\n - ID: %s\n Domains: %s\n Status: %s\n", route.GetID(), strings.Join(domains, ", "), selectedStatus) + resolvedIPs := route.GetResolvedIPs() + + if len(resolvedIPs) > 0 { + printResolvedIPs(cmd, domains, resolvedIPs) + } else { + cmd.Printf(" Resolved IPs: -\n") + } +} + +func printNetworkRoute(cmd *cobra.Command, route *proto.Network, selectedStatus string) { + cmd.Printf("\n - ID: %s\n Network: %s\n Status: %s\n", route.GetID(), route.GetRange(), selectedStatus) +} + +func printResolvedIPs(cmd *cobra.Command, domains []string, resolvedIPs map[string]*proto.IPList) { + cmd.Printf(" Resolved IPs:\n") + for _, domain := range domains { + if ipList, exists := resolvedIPs[domain]; exists { + cmd.Printf(" [%s]: %s\n", domain, strings.Join(ipList.GetIps(), ", ")) + } + } +} + +func networksSelect(cmd *cobra.Command, args []string) error { + conn, err := getClient(cmd) + if err != nil { + return err + } + defer conn.Close() + + client := proto.NewDaemonServiceClient(conn) + req := &proto.SelectNetworksRequest{ + NetworkIDs: args, + } + + if len(args) == 1 && args[0] == "all" { + req.All = true + } else if appendFlag { + req.Append = true + } + + if _, err := client.SelectNetworks(cmd.Context(), req); err != nil { + return fmt.Errorf("failed to select networks: %v", status.Convert(err).Message()) + } + + cmd.Println("Networks selected successfully.") + + return nil +} + +func networksDeselect(cmd *cobra.Command, args []string) error { + conn, err := getClient(cmd) + if err != nil { + return err + } + defer conn.Close() + + client := proto.NewDaemonServiceClient(conn) + req := &proto.SelectNetworksRequest{ + NetworkIDs: args, + } + + if len(args) == 1 && args[0] == "all" { + req.All = true + } + + if _, err := client.DeselectNetworks(cmd.Context(), req); err != nil { + return fmt.Errorf("failed to deselect networks: %v", status.Convert(err).Message()) + } + + cmd.Println("Networks deselected successfully.") + + return nil +} diff --git a/client/cmd/root.go b/client/cmd/root.go index 3f2d04ef3..0305bacc8 100644 --- a/client/cmd/root.go +++ b/client/cmd/root.go @@ -142,14 +142,14 @@ func init() { rootCmd.AddCommand(loginCmd) rootCmd.AddCommand(versionCmd) rootCmd.AddCommand(sshCmd) - rootCmd.AddCommand(routesCmd) + rootCmd.AddCommand(networksCMD) rootCmd.AddCommand(debugCmd) serviceCmd.AddCommand(runCmd, startCmd, stopCmd, restartCmd) // service control commands are subcommands of service serviceCmd.AddCommand(installCmd, uninstallCmd) // service installer commands are subcommands of service - routesCmd.AddCommand(routesListCmd) - routesCmd.AddCommand(routesSelectCmd, routesDeselectCmd) + networksCMD.AddCommand(routesListCmd) + networksCMD.AddCommand(routesSelectCmd, routesDeselectCmd) debugCmd.AddCommand(debugBundleCmd) debugCmd.AddCommand(logCmd) diff --git a/client/cmd/route.go b/client/cmd/route.go deleted file mode 100644 index c8881822b..000000000 --- a/client/cmd/route.go +++ /dev/null @@ -1,174 +0,0 @@ -package cmd - -import ( - "fmt" - "strings" - - "github.com/spf13/cobra" - "google.golang.org/grpc/status" - - "github.com/netbirdio/netbird/client/proto" -) - -var appendFlag bool - -var routesCmd = &cobra.Command{ - Use: "routes", - Short: "Manage network routes", - Long: `Commands to list, select, or deselect network routes.`, -} - -var routesListCmd = &cobra.Command{ - Use: "list", - Aliases: []string{"ls"}, - Short: "List routes", - Example: " netbird routes list", - Long: "List all available network routes.", - RunE: routesList, -} - -var routesSelectCmd = &cobra.Command{ - Use: "select route...|all", - Short: "Select routes", - Long: "Select a list of routes by identifiers or 'all' to clear all selections and to accept all (including new) routes.\nDefault mode is replace, use -a to append to already selected routes.", - Example: " netbird routes select all\n netbird routes select route1 route2\n netbird routes select -a route3", - Args: cobra.MinimumNArgs(1), - RunE: routesSelect, -} - -var routesDeselectCmd = &cobra.Command{ - Use: "deselect route...|all", - Short: "Deselect routes", - Long: "Deselect previously selected routes by identifiers or 'all' to disable accepting any routes.", - Example: " netbird routes deselect all\n netbird routes deselect route1 route2", - Args: cobra.MinimumNArgs(1), - RunE: routesDeselect, -} - -func init() { - routesSelectCmd.PersistentFlags().BoolVarP(&appendFlag, "append", "a", false, "Append to current route selection instead of replacing") -} - -func routesList(cmd *cobra.Command, _ []string) error { - conn, err := getClient(cmd) - if err != nil { - return err - } - defer conn.Close() - - client := proto.NewDaemonServiceClient(conn) - resp, err := client.ListRoutes(cmd.Context(), &proto.ListRoutesRequest{}) - if err != nil { - return fmt.Errorf("failed to list routes: %v", status.Convert(err).Message()) - } - - if len(resp.Routes) == 0 { - cmd.Println("No routes available.") - return nil - } - - printRoutes(cmd, resp) - - return nil -} - -func printRoutes(cmd *cobra.Command, resp *proto.ListRoutesResponse) { - cmd.Println("Available Routes:") - for _, route := range resp.Routes { - printRoute(cmd, route) - } -} - -func printRoute(cmd *cobra.Command, route *proto.Route) { - selectedStatus := getSelectedStatus(route) - domains := route.GetDomains() - - if len(domains) > 0 { - printDomainRoute(cmd, route, domains, selectedStatus) - } else { - printNetworkRoute(cmd, route, selectedStatus) - } -} - -func getSelectedStatus(route *proto.Route) string { - if route.GetSelected() { - return "Selected" - } - return "Not Selected" -} - -func printDomainRoute(cmd *cobra.Command, route *proto.Route, domains []string, selectedStatus string) { - cmd.Printf("\n - ID: %s\n Domains: %s\n Status: %s\n", route.GetID(), strings.Join(domains, ", "), selectedStatus) - resolvedIPs := route.GetResolvedIPs() - - if len(resolvedIPs) > 0 { - printResolvedIPs(cmd, domains, resolvedIPs) - } else { - cmd.Printf(" Resolved IPs: -\n") - } -} - -func printNetworkRoute(cmd *cobra.Command, route *proto.Route, selectedStatus string) { - cmd.Printf("\n - ID: %s\n Network: %s\n Status: %s\n", route.GetID(), route.GetNetwork(), selectedStatus) -} - -func printResolvedIPs(cmd *cobra.Command, domains []string, resolvedIPs map[string]*proto.IPList) { - cmd.Printf(" Resolved IPs:\n") - for _, domain := range domains { - if ipList, exists := resolvedIPs[domain]; exists { - cmd.Printf(" [%s]: %s\n", domain, strings.Join(ipList.GetIps(), ", ")) - } - } -} - -func routesSelect(cmd *cobra.Command, args []string) error { - conn, err := getClient(cmd) - if err != nil { - return err - } - defer conn.Close() - - client := proto.NewDaemonServiceClient(conn) - req := &proto.SelectRoutesRequest{ - RouteIDs: args, - } - - if len(args) == 1 && args[0] == "all" { - req.All = true - } else if appendFlag { - req.Append = true - } - - if _, err := client.SelectRoutes(cmd.Context(), req); err != nil { - return fmt.Errorf("failed to select routes: %v", status.Convert(err).Message()) - } - - cmd.Println("Routes selected successfully.") - - return nil -} - -func routesDeselect(cmd *cobra.Command, args []string) error { - conn, err := getClient(cmd) - if err != nil { - return err - } - defer conn.Close() - - client := proto.NewDaemonServiceClient(conn) - req := &proto.SelectRoutesRequest{ - RouteIDs: args, - } - - if len(args) == 1 && args[0] == "all" { - req.All = true - } - - if _, err := client.DeselectRoutes(cmd.Context(), req); err != nil { - return fmt.Errorf("failed to deselect routes: %v", status.Convert(err).Message()) - } - - cmd.Println("Routes deselected successfully.") - - return nil -} diff --git a/client/cmd/status.go b/client/cmd/status.go index 6db52a677..fa4bff77b 100644 --- a/client/cmd/status.go +++ b/client/cmd/status.go @@ -40,6 +40,7 @@ type peerStateDetailOutput struct { Latency time.Duration `json:"latency" yaml:"latency"` RosenpassEnabled bool `json:"quantumResistance" yaml:"quantumResistance"` Routes []string `json:"routes" yaml:"routes"` + Networks []string `json:"networks" yaml:"networks"` } type peersStateOutput struct { @@ -98,6 +99,7 @@ type statusOutputOverview struct { RosenpassEnabled bool `json:"quantumResistance" yaml:"quantumResistance"` RosenpassPermissive bool `json:"quantumResistancePermissive" yaml:"quantumResistancePermissive"` Routes []string `json:"routes" yaml:"routes"` + Networks []string `json:"networks" yaml:"networks"` NSServerGroups []nsServerGroupStateOutput `json:"dnsServers" yaml:"dnsServers"` } @@ -282,7 +284,8 @@ func convertToStatusOutputOverview(resp *proto.StatusResponse) statusOutputOverv FQDN: pbFullStatus.GetLocalPeerState().GetFqdn(), RosenpassEnabled: pbFullStatus.GetLocalPeerState().GetRosenpassEnabled(), RosenpassPermissive: pbFullStatus.GetLocalPeerState().GetRosenpassPermissive(), - Routes: pbFullStatus.GetLocalPeerState().GetRoutes(), + Routes: pbFullStatus.GetLocalPeerState().GetNetworks(), + Networks: pbFullStatus.GetLocalPeerState().GetNetworks(), NSServerGroups: mapNSGroups(pbFullStatus.GetDnsServers()), } @@ -390,7 +393,8 @@ func mapPeers(peers []*proto.PeerState) peersStateOutput { TransferSent: transferSent, Latency: pbPeerState.GetLatency().AsDuration(), RosenpassEnabled: pbPeerState.GetRosenpassEnabled(), - Routes: pbPeerState.GetRoutes(), + Routes: pbPeerState.GetNetworks(), + Networks: pbPeerState.GetNetworks(), } peersStateDetail = append(peersStateDetail, peerState) @@ -491,10 +495,10 @@ func parseGeneralSummary(overview statusOutputOverview, showURL bool, showRelays relaysString = fmt.Sprintf("%d/%d Available", overview.Relays.Available, overview.Relays.Total) } - routes := "-" - if len(overview.Routes) > 0 { - sort.Strings(overview.Routes) - routes = strings.Join(overview.Routes, ", ") + networks := "-" + if len(overview.Networks) > 0 { + sort.Strings(overview.Networks) + networks = strings.Join(overview.Networks, ", ") } var dnsServersString string @@ -556,6 +560,7 @@ func parseGeneralSummary(overview statusOutputOverview, showURL bool, showRelays "Interface type: %s\n"+ "Quantum resistance: %s\n"+ "Routes: %s\n"+ + "Networks: %s\n"+ "Peers count: %s\n", fmt.Sprintf("%s/%s%s", goos, goarch, goarm), overview.DaemonVersion, @@ -568,7 +573,8 @@ func parseGeneralSummary(overview statusOutputOverview, showURL bool, showRelays interfaceIP, interfaceTypeString, rosenpassEnabledStatus, - routes, + networks, + networks, peersCountString, ) return summary @@ -631,10 +637,10 @@ func parsePeers(peers peersStateOutput, rosenpassEnabled, rosenpassPermissive bo } } - routes := "-" - if len(peerState.Routes) > 0 { - sort.Strings(peerState.Routes) - routes = strings.Join(peerState.Routes, ", ") + networks := "-" + if len(peerState.Networks) > 0 { + sort.Strings(peerState.Networks) + networks = strings.Join(peerState.Networks, ", ") } peerString := fmt.Sprintf( @@ -652,6 +658,7 @@ func parsePeers(peers peersStateOutput, rosenpassEnabled, rosenpassPermissive bo " Transfer status (received/sent) %s/%s\n"+ " Quantum resistance: %s\n"+ " Routes: %s\n"+ + " Networks: %s\n"+ " Latency: %s\n", peerState.FQDN, peerState.IP, @@ -668,7 +675,8 @@ func parsePeers(peers peersStateOutput, rosenpassEnabled, rosenpassPermissive bo toIEC(peerState.TransferReceived), toIEC(peerState.TransferSent), rosenpassEnabledStatus, - routes, + networks, + networks, peerState.Latency.String(), ) @@ -810,6 +818,14 @@ func anonymizePeerDetail(a *anonymize.Anonymizer, peer *peerStateDetailOutput) { peer.RelayAddress = a.AnonymizeURI(peer.RelayAddress) + for i, route := range peer.Networks { + peer.Networks[i] = a.AnonymizeIPString(route) + } + + for i, route := range peer.Networks { + peer.Networks[i] = a.AnonymizeRoute(route) + } + for i, route := range peer.Routes { peer.Routes[i] = a.AnonymizeIPString(route) } @@ -850,6 +866,10 @@ func anonymizeOverview(a *anonymize.Anonymizer, overview *statusOutputOverview) } } + for i, route := range overview.Networks { + overview.Networks[i] = a.AnonymizeRoute(route) + } + for i, route := range overview.Routes { overview.Routes[i] = a.AnonymizeRoute(route) } diff --git a/client/cmd/status_test.go b/client/cmd/status_test.go index ca43df8a5..1f1e95726 100644 --- a/client/cmd/status_test.go +++ b/client/cmd/status_test.go @@ -44,7 +44,7 @@ var resp = &proto.StatusResponse{ LastWireguardHandshake: timestamppb.New(time.Date(2001, time.Month(1), 1, 1, 1, 2, 0, time.UTC)), BytesRx: 200, BytesTx: 100, - Routes: []string{ + Networks: []string{ "10.1.0.0/24", }, Latency: durationpb.New(time.Duration(10000000)), @@ -93,7 +93,7 @@ var resp = &proto.StatusResponse{ PubKey: "Some-Pub-Key", KernelInterface: true, Fqdn: "some-localhost.awesome-domain.com", - Routes: []string{ + Networks: []string{ "10.10.0.0/24", }, }, @@ -149,6 +149,9 @@ var overview = statusOutputOverview{ Routes: []string{ "10.1.0.0/24", }, + Networks: []string{ + "10.1.0.0/24", + }, Latency: time.Duration(10000000), }, { @@ -230,6 +233,9 @@ var overview = statusOutputOverview{ Routes: []string{ "10.10.0.0/24", }, + Networks: []string{ + "10.10.0.0/24", + }, } func TestConversionFromFullStatusToOutputOverview(t *testing.T) { @@ -295,6 +301,9 @@ func TestParsingToJSON(t *testing.T) { "quantumResistance": false, "routes": [ "10.1.0.0/24" + ], + "networks": [ + "10.1.0.0/24" ] }, { @@ -318,7 +327,8 @@ func TestParsingToJSON(t *testing.T) { "transferSent": 1000, "latency": 10000000, "quantumResistance": false, - "routes": null + "routes": null, + "networks": null } ] }, @@ -359,6 +369,9 @@ func TestParsingToJSON(t *testing.T) { "routes": [ "10.10.0.0/24" ], + "networks": [ + "10.10.0.0/24" + ], "dnsServers": [ { "servers": [ @@ -418,6 +431,8 @@ func TestParsingToYAML(t *testing.T) { quantumResistance: false routes: - 10.1.0.0/24 + networks: + - 10.1.0.0/24 - fqdn: peer-2.awesome-domain.com netbirdIp: 192.168.178.102 publicKey: Pubkey2 @@ -437,6 +452,7 @@ func TestParsingToYAML(t *testing.T) { latency: 10ms quantumResistance: false routes: [] + networks: [] cliVersion: development daemonVersion: 0.14.1 management: @@ -465,6 +481,8 @@ quantumResistance: false quantumResistancePermissive: false routes: - 10.10.0.0/24 +networks: + - 10.10.0.0/24 dnsServers: - servers: - 8.8.8.8:53 @@ -509,6 +527,7 @@ func TestParsingToDetail(t *testing.T) { Transfer status (received/sent) 200 B/100 B Quantum resistance: false Routes: 10.1.0.0/24 + Networks: 10.1.0.0/24 Latency: 10ms peer-2.awesome-domain.com: @@ -525,6 +544,7 @@ func TestParsingToDetail(t *testing.T) { Transfer status (received/sent) 2.0 KiB/1000 B Quantum resistance: false Routes: - + Networks: - Latency: 10ms OS: %s/%s @@ -543,6 +563,7 @@ NetBird IP: 192.168.178.100/16 Interface type: Kernel Quantum resistance: false Routes: 10.10.0.0/24 +Networks: 10.10.0.0/24 Peers count: 2/2 Connected `, lastConnectionUpdate1, lastHandshake1, lastConnectionUpdate2, lastHandshake2, runtime.GOOS, runtime.GOARCH, overview.CliVersion) @@ -564,6 +585,7 @@ NetBird IP: 192.168.178.100/16 Interface type: Kernel Quantum resistance: false Routes: 10.10.0.0/24 +Networks: 10.10.0.0/24 Peers count: 2/2 Connected ` diff --git a/client/cmd/testutil_test.go b/client/cmd/testutil_test.go index d998f9ea9..527a6badb 100644 --- a/client/cmd/testutil_test.go +++ b/client/cmd/testutil_test.go @@ -10,6 +10,7 @@ import ( "go.opentelemetry.io/otel" "github.com/netbirdio/netbird/management/server/activity" + "github.com/netbirdio/netbird/management/server/store" "github.com/netbirdio/netbird/management/server/telemetry" "github.com/netbirdio/netbird/util" @@ -71,7 +72,7 @@ func startManagement(t *testing.T, config *mgmt.Config, testFile string) (*grpc. t.Fatal(err) } s := grpc.NewServer() - store, cleanUp, err := mgmt.NewTestStoreFromSQL(context.Background(), testFile, t.TempDir()) + store, cleanUp, err := store.NewTestStoreFromSQL(context.Background(), testFile, t.TempDir()) if err != nil { t.Fatal(err) } diff --git a/client/internal/engine_test.go b/client/internal/engine_test.go index f1fec67e7..9305c0b5a 100644 --- a/client/internal/engine_test.go +++ b/client/internal/engine_test.go @@ -39,6 +39,7 @@ import ( mgmtProto "github.com/netbirdio/netbird/management/proto" "github.com/netbirdio/netbird/management/server" "github.com/netbirdio/netbird/management/server/activity" + "github.com/netbirdio/netbird/management/server/store" "github.com/netbirdio/netbird/management/server/telemetry" relayClient "github.com/netbirdio/netbird/relay/client" "github.com/netbirdio/netbird/route" @@ -1196,7 +1197,7 @@ func startManagement(t *testing.T, dataDir, testFile string) (*grpc.Server, stri } s := grpc.NewServer(grpc.KeepaliveEnforcementPolicy(kaep), grpc.KeepaliveParams(kasp)) - store, cleanUp, err := server.NewTestStoreFromSQL(context.Background(), testFile, config.Datadir) + store, cleanUp, err := store.NewTestStoreFromSQL(context.Background(), testFile, config.Datadir) if err != nil { return nil, "", err } diff --git a/client/proto/daemon.pb.go b/client/proto/daemon.pb.go index 98ce2c4a2..f0d3021e9 100644 --- a/client/proto/daemon.pb.go +++ b/client/proto/daemon.pb.go @@ -1,7 +1,7 @@ // Code generated by protoc-gen-go. DO NOT EDIT. // versions: // protoc-gen-go v1.26.0 -// protoc v4.23.4 +// protoc v3.21.9 // source: daemon.proto package proto @@ -908,7 +908,7 @@ type PeerState struct { BytesRx int64 `protobuf:"varint,13,opt,name=bytesRx,proto3" json:"bytesRx,omitempty"` BytesTx int64 `protobuf:"varint,14,opt,name=bytesTx,proto3" json:"bytesTx,omitempty"` RosenpassEnabled bool `protobuf:"varint,15,opt,name=rosenpassEnabled,proto3" json:"rosenpassEnabled,omitempty"` - Routes []string `protobuf:"bytes,16,rep,name=routes,proto3" json:"routes,omitempty"` + Networks []string `protobuf:"bytes,16,rep,name=networks,proto3" json:"networks,omitempty"` Latency *durationpb.Duration `protobuf:"bytes,17,opt,name=latency,proto3" json:"latency,omitempty"` RelayAddress string `protobuf:"bytes,18,opt,name=relayAddress,proto3" json:"relayAddress,omitempty"` } @@ -1043,9 +1043,9 @@ func (x *PeerState) GetRosenpassEnabled() bool { return false } -func (x *PeerState) GetRoutes() []string { +func (x *PeerState) GetNetworks() []string { if x != nil { - return x.Routes + return x.Networks } return nil } @@ -1076,7 +1076,7 @@ type LocalPeerState struct { Fqdn string `protobuf:"bytes,4,opt,name=fqdn,proto3" json:"fqdn,omitempty"` RosenpassEnabled bool `protobuf:"varint,5,opt,name=rosenpassEnabled,proto3" json:"rosenpassEnabled,omitempty"` RosenpassPermissive bool `protobuf:"varint,6,opt,name=rosenpassPermissive,proto3" json:"rosenpassPermissive,omitempty"` - Routes []string `protobuf:"bytes,7,rep,name=routes,proto3" json:"routes,omitempty"` + Networks []string `protobuf:"bytes,7,rep,name=networks,proto3" json:"networks,omitempty"` } func (x *LocalPeerState) Reset() { @@ -1153,9 +1153,9 @@ func (x *LocalPeerState) GetRosenpassPermissive() bool { return false } -func (x *LocalPeerState) GetRoutes() []string { +func (x *LocalPeerState) GetNetworks() []string { if x != nil { - return x.Routes + return x.Networks } return nil } @@ -1511,14 +1511,14 @@ func (x *FullStatus) GetDnsServers() []*NSGroupState { return nil } -type ListRoutesRequest struct { +type ListNetworksRequest struct { state protoimpl.MessageState sizeCache protoimpl.SizeCache unknownFields protoimpl.UnknownFields } -func (x *ListRoutesRequest) Reset() { - *x = ListRoutesRequest{} +func (x *ListNetworksRequest) Reset() { + *x = ListNetworksRequest{} if protoimpl.UnsafeEnabled { mi := &file_daemon_proto_msgTypes[19] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) @@ -1526,13 +1526,13 @@ func (x *ListRoutesRequest) Reset() { } } -func (x *ListRoutesRequest) String() string { +func (x *ListNetworksRequest) String() string { return protoimpl.X.MessageStringOf(x) } -func (*ListRoutesRequest) ProtoMessage() {} +func (*ListNetworksRequest) ProtoMessage() {} -func (x *ListRoutesRequest) ProtoReflect() protoreflect.Message { +func (x *ListNetworksRequest) ProtoReflect() protoreflect.Message { mi := &file_daemon_proto_msgTypes[19] if protoimpl.UnsafeEnabled && x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) @@ -1544,21 +1544,21 @@ func (x *ListRoutesRequest) ProtoReflect() protoreflect.Message { return mi.MessageOf(x) } -// Deprecated: Use ListRoutesRequest.ProtoReflect.Descriptor instead. -func (*ListRoutesRequest) Descriptor() ([]byte, []int) { +// Deprecated: Use ListNetworksRequest.ProtoReflect.Descriptor instead. +func (*ListNetworksRequest) Descriptor() ([]byte, []int) { return file_daemon_proto_rawDescGZIP(), []int{19} } -type ListRoutesResponse struct { +type ListNetworksResponse struct { state protoimpl.MessageState sizeCache protoimpl.SizeCache unknownFields protoimpl.UnknownFields - Routes []*Route `protobuf:"bytes,1,rep,name=routes,proto3" json:"routes,omitempty"` + Routes []*Network `protobuf:"bytes,1,rep,name=routes,proto3" json:"routes,omitempty"` } -func (x *ListRoutesResponse) Reset() { - *x = ListRoutesResponse{} +func (x *ListNetworksResponse) Reset() { + *x = ListNetworksResponse{} if protoimpl.UnsafeEnabled { mi := &file_daemon_proto_msgTypes[20] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) @@ -1566,13 +1566,13 @@ func (x *ListRoutesResponse) Reset() { } } -func (x *ListRoutesResponse) String() string { +func (x *ListNetworksResponse) String() string { return protoimpl.X.MessageStringOf(x) } -func (*ListRoutesResponse) ProtoMessage() {} +func (*ListNetworksResponse) ProtoMessage() {} -func (x *ListRoutesResponse) ProtoReflect() protoreflect.Message { +func (x *ListNetworksResponse) ProtoReflect() protoreflect.Message { mi := &file_daemon_proto_msgTypes[20] if protoimpl.UnsafeEnabled && x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) @@ -1584,30 +1584,30 @@ func (x *ListRoutesResponse) ProtoReflect() protoreflect.Message { return mi.MessageOf(x) } -// Deprecated: Use ListRoutesResponse.ProtoReflect.Descriptor instead. -func (*ListRoutesResponse) Descriptor() ([]byte, []int) { +// Deprecated: Use ListNetworksResponse.ProtoReflect.Descriptor instead. +func (*ListNetworksResponse) Descriptor() ([]byte, []int) { return file_daemon_proto_rawDescGZIP(), []int{20} } -func (x *ListRoutesResponse) GetRoutes() []*Route { +func (x *ListNetworksResponse) GetRoutes() []*Network { if x != nil { return x.Routes } return nil } -type SelectRoutesRequest struct { +type SelectNetworksRequest struct { state protoimpl.MessageState sizeCache protoimpl.SizeCache unknownFields protoimpl.UnknownFields - RouteIDs []string `protobuf:"bytes,1,rep,name=routeIDs,proto3" json:"routeIDs,omitempty"` - Append bool `protobuf:"varint,2,opt,name=append,proto3" json:"append,omitempty"` - All bool `protobuf:"varint,3,opt,name=all,proto3" json:"all,omitempty"` + NetworkIDs []string `protobuf:"bytes,1,rep,name=networkIDs,proto3" json:"networkIDs,omitempty"` + Append bool `protobuf:"varint,2,opt,name=append,proto3" json:"append,omitempty"` + All bool `protobuf:"varint,3,opt,name=all,proto3" json:"all,omitempty"` } -func (x *SelectRoutesRequest) Reset() { - *x = SelectRoutesRequest{} +func (x *SelectNetworksRequest) Reset() { + *x = SelectNetworksRequest{} if protoimpl.UnsafeEnabled { mi := &file_daemon_proto_msgTypes[21] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) @@ -1615,13 +1615,13 @@ func (x *SelectRoutesRequest) Reset() { } } -func (x *SelectRoutesRequest) String() string { +func (x *SelectNetworksRequest) String() string { return protoimpl.X.MessageStringOf(x) } -func (*SelectRoutesRequest) ProtoMessage() {} +func (*SelectNetworksRequest) ProtoMessage() {} -func (x *SelectRoutesRequest) ProtoReflect() protoreflect.Message { +func (x *SelectNetworksRequest) ProtoReflect() protoreflect.Message { mi := &file_daemon_proto_msgTypes[21] if protoimpl.UnsafeEnabled && x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) @@ -1633,40 +1633,40 @@ func (x *SelectRoutesRequest) ProtoReflect() protoreflect.Message { return mi.MessageOf(x) } -// Deprecated: Use SelectRoutesRequest.ProtoReflect.Descriptor instead. -func (*SelectRoutesRequest) Descriptor() ([]byte, []int) { +// Deprecated: Use SelectNetworksRequest.ProtoReflect.Descriptor instead. +func (*SelectNetworksRequest) Descriptor() ([]byte, []int) { return file_daemon_proto_rawDescGZIP(), []int{21} } -func (x *SelectRoutesRequest) GetRouteIDs() []string { +func (x *SelectNetworksRequest) GetNetworkIDs() []string { if x != nil { - return x.RouteIDs + return x.NetworkIDs } return nil } -func (x *SelectRoutesRequest) GetAppend() bool { +func (x *SelectNetworksRequest) GetAppend() bool { if x != nil { return x.Append } return false } -func (x *SelectRoutesRequest) GetAll() bool { +func (x *SelectNetworksRequest) GetAll() bool { if x != nil { return x.All } return false } -type SelectRoutesResponse struct { +type SelectNetworksResponse struct { state protoimpl.MessageState sizeCache protoimpl.SizeCache unknownFields protoimpl.UnknownFields } -func (x *SelectRoutesResponse) Reset() { - *x = SelectRoutesResponse{} +func (x *SelectNetworksResponse) Reset() { + *x = SelectNetworksResponse{} if protoimpl.UnsafeEnabled { mi := &file_daemon_proto_msgTypes[22] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) @@ -1674,13 +1674,13 @@ func (x *SelectRoutesResponse) Reset() { } } -func (x *SelectRoutesResponse) String() string { +func (x *SelectNetworksResponse) String() string { return protoimpl.X.MessageStringOf(x) } -func (*SelectRoutesResponse) ProtoMessage() {} +func (*SelectNetworksResponse) ProtoMessage() {} -func (x *SelectRoutesResponse) ProtoReflect() protoreflect.Message { +func (x *SelectNetworksResponse) ProtoReflect() protoreflect.Message { mi := &file_daemon_proto_msgTypes[22] if protoimpl.UnsafeEnabled && x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) @@ -1692,8 +1692,8 @@ func (x *SelectRoutesResponse) ProtoReflect() protoreflect.Message { return mi.MessageOf(x) } -// Deprecated: Use SelectRoutesResponse.ProtoReflect.Descriptor instead. -func (*SelectRoutesResponse) Descriptor() ([]byte, []int) { +// Deprecated: Use SelectNetworksResponse.ProtoReflect.Descriptor instead. +func (*SelectNetworksResponse) Descriptor() ([]byte, []int) { return file_daemon_proto_rawDescGZIP(), []int{22} } @@ -1744,20 +1744,20 @@ func (x *IPList) GetIps() []string { return nil } -type Route struct { +type Network struct { state protoimpl.MessageState sizeCache protoimpl.SizeCache unknownFields protoimpl.UnknownFields ID string `protobuf:"bytes,1,opt,name=ID,proto3" json:"ID,omitempty"` - Network string `protobuf:"bytes,2,opt,name=network,proto3" json:"network,omitempty"` + Range string `protobuf:"bytes,2,opt,name=range,proto3" json:"range,omitempty"` Selected bool `protobuf:"varint,3,opt,name=selected,proto3" json:"selected,omitempty"` Domains []string `protobuf:"bytes,4,rep,name=domains,proto3" json:"domains,omitempty"` ResolvedIPs map[string]*IPList `protobuf:"bytes,5,rep,name=resolvedIPs,proto3" json:"resolvedIPs,omitempty" protobuf_key:"bytes,1,opt,name=key,proto3" protobuf_val:"bytes,2,opt,name=value,proto3"` } -func (x *Route) Reset() { - *x = Route{} +func (x *Network) Reset() { + *x = Network{} if protoimpl.UnsafeEnabled { mi := &file_daemon_proto_msgTypes[24] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) @@ -1765,13 +1765,13 @@ func (x *Route) Reset() { } } -func (x *Route) String() string { +func (x *Network) String() string { return protoimpl.X.MessageStringOf(x) } -func (*Route) ProtoMessage() {} +func (*Network) ProtoMessage() {} -func (x *Route) ProtoReflect() protoreflect.Message { +func (x *Network) ProtoReflect() protoreflect.Message { mi := &file_daemon_proto_msgTypes[24] if protoimpl.UnsafeEnabled && x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) @@ -1783,40 +1783,40 @@ func (x *Route) ProtoReflect() protoreflect.Message { return mi.MessageOf(x) } -// Deprecated: Use Route.ProtoReflect.Descriptor instead. -func (*Route) Descriptor() ([]byte, []int) { +// Deprecated: Use Network.ProtoReflect.Descriptor instead. +func (*Network) Descriptor() ([]byte, []int) { return file_daemon_proto_rawDescGZIP(), []int{24} } -func (x *Route) GetID() string { +func (x *Network) GetID() string { if x != nil { return x.ID } return "" } -func (x *Route) GetNetwork() string { +func (x *Network) GetRange() string { if x != nil { - return x.Network + return x.Range } return "" } -func (x *Route) GetSelected() bool { +func (x *Network) GetSelected() bool { if x != nil { return x.Selected } return false } -func (x *Route) GetDomains() []string { +func (x *Network) GetDomains() []string { if x != nil { return x.Domains } return nil } -func (x *Route) GetResolvedIPs() map[string]*IPList { +func (x *Network) GetResolvedIPs() map[string]*IPList { if x != nil { return x.ResolvedIPs } @@ -2671,7 +2671,7 @@ var file_daemon_proto_rawDesc = []byte{ 0x45, 0x6e, 0x61, 0x62, 0x6c, 0x65, 0x64, 0x12, 0x30, 0x0a, 0x13, 0x72, 0x6f, 0x73, 0x65, 0x6e, 0x70, 0x61, 0x73, 0x73, 0x50, 0x65, 0x72, 0x6d, 0x69, 0x73, 0x73, 0x69, 0x76, 0x65, 0x18, 0x0c, 0x20, 0x01, 0x28, 0x08, 0x52, 0x13, 0x72, 0x6f, 0x73, 0x65, 0x6e, 0x70, 0x61, 0x73, 0x73, 0x50, - 0x65, 0x72, 0x6d, 0x69, 0x73, 0x73, 0x69, 0x76, 0x65, 0x22, 0xda, 0x05, 0x0a, 0x09, 0x50, 0x65, + 0x65, 0x72, 0x6d, 0x69, 0x73, 0x73, 0x69, 0x76, 0x65, 0x22, 0xde, 0x05, 0x0a, 0x09, 0x50, 0x65, 0x65, 0x72, 0x53, 0x74, 0x61, 0x74, 0x65, 0x12, 0x0e, 0x0a, 0x02, 0x49, 0x50, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x02, 0x49, 0x50, 0x12, 0x16, 0x0a, 0x06, 0x70, 0x75, 0x62, 0x4b, 0x65, 0x79, 0x18, 0x02, 0x20, 0x01, 0x28, 0x09, 0x52, 0x06, 0x70, 0x75, 0x62, 0x4b, 0x65, 0x79, 0x12, @@ -2710,233 +2710,235 @@ var file_daemon_proto_rawDesc = []byte{ 0x52, 0x07, 0x62, 0x79, 0x74, 0x65, 0x73, 0x54, 0x78, 0x12, 0x2a, 0x0a, 0x10, 0x72, 0x6f, 0x73, 0x65, 0x6e, 0x70, 0x61, 0x73, 0x73, 0x45, 0x6e, 0x61, 0x62, 0x6c, 0x65, 0x64, 0x18, 0x0f, 0x20, 0x01, 0x28, 0x08, 0x52, 0x10, 0x72, 0x6f, 0x73, 0x65, 0x6e, 0x70, 0x61, 0x73, 0x73, 0x45, 0x6e, - 0x61, 0x62, 0x6c, 0x65, 0x64, 0x12, 0x16, 0x0a, 0x06, 0x72, 0x6f, 0x75, 0x74, 0x65, 0x73, 0x18, - 0x10, 0x20, 0x03, 0x28, 0x09, 0x52, 0x06, 0x72, 0x6f, 0x75, 0x74, 0x65, 0x73, 0x12, 0x33, 0x0a, - 0x07, 0x6c, 0x61, 0x74, 0x65, 0x6e, 0x63, 0x79, 0x18, 0x11, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x19, - 0x2e, 0x67, 0x6f, 0x6f, 0x67, 0x6c, 0x65, 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x62, 0x75, 0x66, - 0x2e, 0x44, 0x75, 0x72, 0x61, 0x74, 0x69, 0x6f, 0x6e, 0x52, 0x07, 0x6c, 0x61, 0x74, 0x65, 0x6e, - 0x63, 0x79, 0x12, 0x22, 0x0a, 0x0c, 0x72, 0x65, 0x6c, 0x61, 0x79, 0x41, 0x64, 0x64, 0x72, 0x65, - 0x73, 0x73, 0x18, 0x12, 0x20, 0x01, 0x28, 0x09, 0x52, 0x0c, 0x72, 0x65, 0x6c, 0x61, 0x79, 0x41, - 0x64, 0x64, 0x72, 0x65, 0x73, 0x73, 0x22, 0xec, 0x01, 0x0a, 0x0e, 0x4c, 0x6f, 0x63, 0x61, 0x6c, - 0x50, 0x65, 0x65, 0x72, 0x53, 0x74, 0x61, 0x74, 0x65, 0x12, 0x0e, 0x0a, 0x02, 0x49, 0x50, 0x18, - 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x02, 0x49, 0x50, 0x12, 0x16, 0x0a, 0x06, 0x70, 0x75, 0x62, - 0x4b, 0x65, 0x79, 0x18, 0x02, 0x20, 0x01, 0x28, 0x09, 0x52, 0x06, 0x70, 0x75, 0x62, 0x4b, 0x65, - 0x79, 0x12, 0x28, 0x0a, 0x0f, 0x6b, 0x65, 0x72, 0x6e, 0x65, 0x6c, 0x49, 0x6e, 0x74, 0x65, 0x72, - 0x66, 0x61, 0x63, 0x65, 0x18, 0x03, 0x20, 0x01, 0x28, 0x08, 0x52, 0x0f, 0x6b, 0x65, 0x72, 0x6e, - 0x65, 0x6c, 0x49, 0x6e, 0x74, 0x65, 0x72, 0x66, 0x61, 0x63, 0x65, 0x12, 0x12, 0x0a, 0x04, 0x66, - 0x71, 0x64, 0x6e, 0x18, 0x04, 0x20, 0x01, 0x28, 0x09, 0x52, 0x04, 0x66, 0x71, 0x64, 0x6e, 0x12, - 0x2a, 0x0a, 0x10, 0x72, 0x6f, 0x73, 0x65, 0x6e, 0x70, 0x61, 0x73, 0x73, 0x45, 0x6e, 0x61, 0x62, - 0x6c, 0x65, 0x64, 0x18, 0x05, 0x20, 0x01, 0x28, 0x08, 0x52, 0x10, 0x72, 0x6f, 0x73, 0x65, 0x6e, - 0x70, 0x61, 0x73, 0x73, 0x45, 0x6e, 0x61, 0x62, 0x6c, 0x65, 0x64, 0x12, 0x30, 0x0a, 0x13, 0x72, - 0x6f, 0x73, 0x65, 0x6e, 0x70, 0x61, 0x73, 0x73, 0x50, 0x65, 0x72, 0x6d, 0x69, 0x73, 0x73, 0x69, - 0x76, 0x65, 0x18, 0x06, 0x20, 0x01, 0x28, 0x08, 0x52, 0x13, 0x72, 0x6f, 0x73, 0x65, 0x6e, 0x70, - 0x61, 0x73, 0x73, 0x50, 0x65, 0x72, 0x6d, 0x69, 0x73, 0x73, 0x69, 0x76, 0x65, 0x12, 0x16, 0x0a, - 0x06, 0x72, 0x6f, 0x75, 0x74, 0x65, 0x73, 0x18, 0x07, 0x20, 0x03, 0x28, 0x09, 0x52, 0x06, 0x72, - 0x6f, 0x75, 0x74, 0x65, 0x73, 0x22, 0x53, 0x0a, 0x0b, 0x53, 0x69, 0x67, 0x6e, 0x61, 0x6c, 0x53, - 0x74, 0x61, 0x74, 0x65, 0x12, 0x10, 0x0a, 0x03, 0x55, 0x52, 0x4c, 0x18, 0x01, 0x20, 0x01, 0x28, - 0x09, 0x52, 0x03, 0x55, 0x52, 0x4c, 0x12, 0x1c, 0x0a, 0x09, 0x63, 0x6f, 0x6e, 0x6e, 0x65, 0x63, - 0x74, 0x65, 0x64, 0x18, 0x02, 0x20, 0x01, 0x28, 0x08, 0x52, 0x09, 0x63, 0x6f, 0x6e, 0x6e, 0x65, - 0x63, 0x74, 0x65, 0x64, 0x12, 0x14, 0x0a, 0x05, 0x65, 0x72, 0x72, 0x6f, 0x72, 0x18, 0x03, 0x20, - 0x01, 0x28, 0x09, 0x52, 0x05, 0x65, 0x72, 0x72, 0x6f, 0x72, 0x22, 0x57, 0x0a, 0x0f, 0x4d, 0x61, - 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x53, 0x74, 0x61, 0x74, 0x65, 0x12, 0x10, 0x0a, - 0x03, 0x55, 0x52, 0x4c, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x03, 0x55, 0x52, 0x4c, 0x12, - 0x1c, 0x0a, 0x09, 0x63, 0x6f, 0x6e, 0x6e, 0x65, 0x63, 0x74, 0x65, 0x64, 0x18, 0x02, 0x20, 0x01, - 0x28, 0x08, 0x52, 0x09, 0x63, 0x6f, 0x6e, 0x6e, 0x65, 0x63, 0x74, 0x65, 0x64, 0x12, 0x14, 0x0a, - 0x05, 0x65, 0x72, 0x72, 0x6f, 0x72, 0x18, 0x03, 0x20, 0x01, 0x28, 0x09, 0x52, 0x05, 0x65, 0x72, - 0x72, 0x6f, 0x72, 0x22, 0x52, 0x0a, 0x0a, 0x52, 0x65, 0x6c, 0x61, 0x79, 0x53, 0x74, 0x61, 0x74, - 0x65, 0x12, 0x10, 0x0a, 0x03, 0x55, 0x52, 0x49, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x03, - 0x55, 0x52, 0x49, 0x12, 0x1c, 0x0a, 0x09, 0x61, 0x76, 0x61, 0x69, 0x6c, 0x61, 0x62, 0x6c, 0x65, - 0x18, 0x02, 0x20, 0x01, 0x28, 0x08, 0x52, 0x09, 0x61, 0x76, 0x61, 0x69, 0x6c, 0x61, 0x62, 0x6c, - 0x65, 0x12, 0x14, 0x0a, 0x05, 0x65, 0x72, 0x72, 0x6f, 0x72, 0x18, 0x03, 0x20, 0x01, 0x28, 0x09, - 0x52, 0x05, 0x65, 0x72, 0x72, 0x6f, 0x72, 0x22, 0x72, 0x0a, 0x0c, 0x4e, 0x53, 0x47, 0x72, 0x6f, - 0x75, 0x70, 0x53, 0x74, 0x61, 0x74, 0x65, 0x12, 0x18, 0x0a, 0x07, 0x73, 0x65, 0x72, 0x76, 0x65, - 0x72, 0x73, 0x18, 0x01, 0x20, 0x03, 0x28, 0x09, 0x52, 0x07, 0x73, 0x65, 0x72, 0x76, 0x65, 0x72, - 0x73, 0x12, 0x18, 0x0a, 0x07, 0x64, 0x6f, 0x6d, 0x61, 0x69, 0x6e, 0x73, 0x18, 0x02, 0x20, 0x03, - 0x28, 0x09, 0x52, 0x07, 0x64, 0x6f, 0x6d, 0x61, 0x69, 0x6e, 0x73, 0x12, 0x18, 0x0a, 0x07, 0x65, - 0x6e, 0x61, 0x62, 0x6c, 0x65, 0x64, 0x18, 0x03, 0x20, 0x01, 0x28, 0x08, 0x52, 0x07, 0x65, 0x6e, - 0x61, 0x62, 0x6c, 0x65, 0x64, 0x12, 0x14, 0x0a, 0x05, 0x65, 0x72, 0x72, 0x6f, 0x72, 0x18, 0x04, - 0x20, 0x01, 0x28, 0x09, 0x52, 0x05, 0x65, 0x72, 0x72, 0x6f, 0x72, 0x22, 0xd2, 0x02, 0x0a, 0x0a, - 0x46, 0x75, 0x6c, 0x6c, 0x53, 0x74, 0x61, 0x74, 0x75, 0x73, 0x12, 0x41, 0x0a, 0x0f, 0x6d, 0x61, - 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x53, 0x74, 0x61, 0x74, 0x65, 0x18, 0x01, 0x20, - 0x01, 0x28, 0x0b, 0x32, 0x17, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x4d, 0x61, 0x6e, - 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x53, 0x74, 0x61, 0x74, 0x65, 0x52, 0x0f, 0x6d, 0x61, - 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x53, 0x74, 0x61, 0x74, 0x65, 0x12, 0x35, 0x0a, - 0x0b, 0x73, 0x69, 0x67, 0x6e, 0x61, 0x6c, 0x53, 0x74, 0x61, 0x74, 0x65, 0x18, 0x02, 0x20, 0x01, - 0x28, 0x0b, 0x32, 0x13, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x53, 0x69, 0x67, 0x6e, - 0x61, 0x6c, 0x53, 0x74, 0x61, 0x74, 0x65, 0x52, 0x0b, 0x73, 0x69, 0x67, 0x6e, 0x61, 0x6c, 0x53, - 0x74, 0x61, 0x74, 0x65, 0x12, 0x3e, 0x0a, 0x0e, 0x6c, 0x6f, 0x63, 0x61, 0x6c, 0x50, 0x65, 0x65, - 0x72, 0x53, 0x74, 0x61, 0x74, 0x65, 0x18, 0x03, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x16, 0x2e, 0x64, - 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x4c, 0x6f, 0x63, 0x61, 0x6c, 0x50, 0x65, 0x65, 0x72, 0x53, - 0x74, 0x61, 0x74, 0x65, 0x52, 0x0e, 0x6c, 0x6f, 0x63, 0x61, 0x6c, 0x50, 0x65, 0x65, 0x72, 0x53, - 0x74, 0x61, 0x74, 0x65, 0x12, 0x27, 0x0a, 0x05, 0x70, 0x65, 0x65, 0x72, 0x73, 0x18, 0x04, 0x20, - 0x03, 0x28, 0x0b, 0x32, 0x11, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x50, 0x65, 0x65, - 0x72, 0x53, 0x74, 0x61, 0x74, 0x65, 0x52, 0x05, 0x70, 0x65, 0x65, 0x72, 0x73, 0x12, 0x2a, 0x0a, - 0x06, 0x72, 0x65, 0x6c, 0x61, 0x79, 0x73, 0x18, 0x05, 0x20, 0x03, 0x28, 0x0b, 0x32, 0x12, 0x2e, - 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x52, 0x65, 0x6c, 0x61, 0x79, 0x53, 0x74, 0x61, 0x74, - 0x65, 0x52, 0x06, 0x72, 0x65, 0x6c, 0x61, 0x79, 0x73, 0x12, 0x35, 0x0a, 0x0b, 0x64, 0x6e, 0x73, - 0x5f, 0x73, 0x65, 0x72, 0x76, 0x65, 0x72, 0x73, 0x18, 0x06, 0x20, 0x03, 0x28, 0x0b, 0x32, 0x14, - 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x4e, 0x53, 0x47, 0x72, 0x6f, 0x75, 0x70, 0x53, - 0x74, 0x61, 0x74, 0x65, 0x52, 0x0a, 0x64, 0x6e, 0x73, 0x53, 0x65, 0x72, 0x76, 0x65, 0x72, 0x73, - 0x22, 0x13, 0x0a, 0x11, 0x4c, 0x69, 0x73, 0x74, 0x52, 0x6f, 0x75, 0x74, 0x65, 0x73, 0x52, 0x65, - 0x71, 0x75, 0x65, 0x73, 0x74, 0x22, 0x3b, 0x0a, 0x12, 0x4c, 0x69, 0x73, 0x74, 0x52, 0x6f, 0x75, - 0x74, 0x65, 0x73, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x12, 0x25, 0x0a, 0x06, 0x72, - 0x6f, 0x75, 0x74, 0x65, 0x73, 0x18, 0x01, 0x20, 0x03, 0x28, 0x0b, 0x32, 0x0d, 0x2e, 0x64, 0x61, - 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x52, 0x6f, 0x75, 0x74, 0x65, 0x52, 0x06, 0x72, 0x6f, 0x75, 0x74, - 0x65, 0x73, 0x22, 0x5b, 0x0a, 0x13, 0x53, 0x65, 0x6c, 0x65, 0x63, 0x74, 0x52, 0x6f, 0x75, 0x74, - 0x65, 0x73, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x12, 0x1a, 0x0a, 0x08, 0x72, 0x6f, 0x75, - 0x74, 0x65, 0x49, 0x44, 0x73, 0x18, 0x01, 0x20, 0x03, 0x28, 0x09, 0x52, 0x08, 0x72, 0x6f, 0x75, - 0x74, 0x65, 0x49, 0x44, 0x73, 0x12, 0x16, 0x0a, 0x06, 0x61, 0x70, 0x70, 0x65, 0x6e, 0x64, 0x18, - 0x02, 0x20, 0x01, 0x28, 0x08, 0x52, 0x06, 0x61, 0x70, 0x70, 0x65, 0x6e, 0x64, 0x12, 0x10, 0x0a, - 0x03, 0x61, 0x6c, 0x6c, 0x18, 0x03, 0x20, 0x01, 0x28, 0x08, 0x52, 0x03, 0x61, 0x6c, 0x6c, 0x22, - 0x16, 0x0a, 0x14, 0x53, 0x65, 0x6c, 0x65, 0x63, 0x74, 0x52, 0x6f, 0x75, 0x74, 0x65, 0x73, 0x52, - 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x22, 0x1a, 0x0a, 0x06, 0x49, 0x50, 0x4c, 0x69, 0x73, - 0x74, 0x12, 0x10, 0x0a, 0x03, 0x69, 0x70, 0x73, 0x18, 0x01, 0x20, 0x03, 0x28, 0x09, 0x52, 0x03, - 0x69, 0x70, 0x73, 0x22, 0xf9, 0x01, 0x0a, 0x05, 0x52, 0x6f, 0x75, 0x74, 0x65, 0x12, 0x0e, 0x0a, - 0x02, 0x49, 0x44, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x02, 0x49, 0x44, 0x12, 0x18, 0x0a, - 0x07, 0x6e, 0x65, 0x74, 0x77, 0x6f, 0x72, 0x6b, 0x18, 0x02, 0x20, 0x01, 0x28, 0x09, 0x52, 0x07, - 0x6e, 0x65, 0x74, 0x77, 0x6f, 0x72, 0x6b, 0x12, 0x1a, 0x0a, 0x08, 0x73, 0x65, 0x6c, 0x65, 0x63, - 0x74, 0x65, 0x64, 0x18, 0x03, 0x20, 0x01, 0x28, 0x08, 0x52, 0x08, 0x73, 0x65, 0x6c, 0x65, 0x63, - 0x74, 0x65, 0x64, 0x12, 0x18, 0x0a, 0x07, 0x64, 0x6f, 0x6d, 0x61, 0x69, 0x6e, 0x73, 0x18, 0x04, - 0x20, 0x03, 0x28, 0x09, 0x52, 0x07, 0x64, 0x6f, 0x6d, 0x61, 0x69, 0x6e, 0x73, 0x12, 0x40, 0x0a, - 0x0b, 0x72, 0x65, 0x73, 0x6f, 0x6c, 0x76, 0x65, 0x64, 0x49, 0x50, 0x73, 0x18, 0x05, 0x20, 0x03, - 0x28, 0x0b, 0x32, 0x1e, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x52, 0x6f, 0x75, 0x74, - 0x65, 0x2e, 0x52, 0x65, 0x73, 0x6f, 0x6c, 0x76, 0x65, 0x64, 0x49, 0x50, 0x73, 0x45, 0x6e, 0x74, - 0x72, 0x79, 0x52, 0x0b, 0x72, 0x65, 0x73, 0x6f, 0x6c, 0x76, 0x65, 0x64, 0x49, 0x50, 0x73, 0x1a, - 0x4e, 0x0a, 0x10, 0x52, 0x65, 0x73, 0x6f, 0x6c, 0x76, 0x65, 0x64, 0x49, 0x50, 0x73, 0x45, 0x6e, - 0x74, 0x72, 0x79, 0x12, 0x10, 0x0a, 0x03, 0x6b, 0x65, 0x79, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, - 0x52, 0x03, 0x6b, 0x65, 0x79, 0x12, 0x24, 0x0a, 0x05, 0x76, 0x61, 0x6c, 0x75, 0x65, 0x18, 0x02, - 0x20, 0x01, 0x28, 0x0b, 0x32, 0x0e, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x49, 0x50, - 0x4c, 0x69, 0x73, 0x74, 0x52, 0x05, 0x76, 0x61, 0x6c, 0x75, 0x65, 0x3a, 0x02, 0x38, 0x01, 0x22, - 0x6a, 0x0a, 0x12, 0x44, 0x65, 0x62, 0x75, 0x67, 0x42, 0x75, 0x6e, 0x64, 0x6c, 0x65, 0x52, 0x65, - 0x71, 0x75, 0x65, 0x73, 0x74, 0x12, 0x1c, 0x0a, 0x09, 0x61, 0x6e, 0x6f, 0x6e, 0x79, 0x6d, 0x69, - 0x7a, 0x65, 0x18, 0x01, 0x20, 0x01, 0x28, 0x08, 0x52, 0x09, 0x61, 0x6e, 0x6f, 0x6e, 0x79, 0x6d, - 0x69, 0x7a, 0x65, 0x12, 0x16, 0x0a, 0x06, 0x73, 0x74, 0x61, 0x74, 0x75, 0x73, 0x18, 0x02, 0x20, - 0x01, 0x28, 0x09, 0x52, 0x06, 0x73, 0x74, 0x61, 0x74, 0x75, 0x73, 0x12, 0x1e, 0x0a, 0x0a, 0x73, - 0x79, 0x73, 0x74, 0x65, 0x6d, 0x49, 0x6e, 0x66, 0x6f, 0x18, 0x03, 0x20, 0x01, 0x28, 0x08, 0x52, - 0x0a, 0x73, 0x79, 0x73, 0x74, 0x65, 0x6d, 0x49, 0x6e, 0x66, 0x6f, 0x22, 0x29, 0x0a, 0x13, 0x44, - 0x65, 0x62, 0x75, 0x67, 0x42, 0x75, 0x6e, 0x64, 0x6c, 0x65, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, - 0x73, 0x65, 0x12, 0x12, 0x0a, 0x04, 0x70, 0x61, 0x74, 0x68, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, - 0x52, 0x04, 0x70, 0x61, 0x74, 0x68, 0x22, 0x14, 0x0a, 0x12, 0x47, 0x65, 0x74, 0x4c, 0x6f, 0x67, - 0x4c, 0x65, 0x76, 0x65, 0x6c, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x22, 0x3d, 0x0a, 0x13, - 0x47, 0x65, 0x74, 0x4c, 0x6f, 0x67, 0x4c, 0x65, 0x76, 0x65, 0x6c, 0x52, 0x65, 0x73, 0x70, 0x6f, - 0x6e, 0x73, 0x65, 0x12, 0x26, 0x0a, 0x05, 0x6c, 0x65, 0x76, 0x65, 0x6c, 0x18, 0x01, 0x20, 0x01, - 0x28, 0x0e, 0x32, 0x10, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x4c, 0x6f, 0x67, 0x4c, - 0x65, 0x76, 0x65, 0x6c, 0x52, 0x05, 0x6c, 0x65, 0x76, 0x65, 0x6c, 0x22, 0x3c, 0x0a, 0x12, 0x53, - 0x65, 0x74, 0x4c, 0x6f, 0x67, 0x4c, 0x65, 0x76, 0x65, 0x6c, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, - 0x74, 0x12, 0x26, 0x0a, 0x05, 0x6c, 0x65, 0x76, 0x65, 0x6c, 0x18, 0x01, 0x20, 0x01, 0x28, 0x0e, - 0x32, 0x10, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x4c, 0x6f, 0x67, 0x4c, 0x65, 0x76, - 0x65, 0x6c, 0x52, 0x05, 0x6c, 0x65, 0x76, 0x65, 0x6c, 0x22, 0x15, 0x0a, 0x13, 0x53, 0x65, 0x74, - 0x4c, 0x6f, 0x67, 0x4c, 0x65, 0x76, 0x65, 0x6c, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, - 0x22, 0x1b, 0x0a, 0x05, 0x53, 0x74, 0x61, 0x74, 0x65, 0x12, 0x12, 0x0a, 0x04, 0x6e, 0x61, 0x6d, - 0x65, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x04, 0x6e, 0x61, 0x6d, 0x65, 0x22, 0x13, 0x0a, - 0x11, 0x4c, 0x69, 0x73, 0x74, 0x53, 0x74, 0x61, 0x74, 0x65, 0x73, 0x52, 0x65, 0x71, 0x75, 0x65, - 0x73, 0x74, 0x22, 0x3b, 0x0a, 0x12, 0x4c, 0x69, 0x73, 0x74, 0x53, 0x74, 0x61, 0x74, 0x65, 0x73, - 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x12, 0x25, 0x0a, 0x06, 0x73, 0x74, 0x61, 0x74, - 0x65, 0x73, 0x18, 0x01, 0x20, 0x03, 0x28, 0x0b, 0x32, 0x0d, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, - 0x6e, 0x2e, 0x53, 0x74, 0x61, 0x74, 0x65, 0x52, 0x06, 0x73, 0x74, 0x61, 0x74, 0x65, 0x73, 0x22, - 0x44, 0x0a, 0x11, 0x43, 0x6c, 0x65, 0x61, 0x6e, 0x53, 0x74, 0x61, 0x74, 0x65, 0x52, 0x65, 0x71, - 0x75, 0x65, 0x73, 0x74, 0x12, 0x1d, 0x0a, 0x0a, 0x73, 0x74, 0x61, 0x74, 0x65, 0x5f, 0x6e, 0x61, - 0x6d, 0x65, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x09, 0x73, 0x74, 0x61, 0x74, 0x65, 0x4e, - 0x61, 0x6d, 0x65, 0x12, 0x10, 0x0a, 0x03, 0x61, 0x6c, 0x6c, 0x18, 0x02, 0x20, 0x01, 0x28, 0x08, - 0x52, 0x03, 0x61, 0x6c, 0x6c, 0x22, 0x3b, 0x0a, 0x12, 0x43, 0x6c, 0x65, 0x61, 0x6e, 0x53, 0x74, - 0x61, 0x74, 0x65, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x12, 0x25, 0x0a, 0x0e, 0x63, - 0x6c, 0x65, 0x61, 0x6e, 0x65, 0x64, 0x5f, 0x73, 0x74, 0x61, 0x74, 0x65, 0x73, 0x18, 0x01, 0x20, - 0x01, 0x28, 0x05, 0x52, 0x0d, 0x63, 0x6c, 0x65, 0x61, 0x6e, 0x65, 0x64, 0x53, 0x74, 0x61, 0x74, - 0x65, 0x73, 0x22, 0x45, 0x0a, 0x12, 0x44, 0x65, 0x6c, 0x65, 0x74, 0x65, 0x53, 0x74, 0x61, 0x74, - 0x65, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x12, 0x1d, 0x0a, 0x0a, 0x73, 0x74, 0x61, 0x74, - 0x65, 0x5f, 0x6e, 0x61, 0x6d, 0x65, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x09, 0x73, 0x74, - 0x61, 0x74, 0x65, 0x4e, 0x61, 0x6d, 0x65, 0x12, 0x10, 0x0a, 0x03, 0x61, 0x6c, 0x6c, 0x18, 0x02, - 0x20, 0x01, 0x28, 0x08, 0x52, 0x03, 0x61, 0x6c, 0x6c, 0x22, 0x3c, 0x0a, 0x13, 0x44, 0x65, 0x6c, - 0x65, 0x74, 0x65, 0x53, 0x74, 0x61, 0x74, 0x65, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, - 0x12, 0x25, 0x0a, 0x0e, 0x64, 0x65, 0x6c, 0x65, 0x74, 0x65, 0x64, 0x5f, 0x73, 0x74, 0x61, 0x74, - 0x65, 0x73, 0x18, 0x01, 0x20, 0x01, 0x28, 0x05, 0x52, 0x0d, 0x64, 0x65, 0x6c, 0x65, 0x74, 0x65, - 0x64, 0x53, 0x74, 0x61, 0x74, 0x65, 0x73, 0x22, 0x3b, 0x0a, 0x1f, 0x53, 0x65, 0x74, 0x4e, 0x65, - 0x74, 0x77, 0x6f, 0x72, 0x6b, 0x4d, 0x61, 0x70, 0x50, 0x65, 0x72, 0x73, 0x69, 0x73, 0x74, 0x65, - 0x6e, 0x63, 0x65, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x12, 0x18, 0x0a, 0x07, 0x65, 0x6e, - 0x61, 0x62, 0x6c, 0x65, 0x64, 0x18, 0x01, 0x20, 0x01, 0x28, 0x08, 0x52, 0x07, 0x65, 0x6e, 0x61, - 0x62, 0x6c, 0x65, 0x64, 0x22, 0x22, 0x0a, 0x20, 0x53, 0x65, 0x74, 0x4e, 0x65, 0x74, 0x77, 0x6f, - 0x72, 0x6b, 0x4d, 0x61, 0x70, 0x50, 0x65, 0x72, 0x73, 0x69, 0x73, 0x74, 0x65, 0x6e, 0x63, 0x65, - 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x2a, 0x62, 0x0a, 0x08, 0x4c, 0x6f, 0x67, 0x4c, - 0x65, 0x76, 0x65, 0x6c, 0x12, 0x0b, 0x0a, 0x07, 0x55, 0x4e, 0x4b, 0x4e, 0x4f, 0x57, 0x4e, 0x10, - 0x00, 0x12, 0x09, 0x0a, 0x05, 0x50, 0x41, 0x4e, 0x49, 0x43, 0x10, 0x01, 0x12, 0x09, 0x0a, 0x05, - 0x46, 0x41, 0x54, 0x41, 0x4c, 0x10, 0x02, 0x12, 0x09, 0x0a, 0x05, 0x45, 0x52, 0x52, 0x4f, 0x52, - 0x10, 0x03, 0x12, 0x08, 0x0a, 0x04, 0x57, 0x41, 0x52, 0x4e, 0x10, 0x04, 0x12, 0x08, 0x0a, 0x04, - 0x49, 0x4e, 0x46, 0x4f, 0x10, 0x05, 0x12, 0x09, 0x0a, 0x05, 0x44, 0x45, 0x42, 0x55, 0x47, 0x10, - 0x06, 0x12, 0x09, 0x0a, 0x05, 0x54, 0x52, 0x41, 0x43, 0x45, 0x10, 0x07, 0x32, 0x81, 0x09, 0x0a, - 0x0d, 0x44, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x53, 0x65, 0x72, 0x76, 0x69, 0x63, 0x65, 0x12, 0x36, - 0x0a, 0x05, 0x4c, 0x6f, 0x67, 0x69, 0x6e, 0x12, 0x14, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, - 0x2e, 0x4c, 0x6f, 0x67, 0x69, 0x6e, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x1a, 0x15, 0x2e, - 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x4c, 0x6f, 0x67, 0x69, 0x6e, 0x52, 0x65, 0x73, 0x70, - 0x6f, 0x6e, 0x73, 0x65, 0x22, 0x00, 0x12, 0x4b, 0x0a, 0x0c, 0x57, 0x61, 0x69, 0x74, 0x53, 0x53, - 0x4f, 0x4c, 0x6f, 0x67, 0x69, 0x6e, 0x12, 0x1b, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, - 0x57, 0x61, 0x69, 0x74, 0x53, 0x53, 0x4f, 0x4c, 0x6f, 0x67, 0x69, 0x6e, 0x52, 0x65, 0x71, 0x75, - 0x65, 0x73, 0x74, 0x1a, 0x1c, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x57, 0x61, 0x69, - 0x74, 0x53, 0x53, 0x4f, 0x4c, 0x6f, 0x67, 0x69, 0x6e, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, - 0x65, 0x22, 0x00, 0x12, 0x2d, 0x0a, 0x02, 0x55, 0x70, 0x12, 0x11, 0x2e, 0x64, 0x61, 0x65, 0x6d, - 0x6f, 0x6e, 0x2e, 0x55, 0x70, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x1a, 0x12, 0x2e, 0x64, - 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x55, 0x70, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, - 0x22, 0x00, 0x12, 0x39, 0x0a, 0x06, 0x53, 0x74, 0x61, 0x74, 0x75, 0x73, 0x12, 0x15, 0x2e, 0x64, - 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x53, 0x74, 0x61, 0x74, 0x75, 0x73, 0x52, 0x65, 0x71, 0x75, - 0x65, 0x73, 0x74, 0x1a, 0x16, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x53, 0x74, 0x61, - 0x74, 0x75, 0x73, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x22, 0x00, 0x12, 0x33, 0x0a, - 0x04, 0x44, 0x6f, 0x77, 0x6e, 0x12, 0x13, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x44, - 0x6f, 0x77, 0x6e, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x1a, 0x14, 0x2e, 0x64, 0x61, 0x65, - 0x6d, 0x6f, 0x6e, 0x2e, 0x44, 0x6f, 0x77, 0x6e, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, - 0x22, 0x00, 0x12, 0x42, 0x0a, 0x09, 0x47, 0x65, 0x74, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x12, - 0x18, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x47, 0x65, 0x74, 0x43, 0x6f, 0x6e, 0x66, - 0x69, 0x67, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x1a, 0x19, 0x2e, 0x64, 0x61, 0x65, 0x6d, - 0x6f, 0x6e, 0x2e, 0x47, 0x65, 0x74, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x52, 0x65, 0x73, 0x70, - 0x6f, 0x6e, 0x73, 0x65, 0x22, 0x00, 0x12, 0x45, 0x0a, 0x0a, 0x4c, 0x69, 0x73, 0x74, 0x52, 0x6f, - 0x75, 0x74, 0x65, 0x73, 0x12, 0x19, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x4c, 0x69, - 0x73, 0x74, 0x52, 0x6f, 0x75, 0x74, 0x65, 0x73, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x1a, - 0x1a, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x4c, 0x69, 0x73, 0x74, 0x52, 0x6f, 0x75, - 0x74, 0x65, 0x73, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x22, 0x00, 0x12, 0x4b, 0x0a, - 0x0c, 0x53, 0x65, 0x6c, 0x65, 0x63, 0x74, 0x52, 0x6f, 0x75, 0x74, 0x65, 0x73, 0x12, 0x1b, 0x2e, - 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x53, 0x65, 0x6c, 0x65, 0x63, 0x74, 0x52, 0x6f, 0x75, - 0x74, 0x65, 0x73, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x1a, 0x1c, 0x2e, 0x64, 0x61, 0x65, - 0x6d, 0x6f, 0x6e, 0x2e, 0x53, 0x65, 0x6c, 0x65, 0x63, 0x74, 0x52, 0x6f, 0x75, 0x74, 0x65, 0x73, - 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x22, 0x00, 0x12, 0x4d, 0x0a, 0x0e, 0x44, 0x65, - 0x73, 0x65, 0x6c, 0x65, 0x63, 0x74, 0x52, 0x6f, 0x75, 0x74, 0x65, 0x73, 0x12, 0x1b, 0x2e, 0x64, - 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x53, 0x65, 0x6c, 0x65, 0x63, 0x74, 0x52, 0x6f, 0x75, 0x74, - 0x65, 0x73, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x1a, 0x1c, 0x2e, 0x64, 0x61, 0x65, 0x6d, - 0x6f, 0x6e, 0x2e, 0x53, 0x65, 0x6c, 0x65, 0x63, 0x74, 0x52, 0x6f, 0x75, 0x74, 0x65, 0x73, 0x52, - 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x22, 0x00, 0x12, 0x48, 0x0a, 0x0b, 0x44, 0x65, 0x62, - 0x75, 0x67, 0x42, 0x75, 0x6e, 0x64, 0x6c, 0x65, 0x12, 0x1a, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, - 0x6e, 0x2e, 0x44, 0x65, 0x62, 0x75, 0x67, 0x42, 0x75, 0x6e, 0x64, 0x6c, 0x65, 0x52, 0x65, 0x71, - 0x75, 0x65, 0x73, 0x74, 0x1a, 0x1b, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x44, 0x65, - 0x62, 0x75, 0x67, 0x42, 0x75, 0x6e, 0x64, 0x6c, 0x65, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, - 0x65, 0x22, 0x00, 0x12, 0x48, 0x0a, 0x0b, 0x47, 0x65, 0x74, 0x4c, 0x6f, 0x67, 0x4c, 0x65, 0x76, - 0x65, 0x6c, 0x12, 0x1a, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x47, 0x65, 0x74, 0x4c, - 0x6f, 0x67, 0x4c, 0x65, 0x76, 0x65, 0x6c, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x1a, 0x1b, - 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x47, 0x65, 0x74, 0x4c, 0x6f, 0x67, 0x4c, 0x65, - 0x76, 0x65, 0x6c, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x22, 0x00, 0x12, 0x48, 0x0a, - 0x0b, 0x53, 0x65, 0x74, 0x4c, 0x6f, 0x67, 0x4c, 0x65, 0x76, 0x65, 0x6c, 0x12, 0x1a, 0x2e, 0x64, - 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x53, 0x65, 0x74, 0x4c, 0x6f, 0x67, 0x4c, 0x65, 0x76, 0x65, - 0x6c, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x1a, 0x1b, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, - 0x6e, 0x2e, 0x53, 0x65, 0x74, 0x4c, 0x6f, 0x67, 0x4c, 0x65, 0x76, 0x65, 0x6c, 0x52, 0x65, 0x73, - 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x22, 0x00, 0x12, 0x45, 0x0a, 0x0a, 0x4c, 0x69, 0x73, 0x74, 0x53, - 0x74, 0x61, 0x74, 0x65, 0x73, 0x12, 0x19, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x4c, - 0x69, 0x73, 0x74, 0x53, 0x74, 0x61, 0x74, 0x65, 0x73, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, - 0x1a, 0x1a, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x4c, 0x69, 0x73, 0x74, 0x53, 0x74, - 0x61, 0x74, 0x65, 0x73, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x22, 0x00, 0x12, 0x45, - 0x0a, 0x0a, 0x43, 0x6c, 0x65, 0x61, 0x6e, 0x53, 0x74, 0x61, 0x74, 0x65, 0x12, 0x19, 0x2e, 0x64, - 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x43, 0x6c, 0x65, 0x61, 0x6e, 0x53, 0x74, 0x61, 0x74, 0x65, - 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x1a, 0x1a, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, - 0x2e, 0x43, 0x6c, 0x65, 0x61, 0x6e, 0x53, 0x74, 0x61, 0x74, 0x65, 0x52, 0x65, 0x73, 0x70, 0x6f, - 0x6e, 0x73, 0x65, 0x22, 0x00, 0x12, 0x48, 0x0a, 0x0b, 0x44, 0x65, 0x6c, 0x65, 0x74, 0x65, 0x53, - 0x74, 0x61, 0x74, 0x65, 0x12, 0x1a, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x44, 0x65, - 0x6c, 0x65, 0x74, 0x65, 0x53, 0x74, 0x61, 0x74, 0x65, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, - 0x1a, 0x1b, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x44, 0x65, 0x6c, 0x65, 0x74, 0x65, - 0x53, 0x74, 0x61, 0x74, 0x65, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x22, 0x00, 0x12, - 0x6f, 0x0a, 0x18, 0x53, 0x65, 0x74, 0x4e, 0x65, 0x74, 0x77, 0x6f, 0x72, 0x6b, 0x4d, 0x61, 0x70, - 0x50, 0x65, 0x72, 0x73, 0x69, 0x73, 0x74, 0x65, 0x6e, 0x63, 0x65, 0x12, 0x27, 0x2e, 0x64, 0x61, - 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x53, 0x65, 0x74, 0x4e, 0x65, 0x74, 0x77, 0x6f, 0x72, 0x6b, 0x4d, - 0x61, 0x70, 0x50, 0x65, 0x72, 0x73, 0x69, 0x73, 0x74, 0x65, 0x6e, 0x63, 0x65, 0x52, 0x65, 0x71, - 0x75, 0x65, 0x73, 0x74, 0x1a, 0x28, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x53, 0x65, + 0x61, 0x62, 0x6c, 0x65, 0x64, 0x12, 0x1a, 0x0a, 0x08, 0x6e, 0x65, 0x74, 0x77, 0x6f, 0x72, 0x6b, + 0x73, 0x18, 0x10, 0x20, 0x03, 0x28, 0x09, 0x52, 0x08, 0x6e, 0x65, 0x74, 0x77, 0x6f, 0x72, 0x6b, + 0x73, 0x12, 0x33, 0x0a, 0x07, 0x6c, 0x61, 0x74, 0x65, 0x6e, 0x63, 0x79, 0x18, 0x11, 0x20, 0x01, + 0x28, 0x0b, 0x32, 0x19, 0x2e, 0x67, 0x6f, 0x6f, 0x67, 0x6c, 0x65, 0x2e, 0x70, 0x72, 0x6f, 0x74, + 0x6f, 0x62, 0x75, 0x66, 0x2e, 0x44, 0x75, 0x72, 0x61, 0x74, 0x69, 0x6f, 0x6e, 0x52, 0x07, 0x6c, + 0x61, 0x74, 0x65, 0x6e, 0x63, 0x79, 0x12, 0x22, 0x0a, 0x0c, 0x72, 0x65, 0x6c, 0x61, 0x79, 0x41, + 0x64, 0x64, 0x72, 0x65, 0x73, 0x73, 0x18, 0x12, 0x20, 0x01, 0x28, 0x09, 0x52, 0x0c, 0x72, 0x65, + 0x6c, 0x61, 0x79, 0x41, 0x64, 0x64, 0x72, 0x65, 0x73, 0x73, 0x22, 0xf0, 0x01, 0x0a, 0x0e, 0x4c, + 0x6f, 0x63, 0x61, 0x6c, 0x50, 0x65, 0x65, 0x72, 0x53, 0x74, 0x61, 0x74, 0x65, 0x12, 0x0e, 0x0a, + 0x02, 0x49, 0x50, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x02, 0x49, 0x50, 0x12, 0x16, 0x0a, + 0x06, 0x70, 0x75, 0x62, 0x4b, 0x65, 0x79, 0x18, 0x02, 0x20, 0x01, 0x28, 0x09, 0x52, 0x06, 0x70, + 0x75, 0x62, 0x4b, 0x65, 0x79, 0x12, 0x28, 0x0a, 0x0f, 0x6b, 0x65, 0x72, 0x6e, 0x65, 0x6c, 0x49, + 0x6e, 0x74, 0x65, 0x72, 0x66, 0x61, 0x63, 0x65, 0x18, 0x03, 0x20, 0x01, 0x28, 0x08, 0x52, 0x0f, + 0x6b, 0x65, 0x72, 0x6e, 0x65, 0x6c, 0x49, 0x6e, 0x74, 0x65, 0x72, 0x66, 0x61, 0x63, 0x65, 0x12, + 0x12, 0x0a, 0x04, 0x66, 0x71, 0x64, 0x6e, 0x18, 0x04, 0x20, 0x01, 0x28, 0x09, 0x52, 0x04, 0x66, + 0x71, 0x64, 0x6e, 0x12, 0x2a, 0x0a, 0x10, 0x72, 0x6f, 0x73, 0x65, 0x6e, 0x70, 0x61, 0x73, 0x73, + 0x45, 0x6e, 0x61, 0x62, 0x6c, 0x65, 0x64, 0x18, 0x05, 0x20, 0x01, 0x28, 0x08, 0x52, 0x10, 0x72, + 0x6f, 0x73, 0x65, 0x6e, 0x70, 0x61, 0x73, 0x73, 0x45, 0x6e, 0x61, 0x62, 0x6c, 0x65, 0x64, 0x12, + 0x30, 0x0a, 0x13, 0x72, 0x6f, 0x73, 0x65, 0x6e, 0x70, 0x61, 0x73, 0x73, 0x50, 0x65, 0x72, 0x6d, + 0x69, 0x73, 0x73, 0x69, 0x76, 0x65, 0x18, 0x06, 0x20, 0x01, 0x28, 0x08, 0x52, 0x13, 0x72, 0x6f, + 0x73, 0x65, 0x6e, 0x70, 0x61, 0x73, 0x73, 0x50, 0x65, 0x72, 0x6d, 0x69, 0x73, 0x73, 0x69, 0x76, + 0x65, 0x12, 0x1a, 0x0a, 0x08, 0x6e, 0x65, 0x74, 0x77, 0x6f, 0x72, 0x6b, 0x73, 0x18, 0x07, 0x20, + 0x03, 0x28, 0x09, 0x52, 0x08, 0x6e, 0x65, 0x74, 0x77, 0x6f, 0x72, 0x6b, 0x73, 0x22, 0x53, 0x0a, + 0x0b, 0x53, 0x69, 0x67, 0x6e, 0x61, 0x6c, 0x53, 0x74, 0x61, 0x74, 0x65, 0x12, 0x10, 0x0a, 0x03, + 0x55, 0x52, 0x4c, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x03, 0x55, 0x52, 0x4c, 0x12, 0x1c, + 0x0a, 0x09, 0x63, 0x6f, 0x6e, 0x6e, 0x65, 0x63, 0x74, 0x65, 0x64, 0x18, 0x02, 0x20, 0x01, 0x28, + 0x08, 0x52, 0x09, 0x63, 0x6f, 0x6e, 0x6e, 0x65, 0x63, 0x74, 0x65, 0x64, 0x12, 0x14, 0x0a, 0x05, + 0x65, 0x72, 0x72, 0x6f, 0x72, 0x18, 0x03, 0x20, 0x01, 0x28, 0x09, 0x52, 0x05, 0x65, 0x72, 0x72, + 0x6f, 0x72, 0x22, 0x57, 0x0a, 0x0f, 0x4d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, + 0x53, 0x74, 0x61, 0x74, 0x65, 0x12, 0x10, 0x0a, 0x03, 0x55, 0x52, 0x4c, 0x18, 0x01, 0x20, 0x01, + 0x28, 0x09, 0x52, 0x03, 0x55, 0x52, 0x4c, 0x12, 0x1c, 0x0a, 0x09, 0x63, 0x6f, 0x6e, 0x6e, 0x65, + 0x63, 0x74, 0x65, 0x64, 0x18, 0x02, 0x20, 0x01, 0x28, 0x08, 0x52, 0x09, 0x63, 0x6f, 0x6e, 0x6e, + 0x65, 0x63, 0x74, 0x65, 0x64, 0x12, 0x14, 0x0a, 0x05, 0x65, 0x72, 0x72, 0x6f, 0x72, 0x18, 0x03, + 0x20, 0x01, 0x28, 0x09, 0x52, 0x05, 0x65, 0x72, 0x72, 0x6f, 0x72, 0x22, 0x52, 0x0a, 0x0a, 0x52, + 0x65, 0x6c, 0x61, 0x79, 0x53, 0x74, 0x61, 0x74, 0x65, 0x12, 0x10, 0x0a, 0x03, 0x55, 0x52, 0x49, + 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x03, 0x55, 0x52, 0x49, 0x12, 0x1c, 0x0a, 0x09, 0x61, + 0x76, 0x61, 0x69, 0x6c, 0x61, 0x62, 0x6c, 0x65, 0x18, 0x02, 0x20, 0x01, 0x28, 0x08, 0x52, 0x09, + 0x61, 0x76, 0x61, 0x69, 0x6c, 0x61, 0x62, 0x6c, 0x65, 0x12, 0x14, 0x0a, 0x05, 0x65, 0x72, 0x72, + 0x6f, 0x72, 0x18, 0x03, 0x20, 0x01, 0x28, 0x09, 0x52, 0x05, 0x65, 0x72, 0x72, 0x6f, 0x72, 0x22, + 0x72, 0x0a, 0x0c, 0x4e, 0x53, 0x47, 0x72, 0x6f, 0x75, 0x70, 0x53, 0x74, 0x61, 0x74, 0x65, 0x12, + 0x18, 0x0a, 0x07, 0x73, 0x65, 0x72, 0x76, 0x65, 0x72, 0x73, 0x18, 0x01, 0x20, 0x03, 0x28, 0x09, + 0x52, 0x07, 0x73, 0x65, 0x72, 0x76, 0x65, 0x72, 0x73, 0x12, 0x18, 0x0a, 0x07, 0x64, 0x6f, 0x6d, + 0x61, 0x69, 0x6e, 0x73, 0x18, 0x02, 0x20, 0x03, 0x28, 0x09, 0x52, 0x07, 0x64, 0x6f, 0x6d, 0x61, + 0x69, 0x6e, 0x73, 0x12, 0x18, 0x0a, 0x07, 0x65, 0x6e, 0x61, 0x62, 0x6c, 0x65, 0x64, 0x18, 0x03, + 0x20, 0x01, 0x28, 0x08, 0x52, 0x07, 0x65, 0x6e, 0x61, 0x62, 0x6c, 0x65, 0x64, 0x12, 0x14, 0x0a, + 0x05, 0x65, 0x72, 0x72, 0x6f, 0x72, 0x18, 0x04, 0x20, 0x01, 0x28, 0x09, 0x52, 0x05, 0x65, 0x72, + 0x72, 0x6f, 0x72, 0x22, 0xd2, 0x02, 0x0a, 0x0a, 0x46, 0x75, 0x6c, 0x6c, 0x53, 0x74, 0x61, 0x74, + 0x75, 0x73, 0x12, 0x41, 0x0a, 0x0f, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, + 0x53, 0x74, 0x61, 0x74, 0x65, 0x18, 0x01, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x17, 0x2e, 0x64, 0x61, + 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x4d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x53, + 0x74, 0x61, 0x74, 0x65, 0x52, 0x0f, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, + 0x53, 0x74, 0x61, 0x74, 0x65, 0x12, 0x35, 0x0a, 0x0b, 0x73, 0x69, 0x67, 0x6e, 0x61, 0x6c, 0x53, + 0x74, 0x61, 0x74, 0x65, 0x18, 0x02, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x13, 0x2e, 0x64, 0x61, 0x65, + 0x6d, 0x6f, 0x6e, 0x2e, 0x53, 0x69, 0x67, 0x6e, 0x61, 0x6c, 0x53, 0x74, 0x61, 0x74, 0x65, 0x52, + 0x0b, 0x73, 0x69, 0x67, 0x6e, 0x61, 0x6c, 0x53, 0x74, 0x61, 0x74, 0x65, 0x12, 0x3e, 0x0a, 0x0e, + 0x6c, 0x6f, 0x63, 0x61, 0x6c, 0x50, 0x65, 0x65, 0x72, 0x53, 0x74, 0x61, 0x74, 0x65, 0x18, 0x03, + 0x20, 0x01, 0x28, 0x0b, 0x32, 0x16, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x4c, 0x6f, + 0x63, 0x61, 0x6c, 0x50, 0x65, 0x65, 0x72, 0x53, 0x74, 0x61, 0x74, 0x65, 0x52, 0x0e, 0x6c, 0x6f, + 0x63, 0x61, 0x6c, 0x50, 0x65, 0x65, 0x72, 0x53, 0x74, 0x61, 0x74, 0x65, 0x12, 0x27, 0x0a, 0x05, + 0x70, 0x65, 0x65, 0x72, 0x73, 0x18, 0x04, 0x20, 0x03, 0x28, 0x0b, 0x32, 0x11, 0x2e, 0x64, 0x61, + 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x50, 0x65, 0x65, 0x72, 0x53, 0x74, 0x61, 0x74, 0x65, 0x52, 0x05, + 0x70, 0x65, 0x65, 0x72, 0x73, 0x12, 0x2a, 0x0a, 0x06, 0x72, 0x65, 0x6c, 0x61, 0x79, 0x73, 0x18, + 0x05, 0x20, 0x03, 0x28, 0x0b, 0x32, 0x12, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x52, + 0x65, 0x6c, 0x61, 0x79, 0x53, 0x74, 0x61, 0x74, 0x65, 0x52, 0x06, 0x72, 0x65, 0x6c, 0x61, 0x79, + 0x73, 0x12, 0x35, 0x0a, 0x0b, 0x64, 0x6e, 0x73, 0x5f, 0x73, 0x65, 0x72, 0x76, 0x65, 0x72, 0x73, + 0x18, 0x06, 0x20, 0x03, 0x28, 0x0b, 0x32, 0x14, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, + 0x4e, 0x53, 0x47, 0x72, 0x6f, 0x75, 0x70, 0x53, 0x74, 0x61, 0x74, 0x65, 0x52, 0x0a, 0x64, 0x6e, + 0x73, 0x53, 0x65, 0x72, 0x76, 0x65, 0x72, 0x73, 0x22, 0x15, 0x0a, 0x13, 0x4c, 0x69, 0x73, 0x74, + 0x4e, 0x65, 0x74, 0x77, 0x6f, 0x72, 0x6b, 0x73, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x22, + 0x3f, 0x0a, 0x14, 0x4c, 0x69, 0x73, 0x74, 0x4e, 0x65, 0x74, 0x77, 0x6f, 0x72, 0x6b, 0x73, 0x52, + 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x12, 0x27, 0x0a, 0x06, 0x72, 0x6f, 0x75, 0x74, 0x65, + 0x73, 0x18, 0x01, 0x20, 0x03, 0x28, 0x0b, 0x32, 0x0f, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, + 0x2e, 0x4e, 0x65, 0x74, 0x77, 0x6f, 0x72, 0x6b, 0x52, 0x06, 0x72, 0x6f, 0x75, 0x74, 0x65, 0x73, + 0x22, 0x61, 0x0a, 0x15, 0x53, 0x65, 0x6c, 0x65, 0x63, 0x74, 0x4e, 0x65, 0x74, 0x77, 0x6f, 0x72, + 0x6b, 0x73, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x12, 0x1e, 0x0a, 0x0a, 0x6e, 0x65, 0x74, + 0x77, 0x6f, 0x72, 0x6b, 0x49, 0x44, 0x73, 0x18, 0x01, 0x20, 0x03, 0x28, 0x09, 0x52, 0x0a, 0x6e, + 0x65, 0x74, 0x77, 0x6f, 0x72, 0x6b, 0x49, 0x44, 0x73, 0x12, 0x16, 0x0a, 0x06, 0x61, 0x70, 0x70, + 0x65, 0x6e, 0x64, 0x18, 0x02, 0x20, 0x01, 0x28, 0x08, 0x52, 0x06, 0x61, 0x70, 0x70, 0x65, 0x6e, + 0x64, 0x12, 0x10, 0x0a, 0x03, 0x61, 0x6c, 0x6c, 0x18, 0x03, 0x20, 0x01, 0x28, 0x08, 0x52, 0x03, + 0x61, 0x6c, 0x6c, 0x22, 0x18, 0x0a, 0x16, 0x53, 0x65, 0x6c, 0x65, 0x63, 0x74, 0x4e, 0x65, 0x74, + 0x77, 0x6f, 0x72, 0x6b, 0x73, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x22, 0x1a, 0x0a, + 0x06, 0x49, 0x50, 0x4c, 0x69, 0x73, 0x74, 0x12, 0x10, 0x0a, 0x03, 0x69, 0x70, 0x73, 0x18, 0x01, + 0x20, 0x03, 0x28, 0x09, 0x52, 0x03, 0x69, 0x70, 0x73, 0x22, 0xf9, 0x01, 0x0a, 0x07, 0x4e, 0x65, + 0x74, 0x77, 0x6f, 0x72, 0x6b, 0x12, 0x0e, 0x0a, 0x02, 0x49, 0x44, 0x18, 0x01, 0x20, 0x01, 0x28, + 0x09, 0x52, 0x02, 0x49, 0x44, 0x12, 0x14, 0x0a, 0x05, 0x72, 0x61, 0x6e, 0x67, 0x65, 0x18, 0x02, + 0x20, 0x01, 0x28, 0x09, 0x52, 0x05, 0x72, 0x61, 0x6e, 0x67, 0x65, 0x12, 0x1a, 0x0a, 0x08, 0x73, + 0x65, 0x6c, 0x65, 0x63, 0x74, 0x65, 0x64, 0x18, 0x03, 0x20, 0x01, 0x28, 0x08, 0x52, 0x08, 0x73, + 0x65, 0x6c, 0x65, 0x63, 0x74, 0x65, 0x64, 0x12, 0x18, 0x0a, 0x07, 0x64, 0x6f, 0x6d, 0x61, 0x69, + 0x6e, 0x73, 0x18, 0x04, 0x20, 0x03, 0x28, 0x09, 0x52, 0x07, 0x64, 0x6f, 0x6d, 0x61, 0x69, 0x6e, + 0x73, 0x12, 0x42, 0x0a, 0x0b, 0x72, 0x65, 0x73, 0x6f, 0x6c, 0x76, 0x65, 0x64, 0x49, 0x50, 0x73, + 0x18, 0x05, 0x20, 0x03, 0x28, 0x0b, 0x32, 0x20, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, + 0x4e, 0x65, 0x74, 0x77, 0x6f, 0x72, 0x6b, 0x2e, 0x52, 0x65, 0x73, 0x6f, 0x6c, 0x76, 0x65, 0x64, + 0x49, 0x50, 0x73, 0x45, 0x6e, 0x74, 0x72, 0x79, 0x52, 0x0b, 0x72, 0x65, 0x73, 0x6f, 0x6c, 0x76, + 0x65, 0x64, 0x49, 0x50, 0x73, 0x1a, 0x4e, 0x0a, 0x10, 0x52, 0x65, 0x73, 0x6f, 0x6c, 0x76, 0x65, + 0x64, 0x49, 0x50, 0x73, 0x45, 0x6e, 0x74, 0x72, 0x79, 0x12, 0x10, 0x0a, 0x03, 0x6b, 0x65, 0x79, + 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x03, 0x6b, 0x65, 0x79, 0x12, 0x24, 0x0a, 0x05, 0x76, + 0x61, 0x6c, 0x75, 0x65, 0x18, 0x02, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x0e, 0x2e, 0x64, 0x61, 0x65, + 0x6d, 0x6f, 0x6e, 0x2e, 0x49, 0x50, 0x4c, 0x69, 0x73, 0x74, 0x52, 0x05, 0x76, 0x61, 0x6c, 0x75, + 0x65, 0x3a, 0x02, 0x38, 0x01, 0x22, 0x6a, 0x0a, 0x12, 0x44, 0x65, 0x62, 0x75, 0x67, 0x42, 0x75, + 0x6e, 0x64, 0x6c, 0x65, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x12, 0x1c, 0x0a, 0x09, 0x61, + 0x6e, 0x6f, 0x6e, 0x79, 0x6d, 0x69, 0x7a, 0x65, 0x18, 0x01, 0x20, 0x01, 0x28, 0x08, 0x52, 0x09, + 0x61, 0x6e, 0x6f, 0x6e, 0x79, 0x6d, 0x69, 0x7a, 0x65, 0x12, 0x16, 0x0a, 0x06, 0x73, 0x74, 0x61, + 0x74, 0x75, 0x73, 0x18, 0x02, 0x20, 0x01, 0x28, 0x09, 0x52, 0x06, 0x73, 0x74, 0x61, 0x74, 0x75, + 0x73, 0x12, 0x1e, 0x0a, 0x0a, 0x73, 0x79, 0x73, 0x74, 0x65, 0x6d, 0x49, 0x6e, 0x66, 0x6f, 0x18, + 0x03, 0x20, 0x01, 0x28, 0x08, 0x52, 0x0a, 0x73, 0x79, 0x73, 0x74, 0x65, 0x6d, 0x49, 0x6e, 0x66, + 0x6f, 0x22, 0x29, 0x0a, 0x13, 0x44, 0x65, 0x62, 0x75, 0x67, 0x42, 0x75, 0x6e, 0x64, 0x6c, 0x65, + 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x12, 0x12, 0x0a, 0x04, 0x70, 0x61, 0x74, 0x68, + 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x04, 0x70, 0x61, 0x74, 0x68, 0x22, 0x14, 0x0a, 0x12, + 0x47, 0x65, 0x74, 0x4c, 0x6f, 0x67, 0x4c, 0x65, 0x76, 0x65, 0x6c, 0x52, 0x65, 0x71, 0x75, 0x65, + 0x73, 0x74, 0x22, 0x3d, 0x0a, 0x13, 0x47, 0x65, 0x74, 0x4c, 0x6f, 0x67, 0x4c, 0x65, 0x76, 0x65, + 0x6c, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x12, 0x26, 0x0a, 0x05, 0x6c, 0x65, 0x76, + 0x65, 0x6c, 0x18, 0x01, 0x20, 0x01, 0x28, 0x0e, 0x32, 0x10, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, + 0x6e, 0x2e, 0x4c, 0x6f, 0x67, 0x4c, 0x65, 0x76, 0x65, 0x6c, 0x52, 0x05, 0x6c, 0x65, 0x76, 0x65, + 0x6c, 0x22, 0x3c, 0x0a, 0x12, 0x53, 0x65, 0x74, 0x4c, 0x6f, 0x67, 0x4c, 0x65, 0x76, 0x65, 0x6c, + 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x12, 0x26, 0x0a, 0x05, 0x6c, 0x65, 0x76, 0x65, 0x6c, + 0x18, 0x01, 0x20, 0x01, 0x28, 0x0e, 0x32, 0x10, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, + 0x4c, 0x6f, 0x67, 0x4c, 0x65, 0x76, 0x65, 0x6c, 0x52, 0x05, 0x6c, 0x65, 0x76, 0x65, 0x6c, 0x22, + 0x15, 0x0a, 0x13, 0x53, 0x65, 0x74, 0x4c, 0x6f, 0x67, 0x4c, 0x65, 0x76, 0x65, 0x6c, 0x52, 0x65, + 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x22, 0x1b, 0x0a, 0x05, 0x53, 0x74, 0x61, 0x74, 0x65, 0x12, + 0x12, 0x0a, 0x04, 0x6e, 0x61, 0x6d, 0x65, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x04, 0x6e, + 0x61, 0x6d, 0x65, 0x22, 0x13, 0x0a, 0x11, 0x4c, 0x69, 0x73, 0x74, 0x53, 0x74, 0x61, 0x74, 0x65, + 0x73, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x22, 0x3b, 0x0a, 0x12, 0x4c, 0x69, 0x73, 0x74, + 0x53, 0x74, 0x61, 0x74, 0x65, 0x73, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x12, 0x25, + 0x0a, 0x06, 0x73, 0x74, 0x61, 0x74, 0x65, 0x73, 0x18, 0x01, 0x20, 0x03, 0x28, 0x0b, 0x32, 0x0d, + 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x53, 0x74, 0x61, 0x74, 0x65, 0x52, 0x06, 0x73, + 0x74, 0x61, 0x74, 0x65, 0x73, 0x22, 0x44, 0x0a, 0x11, 0x43, 0x6c, 0x65, 0x61, 0x6e, 0x53, 0x74, + 0x61, 0x74, 0x65, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x12, 0x1d, 0x0a, 0x0a, 0x73, 0x74, + 0x61, 0x74, 0x65, 0x5f, 0x6e, 0x61, 0x6d, 0x65, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x09, + 0x73, 0x74, 0x61, 0x74, 0x65, 0x4e, 0x61, 0x6d, 0x65, 0x12, 0x10, 0x0a, 0x03, 0x61, 0x6c, 0x6c, + 0x18, 0x02, 0x20, 0x01, 0x28, 0x08, 0x52, 0x03, 0x61, 0x6c, 0x6c, 0x22, 0x3b, 0x0a, 0x12, 0x43, + 0x6c, 0x65, 0x61, 0x6e, 0x53, 0x74, 0x61, 0x74, 0x65, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, + 0x65, 0x12, 0x25, 0x0a, 0x0e, 0x63, 0x6c, 0x65, 0x61, 0x6e, 0x65, 0x64, 0x5f, 0x73, 0x74, 0x61, + 0x74, 0x65, 0x73, 0x18, 0x01, 0x20, 0x01, 0x28, 0x05, 0x52, 0x0d, 0x63, 0x6c, 0x65, 0x61, 0x6e, + 0x65, 0x64, 0x53, 0x74, 0x61, 0x74, 0x65, 0x73, 0x22, 0x45, 0x0a, 0x12, 0x44, 0x65, 0x6c, 0x65, + 0x74, 0x65, 0x53, 0x74, 0x61, 0x74, 0x65, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x12, 0x1d, + 0x0a, 0x0a, 0x73, 0x74, 0x61, 0x74, 0x65, 0x5f, 0x6e, 0x61, 0x6d, 0x65, 0x18, 0x01, 0x20, 0x01, + 0x28, 0x09, 0x52, 0x09, 0x73, 0x74, 0x61, 0x74, 0x65, 0x4e, 0x61, 0x6d, 0x65, 0x12, 0x10, 0x0a, + 0x03, 0x61, 0x6c, 0x6c, 0x18, 0x02, 0x20, 0x01, 0x28, 0x08, 0x52, 0x03, 0x61, 0x6c, 0x6c, 0x22, + 0x3c, 0x0a, 0x13, 0x44, 0x65, 0x6c, 0x65, 0x74, 0x65, 0x53, 0x74, 0x61, 0x74, 0x65, 0x52, 0x65, + 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x12, 0x25, 0x0a, 0x0e, 0x64, 0x65, 0x6c, 0x65, 0x74, 0x65, + 0x64, 0x5f, 0x73, 0x74, 0x61, 0x74, 0x65, 0x73, 0x18, 0x01, 0x20, 0x01, 0x28, 0x05, 0x52, 0x0d, + 0x64, 0x65, 0x6c, 0x65, 0x74, 0x65, 0x64, 0x53, 0x74, 0x61, 0x74, 0x65, 0x73, 0x22, 0x3b, 0x0a, + 0x1f, 0x53, 0x65, 0x74, 0x4e, 0x65, 0x74, 0x77, 0x6f, 0x72, 0x6b, 0x4d, 0x61, 0x70, 0x50, 0x65, + 0x72, 0x73, 0x69, 0x73, 0x74, 0x65, 0x6e, 0x63, 0x65, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, + 0x12, 0x18, 0x0a, 0x07, 0x65, 0x6e, 0x61, 0x62, 0x6c, 0x65, 0x64, 0x18, 0x01, 0x20, 0x01, 0x28, + 0x08, 0x52, 0x07, 0x65, 0x6e, 0x61, 0x62, 0x6c, 0x65, 0x64, 0x22, 0x22, 0x0a, 0x20, 0x53, 0x65, 0x74, 0x4e, 0x65, 0x74, 0x77, 0x6f, 0x72, 0x6b, 0x4d, 0x61, 0x70, 0x50, 0x65, 0x72, 0x73, 0x69, - 0x73, 0x74, 0x65, 0x6e, 0x63, 0x65, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x22, 0x00, - 0x42, 0x08, 0x5a, 0x06, 0x2f, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x62, 0x06, 0x70, 0x72, 0x6f, 0x74, - 0x6f, 0x33, + 0x73, 0x74, 0x65, 0x6e, 0x63, 0x65, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x2a, 0x62, + 0x0a, 0x08, 0x4c, 0x6f, 0x67, 0x4c, 0x65, 0x76, 0x65, 0x6c, 0x12, 0x0b, 0x0a, 0x07, 0x55, 0x4e, + 0x4b, 0x4e, 0x4f, 0x57, 0x4e, 0x10, 0x00, 0x12, 0x09, 0x0a, 0x05, 0x50, 0x41, 0x4e, 0x49, 0x43, + 0x10, 0x01, 0x12, 0x09, 0x0a, 0x05, 0x46, 0x41, 0x54, 0x41, 0x4c, 0x10, 0x02, 0x12, 0x09, 0x0a, + 0x05, 0x45, 0x52, 0x52, 0x4f, 0x52, 0x10, 0x03, 0x12, 0x08, 0x0a, 0x04, 0x57, 0x41, 0x52, 0x4e, + 0x10, 0x04, 0x12, 0x08, 0x0a, 0x04, 0x49, 0x4e, 0x46, 0x4f, 0x10, 0x05, 0x12, 0x09, 0x0a, 0x05, + 0x44, 0x45, 0x42, 0x55, 0x47, 0x10, 0x06, 0x12, 0x09, 0x0a, 0x05, 0x54, 0x52, 0x41, 0x43, 0x45, + 0x10, 0x07, 0x32, 0x93, 0x09, 0x0a, 0x0d, 0x44, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x53, 0x65, 0x72, + 0x76, 0x69, 0x63, 0x65, 0x12, 0x36, 0x0a, 0x05, 0x4c, 0x6f, 0x67, 0x69, 0x6e, 0x12, 0x14, 0x2e, + 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x4c, 0x6f, 0x67, 0x69, 0x6e, 0x52, 0x65, 0x71, 0x75, + 0x65, 0x73, 0x74, 0x1a, 0x15, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x4c, 0x6f, 0x67, + 0x69, 0x6e, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x22, 0x00, 0x12, 0x4b, 0x0a, 0x0c, + 0x57, 0x61, 0x69, 0x74, 0x53, 0x53, 0x4f, 0x4c, 0x6f, 0x67, 0x69, 0x6e, 0x12, 0x1b, 0x2e, 0x64, + 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x57, 0x61, 0x69, 0x74, 0x53, 0x53, 0x4f, 0x4c, 0x6f, 0x67, + 0x69, 0x6e, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x1a, 0x1c, 0x2e, 0x64, 0x61, 0x65, 0x6d, + 0x6f, 0x6e, 0x2e, 0x57, 0x61, 0x69, 0x74, 0x53, 0x53, 0x4f, 0x4c, 0x6f, 0x67, 0x69, 0x6e, 0x52, + 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x22, 0x00, 0x12, 0x2d, 0x0a, 0x02, 0x55, 0x70, 0x12, + 0x11, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x55, 0x70, 0x52, 0x65, 0x71, 0x75, 0x65, + 0x73, 0x74, 0x1a, 0x12, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x55, 0x70, 0x52, 0x65, + 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x22, 0x00, 0x12, 0x39, 0x0a, 0x06, 0x53, 0x74, 0x61, 0x74, + 0x75, 0x73, 0x12, 0x15, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x53, 0x74, 0x61, 0x74, + 0x75, 0x73, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x1a, 0x16, 0x2e, 0x64, 0x61, 0x65, 0x6d, + 0x6f, 0x6e, 0x2e, 0x53, 0x74, 0x61, 0x74, 0x75, 0x73, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, + 0x65, 0x22, 0x00, 0x12, 0x33, 0x0a, 0x04, 0x44, 0x6f, 0x77, 0x6e, 0x12, 0x13, 0x2e, 0x64, 0x61, + 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x44, 0x6f, 0x77, 0x6e, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, + 0x1a, 0x14, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x44, 0x6f, 0x77, 0x6e, 0x52, 0x65, + 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x22, 0x00, 0x12, 0x42, 0x0a, 0x09, 0x47, 0x65, 0x74, 0x43, + 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x12, 0x18, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x47, + 0x65, 0x74, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x1a, + 0x19, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x47, 0x65, 0x74, 0x43, 0x6f, 0x6e, 0x66, + 0x69, 0x67, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x22, 0x00, 0x12, 0x4b, 0x0a, 0x0c, + 0x4c, 0x69, 0x73, 0x74, 0x4e, 0x65, 0x74, 0x77, 0x6f, 0x72, 0x6b, 0x73, 0x12, 0x1b, 0x2e, 0x64, + 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x4c, 0x69, 0x73, 0x74, 0x4e, 0x65, 0x74, 0x77, 0x6f, 0x72, + 0x6b, 0x73, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x1a, 0x1c, 0x2e, 0x64, 0x61, 0x65, 0x6d, + 0x6f, 0x6e, 0x2e, 0x4c, 0x69, 0x73, 0x74, 0x4e, 0x65, 0x74, 0x77, 0x6f, 0x72, 0x6b, 0x73, 0x52, + 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x22, 0x00, 0x12, 0x51, 0x0a, 0x0e, 0x53, 0x65, 0x6c, + 0x65, 0x63, 0x74, 0x4e, 0x65, 0x74, 0x77, 0x6f, 0x72, 0x6b, 0x73, 0x12, 0x1d, 0x2e, 0x64, 0x61, + 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x53, 0x65, 0x6c, 0x65, 0x63, 0x74, 0x4e, 0x65, 0x74, 0x77, 0x6f, + 0x72, 0x6b, 0x73, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x1a, 0x1e, 0x2e, 0x64, 0x61, 0x65, + 0x6d, 0x6f, 0x6e, 0x2e, 0x53, 0x65, 0x6c, 0x65, 0x63, 0x74, 0x4e, 0x65, 0x74, 0x77, 0x6f, 0x72, + 0x6b, 0x73, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x22, 0x00, 0x12, 0x53, 0x0a, 0x10, + 0x44, 0x65, 0x73, 0x65, 0x6c, 0x65, 0x63, 0x74, 0x4e, 0x65, 0x74, 0x77, 0x6f, 0x72, 0x6b, 0x73, + 0x12, 0x1d, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x53, 0x65, 0x6c, 0x65, 0x63, 0x74, + 0x4e, 0x65, 0x74, 0x77, 0x6f, 0x72, 0x6b, 0x73, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x1a, + 0x1e, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x53, 0x65, 0x6c, 0x65, 0x63, 0x74, 0x4e, + 0x65, 0x74, 0x77, 0x6f, 0x72, 0x6b, 0x73, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x22, + 0x00, 0x12, 0x48, 0x0a, 0x0b, 0x44, 0x65, 0x62, 0x75, 0x67, 0x42, 0x75, 0x6e, 0x64, 0x6c, 0x65, + 0x12, 0x1a, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x44, 0x65, 0x62, 0x75, 0x67, 0x42, + 0x75, 0x6e, 0x64, 0x6c, 0x65, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x1a, 0x1b, 0x2e, 0x64, + 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x44, 0x65, 0x62, 0x75, 0x67, 0x42, 0x75, 0x6e, 0x64, 0x6c, + 0x65, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x22, 0x00, 0x12, 0x48, 0x0a, 0x0b, 0x47, + 0x65, 0x74, 0x4c, 0x6f, 0x67, 0x4c, 0x65, 0x76, 0x65, 0x6c, 0x12, 0x1a, 0x2e, 0x64, 0x61, 0x65, + 0x6d, 0x6f, 0x6e, 0x2e, 0x47, 0x65, 0x74, 0x4c, 0x6f, 0x67, 0x4c, 0x65, 0x76, 0x65, 0x6c, 0x52, + 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x1a, 0x1b, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, + 0x47, 0x65, 0x74, 0x4c, 0x6f, 0x67, 0x4c, 0x65, 0x76, 0x65, 0x6c, 0x52, 0x65, 0x73, 0x70, 0x6f, + 0x6e, 0x73, 0x65, 0x22, 0x00, 0x12, 0x48, 0x0a, 0x0b, 0x53, 0x65, 0x74, 0x4c, 0x6f, 0x67, 0x4c, + 0x65, 0x76, 0x65, 0x6c, 0x12, 0x1a, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x53, 0x65, + 0x74, 0x4c, 0x6f, 0x67, 0x4c, 0x65, 0x76, 0x65, 0x6c, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, + 0x1a, 0x1b, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x53, 0x65, 0x74, 0x4c, 0x6f, 0x67, + 0x4c, 0x65, 0x76, 0x65, 0x6c, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x22, 0x00, 0x12, + 0x45, 0x0a, 0x0a, 0x4c, 0x69, 0x73, 0x74, 0x53, 0x74, 0x61, 0x74, 0x65, 0x73, 0x12, 0x19, 0x2e, + 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x4c, 0x69, 0x73, 0x74, 0x53, 0x74, 0x61, 0x74, 0x65, + 0x73, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x1a, 0x1a, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, + 0x6e, 0x2e, 0x4c, 0x69, 0x73, 0x74, 0x53, 0x74, 0x61, 0x74, 0x65, 0x73, 0x52, 0x65, 0x73, 0x70, + 0x6f, 0x6e, 0x73, 0x65, 0x22, 0x00, 0x12, 0x45, 0x0a, 0x0a, 0x43, 0x6c, 0x65, 0x61, 0x6e, 0x53, + 0x74, 0x61, 0x74, 0x65, 0x12, 0x19, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x43, 0x6c, + 0x65, 0x61, 0x6e, 0x53, 0x74, 0x61, 0x74, 0x65, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x1a, + 0x1a, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x43, 0x6c, 0x65, 0x61, 0x6e, 0x53, 0x74, + 0x61, 0x74, 0x65, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x22, 0x00, 0x12, 0x48, 0x0a, + 0x0b, 0x44, 0x65, 0x6c, 0x65, 0x74, 0x65, 0x53, 0x74, 0x61, 0x74, 0x65, 0x12, 0x1a, 0x2e, 0x64, + 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x44, 0x65, 0x6c, 0x65, 0x74, 0x65, 0x53, 0x74, 0x61, 0x74, + 0x65, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x1a, 0x1b, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, + 0x6e, 0x2e, 0x44, 0x65, 0x6c, 0x65, 0x74, 0x65, 0x53, 0x74, 0x61, 0x74, 0x65, 0x52, 0x65, 0x73, + 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x22, 0x00, 0x12, 0x6f, 0x0a, 0x18, 0x53, 0x65, 0x74, 0x4e, 0x65, + 0x74, 0x77, 0x6f, 0x72, 0x6b, 0x4d, 0x61, 0x70, 0x50, 0x65, 0x72, 0x73, 0x69, 0x73, 0x74, 0x65, + 0x6e, 0x63, 0x65, 0x12, 0x27, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x53, 0x65, 0x74, + 0x4e, 0x65, 0x74, 0x77, 0x6f, 0x72, 0x6b, 0x4d, 0x61, 0x70, 0x50, 0x65, 0x72, 0x73, 0x69, 0x73, + 0x74, 0x65, 0x6e, 0x63, 0x65, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x1a, 0x28, 0x2e, 0x64, + 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x53, 0x65, 0x74, 0x4e, 0x65, 0x74, 0x77, 0x6f, 0x72, 0x6b, + 0x4d, 0x61, 0x70, 0x50, 0x65, 0x72, 0x73, 0x69, 0x73, 0x74, 0x65, 0x6e, 0x63, 0x65, 0x52, 0x65, + 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x22, 0x00, 0x42, 0x08, 0x5a, 0x06, 0x2f, 0x70, 0x72, 0x6f, + 0x74, 0x6f, 0x62, 0x06, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x33, } var ( @@ -2974,12 +2976,12 @@ var file_daemon_proto_goTypes = []interface{}{ (*RelayState)(nil), // 17: daemon.RelayState (*NSGroupState)(nil), // 18: daemon.NSGroupState (*FullStatus)(nil), // 19: daemon.FullStatus - (*ListRoutesRequest)(nil), // 20: daemon.ListRoutesRequest - (*ListRoutesResponse)(nil), // 21: daemon.ListRoutesResponse - (*SelectRoutesRequest)(nil), // 22: daemon.SelectRoutesRequest - (*SelectRoutesResponse)(nil), // 23: daemon.SelectRoutesResponse + (*ListNetworksRequest)(nil), // 20: daemon.ListNetworksRequest + (*ListNetworksResponse)(nil), // 21: daemon.ListNetworksResponse + (*SelectNetworksRequest)(nil), // 22: daemon.SelectNetworksRequest + (*SelectNetworksResponse)(nil), // 23: daemon.SelectNetworksResponse (*IPList)(nil), // 24: daemon.IPList - (*Route)(nil), // 25: daemon.Route + (*Network)(nil), // 25: daemon.Network (*DebugBundleRequest)(nil), // 26: daemon.DebugBundleRequest (*DebugBundleResponse)(nil), // 27: daemon.DebugBundleResponse (*GetLogLevelRequest)(nil), // 28: daemon.GetLogLevelRequest @@ -2995,7 +2997,7 @@ var file_daemon_proto_goTypes = []interface{}{ (*DeleteStateResponse)(nil), // 38: daemon.DeleteStateResponse (*SetNetworkMapPersistenceRequest)(nil), // 39: daemon.SetNetworkMapPersistenceRequest (*SetNetworkMapPersistenceResponse)(nil), // 40: daemon.SetNetworkMapPersistenceResponse - nil, // 41: daemon.Route.ResolvedIPsEntry + nil, // 41: daemon.Network.ResolvedIPsEntry (*durationpb.Duration)(nil), // 42: google.protobuf.Duration (*timestamppb.Timestamp)(nil), // 43: google.protobuf.Timestamp } @@ -3011,21 +3013,21 @@ var file_daemon_proto_depIdxs = []int32{ 13, // 8: daemon.FullStatus.peers:type_name -> daemon.PeerState 17, // 9: daemon.FullStatus.relays:type_name -> daemon.RelayState 18, // 10: daemon.FullStatus.dns_servers:type_name -> daemon.NSGroupState - 25, // 11: daemon.ListRoutesResponse.routes:type_name -> daemon.Route - 41, // 12: daemon.Route.resolvedIPs:type_name -> daemon.Route.ResolvedIPsEntry + 25, // 11: daemon.ListNetworksResponse.routes:type_name -> daemon.Network + 41, // 12: daemon.Network.resolvedIPs:type_name -> daemon.Network.ResolvedIPsEntry 0, // 13: daemon.GetLogLevelResponse.level:type_name -> daemon.LogLevel 0, // 14: daemon.SetLogLevelRequest.level:type_name -> daemon.LogLevel 32, // 15: daemon.ListStatesResponse.states:type_name -> daemon.State - 24, // 16: daemon.Route.ResolvedIPsEntry.value:type_name -> daemon.IPList + 24, // 16: daemon.Network.ResolvedIPsEntry.value:type_name -> daemon.IPList 1, // 17: daemon.DaemonService.Login:input_type -> daemon.LoginRequest 3, // 18: daemon.DaemonService.WaitSSOLogin:input_type -> daemon.WaitSSOLoginRequest 5, // 19: daemon.DaemonService.Up:input_type -> daemon.UpRequest 7, // 20: daemon.DaemonService.Status:input_type -> daemon.StatusRequest 9, // 21: daemon.DaemonService.Down:input_type -> daemon.DownRequest 11, // 22: daemon.DaemonService.GetConfig:input_type -> daemon.GetConfigRequest - 20, // 23: daemon.DaemonService.ListRoutes:input_type -> daemon.ListRoutesRequest - 22, // 24: daemon.DaemonService.SelectRoutes:input_type -> daemon.SelectRoutesRequest - 22, // 25: daemon.DaemonService.DeselectRoutes:input_type -> daemon.SelectRoutesRequest + 20, // 23: daemon.DaemonService.ListNetworks:input_type -> daemon.ListNetworksRequest + 22, // 24: daemon.DaemonService.SelectNetworks:input_type -> daemon.SelectNetworksRequest + 22, // 25: daemon.DaemonService.DeselectNetworks:input_type -> daemon.SelectNetworksRequest 26, // 26: daemon.DaemonService.DebugBundle:input_type -> daemon.DebugBundleRequest 28, // 27: daemon.DaemonService.GetLogLevel:input_type -> daemon.GetLogLevelRequest 30, // 28: daemon.DaemonService.SetLogLevel:input_type -> daemon.SetLogLevelRequest @@ -3039,9 +3041,9 @@ var file_daemon_proto_depIdxs = []int32{ 8, // 36: daemon.DaemonService.Status:output_type -> daemon.StatusResponse 10, // 37: daemon.DaemonService.Down:output_type -> daemon.DownResponse 12, // 38: daemon.DaemonService.GetConfig:output_type -> daemon.GetConfigResponse - 21, // 39: daemon.DaemonService.ListRoutes:output_type -> daemon.ListRoutesResponse - 23, // 40: daemon.DaemonService.SelectRoutes:output_type -> daemon.SelectRoutesResponse - 23, // 41: daemon.DaemonService.DeselectRoutes:output_type -> daemon.SelectRoutesResponse + 21, // 39: daemon.DaemonService.ListNetworks:output_type -> daemon.ListNetworksResponse + 23, // 40: daemon.DaemonService.SelectNetworks:output_type -> daemon.SelectNetworksResponse + 23, // 41: daemon.DaemonService.DeselectNetworks:output_type -> daemon.SelectNetworksResponse 27, // 42: daemon.DaemonService.DebugBundle:output_type -> daemon.DebugBundleResponse 29, // 43: daemon.DaemonService.GetLogLevel:output_type -> daemon.GetLogLevelResponse 31, // 44: daemon.DaemonService.SetLogLevel:output_type -> daemon.SetLogLevelResponse @@ -3291,7 +3293,7 @@ func file_daemon_proto_init() { } } file_daemon_proto_msgTypes[19].Exporter = func(v interface{}, i int) interface{} { - switch v := v.(*ListRoutesRequest); i { + switch v := v.(*ListNetworksRequest); i { case 0: return &v.state case 1: @@ -3303,7 +3305,7 @@ func file_daemon_proto_init() { } } file_daemon_proto_msgTypes[20].Exporter = func(v interface{}, i int) interface{} { - switch v := v.(*ListRoutesResponse); i { + switch v := v.(*ListNetworksResponse); i { case 0: return &v.state case 1: @@ -3315,7 +3317,7 @@ func file_daemon_proto_init() { } } file_daemon_proto_msgTypes[21].Exporter = func(v interface{}, i int) interface{} { - switch v := v.(*SelectRoutesRequest); i { + switch v := v.(*SelectNetworksRequest); i { case 0: return &v.state case 1: @@ -3327,7 +3329,7 @@ func file_daemon_proto_init() { } } file_daemon_proto_msgTypes[22].Exporter = func(v interface{}, i int) interface{} { - switch v := v.(*SelectRoutesResponse); i { + switch v := v.(*SelectNetworksResponse); i { case 0: return &v.state case 1: @@ -3351,7 +3353,7 @@ func file_daemon_proto_init() { } } file_daemon_proto_msgTypes[24].Exporter = func(v interface{}, i int) interface{} { - switch v := v.(*Route); i { + switch v := v.(*Network); i { case 0: return &v.state case 1: diff --git a/client/proto/daemon.proto b/client/proto/daemon.proto index 96ade5b4e..cddf78242 100644 --- a/client/proto/daemon.proto +++ b/client/proto/daemon.proto @@ -28,14 +28,14 @@ service DaemonService { // GetConfig of the daemon. rpc GetConfig(GetConfigRequest) returns (GetConfigResponse) {} - // List available network routes - rpc ListRoutes(ListRoutesRequest) returns (ListRoutesResponse) {} + // List available networks + rpc ListNetworks(ListNetworksRequest) returns (ListNetworksResponse) {} // Select specific routes - rpc SelectRoutes(SelectRoutesRequest) returns (SelectRoutesResponse) {} + rpc SelectNetworks(SelectNetworksRequest) returns (SelectNetworksResponse) {} // Deselect specific routes - rpc DeselectRoutes(SelectRoutesRequest) returns (SelectRoutesResponse) {} + rpc DeselectNetworks(SelectNetworksRequest) returns (SelectNetworksResponse) {} // DebugBundle creates a debug bundle rpc DebugBundle(DebugBundleRequest) returns (DebugBundleResponse) {} @@ -190,7 +190,7 @@ message PeerState { int64 bytesRx = 13; int64 bytesTx = 14; bool rosenpassEnabled = 15; - repeated string routes = 16; + repeated string networks = 16; google.protobuf.Duration latency = 17; string relayAddress = 18; } @@ -203,7 +203,7 @@ message LocalPeerState { string fqdn = 4; bool rosenpassEnabled = 5; bool rosenpassPermissive = 6; - repeated string routes = 7; + repeated string networks = 7; } // SignalState contains the latest state of a signal connection @@ -244,20 +244,20 @@ message FullStatus { repeated NSGroupState dns_servers = 6; } -message ListRoutesRequest { +message ListNetworksRequest { } -message ListRoutesResponse { - repeated Route routes = 1; +message ListNetworksResponse { + repeated Network routes = 1; } -message SelectRoutesRequest { - repeated string routeIDs = 1; +message SelectNetworksRequest { + repeated string networkIDs = 1; bool append = 2; bool all = 3; } -message SelectRoutesResponse { +message SelectNetworksResponse { } message IPList { @@ -265,9 +265,9 @@ message IPList { } -message Route { +message Network { string ID = 1; - string network = 2; + string range = 2; bool selected = 3; repeated string domains = 4; map resolvedIPs = 5; diff --git a/client/proto/daemon_grpc.pb.go b/client/proto/daemon_grpc.pb.go index 2e063604a..39424aee9 100644 --- a/client/proto/daemon_grpc.pb.go +++ b/client/proto/daemon_grpc.pb.go @@ -31,12 +31,12 @@ type DaemonServiceClient interface { Down(ctx context.Context, in *DownRequest, opts ...grpc.CallOption) (*DownResponse, error) // GetConfig of the daemon. GetConfig(ctx context.Context, in *GetConfigRequest, opts ...grpc.CallOption) (*GetConfigResponse, error) - // List available network routes - ListRoutes(ctx context.Context, in *ListRoutesRequest, opts ...grpc.CallOption) (*ListRoutesResponse, error) + // List available networks + ListNetworks(ctx context.Context, in *ListNetworksRequest, opts ...grpc.CallOption) (*ListNetworksResponse, error) // Select specific routes - SelectRoutes(ctx context.Context, in *SelectRoutesRequest, opts ...grpc.CallOption) (*SelectRoutesResponse, error) + SelectNetworks(ctx context.Context, in *SelectNetworksRequest, opts ...grpc.CallOption) (*SelectNetworksResponse, error) // Deselect specific routes - DeselectRoutes(ctx context.Context, in *SelectRoutesRequest, opts ...grpc.CallOption) (*SelectRoutesResponse, error) + DeselectNetworks(ctx context.Context, in *SelectNetworksRequest, opts ...grpc.CallOption) (*SelectNetworksResponse, error) // DebugBundle creates a debug bundle DebugBundle(ctx context.Context, in *DebugBundleRequest, opts ...grpc.CallOption) (*DebugBundleResponse, error) // GetLogLevel gets the log level of the daemon @@ -115,27 +115,27 @@ func (c *daemonServiceClient) GetConfig(ctx context.Context, in *GetConfigReques return out, nil } -func (c *daemonServiceClient) ListRoutes(ctx context.Context, in *ListRoutesRequest, opts ...grpc.CallOption) (*ListRoutesResponse, error) { - out := new(ListRoutesResponse) - err := c.cc.Invoke(ctx, "/daemon.DaemonService/ListRoutes", in, out, opts...) +func (c *daemonServiceClient) ListNetworks(ctx context.Context, in *ListNetworksRequest, opts ...grpc.CallOption) (*ListNetworksResponse, error) { + out := new(ListNetworksResponse) + err := c.cc.Invoke(ctx, "/daemon.DaemonService/ListNetworks", in, out, opts...) if err != nil { return nil, err } return out, nil } -func (c *daemonServiceClient) SelectRoutes(ctx context.Context, in *SelectRoutesRequest, opts ...grpc.CallOption) (*SelectRoutesResponse, error) { - out := new(SelectRoutesResponse) - err := c.cc.Invoke(ctx, "/daemon.DaemonService/SelectRoutes", in, out, opts...) +func (c *daemonServiceClient) SelectNetworks(ctx context.Context, in *SelectNetworksRequest, opts ...grpc.CallOption) (*SelectNetworksResponse, error) { + out := new(SelectNetworksResponse) + err := c.cc.Invoke(ctx, "/daemon.DaemonService/SelectNetworks", in, out, opts...) if err != nil { return nil, err } return out, nil } -func (c *daemonServiceClient) DeselectRoutes(ctx context.Context, in *SelectRoutesRequest, opts ...grpc.CallOption) (*SelectRoutesResponse, error) { - out := new(SelectRoutesResponse) - err := c.cc.Invoke(ctx, "/daemon.DaemonService/DeselectRoutes", in, out, opts...) +func (c *daemonServiceClient) DeselectNetworks(ctx context.Context, in *SelectNetworksRequest, opts ...grpc.CallOption) (*SelectNetworksResponse, error) { + out := new(SelectNetworksResponse) + err := c.cc.Invoke(ctx, "/daemon.DaemonService/DeselectNetworks", in, out, opts...) if err != nil { return nil, err } @@ -222,12 +222,12 @@ type DaemonServiceServer interface { Down(context.Context, *DownRequest) (*DownResponse, error) // GetConfig of the daemon. GetConfig(context.Context, *GetConfigRequest) (*GetConfigResponse, error) - // List available network routes - ListRoutes(context.Context, *ListRoutesRequest) (*ListRoutesResponse, error) + // List available networks + ListNetworks(context.Context, *ListNetworksRequest) (*ListNetworksResponse, error) // Select specific routes - SelectRoutes(context.Context, *SelectRoutesRequest) (*SelectRoutesResponse, error) + SelectNetworks(context.Context, *SelectNetworksRequest) (*SelectNetworksResponse, error) // Deselect specific routes - DeselectRoutes(context.Context, *SelectRoutesRequest) (*SelectRoutesResponse, error) + DeselectNetworks(context.Context, *SelectNetworksRequest) (*SelectNetworksResponse, error) // DebugBundle creates a debug bundle DebugBundle(context.Context, *DebugBundleRequest) (*DebugBundleResponse, error) // GetLogLevel gets the log level of the daemon @@ -267,14 +267,14 @@ func (UnimplementedDaemonServiceServer) Down(context.Context, *DownRequest) (*Do func (UnimplementedDaemonServiceServer) GetConfig(context.Context, *GetConfigRequest) (*GetConfigResponse, error) { return nil, status.Errorf(codes.Unimplemented, "method GetConfig not implemented") } -func (UnimplementedDaemonServiceServer) ListRoutes(context.Context, *ListRoutesRequest) (*ListRoutesResponse, error) { - return nil, status.Errorf(codes.Unimplemented, "method ListRoutes not implemented") +func (UnimplementedDaemonServiceServer) ListNetworks(context.Context, *ListNetworksRequest) (*ListNetworksResponse, error) { + return nil, status.Errorf(codes.Unimplemented, "method ListNetworks not implemented") } -func (UnimplementedDaemonServiceServer) SelectRoutes(context.Context, *SelectRoutesRequest) (*SelectRoutesResponse, error) { - return nil, status.Errorf(codes.Unimplemented, "method SelectRoutes not implemented") +func (UnimplementedDaemonServiceServer) SelectNetworks(context.Context, *SelectNetworksRequest) (*SelectNetworksResponse, error) { + return nil, status.Errorf(codes.Unimplemented, "method SelectNetworks not implemented") } -func (UnimplementedDaemonServiceServer) DeselectRoutes(context.Context, *SelectRoutesRequest) (*SelectRoutesResponse, error) { - return nil, status.Errorf(codes.Unimplemented, "method DeselectRoutes not implemented") +func (UnimplementedDaemonServiceServer) DeselectNetworks(context.Context, *SelectNetworksRequest) (*SelectNetworksResponse, error) { + return nil, status.Errorf(codes.Unimplemented, "method DeselectNetworks not implemented") } func (UnimplementedDaemonServiceServer) DebugBundle(context.Context, *DebugBundleRequest) (*DebugBundleResponse, error) { return nil, status.Errorf(codes.Unimplemented, "method DebugBundle not implemented") @@ -418,56 +418,56 @@ func _DaemonService_GetConfig_Handler(srv interface{}, ctx context.Context, dec return interceptor(ctx, in, info, handler) } -func _DaemonService_ListRoutes_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) { - in := new(ListRoutesRequest) +func _DaemonService_ListNetworks_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) { + in := new(ListNetworksRequest) if err := dec(in); err != nil { return nil, err } if interceptor == nil { - return srv.(DaemonServiceServer).ListRoutes(ctx, in) + return srv.(DaemonServiceServer).ListNetworks(ctx, in) } info := &grpc.UnaryServerInfo{ Server: srv, - FullMethod: "/daemon.DaemonService/ListRoutes", + FullMethod: "/daemon.DaemonService/ListNetworks", } handler := func(ctx context.Context, req interface{}) (interface{}, error) { - return srv.(DaemonServiceServer).ListRoutes(ctx, req.(*ListRoutesRequest)) + return srv.(DaemonServiceServer).ListNetworks(ctx, req.(*ListNetworksRequest)) } return interceptor(ctx, in, info, handler) } -func _DaemonService_SelectRoutes_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) { - in := new(SelectRoutesRequest) +func _DaemonService_SelectNetworks_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) { + in := new(SelectNetworksRequest) if err := dec(in); err != nil { return nil, err } if interceptor == nil { - return srv.(DaemonServiceServer).SelectRoutes(ctx, in) + return srv.(DaemonServiceServer).SelectNetworks(ctx, in) } info := &grpc.UnaryServerInfo{ Server: srv, - FullMethod: "/daemon.DaemonService/SelectRoutes", + FullMethod: "/daemon.DaemonService/SelectNetworks", } handler := func(ctx context.Context, req interface{}) (interface{}, error) { - return srv.(DaemonServiceServer).SelectRoutes(ctx, req.(*SelectRoutesRequest)) + return srv.(DaemonServiceServer).SelectNetworks(ctx, req.(*SelectNetworksRequest)) } return interceptor(ctx, in, info, handler) } -func _DaemonService_DeselectRoutes_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) { - in := new(SelectRoutesRequest) +func _DaemonService_DeselectNetworks_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) { + in := new(SelectNetworksRequest) if err := dec(in); err != nil { return nil, err } if interceptor == nil { - return srv.(DaemonServiceServer).DeselectRoutes(ctx, in) + return srv.(DaemonServiceServer).DeselectNetworks(ctx, in) } info := &grpc.UnaryServerInfo{ Server: srv, - FullMethod: "/daemon.DaemonService/DeselectRoutes", + FullMethod: "/daemon.DaemonService/DeselectNetworks", } handler := func(ctx context.Context, req interface{}) (interface{}, error) { - return srv.(DaemonServiceServer).DeselectRoutes(ctx, req.(*SelectRoutesRequest)) + return srv.(DaemonServiceServer).DeselectNetworks(ctx, req.(*SelectNetworksRequest)) } return interceptor(ctx, in, info, handler) } @@ -630,16 +630,16 @@ var DaemonService_ServiceDesc = grpc.ServiceDesc{ Handler: _DaemonService_GetConfig_Handler, }, { - MethodName: "ListRoutes", - Handler: _DaemonService_ListRoutes_Handler, + MethodName: "ListNetworks", + Handler: _DaemonService_ListNetworks_Handler, }, { - MethodName: "SelectRoutes", - Handler: _DaemonService_SelectRoutes_Handler, + MethodName: "SelectNetworks", + Handler: _DaemonService_SelectNetworks_Handler, }, { - MethodName: "DeselectRoutes", - Handler: _DaemonService_DeselectRoutes_Handler, + MethodName: "DeselectNetworks", + Handler: _DaemonService_DeselectNetworks_Handler, }, { MethodName: "DebugBundle", diff --git a/client/server/route.go b/client/server/network.go similarity index 79% rename from client/server/route.go rename to client/server/network.go index 1033ae7d3..b4b4071b4 100644 --- a/client/server/route.go +++ b/client/server/network.go @@ -20,8 +20,8 @@ type selectRoute struct { Selected bool } -// ListRoutes returns a list of all available routes. -func (s *Server) ListRoutes(context.Context, *proto.ListRoutesRequest) (*proto.ListRoutesResponse, error) { +// ListNetworks returns a list of all available networks. +func (s *Server) ListNetworks(context.Context, *proto.ListNetworksRequest) (*proto.ListNetworksResponse, error) { s.mutex.Lock() defer s.mutex.Unlock() @@ -67,11 +67,11 @@ func (s *Server) ListRoutes(context.Context, *proto.ListRoutesRequest) (*proto.L }) resolvedDomains := s.statusRecorder.GetResolvedDomainsStates() - var pbRoutes []*proto.Route + var pbRoutes []*proto.Network for _, route := range routes { - pbRoute := &proto.Route{ + pbRoute := &proto.Network{ ID: string(route.NetID), - Network: route.Network.String(), + Range: route.Network.String(), Domains: route.Domains.ToSafeStringList(), ResolvedIPs: map[string]*proto.IPList{}, Selected: route.Selected, @@ -91,13 +91,13 @@ func (s *Server) ListRoutes(context.Context, *proto.ListRoutesRequest) (*proto.L pbRoutes = append(pbRoutes, pbRoute) } - return &proto.ListRoutesResponse{ + return &proto.ListNetworksResponse{ Routes: pbRoutes, }, nil } -// SelectRoutes selects specific routes based on the client request. -func (s *Server) SelectRoutes(_ context.Context, req *proto.SelectRoutesRequest) (*proto.SelectRoutesResponse, error) { +// SelectNetworks selects specific networks based on the client request. +func (s *Server) SelectNetworks(_ context.Context, req *proto.SelectNetworksRequest) (*proto.SelectNetworksResponse, error) { s.mutex.Lock() defer s.mutex.Unlock() @@ -115,7 +115,7 @@ func (s *Server) SelectRoutes(_ context.Context, req *proto.SelectRoutesRequest) if req.GetAll() { routeSelector.SelectAllRoutes() } else { - routes := toNetIDs(req.GetRouteIDs()) + routes := toNetIDs(req.GetNetworkIDs()) netIdRoutes := maps.Keys(routeManager.GetClientRoutesWithNetID()) if err := routeSelector.SelectRoutes(routes, req.GetAppend(), netIdRoutes); err != nil { return nil, fmt.Errorf("select routes: %w", err) @@ -123,11 +123,11 @@ func (s *Server) SelectRoutes(_ context.Context, req *proto.SelectRoutesRequest) } routeManager.TriggerSelection(routeManager.GetClientRoutes()) - return &proto.SelectRoutesResponse{}, nil + return &proto.SelectNetworksResponse{}, nil } -// DeselectRoutes deselects specific routes based on the client request. -func (s *Server) DeselectRoutes(_ context.Context, req *proto.SelectRoutesRequest) (*proto.SelectRoutesResponse, error) { +// DeselectNetworks deselects specific networks based on the client request. +func (s *Server) DeselectNetworks(_ context.Context, req *proto.SelectNetworksRequest) (*proto.SelectNetworksResponse, error) { s.mutex.Lock() defer s.mutex.Unlock() @@ -145,7 +145,7 @@ func (s *Server) DeselectRoutes(_ context.Context, req *proto.SelectRoutesReques if req.GetAll() { routeSelector.DeselectAllRoutes() } else { - routes := toNetIDs(req.GetRouteIDs()) + routes := toNetIDs(req.GetNetworkIDs()) netIdRoutes := maps.Keys(routeManager.GetClientRoutesWithNetID()) if err := routeSelector.DeselectRoutes(routes, netIdRoutes); err != nil { return nil, fmt.Errorf("deselect routes: %w", err) @@ -153,7 +153,7 @@ func (s *Server) DeselectRoutes(_ context.Context, req *proto.SelectRoutesReques } routeManager.TriggerSelection(routeManager.GetClientRoutes()) - return &proto.SelectRoutesResponse{}, nil + return &proto.SelectNetworksResponse{}, nil } func toNetIDs(routes []string) []route.NetID { diff --git a/client/server/server.go b/client/server/server.go index 71eb58a66..5640ffa39 100644 --- a/client/server/server.go +++ b/client/server/server.go @@ -772,7 +772,7 @@ func toProtoFullStatus(fullStatus peer.FullStatus) *proto.FullStatus { pbFullStatus.LocalPeerState.Fqdn = fullStatus.LocalPeerState.FQDN pbFullStatus.LocalPeerState.RosenpassPermissive = fullStatus.RosenpassState.Permissive pbFullStatus.LocalPeerState.RosenpassEnabled = fullStatus.RosenpassState.Enabled - pbFullStatus.LocalPeerState.Routes = maps.Keys(fullStatus.LocalPeerState.Routes) + pbFullStatus.LocalPeerState.Networks = maps.Keys(fullStatus.LocalPeerState.Routes) for _, peerState := range fullStatus.Peers { pbPeerState := &proto.PeerState{ @@ -791,7 +791,7 @@ func toProtoFullStatus(fullStatus peer.FullStatus) *proto.FullStatus { BytesRx: peerState.BytesRx, BytesTx: peerState.BytesTx, RosenpassEnabled: peerState.RosenpassEnabled, - Routes: maps.Keys(peerState.GetRoutes()), + Networks: maps.Keys(peerState.GetRoutes()), Latency: durationpb.New(peerState.Latency), } pbFullStatus.Peers = append(pbFullStatus.Peers, pbPeerState) diff --git a/client/server/server_test.go b/client/server/server_test.go index 61bdaf660..8df033d91 100644 --- a/client/server/server_test.go +++ b/client/server/server_test.go @@ -20,6 +20,7 @@ import ( mgmtProto "github.com/netbirdio/netbird/management/proto" "github.com/netbirdio/netbird/management/server" "github.com/netbirdio/netbird/management/server/activity" + "github.com/netbirdio/netbird/management/server/store" "github.com/netbirdio/netbird/management/server/telemetry" "github.com/netbirdio/netbird/signal/proto" signalServer "github.com/netbirdio/netbird/signal/server" @@ -110,7 +111,7 @@ func startManagement(t *testing.T, signalAddr string, counter *int) (*grpc.Serve return nil, "", err } s := grpc.NewServer(grpc.KeepaliveEnforcementPolicy(kaep), grpc.KeepaliveParams(kasp)) - store, cleanUp, err := server.NewTestStoreFromSQL(context.Background(), "", config.Datadir) + store, cleanUp, err := store.NewTestStoreFromSQL(context.Background(), "", config.Datadir) if err != nil { return nil, "", err } diff --git a/client/ui/client_ui.go b/client/ui/client_ui.go index d046bab5f..157775a11 100644 --- a/client/ui/client_ui.go +++ b/client/ui/client_ui.go @@ -58,7 +58,7 @@ func main() { var showSettings bool flag.BoolVar(&showSettings, "settings", false, "run settings windows") var showRoutes bool - flag.BoolVar(&showRoutes, "routes", false, "run routes windows") + flag.BoolVar(&showRoutes, "networks", false, "run networks windows") var errorMSG string flag.StringVar(&errorMSG, "error-msg", "", "displays a error message window") @@ -233,7 +233,7 @@ func newServiceClient(addr string, a fyne.App, showSettings bool, showRoutes boo s.showSettingsUI() return s } else if showRoutes { - s.showRoutesUI() + s.showNetworksUI() } return s @@ -549,7 +549,7 @@ func (s *serviceClient) onTrayReady() { s.mAdvancedSettings = s.mSettings.AddSubMenuItem("Advanced Settings", "Advanced settings of the application") s.loadSettings() - s.mRoutes = systray.AddMenuItem("Network Routes", "Open the routes management window") + s.mRoutes = systray.AddMenuItem("Networks", "Open the networks management window") s.mRoutes.Disable() systray.AddSeparator() @@ -656,7 +656,7 @@ func (s *serviceClient) onTrayReady() { s.mRoutes.Disable() go func() { defer s.mRoutes.Enable() - s.runSelfCommand("routes", "true") + s.runSelfCommand("networks", "true") }() } if err != nil { diff --git a/client/ui/route.go b/client/ui/network.go similarity index 58% rename from client/ui/route.go rename to client/ui/network.go index 5b6b8fee0..a74c714e0 100644 --- a/client/ui/route.go +++ b/client/ui/network.go @@ -19,32 +19,32 @@ import ( ) const ( - allRoutesText = "All routes" - overlappingRoutesText = "Overlapping routes" - exitNodeRoutesText = "Exit-node routes" - allRoutes filter = "all" - overlappingRoutes filter = "overlapping" - exitNodeRoutes filter = "exit-node" - getClientFMT = "get client: %v" + allNetworksText = "All networks" + overlappingNetworksText = "Overlapping networks" + exitNodeNetworksText = "Exit-node networks" + allNetworks filter = "all" + overlappingNetworks filter = "overlapping" + exitNodeNetworks filter = "exit-node" + getClientFMT = "get client: %v" ) type filter string -func (s *serviceClient) showRoutesUI() { - s.wRoutes = s.app.NewWindow("NetBird Routes") +func (s *serviceClient) showNetworksUI() { + s.wRoutes = s.app.NewWindow("Networks") allGrid := container.New(layout.NewGridLayout(3)) - go s.updateRoutes(allGrid, allRoutes) + go s.updateNetworks(allGrid, allNetworks) overlappingGrid := container.New(layout.NewGridLayout(3)) exitNodeGrid := container.New(layout.NewGridLayout(3)) routeCheckContainer := container.NewVBox() tabs := container.NewAppTabs( - container.NewTabItem(allRoutesText, allGrid), - container.NewTabItem(overlappingRoutesText, overlappingGrid), - container.NewTabItem(exitNodeRoutesText, exitNodeGrid), + container.NewTabItem(allNetworksText, allGrid), + container.NewTabItem(overlappingNetworksText, overlappingGrid), + container.NewTabItem(exitNodeNetworksText, exitNodeGrid), ) tabs.OnSelected = func(item *container.TabItem) { - s.updateRoutesBasedOnDisplayTab(tabs, allGrid, overlappingGrid, exitNodeGrid) + s.updateNetworksBasedOnDisplayTab(tabs, allGrid, overlappingGrid, exitNodeGrid) } tabs.OnUnselected = func(item *container.TabItem) { grid, _ := getGridAndFilterFromTab(tabs, allGrid, overlappingGrid, exitNodeGrid) @@ -58,17 +58,17 @@ func (s *serviceClient) showRoutesUI() { buttonBox := container.NewHBox( layout.NewSpacer(), widget.NewButton("Refresh", func() { - s.updateRoutesBasedOnDisplayTab(tabs, allGrid, overlappingGrid, exitNodeGrid) + s.updateNetworksBasedOnDisplayTab(tabs, allGrid, overlappingGrid, exitNodeGrid) }), widget.NewButton("Select all", func() { _, f := getGridAndFilterFromTab(tabs, allGrid, overlappingGrid, exitNodeGrid) - s.selectAllFilteredRoutes(f) - s.updateRoutesBasedOnDisplayTab(tabs, allGrid, overlappingGrid, exitNodeGrid) + s.selectAllFilteredNetworks(f) + s.updateNetworksBasedOnDisplayTab(tabs, allGrid, overlappingGrid, exitNodeGrid) }), widget.NewButton("Deselect All", func() { _, f := getGridAndFilterFromTab(tabs, allGrid, overlappingGrid, exitNodeGrid) - s.deselectAllFilteredRoutes(f) - s.updateRoutesBasedOnDisplayTab(tabs, allGrid, overlappingGrid, exitNodeGrid) + s.deselectAllFilteredNetworks(f) + s.updateNetworksBasedOnDisplayTab(tabs, allGrid, overlappingGrid, exitNodeGrid) }), layout.NewSpacer(), ) @@ -81,36 +81,36 @@ func (s *serviceClient) showRoutesUI() { s.startAutoRefresh(10*time.Second, tabs, allGrid, overlappingGrid, exitNodeGrid) } -func (s *serviceClient) updateRoutes(grid *fyne.Container, f filter) { +func (s *serviceClient) updateNetworks(grid *fyne.Container, f filter) { grid.Objects = nil grid.Refresh() idHeader := widget.NewLabelWithStyle(" ID", fyne.TextAlignLeading, fyne.TextStyle{Bold: true}) - networkHeader := widget.NewLabelWithStyle("Network/Domains", fyne.TextAlignLeading, fyne.TextStyle{Bold: true}) + networkHeader := widget.NewLabelWithStyle("Range/Domains", fyne.TextAlignLeading, fyne.TextStyle{Bold: true}) resolvedIPsHeader := widget.NewLabelWithStyle("Resolved IPs", fyne.TextAlignLeading, fyne.TextStyle{Bold: true}) grid.Add(idHeader) grid.Add(networkHeader) grid.Add(resolvedIPsHeader) - filteredRoutes, err := s.getFilteredRoutes(f) + filteredRoutes, err := s.getFilteredNetworks(f) if err != nil { return } - sortRoutesByIDs(filteredRoutes) + sortNetworksByIDs(filteredRoutes) for _, route := range filteredRoutes { r := route checkBox := widget.NewCheck(r.GetID(), func(checked bool) { - s.selectRoute(r.ID, checked) + s.selectNetwork(r.ID, checked) }) checkBox.Checked = route.Selected checkBox.Resize(fyne.NewSize(20, 20)) checkBox.Refresh() grid.Add(checkBox) - network := r.GetNetwork() + network := r.GetRange() domains := r.GetDomains() if len(domains) == 0 { @@ -151,35 +151,35 @@ func (s *serviceClient) updateRoutes(grid *fyne.Container, f filter) { grid.Refresh() } -func (s *serviceClient) getFilteredRoutes(f filter) ([]*proto.Route, error) { - routes, err := s.fetchRoutes() +func (s *serviceClient) getFilteredNetworks(f filter) ([]*proto.Network, error) { + routes, err := s.fetchNetworks() if err != nil { log.Errorf(getClientFMT, err) s.showError(fmt.Errorf(getClientFMT, err)) return nil, err } switch f { - case overlappingRoutes: - return getOverlappingRoutes(routes), nil - case exitNodeRoutes: - return getExitNodeRoutes(routes), nil + case overlappingNetworks: + return getOverlappingNetworks(routes), nil + case exitNodeNetworks: + return getExitNodeNetworks(routes), nil default: } return routes, nil } -func getOverlappingRoutes(routes []*proto.Route) []*proto.Route { - var filteredRoutes []*proto.Route - existingRange := make(map[string][]*proto.Route) +func getOverlappingNetworks(routes []*proto.Network) []*proto.Network { + var filteredRoutes []*proto.Network + existingRange := make(map[string][]*proto.Network) for _, route := range routes { if len(route.Domains) > 0 { continue } - if r, exists := existingRange[route.GetNetwork()]; exists { + if r, exists := existingRange[route.GetRange()]; exists { r = append(r, route) - existingRange[route.GetNetwork()] = r + existingRange[route.GetRange()] = r } else { - existingRange[route.GetNetwork()] = []*proto.Route{route} + existingRange[route.GetRange()] = []*proto.Network{route} } } for _, r := range existingRange { @@ -190,29 +190,29 @@ func getOverlappingRoutes(routes []*proto.Route) []*proto.Route { return filteredRoutes } -func getExitNodeRoutes(routes []*proto.Route) []*proto.Route { - var filteredRoutes []*proto.Route +func getExitNodeNetworks(routes []*proto.Network) []*proto.Network { + var filteredRoutes []*proto.Network for _, route := range routes { - if route.Network == "0.0.0.0/0" { + if route.Range == "0.0.0.0/0" { filteredRoutes = append(filteredRoutes, route) } } return filteredRoutes } -func sortRoutesByIDs(routes []*proto.Route) { +func sortNetworksByIDs(routes []*proto.Network) { sort.Slice(routes, func(i, j int) bool { return strings.ToLower(routes[i].GetID()) < strings.ToLower(routes[j].GetID()) }) } -func (s *serviceClient) fetchRoutes() ([]*proto.Route, error) { +func (s *serviceClient) fetchNetworks() ([]*proto.Network, error) { conn, err := s.getSrvClient(defaultFailTimeout) if err != nil { return nil, fmt.Errorf(getClientFMT, err) } - resp, err := conn.ListRoutes(s.ctx, &proto.ListRoutesRequest{}) + resp, err := conn.ListNetworks(s.ctx, &proto.ListNetworksRequest{}) if err != nil { return nil, fmt.Errorf("failed to list routes: %v", err) } @@ -220,7 +220,7 @@ func (s *serviceClient) fetchRoutes() ([]*proto.Route, error) { return resp.Routes, nil } -func (s *serviceClient) selectRoute(id string, checked bool) { +func (s *serviceClient) selectNetwork(id string, checked bool) { conn, err := s.getSrvClient(defaultFailTimeout) if err != nil { log.Errorf(getClientFMT, err) @@ -228,73 +228,73 @@ func (s *serviceClient) selectRoute(id string, checked bool) { return } - req := &proto.SelectRoutesRequest{ - RouteIDs: []string{id}, - Append: checked, + req := &proto.SelectNetworksRequest{ + NetworkIDs: []string{id}, + Append: checked, } if checked { - if _, err := conn.SelectRoutes(s.ctx, req); err != nil { - log.Errorf("failed to select route: %v", err) - s.showError(fmt.Errorf("failed to select route: %v", err)) + if _, err := conn.SelectNetworks(s.ctx, req); err != nil { + log.Errorf("failed to select network: %v", err) + s.showError(fmt.Errorf("failed to select network: %v", err)) return } log.Infof("Route %s selected", id) } else { - if _, err := conn.DeselectRoutes(s.ctx, req); err != nil { - log.Errorf("failed to deselect route: %v", err) - s.showError(fmt.Errorf("failed to deselect route: %v", err)) + if _, err := conn.DeselectNetworks(s.ctx, req); err != nil { + log.Errorf("failed to deselect network: %v", err) + s.showError(fmt.Errorf("failed to deselect network: %v", err)) return } - log.Infof("Route %s deselected", id) + log.Infof("Network %s deselected", id) } } -func (s *serviceClient) selectAllFilteredRoutes(f filter) { +func (s *serviceClient) selectAllFilteredNetworks(f filter) { conn, err := s.getSrvClient(defaultFailTimeout) if err != nil { log.Errorf(getClientFMT, err) return } - req := s.getRoutesRequest(f, true) - if _, err := conn.SelectRoutes(s.ctx, req); err != nil { - log.Errorf("failed to select all routes: %v", err) - s.showError(fmt.Errorf("failed to select all routes: %v", err)) + req := s.getNetworksRequest(f, true) + if _, err := conn.SelectNetworks(s.ctx, req); err != nil { + log.Errorf("failed to select all networks: %v", err) + s.showError(fmt.Errorf("failed to select all networks: %v", err)) return } - log.Debug("All routes selected") + log.Debug("All networks selected") } -func (s *serviceClient) deselectAllFilteredRoutes(f filter) { +func (s *serviceClient) deselectAllFilteredNetworks(f filter) { conn, err := s.getSrvClient(defaultFailTimeout) if err != nil { log.Errorf(getClientFMT, err) return } - req := s.getRoutesRequest(f, false) - if _, err := conn.DeselectRoutes(s.ctx, req); err != nil { - log.Errorf("failed to deselect all routes: %v", err) - s.showError(fmt.Errorf("failed to deselect all routes: %v", err)) + req := s.getNetworksRequest(f, false) + if _, err := conn.DeselectNetworks(s.ctx, req); err != nil { + log.Errorf("failed to deselect all networks: %v", err) + s.showError(fmt.Errorf("failed to deselect all networks: %v", err)) return } - log.Debug("All routes deselected") + log.Debug("All networks deselected") } -func (s *serviceClient) getRoutesRequest(f filter, appendRoute bool) *proto.SelectRoutesRequest { - req := &proto.SelectRoutesRequest{} - if f == allRoutes { +func (s *serviceClient) getNetworksRequest(f filter, appendRoute bool) *proto.SelectNetworksRequest { + req := &proto.SelectNetworksRequest{} + if f == allNetworks { req.All = true } else { - routes, err := s.getFilteredRoutes(f) + routes, err := s.getFilteredNetworks(f) if err != nil { return nil } for _, route := range routes { - req.RouteIDs = append(req.RouteIDs, route.GetID()) + req.NetworkIDs = append(req.NetworkIDs, route.GetID()) } req.Append = appendRoute } @@ -311,7 +311,7 @@ func (s *serviceClient) startAutoRefresh(interval time.Duration, tabs *container ticker := time.NewTicker(interval) go func() { for range ticker.C { - s.updateRoutesBasedOnDisplayTab(tabs, allGrid, overlappingGrid, exitNodesGrid) + s.updateNetworksBasedOnDisplayTab(tabs, allGrid, overlappingGrid, exitNodesGrid) } }() @@ -320,20 +320,20 @@ func (s *serviceClient) startAutoRefresh(interval time.Duration, tabs *container }) } -func (s *serviceClient) updateRoutesBasedOnDisplayTab(tabs *container.AppTabs, allGrid, overlappingGrid, exitNodesGrid *fyne.Container) { +func (s *serviceClient) updateNetworksBasedOnDisplayTab(tabs *container.AppTabs, allGrid, overlappingGrid, exitNodesGrid *fyne.Container) { grid, f := getGridAndFilterFromTab(tabs, allGrid, overlappingGrid, exitNodesGrid) s.wRoutes.Content().Refresh() - s.updateRoutes(grid, f) + s.updateNetworks(grid, f) } func getGridAndFilterFromTab(tabs *container.AppTabs, allGrid, overlappingGrid, exitNodesGrid *fyne.Container) (*fyne.Container, filter) { switch tabs.Selected().Text { - case overlappingRoutesText: - return overlappingGrid, overlappingRoutes - case exitNodeRoutesText: - return exitNodesGrid, exitNodeRoutes + case overlappingNetworksText: + return overlappingGrid, overlappingNetworks + case exitNodeNetworksText: + return exitNodesGrid, exitNodeNetworks default: - return allGrid, allRoutes + return allGrid, allNetworks } } diff --git a/go.mod b/go.mod index 14f800036..53124aa69 100644 --- a/go.mod +++ b/go.mod @@ -60,7 +60,7 @@ require ( github.com/miekg/dns v1.1.59 github.com/mitchellh/hashstructure/v2 v2.0.2 github.com/nadoo/ipset v0.5.0 - github.com/netbirdio/management-integrations/integrations v0.0.0-20241106153857-de8e2beb5254 + github.com/netbirdio/management-integrations/integrations v0.0.0-20241211172827-ba0a446be480 github.com/netbirdio/signal-dispatcher/dispatcher v0.0.0-20241010133937-e0df50df217d github.com/okta/okta-sdk-golang/v2 v2.18.0 github.com/oschwald/maxminddb-golang v1.12.0 diff --git a/go.sum b/go.sum index 3bdeb6827..257644284 100644 --- a/go.sum +++ b/go.sum @@ -521,8 +521,8 @@ github.com/netbirdio/go-netroute v0.0.0-20240611143515-f59b0e1d3944 h1:TDtJKmM6S github.com/netbirdio/go-netroute v0.0.0-20240611143515-f59b0e1d3944/go.mod h1:sHA6TRxjQ6RLbnI+3R4DZo2Eseg/iKiPRfNmcuNySVQ= github.com/netbirdio/ice/v3 v3.0.0-20240315174635-e72a50fcb64e h1:PURA50S8u4mF6RrkYYCAvvPCixhqqEiEy3Ej6avh04c= github.com/netbirdio/ice/v3 v3.0.0-20240315174635-e72a50fcb64e/go.mod h1:YMLU7qbKfVjmEv7EoZPIVEI+kNYxWCdPK3VS0BU+U4Q= -github.com/netbirdio/management-integrations/integrations v0.0.0-20241106153857-de8e2beb5254 h1:L8mNd3tBxMdnQNxMNJ+/EiwHwizNOMy8/nHLVGNfjpg= -github.com/netbirdio/management-integrations/integrations v0.0.0-20241106153857-de8e2beb5254/go.mod h1:nykwWZnxb+sJz2Z//CEq45CMRWSHllH8pODKRB8eY7Y= +github.com/netbirdio/management-integrations/integrations v0.0.0-20241211172827-ba0a446be480 h1:M+UPn/o+plVE7ZehgL6/1dftptsO1tyTPssgImgi+28= +github.com/netbirdio/management-integrations/integrations v0.0.0-20241211172827-ba0a446be480/go.mod h1:RC0PnyATSBPrRWKQgb+7KcC1tMta9eYyzuA414RG9wQ= github.com/netbirdio/service v0.0.0-20240911161631-f62744f42502 h1:3tHlFmhTdX9axERMVN63dqyFqnvuD+EMJHzM7mNGON8= github.com/netbirdio/service v0.0.0-20240911161631-f62744f42502/go.mod h1:CIMRFEJVL+0DS1a3Nx06NaMn4Dz63Ng6O7dl0qH0zVM= github.com/netbirdio/signal-dispatcher/dispatcher v0.0.0-20241010133937-e0df50df217d h1:bRq5TKgC7Iq20pDiuC54yXaWnAVeS5PdGpSokFTlR28= diff --git a/management/client/client_test.go b/management/client/client_test.go index 100b3fcaa..083002442 100644 --- a/management/client/client_test.go +++ b/management/client/client_test.go @@ -11,6 +11,7 @@ import ( "github.com/stretchr/testify/require" "github.com/netbirdio/netbird/management/server/activity" + "github.com/netbirdio/netbird/management/server/store" "github.com/netbirdio/netbird/management/server/telemetry" "github.com/netbirdio/netbird/client/system" @@ -57,7 +58,7 @@ func startManagement(t *testing.T) (*grpc.Server, net.Listener) { t.Fatal(err) } s := grpc.NewServer() - store, cleanUp, err := mgmt.NewTestStoreFromSQL(context.Background(), "../server/testdata/store.sql", t.TempDir()) + store, cleanUp, err := store.NewTestStoreFromSQL(context.Background(), "../server/testdata/store.sql", t.TempDir()) if err != nil { t.Fatal(err) } diff --git a/management/cmd/management.go b/management/cmd/management.go index 719d1a78c..2248b52d9 100644 --- a/management/cmd/management.go +++ b/management/cmd/management.go @@ -42,9 +42,11 @@ import ( nbContext "github.com/netbirdio/netbird/management/server/context" "github.com/netbirdio/netbird/management/server/geolocation" httpapi "github.com/netbirdio/netbird/management/server/http" + "github.com/netbirdio/netbird/management/server/http/configs" "github.com/netbirdio/netbird/management/server/idp" "github.com/netbirdio/netbird/management/server/jwtclaims" "github.com/netbirdio/netbird/management/server/metrics" + "github.com/netbirdio/netbird/management/server/store" "github.com/netbirdio/netbird/management/server/telemetry" "github.com/netbirdio/netbird/util" "github.com/netbirdio/netbird/version" @@ -149,7 +151,7 @@ var ( if err != nil { return err } - store, err := server.NewStore(ctx, config.StoreConfig.Engine, config.Datadir, appMetrics) + store, err := store.NewStore(ctx, config.StoreConfig.Engine, config.Datadir, appMetrics) if err != nil { return fmt.Errorf("failed creating Store: %s: %v", config.Datadir, err) } @@ -257,7 +259,7 @@ var ( return fmt.Errorf("failed creating JWT validator: %v", err) } - httpAPIAuthCfg := httpapi.AuthCfg{ + httpAPIAuthCfg := configs.AuthCfg{ Issuer: config.HttpConfig.AuthIssuer, Audience: config.HttpConfig.AuthAudience, UserIDClaim: config.HttpConfig.AuthUserIDClaim, @@ -399,7 +401,7 @@ func notifyStop(ctx context.Context, msg string) { } } -func getInstallationID(ctx context.Context, store server.Store) (string, error) { +func getInstallationID(ctx context.Context, store store.Store) (string, error) { installationID := store.GetInstallationID() if installationID != "" { return installationID, nil diff --git a/management/cmd/migration_up.go b/management/cmd/migration_up.go index 7aa11f0c9..183fc554d 100644 --- a/management/cmd/migration_up.go +++ b/management/cmd/migration_up.go @@ -9,7 +9,7 @@ import ( "github.com/spf13/cobra" "github.com/netbirdio/netbird/formatter" - "github.com/netbirdio/netbird/management/server" + "github.com/netbirdio/netbird/management/server/store" "github.com/netbirdio/netbird/util" ) @@ -32,7 +32,7 @@ var upCmd = &cobra.Command{ //nolint ctx := context.WithValue(cmd.Context(), formatter.ExecutionContextKey, formatter.SystemSource) - if err := server.MigrateFileStoreToSqlite(ctx, mgmtDataDir); err != nil { + if err := store.MigrateFileStoreToSqlite(ctx, mgmtDataDir); err != nil { return err } log.WithContext(ctx).Info("Migration finished successfully") diff --git a/management/proto/management.pb.go b/management/proto/management.pb.go index 672b2a102..9662e1330 100644 --- a/management/proto/management.pb.go +++ b/management/proto/management.pb.go @@ -1,7 +1,7 @@ // Code generated by protoc-gen-go. DO NOT EDIT. // versions: // protoc-gen-go v1.26.0 -// protoc v4.23.4 +// protoc v4.24.3 // source: management.proto package proto @@ -29,6 +29,7 @@ const ( RuleProtocol_TCP RuleProtocol = 2 RuleProtocol_UDP RuleProtocol = 3 RuleProtocol_ICMP RuleProtocol = 4 + RuleProtocol_CUSTOM RuleProtocol = 5 ) // Enum value maps for RuleProtocol. @@ -39,6 +40,7 @@ var ( 2: "TCP", 3: "UDP", 4: "ICMP", + 5: "CUSTOM", } RuleProtocol_value = map[string]int32{ "UNKNOWN": 0, @@ -46,6 +48,7 @@ var ( "TCP": 2, "UDP": 3, "ICMP": 4, + "CUSTOM": 5, } ) @@ -2780,6 +2783,10 @@ type RouteFirewallRule struct { PortInfo *PortInfo `protobuf:"bytes,5,opt,name=portInfo,proto3" json:"portInfo,omitempty"` // IsDynamic indicates if the route is a DNS route. IsDynamic bool `protobuf:"varint,6,opt,name=isDynamic,proto3" json:"isDynamic,omitempty"` + // Domains is a list of domains for which the rule is applicable. + Domains []string `protobuf:"bytes,7,rep,name=domains,proto3" json:"domains,omitempty"` + // CustomProtocol is a custom protocol ID. + CustomProtocol uint32 `protobuf:"varint,8,opt,name=customProtocol,proto3" json:"customProtocol,omitempty"` } func (x *RouteFirewallRule) Reset() { @@ -2856,6 +2863,20 @@ func (x *RouteFirewallRule) GetIsDynamic() bool { return false } +func (x *RouteFirewallRule) GetDomains() []string { + if x != nil { + return x.Domains + } + return nil +} + +func (x *RouteFirewallRule) GetCustomProtocol() uint32 { + if x != nil { + return x.CustomProtocol + } + return 0 +} + type PortInfo_Range struct { state protoimpl.MessageState sizeCache protoimpl.SizeCache @@ -3266,7 +3287,7 @@ var file_management_proto_rawDesc = []byte{ 0x65, 0x12, 0x14, 0x0a, 0x05, 0x73, 0x74, 0x61, 0x72, 0x74, 0x18, 0x01, 0x20, 0x01, 0x28, 0x0d, 0x52, 0x05, 0x73, 0x74, 0x61, 0x72, 0x74, 0x12, 0x10, 0x0a, 0x03, 0x65, 0x6e, 0x64, 0x18, 0x02, 0x20, 0x01, 0x28, 0x0d, 0x52, 0x03, 0x65, 0x6e, 0x64, 0x42, 0x0f, 0x0a, 0x0d, 0x70, 0x6f, 0x72, - 0x74, 0x53, 0x65, 0x6c, 0x65, 0x63, 0x74, 0x69, 0x6f, 0x6e, 0x22, 0x8f, 0x02, 0x0a, 0x11, 0x52, + 0x74, 0x53, 0x65, 0x6c, 0x65, 0x63, 0x74, 0x69, 0x6f, 0x6e, 0x22, 0xd1, 0x02, 0x0a, 0x11, 0x52, 0x6f, 0x75, 0x74, 0x65, 0x46, 0x69, 0x72, 0x65, 0x77, 0x61, 0x6c, 0x6c, 0x52, 0x75, 0x6c, 0x65, 0x12, 0x22, 0x0a, 0x0c, 0x73, 0x6f, 0x75, 0x72, 0x63, 0x65, 0x52, 0x61, 0x6e, 0x67, 0x65, 0x73, 0x18, 0x01, 0x20, 0x03, 0x28, 0x09, 0x52, 0x0c, 0x73, 0x6f, 0x75, 0x72, 0x63, 0x65, 0x52, 0x61, @@ -3283,50 +3304,55 @@ var file_management_proto_rawDesc = []byte{ 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x50, 0x6f, 0x72, 0x74, 0x49, 0x6e, 0x66, 0x6f, 0x52, 0x08, 0x70, 0x6f, 0x72, 0x74, 0x49, 0x6e, 0x66, 0x6f, 0x12, 0x1c, 0x0a, 0x09, 0x69, 0x73, 0x44, 0x79, 0x6e, 0x61, 0x6d, 0x69, 0x63, 0x18, 0x06, 0x20, 0x01, 0x28, - 0x08, 0x52, 0x09, 0x69, 0x73, 0x44, 0x79, 0x6e, 0x61, 0x6d, 0x69, 0x63, 0x2a, 0x40, 0x0a, 0x0c, - 0x52, 0x75, 0x6c, 0x65, 0x50, 0x72, 0x6f, 0x74, 0x6f, 0x63, 0x6f, 0x6c, 0x12, 0x0b, 0x0a, 0x07, - 0x55, 0x4e, 0x4b, 0x4e, 0x4f, 0x57, 0x4e, 0x10, 0x00, 0x12, 0x07, 0x0a, 0x03, 0x41, 0x4c, 0x4c, - 0x10, 0x01, 0x12, 0x07, 0x0a, 0x03, 0x54, 0x43, 0x50, 0x10, 0x02, 0x12, 0x07, 0x0a, 0x03, 0x55, - 0x44, 0x50, 0x10, 0x03, 0x12, 0x08, 0x0a, 0x04, 0x49, 0x43, 0x4d, 0x50, 0x10, 0x04, 0x2a, 0x20, - 0x0a, 0x0d, 0x52, 0x75, 0x6c, 0x65, 0x44, 0x69, 0x72, 0x65, 0x63, 0x74, 0x69, 0x6f, 0x6e, 0x12, - 0x06, 0x0a, 0x02, 0x49, 0x4e, 0x10, 0x00, 0x12, 0x07, 0x0a, 0x03, 0x4f, 0x55, 0x54, 0x10, 0x01, - 0x2a, 0x22, 0x0a, 0x0a, 0x52, 0x75, 0x6c, 0x65, 0x41, 0x63, 0x74, 0x69, 0x6f, 0x6e, 0x12, 0x0a, - 0x0a, 0x06, 0x41, 0x43, 0x43, 0x45, 0x50, 0x54, 0x10, 0x00, 0x12, 0x08, 0x0a, 0x04, 0x44, 0x52, - 0x4f, 0x50, 0x10, 0x01, 0x32, 0x90, 0x04, 0x0a, 0x11, 0x4d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, - 0x65, 0x6e, 0x74, 0x53, 0x65, 0x72, 0x76, 0x69, 0x63, 0x65, 0x12, 0x45, 0x0a, 0x05, 0x4c, 0x6f, - 0x67, 0x69, 0x6e, 0x12, 0x1c, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, - 0x2e, 0x45, 0x6e, 0x63, 0x72, 0x79, 0x70, 0x74, 0x65, 0x64, 0x4d, 0x65, 0x73, 0x73, 0x61, 0x67, - 0x65, 0x1a, 0x1c, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x45, - 0x6e, 0x63, 0x72, 0x79, 0x70, 0x74, 0x65, 0x64, 0x4d, 0x65, 0x73, 0x73, 0x61, 0x67, 0x65, 0x22, - 0x00, 0x12, 0x46, 0x0a, 0x04, 0x53, 0x79, 0x6e, 0x63, 0x12, 0x1c, 0x2e, 0x6d, 0x61, 0x6e, 0x61, - 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x45, 0x6e, 0x63, 0x72, 0x79, 0x70, 0x74, 0x65, 0x64, - 0x4d, 0x65, 0x73, 0x73, 0x61, 0x67, 0x65, 0x1a, 0x1c, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, + 0x08, 0x52, 0x09, 0x69, 0x73, 0x44, 0x79, 0x6e, 0x61, 0x6d, 0x69, 0x63, 0x12, 0x18, 0x0a, 0x07, + 0x64, 0x6f, 0x6d, 0x61, 0x69, 0x6e, 0x73, 0x18, 0x07, 0x20, 0x03, 0x28, 0x09, 0x52, 0x07, 0x64, + 0x6f, 0x6d, 0x61, 0x69, 0x6e, 0x73, 0x12, 0x26, 0x0a, 0x0e, 0x63, 0x75, 0x73, 0x74, 0x6f, 0x6d, + 0x50, 0x72, 0x6f, 0x74, 0x6f, 0x63, 0x6f, 0x6c, 0x18, 0x08, 0x20, 0x01, 0x28, 0x0d, 0x52, 0x0e, + 0x63, 0x75, 0x73, 0x74, 0x6f, 0x6d, 0x50, 0x72, 0x6f, 0x74, 0x6f, 0x63, 0x6f, 0x6c, 0x2a, 0x4c, + 0x0a, 0x0c, 0x52, 0x75, 0x6c, 0x65, 0x50, 0x72, 0x6f, 0x74, 0x6f, 0x63, 0x6f, 0x6c, 0x12, 0x0b, + 0x0a, 0x07, 0x55, 0x4e, 0x4b, 0x4e, 0x4f, 0x57, 0x4e, 0x10, 0x00, 0x12, 0x07, 0x0a, 0x03, 0x41, + 0x4c, 0x4c, 0x10, 0x01, 0x12, 0x07, 0x0a, 0x03, 0x54, 0x43, 0x50, 0x10, 0x02, 0x12, 0x07, 0x0a, + 0x03, 0x55, 0x44, 0x50, 0x10, 0x03, 0x12, 0x08, 0x0a, 0x04, 0x49, 0x43, 0x4d, 0x50, 0x10, 0x04, + 0x12, 0x0a, 0x0a, 0x06, 0x43, 0x55, 0x53, 0x54, 0x4f, 0x4d, 0x10, 0x05, 0x2a, 0x20, 0x0a, 0x0d, + 0x52, 0x75, 0x6c, 0x65, 0x44, 0x69, 0x72, 0x65, 0x63, 0x74, 0x69, 0x6f, 0x6e, 0x12, 0x06, 0x0a, + 0x02, 0x49, 0x4e, 0x10, 0x00, 0x12, 0x07, 0x0a, 0x03, 0x4f, 0x55, 0x54, 0x10, 0x01, 0x2a, 0x22, + 0x0a, 0x0a, 0x52, 0x75, 0x6c, 0x65, 0x41, 0x63, 0x74, 0x69, 0x6f, 0x6e, 0x12, 0x0a, 0x0a, 0x06, + 0x41, 0x43, 0x43, 0x45, 0x50, 0x54, 0x10, 0x00, 0x12, 0x08, 0x0a, 0x04, 0x44, 0x52, 0x4f, 0x50, + 0x10, 0x01, 0x32, 0x90, 0x04, 0x0a, 0x11, 0x4d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, + 0x74, 0x53, 0x65, 0x72, 0x76, 0x69, 0x63, 0x65, 0x12, 0x45, 0x0a, 0x05, 0x4c, 0x6f, 0x67, 0x69, + 0x6e, 0x12, 0x1c, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x45, + 0x6e, 0x63, 0x72, 0x79, 0x70, 0x74, 0x65, 0x64, 0x4d, 0x65, 0x73, 0x73, 0x61, 0x67, 0x65, 0x1a, + 0x1c, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x45, 0x6e, 0x63, + 0x72, 0x79, 0x70, 0x74, 0x65, 0x64, 0x4d, 0x65, 0x73, 0x73, 0x61, 0x67, 0x65, 0x22, 0x00, 0x12, + 0x46, 0x0a, 0x04, 0x53, 0x79, 0x6e, 0x63, 0x12, 0x1c, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x45, 0x6e, 0x63, 0x72, 0x79, 0x70, 0x74, 0x65, 0x64, 0x4d, 0x65, - 0x73, 0x73, 0x61, 0x67, 0x65, 0x22, 0x00, 0x30, 0x01, 0x12, 0x42, 0x0a, 0x0c, 0x47, 0x65, 0x74, - 0x53, 0x65, 0x72, 0x76, 0x65, 0x72, 0x4b, 0x65, 0x79, 0x12, 0x11, 0x2e, 0x6d, 0x61, 0x6e, 0x61, - 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x45, 0x6d, 0x70, 0x74, 0x79, 0x1a, 0x1d, 0x2e, 0x6d, - 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x53, 0x65, 0x72, 0x76, 0x65, 0x72, - 0x4b, 0x65, 0x79, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x22, 0x00, 0x12, 0x33, 0x0a, - 0x09, 0x69, 0x73, 0x48, 0x65, 0x61, 0x6c, 0x74, 0x68, 0x79, 0x12, 0x11, 0x2e, 0x6d, 0x61, 0x6e, - 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x45, 0x6d, 0x70, 0x74, 0x79, 0x1a, 0x11, 0x2e, - 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x45, 0x6d, 0x70, 0x74, 0x79, - 0x22, 0x00, 0x12, 0x5a, 0x0a, 0x1a, 0x47, 0x65, 0x74, 0x44, 0x65, 0x76, 0x69, 0x63, 0x65, 0x41, - 0x75, 0x74, 0x68, 0x6f, 0x72, 0x69, 0x7a, 0x61, 0x74, 0x69, 0x6f, 0x6e, 0x46, 0x6c, 0x6f, 0x77, - 0x12, 0x1c, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x45, 0x6e, - 0x63, 0x72, 0x79, 0x70, 0x74, 0x65, 0x64, 0x4d, 0x65, 0x73, 0x73, 0x61, 0x67, 0x65, 0x1a, 0x1c, + 0x73, 0x73, 0x61, 0x67, 0x65, 0x1a, 0x1c, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, + 0x6e, 0x74, 0x2e, 0x45, 0x6e, 0x63, 0x72, 0x79, 0x70, 0x74, 0x65, 0x64, 0x4d, 0x65, 0x73, 0x73, + 0x61, 0x67, 0x65, 0x22, 0x00, 0x30, 0x01, 0x12, 0x42, 0x0a, 0x0c, 0x47, 0x65, 0x74, 0x53, 0x65, + 0x72, 0x76, 0x65, 0x72, 0x4b, 0x65, 0x79, 0x12, 0x11, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, + 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x45, 0x6d, 0x70, 0x74, 0x79, 0x1a, 0x1d, 0x2e, 0x6d, 0x61, 0x6e, + 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x53, 0x65, 0x72, 0x76, 0x65, 0x72, 0x4b, 0x65, + 0x79, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x22, 0x00, 0x12, 0x33, 0x0a, 0x09, 0x69, + 0x73, 0x48, 0x65, 0x61, 0x6c, 0x74, 0x68, 0x79, 0x12, 0x11, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, + 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x45, 0x6d, 0x70, 0x74, 0x79, 0x1a, 0x11, 0x2e, 0x6d, 0x61, + 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x45, 0x6d, 0x70, 0x74, 0x79, 0x22, 0x00, + 0x12, 0x5a, 0x0a, 0x1a, 0x47, 0x65, 0x74, 0x44, 0x65, 0x76, 0x69, 0x63, 0x65, 0x41, 0x75, 0x74, + 0x68, 0x6f, 0x72, 0x69, 0x7a, 0x61, 0x74, 0x69, 0x6f, 0x6e, 0x46, 0x6c, 0x6f, 0x77, 0x12, 0x1c, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x45, 0x6e, 0x63, 0x72, - 0x79, 0x70, 0x74, 0x65, 0x64, 0x4d, 0x65, 0x73, 0x73, 0x61, 0x67, 0x65, 0x22, 0x00, 0x12, 0x58, - 0x0a, 0x18, 0x47, 0x65, 0x74, 0x50, 0x4b, 0x43, 0x45, 0x41, 0x75, 0x74, 0x68, 0x6f, 0x72, 0x69, - 0x7a, 0x61, 0x74, 0x69, 0x6f, 0x6e, 0x46, 0x6c, 0x6f, 0x77, 0x12, 0x1c, 0x2e, 0x6d, 0x61, 0x6e, - 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x45, 0x6e, 0x63, 0x72, 0x79, 0x70, 0x74, 0x65, - 0x64, 0x4d, 0x65, 0x73, 0x73, 0x61, 0x67, 0x65, 0x1a, 0x1c, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, + 0x79, 0x70, 0x74, 0x65, 0x64, 0x4d, 0x65, 0x73, 0x73, 0x61, 0x67, 0x65, 0x1a, 0x1c, 0x2e, 0x6d, + 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x45, 0x6e, 0x63, 0x72, 0x79, 0x70, + 0x74, 0x65, 0x64, 0x4d, 0x65, 0x73, 0x73, 0x61, 0x67, 0x65, 0x22, 0x00, 0x12, 0x58, 0x0a, 0x18, + 0x47, 0x65, 0x74, 0x50, 0x4b, 0x43, 0x45, 0x41, 0x75, 0x74, 0x68, 0x6f, 0x72, 0x69, 0x7a, 0x61, + 0x74, 0x69, 0x6f, 0x6e, 0x46, 0x6c, 0x6f, 0x77, 0x12, 0x1c, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x45, 0x6e, 0x63, 0x72, 0x79, 0x70, 0x74, 0x65, 0x64, 0x4d, - 0x65, 0x73, 0x73, 0x61, 0x67, 0x65, 0x22, 0x00, 0x12, 0x3d, 0x0a, 0x08, 0x53, 0x79, 0x6e, 0x63, - 0x4d, 0x65, 0x74, 0x61, 0x12, 0x1c, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, - 0x74, 0x2e, 0x45, 0x6e, 0x63, 0x72, 0x79, 0x70, 0x74, 0x65, 0x64, 0x4d, 0x65, 0x73, 0x73, 0x61, - 0x67, 0x65, 0x1a, 0x11, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, - 0x45, 0x6d, 0x70, 0x74, 0x79, 0x22, 0x00, 0x42, 0x08, 0x5a, 0x06, 0x2f, 0x70, 0x72, 0x6f, 0x74, - 0x6f, 0x62, 0x06, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x33, + 0x65, 0x73, 0x73, 0x61, 0x67, 0x65, 0x1a, 0x1c, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, + 0x65, 0x6e, 0x74, 0x2e, 0x45, 0x6e, 0x63, 0x72, 0x79, 0x70, 0x74, 0x65, 0x64, 0x4d, 0x65, 0x73, + 0x73, 0x61, 0x67, 0x65, 0x22, 0x00, 0x12, 0x3d, 0x0a, 0x08, 0x53, 0x79, 0x6e, 0x63, 0x4d, 0x65, + 0x74, 0x61, 0x12, 0x1c, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, + 0x45, 0x6e, 0x63, 0x72, 0x79, 0x70, 0x74, 0x65, 0x64, 0x4d, 0x65, 0x73, 0x73, 0x61, 0x67, 0x65, + 0x1a, 0x11, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x45, 0x6d, + 0x70, 0x74, 0x79, 0x22, 0x00, 0x42, 0x08, 0x5a, 0x06, 0x2f, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x62, + 0x06, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x33, } var ( diff --git a/management/proto/management.proto b/management/proto/management.proto index fe6a828b1..1e2931065 100644 --- a/management/proto/management.proto +++ b/management/proto/management.proto @@ -396,6 +396,7 @@ enum RuleProtocol { TCP = 2; UDP = 3; ICMP = 4; + CUSTOM = 5; } enum RuleDirection { @@ -459,5 +460,11 @@ message RouteFirewallRule { // IsDynamic indicates if the route is a DNS route. bool isDynamic = 6; + + // Domains is a list of domains for which the rule is applicable. + repeated string domains = 7; + + // CustomProtocol is a custom protocol ID. + uint32 customProtocol = 8; } diff --git a/management/server/account.go b/management/server/account.go index fbe6fcc1a..9e91a54b4 100644 --- a/management/server/account.go +++ b/management/server/account.go @@ -19,8 +19,6 @@ import ( "github.com/eko/gocache/v3/cache" cacheStore "github.com/eko/gocache/v3/store" - "github.com/hashicorp/go-multierror" - "github.com/miekg/dns" gocache "github.com/patrickmn/go-cache" "github.com/rs/xid" log "github.com/sirupsen/logrus" @@ -29,31 +27,30 @@ import ( "github.com/netbirdio/netbird/base62" nbdns "github.com/netbirdio/netbird/dns" "github.com/netbirdio/netbird/management/domain" - "github.com/netbirdio/netbird/management/server/account" "github.com/netbirdio/netbird/management/server/activity" "github.com/netbirdio/netbird/management/server/geolocation" - nbgroup "github.com/netbirdio/netbird/management/server/group" "github.com/netbirdio/netbird/management/server/idp" "github.com/netbirdio/netbird/management/server/integrated_validator" - "github.com/netbirdio/netbird/management/server/integration_reference" "github.com/netbirdio/netbird/management/server/jwtclaims" + "github.com/netbirdio/netbird/management/server/networks" nbpeer "github.com/netbirdio/netbird/management/server/peer" + "github.com/netbirdio/netbird/management/server/permissions" "github.com/netbirdio/netbird/management/server/posture" + "github.com/netbirdio/netbird/management/server/settings" "github.com/netbirdio/netbird/management/server/status" + "github.com/netbirdio/netbird/management/server/store" "github.com/netbirdio/netbird/management/server/telemetry" + "github.com/netbirdio/netbird/management/server/types" + "github.com/netbirdio/netbird/management/server/users" + "github.com/netbirdio/netbird/management/server/util" "github.com/netbirdio/netbird/route" ) const ( - PublicCategory = "public" - PrivateCategory = "private" - UnknownCategory = "unknown" - CacheExpirationMax = 7 * 24 * 3600 * time.Second // 7 days - CacheExpirationMin = 3 * 24 * 3600 * time.Second // 3 days - DefaultPeerLoginExpiration = 24 * time.Hour - DefaultPeerInactivityExpiration = 10 * time.Minute - emptyUserID = "empty user ID in claims" - errorGettingDomainAccIDFmt = "error getting account ID by private domain: %v" + CacheExpirationMax = 7 * 24 * 3600 * time.Second // 7 days + CacheExpirationMin = 3 * 24 * 3600 * time.Second // 3 days + emptyUserID = "empty user ID in claims" + errorGettingDomainAccIDFmt = "error getting account ID by private domain: %v" ) type userLoggedInOnce bool @@ -66,56 +63,56 @@ func cacheEntryExpiration() time.Duration { } type AccountManager interface { - GetOrCreateAccountByUser(ctx context.Context, userId, domain string) (*Account, error) - GetAccount(ctx context.Context, accountID string) (*Account, error) - CreateSetupKey(ctx context.Context, accountID string, keyName string, keyType SetupKeyType, expiresIn time.Duration, - autoGroups []string, usageLimit int, userID string, ephemeral bool) (*SetupKey, error) - SaveSetupKey(ctx context.Context, accountID string, key *SetupKey, userID string) (*SetupKey, error) - CreateUser(ctx context.Context, accountID, initiatorUserID string, key *UserInfo) (*UserInfo, error) + GetOrCreateAccountByUser(ctx context.Context, userId, domain string) (*types.Account, error) + GetAccount(ctx context.Context, accountID string) (*types.Account, error) + CreateSetupKey(ctx context.Context, accountID string, keyName string, keyType types.SetupKeyType, expiresIn time.Duration, + autoGroups []string, usageLimit int, userID string, ephemeral bool) (*types.SetupKey, error) + SaveSetupKey(ctx context.Context, accountID string, key *types.SetupKey, userID string) (*types.SetupKey, error) + CreateUser(ctx context.Context, accountID, initiatorUserID string, key *types.UserInfo) (*types.UserInfo, error) DeleteUser(ctx context.Context, accountID, initiatorUserID string, targetUserID string) error DeleteRegularUsers(ctx context.Context, accountID, initiatorUserID string, targetUserIDs []string) error InviteUser(ctx context.Context, accountID string, initiatorUserID string, targetUserID string) error - ListSetupKeys(ctx context.Context, accountID, userID string) ([]*SetupKey, error) - SaveUser(ctx context.Context, accountID, initiatorUserID string, update *User) (*UserInfo, error) - SaveOrAddUser(ctx context.Context, accountID, initiatorUserID string, update *User, addIfNotExists bool) (*UserInfo, error) - SaveOrAddUsers(ctx context.Context, accountID, initiatorUserID string, updates []*User, addIfNotExists bool) ([]*UserInfo, error) - GetSetupKey(ctx context.Context, accountID, userID, keyID string) (*SetupKey, error) - GetAccountByID(ctx context.Context, accountID string, userID string) (*Account, error) + ListSetupKeys(ctx context.Context, accountID, userID string) ([]*types.SetupKey, error) + SaveUser(ctx context.Context, accountID, initiatorUserID string, update *types.User) (*types.UserInfo, error) + SaveOrAddUser(ctx context.Context, accountID, initiatorUserID string, update *types.User, addIfNotExists bool) (*types.UserInfo, error) + SaveOrAddUsers(ctx context.Context, accountID, initiatorUserID string, updates []*types.User, addIfNotExists bool) ([]*types.UserInfo, error) + GetSetupKey(ctx context.Context, accountID, userID, keyID string) (*types.SetupKey, error) + GetAccountByID(ctx context.Context, accountID string, userID string) (*types.Account, error) AccountExists(ctx context.Context, accountID string) (bool, error) GetAccountIDByUserID(ctx context.Context, userID, domain string) (string, error) GetAccountIDFromToken(ctx context.Context, claims jwtclaims.AuthorizationClaims) (string, string, error) CheckUserAccessByJWTGroups(ctx context.Context, claims jwtclaims.AuthorizationClaims) error - GetAccountFromPAT(ctx context.Context, pat string) (*Account, *User, *PersonalAccessToken, error) + GetAccountFromPAT(ctx context.Context, pat string) (*types.Account, *types.User, *types.PersonalAccessToken, error) DeleteAccount(ctx context.Context, accountID, userID string) error MarkPATUsed(ctx context.Context, tokenID string) error - GetUserByID(ctx context.Context, id string) (*User, error) - GetUser(ctx context.Context, claims jwtclaims.AuthorizationClaims) (*User, error) - ListUsers(ctx context.Context, accountID string) ([]*User, error) + GetUserByID(ctx context.Context, id string) (*types.User, error) + GetUser(ctx context.Context, claims jwtclaims.AuthorizationClaims) (*types.User, error) + ListUsers(ctx context.Context, accountID string) ([]*types.User, error) GetPeers(ctx context.Context, accountID, userID string) ([]*nbpeer.Peer, error) - MarkPeerConnected(ctx context.Context, peerKey string, connected bool, realIP net.IP, account *Account) error + MarkPeerConnected(ctx context.Context, peerKey string, connected bool, realIP net.IP, account *types.Account) error DeletePeer(ctx context.Context, accountID, peerID, userID string) error UpdatePeer(ctx context.Context, accountID, userID string, peer *nbpeer.Peer) (*nbpeer.Peer, error) - GetNetworkMap(ctx context.Context, peerID string) (*NetworkMap, error) - GetPeerNetwork(ctx context.Context, peerID string) (*Network, error) - AddPeer(ctx context.Context, setupKey, userID string, peer *nbpeer.Peer) (*nbpeer.Peer, *NetworkMap, []*posture.Checks, error) - CreatePAT(ctx context.Context, accountID string, initiatorUserID string, targetUserID string, tokenName string, expiresIn int) (*PersonalAccessTokenGenerated, error) + GetNetworkMap(ctx context.Context, peerID string) (*types.NetworkMap, error) + GetPeerNetwork(ctx context.Context, peerID string) (*types.Network, error) + AddPeer(ctx context.Context, setupKey, userID string, peer *nbpeer.Peer) (*nbpeer.Peer, *types.NetworkMap, []*posture.Checks, error) + CreatePAT(ctx context.Context, accountID string, initiatorUserID string, targetUserID string, tokenName string, expiresIn int) (*types.PersonalAccessTokenGenerated, error) DeletePAT(ctx context.Context, accountID string, initiatorUserID string, targetUserID string, tokenID string) error - GetPAT(ctx context.Context, accountID string, initiatorUserID string, targetUserID string, tokenID string) (*PersonalAccessToken, error) - GetAllPATs(ctx context.Context, accountID string, initiatorUserID string, targetUserID string) ([]*PersonalAccessToken, error) - GetUsersFromAccount(ctx context.Context, accountID, userID string) ([]*UserInfo, error) - GetGroup(ctx context.Context, accountId, groupID, userID string) (*nbgroup.Group, error) - GetAllGroups(ctx context.Context, accountID, userID string) ([]*nbgroup.Group, error) - GetGroupByName(ctx context.Context, groupName, accountID string) (*nbgroup.Group, error) - SaveGroup(ctx context.Context, accountID, userID string, group *nbgroup.Group) error - SaveGroups(ctx context.Context, accountID, userID string, newGroups []*nbgroup.Group) error + GetPAT(ctx context.Context, accountID string, initiatorUserID string, targetUserID string, tokenID string) (*types.PersonalAccessToken, error) + GetAllPATs(ctx context.Context, accountID string, initiatorUserID string, targetUserID string) ([]*types.PersonalAccessToken, error) + GetUsersFromAccount(ctx context.Context, accountID, userID string) ([]*types.UserInfo, error) + GetGroup(ctx context.Context, accountId, groupID, userID string) (*types.Group, error) + GetAllGroups(ctx context.Context, accountID, userID string) ([]*types.Group, error) + GetGroupByName(ctx context.Context, groupName, accountID string) (*types.Group, error) + SaveGroup(ctx context.Context, accountID, userID string, group *types.Group) error + SaveGroups(ctx context.Context, accountID, userID string, newGroups []*types.Group) error DeleteGroup(ctx context.Context, accountId, userId, groupID string) error DeleteGroups(ctx context.Context, accountId, userId string, groupIDs []string) error GroupAddPeer(ctx context.Context, accountId, groupID, peerID string) error GroupDeletePeer(ctx context.Context, accountId, groupID, peerID string) error - GetPolicy(ctx context.Context, accountID, policyID, userID string) (*Policy, error) - SavePolicy(ctx context.Context, accountID, userID string, policy *Policy) (*Policy, error) + GetPolicy(ctx context.Context, accountID, policyID, userID string) (*types.Policy, error) + SavePolicy(ctx context.Context, accountID, userID string, policy *types.Policy) (*types.Policy, error) DeletePolicy(ctx context.Context, accountID, policyID, userID string) error - ListPolicies(ctx context.Context, accountID, userID string) ([]*Policy, error) + ListPolicies(ctx context.Context, accountID, userID string) ([]*types.Policy, error) GetRoute(ctx context.Context, accountID string, routeID route.ID, userID string) (*route.Route, error) CreateRoute(ctx context.Context, accountID string, prefix netip.Prefix, networkType route.NetworkType, domains domain.List, peerID string, peerGroupIDs []string, description string, netID route.NetID, masquerade bool, metric int, groups, accessControlGroupIDs []string, enabled bool, userID string, keepRoute bool) (*route.Route, error) SaveRoute(ctx context.Context, accountID, userID string, route *route.Route) error @@ -129,12 +126,12 @@ type AccountManager interface { GetDNSDomain() string StoreEvent(ctx context.Context, initiatorID, targetID, accountID string, activityID activity.ActivityDescriber, meta map[string]any) GetEvents(ctx context.Context, accountID, userID string) ([]*activity.Event, error) - GetDNSSettings(ctx context.Context, accountID string, userID string) (*DNSSettings, error) - SaveDNSSettings(ctx context.Context, accountID string, userID string, dnsSettingsToSave *DNSSettings) error + GetDNSSettings(ctx context.Context, accountID string, userID string) (*types.DNSSettings, error) + SaveDNSSettings(ctx context.Context, accountID string, userID string, dnsSettingsToSave *types.DNSSettings) error GetPeer(ctx context.Context, accountID, peerID, userID string) (*nbpeer.Peer, error) - UpdateAccountSettings(ctx context.Context, accountID, userID string, newSettings *Settings) (*Account, error) - LoginPeer(ctx context.Context, login PeerLogin) (*nbpeer.Peer, *NetworkMap, []*posture.Checks, error) // used by peer gRPC API - SyncPeer(ctx context.Context, sync PeerSync, account *Account) (*nbpeer.Peer, *NetworkMap, []*posture.Checks, error) // used by peer gRPC API + UpdateAccountSettings(ctx context.Context, accountID, userID string, newSettings *types.Settings) (*types.Account, error) + LoginPeer(ctx context.Context, login PeerLogin) (*nbpeer.Peer, *types.NetworkMap, []*posture.Checks, error) // used by peer gRPC API + SyncPeer(ctx context.Context, sync PeerSync, account *types.Account) (*nbpeer.Peer, *types.NetworkMap, []*posture.Checks, error) // used by peer gRPC API GetAllConnectedPeers() (map[string]struct{}, error) HasConnectedChannel(peerID string) bool GetExternalCacheManager() ExternalCacheManager @@ -145,18 +142,20 @@ type AccountManager interface { GetIdpManager() idp.Manager UpdateIntegratedValidatorGroups(ctx context.Context, accountID string, userID string, groups []string) error GroupValidation(ctx context.Context, accountId string, groups []string) (bool, error) - GetValidatedPeers(account *Account) (map[string]struct{}, error) - SyncAndMarkPeer(ctx context.Context, accountID string, peerPubKey string, meta nbpeer.PeerSystemMeta, realIP net.IP) (*nbpeer.Peer, *NetworkMap, []*posture.Checks, error) + GetValidatedPeers(account *types.Account) (map[string]struct{}, error) + SyncAndMarkPeer(ctx context.Context, accountID string, peerPubKey string, meta nbpeer.PeerSystemMeta, realIP net.IP) (*nbpeer.Peer, *types.NetworkMap, []*posture.Checks, error) OnPeerDisconnected(ctx context.Context, accountID string, peerPubKey string) error SyncPeerMeta(ctx context.Context, peerPubKey string, meta nbpeer.PeerSystemMeta) error FindExistingPostureCheck(accountID string, checks *posture.ChecksDefinition) (*posture.Checks, error) GetAccountIDForPeerKey(ctx context.Context, peerKey string) (string, error) - GetAccountSettings(ctx context.Context, accountID string, userID string) (*Settings, error) + GetAccountSettings(ctx context.Context, accountID string, userID string) (*types.Settings, error) DeleteSetupKey(ctx context.Context, accountID, userID, keyID string) error + GetNetworksManager() networks.Manager + GetUserManager() users.Manager } type DefaultAccountManager struct { - Store Store + Store store.Store // cacheMux and cacheLoading helps to make sure that only a single cache reload runs at a time per accountID cacheMux sync.Mutex // cacheLoading keeps the accountIDs that are currently reloading. The accountID has to be removed once cache has been reloaded @@ -189,765 +188,47 @@ type DefaultAccountManager struct { integratedPeerValidator integrated_validator.IntegratedValidator metrics telemetry.AppMetrics -} -// Settings represents Account settings structure that can be modified via API and Dashboard -type Settings struct { - // PeerLoginExpirationEnabled globally enables or disables peer login expiration - PeerLoginExpirationEnabled bool - - // PeerLoginExpiration is a setting that indicates when peer login expires. - // Applies to all peers that have Peer.LoginExpirationEnabled set to true. - PeerLoginExpiration time.Duration - - // PeerInactivityExpirationEnabled globally enables or disables peer inactivity expiration - PeerInactivityExpirationEnabled bool - - // PeerInactivityExpiration is a setting that indicates when peer inactivity expires. - // Applies to all peers that have Peer.PeerInactivityExpirationEnabled set to true. - PeerInactivityExpiration time.Duration - - // RegularUsersViewBlocked allows to block regular users from viewing even their own peers and some UI elements - RegularUsersViewBlocked bool - - // GroupsPropagationEnabled allows to propagate auto groups from the user to the peer - GroupsPropagationEnabled bool - - // JWTGroupsEnabled allows extract groups from JWT claim, which name defined in the JWTGroupsClaimName - // and add it to account groups. - JWTGroupsEnabled bool - - // JWTGroupsClaimName from which we extract groups name to add it to account groups - JWTGroupsClaimName string - - // JWTAllowGroups list of groups to which users are allowed access - JWTAllowGroups []string `gorm:"serializer:json"` - - // Extra is a dictionary of Account settings - Extra *account.ExtraSettings `gorm:"embedded;embeddedPrefix:extra_"` -} - -// Copy copies the Settings struct -func (s *Settings) Copy() *Settings { - settings := &Settings{ - PeerLoginExpirationEnabled: s.PeerLoginExpirationEnabled, - PeerLoginExpiration: s.PeerLoginExpiration, - JWTGroupsEnabled: s.JWTGroupsEnabled, - JWTGroupsClaimName: s.JWTGroupsClaimName, - GroupsPropagationEnabled: s.GroupsPropagationEnabled, - JWTAllowGroups: s.JWTAllowGroups, - RegularUsersViewBlocked: s.RegularUsersViewBlocked, - - PeerInactivityExpirationEnabled: s.PeerInactivityExpirationEnabled, - PeerInactivityExpiration: s.PeerInactivityExpiration, - } - if s.Extra != nil { - settings.Extra = s.Extra.Copy() - } - return settings -} - -// Account represents a unique account of the system -type Account struct { - // we have to name column to aid as it collides with Network.Id when work with associations - Id string `gorm:"primaryKey"` - - // User.Id it was created by - CreatedBy string - CreatedAt time.Time - Domain string `gorm:"index"` - DomainCategory string - IsDomainPrimaryAccount bool - SetupKeys map[string]*SetupKey `gorm:"-"` - SetupKeysG []SetupKey `json:"-" gorm:"foreignKey:AccountID;references:id"` - Network *Network `gorm:"embedded;embeddedPrefix:network_"` - Peers map[string]*nbpeer.Peer `gorm:"-"` - PeersG []nbpeer.Peer `json:"-" gorm:"foreignKey:AccountID;references:id"` - Users map[string]*User `gorm:"-"` - UsersG []User `json:"-" gorm:"foreignKey:AccountID;references:id"` - Groups map[string]*nbgroup.Group `gorm:"-"` - GroupsG []nbgroup.Group `json:"-" gorm:"foreignKey:AccountID;references:id"` - Policies []*Policy `gorm:"foreignKey:AccountID;references:id"` - Routes map[route.ID]*route.Route `gorm:"-"` - RoutesG []route.Route `json:"-" gorm:"foreignKey:AccountID;references:id"` - NameServerGroups map[string]*nbdns.NameServerGroup `gorm:"-"` - NameServerGroupsG []nbdns.NameServerGroup `json:"-" gorm:"foreignKey:AccountID;references:id"` - DNSSettings DNSSettings `gorm:"embedded;embeddedPrefix:dns_settings_"` - PostureChecks []*posture.Checks `gorm:"foreignKey:AccountID;references:id"` - // Settings is a dictionary of Account settings - Settings *Settings `gorm:"embedded;embeddedPrefix:settings_"` -} - -// Subclass used in gorm to only load settings and not whole account -type AccountSettings struct { - Settings *Settings `gorm:"embedded;embeddedPrefix:settings_"` -} - -// Subclass used in gorm to only load network and not whole account -type AccountNetwork struct { - Network *Network `gorm:"embedded;embeddedPrefix:network_"` -} - -// AccountDNSSettings used in gorm to only load dns settings and not whole account -type AccountDNSSettings struct { - DNSSettings DNSSettings `gorm:"embedded;embeddedPrefix:dns_settings_"` -} - -type UserPermissions struct { - DashboardView string `json:"dashboard_view"` -} - -type UserInfo struct { - ID string `json:"id"` - Email string `json:"email"` - Name string `json:"name"` - Role string `json:"role"` - AutoGroups []string `json:"auto_groups"` - Status string `json:"-"` - IsServiceUser bool `json:"is_service_user"` - IsBlocked bool `json:"is_blocked"` - NonDeletable bool `json:"non_deletable"` - LastLogin time.Time `json:"last_login"` - Issued string `json:"issued"` - IntegrationReference integration_reference.IntegrationReference `json:"-"` - Permissions UserPermissions `json:"permissions"` -} - -// getRoutesToSync returns the enabled routes for the peer ID and the routes -// from the ACL peers that have distribution groups associated with the peer ID. -// Please mind, that the returned route.Route objects will contain Peer.Key instead of Peer.ID. -func (a *Account) getRoutesToSync(ctx context.Context, peerID string, aclPeers []*nbpeer.Peer) []*route.Route { - routes, peerDisabledRoutes := a.getRoutingPeerRoutes(ctx, peerID) - peerRoutesMembership := make(lookupMap) - for _, r := range append(routes, peerDisabledRoutes...) { - peerRoutesMembership[string(r.GetHAUniqueID())] = struct{}{} - } - - groupListMap := a.getPeerGroups(peerID) - for _, peer := range aclPeers { - activeRoutes, _ := a.getRoutingPeerRoutes(ctx, peer.ID) - groupFilteredRoutes := a.filterRoutesByGroups(activeRoutes, groupListMap) - filteredRoutes := a.filterRoutesFromPeersOfSameHAGroup(groupFilteredRoutes, peerRoutesMembership) - routes = append(routes, filteredRoutes...) - } - - return routes -} - -// filterRoutesFromPeersOfSameHAGroup filters and returns a list of routes that don't share the same HA route membership -func (a *Account) filterRoutesFromPeersOfSameHAGroup(routes []*route.Route, peerMemberships lookupMap) []*route.Route { - var filteredRoutes []*route.Route - for _, r := range routes { - _, found := peerMemberships[string(r.GetHAUniqueID())] - if !found { - filteredRoutes = append(filteredRoutes, r) - } - } - return filteredRoutes -} - -// filterRoutesByGroups returns a list with routes that have distribution groups in the group's map -func (a *Account) filterRoutesByGroups(routes []*route.Route, groupListMap lookupMap) []*route.Route { - var filteredRoutes []*route.Route - for _, r := range routes { - for _, groupID := range r.Groups { - _, found := groupListMap[groupID] - if found { - filteredRoutes = append(filteredRoutes, r) - break - } - } - } - return filteredRoutes -} - -// getRoutingPeerRoutes returns the enabled and disabled lists of routes that the given routing peer serves -// Please mind, that the returned route.Route objects will contain Peer.Key instead of Peer.ID. -// If the given is not a routing peer, then the lists are empty. -func (a *Account) getRoutingPeerRoutes(ctx context.Context, peerID string) (enabledRoutes []*route.Route, disabledRoutes []*route.Route) { - - peer := a.GetPeer(peerID) - if peer == nil { - log.WithContext(ctx).Errorf("peer %s that doesn't exist under account %s", peerID, a.Id) - return enabledRoutes, disabledRoutes - } - - // currently we support only linux routing peers - if peer.Meta.GoOS != "linux" { - return enabledRoutes, disabledRoutes - } - - seenRoute := make(map[route.ID]struct{}) - - takeRoute := func(r *route.Route, id string) { - if _, ok := seenRoute[r.ID]; ok { - return - } - seenRoute[r.ID] = struct{}{} - - if r.Enabled { - r.Peer = peer.Key - enabledRoutes = append(enabledRoutes, r) - return - } - disabledRoutes = append(disabledRoutes, r) - } - - for _, r := range a.Routes { - for _, groupID := range r.PeerGroups { - group := a.GetGroup(groupID) - if group == nil { - log.WithContext(ctx).Errorf("route %s has peers group %s that doesn't exist under account %s", r.ID, groupID, a.Id) - continue - } - for _, id := range group.Peers { - if id != peerID { - continue - } - - newPeerRoute := r.Copy() - newPeerRoute.Peer = id - newPeerRoute.PeerGroups = nil - newPeerRoute.ID = route.ID(string(r.ID) + ":" + id) // we have to provide unique route id when distribute network map - takeRoute(newPeerRoute, id) - break - } - } - if r.Peer == peerID { - takeRoute(r.Copy(), peerID) - } - } - - return enabledRoutes, disabledRoutes -} - -// GetRoutesByPrefixOrDomains return list of routes by account and route prefix -func (a *Account) GetRoutesByPrefixOrDomains(prefix netip.Prefix, domains domain.List) []*route.Route { - var routes []*route.Route - for _, r := range a.Routes { - dynamic := r.IsDynamic() - if dynamic && r.Domains.PunycodeString() == domains.PunycodeString() || - !dynamic && r.Network.String() == prefix.String() { - routes = append(routes, r) - } - } - - return routes -} - -// GetGroup returns a group by ID if exists, nil otherwise -func (a *Account) GetGroup(groupID string) *nbgroup.Group { - return a.Groups[groupID] -} - -// GetPeerNetworkMap returns the networkmap for the given peer ID. -func (a *Account) GetPeerNetworkMap( - ctx context.Context, - peerID string, - peersCustomZone nbdns.CustomZone, - validatedPeersMap map[string]struct{}, - metrics *telemetry.AccountManagerMetrics, -) *NetworkMap { - start := time.Now() - - peer := a.Peers[peerID] - if peer == nil { - return &NetworkMap{ - Network: a.Network.Copy(), - } - } - - if _, ok := validatedPeersMap[peerID]; !ok { - return &NetworkMap{ - Network: a.Network.Copy(), - } - } - - aclPeers, firewallRules := a.getPeerConnectionResources(ctx, peerID, validatedPeersMap) - // exclude expired peers - var peersToConnect []*nbpeer.Peer - var expiredPeers []*nbpeer.Peer - for _, p := range aclPeers { - expired, _ := p.LoginExpired(a.Settings.PeerLoginExpiration) - if a.Settings.PeerLoginExpirationEnabled && expired { - expiredPeers = append(expiredPeers, p) - continue - } - peersToConnect = append(peersToConnect, p) - } - - routesUpdate := a.getRoutesToSync(ctx, peerID, peersToConnect) - routesFirewallRules := a.getPeerRoutesFirewallRules(ctx, peerID, validatedPeersMap) - - dnsManagementStatus := a.getPeerDNSManagementStatus(peerID) - dnsUpdate := nbdns.Config{ - ServiceEnable: dnsManagementStatus, - } - - if dnsManagementStatus { - var zones []nbdns.CustomZone - - if peersCustomZone.Domain != "" { - zones = append(zones, peersCustomZone) - } - dnsUpdate.CustomZones = zones - dnsUpdate.NameServerGroups = getPeerNSGroups(a, peerID) - } - - nm := &NetworkMap{ - Peers: peersToConnect, - Network: a.Network.Copy(), - Routes: routesUpdate, - DNSConfig: dnsUpdate, - OfflinePeers: expiredPeers, - FirewallRules: firewallRules, - RoutesFirewallRules: routesFirewallRules, - } - - if metrics != nil { - objectCount := int64(len(peersToConnect) + len(expiredPeers) + len(routesUpdate) + len(firewallRules)) - metrics.CountNetworkMapObjects(objectCount) - metrics.CountGetPeerNetworkMapDuration(time.Since(start)) - - if objectCount > 5000 { - log.WithContext(ctx).Tracef("account: %s has a total resource count of %d objects, "+ - "peers to connect: %d, expired peers: %d, routes: %d, firewall rules: %d", - a.Id, objectCount, len(peersToConnect), len(expiredPeers), len(routesUpdate), len(firewallRules)) - } - } - - return nm -} - -func (a *Account) GetPeersCustomZone(ctx context.Context, dnsDomain string) nbdns.CustomZone { - var merr *multierror.Error - - if dnsDomain == "" { - log.WithContext(ctx).Error("no dns domain is set, returning empty zone") - return nbdns.CustomZone{} - } - - customZone := nbdns.CustomZone{ - Domain: dns.Fqdn(dnsDomain), - Records: make([]nbdns.SimpleRecord, 0, len(a.Peers)), - } - - domainSuffix := "." + dnsDomain - - var sb strings.Builder - for _, peer := range a.Peers { - if peer.DNSLabel == "" { - merr = multierror.Append(merr, fmt.Errorf("peer %s has an empty DNS label", peer.Name)) - continue - } - - sb.Grow(len(peer.DNSLabel) + len(domainSuffix)) - sb.WriteString(peer.DNSLabel) - sb.WriteString(domainSuffix) - - customZone.Records = append(customZone.Records, nbdns.SimpleRecord{ - Name: sb.String(), - Type: int(dns.TypeA), - Class: nbdns.DefaultClass, - TTL: defaultTTL, - RData: peer.IP.String(), - }) - - sb.Reset() - } - - go func() { - if merr != nil { - log.WithContext(ctx).Errorf("error generating custom zone for account %s: %v", a.Id, merr) - } - }() - - return customZone -} - -// GetExpiredPeers returns peers that have been expired -func (a *Account) GetExpiredPeers() []*nbpeer.Peer { - var peers []*nbpeer.Peer - for _, peer := range a.GetPeersWithExpiration() { - expired, _ := peer.LoginExpired(a.Settings.PeerLoginExpiration) - if expired { - peers = append(peers, peer) - } - } - - return peers -} - -// GetNextPeerExpiration returns the minimum duration in which the next peer of the account will expire if it was found. -// If there is no peer that expires this function returns false and a duration of 0. -// This function only considers peers that haven't been expired yet and that are connected. -func (a *Account) GetNextPeerExpiration() (time.Duration, bool) { - peersWithExpiry := a.GetPeersWithExpiration() - if len(peersWithExpiry) == 0 { - return 0, false - } - var nextExpiry *time.Duration - for _, peer := range peersWithExpiry { - // consider only connected peers because others will require login on connecting to the management server - if peer.Status.LoginExpired || !peer.Status.Connected { - continue - } - _, duration := peer.LoginExpired(a.Settings.PeerLoginExpiration) - if nextExpiry == nil || duration < *nextExpiry { - // if expiration is below 1s return 1s duration - // this avoids issues with ticker that can't be set to < 0 - if duration < time.Second { - return time.Second, true - } - nextExpiry = &duration - } - } - - if nextExpiry == nil { - return 0, false - } - - return *nextExpiry, true -} - -// GetPeersWithExpiration returns a list of peers that have Peer.LoginExpirationEnabled set to true and that were added by a user -func (a *Account) GetPeersWithExpiration() []*nbpeer.Peer { - peers := make([]*nbpeer.Peer, 0) - for _, peer := range a.Peers { - if peer.LoginExpirationEnabled && peer.AddedWithSSOLogin() { - peers = append(peers, peer) - } - } - return peers -} - -// GetInactivePeers returns peers that have been expired by inactivity -func (a *Account) GetInactivePeers() []*nbpeer.Peer { - var peers []*nbpeer.Peer - for _, inactivePeer := range a.GetPeersWithInactivity() { - inactive, _ := inactivePeer.SessionExpired(a.Settings.PeerInactivityExpiration) - if inactive { - peers = append(peers, inactivePeer) - } - } - return peers -} - -// GetNextInactivePeerExpiration returns the minimum duration in which the next peer of the account will expire if it was found. -// If there is no peer that expires this function returns false and a duration of 0. -// This function only considers peers that haven't been expired yet and that are not connected. -func (a *Account) GetNextInactivePeerExpiration() (time.Duration, bool) { - peersWithExpiry := a.GetPeersWithInactivity() - if len(peersWithExpiry) == 0 { - return 0, false - } - var nextExpiry *time.Duration - for _, peer := range peersWithExpiry { - if peer.Status.LoginExpired || peer.Status.Connected { - continue - } - _, duration := peer.SessionExpired(a.Settings.PeerInactivityExpiration) - if nextExpiry == nil || duration < *nextExpiry { - // if expiration is below 1s return 1s duration - // this avoids issues with ticker that can't be set to < 0 - if duration < time.Second { - return time.Second, true - } - nextExpiry = &duration - } - } - - if nextExpiry == nil { - return 0, false - } - - return *nextExpiry, true -} - -// GetPeersWithInactivity eturns a list of peers that have Peer.InactivityExpirationEnabled set to true and that were added by a user -func (a *Account) GetPeersWithInactivity() []*nbpeer.Peer { - peers := make([]*nbpeer.Peer, 0) - for _, peer := range a.Peers { - if peer.InactivityExpirationEnabled && peer.AddedWithSSOLogin() { - peers = append(peers, peer) - } - } - return peers -} - -// GetPeers returns a list of all Account peers -func (a *Account) GetPeers() []*nbpeer.Peer { - var peers []*nbpeer.Peer - for _, peer := range a.Peers { - peers = append(peers, peer) - } - return peers -} - -// UpdateSettings saves new account settings -func (a *Account) UpdateSettings(update *Settings) *Account { - a.Settings = update.Copy() - return a -} - -// UpdatePeer saves new or replaces existing peer -func (a *Account) UpdatePeer(update *nbpeer.Peer) { - a.Peers[update.ID] = update -} - -// DeletePeer deletes peer from the account cleaning up all the references -func (a *Account) DeletePeer(peerID string) { - // delete peer from groups - for _, g := range a.Groups { - for i, pk := range g.Peers { - if pk == peerID { - g.Peers = append(g.Peers[:i], g.Peers[i+1:]...) - break - } - } - } - - for _, r := range a.Routes { - if r.Peer == peerID { - r.Enabled = false - r.Peer = "" - } - } - - delete(a.Peers, peerID) - a.Network.IncSerial() -} - -// FindPeerByPubKey looks for a Peer by provided WireGuard public key in the Account or returns error if it wasn't found. -// It will return an object copy of the peer. -func (a *Account) FindPeerByPubKey(peerPubKey string) (*nbpeer.Peer, error) { - for _, peer := range a.Peers { - if peer.Key == peerPubKey { - return peer.Copy(), nil - } - } - - return nil, status.Errorf(status.NotFound, "peer with the public key %s not found", peerPubKey) -} - -// FindUserPeers returns a list of peers that user owns (created) -func (a *Account) FindUserPeers(userID string) ([]*nbpeer.Peer, error) { - peers := make([]*nbpeer.Peer, 0) - for _, peer := range a.Peers { - if peer.UserID == userID { - peers = append(peers, peer) - } - } - - return peers, nil -} - -// FindUser looks for a given user in the Account or returns error if user wasn't found. -func (a *Account) FindUser(userID string) (*User, error) { - user := a.Users[userID] - if user == nil { - return nil, status.Errorf(status.NotFound, "user %s not found", userID) - } - - return user, nil -} - -// FindGroupByName looks for a given group in the Account by name or returns error if the group wasn't found. -func (a *Account) FindGroupByName(groupName string) (*nbgroup.Group, error) { - for _, group := range a.Groups { - if group.Name == groupName { - return group, nil - } - } - return nil, status.Errorf(status.NotFound, "group %s not found", groupName) -} - -// FindSetupKey looks for a given SetupKey in the Account or returns error if it wasn't found. -func (a *Account) FindSetupKey(setupKey string) (*SetupKey, error) { - key := a.SetupKeys[setupKey] - if key == nil { - return nil, status.Errorf(status.NotFound, "setup key not found") - } - - return key, nil -} - -// GetPeerGroupsList return with the list of groups ID. -func (a *Account) GetPeerGroupsList(peerID string) []string { - var grps []string - for groupID, group := range a.Groups { - for _, id := range group.Peers { - if id == peerID { - grps = append(grps, groupID) - break - } - } - } - return grps -} - -func (a *Account) getPeerDNSManagementStatus(peerID string) bool { - peerGroups := a.getPeerGroups(peerID) - enabled := true - for _, groupID := range a.DNSSettings.DisabledManagementGroups { - _, found := peerGroups[groupID] - if found { - enabled = false - break - } - } - return enabled -} - -func (a *Account) getPeerGroups(peerID string) lookupMap { - groupList := make(lookupMap) - for groupID, group := range a.Groups { - for _, id := range group.Peers { - if id == peerID { - groupList[groupID] = struct{}{} - break - } - } - } - return groupList -} - -func (a *Account) getTakenIPs() []net.IP { - var takenIps []net.IP - for _, existingPeer := range a.Peers { - takenIps = append(takenIps, existingPeer.IP) - } - - return takenIps -} - -func (a *Account) getPeerDNSLabels() lookupMap { - existingLabels := make(lookupMap) - for _, peer := range a.Peers { - if peer.DNSLabel != "" { - existingLabels[peer.DNSLabel] = struct{}{} - } - } - return existingLabels -} - -func (a *Account) Copy() *Account { - peers := map[string]*nbpeer.Peer{} - for id, peer := range a.Peers { - peers[id] = peer.Copy() - } - - users := map[string]*User{} - for id, user := range a.Users { - users[id] = user.Copy() - } - - setupKeys := map[string]*SetupKey{} - for id, key := range a.SetupKeys { - setupKeys[id] = key.Copy() - } - - groups := map[string]*nbgroup.Group{} - for id, group := range a.Groups { - groups[id] = group.Copy() - } - - policies := []*Policy{} - for _, policy := range a.Policies { - policies = append(policies, policy.Copy()) - } - - routes := map[route.ID]*route.Route{} - for id, r := range a.Routes { - routes[id] = r.Copy() - } - - nsGroups := map[string]*nbdns.NameServerGroup{} - for id, nsGroup := range a.NameServerGroups { - nsGroups[id] = nsGroup.Copy() - } - - dnsSettings := a.DNSSettings.Copy() - - var settings *Settings - if a.Settings != nil { - settings = a.Settings.Copy() - } - - postureChecks := []*posture.Checks{} - for _, postureCheck := range a.PostureChecks { - postureChecks = append(postureChecks, postureCheck.Copy()) - } - - return &Account{ - Id: a.Id, - CreatedBy: a.CreatedBy, - CreatedAt: a.CreatedAt, - Domain: a.Domain, - DomainCategory: a.DomainCategory, - IsDomainPrimaryAccount: a.IsDomainPrimaryAccount, - SetupKeys: setupKeys, - Network: a.Network.Copy(), - Peers: peers, - Users: users, - Groups: groups, - Policies: policies, - Routes: routes, - NameServerGroups: nsGroups, - DNSSettings: dnsSettings, - PostureChecks: postureChecks, - Settings: settings, - } -} - -func (a *Account) GetGroupAll() (*nbgroup.Group, error) { - for _, g := range a.Groups { - if g.Name == "All" { - return g, nil - } - } - return nil, fmt.Errorf("no group ALL found") -} - -// GetPeer looks up a Peer by ID -func (a *Account) GetPeer(peerID string) *nbpeer.Peer { - return a.Peers[peerID] + networksManager networks.Manager + userManager users.Manager + settingsManager settings.Manager + permissionsManager permissions.Manager } // getJWTGroupsChanges calculates the changes needed to sync a user's JWT groups. // Returns a bool indicating if there are changes in the JWT group membership, the updated user AutoGroups, // newly groups to create and an error if any occurred. -func (am *DefaultAccountManager) getJWTGroupsChanges(user *User, groups []*nbgroup.Group, groupNames []string) (bool, []string, []*nbgroup.Group, error) { - existedGroupsByName := make(map[string]*nbgroup.Group) +func (am *DefaultAccountManager) getJWTGroupsChanges(user *types.User, groups []*types.Group, groupNames []string) (bool, []string, []*types.Group, error) { + existedGroupsByName := make(map[string]*types.Group) for _, group := range groups { existedGroupsByName[group.Name] = group } newUserAutoGroups, jwtGroupsMap := separateGroups(user.AutoGroups, groups) - groupsToAdd := difference(groupNames, maps.Keys(jwtGroupsMap)) - groupsToRemove := difference(maps.Keys(jwtGroupsMap), groupNames) + groupsToAdd := util.Difference(groupNames, maps.Keys(jwtGroupsMap)) + groupsToRemove := util.Difference(maps.Keys(jwtGroupsMap), groupNames) // If no groups are added or removed, we should not sync account if len(groupsToAdd) == 0 && len(groupsToRemove) == 0 { return false, nil, nil, nil } - newGroupsToCreate := make([]*nbgroup.Group, 0) + newGroupsToCreate := make([]*types.Group, 0) var modified bool for _, name := range groupsToAdd { group, exists := existedGroupsByName[name] if !exists { - group = &nbgroup.Group{ + group = &types.Group{ ID: xid.New().String(), AccountID: user.AccountID, Name: name, - Issued: nbgroup.GroupIssuedJWT, + Issued: types.GroupIssuedJWT, } newGroupsToCreate = append(newGroupsToCreate, group) } - if group.Issued == nbgroup.GroupIssuedJWT { + if group.Issued == types.GroupIssuedJWT { newUserAutoGroups = append(newUserAutoGroups, group.ID) modified = true } @@ -964,78 +245,10 @@ func (am *DefaultAccountManager) getJWTGroupsChanges(user *User, groups []*nbgro return modified, newUserAutoGroups, newGroupsToCreate, nil } -// UserGroupsAddToPeers adds groups to all peers of user -func (a *Account) UserGroupsAddToPeers(userID string, groups ...string) map[string][]string { - groupUpdates := make(map[string][]string) - - userPeers := make(map[string]struct{}) - for pid, peer := range a.Peers { - if peer.UserID == userID { - userPeers[pid] = struct{}{} - } - } - - for _, gid := range groups { - group, ok := a.Groups[gid] - if !ok { - continue - } - - oldPeers := group.Peers - - groupPeers := make(map[string]struct{}) - for _, pid := range group.Peers { - groupPeers[pid] = struct{}{} - } - - for pid := range userPeers { - groupPeers[pid] = struct{}{} - } - - group.Peers = group.Peers[:0] - for pid := range groupPeers { - group.Peers = append(group.Peers, pid) - } - - groupUpdates[gid] = difference(group.Peers, oldPeers) - } - - return groupUpdates -} - -// UserGroupsRemoveFromPeers removes groups from all peers of user -func (a *Account) UserGroupsRemoveFromPeers(userID string, groups ...string) map[string][]string { - groupUpdates := make(map[string][]string) - - for _, gid := range groups { - group, ok := a.Groups[gid] - if !ok || group.Name == "All" { - continue - } - - oldPeers := group.Peers - - update := make([]string, 0, len(group.Peers)) - for _, pid := range group.Peers { - peer, ok := a.Peers[pid] - if !ok { - continue - } - if peer.UserID != userID { - update = append(update, pid) - } - } - group.Peers = update - groupUpdates[gid] = difference(oldPeers, group.Peers) - } - - return groupUpdates -} - // BuildManager creates a new DefaultAccountManager with a provided Store func BuildManager( ctx context.Context, - store Store, + store store.Store, peersUpdateManager *PeersUpdateManager, idpManager idp.Manager, singleAccountModeDomain string, @@ -1046,11 +259,18 @@ func BuildManager( integratedPeerValidator integrated_validator.IntegratedValidator, metrics telemetry.AppMetrics, ) (*DefaultAccountManager, error) { + userManager := users.NewManager(store) + settingsManager := settings.NewManager(store) + permissionsManager := permissions.NewManager(userManager, settingsManager) am := &DefaultAccountManager{ Store: store, geo: geo, peersUpdateManager: peersUpdateManager, idpManager: idpManager, + networksManager: networks.NewManager(store, permissionsManager), + userManager: userManager, + settingsManager: settingsManager, + permissionsManager: permissionsManager, ctx: context.Background(), cacheMux: sync.Mutex{}, cacheLoading: map[string]chan struct{}{}, @@ -1137,7 +357,7 @@ func (am *DefaultAccountManager) GetIdpManager() idp.Manager { // Only users with role UserRoleAdmin can update the account. // User that performs the update has to belong to the account. // Returns an updated Account -func (am *DefaultAccountManager) UpdateAccountSettings(ctx context.Context, accountID, userID string, newSettings *Settings) (*Account, error) { +func (am *DefaultAccountManager) UpdateAccountSettings(ctx context.Context, accountID, userID string, newSettings *types.Settings) (*types.Account, error) { halfYearLimit := 180 * 24 * time.Hour if newSettings.PeerLoginExpiration > halfYearLimit { return nil, status.Errorf(status.InvalidArgument, "peer login expiration can't be larger than 180 days") @@ -1206,7 +426,7 @@ func (am *DefaultAccountManager) UpdateAccountSettings(ctx context.Context, acco return updatedAccount, nil } -func (am *DefaultAccountManager) handleGroupsPropagationSettings(ctx context.Context, oldSettings, newSettings *Settings, userID, accountID string) error { +func (am *DefaultAccountManager) handleGroupsPropagationSettings(ctx context.Context, oldSettings, newSettings *types.Settings, userID, accountID string) error { if oldSettings.GroupsPropagationEnabled != newSettings.GroupsPropagationEnabled { if newSettings.GroupsPropagationEnabled { am.StoreEvent(ctx, userID, accountID, accountID, activity.UserGroupPropagationEnabled, nil) @@ -1219,7 +439,7 @@ func (am *DefaultAccountManager) handleGroupsPropagationSettings(ctx context.Con return nil } -func (am *DefaultAccountManager) handleInactivityExpirationSettings(ctx context.Context, account *Account, oldSettings, newSettings *Settings, userID, accountID string) error { +func (am *DefaultAccountManager) handleInactivityExpirationSettings(ctx context.Context, account *types.Account, oldSettings, newSettings *types.Settings, userID, accountID string) error { if newSettings.PeerInactivityExpirationEnabled { if oldSettings.PeerInactivityExpiration != newSettings.PeerInactivityExpiration { @@ -1272,7 +492,7 @@ func (am *DefaultAccountManager) peerLoginExpirationJob(ctx context.Context, acc } } -func (am *DefaultAccountManager) checkAndSchedulePeerLoginExpiration(ctx context.Context, account *Account) { +func (am *DefaultAccountManager) checkAndSchedulePeerLoginExpiration(ctx context.Context, account *types.Account) { am.peerLoginExpiry.Cancel(ctx, []string{account.Id}) if nextRun, ok := account.GetNextPeerExpiration(); ok { go am.peerLoginExpiry.Schedule(ctx, nextRun, account.Id, am.peerLoginExpirationJob(ctx, account.Id)) @@ -1309,7 +529,7 @@ func (am *DefaultAccountManager) peerInactivityExpirationJob(ctx context.Context } // checkAndSchedulePeerInactivityExpiration periodically checks for inactive peers to end their sessions -func (am *DefaultAccountManager) checkAndSchedulePeerInactivityExpiration(ctx context.Context, account *Account) { +func (am *DefaultAccountManager) checkAndSchedulePeerInactivityExpiration(ctx context.Context, account *types.Account) { am.peerInactivityExpiry.Cancel(ctx, []string{account.Id}) if nextRun, ok := account.GetNextInactivePeerExpiration(); ok { go am.peerInactivityExpiry.Schedule(ctx, nextRun, account.Id, am.peerInactivityExpirationJob(ctx, account.Id)) @@ -1318,7 +538,7 @@ func (am *DefaultAccountManager) checkAndSchedulePeerInactivityExpiration(ctx co // newAccount creates a new Account with a generated ID and generated default setup keys. // If ID is already in use (due to collision) we try one more time before returning error -func (am *DefaultAccountManager) newAccount(ctx context.Context, userID, domain string) (*Account, error) { +func (am *DefaultAccountManager) newAccount(ctx context.Context, userID, domain string) (*types.Account, error) { for i := 0; i < 2; i++ { accountId := xid.New().String() @@ -1398,7 +618,7 @@ func (am *DefaultAccountManager) DeleteAccount(ctx context.Context, accountID, u return status.Errorf(status.PermissionDenied, "user is not allowed to delete account") } - if user.Role != UserRoleOwner { + if user.Role != types.UserRoleOwner { return status.Errorf(status.PermissionDenied, "user is not allowed to delete account. Only account owner can delete account") } for _, otherUser := range account.Users { @@ -1436,7 +656,7 @@ func (am *DefaultAccountManager) DeleteAccount(ctx context.Context, accountID, u // AccountExists checks if an account exists. func (am *DefaultAccountManager) AccountExists(ctx context.Context, accountID string) (bool, error) { - return am.Store.AccountExists(ctx, LockingStrengthShare, accountID) + return am.Store.AccountExists(ctx, store.LockingStrengthShare, accountID) } // GetAccountIDByUserID retrieves the account ID based on the userID provided. @@ -1473,13 +693,13 @@ func isNil(i idp.Manager) bool { // addAccountIDToIDPAppMeta update user's app metadata in idp manager func (am *DefaultAccountManager) addAccountIDToIDPAppMeta(ctx context.Context, userID string, accountID string) error { if !isNil(am.idpManager) { - accountUsers, err := am.Store.GetAccountUsers(ctx, LockingStrengthShare, accountID) + accountUsers, err := am.Store.GetAccountUsers(ctx, store.LockingStrengthShare, accountID) if err != nil { return err } - cachedAccount := &Account{ + cachedAccount := &types.Account{ Id: accountID, - Users: make(map[string]*User), + Users: make(map[string]*types.User), } for _, user := range accountUsers { cachedAccount.Users[user.Id] = user @@ -1562,14 +782,14 @@ func (am *DefaultAccountManager) lookupUserInCacheByEmail(ctx context.Context, e } // lookupUserInCache looks up user in the IdP cache and returns it. If the user wasn't found, the function returns nil -func (am *DefaultAccountManager) lookupUserInCache(ctx context.Context, userID string, account *Account) (*idp.UserData, error) { +func (am *DefaultAccountManager) lookupUserInCache(ctx context.Context, userID string, account *types.Account) (*idp.UserData, error) { users := make(map[string]userLoggedInOnce, len(account.Users)) // ignore service users and users provisioned by integrations than are never logged in for _, user := range account.Users { if user.IsServiceUser { continue } - if user.Issued == UserIssuedIntegration { + if user.Issued == types.UserIssuedIntegration { continue } users[user.Id] = userLoggedInOnce(!user.LastLogin.IsZero()) @@ -1739,7 +959,7 @@ func (am *DefaultAccountManager) updateAccountDomainAttributesIfNotUpToDate(ctx unlockAccount := am.Store.AcquireWriteLockByUID(ctx, accountID) defer unlockAccount() - accountDomain, domainCategory, err := am.Store.GetAccountDomainAndCategory(ctx, LockingStrengthShare, accountID) + accountDomain, domainCategory, err := am.Store.GetAccountDomainAndCategory(ctx, store.LockingStrengthShare, accountID) if err != nil { log.WithContext(ctx).Errorf("error getting account domain and category: %v", err) return err @@ -1749,7 +969,7 @@ func (am *DefaultAccountManager) updateAccountDomainAttributesIfNotUpToDate(ctx return nil } - user, err := am.Store.GetUserByUserID(ctx, LockingStrengthShare, claims.UserId) + user, err := am.Store.GetUserByUserID(ctx, store.LockingStrengthShare, claims.UserId) if err != nil { log.WithContext(ctx).Errorf("error getting user: %v", err) return err @@ -1834,8 +1054,8 @@ func (am *DefaultAccountManager) addNewUserToDomainAccount(ctx context.Context, unlockAccount := am.Store.AcquireWriteLockByUID(ctx, domainAccountID) defer unlockAccount() - usersMap := make(map[string]*User) - usersMap[claims.UserId] = NewRegularUser(claims.UserId) + usersMap := make(map[string]*types.User) + usersMap[claims.UserId] = types.NewRegularUser(claims.UserId) err := am.Store.SaveUsers(domainAccountID, usersMap) if err != nil { return "", err @@ -1923,22 +1143,22 @@ func (am *DefaultAccountManager) MarkPATUsed(ctx context.Context, tokenID string } // GetAccount returns an account associated with this account ID. -func (am *DefaultAccountManager) GetAccount(ctx context.Context, accountID string) (*Account, error) { +func (am *DefaultAccountManager) GetAccount(ctx context.Context, accountID string) (*types.Account, error) { return am.Store.GetAccount(ctx, accountID) } // GetAccountFromPAT returns Account and User associated with a personal access token -func (am *DefaultAccountManager) GetAccountFromPAT(ctx context.Context, token string) (*Account, *User, *PersonalAccessToken, error) { - if len(token) != PATLength { +func (am *DefaultAccountManager) GetAccountFromPAT(ctx context.Context, token string) (*types.Account, *types.User, *types.PersonalAccessToken, error) { + if len(token) != types.PATLength { return nil, nil, nil, fmt.Errorf("token has wrong length") } - prefix := token[:len(PATPrefix)] - if prefix != PATPrefix { + prefix := token[:len(types.PATPrefix)] + if prefix != types.PATPrefix { return nil, nil, nil, fmt.Errorf("token has wrong prefix") } - secret := token[len(PATPrefix) : len(PATPrefix)+PATSecretLength] - encodedChecksum := token[len(PATPrefix)+PATSecretLength : len(PATPrefix)+PATSecretLength+PATChecksumLength] + secret := token[len(types.PATPrefix) : len(types.PATPrefix)+types.PATSecretLength] + encodedChecksum := token[len(types.PATPrefix)+types.PATSecretLength : len(types.PATPrefix)+types.PATSecretLength+types.PATChecksumLength] verificationChecksum, err := base62.Decode(encodedChecksum) if err != nil { @@ -1976,8 +1196,8 @@ func (am *DefaultAccountManager) GetAccountFromPAT(ctx context.Context, token st } // GetAccountByID returns an account associated with this account ID. -func (am *DefaultAccountManager) GetAccountByID(ctx context.Context, accountID string, userID string) (*Account, error) { - user, err := am.Store.GetUserByUserID(ctx, LockingStrengthShare, userID) +func (am *DefaultAccountManager) GetAccountByID(ctx context.Context, accountID string, userID string) (*types.Account, error) { + user, err := am.Store.GetUserByUserID(ctx, store.LockingStrengthShare, userID) if err != nil { return nil, err } @@ -1998,7 +1218,7 @@ func (am *DefaultAccountManager) GetAccountIDFromToken(ctx context.Context, clai // This section is mostly related to self-hosted installations. // We override incoming domain claims to group users under a single account. claims.Domain = am.singleAccountModeDomain - claims.DomainCategory = PrivateCategory + claims.DomainCategory = types.PrivateCategory log.WithContext(ctx).Debugf("overriding JWT Domain and DomainCategory claims since single account mode is enabled") } @@ -2007,7 +1227,7 @@ func (am *DefaultAccountManager) GetAccountIDFromToken(ctx context.Context, clai return "", "", err } - user, err := am.Store.GetUserByUserID(ctx, LockingStrengthShare, claims.UserId) + user, err := am.Store.GetUserByUserID(ctx, store.LockingStrengthShare, claims.UserId) if err != nil { // this is not really possible because we got an account by user ID return "", "", status.Errorf(status.NotFound, "user %s not found", claims.UserId) @@ -2034,7 +1254,7 @@ func (am *DefaultAccountManager) GetAccountIDFromToken(ctx context.Context, clai // syncJWTGroups processes the JWT groups for a user, updates the account based on the groups, // and propagates changes to peers if group propagation is enabled. func (am *DefaultAccountManager) syncJWTGroups(ctx context.Context, accountID string, claims jwtclaims.AuthorizationClaims) error { - settings, err := am.Store.GetAccountSettings(ctx, LockingStrengthShare, accountID) + settings, err := am.Store.GetAccountSettings(ctx, store.LockingStrengthShare, accountID) if err != nil { return err } @@ -2060,14 +1280,14 @@ func (am *DefaultAccountManager) syncJWTGroups(ctx context.Context, accountID st var addNewGroups []string var removeOldGroups []string var hasChanges bool - var user *User - err = am.Store.ExecuteInTransaction(ctx, func(transaction Store) error { - user, err = transaction.GetUserByUserID(ctx, LockingStrengthShare, claims.UserId) + var user *types.User + err = am.Store.ExecuteInTransaction(ctx, func(transaction store.Store) error { + user, err = transaction.GetUserByUserID(ctx, store.LockingStrengthShare, claims.UserId) if err != nil { return fmt.Errorf("error getting user: %w", err) } - groups, err := transaction.GetAccountGroups(ctx, LockingStrengthShare, accountID) + groups, err := transaction.GetAccountGroups(ctx, store.LockingStrengthShare, accountID) if err != nil { return fmt.Errorf("error getting account groups: %w", err) } @@ -2083,31 +1303,31 @@ func (am *DefaultAccountManager) syncJWTGroups(ctx context.Context, accountID st return nil } - if err = transaction.SaveGroups(ctx, LockingStrengthUpdate, newGroupsToCreate); err != nil { + if err = transaction.SaveGroups(ctx, store.LockingStrengthUpdate, newGroupsToCreate); err != nil { return fmt.Errorf("error saving groups: %w", err) } - addNewGroups = difference(updatedAutoGroups, user.AutoGroups) - removeOldGroups = difference(user.AutoGroups, updatedAutoGroups) + addNewGroups = util.Difference(updatedAutoGroups, user.AutoGroups) + removeOldGroups = util.Difference(user.AutoGroups, updatedAutoGroups) user.AutoGroups = updatedAutoGroups - if err = transaction.SaveUser(ctx, LockingStrengthUpdate, user); err != nil { + if err = transaction.SaveUser(ctx, store.LockingStrengthUpdate, user); err != nil { return fmt.Errorf("error saving user: %w", err) } // Propagate changes to peers if group propagation is enabled if settings.GroupsPropagationEnabled { - groups, err = transaction.GetAccountGroups(ctx, LockingStrengthShare, accountID) + groups, err = transaction.GetAccountGroups(ctx, store.LockingStrengthShare, accountID) if err != nil { return fmt.Errorf("error getting account groups: %w", err) } - groupsMap := make(map[string]*nbgroup.Group, len(groups)) + groupsMap := make(map[string]*types.Group, len(groups)) for _, group := range groups { groupsMap[group.ID] = group } - peers, err := transaction.GetUserPeers(ctx, LockingStrengthShare, accountID, claims.UserId) + peers, err := transaction.GetUserPeers(ctx, store.LockingStrengthShare, accountID, claims.UserId) if err != nil { return fmt.Errorf("error getting user peers: %w", err) } @@ -2117,11 +1337,11 @@ func (am *DefaultAccountManager) syncJWTGroups(ctx context.Context, accountID st return fmt.Errorf("error modifying user peers in groups: %w", err) } - if err = transaction.SaveGroups(ctx, LockingStrengthUpdate, updatedGroups); err != nil { + if err = transaction.SaveGroups(ctx, store.LockingStrengthUpdate, updatedGroups); err != nil { return fmt.Errorf("error saving groups: %w", err) } - if err = transaction.IncrementNetworkSerial(ctx, LockingStrengthUpdate, accountID); err != nil { + if err = transaction.IncrementNetworkSerial(ctx, store.LockingStrengthUpdate, accountID); err != nil { return fmt.Errorf("error incrementing network serial: %w", err) } } @@ -2139,7 +1359,7 @@ func (am *DefaultAccountManager) syncJWTGroups(ctx context.Context, accountID st } for _, g := range addNewGroups { - group, err := am.Store.GetGroupByID(ctx, LockingStrengthShare, accountID, g) + group, err := am.Store.GetGroupByID(ctx, store.LockingStrengthShare, accountID, g) if err != nil { log.WithContext(ctx).Debugf("group %s not found while saving user activity event of account %s", g, accountID) } else { @@ -2152,7 +1372,7 @@ func (am *DefaultAccountManager) syncJWTGroups(ctx context.Context, accountID st } for _, g := range removeOldGroups { - group, err := am.Store.GetGroupByID(ctx, LockingStrengthShare, accountID, g) + group, err := am.Store.GetGroupByID(ctx, store.LockingStrengthShare, accountID, g) if err != nil { log.WithContext(ctx).Debugf("group %s not found while saving user activity event of account %s", g, accountID) } else { @@ -2210,7 +1430,7 @@ func (am *DefaultAccountManager) getAccountIDWithAuthorizationClaims(ctx context return "", errors.New(emptyUserID) } - if claims.DomainCategory != PrivateCategory || !isDomainValid(claims.Domain) { + if claims.DomainCategory != types.PrivateCategory || !isDomainValid(claims.Domain) { return am.GetAccountIDByUserID(ctx, claims.UserId, claims.Domain) } @@ -2248,7 +1468,7 @@ func (am *DefaultAccountManager) getAccountIDWithAuthorizationClaims(ctx context return am.addNewPrivateAccount(ctx, domainAccountID, claims) } func (am *DefaultAccountManager) getPrivateDomainWithGlobalLock(ctx context.Context, domain string) (string, context.CancelFunc, error) { - domainAccountID, err := am.Store.GetAccountIDByPrivateDomain(ctx, LockingStrengthShare, domain) + domainAccountID, err := am.Store.GetAccountIDByPrivateDomain(ctx, store.LockingStrengthShare, domain) if handleNotFound(err) != nil { log.WithContext(ctx).Errorf(errorGettingDomainAccIDFmt, err) @@ -2263,7 +1483,7 @@ func (am *DefaultAccountManager) getPrivateDomainWithGlobalLock(ctx context.Cont cancel := am.Store.AcquireGlobalLock(ctx) // check again if the domain has a primary account because of simultaneous requests - domainAccountID, err = am.Store.GetAccountIDByPrivateDomain(ctx, LockingStrengthShare, domain) + domainAccountID, err = am.Store.GetAccountIDByPrivateDomain(ctx, store.LockingStrengthShare, domain) if handleNotFound(err) != nil { cancel() log.WithContext(ctx).Errorf(errorGettingDomainAccIDFmt, err) @@ -2284,7 +1504,7 @@ func (am *DefaultAccountManager) handlePrivateAccountWithIDFromClaim(ctx context return "", fmt.Errorf("user %s is not part of the account id %s", claims.UserId, claims.AccountId) } - accountDomain, domainCategory, err := am.Store.GetAccountDomainAndCategory(ctx, LockingStrengthShare, claims.AccountId) + accountDomain, domainCategory, err := am.Store.GetAccountDomainAndCategory(ctx, store.LockingStrengthShare, claims.AccountId) if handleNotFound(err) != nil { log.WithContext(ctx).Errorf("error getting account domain and category: %v", err) return "", err @@ -2295,7 +1515,7 @@ func (am *DefaultAccountManager) handlePrivateAccountWithIDFromClaim(ctx context } // We checked if the domain has a primary account already - domainAccountID, err := am.Store.GetAccountIDByPrivateDomain(ctx, LockingStrengthShare, claims.Domain) + domainAccountID, err := am.Store.GetAccountIDByPrivateDomain(ctx, store.LockingStrengthShare, claims.Domain) if handleNotFound(err) != nil { log.WithContext(ctx).Errorf(errorGettingDomainAccIDFmt, err) return "", err @@ -2322,10 +1542,10 @@ func handleNotFound(err error) error { } func domainIsUpToDate(domain string, domainCategory string, claims jwtclaims.AuthorizationClaims) bool { - return domainCategory == PrivateCategory || claims.DomainCategory != PrivateCategory || domain != claims.Domain + return domainCategory == types.PrivateCategory || claims.DomainCategory != types.PrivateCategory || domain != claims.Domain } -func (am *DefaultAccountManager) SyncAndMarkPeer(ctx context.Context, accountID string, peerPubKey string, meta nbpeer.PeerSystemMeta, realIP net.IP) (*nbpeer.Peer, *NetworkMap, []*posture.Checks, error) { +func (am *DefaultAccountManager) SyncAndMarkPeer(ctx context.Context, accountID string, peerPubKey string, meta nbpeer.PeerSystemMeta, realIP net.IP) (*nbpeer.Peer, *types.NetworkMap, []*posture.Checks, error) { accountUnlock := am.Store.AcquireReadLockByUID(ctx, accountID) defer accountUnlock() peerUnlock := am.Store.AcquireWriteLockByUID(ctx, peerPubKey) @@ -2422,7 +1642,7 @@ func (am *DefaultAccountManager) CheckUserAccessByJWTGroups(ctx context.Context, return err } - settings, err := am.Store.GetAccountSettings(ctx, LockingStrengthShare, accountID) + settings, err := am.Store.GetAccountSettings(ctx, store.LockingStrengthShare, accountID) if err != nil { return err } @@ -2455,8 +1675,8 @@ func (am *DefaultAccountManager) GetAccountIDForPeerKey(ctx context.Context, pee return am.Store.GetAccountIDByPeerPubKey(ctx, peerKey) } -func (am *DefaultAccountManager) handleUserPeer(ctx context.Context, peer *nbpeer.Peer, settings *Settings) (bool, error) { - user, err := am.Store.GetUserByUserID(ctx, LockingStrengthShare, peer.UserID) +func (am *DefaultAccountManager) handleUserPeer(ctx context.Context, peer *nbpeer.Peer, settings *types.Settings) (bool, error) { + user, err := am.Store.GetUserByUserID(ctx, store.LockingStrengthShare, peer.UserID) if err != nil { return false, err } @@ -2477,14 +1697,14 @@ func (am *DefaultAccountManager) handleUserPeer(ctx context.Context, peer *nbpee return false, nil } -func (am *DefaultAccountManager) getFreeDNSLabel(ctx context.Context, store Store, accountID string, peerHostName string) (string, error) { - existingLabels, err := store.GetPeerLabelsInAccount(ctx, LockingStrengthShare, accountID) +func (am *DefaultAccountManager) getFreeDNSLabel(ctx context.Context, s store.Store, accountID string, peerHostName string) (string, error) { + existingLabels, err := s.GetPeerLabelsInAccount(ctx, store.LockingStrengthShare, accountID) if err != nil { return "", fmt.Errorf("failed to get peer dns labels: %w", err) } labelMap := ConvertSliceToMap(existingLabels) - newLabel, err := getPeerHostLabel(peerHostName, labelMap) + newLabel, err := types.GetPeerHostLabel(peerHostName, labelMap) if err != nil { return "", fmt.Errorf("failed to get new host label: %w", err) } @@ -2496,8 +1716,8 @@ func (am *DefaultAccountManager) getFreeDNSLabel(ctx context.Context, store Stor return newLabel, nil } -func (am *DefaultAccountManager) GetAccountSettings(ctx context.Context, accountID string, userID string) (*Settings, error) { - user, err := am.Store.GetUserByUserID(ctx, LockingStrengthShare, userID) +func (am *DefaultAccountManager) GetAccountSettings(ctx context.Context, accountID string, userID string) (*types.Settings, error) { + user, err := am.Store.GetUserByUserID(ctx, store.LockingStrengthShare, userID) if err != nil { return nil, err } @@ -2506,70 +1726,78 @@ func (am *DefaultAccountManager) GetAccountSettings(ctx context.Context, account return nil, status.Errorf(status.PermissionDenied, "the user has no permission to access account data") } - return am.Store.GetAccountSettings(ctx, LockingStrengthShare, accountID) + return am.Store.GetAccountSettings(ctx, store.LockingStrengthShare, accountID) +} + +func (am *DefaultAccountManager) GetNetworksManager() networks.Manager { + return am.networksManager +} + +func (am *DefaultAccountManager) GetUserManager() users.Manager { + return am.userManager } // addAllGroup to account object if it doesn't exist -func addAllGroup(account *Account) error { +func addAllGroup(account *types.Account) error { if len(account.Groups) == 0 { - allGroup := &nbgroup.Group{ + allGroup := &types.Group{ ID: xid.New().String(), Name: "All", - Issued: nbgroup.GroupIssuedAPI, + Issued: types.GroupIssuedAPI, } for _, peer := range account.Peers { allGroup.Peers = append(allGroup.Peers, peer.ID) } - account.Groups = map[string]*nbgroup.Group{allGroup.ID: allGroup} + account.Groups = map[string]*types.Group{allGroup.ID: allGroup} id := xid.New().String() - defaultPolicy := &Policy{ + defaultPolicy := &types.Policy{ ID: id, - Name: DefaultRuleName, - Description: DefaultRuleDescription, + Name: types.DefaultRuleName, + Description: types.DefaultRuleDescription, Enabled: true, - Rules: []*PolicyRule{ + Rules: []*types.PolicyRule{ { ID: id, - Name: DefaultRuleName, - Description: DefaultRuleDescription, + Name: types.DefaultRuleName, + Description: types.DefaultRuleDescription, Enabled: true, Sources: []string{allGroup.ID}, Destinations: []string{allGroup.ID}, Bidirectional: true, - Protocol: PolicyRuleProtocolALL, - Action: PolicyTrafficActionAccept, + Protocol: types.PolicyRuleProtocolALL, + Action: types.PolicyTrafficActionAccept, }, }, } - account.Policies = []*Policy{defaultPolicy} + account.Policies = []*types.Policy{defaultPolicy} } return nil } // newAccountWithId creates a new Account with a default SetupKey (doesn't store in a Store) and provided id -func newAccountWithId(ctx context.Context, accountID, userID, domain string) *Account { +func newAccountWithId(ctx context.Context, accountID, userID, domain string) *types.Account { log.WithContext(ctx).Debugf("creating new account") - network := NewNetwork() + network := types.NewNetwork() peers := make(map[string]*nbpeer.Peer) - users := make(map[string]*User) + users := make(map[string]*types.User) routes := make(map[route.ID]*route.Route) - setupKeys := map[string]*SetupKey{} + setupKeys := map[string]*types.SetupKey{} nameServersGroups := make(map[string]*nbdns.NameServerGroup) - owner := NewOwnerUser(userID) + owner := types.NewOwnerUser(userID) owner.AccountID = accountID users[userID] = owner - dnsSettings := DNSSettings{ + dnsSettings := types.DNSSettings{ DisabledManagementGroups: make([]string, 0), } log.WithContext(ctx).Debugf("created new account %s", accountID) - acc := &Account{ + acc := &types.Account{ Id: accountID, CreatedAt: time.Now().UTC(), SetupKeys: setupKeys, @@ -2581,14 +1809,14 @@ func newAccountWithId(ctx context.Context, accountID, userID, domain string) *Ac Routes: routes, NameServerGroups: nameServersGroups, DNSSettings: dnsSettings, - Settings: &Settings{ + Settings: &types.Settings{ PeerLoginExpirationEnabled: true, - PeerLoginExpiration: DefaultPeerLoginExpiration, + PeerLoginExpiration: types.DefaultPeerLoginExpiration, GroupsPropagationEnabled: true, RegularUsersViewBlocked: true, PeerInactivityExpirationEnabled: false, - PeerInactivityExpiration: DefaultPeerInactivityExpiration, + PeerInactivityExpiration: types.DefaultPeerInactivityExpiration, }, } @@ -2634,18 +1862,18 @@ func userHasAllowedGroup(allowedGroups []string, userGroups []string) bool { // separateGroups separates user's auto groups into non-JWT and JWT groups. // Returns the list of standard auto groups and a map of JWT auto groups, // where the keys are the group names and the values are the group IDs. -func separateGroups(autoGroups []string, allGroups []*nbgroup.Group) ([]string, map[string]string) { +func separateGroups(autoGroups []string, allGroups []*types.Group) ([]string, map[string]string) { newAutoGroups := make([]string, 0) jwtAutoGroups := make(map[string]string) // map of group name to group ID - allGroupsMap := make(map[string]*nbgroup.Group, len(allGroups)) + allGroupsMap := make(map[string]*types.Group, len(allGroups)) for _, group := range allGroups { allGroupsMap[group.ID] = group } for _, id := range autoGroups { if group, ok := allGroupsMap[id]; ok { - if group.Issued == nbgroup.GroupIssuedJWT { + if group.Issued == types.GroupIssuedJWT { jwtAutoGroups[group.Name] = id } else { newAutoGroups = append(newAutoGroups, id) diff --git a/management/server/account_request_buffer.go b/management/server/account_request_buffer.go index 5f4897e6a..fa6c45856 100644 --- a/management/server/account_request_buffer.go +++ b/management/server/account_request_buffer.go @@ -7,6 +7,9 @@ import ( "time" log "github.com/sirupsen/logrus" + + "github.com/netbirdio/netbird/management/server/store" + "github.com/netbirdio/netbird/management/server/types" ) // AccountRequest holds the result channel to return the requested account. @@ -17,19 +20,19 @@ type AccountRequest struct { // AccountResult holds the account data or an error. type AccountResult struct { - Account *Account + Account *types.Account Err error } type AccountRequestBuffer struct { - store Store + store store.Store getAccountRequests map[string][]*AccountRequest mu sync.Mutex getAccountRequestCh chan *AccountRequest bufferInterval time.Duration } -func NewAccountRequestBuffer(ctx context.Context, store Store) *AccountRequestBuffer { +func NewAccountRequestBuffer(ctx context.Context, store store.Store) *AccountRequestBuffer { bufferIntervalStr := os.Getenv("NB_GET_ACCOUNT_BUFFER_INTERVAL") bufferInterval, err := time.ParseDuration(bufferIntervalStr) if err != nil { @@ -52,7 +55,7 @@ func NewAccountRequestBuffer(ctx context.Context, store Store) *AccountRequestBu return &ac } -func (ac *AccountRequestBuffer) GetAccountWithBackpressure(ctx context.Context, accountID string) (*Account, error) { +func (ac *AccountRequestBuffer) GetAccountWithBackpressure(ctx context.Context, accountID string) (*types.Account, error) { req := &AccountRequest{ AccountID: accountID, ResultChan: make(chan *AccountResult, 1), diff --git a/management/server/account_test.go b/management/server/account_test.go index d952e118a..ca8f21963 100644 --- a/management/server/account_test.go +++ b/management/server/account_test.go @@ -16,6 +16,11 @@ import ( "time" "github.com/golang-jwt/jwt" + + resourceTypes "github.com/netbirdio/netbird/management/server/networks/resources/types" + routerTypes "github.com/netbirdio/netbird/management/server/networks/routers/types" + networkTypes "github.com/netbirdio/netbird/management/server/networks/types" + log "github.com/sirupsen/logrus" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" @@ -24,11 +29,12 @@ import ( nbdns "github.com/netbirdio/netbird/dns" "github.com/netbirdio/netbird/management/server/account" "github.com/netbirdio/netbird/management/server/activity" - "github.com/netbirdio/netbird/management/server/group" "github.com/netbirdio/netbird/management/server/jwtclaims" nbpeer "github.com/netbirdio/netbird/management/server/peer" "github.com/netbirdio/netbird/management/server/posture" + "github.com/netbirdio/netbird/management/server/store" "github.com/netbirdio/netbird/management/server/telemetry" + "github.com/netbirdio/netbird/management/server/types" "github.com/netbirdio/netbird/route" ) @@ -46,7 +52,7 @@ func (a MocIntegratedValidator) ValidatePeer(_ context.Context, update *nbpeer.P } return update, false, nil } -func (a MocIntegratedValidator) GetValidatedPeers(accountID string, groups map[string]*group.Group, peers map[string]*nbpeer.Peer, extraSettings *account.ExtraSettings) (map[string]struct{}, error) { +func (a MocIntegratedValidator) GetValidatedPeers(accountID string, groups map[string]*types.Group, peers map[string]*nbpeer.Peer, extraSettings *account.ExtraSettings) (map[string]struct{}, error) { validatedPeers := make(map[string]struct{}) for _, peer := range peers { validatedPeers[peer.ID] = struct{}{} @@ -73,7 +79,7 @@ func (MocIntegratedValidator) SetPeerInvalidationListener(func(accountID string) func (MocIntegratedValidator) Stop(_ context.Context) { } -func verifyCanAddPeerToAccount(t *testing.T, manager AccountManager, account *Account, userID string) { +func verifyCanAddPeerToAccount(t *testing.T, manager AccountManager, account *types.Account, userID string) { t.Helper() peer := &nbpeer.Peer{ Key: "BhRPtynAAYRDy08+q4HTMsos8fs4plTP4NOSh7C1ry8=", @@ -101,7 +107,7 @@ func verifyCanAddPeerToAccount(t *testing.T, manager AccountManager, account *Ac } } -func verifyNewAccountHasDefaultFields(t *testing.T, account *Account, createdBy string, domain string, expectedUsers []string) { +func verifyNewAccountHasDefaultFields(t *testing.T, account *types.Account, createdBy string, domain string, expectedUsers []string) { t.Helper() if len(account.Peers) != 0 { t.Errorf("expected account to have len(Peers) = %v, got %v", 0, len(account.Peers)) @@ -156,7 +162,7 @@ func TestAccount_GetPeerNetworkMap(t *testing.T) { // peerID3 := "peer-3" tt := []struct { name string - accountSettings Settings + accountSettings types.Settings peerID string expectedPeers []string expectedOfflinePeers []string @@ -164,7 +170,7 @@ func TestAccount_GetPeerNetworkMap(t *testing.T) { }{ { name: "Should return ALL peers when global peer login expiration disabled", - accountSettings: Settings{PeerLoginExpirationEnabled: false, PeerLoginExpiration: time.Hour}, + accountSettings: types.Settings{PeerLoginExpirationEnabled: false, PeerLoginExpiration: time.Hour}, peerID: peerID1, expectedPeers: []string{peerID2}, expectedOfflinePeers: []string{}, @@ -202,7 +208,7 @@ func TestAccount_GetPeerNetworkMap(t *testing.T) { }, { name: "Should return no peers when global peer login expiration enabled and peers expired", - accountSettings: Settings{PeerLoginExpirationEnabled: true, PeerLoginExpiration: time.Hour}, + accountSettings: types.Settings{PeerLoginExpirationEnabled: true, PeerLoginExpiration: time.Hour}, peerID: peerID1, expectedPeers: []string{}, expectedOfflinePeers: []string{peerID2}, @@ -396,12 +402,12 @@ func TestAccount_GetPeerNetworkMap(t *testing.T) { netIP := net.IP{100, 64, 0, 0} netMask := net.IPMask{255, 255, 0, 0} - network := &Network{ + network := &types.Network{ Identifier: "network", Net: net.IPNet{IP: netIP, Mask: netMask}, Dns: "netbird.selfhosted", Serial: 0, - mu: sync.Mutex{}, + Mu: sync.Mutex{}, } for _, testCase := range tt { @@ -485,12 +491,12 @@ func TestDefaultAccountManager_GetAccountIDFromToken(t *testing.T) { } initUnknown := defaultInitAccount - initUnknown.DomainCategory = UnknownCategory + initUnknown.DomainCategory = types.UnknownCategory initUnknown.Domain = unknownDomain privateInitAccount := defaultInitAccount privateInitAccount.Domain = privateDomain - privateInitAccount.DomainCategory = PrivateCategory + privateInitAccount.DomainCategory = types.PrivateCategory testCases := []struct { name string @@ -500,7 +506,7 @@ func TestDefaultAccountManager_GetAccountIDFromToken(t *testing.T) { inputUpdateClaimAccount bool testingFunc require.ComparisonAssertionFunc expectedMSG string - expectedUserRole UserRole + expectedUserRole types.UserRole expectedDomainCategory string expectedDomain string expectedPrimaryDomainStatus bool @@ -512,12 +518,12 @@ func TestDefaultAccountManager_GetAccountIDFromToken(t *testing.T) { inputClaims: jwtclaims.AuthorizationClaims{ Domain: publicDomain, UserId: "pub-domain-user", - DomainCategory: PublicCategory, + DomainCategory: types.PublicCategory, }, inputInitUserParams: defaultInitAccount, testingFunc: require.NotEqual, expectedMSG: "account IDs shouldn't match", - expectedUserRole: UserRoleOwner, + expectedUserRole: types.UserRoleOwner, expectedDomainCategory: "", expectedDomain: publicDomain, expectedPrimaryDomainStatus: false, @@ -529,12 +535,12 @@ func TestDefaultAccountManager_GetAccountIDFromToken(t *testing.T) { inputClaims: jwtclaims.AuthorizationClaims{ Domain: unknownDomain, UserId: "unknown-domain-user", - DomainCategory: UnknownCategory, + DomainCategory: types.UnknownCategory, }, inputInitUserParams: initUnknown, testingFunc: require.NotEqual, expectedMSG: "account IDs shouldn't match", - expectedUserRole: UserRoleOwner, + expectedUserRole: types.UserRoleOwner, expectedDomain: unknownDomain, expectedDomainCategory: "", expectedPrimaryDomainStatus: false, @@ -546,14 +552,14 @@ func TestDefaultAccountManager_GetAccountIDFromToken(t *testing.T) { inputClaims: jwtclaims.AuthorizationClaims{ Domain: privateDomain, UserId: "pvt-domain-user", - DomainCategory: PrivateCategory, + DomainCategory: types.PrivateCategory, }, inputInitUserParams: defaultInitAccount, testingFunc: require.NotEqual, expectedMSG: "account IDs shouldn't match", - expectedUserRole: UserRoleOwner, + expectedUserRole: types.UserRoleOwner, expectedDomain: privateDomain, - expectedDomainCategory: PrivateCategory, + expectedDomainCategory: types.PrivateCategory, expectedPrimaryDomainStatus: true, expectedCreatedBy: "pvt-domain-user", expectedUsers: []string{"pvt-domain-user"}, @@ -563,15 +569,15 @@ func TestDefaultAccountManager_GetAccountIDFromToken(t *testing.T) { inputClaims: jwtclaims.AuthorizationClaims{ Domain: privateDomain, UserId: "new-pvt-domain-user", - DomainCategory: PrivateCategory, + DomainCategory: types.PrivateCategory, }, inputUpdateAttrs: true, inputInitUserParams: privateInitAccount, testingFunc: require.Equal, expectedMSG: "account IDs should match", - expectedUserRole: UserRoleUser, + expectedUserRole: types.UserRoleUser, expectedDomain: privateDomain, - expectedDomainCategory: PrivateCategory, + expectedDomainCategory: types.PrivateCategory, expectedPrimaryDomainStatus: true, expectedCreatedBy: defaultInitAccount.UserId, expectedUsers: []string{defaultInitAccount.UserId, "new-pvt-domain-user"}, @@ -581,14 +587,14 @@ func TestDefaultAccountManager_GetAccountIDFromToken(t *testing.T) { inputClaims: jwtclaims.AuthorizationClaims{ Domain: defaultInitAccount.Domain, UserId: defaultInitAccount.UserId, - DomainCategory: PrivateCategory, + DomainCategory: types.PrivateCategory, }, inputInitUserParams: defaultInitAccount, testingFunc: require.Equal, expectedMSG: "account IDs should match", - expectedUserRole: UserRoleOwner, + expectedUserRole: types.UserRoleOwner, expectedDomain: defaultInitAccount.Domain, - expectedDomainCategory: PrivateCategory, + expectedDomainCategory: types.PrivateCategory, expectedPrimaryDomainStatus: true, expectedCreatedBy: defaultInitAccount.UserId, expectedUsers: []string{defaultInitAccount.UserId}, @@ -598,15 +604,15 @@ func TestDefaultAccountManager_GetAccountIDFromToken(t *testing.T) { inputClaims: jwtclaims.AuthorizationClaims{ Domain: defaultInitAccount.Domain, UserId: defaultInitAccount.UserId, - DomainCategory: PrivateCategory, + DomainCategory: types.PrivateCategory, }, inputUpdateClaimAccount: true, inputInitUserParams: defaultInitAccount, testingFunc: require.Equal, expectedMSG: "account IDs should match", - expectedUserRole: UserRoleOwner, + expectedUserRole: types.UserRoleOwner, expectedDomain: defaultInitAccount.Domain, - expectedDomainCategory: PrivateCategory, + expectedDomainCategory: types.PrivateCategory, expectedPrimaryDomainStatus: true, expectedCreatedBy: defaultInitAccount.UserId, expectedUsers: []string{defaultInitAccount.UserId}, @@ -616,12 +622,12 @@ func TestDefaultAccountManager_GetAccountIDFromToken(t *testing.T) { inputClaims: jwtclaims.AuthorizationClaims{ Domain: "", UserId: "pvt-domain-user", - DomainCategory: PrivateCategory, + DomainCategory: types.PrivateCategory, }, inputInitUserParams: defaultInitAccount, testingFunc: require.NotEqual, expectedMSG: "account IDs shouldn't match", - expectedUserRole: UserRoleOwner, + expectedUserRole: types.UserRoleOwner, expectedDomain: "", expectedDomainCategory: "", expectedPrimaryDomainStatus: false, @@ -733,7 +739,7 @@ func TestDefaultAccountManager_GetGroupsFromTheToken(t *testing.T) { require.Len(t, account.Groups, 3, "groups should be added to the account") - groupsByNames := map[string]*group.Group{} + groupsByNames := map[string]*types.Group{} for _, g := range account.Groups { groupsByNames[g.Name] = g } @@ -741,32 +747,36 @@ func TestDefaultAccountManager_GetGroupsFromTheToken(t *testing.T) { g1, ok := groupsByNames["group1"] require.True(t, ok, "group1 should be added to the account") require.Equal(t, g1.Name, "group1", "group1 name should match") - require.Equal(t, g1.Issued, group.GroupIssuedJWT, "group1 issued should match") + require.Equal(t, g1.Issued, types.GroupIssuedJWT, "group1 issued should match") g2, ok := groupsByNames["group2"] require.True(t, ok, "group2 should be added to the account") require.Equal(t, g2.Name, "group2", "group2 name should match") - require.Equal(t, g2.Issued, group.GroupIssuedJWT, "group2 issued should match") + require.Equal(t, g2.Issued, types.GroupIssuedJWT, "group2 issued should match") }) } func TestAccountManager_GetAccountFromPAT(t *testing.T) { - store := newStore(t) + store, cleanup, err := store.NewTestStoreFromSQL(context.Background(), "", t.TempDir()) + if err != nil { + t.Fatalf("Error when creating store: %s", err) + } + t.Cleanup(cleanup) account := newAccountWithId(context.Background(), "account_id", "testuser", "") token := "nbp_9999EUDNdkeusjentDLSJEn1902u84390W6W" hashedToken := sha256.Sum256([]byte(token)) encodedHashedToken := b64.StdEncoding.EncodeToString(hashedToken[:]) - account.Users["someUser"] = &User{ + account.Users["someUser"] = &types.User{ Id: "someUser", - PATs: map[string]*PersonalAccessToken{ + PATs: map[string]*types.PersonalAccessToken{ "tokenId": { ID: "tokenId", HashedToken: encodedHashedToken, }, }, } - err := store.SaveAccount(context.Background(), account) + err = store.SaveAccount(context.Background(), account) if err != nil { t.Fatalf("Error when saving account: %s", err) } @@ -786,15 +796,20 @@ func TestAccountManager_GetAccountFromPAT(t *testing.T) { } func TestDefaultAccountManager_MarkPATUsed(t *testing.T) { - store := newStore(t) + store, cleanup, err := store.NewTestStoreFromSQL(context.Background(), "", t.TempDir()) + if err != nil { + t.Fatalf("Error when creating store: %s", err) + } + t.Cleanup(cleanup) + account := newAccountWithId(context.Background(), "account_id", "testuser", "") token := "nbp_9999EUDNdkeusjentDLSJEn1902u84390W6W" hashedToken := sha256.Sum256([]byte(token)) encodedHashedToken := b64.StdEncoding.EncodeToString(hashedToken[:]) - account.Users["someUser"] = &User{ + account.Users["someUser"] = &types.User{ Id: "someUser", - PATs: map[string]*PersonalAccessToken{ + PATs: map[string]*types.PersonalAccessToken{ "tokenId": { ID: "tokenId", HashedToken: encodedHashedToken, @@ -802,7 +817,7 @@ func TestDefaultAccountManager_MarkPATUsed(t *testing.T) { }, }, } - err := store.SaveAccount(context.Background(), account) + err = store.SaveAccount(context.Background(), account) if err != nil { t.Fatalf("Error when saving account: %s", err) } @@ -904,7 +919,7 @@ func TestAccountManager_GetAccountByUserID(t *testing.T) { return } - exists, err := manager.Store.AccountExists(context.Background(), LockingStrengthShare, accountID) + exists, err := manager.Store.AccountExists(context.Background(), store.LockingStrengthShare, accountID) assert.NoError(t, err) assert.True(t, exists, "expected to get existing account after creation using userid") @@ -914,7 +929,7 @@ func TestAccountManager_GetAccountByUserID(t *testing.T) { } } -func createAccount(am *DefaultAccountManager, accountID, userID, domain string) (*Account, error) { +func createAccount(am *DefaultAccountManager, accountID, userID, domain string) (*types.Account, error) { account := newAccountWithId(context.Background(), accountID, userID, domain) err := am.Store.SaveAccount(context.Background(), account) if err != nil { @@ -990,13 +1005,13 @@ func BenchmarkTest_GetAccountWithclaims(b *testing.B) { claims := jwtclaims.AuthorizationClaims{ Domain: "example.com", UserId: "pvt-domain-user", - DomainCategory: PrivateCategory, + DomainCategory: types.PrivateCategory, } publicClaims := jwtclaims.AuthorizationClaims{ Domain: "test.com", UserId: "public-domain-user", - DomainCategory: PublicCategory, + DomainCategory: types.PublicCategory, } am, err := createManager(b) @@ -1074,13 +1089,13 @@ func BenchmarkTest_GetAccountWithclaims(b *testing.B) { } -func genUsers(p string, n int) map[string]*User { - users := map[string]*User{} +func genUsers(p string, n int) map[string]*types.User { + users := map[string]*types.User{} now := time.Now() for i := 0; i < n; i++ { - users[fmt.Sprintf("%s-%d", p, i)] = &User{ + users[fmt.Sprintf("%s-%d", p, i)] = &types.User{ Id: fmt.Sprintf("%s-%d", p, i), - Role: UserRoleAdmin, + Role: types.UserRoleAdmin, LastLogin: now, CreatedAt: now, Issued: "api", @@ -1105,7 +1120,7 @@ func TestAccountManager_AddPeer(t *testing.T) { serial := account.Network.CurrentSerial() // should be 0 - setupKey, err := manager.CreateSetupKey(context.Background(), account.Id, "test-key", SetupKeyReusable, time.Hour, nil, 999, userID, false) + setupKey, err := manager.CreateSetupKey(context.Background(), account.Id, "test-key", types.SetupKeyReusable, time.Hour, nil, 999, userID, false) if err != nil { t.Fatal("error creating setup key") return @@ -1232,7 +1247,7 @@ func TestAccountManager_AddPeerWithUserID(t *testing.T) { func TestAccountManager_NetworkUpdates_SaveGroup(t *testing.T) { manager, account, peer1, peer2, peer3 := setupNetworkMapTest(t) - group := group.Group{ + group := types.Group{ ID: "groupA", Name: "GroupA", Peers: []string{}, @@ -1242,15 +1257,15 @@ func TestAccountManager_NetworkUpdates_SaveGroup(t *testing.T) { return } - _, err := manager.SavePolicy(context.Background(), account.Id, userID, &Policy{ + _, err := manager.SavePolicy(context.Background(), account.Id, userID, &types.Policy{ Enabled: true, - Rules: []*PolicyRule{ + Rules: []*types.PolicyRule{ { Enabled: true, Sources: []string{"groupA"}, Destinations: []string{"groupA"}, Bidirectional: true, - Action: PolicyTrafficActionAccept, + Action: types.PolicyTrafficActionAccept, }, }, }) @@ -1309,7 +1324,7 @@ func TestAccountManager_NetworkUpdates_DeletePolicy(t *testing.T) { func TestAccountManager_NetworkUpdates_SavePolicy(t *testing.T) { manager, account, peer1, peer2, _ := setupNetworkMapTest(t) - group := group.Group{ + group := types.Group{ ID: "groupA", Name: "GroupA", Peers: []string{peer1.ID, peer2.ID}, @@ -1334,15 +1349,15 @@ func TestAccountManager_NetworkUpdates_SavePolicy(t *testing.T) { } }() - _, err := manager.SavePolicy(context.Background(), account.Id, userID, &Policy{ + _, err := manager.SavePolicy(context.Background(), account.Id, userID, &types.Policy{ Enabled: true, - Rules: []*PolicyRule{ + Rules: []*types.PolicyRule{ { Enabled: true, Sources: []string{"groupA"}, Destinations: []string{"groupA"}, Bidirectional: true, - Action: PolicyTrafficActionAccept, + Action: types.PolicyTrafficActionAccept, }, }, }) @@ -1357,7 +1372,7 @@ func TestAccountManager_NetworkUpdates_SavePolicy(t *testing.T) { func TestAccountManager_NetworkUpdates_DeletePeer(t *testing.T) { manager, account, peer1, _, peer3 := setupNetworkMapTest(t) - group := group.Group{ + group := types.Group{ ID: "groupA", Name: "GroupA", Peers: []string{peer1.ID, peer3.ID}, @@ -1367,15 +1382,15 @@ func TestAccountManager_NetworkUpdates_DeletePeer(t *testing.T) { return } - _, err := manager.SavePolicy(context.Background(), account.Id, userID, &Policy{ + _, err := manager.SavePolicy(context.Background(), account.Id, userID, &types.Policy{ Enabled: true, - Rules: []*PolicyRule{ + Rules: []*types.PolicyRule{ { Enabled: true, Sources: []string{"groupA"}, Destinations: []string{"groupA"}, Bidirectional: true, - Action: PolicyTrafficActionAccept, + Action: types.PolicyTrafficActionAccept, }, }, }) @@ -1413,7 +1428,7 @@ func TestAccountManager_NetworkUpdates_DeleteGroup(t *testing.T) { updMsg := manager.peersUpdateManager.CreateChannel(context.Background(), peer1.ID) defer manager.peersUpdateManager.CloseChannel(context.Background(), peer1.ID) - err := manager.SaveGroup(context.Background(), account.Id, userID, &group.Group{ + err := manager.SaveGroup(context.Background(), account.Id, userID, &types.Group{ ID: "groupA", Name: "GroupA", Peers: []string{peer1.ID, peer2.ID, peer3.ID}, @@ -1426,15 +1441,15 @@ func TestAccountManager_NetworkUpdates_DeleteGroup(t *testing.T) { return } - policy, err := manager.SavePolicy(context.Background(), account.Id, userID, &Policy{ + policy, err := manager.SavePolicy(context.Background(), account.Id, userID, &types.Policy{ Enabled: true, - Rules: []*PolicyRule{ + Rules: []*types.PolicyRule{ { Enabled: true, Sources: []string{"groupA"}, Destinations: []string{"groupA"}, Bidirectional: true, - Action: PolicyTrafficActionAccept, + Action: types.PolicyTrafficActionAccept, }, }, }) @@ -1482,7 +1497,7 @@ func TestAccountManager_DeletePeer(t *testing.T) { t.Fatal(err) } - setupKey, err := manager.CreateSetupKey(context.Background(), account.Id, "test-key", SetupKeyReusable, time.Hour, nil, 999, userID, false) + setupKey, err := manager.CreateSetupKey(context.Background(), account.Id, "test-key", types.SetupKeyReusable, time.Hour, nil, 999, userID, false) if err != nil { t.Fatal("error creating setup key") return @@ -1557,7 +1572,7 @@ func TestGetUsersFromAccount(t *testing.T) { t.Fatal(err) } - users := map[string]*User{"1": {Id: "1", Role: UserRoleOwner}, "2": {Id: "2", Role: "user"}, "3": {Id: "3", Role: "user"}} + users := map[string]*types.User{"1": {Id: "1", Role: types.UserRoleOwner}, "2": {Id: "2", Role: "user"}, "3": {Id: "3", Role: "user"}} accountId := "test_account_id" account, err := createAccount(manager, accountId, users["1"].Id, "") @@ -1589,7 +1604,7 @@ func TestFileStore_GetRoutesByPrefix(t *testing.T) { if err != nil { t.Fatal(err) } - account := &Account{ + account := &types.Account{ Routes: map[route.ID]*route.Route{ "route-1": { ID: "route-1", @@ -1636,11 +1651,11 @@ func TestAccount_GetRoutesToSync(t *testing.T) { if err != nil { t.Fatal(err) } - account := &Account{ + account := &types.Account{ Peers: map[string]*nbpeer.Peer{ "peer-1": {Key: "peer-1", Meta: nbpeer.PeerSystemMeta{GoOS: "linux"}}, "peer-2": {Key: "peer-2", Meta: nbpeer.PeerSystemMeta{GoOS: "linux"}}, "peer-3": {Key: "peer-1", Meta: nbpeer.PeerSystemMeta{GoOS: "linux"}}, }, - Groups: map[string]*group.Group{"group1": {ID: "group1", Peers: []string{"peer-1", "peer-2"}}}, + Groups: map[string]*types.Group{"group1": {ID: "group1", Peers: []string{"peer-1", "peer-2"}}}, Routes: map[route.ID]*route.Route{ "route-1": { ID: "route-1", @@ -1681,7 +1696,7 @@ func TestAccount_GetRoutesToSync(t *testing.T) { }, } - routes := account.getRoutesToSync(context.Background(), "peer-2", []*nbpeer.Peer{{Key: "peer-1"}, {Key: "peer-3"}}) + routes := account.GetRoutesToSync(context.Background(), "peer-2", []*nbpeer.Peer{{Key: "peer-1"}, {Key: "peer-3"}}) assert.Len(t, routes, 2) routeIDs := make(map[route.ID]struct{}, 2) @@ -1691,26 +1706,26 @@ func TestAccount_GetRoutesToSync(t *testing.T) { assert.Contains(t, routeIDs, route.ID("route-2")) assert.Contains(t, routeIDs, route.ID("route-3")) - emptyRoutes := account.getRoutesToSync(context.Background(), "peer-3", []*nbpeer.Peer{{Key: "peer-1"}, {Key: "peer-2"}}) + emptyRoutes := account.GetRoutesToSync(context.Background(), "peer-3", []*nbpeer.Peer{{Key: "peer-1"}, {Key: "peer-2"}}) assert.Len(t, emptyRoutes, 0) } func TestAccount_Copy(t *testing.T) { - account := &Account{ + account := &types.Account{ Id: "account1", CreatedBy: "tester", CreatedAt: time.Now().UTC(), Domain: "test.com", DomainCategory: "public", IsDomainPrimaryAccount: true, - SetupKeys: map[string]*SetupKey{ + SetupKeys: map[string]*types.SetupKey{ "setup1": { Id: "setup1", AutoGroups: []string{"group1"}, }, }, - Network: &Network{ + Network: &types.Network{ Identifier: "net1", }, Peers: map[string]*nbpeer.Peer{ @@ -1723,12 +1738,12 @@ func TestAccount_Copy(t *testing.T) { }, }, }, - Users: map[string]*User{ + Users: map[string]*types.User{ "user1": { Id: "user1", - Role: UserRoleAdmin, + Role: types.UserRoleAdmin, AutoGroups: []string{"group1"}, - PATs: map[string]*PersonalAccessToken{ + PATs: map[string]*types.PersonalAccessToken{ "pat1": { ID: "pat1", Name: "First PAT", @@ -1741,17 +1756,18 @@ func TestAccount_Copy(t *testing.T) { }, }, }, - Groups: map[string]*group.Group{ + Groups: map[string]*types.Group{ "group1": { - ID: "group1", - Peers: []string{"peer1"}, + ID: "group1", + Peers: []string{"peer1"}, + Resources: []types.Resource{}, }, }, - Policies: []*Policy{ + Policies: []*types.Policy{ { ID: "policy1", Enabled: true, - Rules: make([]*PolicyRule, 0), + Rules: make([]*types.PolicyRule, 0), SourcePostureChecks: make([]string, 0), }, }, @@ -1771,13 +1787,36 @@ func TestAccount_Copy(t *testing.T) { NameServers: []nbdns.NameServer{}, }, }, - DNSSettings: DNSSettings{DisabledManagementGroups: []string{}}, + DNSSettings: types.DNSSettings{DisabledManagementGroups: []string{}}, PostureChecks: []*posture.Checks{ { ID: "posture Checks1", }, }, - Settings: &Settings{}, + Settings: &types.Settings{}, + Networks: []*networkTypes.Network{ + { + ID: "network1", + }, + }, + NetworkRouters: []*routerTypes.NetworkRouter{ + { + ID: "router1", + NetworkID: "network1", + PeerGroups: []string{"group1"}, + Masquerade: false, + Metric: 0, + }, + }, + NetworkResources: []*resourceTypes.NetworkResource{ + { + ID: "resource1", + NetworkID: "network1", + Name: "resource", + Type: "Subnet", + Address: "172.12.6.1/24", + }, + }, } err := hasNilField(account) if err != nil { @@ -1830,7 +1869,7 @@ func TestDefaultAccountManager_DefaultAccountSettings(t *testing.T) { accountID, err := manager.GetAccountIDByUserID(context.Background(), userID, "") require.NoError(t, err, "unable to create an account") - settings, err := manager.Store.GetAccountSettings(context.Background(), LockingStrengthShare, accountID) + settings, err := manager.Store.GetAccountSettings(context.Background(), store.LockingStrengthShare, accountID) require.NoError(t, err, "unable to get account settings") assert.NotNil(t, settings) @@ -1863,7 +1902,7 @@ func TestDefaultAccountManager_UpdatePeer_PeerLoginExpiration(t *testing.T) { err = manager.MarkPeerConnected(context.Background(), key.PublicKey().String(), true, nil, account) require.NoError(t, err, "unable to mark peer connected") - account, err = manager.UpdateAccountSettings(context.Background(), accountID, userID, &Settings{ + account, err = manager.UpdateAccountSettings(context.Background(), accountID, userID, &types.Settings{ PeerLoginExpiration: time.Hour, PeerLoginExpirationEnabled: true, }) @@ -1911,7 +1950,7 @@ func TestDefaultAccountManager_MarkPeerConnected_PeerLoginExpiration(t *testing. LoginExpirationEnabled: true, }) require.NoError(t, err, "unable to add peer") - _, err = manager.UpdateAccountSettings(context.Background(), accountID, userID, &Settings{ + _, err = manager.UpdateAccountSettings(context.Background(), accountID, userID, &types.Settings{ PeerLoginExpiration: time.Hour, PeerLoginExpirationEnabled: true, }) @@ -1980,7 +2019,7 @@ func TestDefaultAccountManager_UpdateAccountSettings_PeerLoginExpiration(t *test }, } // enabling PeerLoginExpirationEnabled should trigger the expiration job - account, err = manager.UpdateAccountSettings(context.Background(), account.Id, userID, &Settings{ + account, err = manager.UpdateAccountSettings(context.Background(), account.Id, userID, &types.Settings{ PeerLoginExpiration: time.Hour, PeerLoginExpirationEnabled: true, }) @@ -1993,7 +2032,7 @@ func TestDefaultAccountManager_UpdateAccountSettings_PeerLoginExpiration(t *test wg.Add(1) // disabling PeerLoginExpirationEnabled should trigger cancel - _, err = manager.UpdateAccountSettings(context.Background(), account.Id, userID, &Settings{ + _, err = manager.UpdateAccountSettings(context.Background(), account.Id, userID, &types.Settings{ PeerLoginExpiration: time.Hour, PeerLoginExpirationEnabled: false, }) @@ -2011,7 +2050,7 @@ func TestDefaultAccountManager_UpdateAccountSettings(t *testing.T) { accountID, err := manager.GetAccountIDByUserID(context.Background(), userID, "") require.NoError(t, err, "unable to create an account") - updated, err := manager.UpdateAccountSettings(context.Background(), accountID, userID, &Settings{ + updated, err := manager.UpdateAccountSettings(context.Background(), accountID, userID, &types.Settings{ PeerLoginExpiration: time.Hour, PeerLoginExpirationEnabled: false, }) @@ -2019,19 +2058,19 @@ func TestDefaultAccountManager_UpdateAccountSettings(t *testing.T) { assert.False(t, updated.Settings.PeerLoginExpirationEnabled) assert.Equal(t, updated.Settings.PeerLoginExpiration, time.Hour) - settings, err := manager.Store.GetAccountSettings(context.Background(), LockingStrengthShare, accountID) + settings, err := manager.Store.GetAccountSettings(context.Background(), store.LockingStrengthShare, accountID) require.NoError(t, err, "unable to get account settings") assert.False(t, settings.PeerLoginExpirationEnabled) assert.Equal(t, settings.PeerLoginExpiration, time.Hour) - _, err = manager.UpdateAccountSettings(context.Background(), accountID, userID, &Settings{ + _, err = manager.UpdateAccountSettings(context.Background(), accountID, userID, &types.Settings{ PeerLoginExpiration: time.Second, PeerLoginExpirationEnabled: false, }) require.Error(t, err, "expecting to fail when providing PeerLoginExpiration less than one hour") - _, err = manager.UpdateAccountSettings(context.Background(), accountID, userID, &Settings{ + _, err = manager.UpdateAccountSettings(context.Background(), accountID, userID, &types.Settings{ PeerLoginExpiration: time.Hour * 24 * 181, PeerLoginExpirationEnabled: false, }) @@ -2104,9 +2143,9 @@ func TestAccount_GetExpiredPeers(t *testing.T) { for _, testCase := range testCases { t.Run(testCase.name, func(t *testing.T) { - account := &Account{ + account := &types.Account{ Peers: testCase.peers, - Settings: &Settings{ + Settings: &types.Settings{ PeerLoginExpirationEnabled: true, PeerLoginExpiration: time.Hour, }, @@ -2188,9 +2227,9 @@ func TestAccount_GetInactivePeers(t *testing.T) { for _, testCase := range testCases { t.Run(testCase.name, func(t *testing.T) { - account := &Account{ + account := &types.Account{ Peers: testCase.peers, - Settings: &Settings{ + Settings: &types.Settings{ PeerInactivityExpirationEnabled: true, PeerInactivityExpiration: time.Second, }, @@ -2255,7 +2294,7 @@ func TestAccount_GetPeersWithExpiration(t *testing.T) { for _, testCase := range testCases { t.Run(testCase.name, func(t *testing.T) { - account := &Account{ + account := &types.Account{ Peers: testCase.peers, } @@ -2324,7 +2363,7 @@ func TestAccount_GetPeersWithInactivity(t *testing.T) { for _, testCase := range testCases { t.Run(testCase.name, func(t *testing.T) { - account := &Account{ + account := &types.Account{ Peers: testCase.peers, } @@ -2488,9 +2527,9 @@ func TestAccount_GetNextPeerExpiration(t *testing.T) { } for _, testCase := range testCases { t.Run(testCase.name, func(t *testing.T) { - account := &Account{ + account := &types.Account{ Peers: testCase.peers, - Settings: &Settings{PeerLoginExpiration: testCase.expiration, PeerLoginExpirationEnabled: testCase.expirationEnabled}, + Settings: &types.Settings{PeerLoginExpiration: testCase.expiration, PeerLoginExpirationEnabled: testCase.expirationEnabled}, } expiration, ok := account.GetNextPeerExpiration() @@ -2648,9 +2687,9 @@ func TestAccount_GetNextInactivePeerExpiration(t *testing.T) { } for _, testCase := range testCases { t.Run(testCase.name, func(t *testing.T) { - account := &Account{ + account := &types.Account{ Peers: testCase.peers, - Settings: &Settings{PeerInactivityExpiration: testCase.expiration, PeerInactivityExpirationEnabled: testCase.expirationEnabled}, + Settings: &types.Settings{PeerInactivityExpiration: testCase.expiration, PeerInactivityExpirationEnabled: testCase.expirationEnabled}, } expiration, ok := account.GetNextInactivePeerExpiration() @@ -2669,7 +2708,7 @@ func TestAccount_SetJWTGroups(t *testing.T) { require.NoError(t, err, "unable to create account manager") // create a new account - account := &Account{ + account := &types.Account{ Id: "accountID", Peers: map[string]*nbpeer.Peer{ "peer1": {ID: "peer1", Key: "key1", UserID: "user1"}, @@ -2678,11 +2717,11 @@ func TestAccount_SetJWTGroups(t *testing.T) { "peer4": {ID: "peer4", Key: "key4", UserID: "user2"}, "peer5": {ID: "peer5", Key: "key5", UserID: "user2"}, }, - Groups: map[string]*group.Group{ - "group1": {ID: "group1", Name: "group1", Issued: group.GroupIssuedAPI, Peers: []string{}}, + Groups: map[string]*types.Group{ + "group1": {ID: "group1", Name: "group1", Issued: types.GroupIssuedAPI, Peers: []string{}}, }, - Settings: &Settings{GroupsPropagationEnabled: true, JWTGroupsEnabled: true, JWTGroupsClaimName: "groups"}, - Users: map[string]*User{ + Settings: &types.Settings{GroupsPropagationEnabled: true, JWTGroupsEnabled: true, JWTGroupsClaimName: "groups"}, + Users: map[string]*types.User{ "user1": {Id: "user1", AccountID: "accountID"}, "user2": {Id: "user2", AccountID: "accountID"}, }, @@ -2698,7 +2737,7 @@ func TestAccount_SetJWTGroups(t *testing.T) { err := manager.syncJWTGroups(context.Background(), "accountID", claims) assert.NoError(t, err, "unable to sync jwt groups") - user, err := manager.Store.GetUserByUserID(context.Background(), LockingStrengthShare, "user1") + user, err := manager.Store.GetUserByUserID(context.Background(), store.LockingStrengthShare, "user1") assert.NoError(t, err, "unable to get user") assert.Empty(t, user.AutoGroups, "auto groups must be empty") }) @@ -2711,18 +2750,18 @@ func TestAccount_SetJWTGroups(t *testing.T) { err := manager.syncJWTGroups(context.Background(), "accountID", claims) assert.NoError(t, err, "unable to sync jwt groups") - user, err := manager.Store.GetUserByUserID(context.Background(), LockingStrengthShare, "user1") + user, err := manager.Store.GetUserByUserID(context.Background(), store.LockingStrengthShare, "user1") assert.NoError(t, err, "unable to get user") assert.Len(t, user.AutoGroups, 0) - group1, err := manager.Store.GetGroupByID(context.Background(), LockingStrengthShare, "accountID", "group1") + group1, err := manager.Store.GetGroupByID(context.Background(), store.LockingStrengthShare, "accountID", "group1") assert.NoError(t, err, "unable to get group") - assert.Equal(t, group1.Issued, group.GroupIssuedAPI, "group should be api issued") + assert.Equal(t, group1.Issued, types.GroupIssuedAPI, "group should be api issued") }) t.Run("jwt match existing api group in user auto groups", func(t *testing.T) { account.Users["user1"].AutoGroups = []string{"group1"} - assert.NoError(t, manager.Store.SaveUser(context.Background(), LockingStrengthUpdate, account.Users["user1"])) + assert.NoError(t, manager.Store.SaveUser(context.Background(), store.LockingStrengthUpdate, account.Users["user1"])) claims := jwtclaims.AuthorizationClaims{ UserId: "user1", @@ -2731,13 +2770,13 @@ func TestAccount_SetJWTGroups(t *testing.T) { err = manager.syncJWTGroups(context.Background(), "accountID", claims) assert.NoError(t, err, "unable to sync jwt groups") - user, err := manager.Store.GetUserByUserID(context.Background(), LockingStrengthShare, "user1") + user, err := manager.Store.GetUserByUserID(context.Background(), store.LockingStrengthShare, "user1") assert.NoError(t, err, "unable to get user") assert.Len(t, user.AutoGroups, 1) - group1, err := manager.Store.GetGroupByID(context.Background(), LockingStrengthShare, "accountID", "group1") + group1, err := manager.Store.GetGroupByID(context.Background(), store.LockingStrengthShare, "accountID", "group1") assert.NoError(t, err, "unable to get group") - assert.Equal(t, group1.Issued, group.GroupIssuedAPI, "group should be api issued") + assert.Equal(t, group1.Issued, types.GroupIssuedAPI, "group should be api issued") }) t.Run("add jwt group", func(t *testing.T) { @@ -2748,7 +2787,7 @@ func TestAccount_SetJWTGroups(t *testing.T) { err = manager.syncJWTGroups(context.Background(), "accountID", claims) assert.NoError(t, err, "unable to sync jwt groups") - user, err := manager.Store.GetUserByUserID(context.Background(), LockingStrengthShare, "user1") + user, err := manager.Store.GetUserByUserID(context.Background(), store.LockingStrengthShare, "user1") assert.NoError(t, err, "unable to get user") assert.Len(t, user.AutoGroups, 2, "groups count should not be change") }) @@ -2761,7 +2800,7 @@ func TestAccount_SetJWTGroups(t *testing.T) { err = manager.syncJWTGroups(context.Background(), "accountID", claims) assert.NoError(t, err, "unable to sync jwt groups") - user, err := manager.Store.GetUserByUserID(context.Background(), LockingStrengthShare, "user1") + user, err := manager.Store.GetUserByUserID(context.Background(), store.LockingStrengthShare, "user1") assert.NoError(t, err, "unable to get user") assert.Len(t, user.AutoGroups, 2, "groups count should not be change") }) @@ -2774,11 +2813,11 @@ func TestAccount_SetJWTGroups(t *testing.T) { err = manager.syncJWTGroups(context.Background(), "accountID", claims) assert.NoError(t, err, "unable to sync jwt groups") - groups, err := manager.Store.GetAccountGroups(context.Background(), LockingStrengthShare, "accountID") + groups, err := manager.Store.GetAccountGroups(context.Background(), store.LockingStrengthShare, "accountID") assert.NoError(t, err) assert.Len(t, groups, 3, "new group3 should be added") - user, err := manager.Store.GetUserByUserID(context.Background(), LockingStrengthShare, "user2") + user, err := manager.Store.GetUserByUserID(context.Background(), store.LockingStrengthShare, "user2") assert.NoError(t, err, "unable to get user") assert.Len(t, user.AutoGroups, 1, "new group should be added") }) @@ -2791,7 +2830,7 @@ func TestAccount_SetJWTGroups(t *testing.T) { err = manager.syncJWTGroups(context.Background(), "accountID", claims) assert.NoError(t, err, "unable to sync jwt groups") - user, err := manager.Store.GetUserByUserID(context.Background(), LockingStrengthShare, "user1") + user, err := manager.Store.GetUserByUserID(context.Background(), store.LockingStrengthShare, "user1") assert.NoError(t, err, "unable to get user") assert.Len(t, user.AutoGroups, 1, "only non-JWT groups should remain") assert.Contains(t, user.AutoGroups, "group1", " group1 should still be present") @@ -2799,7 +2838,7 @@ func TestAccount_SetJWTGroups(t *testing.T) { } func TestAccount_UserGroupsAddToPeers(t *testing.T) { - account := &Account{ + account := &types.Account{ Peers: map[string]*nbpeer.Peer{ "peer1": {ID: "peer1", Key: "key1", UserID: "user1"}, "peer2": {ID: "peer2", Key: "key2", UserID: "user1"}, @@ -2807,12 +2846,12 @@ func TestAccount_UserGroupsAddToPeers(t *testing.T) { "peer4": {ID: "peer4", Key: "key4", UserID: "user2"}, "peer5": {ID: "peer5", Key: "key5", UserID: "user2"}, }, - Groups: map[string]*group.Group{ - "group1": {ID: "group1", Name: "group1", Issued: group.GroupIssuedAPI, Peers: []string{}}, - "group2": {ID: "group2", Name: "group2", Issued: group.GroupIssuedAPI, Peers: []string{}}, - "group3": {ID: "group3", Name: "group3", Issued: group.GroupIssuedAPI, Peers: []string{}}, + Groups: map[string]*types.Group{ + "group1": {ID: "group1", Name: "group1", Issued: types.GroupIssuedAPI, Peers: []string{}}, + "group2": {ID: "group2", Name: "group2", Issued: types.GroupIssuedAPI, Peers: []string{}}, + "group3": {ID: "group3", Name: "group3", Issued: types.GroupIssuedAPI, Peers: []string{}}, }, - Users: map[string]*User{"user1": {Id: "user1"}, "user2": {Id: "user2"}}, + Users: map[string]*types.User{"user1": {Id: "user1"}, "user2": {Id: "user2"}}, } t.Run("add groups", func(t *testing.T) { @@ -2835,7 +2874,7 @@ func TestAccount_UserGroupsAddToPeers(t *testing.T) { } func TestAccount_UserGroupsRemoveFromPeers(t *testing.T) { - account := &Account{ + account := &types.Account{ Peers: map[string]*nbpeer.Peer{ "peer1": {ID: "peer1", Key: "key1", UserID: "user1"}, "peer2": {ID: "peer2", Key: "key2", UserID: "user1"}, @@ -2843,12 +2882,12 @@ func TestAccount_UserGroupsRemoveFromPeers(t *testing.T) { "peer4": {ID: "peer4", Key: "key4", UserID: "user2"}, "peer5": {ID: "peer5", Key: "key5", UserID: "user2"}, }, - Groups: map[string]*group.Group{ - "group1": {ID: "group1", Name: "group1", Issued: group.GroupIssuedAPI, Peers: []string{"peer1", "peer2", "peer3"}}, - "group2": {ID: "group2", Name: "group2", Issued: group.GroupIssuedAPI, Peers: []string{"peer1", "peer2", "peer3", "peer4", "peer5"}}, - "group3": {ID: "group3", Name: "group3", Issued: group.GroupIssuedAPI, Peers: []string{"peer4", "peer5"}}, + Groups: map[string]*types.Group{ + "group1": {ID: "group1", Name: "group1", Issued: types.GroupIssuedAPI, Peers: []string{"peer1", "peer2", "peer3"}}, + "group2": {ID: "group2", Name: "group2", Issued: types.GroupIssuedAPI, Peers: []string{"peer1", "peer2", "peer3", "peer4", "peer5"}}, + "group3": {ID: "group3", Name: "group3", Issued: types.GroupIssuedAPI, Peers: []string{"peer4", "peer5"}}, }, - Users: map[string]*User{"user1": {Id: "user1"}, "user2": {Id: "user2"}}, + Users: map[string]*types.User{"user1": {Id: "user1"}, "user2": {Id: "user2"}}, } t.Run("remove groups", func(t *testing.T) { @@ -2891,10 +2930,10 @@ func createManager(t TB) (*DefaultAccountManager, error) { return manager, nil } -func createStore(t TB) (Store, error) { +func createStore(t TB) (store.Store, error) { t.Helper() dataDir := t.TempDir() - store, cleanUp, err := NewTestStoreFromSQL(context.Background(), "", dataDir) + store, cleanUp, err := store.NewTestStoreFromSQL(context.Background(), "", dataDir) if err != nil { return nil, err } @@ -2917,7 +2956,7 @@ func waitTimeout(wg *sync.WaitGroup, timeout time.Duration) bool { } } -func setupNetworkMapTest(t *testing.T) (*DefaultAccountManager, *Account, *nbpeer.Peer, *nbpeer.Peer, *nbpeer.Peer) { +func setupNetworkMapTest(t *testing.T) (*DefaultAccountManager, *types.Account, *nbpeer.Peer, *nbpeer.Peer, *nbpeer.Peer) { t.Helper() manager, err := createManager(t) @@ -2930,12 +2969,12 @@ func setupNetworkMapTest(t *testing.T) (*DefaultAccountManager, *Account, *nbpee t.Fatal(err) } - setupKey, err := manager.CreateSetupKey(context.Background(), account.Id, "test-key", SetupKeyReusable, time.Hour, nil, 999, userID, false) + setupKey, err := manager.CreateSetupKey(context.Background(), account.Id, "test-key", types.SetupKeyReusable, time.Hour, nil, 999, userID, false) if err != nil { t.Fatal("error creating setup key") } - getPeer := func(manager *DefaultAccountManager, setupKey *SetupKey) *nbpeer.Peer { + getPeer := func(manager *DefaultAccountManager, setupKey *types.SetupKey) *nbpeer.Peer { key, err := wgtypes.GeneratePrivateKey() if err != nil { t.Fatal(err) diff --git a/management/server/config.go b/management/server/config.go index 2f7e49766..f3555b92b 100644 --- a/management/server/config.go +++ b/management/server/config.go @@ -5,6 +5,7 @@ import ( "net/url" "github.com/netbirdio/netbird/management/server/idp" + "github.com/netbirdio/netbird/management/server/store" "github.com/netbirdio/netbird/util" ) @@ -156,7 +157,7 @@ type ProviderConfig struct { // StoreConfig contains Store configuration type StoreConfig struct { - Engine StoreEngine + Engine store.Engine } // ReverseProxy contains reverse proxy configuration in front of management. diff --git a/management/server/dns.go b/management/server/dns.go index 8df211b0b..27c27dd47 100644 --- a/management/server/dns.go +++ b/management/server/dns.go @@ -2,9 +2,7 @@ package server import ( "context" - "fmt" "slices" - "strconv" "sync" log "github.com/sirupsen/logrus" @@ -12,12 +10,12 @@ import ( nbdns "github.com/netbirdio/netbird/dns" "github.com/netbirdio/netbird/management/proto" "github.com/netbirdio/netbird/management/server/activity" - nbpeer "github.com/netbirdio/netbird/management/server/peer" "github.com/netbirdio/netbird/management/server/status" + "github.com/netbirdio/netbird/management/server/store" + "github.com/netbirdio/netbird/management/server/types" + "github.com/netbirdio/netbird/management/server/util" ) -const defaultTTL = 300 - // DNSConfigCache is a thread-safe cache for DNS configuration components type DNSConfigCache struct { CustomZones sync.Map @@ -62,26 +60,9 @@ func (c *DNSConfigCache) SetNameServerGroup(key string, value *proto.NameServerG c.NameServerGroups.Store(key, value) } -type lookupMap map[string]struct{} - -// DNSSettings defines dns settings at the account level -type DNSSettings struct { - // DisabledManagementGroups groups whose DNS management is disabled - DisabledManagementGroups []string `gorm:"serializer:json"` -} - -// Copy returns a copy of the DNS settings -func (d DNSSettings) Copy() DNSSettings { - settings := DNSSettings{ - DisabledManagementGroups: make([]string, len(d.DisabledManagementGroups)), - } - copy(settings.DisabledManagementGroups, d.DisabledManagementGroups) - return settings -} - // GetDNSSettings validates a user role and returns the DNS settings for the provided account ID -func (am *DefaultAccountManager) GetDNSSettings(ctx context.Context, accountID string, userID string) (*DNSSettings, error) { - user, err := am.Store.GetUserByUserID(ctx, LockingStrengthShare, userID) +func (am *DefaultAccountManager) GetDNSSettings(ctx context.Context, accountID string, userID string) (*types.DNSSettings, error) { + user, err := am.Store.GetUserByUserID(ctx, store.LockingStrengthShare, userID) if err != nil { return nil, err } @@ -94,16 +75,16 @@ func (am *DefaultAccountManager) GetDNSSettings(ctx context.Context, accountID s return nil, status.NewAdminPermissionError() } - return am.Store.GetAccountDNSSettings(ctx, LockingStrengthShare, accountID) + return am.Store.GetAccountDNSSettings(ctx, store.LockingStrengthShare, accountID) } // SaveDNSSettings validates a user role and updates the account's DNS settings -func (am *DefaultAccountManager) SaveDNSSettings(ctx context.Context, accountID string, userID string, dnsSettingsToSave *DNSSettings) error { +func (am *DefaultAccountManager) SaveDNSSettings(ctx context.Context, accountID string, userID string, dnsSettingsToSave *types.DNSSettings) error { if dnsSettingsToSave == nil { return status.Errorf(status.InvalidArgument, "the dns settings provided are nil") } - user, err := am.Store.GetUserByUserID(ctx, LockingStrengthShare, userID) + user, err := am.Store.GetUserByUserID(ctx, store.LockingStrengthShare, userID) if err != nil { return err } @@ -119,18 +100,18 @@ func (am *DefaultAccountManager) SaveDNSSettings(ctx context.Context, accountID var updateAccountPeers bool var eventsToStore []func() - err = am.Store.ExecuteInTransaction(ctx, func(transaction Store) error { + err = am.Store.ExecuteInTransaction(ctx, func(transaction store.Store) error { if err = validateDNSSettings(ctx, transaction, accountID, dnsSettingsToSave); err != nil { return err } - oldSettings, err := transaction.GetAccountDNSSettings(ctx, LockingStrengthUpdate, accountID) + oldSettings, err := transaction.GetAccountDNSSettings(ctx, store.LockingStrengthUpdate, accountID) if err != nil { return err } - addedGroups := difference(dnsSettingsToSave.DisabledManagementGroups, oldSettings.DisabledManagementGroups) - removedGroups := difference(oldSettings.DisabledManagementGroups, dnsSettingsToSave.DisabledManagementGroups) + addedGroups := util.Difference(dnsSettingsToSave.DisabledManagementGroups, oldSettings.DisabledManagementGroups) + removedGroups := util.Difference(oldSettings.DisabledManagementGroups, dnsSettingsToSave.DisabledManagementGroups) updateAccountPeers, err = areDNSSettingChangesAffectPeers(ctx, transaction, accountID, addedGroups, removedGroups) if err != nil { @@ -140,11 +121,11 @@ func (am *DefaultAccountManager) SaveDNSSettings(ctx context.Context, accountID events := am.prepareDNSSettingsEvents(ctx, transaction, accountID, userID, addedGroups, removedGroups) eventsToStore = append(eventsToStore, events...) - if err = transaction.IncrementNetworkSerial(ctx, LockingStrengthUpdate, accountID); err != nil { + if err = transaction.IncrementNetworkSerial(ctx, store.LockingStrengthUpdate, accountID); err != nil { return err } - return transaction.SaveDNSSettings(ctx, LockingStrengthUpdate, accountID, dnsSettingsToSave) + return transaction.SaveDNSSettings(ctx, store.LockingStrengthUpdate, accountID, dnsSettingsToSave) }) if err != nil { return err @@ -162,11 +143,11 @@ func (am *DefaultAccountManager) SaveDNSSettings(ctx context.Context, accountID } // prepareDNSSettingsEvents prepares a list of event functions to be stored. -func (am *DefaultAccountManager) prepareDNSSettingsEvents(ctx context.Context, transaction Store, accountID, userID string, addedGroups, removedGroups []string) []func() { +func (am *DefaultAccountManager) prepareDNSSettingsEvents(ctx context.Context, transaction store.Store, accountID, userID string, addedGroups, removedGroups []string) []func() { var eventsToStore []func() modifiedGroups := slices.Concat(addedGroups, removedGroups) - groups, err := transaction.GetGroupsByIDs(ctx, LockingStrengthShare, accountID, modifiedGroups) + groups, err := transaction.GetGroupsByIDs(ctx, store.LockingStrengthShare, accountID, modifiedGroups) if err != nil { log.WithContext(ctx).Debugf("failed to get groups for dns settings events: %v", err) return nil @@ -203,7 +184,7 @@ func (am *DefaultAccountManager) prepareDNSSettingsEvents(ctx context.Context, t } // areDNSSettingChangesAffectPeers checks if the DNS settings changes affect any peers. -func areDNSSettingChangesAffectPeers(ctx context.Context, transaction Store, accountID string, addedGroups, removedGroups []string) (bool, error) { +func areDNSSettingChangesAffectPeers(ctx context.Context, transaction store.Store, accountID string, addedGroups, removedGroups []string) (bool, error) { hasPeers, err := anyGroupHasPeers(ctx, transaction, accountID, addedGroups) if err != nil { return false, err @@ -217,12 +198,12 @@ func areDNSSettingChangesAffectPeers(ctx context.Context, transaction Store, acc } // validateDNSSettings validates the DNS settings. -func validateDNSSettings(ctx context.Context, transaction Store, accountID string, settings *DNSSettings) error { +func validateDNSSettings(ctx context.Context, transaction store.Store, accountID string, settings *types.DNSSettings) error { if len(settings.DisabledManagementGroups) == 0 { return nil } - groups, err := transaction.GetGroupsByIDs(ctx, LockingStrengthShare, accountID, settings.DisabledManagementGroups) + groups, err := transaction.GetGroupsByIDs(ctx, store.LockingStrengthShare, accountID, settings.DisabledManagementGroups) if err != nil { return err } @@ -298,81 +279,3 @@ func convertToProtoNameServerGroup(nsGroup *nbdns.NameServerGroup) *proto.NameSe } return protoGroup } - -func getPeerNSGroups(account *Account, peerID string) []*nbdns.NameServerGroup { - groupList := account.getPeerGroups(peerID) - - var peerNSGroups []*nbdns.NameServerGroup - - for _, nsGroup := range account.NameServerGroups { - if !nsGroup.Enabled { - continue - } - for _, gID := range nsGroup.Groups { - _, found := groupList[gID] - if found { - if !peerIsNameserver(account.GetPeer(peerID), nsGroup) { - peerNSGroups = append(peerNSGroups, nsGroup.Copy()) - break - } - } - } - } - - return peerNSGroups -} - -// peerIsNameserver returns true if the peer is a nameserver for a nsGroup -func peerIsNameserver(peer *nbpeer.Peer, nsGroup *nbdns.NameServerGroup) bool { - for _, ns := range nsGroup.NameServers { - if peer.IP.Equal(ns.IP.AsSlice()) { - return true - } - } - return false -} - -func addPeerLabelsToAccount(ctx context.Context, account *Account, peerLabels lookupMap) { - for _, peer := range account.Peers { - label, err := getPeerHostLabel(peer.Name, peerLabels) - if err != nil { - log.WithContext(ctx).Errorf("got an error while generating a peer host label. Peer name %s, error: %v. Trying with the peer's meta hostname", peer.Name, err) - label, err = getPeerHostLabel(peer.Meta.Hostname, peerLabels) - if err != nil { - log.WithContext(ctx).Errorf("got another error while generating a peer host label with hostname. Peer hostname %s, error: %v. Skipping", peer.Meta.Hostname, err) - continue - } - } - peer.DNSLabel = label - peerLabels[label] = struct{}{} - } -} - -func getPeerHostLabel(name string, peerLabels lookupMap) (string, error) { - label, err := nbdns.GetParsedDomainLabel(name) - if err != nil { - return "", err - } - - uniqueLabel := getUniqueHostLabel(label, peerLabels) - if uniqueLabel == "" { - return "", fmt.Errorf("couldn't find a unique valid label for %s, parsed label %s", name, label) - } - return uniqueLabel, nil -} - -// getUniqueHostLabel look for a unique host label, and if doesn't find add a suffix up to 999 -func getUniqueHostLabel(name string, peerLabels lookupMap) string { - _, found := peerLabels[name] - if !found { - return name - } - for i := 1; i < 1000; i++ { - nameWithSuffix := name + "-" + strconv.Itoa(i) - _, found = peerLabels[nameWithSuffix] - if !found { - return nameWithSuffix - } - } - return "" -} diff --git a/management/server/dns_test.go b/management/server/dns_test.go index 8a66da96c..6fb9f6a29 100644 --- a/management/server/dns_test.go +++ b/management/server/dns_test.go @@ -11,13 +11,14 @@ import ( "github.com/stretchr/testify/assert" nbdns "github.com/netbirdio/netbird/dns" + "github.com/netbirdio/netbird/management/server/store" "github.com/netbirdio/netbird/management/server/telemetry" + "github.com/netbirdio/netbird/management/server/types" "github.com/stretchr/testify/require" "github.com/netbirdio/netbird/dns" "github.com/netbirdio/netbird/management/server/activity" - "github.com/netbirdio/netbird/management/server/group" nbpeer "github.com/netbirdio/netbird/management/server/peer" "github.com/netbirdio/netbird/management/server/status" ) @@ -53,7 +54,7 @@ func TestGetDNSSettings(t *testing.T) { t.Fatal("DNS settings for new accounts shouldn't return nil") } - account.DNSSettings = DNSSettings{ + account.DNSSettings = types.DNSSettings{ DisabledManagementGroups: []string{group1ID}, } @@ -86,20 +87,20 @@ func TestSaveDNSSettings(t *testing.T) { testCases := []struct { name string userID string - inputSettings *DNSSettings + inputSettings *types.DNSSettings shouldFail bool }{ { name: "Saving As Admin Should Be OK", userID: dnsAdminUserID, - inputSettings: &DNSSettings{ + inputSettings: &types.DNSSettings{ DisabledManagementGroups: []string{dnsGroup1ID}, }, }, { name: "Should Not Update Settings As Regular User", userID: dnsRegularUserID, - inputSettings: &DNSSettings{ + inputSettings: &types.DNSSettings{ DisabledManagementGroups: []string{dnsGroup1ID}, }, shouldFail: true, @@ -113,7 +114,7 @@ func TestSaveDNSSettings(t *testing.T) { { name: "Should Not Update Settings If Group Is Invalid", userID: dnsAdminUserID, - inputSettings: &DNSSettings{ + inputSettings: &types.DNSSettings{ DisabledManagementGroups: []string{"non-existing-group"}, }, shouldFail: true, @@ -210,10 +211,10 @@ func createDNSManager(t *testing.T) (*DefaultAccountManager, error) { return BuildManager(context.Background(), store, NewPeersUpdateManager(nil), nil, "", "netbird.test", eventStore, nil, false, MocIntegratedValidator{}, metrics) } -func createDNSStore(t *testing.T) (Store, error) { +func createDNSStore(t *testing.T) (store.Store, error) { t.Helper() dataDir := t.TempDir() - store, cleanUp, err := NewTestStoreFromSQL(context.Background(), "", dataDir) + store, cleanUp, err := store.NewTestStoreFromSQL(context.Background(), "", dataDir) if err != nil { return nil, err } @@ -222,7 +223,7 @@ func createDNSStore(t *testing.T) (Store, error) { return store, nil } -func initTestDNSAccount(t *testing.T, am *DefaultAccountManager) (*Account, error) { +func initTestDNSAccount(t *testing.T, am *DefaultAccountManager) (*types.Account, error) { t.Helper() peer1 := &nbpeer.Peer{ Key: dnsPeer1Key, @@ -259,9 +260,9 @@ func initTestDNSAccount(t *testing.T, am *DefaultAccountManager) (*Account, erro account := newAccountWithId(context.Background(), dnsAccountID, dnsAdminUserID, domain) - account.Users[dnsRegularUserID] = &User{ + account.Users[dnsRegularUserID] = &types.User{ Id: dnsRegularUserID, - Role: UserRoleUser, + Role: types.UserRoleUser, } err := am.Store.SaveAccount(context.Background(), account) @@ -293,13 +294,13 @@ func initTestDNSAccount(t *testing.T, am *DefaultAccountManager) (*Account, erro return nil, err } - newGroup1 := &group.Group{ + newGroup1 := &types.Group{ ID: dnsGroup1ID, Peers: []string{peer1.ID}, Name: dnsGroup1ID, } - newGroup2 := &group.Group{ + newGroup2 := &types.Group{ ID: dnsGroup2ID, Name: dnsGroup2ID, } @@ -483,7 +484,7 @@ func TestToProtocolDNSConfigWithCache(t *testing.T) { func TestDNSAccountPeersUpdate(t *testing.T) { manager, account, peer1, peer2, peer3 := setupNetworkMapTest(t) - err := manager.SaveGroups(context.Background(), account.Id, userID, []*group.Group{ + err := manager.SaveGroups(context.Background(), account.Id, userID, []*types.Group{ { ID: "groupA", Name: "GroupA", @@ -510,7 +511,7 @@ func TestDNSAccountPeersUpdate(t *testing.T) { close(done) }() - err := manager.SaveDNSSettings(context.Background(), account.Id, userID, &DNSSettings{ + err := manager.SaveDNSSettings(context.Background(), account.Id, userID, &types.DNSSettings{ DisabledManagementGroups: []string{"groupA"}, }) assert.NoError(t, err) @@ -550,7 +551,7 @@ func TestDNSAccountPeersUpdate(t *testing.T) { // Creating DNS settings with groups that have peers should update account peers and send peer update t.Run("creating dns setting with used groups", func(t *testing.T) { - err = manager.SaveGroup(context.Background(), account.Id, userID, &group.Group{ + err = manager.SaveGroup(context.Background(), account.Id, userID, &types.Group{ ID: "groupA", Name: "GroupA", Peers: []string{peer1.ID, peer2.ID, peer3.ID}, @@ -589,7 +590,7 @@ func TestDNSAccountPeersUpdate(t *testing.T) { close(done) }() - err := manager.SaveDNSSettings(context.Background(), account.Id, userID, &DNSSettings{ + err := manager.SaveDNSSettings(context.Background(), account.Id, userID, &types.DNSSettings{ DisabledManagementGroups: []string{"groupA", "groupB"}, }) assert.NoError(t, err) @@ -609,7 +610,7 @@ func TestDNSAccountPeersUpdate(t *testing.T) { close(done) }() - err := manager.SaveDNSSettings(context.Background(), account.Id, userID, &DNSSettings{ + err := manager.SaveDNSSettings(context.Background(), account.Id, userID, &types.DNSSettings{ DisabledManagementGroups: []string{"groupA"}, }) assert.NoError(t, err) @@ -629,7 +630,7 @@ func TestDNSAccountPeersUpdate(t *testing.T) { close(done) }() - err := manager.SaveDNSSettings(context.Background(), account.Id, userID, &DNSSettings{ + err := manager.SaveDNSSettings(context.Background(), account.Id, userID, &types.DNSSettings{ DisabledManagementGroups: []string{}, }) assert.NoError(t, err) diff --git a/management/server/ephemeral.go b/management/server/ephemeral.go index 590b1d708..3c629a0db 100644 --- a/management/server/ephemeral.go +++ b/management/server/ephemeral.go @@ -9,6 +9,8 @@ import ( "github.com/netbirdio/netbird/management/server/activity" nbpeer "github.com/netbirdio/netbird/management/server/peer" + "github.com/netbirdio/netbird/management/server/store" + "github.com/netbirdio/netbird/management/server/types" ) const ( @@ -21,7 +23,7 @@ var ( type ephemeralPeer struct { id string - account *Account + account *types.Account deadline time.Time next *ephemeralPeer } @@ -32,7 +34,7 @@ type ephemeralPeer struct { // EphemeralManager keep a list of ephemeral peers. After ephemeralLifeTime inactivity the peer will be deleted // automatically. Inactivity means the peer disconnected from the Management server. type EphemeralManager struct { - store Store + store store.Store accountManager AccountManager headPeer *ephemeralPeer @@ -42,7 +44,7 @@ type EphemeralManager struct { } // NewEphemeralManager instantiate new EphemeralManager -func NewEphemeralManager(store Store, accountManager AccountManager) *EphemeralManager { +func NewEphemeralManager(store store.Store, accountManager AccountManager) *EphemeralManager { return &EphemeralManager{ store: store, accountManager: accountManager, @@ -177,7 +179,7 @@ func (e *EphemeralManager) cleanup(ctx context.Context) { } } -func (e *EphemeralManager) addPeer(id string, account *Account, deadline time.Time) { +func (e *EphemeralManager) addPeer(id string, account *types.Account, deadline time.Time) { ep := &ephemeralPeer{ id: id, account: account, diff --git a/management/server/ephemeral_test.go b/management/server/ephemeral_test.go index 1390352a5..ac8372440 100644 --- a/management/server/ephemeral_test.go +++ b/management/server/ephemeral_test.go @@ -8,18 +8,20 @@ import ( nbpeer "github.com/netbirdio/netbird/management/server/peer" "github.com/netbirdio/netbird/management/server/status" + "github.com/netbirdio/netbird/management/server/store" + "github.com/netbirdio/netbird/management/server/types" ) type MockStore struct { - Store - account *Account + store.Store + account *types.Account } -func (s *MockStore) GetAllAccounts(_ context.Context) []*Account { - return []*Account{s.account} +func (s *MockStore) GetAllAccounts(_ context.Context) []*types.Account { + return []*types.Account{s.account} } -func (s *MockStore) GetAccountByPeerID(_ context.Context, peerId string) (*Account, error) { +func (s *MockStore) GetAccountByPeerID(_ context.Context, peerId string) (*types.Account, error) { _, ok := s.account.Peers[peerId] if ok { return s.account, nil diff --git a/management/server/group.go b/management/server/group.go index 7b307cf1a..cd228af65 100644 --- a/management/server/group.go +++ b/management/server/group.go @@ -10,10 +10,12 @@ import ( log "github.com/sirupsen/logrus" nbdns "github.com/netbirdio/netbird/dns" + "github.com/netbirdio/netbird/management/server/store" + "github.com/netbirdio/netbird/management/server/types" + "github.com/netbirdio/netbird/management/server/util" "github.com/netbirdio/netbird/route" "github.com/netbirdio/netbird/management/server/activity" - nbgroup "github.com/netbirdio/netbird/management/server/group" "github.com/netbirdio/netbird/management/server/status" ) @@ -28,7 +30,7 @@ func (e *GroupLinkError) Error() string { // CheckGroupPermissions validates if a user has the necessary permissions to view groups func (am *DefaultAccountManager) CheckGroupPermissions(ctx context.Context, accountID, userID string) error { - user, err := am.Store.GetUserByUserID(ctx, LockingStrengthShare, userID) + user, err := am.Store.GetUserByUserID(ctx, store.LockingStrengthShare, userID) if err != nil { return err } @@ -45,38 +47,38 @@ func (am *DefaultAccountManager) CheckGroupPermissions(ctx context.Context, acco } // GetGroup returns a specific group by groupID in an account -func (am *DefaultAccountManager) GetGroup(ctx context.Context, accountID, groupID, userID string) (*nbgroup.Group, error) { +func (am *DefaultAccountManager) GetGroup(ctx context.Context, accountID, groupID, userID string) (*types.Group, error) { if err := am.CheckGroupPermissions(ctx, accountID, userID); err != nil { return nil, err } - return am.Store.GetGroupByID(ctx, LockingStrengthShare, accountID, groupID) + return am.Store.GetGroupByID(ctx, store.LockingStrengthShare, accountID, groupID) } // GetAllGroups returns all groups in an account -func (am *DefaultAccountManager) GetAllGroups(ctx context.Context, accountID, userID string) ([]*nbgroup.Group, error) { +func (am *DefaultAccountManager) GetAllGroups(ctx context.Context, accountID, userID string) ([]*types.Group, error) { if err := am.CheckGroupPermissions(ctx, accountID, userID); err != nil { return nil, err } - return am.Store.GetAccountGroups(ctx, LockingStrengthShare, accountID) + return am.Store.GetAccountGroups(ctx, store.LockingStrengthShare, accountID) } // GetGroupByName filters all groups in an account by name and returns the one with the most peers -func (am *DefaultAccountManager) GetGroupByName(ctx context.Context, groupName, accountID string) (*nbgroup.Group, error) { - return am.Store.GetGroupByName(ctx, LockingStrengthShare, accountID, groupName) +func (am *DefaultAccountManager) GetGroupByName(ctx context.Context, groupName, accountID string) (*types.Group, error) { + return am.Store.GetGroupByName(ctx, store.LockingStrengthShare, accountID, groupName) } // SaveGroup object of the peers -func (am *DefaultAccountManager) SaveGroup(ctx context.Context, accountID, userID string, newGroup *nbgroup.Group) error { +func (am *DefaultAccountManager) SaveGroup(ctx context.Context, accountID, userID string, newGroup *types.Group) error { unlock := am.Store.AcquireWriteLockByUID(ctx, accountID) defer unlock() - return am.SaveGroups(ctx, accountID, userID, []*nbgroup.Group{newGroup}) + return am.SaveGroups(ctx, accountID, userID, []*types.Group{newGroup}) } // SaveGroups adds new groups to the account. // Note: This function does not acquire the global lock. // It is the caller's responsibility to ensure proper locking is in place before invoking this method. -func (am *DefaultAccountManager) SaveGroups(ctx context.Context, accountID, userID string, groups []*nbgroup.Group) error { - user, err := am.Store.GetUserByUserID(ctx, LockingStrengthShare, userID) +func (am *DefaultAccountManager) SaveGroups(ctx context.Context, accountID, userID string, groups []*types.Group) error { + user, err := am.Store.GetUserByUserID(ctx, store.LockingStrengthShare, userID) if err != nil { return err } @@ -90,10 +92,10 @@ func (am *DefaultAccountManager) SaveGroups(ctx context.Context, accountID, user } var eventsToStore []func() - var groupsToSave []*nbgroup.Group + var groupsToSave []*types.Group var updateAccountPeers bool - err = am.Store.ExecuteInTransaction(ctx, func(transaction Store) error { + err = am.Store.ExecuteInTransaction(ctx, func(transaction store.Store) error { groupIDs := make([]string, 0, len(groups)) for _, newGroup := range groups { if err = validateNewGroup(ctx, transaction, accountID, newGroup); err != nil { @@ -113,11 +115,11 @@ func (am *DefaultAccountManager) SaveGroups(ctx context.Context, accountID, user return err } - if err = transaction.IncrementNetworkSerial(ctx, LockingStrengthUpdate, accountID); err != nil { + if err = transaction.IncrementNetworkSerial(ctx, store.LockingStrengthUpdate, accountID); err != nil { return err } - return transaction.SaveGroups(ctx, LockingStrengthUpdate, groupsToSave) + return transaction.SaveGroups(ctx, store.LockingStrengthUpdate, groupsToSave) }) if err != nil { return err @@ -135,16 +137,16 @@ func (am *DefaultAccountManager) SaveGroups(ctx context.Context, accountID, user } // prepareGroupEvents prepares a list of event functions to be stored. -func (am *DefaultAccountManager) prepareGroupEvents(ctx context.Context, transaction Store, accountID, userID string, newGroup *nbgroup.Group) []func() { +func (am *DefaultAccountManager) prepareGroupEvents(ctx context.Context, transaction store.Store, accountID, userID string, newGroup *types.Group) []func() { var eventsToStore []func() addedPeers := make([]string, 0) removedPeers := make([]string, 0) - oldGroup, err := transaction.GetGroupByID(ctx, LockingStrengthShare, accountID, newGroup.ID) + oldGroup, err := transaction.GetGroupByID(ctx, store.LockingStrengthShare, accountID, newGroup.ID) if err == nil && oldGroup != nil { - addedPeers = difference(newGroup.Peers, oldGroup.Peers) - removedPeers = difference(oldGroup.Peers, newGroup.Peers) + addedPeers = util.Difference(newGroup.Peers, oldGroup.Peers) + removedPeers = util.Difference(oldGroup.Peers, newGroup.Peers) } else { addedPeers = append(addedPeers, newGroup.Peers...) eventsToStore = append(eventsToStore, func() { @@ -153,7 +155,7 @@ func (am *DefaultAccountManager) prepareGroupEvents(ctx context.Context, transac } modifiedPeers := slices.Concat(addedPeers, removedPeers) - peers, err := transaction.GetPeersByIDs(ctx, LockingStrengthShare, accountID, modifiedPeers) + peers, err := transaction.GetPeersByIDs(ctx, store.LockingStrengthShare, accountID, modifiedPeers) if err != nil { log.WithContext(ctx).Debugf("failed to get peers for group events: %v", err) return nil @@ -194,21 +196,6 @@ func (am *DefaultAccountManager) prepareGroupEvents(ctx context.Context, transac return eventsToStore } -// difference returns the elements in `a` that aren't in `b`. -func difference(a, b []string) []string { - mb := make(map[string]struct{}, len(b)) - for _, x := range b { - mb[x] = struct{}{} - } - var diff []string - for _, x := range a { - if _, found := mb[x]; !found { - diff = append(diff, x) - } - } - return diff -} - // DeleteGroup object of the peers. func (am *DefaultAccountManager) DeleteGroup(ctx context.Context, accountID, userID, groupID string) error { unlock := am.Store.AcquireWriteLockByUID(ctx, accountID) @@ -223,7 +210,7 @@ func (am *DefaultAccountManager) DeleteGroup(ctx context.Context, accountID, use // If an error occurs while deleting a group, the function skips it and continues deleting other groups. // Errors are collected and returned at the end. func (am *DefaultAccountManager) DeleteGroups(ctx context.Context, accountID, userID string, groupIDs []string) error { - user, err := am.Store.GetUserByUserID(ctx, LockingStrengthShare, userID) + user, err := am.Store.GetUserByUserID(ctx, store.LockingStrengthShare, userID) if err != nil { return err } @@ -238,11 +225,11 @@ func (am *DefaultAccountManager) DeleteGroups(ctx context.Context, accountID, us var allErrors error var groupIDsToDelete []string - var deletedGroups []*nbgroup.Group + var deletedGroups []*types.Group - err = am.Store.ExecuteInTransaction(ctx, func(transaction Store) error { + err = am.Store.ExecuteInTransaction(ctx, func(transaction store.Store) error { for _, groupID := range groupIDs { - group, err := transaction.GetGroupByID(ctx, LockingStrengthUpdate, accountID, groupID) + group, err := transaction.GetGroupByID(ctx, store.LockingStrengthUpdate, accountID, groupID) if err != nil { allErrors = errors.Join(allErrors, err) continue @@ -257,11 +244,11 @@ func (am *DefaultAccountManager) DeleteGroups(ctx context.Context, accountID, us deletedGroups = append(deletedGroups, group) } - if err = transaction.IncrementNetworkSerial(ctx, LockingStrengthUpdate, accountID); err != nil { + if err = transaction.IncrementNetworkSerial(ctx, store.LockingStrengthUpdate, accountID); err != nil { return err } - return transaction.DeleteGroups(ctx, LockingStrengthUpdate, accountID, groupIDsToDelete) + return transaction.DeleteGroups(ctx, store.LockingStrengthUpdate, accountID, groupIDsToDelete) }) if err != nil { return err @@ -279,12 +266,12 @@ func (am *DefaultAccountManager) GroupAddPeer(ctx context.Context, accountID, gr unlock := am.Store.AcquireWriteLockByUID(ctx, accountID) defer unlock() - var group *nbgroup.Group + var group *types.Group var updateAccountPeers bool var err error - err = am.Store.ExecuteInTransaction(ctx, func(transaction Store) error { - group, err = transaction.GetGroupByID(context.Background(), LockingStrengthUpdate, accountID, groupID) + err = am.Store.ExecuteInTransaction(ctx, func(transaction store.Store) error { + group, err = transaction.GetGroupByID(context.Background(), store.LockingStrengthUpdate, accountID, groupID) if err != nil { return err } @@ -298,11 +285,52 @@ func (am *DefaultAccountManager) GroupAddPeer(ctx context.Context, accountID, gr return err } - if err = transaction.IncrementNetworkSerial(ctx, LockingStrengthUpdate, accountID); err != nil { + if err = transaction.IncrementNetworkSerial(ctx, store.LockingStrengthUpdate, accountID); err != nil { return err } - return transaction.SaveGroup(ctx, LockingStrengthUpdate, group) + return transaction.SaveGroup(ctx, store.LockingStrengthUpdate, group) + }) + if err != nil { + return err + } + + if updateAccountPeers { + am.updateAccountPeers(ctx, accountID) + } + + return nil +} + +// GroupAddResource appends resource to the group +func (am *DefaultAccountManager) GroupAddResource(ctx context.Context, accountID, groupID string, resource types.Resource) error { + unlock := am.Store.AcquireWriteLockByUID(ctx, accountID) + defer unlock() + + var group *types.Group + var updateAccountPeers bool + var err error + + err = am.Store.ExecuteInTransaction(ctx, func(transaction store.Store) error { + group, err = transaction.GetGroupByID(context.Background(), store.LockingStrengthUpdate, accountID, groupID) + if err != nil { + return err + } + + if updated := group.AddResource(resource); !updated { + return nil + } + + updateAccountPeers, err = areGroupChangesAffectPeers(ctx, transaction, accountID, []string{groupID}) + if err != nil { + return err + } + + if err = transaction.IncrementNetworkSerial(ctx, store.LockingStrengthUpdate, accountID); err != nil { + return err + } + + return transaction.SaveGroup(ctx, store.LockingStrengthUpdate, group) }) if err != nil { return err @@ -320,12 +348,12 @@ func (am *DefaultAccountManager) GroupDeletePeer(ctx context.Context, accountID, unlock := am.Store.AcquireWriteLockByUID(ctx, accountID) defer unlock() - var group *nbgroup.Group + var group *types.Group var updateAccountPeers bool var err error - err = am.Store.ExecuteInTransaction(ctx, func(transaction Store) error { - group, err = transaction.GetGroupByID(context.Background(), LockingStrengthUpdate, accountID, groupID) + err = am.Store.ExecuteInTransaction(ctx, func(transaction store.Store) error { + group, err = transaction.GetGroupByID(context.Background(), store.LockingStrengthUpdate, accountID, groupID) if err != nil { return err } @@ -339,11 +367,52 @@ func (am *DefaultAccountManager) GroupDeletePeer(ctx context.Context, accountID, return err } - if err = transaction.IncrementNetworkSerial(ctx, LockingStrengthUpdate, accountID); err != nil { + if err = transaction.IncrementNetworkSerial(ctx, store.LockingStrengthUpdate, accountID); err != nil { return err } - return transaction.SaveGroup(ctx, LockingStrengthUpdate, group) + return transaction.SaveGroup(ctx, store.LockingStrengthUpdate, group) + }) + if err != nil { + return err + } + + if updateAccountPeers { + am.updateAccountPeers(ctx, accountID) + } + + return nil +} + +// GroupDeleteResource removes resource from the group +func (am *DefaultAccountManager) GroupDeleteResource(ctx context.Context, accountID, groupID string, resource types.Resource) error { + unlock := am.Store.AcquireWriteLockByUID(ctx, accountID) + defer unlock() + + var group *types.Group + var updateAccountPeers bool + var err error + + err = am.Store.ExecuteInTransaction(ctx, func(transaction store.Store) error { + group, err = transaction.GetGroupByID(context.Background(), store.LockingStrengthUpdate, accountID, groupID) + if err != nil { + return err + } + + if updated := group.RemoveResource(resource); !updated { + return nil + } + + updateAccountPeers, err = areGroupChangesAffectPeers(ctx, transaction, accountID, []string{groupID}) + if err != nil { + return err + } + + if err = transaction.IncrementNetworkSerial(ctx, store.LockingStrengthUpdate, accountID); err != nil { + return err + } + + return transaction.SaveGroup(ctx, store.LockingStrengthUpdate, group) }) if err != nil { return err @@ -357,13 +426,13 @@ func (am *DefaultAccountManager) GroupDeletePeer(ctx context.Context, accountID, } // validateNewGroup validates the new group for existence and required fields. -func validateNewGroup(ctx context.Context, transaction Store, accountID string, newGroup *nbgroup.Group) error { - if newGroup.ID == "" && newGroup.Issued != nbgroup.GroupIssuedAPI { +func validateNewGroup(ctx context.Context, transaction store.Store, accountID string, newGroup *types.Group) error { + if newGroup.ID == "" && newGroup.Issued != types.GroupIssuedAPI { return status.Errorf(status.InvalidArgument, "%s group without ID set", newGroup.Issued) } - if newGroup.ID == "" && newGroup.Issued == nbgroup.GroupIssuedAPI { - existingGroup, err := transaction.GetGroupByName(ctx, LockingStrengthShare, accountID, newGroup.Name) + if newGroup.ID == "" && newGroup.Issued == types.GroupIssuedAPI { + existingGroup, err := transaction.GetGroupByName(ctx, store.LockingStrengthShare, accountID, newGroup.Name) if err != nil { if s, ok := status.FromError(err); !ok || s.Type() != status.NotFound { return err @@ -380,7 +449,7 @@ func validateNewGroup(ctx context.Context, transaction Store, accountID string, } for _, peerID := range newGroup.Peers { - _, err := transaction.GetPeerByID(ctx, LockingStrengthShare, accountID, peerID) + _, err := transaction.GetPeerByID(ctx, store.LockingStrengthShare, accountID, peerID) if err != nil { return status.Errorf(status.InvalidArgument, "peer with ID \"%s\" not found", peerID) } @@ -389,14 +458,14 @@ func validateNewGroup(ctx context.Context, transaction Store, accountID string, return nil } -func validateDeleteGroup(ctx context.Context, transaction Store, group *nbgroup.Group, userID string) error { +func validateDeleteGroup(ctx context.Context, transaction store.Store, group *types.Group, userID string) error { // disable a deleting integration group if the initiator is not an admin service user - if group.Issued == nbgroup.GroupIssuedIntegration { - executingUser, err := transaction.GetUserByUserID(ctx, LockingStrengthShare, userID) + if group.Issued == types.GroupIssuedIntegration { + executingUser, err := transaction.GetUserByUserID(ctx, store.LockingStrengthShare, userID) if err != nil { return err } - if executingUser.Role != UserRoleAdmin || !executingUser.IsServiceUser { + if executingUser.Role != types.UserRoleAdmin || !executingUser.IsServiceUser { return status.Errorf(status.PermissionDenied, "only service users with admin power can delete integration group") } } @@ -429,8 +498,8 @@ func validateDeleteGroup(ctx context.Context, transaction Store, group *nbgroup. } // checkGroupLinkedToSettings verifies if a group is linked to any settings in the account. -func checkGroupLinkedToSettings(ctx context.Context, transaction Store, group *nbgroup.Group) error { - dnsSettings, err := transaction.GetAccountDNSSettings(ctx, LockingStrengthShare, group.AccountID) +func checkGroupLinkedToSettings(ctx context.Context, transaction store.Store, group *types.Group) error { + dnsSettings, err := transaction.GetAccountDNSSettings(ctx, store.LockingStrengthShare, group.AccountID) if err != nil { return err } @@ -439,7 +508,7 @@ func checkGroupLinkedToSettings(ctx context.Context, transaction Store, group *n return &GroupLinkError{"disabled DNS management groups", group.Name} } - settings, err := transaction.GetAccountSettings(ctx, LockingStrengthShare, group.AccountID) + settings, err := transaction.GetAccountSettings(ctx, store.LockingStrengthShare, group.AccountID) if err != nil { return err } @@ -452,8 +521,8 @@ func checkGroupLinkedToSettings(ctx context.Context, transaction Store, group *n } // isGroupLinkedToRoute checks if a group is linked to any route in the account. -func isGroupLinkedToRoute(ctx context.Context, transaction Store, accountID string, groupID string) (bool, *route.Route) { - routes, err := transaction.GetAccountRoutes(ctx, LockingStrengthShare, accountID) +func isGroupLinkedToRoute(ctx context.Context, transaction store.Store, accountID string, groupID string) (bool, *route.Route) { + routes, err := transaction.GetAccountRoutes(ctx, store.LockingStrengthShare, accountID) if err != nil { log.WithContext(ctx).Errorf("error retrieving routes while checking group linkage: %v", err) return false, nil @@ -469,8 +538,8 @@ func isGroupLinkedToRoute(ctx context.Context, transaction Store, accountID stri } // isGroupLinkedToPolicy checks if a group is linked to any policy in the account. -func isGroupLinkedToPolicy(ctx context.Context, transaction Store, accountID string, groupID string) (bool, *Policy) { - policies, err := transaction.GetAccountPolicies(ctx, LockingStrengthShare, accountID) +func isGroupLinkedToPolicy(ctx context.Context, transaction store.Store, accountID string, groupID string) (bool, *types.Policy) { + policies, err := transaction.GetAccountPolicies(ctx, store.LockingStrengthShare, accountID) if err != nil { log.WithContext(ctx).Errorf("error retrieving policies while checking group linkage: %v", err) return false, nil @@ -487,8 +556,8 @@ func isGroupLinkedToPolicy(ctx context.Context, transaction Store, accountID str } // isGroupLinkedToDns checks if a group is linked to any nameserver group in the account. -func isGroupLinkedToDns(ctx context.Context, transaction Store, accountID string, groupID string) (bool, *nbdns.NameServerGroup) { - nameServerGroups, err := transaction.GetAccountNameServerGroups(ctx, LockingStrengthShare, accountID) +func isGroupLinkedToDns(ctx context.Context, transaction store.Store, accountID string, groupID string) (bool, *nbdns.NameServerGroup) { + nameServerGroups, err := transaction.GetAccountNameServerGroups(ctx, store.LockingStrengthShare, accountID) if err != nil { log.WithContext(ctx).Errorf("error retrieving name server groups while checking group linkage: %v", err) return false, nil @@ -506,8 +575,8 @@ func isGroupLinkedToDns(ctx context.Context, transaction Store, accountID string } // isGroupLinkedToSetupKey checks if a group is linked to any setup key in the account. -func isGroupLinkedToSetupKey(ctx context.Context, transaction Store, accountID string, groupID string) (bool, *SetupKey) { - setupKeys, err := transaction.GetAccountSetupKeys(ctx, LockingStrengthShare, accountID) +func isGroupLinkedToSetupKey(ctx context.Context, transaction store.Store, accountID string, groupID string) (bool, *types.SetupKey) { + setupKeys, err := transaction.GetAccountSetupKeys(ctx, store.LockingStrengthShare, accountID) if err != nil { log.WithContext(ctx).Errorf("error retrieving setup keys while checking group linkage: %v", err) return false, nil @@ -522,8 +591,8 @@ func isGroupLinkedToSetupKey(ctx context.Context, transaction Store, accountID s } // isGroupLinkedToUser checks if a group is linked to any user in the account. -func isGroupLinkedToUser(ctx context.Context, transaction Store, accountID string, groupID string) (bool, *User) { - users, err := transaction.GetAccountUsers(ctx, LockingStrengthShare, accountID) +func isGroupLinkedToUser(ctx context.Context, transaction store.Store, accountID string, groupID string) (bool, *types.User) { + users, err := transaction.GetAccountUsers(ctx, store.LockingStrengthShare, accountID) if err != nil { log.WithContext(ctx).Errorf("error retrieving users while checking group linkage: %v", err) return false, nil @@ -538,12 +607,12 @@ func isGroupLinkedToUser(ctx context.Context, transaction Store, accountID strin } // areGroupChangesAffectPeers checks if any changes to the specified groups will affect peers. -func areGroupChangesAffectPeers(ctx context.Context, transaction Store, accountID string, groupIDs []string) (bool, error) { +func areGroupChangesAffectPeers(ctx context.Context, transaction store.Store, accountID string, groupIDs []string) (bool, error) { if len(groupIDs) == 0 { return false, nil } - dnsSettings, err := transaction.GetAccountDNSSettings(ctx, LockingStrengthShare, accountID) + dnsSettings, err := transaction.GetAccountDNSSettings(ctx, store.LockingStrengthShare, accountID) if err != nil { return false, err } @@ -566,7 +635,7 @@ func areGroupChangesAffectPeers(ctx context.Context, transaction Store, accountI return false, nil } -func (am *DefaultAccountManager) anyGroupHasPeers(account *Account, groupIDs []string) bool { +func (am *DefaultAccountManager) anyGroupHasPeers(account *types.Account, groupIDs []string) bool { for _, groupID := range groupIDs { if group, exists := account.Groups[groupID]; exists && group.HasPeers() { return true @@ -576,8 +645,8 @@ func (am *DefaultAccountManager) anyGroupHasPeers(account *Account, groupIDs []s } // anyGroupHasPeers checks if any of the given groups in the account have peers. -func anyGroupHasPeers(ctx context.Context, transaction Store, accountID string, groupIDs []string) (bool, error) { - groups, err := transaction.GetGroupsByIDs(ctx, LockingStrengthShare, accountID, groupIDs) +func anyGroupHasPeers(ctx context.Context, transaction store.Store, accountID string, groupIDs []string) (bool, error) { + groups, err := transaction.GetGroupsByIDs(ctx, store.LockingStrengthShare, accountID, groupIDs) if err != nil { return false, err } diff --git a/management/server/group_test.go b/management/server/group_test.go index ec017fc57..834388d1e 100644 --- a/management/server/group_test.go +++ b/management/server/group_test.go @@ -12,8 +12,8 @@ import ( "github.com/stretchr/testify/require" nbdns "github.com/netbirdio/netbird/dns" - nbgroup "github.com/netbirdio/netbird/management/server/group" "github.com/netbirdio/netbird/management/server/status" + "github.com/netbirdio/netbird/management/server/types" "github.com/netbirdio/netbird/route" ) @@ -32,22 +32,22 @@ func TestDefaultAccountManager_CreateGroup(t *testing.T) { t.Error("failed to init testing account") } for _, group := range account.Groups { - group.Issued = nbgroup.GroupIssuedIntegration + group.Issued = types.GroupIssuedIntegration err = am.SaveGroup(context.Background(), account.Id, groupAdminUserID, group) if err != nil { - t.Errorf("should allow to create %s groups", nbgroup.GroupIssuedIntegration) + t.Errorf("should allow to create %s groups", types.GroupIssuedIntegration) } } for _, group := range account.Groups { - group.Issued = nbgroup.GroupIssuedJWT + group.Issued = types.GroupIssuedJWT err = am.SaveGroup(context.Background(), account.Id, groupAdminUserID, group) if err != nil { - t.Errorf("should allow to create %s groups", nbgroup.GroupIssuedJWT) + t.Errorf("should allow to create %s groups", types.GroupIssuedJWT) } } for _, group := range account.Groups { - group.Issued = nbgroup.GroupIssuedAPI + group.Issued = types.GroupIssuedAPI group.ID = "" err = am.SaveGroup(context.Background(), account.Id, groupAdminUserID, group) if err == nil { @@ -145,13 +145,13 @@ func TestDefaultAccountManager_DeleteGroups(t *testing.T) { manager, account, err := initTestGroupAccount(am) assert.NoError(t, err, "Failed to init testing account") - groups := make([]*nbgroup.Group, 10) + groups := make([]*types.Group, 10) for i := 0; i < 10; i++ { - groups[i] = &nbgroup.Group{ + groups[i] = &types.Group{ ID: fmt.Sprintf("group-%d", i+1), AccountID: account.Id, Name: fmt.Sprintf("group-%d", i+1), - Issued: nbgroup.GroupIssuedAPI, + Issued: types.GroupIssuedAPI, } } @@ -267,63 +267,63 @@ func TestDefaultAccountManager_DeleteGroups(t *testing.T) { } } -func initTestGroupAccount(am *DefaultAccountManager) (*DefaultAccountManager, *Account, error) { +func initTestGroupAccount(am *DefaultAccountManager) (*DefaultAccountManager, *types.Account, error) { accountID := "testingAcc" domain := "example.com" - groupForRoute := &nbgroup.Group{ + groupForRoute := &types.Group{ ID: "grp-for-route", AccountID: "account-id", Name: "Group for route", - Issued: nbgroup.GroupIssuedAPI, + Issued: types.GroupIssuedAPI, Peers: make([]string, 0), } - groupForRoute2 := &nbgroup.Group{ + groupForRoute2 := &types.Group{ ID: "grp-for-route2", AccountID: "account-id", Name: "Group for route", - Issued: nbgroup.GroupIssuedAPI, + Issued: types.GroupIssuedAPI, Peers: make([]string, 0), } - groupForNameServerGroups := &nbgroup.Group{ + groupForNameServerGroups := &types.Group{ ID: "grp-for-name-server-grp", AccountID: "account-id", Name: "Group for name server groups", - Issued: nbgroup.GroupIssuedAPI, + Issued: types.GroupIssuedAPI, Peers: make([]string, 0), } - groupForPolicies := &nbgroup.Group{ + groupForPolicies := &types.Group{ ID: "grp-for-policies", AccountID: "account-id", Name: "Group for policies", - Issued: nbgroup.GroupIssuedAPI, + Issued: types.GroupIssuedAPI, Peers: make([]string, 0), } - groupForSetupKeys := &nbgroup.Group{ + groupForSetupKeys := &types.Group{ ID: "grp-for-keys", AccountID: "account-id", Name: "Group for setup keys", - Issued: nbgroup.GroupIssuedAPI, + Issued: types.GroupIssuedAPI, Peers: make([]string, 0), } - groupForUsers := &nbgroup.Group{ + groupForUsers := &types.Group{ ID: "grp-for-users", AccountID: "account-id", Name: "Group for users", - Issued: nbgroup.GroupIssuedAPI, + Issued: types.GroupIssuedAPI, Peers: make([]string, 0), } - groupForIntegration := &nbgroup.Group{ + groupForIntegration := &types.Group{ ID: "grp-for-integration", AccountID: "account-id", Name: "Group for users integration", - Issued: nbgroup.GroupIssuedIntegration, + Issued: types.GroupIssuedIntegration, Peers: make([]string, 0), } @@ -342,9 +342,9 @@ func initTestGroupAccount(am *DefaultAccountManager) (*DefaultAccountManager, *A Groups: []string{groupForNameServerGroups.ID}, } - policy := &Policy{ + policy := &types.Policy{ ID: "example policy", - Rules: []*PolicyRule{ + Rules: []*types.PolicyRule{ { ID: "example policy rule", Destinations: []string{groupForPolicies.ID}, @@ -352,12 +352,12 @@ func initTestGroupAccount(am *DefaultAccountManager) (*DefaultAccountManager, *A }, } - setupKey := &SetupKey{ + setupKey := &types.SetupKey{ Id: "example setup key", AutoGroups: []string{groupForSetupKeys.ID}, } - user := &User{ + user := &types.User{ Id: "example user", AutoGroups: []string{groupForUsers.ID}, } @@ -392,7 +392,7 @@ func initTestGroupAccount(am *DefaultAccountManager) (*DefaultAccountManager, *A func TestGroupAccountPeersUpdate(t *testing.T) { manager, account, peer1, peer2, peer3 := setupNetworkMapTest(t) - err := manager.SaveGroups(context.Background(), account.Id, userID, []*nbgroup.Group{ + err := manager.SaveGroups(context.Background(), account.Id, userID, []*types.Group{ { ID: "groupA", Name: "GroupA", @@ -429,7 +429,7 @@ func TestGroupAccountPeersUpdate(t *testing.T) { close(done) }() - err := manager.SaveGroup(context.Background(), account.Id, userID, &nbgroup.Group{ + err := manager.SaveGroup(context.Background(), account.Id, userID, &types.Group{ ID: "groupB", Name: "GroupB", Peers: []string{peer1.ID, peer2.ID}, @@ -500,15 +500,15 @@ func TestGroupAccountPeersUpdate(t *testing.T) { }) // adding a group to policy - _, err = manager.SavePolicy(context.Background(), account.Id, userID, &Policy{ + _, err = manager.SavePolicy(context.Background(), account.Id, userID, &types.Policy{ Enabled: true, - Rules: []*PolicyRule{ + Rules: []*types.PolicyRule{ { Enabled: true, Sources: []string{"groupA"}, Destinations: []string{"groupA"}, Bidirectional: true, - Action: PolicyTrafficActionAccept, + Action: types.PolicyTrafficActionAccept, }, }, }) @@ -522,7 +522,7 @@ func TestGroupAccountPeersUpdate(t *testing.T) { close(done) }() - err := manager.SaveGroup(context.Background(), account.Id, userID, &nbgroup.Group{ + err := manager.SaveGroup(context.Background(), account.Id, userID, &types.Group{ ID: "groupA", Name: "GroupA", Peers: []string{peer1.ID, peer2.ID}, @@ -591,7 +591,7 @@ func TestGroupAccountPeersUpdate(t *testing.T) { close(done) }() - err := manager.SaveGroup(context.Background(), account.Id, userID, &nbgroup.Group{ + err := manager.SaveGroup(context.Background(), account.Id, userID, &types.Group{ ID: "groupC", Name: "GroupC", Peers: []string{peer1.ID, peer3.ID}, @@ -632,7 +632,7 @@ func TestGroupAccountPeersUpdate(t *testing.T) { close(done) }() - err = manager.SaveGroup(context.Background(), account.Id, userID, &nbgroup.Group{ + err = manager.SaveGroup(context.Background(), account.Id, userID, &types.Group{ ID: "groupA", Name: "GroupA", Peers: []string{peer1.ID, peer2.ID, peer3.ID}, @@ -648,7 +648,7 @@ func TestGroupAccountPeersUpdate(t *testing.T) { // Saving a group linked to dns settings should update account peers and send peer update t.Run("saving group linked to dns settings", func(t *testing.T) { - err := manager.SaveDNSSettings(context.Background(), account.Id, userID, &DNSSettings{ + err := manager.SaveDNSSettings(context.Background(), account.Id, userID, &types.DNSSettings{ DisabledManagementGroups: []string{"groupD"}, }) assert.NoError(t, err) @@ -659,7 +659,7 @@ func TestGroupAccountPeersUpdate(t *testing.T) { close(done) }() - err = manager.SaveGroup(context.Background(), account.Id, userID, &nbgroup.Group{ + err = manager.SaveGroup(context.Background(), account.Id, userID, &types.Group{ ID: "groupD", Name: "GroupD", Peers: []string{peer1.ID}, diff --git a/management/server/grpcserver.go b/management/server/grpcserver.go index 9c12336f8..b5c782d0d 100644 --- a/management/server/grpcserver.go +++ b/management/server/grpcserver.go @@ -25,6 +25,7 @@ import ( "github.com/netbirdio/netbird/management/server/posture" internalStatus "github.com/netbirdio/netbird/management/server/status" "github.com/netbirdio/netbird/management/server/telemetry" + "github.com/netbirdio/netbird/management/server/types" ) // GRPCServer an instance of a Management gRPC API server @@ -599,7 +600,7 @@ func toWiretrusteeConfig(config *Config, turnCredentials *Token, relayToken *Tok } } -func toPeerConfig(peer *nbpeer.Peer, network *Network, dnsName string) *proto.PeerConfig { +func toPeerConfig(peer *nbpeer.Peer, network *types.Network, dnsName string) *proto.PeerConfig { netmask, _ := network.Net.Mask.Size() fqdn := peer.FQDN(dnsName) return &proto.PeerConfig{ @@ -609,7 +610,7 @@ func toPeerConfig(peer *nbpeer.Peer, network *Network, dnsName string) *proto.Pe } } -func toSyncResponse(ctx context.Context, config *Config, peer *nbpeer.Peer, turnCredentials *Token, relayCredentials *Token, networkMap *NetworkMap, dnsName string, checks []*posture.Checks, dnsCache *DNSConfigCache) *proto.SyncResponse { +func toSyncResponse(ctx context.Context, config *Config, peer *nbpeer.Peer, turnCredentials *Token, relayCredentials *Token, networkMap *types.NetworkMap, dnsName string, checks []*posture.Checks, dnsCache *DNSConfigCache) *proto.SyncResponse { response := &proto.SyncResponse{ WiretrusteeConfig: toWiretrusteeConfig(config, turnCredentials, relayCredentials), PeerConfig: toPeerConfig(peer, networkMap.Network, dnsName), @@ -661,7 +662,7 @@ func (s *GRPCServer) IsHealthy(ctx context.Context, req *proto.Empty) (*proto.Em } // sendInitialSync sends initial proto.SyncResponse to the peer requesting synchronization -func (s *GRPCServer) sendInitialSync(ctx context.Context, peerKey wgtypes.Key, peer *nbpeer.Peer, networkMap *NetworkMap, postureChecks []*posture.Checks, srv proto.ManagementService_SyncServer) error { +func (s *GRPCServer) sendInitialSync(ctx context.Context, peerKey wgtypes.Key, peer *nbpeer.Peer, networkMap *types.NetworkMap, postureChecks []*posture.Checks, srv proto.ManagementService_SyncServer) error { var err error var turnToken *Token diff --git a/management/server/http/api/openapi.yml b/management/server/http/api/openapi.yml index 2e084f6e4..6a1088141 100644 --- a/management/server/http/api/openapi.yml +++ b/management/server/http/api/openapi.yml @@ -668,6 +668,10 @@ components: description: Count of peers associated to the group type: integer example: 2 + resources_count: + description: Count of resources associated to the group + type: integer + example: 5 issued: description: How the group was issued (api, integration, jwt) type: string @@ -677,6 +681,7 @@ components: - id - name - peers_count + - resources_count GroupRequest: type: object properties: @@ -690,6 +695,10 @@ components: items: type: string example: "ch8i4ug6lnn4g9hqv7m1" + resources: + type: array + items: + $ref: '#/components/schemas/Resource' required: - name Group: @@ -702,8 +711,13 @@ components: type: array items: $ref: '#/components/schemas/PeerMinimum' + resources: + type: array + items: + $ref: '#/components/schemas/Resource' required: - peers + - resources PolicyRuleMinimum: type: object properties: @@ -782,15 +796,18 @@ components: items: type: string example: "ch8i4ug6lnn4g9hqv797" + sourceResource: + description: Policy rule source resource that the rule is applied to + $ref: '#/components/schemas/Resource' destinations: description: Policy rule destination group IDs type: array items: type: string example: "ch8i4ug6lnn4g9h7v7m0" - required: - - sources - - destinations + destinationResource: + description: Policy rule destination resource that the rule is applied to + $ref: '#/components/schemas/Resource' PolicyRule: allOf: - $ref: '#/components/schemas/PolicyRuleMinimum' @@ -801,14 +818,17 @@ components: type: array items: $ref: '#/components/schemas/GroupMinimum' + sourceResource: + description: Policy rule source resource that the rule is applied to + $ref: '#/components/schemas/Resource' destinations: description: Policy rule destination group IDs type: array items: $ref: '#/components/schemas/GroupMinimum' - required: - - sources - - destinations + destinationResource: + description: Policy rule destination resource that the rule is applied to + $ref: '#/components/schemas/Resource' PolicyMinimum: type: object properties: @@ -1176,6 +1196,139 @@ components: - id - network_type - $ref: '#/components/schemas/RouteRequest' + Resource: + type: object + properties: + id: + description: ID of the resource + type: string + example: chacdk86lnnboviihd7g + type: + description: Type of the resource + $ref: '#/components/schemas/ResourceType' + required: + - id + - type + ResourceType: + allOf: + - $ref: '#/components/schemas/NetworkResourceType' + - type: string + example: host + NetworkRequest: + type: object + properties: + name: + description: Network name + type: string + example: Remote Network 1 + description: + description: Network description + type: string + example: A remote network that needs to be accessed + required: + - name + Network: + allOf: + - type: object + properties: + id: + description: Network ID + type: string + example: chacdk86lnnboviihd7g + routers: + description: List of router IDs associated with the network + type: array + items: + type: string + example: ch8i4ug6lnn4g9hqv7m0 + resources: + description: List of network resource IDs associated with the network + type: array + items: + type: string + example: ch8i4ug6lnn4g9hqv7m1 + required: + - id + - routers + - resources + - $ref: '#/components/schemas/NetworkRequest' + NetworkResourceRequest: + type: object + properties: + name: + description: Network resource name + type: string + example: Remote Resource 1 + description: + description: Network resource description + type: string + example: A remote resource inside network 1 + address: + description: Network resource address (either a direct host like 1.1.1.1 or 1.1.1.1/32, or a subnet like 192.168.178.0/24, or a domain like example.com) + type: string + example: "1.1.1.1" + required: + - name + - address + NetworkResource: + allOf: + - type: object + properties: + id: + description: Network Resource ID + type: string + example: chacdk86lnnboviihd7g + type: + $ref: '#/components/schemas/NetworkResourceType' + required: + - id + - type + - $ref: '#/components/schemas/NetworkResourceRequest' + NetworkResourceType: + description: Network resource type based of the address + type: string + enum: [ "host", "subnet", "domain" ] + example: host + NetworkRouterRequest: + type: object + properties: + peer: + description: Peer Identifier associated with route. This property can not be set together with `peer_groups` + type: string + example: chacbco6lnnbn6cg5s91 + peer_groups: + description: Peers Group Identifier associated with route. This property can not be set together with `peer` + type: array + items: + type: string + example: chacbco6lnnbn6cg5s91 + metric: + description: Route metric number. Lowest number has higher priority + type: integer + maximum: 9999 + minimum: 1 + example: 9999 + masquerade: + description: Indicate if peer should masquerade traffic to this route's prefix + type: boolean + example: true + required: + # Only one property has to be set + #- peer + #- peer_groups + - metric + - masquerade + NetworkRouter: + allOf: + - type: object + properties: + id: + description: Network Router Id + type: string + example: chacdk86lnnboviihd7g + required: + - id + - $ref: '#/components/schemas/NetworkRouterRequest' Nameserver: type: object properties: @@ -2460,6 +2613,502 @@ paths: "$ref": "#/components/responses/forbidden" '500': "$ref": "#/components/responses/internal_error" + /api/networks: + get: + summary: List all Networks + description: Returns a list of all networks + tags: [ Networks ] + security: + - BearerAuth: [ ] + - TokenAuth: [ ] + responses: + '200': + description: A JSON Array of Networks + content: + application/json: + schema: + type: array + items: + $ref: '#/components/schemas/Network' + '400': + "$ref": "#/components/responses/bad_request" + '401': + "$ref": "#/components/responses/requires_authentication" + '403': + "$ref": "#/components/responses/forbidden" + '500': + "$ref": "#/components/responses/internal_error" + post: + summary: Create a Network + description: Creates a Network + tags: [ Networks ] + security: + - BearerAuth: [ ] + - TokenAuth: [ ] + requestBody: + description: New Network request + content: + 'application/json': + schema: + $ref: '#/components/schemas/NetworkRequest' + responses: + '200': + description: A Network Object + content: + application/json: + schema: + $ref: '#/components/schemas/Network' + '400': + "$ref": "#/components/responses/bad_request" + '401': + "$ref": "#/components/responses/requires_authentication" + '403': + "$ref": "#/components/responses/forbidden" + '500': + "$ref": "#/components/responses/internal_error" + /api/networks/{networkId}: + get: + summary: Retrieve a Network + description: Get information about a Network + tags: [ Networks ] + security: + - BearerAuth: [ ] + - TokenAuth: [ ] + parameters: + - in: path + name: networkId + required: true + schema: + type: string + description: The unique identifier of a network + responses: + '200': + description: A Network object + content: + application/json: + schema: + $ref: '#/components/schemas/Network' + '400': + "$ref": "#/components/responses/bad_request" + '401': + "$ref": "#/components/responses/requires_authentication" + '403': + "$ref": "#/components/responses/forbidden" + '500': + "$ref": "#/components/responses/internal_error" + put: + summary: Update a Network + description: Update/Replace a Network + tags: [ Networks ] + security: + - BearerAuth: [ ] + - TokenAuth: [ ] + parameters: + - in: path + name: networkId + required: true + schema: + type: string + description: The unique identifier of a network + requestBody: + description: Update Network request + content: + application/json: + schema: + $ref: '#/components/schemas/NetworkRequest' + responses: + '200': + description: A Network object + content: + application/json: + schema: + $ref: '#/components/schemas/Network' + '400': + "$ref": "#/components/responses/bad_request" + '401': + "$ref": "#/components/responses/requires_authentication" + '403': + "$ref": "#/components/responses/forbidden" + '500': + "$ref": "#/components/responses/internal_error" + delete: + summary: Delete a Network + description: Delete a network + tags: [ Networks ] + security: + - BearerAuth: [ ] + - TokenAuth: [ ] + parameters: + - in: path + name: networkId + required: true + schema: + type: string + description: The unique identifier of a network + responses: + '200': + description: Delete status code + content: { } + '400': + "$ref": "#/components/responses/bad_request" + '401': + "$ref": "#/components/responses/requires_authentication" + '403': + "$ref": "#/components/responses/forbidden" + '500': + "$ref": "#/components/responses/internal_error" + /api/networks/{networkId}/resources: + get: + summary: List all Network Resources + description: Returns a list of all resources in a network + tags: [ Networks ] + security: + - BearerAuth: [ ] + - TokenAuth: [ ] + parameters: + - in: path + name: networkId + required: true + schema: + type: string + description: The unique identifier of a network + responses: + '200': + description: A JSON Array of Resources + content: + application/json: + schema: + type: array + items: + $ref: '#/components/schemas/NetworkResource' + '400': + "$ref": "#/components/responses/bad_request" + '401': + "$ref": "#/components/responses/requires_authentication" + '403': + "$ref": "#/components/responses/forbidden" + '500': + "$ref": "#/components/responses/internal_error" + post: + summary: Create a Network Resource + description: Creates a Network Resource + tags: [ Networks ] + security: + - BearerAuth: [ ] + - TokenAuth: [ ] + parameters: + - in: path + name: networkId + required: true + schema: + type: string + description: The unique identifier of a network + requestBody: + description: New Network Resource request + content: + 'application/json': + schema: + $ref: '#/components/schemas/NetworkResourceRequest' + responses: + '200': + description: A Network Resource Object + content: + application/json: + schema: + $ref: '#/components/schemas/NetworkResource' + '400': + "$ref": "#/components/responses/bad_request" + '401': + "$ref": "#/components/responses/requires_authentication" + '403': + "$ref": "#/components/responses/forbidden" + '500': + "$ref": "#/components/responses/internal_error" + /api/networks/{networkId}/resources/{resourceId}: + get: + summary: Retrieve a Network Resource + description: Get information about a Network Resource + tags: [ Networks ] + security: + - BearerAuth: [ ] + - TokenAuth: [ ] + parameters: + - in: path + name: networkId + required: true + schema: + type: string + description: The unique identifier of a network + - in: path + name: resourceId + required: true + schema: + type: string + description: The unique identifier of a network resource + responses: + '200': + description: A Network Resource object + content: + application/json: + schema: + $ref: '#/components/schemas/NetworkResource' + '400': + "$ref": "#/components/responses/bad_request" + '401': + "$ref": "#/components/responses/requires_authentication" + '403': + "$ref": "#/components/responses/forbidden" + '500': + "$ref": "#/components/responses/internal_error" + put: + summary: Update a Network Resource + description: Update a Network Resource + tags: [ Networks ] + security: + - BearerAuth: [ ] + - TokenAuth: [ ] + parameters: + - in: path + name: networkId + required: true + schema: + type: string + description: The unique identifier of a network + - in: path + name: resourceId + required: true + schema: + type: string + description: The unique identifier of a resource + requestBody: + description: Update Network Resource request + content: + 'application/json': + schema: + $ref: '#/components/schemas/NetworkResourceRequest' + responses: + '200': + description: A Network Resource object + content: + application/json: + schema: + $ref: '#/components/schemas/NetworkResource' + '400': + "$ref": "#/components/responses/bad_request" + '401': + "$ref": "#/components/responses/requires_authentication" + '403': + "$ref": "#/components/responses/forbidden" + '500': + "$ref": "#/components/responses/internal_error" + delete: + summary: Delete a Network Resource + description: Delete a network resource + tags: [ Networks ] + security: + - BearerAuth: [ ] + - TokenAuth: [ ] + parameters: + - in: path + name: networkId + required: true + schema: + type: string + description: The unique identifier of a network + - in: path + name: resourceId + required: true + schema: + type: string + description: The unique identifier of a network resource + responses: + '200': + description: Delete status code + content: { } + '400': + "$ref": "#/components/responses/bad_request" + '401': + "$ref": "#/components/responses/requires_authentication" + '403': + "$ref": "#/components/responses/forbidden" + '500': + "$ref": "#/components/responses/internal_error" + /api/networks/{networkId}/routers: + get: + summary: List all Network Routers + description: Returns a list of all routers in a network + tags: [ Networks ] + security: + - BearerAuth: [ ] + - TokenAuth: [ ] + parameters: + - in: path + name: networkId + required: true + schema: + type: string + description: The unique identifier of a network + responses: + '200': + description: A JSON Array of Routers + content: + application/json: + schema: + type: array + items: + $ref: '#/components/schemas/NetworkRouter' + '400': + "$ref": "#/components/responses/bad_request" + '401': + "$ref": "#/components/responses/requires_authentication" + '403': + "$ref": "#/components/responses/forbidden" + '500': + "$ref": "#/components/responses/internal_error" + post: + summary: Create a Network Router + description: Creates a Network Router + tags: [ Networks ] + security: + - BearerAuth: [ ] + - TokenAuth: [ ] + parameters: + - in: path + name: networkId + required: true + schema: + type: string + description: The unique identifier of a network + requestBody: + description: New Network Router request + content: + 'application/json': + schema: + $ref: '#/components/schemas/NetworkRouterRequest' + responses: + '200': + description: A Router Object + content: + application/json: + schema: + $ref: '#/components/schemas/NetworkRouter' + '400': + "$ref": "#/components/responses/bad_request" + '401': + "$ref": "#/components/responses/requires_authentication" + '403': + "$ref": "#/components/responses/forbidden" + '500': + "$ref": "#/components/responses/internal_error" + /api/networks/{networkId}/routers/{routerId}: + get: + summary: Retrieve a Network Router + description: Get information about a Network Router + tags: [ Networks ] + security: + - BearerAuth: [ ] + - TokenAuth: [ ] + parameters: + - in: path + name: networkId + required: true + schema: + type: string + description: The unique identifier of a network + - in: path + name: routerId + required: true + schema: + type: string + description: The unique identifier of a router + responses: + '200': + description: A Router object + content: + application/json: + schema: + $ref: '#/components/schemas/NetworkRouter' + '400': + "$ref": "#/components/responses/bad_request" + '401': + "$ref": "#/components/responses/requires_authentication" + '403': + "$ref": "#/components/responses/forbidden" + '500': + "$ref": "#/components/responses/internal_error" + put: + summary: Update a Network Router + description: Update a Network Router + tags: [ Networks ] + security: + - BearerAuth: [ ] + - TokenAuth: [ ] + parameters: + - in: path + name: networkId + required: true + schema: + type: string + description: The unique identifier of a network + - in: path + name: routerId + required: true + schema: + type: string + description: The unique identifier of a router + requestBody: + description: Update Network Router request + content: + 'application/json': + schema: + $ref: '#/components/schemas/NetworkRouterRequest' + responses: + '200': + description: A Router object + content: + application/json: + schema: + $ref: '#/components/schemas/NetworkRouter' + '400': + "$ref": "#/components/responses/bad_request" + '401': + "$ref": "#/components/responses/requires_authentication" + '403': + "$ref": "#/components/responses/forbidden" + '500': + "$ref": "#/components/responses/internal_error" + delete: + summary: Delete a Network Router + description: Delete a network router + tags: [ Networks ] + security: + - BearerAuth: [ ] + - TokenAuth: [ ] + parameters: + - in: path + name: networkId + required: true + schema: + type: string + description: The unique identifier of a network + - in: path + name: routerId + required: true + schema: + type: string + description: The unique identifier of a router + responses: + '200': + description: Delete status code + content: { } + '400': + "$ref": "#/components/responses/bad_request" + '401': + "$ref": "#/components/responses/requires_authentication" + '403': + "$ref": "#/components/responses/forbidden" + '500': + "$ref": "#/components/responses/internal_error" /api/dns/nameservers: get: summary: List all Nameserver Groups diff --git a/management/server/http/api/types.gen.go b/management/server/http/api/types.gen.go index 321395d25..0ffc6eabe 100644 --- a/management/server/http/api/types.gen.go +++ b/management/server/http/api/types.gen.go @@ -88,6 +88,13 @@ const ( NameserverNsTypeUdp NameserverNsType = "udp" ) +// Defines values for NetworkResourceType. +const ( + NetworkResourceTypeDomain NetworkResourceType = "domain" + NetworkResourceTypeHost NetworkResourceType = "host" + NetworkResourceTypeSubnet NetworkResourceType = "subnet" +) + // Defines values for PeerNetworkRangeCheckAction. const ( PeerNetworkRangeCheckActionAllow PeerNetworkRangeCheckAction = "allow" @@ -136,6 +143,13 @@ const ( PolicyRuleUpdateProtocolUdp PolicyRuleUpdateProtocol = "udp" ) +// Defines values for ResourceType. +const ( + ResourceTypeDomain ResourceType = "domain" + ResourceTypeHost ResourceType = "host" + ResourceTypeSubnet ResourceType = "subnet" +) + // Defines values for UserStatus. const ( UserStatusActive UserStatus = "active" @@ -365,7 +379,11 @@ type Group struct { Peers []PeerMinimum `json:"peers"` // PeersCount Count of peers associated to the group - PeersCount int `json:"peers_count"` + PeersCount int `json:"peers_count"` + Resources []Resource `json:"resources"` + + // ResourcesCount Count of resources associated to the group + ResourcesCount int `json:"resources_count"` } // GroupIssued How the group was issued (api, integration, jwt) @@ -384,6 +402,9 @@ type GroupMinimum struct { // PeersCount Count of peers associated to the group PeersCount int `json:"peers_count"` + + // ResourcesCount Count of resources associated to the group + ResourcesCount int `json:"resources_count"` } // GroupMinimumIssued How the group was issued (api, integration, jwt) @@ -395,7 +416,8 @@ type GroupRequest struct { Name string `json:"name"` // Peers List of peers ids - Peers *[]string `json:"peers,omitempty"` + Peers *[]string `json:"peers,omitempty"` + Resources *[]Resource `json:"resources,omitempty"` } // Location Describe geographical location information @@ -494,6 +516,99 @@ type NameserverGroupRequest struct { SearchDomainsEnabled bool `json:"search_domains_enabled"` } +// Network defines model for Network. +type Network struct { + // Description Network description + Description *string `json:"description,omitempty"` + + // Id Network ID + Id string `json:"id"` + + // Name Network name + Name string `json:"name"` + + // Resources List of network resource IDs associated with the network + Resources []string `json:"resources"` + + // Routers List of router IDs associated with the network + Routers []string `json:"routers"` +} + +// NetworkRequest defines model for NetworkRequest. +type NetworkRequest struct { + // Description Network description + Description *string `json:"description,omitempty"` + + // Name Network name + Name string `json:"name"` +} + +// NetworkResource defines model for NetworkResource. +type NetworkResource struct { + // Address Network resource address (either a direct host like 1.1.1.1 or 1.1.1.1/32, or a subnet like 192.168.178.0/24, or a domain like example.com) + Address string `json:"address"` + + // Description Network resource description + Description *string `json:"description,omitempty"` + + // Id Network Resource ID + Id string `json:"id"` + + // Name Network resource name + Name string `json:"name"` + + // Type Network resource type based of the address + Type NetworkResourceType `json:"type"` +} + +// NetworkResourceRequest defines model for NetworkResourceRequest. +type NetworkResourceRequest struct { + // Address Network resource address (either a direct host like 1.1.1.1 or 1.1.1.1/32, or a subnet like 192.168.178.0/24, or a domain like example.com) + Address string `json:"address"` + + // Description Network resource description + Description *string `json:"description,omitempty"` + + // Name Network resource name + Name string `json:"name"` +} + +// NetworkResourceType Network resource type based of the address +type NetworkResourceType string + +// NetworkRouter defines model for NetworkRouter. +type NetworkRouter struct { + // Id Network Router Id + Id string `json:"id"` + + // Masquerade Indicate if peer should masquerade traffic to this route's prefix + Masquerade bool `json:"masquerade"` + + // Metric Route metric number. Lowest number has higher priority + Metric int `json:"metric"` + + // Peer Peer Identifier associated with route. This property can not be set together with `peer_groups` + Peer *string `json:"peer,omitempty"` + + // PeerGroups Peers Group Identifier associated with route. This property can not be set together with `peer` + PeerGroups *[]string `json:"peer_groups,omitempty"` +} + +// NetworkRouterRequest defines model for NetworkRouterRequest. +type NetworkRouterRequest struct { + // Masquerade Indicate if peer should masquerade traffic to this route's prefix + Masquerade bool `json:"masquerade"` + + // Metric Route metric number. Lowest number has higher priority + Metric int `json:"metric"` + + // Peer Peer Identifier associated with route. This property can not be set together with `peer_groups` + Peer *string `json:"peer,omitempty"` + + // PeerGroups Peers Group Identifier associated with route. This property can not be set together with `peer` + PeerGroups *[]string `json:"peer_groups,omitempty"` +} + // OSVersionCheck Posture check for the version of operating system type OSVersionCheck struct { // Android Posture check for the version of operating system @@ -779,10 +894,11 @@ type PolicyRule struct { Bidirectional bool `json:"bidirectional"` // Description Policy rule friendly description - Description *string `json:"description,omitempty"` + Description *string `json:"description,omitempty"` + DestinationResource *Resource `json:"destinationResource,omitempty"` // Destinations Policy rule destination group IDs - Destinations []GroupMinimum `json:"destinations"` + Destinations *[]GroupMinimum `json:"destinations,omitempty"` // Enabled Policy rule status Enabled bool `json:"enabled"` @@ -800,10 +916,11 @@ type PolicyRule struct { Ports *[]string `json:"ports,omitempty"` // Protocol Policy rule type of the traffic - Protocol PolicyRuleProtocol `json:"protocol"` + Protocol PolicyRuleProtocol `json:"protocol"` + SourceResource *Resource `json:"sourceResource,omitempty"` // Sources Policy rule source group IDs - Sources []GroupMinimum `json:"sources"` + Sources *[]GroupMinimum `json:"sources,omitempty"` } // PolicyRuleAction Policy rule accept or drops packets @@ -857,10 +974,11 @@ type PolicyRuleUpdate struct { Bidirectional bool `json:"bidirectional"` // Description Policy rule friendly description - Description *string `json:"description,omitempty"` + Description *string `json:"description,omitempty"` + DestinationResource *Resource `json:"destinationResource,omitempty"` // Destinations Policy rule destination group IDs - Destinations []string `json:"destinations"` + Destinations *[]string `json:"destinations,omitempty"` // Enabled Policy rule status Enabled bool `json:"enabled"` @@ -878,10 +996,11 @@ type PolicyRuleUpdate struct { Ports *[]string `json:"ports,omitempty"` // Protocol Policy rule type of the traffic - Protocol PolicyRuleUpdateProtocol `json:"protocol"` + Protocol PolicyRuleUpdateProtocol `json:"protocol"` + SourceResource *Resource `json:"sourceResource,omitempty"` // Sources Policy rule source group IDs - Sources []string `json:"sources"` + Sources *[]string `json:"sources,omitempty"` } // PolicyRuleUpdateAction Policy rule accept or drops packets @@ -955,6 +1074,16 @@ type ProcessCheck struct { Processes []Process `json:"processes"` } +// Resource defines model for Resource. +type Resource struct { + // Id ID of the resource + Id string `json:"id"` + Type ResourceType `json:"type"` +} + +// ResourceType defines model for ResourceType. +type ResourceType string + // Route defines model for Route. type Route struct { // AccessControlGroups Access control group identifier associated with route. @@ -1292,6 +1421,24 @@ type PostApiGroupsJSONRequestBody = GroupRequest // PutApiGroupsGroupIdJSONRequestBody defines body for PutApiGroupsGroupId for application/json ContentType. type PutApiGroupsGroupIdJSONRequestBody = GroupRequest +// PostApiNetworksJSONRequestBody defines body for PostApiNetworks for application/json ContentType. +type PostApiNetworksJSONRequestBody = NetworkRequest + +// PutApiNetworksNetworkIdJSONRequestBody defines body for PutApiNetworksNetworkId for application/json ContentType. +type PutApiNetworksNetworkIdJSONRequestBody = NetworkRequest + +// PostApiNetworksNetworkIdResourcesJSONRequestBody defines body for PostApiNetworksNetworkIdResources for application/json ContentType. +type PostApiNetworksNetworkIdResourcesJSONRequestBody = NetworkResourceRequest + +// PutApiNetworksNetworkIdResourcesResourceIdJSONRequestBody defines body for PutApiNetworksNetworkIdResourcesResourceId for application/json ContentType. +type PutApiNetworksNetworkIdResourcesResourceIdJSONRequestBody = NetworkResourceRequest + +// PostApiNetworksNetworkIdRoutersJSONRequestBody defines body for PostApiNetworksNetworkIdRouters for application/json ContentType. +type PostApiNetworksNetworkIdRoutersJSONRequestBody = NetworkRouterRequest + +// PutApiNetworksNetworkIdRoutersRouterIdJSONRequestBody defines body for PutApiNetworksNetworkIdRoutersRouterId for application/json ContentType. +type PutApiNetworksNetworkIdRoutersRouterIdJSONRequestBody = NetworkRouterRequest + // PutApiPeersPeerIdJSONRequestBody defines body for PutApiPeersPeerId for application/json ContentType. type PutApiPeersPeerIdJSONRequestBody = PeerRequest diff --git a/management/server/http/configs/auth.go b/management/server/http/configs/auth.go new file mode 100644 index 000000000..aa91fa55b --- /dev/null +++ b/management/server/http/configs/auth.go @@ -0,0 +1,9 @@ +package configs + +// AuthCfg contains parameters for authentication middleware +type AuthCfg struct { + Issuer string + Audience string + UserIDClaim string + KeysLocation string +} diff --git a/management/server/http/handler.go b/management/server/http/handler.go index c3928bff6..1bc11b1e9 100644 --- a/management/server/http/handler.go +++ b/management/server/http/handler.go @@ -12,6 +12,17 @@ import ( s "github.com/netbirdio/netbird/management/server" "github.com/netbirdio/netbird/management/server/geolocation" + "github.com/netbirdio/netbird/management/server/http/configs" + "github.com/netbirdio/netbird/management/server/http/handlers/accounts" + "github.com/netbirdio/netbird/management/server/http/handlers/dns" + "github.com/netbirdio/netbird/management/server/http/handlers/events" + "github.com/netbirdio/netbird/management/server/http/handlers/groups" + "github.com/netbirdio/netbird/management/server/http/handlers/networks" + "github.com/netbirdio/netbird/management/server/http/handlers/peers" + "github.com/netbirdio/netbird/management/server/http/handlers/policies" + "github.com/netbirdio/netbird/management/server/http/handlers/routes" + "github.com/netbirdio/netbird/management/server/http/handlers/setup_keys" + "github.com/netbirdio/netbird/management/server/http/handlers/users" "github.com/netbirdio/netbird/management/server/http/middleware" "github.com/netbirdio/netbird/management/server/integrated_validator" "github.com/netbirdio/netbird/management/server/jwtclaims" @@ -20,27 +31,15 @@ import ( const apiPrefix = "/api" -// AuthCfg contains parameters for authentication middleware -type AuthCfg struct { - Issuer string - Audience string - UserIDClaim string - KeysLocation string -} - type apiHandler struct { Router *mux.Router AccountManager s.AccountManager geolocationManager *geolocation.Geolocation - AuthCfg AuthCfg -} - -// EmptyObject is an empty struct used to return empty JSON object -type emptyObject struct { + AuthCfg configs.AuthCfg } // APIHandler creates the Management service HTTP API handler registering all the available endpoints. -func APIHandler(ctx context.Context, accountManager s.AccountManager, LocationManager *geolocation.Geolocation, jwtValidator jwtclaims.JWTValidator, appMetrics telemetry.AppMetrics, authCfg AuthCfg, integratedValidator integrated_validator.IntegratedValidator) (http.Handler, error) { +func APIHandler(ctx context.Context, accountManager s.AccountManager, LocationManager *geolocation.Geolocation, jwtValidator jwtclaims.JWTValidator, appMetrics telemetry.AppMetrics, authCfg configs.AuthCfg, integratedValidator integrated_validator.IntegratedValidator) (http.Handler, error) { claimsExtractor := jwtclaims.NewClaimsExtractor( jwtclaims.WithAudience(authCfg.Audience), jwtclaims.WithUserIDClaim(authCfg.UserIDClaim), @@ -86,122 +85,16 @@ func APIHandler(ctx context.Context, accountManager s.AccountManager, LocationMa return nil, fmt.Errorf("register integrations endpoints: %w", err) } - api.addAccountsEndpoint() - api.addPeersEndpoint() - api.addUsersEndpoint() - api.addUsersTokensEndpoint() - api.addSetupKeysEndpoint() - api.addPoliciesEndpoint() - api.addGroupsEndpoint() - api.addRoutesEndpoint() - api.addDNSNameserversEndpoint() - api.addDNSSettingEndpoint() - api.addEventsEndpoint() - api.addPostureCheckEndpoint() - api.addLocationsEndpoint() + accounts.AddEndpoints(api.AccountManager, authCfg, router) + peers.AddEndpoints(api.AccountManager, authCfg, router) + users.AddEndpoints(api.AccountManager, authCfg, router) + setup_keys.AddEndpoints(api.AccountManager, authCfg, router) + policies.AddEndpoints(api.AccountManager, api.geolocationManager, authCfg, router) + groups.AddEndpoints(api.AccountManager, authCfg, router) + routes.AddEndpoints(api.AccountManager, authCfg, router) + dns.AddEndpoints(api.AccountManager, authCfg, router) + events.AddEndpoints(api.AccountManager, authCfg, router) + networks.AddEndpoints(api.AccountManager.GetNetworksManager(), api.AccountManager.GetAccountIDFromToken, authCfg, router) return rootRouter, nil } - -func (apiHandler *apiHandler) addAccountsEndpoint() { - accountsHandler := NewAccountsHandler(apiHandler.AccountManager, apiHandler.AuthCfg) - apiHandler.Router.HandleFunc("/accounts/{accountId}", accountsHandler.UpdateAccount).Methods("PUT", "OPTIONS") - apiHandler.Router.HandleFunc("/accounts/{accountId}", accountsHandler.DeleteAccount).Methods("DELETE", "OPTIONS") - apiHandler.Router.HandleFunc("/accounts", accountsHandler.GetAllAccounts).Methods("GET", "OPTIONS") -} - -func (apiHandler *apiHandler) addPeersEndpoint() { - peersHandler := NewPeersHandler(apiHandler.AccountManager, apiHandler.AuthCfg) - apiHandler.Router.HandleFunc("/peers", peersHandler.GetAllPeers).Methods("GET", "OPTIONS") - apiHandler.Router.HandleFunc("/peers/{peerId}", peersHandler.HandlePeer). - Methods("GET", "PUT", "DELETE", "OPTIONS") - apiHandler.Router.HandleFunc("/peers/{peerId}/accessible-peers", peersHandler.GetAccessiblePeers).Methods("GET", "OPTIONS") -} - -func (apiHandler *apiHandler) addUsersEndpoint() { - userHandler := NewUsersHandler(apiHandler.AccountManager, apiHandler.AuthCfg) - apiHandler.Router.HandleFunc("/users", userHandler.GetAllUsers).Methods("GET", "OPTIONS") - apiHandler.Router.HandleFunc("/users/{userId}", userHandler.UpdateUser).Methods("PUT", "OPTIONS") - apiHandler.Router.HandleFunc("/users/{userId}", userHandler.DeleteUser).Methods("DELETE", "OPTIONS") - apiHandler.Router.HandleFunc("/users", userHandler.CreateUser).Methods("POST", "OPTIONS") - apiHandler.Router.HandleFunc("/users/{userId}/invite", userHandler.InviteUser).Methods("POST", "OPTIONS") -} - -func (apiHandler *apiHandler) addUsersTokensEndpoint() { - tokenHandler := NewPATsHandler(apiHandler.AccountManager, apiHandler.AuthCfg) - apiHandler.Router.HandleFunc("/users/{userId}/tokens", tokenHandler.GetAllTokens).Methods("GET", "OPTIONS") - apiHandler.Router.HandleFunc("/users/{userId}/tokens", tokenHandler.CreateToken).Methods("POST", "OPTIONS") - apiHandler.Router.HandleFunc("/users/{userId}/tokens/{tokenId}", tokenHandler.GetToken).Methods("GET", "OPTIONS") - apiHandler.Router.HandleFunc("/users/{userId}/tokens/{tokenId}", tokenHandler.DeleteToken).Methods("DELETE", "OPTIONS") -} - -func (apiHandler *apiHandler) addSetupKeysEndpoint() { - keysHandler := NewSetupKeysHandler(apiHandler.AccountManager, apiHandler.AuthCfg) - apiHandler.Router.HandleFunc("/setup-keys", keysHandler.GetAllSetupKeys).Methods("GET", "OPTIONS") - apiHandler.Router.HandleFunc("/setup-keys", keysHandler.CreateSetupKey).Methods("POST", "OPTIONS") - apiHandler.Router.HandleFunc("/setup-keys/{keyId}", keysHandler.GetSetupKey).Methods("GET", "OPTIONS") - apiHandler.Router.HandleFunc("/setup-keys/{keyId}", keysHandler.UpdateSetupKey).Methods("PUT", "OPTIONS") - apiHandler.Router.HandleFunc("/setup-keys/{keyId}", keysHandler.DeleteSetupKey).Methods("DELETE", "OPTIONS") -} - -func (apiHandler *apiHandler) addPoliciesEndpoint() { - policiesHandler := NewPoliciesHandler(apiHandler.AccountManager, apiHandler.AuthCfg) - apiHandler.Router.HandleFunc("/policies", policiesHandler.GetAllPolicies).Methods("GET", "OPTIONS") - apiHandler.Router.HandleFunc("/policies", policiesHandler.CreatePolicy).Methods("POST", "OPTIONS") - apiHandler.Router.HandleFunc("/policies/{policyId}", policiesHandler.UpdatePolicy).Methods("PUT", "OPTIONS") - apiHandler.Router.HandleFunc("/policies/{policyId}", policiesHandler.GetPolicy).Methods("GET", "OPTIONS") - apiHandler.Router.HandleFunc("/policies/{policyId}", policiesHandler.DeletePolicy).Methods("DELETE", "OPTIONS") -} - -func (apiHandler *apiHandler) addGroupsEndpoint() { - groupsHandler := NewGroupsHandler(apiHandler.AccountManager, apiHandler.AuthCfg) - apiHandler.Router.HandleFunc("/groups", groupsHandler.GetAllGroups).Methods("GET", "OPTIONS") - apiHandler.Router.HandleFunc("/groups", groupsHandler.CreateGroup).Methods("POST", "OPTIONS") - apiHandler.Router.HandleFunc("/groups/{groupId}", groupsHandler.UpdateGroup).Methods("PUT", "OPTIONS") - apiHandler.Router.HandleFunc("/groups/{groupId}", groupsHandler.GetGroup).Methods("GET", "OPTIONS") - apiHandler.Router.HandleFunc("/groups/{groupId}", groupsHandler.DeleteGroup).Methods("DELETE", "OPTIONS") -} - -func (apiHandler *apiHandler) addRoutesEndpoint() { - routesHandler := NewRoutesHandler(apiHandler.AccountManager, apiHandler.AuthCfg) - apiHandler.Router.HandleFunc("/routes", routesHandler.GetAllRoutes).Methods("GET", "OPTIONS") - apiHandler.Router.HandleFunc("/routes", routesHandler.CreateRoute).Methods("POST", "OPTIONS") - apiHandler.Router.HandleFunc("/routes/{routeId}", routesHandler.UpdateRoute).Methods("PUT", "OPTIONS") - apiHandler.Router.HandleFunc("/routes/{routeId}", routesHandler.GetRoute).Methods("GET", "OPTIONS") - apiHandler.Router.HandleFunc("/routes/{routeId}", routesHandler.DeleteRoute).Methods("DELETE", "OPTIONS") -} - -func (apiHandler *apiHandler) addDNSNameserversEndpoint() { - nameserversHandler := NewNameserversHandler(apiHandler.AccountManager, apiHandler.AuthCfg) - apiHandler.Router.HandleFunc("/dns/nameservers", nameserversHandler.GetAllNameservers).Methods("GET", "OPTIONS") - apiHandler.Router.HandleFunc("/dns/nameservers", nameserversHandler.CreateNameserverGroup).Methods("POST", "OPTIONS") - apiHandler.Router.HandleFunc("/dns/nameservers/{nsgroupId}", nameserversHandler.UpdateNameserverGroup).Methods("PUT", "OPTIONS") - apiHandler.Router.HandleFunc("/dns/nameservers/{nsgroupId}", nameserversHandler.GetNameserverGroup).Methods("GET", "OPTIONS") - apiHandler.Router.HandleFunc("/dns/nameservers/{nsgroupId}", nameserversHandler.DeleteNameserverGroup).Methods("DELETE", "OPTIONS") -} - -func (apiHandler *apiHandler) addDNSSettingEndpoint() { - dnsSettingsHandler := NewDNSSettingsHandler(apiHandler.AccountManager, apiHandler.AuthCfg) - apiHandler.Router.HandleFunc("/dns/settings", dnsSettingsHandler.GetDNSSettings).Methods("GET", "OPTIONS") - apiHandler.Router.HandleFunc("/dns/settings", dnsSettingsHandler.UpdateDNSSettings).Methods("PUT", "OPTIONS") -} - -func (apiHandler *apiHandler) addEventsEndpoint() { - eventsHandler := NewEventsHandler(apiHandler.AccountManager, apiHandler.AuthCfg) - apiHandler.Router.HandleFunc("/events", eventsHandler.GetAllEvents).Methods("GET", "OPTIONS") -} - -func (apiHandler *apiHandler) addPostureCheckEndpoint() { - postureCheckHandler := NewPostureChecksHandler(apiHandler.AccountManager, apiHandler.geolocationManager, apiHandler.AuthCfg) - apiHandler.Router.HandleFunc("/posture-checks", postureCheckHandler.GetAllPostureChecks).Methods("GET", "OPTIONS") - apiHandler.Router.HandleFunc("/posture-checks", postureCheckHandler.CreatePostureCheck).Methods("POST", "OPTIONS") - apiHandler.Router.HandleFunc("/posture-checks/{postureCheckId}", postureCheckHandler.UpdatePostureCheck).Methods("PUT", "OPTIONS") - apiHandler.Router.HandleFunc("/posture-checks/{postureCheckId}", postureCheckHandler.GetPostureCheck).Methods("GET", "OPTIONS") - apiHandler.Router.HandleFunc("/posture-checks/{postureCheckId}", postureCheckHandler.DeletePostureCheck).Methods("DELETE", "OPTIONS") -} - -func (apiHandler *apiHandler) addLocationsEndpoint() { - locationHandler := NewGeolocationsHandlerHandler(apiHandler.AccountManager, apiHandler.geolocationManager, apiHandler.AuthCfg) - apiHandler.Router.HandleFunc("/locations/countries", locationHandler.GetAllCountries).Methods("GET", "OPTIONS") - apiHandler.Router.HandleFunc("/locations/countries/{country}/cities", locationHandler.GetCitiesByCountry).Methods("GET", "OPTIONS") -} diff --git a/management/server/http/accounts_handler.go b/management/server/http/handlers/accounts/accounts_handler.go similarity index 76% rename from management/server/http/accounts_handler.go rename to management/server/http/handlers/accounts/accounts_handler.go index 4baf9c692..64b993952 100644 --- a/management/server/http/accounts_handler.go +++ b/management/server/http/handlers/accounts/accounts_handler.go @@ -1,4 +1,4 @@ -package http +package accounts import ( "encoding/json" @@ -10,20 +10,29 @@ import ( "github.com/netbirdio/netbird/management/server" "github.com/netbirdio/netbird/management/server/account" "github.com/netbirdio/netbird/management/server/http/api" + "github.com/netbirdio/netbird/management/server/http/configs" "github.com/netbirdio/netbird/management/server/http/util" "github.com/netbirdio/netbird/management/server/jwtclaims" "github.com/netbirdio/netbird/management/server/status" + "github.com/netbirdio/netbird/management/server/types" ) -// AccountsHandler is a handler that handles the server.Account HTTP endpoints -type AccountsHandler struct { +// handler is a handler that handles the server.Account HTTP endpoints +type handler struct { accountManager server.AccountManager claimsExtractor *jwtclaims.ClaimsExtractor } -// NewAccountsHandler creates a new AccountsHandler HTTP handler -func NewAccountsHandler(accountManager server.AccountManager, authCfg AuthCfg) *AccountsHandler { - return &AccountsHandler{ +func AddEndpoints(accountManager server.AccountManager, authCfg configs.AuthCfg, router *mux.Router) { + accountsHandler := newHandler(accountManager, authCfg) + router.HandleFunc("/accounts/{accountId}", accountsHandler.updateAccount).Methods("PUT", "OPTIONS") + router.HandleFunc("/accounts/{accountId}", accountsHandler.deleteAccount).Methods("DELETE", "OPTIONS") + router.HandleFunc("/accounts", accountsHandler.getAllAccounts).Methods("GET", "OPTIONS") +} + +// newHandler creates a new handler HTTP handler +func newHandler(accountManager server.AccountManager, authCfg configs.AuthCfg) *handler { + return &handler{ accountManager: accountManager, claimsExtractor: jwtclaims.NewClaimsExtractor( jwtclaims.WithAudience(authCfg.Audience), @@ -32,8 +41,8 @@ func NewAccountsHandler(accountManager server.AccountManager, authCfg AuthCfg) * } } -// GetAllAccounts is HTTP GET handler that returns a list of accounts. Effectively returns just a single account. -func (h *AccountsHandler) GetAllAccounts(w http.ResponseWriter, r *http.Request) { +// getAllAccounts is HTTP GET handler that returns a list of accounts. Effectively returns just a single account. +func (h *handler) getAllAccounts(w http.ResponseWriter, r *http.Request) { claims := h.claimsExtractor.FromRequestContext(r) accountID, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims) if err != nil { @@ -51,8 +60,8 @@ func (h *AccountsHandler) GetAllAccounts(w http.ResponseWriter, r *http.Request) util.WriteJSONObject(r.Context(), w, []*api.Account{resp}) } -// UpdateAccount is HTTP PUT handler that updates the provided account. Updates only account settings (server.Settings) -func (h *AccountsHandler) UpdateAccount(w http.ResponseWriter, r *http.Request) { +// updateAccount is HTTP PUT handler that updates the provided account. Updates only account settings (server.Settings) +func (h *handler) updateAccount(w http.ResponseWriter, r *http.Request) { claims := h.claimsExtractor.FromRequestContext(r) _, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims) if err != nil { @@ -74,7 +83,7 @@ func (h *AccountsHandler) UpdateAccount(w http.ResponseWriter, r *http.Request) return } - settings := &server.Settings{ + settings := &types.Settings{ PeerLoginExpirationEnabled: req.Settings.PeerLoginExpirationEnabled, PeerLoginExpiration: time.Duration(float64(time.Second.Nanoseconds()) * float64(req.Settings.PeerLoginExpiration)), RegularUsersViewBlocked: req.Settings.RegularUsersViewBlocked, @@ -111,8 +120,8 @@ func (h *AccountsHandler) UpdateAccount(w http.ResponseWriter, r *http.Request) util.WriteJSONObject(r.Context(), w, &resp) } -// DeleteAccount is a HTTP DELETE handler to delete an account -func (h *AccountsHandler) DeleteAccount(w http.ResponseWriter, r *http.Request) { +// deleteAccount is a HTTP DELETE handler to delete an account +func (h *handler) deleteAccount(w http.ResponseWriter, r *http.Request) { claims := h.claimsExtractor.FromRequestContext(r) vars := mux.Vars(r) targetAccountID := vars["accountId"] @@ -127,10 +136,10 @@ func (h *AccountsHandler) DeleteAccount(w http.ResponseWriter, r *http.Request) return } - util.WriteJSONObject(r.Context(), w, emptyObject{}) + util.WriteJSONObject(r.Context(), w, util.EmptyObject{}) } -func toAccountResponse(accountID string, settings *server.Settings) *api.Account { +func toAccountResponse(accountID string, settings *types.Settings) *api.Account { jwtAllowGroups := settings.JWTAllowGroups if jwtAllowGroups == nil { jwtAllowGroups = []string{} diff --git a/management/server/http/accounts_handler_test.go b/management/server/http/handlers/accounts/accounts_handler_test.go similarity index 91% rename from management/server/http/accounts_handler_test.go rename to management/server/http/handlers/accounts/accounts_handler_test.go index cacb3d430..96f0755cf 100644 --- a/management/server/http/accounts_handler_test.go +++ b/management/server/http/handlers/accounts/accounts_handler_test.go @@ -1,4 +1,4 @@ -package http +package accounts import ( "bytes" @@ -13,23 +13,23 @@ import ( "github.com/gorilla/mux" "github.com/stretchr/testify/assert" - "github.com/netbirdio/netbird/management/server" "github.com/netbirdio/netbird/management/server/http/api" "github.com/netbirdio/netbird/management/server/jwtclaims" "github.com/netbirdio/netbird/management/server/mock_server" "github.com/netbirdio/netbird/management/server/status" + "github.com/netbirdio/netbird/management/server/types" ) -func initAccountsTestData(account *server.Account, admin *server.User) *AccountsHandler { - return &AccountsHandler{ +func initAccountsTestData(account *types.Account, admin *types.User) *handler { + return &handler{ accountManager: &mock_server.MockAccountManager{ GetAccountIDFromTokenFunc: func(ctx context.Context, claims jwtclaims.AuthorizationClaims) (string, string, error) { return account.Id, admin.Id, nil }, - GetAccountSettingsFunc: func(ctx context.Context, accountID string, userID string) (*server.Settings, error) { + GetAccountSettingsFunc: func(ctx context.Context, accountID string, userID string) (*types.Settings, error) { return account.Settings, nil }, - UpdateAccountSettingsFunc: func(ctx context.Context, accountID, userID string, newSettings *server.Settings) (*server.Account, error) { + UpdateAccountSettingsFunc: func(ctx context.Context, accountID, userID string, newSettings *types.Settings) (*types.Account, error) { halfYearLimit := 180 * 24 * time.Hour if newSettings.PeerLoginExpiration > halfYearLimit { return nil, status.Errorf(status.InvalidArgument, "peer login expiration can't be larger than 180 days") @@ -58,19 +58,19 @@ func initAccountsTestData(account *server.Account, admin *server.User) *Accounts func TestAccounts_AccountsHandler(t *testing.T) { accountID := "test_account" - adminUser := server.NewAdminUser("test_user") + adminUser := types.NewAdminUser("test_user") sr := func(v string) *string { return &v } br := func(v bool) *bool { return &v } - handler := initAccountsTestData(&server.Account{ + handler := initAccountsTestData(&types.Account{ Id: accountID, Domain: "hotmail.com", - Network: server.NewNetwork(), - Users: map[string]*server.User{ + Network: types.NewNetwork(), + Users: map[string]*types.User{ adminUser.Id: adminUser, }, - Settings: &server.Settings{ + Settings: &types.Settings{ PeerLoginExpirationEnabled: false, PeerLoginExpiration: time.Hour, RegularUsersViewBlocked: true, @@ -89,7 +89,7 @@ func TestAccounts_AccountsHandler(t *testing.T) { requestBody io.Reader }{ { - name: "GetAllAccounts OK", + name: "getAllAccounts OK", expectedBody: true, requestType: http.MethodGet, requestPath: "/api/accounts", @@ -189,8 +189,8 @@ func TestAccounts_AccountsHandler(t *testing.T) { req := httptest.NewRequest(tc.requestType, tc.requestPath, tc.requestBody) router := mux.NewRouter() - router.HandleFunc("/api/accounts", handler.GetAllAccounts).Methods("GET") - router.HandleFunc("/api/accounts/{accountId}", handler.UpdateAccount).Methods("PUT") + router.HandleFunc("/api/accounts", handler.getAllAccounts).Methods("GET") + router.HandleFunc("/api/accounts/{accountId}", handler.updateAccount).Methods("PUT") router.ServeHTTP(recorder, req) res := recorder.Result() diff --git a/management/server/http/dns_settings_handler.go b/management/server/http/handlers/dns/dns_settings_handler.go similarity index 59% rename from management/server/http/dns_settings_handler.go rename to management/server/http/handlers/dns/dns_settings_handler.go index 13c2101a7..112eee179 100644 --- a/management/server/http/dns_settings_handler.go +++ b/management/server/http/handlers/dns/dns_settings_handler.go @@ -1,26 +1,40 @@ -package http +package dns import ( "encoding/json" "net/http" + "github.com/gorilla/mux" log "github.com/sirupsen/logrus" "github.com/netbirdio/netbird/management/server" "github.com/netbirdio/netbird/management/server/http/api" + "github.com/netbirdio/netbird/management/server/http/configs" "github.com/netbirdio/netbird/management/server/http/util" "github.com/netbirdio/netbird/management/server/jwtclaims" + "github.com/netbirdio/netbird/management/server/types" ) -// DNSSettingsHandler is a handler that returns the DNS settings of the account -type DNSSettingsHandler struct { +// dnsSettingsHandler is a handler that returns the DNS settings of the account +type dnsSettingsHandler struct { accountManager server.AccountManager claimsExtractor *jwtclaims.ClaimsExtractor } -// NewDNSSettingsHandler returns a new instance of DNSSettingsHandler handler -func NewDNSSettingsHandler(accountManager server.AccountManager, authCfg AuthCfg) *DNSSettingsHandler { - return &DNSSettingsHandler{ +func AddEndpoints(accountManager server.AccountManager, authCfg configs.AuthCfg, router *mux.Router) { + addDNSSettingEndpoint(accountManager, authCfg, router) + addDNSNameserversEndpoint(accountManager, authCfg, router) +} + +func addDNSSettingEndpoint(accountManager server.AccountManager, authCfg configs.AuthCfg, router *mux.Router) { + dnsSettingsHandler := newDNSSettingsHandler(accountManager, authCfg) + router.HandleFunc("/dns/settings", dnsSettingsHandler.getDNSSettings).Methods("GET", "OPTIONS") + router.HandleFunc("/dns/settings", dnsSettingsHandler.updateDNSSettings).Methods("PUT", "OPTIONS") +} + +// newDNSSettingsHandler returns a new instance of dnsSettingsHandler handler +func newDNSSettingsHandler(accountManager server.AccountManager, authCfg configs.AuthCfg) *dnsSettingsHandler { + return &dnsSettingsHandler{ accountManager: accountManager, claimsExtractor: jwtclaims.NewClaimsExtractor( jwtclaims.WithAudience(authCfg.Audience), @@ -29,8 +43,8 @@ func NewDNSSettingsHandler(accountManager server.AccountManager, authCfg AuthCfg } } -// GetDNSSettings returns the DNS settings for the account -func (h *DNSSettingsHandler) GetDNSSettings(w http.ResponseWriter, r *http.Request) { +// getDNSSettings returns the DNS settings for the account +func (h *dnsSettingsHandler) getDNSSettings(w http.ResponseWriter, r *http.Request) { claims := h.claimsExtractor.FromRequestContext(r) accountID, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims) if err != nil { @@ -52,8 +66,8 @@ func (h *DNSSettingsHandler) GetDNSSettings(w http.ResponseWriter, r *http.Reque util.WriteJSONObject(r.Context(), w, apiDNSSettings) } -// UpdateDNSSettings handles update to DNS settings of an account -func (h *DNSSettingsHandler) UpdateDNSSettings(w http.ResponseWriter, r *http.Request) { +// updateDNSSettings handles update to DNS settings of an account +func (h *dnsSettingsHandler) updateDNSSettings(w http.ResponseWriter, r *http.Request) { claims := h.claimsExtractor.FromRequestContext(r) accountID, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims) if err != nil { @@ -68,7 +82,7 @@ func (h *DNSSettingsHandler) UpdateDNSSettings(w http.ResponseWriter, r *http.Re return } - updateDNSSettings := &server.DNSSettings{ + updateDNSSettings := &types.DNSSettings{ DisabledManagementGroups: req.DisabledManagementGroups, } diff --git a/management/server/http/dns_settings_handler_test.go b/management/server/http/handlers/dns/dns_settings_handler_test.go similarity index 86% rename from management/server/http/dns_settings_handler_test.go rename to management/server/http/handlers/dns/dns_settings_handler_test.go index 8baea7b15..9ca1dc032 100644 --- a/management/server/http/dns_settings_handler_test.go +++ b/management/server/http/handlers/dns/dns_settings_handler_test.go @@ -1,4 +1,4 @@ -package http +package dns import ( "bytes" @@ -13,10 +13,10 @@ import ( "github.com/netbirdio/netbird/management/server/http/api" "github.com/netbirdio/netbird/management/server/status" + "github.com/netbirdio/netbird/management/server/types" "github.com/gorilla/mux" - "github.com/netbirdio/netbird/management/server" "github.com/netbirdio/netbird/management/server/jwtclaims" "github.com/netbirdio/netbird/management/server/mock_server" ) @@ -27,26 +27,26 @@ const ( testDNSSettingsUserID = "test_user" ) -var baseExistingDNSSettings = server.DNSSettings{ +var baseExistingDNSSettings = types.DNSSettings{ DisabledManagementGroups: []string{testDNSSettingsExistingGroup}, } -var testingDNSSettingsAccount = &server.Account{ +var testingDNSSettingsAccount = &types.Account{ Id: testDNSSettingsAccountID, Domain: "hotmail.com", - Users: map[string]*server.User{ - testDNSSettingsUserID: server.NewAdminUser("test_user"), + Users: map[string]*types.User{ + testDNSSettingsUserID: types.NewAdminUser("test_user"), }, DNSSettings: baseExistingDNSSettings, } -func initDNSSettingsTestData() *DNSSettingsHandler { - return &DNSSettingsHandler{ +func initDNSSettingsTestData() *dnsSettingsHandler { + return &dnsSettingsHandler{ accountManager: &mock_server.MockAccountManager{ - GetDNSSettingsFunc: func(ctx context.Context, accountID string, userID string) (*server.DNSSettings, error) { + GetDNSSettingsFunc: func(ctx context.Context, accountID string, userID string) (*types.DNSSettings, error) { return &testingDNSSettingsAccount.DNSSettings, nil }, - SaveDNSSettingsFunc: func(ctx context.Context, accountID string, userID string, dnsSettingsToSave *server.DNSSettings) error { + SaveDNSSettingsFunc: func(ctx context.Context, accountID string, userID string, dnsSettingsToSave *types.DNSSettings) error { if dnsSettingsToSave != nil { return nil } @@ -120,8 +120,8 @@ func TestDNSSettingsHandlers(t *testing.T) { req := httptest.NewRequest(tc.requestType, tc.requestPath, tc.requestBody) router := mux.NewRouter() - router.HandleFunc("/api/dns/settings", p.GetDNSSettings).Methods("GET") - router.HandleFunc("/api/dns/settings", p.UpdateDNSSettings).Methods("PUT") + router.HandleFunc("/api/dns/settings", p.getDNSSettings).Methods("GET") + router.HandleFunc("/api/dns/settings", p.updateDNSSettings).Methods("PUT") router.ServeHTTP(recorder, req) res := recorder.Result() diff --git a/management/server/http/nameservers_handler.go b/management/server/http/handlers/dns/nameservers_handler.go similarity index 77% rename from management/server/http/nameservers_handler.go rename to management/server/http/handlers/dns/nameservers_handler.go index e7a2bc2ae..09047e231 100644 --- a/management/server/http/nameservers_handler.go +++ b/management/server/http/handlers/dns/nameservers_handler.go @@ -1,4 +1,4 @@ -package http +package dns import ( "encoding/json" @@ -11,20 +11,30 @@ import ( nbdns "github.com/netbirdio/netbird/dns" "github.com/netbirdio/netbird/management/server" "github.com/netbirdio/netbird/management/server/http/api" + "github.com/netbirdio/netbird/management/server/http/configs" "github.com/netbirdio/netbird/management/server/http/util" "github.com/netbirdio/netbird/management/server/jwtclaims" "github.com/netbirdio/netbird/management/server/status" ) -// NameserversHandler is the nameserver group handler of the account -type NameserversHandler struct { +// nameserversHandler is the nameserver group handler of the account +type nameserversHandler struct { accountManager server.AccountManager claimsExtractor *jwtclaims.ClaimsExtractor } -// NewNameserversHandler returns a new instance of NameserversHandler handler -func NewNameserversHandler(accountManager server.AccountManager, authCfg AuthCfg) *NameserversHandler { - return &NameserversHandler{ +func addDNSNameserversEndpoint(accountManager server.AccountManager, authCfg configs.AuthCfg, router *mux.Router) { + nameserversHandler := newNameserversHandler(accountManager, authCfg) + router.HandleFunc("/dns/nameservers", nameserversHandler.getAllNameservers).Methods("GET", "OPTIONS") + router.HandleFunc("/dns/nameservers", nameserversHandler.createNameserverGroup).Methods("POST", "OPTIONS") + router.HandleFunc("/dns/nameservers/{nsgroupId}", nameserversHandler.updateNameserverGroup).Methods("PUT", "OPTIONS") + router.HandleFunc("/dns/nameservers/{nsgroupId}", nameserversHandler.getNameserverGroup).Methods("GET", "OPTIONS") + router.HandleFunc("/dns/nameservers/{nsgroupId}", nameserversHandler.deleteNameserverGroup).Methods("DELETE", "OPTIONS") +} + +// newNameserversHandler returns a new instance of nameserversHandler handler +func newNameserversHandler(accountManager server.AccountManager, authCfg configs.AuthCfg) *nameserversHandler { + return &nameserversHandler{ accountManager: accountManager, claimsExtractor: jwtclaims.NewClaimsExtractor( jwtclaims.WithAudience(authCfg.Audience), @@ -33,8 +43,8 @@ func NewNameserversHandler(accountManager server.AccountManager, authCfg AuthCfg } } -// GetAllNameservers returns the list of nameserver groups for the account -func (h *NameserversHandler) GetAllNameservers(w http.ResponseWriter, r *http.Request) { +// getAllNameservers returns the list of nameserver groups for the account +func (h *nameserversHandler) getAllNameservers(w http.ResponseWriter, r *http.Request) { claims := h.claimsExtractor.FromRequestContext(r) accountID, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims) if err != nil { @@ -57,8 +67,8 @@ func (h *NameserversHandler) GetAllNameservers(w http.ResponseWriter, r *http.Re util.WriteJSONObject(r.Context(), w, apiNameservers) } -// CreateNameserverGroup handles nameserver group creation request -func (h *NameserversHandler) CreateNameserverGroup(w http.ResponseWriter, r *http.Request) { +// createNameserverGroup handles nameserver group creation request +func (h *nameserversHandler) createNameserverGroup(w http.ResponseWriter, r *http.Request) { claims := h.claimsExtractor.FromRequestContext(r) accountID, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims) if err != nil { @@ -90,8 +100,8 @@ func (h *NameserversHandler) CreateNameserverGroup(w http.ResponseWriter, r *htt util.WriteJSONObject(r.Context(), w, &resp) } -// UpdateNameserverGroup handles update to a nameserver group identified by a given ID -func (h *NameserversHandler) UpdateNameserverGroup(w http.ResponseWriter, r *http.Request) { +// updateNameserverGroup handles update to a nameserver group identified by a given ID +func (h *nameserversHandler) updateNameserverGroup(w http.ResponseWriter, r *http.Request) { claims := h.claimsExtractor.FromRequestContext(r) accountID, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims) if err != nil { @@ -141,8 +151,8 @@ func (h *NameserversHandler) UpdateNameserverGroup(w http.ResponseWriter, r *htt util.WriteJSONObject(r.Context(), w, &resp) } -// DeleteNameserverGroup handles nameserver group deletion request -func (h *NameserversHandler) DeleteNameserverGroup(w http.ResponseWriter, r *http.Request) { +// deleteNameserverGroup handles nameserver group deletion request +func (h *nameserversHandler) deleteNameserverGroup(w http.ResponseWriter, r *http.Request) { claims := h.claimsExtractor.FromRequestContext(r) accountID, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims) if err != nil { @@ -162,11 +172,11 @@ func (h *NameserversHandler) DeleteNameserverGroup(w http.ResponseWriter, r *htt return } - util.WriteJSONObject(r.Context(), w, emptyObject{}) + util.WriteJSONObject(r.Context(), w, util.EmptyObject{}) } -// GetNameserverGroup handles a nameserver group Get request identified by ID -func (h *NameserversHandler) GetNameserverGroup(w http.ResponseWriter, r *http.Request) { +// getNameserverGroup handles a nameserver group Get request identified by ID +func (h *nameserversHandler) getNameserverGroup(w http.ResponseWriter, r *http.Request) { claims := h.claimsExtractor.FromRequestContext(r) accountID, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims) if err != nil { diff --git a/management/server/http/nameservers_handler_test.go b/management/server/http/handlers/dns/nameservers_handler_test.go similarity index 95% rename from management/server/http/nameservers_handler_test.go rename to management/server/http/handlers/dns/nameservers_handler_test.go index 98c2e402d..c6561e4d8 100644 --- a/management/server/http/nameservers_handler_test.go +++ b/management/server/http/handlers/dns/nameservers_handler_test.go @@ -1,4 +1,4 @@ -package http +package dns import ( "bytes" @@ -50,8 +50,8 @@ var baseExistingNSGroup = &nbdns.NameServerGroup{ Enabled: true, } -func initNameserversTestData() *NameserversHandler { - return &NameserversHandler{ +func initNameserversTestData() *nameserversHandler { + return &nameserversHandler{ accountManager: &mock_server.MockAccountManager{ GetNameServerGroupFunc: func(_ context.Context, accountID, userID, nsGroupID string) (*nbdns.NameServerGroup, error) { if nsGroupID == existingNSGroupID { @@ -206,10 +206,10 @@ func TestNameserversHandlers(t *testing.T) { req := httptest.NewRequest(tc.requestType, tc.requestPath, tc.requestBody) router := mux.NewRouter() - router.HandleFunc("/api/dns/nameservers/{nsgroupId}", p.GetNameserverGroup).Methods("GET") - router.HandleFunc("/api/dns/nameservers", p.CreateNameserverGroup).Methods("POST") - router.HandleFunc("/api/dns/nameservers/{nsgroupId}", p.DeleteNameserverGroup).Methods("DELETE") - router.HandleFunc("/api/dns/nameservers/{nsgroupId}", p.UpdateNameserverGroup).Methods("PUT") + router.HandleFunc("/api/dns/nameservers/{nsgroupId}", p.getNameserverGroup).Methods("GET") + router.HandleFunc("/api/dns/nameservers", p.createNameserverGroup).Methods("POST") + router.HandleFunc("/api/dns/nameservers/{nsgroupId}", p.deleteNameserverGroup).Methods("DELETE") + router.HandleFunc("/api/dns/nameservers/{nsgroupId}", p.updateNameserverGroup).Methods("PUT") router.ServeHTTP(recorder, req) res := recorder.Result() diff --git a/management/server/http/events_handler.go b/management/server/http/handlers/events/events_handler.go similarity index 79% rename from management/server/http/events_handler.go rename to management/server/http/handlers/events/events_handler.go index ee0c63f28..62da59535 100644 --- a/management/server/http/events_handler.go +++ b/management/server/http/handlers/events/events_handler.go @@ -1,28 +1,35 @@ -package http +package events import ( "context" "fmt" "net/http" + "github.com/gorilla/mux" log "github.com/sirupsen/logrus" "github.com/netbirdio/netbird/management/server" "github.com/netbirdio/netbird/management/server/activity" "github.com/netbirdio/netbird/management/server/http/api" + "github.com/netbirdio/netbird/management/server/http/configs" "github.com/netbirdio/netbird/management/server/http/util" "github.com/netbirdio/netbird/management/server/jwtclaims" ) -// EventsHandler HTTP handler -type EventsHandler struct { +// handler HTTP handler +type handler struct { accountManager server.AccountManager claimsExtractor *jwtclaims.ClaimsExtractor } -// NewEventsHandler creates a new EventsHandler HTTP handler -func NewEventsHandler(accountManager server.AccountManager, authCfg AuthCfg) *EventsHandler { - return &EventsHandler{ +func AddEndpoints(accountManager server.AccountManager, authCfg configs.AuthCfg, router *mux.Router) { + eventsHandler := newHandler(accountManager, authCfg) + router.HandleFunc("/events", eventsHandler.getAllEvents).Methods("GET", "OPTIONS") +} + +// newHandler creates a new events handler +func newHandler(accountManager server.AccountManager, authCfg configs.AuthCfg) *handler { + return &handler{ accountManager: accountManager, claimsExtractor: jwtclaims.NewClaimsExtractor( jwtclaims.WithAudience(authCfg.Audience), @@ -31,8 +38,8 @@ func NewEventsHandler(accountManager server.AccountManager, authCfg AuthCfg) *Ev } } -// GetAllEvents list of the given account -func (h *EventsHandler) GetAllEvents(w http.ResponseWriter, r *http.Request) { +// getAllEvents list of the given account +func (h *handler) getAllEvents(w http.ResponseWriter, r *http.Request) { claims := h.claimsExtractor.FromRequestContext(r) accountID, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims) if err != nil { @@ -60,7 +67,7 @@ func (h *EventsHandler) GetAllEvents(w http.ResponseWriter, r *http.Request) { util.WriteJSONObject(r.Context(), w, events) } -func (h *EventsHandler) fillEventsWithUserInfo(ctx context.Context, events []*api.Event, accountId, userId string) error { +func (h *handler) fillEventsWithUserInfo(ctx context.Context, events []*api.Event, accountId, userId string) error { // build email, name maps based on users userInfos, err := h.accountManager.GetUsersFromAccount(ctx, accountId, userId) if err != nil { diff --git a/management/server/http/events_handler_test.go b/management/server/http/handlers/events/events_handler_test.go similarity index 95% rename from management/server/http/events_handler_test.go rename to management/server/http/handlers/events/events_handler_test.go index e525cf2ee..17478aba3 100644 --- a/management/server/http/events_handler_test.go +++ b/management/server/http/handlers/events/events_handler_test.go @@ -1,4 +1,4 @@ -package http +package events import ( "context" @@ -13,15 +13,15 @@ import ( "github.com/gorilla/mux" "github.com/stretchr/testify/assert" - "github.com/netbirdio/netbird/management/server" "github.com/netbirdio/netbird/management/server/activity" "github.com/netbirdio/netbird/management/server/http/api" "github.com/netbirdio/netbird/management/server/jwtclaims" "github.com/netbirdio/netbird/management/server/mock_server" + "github.com/netbirdio/netbird/management/server/types" ) -func initEventsTestData(account string, events ...*activity.Event) *EventsHandler { - return &EventsHandler{ +func initEventsTestData(account string, events ...*activity.Event) *handler { + return &handler{ accountManager: &mock_server.MockAccountManager{ GetEventsFunc: func(_ context.Context, accountID, userID string) ([]*activity.Event, error) { if accountID == account { @@ -32,8 +32,8 @@ func initEventsTestData(account string, events ...*activity.Event) *EventsHandle GetAccountIDFromTokenFunc: func(_ context.Context, claims jwtclaims.AuthorizationClaims) (string, string, error) { return claims.AccountId, claims.UserId, nil }, - GetUsersFromAccountFunc: func(_ context.Context, accountID, userID string) ([]*server.UserInfo, error) { - return make([]*server.UserInfo, 0), nil + GetUsersFromAccountFunc: func(_ context.Context, accountID, userID string) ([]*types.UserInfo, error) { + return make([]*types.UserInfo, 0), nil }, }, claimsExtractor: jwtclaims.NewClaimsExtractor( @@ -183,7 +183,7 @@ func TestEvents_GetEvents(t *testing.T) { requestBody io.Reader }{ { - name: "GetAllEvents OK", + name: "getAllEvents OK", expectedBody: true, requestType: http.MethodGet, requestPath: "/api/events/", @@ -191,7 +191,7 @@ func TestEvents_GetEvents(t *testing.T) { }, } accountID := "test_account" - adminUser := server.NewAdminUser("test_user") + adminUser := types.NewAdminUser("test_user") events := generateEvents(accountID, adminUser.Id) handler := initEventsTestData(accountID, events...) @@ -201,7 +201,7 @@ func TestEvents_GetEvents(t *testing.T) { req := httptest.NewRequest(tc.requestType, tc.requestPath, tc.requestBody) router := mux.NewRouter() - router.HandleFunc("/api/events/", handler.GetAllEvents).Methods("GET") + router.HandleFunc("/api/events/", handler.getAllEvents).Methods("GET") router.ServeHTTP(recorder, req) res := recorder.Result() diff --git a/management/server/http/groups_handler.go b/management/server/http/handlers/groups/groups_handler.go similarity index 69% rename from management/server/http/groups_handler.go rename to management/server/http/handlers/groups/groups_handler.go index f369d1a00..ee52d8b4c 100644 --- a/management/server/http/groups_handler.go +++ b/management/server/http/handlers/groups/groups_handler.go @@ -1,30 +1,41 @@ -package http +package groups import ( "encoding/json" "net/http" "github.com/gorilla/mux" - nbpeer "github.com/netbirdio/netbird/management/server/peer" log "github.com/sirupsen/logrus" + "github.com/netbirdio/netbird/management/server/http/configs" + nbpeer "github.com/netbirdio/netbird/management/server/peer" + "github.com/netbirdio/netbird/management/server/types" + "github.com/netbirdio/netbird/management/server" - nbgroup "github.com/netbirdio/netbird/management/server/group" "github.com/netbirdio/netbird/management/server/http/api" "github.com/netbirdio/netbird/management/server/http/util" "github.com/netbirdio/netbird/management/server/jwtclaims" "github.com/netbirdio/netbird/management/server/status" ) -// GroupsHandler is a handler that returns groups of the account -type GroupsHandler struct { +// handler is a handler that returns groups of the account +type handler struct { accountManager server.AccountManager claimsExtractor *jwtclaims.ClaimsExtractor } -// NewGroupsHandler creates a new GroupsHandler HTTP handler -func NewGroupsHandler(accountManager server.AccountManager, authCfg AuthCfg) *GroupsHandler { - return &GroupsHandler{ +func AddEndpoints(accountManager server.AccountManager, authCfg configs.AuthCfg, router *mux.Router) { + groupsHandler := newHandler(accountManager, authCfg) + router.HandleFunc("/groups", groupsHandler.getAllGroups).Methods("GET", "OPTIONS") + router.HandleFunc("/groups", groupsHandler.createGroup).Methods("POST", "OPTIONS") + router.HandleFunc("/groups/{groupId}", groupsHandler.updateGroup).Methods("PUT", "OPTIONS") + router.HandleFunc("/groups/{groupId}", groupsHandler.getGroup).Methods("GET", "OPTIONS") + router.HandleFunc("/groups/{groupId}", groupsHandler.deleteGroup).Methods("DELETE", "OPTIONS") +} + +// newHandler creates a new groups handler +func newHandler(accountManager server.AccountManager, authCfg configs.AuthCfg) *handler { + return &handler{ accountManager: accountManager, claimsExtractor: jwtclaims.NewClaimsExtractor( jwtclaims.WithAudience(authCfg.Audience), @@ -33,8 +44,8 @@ func NewGroupsHandler(accountManager server.AccountManager, authCfg AuthCfg) *Gr } } -// GetAllGroups list for the account -func (h *GroupsHandler) GetAllGroups(w http.ResponseWriter, r *http.Request) { +// getAllGroups list for the account +func (h *handler) getAllGroups(w http.ResponseWriter, r *http.Request) { claims := h.claimsExtractor.FromRequestContext(r) accountID, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims) if err != nil { @@ -63,8 +74,8 @@ func (h *GroupsHandler) GetAllGroups(w http.ResponseWriter, r *http.Request) { util.WriteJSONObject(r.Context(), w, groupsResponse) } -// UpdateGroup handles update to a group identified by a given ID -func (h *GroupsHandler) UpdateGroup(w http.ResponseWriter, r *http.Request) { +// updateGroup handles update to a group identified by a given ID +func (h *handler) updateGroup(w http.ResponseWriter, r *http.Request) { claims := h.claimsExtractor.FromRequestContext(r) accountID, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims) if err != nil { @@ -118,10 +129,21 @@ func (h *GroupsHandler) UpdateGroup(w http.ResponseWriter, r *http.Request) { } else { peers = *req.Peers } - group := nbgroup.Group{ + + resources := make([]types.Resource, 0) + if req.Resources != nil { + for _, res := range *req.Resources { + resource := types.Resource{} + resource.FromAPIRequest(&res) + resources = append(resources, resource) + } + } + + group := types.Group{ ID: groupID, Name: req.Name, Peers: peers, + Resources: resources, Issued: existingGroup.Issued, IntegrationReference: existingGroup.IntegrationReference, } @@ -141,8 +163,8 @@ func (h *GroupsHandler) UpdateGroup(w http.ResponseWriter, r *http.Request) { util.WriteJSONObject(r.Context(), w, toGroupResponse(accountPeers, &group)) } -// CreateGroup handles group creation request -func (h *GroupsHandler) CreateGroup(w http.ResponseWriter, r *http.Request) { +// createGroup handles group creation request +func (h *handler) createGroup(w http.ResponseWriter, r *http.Request) { claims := h.claimsExtractor.FromRequestContext(r) accountID, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims) if err != nil { @@ -168,10 +190,21 @@ func (h *GroupsHandler) CreateGroup(w http.ResponseWriter, r *http.Request) { } else { peers = *req.Peers } - group := nbgroup.Group{ - Name: req.Name, - Peers: peers, - Issued: nbgroup.GroupIssuedAPI, + + resources := make([]types.Resource, 0) + if req.Resources != nil { + for _, res := range *req.Resources { + resource := types.Resource{} + resource.FromAPIRequest(&res) + resources = append(resources, resource) + } + } + + group := types.Group{ + Name: req.Name, + Peers: peers, + Resources: resources, + Issued: types.GroupIssuedAPI, } err = h.accountManager.SaveGroup(r.Context(), accountID, userID, &group) @@ -189,8 +222,8 @@ func (h *GroupsHandler) CreateGroup(w http.ResponseWriter, r *http.Request) { util.WriteJSONObject(r.Context(), w, toGroupResponse(accountPeers, &group)) } -// DeleteGroup handles group deletion request -func (h *GroupsHandler) DeleteGroup(w http.ResponseWriter, r *http.Request) { +// deleteGroup handles group deletion request +func (h *handler) deleteGroup(w http.ResponseWriter, r *http.Request) { claims := h.claimsExtractor.FromRequestContext(r) accountID, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims) if err != nil { @@ -215,11 +248,11 @@ func (h *GroupsHandler) DeleteGroup(w http.ResponseWriter, r *http.Request) { return } - util.WriteJSONObject(r.Context(), w, emptyObject{}) + util.WriteJSONObject(r.Context(), w, util.EmptyObject{}) } -// GetGroup returns a group -func (h *GroupsHandler) GetGroup(w http.ResponseWriter, r *http.Request) { +// getGroup returns a group +func (h *handler) getGroup(w http.ResponseWriter, r *http.Request) { claims := h.claimsExtractor.FromRequestContext(r) accountID, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims) if err != nil { @@ -248,13 +281,19 @@ func (h *GroupsHandler) GetGroup(w http.ResponseWriter, r *http.Request) { } -func toGroupResponse(peers []*nbpeer.Peer, group *nbgroup.Group) *api.Group { +func toGroupResponse(peers []*nbpeer.Peer, group *types.Group) *api.Group { peersMap := make(map[string]*nbpeer.Peer, len(peers)) for _, peer := range peers { peersMap[peer.ID] = peer } - cache := make(map[string]api.PeerMinimum) + resMap := make(map[string]types.Resource, len(peers)) + for _, peer := range peers { + peersMap[peer.ID] = peer + } + + peerCache := make(map[string]api.PeerMinimum) + resCache := make(map[string]api.Resource) gr := api.Group{ Id: group.ID, Name: group.Name, @@ -262,7 +301,7 @@ func toGroupResponse(peers []*nbpeer.Peer, group *nbgroup.Group) *api.Group { } for _, pid := range group.Peers { - _, ok := cache[pid] + _, ok := peerCache[pid] if !ok { peer, ok := peersMap[pid] if !ok { @@ -272,12 +311,27 @@ func toGroupResponse(peers []*nbpeer.Peer, group *nbgroup.Group) *api.Group { Id: peer.ID, Name: peer.Name, } - cache[pid] = peerResp + peerCache[pid] = peerResp gr.Peers = append(gr.Peers, peerResp) } } gr.PeersCount = len(gr.Peers) + for _, res := range group.Resources { + _, ok := resCache[res.ID] + if !ok { + resource, ok := resMap[res.ID] + if !ok { + continue + } + resResp := resource.ToAPIResponse() + resCache[res.ID] = *resResp + gr.Resources = append(gr.Resources, *resResp) + } + } + + gr.ResourcesCount = len(gr.Resources) + return &gr } diff --git a/management/server/http/groups_handler_test.go b/management/server/http/handlers/groups/groups_handler_test.go similarity index 90% rename from management/server/http/groups_handler_test.go rename to management/server/http/handlers/groups/groups_handler_test.go index 7f3c81f18..49805ca9b 100644 --- a/management/server/http/groups_handler_test.go +++ b/management/server/http/handlers/groups/groups_handler_test.go @@ -1,4 +1,4 @@ -package http +package groups import ( "bytes" @@ -17,13 +17,13 @@ import ( "golang.org/x/exp/maps" "github.com/netbirdio/netbird/management/server" - nbgroup "github.com/netbirdio/netbird/management/server/group" "github.com/netbirdio/netbird/management/server/http/api" "github.com/netbirdio/netbird/management/server/http/util" "github.com/netbirdio/netbird/management/server/jwtclaims" "github.com/netbirdio/netbird/management/server/mock_server" nbpeer "github.com/netbirdio/netbird/management/server/peer" "github.com/netbirdio/netbird/management/server/status" + "github.com/netbirdio/netbird/management/server/types" ) var TestPeers = map[string]*nbpeer.Peer{ @@ -31,20 +31,20 @@ var TestPeers = map[string]*nbpeer.Peer{ "B": {Key: "B", ID: "peer-B-ID", IP: net.ParseIP("200.200.200.200")}, } -func initGroupTestData(initGroups ...*nbgroup.Group) *GroupsHandler { - return &GroupsHandler{ +func initGroupTestData(initGroups ...*types.Group) *handler { + return &handler{ accountManager: &mock_server.MockAccountManager{ - SaveGroupFunc: func(_ context.Context, accountID, userID string, group *nbgroup.Group) error { + SaveGroupFunc: func(_ context.Context, accountID, userID string, group *types.Group) error { if !strings.HasPrefix(group.ID, "id-") { group.ID = "id-was-set" } return nil }, - GetGroupFunc: func(_ context.Context, _, groupID, _ string) (*nbgroup.Group, error) { - groups := map[string]*nbgroup.Group{ - "id-jwt-group": {ID: "id-jwt-group", Name: "From JWT", Issued: nbgroup.GroupIssuedJWT}, - "id-existed": {ID: "id-existed", Peers: []string{"A", "B"}, Issued: nbgroup.GroupIssuedAPI}, - "id-all": {ID: "id-all", Name: "All", Issued: nbgroup.GroupIssuedAPI}, + GetGroupFunc: func(_ context.Context, _, groupID, _ string) (*types.Group, error) { + groups := map[string]*types.Group{ + "id-jwt-group": {ID: "id-jwt-group", Name: "From JWT", Issued: types.GroupIssuedJWT}, + "id-existed": {ID: "id-existed", Peers: []string{"A", "B"}, Issued: types.GroupIssuedAPI}, + "id-all": {ID: "id-all", Name: "All", Issued: types.GroupIssuedAPI}, } for _, group := range initGroups { @@ -61,9 +61,9 @@ func initGroupTestData(initGroups ...*nbgroup.Group) *GroupsHandler { GetAccountIDFromTokenFunc: func(_ context.Context, claims jwtclaims.AuthorizationClaims) (string, string, error) { return claims.AccountId, claims.UserId, nil }, - GetGroupByNameFunc: func(ctx context.Context, groupName, _ string) (*nbgroup.Group, error) { + GetGroupByNameFunc: func(ctx context.Context, groupName, _ string) (*types.Group, error) { if groupName == "All" { - return &nbgroup.Group{ID: "id-all", Name: "All", Issued: nbgroup.GroupIssuedAPI}, nil + return &types.Group{ID: "id-all", Name: "All", Issued: types.GroupIssuedAPI}, nil } return nil, fmt.Errorf("unknown group name") @@ -106,21 +106,21 @@ func TestGetGroup(t *testing.T) { requestBody io.Reader }{ { - name: "GetGroup OK", + name: "getGroup OK", expectedBody: true, requestType: http.MethodGet, requestPath: "/api/groups/idofthegroup", expectedStatus: http.StatusOK, }, { - name: "GetGroup not found", + name: "getGroup not found", requestType: http.MethodGet, requestPath: "/api/groups/notexists", expectedStatus: http.StatusNotFound, }, } - group := &nbgroup.Group{ + group := &types.Group{ ID: "idofthegroup", Name: "Group", } @@ -133,7 +133,7 @@ func TestGetGroup(t *testing.T) { req := httptest.NewRequest(tc.requestType, tc.requestPath, tc.requestBody) router := mux.NewRouter() - router.HandleFunc("/api/groups/{groupId}", p.GetGroup).Methods("GET") + router.HandleFunc("/api/groups/{groupId}", p.getGroup).Methods("GET") router.ServeHTTP(recorder, req) res := recorder.Result() @@ -154,7 +154,7 @@ func TestGetGroup(t *testing.T) { t.Fatalf("I don't know what I expected; %v", err) } - got := &nbgroup.Group{} + got := &types.Group{} if err = json.Unmarshal(content, &got); err != nil { t.Fatalf("Sent content is not in correct json format; %v", err) } @@ -254,8 +254,8 @@ func TestWriteGroup(t *testing.T) { req := httptest.NewRequest(tc.requestType, tc.requestPath, tc.requestBody) router := mux.NewRouter() - router.HandleFunc("/api/groups", p.CreateGroup).Methods("POST") - router.HandleFunc("/api/groups/{groupId}", p.UpdateGroup).Methods("PUT") + router.HandleFunc("/api/groups", p.createGroup).Methods("POST") + router.HandleFunc("/api/groups/{groupId}", p.updateGroup).Methods("PUT") router.ServeHTTP(recorder, req) res := recorder.Result() @@ -331,7 +331,7 @@ func TestDeleteGroup(t *testing.T) { req := httptest.NewRequest(tc.requestType, tc.requestPath, nil) router := mux.NewRouter() - router.HandleFunc("/api/groups/{groupId}", p.DeleteGroup).Methods("DELETE") + router.HandleFunc("/api/groups/{groupId}", p.deleteGroup).Methods("DELETE") router.ServeHTTP(recorder, req) res := recorder.Result() diff --git a/management/server/http/handlers/networks/handler.go b/management/server/http/handlers/networks/handler.go new file mode 100644 index 000000000..1ce856118 --- /dev/null +++ b/management/server/http/handlers/networks/handler.go @@ -0,0 +1,229 @@ +package networks + +import ( + "context" + "encoding/json" + "fmt" + "net/http" + + "github.com/gorilla/mux" + + "github.com/netbirdio/netbird/management/server/http/api" + "github.com/netbirdio/netbird/management/server/http/configs" + "github.com/netbirdio/netbird/management/server/http/util" + "github.com/netbirdio/netbird/management/server/jwtclaims" + "github.com/netbirdio/netbird/management/server/networks" + "github.com/netbirdio/netbird/management/server/networks/types" + "github.com/netbirdio/netbird/management/server/status" +) + +// handler is a handler that returns networks of the account +type handler struct { + networksManager networks.Manager + extractFromToken func(ctx context.Context, claims jwtclaims.AuthorizationClaims) (string, string, error) + claimsExtractor *jwtclaims.ClaimsExtractor +} + +func AddEndpoints(networksManager networks.Manager, extractFromToken func(ctx context.Context, claims jwtclaims.AuthorizationClaims) (string, string, error), authCfg configs.AuthCfg, router *mux.Router) { + networksHandler := newHandler(networksManager, extractFromToken, authCfg) + router.HandleFunc("/networks", networksHandler.getAllNetworks).Methods("GET", "OPTIONS") + router.HandleFunc("/networks", networksHandler.createNetwork).Methods("POST", "OPTIONS") + router.HandleFunc("/networks/{networkId}", networksHandler.getNetwork).Methods("GET", "OPTIONS") + router.HandleFunc("/networks/{networkId}", networksHandler.updateNetwork).Methods("PUT", "OPTIONS") + router.HandleFunc("/networks/{networkId}", networksHandler.deleteNetwork).Methods("DELETE", "OPTIONS") + addRouterEndpoints(networksManager.GetRouterManager(), extractFromToken, authCfg, router) + addResourceEndpoints(networksManager.GetResourceManager(), extractFromToken, authCfg, router) +} + +func newHandler(networksManager networks.Manager, extractFromToken func(ctx context.Context, claims jwtclaims.AuthorizationClaims) (string, string, error), authCfg configs.AuthCfg) *handler { + return &handler{ + networksManager: networksManager, + extractFromToken: extractFromToken, + claimsExtractor: jwtclaims.NewClaimsExtractor( + jwtclaims.WithAudience(authCfg.Audience), + jwtclaims.WithUserIDClaim(authCfg.UserIDClaim), + ), + } +} + +func (h *handler) getAllNetworks(w http.ResponseWriter, r *http.Request) { + claims := h.claimsExtractor.FromRequestContext(r) + accountID, userID, err := h.extractFromToken(r.Context(), claims) + if err != nil { + util.WriteError(r.Context(), err, w) + return + } + + networks, err := h.networksManager.GetAllNetworks(r.Context(), accountID, userID) + if err != nil { + util.WriteError(r.Context(), err, w) + return + } + + routers, err := h.networksManager.GetRouterManager().GetAllRouterIDsInAccount(r.Context(), accountID, userID) + if err != nil { + util.WriteError(r.Context(), err, w) + return + } + + resources, err := h.networksManager.GetResourceManager().GetAllResourceIDsInAccount(r.Context(), accountID, userID) + if err != nil { + util.WriteError(r.Context(), err, w) + return + } + + var networkResponse []*api.Network + for _, network := range networks { + networkResponse = append(networkResponse, network.ToAPIResponse(routers[network.ID], resources[network.ID])) + } + + util.WriteJSONObject(r.Context(), w, networkResponse) +} + +func (h *handler) createNetwork(w http.ResponseWriter, r *http.Request) { + claims := h.claimsExtractor.FromRequestContext(r) + accountID, userID, err := h.extractFromToken(r.Context(), claims) + if err != nil { + util.WriteError(r.Context(), err, w) + return + } + + var req api.NetworkRequest + err = json.NewDecoder(r.Body).Decode(&req) + if err != nil { + util.WriteErrorResponse("couldn't parse JSON request", http.StatusBadRequest, w) + return + } + + network := &types.Network{} + network.FromAPIRequest(&req) + + network.AccountID = accountID + network, err = h.networksManager.CreateNetwork(r.Context(), userID, network) + if err != nil { + util.WriteError(r.Context(), err, w) + return + } + + util.WriteJSONObject(r.Context(), w, network.ToAPIResponse([]string{}, []string{})) +} + +func (h *handler) getNetwork(w http.ResponseWriter, r *http.Request) { + claims := h.claimsExtractor.FromRequestContext(r) + accountID, userID, err := h.extractFromToken(r.Context(), claims) + if err != nil { + util.WriteError(r.Context(), err, w) + return + } + + vars := mux.Vars(r) + networkID := vars["networkId"] + if len(networkID) == 0 { + util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "invalid network ID"), w) + return + } + + network, err := h.networksManager.GetNetwork(r.Context(), accountID, userID, networkID) + if err != nil { + util.WriteError(r.Context(), err, w) + return + } + + routerIDs, resourceIDs, err := h.collectIDsInNetwork(r.Context(), accountID, userID, networkID) + if err != nil { + util.WriteError(r.Context(), err, w) + return + } + + util.WriteJSONObject(r.Context(), w, network.ToAPIResponse(routerIDs, resourceIDs)) +} + +func (h *handler) updateNetwork(w http.ResponseWriter, r *http.Request) { + claims := h.claimsExtractor.FromRequestContext(r) + accountID, userID, err := h.extractFromToken(r.Context(), claims) + if err != nil { + util.WriteError(r.Context(), err, w) + return + } + + vars := mux.Vars(r) + networkID := vars["networkId"] + if len(networkID) == 0 { + util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "invalid network ID"), w) + return + } + + var req api.NetworkRequest + err = json.NewDecoder(r.Body).Decode(&req) + if err != nil { + util.WriteErrorResponse("couldn't parse JSON request", http.StatusBadRequest, w) + return + } + + network := &types.Network{} + network.FromAPIRequest(&req) + + network.ID = networkID + network.AccountID = accountID + network, err = h.networksManager.UpdateNetwork(r.Context(), userID, network) + if err != nil { + util.WriteError(r.Context(), err, w) + return + } + + routerIDs, resourceIDs, err := h.collectIDsInNetwork(r.Context(), accountID, userID, networkID) + if err != nil { + util.WriteError(r.Context(), err, w) + return + } + + util.WriteJSONObject(r.Context(), w, network.ToAPIResponse(routerIDs, resourceIDs)) +} + +func (h *handler) deleteNetwork(w http.ResponseWriter, r *http.Request) { + claims := h.claimsExtractor.FromRequestContext(r) + accountID, userID, err := h.extractFromToken(r.Context(), claims) + if err != nil { + util.WriteError(r.Context(), err, w) + return + } + + vars := mux.Vars(r) + networkID := vars["networkId"] + if len(networkID) == 0 { + util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "invalid network ID"), w) + return + } + + err = h.networksManager.DeleteNetwork(r.Context(), accountID, userID, networkID) + if err != nil { + util.WriteError(r.Context(), err, w) + return + } + + util.WriteJSONObject(r.Context(), w, util.EmptyObject{}) +} + +func (h *handler) collectIDsInNetwork(ctx context.Context, accountID, userID, networkID string) ([]string, []string, error) { + resources, err := h.networksManager.GetResourceManager().GetAllResourcesInNetwork(ctx, accountID, userID, networkID) + if err != nil { + return nil, nil, fmt.Errorf("failed to get resources in network: %w", err) + } + + var resourceIDs []string + for _, resource := range resources { + resourceIDs = append(resourceIDs, resource.ID) + } + + routers, err := h.networksManager.GetRouterManager().GetAllRoutersInNetwork(ctx, accountID, userID, networkID) + if err != nil { + return nil, nil, fmt.Errorf("failed to get routers in network: %w", err) + } + + var routerIDs []string + for _, router := range routers { + routerIDs = append(routerIDs, router.ID) + } + + return routerIDs, resourceIDs, nil +} diff --git a/management/server/http/handlers/networks/resources_handler.go b/management/server/http/handlers/networks/resources_handler.go new file mode 100644 index 000000000..9221cefaf --- /dev/null +++ b/management/server/http/handlers/networks/resources_handler.go @@ -0,0 +1,184 @@ +package networks + +import ( + "context" + "encoding/json" + "net/http" + + "github.com/gorilla/mux" + + "github.com/netbirdio/netbird/management/server/http/api" + "github.com/netbirdio/netbird/management/server/http/configs" + "github.com/netbirdio/netbird/management/server/http/util" + "github.com/netbirdio/netbird/management/server/jwtclaims" + "github.com/netbirdio/netbird/management/server/networks/resources" + "github.com/netbirdio/netbird/management/server/networks/resources/types" +) + +type resourceHandler struct { + resourceManager resources.Manager + extractFromToken func(ctx context.Context, claims jwtclaims.AuthorizationClaims) (string, string, error) + claimsExtractor *jwtclaims.ClaimsExtractor +} + +func addResourceEndpoints(resourcesManager resources.Manager, extractFromToken func(ctx context.Context, claims jwtclaims.AuthorizationClaims) (string, string, error), authCfg configs.AuthCfg, router *mux.Router) { + resourceHandler := newResourceHandler(resourcesManager, extractFromToken, authCfg) + router.HandleFunc("/networks/{networkId}/resources", resourceHandler.getAllResourcesInNetwork).Methods("GET", "OPTIONS") + router.HandleFunc("/networks/{networkId}/resources", resourceHandler.createResource).Methods("POST", "OPTIONS") + router.HandleFunc("/networks/{networkId}/resources/{resourceId}", resourceHandler.getResource).Methods("GET", "OPTIONS") + router.HandleFunc("/networks/{networkId}/resources/{resourceId}", resourceHandler.updateResource).Methods("PUT", "OPTIONS") + router.HandleFunc("/networks/{networkId}/resources/{resourceId}", resourceHandler.deleteResource).Methods("DELETE", "OPTIONS") + router.HandleFunc("/networks/resources", resourceHandler.getAllResourcesInAccount).Methods("GET", "OPTIONS") +} + +func newResourceHandler(resourceManager resources.Manager, extractFromToken func(ctx context.Context, claims jwtclaims.AuthorizationClaims) (string, string, error), authCfg configs.AuthCfg) *resourceHandler { + return &resourceHandler{ + resourceManager: resourceManager, + extractFromToken: extractFromToken, + claimsExtractor: jwtclaims.NewClaimsExtractor( + jwtclaims.WithAudience(authCfg.Audience), + jwtclaims.WithUserIDClaim(authCfg.UserIDClaim), + ), + } +} + +func (h *resourceHandler) getAllResourcesInNetwork(w http.ResponseWriter, r *http.Request) { + claims := h.claimsExtractor.FromRequestContext(r) + accountID, userID, err := h.extractFromToken(r.Context(), claims) + if err != nil { + util.WriteError(r.Context(), err, w) + return + } + + networkID := mux.Vars(r)["networkId"] + resources, err := h.resourceManager.GetAllResourcesInNetwork(r.Context(), accountID, userID, networkID) + if err != nil { + util.WriteError(r.Context(), err, w) + return + } + + var resourcesResponse []*api.NetworkResource + for _, resource := range resources { + resourcesResponse = append(resourcesResponse, resource.ToAPIResponse()) + } + + util.WriteJSONObject(r.Context(), w, resourcesResponse) +} +func (h *resourceHandler) getAllResourcesInAccount(w http.ResponseWriter, r *http.Request) { + claims := h.claimsExtractor.FromRequestContext(r) + accountID, userID, err := h.extractFromToken(r.Context(), claims) + if err != nil { + util.WriteError(r.Context(), err, w) + return + } + + resources, err := h.resourceManager.GetAllResourcesInAccount(r.Context(), accountID, userID) + if err != nil { + util.WriteError(r.Context(), err, w) + return + } + + var resourcesResponse []*api.NetworkResource + for _, resource := range resources { + resourcesResponse = append(resourcesResponse, resource.ToAPIResponse()) + } + + util.WriteJSONObject(r.Context(), w, resourcesResponse) +} + +func (h *resourceHandler) createResource(w http.ResponseWriter, r *http.Request) { + claims := h.claimsExtractor.FromRequestContext(r) + accountID, userID, err := h.extractFromToken(r.Context(), claims) + if err != nil { + util.WriteError(r.Context(), err, w) + return + } + + var req api.NetworkResourceRequest + err = json.NewDecoder(r.Body).Decode(&req) + if err != nil { + util.WriteErrorResponse("couldn't parse JSON request", http.StatusBadRequest, w) + return + } + + resource := &types.NetworkResource{} + resource.FromAPIRequest(&req) + + resource.NetworkID = mux.Vars(r)["networkId"] + resource.AccountID = accountID + resource, err = h.resourceManager.CreateResource(r.Context(), userID, resource) + if err != nil { + util.WriteError(r.Context(), err, w) + return + } + + util.WriteJSONObject(r.Context(), w, resource.ToAPIResponse()) +} + +func (h *resourceHandler) getResource(w http.ResponseWriter, r *http.Request) { + claims := h.claimsExtractor.FromRequestContext(r) + accountID, userID, err := h.extractFromToken(r.Context(), claims) + if err != nil { + util.WriteError(r.Context(), err, w) + return + } + + networkID := mux.Vars(r)["networkId"] + resourceID := mux.Vars(r)["resourceId"] + resource, err := h.resourceManager.GetResource(r.Context(), accountID, userID, networkID, resourceID) + if err != nil { + util.WriteError(r.Context(), err, w) + return + } + + util.WriteJSONObject(r.Context(), w, resource.ToAPIResponse()) +} + +func (h *resourceHandler) updateResource(w http.ResponseWriter, r *http.Request) { + claims := h.claimsExtractor.FromRequestContext(r) + accountID, userID, err := h.extractFromToken(r.Context(), claims) + if err != nil { + util.WriteError(r.Context(), err, w) + return + } + + var req api.NetworkResourceRequest + err = json.NewDecoder(r.Body).Decode(&req) + if err != nil { + util.WriteErrorResponse("couldn't parse JSON request", http.StatusBadRequest, w) + return + } + + resource := &types.NetworkResource{} + resource.FromAPIRequest(&req) + + resource.ID = mux.Vars(r)["resourceId"] + resource.NetworkID = mux.Vars(r)["networkId"] + resource.AccountID = accountID + resource, err = h.resourceManager.UpdateResource(r.Context(), userID, resource) + if err != nil { + util.WriteError(r.Context(), err, w) + return + } + + util.WriteJSONObject(r.Context(), w, resource.ToAPIResponse()) +} + +func (h *resourceHandler) deleteResource(w http.ResponseWriter, r *http.Request) { + claims := h.claimsExtractor.FromRequestContext(r) + accountID, userID, err := h.extractFromToken(r.Context(), claims) + if err != nil { + util.WriteError(r.Context(), err, w) + return + } + + networkID := mux.Vars(r)["networkId"] + resourceID := mux.Vars(r)["resourceId"] + err = h.resourceManager.DeleteResource(r.Context(), accountID, userID, networkID, resourceID) + if err != nil { + util.WriteError(r.Context(), err, w) + return + } + + util.WriteJSONObject(r.Context(), w, util.EmptyObject{}) +} diff --git a/management/server/http/handlers/networks/routers_handler.go b/management/server/http/handlers/networks/routers_handler.go new file mode 100644 index 000000000..2cf39a132 --- /dev/null +++ b/management/server/http/handlers/networks/routers_handler.go @@ -0,0 +1,165 @@ +package networks + +import ( + "context" + "encoding/json" + "net/http" + + "github.com/gorilla/mux" + + "github.com/netbirdio/netbird/management/server/http/api" + "github.com/netbirdio/netbird/management/server/http/configs" + "github.com/netbirdio/netbird/management/server/http/util" + "github.com/netbirdio/netbird/management/server/jwtclaims" + "github.com/netbirdio/netbird/management/server/networks/routers" + "github.com/netbirdio/netbird/management/server/networks/routers/types" +) + +type routersHandler struct { + routersManager routers.Manager + extractFromToken func(ctx context.Context, claims jwtclaims.AuthorizationClaims) (string, string, error) + claimsExtractor *jwtclaims.ClaimsExtractor +} + +func addRouterEndpoints(routersManager routers.Manager, extractFromToken func(ctx context.Context, claims jwtclaims.AuthorizationClaims) (string, string, error), authCfg configs.AuthCfg, router *mux.Router) { + routersHandler := newRoutersHandler(routersManager, extractFromToken, authCfg) + router.HandleFunc("/networks/{networkId}/routers", routersHandler.getAllRouters).Methods("GET", "OPTIONS") + router.HandleFunc("/networks/{networkId}/routers", routersHandler.createRouter).Methods("POST", "OPTIONS") + router.HandleFunc("/networks/{networkId}/routers/{routerId}", routersHandler.getRouter).Methods("GET", "OPTIONS") + router.HandleFunc("/networks/{networkId}/routers/{routerId}", routersHandler.updateRouter).Methods("PUT", "OPTIONS") + router.HandleFunc("/networks/{networkId}/routers/{routerId}", routersHandler.deleteRouter).Methods("DELETE", "OPTIONS") +} + +func newRoutersHandler(routersManager routers.Manager, extractFromToken func(ctx context.Context, claims jwtclaims.AuthorizationClaims) (string, string, error), authCfg configs.AuthCfg) *routersHandler { + return &routersHandler{ + routersManager: routersManager, + extractFromToken: extractFromToken, + claimsExtractor: jwtclaims.NewClaimsExtractor( + jwtclaims.WithAudience(authCfg.Audience), + jwtclaims.WithUserIDClaim(authCfg.UserIDClaim), + ), + } +} + +func (h *routersHandler) getAllRouters(w http.ResponseWriter, r *http.Request) { + claims := h.claimsExtractor.FromRequestContext(r) + accountID, userID, err := h.extractFromToken(r.Context(), claims) + if err != nil { + util.WriteError(r.Context(), err, w) + return + } + + networkID := mux.Vars(r)["networkId"] + routers, err := h.routersManager.GetAllRoutersInNetwork(r.Context(), accountID, userID, networkID) + if err != nil { + util.WriteError(r.Context(), err, w) + return + } + + var routersResponse []*api.NetworkRouter + for _, router := range routers { + routersResponse = append(routersResponse, router.ToAPIResponse()) + } + + util.WriteJSONObject(r.Context(), w, routersResponse) +} + +func (h *routersHandler) createRouter(w http.ResponseWriter, r *http.Request) { + claims := h.claimsExtractor.FromRequestContext(r) + accountID, userID, err := h.extractFromToken(r.Context(), claims) + if err != nil { + util.WriteError(r.Context(), err, w) + return + } + + networkID := mux.Vars(r)["networkId"] + var req api.NetworkRouterRequest + err = json.NewDecoder(r.Body).Decode(&req) + if err != nil { + util.WriteErrorResponse("couldn't parse JSON request", http.StatusBadRequest, w) + return + } + + router := &types.NetworkRouter{} + router.FromAPIRequest(&req) + + router.NetworkID = networkID + router.AccountID = accountID + + router, err = h.routersManager.CreateRouter(r.Context(), userID, router) + if err != nil { + util.WriteError(r.Context(), err, w) + return + } + + util.WriteJSONObject(r.Context(), w, router.ToAPIResponse()) +} + +func (h *routersHandler) getRouter(w http.ResponseWriter, r *http.Request) { + claims := h.claimsExtractor.FromRequestContext(r) + accountID, userID, err := h.extractFromToken(r.Context(), claims) + if err != nil { + util.WriteError(r.Context(), err, w) + return + } + + routerID := mux.Vars(r)["routerId"] + networkID := mux.Vars(r)["networkId"] + router, err := h.routersManager.GetRouter(r.Context(), accountID, userID, networkID, routerID) + if err != nil { + util.WriteError(r.Context(), err, w) + return + } + + util.WriteJSONObject(r.Context(), w, router.ToAPIResponse()) +} + +func (h *routersHandler) updateRouter(w http.ResponseWriter, r *http.Request) { + claims := h.claimsExtractor.FromRequestContext(r) + accountID, userID, err := h.extractFromToken(r.Context(), claims) + if err != nil { + util.WriteError(r.Context(), err, w) + return + } + + var req api.NetworkRouterRequest + err = json.NewDecoder(r.Body).Decode(&req) + if err != nil { + util.WriteErrorResponse("couldn't parse JSON request", http.StatusBadRequest, w) + return + } + + router := &types.NetworkRouter{} + router.FromAPIRequest(&req) + + router.NetworkID = mux.Vars(r)["networkId"] + router.ID = mux.Vars(r)["routerId"] + router.AccountID = accountID + + router, err = h.routersManager.UpdateRouter(r.Context(), userID, router) + if err != nil { + util.WriteError(r.Context(), err, w) + return + } + + util.WriteJSONObject(r.Context(), w, router.ToAPIResponse()) +} + +func (h *routersHandler) deleteRouter(w http.ResponseWriter, r *http.Request) { + claims := h.claimsExtractor.FromRequestContext(r) + accountID, userID, err := h.extractFromToken(r.Context(), claims) + if err != nil { + util.WriteError(r.Context(), err, w) + return + } + + routerID := mux.Vars(r)["routerId"] + networkID := mux.Vars(r)["networkId"] + err = h.routersManager.DeleteRouter(r.Context(), accountID, userID, networkID, routerID) + if err != nil { + util.WriteError(r.Context(), err, w) + return + } + + util.WriteJSONObject(r.Context(), w, struct{}{}) +} diff --git a/management/server/http/peers_handler.go b/management/server/http/handlers/peers/peers_handler.go similarity index 86% rename from management/server/http/peers_handler.go rename to management/server/http/handlers/peers/peers_handler.go index f5027cd77..4562766bd 100644 --- a/management/server/http/peers_handler.go +++ b/management/server/http/handlers/peers/peers_handler.go @@ -1,4 +1,4 @@ -package http +package peers import ( "context" @@ -10,23 +10,32 @@ import ( log "github.com/sirupsen/logrus" "github.com/netbirdio/netbird/management/server" - nbgroup "github.com/netbirdio/netbird/management/server/group" "github.com/netbirdio/netbird/management/server/http/api" + "github.com/netbirdio/netbird/management/server/http/configs" "github.com/netbirdio/netbird/management/server/http/util" "github.com/netbirdio/netbird/management/server/jwtclaims" nbpeer "github.com/netbirdio/netbird/management/server/peer" "github.com/netbirdio/netbird/management/server/status" + "github.com/netbirdio/netbird/management/server/types" ) -// PeersHandler is a handler that returns peers of the account -type PeersHandler struct { +// Handler is a handler that returns peers of the account +type Handler struct { accountManager server.AccountManager claimsExtractor *jwtclaims.ClaimsExtractor } -// NewPeersHandler creates a new PeersHandler HTTP handler -func NewPeersHandler(accountManager server.AccountManager, authCfg AuthCfg) *PeersHandler { - return &PeersHandler{ +func AddEndpoints(accountManager server.AccountManager, authCfg configs.AuthCfg, router *mux.Router) { + peersHandler := NewHandler(accountManager, authCfg) + router.HandleFunc("/peers", peersHandler.GetAllPeers).Methods("GET", "OPTIONS") + router.HandleFunc("/peers/{peerId}", peersHandler.HandlePeer). + Methods("GET", "PUT", "DELETE", "OPTIONS") + router.HandleFunc("/peers/{peerId}/accessible-peers", peersHandler.GetAccessiblePeers).Methods("GET", "OPTIONS") +} + +// NewHandler creates a new peers Handler +func NewHandler(accountManager server.AccountManager, authCfg configs.AuthCfg) *Handler { + return &Handler{ accountManager: accountManager, claimsExtractor: jwtclaims.NewClaimsExtractor( jwtclaims.WithAudience(authCfg.Audience), @@ -35,7 +44,7 @@ func NewPeersHandler(accountManager server.AccountManager, authCfg AuthCfg) *Pee } } -func (h *PeersHandler) checkPeerStatus(peer *nbpeer.Peer) (*nbpeer.Peer, error) { +func (h *Handler) checkPeerStatus(peer *nbpeer.Peer) (*nbpeer.Peer, error) { peerToReturn := peer.Copy() if peer.Status.Connected { // Although we have online status in store we do not yet have an updated channel so have to show it as disconnected @@ -48,7 +57,7 @@ func (h *PeersHandler) checkPeerStatus(peer *nbpeer.Peer) (*nbpeer.Peer, error) return peerToReturn, nil } -func (h *PeersHandler) getPeer(ctx context.Context, account *server.Account, peerID, userID string, w http.ResponseWriter) { +func (h *Handler) getPeer(ctx context.Context, account *types.Account, peerID, userID string, w http.ResponseWriter) { peer, err := h.accountManager.GetPeer(ctx, account.Id, peerID, userID) if err != nil { util.WriteError(ctx, err, w) @@ -75,7 +84,7 @@ func (h *PeersHandler) getPeer(ctx context.Context, account *server.Account, pee util.WriteJSONObject(ctx, w, toSinglePeerResponse(peerToReturn, groupsInfo, dnsDomain, valid)) } -func (h *PeersHandler) updatePeer(ctx context.Context, account *server.Account, userID, peerID string, w http.ResponseWriter, r *http.Request) { +func (h *Handler) updatePeer(ctx context.Context, account *types.Account, userID, peerID string, w http.ResponseWriter, r *http.Request) { req := &api.PeerRequest{} err := json.NewDecoder(r.Body).Decode(&req) if err != nil { @@ -120,18 +129,18 @@ func (h *PeersHandler) updatePeer(ctx context.Context, account *server.Account, util.WriteJSONObject(r.Context(), w, toSinglePeerResponse(peer, groupMinimumInfo, dnsDomain, valid)) } -func (h *PeersHandler) deletePeer(ctx context.Context, accountID, userID string, peerID string, w http.ResponseWriter) { +func (h *Handler) deletePeer(ctx context.Context, accountID, userID string, peerID string, w http.ResponseWriter) { err := h.accountManager.DeletePeer(ctx, accountID, peerID, userID) if err != nil { log.WithContext(ctx).Errorf("failed to delete peer: %v", err) util.WriteError(ctx, err, w) return } - util.WriteJSONObject(ctx, w, emptyObject{}) + util.WriteJSONObject(ctx, w, util.EmptyObject{}) } // HandlePeer handles all peer requests for GET, PUT and DELETE operations -func (h *PeersHandler) HandlePeer(w http.ResponseWriter, r *http.Request) { +func (h *Handler) HandlePeer(w http.ResponseWriter, r *http.Request) { claims := h.claimsExtractor.FromRequestContext(r) accountID, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims) if err != nil { @@ -168,7 +177,7 @@ func (h *PeersHandler) HandlePeer(w http.ResponseWriter, r *http.Request) { } // GetAllPeers returns a list of all peers associated with a provided account -func (h *PeersHandler) GetAllPeers(w http.ResponseWriter, r *http.Request) { +func (h *Handler) GetAllPeers(w http.ResponseWriter, r *http.Request) { claims := h.claimsExtractor.FromRequestContext(r) accountID, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims) if err != nil { @@ -190,7 +199,7 @@ func (h *PeersHandler) GetAllPeers(w http.ResponseWriter, r *http.Request) { return } - groupsMap := map[string]*nbgroup.Group{} + groupsMap := map[string]*types.Group{} groups, _ := h.accountManager.GetAllGroups(r.Context(), accountID, userID) for _, group := range groups { groupsMap[group.ID] = group @@ -219,7 +228,7 @@ func (h *PeersHandler) GetAllPeers(w http.ResponseWriter, r *http.Request) { util.WriteJSONObject(r.Context(), w, respBody) } -func (h *PeersHandler) setApprovalRequiredFlag(respBody []*api.PeerBatch, approvedPeersMap map[string]struct{}) { +func (h *Handler) setApprovalRequiredFlag(respBody []*api.PeerBatch, approvedPeersMap map[string]struct{}) { for _, peer := range respBody { _, ok := approvedPeersMap[peer.Id] if !ok { @@ -229,7 +238,7 @@ func (h *PeersHandler) setApprovalRequiredFlag(respBody []*api.PeerBatch, approv } // GetAccessiblePeers returns a list of all peers that the specified peer can connect to within the network. -func (h *PeersHandler) GetAccessiblePeers(w http.ResponseWriter, r *http.Request) { +func (h *Handler) GetAccessiblePeers(w http.ResponseWriter, r *http.Request) { claims := h.claimsExtractor.FromRequestContext(r) accountID, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims) if err != nil { @@ -286,7 +295,7 @@ func (h *PeersHandler) GetAccessiblePeers(w http.ResponseWriter, r *http.Request util.WriteJSONObject(r.Context(), w, toAccessiblePeers(netMap, dnsDomain)) } -func toAccessiblePeers(netMap *server.NetworkMap, dnsDomain string) []api.AccessiblePeer { +func toAccessiblePeers(netMap *types.NetworkMap, dnsDomain string) []api.AccessiblePeer { accessiblePeers := make([]api.AccessiblePeer, 0, len(netMap.Peers)+len(netMap.OfflinePeers)) for _, p := range netMap.Peers { accessiblePeers = append(accessiblePeers, peerToAccessiblePeer(p, dnsDomain)) @@ -315,7 +324,7 @@ func peerToAccessiblePeer(peer *nbpeer.Peer, dnsDomain string) api.AccessiblePee } } -func toGroupsInfo(groups map[string]*nbgroup.Group, peerID string) []api.GroupMinimum { +func toGroupsInfo(groups map[string]*types.Group, peerID string) []api.GroupMinimum { groupsInfo := []api.GroupMinimum{} groupsChecked := make(map[string]struct{}) for _, group := range groups { diff --git a/management/server/http/peers_handler_test.go b/management/server/http/handlers/peers/peers_handler_test.go similarity index 94% rename from management/server/http/peers_handler_test.go rename to management/server/http/handlers/peers/peers_handler_test.go index dd49c03b8..83abc1c40 100644 --- a/management/server/http/peers_handler_test.go +++ b/management/server/http/handlers/peers/peers_handler_test.go @@ -1,4 +1,4 @@ -package http +package peers import ( "bytes" @@ -15,11 +15,10 @@ import ( "github.com/gorilla/mux" "golang.org/x/exp/maps" - "github.com/netbirdio/netbird/management/server" - nbgroup "github.com/netbirdio/netbird/management/server/group" "github.com/netbirdio/netbird/management/server/http/api" "github.com/netbirdio/netbird/management/server/jwtclaims" nbpeer "github.com/netbirdio/netbird/management/server/peer" + "github.com/netbirdio/netbird/management/server/types" "github.com/stretchr/testify/assert" @@ -38,8 +37,8 @@ const ( userIDKey ctxKey = "user_id" ) -func initTestMetaData(peers ...*nbpeer.Peer) *PeersHandler { - return &PeersHandler{ +func initTestMetaData(peers ...*nbpeer.Peer) *Handler { + return &Handler{ accountManager: &mock_server.MockAccountManager{ UpdatePeerFunc: func(_ context.Context, accountID, userID string, update *nbpeer.Peer) (*nbpeer.Peer, error) { var p *nbpeer.Peer @@ -73,18 +72,18 @@ func initTestMetaData(peers ...*nbpeer.Peer) *PeersHandler { GetAccountIDFromTokenFunc: func(_ context.Context, claims jwtclaims.AuthorizationClaims) (string, string, error) { return claims.AccountId, claims.UserId, nil }, - GetAccountByIDFunc: func(ctx context.Context, accountID string, userID string) (*server.Account, error) { + GetAccountByIDFunc: func(ctx context.Context, accountID string, userID string) (*types.Account, error) { peersMap := make(map[string]*nbpeer.Peer) for _, peer := range peers { peersMap[peer.ID] = peer.Copy() } - policy := &server.Policy{ + policy := &types.Policy{ ID: "policy", AccountID: accountID, Name: "policy", Enabled: true, - Rules: []*server.PolicyRule{ + Rules: []*types.PolicyRule{ { ID: "rule", Name: "rule", @@ -99,19 +98,19 @@ func initTestMetaData(peers ...*nbpeer.Peer) *PeersHandler { }, } - srvUser := server.NewRegularUser(serviceUser) + srvUser := types.NewRegularUser(serviceUser) srvUser.IsServiceUser = true - account := &server.Account{ + account := &types.Account{ Id: accountID, Domain: "hotmail.com", Peers: peersMap, - Users: map[string]*server.User{ - adminUser: server.NewAdminUser(adminUser), - regularUser: server.NewRegularUser(regularUser), + Users: map[string]*types.User{ + adminUser: types.NewAdminUser(adminUser), + regularUser: types.NewRegularUser(regularUser), serviceUser: srvUser, }, - Groups: map[string]*nbgroup.Group{ + Groups: map[string]*types.Group{ "group1": { ID: "group1", AccountID: accountID, @@ -120,12 +119,12 @@ func initTestMetaData(peers ...*nbpeer.Peer) *PeersHandler { Peers: maps.Keys(peersMap), }, }, - Settings: &server.Settings{ + Settings: &types.Settings{ PeerLoginExpirationEnabled: true, PeerLoginExpiration: time.Hour, }, - Policies: []*server.Policy{policy}, - Network: &server.Network{ + Policies: []*types.Policy{policy}, + Network: &types.Network{ Identifier: "ciclqisab2ss43jdn8q0", Net: net.IPNet{ IP: net.ParseIP("100.67.0.0"), diff --git a/management/server/http/geolocation_handler_test.go b/management/server/http/handlers/policies/geolocation_handler_test.go similarity index 91% rename from management/server/http/geolocation_handler_test.go rename to management/server/http/handlers/policies/geolocation_handler_test.go index 19c916dd2..fc5839baa 100644 --- a/management/server/http/geolocation_handler_test.go +++ b/management/server/http/handlers/policies/geolocation_handler_test.go @@ -1,4 +1,4 @@ -package http +package policies import ( "context" @@ -11,22 +11,22 @@ import ( "testing" "github.com/gorilla/mux" - "github.com/netbirdio/netbird/management/server" "github.com/stretchr/testify/assert" "github.com/netbirdio/netbird/management/server/geolocation" "github.com/netbirdio/netbird/management/server/http/api" "github.com/netbirdio/netbird/management/server/jwtclaims" "github.com/netbirdio/netbird/management/server/mock_server" + "github.com/netbirdio/netbird/management/server/types" "github.com/netbirdio/netbird/util" ) -func initGeolocationTestData(t *testing.T) *GeolocationsHandler { +func initGeolocationTestData(t *testing.T) *geolocationsHandler { t.Helper() var ( - mmdbPath = "../testdata/GeoLite2-City_20240305.mmdb" - geonamesdbPath = "../testdata/geonames_20240305.db" + mmdbPath = "../../../testdata/GeoLite2-City_20240305.mmdb" + geonamesdbPath = "../../../testdata/geonames_20240305.db" ) tempDir := t.TempDir() @@ -41,13 +41,13 @@ func initGeolocationTestData(t *testing.T) *GeolocationsHandler { assert.NoError(t, err) t.Cleanup(func() { _ = geo.Stop() }) - return &GeolocationsHandler{ + return &geolocationsHandler{ accountManager: &mock_server.MockAccountManager{ GetAccountIDFromTokenFunc: func(_ context.Context, claims jwtclaims.AuthorizationClaims) (string, string, error) { return claims.AccountId, claims.UserId, nil }, - GetUserByIDFunc: func(ctx context.Context, id string) (*server.User, error) { - return server.NewAdminUser(id), nil + GetUserByIDFunc: func(ctx context.Context, id string) (*types.User, error) { + return types.NewAdminUser(id), nil }, }, geolocationManager: geo, @@ -114,7 +114,7 @@ func TestGetCitiesByCountry(t *testing.T) { req := httptest.NewRequest(tc.requestType, tc.requestPath, nil) router := mux.NewRouter() - router.HandleFunc("/api/locations/countries/{country}/cities", geolocationHandler.GetCitiesByCountry).Methods("GET") + router.HandleFunc("/api/locations/countries/{country}/cities", geolocationHandler.getCitiesByCountry).Methods("GET") router.ServeHTTP(recorder, req) res := recorder.Result() @@ -202,7 +202,7 @@ func TestGetAllCountries(t *testing.T) { req := httptest.NewRequest(tc.requestType, tc.requestPath, nil) router := mux.NewRouter() - router.HandleFunc("/api/locations/countries", geolocationHandler.GetAllCountries).Methods("GET") + router.HandleFunc("/api/locations/countries", geolocationHandler.getAllCountries).Methods("GET") router.ServeHTTP(recorder, req) res := recorder.Result() diff --git a/management/server/http/geolocations_handler.go b/management/server/http/handlers/policies/geolocations_handler.go similarity index 72% rename from management/server/http/geolocations_handler.go rename to management/server/http/handlers/policies/geolocations_handler.go index 418228abf..e5bf3e695 100644 --- a/management/server/http/geolocations_handler.go +++ b/management/server/http/handlers/policies/geolocations_handler.go @@ -1,4 +1,4 @@ -package http +package policies import ( "net/http" @@ -9,6 +9,7 @@ import ( "github.com/netbirdio/netbird/management/server" "github.com/netbirdio/netbird/management/server/geolocation" "github.com/netbirdio/netbird/management/server/http/api" + "github.com/netbirdio/netbird/management/server/http/configs" "github.com/netbirdio/netbird/management/server/http/util" "github.com/netbirdio/netbird/management/server/jwtclaims" "github.com/netbirdio/netbird/management/server/status" @@ -18,16 +19,22 @@ var ( countryCodeRegex = regexp.MustCompile("^[a-zA-Z]{2}$") ) -// GeolocationsHandler is a handler that returns locations. -type GeolocationsHandler struct { +// geolocationsHandler is a handler that returns locations. +type geolocationsHandler struct { accountManager server.AccountManager geolocationManager *geolocation.Geolocation claimsExtractor *jwtclaims.ClaimsExtractor } -// NewGeolocationsHandlerHandler creates a new Geolocations handler -func NewGeolocationsHandlerHandler(accountManager server.AccountManager, geolocationManager *geolocation.Geolocation, authCfg AuthCfg) *GeolocationsHandler { - return &GeolocationsHandler{ +func addLocationsEndpoint(accountManager server.AccountManager, locationManager *geolocation.Geolocation, authCfg configs.AuthCfg, router *mux.Router) { + locationHandler := newGeolocationsHandlerHandler(accountManager, locationManager, authCfg) + router.HandleFunc("/locations/countries", locationHandler.getAllCountries).Methods("GET", "OPTIONS") + router.HandleFunc("/locations/countries/{country}/cities", locationHandler.getCitiesByCountry).Methods("GET", "OPTIONS") +} + +// newGeolocationsHandlerHandler creates a new Geolocations handler +func newGeolocationsHandlerHandler(accountManager server.AccountManager, geolocationManager *geolocation.Geolocation, authCfg configs.AuthCfg) *geolocationsHandler { + return &geolocationsHandler{ accountManager: accountManager, geolocationManager: geolocationManager, claimsExtractor: jwtclaims.NewClaimsExtractor( @@ -37,8 +44,8 @@ func NewGeolocationsHandlerHandler(accountManager server.AccountManager, geoloca } } -// GetAllCountries retrieves a list of all countries -func (l *GeolocationsHandler) GetAllCountries(w http.ResponseWriter, r *http.Request) { +// getAllCountries retrieves a list of all countries +func (l *geolocationsHandler) getAllCountries(w http.ResponseWriter, r *http.Request) { if err := l.authenticateUser(r); err != nil { util.WriteError(r.Context(), err, w) return @@ -63,8 +70,8 @@ func (l *GeolocationsHandler) GetAllCountries(w http.ResponseWriter, r *http.Req util.WriteJSONObject(r.Context(), w, countries) } -// GetCitiesByCountry retrieves a list of cities based on the given country code -func (l *GeolocationsHandler) GetCitiesByCountry(w http.ResponseWriter, r *http.Request) { +// getCitiesByCountry retrieves a list of cities based on the given country code +func (l *geolocationsHandler) getCitiesByCountry(w http.ResponseWriter, r *http.Request) { if err := l.authenticateUser(r); err != nil { util.WriteError(r.Context(), err, w) return @@ -96,7 +103,7 @@ func (l *GeolocationsHandler) GetCitiesByCountry(w http.ResponseWriter, r *http. util.WriteJSONObject(r.Context(), w, cities) } -func (l *GeolocationsHandler) authenticateUser(r *http.Request) error { +func (l *geolocationsHandler) authenticateUser(r *http.Request) error { claims := l.claimsExtractor.FromRequestContext(r) _, userID, err := l.accountManager.GetAccountIDFromToken(r.Context(), claims) if err != nil { diff --git a/management/server/http/policies_handler.go b/management/server/http/handlers/policies/policies_handler.go similarity index 66% rename from management/server/http/policies_handler.go rename to management/server/http/handlers/policies/policies_handler.go index eff9092d4..d538d07db 100644 --- a/management/server/http/policies_handler.go +++ b/management/server/http/handlers/policies/policies_handler.go @@ -1,4 +1,4 @@ -package http +package policies import ( "encoding/json" @@ -6,23 +6,36 @@ import ( "strconv" "github.com/gorilla/mux" + "github.com/netbirdio/netbird/management/server" - nbgroup "github.com/netbirdio/netbird/management/server/group" + "github.com/netbirdio/netbird/management/server/geolocation" "github.com/netbirdio/netbird/management/server/http/api" + "github.com/netbirdio/netbird/management/server/http/configs" "github.com/netbirdio/netbird/management/server/http/util" "github.com/netbirdio/netbird/management/server/jwtclaims" "github.com/netbirdio/netbird/management/server/status" + "github.com/netbirdio/netbird/management/server/types" ) -// Policies is a handler that returns policy of the account -type Policies struct { +// handler is a handler that returns policy of the account +type handler struct { accountManager server.AccountManager claimsExtractor *jwtclaims.ClaimsExtractor } -// NewPoliciesHandler creates a new Policies handler -func NewPoliciesHandler(accountManager server.AccountManager, authCfg AuthCfg) *Policies { - return &Policies{ +func AddEndpoints(accountManager server.AccountManager, locationManager *geolocation.Geolocation, authCfg configs.AuthCfg, router *mux.Router) { + policiesHandler := newHandler(accountManager, authCfg) + router.HandleFunc("/policies", policiesHandler.getAllPolicies).Methods("GET", "OPTIONS") + router.HandleFunc("/policies", policiesHandler.createPolicy).Methods("POST", "OPTIONS") + router.HandleFunc("/policies/{policyId}", policiesHandler.updatePolicy).Methods("PUT", "OPTIONS") + router.HandleFunc("/policies/{policyId}", policiesHandler.getPolicy).Methods("GET", "OPTIONS") + router.HandleFunc("/policies/{policyId}", policiesHandler.deletePolicy).Methods("DELETE", "OPTIONS") + addPostureCheckEndpoint(accountManager, locationManager, authCfg, router) +} + +// newHandler creates a new policies handler +func newHandler(accountManager server.AccountManager, authCfg configs.AuthCfg) *handler { + return &handler{ accountManager: accountManager, claimsExtractor: jwtclaims.NewClaimsExtractor( jwtclaims.WithAudience(authCfg.Audience), @@ -31,8 +44,8 @@ func NewPoliciesHandler(accountManager server.AccountManager, authCfg AuthCfg) * } } -// GetAllPolicies list for the account -func (h *Policies) GetAllPolicies(w http.ResponseWriter, r *http.Request) { +// getAllPolicies list for the account +func (h *handler) getAllPolicies(w http.ResponseWriter, r *http.Request) { claims := h.claimsExtractor.FromRequestContext(r) accountID, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims) if err != nil { @@ -65,8 +78,8 @@ func (h *Policies) GetAllPolicies(w http.ResponseWriter, r *http.Request) { util.WriteJSONObject(r.Context(), w, policies) } -// UpdatePolicy handles update to a policy identified by a given ID -func (h *Policies) UpdatePolicy(w http.ResponseWriter, r *http.Request) { +// updatePolicy handles update to a policy identified by a given ID +func (h *handler) updatePolicy(w http.ResponseWriter, r *http.Request) { claims := h.claimsExtractor.FromRequestContext(r) accountID, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims) if err != nil { @@ -90,8 +103,8 @@ func (h *Policies) UpdatePolicy(w http.ResponseWriter, r *http.Request) { h.savePolicy(w, r, accountID, userID, policyID) } -// CreatePolicy handles policy creation request -func (h *Policies) CreatePolicy(w http.ResponseWriter, r *http.Request) { +// createPolicy handles policy creation request +func (h *handler) createPolicy(w http.ResponseWriter, r *http.Request) { claims := h.claimsExtractor.FromRequestContext(r) accountID, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims) if err != nil { @@ -103,7 +116,7 @@ func (h *Policies) CreatePolicy(w http.ResponseWriter, r *http.Request) { } // savePolicy handles policy creation and update -func (h *Policies) savePolicy(w http.ResponseWriter, r *http.Request, accountID string, userID string, policyID string) { +func (h *handler) savePolicy(w http.ResponseWriter, r *http.Request, accountID string, userID string, policyID string) { var req api.PutApiPoliciesPolicyIdJSONRequestBody if err := json.NewDecoder(r.Body).Decode(&req); err != nil { util.WriteErrorResponse("couldn't parse JSON request", http.StatusBadRequest, w) @@ -120,7 +133,7 @@ func (h *Policies) savePolicy(w http.ResponseWriter, r *http.Request, accountID return } - policy := &server.Policy{ + policy := &types.Policy{ ID: policyID, AccountID: accountID, Name: req.Name, @@ -133,15 +146,56 @@ func (h *Policies) savePolicy(w http.ResponseWriter, r *http.Request, accountID ruleID = *rule.Id } - pr := server.PolicyRule{ + hasSources := rule.Sources != nil + hasSourceResource := rule.SourceResource != nil + + hasDestinations := rule.Destinations != nil + hasDestinationResource := rule.DestinationResource != nil + + if hasSources && hasSourceResource { + util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "specify either sources or source resources, not both"), w) + return + } + + if hasDestinations && hasDestinationResource { + util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "specify either destinations or destination resources, not both"), w) + return + } + + if !(hasSources || hasSourceResource) || !(hasDestinations || hasDestinationResource) { + util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "specify either sources or source resources and destinations or destination resources"), w) + return + } + + pr := types.PolicyRule{ ID: ruleID, PolicyID: policyID, Name: rule.Name, - Destinations: rule.Destinations, - Sources: rule.Sources, Bidirectional: rule.Bidirectional, } + if hasSources { + pr.Sources = *rule.Sources + } + + if hasSourceResource { + // TODO: validate the resource id and type + sourceResource := &types.Resource{} + sourceResource.FromAPIRequest(rule.SourceResource) + pr.SourceResource = *sourceResource + } + + if hasDestinations { + pr.Destinations = *rule.Destinations + } + + if hasDestinationResource { + // TODO: validate the resource id and type + destinationResource := &types.Resource{} + destinationResource.FromAPIRequest(rule.DestinationResource) + pr.DestinationResource = *destinationResource + } + pr.Enabled = rule.Enabled if rule.Description != nil { pr.Description = *rule.Description @@ -149,9 +203,9 @@ func (h *Policies) savePolicy(w http.ResponseWriter, r *http.Request, accountID switch rule.Action { case api.PolicyRuleUpdateActionAccept: - pr.Action = server.PolicyTrafficActionAccept + pr.Action = types.PolicyTrafficActionAccept case api.PolicyRuleUpdateActionDrop: - pr.Action = server.PolicyTrafficActionDrop + pr.Action = types.PolicyTrafficActionDrop default: util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "unknown action type"), w) return @@ -159,13 +213,13 @@ func (h *Policies) savePolicy(w http.ResponseWriter, r *http.Request, accountID switch rule.Protocol { case api.PolicyRuleUpdateProtocolAll: - pr.Protocol = server.PolicyRuleProtocolALL + pr.Protocol = types.PolicyRuleProtocolALL case api.PolicyRuleUpdateProtocolTcp: - pr.Protocol = server.PolicyRuleProtocolTCP + pr.Protocol = types.PolicyRuleProtocolTCP case api.PolicyRuleUpdateProtocolUdp: - pr.Protocol = server.PolicyRuleProtocolUDP + pr.Protocol = types.PolicyRuleProtocolUDP case api.PolicyRuleUpdateProtocolIcmp: - pr.Protocol = server.PolicyRuleProtocolICMP + pr.Protocol = types.PolicyRuleProtocolICMP default: util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "unknown protocol type: %v", rule.Protocol), w) return @@ -192,7 +246,7 @@ func (h *Policies) savePolicy(w http.ResponseWriter, r *http.Request, accountID util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "valid port value is in 1..65535 range"), w) return } - pr.PortRanges = append(pr.PortRanges, server.RulePortRange{ + pr.PortRanges = append(pr.PortRanges, types.RulePortRange{ Start: uint16(portRange.Start), End: uint16(portRange.End), }) @@ -201,7 +255,7 @@ func (h *Policies) savePolicy(w http.ResponseWriter, r *http.Request, accountID // validate policy object switch pr.Protocol { - case server.PolicyRuleProtocolALL, server.PolicyRuleProtocolICMP: + case types.PolicyRuleProtocolALL, types.PolicyRuleProtocolICMP: if len(pr.Ports) != 0 || len(pr.PortRanges) != 0 { util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "for ALL or ICMP protocol ports is not allowed"), w) return @@ -210,7 +264,7 @@ func (h *Policies) savePolicy(w http.ResponseWriter, r *http.Request, accountID util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "for ALL or ICMP protocol type flow can be only bi-directional"), w) return } - case server.PolicyRuleProtocolTCP, server.PolicyRuleProtocolUDP: + case types.PolicyRuleProtocolTCP, types.PolicyRuleProtocolUDP: if !pr.Bidirectional && (len(pr.Ports) == 0 || len(pr.PortRanges) != 0) { util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "for ALL or ICMP protocol type flow can be only bi-directional"), w) return @@ -245,8 +299,8 @@ func (h *Policies) savePolicy(w http.ResponseWriter, r *http.Request, accountID util.WriteJSONObject(r.Context(), w, resp) } -// DeletePolicy handles policy deletion request -func (h *Policies) DeletePolicy(w http.ResponseWriter, r *http.Request) { +// deletePolicy handles policy deletion request +func (h *handler) deletePolicy(w http.ResponseWriter, r *http.Request) { claims := h.claimsExtractor.FromRequestContext(r) accountID, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims) if err != nil { @@ -266,11 +320,11 @@ func (h *Policies) DeletePolicy(w http.ResponseWriter, r *http.Request) { return } - util.WriteJSONObject(r.Context(), w, emptyObject{}) + util.WriteJSONObject(r.Context(), w, util.EmptyObject{}) } -// GetPolicy handles a group Get request identified by ID -func (h *Policies) GetPolicy(w http.ResponseWriter, r *http.Request) { +// getPolicy handles a group Get request identified by ID +func (h *handler) getPolicy(w http.ResponseWriter, r *http.Request) { claims := h.claimsExtractor.FromRequestContext(r) accountID, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims) if err != nil { @@ -306,8 +360,8 @@ func (h *Policies) GetPolicy(w http.ResponseWriter, r *http.Request) { util.WriteJSONObject(r.Context(), w, resp) } -func toPolicyResponse(groups []*nbgroup.Group, policy *server.Policy) *api.Policy { - groupsMap := make(map[string]*nbgroup.Group) +func toPolicyResponse(groups []*types.Group, policy *types.Policy) *api.Policy { + groupsMap := make(map[string]*types.Group) for _, group := range groups { groupsMap[group.ID] = group } @@ -324,13 +378,15 @@ func toPolicyResponse(groups []*nbgroup.Group, policy *server.Policy) *api.Polic rID := r.ID rDescription := r.Description rule := api.PolicyRule{ - Id: &rID, - Name: r.Name, - Enabled: r.Enabled, - Description: &rDescription, - Bidirectional: r.Bidirectional, - Protocol: api.PolicyRuleProtocol(r.Protocol), - Action: api.PolicyRuleAction(r.Action), + Id: &rID, + Name: r.Name, + Enabled: r.Enabled, + Description: &rDescription, + Bidirectional: r.Bidirectional, + Protocol: api.PolicyRuleProtocol(r.Protocol), + Action: api.PolicyRuleAction(r.Action), + SourceResource: r.SourceResource.ToAPIResponse(), + DestinationResource: r.DestinationResource.ToAPIResponse(), } if len(r.Ports) != 0 { @@ -349,26 +405,30 @@ func toPolicyResponse(groups []*nbgroup.Group, policy *server.Policy) *api.Polic rule.PortRanges = &portRanges } + var sources []api.GroupMinimum for _, gid := range r.Sources { _, ok := cache[gid] if ok { continue } + if group, ok := groupsMap[gid]; ok { minimum := api.GroupMinimum{ Id: group.ID, Name: group.Name, PeersCount: len(group.Peers), } - rule.Sources = append(rule.Sources, minimum) + sources = append(sources, minimum) cache[gid] = minimum } } + rule.Sources = &sources + var destinations []api.GroupMinimum for _, gid := range r.Destinations { cachedMinimum, ok := cache[gid] if ok { - rule.Destinations = append(rule.Destinations, cachedMinimum) + destinations = append(destinations, cachedMinimum) continue } if group, ok := groupsMap[gid]; ok { @@ -377,10 +437,12 @@ func toPolicyResponse(groups []*nbgroup.Group, policy *server.Policy) *api.Polic Name: group.Name, PeersCount: len(group.Peers), } - rule.Destinations = append(rule.Destinations, minimum) + destinations = append(destinations, minimum) cache[gid] = minimum } } + rule.Destinations = &destinations + ap.Rules = append(ap.Rules, rule) } return ap diff --git a/management/server/http/policies_handler_test.go b/management/server/http/handlers/policies/policies_handler_test.go similarity index 83% rename from management/server/http/policies_handler_test.go rename to management/server/http/handlers/policies/policies_handler_test.go index f8a897eb2..956d0b7cd 100644 --- a/management/server/http/policies_handler_test.go +++ b/management/server/http/handlers/policies/policies_handler_test.go @@ -1,4 +1,4 @@ -package http +package policies import ( "bytes" @@ -10,9 +10,9 @@ import ( "strings" "testing" - nbgroup "github.com/netbirdio/netbird/management/server/group" "github.com/netbirdio/netbird/management/server/http/api" "github.com/netbirdio/netbird/management/server/status" + "github.com/netbirdio/netbird/management/server/types" "github.com/gorilla/mux" @@ -20,50 +20,49 @@ import ( "github.com/magiconair/properties/assert" - "github.com/netbirdio/netbird/management/server" "github.com/netbirdio/netbird/management/server/mock_server" ) -func initPoliciesTestData(policies ...*server.Policy) *Policies { - testPolicies := make(map[string]*server.Policy, len(policies)) +func initPoliciesTestData(policies ...*types.Policy) *handler { + testPolicies := make(map[string]*types.Policy, len(policies)) for _, policy := range policies { testPolicies[policy.ID] = policy } - return &Policies{ + return &handler{ accountManager: &mock_server.MockAccountManager{ - GetPolicyFunc: func(_ context.Context, _, policyID, _ string) (*server.Policy, error) { + GetPolicyFunc: func(_ context.Context, _, policyID, _ string) (*types.Policy, error) { policy, ok := testPolicies[policyID] if !ok { return nil, status.Errorf(status.NotFound, "policy not found") } return policy, nil }, - SavePolicyFunc: func(_ context.Context, _, _ string, policy *server.Policy) (*server.Policy, error) { + SavePolicyFunc: func(_ context.Context, _, _ string, policy *types.Policy) (*types.Policy, error) { if !strings.HasPrefix(policy.ID, "id-") { policy.ID = "id-was-set" policy.Rules[0].ID = "id-was-set" } return policy, nil }, - GetAllGroupsFunc: func(ctx context.Context, accountID, userID string) ([]*nbgroup.Group, error) { - return []*nbgroup.Group{{ID: "F"}, {ID: "G"}}, nil + GetAllGroupsFunc: func(ctx context.Context, accountID, userID string) ([]*types.Group, error) { + return []*types.Group{{ID: "F"}, {ID: "G"}}, nil }, GetAccountIDFromTokenFunc: func(_ context.Context, claims jwtclaims.AuthorizationClaims) (string, string, error) { return claims.AccountId, claims.UserId, nil }, - GetAccountByIDFunc: func(ctx context.Context, accountID string, userID string) (*server.Account, error) { - user := server.NewAdminUser(userID) - return &server.Account{ + GetAccountByIDFunc: func(ctx context.Context, accountID string, userID string) (*types.Account, error) { + user := types.NewAdminUser(userID) + return &types.Account{ Id: accountID, Domain: "hotmail.com", - Policies: []*server.Policy{ + Policies: []*types.Policy{ {ID: "id-existed"}, }, - Groups: map[string]*nbgroup.Group{ + Groups: map[string]*types.Group{ "F": {ID: "F"}, "G": {ID: "G"}, }, - Users: map[string]*server.User{ + Users: map[string]*types.User{ "test_user": user, }, }, nil @@ -91,24 +90,24 @@ func TestPoliciesGetPolicy(t *testing.T) { requestBody io.Reader }{ { - name: "GetPolicy OK", + name: "getPolicy OK", expectedBody: true, requestType: http.MethodGet, requestPath: "/api/policies/idofthepolicy", expectedStatus: http.StatusOK, }, { - name: "GetPolicy not found", + name: "getPolicy not found", requestType: http.MethodGet, requestPath: "/api/policies/notexists", expectedStatus: http.StatusNotFound, }, } - policy := &server.Policy{ + policy := &types.Policy{ ID: "idofthepolicy", Name: "Rule", - Rules: []*server.PolicyRule{ + Rules: []*types.PolicyRule{ {ID: "idoftherule", Name: "Rule"}, }, } @@ -121,7 +120,7 @@ func TestPoliciesGetPolicy(t *testing.T) { req := httptest.NewRequest(tc.requestType, tc.requestPath, tc.requestBody) router := mux.NewRouter() - router.HandleFunc("/api/policies/{policyId}", p.GetPolicy).Methods("GET") + router.HandleFunc("/api/policies/{policyId}", p.getPolicy).Methods("GET") router.ServeHTTP(recorder, req) res := recorder.Result() @@ -177,7 +176,9 @@ func TestPoliciesWritePolicy(t *testing.T) { "Description": "Description", "Protocol": "tcp", "Action": "accept", - "Bidirectional":true + "Bidirectional":true, + "Sources": ["F"], + "Destinations": ["G"] } ]}`)), expectedStatus: http.StatusOK, @@ -193,6 +194,8 @@ func TestPoliciesWritePolicy(t *testing.T) { Protocol: "tcp", Action: "accept", Bidirectional: true, + Sources: &[]api.GroupMinimum{{Id: "F"}}, + Destinations: &[]api.GroupMinimum{{Id: "G"}}, }, }, }, @@ -221,7 +224,9 @@ func TestPoliciesWritePolicy(t *testing.T) { "Description": "Description", "Protocol": "tcp", "Action": "accept", - "Bidirectional":true + "Bidirectional":true, + "Sources": ["F"], + "Destinations": ["F"] } ]}`)), expectedStatus: http.StatusOK, @@ -237,6 +242,8 @@ func TestPoliciesWritePolicy(t *testing.T) { Protocol: "tcp", Action: "accept", Bidirectional: true, + Sources: &[]api.GroupMinimum{{Id: "F"}}, + Destinations: &[]api.GroupMinimum{{Id: "F"}}, }, }, }, @@ -251,10 +258,10 @@ func TestPoliciesWritePolicy(t *testing.T) { }, } - p := initPoliciesTestData(&server.Policy{ + p := initPoliciesTestData(&types.Policy{ ID: "id-existed", Name: "Default POSTed Rule", - Rules: []*server.PolicyRule{ + Rules: []*types.PolicyRule{ { ID: "id-existed", Name: "Default POSTed Rule", @@ -269,8 +276,8 @@ func TestPoliciesWritePolicy(t *testing.T) { req := httptest.NewRequest(tc.requestType, tc.requestPath, tc.requestBody) router := mux.NewRouter() - router.HandleFunc("/api/policies", p.CreatePolicy).Methods("POST") - router.HandleFunc("/api/policies/{policyId}", p.UpdatePolicy).Methods("PUT") + router.HandleFunc("/api/policies", p.createPolicy).Methods("POST") + router.HandleFunc("/api/policies/{policyId}", p.updatePolicy).Methods("PUT") router.ServeHTTP(recorder, req) res := recorder.Result() diff --git a/management/server/http/posture_checks_handler.go b/management/server/http/handlers/policies/posture_checks_handler.go similarity index 70% rename from management/server/http/posture_checks_handler.go rename to management/server/http/handlers/policies/posture_checks_handler.go index 2c8204292..44917605b 100644 --- a/management/server/http/posture_checks_handler.go +++ b/management/server/http/handlers/policies/posture_checks_handler.go @@ -1,4 +1,4 @@ -package http +package policies import ( "encoding/json" @@ -9,22 +9,33 @@ import ( "github.com/netbirdio/netbird/management/server" "github.com/netbirdio/netbird/management/server/geolocation" "github.com/netbirdio/netbird/management/server/http/api" + "github.com/netbirdio/netbird/management/server/http/configs" "github.com/netbirdio/netbird/management/server/http/util" "github.com/netbirdio/netbird/management/server/jwtclaims" "github.com/netbirdio/netbird/management/server/posture" "github.com/netbirdio/netbird/management/server/status" ) -// PostureChecksHandler is a handler that returns posture checks of the account. -type PostureChecksHandler struct { +// postureChecksHandler is a handler that returns posture checks of the account. +type postureChecksHandler struct { accountManager server.AccountManager geolocationManager *geolocation.Geolocation claimsExtractor *jwtclaims.ClaimsExtractor } -// NewPostureChecksHandler creates a new PostureChecks handler -func NewPostureChecksHandler(accountManager server.AccountManager, geolocationManager *geolocation.Geolocation, authCfg AuthCfg) *PostureChecksHandler { - return &PostureChecksHandler{ +func addPostureCheckEndpoint(accountManager server.AccountManager, locationManager *geolocation.Geolocation, authCfg configs.AuthCfg, router *mux.Router) { + postureCheckHandler := newPostureChecksHandler(accountManager, locationManager, authCfg) + router.HandleFunc("/posture-checks", postureCheckHandler.getAllPostureChecks).Methods("GET", "OPTIONS") + router.HandleFunc("/posture-checks", postureCheckHandler.createPostureCheck).Methods("POST", "OPTIONS") + router.HandleFunc("/posture-checks/{postureCheckId}", postureCheckHandler.updatePostureCheck).Methods("PUT", "OPTIONS") + router.HandleFunc("/posture-checks/{postureCheckId}", postureCheckHandler.getPostureCheck).Methods("GET", "OPTIONS") + router.HandleFunc("/posture-checks/{postureCheckId}", postureCheckHandler.deletePostureCheck).Methods("DELETE", "OPTIONS") + addLocationsEndpoint(accountManager, locationManager, authCfg, router) +} + +// newPostureChecksHandler creates a new PostureChecks handler +func newPostureChecksHandler(accountManager server.AccountManager, geolocationManager *geolocation.Geolocation, authCfg configs.AuthCfg) *postureChecksHandler { + return &postureChecksHandler{ accountManager: accountManager, geolocationManager: geolocationManager, claimsExtractor: jwtclaims.NewClaimsExtractor( @@ -34,8 +45,8 @@ func NewPostureChecksHandler(accountManager server.AccountManager, geolocationMa } } -// GetAllPostureChecks list for the account -func (p *PostureChecksHandler) GetAllPostureChecks(w http.ResponseWriter, r *http.Request) { +// getAllPostureChecks list for the account +func (p *postureChecksHandler) getAllPostureChecks(w http.ResponseWriter, r *http.Request) { claims := p.claimsExtractor.FromRequestContext(r) accountID, userID, err := p.accountManager.GetAccountIDFromToken(r.Context(), claims) if err != nil { @@ -57,8 +68,8 @@ func (p *PostureChecksHandler) GetAllPostureChecks(w http.ResponseWriter, r *htt util.WriteJSONObject(r.Context(), w, postureChecks) } -// UpdatePostureCheck handles update to a posture check identified by a given ID -func (p *PostureChecksHandler) UpdatePostureCheck(w http.ResponseWriter, r *http.Request) { +// updatePostureCheck handles update to a posture check identified by a given ID +func (p *postureChecksHandler) updatePostureCheck(w http.ResponseWriter, r *http.Request) { claims := p.claimsExtractor.FromRequestContext(r) accountID, userID, err := p.accountManager.GetAccountIDFromToken(r.Context(), claims) if err != nil { @@ -82,8 +93,8 @@ func (p *PostureChecksHandler) UpdatePostureCheck(w http.ResponseWriter, r *http p.savePostureChecks(w, r, accountID, userID, postureChecksID) } -// CreatePostureCheck handles posture check creation request -func (p *PostureChecksHandler) CreatePostureCheck(w http.ResponseWriter, r *http.Request) { +// createPostureCheck handles posture check creation request +func (p *postureChecksHandler) createPostureCheck(w http.ResponseWriter, r *http.Request) { claims := p.claimsExtractor.FromRequestContext(r) accountID, userID, err := p.accountManager.GetAccountIDFromToken(r.Context(), claims) if err != nil { @@ -94,8 +105,8 @@ func (p *PostureChecksHandler) CreatePostureCheck(w http.ResponseWriter, r *http p.savePostureChecks(w, r, accountID, userID, "") } -// GetPostureCheck handles a posture check Get request identified by ID -func (p *PostureChecksHandler) GetPostureCheck(w http.ResponseWriter, r *http.Request) { +// getPostureCheck handles a posture check Get request identified by ID +func (p *postureChecksHandler) getPostureCheck(w http.ResponseWriter, r *http.Request) { claims := p.claimsExtractor.FromRequestContext(r) accountID, userID, err := p.accountManager.GetAccountIDFromToken(r.Context(), claims) if err != nil { @@ -119,8 +130,8 @@ func (p *PostureChecksHandler) GetPostureCheck(w http.ResponseWriter, r *http.Re util.WriteJSONObject(r.Context(), w, postureChecks.ToAPIResponse()) } -// DeletePostureCheck handles posture check deletion request -func (p *PostureChecksHandler) DeletePostureCheck(w http.ResponseWriter, r *http.Request) { +// deletePostureCheck handles posture check deletion request +func (p *postureChecksHandler) deletePostureCheck(w http.ResponseWriter, r *http.Request) { claims := p.claimsExtractor.FromRequestContext(r) accountID, userID, err := p.accountManager.GetAccountIDFromToken(r.Context(), claims) if err != nil { @@ -140,11 +151,11 @@ func (p *PostureChecksHandler) DeletePostureCheck(w http.ResponseWriter, r *http return } - util.WriteJSONObject(r.Context(), w, emptyObject{}) + util.WriteJSONObject(r.Context(), w, util.EmptyObject{}) } // savePostureChecks handles posture checks create and update -func (p *PostureChecksHandler) savePostureChecks(w http.ResponseWriter, r *http.Request, accountID, userID, postureChecksID string) { +func (p *postureChecksHandler) savePostureChecks(w http.ResponseWriter, r *http.Request, accountID, userID, postureChecksID string) { var ( err error req api.PostureCheckUpdate diff --git a/management/server/http/posture_checks_handler_test.go b/management/server/http/handlers/policies/posture_checks_handler_test.go similarity index 96% rename from management/server/http/posture_checks_handler_test.go rename to management/server/http/handlers/policies/posture_checks_handler_test.go index f400cec81..e9a539e45 100644 --- a/management/server/http/posture_checks_handler_test.go +++ b/management/server/http/handlers/policies/posture_checks_handler_test.go @@ -1,4 +1,4 @@ -package http +package policies import ( "bytes" @@ -25,13 +25,13 @@ import ( var berlin = "Berlin" var losAngeles = "Los Angeles" -func initPostureChecksTestData(postureChecks ...*posture.Checks) *PostureChecksHandler { +func initPostureChecksTestData(postureChecks ...*posture.Checks) *postureChecksHandler { testPostureChecks := make(map[string]*posture.Checks, len(postureChecks)) for _, postureCheck := range postureChecks { testPostureChecks[postureCheck.ID] = postureCheck } - return &PostureChecksHandler{ + return &postureChecksHandler{ accountManager: &mock_server.MockAccountManager{ GetPostureChecksFunc: func(_ context.Context, accountID, postureChecksID, userID string) (*posture.Checks, error) { p, ok := testPostureChecks[postureChecksID] @@ -147,35 +147,35 @@ func TestGetPostureCheck(t *testing.T) { requestBody io.Reader }{ { - name: "GetPostureCheck NBVersion OK", + name: "getPostureCheck NBVersion OK", expectedBody: true, id: postureCheck.ID, checkName: postureCheck.Name, expectedStatus: http.StatusOK, }, { - name: "GetPostureCheck OSVersion OK", + name: "getPostureCheck OSVersion OK", expectedBody: true, id: osPostureCheck.ID, checkName: osPostureCheck.Name, expectedStatus: http.StatusOK, }, { - name: "GetPostureCheck GeoLocation OK", + name: "getPostureCheck GeoLocation OK", expectedBody: true, id: geoPostureCheck.ID, checkName: geoPostureCheck.Name, expectedStatus: http.StatusOK, }, { - name: "GetPostureCheck PrivateNetwork OK", + name: "getPostureCheck PrivateNetwork OK", expectedBody: true, id: privateNetworkCheck.ID, checkName: privateNetworkCheck.Name, expectedStatus: http.StatusOK, }, { - name: "GetPostureCheck Not Found", + name: "getPostureCheck Not Found", id: "not-exists", expectedStatus: http.StatusNotFound, }, @@ -189,7 +189,7 @@ func TestGetPostureCheck(t *testing.T) { req := httptest.NewRequest(http.MethodGet, "/api/posture-checks/"+tc.id, tc.requestBody) router := mux.NewRouter() - router.HandleFunc("/api/posture-checks/{postureCheckId}", p.GetPostureCheck).Methods("GET") + router.HandleFunc("/api/posture-checks/{postureCheckId}", p.getPostureCheck).Methods("GET") router.ServeHTTP(recorder, req) res := recorder.Result() @@ -231,7 +231,7 @@ func TestPostureCheckUpdate(t *testing.T) { requestType string requestPath string requestBody io.Reader - setupHandlerFunc func(handler *PostureChecksHandler) + setupHandlerFunc func(handler *postureChecksHandler) }{ { name: "Create Posture Checks NB version", @@ -286,7 +286,7 @@ func TestPostureCheckUpdate(t *testing.T) { }, }, }, - setupHandlerFunc: func(handler *PostureChecksHandler) { + setupHandlerFunc: func(handler *postureChecksHandler) { handler.geolocationManager = nil }, }, @@ -427,7 +427,7 @@ func TestPostureCheckUpdate(t *testing.T) { }`)), expectedStatus: http.StatusPreconditionFailed, expectedBody: false, - setupHandlerFunc: func(handler *PostureChecksHandler) { + setupHandlerFunc: func(handler *postureChecksHandler) { handler.geolocationManager = nil }, }, @@ -614,7 +614,7 @@ func TestPostureCheckUpdate(t *testing.T) { }, }, }, - setupHandlerFunc: func(handler *PostureChecksHandler) { + setupHandlerFunc: func(handler *postureChecksHandler) { handler.geolocationManager = nil }, }, @@ -677,7 +677,7 @@ func TestPostureCheckUpdate(t *testing.T) { }`)), expectedStatus: http.StatusPreconditionFailed, expectedBody: false, - setupHandlerFunc: func(handler *PostureChecksHandler) { + setupHandlerFunc: func(handler *postureChecksHandler) { handler.geolocationManager = nil }, }, @@ -842,8 +842,8 @@ func TestPostureCheckUpdate(t *testing.T) { } router := mux.NewRouter() - router.HandleFunc("/api/posture-checks", defaultHandler.CreatePostureCheck).Methods("POST") - router.HandleFunc("/api/posture-checks/{postureCheckId}", defaultHandler.UpdatePostureCheck).Methods("PUT") + router.HandleFunc("/api/posture-checks", defaultHandler.createPostureCheck).Methods("POST") + router.HandleFunc("/api/posture-checks/{postureCheckId}", defaultHandler.updatePostureCheck).Methods("PUT") router.ServeHTTP(recorder, req) res := recorder.Result() diff --git a/management/server/http/routes_handler.go b/management/server/http/handlers/routes/routes_handler.go similarity index 85% rename from management/server/http/routes_handler.go rename to management/server/http/handlers/routes/routes_handler.go index cbf5e72dd..a29ba4562 100644 --- a/management/server/http/routes_handler.go +++ b/management/server/http/handlers/routes/routes_handler.go @@ -1,4 +1,4 @@ -package http +package routes import ( "encoding/json" @@ -14,6 +14,7 @@ import ( "github.com/netbirdio/netbird/management/domain" "github.com/netbirdio/netbird/management/server" "github.com/netbirdio/netbird/management/server/http/api" + "github.com/netbirdio/netbird/management/server/http/configs" "github.com/netbirdio/netbird/management/server/http/util" "github.com/netbirdio/netbird/management/server/jwtclaims" "github.com/netbirdio/netbird/management/server/status" @@ -23,15 +24,24 @@ import ( const maxDomains = 32 const failedToConvertRoute = "failed to convert route to response: %v" -// RoutesHandler is the routes handler of the account -type RoutesHandler struct { +// handler is the routes handler of the account +type handler struct { accountManager server.AccountManager claimsExtractor *jwtclaims.ClaimsExtractor } -// NewRoutesHandler returns a new instance of RoutesHandler handler -func NewRoutesHandler(accountManager server.AccountManager, authCfg AuthCfg) *RoutesHandler { - return &RoutesHandler{ +func AddEndpoints(accountManager server.AccountManager, authCfg configs.AuthCfg, router *mux.Router) { + routesHandler := newHandler(accountManager, authCfg) + router.HandleFunc("/routes", routesHandler.getAllRoutes).Methods("GET", "OPTIONS") + router.HandleFunc("/routes", routesHandler.createRoute).Methods("POST", "OPTIONS") + router.HandleFunc("/routes/{routeId}", routesHandler.updateRoute).Methods("PUT", "OPTIONS") + router.HandleFunc("/routes/{routeId}", routesHandler.getRoute).Methods("GET", "OPTIONS") + router.HandleFunc("/routes/{routeId}", routesHandler.deleteRoute).Methods("DELETE", "OPTIONS") +} + +// newHandler returns a new instance of routes handler +func newHandler(accountManager server.AccountManager, authCfg configs.AuthCfg) *handler { + return &handler{ accountManager: accountManager, claimsExtractor: jwtclaims.NewClaimsExtractor( jwtclaims.WithAudience(authCfg.Audience), @@ -40,8 +50,8 @@ func NewRoutesHandler(accountManager server.AccountManager, authCfg AuthCfg) *Ro } } -// GetAllRoutes returns the list of routes for the account -func (h *RoutesHandler) GetAllRoutes(w http.ResponseWriter, r *http.Request) { +// getAllRoutes returns the list of routes for the account +func (h *handler) getAllRoutes(w http.ResponseWriter, r *http.Request) { claims := h.claimsExtractor.FromRequestContext(r) accountID, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims) if err != nil { @@ -67,8 +77,8 @@ func (h *RoutesHandler) GetAllRoutes(w http.ResponseWriter, r *http.Request) { util.WriteJSONObject(r.Context(), w, apiRoutes) } -// CreateRoute handles route creation request -func (h *RoutesHandler) CreateRoute(w http.ResponseWriter, r *http.Request) { +// createRoute handles route creation request +func (h *handler) createRoute(w http.ResponseWriter, r *http.Request) { claims := h.claimsExtractor.FromRequestContext(r) accountID, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims) if err != nil { @@ -139,7 +149,7 @@ func (h *RoutesHandler) CreateRoute(w http.ResponseWriter, r *http.Request) { util.WriteJSONObject(r.Context(), w, routes) } -func (h *RoutesHandler) validateRoute(req api.PostApiRoutesJSONRequestBody) error { +func (h *handler) validateRoute(req api.PostApiRoutesJSONRequestBody) error { if req.Network != nil && req.Domains != nil { return status.Errorf(status.InvalidArgument, "only one of 'network' or 'domains' should be provided") } @@ -164,8 +174,8 @@ func (h *RoutesHandler) validateRoute(req api.PostApiRoutesJSONRequestBody) erro return nil } -// UpdateRoute handles update to a route identified by a given ID -func (h *RoutesHandler) UpdateRoute(w http.ResponseWriter, r *http.Request) { +// updateRoute handles update to a route identified by a given ID +func (h *handler) updateRoute(w http.ResponseWriter, r *http.Request) { claims := h.claimsExtractor.FromRequestContext(r) accountID, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims) if err != nil { @@ -257,8 +267,8 @@ func (h *RoutesHandler) UpdateRoute(w http.ResponseWriter, r *http.Request) { util.WriteJSONObject(r.Context(), w, routes) } -// DeleteRoute handles route deletion request -func (h *RoutesHandler) DeleteRoute(w http.ResponseWriter, r *http.Request) { +// deleteRoute handles route deletion request +func (h *handler) deleteRoute(w http.ResponseWriter, r *http.Request) { claims := h.claimsExtractor.FromRequestContext(r) accountID, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims) if err != nil { @@ -278,11 +288,11 @@ func (h *RoutesHandler) DeleteRoute(w http.ResponseWriter, r *http.Request) { return } - util.WriteJSONObject(r.Context(), w, emptyObject{}) + util.WriteJSONObject(r.Context(), w, util.EmptyObject{}) } -// GetRoute handles a route Get request identified by ID -func (h *RoutesHandler) GetRoute(w http.ResponseWriter, r *http.Request) { +// getRoute handles a route Get request identified by ID +func (h *handler) getRoute(w http.ResponseWriter, r *http.Request) { claims := h.claimsExtractor.FromRequestContext(r) accountID, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims) if err != nil { diff --git a/management/server/http/routes_handler_test.go b/management/server/http/handlers/routes/routes_handler_test.go similarity index 97% rename from management/server/http/routes_handler_test.go rename to management/server/http/handlers/routes/routes_handler_test.go index c6eabe782..4cee3ee30 100644 --- a/management/server/http/routes_handler_test.go +++ b/management/server/http/handlers/routes/routes_handler_test.go @@ -1,4 +1,4 @@ -package http +package routes import ( "bytes" @@ -16,13 +16,13 @@ import ( "github.com/netbirdio/netbird/management/server/http/api" nbpeer "github.com/netbirdio/netbird/management/server/peer" "github.com/netbirdio/netbird/management/server/status" + "github.com/netbirdio/netbird/management/server/types" "github.com/netbirdio/netbird/route" "github.com/gorilla/mux" "github.com/magiconair/properties/assert" "github.com/netbirdio/netbird/management/domain" - "github.com/netbirdio/netbird/management/server" "github.com/netbirdio/netbird/management/server/jwtclaims" "github.com/netbirdio/netbird/management/server/mock_server" ) @@ -61,7 +61,7 @@ var baseExistingRoute = &route.Route{ Groups: []string{existingGroupID}, } -var testingAccount = &server.Account{ +var testingAccount = &types.Account{ Id: testAccountID, Domain: "hotmail.com", Peers: map[string]*nbpeer.Peer{ @@ -82,13 +82,13 @@ var testingAccount = &server.Account{ }, }, }, - Users: map[string]*server.User{ - "test_user": server.NewAdminUser("test_user"), + Users: map[string]*types.User{ + "test_user": types.NewAdminUser("test_user"), }, } -func initRoutesTestData() *RoutesHandler { - return &RoutesHandler{ +func initRoutesTestData() *handler { + return &handler{ accountManager: &mock_server.MockAccountManager{ GetRouteFunc: func(_ context.Context, _ string, routeID route.ID, _ string) (*route.Route, error) { if routeID == existingRouteID { @@ -152,7 +152,7 @@ func initRoutesTestData() *RoutesHandler { return nil }, GetAccountIDFromTokenFunc: func(_ context.Context, _ jwtclaims.AuthorizationClaims) (string, string, error) { - //return testingAccount, testingAccount.Users["test_user"], nil + // return testingAccount, testingAccount.Users["test_user"], nil return testingAccount.Id, testingAccount.Users["test_user"].Id, nil }, }, @@ -529,10 +529,10 @@ func TestRoutesHandlers(t *testing.T) { req := httptest.NewRequest(tc.requestType, tc.requestPath, tc.requestBody) router := mux.NewRouter() - router.HandleFunc("/api/routes/{routeId}", p.GetRoute).Methods("GET") - router.HandleFunc("/api/routes/{routeId}", p.DeleteRoute).Methods("DELETE") - router.HandleFunc("/api/routes", p.CreateRoute).Methods("POST") - router.HandleFunc("/api/routes/{routeId}", p.UpdateRoute).Methods("PUT") + router.HandleFunc("/api/routes/{routeId}", p.getRoute).Methods("GET") + router.HandleFunc("/api/routes/{routeId}", p.deleteRoute).Methods("DELETE") + router.HandleFunc("/api/routes", p.createRoute).Methods("POST") + router.HandleFunc("/api/routes/{routeId}", p.updateRoute).Methods("PUT") router.ServeHTTP(recorder, req) res := recorder.Result() diff --git a/management/server/http/setupkeys_handler.go b/management/server/http/handlers/setup_keys/setupkeys_handler.go similarity index 73% rename from management/server/http/setupkeys_handler.go rename to management/server/http/handlers/setup_keys/setupkeys_handler.go index 9ba5977bb..89696a165 100644 --- a/management/server/http/setupkeys_handler.go +++ b/management/server/http/handlers/setup_keys/setupkeys_handler.go @@ -1,4 +1,4 @@ -package http +package setup_keys import ( "context" @@ -10,20 +10,31 @@ import ( "github.com/netbirdio/netbird/management/server" "github.com/netbirdio/netbird/management/server/http/api" + "github.com/netbirdio/netbird/management/server/http/configs" "github.com/netbirdio/netbird/management/server/http/util" "github.com/netbirdio/netbird/management/server/jwtclaims" "github.com/netbirdio/netbird/management/server/status" + "github.com/netbirdio/netbird/management/server/types" ) -// SetupKeysHandler is a handler that returns a list of setup keys of the account -type SetupKeysHandler struct { +// handler is a handler that returns a list of setup keys of the account +type handler struct { accountManager server.AccountManager claimsExtractor *jwtclaims.ClaimsExtractor } -// NewSetupKeysHandler creates a new SetupKeysHandler HTTP handler -func NewSetupKeysHandler(accountManager server.AccountManager, authCfg AuthCfg) *SetupKeysHandler { - return &SetupKeysHandler{ +func AddEndpoints(accountManager server.AccountManager, authCfg configs.AuthCfg, router *mux.Router) { + keysHandler := newHandler(accountManager, authCfg) + router.HandleFunc("/setup-keys", keysHandler.getAllSetupKeys).Methods("GET", "OPTIONS") + router.HandleFunc("/setup-keys", keysHandler.createSetupKey).Methods("POST", "OPTIONS") + router.HandleFunc("/setup-keys/{keyId}", keysHandler.getSetupKey).Methods("GET", "OPTIONS") + router.HandleFunc("/setup-keys/{keyId}", keysHandler.updateSetupKey).Methods("PUT", "OPTIONS") + router.HandleFunc("/setup-keys/{keyId}", keysHandler.deleteSetupKey).Methods("DELETE", "OPTIONS") +} + +// newHandler creates a new setup key handler +func newHandler(accountManager server.AccountManager, authCfg configs.AuthCfg) *handler { + return &handler{ accountManager: accountManager, claimsExtractor: jwtclaims.NewClaimsExtractor( jwtclaims.WithAudience(authCfg.Audience), @@ -32,8 +43,8 @@ func NewSetupKeysHandler(accountManager server.AccountManager, authCfg AuthCfg) } } -// CreateSetupKey is a POST requests that creates a new SetupKey -func (h *SetupKeysHandler) CreateSetupKey(w http.ResponseWriter, r *http.Request) { +// createSetupKey is a POST requests that creates a new SetupKey +func (h *handler) createSetupKey(w http.ResponseWriter, r *http.Request) { claims := h.claimsExtractor.FromRequestContext(r) accountID, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims) if err != nil { @@ -53,8 +64,8 @@ func (h *SetupKeysHandler) CreateSetupKey(w http.ResponseWriter, r *http.Request return } - if !(server.SetupKeyType(req.Type) == server.SetupKeyReusable || - server.SetupKeyType(req.Type) == server.SetupKeyOneOff) { + if !(types.SetupKeyType(req.Type) == types.SetupKeyReusable || + types.SetupKeyType(req.Type) == types.SetupKeyOneOff) { util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "unknown setup key type %s", req.Type), w) return } @@ -75,7 +86,7 @@ func (h *SetupKeysHandler) CreateSetupKey(w http.ResponseWriter, r *http.Request ephemeral = *req.Ephemeral } - setupKey, err := h.accountManager.CreateSetupKey(r.Context(), accountID, req.Name, server.SetupKeyType(req.Type), expiresIn, + setupKey, err := h.accountManager.CreateSetupKey(r.Context(), accountID, req.Name, types.SetupKeyType(req.Type), expiresIn, req.AutoGroups, req.UsageLimit, userID, ephemeral) if err != nil { util.WriteError(r.Context(), err, w) @@ -89,8 +100,8 @@ func (h *SetupKeysHandler) CreateSetupKey(w http.ResponseWriter, r *http.Request util.WriteJSONObject(r.Context(), w, apiSetupKeys) } -// GetSetupKey is a GET request to get a SetupKey by ID -func (h *SetupKeysHandler) GetSetupKey(w http.ResponseWriter, r *http.Request) { +// getSetupKey is a GET request to get a SetupKey by ID +func (h *handler) getSetupKey(w http.ResponseWriter, r *http.Request) { claims := h.claimsExtractor.FromRequestContext(r) accountID, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims) if err != nil { @@ -114,8 +125,8 @@ func (h *SetupKeysHandler) GetSetupKey(w http.ResponseWriter, r *http.Request) { writeSuccess(r.Context(), w, key) } -// UpdateSetupKey is a PUT request to update server.SetupKey -func (h *SetupKeysHandler) UpdateSetupKey(w http.ResponseWriter, r *http.Request) { +// updateSetupKey is a PUT request to update server.SetupKey +func (h *handler) updateSetupKey(w http.ResponseWriter, r *http.Request) { claims := h.claimsExtractor.FromRequestContext(r) accountID, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims) if err != nil { @@ -142,7 +153,7 @@ func (h *SetupKeysHandler) UpdateSetupKey(w http.ResponseWriter, r *http.Request return } - newKey := &server.SetupKey{} + newKey := &types.SetupKey{} newKey.AutoGroups = req.AutoGroups newKey.Revoked = req.Revoked newKey.Id = keyID @@ -155,8 +166,8 @@ func (h *SetupKeysHandler) UpdateSetupKey(w http.ResponseWriter, r *http.Request writeSuccess(r.Context(), w, newKey) } -// GetAllSetupKeys is a GET request that returns a list of SetupKey -func (h *SetupKeysHandler) GetAllSetupKeys(w http.ResponseWriter, r *http.Request) { +// getAllSetupKeys is a GET request that returns a list of SetupKey +func (h *handler) getAllSetupKeys(w http.ResponseWriter, r *http.Request) { claims := h.claimsExtractor.FromRequestContext(r) accountID, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims) if err != nil { @@ -178,7 +189,7 @@ func (h *SetupKeysHandler) GetAllSetupKeys(w http.ResponseWriter, r *http.Reques util.WriteJSONObject(r.Context(), w, apiSetupKeys) } -func (h *SetupKeysHandler) DeleteSetupKey(w http.ResponseWriter, r *http.Request) { +func (h *handler) deleteSetupKey(w http.ResponseWriter, r *http.Request) { claims := h.claimsExtractor.FromRequestContext(r) accountID, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims) if err != nil { @@ -199,10 +210,10 @@ func (h *SetupKeysHandler) DeleteSetupKey(w http.ResponseWriter, r *http.Request return } - util.WriteJSONObject(r.Context(), w, emptyObject{}) + util.WriteJSONObject(r.Context(), w, util.EmptyObject{}) } -func writeSuccess(ctx context.Context, w http.ResponseWriter, key *server.SetupKey) { +func writeSuccess(ctx context.Context, w http.ResponseWriter, key *types.SetupKey) { w.Header().Set("Content-Type", "application/json") w.WriteHeader(200) err := json.NewEncoder(w).Encode(toResponseBody(key)) @@ -212,7 +223,7 @@ func writeSuccess(ctx context.Context, w http.ResponseWriter, key *server.SetupK } } -func toResponseBody(key *server.SetupKey) *api.SetupKey { +func toResponseBody(key *types.SetupKey) *api.SetupKey { var state string switch { case key.IsExpired(): diff --git a/management/server/http/setupkeys_handler_test.go b/management/server/http/handlers/setup_keys/setupkeys_handler_test.go similarity index 86% rename from management/server/http/setupkeys_handler_test.go rename to management/server/http/handlers/setup_keys/setupkeys_handler_test.go index 09256d0ea..4ecb1e9ed 100644 --- a/management/server/http/setupkeys_handler_test.go +++ b/management/server/http/handlers/setup_keys/setupkeys_handler_test.go @@ -1,4 +1,4 @@ -package http +package setup_keys import ( "bytes" @@ -14,11 +14,11 @@ import ( "github.com/gorilla/mux" "github.com/stretchr/testify/assert" - "github.com/netbirdio/netbird/management/server" "github.com/netbirdio/netbird/management/server/http/api" "github.com/netbirdio/netbird/management/server/jwtclaims" "github.com/netbirdio/netbird/management/server/mock_server" "github.com/netbirdio/netbird/management/server/status" + "github.com/netbirdio/netbird/management/server/types" ) const ( @@ -26,19 +26,20 @@ const ( newSetupKeyName = "New Setup Key" updatedSetupKeyName = "KKKey" notFoundSetupKeyID = "notFoundSetupKeyID" + testAccountID = "test_id" ) -func initSetupKeysTestMetaData(defaultKey *server.SetupKey, newKey *server.SetupKey, updatedSetupKey *server.SetupKey, - user *server.User, -) *SetupKeysHandler { - return &SetupKeysHandler{ +func initSetupKeysTestMetaData(defaultKey *types.SetupKey, newKey *types.SetupKey, updatedSetupKey *types.SetupKey, + user *types.User, +) *handler { + return &handler{ accountManager: &mock_server.MockAccountManager{ GetAccountIDFromTokenFunc: func(_ context.Context, claims jwtclaims.AuthorizationClaims) (string, string, error) { return claims.AccountId, claims.UserId, nil }, - CreateSetupKeyFunc: func(_ context.Context, _ string, keyName string, typ server.SetupKeyType, _ time.Duration, _ []string, + CreateSetupKeyFunc: func(_ context.Context, _ string, keyName string, typ types.SetupKeyType, _ time.Duration, _ []string, _ int, _ string, ephemeral bool, - ) (*server.SetupKey, error) { + ) (*types.SetupKey, error) { if keyName == newKey.Name || typ != newKey.Type { nk := newKey.Copy() nk.Ephemeral = ephemeral @@ -46,7 +47,7 @@ func initSetupKeysTestMetaData(defaultKey *server.SetupKey, newKey *server.Setup } return nil, fmt.Errorf("failed creating setup key") }, - GetSetupKeyFunc: func(_ context.Context, accountID, userID, keyID string) (*server.SetupKey, error) { + GetSetupKeyFunc: func(_ context.Context, accountID, userID, keyID string) (*types.SetupKey, error) { switch keyID { case defaultKey.Id: return defaultKey, nil @@ -57,15 +58,15 @@ func initSetupKeysTestMetaData(defaultKey *server.SetupKey, newKey *server.Setup } }, - SaveSetupKeyFunc: func(_ context.Context, accountID string, key *server.SetupKey, _ string) (*server.SetupKey, error) { + SaveSetupKeyFunc: func(_ context.Context, accountID string, key *types.SetupKey, _ string) (*types.SetupKey, error) { if key.Id == updatedSetupKey.Id { return updatedSetupKey, nil } return nil, status.Errorf(status.NotFound, "key %s not found", key.Id) }, - ListSetupKeysFunc: func(_ context.Context, accountID, userID string) ([]*server.SetupKey, error) { - return []*server.SetupKey{defaultKey}, nil + ListSetupKeysFunc: func(_ context.Context, accountID, userID string) ([]*types.SetupKey, error) { + return []*types.SetupKey{defaultKey}, nil }, DeleteSetupKeyFunc: func(_ context.Context, accountID, userID, keyID string) error { @@ -88,13 +89,13 @@ func initSetupKeysTestMetaData(defaultKey *server.SetupKey, newKey *server.Setup } func TestSetupKeysHandlers(t *testing.T) { - defaultSetupKey, _ := server.GenerateDefaultSetupKey() + defaultSetupKey, _ := types.GenerateDefaultSetupKey() defaultSetupKey.Id = existingSetupKeyID - adminUser := server.NewAdminUser("test_user") + adminUser := types.NewAdminUser("test_user") - newSetupKey, plainKey := server.GenerateSetupKey(newSetupKeyName, server.SetupKeyReusable, 0, []string{"group-1"}, - server.SetupKeyUnlimitedUsage, true) + newSetupKey, plainKey := types.GenerateSetupKey(newSetupKeyName, types.SetupKeyReusable, 0, []string{"group-1"}, + types.SetupKeyUnlimitedUsage, true) newSetupKey.Key = plainKey updatedDefaultSetupKey := defaultSetupKey.Copy() updatedDefaultSetupKey.AutoGroups = []string{"group-1"} @@ -178,11 +179,11 @@ func TestSetupKeysHandlers(t *testing.T) { req := httptest.NewRequest(tc.requestType, tc.requestPath, tc.requestBody) router := mux.NewRouter() - router.HandleFunc("/api/setup-keys", handler.GetAllSetupKeys).Methods("GET", "OPTIONS") - router.HandleFunc("/api/setup-keys", handler.CreateSetupKey).Methods("POST", "OPTIONS") - router.HandleFunc("/api/setup-keys/{keyId}", handler.GetSetupKey).Methods("GET", "OPTIONS") - router.HandleFunc("/api/setup-keys/{keyId}", handler.UpdateSetupKey).Methods("PUT", "OPTIONS") - router.HandleFunc("/api/setup-keys/{keyId}", handler.DeleteSetupKey).Methods("DELETE", "OPTIONS") + router.HandleFunc("/api/setup-keys", handler.getAllSetupKeys).Methods("GET", "OPTIONS") + router.HandleFunc("/api/setup-keys", handler.createSetupKey).Methods("POST", "OPTIONS") + router.HandleFunc("/api/setup-keys/{keyId}", handler.getSetupKey).Methods("GET", "OPTIONS") + router.HandleFunc("/api/setup-keys/{keyId}", handler.updateSetupKey).Methods("PUT", "OPTIONS") + router.HandleFunc("/api/setup-keys/{keyId}", handler.deleteSetupKey).Methods("DELETE", "OPTIONS") router.ServeHTTP(recorder, req) res := recorder.Result() diff --git a/management/server/http/pat_handler.go b/management/server/http/handlers/users/pat_handler.go similarity index 71% rename from management/server/http/pat_handler.go rename to management/server/http/handlers/users/pat_handler.go index dfa9563e3..197785b34 100644 --- a/management/server/http/pat_handler.go +++ b/management/server/http/handlers/users/pat_handler.go @@ -1,4 +1,4 @@ -package http +package users import ( "encoding/json" @@ -9,20 +9,30 @@ import ( "github.com/netbirdio/netbird/management/server" "github.com/netbirdio/netbird/management/server/http/api" + "github.com/netbirdio/netbird/management/server/http/configs" "github.com/netbirdio/netbird/management/server/http/util" "github.com/netbirdio/netbird/management/server/jwtclaims" "github.com/netbirdio/netbird/management/server/status" + "github.com/netbirdio/netbird/management/server/types" ) -// PATHandler is the nameserver group handler of the account -type PATHandler struct { +// patHandler is the nameserver group handler of the account +type patHandler struct { accountManager server.AccountManager claimsExtractor *jwtclaims.ClaimsExtractor } -// NewPATsHandler creates a new PATHandler HTTP handler -func NewPATsHandler(accountManager server.AccountManager, authCfg AuthCfg) *PATHandler { - return &PATHandler{ +func addUsersTokensEndpoint(accountManager server.AccountManager, authCfg configs.AuthCfg, router *mux.Router) { + tokenHandler := newPATsHandler(accountManager, authCfg) + router.HandleFunc("/users/{userId}/tokens", tokenHandler.getAllTokens).Methods("GET", "OPTIONS") + router.HandleFunc("/users/{userId}/tokens", tokenHandler.createToken).Methods("POST", "OPTIONS") + router.HandleFunc("/users/{userId}/tokens/{tokenId}", tokenHandler.getToken).Methods("GET", "OPTIONS") + router.HandleFunc("/users/{userId}/tokens/{tokenId}", tokenHandler.deleteToken).Methods("DELETE", "OPTIONS") +} + +// newPATsHandler creates a new patHandler HTTP handler +func newPATsHandler(accountManager server.AccountManager, authCfg configs.AuthCfg) *patHandler { + return &patHandler{ accountManager: accountManager, claimsExtractor: jwtclaims.NewClaimsExtractor( jwtclaims.WithAudience(authCfg.Audience), @@ -31,8 +41,8 @@ func NewPATsHandler(accountManager server.AccountManager, authCfg AuthCfg) *PATH } } -// GetAllTokens is HTTP GET handler that returns a list of all personal access tokens for the given user -func (h *PATHandler) GetAllTokens(w http.ResponseWriter, r *http.Request) { +// getAllTokens is HTTP GET handler that returns a list of all personal access tokens for the given user +func (h *patHandler) getAllTokens(w http.ResponseWriter, r *http.Request) { claims := h.claimsExtractor.FromRequestContext(r) accountID, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims) if err != nil { @@ -61,8 +71,8 @@ func (h *PATHandler) GetAllTokens(w http.ResponseWriter, r *http.Request) { util.WriteJSONObject(r.Context(), w, patResponse) } -// GetToken is HTTP GET handler that returns a personal access token for the given user -func (h *PATHandler) GetToken(w http.ResponseWriter, r *http.Request) { +// getToken is HTTP GET handler that returns a personal access token for the given user +func (h *patHandler) getToken(w http.ResponseWriter, r *http.Request) { claims := h.claimsExtractor.FromRequestContext(r) accountID, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims) if err != nil { @@ -92,8 +102,8 @@ func (h *PATHandler) GetToken(w http.ResponseWriter, r *http.Request) { util.WriteJSONObject(r.Context(), w, toPATResponse(pat)) } -// CreateToken is HTTP POST handler that creates a personal access token for the given user -func (h *PATHandler) CreateToken(w http.ResponseWriter, r *http.Request) { +// createToken is HTTP POST handler that creates a personal access token for the given user +func (h *patHandler) createToken(w http.ResponseWriter, r *http.Request) { claims := h.claimsExtractor.FromRequestContext(r) accountID, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims) if err != nil { @@ -124,8 +134,8 @@ func (h *PATHandler) CreateToken(w http.ResponseWriter, r *http.Request) { util.WriteJSONObject(r.Context(), w, toPATGeneratedResponse(pat)) } -// DeleteToken is HTTP DELETE handler that deletes a personal access token for the given user -func (h *PATHandler) DeleteToken(w http.ResponseWriter, r *http.Request) { +// deleteToken is HTTP DELETE handler that deletes a personal access token for the given user +func (h *patHandler) deleteToken(w http.ResponseWriter, r *http.Request) { claims := h.claimsExtractor.FromRequestContext(r) accountID, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims) if err != nil { @@ -152,10 +162,10 @@ func (h *PATHandler) DeleteToken(w http.ResponseWriter, r *http.Request) { return } - util.WriteJSONObject(r.Context(), w, emptyObject{}) + util.WriteJSONObject(r.Context(), w, util.EmptyObject{}) } -func toPATResponse(pat *server.PersonalAccessToken) *api.PersonalAccessToken { +func toPATResponse(pat *types.PersonalAccessToken) *api.PersonalAccessToken { var lastUsed *time.Time if !pat.LastUsed.IsZero() { lastUsed = &pat.LastUsed @@ -170,7 +180,7 @@ func toPATResponse(pat *server.PersonalAccessToken) *api.PersonalAccessToken { } } -func toPATGeneratedResponse(pat *server.PersonalAccessTokenGenerated) *api.PersonalAccessTokenGenerated { +func toPATGeneratedResponse(pat *types.PersonalAccessTokenGenerated) *api.PersonalAccessTokenGenerated { return &api.PersonalAccessTokenGenerated{ PlainToken: pat.PlainToken, PersonalAccessToken: *toPATResponse(&pat.PersonalAccessToken), diff --git a/management/server/http/pat_handler_test.go b/management/server/http/handlers/users/pat_handler_test.go similarity index 87% rename from management/server/http/pat_handler_test.go rename to management/server/http/handlers/users/pat_handler_test.go index c28228a50..21bdc461e 100644 --- a/management/server/http/pat_handler_test.go +++ b/management/server/http/handlers/users/pat_handler_test.go @@ -1,4 +1,4 @@ -package http +package users import ( "bytes" @@ -14,11 +14,11 @@ import ( "github.com/gorilla/mux" "github.com/stretchr/testify/assert" - "github.com/netbirdio/netbird/management/server" "github.com/netbirdio/netbird/management/server/http/api" "github.com/netbirdio/netbird/management/server/jwtclaims" "github.com/netbirdio/netbird/management/server/mock_server" "github.com/netbirdio/netbird/management/server/status" + "github.com/netbirdio/netbird/management/server/types" ) const ( @@ -31,13 +31,13 @@ const ( testDomain = "hotmail.com" ) -var testAccount = &server.Account{ +var testAccount = &types.Account{ Id: existingAccountID, Domain: testDomain, - Users: map[string]*server.User{ + Users: map[string]*types.User{ existingUserID: { Id: existingUserID, - PATs: map[string]*server.PersonalAccessToken{ + PATs: map[string]*types.PersonalAccessToken{ existingTokenID: { ID: existingTokenID, Name: "My first token", @@ -61,19 +61,19 @@ var testAccount = &server.Account{ }, } -func initPATTestData() *PATHandler { - return &PATHandler{ +func initPATTestData() *patHandler { + return &patHandler{ accountManager: &mock_server.MockAccountManager{ - CreatePATFunc: func(_ context.Context, accountID string, initiatorUserID string, targetUserID string, tokenName string, expiresIn int) (*server.PersonalAccessTokenGenerated, error) { + CreatePATFunc: func(_ context.Context, accountID string, initiatorUserID string, targetUserID string, tokenName string, expiresIn int) (*types.PersonalAccessTokenGenerated, error) { if accountID != existingAccountID { return nil, status.Errorf(status.NotFound, "account with ID %s not found", accountID) } if targetUserID != existingUserID { return nil, status.Errorf(status.NotFound, "user with ID %s not found", targetUserID) } - return &server.PersonalAccessTokenGenerated{ + return &types.PersonalAccessTokenGenerated{ PlainToken: "nbp_z1pvsg2wP3EzmEou4S679KyTNhov632eyrXe", - PersonalAccessToken: server.PersonalAccessToken{}, + PersonalAccessToken: types.PersonalAccessToken{}, }, nil }, @@ -92,7 +92,7 @@ func initPATTestData() *PATHandler { } return nil }, - GetPATFunc: func(_ context.Context, accountID string, initiatorUserID string, targetUserID string, tokenID string) (*server.PersonalAccessToken, error) { + GetPATFunc: func(_ context.Context, accountID string, initiatorUserID string, targetUserID string, tokenID string) (*types.PersonalAccessToken, error) { if accountID != existingAccountID { return nil, status.Errorf(status.NotFound, "account with ID %s not found", accountID) } @@ -104,14 +104,14 @@ func initPATTestData() *PATHandler { } return testAccount.Users[existingUserID].PATs[existingTokenID], nil }, - GetAllPATsFunc: func(_ context.Context, accountID string, initiatorUserID string, targetUserID string) ([]*server.PersonalAccessToken, error) { + GetAllPATsFunc: func(_ context.Context, accountID string, initiatorUserID string, targetUserID string) ([]*types.PersonalAccessToken, error) { if accountID != existingAccountID { return nil, status.Errorf(status.NotFound, "account with ID %s not found", accountID) } if targetUserID != existingUserID { return nil, status.Errorf(status.NotFound, "user with ID %s not found", targetUserID) } - return []*server.PersonalAccessToken{testAccount.Users[existingUserID].PATs[existingTokenID], testAccount.Users[existingUserID].PATs["token2"]}, nil + return []*types.PersonalAccessToken{testAccount.Users[existingUserID].PATs[existingTokenID], testAccount.Users[existingUserID].PATs["token2"]}, nil }, }, claimsExtractor: jwtclaims.NewClaimsExtractor( @@ -186,10 +186,10 @@ func TestTokenHandlers(t *testing.T) { req := httptest.NewRequest(tc.requestType, tc.requestPath, tc.requestBody) router := mux.NewRouter() - router.HandleFunc("/api/users/{userId}/tokens", p.GetAllTokens).Methods("GET") - router.HandleFunc("/api/users/{userId}/tokens/{tokenId}", p.GetToken).Methods("GET") - router.HandleFunc("/api/users/{userId}/tokens", p.CreateToken).Methods("POST") - router.HandleFunc("/api/users/{userId}/tokens/{tokenId}", p.DeleteToken).Methods("DELETE") + router.HandleFunc("/api/users/{userId}/tokens", p.getAllTokens).Methods("GET") + router.HandleFunc("/api/users/{userId}/tokens/{tokenId}", p.getToken).Methods("GET") + router.HandleFunc("/api/users/{userId}/tokens", p.createToken).Methods("POST") + router.HandleFunc("/api/users/{userId}/tokens/{tokenId}", p.deleteToken).Methods("DELETE") router.ServeHTTP(recorder, req) res := recorder.Result() @@ -217,7 +217,7 @@ func TestTokenHandlers(t *testing.T) { t.Fatalf("Sent content is not in correct json format; %v", err) } assert.NotEmpty(t, got.PlainToken) - assert.Equal(t, server.PATLength, len(got.PlainToken)) + assert.Equal(t, types.PATLength, len(got.PlainToken)) case "Get All Tokens": expectedTokens := []api.PersonalAccessToken{ toTokenResponse(*testAccount.Users[existingUserID].PATs[existingTokenID]), @@ -243,7 +243,7 @@ func TestTokenHandlers(t *testing.T) { } } -func toTokenResponse(serverToken server.PersonalAccessToken) api.PersonalAccessToken { +func toTokenResponse(serverToken types.PersonalAccessToken) api.PersonalAccessToken { return api.PersonalAccessToken{ Id: serverToken.ID, Name: serverToken.Name, diff --git a/management/server/http/users_handler.go b/management/server/http/handlers/users/users_handler.go similarity index 76% rename from management/server/http/users_handler.go rename to management/server/http/handlers/users/users_handler.go index 6e151a0da..7380dd97e 100644 --- a/management/server/http/users_handler.go +++ b/management/server/http/handlers/users/users_handler.go @@ -1,4 +1,4 @@ -package http +package users import ( "encoding/json" @@ -9,22 +9,34 @@ import ( log "github.com/sirupsen/logrus" "github.com/netbirdio/netbird/management/server/http/api" + "github.com/netbirdio/netbird/management/server/http/configs" "github.com/netbirdio/netbird/management/server/http/util" "github.com/netbirdio/netbird/management/server/status" + "github.com/netbirdio/netbird/management/server/types" "github.com/netbirdio/netbird/management/server" "github.com/netbirdio/netbird/management/server/jwtclaims" ) -// UsersHandler is a handler that returns users of the account -type UsersHandler struct { +// handler is a handler that returns users of the account +type handler struct { accountManager server.AccountManager claimsExtractor *jwtclaims.ClaimsExtractor } -// NewUsersHandler creates a new UsersHandler HTTP handler -func NewUsersHandler(accountManager server.AccountManager, authCfg AuthCfg) *UsersHandler { - return &UsersHandler{ +func AddEndpoints(accountManager server.AccountManager, authCfg configs.AuthCfg, router *mux.Router) { + userHandler := newHandler(accountManager, authCfg) + router.HandleFunc("/users", userHandler.getAllUsers).Methods("GET", "OPTIONS") + router.HandleFunc("/users/{userId}", userHandler.updateUser).Methods("PUT", "OPTIONS") + router.HandleFunc("/users/{userId}", userHandler.deleteUser).Methods("DELETE", "OPTIONS") + router.HandleFunc("/users", userHandler.createUser).Methods("POST", "OPTIONS") + router.HandleFunc("/users/{userId}/invite", userHandler.inviteUser).Methods("POST", "OPTIONS") + addUsersTokensEndpoint(accountManager, authCfg, router) +} + +// newHandler creates a new UsersHandler HTTP handler +func newHandler(accountManager server.AccountManager, authCfg configs.AuthCfg) *handler { + return &handler{ accountManager: accountManager, claimsExtractor: jwtclaims.NewClaimsExtractor( jwtclaims.WithAudience(authCfg.Audience), @@ -33,8 +45,8 @@ func NewUsersHandler(accountManager server.AccountManager, authCfg AuthCfg) *Use } } -// UpdateUser is a PUT requests to update User data -func (h *UsersHandler) UpdateUser(w http.ResponseWriter, r *http.Request) { +// updateUser is a PUT requests to update User data +func (h *handler) updateUser(w http.ResponseWriter, r *http.Request) { if r.Method != http.MethodPut { util.WriteErrorResponse("wrong HTTP method", http.StatusMethodNotAllowed, w) return @@ -72,13 +84,13 @@ func (h *UsersHandler) UpdateUser(w http.ResponseWriter, r *http.Request) { return } - userRole := server.StrRoleToUserRole(req.Role) - if userRole == server.UserRoleUnknown { + userRole := types.StrRoleToUserRole(req.Role) + if userRole == types.UserRoleUnknown { util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "invalid user role"), w) return } - newUser, err := h.accountManager.SaveUser(r.Context(), accountID, userID, &server.User{ + newUser, err := h.accountManager.SaveUser(r.Context(), accountID, userID, &types.User{ Id: targetUserID, Role: userRole, AutoGroups: req.AutoGroups, @@ -94,8 +106,8 @@ func (h *UsersHandler) UpdateUser(w http.ResponseWriter, r *http.Request) { util.WriteJSONObject(r.Context(), w, toUserResponse(newUser, claims.UserId)) } -// DeleteUser is a DELETE request to delete a user -func (h *UsersHandler) DeleteUser(w http.ResponseWriter, r *http.Request) { +// deleteUser is a DELETE request to delete a user +func (h *handler) deleteUser(w http.ResponseWriter, r *http.Request) { if r.Method != http.MethodDelete { util.WriteErrorResponse("wrong HTTP method", http.StatusMethodNotAllowed, w) return @@ -121,11 +133,11 @@ func (h *UsersHandler) DeleteUser(w http.ResponseWriter, r *http.Request) { return } - util.WriteJSONObject(r.Context(), w, emptyObject{}) + util.WriteJSONObject(r.Context(), w, util.EmptyObject{}) } -// CreateUser creates a User in the system with a status "invited" (effectively this is a user invite). -func (h *UsersHandler) CreateUser(w http.ResponseWriter, r *http.Request) { +// createUser creates a User in the system with a status "invited" (effectively this is a user invite). +func (h *handler) createUser(w http.ResponseWriter, r *http.Request) { if r.Method != http.MethodPost { util.WriteErrorResponse("wrong HTTP method", http.StatusMethodNotAllowed, w) return @@ -145,7 +157,7 @@ func (h *UsersHandler) CreateUser(w http.ResponseWriter, r *http.Request) { return } - if server.StrRoleToUserRole(req.Role) == server.UserRoleUnknown { + if types.StrRoleToUserRole(req.Role) == types.UserRoleUnknown { util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "unknown user role %s", req.Role), w) return } @@ -160,13 +172,13 @@ func (h *UsersHandler) CreateUser(w http.ResponseWriter, r *http.Request) { name = *req.Name } - newUser, err := h.accountManager.CreateUser(r.Context(), accountID, userID, &server.UserInfo{ + newUser, err := h.accountManager.CreateUser(r.Context(), accountID, userID, &types.UserInfo{ Email: email, Name: name, Role: req.Role, AutoGroups: req.AutoGroups, IsServiceUser: req.IsServiceUser, - Issued: server.UserIssuedAPI, + Issued: types.UserIssuedAPI, }) if err != nil { util.WriteError(r.Context(), err, w) @@ -175,9 +187,9 @@ func (h *UsersHandler) CreateUser(w http.ResponseWriter, r *http.Request) { util.WriteJSONObject(r.Context(), w, toUserResponse(newUser, claims.UserId)) } -// GetAllUsers returns a list of users of the account this user belongs to. +// getAllUsers returns a list of users of the account this user belongs to. // It also gathers additional user data (like email and name) from the IDP manager. -func (h *UsersHandler) GetAllUsers(w http.ResponseWriter, r *http.Request) { +func (h *handler) getAllUsers(w http.ResponseWriter, r *http.Request) { if r.Method != http.MethodGet { util.WriteErrorResponse("wrong HTTP method", http.StatusMethodNotAllowed, w) return @@ -222,9 +234,9 @@ func (h *UsersHandler) GetAllUsers(w http.ResponseWriter, r *http.Request) { util.WriteJSONObject(r.Context(), w, users) } -// InviteUser resend invitations to users who haven't activated their accounts, +// inviteUser resend invitations to users who haven't activated their accounts, // prior to the expiration period. -func (h *UsersHandler) InviteUser(w http.ResponseWriter, r *http.Request) { +func (h *handler) inviteUser(w http.ResponseWriter, r *http.Request) { if r.Method != http.MethodPost { util.WriteErrorResponse("wrong HTTP method", http.StatusMethodNotAllowed, w) return @@ -250,10 +262,10 @@ func (h *UsersHandler) InviteUser(w http.ResponseWriter, r *http.Request) { return } - util.WriteJSONObject(r.Context(), w, emptyObject{}) + util.WriteJSONObject(r.Context(), w, util.EmptyObject{}) } -func toUserResponse(user *server.UserInfo, currenUserID string) *api.User { +func toUserResponse(user *types.UserInfo, currenUserID string) *api.User { autoGroups := user.AutoGroups if autoGroups == nil { autoGroups = []string{} diff --git a/management/server/http/users_handler_test.go b/management/server/http/handlers/users/users_handler_test.go similarity index 92% rename from management/server/http/users_handler_test.go rename to management/server/http/handlers/users/users_handler_test.go index f3d989da1..90081830a 100644 --- a/management/server/http/users_handler_test.go +++ b/management/server/http/handlers/users/users_handler_test.go @@ -1,4 +1,4 @@ -package http +package users import ( "bytes" @@ -13,11 +13,11 @@ import ( "github.com/gorilla/mux" "github.com/stretchr/testify/assert" - "github.com/netbirdio/netbird/management/server" "github.com/netbirdio/netbird/management/server/http/api" "github.com/netbirdio/netbird/management/server/jwtclaims" "github.com/netbirdio/netbird/management/server/mock_server" "github.com/netbirdio/netbird/management/server/status" + "github.com/netbirdio/netbird/management/server/types" ) const ( @@ -26,54 +26,54 @@ const ( regularUserID = "regularUserID" ) -var usersTestAccount = &server.Account{ +var usersTestAccount = &types.Account{ Id: existingAccountID, Domain: testDomain, - Users: map[string]*server.User{ + Users: map[string]*types.User{ existingUserID: { Id: existingUserID, Role: "admin", IsServiceUser: false, AutoGroups: []string{"group_1"}, - Issued: server.UserIssuedAPI, + Issued: types.UserIssuedAPI, }, regularUserID: { Id: regularUserID, Role: "user", IsServiceUser: false, AutoGroups: []string{"group_1"}, - Issued: server.UserIssuedAPI, + Issued: types.UserIssuedAPI, }, serviceUserID: { Id: serviceUserID, Role: "user", IsServiceUser: true, AutoGroups: []string{"group_1"}, - Issued: server.UserIssuedAPI, + Issued: types.UserIssuedAPI, }, nonDeletableServiceUserID: { Id: serviceUserID, Role: "admin", IsServiceUser: true, NonDeletable: true, - Issued: server.UserIssuedIntegration, + Issued: types.UserIssuedIntegration, }, }, } -func initUsersTestData() *UsersHandler { - return &UsersHandler{ +func initUsersTestData() *handler { + return &handler{ accountManager: &mock_server.MockAccountManager{ GetAccountIDFromTokenFunc: func(_ context.Context, claims jwtclaims.AuthorizationClaims) (string, string, error) { return usersTestAccount.Id, claims.UserId, nil }, - GetUserByIDFunc: func(ctx context.Context, id string) (*server.User, error) { + GetUserByIDFunc: func(ctx context.Context, id string) (*types.User, error) { return usersTestAccount.Users[id], nil }, - GetUsersFromAccountFunc: func(_ context.Context, accountID, userID string) ([]*server.UserInfo, error) { - users := make([]*server.UserInfo, 0) + GetUsersFromAccountFunc: func(_ context.Context, accountID, userID string) ([]*types.UserInfo, error) { + users := make([]*types.UserInfo, 0) for _, v := range usersTestAccount.Users { - users = append(users, &server.UserInfo{ + users = append(users, &types.UserInfo{ ID: v.Id, Role: string(v.Role), Name: "", @@ -85,7 +85,7 @@ func initUsersTestData() *UsersHandler { } return users, nil }, - CreateUserFunc: func(_ context.Context, accountID, userID string, key *server.UserInfo) (*server.UserInfo, error) { + CreateUserFunc: func(_ context.Context, accountID, userID string, key *types.UserInfo) (*types.UserInfo, error) { if userID != existingUserID { return nil, status.Errorf(status.NotFound, "user with ID %s does not exists", userID) } @@ -100,7 +100,7 @@ func initUsersTestData() *UsersHandler { } return nil }, - SaveUserFunc: func(_ context.Context, accountID, userID string, update *server.User) (*server.UserInfo, error) { + SaveUserFunc: func(_ context.Context, accountID, userID string, update *types.User) (*types.UserInfo, error) { if update.Id == notFoundUserID { return nil, status.Errorf(status.NotFound, "user with ID %s does not exists", update.Id) } @@ -109,7 +109,7 @@ func initUsersTestData() *UsersHandler { return nil, status.Errorf(status.NotFound, "user with ID %s does not exists", userID) } - info, err := update.Copy().ToUserInfo(nil, &server.Settings{RegularUsersViewBlocked: false}) + info, err := update.Copy().ToUserInfo(nil, &types.Settings{RegularUsersViewBlocked: false}) if err != nil { return nil, err } @@ -147,7 +147,7 @@ func TestGetUsers(t *testing.T) { requestPath string expectedUserIDs []string }{ - {name: "GetAllUsers", requestType: http.MethodGet, requestPath: "/api/users", expectedStatus: http.StatusOK, expectedUserIDs: []string{existingUserID, regularUserID, serviceUserID}}, + {name: "getAllUsers", requestType: http.MethodGet, requestPath: "/api/users", expectedStatus: http.StatusOK, expectedUserIDs: []string{existingUserID, regularUserID, serviceUserID}}, {name: "GetOnlyServiceUsers", requestType: http.MethodGet, requestPath: "/api/users?service_user=true", expectedStatus: http.StatusOK, expectedUserIDs: []string{serviceUserID}}, {name: "GetOnlyRegularUsers", requestType: http.MethodGet, requestPath: "/api/users?service_user=false", expectedStatus: http.StatusOK, expectedUserIDs: []string{existingUserID, regularUserID}}, } @@ -159,7 +159,7 @@ func TestGetUsers(t *testing.T) { recorder := httptest.NewRecorder() req := httptest.NewRequest(tc.requestType, tc.requestPath, nil) - userHandler.GetAllUsers(recorder, req) + userHandler.getAllUsers(recorder, req) res := recorder.Result() defer res.Body.Close() @@ -175,7 +175,7 @@ func TestGetUsers(t *testing.T) { return } - respBody := []*server.UserInfo{} + respBody := []*types.UserInfo{} err = json.Unmarshal(content, &respBody) if err != nil { t.Fatalf("Sent content is not in correct json format; %v", err) @@ -265,7 +265,7 @@ func TestUpdateUser(t *testing.T) { req := httptest.NewRequest(tc.requestType, tc.requestPath, tc.requestBody) router := mux.NewRouter() - router.HandleFunc("/api/users/{userId}", userHandler.UpdateUser).Methods("PUT") + router.HandleFunc("/api/users/{userId}", userHandler.updateUser).Methods("PUT") router.ServeHTTP(recorder, req) res := recorder.Result() @@ -342,7 +342,7 @@ func TestCreateUser(t *testing.T) { requestType string requestPath string requestBody io.Reader - expectedResult []*server.User + expectedResult []*types.User }{ {name: "CreateServiceUser", requestType: http.MethodPost, requestPath: "/api/users", expectedStatus: http.StatusOK, requestBody: bytes.NewBuffer(serviceUserString)}, // right now creation is blocked in AC middleware, will be refactored in the future @@ -356,7 +356,7 @@ func TestCreateUser(t *testing.T) { req := httptest.NewRequest(tc.requestType, tc.requestPath, tc.requestBody) rr := httptest.NewRecorder() - userHandler.CreateUser(rr, req) + userHandler.createUser(rr, req) res := rr.Result() defer res.Body.Close() @@ -401,7 +401,7 @@ func TestInviteUser(t *testing.T) { req = mux.SetURLVars(req, tc.requestVars) rr := httptest.NewRecorder() - userHandler.InviteUser(rr, req) + userHandler.inviteUser(rr, req) res := rr.Result() defer res.Body.Close() @@ -454,7 +454,7 @@ func TestDeleteUser(t *testing.T) { req = mux.SetURLVars(req, tc.requestVars) rr := httptest.NewRecorder() - userHandler.DeleteUser(rr, req) + userHandler.deleteUser(rr, req) res := rr.Result() defer res.Body.Close() diff --git a/management/server/http/middleware/access_control.go b/management/server/http/middleware/access_control.go index 0ad250f43..c5bdf5fe7 100644 --- a/management/server/http/middleware/access_control.go +++ b/management/server/http/middleware/access_control.go @@ -7,16 +7,16 @@ import ( log "github.com/sirupsen/logrus" - "github.com/netbirdio/netbird/management/server" "github.com/netbirdio/netbird/management/server/http/middleware/bypass" "github.com/netbirdio/netbird/management/server/http/util" "github.com/netbirdio/netbird/management/server/status" + "github.com/netbirdio/netbird/management/server/types" "github.com/netbirdio/netbird/management/server/jwtclaims" ) // GetUser function defines a function to fetch user from Account by jwtclaims.AuthorizationClaims -type GetUser func(ctx context.Context, claims jwtclaims.AuthorizationClaims) (*server.User, error) +type GetUser func(ctx context.Context, claims jwtclaims.AuthorizationClaims) (*types.User, error) // AccessControl middleware to restrict to make POST/PUT/DELETE requests by admin only type AccessControl struct { diff --git a/management/server/http/middleware/auth_middleware.go b/management/server/http/middleware/auth_middleware.go index b25aad99c..0d3459712 100644 --- a/management/server/http/middleware/auth_middleware.go +++ b/management/server/http/middleware/auth_middleware.go @@ -11,16 +11,16 @@ import ( "github.com/golang-jwt/jwt" log "github.com/sirupsen/logrus" - "github.com/netbirdio/netbird/management/server" nbContext "github.com/netbirdio/netbird/management/server/context" "github.com/netbirdio/netbird/management/server/http/middleware/bypass" "github.com/netbirdio/netbird/management/server/http/util" "github.com/netbirdio/netbird/management/server/jwtclaims" "github.com/netbirdio/netbird/management/server/status" + "github.com/netbirdio/netbird/management/server/types" ) // GetAccountFromPATFunc function -type GetAccountFromPATFunc func(ctx context.Context, token string) (*server.Account, *server.User, *server.PersonalAccessToken, error) +type GetAccountFromPATFunc func(ctx context.Context, token string) (*types.Account, *types.User, *types.PersonalAccessToken, error) // ValidateAndParseTokenFunc function type ValidateAndParseTokenFunc func(ctx context.Context, token string) (*jwt.Token, error) diff --git a/management/server/http/middleware/auth_middleware_test.go b/management/server/http/middleware/auth_middleware_test.go index fdfb0ea24..b0d970c5d 100644 --- a/management/server/http/middleware/auth_middleware_test.go +++ b/management/server/http/middleware/auth_middleware_test.go @@ -10,9 +10,9 @@ import ( "github.com/golang-jwt/jwt" - "github.com/netbirdio/netbird/management/server" "github.com/netbirdio/netbird/management/server/http/middleware/bypass" "github.com/netbirdio/netbird/management/server/jwtclaims" + "github.com/netbirdio/netbird/management/server/types" ) const ( @@ -28,13 +28,13 @@ const ( wrongToken = "wrongToken" ) -var testAccount = &server.Account{ +var testAccount = &types.Account{ Id: accountID, Domain: domain, - Users: map[string]*server.User{ + Users: map[string]*types.User{ userID: { Id: userID, - PATs: map[string]*server.PersonalAccessToken{ + PATs: map[string]*types.PersonalAccessToken{ tokenID: { ID: tokenID, Name: "My first token", @@ -49,7 +49,7 @@ var testAccount = &server.Account{ }, } -func mockGetAccountFromPAT(_ context.Context, token string) (*server.Account, *server.User, *server.PersonalAccessToken, error) { +func mockGetAccountFromPAT(_ context.Context, token string) (*types.Account, *types.User, *types.PersonalAccessToken, error) { if token == PAT { return testAccount, testAccount.Users[userID], testAccount.Users[userID].PATs[tokenID], nil } diff --git a/management/server/http/util/util.go b/management/server/http/util/util.go index 603c1c696..3d7eed498 100644 --- a/management/server/http/util/util.go +++ b/management/server/http/util/util.go @@ -14,6 +14,10 @@ import ( "github.com/netbirdio/netbird/management/server/status" ) +// EmptyObject is an empty struct used to return empty JSON object +type EmptyObject struct { +} + type ErrorResponse struct { Message string `json:"message"` Code int `json:"code"` diff --git a/management/server/integrated_validator.go b/management/server/integrated_validator.go index 0c70b702a..47c4ca6ae 100644 --- a/management/server/integrated_validator.go +++ b/management/server/integrated_validator.go @@ -7,6 +7,8 @@ import ( log "github.com/sirupsen/logrus" "github.com/netbirdio/netbird/management/server/account" + "github.com/netbirdio/netbird/management/server/store" + "github.com/netbirdio/netbird/management/server/types" ) // UpdateIntegratedValidatorGroups updates the integrated validator groups for a specified account. @@ -57,9 +59,9 @@ func (am *DefaultAccountManager) GroupValidation(ctx context.Context, accountID return true, nil } - err := am.Store.ExecuteInTransaction(ctx, func(transaction Store) error { + err := am.Store.ExecuteInTransaction(ctx, func(transaction store.Store) error { for _, groupID := range groupIDs { - _, err := transaction.GetGroupByID(context.Background(), LockingStrengthShare, accountID, groupID) + _, err := transaction.GetGroupByID(context.Background(), store.LockingStrengthShare, accountID, groupID) if err != nil { return err } @@ -73,6 +75,6 @@ func (am *DefaultAccountManager) GroupValidation(ctx context.Context, accountID return true, nil } -func (am *DefaultAccountManager) GetValidatedPeers(account *Account) (map[string]struct{}, error) { +func (am *DefaultAccountManager) GetValidatedPeers(account *types.Account) (map[string]struct{}, error) { return am.integratedPeerValidator.GetValidatedPeers(account.Id, account.Groups, account.Peers, account.Settings.Extra) } diff --git a/management/server/integrated_validator/interface.go b/management/server/integrated_validator/interface.go index 03be9d039..22b8026aa 100644 --- a/management/server/integrated_validator/interface.go +++ b/management/server/integrated_validator/interface.go @@ -4,8 +4,8 @@ import ( "context" "github.com/netbirdio/netbird/management/server/account" - nbgroup "github.com/netbirdio/netbird/management/server/group" nbpeer "github.com/netbirdio/netbird/management/server/peer" + "github.com/netbirdio/netbird/management/server/types" ) // IntegratedValidator interface exists to avoid the circle dependencies @@ -14,7 +14,7 @@ type IntegratedValidator interface { ValidatePeer(ctx context.Context, update *nbpeer.Peer, peer *nbpeer.Peer, userID string, accountID string, dnsDomain string, peersGroup []string, extraSettings *account.ExtraSettings) (*nbpeer.Peer, bool, error) PreparePeer(ctx context.Context, accountID string, peer *nbpeer.Peer, peersGroup []string, extraSettings *account.ExtraSettings) *nbpeer.Peer IsNotValidPeer(ctx context.Context, accountID string, peer *nbpeer.Peer, peersGroup []string, extraSettings *account.ExtraSettings) (bool, bool, error) - GetValidatedPeers(accountID string, groups map[string]*nbgroup.Group, peers map[string]*nbpeer.Peer, extraSettings *account.ExtraSettings) (map[string]struct{}, error) + GetValidatedPeers(accountID string, groups map[string]*types.Group, peers map[string]*nbpeer.Peer, extraSettings *account.ExtraSettings) (map[string]struct{}, error) PeerDeleted(ctx context.Context, accountID, peerID string) error SetPeerInvalidationListener(fn func(accountID string)) Stop(ctx context.Context) diff --git a/management/server/management_proto_test.go b/management/server/management_proto_test.go index dc8765e19..f1d6de361 100644 --- a/management/server/management_proto_test.go +++ b/management/server/management_proto_test.go @@ -23,7 +23,9 @@ import ( "github.com/netbirdio/netbird/formatter" mgmtProto "github.com/netbirdio/netbird/management/proto" "github.com/netbirdio/netbird/management/server/activity" + "github.com/netbirdio/netbird/management/server/store" "github.com/netbirdio/netbird/management/server/telemetry" + "github.com/netbirdio/netbird/management/server/types" "github.com/netbirdio/netbird/util" ) @@ -413,7 +415,7 @@ func startManagementForTest(t *testing.T, testFile string, config *Config) (*grp } s := grpc.NewServer(grpc.KeepaliveEnforcementPolicy(kaep), grpc.KeepaliveParams(kasp)) - store, cleanup, err := NewTestStoreFromSQL(context.Background(), testFile, t.TempDir()) + store, cleanup, err := store.NewTestStoreFromSQL(context.Background(), testFile, t.TempDir()) if err != nil { t.Fatal(err) } @@ -618,7 +620,7 @@ func testSyncStatusRace(t *testing.T) { } time.Sleep(10 * time.Millisecond) - peer, err := am.Store.GetPeerByPeerPubKey(context.Background(), LockingStrengthShare, peerWithInvalidStatus.PublicKey().String()) + peer, err := am.Store.GetPeerByPeerPubKey(context.Background(), store.LockingStrengthShare, peerWithInvalidStatus.PublicKey().String()) if err != nil { t.Fatal(err) return @@ -705,7 +707,7 @@ func Test_LoginPerformance(t *testing.T) { return } - setupKey, err := am.CreateSetupKey(context.Background(), account.Id, fmt.Sprintf("key-%d", j), SetupKeyReusable, time.Hour, nil, 0, fmt.Sprintf("user-%d", j), false) + setupKey, err := am.CreateSetupKey(context.Background(), account.Id, fmt.Sprintf("key-%d", j), types.SetupKeyReusable, time.Hour, nil, 0, fmt.Sprintf("user-%d", j), false) if err != nil { t.Logf("error creating setup key: %v", err) return diff --git a/management/server/management_test.go b/management/server/management_test.go index 5361da53f..f0f83a237 100644 --- a/management/server/management_test.go +++ b/management/server/management_test.go @@ -23,9 +23,10 @@ import ( "github.com/netbirdio/netbird/management/server" "github.com/netbirdio/netbird/management/server/account" "github.com/netbirdio/netbird/management/server/activity" - "github.com/netbirdio/netbird/management/server/group" nbpeer "github.com/netbirdio/netbird/management/server/peer" + "github.com/netbirdio/netbird/management/server/store" "github.com/netbirdio/netbird/management/server/telemetry" + "github.com/netbirdio/netbird/management/server/types" "github.com/netbirdio/netbird/util" ) @@ -457,7 +458,7 @@ func (a MocIntegratedValidator) ValidatePeer(_ context.Context, update *nbpeer.P return update, false, nil } -func (a MocIntegratedValidator) GetValidatedPeers(accountID string, groups map[string]*group.Group, peers map[string]*nbpeer.Peer, extraSettings *account.ExtraSettings) (map[string]struct{}, error) { +func (a MocIntegratedValidator) GetValidatedPeers(accountID string, groups map[string]*types.Group, peers map[string]*nbpeer.Peer, extraSettings *account.ExtraSettings) (map[string]struct{}, error) { validatedPeers := make(map[string]struct{}) for p := range peers { validatedPeers[p] = struct{}{} @@ -532,7 +533,7 @@ func startServer(config *server.Config, dataDir string, testFile string) (*grpc. Expect(err).NotTo(HaveOccurred()) s := grpc.NewServer() - store, _, err := server.NewTestStoreFromSQL(context.Background(), testFile, dataDir) + store, _, err := store.NewTestStoreFromSQL(context.Background(), testFile, dataDir) if err != nil { log.Fatalf("failed creating a store: %s: %v", config.Datadir, err) } diff --git a/management/server/metrics/selfhosted.go b/management/server/metrics/selfhosted.go index 843fa575e..82b34393f 100644 --- a/management/server/metrics/selfhosted.go +++ b/management/server/metrics/selfhosted.go @@ -15,7 +15,8 @@ import ( "github.com/hashicorp/go-version" log "github.com/sirupsen/logrus" - "github.com/netbirdio/netbird/management/server" + "github.com/netbirdio/netbird/management/server/store" + "github.com/netbirdio/netbird/management/server/types" nbversion "github.com/netbirdio/netbird/version" ) @@ -47,8 +48,8 @@ type properties map[string]interface{} // DataSource metric data source type DataSource interface { - GetAllAccounts(ctx context.Context) []*server.Account - GetStoreEngine() server.StoreEngine + GetAllAccounts(ctx context.Context) []*types.Account + GetStoreEngine() store.Engine } // ConnManager peer connection manager that holds state for current active connections diff --git a/management/server/metrics/selfhosted_test.go b/management/server/metrics/selfhosted_test.go index 2ac2d68a0..1d356387f 100644 --- a/management/server/metrics/selfhosted_test.go +++ b/management/server/metrics/selfhosted_test.go @@ -5,10 +5,10 @@ import ( "testing" nbdns "github.com/netbirdio/netbird/dns" - "github.com/netbirdio/netbird/management/server" - "github.com/netbirdio/netbird/management/server/group" nbpeer "github.com/netbirdio/netbird/management/server/peer" "github.com/netbirdio/netbird/management/server/posture" + "github.com/netbirdio/netbird/management/server/store" + "github.com/netbirdio/netbird/management/server/types" "github.com/netbirdio/netbird/route" ) @@ -22,19 +22,19 @@ func (mockDatasource) GetAllConnectedPeers() map[string]struct{} { } // GetAllAccounts returns a list of *server.Account for use in tests with predefined information -func (mockDatasource) GetAllAccounts(_ context.Context) []*server.Account { - return []*server.Account{ +func (mockDatasource) GetAllAccounts(_ context.Context) []*types.Account { + return []*types.Account{ { Id: "1", - Settings: &server.Settings{PeerLoginExpirationEnabled: true}, - SetupKeys: map[string]*server.SetupKey{ + Settings: &types.Settings{PeerLoginExpirationEnabled: true}, + SetupKeys: map[string]*types.SetupKey{ "1": { Id: "1", Ephemeral: true, UsedTimes: 1, }, }, - Groups: map[string]*group.Group{ + Groups: map[string]*types.Group{ "1": {}, "2": {}, }, @@ -49,20 +49,20 @@ func (mockDatasource) GetAllAccounts(_ context.Context) []*server.Account { Meta: nbpeer.PeerSystemMeta{GoOS: "linux", WtVersion: "0.0.1"}, }, }, - Policies: []*server.Policy{ + Policies: []*types.Policy{ { - Rules: []*server.PolicyRule{ + Rules: []*types.PolicyRule{ { Bidirectional: true, - Protocol: server.PolicyRuleProtocolTCP, + Protocol: types.PolicyRuleProtocolTCP, }, }, }, { - Rules: []*server.PolicyRule{ + Rules: []*types.PolicyRule{ { Bidirectional: false, - Protocol: server.PolicyRuleProtocolTCP, + Protocol: types.PolicyRuleProtocolTCP, }, }, SourcePostureChecks: []string{"1"}, @@ -94,16 +94,16 @@ func (mockDatasource) GetAllAccounts(_ context.Context) []*server.Account { }, }, }, - Users: map[string]*server.User{ + Users: map[string]*types.User{ "1": { IsServiceUser: true, - PATs: map[string]*server.PersonalAccessToken{ + PATs: map[string]*types.PersonalAccessToken{ "1": {}, }, }, "2": { IsServiceUser: false, - PATs: map[string]*server.PersonalAccessToken{ + PATs: map[string]*types.PersonalAccessToken{ "1": {}, }, }, @@ -111,15 +111,15 @@ func (mockDatasource) GetAllAccounts(_ context.Context) []*server.Account { }, { Id: "2", - Settings: &server.Settings{PeerLoginExpirationEnabled: true}, - SetupKeys: map[string]*server.SetupKey{ + Settings: &types.Settings{PeerLoginExpirationEnabled: true}, + SetupKeys: map[string]*types.SetupKey{ "1": { Id: "1", Ephemeral: true, UsedTimes: 1, }, }, - Groups: map[string]*group.Group{ + Groups: map[string]*types.Group{ "1": {}, "2": {}, }, @@ -134,20 +134,20 @@ func (mockDatasource) GetAllAccounts(_ context.Context) []*server.Account { Meta: nbpeer.PeerSystemMeta{GoOS: "linux", WtVersion: "0.0.1"}, }, }, - Policies: []*server.Policy{ + Policies: []*types.Policy{ { - Rules: []*server.PolicyRule{ + Rules: []*types.PolicyRule{ { Bidirectional: true, - Protocol: server.PolicyRuleProtocolTCP, + Protocol: types.PolicyRuleProtocolTCP, }, }, }, { - Rules: []*server.PolicyRule{ + Rules: []*types.PolicyRule{ { Bidirectional: false, - Protocol: server.PolicyRuleProtocolTCP, + Protocol: types.PolicyRuleProtocolTCP, }, }, }, @@ -158,16 +158,16 @@ func (mockDatasource) GetAllAccounts(_ context.Context) []*server.Account { PeerGroups: make([]string, 1), }, }, - Users: map[string]*server.User{ + Users: map[string]*types.User{ "1": { IsServiceUser: true, - PATs: map[string]*server.PersonalAccessToken{ + PATs: map[string]*types.PersonalAccessToken{ "1": {}, }, }, "2": { IsServiceUser: false, - PATs: map[string]*server.PersonalAccessToken{ + PATs: map[string]*types.PersonalAccessToken{ "1": {}, }, }, @@ -177,8 +177,8 @@ func (mockDatasource) GetAllAccounts(_ context.Context) []*server.Account { } // GetStoreEngine returns FileStoreEngine -func (mockDatasource) GetStoreEngine() server.StoreEngine { - return server.FileStoreEngine +func (mockDatasource) GetStoreEngine() store.Engine { + return store.FileStoreEngine } // TestGenerateProperties tests and validate the properties generation by using the mockDatasource for the Worker.generateProperties @@ -267,7 +267,7 @@ func TestGenerateProperties(t *testing.T) { t.Errorf("expected 2 user_peers, got %d", properties["user_peers"]) } - if properties["store_engine"] != server.FileStoreEngine { + if properties["store_engine"] != store.FileStoreEngine { t.Errorf("expected JsonFile, got %s", properties["store_engine"]) } diff --git a/management/server/migration/migration_test.go b/management/server/migration/migration_test.go index 51358c7ad..a645ae325 100644 --- a/management/server/migration/migration_test.go +++ b/management/server/migration/migration_test.go @@ -12,9 +12,9 @@ import ( "gorm.io/driver/sqlite" "gorm.io/gorm" - "github.com/netbirdio/netbird/management/server" "github.com/netbirdio/netbird/management/server/migration" nbpeer "github.com/netbirdio/netbird/management/server/peer" + "github.com/netbirdio/netbird/management/server/types" "github.com/netbirdio/netbird/route" ) @@ -31,64 +31,64 @@ func setupDatabase(t *testing.T) *gorm.DB { func TestMigrateFieldFromGobToJSON_EmptyDB(t *testing.T) { db := setupDatabase(t) - err := migration.MigrateFieldFromGobToJSON[server.Account, net.IPNet](context.Background(), db, "network_net") + err := migration.MigrateFieldFromGobToJSON[types.Account, net.IPNet](context.Background(), db, "network_net") require.NoError(t, err, "Migration should not fail for an empty database") } func TestMigrateFieldFromGobToJSON_WithGobData(t *testing.T) { db := setupDatabase(t) - err := db.AutoMigrate(&server.Account{}, &route.Route{}) + err := db.AutoMigrate(&types.Account{}, &route.Route{}) require.NoError(t, err, "Failed to auto-migrate tables") _, ipnet, err := net.ParseCIDR("10.0.0.0/24") require.NoError(t, err, "Failed to parse CIDR") type network struct { - server.Network + types.Network Net net.IPNet `gorm:"serializer:gob"` } type account struct { - server.Account + types.Account Network *network `gorm:"embedded;embeddedPrefix:network_"` } - err = db.Save(&account{Account: server.Account{Id: "123"}, Network: &network{Net: *ipnet}}).Error + err = db.Save(&account{Account: types.Account{Id: "123"}, Network: &network{Net: *ipnet}}).Error require.NoError(t, err, "Failed to insert Gob data") var gobStr string - err = db.Model(&server.Account{}).Select("network_net").First(&gobStr).Error + err = db.Model(&types.Account{}).Select("network_net").First(&gobStr).Error assert.NoError(t, err, "Failed to fetch Gob data") err = gob.NewDecoder(strings.NewReader(gobStr)).Decode(&ipnet) require.NoError(t, err, "Failed to decode Gob data") - err = migration.MigrateFieldFromGobToJSON[server.Account, net.IPNet](context.Background(), db, "network_net") + err = migration.MigrateFieldFromGobToJSON[types.Account, net.IPNet](context.Background(), db, "network_net") require.NoError(t, err, "Migration should not fail with Gob data") var jsonStr string - db.Model(&server.Account{}).Select("network_net").First(&jsonStr) + db.Model(&types.Account{}).Select("network_net").First(&jsonStr) assert.JSONEq(t, `{"IP":"10.0.0.0","Mask":"////AA=="}`, jsonStr, "Data should be migrated") } func TestMigrateFieldFromGobToJSON_WithJSONData(t *testing.T) { db := setupDatabase(t) - err := db.AutoMigrate(&server.Account{}, &route.Route{}) + err := db.AutoMigrate(&types.Account{}, &route.Route{}) require.NoError(t, err, "Failed to auto-migrate tables") _, ipnet, err := net.ParseCIDR("10.0.0.0/24") require.NoError(t, err, "Failed to parse CIDR") - err = db.Save(&server.Account{Network: &server.Network{Net: *ipnet}}).Error + err = db.Save(&types.Account{Network: &types.Network{Net: *ipnet}}).Error require.NoError(t, err, "Failed to insert JSON data") - err = migration.MigrateFieldFromGobToJSON[server.Account, net.IPNet](context.Background(), db, "network_net") + err = migration.MigrateFieldFromGobToJSON[types.Account, net.IPNet](context.Background(), db, "network_net") require.NoError(t, err, "Migration should not fail with JSON data") var jsonStr string - db.Model(&server.Account{}).Select("network_net").First(&jsonStr) + db.Model(&types.Account{}).Select("network_net").First(&jsonStr) assert.JSONEq(t, `{"IP":"10.0.0.0","Mask":"////AA=="}`, jsonStr, "Data should be unchanged") } @@ -101,7 +101,7 @@ func TestMigrateNetIPFieldFromBlobToJSON_EmptyDB(t *testing.T) { func TestMigrateNetIPFieldFromBlobToJSON_WithBlobData(t *testing.T) { db := setupDatabase(t) - err := db.AutoMigrate(&server.Account{}, &nbpeer.Peer{}) + err := db.AutoMigrate(&types.Account{}, &nbpeer.Peer{}) require.NoError(t, err, "Failed to auto-migrate tables") type location struct { @@ -115,12 +115,12 @@ func TestMigrateNetIPFieldFromBlobToJSON_WithBlobData(t *testing.T) { } type account struct { - server.Account + types.Account Peers []peer `gorm:"foreignKey:AccountID;references:id"` } err = db.Save(&account{ - Account: server.Account{Id: "123"}, + Account: types.Account{Id: "123"}, Peers: []peer{ {Location: location{ConnectionIP: net.IP{10, 0, 0, 1}}}, }}, @@ -142,10 +142,10 @@ func TestMigrateNetIPFieldFromBlobToJSON_WithBlobData(t *testing.T) { func TestMigrateNetIPFieldFromBlobToJSON_WithJSONData(t *testing.T) { db := setupDatabase(t) - err := db.AutoMigrate(&server.Account{}, &nbpeer.Peer{}) + err := db.AutoMigrate(&types.Account{}, &nbpeer.Peer{}) require.NoError(t, err, "Failed to auto-migrate tables") - err = db.Save(&server.Account{ + err = db.Save(&types.Account{ Id: "1234", PeersG: []nbpeer.Peer{ {Location: nbpeer.Location{ConnectionIP: net.IP{10, 0, 0, 1}}}, @@ -164,20 +164,20 @@ func TestMigrateNetIPFieldFromBlobToJSON_WithJSONData(t *testing.T) { func TestMigrateSetupKeyToHashedSetupKey_ForPlainKey(t *testing.T) { db := setupDatabase(t) - err := db.AutoMigrate(&server.SetupKey{}) + err := db.AutoMigrate(&types.SetupKey{}) require.NoError(t, err, "Failed to auto-migrate tables") - err = db.Save(&server.SetupKey{ + err = db.Save(&types.SetupKey{ Id: "1", Key: "EEFDAB47-C1A5-4472-8C05-71DE9A1E8382", }).Error require.NoError(t, err, "Failed to insert setup key") - err = migration.MigrateSetupKeyToHashedSetupKey[server.SetupKey](context.Background(), db) + err = migration.MigrateSetupKeyToHashedSetupKey[types.SetupKey](context.Background(), db) require.NoError(t, err, "Migration should not fail to migrate setup key") - var key server.SetupKey - err = db.Model(&server.SetupKey{}).First(&key).Error + var key types.SetupKey + err = db.Model(&types.SetupKey{}).First(&key).Error assert.NoError(t, err, "Failed to fetch setup key") assert.Equal(t, "EEFDA****", key.KeySecret, "Key should be secret") @@ -187,21 +187,21 @@ func TestMigrateSetupKeyToHashedSetupKey_ForPlainKey(t *testing.T) { func TestMigrateSetupKeyToHashedSetupKey_ForAlreadyMigratedKey_Case1(t *testing.T) { db := setupDatabase(t) - err := db.AutoMigrate(&server.SetupKey{}) + err := db.AutoMigrate(&types.SetupKey{}) require.NoError(t, err, "Failed to auto-migrate tables") - err = db.Save(&server.SetupKey{ + err = db.Save(&types.SetupKey{ Id: "1", Key: "9+FQcmNd2GCxIK+SvHmtp6PPGV4MKEicDS+xuSQmvlE=", KeySecret: "EEFDA****", }).Error require.NoError(t, err, "Failed to insert setup key") - err = migration.MigrateSetupKeyToHashedSetupKey[server.SetupKey](context.Background(), db) + err = migration.MigrateSetupKeyToHashedSetupKey[types.SetupKey](context.Background(), db) require.NoError(t, err, "Migration should not fail to migrate setup key") - var key server.SetupKey - err = db.Model(&server.SetupKey{}).First(&key).Error + var key types.SetupKey + err = db.Model(&types.SetupKey{}).First(&key).Error assert.NoError(t, err, "Failed to fetch setup key") assert.Equal(t, "EEFDA****", key.KeySecret, "Key should be secret") @@ -211,20 +211,20 @@ func TestMigrateSetupKeyToHashedSetupKey_ForAlreadyMigratedKey_Case1(t *testing. func TestMigrateSetupKeyToHashedSetupKey_ForAlreadyMigratedKey_Case2(t *testing.T) { db := setupDatabase(t) - err := db.AutoMigrate(&server.SetupKey{}) + err := db.AutoMigrate(&types.SetupKey{}) require.NoError(t, err, "Failed to auto-migrate tables") - err = db.Save(&server.SetupKey{ + err = db.Save(&types.SetupKey{ Id: "1", Key: "9+FQcmNd2GCxIK+SvHmtp6PPGV4MKEicDS+xuSQmvlE=", }).Error require.NoError(t, err, "Failed to insert setup key") - err = migration.MigrateSetupKeyToHashedSetupKey[server.SetupKey](context.Background(), db) + err = migration.MigrateSetupKeyToHashedSetupKey[types.SetupKey](context.Background(), db) require.NoError(t, err, "Migration should not fail to migrate setup key") - var key server.SetupKey - err = db.Model(&server.SetupKey{}).First(&key).Error + var key types.SetupKey + err = db.Model(&types.SetupKey{}).First(&key).Error assert.NoError(t, err, "Failed to fetch setup key") assert.Equal(t, "9+FQcmNd2GCxIK+SvHmtp6PPGV4MKEicDS+xuSQmvlE=", key.Key, "Key should be hashed") diff --git a/management/server/mock_server/account_mock.go b/management/server/mock_server/account_mock.go index 46a4fbc1f..45d5eceb6 100644 --- a/management/server/mock_server/account_mock.go +++ b/management/server/mock_server/account_mock.go @@ -13,47 +13,49 @@ import ( "github.com/netbirdio/netbird/management/domain" "github.com/netbirdio/netbird/management/server" "github.com/netbirdio/netbird/management/server/activity" - "github.com/netbirdio/netbird/management/server/group" "github.com/netbirdio/netbird/management/server/idp" "github.com/netbirdio/netbird/management/server/jwtclaims" + "github.com/netbirdio/netbird/management/server/networks" nbpeer "github.com/netbirdio/netbird/management/server/peer" "github.com/netbirdio/netbird/management/server/posture" + "github.com/netbirdio/netbird/management/server/types" + "github.com/netbirdio/netbird/management/server/users" "github.com/netbirdio/netbird/route" ) type MockAccountManager struct { - GetOrCreateAccountByUserFunc func(ctx context.Context, userId, domain string) (*server.Account, error) - GetAccountFunc func(ctx context.Context, accountID string) (*server.Account, error) - CreateSetupKeyFunc func(ctx context.Context, accountId string, keyName string, keyType server.SetupKeyType, - expiresIn time.Duration, autoGroups []string, usageLimit int, userID string, ephemeral bool) (*server.SetupKey, error) - GetSetupKeyFunc func(ctx context.Context, accountID, userID, keyID string) (*server.SetupKey, error) + GetOrCreateAccountByUserFunc func(ctx context.Context, userId, domain string) (*types.Account, error) + GetAccountFunc func(ctx context.Context, accountID string) (*types.Account, error) + CreateSetupKeyFunc func(ctx context.Context, accountId string, keyName string, keyType types.SetupKeyType, + expiresIn time.Duration, autoGroups []string, usageLimit int, userID string, ephemeral bool) (*types.SetupKey, error) + GetSetupKeyFunc func(ctx context.Context, accountID, userID, keyID string) (*types.SetupKey, error) AccountExistsFunc func(ctx context.Context, accountID string) (bool, error) GetAccountIDByUserIdFunc func(ctx context.Context, userId, domain string) (string, error) - GetUserFunc func(ctx context.Context, claims jwtclaims.AuthorizationClaims) (*server.User, error) - ListUsersFunc func(ctx context.Context, accountID string) ([]*server.User, error) + GetUserFunc func(ctx context.Context, claims jwtclaims.AuthorizationClaims) (*types.User, error) + ListUsersFunc func(ctx context.Context, accountID string) ([]*types.User, error) GetPeersFunc func(ctx context.Context, accountID, userID string) ([]*nbpeer.Peer, error) MarkPeerConnectedFunc func(ctx context.Context, peerKey string, connected bool, realIP net.IP) error - SyncAndMarkPeerFunc func(ctx context.Context, accountID string, peerPubKey string, meta nbpeer.PeerSystemMeta, realIP net.IP) (*nbpeer.Peer, *server.NetworkMap, []*posture.Checks, error) + SyncAndMarkPeerFunc func(ctx context.Context, accountID string, peerPubKey string, meta nbpeer.PeerSystemMeta, realIP net.IP) (*nbpeer.Peer, *types.NetworkMap, []*posture.Checks, error) DeletePeerFunc func(ctx context.Context, accountID, peerKey, userID string) error - GetNetworkMapFunc func(ctx context.Context, peerKey string) (*server.NetworkMap, error) - GetPeerNetworkFunc func(ctx context.Context, peerKey string) (*server.Network, error) - AddPeerFunc func(ctx context.Context, setupKey string, userId string, peer *nbpeer.Peer) (*nbpeer.Peer, *server.NetworkMap, []*posture.Checks, error) - GetGroupFunc func(ctx context.Context, accountID, groupID, userID string) (*group.Group, error) - GetAllGroupsFunc func(ctx context.Context, accountID, userID string) ([]*group.Group, error) - GetGroupByNameFunc func(ctx context.Context, accountID, groupName string) (*group.Group, error) - SaveGroupFunc func(ctx context.Context, accountID, userID string, group *group.Group) error - SaveGroupsFunc func(ctx context.Context, accountID, userID string, groups []*group.Group) error + GetNetworkMapFunc func(ctx context.Context, peerKey string) (*types.NetworkMap, error) + GetPeerNetworkFunc func(ctx context.Context, peerKey string) (*types.Network, error) + AddPeerFunc func(ctx context.Context, setupKey string, userId string, peer *nbpeer.Peer) (*nbpeer.Peer, *types.NetworkMap, []*posture.Checks, error) + GetGroupFunc func(ctx context.Context, accountID, groupID, userID string) (*types.Group, error) + GetAllGroupsFunc func(ctx context.Context, accountID, userID string) ([]*types.Group, error) + GetGroupByNameFunc func(ctx context.Context, accountID, groupName string) (*types.Group, error) + SaveGroupFunc func(ctx context.Context, accountID, userID string, group *types.Group) error + SaveGroupsFunc func(ctx context.Context, accountID, userID string, groups []*types.Group) error DeleteGroupFunc func(ctx context.Context, accountID, userId, groupID string) error DeleteGroupsFunc func(ctx context.Context, accountId, userId string, groupIDs []string) error GroupAddPeerFunc func(ctx context.Context, accountID, groupID, peerID string) error GroupDeletePeerFunc func(ctx context.Context, accountID, groupID, peerID string) error DeleteRuleFunc func(ctx context.Context, accountID, ruleID, userID string) error - GetPolicyFunc func(ctx context.Context, accountID, policyID, userID string) (*server.Policy, error) - SavePolicyFunc func(ctx context.Context, accountID, userID string, policy *server.Policy) (*server.Policy, error) + GetPolicyFunc func(ctx context.Context, accountID, policyID, userID string) (*types.Policy, error) + SavePolicyFunc func(ctx context.Context, accountID, userID string, policy *types.Policy) (*types.Policy, error) DeletePolicyFunc func(ctx context.Context, accountID, policyID, userID string) error - ListPoliciesFunc func(ctx context.Context, accountID, userID string) ([]*server.Policy, error) - GetUsersFromAccountFunc func(ctx context.Context, accountID, userID string) ([]*server.UserInfo, error) - GetAccountFromPATFunc func(ctx context.Context, pat string) (*server.Account, *server.User, *server.PersonalAccessToken, error) + ListPoliciesFunc func(ctx context.Context, accountID, userID string) ([]*types.Policy, error) + GetUsersFromAccountFunc func(ctx context.Context, accountID, userID string) ([]*types.UserInfo, error) + GetAccountFromPATFunc func(ctx context.Context, pat string) (*types.Account, *types.User, *types.PersonalAccessToken, error) MarkPATUsedFunc func(ctx context.Context, pat string) error UpdatePeerMetaFunc func(ctx context.Context, peerID string, meta nbpeer.PeerSystemMeta) error UpdatePeerFunc func(ctx context.Context, accountID, userID string, peer *nbpeer.Peer) (*nbpeer.Peer, error) @@ -62,35 +64,35 @@ type MockAccountManager struct { SaveRouteFunc func(ctx context.Context, accountID string, userID string, route *route.Route) error DeleteRouteFunc func(ctx context.Context, accountID string, routeID route.ID, userID string) error ListRoutesFunc func(ctx context.Context, accountID, userID string) ([]*route.Route, error) - SaveSetupKeyFunc func(ctx context.Context, accountID string, key *server.SetupKey, userID string) (*server.SetupKey, error) - ListSetupKeysFunc func(ctx context.Context, accountID, userID string) ([]*server.SetupKey, error) - SaveUserFunc func(ctx context.Context, accountID, userID string, user *server.User) (*server.UserInfo, error) - SaveOrAddUserFunc func(ctx context.Context, accountID, userID string, user *server.User, addIfNotExists bool) (*server.UserInfo, error) - SaveOrAddUsersFunc func(ctx context.Context, accountID, initiatorUserID string, update []*server.User, addIfNotExists bool) ([]*server.UserInfo, error) + SaveSetupKeyFunc func(ctx context.Context, accountID string, key *types.SetupKey, userID string) (*types.SetupKey, error) + ListSetupKeysFunc func(ctx context.Context, accountID, userID string) ([]*types.SetupKey, error) + SaveUserFunc func(ctx context.Context, accountID, userID string, user *types.User) (*types.UserInfo, error) + SaveOrAddUserFunc func(ctx context.Context, accountID, userID string, user *types.User, addIfNotExists bool) (*types.UserInfo, error) + SaveOrAddUsersFunc func(ctx context.Context, accountID, initiatorUserID string, update []*types.User, addIfNotExists bool) ([]*types.UserInfo, error) DeleteUserFunc func(ctx context.Context, accountID string, initiatorUserID string, targetUserID string) error DeleteRegularUsersFunc func(ctx context.Context, accountID, initiatorUserID string, targetUserIDs []string) error - CreatePATFunc func(ctx context.Context, accountID string, initiatorUserID string, targetUserId string, tokenName string, expiresIn int) (*server.PersonalAccessTokenGenerated, error) + CreatePATFunc func(ctx context.Context, accountID string, initiatorUserID string, targetUserId string, tokenName string, expiresIn int) (*types.PersonalAccessTokenGenerated, error) DeletePATFunc func(ctx context.Context, accountID string, initiatorUserID string, targetUserId string, tokenID string) error - GetPATFunc func(ctx context.Context, accountID string, initiatorUserID string, targetUserId string, tokenID string) (*server.PersonalAccessToken, error) - GetAllPATsFunc func(ctx context.Context, accountID string, initiatorUserID string, targetUserId string) ([]*server.PersonalAccessToken, error) + GetPATFunc func(ctx context.Context, accountID string, initiatorUserID string, targetUserId string, tokenID string) (*types.PersonalAccessToken, error) + GetAllPATsFunc func(ctx context.Context, accountID string, initiatorUserID string, targetUserId string) ([]*types.PersonalAccessToken, error) GetNameServerGroupFunc func(ctx context.Context, accountID, userID, nsGroupID string) (*nbdns.NameServerGroup, error) CreateNameServerGroupFunc func(ctx context.Context, accountID string, name, description string, nameServerList []nbdns.NameServer, groups []string, primary bool, domains []string, enabled bool, userID string, searchDomainsEnabled bool) (*nbdns.NameServerGroup, error) SaveNameServerGroupFunc func(ctx context.Context, accountID, userID string, nsGroupToSave *nbdns.NameServerGroup) error DeleteNameServerGroupFunc func(ctx context.Context, accountID, nsGroupID, userID string) error ListNameServerGroupsFunc func(ctx context.Context, accountID string, userID string) ([]*nbdns.NameServerGroup, error) - CreateUserFunc func(ctx context.Context, accountID, userID string, key *server.UserInfo) (*server.UserInfo, error) + CreateUserFunc func(ctx context.Context, accountID, userID string, key *types.UserInfo) (*types.UserInfo, error) GetAccountIDFromTokenFunc func(ctx context.Context, claims jwtclaims.AuthorizationClaims) (string, string, error) CheckUserAccessByJWTGroupsFunc func(ctx context.Context, claims jwtclaims.AuthorizationClaims) error DeleteAccountFunc func(ctx context.Context, accountID, userID string) error GetDNSDomainFunc func() string StoreEventFunc func(ctx context.Context, initiatorID, targetID, accountID string, activityID activity.ActivityDescriber, meta map[string]any) GetEventsFunc func(ctx context.Context, accountID, userID string) ([]*activity.Event, error) - GetDNSSettingsFunc func(ctx context.Context, accountID, userID string) (*server.DNSSettings, error) - SaveDNSSettingsFunc func(ctx context.Context, accountID, userID string, dnsSettingsToSave *server.DNSSettings) error + GetDNSSettingsFunc func(ctx context.Context, accountID, userID string) (*types.DNSSettings, error) + SaveDNSSettingsFunc func(ctx context.Context, accountID, userID string, dnsSettingsToSave *types.DNSSettings) error GetPeerFunc func(ctx context.Context, accountID, peerID, userID string) (*nbpeer.Peer, error) - UpdateAccountSettingsFunc func(ctx context.Context, accountID, userID string, newSettings *server.Settings) (*server.Account, error) - LoginPeerFunc func(ctx context.Context, login server.PeerLogin) (*nbpeer.Peer, *server.NetworkMap, []*posture.Checks, error) - SyncPeerFunc func(ctx context.Context, sync server.PeerSync, account *server.Account) (*nbpeer.Peer, *server.NetworkMap, []*posture.Checks, error) + UpdateAccountSettingsFunc func(ctx context.Context, accountID, userID string, newSettings *types.Settings) (*types.Account, error) + LoginPeerFunc func(ctx context.Context, login server.PeerLogin) (*nbpeer.Peer, *types.NetworkMap, []*posture.Checks, error) + SyncPeerFunc func(ctx context.Context, sync server.PeerSync, account *types.Account) (*nbpeer.Peer, *types.NetworkMap, []*posture.Checks, error) InviteUserFunc func(ctx context.Context, accountID string, initiatorUserID string, targetUserEmail string) error GetAllConnectedPeersFunc func() (map[string]struct{}, error) HasConnectedChannelFunc func(peerID string) bool @@ -105,12 +107,22 @@ type MockAccountManager struct { SyncPeerMetaFunc func(ctx context.Context, peerPubKey string, meta nbpeer.PeerSystemMeta) error FindExistingPostureCheckFunc func(accountID string, checks *posture.ChecksDefinition) (*posture.Checks, error) GetAccountIDForPeerKeyFunc func(ctx context.Context, peerKey string) (string, error) - GetAccountByIDFunc func(ctx context.Context, accountID string, userID string) (*server.Account, error) - GetUserByIDFunc func(ctx context.Context, id string) (*server.User, error) - GetAccountSettingsFunc func(ctx context.Context, accountID string, userID string) (*server.Settings, error) + GetAccountByIDFunc func(ctx context.Context, accountID string, userID string) (*types.Account, error) + GetUserByIDFunc func(ctx context.Context, id string) (*types.User, error) + GetAccountSettingsFunc func(ctx context.Context, accountID string, userID string) (*types.Settings, error) DeleteSetupKeyFunc func(ctx context.Context, accountID, userID, keyID string) error } +func (am *MockAccountManager) GetUserManager() users.Manager { + // TODO implement me + panic("implement me") +} + +func (am *MockAccountManager) GetNetworksManager() networks.Manager { + // TODO implement me + panic("implement me") +} + func (am *MockAccountManager) DeleteSetupKey(ctx context.Context, accountID, userID, keyID string) error { if am.DeleteSetupKeyFunc != nil { return am.DeleteSetupKeyFunc(ctx, accountID, userID, keyID) @@ -118,7 +130,7 @@ func (am *MockAccountManager) DeleteSetupKey(ctx context.Context, accountID, use return status.Errorf(codes.Unimplemented, "method DeleteSetupKey is not implemented") } -func (am *MockAccountManager) SyncAndMarkPeer(ctx context.Context, accountID string, peerPubKey string, meta nbpeer.PeerSystemMeta, realIP net.IP) (*nbpeer.Peer, *server.NetworkMap, []*posture.Checks, error) { +func (am *MockAccountManager) SyncAndMarkPeer(ctx context.Context, accountID string, peerPubKey string, meta nbpeer.PeerSystemMeta, realIP net.IP) (*nbpeer.Peer, *types.NetworkMap, []*posture.Checks, error) { if am.SyncAndMarkPeerFunc != nil { return am.SyncAndMarkPeerFunc(ctx, accountID, peerPubKey, meta, realIP) } @@ -130,7 +142,7 @@ func (am *MockAccountManager) OnPeerDisconnected(_ context.Context, accountID st panic("implement me") } -func (am *MockAccountManager) GetValidatedPeers(account *server.Account) (map[string]struct{}, error) { +func (am *MockAccountManager) GetValidatedPeers(account *types.Account) (map[string]struct{}, error) { approvedPeers := make(map[string]struct{}) for id := range account.Peers { approvedPeers[id] = struct{}{} @@ -139,7 +151,7 @@ func (am *MockAccountManager) GetValidatedPeers(account *server.Account) (map[st } // GetGroup mock implementation of GetGroup from server.AccountManager interface -func (am *MockAccountManager) GetGroup(ctx context.Context, accountId, groupID, userID string) (*group.Group, error) { +func (am *MockAccountManager) GetGroup(ctx context.Context, accountId, groupID, userID string) (*types.Group, error) { if am.GetGroupFunc != nil { return am.GetGroupFunc(ctx, accountId, groupID, userID) } @@ -147,7 +159,7 @@ func (am *MockAccountManager) GetGroup(ctx context.Context, accountId, groupID, } // GetAllGroups mock implementation of GetAllGroups from server.AccountManager interface -func (am *MockAccountManager) GetAllGroups(ctx context.Context, accountID, userID string) ([]*group.Group, error) { +func (am *MockAccountManager) GetAllGroups(ctx context.Context, accountID, userID string) ([]*types.Group, error) { if am.GetAllGroupsFunc != nil { return am.GetAllGroupsFunc(ctx, accountID, userID) } @@ -155,7 +167,7 @@ func (am *MockAccountManager) GetAllGroups(ctx context.Context, accountID, userI } // GetUsersFromAccount mock implementation of GetUsersFromAccount from server.AccountManager interface -func (am *MockAccountManager) GetUsersFromAccount(ctx context.Context, accountID string, userID string) ([]*server.UserInfo, error) { +func (am *MockAccountManager) GetUsersFromAccount(ctx context.Context, accountID string, userID string) ([]*types.UserInfo, error) { if am.GetUsersFromAccountFunc != nil { return am.GetUsersFromAccountFunc(ctx, accountID, userID) } @@ -173,7 +185,7 @@ func (am *MockAccountManager) DeletePeer(ctx context.Context, accountID, peerID, // GetOrCreateAccountByUser mock implementation of GetOrCreateAccountByUser from server.AccountManager interface func (am *MockAccountManager) GetOrCreateAccountByUser( ctx context.Context, userId, domain string, -) (*server.Account, error) { +) (*types.Account, error) { if am.GetOrCreateAccountByUserFunc != nil { return am.GetOrCreateAccountByUserFunc(ctx, userId, domain) } @@ -188,13 +200,13 @@ func (am *MockAccountManager) CreateSetupKey( ctx context.Context, accountID string, keyName string, - keyType server.SetupKeyType, + keyType types.SetupKeyType, expiresIn time.Duration, autoGroups []string, usageLimit int, userID string, ephemeral bool, -) (*server.SetupKey, error) { +) (*types.SetupKey, error) { if am.CreateSetupKeyFunc != nil { return am.CreateSetupKeyFunc(ctx, accountID, keyName, keyType, expiresIn, autoGroups, usageLimit, userID, ephemeral) } @@ -221,7 +233,7 @@ func (am *MockAccountManager) GetAccountIDByUserID(ctx context.Context, userId, } // MarkPeerConnected mock implementation of MarkPeerConnected from server.AccountManager interface -func (am *MockAccountManager) MarkPeerConnected(ctx context.Context, peerKey string, connected bool, realIP net.IP, account *server.Account) error { +func (am *MockAccountManager) MarkPeerConnected(ctx context.Context, peerKey string, connected bool, realIP net.IP, account *types.Account) error { if am.MarkPeerConnectedFunc != nil { return am.MarkPeerConnectedFunc(ctx, peerKey, connected, realIP) } @@ -229,7 +241,7 @@ func (am *MockAccountManager) MarkPeerConnected(ctx context.Context, peerKey str } // GetAccountFromPAT mock implementation of GetAccountFromPAT from server.AccountManager interface -func (am *MockAccountManager) GetAccountFromPAT(ctx context.Context, pat string) (*server.Account, *server.User, *server.PersonalAccessToken, error) { +func (am *MockAccountManager) GetAccountFromPAT(ctx context.Context, pat string) (*types.Account, *types.User, *types.PersonalAccessToken, error) { if am.GetAccountFromPATFunc != nil { return am.GetAccountFromPATFunc(ctx, pat) } @@ -253,7 +265,7 @@ func (am *MockAccountManager) MarkPATUsed(ctx context.Context, pat string) error } // CreatePAT mock implementation of GetPAT from server.AccountManager interface -func (am *MockAccountManager) CreatePAT(ctx context.Context, accountID string, initiatorUserID string, targetUserID string, name string, expiresIn int) (*server.PersonalAccessTokenGenerated, error) { +func (am *MockAccountManager) CreatePAT(ctx context.Context, accountID string, initiatorUserID string, targetUserID string, name string, expiresIn int) (*types.PersonalAccessTokenGenerated, error) { if am.CreatePATFunc != nil { return am.CreatePATFunc(ctx, accountID, initiatorUserID, targetUserID, name, expiresIn) } @@ -269,7 +281,7 @@ func (am *MockAccountManager) DeletePAT(ctx context.Context, accountID string, i } // GetPAT mock implementation of GetPAT from server.AccountManager interface -func (am *MockAccountManager) GetPAT(ctx context.Context, accountID string, initiatorUserID string, targetUserID string, tokenID string) (*server.PersonalAccessToken, error) { +func (am *MockAccountManager) GetPAT(ctx context.Context, accountID string, initiatorUserID string, targetUserID string, tokenID string) (*types.PersonalAccessToken, error) { if am.GetPATFunc != nil { return am.GetPATFunc(ctx, accountID, initiatorUserID, targetUserID, tokenID) } @@ -277,7 +289,7 @@ func (am *MockAccountManager) GetPAT(ctx context.Context, accountID string, init } // GetAllPATs mock implementation of GetAllPATs from server.AccountManager interface -func (am *MockAccountManager) GetAllPATs(ctx context.Context, accountID string, initiatorUserID string, targetUserID string) ([]*server.PersonalAccessToken, error) { +func (am *MockAccountManager) GetAllPATs(ctx context.Context, accountID string, initiatorUserID string, targetUserID string) ([]*types.PersonalAccessToken, error) { if am.GetAllPATsFunc != nil { return am.GetAllPATsFunc(ctx, accountID, initiatorUserID, targetUserID) } @@ -285,7 +297,7 @@ func (am *MockAccountManager) GetAllPATs(ctx context.Context, accountID string, } // GetNetworkMap mock implementation of GetNetworkMap from server.AccountManager interface -func (am *MockAccountManager) GetNetworkMap(ctx context.Context, peerKey string) (*server.NetworkMap, error) { +func (am *MockAccountManager) GetNetworkMap(ctx context.Context, peerKey string) (*types.NetworkMap, error) { if am.GetNetworkMapFunc != nil { return am.GetNetworkMapFunc(ctx, peerKey) } @@ -293,7 +305,7 @@ func (am *MockAccountManager) GetNetworkMap(ctx context.Context, peerKey string) } // GetPeerNetwork mock implementation of GetPeerNetwork from server.AccountManager interface -func (am *MockAccountManager) GetPeerNetwork(ctx context.Context, peerKey string) (*server.Network, error) { +func (am *MockAccountManager) GetPeerNetwork(ctx context.Context, peerKey string) (*types.Network, error) { if am.GetPeerNetworkFunc != nil { return am.GetPeerNetworkFunc(ctx, peerKey) } @@ -306,7 +318,7 @@ func (am *MockAccountManager) AddPeer( setupKey string, userId string, peer *nbpeer.Peer, -) (*nbpeer.Peer, *server.NetworkMap, []*posture.Checks, error) { +) (*nbpeer.Peer, *types.NetworkMap, []*posture.Checks, error) { if am.AddPeerFunc != nil { return am.AddPeerFunc(ctx, setupKey, userId, peer) } @@ -314,7 +326,7 @@ func (am *MockAccountManager) AddPeer( } // GetGroupByName mock implementation of GetGroupByName from server.AccountManager interface -func (am *MockAccountManager) GetGroupByName(ctx context.Context, accountID, groupName string) (*group.Group, error) { +func (am *MockAccountManager) GetGroupByName(ctx context.Context, accountID, groupName string) (*types.Group, error) { if am.GetGroupFunc != nil { return am.GetGroupByNameFunc(ctx, accountID, groupName) } @@ -322,7 +334,7 @@ func (am *MockAccountManager) GetGroupByName(ctx context.Context, accountID, gro } // SaveGroup mock implementation of SaveGroup from server.AccountManager interface -func (am *MockAccountManager) SaveGroup(ctx context.Context, accountID, userID string, group *group.Group) error { +func (am *MockAccountManager) SaveGroup(ctx context.Context, accountID, userID string, group *types.Group) error { if am.SaveGroupFunc != nil { return am.SaveGroupFunc(ctx, accountID, userID, group) } @@ -330,7 +342,7 @@ func (am *MockAccountManager) SaveGroup(ctx context.Context, accountID, userID s } // SaveGroups mock implementation of SaveGroups from server.AccountManager interface -func (am *MockAccountManager) SaveGroups(ctx context.Context, accountID, userID string, groups []*group.Group) error { +func (am *MockAccountManager) SaveGroups(ctx context.Context, accountID, userID string, groups []*types.Group) error { if am.SaveGroupsFunc != nil { return am.SaveGroupsFunc(ctx, accountID, userID, groups) } @@ -378,7 +390,7 @@ func (am *MockAccountManager) DeleteRule(ctx context.Context, accountID, ruleID, } // GetPolicy mock implementation of GetPolicy from server.AccountManager interface -func (am *MockAccountManager) GetPolicy(ctx context.Context, accountID, policyID, userID string) (*server.Policy, error) { +func (am *MockAccountManager) GetPolicy(ctx context.Context, accountID, policyID, userID string) (*types.Policy, error) { if am.GetPolicyFunc != nil { return am.GetPolicyFunc(ctx, accountID, policyID, userID) } @@ -386,7 +398,7 @@ func (am *MockAccountManager) GetPolicy(ctx context.Context, accountID, policyID } // SavePolicy mock implementation of SavePolicy from server.AccountManager interface -func (am *MockAccountManager) SavePolicy(ctx context.Context, accountID, userID string, policy *server.Policy) (*server.Policy, error) { +func (am *MockAccountManager) SavePolicy(ctx context.Context, accountID, userID string, policy *types.Policy) (*types.Policy, error) { if am.SavePolicyFunc != nil { return am.SavePolicyFunc(ctx, accountID, userID, policy) } @@ -402,7 +414,7 @@ func (am *MockAccountManager) DeletePolicy(ctx context.Context, accountID, polic } // ListPolicies mock implementation of ListPolicies from server.AccountManager interface -func (am *MockAccountManager) ListPolicies(ctx context.Context, accountID, userID string) ([]*server.Policy, error) { +func (am *MockAccountManager) ListPolicies(ctx context.Context, accountID, userID string) ([]*types.Policy, error) { if am.ListPoliciesFunc != nil { return am.ListPoliciesFunc(ctx, accountID, userID) } @@ -418,14 +430,14 @@ func (am *MockAccountManager) UpdatePeerMeta(ctx context.Context, peerID string, } // GetUser mock implementation of GetUser from server.AccountManager interface -func (am *MockAccountManager) GetUser(ctx context.Context, claims jwtclaims.AuthorizationClaims) (*server.User, error) { +func (am *MockAccountManager) GetUser(ctx context.Context, claims jwtclaims.AuthorizationClaims) (*types.User, error) { if am.GetUserFunc != nil { return am.GetUserFunc(ctx, claims) } return nil, status.Errorf(codes.Unimplemented, "method GetUser is not implemented") } -func (am *MockAccountManager) ListUsers(ctx context.Context, accountID string) ([]*server.User, error) { +func (am *MockAccountManager) ListUsers(ctx context.Context, accountID string) ([]*types.User, error) { if am.ListUsersFunc != nil { return am.ListUsersFunc(ctx, accountID) } @@ -481,7 +493,7 @@ func (am *MockAccountManager) ListRoutes(ctx context.Context, accountID, userID } // SaveSetupKey mocks SaveSetupKey of the AccountManager interface -func (am *MockAccountManager) SaveSetupKey(ctx context.Context, accountID string, key *server.SetupKey, userID string) (*server.SetupKey, error) { +func (am *MockAccountManager) SaveSetupKey(ctx context.Context, accountID string, key *types.SetupKey, userID string) (*types.SetupKey, error) { if am.SaveSetupKeyFunc != nil { return am.SaveSetupKeyFunc(ctx, accountID, key, userID) } @@ -490,7 +502,7 @@ func (am *MockAccountManager) SaveSetupKey(ctx context.Context, accountID string } // GetSetupKey mocks GetSetupKey of the AccountManager interface -func (am *MockAccountManager) GetSetupKey(ctx context.Context, accountID, userID, keyID string) (*server.SetupKey, error) { +func (am *MockAccountManager) GetSetupKey(ctx context.Context, accountID, userID, keyID string) (*types.SetupKey, error) { if am.GetSetupKeyFunc != nil { return am.GetSetupKeyFunc(ctx, accountID, userID, keyID) } @@ -499,7 +511,7 @@ func (am *MockAccountManager) GetSetupKey(ctx context.Context, accountID, userID } // ListSetupKeys mocks ListSetupKeys of the AccountManager interface -func (am *MockAccountManager) ListSetupKeys(ctx context.Context, accountID, userID string) ([]*server.SetupKey, error) { +func (am *MockAccountManager) ListSetupKeys(ctx context.Context, accountID, userID string) ([]*types.SetupKey, error) { if am.ListSetupKeysFunc != nil { return am.ListSetupKeysFunc(ctx, accountID, userID) } @@ -508,7 +520,7 @@ func (am *MockAccountManager) ListSetupKeys(ctx context.Context, accountID, user } // SaveUser mocks SaveUser of the AccountManager interface -func (am *MockAccountManager) SaveUser(ctx context.Context, accountID, userID string, user *server.User) (*server.UserInfo, error) { +func (am *MockAccountManager) SaveUser(ctx context.Context, accountID, userID string, user *types.User) (*types.UserInfo, error) { if am.SaveUserFunc != nil { return am.SaveUserFunc(ctx, accountID, userID, user) } @@ -516,7 +528,7 @@ func (am *MockAccountManager) SaveUser(ctx context.Context, accountID, userID st } // SaveOrAddUser mocks SaveOrAddUser of the AccountManager interface -func (am *MockAccountManager) SaveOrAddUser(ctx context.Context, accountID, userID string, user *server.User, addIfNotExists bool) (*server.UserInfo, error) { +func (am *MockAccountManager) SaveOrAddUser(ctx context.Context, accountID, userID string, user *types.User, addIfNotExists bool) (*types.UserInfo, error) { if am.SaveOrAddUserFunc != nil { return am.SaveOrAddUserFunc(ctx, accountID, userID, user, addIfNotExists) } @@ -524,7 +536,7 @@ func (am *MockAccountManager) SaveOrAddUser(ctx context.Context, accountID, user } // SaveOrAddUsers mocks SaveOrAddUsers of the AccountManager interface -func (am *MockAccountManager) SaveOrAddUsers(ctx context.Context, accountID, userID string, users []*server.User, addIfNotExists bool) ([]*server.UserInfo, error) { +func (am *MockAccountManager) SaveOrAddUsers(ctx context.Context, accountID, userID string, users []*types.User, addIfNotExists bool) ([]*types.UserInfo, error) { if am.SaveOrAddUsersFunc != nil { return am.SaveOrAddUsersFunc(ctx, accountID, userID, users, addIfNotExists) } @@ -595,7 +607,7 @@ func (am *MockAccountManager) ListNameServerGroups(ctx context.Context, accountI } // CreateUser mocks CreateUser of the AccountManager interface -func (am *MockAccountManager) CreateUser(ctx context.Context, accountID, userID string, invite *server.UserInfo) (*server.UserInfo, error) { +func (am *MockAccountManager) CreateUser(ctx context.Context, accountID, userID string, invite *types.UserInfo) (*types.UserInfo, error) { if am.CreateUserFunc != nil { return am.CreateUserFunc(ctx, accountID, userID, invite) } @@ -642,7 +654,7 @@ func (am *MockAccountManager) GetEvents(ctx context.Context, accountID, userID s } // GetDNSSettings mocks GetDNSSettings of the AccountManager interface -func (am *MockAccountManager) GetDNSSettings(ctx context.Context, accountID string, userID string) (*server.DNSSettings, error) { +func (am *MockAccountManager) GetDNSSettings(ctx context.Context, accountID string, userID string) (*types.DNSSettings, error) { if am.GetDNSSettingsFunc != nil { return am.GetDNSSettingsFunc(ctx, accountID, userID) } @@ -650,7 +662,7 @@ func (am *MockAccountManager) GetDNSSettings(ctx context.Context, accountID stri } // SaveDNSSettings mocks SaveDNSSettings of the AccountManager interface -func (am *MockAccountManager) SaveDNSSettings(ctx context.Context, accountID string, userID string, dnsSettingsToSave *server.DNSSettings) error { +func (am *MockAccountManager) SaveDNSSettings(ctx context.Context, accountID string, userID string, dnsSettingsToSave *types.DNSSettings) error { if am.SaveDNSSettingsFunc != nil { return am.SaveDNSSettingsFunc(ctx, accountID, userID, dnsSettingsToSave) } @@ -666,7 +678,7 @@ func (am *MockAccountManager) GetPeer(ctx context.Context, accountID, peerID, us } // UpdateAccountSettings mocks UpdateAccountSettings of the AccountManager interface -func (am *MockAccountManager) UpdateAccountSettings(ctx context.Context, accountID, userID string, newSettings *server.Settings) (*server.Account, error) { +func (am *MockAccountManager) UpdateAccountSettings(ctx context.Context, accountID, userID string, newSettings *types.Settings) (*types.Account, error) { if am.UpdateAccountSettingsFunc != nil { return am.UpdateAccountSettingsFunc(ctx, accountID, userID, newSettings) } @@ -674,7 +686,7 @@ func (am *MockAccountManager) UpdateAccountSettings(ctx context.Context, account } // LoginPeer mocks LoginPeer of the AccountManager interface -func (am *MockAccountManager) LoginPeer(ctx context.Context, login server.PeerLogin) (*nbpeer.Peer, *server.NetworkMap, []*posture.Checks, error) { +func (am *MockAccountManager) LoginPeer(ctx context.Context, login server.PeerLogin) (*nbpeer.Peer, *types.NetworkMap, []*posture.Checks, error) { if am.LoginPeerFunc != nil { return am.LoginPeerFunc(ctx, login) } @@ -682,7 +694,7 @@ func (am *MockAccountManager) LoginPeer(ctx context.Context, login server.PeerLo } // SyncPeer mocks SyncPeer of the AccountManager interface -func (am *MockAccountManager) SyncPeer(ctx context.Context, sync server.PeerSync, account *server.Account) (*nbpeer.Peer, *server.NetworkMap, []*posture.Checks, error) { +func (am *MockAccountManager) SyncPeer(ctx context.Context, sync server.PeerSync, account *types.Account) (*nbpeer.Peer, *types.NetworkMap, []*posture.Checks, error) { if am.SyncPeerFunc != nil { return am.SyncPeerFunc(ctx, sync, account) } @@ -803,7 +815,7 @@ func (am *MockAccountManager) GetAccountIDForPeerKey(ctx context.Context, peerKe } // GetAccountByID mocks GetAccountByID of the AccountManager interface -func (am *MockAccountManager) GetAccountByID(ctx context.Context, accountID string, userID string) (*server.Account, error) { +func (am *MockAccountManager) GetAccountByID(ctx context.Context, accountID string, userID string) (*types.Account, error) { if am.GetAccountByIDFunc != nil { return am.GetAccountByIDFunc(ctx, accountID, userID) } @@ -811,21 +823,21 @@ func (am *MockAccountManager) GetAccountByID(ctx context.Context, accountID stri } // GetUserByID mocks GetUserByID of the AccountManager interface -func (am *MockAccountManager) GetUserByID(ctx context.Context, id string) (*server.User, error) { +func (am *MockAccountManager) GetUserByID(ctx context.Context, id string) (*types.User, error) { if am.GetUserByIDFunc != nil { return am.GetUserByIDFunc(ctx, id) } return nil, status.Errorf(codes.Unimplemented, "method GetUserByID is not implemented") } -func (am *MockAccountManager) GetAccountSettings(ctx context.Context, accountID string, userID string) (*server.Settings, error) { +func (am *MockAccountManager) GetAccountSettings(ctx context.Context, accountID string, userID string) (*types.Settings, error) { if am.GetAccountSettingsFunc != nil { return am.GetAccountSettingsFunc(ctx, accountID, userID) } return nil, status.Errorf(codes.Unimplemented, "method GetAccountSettings is not implemented") } -func (am *MockAccountManager) GetAccount(ctx context.Context, accountID string) (*server.Account, error) { +func (am *MockAccountManager) GetAccount(ctx context.Context, accountID string) (*types.Account, error) { if am.GetAccountFunc != nil { return am.GetAccountFunc(ctx, accountID) } diff --git a/management/server/nameserver.go b/management/server/nameserver.go index e7a5387a1..19acdf1ba 100644 --- a/management/server/nameserver.go +++ b/management/server/nameserver.go @@ -11,15 +11,16 @@ import ( nbdns "github.com/netbirdio/netbird/dns" "github.com/netbirdio/netbird/management/server/activity" - nbgroup "github.com/netbirdio/netbird/management/server/group" "github.com/netbirdio/netbird/management/server/status" + "github.com/netbirdio/netbird/management/server/store" + "github.com/netbirdio/netbird/management/server/types" ) const domainPattern = `^(?i)[a-z0-9]+([\-\.]{1}[a-z0-9]+)*\.[a-z]{2,}$` // GetNameServerGroup gets a nameserver group object from account and nameserver group IDs func (am *DefaultAccountManager) GetNameServerGroup(ctx context.Context, accountID, userID, nsGroupID string) (*nbdns.NameServerGroup, error) { - user, err := am.Store.GetUserByUserID(ctx, LockingStrengthShare, userID) + user, err := am.Store.GetUserByUserID(ctx, store.LockingStrengthShare, userID) if err != nil { return nil, err } @@ -32,7 +33,7 @@ func (am *DefaultAccountManager) GetNameServerGroup(ctx context.Context, account return nil, status.NewAdminPermissionError() } - return am.Store.GetNameServerGroupByID(ctx, LockingStrengthShare, accountID, nsGroupID) + return am.Store.GetNameServerGroupByID(ctx, store.LockingStrengthShare, accountID, nsGroupID) } // CreateNameServerGroup creates and saves a new nameserver group @@ -40,7 +41,7 @@ func (am *DefaultAccountManager) CreateNameServerGroup(ctx context.Context, acco unlock := am.Store.AcquireWriteLockByUID(ctx, accountID) defer unlock() - user, err := am.Store.GetUserByUserID(ctx, LockingStrengthShare, userID) + user, err := am.Store.GetUserByUserID(ctx, store.LockingStrengthShare, userID) if err != nil { return nil, err } @@ -64,7 +65,7 @@ func (am *DefaultAccountManager) CreateNameServerGroup(ctx context.Context, acco var updateAccountPeers bool - err = am.Store.ExecuteInTransaction(ctx, func(transaction Store) error { + err = am.Store.ExecuteInTransaction(ctx, func(transaction store.Store) error { if err = validateNameServerGroup(ctx, transaction, accountID, newNSGroup); err != nil { return err } @@ -74,11 +75,11 @@ func (am *DefaultAccountManager) CreateNameServerGroup(ctx context.Context, acco return err } - if err = transaction.IncrementNetworkSerial(ctx, LockingStrengthUpdate, accountID); err != nil { + if err = transaction.IncrementNetworkSerial(ctx, store.LockingStrengthUpdate, accountID); err != nil { return err } - return transaction.SaveNameServerGroup(ctx, LockingStrengthUpdate, newNSGroup) + return transaction.SaveNameServerGroup(ctx, store.LockingStrengthUpdate, newNSGroup) }) if err != nil { return nil, err @@ -102,7 +103,7 @@ func (am *DefaultAccountManager) SaveNameServerGroup(ctx context.Context, accoun return status.Errorf(status.InvalidArgument, "nameserver group provided is nil") } - user, err := am.Store.GetUserByUserID(ctx, LockingStrengthShare, userID) + user, err := am.Store.GetUserByUserID(ctx, store.LockingStrengthShare, userID) if err != nil { return err } @@ -113,8 +114,8 @@ func (am *DefaultAccountManager) SaveNameServerGroup(ctx context.Context, accoun var updateAccountPeers bool - err = am.Store.ExecuteInTransaction(ctx, func(transaction Store) error { - oldNSGroup, err := transaction.GetNameServerGroupByID(ctx, LockingStrengthShare, accountID, nsGroupToSave.ID) + err = am.Store.ExecuteInTransaction(ctx, func(transaction store.Store) error { + oldNSGroup, err := transaction.GetNameServerGroupByID(ctx, store.LockingStrengthShare, accountID, nsGroupToSave.ID) if err != nil { return err } @@ -129,11 +130,11 @@ func (am *DefaultAccountManager) SaveNameServerGroup(ctx context.Context, accoun return err } - if err = transaction.IncrementNetworkSerial(ctx, LockingStrengthUpdate, accountID); err != nil { + if err = transaction.IncrementNetworkSerial(ctx, store.LockingStrengthUpdate, accountID); err != nil { return err } - return transaction.SaveNameServerGroup(ctx, LockingStrengthUpdate, nsGroupToSave) + return transaction.SaveNameServerGroup(ctx, store.LockingStrengthUpdate, nsGroupToSave) }) if err != nil { return err @@ -153,7 +154,7 @@ func (am *DefaultAccountManager) DeleteNameServerGroup(ctx context.Context, acco unlock := am.Store.AcquireWriteLockByUID(ctx, accountID) defer unlock() - user, err := am.Store.GetUserByUserID(ctx, LockingStrengthShare, userID) + user, err := am.Store.GetUserByUserID(ctx, store.LockingStrengthShare, userID) if err != nil { return err } @@ -165,8 +166,8 @@ func (am *DefaultAccountManager) DeleteNameServerGroup(ctx context.Context, acco var nsGroup *nbdns.NameServerGroup var updateAccountPeers bool - err = am.Store.ExecuteInTransaction(ctx, func(transaction Store) error { - nsGroup, err = transaction.GetNameServerGroupByID(ctx, LockingStrengthUpdate, accountID, nsGroupID) + err = am.Store.ExecuteInTransaction(ctx, func(transaction store.Store) error { + nsGroup, err = transaction.GetNameServerGroupByID(ctx, store.LockingStrengthUpdate, accountID, nsGroupID) if err != nil { return err } @@ -176,11 +177,11 @@ func (am *DefaultAccountManager) DeleteNameServerGroup(ctx context.Context, acco return err } - if err = transaction.IncrementNetworkSerial(ctx, LockingStrengthUpdate, accountID); err != nil { + if err = transaction.IncrementNetworkSerial(ctx, store.LockingStrengthUpdate, accountID); err != nil { return err } - return transaction.DeleteNameServerGroup(ctx, LockingStrengthUpdate, accountID, nsGroupID) + return transaction.DeleteNameServerGroup(ctx, store.LockingStrengthUpdate, accountID, nsGroupID) }) if err != nil { return err @@ -197,7 +198,7 @@ func (am *DefaultAccountManager) DeleteNameServerGroup(ctx context.Context, acco // ListNameServerGroups returns a list of nameserver groups from account func (am *DefaultAccountManager) ListNameServerGroups(ctx context.Context, accountID string, userID string) ([]*nbdns.NameServerGroup, error) { - user, err := am.Store.GetUserByUserID(ctx, LockingStrengthShare, userID) + user, err := am.Store.GetUserByUserID(ctx, store.LockingStrengthShare, userID) if err != nil { return nil, err } @@ -210,10 +211,10 @@ func (am *DefaultAccountManager) ListNameServerGroups(ctx context.Context, accou return nil, status.NewAdminPermissionError() } - return am.Store.GetAccountNameServerGroups(ctx, LockingStrengthShare, accountID) + return am.Store.GetAccountNameServerGroups(ctx, store.LockingStrengthShare, accountID) } -func validateNameServerGroup(ctx context.Context, transaction Store, accountID string, nameserverGroup *nbdns.NameServerGroup) error { +func validateNameServerGroup(ctx context.Context, transaction store.Store, accountID string, nameserverGroup *nbdns.NameServerGroup) error { err := validateDomainInput(nameserverGroup.Primary, nameserverGroup.Domains, nameserverGroup.SearchDomainsEnabled) if err != nil { return err @@ -224,7 +225,7 @@ func validateNameServerGroup(ctx context.Context, transaction Store, accountID s return err } - nsServerGroups, err := transaction.GetAccountNameServerGroups(ctx, LockingStrengthShare, accountID) + nsServerGroups, err := transaction.GetAccountNameServerGroups(ctx, store.LockingStrengthShare, accountID) if err != nil { return err } @@ -234,7 +235,7 @@ func validateNameServerGroup(ctx context.Context, transaction Store, accountID s return err } - groups, err := transaction.GetGroupsByIDs(ctx, LockingStrengthShare, accountID, nameserverGroup.Groups) + groups, err := transaction.GetGroupsByIDs(ctx, store.LockingStrengthShare, accountID, nameserverGroup.Groups) if err != nil { return err } @@ -243,7 +244,7 @@ func validateNameServerGroup(ctx context.Context, transaction Store, accountID s } // areNameServerGroupChangesAffectPeers checks if the changes in the nameserver group affect the peers. -func areNameServerGroupChangesAffectPeers(ctx context.Context, transaction Store, newNSGroup, oldNSGroup *nbdns.NameServerGroup) (bool, error) { +func areNameServerGroupChangesAffectPeers(ctx context.Context, transaction store.Store, newNSGroup, oldNSGroup *nbdns.NameServerGroup) (bool, error) { if !newNSGroup.Enabled && !oldNSGroup.Enabled { return false, nil } @@ -305,7 +306,7 @@ func validateNSList(list []nbdns.NameServer) error { return nil } -func validateGroups(list []string, groups map[string]*nbgroup.Group) error { +func validateGroups(list []string, groups map[string]*types.Group) error { if len(list) == 0 { return status.Errorf(status.InvalidArgument, "the list of group IDs should not be empty") } diff --git a/management/server/nameserver_test.go b/management/server/nameserver_test.go index 846dbf023..0743db513 100644 --- a/management/server/nameserver_test.go +++ b/management/server/nameserver_test.go @@ -11,9 +11,10 @@ import ( nbdns "github.com/netbirdio/netbird/dns" "github.com/netbirdio/netbird/management/server/activity" - nbgroup "github.com/netbirdio/netbird/management/server/group" nbpeer "github.com/netbirdio/netbird/management/server/peer" + "github.com/netbirdio/netbird/management/server/store" "github.com/netbirdio/netbird/management/server/telemetry" + "github.com/netbirdio/netbird/management/server/types" ) const ( @@ -772,10 +773,10 @@ func createNSManager(t *testing.T) (*DefaultAccountManager, error) { return BuildManager(context.Background(), store, NewPeersUpdateManager(nil), nil, "", "netbird.selfhosted", eventStore, nil, false, MocIntegratedValidator{}, metrics) } -func createNSStore(t *testing.T) (Store, error) { +func createNSStore(t *testing.T) (store.Store, error) { t.Helper() dataDir := t.TempDir() - store, cleanUp, err := NewTestStoreFromSQL(context.Background(), "", dataDir) + store, cleanUp, err := store.NewTestStoreFromSQL(context.Background(), "", dataDir) if err != nil { return nil, err } @@ -784,7 +785,7 @@ func createNSStore(t *testing.T) (Store, error) { return store, nil } -func initTestNSAccount(t *testing.T, am *DefaultAccountManager) (*Account, error) { +func initTestNSAccount(t *testing.T, am *DefaultAccountManager) (*types.Account, error) { t.Helper() peer1 := &nbpeer.Peer{ Key: nsGroupPeer1Key, @@ -842,12 +843,12 @@ func initTestNSAccount(t *testing.T, am *DefaultAccountManager) (*Account, error account.NameServerGroups[existingNSGroup.ID] = &existingNSGroup - newGroup1 := &nbgroup.Group{ + newGroup1 := &types.Group{ ID: group1ID, Name: group1ID, } - newGroup2 := &nbgroup.Group{ + newGroup2 := &types.Group{ ID: group2ID, Name: group2ID, } @@ -944,7 +945,7 @@ func TestNameServerAccountPeersUpdate(t *testing.T) { var newNameServerGroupA *nbdns.NameServerGroup var newNameServerGroupB *nbdns.NameServerGroup - err := manager.SaveGroups(context.Background(), account.Id, userID, []*nbgroup.Group{ + err := manager.SaveGroups(context.Background(), account.Id, userID, []*types.Group{ { ID: "groupA", Name: "GroupA", diff --git a/management/server/networks/manager.go b/management/server/networks/manager.go new file mode 100644 index 000000000..61dd59cb8 --- /dev/null +++ b/management/server/networks/manager.go @@ -0,0 +1,110 @@ +package networks + +import ( + "context" + + "github.com/rs/xid" + + "github.com/netbirdio/netbird/management/server/networks/resources" + "github.com/netbirdio/netbird/management/server/networks/routers" + "github.com/netbirdio/netbird/management/server/networks/types" + "github.com/netbirdio/netbird/management/server/permissions" + "github.com/netbirdio/netbird/management/server/status" + "github.com/netbirdio/netbird/management/server/store" +) + +type Manager interface { + GetAllNetworks(ctx context.Context, accountID, userID string) ([]*types.Network, error) + CreateNetwork(ctx context.Context, userID string, network *types.Network) (*types.Network, error) + GetNetwork(ctx context.Context, accountID, userID, networkID string) (*types.Network, error) + UpdateNetwork(ctx context.Context, userID string, network *types.Network) (*types.Network, error) + DeleteNetwork(ctx context.Context, accountID, userID, networkID string) error + GetResourceManager() resources.Manager + GetRouterManager() routers.Manager +} + +type managerImpl struct { + store store.Store + permissionsManager permissions.Manager + routersManager routers.Manager + resourcesManager resources.Manager +} + +func NewManager(store store.Store, permissionsManager permissions.Manager) Manager { + return &managerImpl{ + store: store, + permissionsManager: permissionsManager, + routersManager: routers.NewManager(store, permissionsManager), + resourcesManager: resources.NewManager(store, permissionsManager), + } +} + +func (m *managerImpl) GetAllNetworks(ctx context.Context, accountID, userID string) ([]*types.Network, error) { + ok, err := m.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, permissions.Networks, permissions.Read) + if err != nil { + return nil, status.NewPermissionValidationError(err) + } + if !ok { + return nil, status.NewPermissionDeniedError() + } + + return m.store.GetAccountNetworks(ctx, store.LockingStrengthShare, accountID) +} + +func (m *managerImpl) CreateNetwork(ctx context.Context, userID string, network *types.Network) (*types.Network, error) { + ok, err := m.permissionsManager.ValidateUserPermissions(ctx, network.AccountID, userID, permissions.Networks, permissions.Write) + if err != nil { + return nil, status.NewPermissionValidationError(err) + } + if !ok { + return nil, status.NewPermissionDeniedError() + } + + network.ID = xid.New().String() + + return network, m.store.SaveNetwork(ctx, store.LockingStrengthUpdate, network) +} + +func (m *managerImpl) GetNetwork(ctx context.Context, accountID, userID, networkID string) (*types.Network, error) { + ok, err := m.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, permissions.Networks, permissions.Read) + if err != nil { + return nil, status.NewPermissionValidationError(err) + } + if !ok { + return nil, status.NewPermissionDeniedError() + } + + return m.store.GetNetworkByID(ctx, store.LockingStrengthShare, accountID, networkID) +} + +func (m *managerImpl) UpdateNetwork(ctx context.Context, userID string, network *types.Network) (*types.Network, error) { + ok, err := m.permissionsManager.ValidateUserPermissions(ctx, network.AccountID, userID, permissions.Networks, permissions.Write) + if err != nil { + return nil, status.NewPermissionValidationError(err) + } + if !ok { + return nil, status.NewPermissionDeniedError() + } + + return network, m.store.SaveNetwork(ctx, store.LockingStrengthUpdate, network) +} + +func (m *managerImpl) DeleteNetwork(ctx context.Context, accountID, userID, networkID string) error { + ok, err := m.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, permissions.Networks, permissions.Write) + if err != nil { + return status.NewPermissionValidationError(err) + } + if !ok { + return status.NewPermissionDeniedError() + } + + return m.store.DeleteNetwork(ctx, store.LockingStrengthUpdate, accountID, networkID) +} + +func (m *managerImpl) GetResourceManager() resources.Manager { + return m.resourcesManager +} + +func (m *managerImpl) GetRouterManager() routers.Manager { + return m.routersManager +} diff --git a/management/server/networks/resources/manager.go b/management/server/networks/resources/manager.go new file mode 100644 index 000000000..ad62f7b03 --- /dev/null +++ b/management/server/networks/resources/manager.go @@ -0,0 +1,149 @@ +package resources + +import ( + "context" + "errors" + "fmt" + + "github.com/netbirdio/netbird/management/server/networks/resources/types" + "github.com/netbirdio/netbird/management/server/permissions" + "github.com/netbirdio/netbird/management/server/status" + "github.com/netbirdio/netbird/management/server/store" +) + +type Manager interface { + GetAllResourcesInNetwork(ctx context.Context, accountID, userID, networkID string) ([]*types.NetworkResource, error) + GetAllResourcesInAccount(ctx context.Context, accountID, userID string) ([]*types.NetworkResource, error) + GetAllResourceIDsInAccount(ctx context.Context, accountID, userID string) (map[string][]string, error) + CreateResource(ctx context.Context, userID string, resource *types.NetworkResource) (*types.NetworkResource, error) + GetResource(ctx context.Context, accountID, userID, networkID, resourceID string) (*types.NetworkResource, error) + UpdateResource(ctx context.Context, userID string, resource *types.NetworkResource) (*types.NetworkResource, error) + DeleteResource(ctx context.Context, accountID, userID, networkID, resourceID string) error +} + +type managerImpl struct { + store store.Store + permissionsManager permissions.Manager +} + +func NewManager(store store.Store, permissionsManager permissions.Manager) Manager { + return &managerImpl{ + store: store, + permissionsManager: permissionsManager, + } +} + +func (m *managerImpl) GetAllResourcesInNetwork(ctx context.Context, accountID, userID, networkID string) ([]*types.NetworkResource, error) { + ok, err := m.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, permissions.Networks, permissions.Read) + if err != nil { + return nil, status.NewPermissionValidationError(err) + } + if !ok { + return nil, status.NewPermissionDeniedError() + } + + return m.store.GetNetworkResourcesByNetID(ctx, store.LockingStrengthShare, accountID, networkID) +} + +func (m *managerImpl) GetAllResourcesInAccount(ctx context.Context, accountID, userID string) ([]*types.NetworkResource, error) { + ok, err := m.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, permissions.Networks, permissions.Read) + if err != nil { + return nil, status.NewPermissionValidationError(err) + } + if !ok { + return nil, status.NewPermissionDeniedError() + } + + return m.store.GetNetworkResourcesByAccountID(ctx, store.LockingStrengthShare, accountID) +} + +func (m *managerImpl) GetAllResourceIDsInAccount(ctx context.Context, accountID, userID string) (map[string][]string, error) { + ok, err := m.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, permissions.Networks, permissions.Read) + if err != nil { + return nil, status.NewPermissionValidationError(err) + } + if !ok { + return nil, status.NewPermissionDeniedError() + } + + resources, err := m.store.GetNetworkResourcesByAccountID(ctx, store.LockingStrengthShare, accountID) + if err != nil { + return nil, fmt.Errorf("failed to get network resources: %w", err) + } + + resourceMap := make(map[string][]string) + for _, resource := range resources { + resourceMap[resource.NetworkID] = append(resourceMap[resource.NetworkID], resource.ID) + } + + return resourceMap, nil +} + +func (m *managerImpl) CreateResource(ctx context.Context, userID string, resource *types.NetworkResource) (*types.NetworkResource, error) { + ok, err := m.permissionsManager.ValidateUserPermissions(ctx, resource.AccountID, userID, permissions.Networks, permissions.Write) + if err != nil { + return nil, status.NewPermissionValidationError(err) + } + if !ok { + return nil, status.NewPermissionDeniedError() + } + + resource, err = types.NewNetworkResource(resource.AccountID, resource.NetworkID, resource.Name, resource.Description, resource.Address) + if err != nil { + return nil, fmt.Errorf("failed to create new network resource: %w", err) + } + + return resource, m.store.SaveNetworkResource(ctx, store.LockingStrengthUpdate, resource) +} + +func (m *managerImpl) GetResource(ctx context.Context, accountID, userID, networkID, resourceID string) (*types.NetworkResource, error) { + ok, err := m.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, permissions.Networks, permissions.Read) + if err != nil { + return nil, status.NewPermissionValidationError(err) + } + if !ok { + return nil, status.NewPermissionDeniedError() + } + + resource, err := m.store.GetNetworkResourceByID(ctx, store.LockingStrengthShare, accountID, resourceID) + if err != nil { + return nil, fmt.Errorf("failed to get network resource: %w", err) + } + + if resource.NetworkID != networkID { + return nil, errors.New("resource not part of network") + } + + return resource, nil +} + +func (m *managerImpl) UpdateResource(ctx context.Context, userID string, resource *types.NetworkResource) (*types.NetworkResource, error) { + ok, err := m.permissionsManager.ValidateUserPermissions(ctx, resource.AccountID, userID, permissions.Networks, permissions.Write) + if err != nil { + return nil, status.NewPermissionValidationError(err) + } + if !ok { + return nil, status.NewPermissionDeniedError() + } + + resourceType, err := types.GetResourceType(resource.Address) + if err != nil { + return nil, fmt.Errorf("failed to get resource type: %w", err) + } + + resource.Type = resourceType + + return resource, m.store.SaveNetworkResource(ctx, store.LockingStrengthUpdate, resource) +} + +func (m *managerImpl) DeleteResource(ctx context.Context, accountID, userID, networkID, resourceID string) error { + ok, err := m.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, permissions.Networks, permissions.Write) + if err != nil { + return status.NewPermissionValidationError(err) + } + if !ok { + return status.NewPermissionDeniedError() + } + + return m.store.DeleteNetworkResource(ctx, store.LockingStrengthUpdate, accountID, resourceID) +} diff --git a/management/server/networks/resources/types/resource.go b/management/server/networks/resources/types/resource.go new file mode 100644 index 000000000..dd2c10fa5 --- /dev/null +++ b/management/server/networks/resources/types/resource.go @@ -0,0 +1,105 @@ +package types + +import ( + "errors" + "fmt" + "net" + "regexp" + "strings" + + "github.com/rs/xid" + + "github.com/netbirdio/netbird/management/server/http/api" +) + +type NetworkResourceType string + +const ( + host NetworkResourceType = "host" + subnet NetworkResourceType = "subnet" + domain NetworkResourceType = "domain" +) + +func (p NetworkResourceType) String() string { + return string(p) +} + +type NetworkResource struct { + ID string `gorm:"index"` + NetworkID string `gorm:"index"` + AccountID string `gorm:"index"` + Name string + Description string + Type NetworkResourceType + Address string +} + +func NewNetworkResource(accountID, networkID, name, description, address string) (*NetworkResource, error) { + resourceType, err := GetResourceType(address) + if err != nil { + return nil, fmt.Errorf("invalid address: %w", err) + } + + return &NetworkResource{ + ID: xid.New().String(), + AccountID: accountID, + NetworkID: networkID, + Name: name, + Description: description, + Type: resourceType, + Address: address, + }, nil +} + +func (n *NetworkResource) ToAPIResponse() *api.NetworkResource { + return &api.NetworkResource{ + Id: n.ID, + Name: n.Name, + Description: &n.Description, + Type: api.NetworkResourceType(n.Type.String()), + Address: n.Address, + } +} + +func (n *NetworkResource) FromAPIRequest(req *api.NetworkResourceRequest) { + n.Name = req.Name + + if req.Description != nil { + n.Description = *req.Description + } + n.Address = req.Address +} + +func (n *NetworkResource) Copy() *NetworkResource { + return &NetworkResource{ + ID: n.ID, + AccountID: n.AccountID, + NetworkID: n.NetworkID, + Name: n.Name, + Description: n.Description, + Type: n.Type, + Address: n.Address, + } +} + +// GetResourceType returns the type of the resource based on the address +func GetResourceType(address string) (NetworkResourceType, error) { + if ip, cidr, err := net.ParseCIDR(address); err == nil { + ones, _ := cidr.Mask.Size() + if strings.HasSuffix(address, "/32") || (ip != nil && ones == 32) { + return host, nil + } + return subnet, nil + } + + if net.ParseIP(address) != nil { + return host, nil + } + + domainRegex := regexp.MustCompile(`^(\*\.)?([a-zA-Z0-9-]+\.)+[a-zA-Z]{2,}$`) + if domainRegex.MatchString(address) { + return domain, nil + } + + return "", errors.New("not a host, subnet, or domain") +} diff --git a/management/server/networks/resources/types/resource_test.go b/management/server/networks/resources/types/resource_test.go new file mode 100644 index 000000000..6b12ca0fc --- /dev/null +++ b/management/server/networks/resources/types/resource_test.go @@ -0,0 +1,41 @@ +package types + +import ( + "testing" +) + +func TestGetResourceType(t *testing.T) { + tests := []struct { + input string + expectedType NetworkResourceType + expectedErr bool + }{ + // Valid host IPs + {"1.1.1.1", host, false}, + {"1.1.1.1/32", host, false}, + // Valid subnets + {"192.168.1.0/24", subnet, false}, + {"10.0.0.0/16", subnet, false}, + // Valid domains + {"example.com", domain, false}, + {"*.example.com", domain, false}, + {"sub.example.com", domain, false}, + // Invalid inputs + {"invalid", "", true}, + {"1.1.1.1/abc", "", true}, + {"1234", "", true}, + } + + for _, tt := range tests { + t.Run(tt.input, func(t *testing.T) { + result, err := GetResourceType(tt.input) + if result != tt.expectedType { + t.Errorf("Expected type %v, got %v", tt.expectedType, result) + } + + if tt.expectedErr && err == nil { + t.Errorf("Expected error, got nil") + } + }) + } +} diff --git a/management/server/networks/routers/manager.go b/management/server/networks/routers/manager.go new file mode 100644 index 000000000..0ced5ac9b --- /dev/null +++ b/management/server/networks/routers/manager.go @@ -0,0 +1,128 @@ +package routers + +import ( + "context" + "errors" + "fmt" + + "github.com/rs/xid" + + "github.com/netbirdio/netbird/management/server/networks/routers/types" + "github.com/netbirdio/netbird/management/server/permissions" + "github.com/netbirdio/netbird/management/server/status" + "github.com/netbirdio/netbird/management/server/store" +) + +type Manager interface { + GetAllRoutersInNetwork(ctx context.Context, accountID, userID, networkID string) ([]*types.NetworkRouter, error) + GetAllRouterIDsInAccount(ctx context.Context, accountID, userID string) (map[string][]string, error) + CreateRouter(ctx context.Context, userID string, router *types.NetworkRouter) (*types.NetworkRouter, error) + GetRouter(ctx context.Context, accountID, userID, networkID, routerID string) (*types.NetworkRouter, error) + UpdateRouter(ctx context.Context, userID string, router *types.NetworkRouter) (*types.NetworkRouter, error) + DeleteRouter(ctx context.Context, accountID, userID, networkID, routerID string) error +} + +type managerImpl struct { + store store.Store + permissionsManager permissions.Manager +} + +func NewManager(store store.Store, permissionsManager permissions.Manager) Manager { + return &managerImpl{ + store: store, + permissionsManager: permissionsManager, + } +} + +func (m *managerImpl) GetAllRoutersInNetwork(ctx context.Context, accountID, userID, networkID string) ([]*types.NetworkRouter, error) { + ok, err := m.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, permissions.Networks, permissions.Read) + if err != nil { + return nil, status.NewPermissionValidationError(err) + } + if !ok { + return nil, status.NewPermissionDeniedError() + } + + return m.store.GetNetworkRoutersByNetID(ctx, store.LockingStrengthShare, accountID, networkID) +} + +func (m *managerImpl) GetAllRouterIDsInAccount(ctx context.Context, accountID, userID string) (map[string][]string, error) { + ok, err := m.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, permissions.Networks, permissions.Read) + if err != nil { + return nil, status.NewPermissionValidationError(err) + } + if !ok { + return nil, status.NewPermissionDeniedError() + } + + routers, err := m.store.GetNetworkRoutersByAccountID(ctx, store.LockingStrengthShare, accountID) + if err != nil { + return nil, fmt.Errorf("failed to get network routers: %w", err) + } + + routersMap := make(map[string][]string) + for _, router := range routers { + routersMap[router.NetworkID] = append(routersMap[router.NetworkID], router.ID) + } + + return routersMap, nil +} + +func (m *managerImpl) CreateRouter(ctx context.Context, userID string, router *types.NetworkRouter) (*types.NetworkRouter, error) { + ok, err := m.permissionsManager.ValidateUserPermissions(ctx, router.AccountID, userID, permissions.Networks, permissions.Write) + if err != nil { + return nil, status.NewPermissionValidationError(err) + } + if !ok { + return nil, status.NewPermissionDeniedError() + } + + router.ID = xid.New().String() + + return router, m.store.SaveNetworkRouter(ctx, store.LockingStrengthUpdate, router) +} + +func (m *managerImpl) GetRouter(ctx context.Context, accountID, userID, networkID, routerID string) (*types.NetworkRouter, error) { + ok, err := m.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, permissions.Networks, permissions.Read) + if err != nil { + return nil, status.NewPermissionValidationError(err) + } + if !ok { + return nil, status.NewPermissionDeniedError() + } + + router, err := m.store.GetNetworkRouterByID(ctx, store.LockingStrengthShare, accountID, routerID) + if err != nil { + return nil, fmt.Errorf("failed to get network router: %w", err) + } + + if router.NetworkID != networkID { + return nil, errors.New("router not part of network") + } + + return router, nil +} + +func (m *managerImpl) UpdateRouter(ctx context.Context, userID string, router *types.NetworkRouter) (*types.NetworkRouter, error) { + ok, err := m.permissionsManager.ValidateUserPermissions(ctx, router.AccountID, userID, permissions.Networks, permissions.Write) + if err != nil { + return nil, status.NewPermissionValidationError(err) + } + if !ok { + return nil, status.NewPermissionDeniedError() + } + + return router, m.store.SaveNetworkRouter(ctx, store.LockingStrengthUpdate, router) +} + +func (m *managerImpl) DeleteRouter(ctx context.Context, accountID, userID, networkID, routerID string) error { + ok, err := m.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, permissions.Networks, permissions.Write) + if err != nil { + return status.NewPermissionValidationError(err) + } + if !ok { + return status.NewPermissionDeniedError() + } + + return m.store.DeleteNetworkRouter(ctx, store.LockingStrengthUpdate, accountID, routerID) +} diff --git a/management/server/networks/routers/types/router.go b/management/server/networks/routers/types/router.go new file mode 100644 index 000000000..b1491d2d1 --- /dev/null +++ b/management/server/networks/routers/types/router.go @@ -0,0 +1,70 @@ +package types + +import ( + "errors" + + "github.com/rs/xid" + + "github.com/netbirdio/netbird/management/server/http/api" +) + +type NetworkRouter struct { + ID string `gorm:"index"` + NetworkID string `gorm:"index"` + AccountID string `gorm:"index"` + Peer string + PeerGroups []string `gorm:"serializer:json"` + Masquerade bool + Metric int +} + +func NewNetworkRouter(accountID string, networkID string, peer string, peerGroups []string, masquerade bool, metric int) (*NetworkRouter, error) { + if peer != "" && len(peerGroups) > 0 { + return nil, errors.New("peer and peerGroups cannot be set at the same time") + } + + return &NetworkRouter{ + ID: xid.New().String(), + AccountID: accountID, + NetworkID: networkID, + Peer: peer, + PeerGroups: peerGroups, + Masquerade: masquerade, + Metric: metric, + }, nil +} + +func (n *NetworkRouter) ToAPIResponse() *api.NetworkRouter { + return &api.NetworkRouter{ + Id: n.ID, + Peer: &n.Peer, + PeerGroups: &n.PeerGroups, + Masquerade: n.Masquerade, + Metric: n.Metric, + } +} + +func (n *NetworkRouter) FromAPIRequest(req *api.NetworkRouterRequest) { + if req.Peer != nil { + n.Peer = *req.Peer + } + + if req.PeerGroups != nil { + n.PeerGroups = *req.PeerGroups + } + + n.Masquerade = req.Masquerade + n.Metric = req.Metric +} + +func (n *NetworkRouter) Copy() *NetworkRouter { + return &NetworkRouter{ + ID: n.ID, + NetworkID: n.NetworkID, + AccountID: n.AccountID, + Peer: n.Peer, + PeerGroups: n.PeerGroups, + Masquerade: n.Masquerade, + Metric: n.Metric, + } +} diff --git a/management/server/networks/routers/types/router_test.go b/management/server/networks/routers/types/router_test.go new file mode 100644 index 000000000..3335f7c89 --- /dev/null +++ b/management/server/networks/routers/types/router_test.go @@ -0,0 +1,100 @@ +package types + +import "testing" + +func TestNewNetworkRouter(t *testing.T) { + tests := []struct { + name string + accountID string + networkID string + peer string + peerGroups []string + masquerade bool + metric int + expectedError bool + }{ + // Valid cases + { + name: "Valid with peer only", + networkID: "network-1", + accountID: "account-1", + peer: "peer-1", + peerGroups: nil, + masquerade: true, + metric: 100, + expectedError: false, + }, + { + name: "Valid with peerGroups only", + networkID: "network-2", + accountID: "account-2", + peer: "", + peerGroups: []string{"group-1", "group-2"}, + masquerade: false, + metric: 200, + expectedError: false, + }, + { + name: "Valid with no peer or peerGroups", + networkID: "network-3", + accountID: "account-3", + peer: "", + peerGroups: nil, + masquerade: true, + metric: 300, + expectedError: false, + }, + + // Invalid cases + { + name: "Invalid with both peer and peerGroups", + networkID: "network-4", + accountID: "account-4", + peer: "peer-2", + peerGroups: []string{"group-3"}, + masquerade: false, + metric: 400, + expectedError: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + router, err := NewNetworkRouter(tt.accountID, tt.networkID, tt.peer, tt.peerGroups, tt.masquerade, tt.metric) + + if tt.expectedError && err == nil { + t.Fatalf("Expected an error, got nil") + } + + if tt.expectedError == false { + if router == nil { + t.Fatalf("Expected a NetworkRouter object, got nil") + } + + if router.AccountID != tt.accountID { + t.Errorf("Expected AccountID %s, got %s", tt.accountID, router.AccountID) + } + + if router.NetworkID != tt.networkID { + t.Errorf("Expected NetworkID %s, got %s", tt.networkID, router.NetworkID) + } + + if router.Peer != tt.peer { + t.Errorf("Expected Peer %s, got %s", tt.peer, router.Peer) + } + + if len(router.PeerGroups) != len(tt.peerGroups) { + t.Errorf("Expected PeerGroups %v, got %v", tt.peerGroups, router.PeerGroups) + } + + if router.Masquerade != tt.masquerade { + t.Errorf("Expected Masquerade %v, got %v", tt.masquerade, router.Masquerade) + } + + if router.Metric != tt.metric { + t.Errorf("Expected Metric %d, got %d", tt.metric, router.Metric) + } + } + }) + } +} diff --git a/management/server/networks/types/network.go b/management/server/networks/types/network.go new file mode 100644 index 000000000..b884690d5 --- /dev/null +++ b/management/server/networks/types/network.go @@ -0,0 +1,50 @@ +package types + +import ( + "github.com/rs/xid" + + "github.com/netbirdio/netbird/management/server/http/api" +) + +type Network struct { + ID string `gorm:"index"` + AccountID string `gorm:"index"` + Name string + Description string +} + +func NewNetwork(accountId, name, description string) *Network { + return &Network{ + ID: xid.New().String(), + AccountID: accountId, + Name: name, + Description: description, + } +} + +func (n *Network) ToAPIResponse(routerIDs []string, resourceIDs []string) *api.Network { + return &api.Network{ + Id: n.ID, + Name: n.Name, + Description: &n.Description, + Routers: routerIDs, + Resources: resourceIDs, + } +} + +func (n *Network) FromAPIRequest(req *api.NetworkRequest) { + n.Name = req.Name + if req.Description != nil { + n.Description = *req.Description + } +} + +// Copy returns a copy of a posture checks. +func (n *Network) Copy() *Network { + return &Network{ + ID: n.ID, + AccountID: n.AccountID, + Name: n.Name, + Description: n.Description, + } +} diff --git a/management/server/peer.go b/management/server/peer.go index ba211be96..616bd11ad 100644 --- a/management/server/peer.go +++ b/management/server/peer.go @@ -16,6 +16,8 @@ import ( "github.com/netbirdio/netbird/management/server/idp" "github.com/netbirdio/netbird/management/server/posture" + "github.com/netbirdio/netbird/management/server/store" + "github.com/netbirdio/netbird/management/server/types" "github.com/netbirdio/netbird/management/proto" "github.com/netbirdio/netbird/management/server/activity" @@ -92,7 +94,7 @@ func (am *DefaultAccountManager) GetPeers(ctx context.Context, accountID, userID // fetch all the peers that have access to the user's peers for _, peer := range peers { - aclPeers, _ := account.getPeerConnectionResources(ctx, peer.ID, approvedPeersMap) + aclPeers, _ := account.GetPeerConnectionResources(ctx, peer.ID, approvedPeersMap) for _, p := range aclPeers { peersMap[p.ID] = p } @@ -107,7 +109,7 @@ func (am *DefaultAccountManager) GetPeers(ctx context.Context, accountID, userID } // MarkPeerConnected marks peer as connected (true) or disconnected (false) -func (am *DefaultAccountManager) MarkPeerConnected(ctx context.Context, peerPubKey string, connected bool, realIP net.IP, account *Account) error { +func (am *DefaultAccountManager) MarkPeerConnected(ctx context.Context, peerPubKey string, connected bool, realIP net.IP, account *types.Account) error { peer, err := account.FindPeerByPubKey(peerPubKey) if err != nil { return fmt.Errorf("failed to find peer by pub key: %w", err) @@ -139,7 +141,7 @@ func (am *DefaultAccountManager) MarkPeerConnected(ctx context.Context, peerPubK return nil } -func (am *DefaultAccountManager) updatePeerStatusAndLocation(ctx context.Context, peer *nbpeer.Peer, connected bool, realIP net.IP, account *Account) (bool, error) { +func (am *DefaultAccountManager) updatePeerStatusAndLocation(ctx context.Context, peer *nbpeer.Peer, connected bool, realIP net.IP, account *types.Account) (bool, error) { oldStatus := peer.Status.Copy() newStatus := oldStatus newStatus.LastSeen = time.Now().UTC() @@ -213,9 +215,9 @@ func (am *DefaultAccountManager) UpdatePeer(ctx context.Context, accountID, user if peerLabelUpdated { peer.Name = update.Name - existingLabels := account.getPeerDNSLabels() + existingLabels := account.GetPeerDNSLabels() - newLabel, err := getPeerHostLabel(peer.Name, existingLabels) + newLabel, err := types.GetPeerHostLabel(peer.Name, existingLabels) if err != nil { return nil, err } @@ -278,7 +280,7 @@ func (am *DefaultAccountManager) UpdatePeer(ctx context.Context, accountID, user } // deletePeers will delete all specified peers and send updates to the remote peers. Don't call without acquiring account lock -func (am *DefaultAccountManager) deletePeers(ctx context.Context, account *Account, peerIDs []string, userID string) error { +func (am *DefaultAccountManager) deletePeers(ctx context.Context, account *types.Account, peerIDs []string, userID string) error { // the first loop is needed to ensure all peers present under the account before modifying, otherwise // we might have some inconsistencies @@ -316,7 +318,7 @@ func (am *DefaultAccountManager) deletePeers(ctx context.Context, account *Accou FirewallRulesIsEmpty: true, }, }, - NetworkMap: &NetworkMap{}, + NetworkMap: &types.NetworkMap{}, }) am.peersUpdateManager.CloseChannel(ctx, peer.ID) am.StoreEvent(ctx, userID, peer.ID, account.Id, activity.PeerRemovedByUser, peer.EventMeta(am.GetDNSDomain())) @@ -358,7 +360,7 @@ func (am *DefaultAccountManager) DeletePeer(ctx context.Context, accountID, peer } // GetNetworkMap returns Network map for a given peer (omits original peer from the Peers result) -func (am *DefaultAccountManager) GetNetworkMap(ctx context.Context, peerID string) (*NetworkMap, error) { +func (am *DefaultAccountManager) GetNetworkMap(ctx context.Context, peerID string) (*types.NetworkMap, error) { account, err := am.Store.GetAccountByPeerID(ctx, peerID) if err != nil { return nil, err @@ -383,7 +385,7 @@ func (am *DefaultAccountManager) GetNetworkMap(ctx context.Context, peerID strin } // GetPeerNetwork returns the Network for a given peer -func (am *DefaultAccountManager) GetPeerNetwork(ctx context.Context, peerID string) (*Network, error) { +func (am *DefaultAccountManager) GetPeerNetwork(ctx context.Context, peerID string) (*types.Network, error) { account, err := am.Store.GetAccountByPeerID(ctx, peerID) if err != nil { return nil, err @@ -399,7 +401,7 @@ func (am *DefaultAccountManager) GetPeerNetwork(ctx context.Context, peerID stri // to it. We also add the User ID to the peer metadata to identify registrant. If no userID provided, then fail with status.PermissionDenied // Each new Peer will be assigned a new next net.IP from the Account.Network and Account.Network.LastIP will be updated (IP's are not reused). // The peer property is just a placeholder for the Peer properties to pass further -func (am *DefaultAccountManager) AddPeer(ctx context.Context, setupKey, userID string, peer *nbpeer.Peer) (*nbpeer.Peer, *NetworkMap, []*posture.Checks, error) { +func (am *DefaultAccountManager) AddPeer(ctx context.Context, setupKey, userID string, peer *nbpeer.Peer) (*nbpeer.Peer, *types.NetworkMap, []*posture.Checks, error) { if setupKey == "" && userID == "" { // no auth method provided => reject access return nil, nil, nil, status.Errorf(status.Unauthenticated, "no peer auth method provided, please use a setup key or interactive SSO login") @@ -433,7 +435,7 @@ func (am *DefaultAccountManager) AddPeer(ctx context.Context, setupKey, userID s // and the peer disconnects with a timeout and tries to register again. // We just check if this machine has been registered before and reject the second registration. // The connecting peer should be able to recover with a retry. - _, err = am.Store.GetPeerByPeerPubKey(ctx, LockingStrengthShare, peer.Key) + _, err = am.Store.GetPeerByPeerPubKey(ctx, store.LockingStrengthShare, peer.Key) if err == nil { return nil, nil, nil, status.Errorf(status.PreconditionFailed, "peer has been already registered") } @@ -446,12 +448,12 @@ func (am *DefaultAccountManager) AddPeer(ctx context.Context, setupKey, userID s var newPeer *nbpeer.Peer var groupsToAdd []string - err = am.Store.ExecuteInTransaction(ctx, func(transaction Store) error { + err = am.Store.ExecuteInTransaction(ctx, func(transaction store.Store) error { var setupKeyID string var setupKeyName string var ephemeral bool if addedByUser { - user, err := transaction.GetUserByUserID(ctx, LockingStrengthUpdate, userID) + user, err := transaction.GetUserByUserID(ctx, store.LockingStrengthUpdate, userID) if err != nil { return fmt.Errorf("failed to get user groups: %w", err) } @@ -460,7 +462,7 @@ func (am *DefaultAccountManager) AddPeer(ctx context.Context, setupKey, userID s opEvent.Activity = activity.PeerAddedByUser } else { // Validate the setup key - sk, err := transaction.GetSetupKeyBySecret(ctx, LockingStrengthUpdate, encodedHashedKey) + sk, err := transaction.GetSetupKeyBySecret(ctx, store.LockingStrengthUpdate, encodedHashedKey) if err != nil { return fmt.Errorf("failed to get setup key: %w", err) } @@ -533,7 +535,7 @@ func (am *DefaultAccountManager) AddPeer(ctx context.Context, setupKey, userID s } } - settings, err := transaction.GetAccountSettings(ctx, LockingStrengthShare, accountID) + settings, err := transaction.GetAccountSettings(ctx, store.LockingStrengthShare, accountID) if err != nil { return fmt.Errorf("failed to get account settings: %w", err) } @@ -558,7 +560,7 @@ func (am *DefaultAccountManager) AddPeer(ctx context.Context, setupKey, userID s return fmt.Errorf("failed to add peer to account: %w", err) } - err = transaction.IncrementNetworkSerial(ctx, LockingStrengthUpdate, accountID) + err = transaction.IncrementNetworkSerial(ctx, store.LockingStrengthUpdate, accountID) if err != nil { return fmt.Errorf("failed to increment network serial: %w", err) } @@ -627,18 +629,18 @@ func (am *DefaultAccountManager) AddPeer(ctx context.Context, setupKey, userID s return newPeer, networkMap, postureChecks, nil } -func (am *DefaultAccountManager) getFreeIP(ctx context.Context, store Store, accountID string) (net.IP, error) { - takenIps, err := store.GetTakenIPs(ctx, LockingStrengthShare, accountID) +func (am *DefaultAccountManager) getFreeIP(ctx context.Context, s store.Store, accountID string) (net.IP, error) { + takenIps, err := s.GetTakenIPs(ctx, store.LockingStrengthUpdate, accountID) if err != nil { return nil, fmt.Errorf("failed to get taken IPs: %w", err) } - network, err := store.GetAccountNetwork(ctx, LockingStrengthUpdate, accountID) + network, err := s.GetAccountNetwork(ctx, store.LockingStrengthUpdate, accountID) if err != nil { return nil, fmt.Errorf("failed getting network: %w", err) } - nextIp, err := AllocatePeerIP(network.Net, takenIps) + nextIp, err := types.AllocatePeerIP(network.Net, takenIps) if err != nil { return nil, fmt.Errorf("failed to allocate new peer ip: %w", err) } @@ -647,7 +649,7 @@ func (am *DefaultAccountManager) getFreeIP(ctx context.Context, store Store, acc } // SyncPeer checks whether peer is eligible for receiving NetworkMap (authenticated) and returns its NetworkMap if eligible -func (am *DefaultAccountManager) SyncPeer(ctx context.Context, sync PeerSync, account *Account) (*nbpeer.Peer, *NetworkMap, []*posture.Checks, error) { +func (am *DefaultAccountManager) SyncPeer(ctx context.Context, sync PeerSync, account *types.Account) (*nbpeer.Peer, *types.NetworkMap, []*posture.Checks, error) { peer, err := account.FindPeerByPubKey(sync.WireGuardPubKey) if err != nil { return nil, nil, nil, status.NewPeerNotRegisteredError() @@ -695,7 +697,7 @@ func (am *DefaultAccountManager) SyncPeer(ctx context.Context, sync PeerSync, ac } if peerNotValid { - emptyMap := &NetworkMap{ + emptyMap := &types.NetworkMap{ Network: account.Network.Copy(), } return peer, emptyMap, []*posture.Checks{}, nil @@ -710,7 +712,7 @@ func (am *DefaultAccountManager) SyncPeer(ctx context.Context, sync PeerSync, ac return peer, account.GetPeerNetworkMap(ctx, peer.ID, customZone, validPeersMap, am.metrics.AccountManagerMetrics()), postureChecks, nil } -func (am *DefaultAccountManager) handlePeerLoginNotFound(ctx context.Context, login PeerLogin, err error) (*nbpeer.Peer, *NetworkMap, []*posture.Checks, error) { +func (am *DefaultAccountManager) handlePeerLoginNotFound(ctx context.Context, login PeerLogin, err error) (*nbpeer.Peer, *types.NetworkMap, []*posture.Checks, error) { if errStatus, ok := status.FromError(err); ok && errStatus.Type() == status.NotFound { // we couldn't find this peer by its public key which can mean that peer hasn't been registered yet. // Try registering it. @@ -730,7 +732,7 @@ func (am *DefaultAccountManager) handlePeerLoginNotFound(ctx context.Context, lo // LoginPeer logs in or registers a peer. // If peer doesn't exist the function checks whether a setup key or a user is present and registers a new peer if so. -func (am *DefaultAccountManager) LoginPeer(ctx context.Context, login PeerLogin) (*nbpeer.Peer, *NetworkMap, []*posture.Checks, error) { +func (am *DefaultAccountManager) LoginPeer(ctx context.Context, login PeerLogin) (*nbpeer.Peer, *types.NetworkMap, []*posture.Checks, error) { accountID, err := am.Store.GetAccountIDByPeerPubKey(ctx, login.WireGuardPubKey) if err != nil { return am.handlePeerLoginNotFound(ctx, login, err) @@ -755,12 +757,12 @@ func (am *DefaultAccountManager) LoginPeer(ctx context.Context, login PeerLogin) } }() - peer, err := am.Store.GetPeerByPeerPubKey(ctx, LockingStrengthUpdate, login.WireGuardPubKey) + peer, err := am.Store.GetPeerByPeerPubKey(ctx, store.LockingStrengthUpdate, login.WireGuardPubKey) if err != nil { return nil, nil, nil, err } - settings, err := am.Store.GetAccountSettings(ctx, LockingStrengthShare, accountID) + settings, err := am.Store.GetAccountSettings(ctx, store.LockingStrengthShare, accountID) if err != nil { return nil, nil, nil, err } @@ -785,7 +787,7 @@ func (am *DefaultAccountManager) LoginPeer(ctx context.Context, login PeerLogin) } } - groups, err := am.Store.GetAccountGroups(ctx, LockingStrengthShare, accountID) + groups, err := am.Store.GetAccountGroups(ctx, store.LockingStrengthShare, accountID) if err != nil { return nil, nil, nil, err } @@ -849,7 +851,7 @@ func (am *DefaultAccountManager) LoginPeer(ctx context.Context, login PeerLogin) // with no JWT token and usually no setup-key. As the client can send up to two login request to check if it is expired // and before starting the engine, we do the checks without an account lock to avoid piling up requests. func (am *DefaultAccountManager) checkIFPeerNeedsLoginWithoutLock(ctx context.Context, accountID string, login PeerLogin) error { - peer, err := am.Store.GetPeerByPeerPubKey(ctx, LockingStrengthShare, login.WireGuardPubKey) + peer, err := am.Store.GetPeerByPeerPubKey(ctx, store.LockingStrengthShare, login.WireGuardPubKey) if err != nil { return err } @@ -860,7 +862,7 @@ func (am *DefaultAccountManager) checkIFPeerNeedsLoginWithoutLock(ctx context.Co return nil } - settings, err := am.Store.GetAccountSettings(ctx, LockingStrengthShare, accountID) + settings, err := am.Store.GetAccountSettings(ctx, store.LockingStrengthShare, accountID) if err != nil { return err } @@ -872,11 +874,11 @@ func (am *DefaultAccountManager) checkIFPeerNeedsLoginWithoutLock(ctx context.Co return nil } -func (am *DefaultAccountManager) getValidatedPeerWithMap(ctx context.Context, isRequiresApproval bool, account *Account, peer *nbpeer.Peer) (*nbpeer.Peer, *NetworkMap, []*posture.Checks, error) { +func (am *DefaultAccountManager) getValidatedPeerWithMap(ctx context.Context, isRequiresApproval bool, account *types.Account, peer *nbpeer.Peer) (*nbpeer.Peer, *types.NetworkMap, []*posture.Checks, error) { var postureChecks []*posture.Checks if isRequiresApproval { - emptyMap := &NetworkMap{ + emptyMap := &types.NetworkMap{ Network: account.Network.Copy(), } return peer, emptyMap, nil, nil @@ -896,7 +898,7 @@ func (am *DefaultAccountManager) getValidatedPeerWithMap(ctx context.Context, is return peer, account.GetPeerNetworkMap(ctx, peer.ID, customZone, approvedPeersMap, am.metrics.AccountManagerMetrics()), postureChecks, nil } -func (am *DefaultAccountManager) handleExpiredPeer(ctx context.Context, user *User, peer *nbpeer.Peer) error { +func (am *DefaultAccountManager) handleExpiredPeer(ctx context.Context, user *types.User, peer *nbpeer.Peer) error { err := checkAuth(ctx, user.Id, peer) if err != nil { return err @@ -918,7 +920,7 @@ func (am *DefaultAccountManager) handleExpiredPeer(ctx context.Context, user *Us return nil } -func checkIfPeerOwnerIsBlocked(peer *nbpeer.Peer, user *User) error { +func checkIfPeerOwnerIsBlocked(peer *nbpeer.Peer, user *types.User) error { if peer.AddedWithSSOLogin() { if user.IsBlocked() { return status.Errorf(status.PermissionDenied, "user is blocked") @@ -939,7 +941,7 @@ func checkAuth(ctx context.Context, loginUserID string, peer *nbpeer.Peer) error return nil } -func peerLoginExpired(ctx context.Context, peer *nbpeer.Peer, settings *Settings) bool { +func peerLoginExpired(ctx context.Context, peer *nbpeer.Peer, settings *types.Settings) bool { expired, expiresIn := peer.LoginExpired(settings.PeerLoginExpiration) expired = settings.PeerLoginExpirationEnabled && expired if expired || peer.Status.LoginExpired { @@ -991,7 +993,7 @@ func (am *DefaultAccountManager) GetPeer(ctx context.Context, accountID, peerID, } for _, p := range userPeers { - aclPeers, _ := account.getPeerConnectionResources(ctx, p.ID, approvedPeersMap) + aclPeers, _ := account.GetPeerConnectionResources(ctx, p.ID, approvedPeersMap) for _, aclPeer := range aclPeers { if aclPeer.ID == peerID { return peer, nil @@ -1069,7 +1071,7 @@ func ConvertSliceToMap(existingLabels []string) map[string]struct{} { // IsPeerInActiveGroup checks if the given peer is part of a group that is used // in an active DNS, route, or ACL configuration. -func (am *DefaultAccountManager) isPeerInActiveGroup(ctx context.Context, account *Account, peerID string) (bool, error) { +func (am *DefaultAccountManager) isPeerInActiveGroup(ctx context.Context, account *types.Account, peerID string) (bool, error) { peerGroupIDs := make([]string, 0) for _, group := range account.Groups { if slices.Contains(group.Peers, peerID) { diff --git a/management/server/peer_test.go b/management/server/peer_test.go index b15315f98..72a39441e 100644 --- a/management/server/peer_test.go +++ b/management/server/peer_test.go @@ -24,10 +24,11 @@ import ( "github.com/netbirdio/netbird/management/proto" nbAccount "github.com/netbirdio/netbird/management/server/account" "github.com/netbirdio/netbird/management/server/activity" - nbgroup "github.com/netbirdio/netbird/management/server/group" nbpeer "github.com/netbirdio/netbird/management/server/peer" "github.com/netbirdio/netbird/management/server/posture" + "github.com/netbirdio/netbird/management/server/store" "github.com/netbirdio/netbird/management/server/telemetry" + "github.com/netbirdio/netbird/management/server/types" nbroute "github.com/netbirdio/netbird/route" ) @@ -37,13 +38,13 @@ func TestPeer_LoginExpired(t *testing.T) { expirationEnabled bool lastLogin time.Time expected bool - accountSettings *Settings + accountSettings *types.Settings }{ { name: "Peer Login Expiration Disabled. Peer Login Should Not Expire", expirationEnabled: false, lastLogin: time.Now().UTC().Add(-25 * time.Hour), - accountSettings: &Settings{ + accountSettings: &types.Settings{ PeerLoginExpirationEnabled: true, PeerLoginExpiration: time.Hour, }, @@ -53,7 +54,7 @@ func TestPeer_LoginExpired(t *testing.T) { name: "Peer Login Should Expire", expirationEnabled: true, lastLogin: time.Now().UTC().Add(-25 * time.Hour), - accountSettings: &Settings{ + accountSettings: &types.Settings{ PeerLoginExpirationEnabled: true, PeerLoginExpiration: time.Hour, }, @@ -63,7 +64,7 @@ func TestPeer_LoginExpired(t *testing.T) { name: "Peer Login Should Not Expire", expirationEnabled: true, lastLogin: time.Now().UTC(), - accountSettings: &Settings{ + accountSettings: &types.Settings{ PeerLoginExpirationEnabled: true, PeerLoginExpiration: time.Hour, }, @@ -92,14 +93,14 @@ func TestPeer_SessionExpired(t *testing.T) { lastLogin time.Time connected bool expected bool - accountSettings *Settings + accountSettings *types.Settings }{ { name: "Peer Inactivity Expiration Disabled. Peer Inactivity Should Not Expire", expirationEnabled: false, connected: false, lastLogin: time.Now().UTC().Add(-1 * time.Second), - accountSettings: &Settings{ + accountSettings: &types.Settings{ PeerInactivityExpirationEnabled: true, PeerInactivityExpiration: time.Hour, }, @@ -110,7 +111,7 @@ func TestPeer_SessionExpired(t *testing.T) { expirationEnabled: true, connected: false, lastLogin: time.Now().UTC().Add(-1 * time.Second), - accountSettings: &Settings{ + accountSettings: &types.Settings{ PeerInactivityExpirationEnabled: true, PeerInactivityExpiration: time.Second, }, @@ -121,7 +122,7 @@ func TestPeer_SessionExpired(t *testing.T) { expirationEnabled: true, connected: true, lastLogin: time.Now().UTC(), - accountSettings: &Settings{ + accountSettings: &types.Settings{ PeerInactivityExpirationEnabled: true, PeerInactivityExpiration: time.Second, }, @@ -161,7 +162,7 @@ func TestAccountManager_GetNetworkMap(t *testing.T) { t.Fatal(err) } - setupKey, err := manager.CreateSetupKey(context.Background(), account.Id, "test-key", SetupKeyReusable, time.Hour, nil, 999, userId, false) + setupKey, err := manager.CreateSetupKey(context.Background(), account.Id, "test-key", types.SetupKeyReusable, time.Hour, nil, 999, userId, false) if err != nil { t.Fatal("error creating setup key") return @@ -233,9 +234,9 @@ func TestAccountManager_GetNetworkMapWithPolicy(t *testing.T) { t.Fatal(err) } - var setupKey *SetupKey + var setupKey *types.SetupKey for _, key := range account.SetupKeys { - if key.Type == SetupKeyReusable { + if key.Type == types.SetupKeyReusable { setupKey = key } } @@ -281,8 +282,8 @@ func TestAccountManager_GetNetworkMapWithPolicy(t *testing.T) { return } var ( - group1 nbgroup.Group - group2 nbgroup.Group + group1 types.Group + group2 types.Group ) group1.ID = xid.New().String() @@ -303,16 +304,16 @@ func TestAccountManager_GetNetworkMapWithPolicy(t *testing.T) { return } - policy := &Policy{ + policy := &types.Policy{ Name: "test", Enabled: true, - Rules: []*PolicyRule{ + Rules: []*types.PolicyRule{ { Enabled: true, Sources: []string{group1.ID}, Destinations: []string{group2.ID}, Bidirectional: true, - Action: PolicyTrafficActionAccept, + Action: types.PolicyTrafficActionAccept, }, }, } @@ -410,7 +411,7 @@ func TestAccountManager_GetPeerNetwork(t *testing.T) { t.Fatal(err) } - setupKey, err := manager.CreateSetupKey(context.Background(), account.Id, "test-key", SetupKeyReusable, time.Hour, nil, 999, userId, false) + setupKey, err := manager.CreateSetupKey(context.Background(), account.Id, "test-key", types.SetupKeyReusable, time.Hour, nil, 999, userId, false) if err != nil { t.Fatal("error creating setup key") return @@ -469,9 +470,9 @@ func TestDefaultAccountManager_GetPeer(t *testing.T) { adminUser := "account_creator" someUser := "some_user" account := newAccountWithId(context.Background(), accountID, adminUser, "") - account.Users[someUser] = &User{ + account.Users[someUser] = &types.User{ Id: someUser, - Role: UserRoleUser, + Role: types.UserRoleUser, } account.Settings.RegularUsersViewBlocked = false @@ -482,7 +483,7 @@ func TestDefaultAccountManager_GetPeer(t *testing.T) { } // two peers one added by a regular user and one with a setup key - setupKey, err := manager.CreateSetupKey(context.Background(), account.Id, "test-key", SetupKeyReusable, time.Hour, nil, 999, adminUser, false) + setupKey, err := manager.CreateSetupKey(context.Background(), account.Id, "test-key", types.SetupKeyReusable, time.Hour, nil, 999, adminUser, false) if err != nil { t.Fatal("error creating setup key") return @@ -567,77 +568,77 @@ func TestDefaultAccountManager_GetPeer(t *testing.T) { func TestDefaultAccountManager_GetPeers(t *testing.T) { testCases := []struct { name string - role UserRole + role types.UserRole limitedViewSettings bool isServiceUser bool expectedPeerCount int }{ { name: "Regular user, no limited view settings, not a service user", - role: UserRoleUser, + role: types.UserRoleUser, limitedViewSettings: false, isServiceUser: false, expectedPeerCount: 1, }, { name: "Service user, no limited view settings", - role: UserRoleUser, + role: types.UserRoleUser, limitedViewSettings: false, isServiceUser: true, expectedPeerCount: 2, }, { name: "Regular user, limited view settings", - role: UserRoleUser, + role: types.UserRoleUser, limitedViewSettings: true, isServiceUser: false, expectedPeerCount: 0, }, { name: "Service user, limited view settings", - role: UserRoleUser, + role: types.UserRoleUser, limitedViewSettings: true, isServiceUser: true, expectedPeerCount: 2, }, { name: "Admin, no limited view settings, not a service user", - role: UserRoleAdmin, + role: types.UserRoleAdmin, limitedViewSettings: false, isServiceUser: false, expectedPeerCount: 2, }, { name: "Admin service user, no limited view settings", - role: UserRoleAdmin, + role: types.UserRoleAdmin, limitedViewSettings: false, isServiceUser: true, expectedPeerCount: 2, }, { name: "Admin, limited view settings", - role: UserRoleAdmin, + role: types.UserRoleAdmin, limitedViewSettings: true, isServiceUser: false, expectedPeerCount: 2, }, { name: "Admin Service user, limited view settings", - role: UserRoleAdmin, + role: types.UserRoleAdmin, limitedViewSettings: true, isServiceUser: true, expectedPeerCount: 2, }, { name: "Owner, no limited view settings", - role: UserRoleOwner, + role: types.UserRoleOwner, limitedViewSettings: true, isServiceUser: false, expectedPeerCount: 2, }, { name: "Owner, limited view settings", - role: UserRoleOwner, + role: types.UserRoleOwner, limitedViewSettings: true, isServiceUser: false, expectedPeerCount: 2, @@ -656,12 +657,12 @@ func TestDefaultAccountManager_GetPeers(t *testing.T) { adminUser := "account_creator" someUser := "some_user" account := newAccountWithId(context.Background(), accountID, adminUser, "") - account.Users[someUser] = &User{ + account.Users[someUser] = &types.User{ Id: someUser, Role: testCase.role, IsServiceUser: testCase.isServiceUser, } - account.Policies = []*Policy{} + account.Policies = []*types.Policy{} account.Settings.RegularUsersViewBlocked = testCase.limitedViewSettings err = manager.Store.SaveAccount(context.Background(), account) @@ -726,9 +727,9 @@ func setupTestAccountManager(b *testing.B, peers int, groups int) (*DefaultAccou regularUser := "regular_user" account := newAccountWithId(context.Background(), accountID, adminUser, "") - account.Users[regularUser] = &User{ + account.Users[regularUser] = &types.User{ Id: regularUser, - Role: UserRoleUser, + Role: types.UserRoleUser, } // Create peers @@ -746,10 +747,10 @@ func setupTestAccountManager(b *testing.B, peers int, groups int) (*DefaultAccou } // Create groups and policies - account.Policies = make([]*Policy, 0, groups) + account.Policies = make([]*types.Policy, 0, groups) for i := 0; i < groups; i++ { groupID := fmt.Sprintf("group-%d", i) - group := &nbgroup.Group{ + group := &types.Group{ ID: groupID, Name: fmt.Sprintf("Group %d", i), } @@ -760,11 +761,11 @@ func setupTestAccountManager(b *testing.B, peers int, groups int) (*DefaultAccou account.Groups[groupID] = group // Create a policy for this group - policy := &Policy{ + policy := &types.Policy{ ID: fmt.Sprintf("policy-%d", i), Name: fmt.Sprintf("Policy for Group %d", i), Enabled: true, - Rules: []*PolicyRule{ + Rules: []*types.PolicyRule{ { ID: fmt.Sprintf("rule-%d", i), Name: fmt.Sprintf("Rule for Group %d", i), @@ -772,8 +773,8 @@ func setupTestAccountManager(b *testing.B, peers int, groups int) (*DefaultAccou Sources: []string{groupID}, Destinations: []string{groupID}, Bidirectional: true, - Protocol: PolicyRuleProtocolALL, - Action: PolicyTrafficActionAccept, + Protocol: types.PolicyRuleProtocolALL, + Action: types.PolicyTrafficActionAccept, }, }, } @@ -939,8 +940,8 @@ func TestToSyncResponse(t *testing.T) { Payload: "turn-user", Signature: "turn-pass", } - networkMap := &NetworkMap{ - Network: &Network{Net: *ipnet, Serial: 1000}, + networkMap := &types.NetworkMap{ + Network: &types.Network{Net: *ipnet, Serial: 1000}, Peers: []*nbpeer.Peer{{IP: net.ParseIP("192.168.1.2"), Key: "peer2-key", DNSLabel: "peer2", SSHEnabled: true, SSHKey: "peer2-ssh-key"}}, OfflinePeers: []*nbpeer.Peer{{IP: net.ParseIP("192.168.1.3"), Key: "peer3-key", DNSLabel: "peer3", SSHEnabled: true, SSHKey: "peer3-ssh-key"}}, Routes: []*nbroute.Route{ @@ -987,8 +988,8 @@ func TestToSyncResponse(t *testing.T) { }, CustomZones: []nbdns.CustomZone{{Domain: "example.com", Records: []nbdns.SimpleRecord{{Name: "example.com", Type: 1, Class: "IN", TTL: 60, RData: "100.64.0.1"}}}}, }, - FirewallRules: []*FirewallRule{ - {PeerIP: "192.168.1.2", Direction: firewallRuleDirectionIN, Action: string(PolicyTrafficActionAccept), Protocol: string(PolicyRuleProtocolTCP), Port: "80"}, + FirewallRules: []*types.FirewallRule{ + {PeerIP: "192.168.1.2", Direction: types.FirewallRuleDirectionIN, Action: string(types.PolicyTrafficActionAccept), Protocol: string(types.PolicyRuleProtocolTCP), Port: "80"}, }, } dnsName := "example.com" @@ -1088,7 +1089,7 @@ func Test_RegisterPeerByUser(t *testing.T) { t.Skip("The SQLite store is not properly supported by Windows yet") } - store, cleanup, err := NewTestStoreFromSQL(context.Background(), "testdata/extended-store.sql", t.TempDir()) + s, cleanup, err := store.NewTestStoreFromSQL(context.Background(), "testdata/extended-store.sql", t.TempDir()) if err != nil { t.Fatal(err) } @@ -1099,13 +1100,13 @@ func Test_RegisterPeerByUser(t *testing.T) { metrics, err := telemetry.NewDefaultAppMetrics(context.Background()) assert.NoError(t, err) - am, err := BuildManager(context.Background(), store, NewPeersUpdateManager(nil), nil, "", "netbird.cloud", eventStore, nil, false, MocIntegratedValidator{}, metrics) + am, err := BuildManager(context.Background(), s, NewPeersUpdateManager(nil), nil, "", "netbird.cloud", eventStore, nil, false, MocIntegratedValidator{}, metrics) assert.NoError(t, err) existingAccountID := "bf1c8084-ba50-4ce7-9439-34653001fc3b" existingUserID := "edafee4e-63fb-11ec-90d6-0242ac120003" - _, err = store.GetAccount(context.Background(), existingAccountID) + _, err = s.GetAccount(context.Background(), existingAccountID) require.NoError(t, err) newPeer := &nbpeer.Peer{ @@ -1128,12 +1129,12 @@ func Test_RegisterPeerByUser(t *testing.T) { addedPeer, _, _, err := am.AddPeer(context.Background(), "", existingUserID, newPeer) require.NoError(t, err) - peer, err := store.GetPeerByPeerPubKey(context.Background(), LockingStrengthShare, addedPeer.Key) + peer, err := s.GetPeerByPeerPubKey(context.Background(), store.LockingStrengthShare, addedPeer.Key) require.NoError(t, err) assert.Equal(t, peer.AccountID, existingAccountID) assert.Equal(t, peer.UserID, existingUserID) - account, err := store.GetAccount(context.Background(), existingAccountID) + account, err := s.GetAccount(context.Background(), existingAccountID) require.NoError(t, err) assert.Contains(t, account.Peers, addedPeer.ID) assert.Equal(t, peer.Meta.Hostname, newPeer.Meta.Hostname) @@ -1152,7 +1153,7 @@ func Test_RegisterPeerBySetupKey(t *testing.T) { t.Skip("The SQLite store is not properly supported by Windows yet") } - store, cleanup, err := NewTestStoreFromSQL(context.Background(), "testdata/extended-store.sql", t.TempDir()) + s, cleanup, err := store.NewTestStoreFromSQL(context.Background(), "testdata/extended-store.sql", t.TempDir()) if err != nil { t.Fatal(err) } @@ -1163,13 +1164,13 @@ func Test_RegisterPeerBySetupKey(t *testing.T) { metrics, err := telemetry.NewDefaultAppMetrics(context.Background()) assert.NoError(t, err) - am, err := BuildManager(context.Background(), store, NewPeersUpdateManager(nil), nil, "", "netbird.cloud", eventStore, nil, false, MocIntegratedValidator{}, metrics) + am, err := BuildManager(context.Background(), s, NewPeersUpdateManager(nil), nil, "", "netbird.cloud", eventStore, nil, false, MocIntegratedValidator{}, metrics) assert.NoError(t, err) existingAccountID := "bf1c8084-ba50-4ce7-9439-34653001fc3b" existingSetupKeyID := "A2C8E62B-38F5-4553-B31E-DD66C696CEBB" - _, err = store.GetAccount(context.Background(), existingAccountID) + _, err = s.GetAccount(context.Background(), existingAccountID) require.NoError(t, err) newPeer := &nbpeer.Peer{ @@ -1192,11 +1193,11 @@ func Test_RegisterPeerBySetupKey(t *testing.T) { require.NoError(t, err) - peer, err := store.GetPeerByPeerPubKey(context.Background(), LockingStrengthShare, newPeer.Key) + peer, err := s.GetPeerByPeerPubKey(context.Background(), store.LockingStrengthShare, newPeer.Key) require.NoError(t, err) assert.Equal(t, peer.AccountID, existingAccountID) - account, err := store.GetAccount(context.Background(), existingAccountID) + account, err := s.GetAccount(context.Background(), existingAccountID) require.NoError(t, err) assert.Contains(t, account.Peers, addedPeer.ID) assert.Contains(t, account.Groups["cfefqs706sqkneg59g2g"].Peers, addedPeer.ID) @@ -1219,7 +1220,7 @@ func Test_RegisterPeerRollbackOnFailure(t *testing.T) { t.Skip("The SQLite store is not properly supported by Windows yet") } - store, cleanup, err := NewTestStoreFromSQL(context.Background(), "testdata/extended-store.sql", t.TempDir()) + s, cleanup, err := store.NewTestStoreFromSQL(context.Background(), "testdata/extended-store.sql", t.TempDir()) if err != nil { t.Fatal(err) } @@ -1230,13 +1231,13 @@ func Test_RegisterPeerRollbackOnFailure(t *testing.T) { metrics, err := telemetry.NewDefaultAppMetrics(context.Background()) assert.NoError(t, err) - am, err := BuildManager(context.Background(), store, NewPeersUpdateManager(nil), nil, "", "netbird.cloud", eventStore, nil, false, MocIntegratedValidator{}, metrics) + am, err := BuildManager(context.Background(), s, NewPeersUpdateManager(nil), nil, "", "netbird.cloud", eventStore, nil, false, MocIntegratedValidator{}, metrics) assert.NoError(t, err) existingAccountID := "bf1c8084-ba50-4ce7-9439-34653001fc3b" faultyKey := "A2C8E62B-38F5-4553-B31E-DD66C696CEBC" - _, err = store.GetAccount(context.Background(), existingAccountID) + _, err = s.GetAccount(context.Background(), existingAccountID) require.NoError(t, err) newPeer := &nbpeer.Peer{ @@ -1258,10 +1259,10 @@ func Test_RegisterPeerRollbackOnFailure(t *testing.T) { _, _, _, err = am.AddPeer(context.Background(), faultyKey, "", newPeer) require.Error(t, err) - _, err = store.GetPeerByPeerPubKey(context.Background(), LockingStrengthShare, newPeer.Key) + _, err = s.GetPeerByPeerPubKey(context.Background(), store.LockingStrengthShare, newPeer.Key) require.Error(t, err) - account, err := store.GetAccount(context.Background(), existingAccountID) + account, err := s.GetAccount(context.Background(), existingAccountID) require.NoError(t, err) assert.NotContains(t, account.Peers, newPeer.ID) assert.NotContains(t, account.Groups["cfefqs706sqkneg59g3g"].Peers, newPeer.ID) @@ -1284,7 +1285,7 @@ func TestPeerAccountPeersUpdate(t *testing.T) { err := manager.DeletePolicy(context.Background(), account.Id, account.Policies[0].ID, userID) require.NoError(t, err) - err = manager.SaveGroups(context.Background(), account.Id, userID, []*nbgroup.Group{ + err = manager.SaveGroups(context.Background(), account.Id, userID, []*types.Group{ { ID: "groupA", Name: "GroupA", @@ -1304,26 +1305,26 @@ func TestPeerAccountPeersUpdate(t *testing.T) { require.NoError(t, err) // create a user with auto groups - _, err = manager.SaveOrAddUsers(context.Background(), account.Id, userID, []*User{ + _, err = manager.SaveOrAddUsers(context.Background(), account.Id, userID, []*types.User{ { Id: "regularUser1", AccountID: account.Id, - Role: UserRoleAdmin, - Issued: UserIssuedAPI, + Role: types.UserRoleAdmin, + Issued: types.UserIssuedAPI, AutoGroups: []string{"groupA"}, }, { Id: "regularUser2", AccountID: account.Id, - Role: UserRoleAdmin, - Issued: UserIssuedAPI, + Role: types.UserRoleAdmin, + Issued: types.UserIssuedAPI, AutoGroups: []string{"groupB"}, }, { Id: "regularUser3", AccountID: account.Id, - Role: UserRoleAdmin, - Issued: UserIssuedAPI, + Role: types.UserRoleAdmin, + Issued: types.UserIssuedAPI, AutoGroups: []string{"groupC"}, }, }, true) @@ -1464,15 +1465,15 @@ func TestPeerAccountPeersUpdate(t *testing.T) { // Adding peer to group linked with policy should update account peers and send peer update t.Run("adding peer to group linked with policy", func(t *testing.T) { - _, err = manager.SavePolicy(context.Background(), account.Id, userID, &Policy{ + _, err = manager.SavePolicy(context.Background(), account.Id, userID, &types.Policy{ Enabled: true, - Rules: []*PolicyRule{ + Rules: []*types.PolicyRule{ { Enabled: true, Sources: []string{"groupA"}, Destinations: []string{"groupA"}, Bidirectional: true, - Action: PolicyTrafficActionAccept, + Action: types.PolicyTrafficActionAccept, }, }, }) diff --git a/management/server/permissions/manager.go b/management/server/permissions/manager.go new file mode 100644 index 000000000..5d1ba2320 --- /dev/null +++ b/management/server/permissions/manager.go @@ -0,0 +1,87 @@ +package permissions + +import ( + "context" + "errors" + "fmt" + + "github.com/netbirdio/netbird/management/server/settings" + "github.com/netbirdio/netbird/management/server/types" + "github.com/netbirdio/netbird/management/server/users" +) + +type Module string + +const ( + Networks Module = "networks" + Peers Module = "peers" +) + +type Operation string + +const ( + Read Operation = "read" + Write Operation = "write" +) + +type Manager interface { + ValidateUserPermissions(ctx context.Context, accountID, userID string, module Module, operation Operation) (bool, error) +} + +type managerImpl struct { + userManager users.Manager + settingsManager settings.Manager +} + +func NewManager(userManager users.Manager, settingsManager settings.Manager) Manager { + return &managerImpl{ + userManager: userManager, + settingsManager: settingsManager, + } +} + +func (m *managerImpl) ValidateUserPermissions(ctx context.Context, accountID, userID string, module Module, operation Operation) (bool, error) { + user, err := m.userManager.GetUser(ctx, userID) + if err != nil { + return false, err + } + + if user == nil { + return false, errors.New("user not found") + } + + if user.AccountID != accountID { + return false, errors.New("user does not belong to account") + } + + switch user.Role { + case types.UserRoleAdmin, types.UserRoleOwner: + return true, nil + case types.UserRoleUser: + return m.validateRegularUserPermissions(ctx, accountID, userID, module, operation) + case types.UserRoleBillingAdmin: + return false, nil + default: + return false, errors.New("invalid role") + } +} + +func (m *managerImpl) validateRegularUserPermissions(ctx context.Context, accountID, userID string, module Module, operation Operation) (bool, error) { + settings, err := m.settingsManager.GetSettings(ctx, accountID, userID) + if err != nil { + return false, fmt.Errorf("failed to get settings: %w", err) + } + if settings.RegularUsersViewBlocked { + return false, nil + } + + if operation == Write { + return false, nil + } + + if module == Peers { + return true, nil + } + + return false, nil +} diff --git a/management/server/policy.go b/management/server/policy.go index 2d3abc3f1..8ae2f96d0 100644 --- a/management/server/policy.go +++ b/management/server/policy.go @@ -3,344 +3,21 @@ package server import ( "context" _ "embed" - "strconv" - "strings" + + "github.com/rs/xid" "github.com/netbirdio/netbird/management/proto" - "github.com/rs/xid" - log "github.com/sirupsen/logrus" + "github.com/netbirdio/netbird/management/server/store" + "github.com/netbirdio/netbird/management/server/types" "github.com/netbirdio/netbird/management/server/activity" - nbgroup "github.com/netbirdio/netbird/management/server/group" - nbpeer "github.com/netbirdio/netbird/management/server/peer" "github.com/netbirdio/netbird/management/server/posture" "github.com/netbirdio/netbird/management/server/status" ) -// PolicyUpdateOperationType operation type -type PolicyUpdateOperationType int - -// PolicyTrafficActionType action type for the firewall -type PolicyTrafficActionType string - -// PolicyRuleProtocolType type of traffic -type PolicyRuleProtocolType string - -// PolicyRuleDirection direction of traffic -type PolicyRuleDirection string - -const ( - // PolicyTrafficActionAccept indicates that the traffic is accepted - PolicyTrafficActionAccept = PolicyTrafficActionType("accept") - // PolicyTrafficActionDrop indicates that the traffic is dropped - PolicyTrafficActionDrop = PolicyTrafficActionType("drop") -) - -const ( - // PolicyRuleProtocolALL type of traffic - PolicyRuleProtocolALL = PolicyRuleProtocolType("all") - // PolicyRuleProtocolTCP type of traffic - PolicyRuleProtocolTCP = PolicyRuleProtocolType("tcp") - // PolicyRuleProtocolUDP type of traffic - PolicyRuleProtocolUDP = PolicyRuleProtocolType("udp") - // PolicyRuleProtocolICMP type of traffic - PolicyRuleProtocolICMP = PolicyRuleProtocolType("icmp") -) - -const ( - // PolicyRuleFlowDirect allows traffic from source to destination - PolicyRuleFlowDirect = PolicyRuleDirection("direct") - // PolicyRuleFlowBidirect allows traffic to both directions - PolicyRuleFlowBidirect = PolicyRuleDirection("bidirect") -) - -const ( - // DefaultRuleName is a name for the Default rule that is created for every account - DefaultRuleName = "Default" - // DefaultRuleDescription is a description for the Default rule that is created for every account - DefaultRuleDescription = "This is a default rule that allows connections between all the resources" - // DefaultPolicyName is a name for the Default policy that is created for every account - DefaultPolicyName = "Default" - // DefaultPolicyDescription is a description for the Default policy that is created for every account - DefaultPolicyDescription = "This is a default policy that allows connections between all the resources" -) - -const ( - firewallRuleDirectionIN = 0 - firewallRuleDirectionOUT = 1 -) - -// PolicyUpdateOperation operation object with type and values to be applied -type PolicyUpdateOperation struct { - Type PolicyUpdateOperationType - Values []string -} - -// RulePortRange represents a range of ports for a firewall rule. -type RulePortRange struct { - Start uint16 - End uint16 -} - -// PolicyRule is the metadata of the policy -type PolicyRule struct { - // ID of the policy rule - ID string `gorm:"primaryKey"` - - // PolicyID is a reference to Policy that this object belongs - PolicyID string `json:"-" gorm:"index"` - - // Name of the rule visible in the UI - Name string - - // Description of the rule visible in the UI - Description string - - // Enabled status of rule in the system - Enabled bool - - // Action policy accept or drops packets - Action PolicyTrafficActionType - - // Destinations policy destination groups - Destinations []string `gorm:"serializer:json"` - - // Sources policy source groups - Sources []string `gorm:"serializer:json"` - - // Bidirectional define if the rule is applicable in both directions, sources, and destinations - Bidirectional bool - - // Protocol type of the traffic - Protocol PolicyRuleProtocolType - - // Ports or it ranges list - Ports []string `gorm:"serializer:json"` - - // PortRanges a list of port ranges. - PortRanges []RulePortRange `gorm:"serializer:json"` -} - -// Copy returns a copy of a policy rule -func (pm *PolicyRule) Copy() *PolicyRule { - rule := &PolicyRule{ - ID: pm.ID, - PolicyID: pm.PolicyID, - Name: pm.Name, - Description: pm.Description, - Enabled: pm.Enabled, - Action: pm.Action, - Destinations: make([]string, len(pm.Destinations)), - Sources: make([]string, len(pm.Sources)), - Bidirectional: pm.Bidirectional, - Protocol: pm.Protocol, - Ports: make([]string, len(pm.Ports)), - PortRanges: make([]RulePortRange, len(pm.PortRanges)), - } - copy(rule.Destinations, pm.Destinations) - copy(rule.Sources, pm.Sources) - copy(rule.Ports, pm.Ports) - copy(rule.PortRanges, pm.PortRanges) - return rule -} - -// Policy of the Rego query -type Policy struct { - // ID of the policy' - ID string `gorm:"primaryKey"` - - // AccountID is a reference to Account that this object belongs - AccountID string `json:"-" gorm:"index"` - - // Name of the Policy - Name string - - // Description of the policy visible in the UI - Description string - - // Enabled status of the policy - Enabled bool - - // Rules of the policy - Rules []*PolicyRule `gorm:"foreignKey:PolicyID;references:id;constraint:OnDelete:CASCADE;"` - - // SourcePostureChecks are ID references to Posture checks for policy source groups - SourcePostureChecks []string `gorm:"serializer:json"` -} - -// Copy returns a copy of the policy. -func (p *Policy) Copy() *Policy { - c := &Policy{ - ID: p.ID, - AccountID: p.AccountID, - Name: p.Name, - Description: p.Description, - Enabled: p.Enabled, - Rules: make([]*PolicyRule, len(p.Rules)), - SourcePostureChecks: make([]string, len(p.SourcePostureChecks)), - } - for i, r := range p.Rules { - c.Rules[i] = r.Copy() - } - copy(c.SourcePostureChecks, p.SourcePostureChecks) - return c -} - -// EventMeta returns activity event meta related to this policy -func (p *Policy) EventMeta() map[string]any { - return map[string]any{"name": p.Name} -} - -// UpgradeAndFix different version of policies to latest version -func (p *Policy) UpgradeAndFix() { - for _, r := range p.Rules { - // start migrate from version v0.20.3 - if r.Protocol == "" { - r.Protocol = PolicyRuleProtocolALL - } - if r.Protocol == PolicyRuleProtocolALL && !r.Bidirectional { - r.Bidirectional = true - } - // -- v0.20.4 - } -} - -// ruleGroups returns a list of all groups referenced in the policy's rules, -// including sources and destinations. -func (p *Policy) ruleGroups() []string { - groups := make([]string, 0) - for _, rule := range p.Rules { - groups = append(groups, rule.Sources...) - groups = append(groups, rule.Destinations...) - } - - return groups -} - -// FirewallRule is a rule of the firewall. -type FirewallRule struct { - // PeerIP of the peer - PeerIP string - - // Direction of the traffic - Direction int - - // Action of the traffic - Action string - - // Protocol of the traffic - Protocol string - - // Port of the traffic - Port string -} - -// getPeerConnectionResources for a given peer -// -// This function returns the list of peers and firewall rules that are applicable to a given peer. -func (a *Account) getPeerConnectionResources(ctx context.Context, peerID string, validatedPeersMap map[string]struct{}) ([]*nbpeer.Peer, []*FirewallRule) { - generateResources, getAccumulatedResources := a.connResourcesGenerator(ctx) - for _, policy := range a.Policies { - if !policy.Enabled { - continue - } - - for _, rule := range policy.Rules { - if !rule.Enabled { - continue - } - - sourcePeers, peerInSources := a.getAllPeersFromGroups(ctx, rule.Sources, peerID, policy.SourcePostureChecks, validatedPeersMap) - destinationPeers, peerInDestinations := a.getAllPeersFromGroups(ctx, rule.Destinations, peerID, nil, validatedPeersMap) - - if rule.Bidirectional { - if peerInSources { - generateResources(rule, destinationPeers, firewallRuleDirectionIN) - } - if peerInDestinations { - generateResources(rule, sourcePeers, firewallRuleDirectionOUT) - } - } - - if peerInSources { - generateResources(rule, destinationPeers, firewallRuleDirectionOUT) - } - - if peerInDestinations { - generateResources(rule, sourcePeers, firewallRuleDirectionIN) - } - } - } - - return getAccumulatedResources() -} - -// connResourcesGenerator returns generator and accumulator function which returns the result of generator calls -// -// The generator function is used to generate the list of peers and firewall rules that are applicable to a given peer. -// It safe to call the generator function multiple times for same peer and different rules no duplicates will be -// generated. The accumulator function returns the result of all the generator calls. -func (a *Account) connResourcesGenerator(ctx context.Context) (func(*PolicyRule, []*nbpeer.Peer, int), func() ([]*nbpeer.Peer, []*FirewallRule)) { - rulesExists := make(map[string]struct{}) - peersExists := make(map[string]struct{}) - rules := make([]*FirewallRule, 0) - peers := make([]*nbpeer.Peer, 0) - - all, err := a.GetGroupAll() - if err != nil { - log.WithContext(ctx).Errorf("failed to get group all: %v", err) - all = &nbgroup.Group{} - } - - return func(rule *PolicyRule, groupPeers []*nbpeer.Peer, direction int) { - isAll := (len(all.Peers) - 1) == len(groupPeers) - for _, peer := range groupPeers { - if peer == nil { - continue - } - - if _, ok := peersExists[peer.ID]; !ok { - peers = append(peers, peer) - peersExists[peer.ID] = struct{}{} - } - - fr := FirewallRule{ - PeerIP: peer.IP.String(), - Direction: direction, - Action: string(rule.Action), - Protocol: string(rule.Protocol), - } - - if isAll { - fr.PeerIP = "0.0.0.0" - } - - ruleID := rule.ID + fr.PeerIP + strconv.Itoa(direction) + - fr.Protocol + fr.Action + strings.Join(rule.Ports, ",") - if _, ok := rulesExists[ruleID]; ok { - continue - } - rulesExists[ruleID] = struct{}{} - - if len(rule.Ports) == 0 { - rules = append(rules, &fr) - continue - } - - for _, port := range rule.Ports { - pr := fr // clone rule and add set new port - pr.Port = port - rules = append(rules, &pr) - } - } - }, func() ([]*nbpeer.Peer, []*FirewallRule) { - return peers, rules - } -} - // GetPolicy from the store -func (am *DefaultAccountManager) GetPolicy(ctx context.Context, accountID, policyID, userID string) (*Policy, error) { - user, err := am.Store.GetUserByUserID(ctx, LockingStrengthShare, userID) +func (am *DefaultAccountManager) GetPolicy(ctx context.Context, accountID, policyID, userID string) (*types.Policy, error) { + user, err := am.Store.GetUserByUserID(ctx, store.LockingStrengthShare, userID) if err != nil { return nil, err } @@ -353,15 +30,15 @@ func (am *DefaultAccountManager) GetPolicy(ctx context.Context, accountID, polic return nil, status.NewAdminPermissionError() } - return am.Store.GetPolicyByID(ctx, LockingStrengthShare, accountID, policyID) + return am.Store.GetPolicyByID(ctx, store.LockingStrengthShare, accountID, policyID) } // SavePolicy in the store -func (am *DefaultAccountManager) SavePolicy(ctx context.Context, accountID, userID string, policy *Policy) (*Policy, error) { +func (am *DefaultAccountManager) SavePolicy(ctx context.Context, accountID, userID string, policy *types.Policy) (*types.Policy, error) { unlock := am.Store.AcquireWriteLockByUID(ctx, accountID) defer unlock() - user, err := am.Store.GetUserByUserID(ctx, LockingStrengthShare, userID) + user, err := am.Store.GetUserByUserID(ctx, store.LockingStrengthShare, userID) if err != nil { return nil, err } @@ -378,7 +55,7 @@ func (am *DefaultAccountManager) SavePolicy(ctx context.Context, accountID, user var updateAccountPeers bool var action = activity.PolicyAdded - err = am.Store.ExecuteInTransaction(ctx, func(transaction Store) error { + err = am.Store.ExecuteInTransaction(ctx, func(transaction store.Store) error { if err = validatePolicy(ctx, transaction, accountID, policy); err != nil { return err } @@ -388,7 +65,7 @@ func (am *DefaultAccountManager) SavePolicy(ctx context.Context, accountID, user return err } - if err = transaction.IncrementNetworkSerial(ctx, LockingStrengthUpdate, accountID); err != nil { + if err = transaction.IncrementNetworkSerial(ctx, store.LockingStrengthUpdate, accountID); err != nil { return err } @@ -398,7 +75,7 @@ func (am *DefaultAccountManager) SavePolicy(ctx context.Context, accountID, user saveFunc = transaction.SavePolicy } - return saveFunc(ctx, LockingStrengthUpdate, policy) + return saveFunc(ctx, store.LockingStrengthUpdate, policy) }) if err != nil { return nil, err @@ -418,7 +95,7 @@ func (am *DefaultAccountManager) DeletePolicy(ctx context.Context, accountID, po unlock := am.Store.AcquireWriteLockByUID(ctx, accountID) defer unlock() - user, err := am.Store.GetUserByUserID(ctx, LockingStrengthShare, userID) + user, err := am.Store.GetUserByUserID(ctx, store.LockingStrengthShare, userID) if err != nil { return err } @@ -431,11 +108,11 @@ func (am *DefaultAccountManager) DeletePolicy(ctx context.Context, accountID, po return status.NewAdminPermissionError() } - var policy *Policy + var policy *types.Policy var updateAccountPeers bool - err = am.Store.ExecuteInTransaction(ctx, func(transaction Store) error { - policy, err = transaction.GetPolicyByID(ctx, LockingStrengthUpdate, accountID, policyID) + err = am.Store.ExecuteInTransaction(ctx, func(transaction store.Store) error { + policy, err = transaction.GetPolicyByID(ctx, store.LockingStrengthUpdate, accountID, policyID) if err != nil { return err } @@ -445,11 +122,11 @@ func (am *DefaultAccountManager) DeletePolicy(ctx context.Context, accountID, po return err } - if err = transaction.IncrementNetworkSerial(ctx, LockingStrengthUpdate, accountID); err != nil { + if err = transaction.IncrementNetworkSerial(ctx, store.LockingStrengthUpdate, accountID); err != nil { return err } - return transaction.DeletePolicy(ctx, LockingStrengthUpdate, accountID, policyID) + return transaction.DeletePolicy(ctx, store.LockingStrengthUpdate, accountID, policyID) }) if err != nil { return err @@ -465,8 +142,8 @@ func (am *DefaultAccountManager) DeletePolicy(ctx context.Context, accountID, po } // ListPolicies from the store. -func (am *DefaultAccountManager) ListPolicies(ctx context.Context, accountID, userID string) ([]*Policy, error) { - user, err := am.Store.GetUserByUserID(ctx, LockingStrengthShare, userID) +func (am *DefaultAccountManager) ListPolicies(ctx context.Context, accountID, userID string) ([]*types.Policy, error) { + user, err := am.Store.GetUserByUserID(ctx, store.LockingStrengthShare, userID) if err != nil { return nil, err } @@ -479,13 +156,13 @@ func (am *DefaultAccountManager) ListPolicies(ctx context.Context, accountID, us return nil, status.NewAdminPermissionError() } - return am.Store.GetAccountPolicies(ctx, LockingStrengthShare, accountID) + return am.Store.GetAccountPolicies(ctx, store.LockingStrengthShare, accountID) } // arePolicyChangesAffectPeers checks if changes to a policy will affect any associated peers. -func arePolicyChangesAffectPeers(ctx context.Context, transaction Store, accountID string, policy *Policy, isUpdate bool) (bool, error) { +func arePolicyChangesAffectPeers(ctx context.Context, transaction store.Store, accountID string, policy *types.Policy, isUpdate bool) (bool, error) { if isUpdate { - existingPolicy, err := transaction.GetPolicyByID(ctx, LockingStrengthShare, accountID, policy.ID) + existingPolicy, err := transaction.GetPolicyByID(ctx, store.LockingStrengthShare, accountID, policy.ID) if err != nil { return false, err } @@ -494,7 +171,7 @@ func arePolicyChangesAffectPeers(ctx context.Context, transaction Store, account return false, nil } - hasPeers, err := anyGroupHasPeers(ctx, transaction, policy.AccountID, existingPolicy.ruleGroups()) + hasPeers, err := anyGroupHasPeers(ctx, transaction, policy.AccountID, existingPolicy.RuleGroups()) if err != nil { return false, err } @@ -504,13 +181,13 @@ func arePolicyChangesAffectPeers(ctx context.Context, transaction Store, account } } - return anyGroupHasPeers(ctx, transaction, policy.AccountID, policy.ruleGroups()) + return anyGroupHasPeers(ctx, transaction, policy.AccountID, policy.RuleGroups()) } // validatePolicy validates the policy and its rules. -func validatePolicy(ctx context.Context, transaction Store, accountID string, policy *Policy) error { +func validatePolicy(ctx context.Context, transaction store.Store, accountID string, policy *types.Policy) error { if policy.ID != "" { - _, err := transaction.GetPolicyByID(ctx, LockingStrengthShare, accountID, policy.ID) + _, err := transaction.GetPolicyByID(ctx, store.LockingStrengthShare, accountID, policy.ID) if err != nil { return err } @@ -519,12 +196,12 @@ func validatePolicy(ctx context.Context, transaction Store, accountID string, po policy.AccountID = accountID } - groups, err := transaction.GetGroupsByIDs(ctx, LockingStrengthShare, accountID, policy.ruleGroups()) + groups, err := transaction.GetGroupsByIDs(ctx, store.LockingStrengthShare, accountID, policy.RuleGroups()) if err != nil { return err } - postureChecks, err := transaction.GetPostureChecksByIDs(ctx, LockingStrengthShare, accountID, policy.SourcePostureChecks) + postureChecks, err := transaction.GetPostureChecksByIDs(ctx, store.LockingStrengthShare, accountID, policy.SourcePostureChecks) if err != nil { return err } @@ -548,84 +225,6 @@ func validatePolicy(ctx context.Context, transaction Store, accountID string, po return nil } -// getAllPeersFromGroups for given peer ID and list of groups -// -// Returns a list of peers from specified groups that pass specified posture checks -// and a boolean indicating if the supplied peer ID exists within these groups. -// -// Important: Posture checks are applicable only to source group peers, -// for destination group peers, call this method with an empty list of sourcePostureChecksIDs -func (a *Account) getAllPeersFromGroups(ctx context.Context, groups []string, peerID string, sourcePostureChecksIDs []string, validatedPeersMap map[string]struct{}) ([]*nbpeer.Peer, bool) { - peerInGroups := false - filteredPeers := make([]*nbpeer.Peer, 0, len(groups)) - for _, g := range groups { - group, ok := a.Groups[g] - if !ok { - continue - } - - for _, p := range group.Peers { - peer, ok := a.Peers[p] - if !ok || peer == nil { - continue - } - - // validate the peer based on policy posture checks applied - isValid := a.validatePostureChecksOnPeer(ctx, sourcePostureChecksIDs, peer.ID) - if !isValid { - continue - } - - if _, ok := validatedPeersMap[peer.ID]; !ok { - continue - } - - if peer.ID == peerID { - peerInGroups = true - continue - } - - filteredPeers = append(filteredPeers, peer) - } - } - return filteredPeers, peerInGroups -} - -// validatePostureChecksOnPeer validates the posture checks on a peer -func (a *Account) validatePostureChecksOnPeer(ctx context.Context, sourcePostureChecksID []string, peerID string) bool { - peer, ok := a.Peers[peerID] - if !ok && peer == nil { - return false - } - - for _, postureChecksID := range sourcePostureChecksID { - postureChecks := a.getPostureChecks(postureChecksID) - if postureChecks == nil { - continue - } - - for _, check := range postureChecks.GetChecks() { - isValid, err := check.Check(ctx, *peer) - if err != nil { - log.WithContext(ctx).Debugf("an error occurred check %s: on peer: %s :%s", check.Name(), peer.ID, err.Error()) - } - if !isValid { - return false - } - } - } - return true -} - -func (a *Account) getPostureChecks(postureChecksID string) *posture.Checks { - for _, postureChecks := range a.PostureChecks { - if postureChecks.ID == postureChecksID { - return postureChecks - } - } - return nil -} - // getValidPostureCheckIDs filters and returns only the valid posture check IDs from the provided list. func getValidPostureCheckIDs(postureChecks map[string]*posture.Checks, postureChecksIds []string) []string { validIDs := make([]string, 0, len(postureChecksIds)) @@ -639,7 +238,7 @@ func getValidPostureCheckIDs(postureChecks map[string]*posture.Checks, postureCh } // getValidGroupIDs filters and returns only the valid group IDs from the provided list. -func getValidGroupIDs(groups map[string]*nbgroup.Group, groupIDs []string) []string { +func getValidGroupIDs(groups map[string]*types.Group, groupIDs []string) []string { validIDs := make([]string, 0, len(groupIDs)) for _, id := range groupIDs { if _, exists := groups[id]; exists { @@ -651,7 +250,7 @@ func getValidGroupIDs(groups map[string]*nbgroup.Group, groupIDs []string) []str } // toProtocolFirewallRules converts the firewall rules to the protocol firewall rules. -func toProtocolFirewallRules(rules []*FirewallRule) []*proto.FirewallRule { +func toProtocolFirewallRules(rules []*types.FirewallRule) []*proto.FirewallRule { result := make([]*proto.FirewallRule, len(rules)) for i := range rules { rule := rules[i] diff --git a/management/server/policy_test.go b/management/server/policy_test.go index 62d80f46e..fab738abe 100644 --- a/management/server/policy_test.go +++ b/management/server/policy_test.go @@ -10,13 +10,13 @@ import ( "github.com/stretchr/testify/assert" "golang.org/x/exp/slices" - nbgroup "github.com/netbirdio/netbird/management/server/group" nbpeer "github.com/netbirdio/netbird/management/server/peer" "github.com/netbirdio/netbird/management/server/posture" + "github.com/netbirdio/netbird/management/server/types" ) func TestAccount_getPeersByPolicy(t *testing.T) { - account := &Account{ + account := &types.Account{ Peers: map[string]*nbpeer.Peer{ "peerA": { ID: "peerA", @@ -59,7 +59,7 @@ func TestAccount_getPeersByPolicy(t *testing.T) { Status: &nbpeer.PeerStatus{}, }, }, - Groups: map[string]*nbgroup.Group{ + Groups: map[string]*types.Group{ "GroupAll": { ID: "GroupAll", Name: "All", @@ -87,21 +87,21 @@ func TestAccount_getPeersByPolicy(t *testing.T) { }, }, }, - Policies: []*Policy{ + Policies: []*types.Policy{ { ID: "RuleDefault", Name: "Default", Description: "This is a default rule that allows connections between all the resources", Enabled: true, - Rules: []*PolicyRule{ + Rules: []*types.PolicyRule{ { ID: "RuleDefault", Name: "Default", Description: "This is a default rule that allows connections between all the resources", Bidirectional: true, Enabled: true, - Protocol: PolicyRuleProtocolALL, - Action: PolicyTrafficActionAccept, + Protocol: types.PolicyRuleProtocolALL, + Action: types.PolicyTrafficActionAccept, Sources: []string{ "GroupAll", }, @@ -116,15 +116,15 @@ func TestAccount_getPeersByPolicy(t *testing.T) { Name: "Swarm", Description: "No description", Enabled: true, - Rules: []*PolicyRule{ + Rules: []*types.PolicyRule{ { ID: "RuleSwarm", Name: "Swarm", Description: "No description", Bidirectional: true, Enabled: true, - Protocol: PolicyRuleProtocolALL, - Action: PolicyTrafficActionAccept, + Protocol: types.PolicyRuleProtocolALL, + Action: types.PolicyTrafficActionAccept, Sources: []string{ "GroupSwarm", "GroupAll", @@ -145,14 +145,14 @@ func TestAccount_getPeersByPolicy(t *testing.T) { t.Run("check that all peers get map", func(t *testing.T) { for _, p := range account.Peers { - peers, firewallRules := account.getPeerConnectionResources(context.Background(), p.ID, validatedPeers) + peers, firewallRules := account.GetPeerConnectionResources(context.Background(), p.ID, validatedPeers) assert.GreaterOrEqual(t, len(peers), 2, "minimum number peers should present") assert.GreaterOrEqual(t, len(firewallRules), 2, "minimum number of firewall rules should present") } }) t.Run("check first peer map details", func(t *testing.T) { - peers, firewallRules := account.getPeerConnectionResources(context.Background(), "peerB", validatedPeers) + peers, firewallRules := account.GetPeerConnectionResources(context.Background(), "peerB", validatedPeers) assert.Len(t, peers, 7) assert.Contains(t, peers, account.Peers["peerA"]) assert.Contains(t, peers, account.Peers["peerC"]) @@ -160,45 +160,45 @@ func TestAccount_getPeersByPolicy(t *testing.T) { assert.Contains(t, peers, account.Peers["peerE"]) assert.Contains(t, peers, account.Peers["peerF"]) - epectedFirewallRules := []*FirewallRule{ + epectedFirewallRules := []*types.FirewallRule{ { PeerIP: "0.0.0.0", - Direction: firewallRuleDirectionIN, + Direction: types.FirewallRuleDirectionIN, Action: "accept", Protocol: "all", Port: "", }, { PeerIP: "0.0.0.0", - Direction: firewallRuleDirectionOUT, + Direction: types.FirewallRuleDirectionOUT, Action: "accept", Protocol: "all", Port: "", }, { PeerIP: "100.65.14.88", - Direction: firewallRuleDirectionIN, + Direction: types.FirewallRuleDirectionIN, Action: "accept", Protocol: "all", Port: "", }, { PeerIP: "100.65.14.88", - Direction: firewallRuleDirectionOUT, + Direction: types.FirewallRuleDirectionOUT, Action: "accept", Protocol: "all", Port: "", }, { PeerIP: "100.65.254.139", - Direction: firewallRuleDirectionOUT, + Direction: types.FirewallRuleDirectionOUT, Action: "accept", Protocol: "all", Port: "", }, { PeerIP: "100.65.254.139", - Direction: firewallRuleDirectionIN, + Direction: types.FirewallRuleDirectionIN, Action: "accept", Protocol: "all", Port: "", @@ -206,14 +206,14 @@ func TestAccount_getPeersByPolicy(t *testing.T) { { PeerIP: "100.65.62.5", - Direction: firewallRuleDirectionOUT, + Direction: types.FirewallRuleDirectionOUT, Action: "accept", Protocol: "all", Port: "", }, { PeerIP: "100.65.62.5", - Direction: firewallRuleDirectionIN, + Direction: types.FirewallRuleDirectionIN, Action: "accept", Protocol: "all", Port: "", @@ -221,14 +221,14 @@ func TestAccount_getPeersByPolicy(t *testing.T) { { PeerIP: "100.65.32.206", - Direction: firewallRuleDirectionOUT, + Direction: types.FirewallRuleDirectionOUT, Action: "accept", Protocol: "all", Port: "", }, { PeerIP: "100.65.32.206", - Direction: firewallRuleDirectionIN, + Direction: types.FirewallRuleDirectionIN, Action: "accept", Protocol: "all", Port: "", @@ -236,14 +236,14 @@ func TestAccount_getPeersByPolicy(t *testing.T) { { PeerIP: "100.65.250.202", - Direction: firewallRuleDirectionOUT, + Direction: types.FirewallRuleDirectionOUT, Action: "accept", Protocol: "all", Port: "", }, { PeerIP: "100.65.250.202", - Direction: firewallRuleDirectionIN, + Direction: types.FirewallRuleDirectionIN, Action: "accept", Protocol: "all", Port: "", @@ -251,14 +251,14 @@ func TestAccount_getPeersByPolicy(t *testing.T) { { PeerIP: "100.65.13.186", - Direction: firewallRuleDirectionOUT, + Direction: types.FirewallRuleDirectionOUT, Action: "accept", Protocol: "all", Port: "", }, { PeerIP: "100.65.13.186", - Direction: firewallRuleDirectionIN, + Direction: types.FirewallRuleDirectionIN, Action: "accept", Protocol: "all", Port: "", @@ -266,14 +266,14 @@ func TestAccount_getPeersByPolicy(t *testing.T) { { PeerIP: "100.65.29.55", - Direction: firewallRuleDirectionOUT, + Direction: types.FirewallRuleDirectionOUT, Action: "accept", Protocol: "all", Port: "", }, { PeerIP: "100.65.29.55", - Direction: firewallRuleDirectionIN, + Direction: types.FirewallRuleDirectionIN, Action: "accept", Protocol: "all", Port: "", @@ -289,7 +289,7 @@ func TestAccount_getPeersByPolicy(t *testing.T) { } func TestAccount_getPeersByPolicyDirect(t *testing.T) { - account := &Account{ + account := &types.Account{ Peers: map[string]*nbpeer.Peer{ "peerA": { ID: "peerA", @@ -307,7 +307,7 @@ func TestAccount_getPeersByPolicyDirect(t *testing.T) { Status: &nbpeer.PeerStatus{}, }, }, - Groups: map[string]*nbgroup.Group{ + Groups: map[string]*types.Group{ "GroupAll": { ID: "GroupAll", Name: "All", @@ -332,21 +332,21 @@ func TestAccount_getPeersByPolicyDirect(t *testing.T) { }, }, }, - Policies: []*Policy{ + Policies: []*types.Policy{ { ID: "RuleDefault", Name: "Default", Description: "This is a default rule that allows connections between all the resources", Enabled: false, - Rules: []*PolicyRule{ + Rules: []*types.PolicyRule{ { ID: "RuleDefault", Name: "Default", Description: "This is a default rule that allows connections between all the resources", Bidirectional: true, Enabled: false, - Protocol: PolicyRuleProtocolALL, - Action: PolicyTrafficActionAccept, + Protocol: types.PolicyRuleProtocolALL, + Action: types.PolicyTrafficActionAccept, Sources: []string{ "GroupAll", }, @@ -361,15 +361,15 @@ func TestAccount_getPeersByPolicyDirect(t *testing.T) { Name: "Swarm", Description: "No description", Enabled: true, - Rules: []*PolicyRule{ + Rules: []*types.PolicyRule{ { ID: "RuleSwarm", Name: "Swarm", Description: "No description", Bidirectional: true, Enabled: true, - Protocol: PolicyRuleProtocolALL, - Action: PolicyTrafficActionAccept, + Protocol: types.PolicyRuleProtocolALL, + Action: types.PolicyTrafficActionAccept, Sources: []string{ "GroupSwarm", }, @@ -388,20 +388,20 @@ func TestAccount_getPeersByPolicyDirect(t *testing.T) { } t.Run("check first peer map", func(t *testing.T) { - peers, firewallRules := account.getPeerConnectionResources(context.Background(), "peerB", approvedPeers) + peers, firewallRules := account.GetPeerConnectionResources(context.Background(), "peerB", approvedPeers) assert.Contains(t, peers, account.Peers["peerC"]) - epectedFirewallRules := []*FirewallRule{ + epectedFirewallRules := []*types.FirewallRule{ { PeerIP: "100.65.254.139", - Direction: firewallRuleDirectionIN, + Direction: types.FirewallRuleDirectionIN, Action: "accept", Protocol: "all", Port: "", }, { PeerIP: "100.65.254.139", - Direction: firewallRuleDirectionOUT, + Direction: types.FirewallRuleDirectionOUT, Action: "accept", Protocol: "all", Port: "", @@ -416,20 +416,20 @@ func TestAccount_getPeersByPolicyDirect(t *testing.T) { }) t.Run("check second peer map", func(t *testing.T) { - peers, firewallRules := account.getPeerConnectionResources(context.Background(), "peerC", approvedPeers) + peers, firewallRules := account.GetPeerConnectionResources(context.Background(), "peerC", approvedPeers) assert.Contains(t, peers, account.Peers["peerB"]) - epectedFirewallRules := []*FirewallRule{ + epectedFirewallRules := []*types.FirewallRule{ { PeerIP: "100.65.80.39", - Direction: firewallRuleDirectionIN, + Direction: types.FirewallRuleDirectionIN, Action: "accept", Protocol: "all", Port: "", }, { PeerIP: "100.65.80.39", - Direction: firewallRuleDirectionOUT, + Direction: types.FirewallRuleDirectionOUT, Action: "accept", Protocol: "all", Port: "", @@ -446,13 +446,13 @@ func TestAccount_getPeersByPolicyDirect(t *testing.T) { account.Policies[1].Rules[0].Bidirectional = false t.Run("check first peer map directional only", func(t *testing.T) { - peers, firewallRules := account.getPeerConnectionResources(context.Background(), "peerB", approvedPeers) + peers, firewallRules := account.GetPeerConnectionResources(context.Background(), "peerB", approvedPeers) assert.Contains(t, peers, account.Peers["peerC"]) - epectedFirewallRules := []*FirewallRule{ + epectedFirewallRules := []*types.FirewallRule{ { PeerIP: "100.65.254.139", - Direction: firewallRuleDirectionOUT, + Direction: types.FirewallRuleDirectionOUT, Action: "accept", Protocol: "all", Port: "", @@ -467,13 +467,13 @@ func TestAccount_getPeersByPolicyDirect(t *testing.T) { }) t.Run("check second peer map directional only", func(t *testing.T) { - peers, firewallRules := account.getPeerConnectionResources(context.Background(), "peerC", approvedPeers) + peers, firewallRules := account.GetPeerConnectionResources(context.Background(), "peerC", approvedPeers) assert.Contains(t, peers, account.Peers["peerB"]) - epectedFirewallRules := []*FirewallRule{ + epectedFirewallRules := []*types.FirewallRule{ { PeerIP: "100.65.80.39", - Direction: firewallRuleDirectionIN, + Direction: types.FirewallRuleDirectionIN, Action: "accept", Protocol: "all", Port: "", @@ -489,7 +489,7 @@ func TestAccount_getPeersByPolicyDirect(t *testing.T) { } func TestAccount_getPeersByPolicyPostureChecks(t *testing.T) { - account := &Account{ + account := &types.Account{ Peers: map[string]*nbpeer.Peer{ "peerA": { ID: "peerA", @@ -582,7 +582,7 @@ func TestAccount_getPeersByPolicyPostureChecks(t *testing.T) { }, }, }, - Groups: map[string]*nbgroup.Group{ + Groups: map[string]*types.Group{ "GroupAll": { ID: "GroupAll", Name: "All", @@ -630,17 +630,17 @@ func TestAccount_getPeersByPolicyPostureChecks(t *testing.T) { }, } - account.Policies = append(account.Policies, &Policy{ + account.Policies = append(account.Policies, &types.Policy{ ID: "PolicyPostureChecks", Name: "", Description: "This is the policy with posture checks applied", Enabled: true, - Rules: []*PolicyRule{ + Rules: []*types.PolicyRule{ { ID: "RuleSwarm", Name: "Swarm", Enabled: true, - Action: PolicyTrafficActionAccept, + Action: types.PolicyTrafficActionAccept, Destinations: []string{ "GroupSwarm", }, @@ -648,7 +648,7 @@ func TestAccount_getPeersByPolicyPostureChecks(t *testing.T) { "GroupAll", }, Bidirectional: false, - Protocol: PolicyRuleProtocolTCP, + Protocol: types.PolicyRuleProtocolTCP, Ports: []string{"80"}, }, }, @@ -664,7 +664,7 @@ func TestAccount_getPeersByPolicyPostureChecks(t *testing.T) { t.Run("verify peer's network map with default group peer list", func(t *testing.T) { // peerB doesn't fulfill the NB posture check but is included in the destination group Swarm, // will establish a connection with all source peers satisfying the NB posture check. - peers, firewallRules := account.getPeerConnectionResources(context.Background(), "peerB", approvedPeers) + peers, firewallRules := account.GetPeerConnectionResources(context.Background(), "peerB", approvedPeers) assert.Len(t, peers, 4) assert.Len(t, firewallRules, 4) assert.Contains(t, peers, account.Peers["peerA"]) @@ -674,13 +674,13 @@ func TestAccount_getPeersByPolicyPostureChecks(t *testing.T) { // peerC satisfy the NB posture check, should establish connection to all destination group peer's // We expect a single permissive firewall rule which all outgoing connections - peers, firewallRules = account.getPeerConnectionResources(context.Background(), "peerC", approvedPeers) + peers, firewallRules = account.GetPeerConnectionResources(context.Background(), "peerC", approvedPeers) assert.Len(t, peers, len(account.Groups["GroupSwarm"].Peers)) assert.Len(t, firewallRules, 1) - expectedFirewallRules := []*FirewallRule{ + expectedFirewallRules := []*types.FirewallRule{ { PeerIP: "0.0.0.0", - Direction: firewallRuleDirectionOUT, + Direction: types.FirewallRuleDirectionOUT, Action: "accept", Protocol: "tcp", Port: "80", @@ -690,7 +690,7 @@ func TestAccount_getPeersByPolicyPostureChecks(t *testing.T) { // peerE doesn't fulfill the NB posture check and exists in only destination group Swarm, // all source group peers satisfying the NB posture check should establish connection - peers, firewallRules = account.getPeerConnectionResources(context.Background(), "peerE", approvedPeers) + peers, firewallRules = account.GetPeerConnectionResources(context.Background(), "peerE", approvedPeers) assert.Len(t, peers, 4) assert.Len(t, firewallRules, 4) assert.Contains(t, peers, account.Peers["peerA"]) @@ -700,7 +700,7 @@ func TestAccount_getPeersByPolicyPostureChecks(t *testing.T) { // peerI doesn't fulfill the OS version posture check and exists in only destination group Swarm, // all source group peers satisfying the NB posture check should establish connection - peers, firewallRules = account.getPeerConnectionResources(context.Background(), "peerI", approvedPeers) + peers, firewallRules = account.GetPeerConnectionResources(context.Background(), "peerI", approvedPeers) assert.Len(t, peers, 4) assert.Len(t, firewallRules, 4) assert.Contains(t, peers, account.Peers["peerA"]) @@ -715,19 +715,19 @@ func TestAccount_getPeersByPolicyPostureChecks(t *testing.T) { // peerB doesn't satisfy the NB posture check, and doesn't exist in destination group peer's // no connection should be established to any peer of destination group - peers, firewallRules := account.getPeerConnectionResources(context.Background(), "peerB", approvedPeers) + peers, firewallRules := account.GetPeerConnectionResources(context.Background(), "peerB", approvedPeers) assert.Len(t, peers, 0) assert.Len(t, firewallRules, 0) // peerI doesn't satisfy the OS version posture check, and doesn't exist in destination group peer's // no connection should be established to any peer of destination group - peers, firewallRules = account.getPeerConnectionResources(context.Background(), "peerI", approvedPeers) + peers, firewallRules = account.GetPeerConnectionResources(context.Background(), "peerI", approvedPeers) assert.Len(t, peers, 0) assert.Len(t, firewallRules, 0) // peerC satisfy the NB posture check, should establish connection to all destination group peer's // We expect a single permissive firewall rule which all outgoing connections - peers, firewallRules = account.getPeerConnectionResources(context.Background(), "peerC", approvedPeers) + peers, firewallRules = account.GetPeerConnectionResources(context.Background(), "peerC", approvedPeers) assert.Len(t, peers, len(account.Groups["GroupSwarm"].Peers)) assert.Len(t, firewallRules, len(account.Groups["GroupSwarm"].Peers)) @@ -742,14 +742,14 @@ func TestAccount_getPeersByPolicyPostureChecks(t *testing.T) { // peerE doesn't fulfill the NB posture check and exists in only destination group Swarm, // all source group peers satisfying the NB posture check should establish connection - peers, firewallRules = account.getPeerConnectionResources(context.Background(), "peerE", approvedPeers) + peers, firewallRules = account.GetPeerConnectionResources(context.Background(), "peerE", approvedPeers) assert.Len(t, peers, 3) assert.Len(t, firewallRules, 3) assert.Contains(t, peers, account.Peers["peerA"]) assert.Contains(t, peers, account.Peers["peerC"]) assert.Contains(t, peers, account.Peers["peerD"]) - peers, firewallRules = account.getPeerConnectionResources(context.Background(), "peerA", approvedPeers) + peers, firewallRules = account.GetPeerConnectionResources(context.Background(), "peerA", approvedPeers) assert.Len(t, peers, 5) // assert peers from Group Swarm assert.Contains(t, peers, account.Peers["peerD"]) @@ -760,45 +760,45 @@ func TestAccount_getPeersByPolicyPostureChecks(t *testing.T) { // assert peers from Group All assert.Contains(t, peers, account.Peers["peerC"]) - expectedFirewallRules := []*FirewallRule{ + expectedFirewallRules := []*types.FirewallRule{ { PeerIP: "100.65.62.5", - Direction: firewallRuleDirectionOUT, + Direction: types.FirewallRuleDirectionOUT, Action: "accept", Protocol: "tcp", Port: "80", }, { PeerIP: "100.65.32.206", - Direction: firewallRuleDirectionOUT, + Direction: types.FirewallRuleDirectionOUT, Action: "accept", Protocol: "tcp", Port: "80", }, { PeerIP: "100.65.13.186", - Direction: firewallRuleDirectionOUT, + Direction: types.FirewallRuleDirectionOUT, Action: "accept", Protocol: "tcp", Port: "80", }, { PeerIP: "100.65.29.55", - Direction: firewallRuleDirectionOUT, + Direction: types.FirewallRuleDirectionOUT, Action: "accept", Protocol: "tcp", Port: "80", }, { PeerIP: "100.65.254.139", - Direction: firewallRuleDirectionIN, + Direction: types.FirewallRuleDirectionIN, Action: "accept", Protocol: "tcp", Port: "80", }, { PeerIP: "100.65.62.5", - Direction: firewallRuleDirectionIN, + Direction: types.FirewallRuleDirectionIN, Action: "accept", Protocol: "tcp", Port: "80", @@ -809,8 +809,8 @@ func TestAccount_getPeersByPolicyPostureChecks(t *testing.T) { }) } -func sortFunc() func(a *FirewallRule, b *FirewallRule) int { - return func(a, b *FirewallRule) int { +func sortFunc() func(a *types.FirewallRule, b *types.FirewallRule) int { + return func(a, b *types.FirewallRule) int { // Concatenate PeerIP and Direction as string for comparison aStr := a.PeerIP + fmt.Sprintf("%d", a.Direction) bStr := b.PeerIP + fmt.Sprintf("%d", b.Direction) @@ -829,7 +829,7 @@ func sortFunc() func(a *FirewallRule, b *FirewallRule) int { func TestPolicyAccountPeersUpdate(t *testing.T) { manager, account, peer1, peer2, peer3 := setupNetworkMapTest(t) - err := manager.SaveGroups(context.Background(), account.Id, userID, []*nbgroup.Group{ + err := manager.SaveGroups(context.Background(), account.Id, userID, []*types.Group{ { ID: "groupA", Name: "GroupA", @@ -858,9 +858,9 @@ func TestPolicyAccountPeersUpdate(t *testing.T) { manager.peersUpdateManager.CloseChannel(context.Background(), peer1.ID) }) - var policyWithGroupRulesNoPeers *Policy - var policyWithDestinationPeersOnly *Policy - var policyWithSourceAndDestinationPeers *Policy + var policyWithGroupRulesNoPeers *types.Policy + var policyWithDestinationPeersOnly *types.Policy + var policyWithSourceAndDestinationPeers *types.Policy // Saving policy with rule groups with no peers should not update account's peers and not send peer update t.Run("saving policy with rule groups with no peers", func(t *testing.T) { @@ -870,16 +870,16 @@ func TestPolicyAccountPeersUpdate(t *testing.T) { close(done) }() - policyWithGroupRulesNoPeers, err = manager.SavePolicy(context.Background(), account.Id, userID, &Policy{ + policyWithGroupRulesNoPeers, err = manager.SavePolicy(context.Background(), account.Id, userID, &types.Policy{ AccountID: account.Id, Enabled: true, - Rules: []*PolicyRule{ + Rules: []*types.PolicyRule{ { Enabled: true, Sources: []string{"groupB"}, Destinations: []string{"groupC"}, Bidirectional: true, - Action: PolicyTrafficActionAccept, + Action: types.PolicyTrafficActionAccept, }, }, }) @@ -901,17 +901,17 @@ func TestPolicyAccountPeersUpdate(t *testing.T) { close(done) }() - _, err = manager.SavePolicy(context.Background(), account.Id, userID, &Policy{ + _, err = manager.SavePolicy(context.Background(), account.Id, userID, &types.Policy{ AccountID: account.Id, Enabled: true, - Rules: []*PolicyRule{ + Rules: []*types.PolicyRule{ { Enabled: true, Sources: []string{"groupA"}, Destinations: []string{"groupB"}, - Protocol: PolicyRuleProtocolTCP, + Protocol: types.PolicyRuleProtocolTCP, Bidirectional: true, - Action: PolicyTrafficActionAccept, + Action: types.PolicyTrafficActionAccept, }, }, }) @@ -933,17 +933,17 @@ func TestPolicyAccountPeersUpdate(t *testing.T) { close(done) }() - policyWithDestinationPeersOnly, err = manager.SavePolicy(context.Background(), account.Id, userID, &Policy{ + policyWithDestinationPeersOnly, err = manager.SavePolicy(context.Background(), account.Id, userID, &types.Policy{ AccountID: account.Id, Enabled: true, - Rules: []*PolicyRule{ + Rules: []*types.PolicyRule{ { Enabled: true, Sources: []string{"groupC"}, Destinations: []string{"groupD"}, Bidirectional: true, - Protocol: PolicyRuleProtocolTCP, - Action: PolicyTrafficActionAccept, + Protocol: types.PolicyRuleProtocolTCP, + Action: types.PolicyTrafficActionAccept, }, }, }) @@ -965,16 +965,16 @@ func TestPolicyAccountPeersUpdate(t *testing.T) { close(done) }() - policyWithSourceAndDestinationPeers, err = manager.SavePolicy(context.Background(), account.Id, userID, &Policy{ + policyWithSourceAndDestinationPeers, err = manager.SavePolicy(context.Background(), account.Id, userID, &types.Policy{ AccountID: account.Id, Enabled: true, - Rules: []*PolicyRule{ + Rules: []*types.PolicyRule{ { Enabled: true, Sources: []string{"groupA"}, Destinations: []string{"groupD"}, Bidirectional: true, - Action: PolicyTrafficActionAccept, + Action: types.PolicyTrafficActionAccept, }, }, }) diff --git a/management/server/posture_checks.go b/management/server/posture_checks.go index 0467efedb..c9329766b 100644 --- a/management/server/posture_checks.go +++ b/management/server/posture_checks.go @@ -12,10 +12,12 @@ import ( "github.com/netbirdio/netbird/management/server/activity" "github.com/netbirdio/netbird/management/server/posture" "github.com/netbirdio/netbird/management/server/status" + "github.com/netbirdio/netbird/management/server/store" + "github.com/netbirdio/netbird/management/server/types" ) func (am *DefaultAccountManager) GetPostureChecks(ctx context.Context, accountID, postureChecksID, userID string) (*posture.Checks, error) { - user, err := am.Store.GetUserByUserID(ctx, LockingStrengthShare, userID) + user, err := am.Store.GetUserByUserID(ctx, store.LockingStrengthShare, userID) if err != nil { return nil, err } @@ -28,7 +30,7 @@ func (am *DefaultAccountManager) GetPostureChecks(ctx context.Context, accountID return nil, status.NewAdminPermissionError() } - return am.Store.GetPostureChecksByID(ctx, LockingStrengthShare, accountID, postureChecksID) + return am.Store.GetPostureChecksByID(ctx, store.LockingStrengthShare, accountID, postureChecksID) } // SavePostureChecks saves a posture check. @@ -36,7 +38,7 @@ func (am *DefaultAccountManager) SavePostureChecks(ctx context.Context, accountI unlock := am.Store.AcquireWriteLockByUID(ctx, accountID) defer unlock() - user, err := am.Store.GetUserByUserID(ctx, LockingStrengthShare, userID) + user, err := am.Store.GetUserByUserID(ctx, store.LockingStrengthShare, userID) if err != nil { return nil, err } @@ -53,7 +55,7 @@ func (am *DefaultAccountManager) SavePostureChecks(ctx context.Context, accountI var isUpdate = postureChecks.ID != "" var action = activity.PostureCheckCreated - err = am.Store.ExecuteInTransaction(ctx, func(transaction Store) error { + err = am.Store.ExecuteInTransaction(ctx, func(transaction store.Store) error { if err = validatePostureChecks(ctx, transaction, accountID, postureChecks); err != nil { return err } @@ -64,7 +66,7 @@ func (am *DefaultAccountManager) SavePostureChecks(ctx context.Context, accountI return err } - if err = transaction.IncrementNetworkSerial(ctx, LockingStrengthUpdate, accountID); err != nil { + if err = transaction.IncrementNetworkSerial(ctx, store.LockingStrengthUpdate, accountID); err != nil { return err } @@ -72,7 +74,7 @@ func (am *DefaultAccountManager) SavePostureChecks(ctx context.Context, accountI } postureChecks.AccountID = accountID - return transaction.SavePostureChecks(ctx, LockingStrengthUpdate, postureChecks) + return transaction.SavePostureChecks(ctx, store.LockingStrengthUpdate, postureChecks) }) if err != nil { return nil, err @@ -92,7 +94,7 @@ func (am *DefaultAccountManager) DeletePostureChecks(ctx context.Context, accoun unlock := am.Store.AcquireWriteLockByUID(ctx, accountID) defer unlock() - user, err := am.Store.GetUserByUserID(ctx, LockingStrengthShare, userID) + user, err := am.Store.GetUserByUserID(ctx, store.LockingStrengthShare, userID) if err != nil { return err } @@ -107,8 +109,8 @@ func (am *DefaultAccountManager) DeletePostureChecks(ctx context.Context, accoun var postureChecks *posture.Checks - err = am.Store.ExecuteInTransaction(ctx, func(transaction Store) error { - postureChecks, err = transaction.GetPostureChecksByID(ctx, LockingStrengthShare, accountID, postureChecksID) + err = am.Store.ExecuteInTransaction(ctx, func(transaction store.Store) error { + postureChecks, err = transaction.GetPostureChecksByID(ctx, store.LockingStrengthShare, accountID, postureChecksID) if err != nil { return err } @@ -117,11 +119,11 @@ func (am *DefaultAccountManager) DeletePostureChecks(ctx context.Context, accoun return err } - if err = transaction.IncrementNetworkSerial(ctx, LockingStrengthUpdate, accountID); err != nil { + if err = transaction.IncrementNetworkSerial(ctx, store.LockingStrengthUpdate, accountID); err != nil { return err } - return transaction.DeletePostureChecks(ctx, LockingStrengthUpdate, accountID, postureChecksID) + return transaction.DeletePostureChecks(ctx, store.LockingStrengthUpdate, accountID, postureChecksID) }) if err != nil { return err @@ -134,7 +136,7 @@ func (am *DefaultAccountManager) DeletePostureChecks(ctx context.Context, accoun // ListPostureChecks returns a list of posture checks. func (am *DefaultAccountManager) ListPostureChecks(ctx context.Context, accountID, userID string) ([]*posture.Checks, error) { - user, err := am.Store.GetUserByUserID(ctx, LockingStrengthShare, userID) + user, err := am.Store.GetUserByUserID(ctx, store.LockingStrengthShare, userID) if err != nil { return nil, err } @@ -147,11 +149,11 @@ func (am *DefaultAccountManager) ListPostureChecks(ctx context.Context, accountI return nil, status.NewAdminPermissionError() } - return am.Store.GetAccountPostureChecks(ctx, LockingStrengthShare, accountID) + return am.Store.GetAccountPostureChecks(ctx, store.LockingStrengthShare, accountID) } // getPeerPostureChecks returns the posture checks applied for a given peer. -func (am *DefaultAccountManager) getPeerPostureChecks(account *Account, peerID string) ([]*posture.Checks, error) { +func (am *DefaultAccountManager) getPeerPostureChecks(account *types.Account, peerID string) ([]*posture.Checks, error) { peerPostureChecks := make(map[string]*posture.Checks) if len(account.PostureChecks) == 0 { @@ -172,15 +174,15 @@ func (am *DefaultAccountManager) getPeerPostureChecks(account *Account, peerID s } // arePostureCheckChangesAffectPeers checks if the changes in posture checks are affecting peers. -func arePostureCheckChangesAffectPeers(ctx context.Context, transaction Store, accountID, postureCheckID string) (bool, error) { - policies, err := transaction.GetAccountPolicies(ctx, LockingStrengthShare, accountID) +func arePostureCheckChangesAffectPeers(ctx context.Context, transaction store.Store, accountID, postureCheckID string) (bool, error) { + policies, err := transaction.GetAccountPolicies(ctx, store.LockingStrengthShare, accountID) if err != nil { return false, err } for _, policy := range policies { if slices.Contains(policy.SourcePostureChecks, postureCheckID) { - hasPeers, err := anyGroupHasPeers(ctx, transaction, accountID, policy.ruleGroups()) + hasPeers, err := anyGroupHasPeers(ctx, transaction, accountID, policy.RuleGroups()) if err != nil { return false, err } @@ -195,21 +197,21 @@ func arePostureCheckChangesAffectPeers(ctx context.Context, transaction Store, a } // validatePostureChecks validates the posture checks. -func validatePostureChecks(ctx context.Context, transaction Store, accountID string, postureChecks *posture.Checks) error { +func validatePostureChecks(ctx context.Context, transaction store.Store, accountID string, postureChecks *posture.Checks) error { if err := postureChecks.Validate(); err != nil { return status.Errorf(status.InvalidArgument, err.Error()) //nolint } // If the posture check already has an ID, verify its existence in the store. if postureChecks.ID != "" { - if _, err := transaction.GetPostureChecksByID(ctx, LockingStrengthShare, accountID, postureChecks.ID); err != nil { + if _, err := transaction.GetPostureChecksByID(ctx, store.LockingStrengthShare, accountID, postureChecks.ID); err != nil { return err } return nil } // For new posture checks, ensure no duplicates by name. - checks, err := transaction.GetAccountPostureChecks(ctx, LockingStrengthShare, accountID) + checks, err := transaction.GetAccountPostureChecks(ctx, store.LockingStrengthShare, accountID) if err != nil { return err } @@ -226,7 +228,7 @@ func validatePostureChecks(ctx context.Context, transaction Store, accountID str } // addPolicyPostureChecks adds posture checks from a policy to the peer posture checks map if the peer is in the policy's source groups. -func addPolicyPostureChecks(account *Account, peerID string, policy *Policy, peerPostureChecks map[string]*posture.Checks) error { +func addPolicyPostureChecks(account *types.Account, peerID string, policy *types.Policy, peerPostureChecks map[string]*posture.Checks) error { isInGroup, err := isPeerInPolicySourceGroups(account, peerID, policy) if err != nil { return err @@ -237,7 +239,7 @@ func addPolicyPostureChecks(account *Account, peerID string, policy *Policy, pee } for _, sourcePostureCheckID := range policy.SourcePostureChecks { - postureCheck := account.getPostureChecks(sourcePostureCheckID) + postureCheck := account.GetPostureChecks(sourcePostureCheckID) if postureCheck == nil { return errors.New("failed to add policy posture checks: posture checks not found") } @@ -248,7 +250,7 @@ func addPolicyPostureChecks(account *Account, peerID string, policy *Policy, pee } // isPeerInPolicySourceGroups checks if a peer is present in any of the policy rule source groups. -func isPeerInPolicySourceGroups(account *Account, peerID string, policy *Policy) (bool, error) { +func isPeerInPolicySourceGroups(account *types.Account, peerID string, policy *types.Policy) (bool, error) { for _, rule := range policy.Rules { if !rule.Enabled { continue @@ -270,8 +272,8 @@ func isPeerInPolicySourceGroups(account *Account, peerID string, policy *Policy) } // isPostureCheckLinkedToPolicy checks whether the posture check is linked to any account policy. -func isPostureCheckLinkedToPolicy(ctx context.Context, transaction Store, postureChecksID, accountID string) error { - policies, err := transaction.GetAccountPolicies(ctx, LockingStrengthShare, accountID) +func isPostureCheckLinkedToPolicy(ctx context.Context, transaction store.Store, postureChecksID, accountID string) error { + policies, err := transaction.GetAccountPolicies(ctx, store.LockingStrengthShare, accountID) if err != nil { return err } diff --git a/management/server/posture_checks_test.go b/management/server/posture_checks_test.go index 93e5741cf..bad162f05 100644 --- a/management/server/posture_checks_test.go +++ b/management/server/posture_checks_test.go @@ -8,7 +8,8 @@ import ( "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" - "github.com/netbirdio/netbird/management/server/group" + "github.com/netbirdio/netbird/management/server/store" + "github.com/netbirdio/netbird/management/server/types" "github.com/netbirdio/netbird/management/server/posture" ) @@ -92,17 +93,17 @@ func TestDefaultAccountManager_PostureCheck(t *testing.T) { }) } -func initTestPostureChecksAccount(am *DefaultAccountManager) (*Account, error) { +func initTestPostureChecksAccount(am *DefaultAccountManager) (*types.Account, error) { accountID := "testingAccount" domain := "example.com" - admin := &User{ + admin := &types.User{ Id: adminUserID, - Role: UserRoleAdmin, + Role: types.UserRoleAdmin, } - user := &User{ + user := &types.User{ Id: regularUserID, - Role: UserRoleUser, + Role: types.UserRoleUser, } account := newAccountWithId(context.Background(), accountID, groupAdminUserID, domain) @@ -120,7 +121,7 @@ func initTestPostureChecksAccount(am *DefaultAccountManager) (*Account, error) { func TestPostureCheckAccountPeersUpdate(t *testing.T) { manager, account, peer1, peer2, peer3 := setupNetworkMapTest(t) - err := manager.SaveGroups(context.Background(), account.Id, userID, []*group.Group{ + err := manager.SaveGroups(context.Background(), account.Id, userID, []*types.Group{ { ID: "groupA", Name: "GroupA", @@ -209,15 +210,15 @@ func TestPostureCheckAccountPeersUpdate(t *testing.T) { } }) - policy := &Policy{ + policy := &types.Policy{ Enabled: true, - Rules: []*PolicyRule{ + Rules: []*types.PolicyRule{ { Enabled: true, Sources: []string{"groupA"}, Destinations: []string{"groupA"}, Bidirectional: true, - Action: PolicyTrafficActionAccept, + Action: types.PolicyTrafficActionAccept, }, }, SourcePostureChecks: []string{postureCheckB.ID}, @@ -312,15 +313,15 @@ func TestPostureCheckAccountPeersUpdate(t *testing.T) { // Updating linked posture check to policy with no peers should not trigger account peers update and not send peer update t.Run("updating linked posture check to policy with no peers", func(t *testing.T) { - _, err = manager.SavePolicy(context.Background(), account.Id, userID, &Policy{ + _, err = manager.SavePolicy(context.Background(), account.Id, userID, &types.Policy{ Enabled: true, - Rules: []*PolicyRule{ + Rules: []*types.PolicyRule{ { Enabled: true, Sources: []string{"groupB"}, Destinations: []string{"groupC"}, Bidirectional: true, - Action: PolicyTrafficActionAccept, + Action: types.PolicyTrafficActionAccept, }, }, SourcePostureChecks: []string{postureCheckB.ID}, @@ -356,15 +357,15 @@ func TestPostureCheckAccountPeersUpdate(t *testing.T) { manager.peersUpdateManager.CloseChannel(context.Background(), peer2.ID) }) - _, err = manager.SavePolicy(context.Background(), account.Id, userID, &Policy{ + _, err = manager.SavePolicy(context.Background(), account.Id, userID, &types.Policy{ Enabled: true, - Rules: []*PolicyRule{ + Rules: []*types.PolicyRule{ { Enabled: true, Sources: []string{"groupB"}, Destinations: []string{"groupA"}, Bidirectional: true, - Action: PolicyTrafficActionAccept, + Action: types.PolicyTrafficActionAccept, }, }, SourcePostureChecks: []string{postureCheckB.ID}, @@ -395,15 +396,15 @@ func TestPostureCheckAccountPeersUpdate(t *testing.T) { // Updating linked client posture check to policy where source has peers but destination does not, // should trigger account peers update and send peer update t.Run("updating linked posture check to policy where source has peers but destination does not", func(t *testing.T) { - _, err = manager.SavePolicy(context.Background(), account.Id, userID, &Policy{ + _, err = manager.SavePolicy(context.Background(), account.Id, userID, &types.Policy{ Enabled: true, - Rules: []*PolicyRule{ + Rules: []*types.PolicyRule{ { Enabled: true, Sources: []string{"groupA"}, Destinations: []string{"groupB"}, Bidirectional: true, - Action: PolicyTrafficActionAccept, + Action: types.PolicyTrafficActionAccept, }, }, SourcePostureChecks: []string{postureCheckB.ID}, @@ -443,18 +444,18 @@ func TestArePostureCheckChangesAffectPeers(t *testing.T) { account, err := initTestPostureChecksAccount(manager) require.NoError(t, err, "failed to init testing account") - groupA := &group.Group{ + groupA := &types.Group{ ID: "groupA", AccountID: account.Id, Peers: []string{"peer1"}, } - groupB := &group.Group{ + groupB := &types.Group{ ID: "groupB", AccountID: account.Id, Peers: []string{}, } - err = manager.Store.SaveGroups(context.Background(), LockingStrengthUpdate, []*group.Group{groupA, groupB}) + err = manager.Store.SaveGroups(context.Background(), store.LockingStrengthUpdate, []*types.Group{groupA, groupB}) require.NoError(t, err, "failed to save groups") postureCheckA := &posture.Checks{ @@ -477,9 +478,9 @@ func TestArePostureCheckChangesAffectPeers(t *testing.T) { postureCheckB, err = manager.SavePostureChecks(context.Background(), account.Id, adminUserID, postureCheckB) require.NoError(t, err, "failed to save postureCheckB") - policy := &Policy{ + policy := &types.Policy{ AccountID: account.Id, - Rules: []*PolicyRule{ + Rules: []*types.PolicyRule{ { Enabled: true, Sources: []string{"groupA"}, @@ -534,7 +535,7 @@ func TestArePostureCheckChangesAffectPeers(t *testing.T) { t.Run("posture check is linked to policy but no peers in groups", func(t *testing.T) { groupA.Peers = []string{} - err = manager.Store.SaveGroup(context.Background(), LockingStrengthUpdate, groupA) + err = manager.Store.SaveGroup(context.Background(), store.LockingStrengthUpdate, groupA) require.NoError(t, err, "failed to save groups") result, err := arePostureCheckChangesAffectPeers(context.Background(), manager.Store, account.Id, postureCheckA.ID) diff --git a/management/server/resource.go b/management/server/resource.go new file mode 100644 index 000000000..77a5612b3 --- /dev/null +++ b/management/server/resource.go @@ -0,0 +1,21 @@ +package server + +type ResourceType string + +const ( + // nolint + hostType ResourceType = "Host" + //nolint + subnetType ResourceType = "Subnet" + // nolint + domainType ResourceType = "Domain" +) + +func (p ResourceType) String() string { + return string(p) +} + +type Resource struct { + Type ResourceType + ID string +} diff --git a/management/server/route.go b/management/server/route.go index 23bea87e3..49d76bc43 100644 --- a/management/server/route.go +++ b/management/server/route.go @@ -4,15 +4,12 @@ import ( "context" "fmt" "net/netip" - "slices" - "strconv" - "strings" "unicode/utf8" "github.com/rs/xid" - log "github.com/sirupsen/logrus" - nbpeer "github.com/netbirdio/netbird/management/server/peer" + "github.com/netbirdio/netbird/management/server/store" + "github.com/netbirdio/netbird/management/server/types" "github.com/netbirdio/netbird/management/domain" "github.com/netbirdio/netbird/management/proto" @@ -21,33 +18,9 @@ import ( "github.com/netbirdio/netbird/route" ) -// RouteFirewallRule a firewall rule applicable for a routed network. -type RouteFirewallRule struct { - // SourceRanges IP ranges of the routing peers. - SourceRanges []string - - // Action of the traffic when the rule is applicable - Action string - - // Destination a network prefix for the routed traffic - Destination string - - // Protocol of the traffic - Protocol string - - // Port of the traffic - Port uint16 - - // PortRange represents the range of ports for a firewall rule - PortRange RulePortRange - - // isDynamic indicates whether the rule is for DNS routing - IsDynamic bool -} - // GetRoute gets a route object from account and route IDs func (am *DefaultAccountManager) GetRoute(ctx context.Context, accountID string, routeID route.ID, userID string) (*route.Route, error) { - user, err := am.Store.GetUserByUserID(ctx, LockingStrengthShare, userID) + user, err := am.Store.GetUserByUserID(ctx, store.LockingStrengthShare, userID) if err != nil { return nil, err } @@ -56,11 +29,11 @@ func (am *DefaultAccountManager) GetRoute(ctx context.Context, accountID string, return nil, status.Errorf(status.PermissionDenied, "only users with admin power can view Network Routes") } - return am.Store.GetRouteByID(ctx, LockingStrengthShare, string(routeID), accountID) + return am.Store.GetRouteByID(ctx, store.LockingStrengthShare, string(routeID), accountID) } // checkRoutePrefixOrDomainsExistForPeers checks if a route with a given prefix exists for a single peer or multiple peer groups. -func (am *DefaultAccountManager) checkRoutePrefixOrDomainsExistForPeers(account *Account, peerID string, routeID route.ID, peerGroupIDs []string, prefix netip.Prefix, domains domain.List) error { +func (am *DefaultAccountManager) checkRoutePrefixOrDomainsExistForPeers(account *types.Account, peerID string, routeID route.ID, peerGroupIDs []string, prefix netip.Prefix, domains domain.List) error { // routes can have both peer and peer_groups routesWithPrefix := account.GetRoutesByPrefixOrDomains(prefix, domains) @@ -364,7 +337,7 @@ func (am *DefaultAccountManager) DeleteRoute(ctx context.Context, accountID stri // ListRoutes returns a list of routes from account func (am *DefaultAccountManager) ListRoutes(ctx context.Context, accountID, userID string) ([]*route.Route, error) { - user, err := am.Store.GetUserByUserID(ctx, LockingStrengthShare, userID) + user, err := am.Store.GetUserByUserID(ctx, store.LockingStrengthShare, userID) if err != nil { return nil, err } @@ -373,7 +346,7 @@ func (am *DefaultAccountManager) ListRoutes(ctx context.Context, accountID, user return nil, status.Errorf(status.PermissionDenied, "only users with admin power can view Network Routes") } - return am.Store.GetAccountRoutes(ctx, LockingStrengthShare, accountID) + return am.Store.GetAccountRoutes(ctx, store.LockingStrengthShare, accountID) } func toProtocolRoute(route *route.Route) *proto.Route { @@ -404,244 +377,7 @@ func getPlaceholderIP() netip.Prefix { return netip.PrefixFrom(netip.AddrFrom4([4]byte{192, 0, 2, 0}), 32) } -// getPeerRoutesFirewallRules gets the routes firewall rules associated with a routing peer ID for the account. -func (a *Account) getPeerRoutesFirewallRules(ctx context.Context, peerID string, validatedPeersMap map[string]struct{}) []*RouteFirewallRule { - routesFirewallRules := make([]*RouteFirewallRule, 0, len(a.Routes)) - - enabledRoutes, _ := a.getRoutingPeerRoutes(ctx, peerID) - for _, route := range enabledRoutes { - // If no access control groups are specified, accept all traffic. - if len(route.AccessControlGroups) == 0 { - defaultPermit := getDefaultPermit(route) - routesFirewallRules = append(routesFirewallRules, defaultPermit...) - continue - } - - distributionPeers := a.getDistributionGroupsPeers(route) - - for _, accessGroup := range route.AccessControlGroups { - policies := getAllRoutePoliciesFromGroups(a, []string{accessGroup}) - rules := a.getRouteFirewallRules(ctx, peerID, policies, route, validatedPeersMap, distributionPeers) - routesFirewallRules = append(routesFirewallRules, rules...) - } - } - - return routesFirewallRules -} - -func (a *Account) getRouteFirewallRules(ctx context.Context, peerID string, policies []*Policy, route *route.Route, validatedPeersMap map[string]struct{}, distributionPeers map[string]struct{}) []*RouteFirewallRule { - var fwRules []*RouteFirewallRule - for _, policy := range policies { - if !policy.Enabled { - continue - } - - for _, rule := range policy.Rules { - if !rule.Enabled { - continue - } - - rulePeers := a.getRulePeers(rule, peerID, distributionPeers, validatedPeersMap) - rules := generateRouteFirewallRules(ctx, route, rule, rulePeers, firewallRuleDirectionIN) - fwRules = append(fwRules, rules...) - } - } - return fwRules -} - -func (a *Account) getRulePeers(rule *PolicyRule, peerID string, distributionPeers map[string]struct{}, validatedPeersMap map[string]struct{}) []*nbpeer.Peer { - distPeersWithPolicy := make(map[string]struct{}) - for _, id := range rule.Sources { - group := a.Groups[id] - if group == nil { - continue - } - - for _, pID := range group.Peers { - if pID == peerID { - continue - } - _, distPeer := distributionPeers[pID] - _, valid := validatedPeersMap[pID] - if distPeer && valid { - distPeersWithPolicy[pID] = struct{}{} - } - } - } - - distributionGroupPeers := make([]*nbpeer.Peer, 0, len(distPeersWithPolicy)) - for pID := range distPeersWithPolicy { - peer := a.Peers[pID] - if peer == nil { - continue - } - distributionGroupPeers = append(distributionGroupPeers, peer) - } - return distributionGroupPeers -} - -func (a *Account) getDistributionGroupsPeers(route *route.Route) map[string]struct{} { - distPeers := make(map[string]struct{}) - for _, id := range route.Groups { - group := a.Groups[id] - if group == nil { - continue - } - - for _, pID := range group.Peers { - distPeers[pID] = struct{}{} - } - } - return distPeers -} - -func getDefaultPermit(route *route.Route) []*RouteFirewallRule { - var rules []*RouteFirewallRule - - sources := []string{"0.0.0.0/0"} - if route.Network.Addr().Is6() { - sources = []string{"::/0"} - } - rule := RouteFirewallRule{ - SourceRanges: sources, - Action: string(PolicyTrafficActionAccept), - Destination: route.Network.String(), - Protocol: string(PolicyRuleProtocolALL), - IsDynamic: route.IsDynamic(), - } - - rules = append(rules, &rule) - - // dynamic routes always contain an IPv4 placeholder as destination, hence we must add IPv6 rules additionally - if route.IsDynamic() { - ruleV6 := rule - ruleV6.SourceRanges = []string{"::/0"} - rules = append(rules, &ruleV6) - } - - return rules -} - -// getAllRoutePoliciesFromGroups retrieves route policies associated with the specified access control groups -// and returns a list of policies that have rules with destinations matching the specified groups. -func getAllRoutePoliciesFromGroups(account *Account, accessControlGroups []string) []*Policy { - routePolicies := make([]*Policy, 0) - for _, groupID := range accessControlGroups { - group, ok := account.Groups[groupID] - if !ok { - continue - } - - for _, policy := range account.Policies { - for _, rule := range policy.Rules { - exist := slices.ContainsFunc(rule.Destinations, func(groupID string) bool { - return groupID == group.ID - }) - if exist { - routePolicies = append(routePolicies, policy) - continue - } - } - } - } - - return routePolicies -} - -// generateRouteFirewallRules generates a list of firewall rules for a given route. -func generateRouteFirewallRules(ctx context.Context, route *route.Route, rule *PolicyRule, groupPeers []*nbpeer.Peer, direction int) []*RouteFirewallRule { - rulesExists := make(map[string]struct{}) - rules := make([]*RouteFirewallRule, 0) - - sourceRanges := make([]string, 0, len(groupPeers)) - for _, peer := range groupPeers { - if peer == nil { - continue - } - sourceRanges = append(sourceRanges, fmt.Sprintf(AllowedIPsFormat, peer.IP)) - } - - baseRule := RouteFirewallRule{ - SourceRanges: sourceRanges, - Action: string(rule.Action), - Destination: route.Network.String(), - Protocol: string(rule.Protocol), - IsDynamic: route.IsDynamic(), - } - - // generate rule for port range - if len(rule.Ports) == 0 { - rules = append(rules, generateRulesWithPortRanges(baseRule, rule, rulesExists)...) - } else { - rules = append(rules, generateRulesWithPorts(ctx, baseRule, rule, rulesExists)...) - - } - - // TODO: generate IPv6 rules for dynamic routes - - return rules -} - -// generateRuleIDBase generates the base rule ID for checking duplicates. -func generateRuleIDBase(rule *PolicyRule, baseRule RouteFirewallRule) string { - return rule.ID + strings.Join(baseRule.SourceRanges, ",") + strconv.Itoa(firewallRuleDirectionIN) + baseRule.Protocol + baseRule.Action -} - -// generateRulesForPeer generates rules for a given peer based on ports and port ranges. -func generateRulesWithPortRanges(baseRule RouteFirewallRule, rule *PolicyRule, rulesExists map[string]struct{}) []*RouteFirewallRule { - rules := make([]*RouteFirewallRule, 0) - - ruleIDBase := generateRuleIDBase(rule, baseRule) - if len(rule.Ports) == 0 { - if len(rule.PortRanges) == 0 { - if _, ok := rulesExists[ruleIDBase]; !ok { - rulesExists[ruleIDBase] = struct{}{} - rules = append(rules, &baseRule) - } - } else { - for _, portRange := range rule.PortRanges { - ruleID := fmt.Sprintf("%s%d-%d", ruleIDBase, portRange.Start, portRange.End) - if _, ok := rulesExists[ruleID]; !ok { - rulesExists[ruleID] = struct{}{} - pr := baseRule - pr.PortRange = portRange - rules = append(rules, &pr) - } - } - } - return rules - } - - return rules -} - -// generateRulesWithPorts generates rules when specific ports are provided. -func generateRulesWithPorts(ctx context.Context, baseRule RouteFirewallRule, rule *PolicyRule, rulesExists map[string]struct{}) []*RouteFirewallRule { - rules := make([]*RouteFirewallRule, 0) - ruleIDBase := generateRuleIDBase(rule, baseRule) - - for _, port := range rule.Ports { - ruleID := ruleIDBase + port - if _, ok := rulesExists[ruleID]; ok { - continue - } - rulesExists[ruleID] = struct{}{} - - pr := baseRule - p, err := strconv.ParseUint(port, 10, 16) - if err != nil { - log.WithContext(ctx).Errorf("failed to parse port %s for rule: %s", port, rule.ID) - continue - } - - pr.Port = uint16(p) - rules = append(rules, &pr) - } - - return rules -} - -func toProtocolRoutesFirewallRules(rules []*RouteFirewallRule) []*proto.RouteFirewallRule { +func toProtocolRoutesFirewallRules(rules []*types.RouteFirewallRule) []*proto.RouteFirewallRule { result := make([]*proto.RouteFirewallRule, len(rules)) for i := range rules { rule := rules[i] @@ -660,7 +396,7 @@ func toProtocolRoutesFirewallRules(rules []*RouteFirewallRule) []*proto.RouteFir // getProtoDirection converts the direction to proto.RuleDirection. func getProtoDirection(direction int) proto.RuleDirection { - if direction == firewallRuleDirectionOUT { + if direction == types.FirewallRuleDirectionOUT { return proto.RuleDirection_OUT } return proto.RuleDirection_IN @@ -668,7 +404,7 @@ func getProtoDirection(direction int) proto.RuleDirection { // getProtoAction converts the action to proto.RuleAction. func getProtoAction(action string) proto.RuleAction { - if action == string(PolicyTrafficActionDrop) { + if action == string(types.PolicyTrafficActionDrop) { return proto.RuleAction_DROP } return proto.RuleAction_ACCEPT @@ -676,14 +412,14 @@ func getProtoAction(action string) proto.RuleAction { // getProtoProtocol converts the protocol to proto.RuleProtocol. func getProtoProtocol(protocol string) proto.RuleProtocol { - switch PolicyRuleProtocolType(protocol) { - case PolicyRuleProtocolALL: + switch types.PolicyRuleProtocolType(protocol) { + case types.PolicyRuleProtocolALL: return proto.RuleProtocol_ALL - case PolicyRuleProtocolTCP: + case types.PolicyRuleProtocolTCP: return proto.RuleProtocol_TCP - case PolicyRuleProtocolUDP: + case types.PolicyRuleProtocolUDP: return proto.RuleProtocol_UDP - case PolicyRuleProtocolICMP: + case types.PolicyRuleProtocolICMP: return proto.RuleProtocol_ICMP default: return proto.RuleProtocol_UNKNOWN @@ -691,7 +427,7 @@ func getProtoProtocol(protocol string) proto.RuleProtocol { } // getProtoPortInfo converts the port info to proto.PortInfo. -func getProtoPortInfo(rule *RouteFirewallRule) *proto.PortInfo { +func getProtoPortInfo(rule *types.RouteFirewallRule) *proto.PortInfo { var portInfo proto.PortInfo if rule.Port != 0 { portInfo.PortSelection = &proto.PortInfo_Port{Port: uint32(rule.Port)} @@ -708,6 +444,6 @@ func getProtoPortInfo(rule *RouteFirewallRule) *proto.PortInfo { // isRouteChangeAffectPeers checks if a given route affects peers by determining // if it has a routing peer, distribution, or peer groups that include peers -func (am *DefaultAccountManager) isRouteChangeAffectPeers(account *Account, route *route.Route) bool { +func (am *DefaultAccountManager) isRouteChangeAffectPeers(account *types.Account, route *route.Route) bool { return am.anyGroupHasPeers(account, route.Groups) || am.anyGroupHasPeers(account, route.PeerGroups) || route.Peer != "" } diff --git a/management/server/route_test.go b/management/server/route_test.go index 8bf9a3aeb..5e2e24611 100644 --- a/management/server/route_test.go +++ b/management/server/route_test.go @@ -15,9 +15,10 @@ import ( "github.com/netbirdio/netbird/management/domain" "github.com/netbirdio/netbird/management/server/activity" - nbgroup "github.com/netbirdio/netbird/management/server/group" nbpeer "github.com/netbirdio/netbird/management/server/peer" + "github.com/netbirdio/netbird/management/server/store" "github.com/netbirdio/netbird/management/server/telemetry" + "github.com/netbirdio/netbird/management/server/types" "github.com/netbirdio/netbird/route" ) @@ -1092,9 +1093,9 @@ func TestGetNetworkMap_RouteSyncPeerGroups(t *testing.T) { require.NoError(t, err) assert.Len(t, peer4Routes.Routes, 1, "HA route should have 1 server route") - groups, err := am.Store.GetAccountGroups(context.Background(), LockingStrengthShare, account.Id) + groups, err := am.Store.GetAccountGroups(context.Background(), store.LockingStrengthShare, account.Id) require.NoError(t, err) - var groupHA1, groupHA2 *nbgroup.Group + var groupHA1, groupHA2 *types.Group for _, group := range groups { switch group.Name { case routeGroupHA1: @@ -1202,7 +1203,7 @@ func TestGetNetworkMap_RouteSync(t *testing.T) { require.Len(t, peer2Routes.Routes, 1, "we should receive one route") require.True(t, peer1Routes.Routes[0].IsEqual(peer2Routes.Routes[0]), "routes should be the same for peers in the same group") - newGroup := &nbgroup.Group{ + newGroup := &types.Group{ ID: xid.New().String(), Name: "peer1 group", Peers: []string{peer1ID}, @@ -1255,10 +1256,10 @@ func createRouterManager(t *testing.T) (*DefaultAccountManager, error) { return BuildManager(context.Background(), store, NewPeersUpdateManager(nil), nil, "", "netbird.selfhosted", eventStore, nil, false, MocIntegratedValidator{}, metrics) } -func createRouterStore(t *testing.T) (Store, error) { +func createRouterStore(t *testing.T) (store.Store, error) { t.Helper() dataDir := t.TempDir() - store, cleanUp, err := NewTestStoreFromSQL(context.Background(), "", dataDir) + store, cleanUp, err := store.NewTestStoreFromSQL(context.Background(), "", dataDir) if err != nil { return nil, err } @@ -1267,7 +1268,7 @@ func createRouterStore(t *testing.T) (Store, error) { return store, nil } -func initTestRouteAccount(t *testing.T, am *DefaultAccountManager) (*Account, error) { +func initTestRouteAccount(t *testing.T, am *DefaultAccountManager) (*types.Account, error) { t.Helper() accountID := "testingAcc" @@ -1279,8 +1280,8 @@ func initTestRouteAccount(t *testing.T, am *DefaultAccountManager) (*Account, er return nil, err } - ips := account.getTakenIPs() - peer1IP, err := AllocatePeerIP(account.Network.Net, ips) + ips := account.GetTakenIPs() + peer1IP, err := types.AllocatePeerIP(account.Network.Net, ips) if err != nil { return nil, err } @@ -1306,8 +1307,8 @@ func initTestRouteAccount(t *testing.T, am *DefaultAccountManager) (*Account, er } account.Peers[peer1.ID] = peer1 - ips = account.getTakenIPs() - peer2IP, err := AllocatePeerIP(account.Network.Net, ips) + ips = account.GetTakenIPs() + peer2IP, err := types.AllocatePeerIP(account.Network.Net, ips) if err != nil { return nil, err } @@ -1333,8 +1334,8 @@ func initTestRouteAccount(t *testing.T, am *DefaultAccountManager) (*Account, er } account.Peers[peer2.ID] = peer2 - ips = account.getTakenIPs() - peer3IP, err := AllocatePeerIP(account.Network.Net, ips) + ips = account.GetTakenIPs() + peer3IP, err := types.AllocatePeerIP(account.Network.Net, ips) if err != nil { return nil, err } @@ -1360,8 +1361,8 @@ func initTestRouteAccount(t *testing.T, am *DefaultAccountManager) (*Account, er } account.Peers[peer3.ID] = peer3 - ips = account.getTakenIPs() - peer4IP, err := AllocatePeerIP(account.Network.Net, ips) + ips = account.GetTakenIPs() + peer4IP, err := types.AllocatePeerIP(account.Network.Net, ips) if err != nil { return nil, err } @@ -1387,8 +1388,8 @@ func initTestRouteAccount(t *testing.T, am *DefaultAccountManager) (*Account, er } account.Peers[peer4.ID] = peer4 - ips = account.getTakenIPs() - peer5IP, err := AllocatePeerIP(account.Network.Net, ips) + ips = account.GetTakenIPs() + peer5IP, err := types.AllocatePeerIP(account.Network.Net, ips) if err != nil { return nil, err } @@ -1439,7 +1440,7 @@ func initTestRouteAccount(t *testing.T, am *DefaultAccountManager) (*Account, er return nil, err } - newGroup := []*nbgroup.Group{ + newGroup := []*types.Group{ { ID: routeGroup1, Name: routeGroup1, @@ -1491,7 +1492,7 @@ func TestAccount_getPeersRoutesFirewall(t *testing.T) { peerKIp = "100.65.29.66" ) - account := &Account{ + account := &types.Account{ Peers: map[string]*nbpeer.Peer{ "peerA": { ID: "peerA", @@ -1555,7 +1556,7 @@ func TestAccount_getPeersRoutesFirewall(t *testing.T) { Status: &nbpeer.PeerStatus{}, }, }, - Groups: map[string]*nbgroup.Group{ + Groups: map[string]*types.Group{ "routingPeer1": { ID: "routingPeer1", Name: "RoutingPeer1", @@ -1685,19 +1686,19 @@ func TestAccount_getPeersRoutesFirewall(t *testing.T) { AccessControlGroups: []string{"route4"}, }, }, - Policies: []*Policy{ + Policies: []*types.Policy{ { ID: "RuleRoute1", Name: "Route1", Enabled: true, - Rules: []*PolicyRule{ + Rules: []*types.PolicyRule{ { ID: "RuleRoute1", Name: "ruleRoute1", Bidirectional: true, Enabled: true, - Protocol: PolicyRuleProtocolALL, - Action: PolicyTrafficActionAccept, + Protocol: types.PolicyRuleProtocolALL, + Action: types.PolicyTrafficActionAccept, Ports: []string{"80", "320"}, Sources: []string{ "dev", @@ -1712,15 +1713,15 @@ func TestAccount_getPeersRoutesFirewall(t *testing.T) { ID: "RuleRoute2", Name: "Route2", Enabled: true, - Rules: []*PolicyRule{ + Rules: []*types.PolicyRule{ { ID: "RuleRoute2", Name: "ruleRoute2", Bidirectional: true, Enabled: true, - Protocol: PolicyRuleProtocolTCP, - Action: PolicyTrafficActionAccept, - PortRanges: []RulePortRange{ + Protocol: types.PolicyRuleProtocolTCP, + Action: types.PolicyTrafficActionAccept, + PortRanges: []types.RulePortRange{ { Start: 80, End: 350, @@ -1742,14 +1743,14 @@ func TestAccount_getPeersRoutesFirewall(t *testing.T) { ID: "RuleRoute4", Name: "RuleRoute4", Enabled: true, - Rules: []*PolicyRule{ + Rules: []*types.PolicyRule{ { ID: "RuleRoute4", Name: "RuleRoute4", Bidirectional: true, Enabled: true, - Protocol: PolicyRuleProtocolTCP, - Action: PolicyTrafficActionAccept, + Protocol: types.PolicyRuleProtocolTCP, + Action: types.PolicyTrafficActionAccept, Ports: []string{"80"}, Sources: []string{ "restrictQA", @@ -1764,14 +1765,14 @@ func TestAccount_getPeersRoutesFirewall(t *testing.T) { ID: "RuleRoute5", Name: "RuleRoute5", Enabled: true, - Rules: []*PolicyRule{ + Rules: []*types.PolicyRule{ { ID: "RuleRoute5", Name: "RuleRoute5", Bidirectional: true, Enabled: true, - Protocol: PolicyRuleProtocolALL, - Action: PolicyTrafficActionAccept, + Protocol: types.PolicyRuleProtocolALL, + Action: types.PolicyTrafficActionAccept, Sources: []string{ "unrestrictedQA", }, @@ -1791,28 +1792,28 @@ func TestAccount_getPeersRoutesFirewall(t *testing.T) { t.Run("check applied policies for the route", func(t *testing.T) { route1 := account.Routes["route1"] - policies := getAllRoutePoliciesFromGroups(account, route1.AccessControlGroups) + policies := types.GetAllRoutePoliciesFromGroups(account, route1.AccessControlGroups) assert.Len(t, policies, 1) route2 := account.Routes["route2"] - policies = getAllRoutePoliciesFromGroups(account, route2.AccessControlGroups) + policies = types.GetAllRoutePoliciesFromGroups(account, route2.AccessControlGroups) assert.Len(t, policies, 1) route3 := account.Routes["route3"] - policies = getAllRoutePoliciesFromGroups(account, route3.AccessControlGroups) + policies = types.GetAllRoutePoliciesFromGroups(account, route3.AccessControlGroups) assert.Len(t, policies, 0) }) t.Run("check peer routes firewall rules", func(t *testing.T) { - routesFirewallRules := account.getPeerRoutesFirewallRules(context.Background(), "peerA", validatedPeers) + routesFirewallRules := account.GetPeerRoutesFirewallRules(context.Background(), "peerA", validatedPeers) assert.Len(t, routesFirewallRules, 4) - expectedRoutesFirewallRules := []*RouteFirewallRule{ + expectedRoutesFirewallRules := []*types.RouteFirewallRule{ { SourceRanges: []string{ - fmt.Sprintf(AllowedIPsFormat, peerCIp), - fmt.Sprintf(AllowedIPsFormat, peerHIp), - fmt.Sprintf(AllowedIPsFormat, peerBIp), + fmt.Sprintf(types.AllowedIPsFormat, peerCIp), + fmt.Sprintf(types.AllowedIPsFormat, peerHIp), + fmt.Sprintf(types.AllowedIPsFormat, peerBIp), }, Action: "accept", Destination: "192.168.0.0/16", @@ -1821,9 +1822,9 @@ func TestAccount_getPeersRoutesFirewall(t *testing.T) { }, { SourceRanges: []string{ - fmt.Sprintf(AllowedIPsFormat, peerCIp), - fmt.Sprintf(AllowedIPsFormat, peerHIp), - fmt.Sprintf(AllowedIPsFormat, peerBIp), + fmt.Sprintf(types.AllowedIPsFormat, peerCIp), + fmt.Sprintf(types.AllowedIPsFormat, peerHIp), + fmt.Sprintf(types.AllowedIPsFormat, peerBIp), }, Action: "accept", Destination: "192.168.0.0/16", @@ -1831,10 +1832,10 @@ func TestAccount_getPeersRoutesFirewall(t *testing.T) { Port: 320, }, } - additionalFirewallRule := []*RouteFirewallRule{ + additionalFirewallRule := []*types.RouteFirewallRule{ { SourceRanges: []string{ - fmt.Sprintf(AllowedIPsFormat, peerJIp), + fmt.Sprintf(types.AllowedIPsFormat, peerJIp), }, Action: "accept", Destination: "192.168.10.0/16", @@ -1843,7 +1844,7 @@ func TestAccount_getPeersRoutesFirewall(t *testing.T) { }, { SourceRanges: []string{ - fmt.Sprintf(AllowedIPsFormat, peerKIp), + fmt.Sprintf(types.AllowedIPsFormat, peerKIp), }, Action: "accept", Destination: "192.168.10.0/16", @@ -1854,21 +1855,21 @@ func TestAccount_getPeersRoutesFirewall(t *testing.T) { assert.ElementsMatch(t, orderRuleSourceRanges(routesFirewallRules), orderRuleSourceRanges(append(expectedRoutesFirewallRules, additionalFirewallRule...))) // peerD is also the routing peer for route1, should contain same routes firewall rules as peerA - routesFirewallRules = account.getPeerRoutesFirewallRules(context.Background(), "peerD", validatedPeers) + routesFirewallRules = account.GetPeerRoutesFirewallRules(context.Background(), "peerD", validatedPeers) assert.Len(t, routesFirewallRules, 2) assert.ElementsMatch(t, orderRuleSourceRanges(routesFirewallRules), orderRuleSourceRanges(expectedRoutesFirewallRules)) // peerE is a single routing peer for route 2 and route 3 - routesFirewallRules = account.getPeerRoutesFirewallRules(context.Background(), "peerE", validatedPeers) + routesFirewallRules = account.GetPeerRoutesFirewallRules(context.Background(), "peerE", validatedPeers) assert.Len(t, routesFirewallRules, 3) - expectedRoutesFirewallRules = []*RouteFirewallRule{ + expectedRoutesFirewallRules = []*types.RouteFirewallRule{ { SourceRanges: []string{"100.65.250.202/32", "100.65.13.186/32"}, Action: "accept", Destination: existingNetwork.String(), Protocol: "tcp", - PortRange: RulePortRange{Start: 80, End: 350}, + PortRange: types.RulePortRange{Start: 80, End: 350}, }, { SourceRanges: []string{"0.0.0.0/0"}, @@ -1888,14 +1889,14 @@ func TestAccount_getPeersRoutesFirewall(t *testing.T) { assert.ElementsMatch(t, orderRuleSourceRanges(routesFirewallRules), orderRuleSourceRanges(expectedRoutesFirewallRules)) // peerC is part of route1 distribution groups but should not receive the routes firewall rules - routesFirewallRules = account.getPeerRoutesFirewallRules(context.Background(), "peerC", validatedPeers) + routesFirewallRules = account.GetPeerRoutesFirewallRules(context.Background(), "peerC", validatedPeers) assert.Len(t, routesFirewallRules, 0) }) } // orderList is a helper function to sort a list of strings -func orderRuleSourceRanges(ruleList []*RouteFirewallRule) []*RouteFirewallRule { +func orderRuleSourceRanges(ruleList []*types.RouteFirewallRule) []*types.RouteFirewallRule { for _, rule := range ruleList { sort.Strings(rule.SourceRanges) } @@ -1909,7 +1910,7 @@ func TestRouteAccountPeersUpdate(t *testing.T) { account, err := initTestRouteAccount(t, manager) require.NoError(t, err, "failed to init testing account") - err = manager.SaveGroups(context.Background(), account.Id, userID, []*nbgroup.Group{ + err = manager.SaveGroups(context.Background(), account.Id, userID, []*types.Group{ { ID: "groupA", Name: "GroupA", @@ -2105,7 +2106,7 @@ func TestRouteAccountPeersUpdate(t *testing.T) { close(done) }() - err = manager.SaveGroup(context.Background(), account.Id, userID, &nbgroup.Group{ + err = manager.SaveGroup(context.Background(), account.Id, userID, &types.Group{ ID: "groupB", Name: "GroupB", Peers: []string{peer1ID}, @@ -2145,7 +2146,7 @@ func TestRouteAccountPeersUpdate(t *testing.T) { close(done) }() - err = manager.SaveGroup(context.Background(), account.Id, userID, &nbgroup.Group{ + err = manager.SaveGroup(context.Background(), account.Id, userID, &types.Group{ ID: "groupC", Name: "GroupC", Peers: []string{peer1ID}, diff --git a/management/server/settings/manager.go b/management/server/settings/manager.go new file mode 100644 index 000000000..7d564a02e --- /dev/null +++ b/management/server/settings/manager.go @@ -0,0 +1,26 @@ +package settings + +import ( + "context" + + "github.com/netbirdio/netbird/management/server/store" + "github.com/netbirdio/netbird/management/server/types" +) + +type Manager interface { + GetSettings(ctx context.Context, accountID string, userID string) (*types.Settings, error) +} + +type managerImpl struct { + store store.Store +} + +func NewManager(store store.Store) Manager { + return &managerImpl{ + store: store, + } +} + +func (m *managerImpl) GetSettings(ctx context.Context, accountID string, userID string) (*types.Settings, error) { + return m.store.GetAccountSettings(ctx, store.LockingStrengthShare, accountID) +} diff --git a/management/server/setupkey.go b/management/server/setupkey.go index ef431d3ad..9a4a1efb8 100644 --- a/management/server/setupkey.go +++ b/management/server/setupkey.go @@ -2,34 +2,16 @@ package server import ( "context" - "crypto/sha256" - b64 "encoding/base64" - "hash/fnv" "slices" - "strconv" - "strings" "time" - "unicode/utf8" - "github.com/google/uuid" log "github.com/sirupsen/logrus" "github.com/netbirdio/netbird/management/server/activity" "github.com/netbirdio/netbird/management/server/status" -) - -const ( - // SetupKeyReusable is a multi-use key (can be used for multiple machines) - SetupKeyReusable SetupKeyType = "reusable" - // SetupKeyOneOff is a single use key (can be used only once) - SetupKeyOneOff SetupKeyType = "one-off" - - // DefaultSetupKeyDuration = 1 month - DefaultSetupKeyDuration = 24 * 30 * time.Hour - // DefaultSetupKeyName is a default name of the default setup key - DefaultSetupKeyName = "Default key" - // SetupKeyUnlimitedUsage indicates an unlimited usage of a setup key - SetupKeyUnlimitedUsage = 0 + "github.com/netbirdio/netbird/management/server/store" + "github.com/netbirdio/netbird/management/server/types" + "github.com/netbirdio/netbird/management/server/util" ) const ( @@ -67,169 +49,14 @@ type SetupKeyUpdateOperation struct { Values []string } -// SetupKeyType is the type of setup key -type SetupKeyType string - -// SetupKey represents a pre-authorized key used to register machines (peers) -type SetupKey struct { - Id string - // AccountID is a reference to Account that this object belongs - AccountID string `json:"-" gorm:"index"` - Key string - KeySecret string - Name string - Type SetupKeyType - CreatedAt time.Time - ExpiresAt time.Time - UpdatedAt time.Time `gorm:"autoUpdateTime:false"` - // Revoked indicates whether the key was revoked or not (we don't remove them for tracking purposes) - Revoked bool - // UsedTimes indicates how many times the key was used - UsedTimes int - // LastUsed last time the key was used for peer registration - LastUsed time.Time - // AutoGroups is a list of Group IDs that are auto assigned to a Peer when it uses this key to register - AutoGroups []string `gorm:"serializer:json"` - // UsageLimit indicates the number of times this key can be used to enroll a machine. - // The value of 0 indicates the unlimited usage. - UsageLimit int - // Ephemeral indicate if the peers will be ephemeral or not - Ephemeral bool -} - -// Copy copies SetupKey to a new object -func (key *SetupKey) Copy() *SetupKey { - autoGroups := make([]string, len(key.AutoGroups)) - copy(autoGroups, key.AutoGroups) - if key.UpdatedAt.IsZero() { - key.UpdatedAt = key.CreatedAt - } - return &SetupKey{ - Id: key.Id, - AccountID: key.AccountID, - Key: key.Key, - KeySecret: key.KeySecret, - Name: key.Name, - Type: key.Type, - CreatedAt: key.CreatedAt, - ExpiresAt: key.ExpiresAt, - UpdatedAt: key.UpdatedAt, - Revoked: key.Revoked, - UsedTimes: key.UsedTimes, - LastUsed: key.LastUsed, - AutoGroups: autoGroups, - UsageLimit: key.UsageLimit, - Ephemeral: key.Ephemeral, - } -} - -// EventMeta returns activity event meta related to the setup key -func (key *SetupKey) EventMeta() map[string]any { - return map[string]any{"name": key.Name, "type": key.Type, "key": key.KeySecret} -} - -// hiddenKey returns the Key value hidden with "*" and a 5 character prefix. -// E.g., "831F6*******************************" -func hiddenKey(key string, length int) string { - prefix := key[0:5] - if length > utf8.RuneCountInString(key) { - length = utf8.RuneCountInString(key) - len(prefix) - } - return prefix + strings.Repeat("*", length) -} - -// IncrementUsage makes a copy of a key, increments the UsedTimes by 1 and sets LastUsed to now -func (key *SetupKey) IncrementUsage() *SetupKey { - c := key.Copy() - c.UsedTimes++ - c.LastUsed = time.Now().UTC() - return c -} - -// IsValid is true if the key was not revoked, is not expired and used not more than it was supposed to -func (key *SetupKey) IsValid() bool { - return !key.IsRevoked() && !key.IsExpired() && !key.IsOverUsed() -} - -// IsRevoked if key was revoked -func (key *SetupKey) IsRevoked() bool { - return key.Revoked -} - -// IsExpired if key was expired -func (key *SetupKey) IsExpired() bool { - if key.ExpiresAt.IsZero() { - return false - } - return time.Now().After(key.ExpiresAt) -} - -// IsOverUsed if the key was used too many times. SetupKey.UsageLimit == 0 indicates the unlimited usage. -func (key *SetupKey) IsOverUsed() bool { - limit := key.UsageLimit - if key.Type == SetupKeyOneOff { - limit = 1 - } - return limit > 0 && key.UsedTimes >= limit -} - -// GenerateSetupKey generates a new setup key -func GenerateSetupKey(name string, t SetupKeyType, validFor time.Duration, autoGroups []string, - usageLimit int, ephemeral bool) (*SetupKey, string) { - key := strings.ToUpper(uuid.New().String()) - limit := usageLimit - if t == SetupKeyOneOff { - limit = 1 - } - - expiresAt := time.Time{} - if validFor != 0 { - expiresAt = time.Now().UTC().Add(validFor) - } - - hashedKey := sha256.Sum256([]byte(key)) - encodedHashedKey := b64.StdEncoding.EncodeToString(hashedKey[:]) - - return &SetupKey{ - Id: strconv.Itoa(int(Hash(key))), - Key: encodedHashedKey, - KeySecret: hiddenKey(key, 4), - Name: name, - Type: t, - CreatedAt: time.Now().UTC(), - ExpiresAt: expiresAt, - UpdatedAt: time.Now().UTC(), - Revoked: false, - UsedTimes: 0, - AutoGroups: autoGroups, - UsageLimit: limit, - Ephemeral: ephemeral, - }, key -} - -// GenerateDefaultSetupKey generates a default reusable setup key with an unlimited usage and 30 days expiration -func GenerateDefaultSetupKey() (*SetupKey, string) { - return GenerateSetupKey(DefaultSetupKeyName, SetupKeyReusable, DefaultSetupKeyDuration, []string{}, - SetupKeyUnlimitedUsage, false) -} - -func Hash(s string) uint32 { - h := fnv.New32a() - _, err := h.Write([]byte(s)) - if err != nil { - panic(err) - } - return h.Sum32() -} - // CreateSetupKey generates a new setup key with a given name, type, list of groups IDs to auto-assign to peers registered with this key, // and adds it to the specified account. A list of autoGroups IDs can be empty. -func (am *DefaultAccountManager) CreateSetupKey(ctx context.Context, accountID string, keyName string, keyType SetupKeyType, - expiresIn time.Duration, autoGroups []string, usageLimit int, userID string, ephemeral bool) (*SetupKey, error) { +func (am *DefaultAccountManager) CreateSetupKey(ctx context.Context, accountID string, keyName string, keyType types.SetupKeyType, + expiresIn time.Duration, autoGroups []string, usageLimit int, userID string, ephemeral bool) (*types.SetupKey, error) { unlock := am.Store.AcquireWriteLockByUID(ctx, accountID) defer unlock() - user, err := am.Store.GetUserByUserID(ctx, LockingStrengthShare, userID) + user, err := am.Store.GetUserByUserID(ctx, store.LockingStrengthShare, userID) if err != nil { return nil, err } @@ -242,22 +69,22 @@ func (am *DefaultAccountManager) CreateSetupKey(ctx context.Context, accountID s return nil, status.NewAdminPermissionError() } - var setupKey *SetupKey + var setupKey *types.SetupKey var plainKey string var eventsToStore []func() - err = am.Store.ExecuteInTransaction(ctx, func(transaction Store) error { + err = am.Store.ExecuteInTransaction(ctx, func(transaction store.Store) error { if err = validateSetupKeyAutoGroups(ctx, transaction, accountID, autoGroups); err != nil { return err } - setupKey, plainKey = GenerateSetupKey(keyName, keyType, expiresIn, autoGroups, usageLimit, ephemeral) + setupKey, plainKey = types.GenerateSetupKey(keyName, keyType, expiresIn, autoGroups, usageLimit, ephemeral) setupKey.AccountID = accountID events := am.prepareSetupKeyEvents(ctx, transaction, accountID, userID, autoGroups, nil, setupKey) eventsToStore = append(eventsToStore, events...) - return transaction.SaveSetupKey(ctx, LockingStrengthUpdate, setupKey) + return transaction.SaveSetupKey(ctx, store.LockingStrengthUpdate, setupKey) }) if err != nil { return nil, err @@ -278,7 +105,7 @@ func (am *DefaultAccountManager) CreateSetupKey(ctx context.Context, accountID s // Due to the unique nature of a SetupKey certain properties must not be overwritten // (e.g. the key itself, creation date, ID, etc). // These properties are overwritten: AutoGroups, Revoked (only from false to true), and the UpdatedAt. The rest is copied from the existing key. -func (am *DefaultAccountManager) SaveSetupKey(ctx context.Context, accountID string, keyToSave *SetupKey, userID string) (*SetupKey, error) { +func (am *DefaultAccountManager) SaveSetupKey(ctx context.Context, accountID string, keyToSave *types.SetupKey, userID string) (*types.SetupKey, error) { if keyToSave == nil { return nil, status.Errorf(status.InvalidArgument, "provided setup key to update is nil") } @@ -286,7 +113,7 @@ func (am *DefaultAccountManager) SaveSetupKey(ctx context.Context, accountID str unlock := am.Store.AcquireWriteLockByUID(ctx, accountID) defer unlock() - user, err := am.Store.GetUserByUserID(ctx, LockingStrengthShare, userID) + user, err := am.Store.GetUserByUserID(ctx, store.LockingStrengthShare, userID) if err != nil { return nil, err } @@ -299,16 +126,16 @@ func (am *DefaultAccountManager) SaveSetupKey(ctx context.Context, accountID str return nil, status.NewAdminPermissionError() } - var oldKey *SetupKey - var newKey *SetupKey + var oldKey *types.SetupKey + var newKey *types.SetupKey var eventsToStore []func() - err = am.Store.ExecuteInTransaction(ctx, func(transaction Store) error { + err = am.Store.ExecuteInTransaction(ctx, func(transaction store.Store) error { if err = validateSetupKeyAutoGroups(ctx, transaction, accountID, keyToSave.AutoGroups); err != nil { return err } - oldKey, err = transaction.GetSetupKeyByID(ctx, LockingStrengthShare, accountID, keyToSave.Id) + oldKey, err = transaction.GetSetupKeyByID(ctx, store.LockingStrengthShare, accountID, keyToSave.Id) if err != nil { return err } @@ -323,13 +150,13 @@ func (am *DefaultAccountManager) SaveSetupKey(ctx context.Context, accountID str newKey.Revoked = keyToSave.Revoked newKey.UpdatedAt = time.Now().UTC() - addedGroups := difference(newKey.AutoGroups, oldKey.AutoGroups) - removedGroups := difference(oldKey.AutoGroups, newKey.AutoGroups) + addedGroups := util.Difference(newKey.AutoGroups, oldKey.AutoGroups) + removedGroups := util.Difference(oldKey.AutoGroups, newKey.AutoGroups) events := am.prepareSetupKeyEvents(ctx, transaction, accountID, userID, addedGroups, removedGroups, oldKey) eventsToStore = append(eventsToStore, events...) - return transaction.SaveSetupKey(ctx, LockingStrengthUpdate, newKey) + return transaction.SaveSetupKey(ctx, store.LockingStrengthUpdate, newKey) }) if err != nil { return nil, err @@ -347,8 +174,8 @@ func (am *DefaultAccountManager) SaveSetupKey(ctx context.Context, accountID str } // ListSetupKeys returns a list of all setup keys of the account -func (am *DefaultAccountManager) ListSetupKeys(ctx context.Context, accountID, userID string) ([]*SetupKey, error) { - user, err := am.Store.GetUserByUserID(ctx, LockingStrengthShare, userID) +func (am *DefaultAccountManager) ListSetupKeys(ctx context.Context, accountID, userID string) ([]*types.SetupKey, error) { + user, err := am.Store.GetUserByUserID(ctx, store.LockingStrengthShare, userID) if err != nil { return nil, err } @@ -361,12 +188,12 @@ func (am *DefaultAccountManager) ListSetupKeys(ctx context.Context, accountID, u return nil, status.NewAdminPermissionError() } - return am.Store.GetAccountSetupKeys(ctx, LockingStrengthShare, accountID) + return am.Store.GetAccountSetupKeys(ctx, store.LockingStrengthShare, accountID) } // GetSetupKey looks up a SetupKey by KeyID, returns NotFound error if not found. -func (am *DefaultAccountManager) GetSetupKey(ctx context.Context, accountID, userID, keyID string) (*SetupKey, error) { - user, err := am.Store.GetUserByUserID(ctx, LockingStrengthShare, userID) +func (am *DefaultAccountManager) GetSetupKey(ctx context.Context, accountID, userID, keyID string) (*types.SetupKey, error) { + user, err := am.Store.GetUserByUserID(ctx, store.LockingStrengthShare, userID) if err != nil { return nil, err } @@ -379,7 +206,7 @@ func (am *DefaultAccountManager) GetSetupKey(ctx context.Context, accountID, use return nil, status.NewAdminPermissionError() } - setupKey, err := am.Store.GetSetupKeyByID(ctx, LockingStrengthShare, accountID, keyID) + setupKey, err := am.Store.GetSetupKeyByID(ctx, store.LockingStrengthShare, accountID, keyID) if err != nil { return nil, err } @@ -394,7 +221,7 @@ func (am *DefaultAccountManager) GetSetupKey(ctx context.Context, accountID, use // DeleteSetupKey removes the setup key from the account func (am *DefaultAccountManager) DeleteSetupKey(ctx context.Context, accountID, userID, keyID string) error { - user, err := am.Store.GetUserByUserID(ctx, LockingStrengthShare, userID) + user, err := am.Store.GetUserByUserID(ctx, store.LockingStrengthShare, userID) if err != nil { return err } @@ -407,15 +234,15 @@ func (am *DefaultAccountManager) DeleteSetupKey(ctx context.Context, accountID, return status.NewAdminPermissionError() } - var deletedSetupKey *SetupKey + var deletedSetupKey *types.SetupKey - err = am.Store.ExecuteInTransaction(ctx, func(transaction Store) error { - deletedSetupKey, err = transaction.GetSetupKeyByID(ctx, LockingStrengthShare, accountID, keyID) + err = am.Store.ExecuteInTransaction(ctx, func(transaction store.Store) error { + deletedSetupKey, err = transaction.GetSetupKeyByID(ctx, store.LockingStrengthShare, accountID, keyID) if err != nil { return err } - return transaction.DeleteSetupKey(ctx, LockingStrengthUpdate, accountID, keyID) + return transaction.DeleteSetupKey(ctx, store.LockingStrengthUpdate, accountID, keyID) }) if err != nil { return err @@ -426,8 +253,8 @@ func (am *DefaultAccountManager) DeleteSetupKey(ctx context.Context, accountID, return nil } -func validateSetupKeyAutoGroups(ctx context.Context, transaction Store, accountID string, autoGroupIDs []string) error { - groups, err := transaction.GetGroupsByIDs(ctx, LockingStrengthShare, accountID, autoGroupIDs) +func validateSetupKeyAutoGroups(ctx context.Context, transaction store.Store, accountID string, autoGroupIDs []string) error { + groups, err := transaction.GetGroupsByIDs(ctx, store.LockingStrengthShare, accountID, autoGroupIDs) if err != nil { return err } @@ -447,11 +274,11 @@ func validateSetupKeyAutoGroups(ctx context.Context, transaction Store, accountI } // prepareSetupKeyEvents prepares a list of event functions to be stored. -func (am *DefaultAccountManager) prepareSetupKeyEvents(ctx context.Context, transaction Store, accountID, userID string, addedGroups, removedGroups []string, key *SetupKey) []func() { +func (am *DefaultAccountManager) prepareSetupKeyEvents(ctx context.Context, transaction store.Store, accountID, userID string, addedGroups, removedGroups []string, key *types.SetupKey) []func() { var eventsToStore []func() modifiedGroups := slices.Concat(addedGroups, removedGroups) - groups, err := transaction.GetGroupsByIDs(ctx, LockingStrengthShare, accountID, modifiedGroups) + groups, err := transaction.GetGroupsByIDs(ctx, store.LockingStrengthShare, accountID, modifiedGroups) if err != nil { log.WithContext(ctx).Debugf("failed to get groups for setup key events: %v", err) return nil diff --git a/management/server/setupkey_test.go b/management/server/setupkey_test.go index 614547c60..f728db5d4 100644 --- a/management/server/setupkey_test.go +++ b/management/server/setupkey_test.go @@ -15,7 +15,7 @@ import ( "github.com/stretchr/testify/require" "github.com/netbirdio/netbird/management/server/activity" - nbgroup "github.com/netbirdio/netbird/management/server/group" + "github.com/netbirdio/netbird/management/server/types" ) func TestDefaultAccountManager_SaveSetupKey(t *testing.T) { @@ -30,7 +30,7 @@ func TestDefaultAccountManager_SaveSetupKey(t *testing.T) { t.Fatal(err) } - err = manager.SaveGroups(context.Background(), account.Id, userID, []*nbgroup.Group{ + err = manager.SaveGroups(context.Background(), account.Id, userID, []*types.Group{ { ID: "group_1", Name: "group_name_1", @@ -49,15 +49,15 @@ func TestDefaultAccountManager_SaveSetupKey(t *testing.T) { expiresIn := time.Hour keyName := "my-test-key" - key, err := manager.CreateSetupKey(context.Background(), account.Id, keyName, SetupKeyReusable, expiresIn, []string{}, - SetupKeyUnlimitedUsage, userID, false) + key, err := manager.CreateSetupKey(context.Background(), account.Id, keyName, types.SetupKeyReusable, expiresIn, []string{}, + types.SetupKeyUnlimitedUsage, userID, false) if err != nil { t.Fatal(err) } autoGroups := []string{"group_1", "group_2"} revoked := true - newKey, err := manager.SaveSetupKey(context.Background(), account.Id, &SetupKey{ + newKey, err := manager.SaveSetupKey(context.Background(), account.Id, &types.SetupKey{ Id: key.Id, Revoked: revoked, AutoGroups: autoGroups, @@ -85,7 +85,7 @@ func TestDefaultAccountManager_SaveSetupKey(t *testing.T) { // saving setup key with All group assigned to auto groups should return error autoGroups = append(autoGroups, groupAll.ID) - _, err = manager.SaveSetupKey(context.Background(), account.Id, &SetupKey{ + _, err = manager.SaveSetupKey(context.Background(), account.Id, &types.SetupKey{ Id: key.Id, Revoked: revoked, AutoGroups: autoGroups, @@ -105,7 +105,7 @@ func TestDefaultAccountManager_CreateSetupKey(t *testing.T) { t.Fatal(err) } - err = manager.SaveGroup(context.Background(), account.Id, userID, &nbgroup.Group{ + err = manager.SaveGroup(context.Background(), account.Id, userID, &types.Group{ ID: "group_1", Name: "group_name_1", Peers: []string{}, @@ -114,7 +114,7 @@ func TestDefaultAccountManager_CreateSetupKey(t *testing.T) { t.Fatal(err) } - err = manager.SaveGroup(context.Background(), account.Id, userID, &nbgroup.Group{ + err = manager.SaveGroup(context.Background(), account.Id, userID, &types.Group{ ID: "group_2", Name: "group_name_2", Peers: []string{}, @@ -167,8 +167,8 @@ func TestDefaultAccountManager_CreateSetupKey(t *testing.T) { for _, tCase := range []testCase{testCase1, testCase2, testCase3} { t.Run(tCase.name, func(t *testing.T) { - key, err := manager.CreateSetupKey(context.Background(), account.Id, tCase.expectedKeyName, SetupKeyReusable, expiresIn, - tCase.expectedGroups, SetupKeyUnlimitedUsage, userID, false) + key, err := manager.CreateSetupKey(context.Background(), account.Id, tCase.expectedKeyName, types.SetupKeyReusable, expiresIn, + tCase.expectedGroups, types.SetupKeyUnlimitedUsage, userID, false) if tCase.expectedFailure { if err == nil { @@ -182,7 +182,7 @@ func TestDefaultAccountManager_CreateSetupKey(t *testing.T) { } assertKey(t, key, tCase.expectedKeyName, false, tCase.expectedType, tCase.expectedUsedTimes, - tCase.expectedCreatedAt, tCase.expectedExpiresAt, strconv.Itoa(int(Hash(key.Key))), + tCase.expectedCreatedAt, tCase.expectedExpiresAt, strconv.Itoa(int(types.Hash(key.Key))), tCase.expectedUpdatedAt, tCase.expectedGroups, false) // check the corresponding events that should have been generated @@ -210,7 +210,7 @@ func TestGetSetupKeys(t *testing.T) { t.Fatal(err) } - plainKey, err := manager.CreateSetupKey(context.Background(), account.Id, "key1", SetupKeyReusable, time.Hour, nil, SetupKeyUnlimitedUsage, userID, false) + plainKey, err := manager.CreateSetupKey(context.Background(), account.Id, "key1", types.SetupKeyReusable, time.Hour, nil, types.SetupKeyUnlimitedUsage, userID, false) if err != nil { t.Fatal(err) } @@ -258,10 +258,10 @@ func TestGenerateDefaultSetupKey(t *testing.T) { expectedExpiresAt := time.Now().UTC().Add(24 * 30 * time.Hour) var expectedAutoGroups []string - key, plainKey := GenerateDefaultSetupKey() + key, plainKey := types.GenerateDefaultSetupKey() assertKey(t, key, expectedName, expectedRevoke, expectedType, expectedUsedTimes, expectedCreatedAt, - expectedExpiresAt, strconv.Itoa(int(Hash(plainKey))), expectedUpdatedAt, expectedAutoGroups, true) + expectedExpiresAt, strconv.Itoa(int(types.Hash(plainKey))), expectedUpdatedAt, expectedAutoGroups, true) } @@ -275,48 +275,48 @@ func TestGenerateSetupKey(t *testing.T) { expectedUpdatedAt := time.Now().UTC() var expectedAutoGroups []string - key, plain := GenerateSetupKey(expectedName, SetupKeyOneOff, time.Hour, []string{}, SetupKeyUnlimitedUsage, false) + key, plain := types.GenerateSetupKey(expectedName, types.SetupKeyOneOff, time.Hour, []string{}, types.SetupKeyUnlimitedUsage, false) assertKey(t, key, expectedName, expectedRevoke, expectedType, expectedUsedTimes, expectedCreatedAt, - expectedExpiresAt, strconv.Itoa(int(Hash(plain))), expectedUpdatedAt, expectedAutoGroups, true) + expectedExpiresAt, strconv.Itoa(int(types.Hash(plain))), expectedUpdatedAt, expectedAutoGroups, true) } func TestSetupKey_IsValid(t *testing.T) { - validKey, _ := GenerateSetupKey("valid key", SetupKeyOneOff, time.Hour, []string{}, SetupKeyUnlimitedUsage, false) + validKey, _ := types.GenerateSetupKey("valid key", types.SetupKeyOneOff, time.Hour, []string{}, types.SetupKeyUnlimitedUsage, false) if !validKey.IsValid() { t.Errorf("expected key to be valid, got invalid %v", validKey) } // expired - expiredKey, _ := GenerateSetupKey("invalid key", SetupKeyOneOff, -time.Hour, []string{}, SetupKeyUnlimitedUsage, false) + expiredKey, _ := types.GenerateSetupKey("invalid key", types.SetupKeyOneOff, -time.Hour, []string{}, types.SetupKeyUnlimitedUsage, false) if expiredKey.IsValid() { t.Errorf("expected key to be invalid due to expiration, got valid %v", expiredKey) } // revoked - revokedKey, _ := GenerateSetupKey("invalid key", SetupKeyOneOff, time.Hour, []string{}, SetupKeyUnlimitedUsage, false) + revokedKey, _ := types.GenerateSetupKey("invalid key", types.SetupKeyOneOff, time.Hour, []string{}, types.SetupKeyUnlimitedUsage, false) revokedKey.Revoked = true if revokedKey.IsValid() { t.Errorf("expected revoked key to be invalid, got valid %v", revokedKey) } // overused - overUsedKey, _ := GenerateSetupKey("invalid key", SetupKeyOneOff, time.Hour, []string{}, SetupKeyUnlimitedUsage, false) + overUsedKey, _ := types.GenerateSetupKey("invalid key", types.SetupKeyOneOff, time.Hour, []string{}, types.SetupKeyUnlimitedUsage, false) overUsedKey.UsedTimes = 1 if overUsedKey.IsValid() { t.Errorf("expected overused key to be invalid, got valid %v", overUsedKey) } // overused - reusableKey, _ := GenerateSetupKey("valid key", SetupKeyReusable, time.Hour, []string{}, SetupKeyUnlimitedUsage, false) + reusableKey, _ := types.GenerateSetupKey("valid key", types.SetupKeyReusable, time.Hour, []string{}, types.SetupKeyUnlimitedUsage, false) reusableKey.UsedTimes = 99 if !reusableKey.IsValid() { t.Errorf("expected reusable key to be valid when used many times, got valid %v", reusableKey) } } -func assertKey(t *testing.T, key *SetupKey, expectedName string, expectedRevoke bool, expectedType string, +func assertKey(t *testing.T, key *types.SetupKey, expectedName string, expectedRevoke bool, expectedType string, expectedUsedTimes int, expectedCreatedAt time.Time, expectedExpiresAt time.Time, expectedID string, expectedUpdatedAt time.Time, expectedAutoGroups []string, expectHashedKey bool) { t.Helper() @@ -388,7 +388,7 @@ func isValidBase64SHA256(encodedKey string) bool { func TestSetupKey_Copy(t *testing.T) { - key, _ := GenerateSetupKey("key name", SetupKeyOneOff, time.Hour, []string{}, SetupKeyUnlimitedUsage, false) + key, _ := types.GenerateSetupKey("key name", types.SetupKeyOneOff, time.Hour, []string{}, types.SetupKeyUnlimitedUsage, false) keyCopy := key.Copy() assertKey(t, keyCopy, key.Name, key.Revoked, string(key.Type), key.UsedTimes, key.CreatedAt, key.ExpiresAt, key.Id, @@ -399,22 +399,22 @@ func TestSetupKey_Copy(t *testing.T) { func TestSetupKeyAccountPeersUpdate(t *testing.T) { manager, account, peer1, peer2, peer3 := setupNetworkMapTest(t) - err := manager.SaveGroup(context.Background(), account.Id, userID, &nbgroup.Group{ + err := manager.SaveGroup(context.Background(), account.Id, userID, &types.Group{ ID: "groupA", Name: "GroupA", Peers: []string{peer1.ID, peer2.ID, peer3.ID}, }) assert.NoError(t, err) - policy := &Policy{ + policy := &types.Policy{ Enabled: true, - Rules: []*PolicyRule{ + Rules: []*types.PolicyRule{ { Enabled: true, Sources: []string{"groupA"}, Destinations: []string{"group"}, Bidirectional: true, - Action: PolicyTrafficActionAccept, + Action: types.PolicyTrafficActionAccept, }, }, } @@ -426,7 +426,7 @@ func TestSetupKeyAccountPeersUpdate(t *testing.T) { manager.peersUpdateManager.CloseChannel(context.Background(), peer1.ID) }) - var setupKey *SetupKey + var setupKey *types.SetupKey // Creating setup key should not update account peers and not send peer update t.Run("creating setup key", func(t *testing.T) { @@ -436,7 +436,7 @@ func TestSetupKeyAccountPeersUpdate(t *testing.T) { close(done) }() - setupKey, err = manager.CreateSetupKey(context.Background(), account.Id, "key1", SetupKeyReusable, time.Hour, nil, 999, userID, false) + setupKey, err = manager.CreateSetupKey(context.Background(), account.Id, "key1", types.SetupKeyReusable, time.Hour, nil, 999, userID, false) assert.NoError(t, err) select { @@ -477,7 +477,7 @@ func TestDefaultAccountManager_CreateSetupKey_ShouldNotAllowToUpdateRevokedKey(t t.Fatal(err) } - key, err := manager.CreateSetupKey(context.Background(), account.Id, "testName", SetupKeyReusable, time.Hour, nil, SetupKeyUnlimitedUsage, userID, false) + key, err := manager.CreateSetupKey(context.Background(), account.Id, "testName", types.SetupKeyReusable, time.Hour, nil, types.SetupKeyUnlimitedUsage, userID, false) assert.NoError(t, err) // revoke the key diff --git a/management/server/status/error.go b/management/server/status/error.go index 59f436f5b..d65931b5a 100644 --- a/management/server/status/error.go +++ b/management/server/status/error.go @@ -154,3 +154,27 @@ func NewPolicyNotFoundError(policyID string) error { func NewNameServerGroupNotFoundError(nsGroupID string) error { return Errorf(NotFound, "nameserver group: %s not found", nsGroupID) } + +// NewNetworkNotFoundError creates a new Error with NotFound type for a missing network. +func NewNetworkNotFoundError(networkID string) error { + return Errorf(NotFound, "network: %s not found", networkID) +} + +// NewNetworkRouterNotFoundError creates a new Error with NotFound type for a missing network router. +func NewNetworkRouterNotFoundError(routerID string) error { + return Errorf(NotFound, "network router: %s not found", routerID) +} + +// NewNetworkResourceNotFoundError creates a new Error with NotFound type for a missing network resource. +func NewNetworkResourceNotFoundError(resourceID string) error { + return Errorf(NotFound, "network resource: %s not found", resourceID) +} + +// NewPermissionDeniedError creates a new Error with PermissionDenied type for a permission denied error. +func NewPermissionDeniedError() error { + return Errorf(PermissionDenied, "permission denied") +} + +func NewPermissionValidationError(err error) error { + return Errorf(PermissionDenied, "failed to vlidate user permissions: %s", err) +} diff --git a/management/server/file_store.go b/management/server/store/file_store.go similarity index 89% rename from management/server/file_store.go rename to management/server/store/file_store.go index f375fb990..9127c2705 100644 --- a/management/server/file_store.go +++ b/management/server/store/file_store.go @@ -1,4 +1,4 @@ -package server +package store import ( "context" @@ -11,9 +11,9 @@ import ( "github.com/rs/xid" log "github.com/sirupsen/logrus" - nbgroup "github.com/netbirdio/netbird/management/server/group" nbpeer "github.com/netbirdio/netbird/management/server/peer" "github.com/netbirdio/netbird/management/server/telemetry" + "github.com/netbirdio/netbird/management/server/types" "github.com/netbirdio/netbird/util" ) @@ -22,7 +22,7 @@ const storeFileName = "store.json" // FileStore represents an account storage backed by a file persisted to disk type FileStore struct { - Accounts map[string]*Account + Accounts map[string]*types.Account SetupKeyID2AccountID map[string]string `json:"-"` PeerKeyID2AccountID map[string]string `json:"-"` PeerID2AccountID map[string]string `json:"-"` @@ -55,7 +55,7 @@ func restore(ctx context.Context, file string) (*FileStore, error) { if _, err := os.Stat(file); os.IsNotExist(err) { // create a new FileStore if previously didn't exist (e.g. first run) s := &FileStore{ - Accounts: make(map[string]*Account), + Accounts: make(map[string]*types.Account), mux: sync.Mutex{}, SetupKeyID2AccountID: make(map[string]string), PeerKeyID2AccountID: make(map[string]string), @@ -92,12 +92,12 @@ func restore(ctx context.Context, file string) (*FileStore, error) { for accountID, account := range store.Accounts { if account.Settings == nil { - account.Settings = &Settings{ + account.Settings = &types.Settings{ PeerLoginExpirationEnabled: false, - PeerLoginExpiration: DefaultPeerLoginExpiration, + PeerLoginExpiration: types.DefaultPeerLoginExpiration, PeerInactivityExpirationEnabled: false, - PeerInactivityExpiration: DefaultPeerInactivityExpiration, + PeerInactivityExpiration: types.DefaultPeerInactivityExpiration, } } @@ -112,7 +112,7 @@ func restore(ctx context.Context, file string) (*FileStore, error) { for _, user := range account.Users { store.UserID2AccountID[user.Id] = accountID if user.Issued == "" { - user.Issued = UserIssuedAPI + user.Issued = types.UserIssuedAPI account.Users[user.Id] = user } @@ -122,7 +122,7 @@ func restore(ctx context.Context, file string) (*FileStore, error) { } } - if account.Domain != "" && account.DomainCategory == PrivateCategory && + if account.Domain != "" && account.DomainCategory == types.PrivateCategory && account.IsDomainPrimaryAccount { store.PrivateDomain2AccountID[account.Domain] = accountID } @@ -134,20 +134,20 @@ func restore(ctx context.Context, file string) (*FileStore, error) { policy.UpgradeAndFix() } if account.Policies == nil { - account.Policies = make([]*Policy, 0) + account.Policies = make([]*types.Policy, 0) } // for data migration. Can be removed once most base will be with labels - existingLabels := account.getPeerDNSLabels() + existingLabels := account.GetPeerDNSLabels() if len(existingLabels) != len(account.Peers) { - addPeerLabelsToAccount(ctx, account, existingLabels) + types.AddPeerLabelsToAccount(ctx, account, existingLabels) } // TODO: delete this block after migration // Set API as issuer for groups which has not this field for _, group := range account.Groups { if group.Issued == "" { - group.Issued = nbgroup.GroupIssuedAPI + group.Issued = types.GroupIssuedAPI } } @@ -236,7 +236,7 @@ func (s *FileStore) persist(ctx context.Context, file string) error { } // GetAllAccounts returns all accounts -func (s *FileStore) GetAllAccounts(_ context.Context) (all []*Account) { +func (s *FileStore) GetAllAccounts(_ context.Context) (all []*types.Account) { s.mux.Lock() defer s.mux.Unlock() for _, a := range s.Accounts { @@ -257,6 +257,6 @@ func (s *FileStore) Close(ctx context.Context) error { } // GetStoreEngine returns FileStoreEngine -func (s *FileStore) GetStoreEngine() StoreEngine { +func (s *FileStore) GetStoreEngine() Engine { return FileStoreEngine } diff --git a/management/server/sql_store.go b/management/server/store/sql_store.go similarity index 79% rename from management/server/sql_store.go rename to management/server/store/sql_store.go index 1fd8ae2aa..771a32aae 100644 --- a/management/server/sql_store.go +++ b/management/server/store/sql_store.go @@ -1,4 +1,4 @@ -package server +package store import ( "context" @@ -24,11 +24,14 @@ import ( nbdns "github.com/netbirdio/netbird/dns" "github.com/netbirdio/netbird/management/server/account" - nbgroup "github.com/netbirdio/netbird/management/server/group" + resourceTypes "github.com/netbirdio/netbird/management/server/networks/resources/types" + routerTypes "github.com/netbirdio/netbird/management/server/networks/routers/types" + networkTypes "github.com/netbirdio/netbird/management/server/networks/types" nbpeer "github.com/netbirdio/netbird/management/server/peer" "github.com/netbirdio/netbird/management/server/posture" "github.com/netbirdio/netbird/management/server/status" "github.com/netbirdio/netbird/management/server/telemetry" + "github.com/netbirdio/netbird/management/server/types" "github.com/netbirdio/netbird/route" ) @@ -49,7 +52,7 @@ type SqlStore struct { globalAccountLock sync.Mutex metrics telemetry.AppMetrics installationPK int - storeEngine StoreEngine + storeEngine Engine } type installation struct { @@ -60,7 +63,7 @@ type installation struct { type migrationFunc func(*gorm.DB) error // NewSqlStore creates a new SqlStore instance. -func NewSqlStore(ctx context.Context, db *gorm.DB, storeEngine StoreEngine, metrics telemetry.AppMetrics) (*SqlStore, error) { +func NewSqlStore(ctx context.Context, db *gorm.DB, storeEngine Engine, metrics telemetry.AppMetrics) (*SqlStore, error) { sql, err := db.DB() if err != nil { return nil, err @@ -86,9 +89,10 @@ func NewSqlStore(ctx context.Context, db *gorm.DB, storeEngine StoreEngine, metr return nil, fmt.Errorf("migrate: %w", err) } err = db.AutoMigrate( - &SetupKey{}, &nbpeer.Peer{}, &User{}, &PersonalAccessToken{}, &nbgroup.Group{}, - &Account{}, &Policy{}, &PolicyRule{}, &route.Route{}, &nbdns.NameServerGroup{}, + &types.SetupKey{}, &nbpeer.Peer{}, &types.User{}, &types.PersonalAccessToken{}, &types.Group{}, + &types.Account{}, &types.Policy{}, &types.PolicyRule{}, &route.Route{}, &nbdns.NameServerGroup{}, &installation{}, &account.ExtraSettings{}, &posture.Checks{}, &nbpeer.NetworkAddress{}, + &networkTypes.Network{}, &routerTypes.NetworkRouter{}, &resourceTypes.NetworkResource{}, ) if err != nil { return nil, fmt.Errorf("auto migrate: %w", err) @@ -151,7 +155,7 @@ func (s *SqlStore) AcquireReadLockByUID(ctx context.Context, uniqueID string) (u return unlock } -func (s *SqlStore) SaveAccount(ctx context.Context, account *Account) error { +func (s *SqlStore) SaveAccount(ctx context.Context, account *types.Account) error { start := time.Now() defer func() { elapsed := time.Since(start) @@ -201,7 +205,7 @@ func (s *SqlStore) SaveAccount(ctx context.Context, account *Account) error { } // generateAccountSQLTypes generates the GORM compatible types for the account -func generateAccountSQLTypes(account *Account) { +func generateAccountSQLTypes(account *types.Account) { for _, key := range account.SetupKeys { account.SetupKeysG = append(account.SetupKeysG, *key) } @@ -238,7 +242,7 @@ func generateAccountSQLTypes(account *Account) { // checkAccountDomainBeforeSave temporary method to troubleshoot an issue with domains getting blank func (s *SqlStore) checkAccountDomainBeforeSave(ctx context.Context, accountID, newDomain string) { - var acc Account + var acc types.Account var domain string result := s.db.Model(&acc).Select("domain").Where(idQueryCondition, accountID).First(&domain) if result.Error != nil { @@ -252,7 +256,7 @@ func (s *SqlStore) checkAccountDomainBeforeSave(ctx context.Context, accountID, } } -func (s *SqlStore) DeleteAccount(ctx context.Context, account *Account) error { +func (s *SqlStore) DeleteAccount(ctx context.Context, account *types.Account) error { start := time.Now() err := s.db.Transaction(func(tx *gorm.DB) error { @@ -333,14 +337,14 @@ func (s *SqlStore) SavePeer(ctx context.Context, accountID string, peer *nbpeer. } func (s *SqlStore) UpdateAccountDomainAttributes(ctx context.Context, accountID string, domain string, category string, isPrimaryDomain bool) error { - accountCopy := Account{ + accountCopy := types.Account{ Domain: domain, DomainCategory: category, IsDomainPrimaryAccount: isPrimaryDomain, } fieldsToUpdate := []string{"domain", "domain_category", "is_domain_primary_account"} - result := s.db.Model(&Account{}). + result := s.db.Model(&types.Account{}). Select(fieldsToUpdate). Where(idQueryCondition, accountID). Updates(&accountCopy) @@ -402,8 +406,8 @@ func (s *SqlStore) SavePeerLocation(accountID string, peerWithLocation *nbpeer.P // SaveUsers saves the given list of users to the database. // It updates existing users if a conflict occurs. -func (s *SqlStore) SaveUsers(accountID string, users map[string]*User) error { - usersToSave := make([]User, 0, len(users)) +func (s *SqlStore) SaveUsers(accountID string, users map[string]*types.User) error { + usersToSave := make([]types.User, 0, len(users)) for _, user := range users { user.AccountID = accountID for id, pat := range user.PATs { @@ -423,7 +427,7 @@ func (s *SqlStore) SaveUsers(accountID string, users map[string]*User) error { } // SaveUser saves the given user to the database. -func (s *SqlStore) SaveUser(ctx context.Context, lockStrength LockingStrength, user *User) error { +func (s *SqlStore) SaveUser(ctx context.Context, lockStrength LockingStrength, user *types.User) error { result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}).Save(user) if result.Error != nil { return status.Errorf(status.Internal, "failed to save user to store: %v", result.Error) @@ -432,7 +436,7 @@ func (s *SqlStore) SaveUser(ctx context.Context, lockStrength LockingStrength, u } // SaveGroups saves the given list of groups to the database. -func (s *SqlStore) SaveGroups(ctx context.Context, lockStrength LockingStrength, groups []*nbgroup.Group) error { +func (s *SqlStore) SaveGroups(ctx context.Context, lockStrength LockingStrength, groups []*types.Group) error { if len(groups) == 0 { return nil } @@ -454,7 +458,7 @@ func (s *SqlStore) DeleteTokenID2UserIDIndex(tokenID string) error { return nil } -func (s *SqlStore) GetAccountByPrivateDomain(ctx context.Context, domain string) (*Account, error) { +func (s *SqlStore) GetAccountByPrivateDomain(ctx context.Context, domain string) (*types.Account, error) { accountID, err := s.GetAccountIDByPrivateDomain(ctx, LockingStrengthShare, domain) if err != nil { return nil, err @@ -466,9 +470,9 @@ func (s *SqlStore) GetAccountByPrivateDomain(ctx context.Context, domain string) func (s *SqlStore) GetAccountIDByPrivateDomain(ctx context.Context, lockStrength LockingStrength, domain string) (string, error) { var accountID string - result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}).Model(&Account{}).Select("id"). + result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}).Model(&types.Account{}).Select("id"). Where("domain = ? and is_domain_primary_account = ? and domain_category = ?", - strings.ToLower(domain), true, PrivateCategory, + strings.ToLower(domain), true, types.PrivateCategory, ).First(&accountID) if result.Error != nil { if errors.Is(result.Error, gorm.ErrRecordNotFound) { @@ -481,8 +485,8 @@ func (s *SqlStore) GetAccountIDByPrivateDomain(ctx context.Context, lockStrength return accountID, nil } -func (s *SqlStore) GetAccountBySetupKey(ctx context.Context, setupKey string) (*Account, error) { - var key SetupKey +func (s *SqlStore) GetAccountBySetupKey(ctx context.Context, setupKey string) (*types.Account, error) { + var key types.SetupKey result := s.db.Select("account_id").First(&key, keyQueryCondition, setupKey) if result.Error != nil { if errors.Is(result.Error, gorm.ErrRecordNotFound) { @@ -500,7 +504,7 @@ func (s *SqlStore) GetAccountBySetupKey(ctx context.Context, setupKey string) (* } func (s *SqlStore) GetTokenIDByHashedToken(ctx context.Context, hashedToken string) (string, error) { - var token PersonalAccessToken + var token types.PersonalAccessToken result := s.db.First(&token, "hashed_token = ?", hashedToken) if result.Error != nil { if errors.Is(result.Error, gorm.ErrRecordNotFound) { @@ -513,8 +517,8 @@ func (s *SqlStore) GetTokenIDByHashedToken(ctx context.Context, hashedToken stri return token.ID, nil } -func (s *SqlStore) GetUserByTokenID(ctx context.Context, tokenID string) (*User, error) { - var token PersonalAccessToken +func (s *SqlStore) GetUserByTokenID(ctx context.Context, tokenID string) (*types.User, error) { + var token types.PersonalAccessToken result := s.db.First(&token, idQueryCondition, tokenID) if result.Error != nil { if errors.Is(result.Error, gorm.ErrRecordNotFound) { @@ -528,13 +532,13 @@ func (s *SqlStore) GetUserByTokenID(ctx context.Context, tokenID string) (*User, return nil, status.Errorf(status.NotFound, "account not found: index lookup failed") } - var user User + var user types.User result = s.db.Preload("PATsG").First(&user, idQueryCondition, token.UserID) if result.Error != nil { return nil, status.Errorf(status.NotFound, "account not found: index lookup failed") } - user.PATs = make(map[string]*PersonalAccessToken, len(user.PATsG)) + user.PATs = make(map[string]*types.PersonalAccessToken, len(user.PATsG)) for _, pat := range user.PATsG { user.PATs[pat.ID] = pat.Copy() } @@ -542,8 +546,8 @@ func (s *SqlStore) GetUserByTokenID(ctx context.Context, tokenID string) (*User, return &user, nil } -func (s *SqlStore) GetUserByUserID(ctx context.Context, lockStrength LockingStrength, userID string) (*User, error) { - var user User +func (s *SqlStore) GetUserByUserID(ctx context.Context, lockStrength LockingStrength, userID string) (*types.User, error) { + var user types.User result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}). Preload(clause.Associations).First(&user, idQueryCondition, userID) if result.Error != nil { @@ -556,8 +560,8 @@ func (s *SqlStore) GetUserByUserID(ctx context.Context, lockStrength LockingStre return &user, nil } -func (s *SqlStore) GetAccountUsers(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*User, error) { - var users []*User +func (s *SqlStore) GetAccountUsers(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*types.User, error) { + var users []*types.User result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}).Find(&users, accountIDCondition, accountID) if result.Error != nil { if errors.Is(result.Error, gorm.ErrRecordNotFound) { @@ -570,8 +574,8 @@ func (s *SqlStore) GetAccountUsers(ctx context.Context, lockStrength LockingStre return users, nil } -func (s *SqlStore) GetAccountGroups(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*nbgroup.Group, error) { - var groups []*nbgroup.Group +func (s *SqlStore) GetAccountGroups(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*types.Group, error) { + var groups []*types.Group result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}).Find(&groups, accountIDCondition, accountID) if result.Error != nil { if errors.Is(result.Error, gorm.ErrRecordNotFound) { @@ -584,8 +588,8 @@ func (s *SqlStore) GetAccountGroups(ctx context.Context, lockStrength LockingStr return groups, nil } -func (s *SqlStore) GetAllAccounts(ctx context.Context) (all []*Account) { - var accounts []Account +func (s *SqlStore) GetAllAccounts(ctx context.Context) (all []*types.Account) { + var accounts []types.Account result := s.db.Find(&accounts) if result.Error != nil { return all @@ -600,7 +604,7 @@ func (s *SqlStore) GetAllAccounts(ctx context.Context) (all []*Account) { return all } -func (s *SqlStore) GetAccount(ctx context.Context, accountID string) (*Account, error) { +func (s *SqlStore) GetAccount(ctx context.Context, accountID string) (*types.Account, error) { start := time.Now() defer func() { elapsed := time.Since(start) @@ -609,7 +613,7 @@ func (s *SqlStore) GetAccount(ctx context.Context, accountID string) (*Account, } }() - var account Account + var account types.Account result := s.db.Model(&account). Preload("UsersG.PATsG"). // have to be specifies as this is nester reference Preload(clause.Associations). @@ -624,15 +628,15 @@ func (s *SqlStore) GetAccount(ctx context.Context, accountID string) (*Account, // we have to manually preload policy rules as it seems that gorm preloading doesn't do it for us for i, policy := range account.Policies { - var rules []*PolicyRule - err := s.db.Model(&PolicyRule{}).Find(&rules, "policy_id = ?", policy.ID).Error + var rules []*types.PolicyRule + err := s.db.Model(&types.PolicyRule{}).Find(&rules, "policy_id = ?", policy.ID).Error if err != nil { return nil, status.Errorf(status.NotFound, "rule not found") } account.Policies[i].Rules = rules } - account.SetupKeys = make(map[string]*SetupKey, len(account.SetupKeysG)) + account.SetupKeys = make(map[string]*types.SetupKey, len(account.SetupKeysG)) for _, key := range account.SetupKeysG { account.SetupKeys[key.Key] = key.Copy() } @@ -644,9 +648,9 @@ func (s *SqlStore) GetAccount(ctx context.Context, accountID string) (*Account, } account.PeersG = nil - account.Users = make(map[string]*User, len(account.UsersG)) + account.Users = make(map[string]*types.User, len(account.UsersG)) for _, user := range account.UsersG { - user.PATs = make(map[string]*PersonalAccessToken, len(user.PATs)) + user.PATs = make(map[string]*types.PersonalAccessToken, len(user.PATs)) for _, pat := range user.PATsG { user.PATs[pat.ID] = pat.Copy() } @@ -654,7 +658,7 @@ func (s *SqlStore) GetAccount(ctx context.Context, accountID string) (*Account, } account.UsersG = nil - account.Groups = make(map[string]*nbgroup.Group, len(account.GroupsG)) + account.Groups = make(map[string]*types.Group, len(account.GroupsG)) for _, group := range account.GroupsG { account.Groups[group.ID] = group.Copy() } @@ -675,8 +679,8 @@ func (s *SqlStore) GetAccount(ctx context.Context, accountID string) (*Account, return &account, nil } -func (s *SqlStore) GetAccountByUser(ctx context.Context, userID string) (*Account, error) { - var user User +func (s *SqlStore) GetAccountByUser(ctx context.Context, userID string) (*types.Account, error) { + var user types.User result := s.db.Select("account_id").First(&user, idQueryCondition, userID) if result.Error != nil { if errors.Is(result.Error, gorm.ErrRecordNotFound) { @@ -692,7 +696,7 @@ func (s *SqlStore) GetAccountByUser(ctx context.Context, userID string) (*Accoun return s.GetAccount(ctx, user.AccountID) } -func (s *SqlStore) GetAccountByPeerID(ctx context.Context, peerID string) (*Account, error) { +func (s *SqlStore) GetAccountByPeerID(ctx context.Context, peerID string) (*types.Account, error) { var peer nbpeer.Peer result := s.db.Select("account_id").First(&peer, idQueryCondition, peerID) if result.Error != nil { @@ -709,7 +713,7 @@ func (s *SqlStore) GetAccountByPeerID(ctx context.Context, peerID string) (*Acco return s.GetAccount(ctx, peer.AccountID) } -func (s *SqlStore) GetAccountByPeerPubKey(ctx context.Context, peerKey string) (*Account, error) { +func (s *SqlStore) GetAccountByPeerPubKey(ctx context.Context, peerKey string) (*types.Account, error) { var peer nbpeer.Peer result := s.db.Select("account_id").First(&peer, keyQueryCondition, peerKey) if result.Error != nil { @@ -742,7 +746,7 @@ func (s *SqlStore) GetAccountIDByPeerPubKey(ctx context.Context, peerKey string) func (s *SqlStore) GetAccountIDByUserID(userID string) (string, error) { var accountID string - result := s.db.Model(&User{}).Select("account_id").Where(idQueryCondition, userID).First(&accountID) + result := s.db.Model(&types.User{}).Select("account_id").Where(idQueryCondition, userID).First(&accountID) if result.Error != nil { if errors.Is(result.Error, gorm.ErrRecordNotFound) { return "", status.Errorf(status.NotFound, "account not found: index lookup failed") @@ -755,7 +759,7 @@ func (s *SqlStore) GetAccountIDByUserID(userID string) (string, error) { func (s *SqlStore) GetAccountIDBySetupKey(ctx context.Context, setupKey string) (string, error) { var accountID string - result := s.db.Model(&SetupKey{}).Select("account_id").Where(keyQueryCondition, setupKey).First(&accountID) + result := s.db.Model(&types.SetupKey{}).Select("account_id").Where(keyQueryCondition, setupKey).First(&accountID) if result.Error != nil { if errors.Is(result.Error, gorm.ErrRecordNotFound) { return "", status.NewSetupKeyNotFoundError(setupKey) @@ -815,9 +819,9 @@ func (s *SqlStore) GetPeerLabelsInAccount(ctx context.Context, lockStrength Lock return labels, nil } -func (s *SqlStore) GetAccountNetwork(ctx context.Context, lockStrength LockingStrength, accountID string) (*Network, error) { - var accountNetwork AccountNetwork - if err := s.db.Model(&Account{}).Where(idQueryCondition, accountID).First(&accountNetwork).Error; err != nil { +func (s *SqlStore) GetAccountNetwork(ctx context.Context, lockStrength LockingStrength, accountID string) (*types.Network, error) { + var accountNetwork types.AccountNetwork + if err := s.db.Model(&types.Account{}).Where(idQueryCondition, accountID).First(&accountNetwork).Error; err != nil { if errors.Is(err, gorm.ErrRecordNotFound) { return nil, status.NewAccountNotFoundError(accountID) } @@ -839,9 +843,9 @@ func (s *SqlStore) GetPeerByPeerPubKey(ctx context.Context, lockStrength Locking return &peer, nil } -func (s *SqlStore) GetAccountSettings(ctx context.Context, lockStrength LockingStrength, accountID string) (*Settings, error) { - var accountSettings AccountSettings - if err := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}).Model(&Account{}).Where(idQueryCondition, accountID).First(&accountSettings).Error; err != nil { +func (s *SqlStore) GetAccountSettings(ctx context.Context, lockStrength LockingStrength, accountID string) (*types.Settings, error) { + var accountSettings types.AccountSettings + if err := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}).Model(&types.Account{}).Where(idQueryCondition, accountID).First(&accountSettings).Error; err != nil { if errors.Is(err, gorm.ErrRecordNotFound) { return nil, status.Errorf(status.NotFound, "settings not found") } @@ -852,7 +856,7 @@ func (s *SqlStore) GetAccountSettings(ctx context.Context, lockStrength LockingS // SaveUserLastLogin stores the last login time for a user in DB. func (s *SqlStore) SaveUserLastLogin(ctx context.Context, accountID, userID string, lastLogin time.Time) error { - var user User + var user types.User result := s.db.First(&user, accountAndIDQueryCondition, accountID, userID) if result.Error != nil { if errors.Is(result.Error, gorm.ErrRecordNotFound) { @@ -890,7 +894,7 @@ func (s *SqlStore) Close(_ context.Context) error { } // GetStoreEngine returns underlying store engine -func (s *SqlStore) GetStoreEngine() StoreEngine { +func (s *SqlStore) GetStoreEngine() Engine { return s.storeEngine } @@ -982,8 +986,8 @@ func NewPostgresqlStoreFromSqlStore(ctx context.Context, sqliteStore *SqlStore, return store, nil } -func (s *SqlStore) GetSetupKeyBySecret(ctx context.Context, lockStrength LockingStrength, key string) (*SetupKey, error) { - var setupKey SetupKey +func (s *SqlStore) GetSetupKeyBySecret(ctx context.Context, lockStrength LockingStrength, key string) (*types.SetupKey, error) { + var setupKey types.SetupKey result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}). First(&setupKey, keyQueryCondition, key) if result.Error != nil { @@ -997,7 +1001,7 @@ func (s *SqlStore) GetSetupKeyBySecret(ctx context.Context, lockStrength Locking } func (s *SqlStore) IncrementSetupKeyUsage(ctx context.Context, setupKeyID string) error { - result := s.db.Model(&SetupKey{}). + result := s.db.Model(&types.SetupKey{}). Where(idQueryCondition, setupKeyID). Updates(map[string]interface{}{ "used_times": gorm.Expr("used_times + 1"), @@ -1016,7 +1020,7 @@ func (s *SqlStore) IncrementSetupKeyUsage(ctx context.Context, setupKeyID string } func (s *SqlStore) AddPeerToAllGroup(ctx context.Context, accountID string, peerID string) error { - var group nbgroup.Group + var group types.Group result := s.db.Where("account_id = ? AND name = ?", accountID, "All").First(&group) if result.Error != nil { if errors.Is(result.Error, gorm.ErrRecordNotFound) { @@ -1041,7 +1045,7 @@ func (s *SqlStore) AddPeerToAllGroup(ctx context.Context, accountID string, peer } func (s *SqlStore) AddPeerToGroup(ctx context.Context, accountId string, peerId string, groupID string) error { - var group nbgroup.Group + var group types.Group result := s.db.Where(accountAndIDQueryCondition, accountId, groupID).First(&group) if result.Error != nil { if errors.Is(result.Error, gorm.ErrRecordNotFound) { @@ -1114,7 +1118,7 @@ func (s *SqlStore) GetPeersByIDs(ctx context.Context, lockStrength LockingStreng func (s *SqlStore) IncrementNetworkSerial(ctx context.Context, lockStrength LockingStrength, accountId string) error { result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}). - Model(&Account{}).Where(idQueryCondition, accountId).Update("network_serial", gorm.Expr("network_serial + 1")) + Model(&types.Account{}).Where(idQueryCondition, accountId).Update("network_serial", gorm.Expr("network_serial + 1")) if result.Error != nil { log.WithContext(ctx).Errorf("failed to increment network serial count in store: %v", result.Error) return status.Errorf(status.Internal, "failed to increment network serial count in store") @@ -1156,9 +1160,9 @@ func (s *SqlStore) GetDB() *gorm.DB { return s.db } -func (s *SqlStore) GetAccountDNSSettings(ctx context.Context, lockStrength LockingStrength, accountID string) (*DNSSettings, error) { - var accountDNSSettings AccountDNSSettings - result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}).Model(&Account{}). +func (s *SqlStore) GetAccountDNSSettings(ctx context.Context, lockStrength LockingStrength, accountID string) (*types.DNSSettings, error) { + var accountDNSSettings types.AccountDNSSettings + result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}).Model(&types.Account{}). First(&accountDNSSettings, idQueryCondition, accountID) if result.Error != nil { if errors.Is(result.Error, gorm.ErrRecordNotFound) { @@ -1173,7 +1177,7 @@ func (s *SqlStore) GetAccountDNSSettings(ctx context.Context, lockStrength Locki // AccountExists checks whether an account exists by the given ID. func (s *SqlStore) AccountExists(ctx context.Context, lockStrength LockingStrength, id string) (bool, error) { var accountID string - result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}).Model(&Account{}). + result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}).Model(&types.Account{}). Select("id").First(&accountID, idQueryCondition, id) if result.Error != nil { if errors.Is(result.Error, gorm.ErrRecordNotFound) { @@ -1187,8 +1191,8 @@ func (s *SqlStore) AccountExists(ctx context.Context, lockStrength LockingStreng // GetAccountDomainAndCategory retrieves the Domain and DomainCategory fields for an account based on the given accountID. func (s *SqlStore) GetAccountDomainAndCategory(ctx context.Context, lockStrength LockingStrength, accountID string) (string, string, error) { - var account Account - result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}).Model(&Account{}).Select("domain", "domain_category"). + var account types.Account + result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}).Model(&types.Account{}).Select("domain", "domain_category"). Where(idQueryCondition, accountID).First(&account) if result.Error != nil { if errors.Is(result.Error, gorm.ErrRecordNotFound) { @@ -1201,8 +1205,8 @@ func (s *SqlStore) GetAccountDomainAndCategory(ctx context.Context, lockStrength } // GetGroupByID retrieves a group by ID and account ID. -func (s *SqlStore) GetGroupByID(ctx context.Context, lockStrength LockingStrength, accountID, groupID string) (*nbgroup.Group, error) { - var group *nbgroup.Group +func (s *SqlStore) GetGroupByID(ctx context.Context, lockStrength LockingStrength, accountID, groupID string) (*types.Group, error) { + var group *types.Group result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}).First(&group, accountAndIDQueryCondition, accountID, groupID) if err := result.Error; err != nil { if errors.Is(err, gorm.ErrRecordNotFound) { @@ -1216,8 +1220,8 @@ func (s *SqlStore) GetGroupByID(ctx context.Context, lockStrength LockingStrengt } // GetGroupByName retrieves a group by name and account ID. -func (s *SqlStore) GetGroupByName(ctx context.Context, lockStrength LockingStrength, accountID, groupName string) (*nbgroup.Group, error) { - var group nbgroup.Group +func (s *SqlStore) GetGroupByName(ctx context.Context, lockStrength LockingStrength, accountID, groupName string) (*types.Group, error) { + var group types.Group // TODO: This fix is accepted for now, but if we need to handle this more frequently // we may need to reconsider changing the types. @@ -1240,15 +1244,15 @@ func (s *SqlStore) GetGroupByName(ctx context.Context, lockStrength LockingStren } // GetGroupsByIDs retrieves groups by their IDs and account ID. -func (s *SqlStore) GetGroupsByIDs(ctx context.Context, lockStrength LockingStrength, accountID string, groupIDs []string) (map[string]*nbgroup.Group, error) { - var groups []*nbgroup.Group +func (s *SqlStore) GetGroupsByIDs(ctx context.Context, lockStrength LockingStrength, accountID string, groupIDs []string) (map[string]*types.Group, error) { + var groups []*types.Group result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}).Find(&groups, accountAndIDsQueryCondition, accountID, groupIDs) if result.Error != nil { log.WithContext(ctx).Errorf("failed to get groups by ID's from store: %s", result.Error) return nil, status.Errorf(status.Internal, "failed to get groups by ID's from store") } - groupsMap := make(map[string]*nbgroup.Group) + groupsMap := make(map[string]*types.Group) for _, group := range groups { groupsMap[group.ID] = group } @@ -1257,7 +1261,7 @@ func (s *SqlStore) GetGroupsByIDs(ctx context.Context, lockStrength LockingStren } // SaveGroup saves a group to the store. -func (s *SqlStore) SaveGroup(ctx context.Context, lockStrength LockingStrength, group *nbgroup.Group) error { +func (s *SqlStore) SaveGroup(ctx context.Context, lockStrength LockingStrength, group *types.Group) error { result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}).Save(group) if result.Error != nil { log.WithContext(ctx).Errorf("failed to save group to store: %v", result.Error) @@ -1269,7 +1273,7 @@ func (s *SqlStore) SaveGroup(ctx context.Context, lockStrength LockingStrength, // DeleteGroup deletes a group from the database. func (s *SqlStore) DeleteGroup(ctx context.Context, lockStrength LockingStrength, accountID, groupID string) error { result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}). - Delete(&nbgroup.Group{}, accountAndIDQueryCondition, accountID, groupID) + Delete(&types.Group{}, accountAndIDQueryCondition, accountID, groupID) if err := result.Error; err != nil { log.WithContext(ctx).Errorf("failed to delete group from store: %s", result.Error) return status.Errorf(status.Internal, "failed to delete group from store") @@ -1285,7 +1289,7 @@ func (s *SqlStore) DeleteGroup(ctx context.Context, lockStrength LockingStrength // DeleteGroups deletes groups from the database. func (s *SqlStore) DeleteGroups(ctx context.Context, strength LockingStrength, accountID string, groupIDs []string) error { result := s.db.Clauses(clause.Locking{Strength: string(strength)}). - Delete(&nbgroup.Group{}, accountAndIDsQueryCondition, accountID, groupIDs) + Delete(&types.Group{}, accountAndIDsQueryCondition, accountID, groupIDs) if result.Error != nil { log.WithContext(ctx).Errorf("failed to delete groups from store: %v", result.Error) return status.Errorf(status.Internal, "failed to delete groups from store") @@ -1295,8 +1299,8 @@ func (s *SqlStore) DeleteGroups(ctx context.Context, strength LockingStrength, a } // GetAccountPolicies retrieves policies for an account. -func (s *SqlStore) GetAccountPolicies(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*Policy, error) { - var policies []*Policy +func (s *SqlStore) GetAccountPolicies(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*types.Policy, error) { + var policies []*types.Policy result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}). Preload(clause.Associations).Find(&policies, accountIDCondition, accountID) if err := result.Error; err != nil { @@ -1308,8 +1312,8 @@ func (s *SqlStore) GetAccountPolicies(ctx context.Context, lockStrength LockingS } // GetPolicyByID retrieves a policy by its ID and account ID. -func (s *SqlStore) GetPolicyByID(ctx context.Context, lockStrength LockingStrength, accountID, policyID string) (*Policy, error) { - var policy *Policy +func (s *SqlStore) GetPolicyByID(ctx context.Context, lockStrength LockingStrength, accountID, policyID string) (*types.Policy, error) { + var policy *types.Policy result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}).Preload(clause.Associations). First(&policy, accountAndIDQueryCondition, accountID, policyID) if err := result.Error; err != nil { @@ -1323,7 +1327,7 @@ func (s *SqlStore) GetPolicyByID(ctx context.Context, lockStrength LockingStreng return policy, nil } -func (s *SqlStore) CreatePolicy(ctx context.Context, lockStrength LockingStrength, policy *Policy) error { +func (s *SqlStore) CreatePolicy(ctx context.Context, lockStrength LockingStrength, policy *types.Policy) error { result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}).Create(policy) if result.Error != nil { log.WithContext(ctx).Errorf("failed to create policy in store: %s", result.Error) @@ -1334,7 +1338,7 @@ func (s *SqlStore) CreatePolicy(ctx context.Context, lockStrength LockingStrengt } // SavePolicy saves a policy to the database. -func (s *SqlStore) SavePolicy(ctx context.Context, lockStrength LockingStrength, policy *Policy) error { +func (s *SqlStore) SavePolicy(ctx context.Context, lockStrength LockingStrength, policy *types.Policy) error { result := s.db.Session(&gorm.Session{FullSaveAssociations: true}). Clauses(clause.Locking{Strength: string(lockStrength)}).Save(policy) if err := result.Error; err != nil { @@ -1346,7 +1350,7 @@ func (s *SqlStore) SavePolicy(ctx context.Context, lockStrength LockingStrength, func (s *SqlStore) DeletePolicy(ctx context.Context, lockStrength LockingStrength, accountID, policyID string) error { result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}). - Delete(&Policy{}, accountAndIDQueryCondition, accountID, policyID) + Delete(&types.Policy{}, accountAndIDQueryCondition, accountID, policyID) if err := result.Error; err != nil { log.WithContext(ctx).Errorf("failed to delete policy from store: %s", err) return status.Errorf(status.Internal, "failed to delete policy from store") @@ -1442,8 +1446,8 @@ func (s *SqlStore) GetRouteByID(ctx context.Context, lockStrength LockingStrengt } // GetAccountSetupKeys retrieves setup keys for an account. -func (s *SqlStore) GetAccountSetupKeys(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*SetupKey, error) { - var setupKeys []*SetupKey +func (s *SqlStore) GetAccountSetupKeys(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*types.SetupKey, error) { + var setupKeys []*types.SetupKey result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}). Find(&setupKeys, accountIDCondition, accountID) if err := result.Error; err != nil { @@ -1455,8 +1459,8 @@ func (s *SqlStore) GetAccountSetupKeys(ctx context.Context, lockStrength Locking } // GetSetupKeyByID retrieves a setup key by its ID and account ID. -func (s *SqlStore) GetSetupKeyByID(ctx context.Context, lockStrength LockingStrength, accountID, setupKeyID string) (*SetupKey, error) { - var setupKey *SetupKey +func (s *SqlStore) GetSetupKeyByID(ctx context.Context, lockStrength LockingStrength, accountID, setupKeyID string) (*types.SetupKey, error) { + var setupKey *types.SetupKey result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}). First(&setupKey, accountAndIDQueryCondition, accountID, setupKeyID) if err := result.Error; err != nil { @@ -1471,7 +1475,7 @@ func (s *SqlStore) GetSetupKeyByID(ctx context.Context, lockStrength LockingStre } // SaveSetupKey saves a setup key to the database. -func (s *SqlStore) SaveSetupKey(ctx context.Context, lockStrength LockingStrength, setupKey *SetupKey) error { +func (s *SqlStore) SaveSetupKey(ctx context.Context, lockStrength LockingStrength, setupKey *types.SetupKey) error { result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}).Save(setupKey) if result.Error != nil { log.WithContext(ctx).Errorf("failed to save setup key to store: %s", result.Error) @@ -1483,7 +1487,7 @@ func (s *SqlStore) SaveSetupKey(ctx context.Context, lockStrength LockingStrengt // DeleteSetupKey deletes a setup key from the database. func (s *SqlStore) DeleteSetupKey(ctx context.Context, lockStrength LockingStrength, accountID, keyID string) error { - result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}).Delete(&SetupKey{}, accountAndIDQueryCondition, accountID, keyID) + result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}).Delete(&types.SetupKey{}, accountAndIDQueryCondition, accountID, keyID) if result.Error != nil { log.WithContext(ctx).Errorf("failed to delete setup key from store: %s", result.Error) return status.Errorf(status.Internal, "failed to delete setup key from store") @@ -1583,9 +1587,9 @@ func getRecordByID[T any](db *gorm.DB, lockStrength LockingStrength, recordID, a } // SaveDNSSettings saves the DNS settings to the store. -func (s *SqlStore) SaveDNSSettings(ctx context.Context, lockStrength LockingStrength, accountID string, settings *DNSSettings) error { - result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}).Model(&Account{}). - Where(idQueryCondition, accountID).Updates(&AccountDNSSettings{DNSSettings: *settings}) +func (s *SqlStore) SaveDNSSettings(ctx context.Context, lockStrength LockingStrength, accountID string, settings *types.DNSSettings) error { + result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}).Model(&types.Account{}). + Where(idQueryCondition, accountID).Updates(&types.AccountDNSSettings{DNSSettings: *settings}) if result.Error != nil { log.WithContext(ctx).Errorf("failed to save dns settings to store: %v", result.Error) return status.Errorf(status.Internal, "failed to save dns settings to store") @@ -1597,3 +1601,183 @@ func (s *SqlStore) SaveDNSSettings(ctx context.Context, lockStrength LockingStre return nil } + +func (s *SqlStore) GetAccountNetworks(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*networkTypes.Network, error) { + var networks []*networkTypes.Network + result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}).Find(&networks, accountIDCondition, accountID) + if result.Error != nil { + log.WithContext(ctx).Errorf("failed to get networks from the store: %s", result.Error) + return nil, status.Errorf(status.Internal, "failed to get networks from store") + } + + return networks, nil +} + +func (s *SqlStore) GetNetworkByID(ctx context.Context, lockStrength LockingStrength, accountID, networkID string) (*networkTypes.Network, error) { + var network *networkTypes.Network + result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}). + First(&network, accountAndIDQueryCondition, accountID, networkID) + if result.Error != nil { + if errors.Is(result.Error, gorm.ErrRecordNotFound) { + return nil, status.NewNetworkNotFoundError(networkID) + } + + log.WithContext(ctx).Errorf("failed to get network from store: %v", result.Error) + return nil, status.Errorf(status.Internal, "failed to get network from store") + } + + return network, nil +} + +func (s *SqlStore) SaveNetwork(ctx context.Context, lockStrength LockingStrength, network *networkTypes.Network) error { + result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}).Save(network) + if result.Error != nil { + log.WithContext(ctx).Errorf("failed to save network to store: %v", result.Error) + return status.Errorf(status.Internal, "failed to save network to store") + } + + return nil +} + +func (s *SqlStore) DeleteNetwork(ctx context.Context, lockStrength LockingStrength, accountID, networkID string) error { + result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}). + Delete(&networkTypes.Network{}, accountAndIDQueryCondition, accountID, networkID) + if result.Error != nil { + log.WithContext(ctx).Errorf("failed to delete network from store: %v", result.Error) + return status.Errorf(status.Internal, "failed to delete network from store") + } + + if result.RowsAffected == 0 { + return status.NewNetworkNotFoundError(networkID) + } + + return nil +} + +func (s *SqlStore) GetNetworkRoutersByNetID(ctx context.Context, lockStrength LockingStrength, accountID, netID string) ([]*routerTypes.NetworkRouter, error) { + var netRouters []*routerTypes.NetworkRouter + result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}). + Find(&netRouters, "account_id = ? AND network_id = ?", accountID, netID) + if result.Error != nil { + log.WithContext(ctx).Errorf("failed to get network routers from store: %v", result.Error) + return nil, status.Errorf(status.Internal, "failed to get network routers from store") + } + + return netRouters, nil +} + +func (s *SqlStore) GetNetworkRoutersByAccountID(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*routerTypes.NetworkRouter, error) { + var netRouters []*routerTypes.NetworkRouter + result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}). + Find(&netRouters, accountIDCondition, accountID) + if result.Error != nil { + log.WithContext(ctx).Errorf("failed to get network routers from store: %v", result.Error) + return nil, status.Errorf(status.Internal, "failed to get network routers from store") + } + + return netRouters, nil +} + +func (s *SqlStore) GetNetworkRouterByID(ctx context.Context, lockStrength LockingStrength, accountID, routerID string) (*routerTypes.NetworkRouter, error) { + var netRouter *routerTypes.NetworkRouter + result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}). + First(&netRouter, accountAndIDQueryCondition, accountID, routerID) + if result.Error != nil { + if errors.Is(result.Error, gorm.ErrRecordNotFound) { + return nil, status.NewNetworkRouterNotFoundError(routerID) + } + log.WithContext(ctx).Errorf("failed to get network router from store: %v", result.Error) + return nil, status.Errorf(status.Internal, "failed to get network router from store") + } + + return netRouter, nil +} + +func (s *SqlStore) SaveNetworkRouter(ctx context.Context, lockStrength LockingStrength, router *routerTypes.NetworkRouter) error { + result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}).Save(router) + if result.Error != nil { + log.WithContext(ctx).Errorf("failed to save network router to store: %v", result.Error) + return status.Errorf(status.Internal, "failed to save network router to store") + } + + return nil +} + +func (s *SqlStore) DeleteNetworkRouter(ctx context.Context, lockStrength LockingStrength, accountID, routerID string) error { + result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}). + Delete(&routerTypes.NetworkRouter{}, accountAndIDQueryCondition, accountID, routerID) + if result.Error != nil { + log.WithContext(ctx).Errorf("failed to delete network router from store: %v", result.Error) + return status.Errorf(status.Internal, "failed to delete network router from store") + } + + if result.RowsAffected == 0 { + return status.NewNetworkRouterNotFoundError(routerID) + } + + return nil +} + +func (s *SqlStore) GetNetworkResourcesByNetID(ctx context.Context, lockStrength LockingStrength, accountID, networkID string) ([]*resourceTypes.NetworkResource, error) { + var netResources []*resourceTypes.NetworkResource + result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}). + Find(&netResources, "account_id = ? AND network_id = ?", accountID, networkID) + if result.Error != nil { + log.WithContext(ctx).Errorf("failed to get network resources from store: %v", result.Error) + return nil, status.Errorf(status.Internal, "failed to get network resources from store") + } + + return netResources, nil +} + +func (s *SqlStore) GetNetworkResourcesByAccountID(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*resourceTypes.NetworkResource, error) { + var netResources []*resourceTypes.NetworkResource + result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}). + Find(&netResources, accountIDCondition, accountID) + if result.Error != nil { + log.WithContext(ctx).Errorf("failed to get network resources from store: %v", result.Error) + return nil, status.Errorf(status.Internal, "failed to get network resources from store") + } + + return netResources, nil +} + +func (s *SqlStore) GetNetworkResourceByID(ctx context.Context, lockStrength LockingStrength, accountID, resourceID string) (*resourceTypes.NetworkResource, error) { + var netResources *resourceTypes.NetworkResource + result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}). + First(&netResources, accountAndIDQueryCondition, accountID, resourceID) + if result.Error != nil { + if errors.Is(result.Error, gorm.ErrRecordNotFound) { + return nil, status.NewNetworkResourceNotFoundError(resourceID) + } + log.WithContext(ctx).Errorf("failed to get network resource from store: %v", result.Error) + return nil, status.Errorf(status.Internal, "failed to get network resource from store") + } + + return netResources, nil +} + +func (s *SqlStore) SaveNetworkResource(ctx context.Context, lockStrength LockingStrength, resource *resourceTypes.NetworkResource) error { + result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}).Save(resource) + if result.Error != nil { + log.WithContext(ctx).Errorf("failed to save network resource to store: %v", result.Error) + return status.Errorf(status.Internal, "failed to save network resource to store") + } + + return nil +} + +func (s *SqlStore) DeleteNetworkResource(ctx context.Context, lockStrength LockingStrength, accountID, resourceID string) error { + result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}). + Delete(&resourceTypes.NetworkResource{}, accountAndIDQueryCondition, accountID, resourceID) + if result.Error != nil { + log.WithContext(ctx).Errorf("failed to delete network resource from store: %v", result.Error) + return status.Errorf(status.Internal, "failed to delete network resource from store") + } + + if result.RowsAffected == 0 { + return status.NewNetworkResourceNotFoundError(resourceID) + } + + return nil +} diff --git a/management/server/sql_store_test.go b/management/server/store/sql_store_test.go similarity index 76% rename from management/server/sql_store_test.go rename to management/server/store/sql_store_test.go index 6064b019f..9bb7addcb 100644 --- a/management/server/sql_store_test.go +++ b/management/server/store/sql_store_test.go @@ -1,4 +1,4 @@ -package server +package store import ( "context" @@ -14,17 +14,24 @@ import ( "time" "github.com/google/uuid" - nbdns "github.com/netbirdio/netbird/dns" - nbgroup "github.com/netbirdio/netbird/management/server/group" - "github.com/netbirdio/netbird/management/server/posture" + "github.com/rs/xid" + log "github.com/sirupsen/logrus" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" + nbdns "github.com/netbirdio/netbird/dns" + resourceTypes "github.com/netbirdio/netbird/management/server/networks/resources/types" + routerTypes "github.com/netbirdio/netbird/management/server/networks/routers/types" + networkTypes "github.com/netbirdio/netbird/management/server/networks/types" + "github.com/netbirdio/netbird/management/server/posture" + "github.com/netbirdio/netbird/management/server/types" + route2 "github.com/netbirdio/netbird/route" "github.com/netbirdio/netbird/management/server/status" nbpeer "github.com/netbirdio/netbird/management/server/peer" + nbroute "github.com/netbirdio/netbird/route" ) func TestSqlite_NewStore(t *testing.T) { @@ -73,7 +80,7 @@ func runLargeTest(t *testing.T, store Store) { if err != nil { t.Fatal(err) } - setupKey, _ := GenerateDefaultSetupKey() + setupKey, _ := types.GenerateDefaultSetupKey() account.SetupKeys[setupKey.Key] = setupKey const numPerAccount = 6000 for n := 0; n < numPerAccount; n++ { @@ -86,14 +93,14 @@ func runLargeTest(t *testing.T, store Store) { IP: netIP, Name: peerID, DNSLabel: peerID, - UserID: userID, + UserID: "testuser", Status: &nbpeer.PeerStatus{Connected: false, LastSeen: time.Now()}, SSHEnabled: false, } account.Peers[peerID] = peer group, _ := account.GetGroupAll() group.Peers = append(group.Peers, peerID) - user := &User{ + user := &types.User{ Id: fmt.Sprintf("%s-user-%d", account.Id, n), AccountID: account.Id, } @@ -111,7 +118,7 @@ func runLargeTest(t *testing.T, store Store) { } account.Routes[route.ID] = route - group = &nbgroup.Group{ + group = &types.Group{ ID: fmt.Sprintf("group-id-%d", n), AccountID: account.Id, Name: fmt.Sprintf("group-id-%d", n), @@ -134,7 +141,7 @@ func runLargeTest(t *testing.T, store Store) { } account.NameServerGroups[nameserver.ID] = nameserver - setupKey, _ := GenerateDefaultSetupKey() + setupKey, _ := types.GenerateDefaultSetupKey() account.SetupKeys[setupKey.Key] = setupKey } @@ -216,7 +223,7 @@ func TestSqlite_SaveAccount(t *testing.T) { assert.NoError(t, err) account := newAccountWithId(context.Background(), "account_id", "testuser", "") - setupKey, _ := GenerateDefaultSetupKey() + setupKey, _ := types.GenerateDefaultSetupKey() account.SetupKeys[setupKey.Key] = setupKey account.Peers["testpeer"] = &nbpeer.Peer{ Key: "peerkey", @@ -230,7 +237,7 @@ func TestSqlite_SaveAccount(t *testing.T) { require.NoError(t, err) account2 := newAccountWithId(context.Background(), "account_id2", "testuser2", "") - setupKey, _ = GenerateDefaultSetupKey() + setupKey, _ = types.GenerateDefaultSetupKey() account2.SetupKeys[setupKey.Key] = setupKey account2.Peers["testpeer2"] = &nbpeer.Peer{ Key: "peerkey2", @@ -289,14 +296,14 @@ func TestSqlite_DeleteAccount(t *testing.T) { assert.NoError(t, err) testUserID := "testuser" - user := NewAdminUser(testUserID) - user.PATs = map[string]*PersonalAccessToken{"testtoken": { + user := types.NewAdminUser(testUserID) + user.PATs = map[string]*types.PersonalAccessToken{"testtoken": { ID: "testtoken", Name: "test token", }} account := newAccountWithId(context.Background(), "account_id", testUserID, "") - setupKey, _ := GenerateDefaultSetupKey() + setupKey, _ := types.GenerateDefaultSetupKey() account.SetupKeys[setupKey.Key] = setupKey account.Peers["testpeer"] = &nbpeer.Peer{ Key: "peerkey", @@ -306,6 +313,35 @@ func TestSqlite_DeleteAccount(t *testing.T) { Status: &nbpeer.PeerStatus{Connected: true, LastSeen: time.Now().UTC()}, } account.Users[testUserID] = user + account.Networks = []*networkTypes.Network{ + { + ID: "network_id", + AccountID: account.Id, + Name: "network name", + Description: "network description", + }, + } + account.NetworkRouters = []*routerTypes.NetworkRouter{ + { + ID: "router_id", + NetworkID: account.Networks[0].ID, + AccountID: account.Id, + PeerGroups: []string{"group_id"}, + Masquerade: true, + Metric: 1, + }, + } + account.NetworkResources = []*resourceTypes.NetworkResource{ + { + ID: "resource_id", + NetworkID: account.Networks[0].ID, + AccountID: account.Id, + Name: "Name", + Description: "Description", + Type: "Domain", + Address: "example.com", + }, + } err = store.SaveAccount(context.Background(), account) require.NoError(t, err) @@ -337,21 +373,30 @@ func TestSqlite_DeleteAccount(t *testing.T) { require.Error(t, err, "expecting error after removing DeleteAccount when getting account by id") for _, policy := range account.Policies { - var rules []*PolicyRule - err = store.(*SqlStore).db.Model(&PolicyRule{}).Find(&rules, "policy_id = ?", policy.ID).Error + var rules []*types.PolicyRule + err = store.(*SqlStore).db.Model(&types.PolicyRule{}).Find(&rules, "policy_id = ?", policy.ID).Error require.NoError(t, err, "expecting no error after removing DeleteAccount when searching for policy rules") require.Len(t, rules, 0, "expecting no policy rules to be found after removing DeleteAccount") } for _, accountUser := range account.Users { - var pats []*PersonalAccessToken - err = store.(*SqlStore).db.Model(&PersonalAccessToken{}).Find(&pats, "user_id = ?", accountUser.Id).Error + var pats []*types.PersonalAccessToken + err = store.(*SqlStore).db.Model(&types.PersonalAccessToken{}).Find(&pats, "user_id = ?", accountUser.Id).Error require.NoError(t, err, "expecting no error after removing DeleteAccount when searching for personal access token") require.Len(t, pats, 0, "expecting no personal access token to be found after removing DeleteAccount") } + for _, network := range account.Networks { + routers, err := store.GetNetworkRoutersByNetID(context.Background(), LockingStrengthShare, account.Id, network.ID) + require.NoError(t, err, "expecting no error after removing DeleteAccount when searching for network routers") + require.Len(t, routers, 0, "expecting no network routers to be found after DeleteAccount") + + resources, err := store.GetNetworkResourcesByNetID(context.Background(), LockingStrengthShare, account.Id, network.ID) + require.NoError(t, err, "expecting no error after removing DeleteAccount when searching for network resources") + require.Len(t, resources, 0, "expecting no network resources to be found after DeleteAccount") + } } func TestSqlite_GetAccount(t *testing.T) { @@ -360,7 +405,7 @@ func TestSqlite_GetAccount(t *testing.T) { } t.Setenv("NETBIRD_STORE_ENGINE", string(SqliteStoreEngine)) - store, cleanUp, err := NewTestStoreFromSQL(context.Background(), "testdata/store.sql", t.TempDir()) + store, cleanUp, err := NewTestStoreFromSQL(context.Background(), "../testdata/store.sql", t.TempDir()) t.Cleanup(cleanUp) assert.NoError(t, err) @@ -383,7 +428,7 @@ func TestSqlite_SavePeer(t *testing.T) { } t.Setenv("NETBIRD_STORE_ENGINE", string(SqliteStoreEngine)) - store, cleanUp, err := NewTestStoreFromSQL(context.Background(), "testdata/store.sql", t.TempDir()) + store, cleanUp, err := NewTestStoreFromSQL(context.Background(), "../testdata/store.sql", t.TempDir()) t.Cleanup(cleanUp) assert.NoError(t, err) @@ -433,7 +478,7 @@ func TestSqlite_SavePeerStatus(t *testing.T) { } t.Setenv("NETBIRD_STORE_ENGINE", string(SqliteStoreEngine)) - store, cleanUp, err := NewTestStoreFromSQL(context.Background(), "testdata/store.sql", t.TempDir()) + store, cleanUp, err := NewTestStoreFromSQL(context.Background(), "../testdata/store.sql", t.TempDir()) t.Cleanup(cleanUp) assert.NoError(t, err) @@ -488,7 +533,7 @@ func TestSqlite_SavePeerLocation(t *testing.T) { } t.Setenv("NETBIRD_STORE_ENGINE", string(SqliteStoreEngine)) - store, cleanUp, err := NewTestStoreFromSQL(context.Background(), "testdata/store.sql", t.TempDir()) + store, cleanUp, err := NewTestStoreFromSQL(context.Background(), "../testdata/store.sql", t.TempDir()) t.Cleanup(cleanUp) assert.NoError(t, err) @@ -542,7 +587,7 @@ func TestSqlite_TestGetAccountByPrivateDomain(t *testing.T) { } t.Setenv("NETBIRD_STORE_ENGINE", string(SqliteStoreEngine)) - store, cleanUp, err := NewTestStoreFromSQL(context.Background(), "testdata/store.sql", t.TempDir()) + store, cleanUp, err := NewTestStoreFromSQL(context.Background(), "../testdata/store.sql", t.TempDir()) t.Cleanup(cleanUp) assert.NoError(t, err) @@ -565,7 +610,7 @@ func TestSqlite_GetTokenIDByHashedToken(t *testing.T) { } t.Setenv("NETBIRD_STORE_ENGINE", string(SqliteStoreEngine)) - store, cleanUp, err := NewTestStoreFromSQL(context.Background(), "testdata/store.sql", t.TempDir()) + store, cleanUp, err := NewTestStoreFromSQL(context.Background(), "../testdata/store.sql", t.TempDir()) t.Cleanup(cleanUp) assert.NoError(t, err) @@ -589,7 +634,7 @@ func TestSqlite_GetUserByTokenID(t *testing.T) { } t.Setenv("NETBIRD_STORE_ENGINE", string(SqliteStoreEngine)) - store, cleanUp, err := NewTestStoreFromSQL(context.Background(), "testdata/store.sql", t.TempDir()) + store, cleanUp, err := NewTestStoreFromSQL(context.Background(), "../testdata/store.sql", t.TempDir()) t.Cleanup(cleanUp) assert.NoError(t, err) @@ -625,7 +670,7 @@ func TestMigrate(t *testing.T) { require.NoError(t, err, "Failed to parse CIDR") type network struct { - Network + types.Network Net net.IPNet `gorm:"serializer:gob"` } @@ -640,7 +685,7 @@ func TestMigrate(t *testing.T) { } type account struct { - Account + types.Account Network *network `gorm:"embedded;embeddedPrefix:network_"` Peers []peer `gorm:"foreignKey:AccountID;references:id"` } @@ -700,23 +745,10 @@ func TestMigrate(t *testing.T) { } -func newSqliteStore(t *testing.T) *SqlStore { - t.Helper() - - store, err := NewSqliteStore(context.Background(), t.TempDir(), nil) - t.Cleanup(func() { - store.Close(context.Background()) - }) - require.NoError(t, err) - require.NotNil(t, store) - - return store -} - func newAccount(store Store, id int) error { str := fmt.Sprintf("%s-%d", uuid.New().String(), id) account := newAccountWithId(context.Background(), str, str+"-testuser", "example.com") - setupKey, _ := GenerateDefaultSetupKey() + setupKey, _ := types.GenerateDefaultSetupKey() account.SetupKeys[setupKey.Key] = setupKey account.Peers["p"+str] = &nbpeer.Peer{ Key: "peerkey" + str, @@ -755,7 +787,7 @@ func TestPostgresql_SaveAccount(t *testing.T) { assert.NoError(t, err) account := newAccountWithId(context.Background(), "account_id", "testuser", "") - setupKey, _ := GenerateDefaultSetupKey() + setupKey, _ := types.GenerateDefaultSetupKey() account.SetupKeys[setupKey.Key] = setupKey account.Peers["testpeer"] = &nbpeer.Peer{ Key: "peerkey", @@ -769,7 +801,7 @@ func TestPostgresql_SaveAccount(t *testing.T) { require.NoError(t, err) account2 := newAccountWithId(context.Background(), "account_id2", "testuser2", "") - setupKey, _ = GenerateDefaultSetupKey() + setupKey, _ = types.GenerateDefaultSetupKey() account2.SetupKeys[setupKey.Key] = setupKey account2.Peers["testpeer2"] = &nbpeer.Peer{ Key: "peerkey2", @@ -828,14 +860,14 @@ func TestPostgresql_DeleteAccount(t *testing.T) { assert.NoError(t, err) testUserID := "testuser" - user := NewAdminUser(testUserID) - user.PATs = map[string]*PersonalAccessToken{"testtoken": { + user := types.NewAdminUser(testUserID) + user.PATs = map[string]*types.PersonalAccessToken{"testtoken": { ID: "testtoken", Name: "test token", }} account := newAccountWithId(context.Background(), "account_id", testUserID, "") - setupKey, _ := GenerateDefaultSetupKey() + setupKey, _ := types.GenerateDefaultSetupKey() account.SetupKeys[setupKey.Key] = setupKey account.Peers["testpeer"] = &nbpeer.Peer{ Key: "peerkey", @@ -876,16 +908,16 @@ func TestPostgresql_DeleteAccount(t *testing.T) { require.Error(t, err, "expecting error after removing DeleteAccount when getting account by id") for _, policy := range account.Policies { - var rules []*PolicyRule - err = store.(*SqlStore).db.Model(&PolicyRule{}).Find(&rules, "policy_id = ?", policy.ID).Error + var rules []*types.PolicyRule + err = store.(*SqlStore).db.Model(&types.PolicyRule{}).Find(&rules, "policy_id = ?", policy.ID).Error require.NoError(t, err, "expecting no error after removing DeleteAccount when searching for policy rules") require.Len(t, rules, 0, "expecting no policy rules to be found after removing DeleteAccount") } for _, accountUser := range account.Users { - var pats []*PersonalAccessToken - err = store.(*SqlStore).db.Model(&PersonalAccessToken{}).Find(&pats, "user_id = ?", accountUser.Id).Error + var pats []*types.PersonalAccessToken + err = store.(*SqlStore).db.Model(&types.PersonalAccessToken{}).Find(&pats, "user_id = ?", accountUser.Id).Error require.NoError(t, err, "expecting no error after removing DeleteAccount when searching for personal access token") require.Len(t, pats, 0, "expecting no personal access token to be found after removing DeleteAccount") @@ -899,7 +931,7 @@ func TestPostgresql_SavePeerStatus(t *testing.T) { } t.Setenv("NETBIRD_STORE_ENGINE", string(PostgresStoreEngine)) - store, cleanUp, err := NewTestStoreFromSQL(context.Background(), "testdata/store.sql", t.TempDir()) + store, cleanUp, err := NewTestStoreFromSQL(context.Background(), "../testdata/store.sql", t.TempDir()) t.Cleanup(cleanUp) assert.NoError(t, err) @@ -940,7 +972,7 @@ func TestPostgresql_TestGetAccountByPrivateDomain(t *testing.T) { } t.Setenv("NETBIRD_STORE_ENGINE", string(PostgresStoreEngine)) - store, cleanUp, err := NewTestStoreFromSQL(context.Background(), "testdata/store.sql", t.TempDir()) + store, cleanUp, err := NewTestStoreFromSQL(context.Background(), "../testdata/store.sql", t.TempDir()) t.Cleanup(cleanUp) assert.NoError(t, err) @@ -960,7 +992,7 @@ func TestPostgresql_GetTokenIDByHashedToken(t *testing.T) { } t.Setenv("NETBIRD_STORE_ENGINE", string(PostgresStoreEngine)) - store, cleanUp, err := NewTestStoreFromSQL(context.Background(), "testdata/store.sql", t.TempDir()) + store, cleanUp, err := NewTestStoreFromSQL(context.Background(), "../testdata/store.sql", t.TempDir()) t.Cleanup(cleanUp) assert.NoError(t, err) @@ -978,7 +1010,7 @@ func TestPostgresql_GetUserByTokenID(t *testing.T) { } t.Setenv("NETBIRD_STORE_ENGINE", string(PostgresStoreEngine)) - store, cleanUp, err := NewTestStoreFromSQL(context.Background(), "testdata/store.sql", t.TempDir()) + store, cleanUp, err := NewTestStoreFromSQL(context.Background(), "../testdata/store.sql", t.TempDir()) t.Cleanup(cleanUp) assert.NoError(t, err) @@ -991,7 +1023,7 @@ func TestPostgresql_GetUserByTokenID(t *testing.T) { func TestSqlite_GetTakenIPs(t *testing.T) { t.Setenv("NETBIRD_STORE_ENGINE", string(SqliteStoreEngine)) - store, cleanup, err := NewTestStoreFromSQL(context.Background(), "testdata/extended-store.sql", t.TempDir()) + store, cleanup, err := NewTestStoreFromSQL(context.Background(), "../testdata/extended-store.sql", t.TempDir()) defer cleanup() if err != nil { t.Fatal(err) @@ -1036,7 +1068,7 @@ func TestSqlite_GetTakenIPs(t *testing.T) { func TestSqlite_GetPeerLabelsInAccount(t *testing.T) { t.Setenv("NETBIRD_STORE_ENGINE", string(SqliteStoreEngine)) - store, cleanup, err := NewTestStoreFromSQL(context.Background(), "testdata/extended-store.sql", t.TempDir()) + store, cleanup, err := NewTestStoreFromSQL(context.Background(), "../testdata/extended-store.sql", t.TempDir()) if err != nil { return } @@ -1078,7 +1110,7 @@ func TestSqlite_GetPeerLabelsInAccount(t *testing.T) { func TestSqlite_GetAccountNetwork(t *testing.T) { t.Setenv("NETBIRD_STORE_ENGINE", string(SqliteStoreEngine)) - store, cleanup, err := NewTestStoreFromSQL(context.Background(), "testdata/extended-store.sql", t.TempDir()) + store, cleanup, err := NewTestStoreFromSQL(context.Background(), "../testdata/extended-store.sql", t.TempDir()) t.Cleanup(cleanup) if err != nil { t.Fatal(err) @@ -1101,7 +1133,7 @@ func TestSqlite_GetAccountNetwork(t *testing.T) { func TestSqlite_GetSetupKeyBySecret(t *testing.T) { t.Setenv("NETBIRD_STORE_ENGINE", string(SqliteStoreEngine)) - store, cleanup, err := NewTestStoreFromSQL(context.Background(), "testdata/extended-store.sql", t.TempDir()) + store, cleanup, err := NewTestStoreFromSQL(context.Background(), "../testdata/extended-store.sql", t.TempDir()) t.Cleanup(cleanup) if err != nil { t.Fatal(err) @@ -1119,14 +1151,14 @@ func TestSqlite_GetSetupKeyBySecret(t *testing.T) { setupKey, err := store.GetSetupKeyBySecret(context.Background(), LockingStrengthShare, encodedHashedKey) require.NoError(t, err) assert.Equal(t, encodedHashedKey, setupKey.Key) - assert.Equal(t, hiddenKey(plainKey, 4), setupKey.KeySecret) + assert.Equal(t, types.HiddenKey(plainKey, 4), setupKey.KeySecret) assert.Equal(t, "bf1c8084-ba50-4ce7-9439-34653001fc3b", setupKey.AccountID) assert.Equal(t, "Default key", setupKey.Name) } func TestSqlite_incrementSetupKeyUsage(t *testing.T) { t.Setenv("NETBIRD_STORE_ENGINE", string(SqliteStoreEngine)) - store, cleanup, err := NewTestStoreFromSQL(context.Background(), "testdata/extended-store.sql", t.TempDir()) + store, cleanup, err := NewTestStoreFromSQL(context.Background(), "../testdata/extended-store.sql", t.TempDir()) t.Cleanup(cleanup) if err != nil { t.Fatal(err) @@ -1162,13 +1194,13 @@ func TestSqlite_incrementSetupKeyUsage(t *testing.T) { func TestSqlite_CreateAndGetObjectInTransaction(t *testing.T) { t.Setenv("NETBIRD_STORE_ENGINE", string(SqliteStoreEngine)) - store, cleanup, err := NewTestStoreFromSQL(context.Background(), "testdata/extended-store.sql", t.TempDir()) + store, cleanup, err := NewTestStoreFromSQL(context.Background(), "../testdata/extended-store.sql", t.TempDir()) t.Cleanup(cleanup) if err != nil { t.Fatal(err) } - group := &nbgroup.Group{ + group := &types.Group{ ID: "group-id", AccountID: "account-id", Name: "group-name", @@ -1193,7 +1225,7 @@ func TestSqlite_CreateAndGetObjectInTransaction(t *testing.T) { } func TestSqlite_GetAccoundUsers(t *testing.T) { - store, cleanup, err := NewTestStoreFromSQL(context.Background(), "testdata/extended-store.sql", t.TempDir()) + store, cleanup, err := NewTestStoreFromSQL(context.Background(), "../testdata/extended-store.sql", t.TempDir()) t.Cleanup(cleanup) if err != nil { t.Fatal(err) @@ -1207,7 +1239,7 @@ func TestSqlite_GetAccoundUsers(t *testing.T) { } func TestSqlStore_UpdateAccountDomainAttributes(t *testing.T) { - store, cleanup, err := NewTestStoreFromSQL(context.Background(), "testdata/extended-store.sql", t.TempDir()) + store, cleanup, err := NewTestStoreFromSQL(context.Background(), "../testdata/extended-store.sql", t.TempDir()) t.Cleanup(cleanup) if err != nil { t.Fatal(err) @@ -1253,7 +1285,7 @@ func TestSqlStore_UpdateAccountDomainAttributes(t *testing.T) { } func TestSqlite_GetGroupByName(t *testing.T) { - store, cleanup, err := NewTestStoreFromSQL(context.Background(), "testdata/extended-store.sql", t.TempDir()) + store, cleanup, err := NewTestStoreFromSQL(context.Background(), "../testdata/extended-store.sql", t.TempDir()) t.Cleanup(cleanup) if err != nil { t.Fatal(err) @@ -1267,7 +1299,7 @@ func TestSqlite_GetGroupByName(t *testing.T) { func Test_DeleteSetupKeySuccessfully(t *testing.T) { t.Setenv("NETBIRD_STORE_ENGINE", string(SqliteStoreEngine)) - store, cleanup, err := NewTestStoreFromSQL(context.Background(), "testdata/extended-store.sql", t.TempDir()) + store, cleanup, err := NewTestStoreFromSQL(context.Background(), "../testdata/extended-store.sql", t.TempDir()) t.Cleanup(cleanup) require.NoError(t, err) @@ -1283,7 +1315,7 @@ func Test_DeleteSetupKeySuccessfully(t *testing.T) { func Test_DeleteSetupKeyFailsForNonExistingKey(t *testing.T) { t.Setenv("NETBIRD_STORE_ENGINE", string(SqliteStoreEngine)) - store, cleanup, err := NewTestStoreFromSQL(context.Background(), "testdata/extended-store.sql", t.TempDir()) + store, cleanup, err := NewTestStoreFromSQL(context.Background(), "../testdata/extended-store.sql", t.TempDir()) t.Cleanup(cleanup) require.NoError(t, err) @@ -1295,7 +1327,7 @@ func Test_DeleteSetupKeyFailsForNonExistingKey(t *testing.T) { } func TestSqlStore_GetGroupsByIDs(t *testing.T) { - store, cleanup, err := NewTestStoreFromSQL(context.Background(), "testdata/extended-store.sql", t.TempDir()) + store, cleanup, err := NewTestStoreFromSQL(context.Background(), "../testdata/extended-store.sql", t.TempDir()) t.Cleanup(cleanup) require.NoError(t, err) @@ -1338,13 +1370,13 @@ func TestSqlStore_GetGroupsByIDs(t *testing.T) { } func TestSqlStore_SaveGroup(t *testing.T) { - store, cleanup, err := NewTestStoreFromSQL(context.Background(), "testdata/extended-store.sql", t.TempDir()) + store, cleanup, err := NewTestStoreFromSQL(context.Background(), "../testdata/extended-store.sql", t.TempDir()) t.Cleanup(cleanup) require.NoError(t, err) accountID := "bf1c8084-ba50-4ce7-9439-34653001fc3b" - group := &nbgroup.Group{ + group := &types.Group{ ID: "group-id", AccountID: accountID, Issued: "api", @@ -1359,13 +1391,13 @@ func TestSqlStore_SaveGroup(t *testing.T) { } func TestSqlStore_SaveGroups(t *testing.T) { - store, cleanup, err := NewTestStoreFromSQL(context.Background(), "testdata/extended-store.sql", t.TempDir()) + store, cleanup, err := NewTestStoreFromSQL(context.Background(), "../testdata/extended-store.sql", t.TempDir()) t.Cleanup(cleanup) require.NoError(t, err) accountID := "bf1c8084-ba50-4ce7-9439-34653001fc3b" - groups := []*nbgroup.Group{ + groups := []*types.Group{ { ID: "group-1", AccountID: accountID, @@ -1384,7 +1416,7 @@ func TestSqlStore_SaveGroups(t *testing.T) { } func TestSqlStore_DeleteGroup(t *testing.T) { - store, cleanup, err := NewTestStoreFromSQL(context.Background(), "testdata/extended-store.sql", t.TempDir()) + store, cleanup, err := NewTestStoreFromSQL(context.Background(), "../testdata/extended-store.sql", t.TempDir()) t.Cleanup(cleanup) require.NoError(t, err) @@ -1432,7 +1464,7 @@ func TestSqlStore_DeleteGroup(t *testing.T) { } func TestSqlStore_DeleteGroups(t *testing.T) { - store, cleanup, err := NewTestStoreFromSQL(context.Background(), "testdata/extended-store.sql", t.TempDir()) + store, cleanup, err := NewTestStoreFromSQL(context.Background(), "../testdata/extended-store.sql", t.TempDir()) t.Cleanup(cleanup) require.NoError(t, err) @@ -1479,7 +1511,7 @@ func TestSqlStore_DeleteGroups(t *testing.T) { } func TestSqlStore_GetPeerByID(t *testing.T) { - store, cleanup, err := NewTestStoreFromSQL(context.Background(), "testdata/store_policy_migrate.sql", t.TempDir()) + store, cleanup, err := NewTestStoreFromSQL(context.Background(), "../testdata/store_policy_migrate.sql", t.TempDir()) t.Cleanup(cleanup) require.NoError(t, err) @@ -1525,7 +1557,7 @@ func TestSqlStore_GetPeerByID(t *testing.T) { } func TestSqlStore_GetPeersByIDs(t *testing.T) { - store, cleanup, err := NewTestStoreFromSQL(context.Background(), "testdata/store_policy_migrate.sql", t.TempDir()) + store, cleanup, err := NewTestStoreFromSQL(context.Background(), "../testdata/store_policy_migrate.sql", t.TempDir()) t.Cleanup(cleanup) require.NoError(t, err) @@ -1567,7 +1599,7 @@ func TestSqlStore_GetPeersByIDs(t *testing.T) { } func TestSqlStore_GetPostureChecksByID(t *testing.T) { - store, cleanup, err := NewTestStoreFromSQL(context.Background(), "testdata/extended-store.sql", t.TempDir()) + store, cleanup, err := NewTestStoreFromSQL(context.Background(), "../testdata/extended-store.sql", t.TempDir()) t.Cleanup(cleanup) require.NoError(t, err) @@ -1613,7 +1645,7 @@ func TestSqlStore_GetPostureChecksByID(t *testing.T) { } func TestSqlStore_GetPostureChecksByIDs(t *testing.T) { - store, cleanup, err := NewTestStoreFromSQL(context.Background(), "testdata/extended-store.sql", t.TempDir()) + store, cleanup, err := NewTestStoreFromSQL(context.Background(), "../testdata/extended-store.sql", t.TempDir()) t.Cleanup(cleanup) require.NoError(t, err) @@ -1656,7 +1688,7 @@ func TestSqlStore_GetPostureChecksByIDs(t *testing.T) { } func TestSqlStore_SavePostureChecks(t *testing.T) { - store, cleanup, err := NewTestStoreFromSQL(context.Background(), "testdata/extended-store.sql", t.TempDir()) + store, cleanup, err := NewTestStoreFromSQL(context.Background(), "../testdata/extended-store.sql", t.TempDir()) t.Cleanup(cleanup) require.NoError(t, err) @@ -1697,7 +1729,7 @@ func TestSqlStore_SavePostureChecks(t *testing.T) { } func TestSqlStore_DeletePostureChecks(t *testing.T) { - store, cleanup, err := NewTestStoreFromSQL(context.Background(), "testdata/extended-store.sql", t.TempDir()) + store, cleanup, err := NewTestStoreFromSQL(context.Background(), "../testdata/extended-store.sql", t.TempDir()) t.Cleanup(cleanup) require.NoError(t, err) @@ -1744,7 +1776,7 @@ func TestSqlStore_DeletePostureChecks(t *testing.T) { } func TestSqlStore_GetPolicyByID(t *testing.T) { - store, cleanup, err := NewTestStoreFromSQL(context.Background(), "testdata/store.sql", t.TempDir()) + store, cleanup, err := NewTestStoreFromSQL(context.Background(), "../testdata/store.sql", t.TempDir()) t.Cleanup(cleanup) require.NoError(t, err) @@ -1790,23 +1822,23 @@ func TestSqlStore_GetPolicyByID(t *testing.T) { } func TestSqlStore_CreatePolicy(t *testing.T) { - store, cleanup, err := NewTestStoreFromSQL(context.Background(), "testdata/store.sql", t.TempDir()) + store, cleanup, err := NewTestStoreFromSQL(context.Background(), "../testdata/store.sql", t.TempDir()) t.Cleanup(cleanup) require.NoError(t, err) accountID := "bf1c8084-ba50-4ce7-9439-34653001fc3b" - policy := &Policy{ + policy := &types.Policy{ ID: "policy-id", AccountID: accountID, Enabled: true, - Rules: []*PolicyRule{ + Rules: []*types.PolicyRule{ { Enabled: true, Sources: []string{"groupA"}, Destinations: []string{"groupC"}, Bidirectional: true, - Action: PolicyTrafficActionAccept, + Action: types.PolicyTrafficActionAccept, }, }, } @@ -1820,7 +1852,7 @@ func TestSqlStore_CreatePolicy(t *testing.T) { } func TestSqlStore_SavePolicy(t *testing.T) { - store, cleanup, err := NewTestStoreFromSQL(context.Background(), "testdata/store.sql", t.TempDir()) + store, cleanup, err := NewTestStoreFromSQL(context.Background(), "../testdata/store.sql", t.TempDir()) t.Cleanup(cleanup) require.NoError(t, err) @@ -1843,7 +1875,7 @@ func TestSqlStore_SavePolicy(t *testing.T) { } func TestSqlStore_DeletePolicy(t *testing.T) { - store, cleanup, err := NewTestStoreFromSQL(context.Background(), "testdata/store.sql", t.TempDir()) + store, cleanup, err := NewTestStoreFromSQL(context.Background(), "../testdata/store.sql", t.TempDir()) t.Cleanup(cleanup) require.NoError(t, err) @@ -1859,7 +1891,7 @@ func TestSqlStore_DeletePolicy(t *testing.T) { } func TestSqlStore_GetDNSSettings(t *testing.T) { - store, cleanup, err := NewTestStoreFromSQL(context.Background(), "testdata/store.sql", t.TempDir()) + store, cleanup, err := NewTestStoreFromSQL(context.Background(), "../testdata/store.sql", t.TempDir()) t.Cleanup(cleanup) require.NoError(t, err) @@ -1903,7 +1935,7 @@ func TestSqlStore_GetDNSSettings(t *testing.T) { } func TestSqlStore_SaveDNSSettings(t *testing.T) { - store, cleanup, err := NewTestStoreFromSQL(context.Background(), "testdata/store.sql", t.TempDir()) + store, cleanup, err := NewTestStoreFromSQL(context.Background(), "../testdata/store.sql", t.TempDir()) t.Cleanup(cleanup) require.NoError(t, err) @@ -1922,7 +1954,7 @@ func TestSqlStore_SaveDNSSettings(t *testing.T) { } func TestSqlStore_GetAccountNameServerGroups(t *testing.T) { - store, cleanup, err := NewTestStoreFromSQL(context.Background(), "testdata/extended-store.sql", t.TempDir()) + store, cleanup, err := NewTestStoreFromSQL(context.Background(), "../testdata/extended-store.sql", t.TempDir()) t.Cleanup(cleanup) require.NoError(t, err) @@ -1959,7 +1991,7 @@ func TestSqlStore_GetAccountNameServerGroups(t *testing.T) { } func TestSqlStore_GetNameServerByID(t *testing.T) { - store, cleanup, err := NewTestStoreFromSQL(context.Background(), "testdata/extended-store.sql", t.TempDir()) + store, cleanup, err := NewTestStoreFromSQL(context.Background(), "../testdata/extended-store.sql", t.TempDir()) t.Cleanup(cleanup) require.NoError(t, err) @@ -2005,7 +2037,7 @@ func TestSqlStore_GetNameServerByID(t *testing.T) { } func TestSqlStore_SaveNameServerGroup(t *testing.T) { - store, cleanup, err := NewTestStoreFromSQL(context.Background(), "testdata/extended-store.sql", t.TempDir()) + store, cleanup, err := NewTestStoreFromSQL(context.Background(), "../testdata/extended-store.sql", t.TempDir()) t.Cleanup(cleanup) require.NoError(t, err) @@ -2037,7 +2069,7 @@ func TestSqlStore_SaveNameServerGroup(t *testing.T) { } func TestSqlStore_DeleteNameServerGroup(t *testing.T) { - store, cleanup, err := NewTestStoreFromSQL(context.Background(), "testdata/extended-store.sql", t.TempDir()) + store, cleanup, err := NewTestStoreFromSQL(context.Background(), "../testdata/extended-store.sql", t.TempDir()) t.Cleanup(cleanup) require.NoError(t, err) @@ -2051,3 +2083,443 @@ func TestSqlStore_DeleteNameServerGroup(t *testing.T) { require.Error(t, err) require.Nil(t, nsGroup) } + +// newAccountWithId creates a new Account with a default SetupKey (doesn't store in a Store) and provided id +func newAccountWithId(ctx context.Context, accountID, userID, domain string) *types.Account { + log.WithContext(ctx).Debugf("creating new account") + + network := types.NewNetwork() + peers := make(map[string]*nbpeer.Peer) + users := make(map[string]*types.User) + routes := make(map[nbroute.ID]*nbroute.Route) + setupKeys := map[string]*types.SetupKey{} + nameServersGroups := make(map[string]*nbdns.NameServerGroup) + + owner := types.NewOwnerUser(userID) + owner.AccountID = accountID + users[userID] = owner + + dnsSettings := types.DNSSettings{ + DisabledManagementGroups: make([]string, 0), + } + log.WithContext(ctx).Debugf("created new account %s", accountID) + + acc := &types.Account{ + Id: accountID, + CreatedAt: time.Now().UTC(), + SetupKeys: setupKeys, + Network: network, + Peers: peers, + Users: users, + CreatedBy: userID, + Domain: domain, + Routes: routes, + NameServerGroups: nameServersGroups, + DNSSettings: dnsSettings, + Settings: &types.Settings{ + PeerLoginExpirationEnabled: true, + PeerLoginExpiration: types.DefaultPeerLoginExpiration, + GroupsPropagationEnabled: true, + RegularUsersViewBlocked: true, + + PeerInactivityExpirationEnabled: false, + PeerInactivityExpiration: types.DefaultPeerInactivityExpiration, + }, + } + + if err := addAllGroup(acc); err != nil { + log.WithContext(ctx).Errorf("error adding all group to account %s: %v", acc.Id, err) + } + return acc +} + +// addAllGroup to account object if it doesn't exist +func addAllGroup(account *types.Account) error { + if len(account.Groups) == 0 { + allGroup := &types.Group{ + ID: xid.New().String(), + Name: "All", + Issued: types.GroupIssuedAPI, + } + for _, peer := range account.Peers { + allGroup.Peers = append(allGroup.Peers, peer.ID) + } + account.Groups = map[string]*types.Group{allGroup.ID: allGroup} + + id := xid.New().String() + + defaultPolicy := &types.Policy{ + ID: id, + Name: types.DefaultRuleName, + Description: types.DefaultRuleDescription, + Enabled: true, + Rules: []*types.PolicyRule{ + { + ID: id, + Name: types.DefaultRuleName, + Description: types.DefaultRuleDescription, + Enabled: true, + Sources: []string{allGroup.ID}, + Destinations: []string{allGroup.ID}, + Bidirectional: true, + Protocol: types.PolicyRuleProtocolALL, + Action: types.PolicyTrafficActionAccept, + }, + }, + } + + account.Policies = []*types.Policy{defaultPolicy} + } + return nil +} + +func TestSqlStore_GetAccountNetworks(t *testing.T) { + store, cleanup, err := NewTestStoreFromSQL(context.Background(), "../testdata/store.sql", t.TempDir()) + t.Cleanup(cleanup) + require.NoError(t, err) + + tests := []struct { + name string + accountID string + expectedCount int + }{ + { + name: "retrieve networks by existing account ID", + accountID: "bf1c8084-ba50-4ce7-9439-34653001fc3b", + expectedCount: 1, + }, + + { + name: "retrieve networks by non-existing account ID", + accountID: "non-existent", + expectedCount: 0, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + networks, err := store.GetAccountNetworks(context.Background(), LockingStrengthShare, tt.accountID) + require.NoError(t, err) + require.Len(t, networks, tt.expectedCount) + }) + } +} + +func TestSqlStore_GetNetworkByID(t *testing.T) { + store, cleanup, err := NewTestStoreFromSQL(context.Background(), "../testdata/store.sql", t.TempDir()) + t.Cleanup(cleanup) + require.NoError(t, err) + + accountID := "bf1c8084-ba50-4ce7-9439-34653001fc3b" + tests := []struct { + name string + networkID string + expectError bool + }{ + { + name: "retrieve existing network ID", + networkID: "ct286bi7qv930dsrrug0", + expectError: false, + }, + { + name: "retrieve non-existing network ID", + networkID: "non-existing", + expectError: true, + }, + { + name: "retrieve network with empty ID", + networkID: "", + expectError: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + network, err := store.GetNetworkByID(context.Background(), LockingStrengthShare, accountID, tt.networkID) + if tt.expectError { + require.Error(t, err) + sErr, ok := status.FromError(err) + require.True(t, ok) + require.Equal(t, sErr.Type(), status.NotFound) + require.Nil(t, network) + } else { + require.NoError(t, err) + require.NotNil(t, network) + require.Equal(t, tt.networkID, network.ID) + } + }) + } +} + +func TestSqlStore_SaveNetwork(t *testing.T) { + store, cleanup, err := NewTestStoreFromSQL(context.Background(), "../testdata/store.sql", t.TempDir()) + t.Cleanup(cleanup) + require.NoError(t, err) + + accountID := "bf1c8084-ba50-4ce7-9439-34653001fc3b" + network := &networkTypes.Network{ + ID: "net-id", + AccountID: accountID, + Name: "net", + } + + err = store.SaveNetwork(context.Background(), LockingStrengthUpdate, network) + require.NoError(t, err) + + savedNet, err := store.GetNetworkByID(context.Background(), LockingStrengthShare, accountID, network.ID) + require.NoError(t, err) + require.Equal(t, network, savedNet) +} + +func TestSqlStore_DeleteNetwork(t *testing.T) { + store, cleanup, err := NewTestStoreFromSQL(context.Background(), "../testdata/store.sql", t.TempDir()) + t.Cleanup(cleanup) + require.NoError(t, err) + + accountID := "bf1c8084-ba50-4ce7-9439-34653001fc3b" + networkID := "ct286bi7qv930dsrrug0" + + err = store.DeleteNetwork(context.Background(), LockingStrengthUpdate, accountID, networkID) + require.NoError(t, err) + + network, err := store.GetNetworkByID(context.Background(), LockingStrengthShare, accountID, networkID) + require.Error(t, err) + sErr, ok := status.FromError(err) + require.True(t, ok) + require.Equal(t, status.NotFound, sErr.Type()) + require.Nil(t, network) +} + +func TestSqlStore_GetNetworkRoutersByNetID(t *testing.T) { + store, cleanup, err := NewTestStoreFromSQL(context.Background(), "../testdata/store.sql", t.TempDir()) + t.Cleanup(cleanup) + require.NoError(t, err) + + accountID := "bf1c8084-ba50-4ce7-9439-34653001fc3b" + + tests := []struct { + name string + networkID string + expectedCount int + }{ + { + name: "retrieve routers by existing network ID", + networkID: "ct286bi7qv930dsrrug0", + expectedCount: 1, + }, + { + name: "retrieve routers by non-existing network ID", + networkID: "non-existent", + expectedCount: 0, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + routers, err := store.GetNetworkRoutersByNetID(context.Background(), LockingStrengthShare, accountID, tt.networkID) + require.NoError(t, err) + require.Len(t, routers, tt.expectedCount) + }) + } +} + +func TestSqlStore_GetNetworkRouterByID(t *testing.T) { + store, cleanup, err := NewTestStoreFromSQL(context.Background(), "../testdata/store.sql", t.TempDir()) + t.Cleanup(cleanup) + require.NoError(t, err) + + accountID := "bf1c8084-ba50-4ce7-9439-34653001fc3b" + tests := []struct { + name string + networkRouterID string + expectError bool + }{ + { + name: "retrieve existing network router ID", + networkRouterID: "ctc20ji7qv9ck2sebc80", + expectError: false, + }, + { + name: "retrieve non-existing network router ID", + networkRouterID: "non-existing", + expectError: true, + }, + { + name: "retrieve network with empty router ID", + networkRouterID: "", + expectError: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + networkRouter, err := store.GetNetworkRouterByID(context.Background(), LockingStrengthShare, accountID, tt.networkRouterID) + if tt.expectError { + require.Error(t, err) + sErr, ok := status.FromError(err) + require.True(t, ok) + require.Equal(t, sErr.Type(), status.NotFound) + require.Nil(t, networkRouter) + } else { + require.NoError(t, err) + require.NotNil(t, networkRouter) + require.Equal(t, tt.networkRouterID, networkRouter.ID) + } + }) + } +} + +func TestSqlStore_SaveNetworkRouter(t *testing.T) { + store, cleanup, err := NewTestStoreFromSQL(context.Background(), "../testdata/store.sql", t.TempDir()) + t.Cleanup(cleanup) + require.NoError(t, err) + + accountID := "bf1c8084-ba50-4ce7-9439-34653001fc3b" + networkID := "ct286bi7qv930dsrrug0" + + netRouter, err := routerTypes.NewNetworkRouter(accountID, networkID, "", []string{"net-router-grp"}, true, 0) + require.NoError(t, err) + + err = store.SaveNetworkRouter(context.Background(), LockingStrengthUpdate, netRouter) + require.NoError(t, err) + + savedNetRouter, err := store.GetNetworkRouterByID(context.Background(), LockingStrengthShare, accountID, netRouter.ID) + require.NoError(t, err) + require.Equal(t, netRouter, savedNetRouter) +} + +func TestSqlStore_DeleteNetworkRouter(t *testing.T) { + store, cleanup, err := NewTestStoreFromSQL(context.Background(), "../testdata/store.sql", t.TempDir()) + t.Cleanup(cleanup) + require.NoError(t, err) + + accountID := "bf1c8084-ba50-4ce7-9439-34653001fc3b" + netRouterID := "ctc20ji7qv9ck2sebc80" + + err = store.DeleteNetworkRouter(context.Background(), LockingStrengthUpdate, accountID, netRouterID) + require.NoError(t, err) + + netRouter, err := store.GetNetworkByID(context.Background(), LockingStrengthShare, accountID, netRouterID) + require.Error(t, err) + sErr, ok := status.FromError(err) + require.True(t, ok) + require.Equal(t, status.NotFound, sErr.Type()) + require.Nil(t, netRouter) +} + +func TestSqlStore_GetNetworkResourcesByNetID(t *testing.T) { + store, cleanup, err := NewTestStoreFromSQL(context.Background(), "../testdata/store.sql", t.TempDir()) + t.Cleanup(cleanup) + require.NoError(t, err) + + accountID := "bf1c8084-ba50-4ce7-9439-34653001fc3b" + + tests := []struct { + name string + networkID string + expectedCount int + }{ + { + name: "retrieve resources by existing network ID", + networkID: "ct286bi7qv930dsrrug0", + expectedCount: 1, + }, + { + name: "retrieve resources by non-existing network ID", + networkID: "non-existent", + expectedCount: 0, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + netResources, err := store.GetNetworkResourcesByNetID(context.Background(), LockingStrengthShare, accountID, tt.networkID) + require.NoError(t, err) + require.Len(t, netResources, tt.expectedCount) + }) + } +} + +func TestSqlStore_GetNetworkResourceByID(t *testing.T) { + store, cleanup, err := NewTestStoreFromSQL(context.Background(), "../testdata/store.sql", t.TempDir()) + t.Cleanup(cleanup) + require.NoError(t, err) + + accountID := "bf1c8084-ba50-4ce7-9439-34653001fc3b" + tests := []struct { + name string + netResourceID string + expectError bool + }{ + { + name: "retrieve existing network resource ID", + netResourceID: "ctc4nci7qv9061u6ilfg", + expectError: false, + }, + { + name: "retrieve non-existing network resource ID", + netResourceID: "non-existing", + expectError: true, + }, + { + name: "retrieve network with empty resource ID", + netResourceID: "", + expectError: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + netResource, err := store.GetNetworkResourceByID(context.Background(), LockingStrengthShare, accountID, tt.netResourceID) + if tt.expectError { + require.Error(t, err) + sErr, ok := status.FromError(err) + require.True(t, ok) + require.Equal(t, sErr.Type(), status.NotFound) + require.Nil(t, netResource) + } else { + require.NoError(t, err) + require.NotNil(t, netResource) + require.Equal(t, tt.netResourceID, netResource.ID) + } + }) + } +} + +func TestSqlStore_SaveNetworkResource(t *testing.T) { + store, cleanup, err := NewTestStoreFromSQL(context.Background(), "../testdata/store.sql", t.TempDir()) + t.Cleanup(cleanup) + require.NoError(t, err) + + accountID := "bf1c8084-ba50-4ce7-9439-34653001fc3b" + networkID := "ct286bi7qv930dsrrug0" + + netResource, err := resourceTypes.NewNetworkResource(accountID, networkID, "resource-name", "", "example.com") + require.NoError(t, err) + + err = store.SaveNetworkResource(context.Background(), LockingStrengthUpdate, netResource) + require.NoError(t, err) + + savedNetResource, err := store.GetNetworkResourceByID(context.Background(), LockingStrengthShare, accountID, netResource.ID) + require.NoError(t, err) + require.Equal(t, netResource, savedNetResource) +} + +func TestSqlStore_DeleteNetworkResource(t *testing.T) { + store, cleanup, err := NewTestStoreFromSQL(context.Background(), "../testdata/store.sql", t.TempDir()) + t.Cleanup(cleanup) + require.NoError(t, err) + + accountID := "bf1c8084-ba50-4ce7-9439-34653001fc3b" + netResourceID := "ctc4nci7qv9061u6ilfg" + + err = store.DeleteNetworkResource(context.Background(), LockingStrengthUpdate, accountID, netResourceID) + require.NoError(t, err) + + netResource, err := store.GetNetworkByID(context.Background(), LockingStrengthShare, accountID, netResourceID) + require.Error(t, err) + sErr, ok := status.FromError(err) + require.True(t, ok) + require.Equal(t, status.NotFound, sErr.Type()) + require.Nil(t, netResource) +} diff --git a/management/server/store.go b/management/server/store/store.go similarity index 75% rename from management/server/store.go rename to management/server/store/store.go index b16ad8a1a..07fef6cfd 100644 --- a/management/server/store.go +++ b/management/server/store/store.go @@ -1,4 +1,4 @@ -package server +package store import ( "context" @@ -18,13 +18,15 @@ import ( "gorm.io/gorm" "github.com/netbirdio/netbird/dns" - - nbgroup "github.com/netbirdio/netbird/management/server/group" + "github.com/netbirdio/netbird/management/server/types" "github.com/netbirdio/netbird/management/server/telemetry" "github.com/netbirdio/netbird/util" "github.com/netbirdio/netbird/management/server/migration" + resourceTypes "github.com/netbirdio/netbird/management/server/networks/resources/types" + routerTypes "github.com/netbirdio/netbird/management/server/networks/routers/types" + networkTypes "github.com/netbirdio/netbird/management/server/networks/types" nbpeer "github.com/netbirdio/netbird/management/server/peer" "github.com/netbirdio/netbird/management/server/posture" "github.com/netbirdio/netbird/management/server/testutil" @@ -41,49 +43,49 @@ const ( ) type Store interface { - GetAllAccounts(ctx context.Context) []*Account - GetAccount(ctx context.Context, accountID string) (*Account, error) + GetAllAccounts(ctx context.Context) []*types.Account + GetAccount(ctx context.Context, accountID string) (*types.Account, error) AccountExists(ctx context.Context, lockStrength LockingStrength, id string) (bool, error) GetAccountDomainAndCategory(ctx context.Context, lockStrength LockingStrength, accountID string) (string, string, error) - GetAccountByUser(ctx context.Context, userID string) (*Account, error) - GetAccountByPeerPubKey(ctx context.Context, peerKey string) (*Account, error) + GetAccountByUser(ctx context.Context, userID string) (*types.Account, error) + GetAccountByPeerPubKey(ctx context.Context, peerKey string) (*types.Account, error) GetAccountIDByPeerPubKey(ctx context.Context, peerKey string) (string, error) GetAccountIDByUserID(userID string) (string, error) GetAccountIDBySetupKey(ctx context.Context, peerKey string) (string, error) - GetAccountByPeerID(ctx context.Context, peerID string) (*Account, error) - GetAccountBySetupKey(ctx context.Context, setupKey string) (*Account, error) // todo use key hash later - GetAccountByPrivateDomain(ctx context.Context, domain string) (*Account, error) + GetAccountByPeerID(ctx context.Context, peerID string) (*types.Account, error) + GetAccountBySetupKey(ctx context.Context, setupKey string) (*types.Account, error) // todo use key hash later + GetAccountByPrivateDomain(ctx context.Context, domain string) (*types.Account, error) GetAccountIDByPrivateDomain(ctx context.Context, lockStrength LockingStrength, domain string) (string, error) - GetAccountSettings(ctx context.Context, lockStrength LockingStrength, accountID string) (*Settings, error) - GetAccountDNSSettings(ctx context.Context, lockStrength LockingStrength, accountID string) (*DNSSettings, error) - SaveAccount(ctx context.Context, account *Account) error - DeleteAccount(ctx context.Context, account *Account) error + GetAccountSettings(ctx context.Context, lockStrength LockingStrength, accountID string) (*types.Settings, error) + GetAccountDNSSettings(ctx context.Context, lockStrength LockingStrength, accountID string) (*types.DNSSettings, error) + SaveAccount(ctx context.Context, account *types.Account) error + DeleteAccount(ctx context.Context, account *types.Account) error UpdateAccountDomainAttributes(ctx context.Context, accountID string, domain string, category string, isPrimaryDomain bool) error - SaveDNSSettings(ctx context.Context, lockStrength LockingStrength, accountID string, settings *DNSSettings) error + SaveDNSSettings(ctx context.Context, lockStrength LockingStrength, accountID string, settings *types.DNSSettings) error - GetUserByTokenID(ctx context.Context, tokenID string) (*User, error) - GetUserByUserID(ctx context.Context, lockStrength LockingStrength, userID string) (*User, error) - GetAccountUsers(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*User, error) - SaveUsers(accountID string, users map[string]*User) error - SaveUser(ctx context.Context, lockStrength LockingStrength, user *User) error + GetUserByTokenID(ctx context.Context, tokenID string) (*types.User, error) + GetUserByUserID(ctx context.Context, lockStrength LockingStrength, userID string) (*types.User, error) + GetAccountUsers(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*types.User, error) + SaveUsers(accountID string, users map[string]*types.User) error + SaveUser(ctx context.Context, lockStrength LockingStrength, user *types.User) error SaveUserLastLogin(ctx context.Context, accountID, userID string, lastLogin time.Time) error GetTokenIDByHashedToken(ctx context.Context, secret string) (string, error) DeleteHashedPAT2TokenIDIndex(hashedToken string) error DeleteTokenID2UserIDIndex(tokenID string) error - GetAccountGroups(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*nbgroup.Group, error) - GetGroupByID(ctx context.Context, lockStrength LockingStrength, groupID, accountID string) (*nbgroup.Group, error) - GetGroupByName(ctx context.Context, lockStrength LockingStrength, groupName, accountID string) (*nbgroup.Group, error) - GetGroupsByIDs(ctx context.Context, lockStrength LockingStrength, accountID string, groupIDs []string) (map[string]*nbgroup.Group, error) - SaveGroups(ctx context.Context, lockStrength LockingStrength, groups []*nbgroup.Group) error - SaveGroup(ctx context.Context, lockStrength LockingStrength, group *nbgroup.Group) error + GetAccountGroups(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*types.Group, error) + GetGroupByID(ctx context.Context, lockStrength LockingStrength, groupID, accountID string) (*types.Group, error) + GetGroupByName(ctx context.Context, lockStrength LockingStrength, groupName, accountID string) (*types.Group, error) + GetGroupsByIDs(ctx context.Context, lockStrength LockingStrength, accountID string, groupIDs []string) (map[string]*types.Group, error) + SaveGroups(ctx context.Context, lockStrength LockingStrength, groups []*types.Group) error + SaveGroup(ctx context.Context, lockStrength LockingStrength, group *types.Group) error DeleteGroup(ctx context.Context, lockStrength LockingStrength, accountID, groupID string) error DeleteGroups(ctx context.Context, strength LockingStrength, accountID string, groupIDs []string) error - GetAccountPolicies(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*Policy, error) - GetPolicyByID(ctx context.Context, lockStrength LockingStrength, accountID, policyID string) (*Policy, error) - CreatePolicy(ctx context.Context, lockStrength LockingStrength, policy *Policy) error - SavePolicy(ctx context.Context, lockStrength LockingStrength, policy *Policy) error + GetAccountPolicies(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*types.Policy, error) + GetPolicyByID(ctx context.Context, lockStrength LockingStrength, accountID, policyID string) (*types.Policy, error) + CreatePolicy(ctx context.Context, lockStrength LockingStrength, policy *types.Policy) error + SavePolicy(ctx context.Context, lockStrength LockingStrength, policy *types.Policy) error DeletePolicy(ctx context.Context, lockStrength LockingStrength, accountID, policyID string) error GetPostureCheckByChecksDefinition(accountID string, checks *posture.ChecksDefinition) (*posture.Checks, error) @@ -105,11 +107,11 @@ type Store interface { SavePeerStatus(accountID, peerID string, status nbpeer.PeerStatus) error SavePeerLocation(accountID string, peer *nbpeer.Peer) error - GetSetupKeyBySecret(ctx context.Context, lockStrength LockingStrength, key string) (*SetupKey, error) + GetSetupKeyBySecret(ctx context.Context, lockStrength LockingStrength, key string) (*types.SetupKey, error) IncrementSetupKeyUsage(ctx context.Context, setupKeyID string) error - GetAccountSetupKeys(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*SetupKey, error) - GetSetupKeyByID(ctx context.Context, lockStrength LockingStrength, accountID, setupKeyID string) (*SetupKey, error) - SaveSetupKey(ctx context.Context, lockStrength LockingStrength, setupKey *SetupKey) error + GetAccountSetupKeys(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*types.SetupKey, error) + GetSetupKeyByID(ctx context.Context, lockStrength LockingStrength, accountID, setupKeyID string) (*types.SetupKey, error) + SaveSetupKey(ctx context.Context, lockStrength LockingStrength, setupKey *types.SetupKey) error DeleteSetupKey(ctx context.Context, lockStrength LockingStrength, accountID, keyID string) error GetAccountRoutes(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*route.Route, error) @@ -122,7 +124,7 @@ type Store interface { GetTakenIPs(ctx context.Context, lockStrength LockingStrength, accountId string) ([]net.IP, error) IncrementNetworkSerial(ctx context.Context, lockStrength LockingStrength, accountId string) error - GetAccountNetwork(ctx context.Context, lockStrength LockingStrength, accountId string) (*Network, error) + GetAccountNetwork(ctx context.Context, lockStrength LockingStrength, accountId string) (*types.Network, error) GetInstallationID() string SaveInstallationID(ctx context.Context, ID string) error @@ -136,30 +138,47 @@ type Store interface { // Close should close the store persisting all unsaved data. Close(ctx context.Context) error - // GetStoreEngine should return StoreEngine of the current store implementation. + // GetStoreEngine should return Engine of the current store implementation. // This is also a method of metrics.DataSource interface. - GetStoreEngine() StoreEngine + GetStoreEngine() Engine ExecuteInTransaction(ctx context.Context, f func(store Store) error) error + + GetAccountNetworks(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*networkTypes.Network, error) + GetNetworkByID(ctx context.Context, lockStrength LockingStrength, accountID, networkID string) (*networkTypes.Network, error) + SaveNetwork(ctx context.Context, lockStrength LockingStrength, network *networkTypes.Network) error + DeleteNetwork(ctx context.Context, lockStrength LockingStrength, accountID, networkID string) error + + GetNetworkRoutersByNetID(ctx context.Context, lockStrength LockingStrength, accountID, netID string) ([]*routerTypes.NetworkRouter, error) + GetNetworkRoutersByAccountID(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*routerTypes.NetworkRouter, error) + GetNetworkRouterByID(ctx context.Context, lockStrength LockingStrength, accountID, routerID string) (*routerTypes.NetworkRouter, error) + SaveNetworkRouter(ctx context.Context, lockStrength LockingStrength, router *routerTypes.NetworkRouter) error + DeleteNetworkRouter(ctx context.Context, lockStrength LockingStrength, accountID, routerID string) error + + GetNetworkResourcesByNetID(ctx context.Context, lockStrength LockingStrength, accountID, netID string) ([]*resourceTypes.NetworkResource, error) + GetNetworkResourcesByAccountID(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*resourceTypes.NetworkResource, error) + GetNetworkResourceByID(ctx context.Context, lockStrength LockingStrength, accountID, resourceID string) (*resourceTypes.NetworkResource, error) + SaveNetworkResource(ctx context.Context, lockStrength LockingStrength, resource *resourceTypes.NetworkResource) error + DeleteNetworkResource(ctx context.Context, lockStrength LockingStrength, accountID, resourceID string) error } -type StoreEngine string +type Engine string const ( - FileStoreEngine StoreEngine = "jsonfile" - SqliteStoreEngine StoreEngine = "sqlite" - PostgresStoreEngine StoreEngine = "postgres" + FileStoreEngine Engine = "jsonfile" + SqliteStoreEngine Engine = "sqlite" + PostgresStoreEngine Engine = "postgres" postgresDsnEnv = "NETBIRD_STORE_ENGINE_POSTGRES_DSN" ) -func getStoreEngineFromEnv() StoreEngine { +func getStoreEngineFromEnv() Engine { // NETBIRD_STORE_ENGINE supposed to be used in tests. Otherwise, rely on the config file. kind, ok := os.LookupEnv("NETBIRD_STORE_ENGINE") if !ok { return "" } - value := StoreEngine(strings.ToLower(kind)) + value := Engine(strings.ToLower(kind)) if value == SqliteStoreEngine || value == PostgresStoreEngine { return value } @@ -171,7 +190,7 @@ func getStoreEngineFromEnv() StoreEngine { // If no engine is specified, it attempts to retrieve it from the environment. // If still not specified, it defaults to using SQLite. // Additionally, it handles the migration from a JSON store file to SQLite if applicable. -func getStoreEngine(ctx context.Context, dataDir string, kind StoreEngine) StoreEngine { +func getStoreEngine(ctx context.Context, dataDir string, kind Engine) Engine { if kind == "" { kind = getStoreEngineFromEnv() if kind == "" { @@ -197,7 +216,7 @@ func getStoreEngine(ctx context.Context, dataDir string, kind StoreEngine) Store } // NewStore creates a new store based on the provided engine type, data directory, and telemetry metrics -func NewStore(ctx context.Context, kind StoreEngine, dataDir string, metrics telemetry.AppMetrics) (Store, error) { +func NewStore(ctx context.Context, kind Engine, dataDir string, metrics telemetry.AppMetrics) (Store, error) { kind = getStoreEngine(ctx, dataDir, kind) if err := checkFileStoreEngine(kind, dataDir); err != nil { @@ -216,7 +235,7 @@ func NewStore(ctx context.Context, kind StoreEngine, dataDir string, metrics tel } } -func checkFileStoreEngine(kind StoreEngine, dataDir string) error { +func checkFileStoreEngine(kind Engine, dataDir string) error { if kind == FileStoreEngine { storeFile := filepath.Join(dataDir, storeFileName) if util.FileExists(storeFile) { @@ -243,7 +262,7 @@ func migrate(ctx context.Context, db *gorm.DB) error { func getMigrations(ctx context.Context) []migrationFunc { return []migrationFunc{ func(db *gorm.DB) error { - return migration.MigrateFieldFromGobToJSON[Account, net.IPNet](ctx, db, "network_net") + return migration.MigrateFieldFromGobToJSON[types.Account, net.IPNet](ctx, db, "network_net") }, func(db *gorm.DB) error { return migration.MigrateFieldFromGobToJSON[route.Route, netip.Prefix](ctx, db, "network") @@ -258,7 +277,7 @@ func getMigrations(ctx context.Context) []migrationFunc { return migration.MigrateNetIPFieldFromBlobToJSON[nbpeer.Peer](ctx, db, "ip", "idx_peers_account_id_ip") }, func(db *gorm.DB) error { - return migration.MigrateSetupKeyToHashedSetupKey[SetupKey](ctx, db) + return migration.MigrateSetupKeyToHashedSetupKey[types.SetupKey](ctx, db) }, } } diff --git a/management/server/store_test.go b/management/server/store/store_test.go similarity index 93% rename from management/server/store_test.go rename to management/server/store/store_test.go index fc821670d..1d0026e3d 100644 --- a/management/server/store_test.go +++ b/management/server/store/store_test.go @@ -1,4 +1,4 @@ -package server +package store import ( "context" @@ -76,11 +76,3 @@ func BenchmarkTest_StoreRead(b *testing.B) { }) } } - -func newStore(t *testing.T) Store { - t.Helper() - - store := newSqliteStore(t) - - return store -} diff --git a/management/server/testdata/store.sql b/management/server/testdata/store.sql index 168973cad..7f0c7b5a4 100644 --- a/management/server/testdata/store.sql +++ b/management/server/testdata/store.sql @@ -12,6 +12,9 @@ CREATE TABLE `installations` (`id` integer,`installation_id_value` text,PRIMARY CREATE TABLE `extra_settings` (`peer_approval_enabled` numeric,`integrated_validator_groups` text); CREATE TABLE `posture_checks` (`id` text,`name` text,`description` text,`account_id` text,`checks` text,PRIMARY KEY (`id`),CONSTRAINT `fk_accounts_posture_checks` FOREIGN KEY (`account_id`) REFERENCES `accounts`(`id`)); CREATE TABLE `network_addresses` (`net_ip` text,`mac` text); +CREATE TABLE `network_routers` (`id` text,`network_id` text,`account_id` text,`peer` text,`peer_groups` text,`masquerade` numeric,`metric` integer,PRIMARY KEY (`id`)); +CREATE TABLE `network_resources` (`id` text,`network_id` text,`account_id` text,`type` text,`address` text,PRIMARY KEY (`id`)); +CREATE TABLE `networks` (`id` text,`account_id` text,`name` text,`description` text,PRIMARY KEY (`id`),CONSTRAINT `fk_accounts_networks` FOREIGN KEY (`account_id`) REFERENCES `accounts`(`id`)); CREATE INDEX `idx_accounts_domain` ON `accounts`(`domain`); CREATE INDEX `idx_setup_keys_account_id` ON `setup_keys`(`account_id`); CREATE INDEX `idx_peers_key` ON `peers`(`key`); @@ -24,6 +27,14 @@ CREATE INDEX `idx_policy_rules_policy_id` ON `policy_rules`(`policy_id`); CREATE INDEX `idx_routes_account_id` ON `routes`(`account_id`); CREATE INDEX `idx_name_server_groups_account_id` ON `name_server_groups`(`account_id`); CREATE INDEX `idx_posture_checks_account_id` ON `posture_checks`(`account_id`); +CREATE INDEX `idx_network_routers_id` ON `network_routers`(`id`); +CREATE INDEX `idx_network_routers_account_id` ON `network_routers`(`account_id`); +CREATE INDEX `idx_network_routers_network_id` ON `network_routers`(`network_id`); +CREATE INDEX `idx_network_resources_account_id` ON `network_resources`(`account_id`); +CREATE INDEX `idx_network_resources_network_id` ON `network_resources`(`network_id`); +CREATE INDEX `idx_network_resources_id` ON `network_resources`(`id`); +CREATE INDEX `idx_networks_id` ON `networks`(`id`); +CREATE INDEX `idx_networks_account_id` ON `networks`(`account_id`); INSERT INTO accounts VALUES('bf1c8084-ba50-4ce7-9439-34653001fc3b','','2024-10-02 16:03:06.778746+02:00','test.com','private',1,'af1c8024-ha40-4ce2-9418-34653101fc3c','{"IP":"100.64.0.0","Mask":"//8AAA=="}','',0,'[]',0,86400000000000,0,0,0,'',NULL,NULL,NULL); INSERT INTO "groups" VALUES('cs1tnh0hhcjnqoiuebeg','bf1c8084-ba50-4ce7-9439-34653001fc3b','All','api','[]',0,''); @@ -34,3 +45,6 @@ INSERT INTO personal_access_tokens VALUES('9dj38s35-63fb-11ec-90d6-0242ac120003' INSERT INTO installations VALUES(1,''); INSERT INTO policies VALUES('cs1tnh0hhcjnqoiuebf0','bf1c8084-ba50-4ce7-9439-34653001fc3b','Default','This is a default rule that allows connections between all the resources',1,'[]'); INSERT INTO policy_rules VALUES('cs387mkv2d4bgq41b6n0','cs1tnh0hhcjnqoiuebf0','Default','This is a default rule that allows connections between all the resources',1,'accept','["cs1tnh0hhcjnqoiuebeg"]','["cs1tnh0hhcjnqoiuebeg"]',1,'all',NULL,NULL); +INSERT INTO network_routers VALUES('ctc20ji7qv9ck2sebc80','ct286bi7qv930dsrrug0','bf1c8084-ba50-4ce7-9439-34653001fc3b','cs1tnh0hhcjnqoiuebeg',NULL,0,0); +INSERT INTO network_resources VALUES ('ctc4nci7qv9061u6ilfg','ct286bi7qv930dsrrug0','bf1c8084-ba50-4ce7-9439-34653001fc3b','Host','192.168.1.1'); +INSERT INTO networks VALUES('ct286bi7qv930dsrrug0','bf1c8084-ba50-4ce7-9439-34653001fc3b','Test Network','Test Network'); diff --git a/management/server/types/account.go b/management/server/types/account.go new file mode 100644 index 000000000..281c8ea63 --- /dev/null +++ b/management/server/types/account.go @@ -0,0 +1,1181 @@ +package types + +import ( + "context" + "fmt" + "net" + "net/netip" + "slices" + "strconv" + "strings" + "time" + + "github.com/hashicorp/go-multierror" + "github.com/miekg/dns" + log "github.com/sirupsen/logrus" + + nbdns "github.com/netbirdio/netbird/dns" + "github.com/netbirdio/netbird/management/domain" + + resourceTypes "github.com/netbirdio/netbird/management/server/networks/resources/types" + routerTypes "github.com/netbirdio/netbird/management/server/networks/routers/types" + networkTypes "github.com/netbirdio/netbird/management/server/networks/types" + nbpeer "github.com/netbirdio/netbird/management/server/peer" + "github.com/netbirdio/netbird/management/server/posture" + "github.com/netbirdio/netbird/management/server/status" + "github.com/netbirdio/netbird/management/server/telemetry" + "github.com/netbirdio/netbird/management/server/util" + "github.com/netbirdio/netbird/route" +) + +const ( + defaultTTL = 300 + DefaultPeerLoginExpiration = 24 * time.Hour + DefaultPeerInactivityExpiration = 10 * time.Minute + + PublicCategory = "public" + PrivateCategory = "private" + UnknownCategory = "unknown" +) + +type LookupMap map[string]struct{} + +// Account represents a unique account of the system +type Account struct { + // we have to name column to aid as it collides with Network.Id when work with associations + Id string `gorm:"primaryKey"` + + // User.Id it was created by + CreatedBy string + CreatedAt time.Time + Domain string `gorm:"index"` + DomainCategory string + IsDomainPrimaryAccount bool + SetupKeys map[string]*SetupKey `gorm:"-"` + SetupKeysG []SetupKey `json:"-" gorm:"foreignKey:AccountID;references:id"` + Network *Network `gorm:"embedded;embeddedPrefix:network_"` + Peers map[string]*nbpeer.Peer `gorm:"-"` + PeersG []nbpeer.Peer `json:"-" gorm:"foreignKey:AccountID;references:id"` + Users map[string]*User `gorm:"-"` + UsersG []User `json:"-" gorm:"foreignKey:AccountID;references:id"` + Groups map[string]*Group `gorm:"-"` + GroupsG []Group `json:"-" gorm:"foreignKey:AccountID;references:id"` + Policies []*Policy `gorm:"foreignKey:AccountID;references:id"` + Routes map[route.ID]*route.Route `gorm:"-"` + RoutesG []route.Route `json:"-" gorm:"foreignKey:AccountID;references:id"` + NameServerGroups map[string]*nbdns.NameServerGroup `gorm:"-"` + NameServerGroupsG []nbdns.NameServerGroup `json:"-" gorm:"foreignKey:AccountID;references:id"` + DNSSettings DNSSettings `gorm:"embedded;embeddedPrefix:dns_settings_"` + PostureChecks []*posture.Checks `gorm:"foreignKey:AccountID;references:id"` + // Settings is a dictionary of Account settings + Settings *Settings `gorm:"embedded;embeddedPrefix:settings_"` + + Networks []*networkTypes.Network `gorm:"foreignKey:AccountID;references:id"` + NetworkRouters []*routerTypes.NetworkRouter `gorm:"foreignKey:AccountID;references:id"` + NetworkResources []*resourceTypes.NetworkResource `gorm:"foreignKey:AccountID;references:id"` +} + +// Subclass used in gorm to only load network and not whole account +type AccountNetwork struct { + Network *Network `gorm:"embedded;embeddedPrefix:network_"` +} + +// AccountDNSSettings used in gorm to only load dns settings and not whole account +type AccountDNSSettings struct { + DNSSettings DNSSettings `gorm:"embedded;embeddedPrefix:dns_settings_"` +} + +// Subclass used in gorm to only load settings and not whole account +type AccountSettings struct { + Settings *Settings `gorm:"embedded;embeddedPrefix:settings_"` +} + +// GetRoutesToSync returns the enabled routes for the peer ID and the routes +// from the ACL peers that have distribution groups associated with the peer ID. +// Please mind, that the returned route.Route objects will contain Peer.Key instead of Peer.ID. +func (a *Account) GetRoutesToSync(ctx context.Context, peerID string, aclPeers []*nbpeer.Peer) []*route.Route { + routes, peerDisabledRoutes := a.getRoutingPeerRoutes(ctx, peerID) + peerRoutesMembership := make(LookupMap) + for _, r := range append(routes, peerDisabledRoutes...) { + peerRoutesMembership[string(r.GetHAUniqueID())] = struct{}{} + } + + groupListMap := a.GetPeerGroups(peerID) + for _, peer := range aclPeers { + activeRoutes, _ := a.getRoutingPeerRoutes(ctx, peer.ID) + groupFilteredRoutes := a.filterRoutesByGroups(activeRoutes, groupListMap) + filteredRoutes := a.filterRoutesFromPeersOfSameHAGroup(groupFilteredRoutes, peerRoutesMembership) + routes = append(routes, filteredRoutes...) + } + + return routes +} + +// filterRoutesFromPeersOfSameHAGroup filters and returns a list of routes that don't share the same HA route membership +func (a *Account) filterRoutesFromPeersOfSameHAGroup(routes []*route.Route, peerMemberships LookupMap) []*route.Route { + var filteredRoutes []*route.Route + for _, r := range routes { + _, found := peerMemberships[string(r.GetHAUniqueID())] + if !found { + filteredRoutes = append(filteredRoutes, r) + } + } + return filteredRoutes +} + +// filterRoutesByGroups returns a list with routes that have distribution groups in the group's map +func (a *Account) filterRoutesByGroups(routes []*route.Route, groupListMap LookupMap) []*route.Route { + var filteredRoutes []*route.Route + for _, r := range routes { + for _, groupID := range r.Groups { + _, found := groupListMap[groupID] + if found { + filteredRoutes = append(filteredRoutes, r) + break + } + } + } + return filteredRoutes +} + +// getRoutingPeerRoutes returns the enabled and disabled lists of routes that the given routing peer serves +// Please mind, that the returned route.Route objects will contain Peer.Key instead of Peer.ID. +// If the given is not a routing peer, then the lists are empty. +func (a *Account) getRoutingPeerRoutes(ctx context.Context, peerID string) (enabledRoutes []*route.Route, disabledRoutes []*route.Route) { + + peer := a.GetPeer(peerID) + if peer == nil { + log.WithContext(ctx).Errorf("peer %s that doesn't exist under account %s", peerID, a.Id) + return enabledRoutes, disabledRoutes + } + + // currently we support only linux routing peers + if peer.Meta.GoOS != "linux" { + return enabledRoutes, disabledRoutes + } + + seenRoute := make(map[route.ID]struct{}) + + takeRoute := func(r *route.Route, id string) { + if _, ok := seenRoute[r.ID]; ok { + return + } + seenRoute[r.ID] = struct{}{} + + if r.Enabled { + r.Peer = peer.Key + enabledRoutes = append(enabledRoutes, r) + return + } + disabledRoutes = append(disabledRoutes, r) + } + + for _, r := range a.Routes { + for _, groupID := range r.PeerGroups { + group := a.GetGroup(groupID) + if group == nil { + log.WithContext(ctx).Errorf("route %s has peers group %s that doesn't exist under account %s", r.ID, groupID, a.Id) + continue + } + for _, id := range group.Peers { + if id != peerID { + continue + } + + newPeerRoute := r.Copy() + newPeerRoute.Peer = id + newPeerRoute.PeerGroups = nil + newPeerRoute.ID = route.ID(string(r.ID) + ":" + id) // we have to provide unique route id when distribute network map + takeRoute(newPeerRoute, id) + break + } + } + if r.Peer == peerID { + takeRoute(r.Copy(), peerID) + } + } + + return enabledRoutes, disabledRoutes +} + +// GetRoutesByPrefixOrDomains return list of routes by account and route prefix +func (a *Account) GetRoutesByPrefixOrDomains(prefix netip.Prefix, domains domain.List) []*route.Route { + var routes []*route.Route + for _, r := range a.Routes { + dynamic := r.IsDynamic() + if dynamic && r.Domains.PunycodeString() == domains.PunycodeString() || + !dynamic && r.Network.String() == prefix.String() { + routes = append(routes, r) + } + } + + return routes +} + +// GetGroup returns a group by ID if exists, nil otherwise +func (a *Account) GetGroup(groupID string) *Group { + return a.Groups[groupID] +} + +// GetPeerNetworkMap returns the networkmap for the given peer ID. +func (a *Account) GetPeerNetworkMap( + ctx context.Context, + peerID string, + peersCustomZone nbdns.CustomZone, + validatedPeersMap map[string]struct{}, + metrics *telemetry.AccountManagerMetrics, +) *NetworkMap { + start := time.Now() + + peer := a.Peers[peerID] + if peer == nil { + return &NetworkMap{ + Network: a.Network.Copy(), + } + } + + if _, ok := validatedPeersMap[peerID]; !ok { + return &NetworkMap{ + Network: a.Network.Copy(), + } + } + + aclPeers, firewallRules := a.GetPeerConnectionResources(ctx, peerID, validatedPeersMap) + // exclude expired peers + var peersToConnect []*nbpeer.Peer + var expiredPeers []*nbpeer.Peer + for _, p := range aclPeers { + expired, _ := p.LoginExpired(a.Settings.PeerLoginExpiration) + if a.Settings.PeerLoginExpirationEnabled && expired { + expiredPeers = append(expiredPeers, p) + continue + } + peersToConnect = append(peersToConnect, p) + } + + routesUpdate := a.GetRoutesToSync(ctx, peerID, peersToConnect) + routesFirewallRules := a.GetPeerRoutesFirewallRules(ctx, peerID, validatedPeersMap) + + dnsManagementStatus := a.getPeerDNSManagementStatus(peerID) + dnsUpdate := nbdns.Config{ + ServiceEnable: dnsManagementStatus, + } + + if dnsManagementStatus { + var zones []nbdns.CustomZone + + if peersCustomZone.Domain != "" { + zones = append(zones, peersCustomZone) + } + dnsUpdate.CustomZones = zones + dnsUpdate.NameServerGroups = getPeerNSGroups(a, peerID) + } + + nm := &NetworkMap{ + Peers: peersToConnect, + Network: a.Network.Copy(), + Routes: routesUpdate, + DNSConfig: dnsUpdate, + OfflinePeers: expiredPeers, + FirewallRules: firewallRules, + RoutesFirewallRules: routesFirewallRules, + } + + if metrics != nil { + objectCount := int64(len(peersToConnect) + len(expiredPeers) + len(routesUpdate) + len(firewallRules)) + metrics.CountNetworkMapObjects(objectCount) + metrics.CountGetPeerNetworkMapDuration(time.Since(start)) + + if objectCount > 5000 { + log.WithContext(ctx).Tracef("account: %s has a total resource count of %d objects, "+ + "peers to connect: %d, expired peers: %d, routes: %d, firewall rules: %d", + a.Id, objectCount, len(peersToConnect), len(expiredPeers), len(routesUpdate), len(firewallRules)) + } + } + + return nm +} + +func getPeerNSGroups(account *Account, peerID string) []*nbdns.NameServerGroup { + groupList := account.GetPeerGroups(peerID) + + var peerNSGroups []*nbdns.NameServerGroup + + for _, nsGroup := range account.NameServerGroups { + if !nsGroup.Enabled { + continue + } + for _, gID := range nsGroup.Groups { + _, found := groupList[gID] + if found { + if !peerIsNameserver(account.GetPeer(peerID), nsGroup) { + peerNSGroups = append(peerNSGroups, nsGroup.Copy()) + break + } + } + } + } + + return peerNSGroups +} + +// peerIsNameserver returns true if the peer is a nameserver for a nsGroup +func peerIsNameserver(peer *nbpeer.Peer, nsGroup *nbdns.NameServerGroup) bool { + for _, ns := range nsGroup.NameServers { + if peer.IP.Equal(ns.IP.AsSlice()) { + return true + } + } + return false +} + +func AddPeerLabelsToAccount(ctx context.Context, account *Account, peerLabels LookupMap) { + for _, peer := range account.Peers { + label, err := GetPeerHostLabel(peer.Name, peerLabels) + if err != nil { + log.WithContext(ctx).Errorf("got an error while generating a peer host label. Peer name %s, error: %v. Trying with the peer's meta hostname", peer.Name, err) + label, err = GetPeerHostLabel(peer.Meta.Hostname, peerLabels) + if err != nil { + log.WithContext(ctx).Errorf("got another error while generating a peer host label with hostname. Peer hostname %s, error: %v. Skipping", peer.Meta.Hostname, err) + continue + } + } + peer.DNSLabel = label + peerLabels[label] = struct{}{} + } +} + +func GetPeerHostLabel(name string, peerLabels LookupMap) (string, error) { + label, err := nbdns.GetParsedDomainLabel(name) + if err != nil { + return "", err + } + + uniqueLabel := getUniqueHostLabel(label, peerLabels) + if uniqueLabel == "" { + return "", fmt.Errorf("couldn't find a unique valid label for %s, parsed label %s", name, label) + } + return uniqueLabel, nil +} + +// getUniqueHostLabel look for a unique host label, and if doesn't find add a suffix up to 999 +func getUniqueHostLabel(name string, peerLabels LookupMap) string { + _, found := peerLabels[name] + if !found { + return name + } + for i := 1; i < 1000; i++ { + nameWithSuffix := name + "-" + strconv.Itoa(i) + _, found = peerLabels[nameWithSuffix] + if !found { + return nameWithSuffix + } + } + return "" +} + +func (a *Account) GetPeersCustomZone(ctx context.Context, dnsDomain string) nbdns.CustomZone { + var merr *multierror.Error + + if dnsDomain == "" { + log.WithContext(ctx).Error("no dns domain is set, returning empty zone") + return nbdns.CustomZone{} + } + + customZone := nbdns.CustomZone{ + Domain: dns.Fqdn(dnsDomain), + Records: make([]nbdns.SimpleRecord, 0, len(a.Peers)), + } + + domainSuffix := "." + dnsDomain + + var sb strings.Builder + for _, peer := range a.Peers { + if peer.DNSLabel == "" { + merr = multierror.Append(merr, fmt.Errorf("peer %s has an empty DNS label", peer.Name)) + continue + } + + sb.Grow(len(peer.DNSLabel) + len(domainSuffix)) + sb.WriteString(peer.DNSLabel) + sb.WriteString(domainSuffix) + + customZone.Records = append(customZone.Records, nbdns.SimpleRecord{ + Name: sb.String(), + Type: int(dns.TypeA), + Class: nbdns.DefaultClass, + TTL: defaultTTL, + RData: peer.IP.String(), + }) + + sb.Reset() + } + + go func() { + if merr != nil { + log.WithContext(ctx).Errorf("error generating custom zone for account %s: %v", a.Id, merr) + } + }() + + return customZone +} + +// GetExpiredPeers returns peers that have been expired +func (a *Account) GetExpiredPeers() []*nbpeer.Peer { + var peers []*nbpeer.Peer + for _, peer := range a.GetPeersWithExpiration() { + expired, _ := peer.LoginExpired(a.Settings.PeerLoginExpiration) + if expired { + peers = append(peers, peer) + } + } + + return peers +} + +// GetNextPeerExpiration returns the minimum duration in which the next peer of the account will expire if it was found. +// If there is no peer that expires this function returns false and a duration of 0. +// This function only considers peers that haven't been expired yet and that are connected. +func (a *Account) GetNextPeerExpiration() (time.Duration, bool) { + peersWithExpiry := a.GetPeersWithExpiration() + if len(peersWithExpiry) == 0 { + return 0, false + } + var nextExpiry *time.Duration + for _, peer := range peersWithExpiry { + // consider only connected peers because others will require login on connecting to the management server + if peer.Status.LoginExpired || !peer.Status.Connected { + continue + } + _, duration := peer.LoginExpired(a.Settings.PeerLoginExpiration) + if nextExpiry == nil || duration < *nextExpiry { + // if expiration is below 1s return 1s duration + // this avoids issues with ticker that can't be set to < 0 + if duration < time.Second { + return time.Second, true + } + nextExpiry = &duration + } + } + + if nextExpiry == nil { + return 0, false + } + + return *nextExpiry, true +} + +// GetPeersWithExpiration returns a list of peers that have Peer.LoginExpirationEnabled set to true and that were added by a user +func (a *Account) GetPeersWithExpiration() []*nbpeer.Peer { + peers := make([]*nbpeer.Peer, 0) + for _, peer := range a.Peers { + if peer.LoginExpirationEnabled && peer.AddedWithSSOLogin() { + peers = append(peers, peer) + } + } + return peers +} + +// GetInactivePeers returns peers that have been expired by inactivity +func (a *Account) GetInactivePeers() []*nbpeer.Peer { + var peers []*nbpeer.Peer + for _, inactivePeer := range a.GetPeersWithInactivity() { + inactive, _ := inactivePeer.SessionExpired(a.Settings.PeerInactivityExpiration) + if inactive { + peers = append(peers, inactivePeer) + } + } + return peers +} + +// GetNextInactivePeerExpiration returns the minimum duration in which the next peer of the account will expire if it was found. +// If there is no peer that expires this function returns false and a duration of 0. +// This function only considers peers that haven't been expired yet and that are not connected. +func (a *Account) GetNextInactivePeerExpiration() (time.Duration, bool) { + peersWithExpiry := a.GetPeersWithInactivity() + if len(peersWithExpiry) == 0 { + return 0, false + } + var nextExpiry *time.Duration + for _, peer := range peersWithExpiry { + if peer.Status.LoginExpired || peer.Status.Connected { + continue + } + _, duration := peer.SessionExpired(a.Settings.PeerInactivityExpiration) + if nextExpiry == nil || duration < *nextExpiry { + // if expiration is below 1s return 1s duration + // this avoids issues with ticker that can't be set to < 0 + if duration < time.Second { + return time.Second, true + } + nextExpiry = &duration + } + } + + if nextExpiry == nil { + return 0, false + } + + return *nextExpiry, true +} + +// GetPeersWithInactivity eturns a list of peers that have Peer.InactivityExpirationEnabled set to true and that were added by a user +func (a *Account) GetPeersWithInactivity() []*nbpeer.Peer { + peers := make([]*nbpeer.Peer, 0) + for _, peer := range a.Peers { + if peer.InactivityExpirationEnabled && peer.AddedWithSSOLogin() { + peers = append(peers, peer) + } + } + return peers +} + +// GetPeers returns a list of all Account peers +func (a *Account) GetPeers() []*nbpeer.Peer { + var peers []*nbpeer.Peer + for _, peer := range a.Peers { + peers = append(peers, peer) + } + return peers +} + +// UpdateSettings saves new account settings +func (a *Account) UpdateSettings(update *Settings) *Account { + a.Settings = update.Copy() + return a +} + +// UpdatePeer saves new or replaces existing peer +func (a *Account) UpdatePeer(update *nbpeer.Peer) { + a.Peers[update.ID] = update +} + +// DeletePeer deletes peer from the account cleaning up all the references +func (a *Account) DeletePeer(peerID string) { + // delete peer from groups + for _, g := range a.Groups { + for i, pk := range g.Peers { + if pk == peerID { + g.Peers = append(g.Peers[:i], g.Peers[i+1:]...) + break + } + } + } + + for _, r := range a.Routes { + if r.Peer == peerID { + r.Enabled = false + r.Peer = "" + } + } + + delete(a.Peers, peerID) + a.Network.IncSerial() +} + +// FindPeerByPubKey looks for a Peer by provided WireGuard public key in the Account or returns error if it wasn't found. +// It will return an object copy of the peer. +func (a *Account) FindPeerByPubKey(peerPubKey string) (*nbpeer.Peer, error) { + for _, peer := range a.Peers { + if peer.Key == peerPubKey { + return peer.Copy(), nil + } + } + + return nil, status.Errorf(status.NotFound, "peer with the public key %s not found", peerPubKey) +} + +// FindUserPeers returns a list of peers that user owns (created) +func (a *Account) FindUserPeers(userID string) ([]*nbpeer.Peer, error) { + peers := make([]*nbpeer.Peer, 0) + for _, peer := range a.Peers { + if peer.UserID == userID { + peers = append(peers, peer) + } + } + + return peers, nil +} + +// FindUser looks for a given user in the Account or returns error if user wasn't found. +func (a *Account) FindUser(userID string) (*User, error) { + user := a.Users[userID] + if user == nil { + return nil, status.Errorf(status.NotFound, "user %s not found", userID) + } + + return user, nil +} + +// FindGroupByName looks for a given group in the Account by name or returns error if the group wasn't found. +func (a *Account) FindGroupByName(groupName string) (*Group, error) { + for _, group := range a.Groups { + if group.Name == groupName { + return group, nil + } + } + return nil, status.Errorf(status.NotFound, "group %s not found", groupName) +} + +// FindSetupKey looks for a given SetupKey in the Account or returns error if it wasn't found. +func (a *Account) FindSetupKey(setupKey string) (*SetupKey, error) { + key := a.SetupKeys[setupKey] + if key == nil { + return nil, status.Errorf(status.NotFound, "setup key not found") + } + + return key, nil +} + +// GetPeerGroupsList return with the list of groups ID. +func (a *Account) GetPeerGroupsList(peerID string) []string { + var grps []string + for groupID, group := range a.Groups { + for _, id := range group.Peers { + if id == peerID { + grps = append(grps, groupID) + break + } + } + } + return grps +} + +func (a *Account) getPeerDNSManagementStatus(peerID string) bool { + peerGroups := a.GetPeerGroups(peerID) + enabled := true + for _, groupID := range a.DNSSettings.DisabledManagementGroups { + _, found := peerGroups[groupID] + if found { + enabled = false + break + } + } + return enabled +} + +func (a *Account) GetPeerGroups(peerID string) LookupMap { + groupList := make(LookupMap) + for groupID, group := range a.Groups { + for _, id := range group.Peers { + if id == peerID { + groupList[groupID] = struct{}{} + break + } + } + } + return groupList +} + +func (a *Account) GetTakenIPs() []net.IP { + var takenIps []net.IP + for _, existingPeer := range a.Peers { + takenIps = append(takenIps, existingPeer.IP) + } + + return takenIps +} + +func (a *Account) GetPeerDNSLabels() LookupMap { + existingLabels := make(LookupMap) + for _, peer := range a.Peers { + if peer.DNSLabel != "" { + existingLabels[peer.DNSLabel] = struct{}{} + } + } + return existingLabels +} + +func (a *Account) Copy() *Account { + peers := map[string]*nbpeer.Peer{} + for id, peer := range a.Peers { + peers[id] = peer.Copy() + } + + users := map[string]*User{} + for id, user := range a.Users { + users[id] = user.Copy() + } + + setupKeys := map[string]*SetupKey{} + for id, key := range a.SetupKeys { + setupKeys[id] = key.Copy() + } + + groups := map[string]*Group{} + for id, group := range a.Groups { + groups[id] = group.Copy() + } + + policies := []*Policy{} + for _, policy := range a.Policies { + policies = append(policies, policy.Copy()) + } + + routes := map[route.ID]*route.Route{} + for id, r := range a.Routes { + routes[id] = r.Copy() + } + + nsGroups := map[string]*nbdns.NameServerGroup{} + for id, nsGroup := range a.NameServerGroups { + nsGroups[id] = nsGroup.Copy() + } + + dnsSettings := a.DNSSettings.Copy() + + var settings *Settings + if a.Settings != nil { + settings = a.Settings.Copy() + } + + postureChecks := []*posture.Checks{} + for _, postureCheck := range a.PostureChecks { + postureChecks = append(postureChecks, postureCheck.Copy()) + } + + nets := []*networkTypes.Network{} + for _, network := range a.Networks { + nets = append(nets, network.Copy()) + } + + networkRouters := []*routerTypes.NetworkRouter{} + for _, router := range a.NetworkRouters { + networkRouters = append(networkRouters, router.Copy()) + } + + networkResources := []*resourceTypes.NetworkResource{} + for _, resource := range a.NetworkResources { + networkResources = append(networkResources, resource.Copy()) + } + + return &Account{ + Id: a.Id, + CreatedBy: a.CreatedBy, + CreatedAt: a.CreatedAt, + Domain: a.Domain, + DomainCategory: a.DomainCategory, + IsDomainPrimaryAccount: a.IsDomainPrimaryAccount, + SetupKeys: setupKeys, + Network: a.Network.Copy(), + Peers: peers, + Users: users, + Groups: groups, + Policies: policies, + Routes: routes, + NameServerGroups: nsGroups, + DNSSettings: dnsSettings, + PostureChecks: postureChecks, + Settings: settings, + Networks: nets, + NetworkRouters: networkRouters, + NetworkResources: networkResources, + } +} + +func (a *Account) GetGroupAll() (*Group, error) { + for _, g := range a.Groups { + if g.Name == "All" { + return g, nil + } + } + return nil, fmt.Errorf("no group ALL found") +} + +// GetPeer looks up a Peer by ID +func (a *Account) GetPeer(peerID string) *nbpeer.Peer { + return a.Peers[peerID] +} + +// UserGroupsAddToPeers adds groups to all peers of user +func (a *Account) UserGroupsAddToPeers(userID string, groups ...string) map[string][]string { + groupUpdates := make(map[string][]string) + + userPeers := make(map[string]struct{}) + for pid, peer := range a.Peers { + if peer.UserID == userID { + userPeers[pid] = struct{}{} + } + } + + for _, gid := range groups { + group, ok := a.Groups[gid] + if !ok { + continue + } + + oldPeers := group.Peers + + groupPeers := make(map[string]struct{}) + for _, pid := range group.Peers { + groupPeers[pid] = struct{}{} + } + + for pid := range userPeers { + groupPeers[pid] = struct{}{} + } + + group.Peers = group.Peers[:0] + for pid := range groupPeers { + group.Peers = append(group.Peers, pid) + } + + groupUpdates[gid] = util.Difference(group.Peers, oldPeers) + } + + return groupUpdates +} + +// UserGroupsRemoveFromPeers removes groups from all peers of user +func (a *Account) UserGroupsRemoveFromPeers(userID string, groups ...string) map[string][]string { + groupUpdates := make(map[string][]string) + + for _, gid := range groups { + group, ok := a.Groups[gid] + if !ok || group.Name == "All" { + continue + } + + oldPeers := group.Peers + + update := make([]string, 0, len(group.Peers)) + for _, pid := range group.Peers { + peer, ok := a.Peers[pid] + if !ok { + continue + } + if peer.UserID != userID { + update = append(update, pid) + } + } + group.Peers = update + groupUpdates[gid] = util.Difference(oldPeers, group.Peers) + } + + return groupUpdates +} + +// GetPeerConnectionResources for a given peer +// +// This function returns the list of peers and firewall rules that are applicable to a given peer. +func (a *Account) GetPeerConnectionResources(ctx context.Context, peerID string, validatedPeersMap map[string]struct{}) ([]*nbpeer.Peer, []*FirewallRule) { + generateResources, getAccumulatedResources := a.connResourcesGenerator(ctx) + for _, policy := range a.Policies { + if !policy.Enabled { + continue + } + + for _, rule := range policy.Rules { + if !rule.Enabled { + continue + } + + sourcePeers, peerInSources := a.getAllPeersFromGroups(ctx, rule.Sources, peerID, policy.SourcePostureChecks, validatedPeersMap) + destinationPeers, peerInDestinations := a.getAllPeersFromGroups(ctx, rule.Destinations, peerID, nil, validatedPeersMap) + + if rule.Bidirectional { + if peerInSources { + generateResources(rule, destinationPeers, FirewallRuleDirectionIN) + } + if peerInDestinations { + generateResources(rule, sourcePeers, FirewallRuleDirectionOUT) + } + } + + if peerInSources { + generateResources(rule, destinationPeers, FirewallRuleDirectionOUT) + } + + if peerInDestinations { + generateResources(rule, sourcePeers, FirewallRuleDirectionIN) + } + } + } + + return getAccumulatedResources() +} + +// connResourcesGenerator returns generator and accumulator function which returns the result of generator calls +// +// The generator function is used to generate the list of peers and firewall rules that are applicable to a given peer. +// It safe to call the generator function multiple times for same peer and different rules no duplicates will be +// generated. The accumulator function returns the result of all the generator calls. +func (a *Account) connResourcesGenerator(ctx context.Context) (func(*PolicyRule, []*nbpeer.Peer, int), func() ([]*nbpeer.Peer, []*FirewallRule)) { + rulesExists := make(map[string]struct{}) + peersExists := make(map[string]struct{}) + rules := make([]*FirewallRule, 0) + peers := make([]*nbpeer.Peer, 0) + + all, err := a.GetGroupAll() + if err != nil { + log.WithContext(ctx).Errorf("failed to get group all: %v", err) + all = &Group{} + } + + return func(rule *PolicyRule, groupPeers []*nbpeer.Peer, direction int) { + isAll := (len(all.Peers) - 1) == len(groupPeers) + for _, peer := range groupPeers { + if peer == nil { + continue + } + + if _, ok := peersExists[peer.ID]; !ok { + peers = append(peers, peer) + peersExists[peer.ID] = struct{}{} + } + + fr := FirewallRule{ + PeerIP: peer.IP.String(), + Direction: direction, + Action: string(rule.Action), + Protocol: string(rule.Protocol), + } + + if isAll { + fr.PeerIP = "0.0.0.0" + } + + ruleID := rule.ID + fr.PeerIP + strconv.Itoa(direction) + + fr.Protocol + fr.Action + strings.Join(rule.Ports, ",") + if _, ok := rulesExists[ruleID]; ok { + continue + } + rulesExists[ruleID] = struct{}{} + + if len(rule.Ports) == 0 { + rules = append(rules, &fr) + continue + } + + for _, port := range rule.Ports { + pr := fr // clone rule and add set new port + pr.Port = port + rules = append(rules, &pr) + } + } + }, func() ([]*nbpeer.Peer, []*FirewallRule) { + return peers, rules + } +} + +// getAllPeersFromGroups for given peer ID and list of groups +// +// Returns a list of peers from specified groups that pass specified posture checks +// and a boolean indicating if the supplied peer ID exists within these groups. +// +// Important: Posture checks are applicable only to source group peers, +// for destination group peers, call this method with an empty list of sourcePostureChecksIDs +func (a *Account) getAllPeersFromGroups(ctx context.Context, groups []string, peerID string, sourcePostureChecksIDs []string, validatedPeersMap map[string]struct{}) ([]*nbpeer.Peer, bool) { + peerInGroups := false + filteredPeers := make([]*nbpeer.Peer, 0, len(groups)) + for _, g := range groups { + group, ok := a.Groups[g] + if !ok { + continue + } + + for _, p := range group.Peers { + peer, ok := a.Peers[p] + if !ok || peer == nil { + continue + } + + // validate the peer based on policy posture checks applied + isValid := a.validatePostureChecksOnPeer(ctx, sourcePostureChecksIDs, peer.ID) + if !isValid { + continue + } + + if _, ok := validatedPeersMap[peer.ID]; !ok { + continue + } + + if peer.ID == peerID { + peerInGroups = true + continue + } + + filteredPeers = append(filteredPeers, peer) + } + } + return filteredPeers, peerInGroups +} + +// validatePostureChecksOnPeer validates the posture checks on a peer +func (a *Account) validatePostureChecksOnPeer(ctx context.Context, sourcePostureChecksID []string, peerID string) bool { + peer, ok := a.Peers[peerID] + if !ok && peer == nil { + return false + } + + for _, postureChecksID := range sourcePostureChecksID { + postureChecks := a.GetPostureChecks(postureChecksID) + if postureChecks == nil { + continue + } + + for _, check := range postureChecks.GetChecks() { + isValid, err := check.Check(ctx, *peer) + if err != nil { + log.WithContext(ctx).Debugf("an error occurred check %s: on peer: %s :%s", check.Name(), peer.ID, err.Error()) + } + if !isValid { + return false + } + } + } + return true +} + +func (a *Account) GetPostureChecks(postureChecksID string) *posture.Checks { + for _, postureChecks := range a.PostureChecks { + if postureChecks.ID == postureChecksID { + return postureChecks + } + } + return nil +} + +// GetPeerRoutesFirewallRules gets the routes firewall rules associated with a routing peer ID for the account. +func (a *Account) GetPeerRoutesFirewallRules(ctx context.Context, peerID string, validatedPeersMap map[string]struct{}) []*RouteFirewallRule { + routesFirewallRules := make([]*RouteFirewallRule, 0, len(a.Routes)) + + enabledRoutes, _ := a.getRoutingPeerRoutes(ctx, peerID) + for _, route := range enabledRoutes { + // If no access control groups are specified, accept all traffic. + if len(route.AccessControlGroups) == 0 { + defaultPermit := getDefaultPermit(route) + routesFirewallRules = append(routesFirewallRules, defaultPermit...) + continue + } + + distributionPeers := a.getDistributionGroupsPeers(route) + + for _, accessGroup := range route.AccessControlGroups { + policies := GetAllRoutePoliciesFromGroups(a, []string{accessGroup}) + rules := a.getRouteFirewallRules(ctx, peerID, policies, route, validatedPeersMap, distributionPeers) + routesFirewallRules = append(routesFirewallRules, rules...) + } + } + + return routesFirewallRules +} + +func (a *Account) getRouteFirewallRules(ctx context.Context, peerID string, policies []*Policy, route *route.Route, validatedPeersMap map[string]struct{}, distributionPeers map[string]struct{}) []*RouteFirewallRule { + var fwRules []*RouteFirewallRule + for _, policy := range policies { + if !policy.Enabled { + continue + } + + for _, rule := range policy.Rules { + if !rule.Enabled { + continue + } + + rulePeers := a.getRulePeers(rule, peerID, distributionPeers, validatedPeersMap) + rules := generateRouteFirewallRules(ctx, route, rule, rulePeers, FirewallRuleDirectionIN) + fwRules = append(fwRules, rules...) + } + } + return fwRules +} + +func (a *Account) getRulePeers(rule *PolicyRule, peerID string, distributionPeers map[string]struct{}, validatedPeersMap map[string]struct{}) []*nbpeer.Peer { + distPeersWithPolicy := make(map[string]struct{}) + for _, id := range rule.Sources { + group := a.Groups[id] + if group == nil { + continue + } + + for _, pID := range group.Peers { + if pID == peerID { + continue + } + _, distPeer := distributionPeers[pID] + _, valid := validatedPeersMap[pID] + if distPeer && valid { + distPeersWithPolicy[pID] = struct{}{} + } + } + } + + distributionGroupPeers := make([]*nbpeer.Peer, 0, len(distPeersWithPolicy)) + for pID := range distPeersWithPolicy { + peer := a.Peers[pID] + if peer == nil { + continue + } + distributionGroupPeers = append(distributionGroupPeers, peer) + } + return distributionGroupPeers +} + +func (a *Account) getDistributionGroupsPeers(route *route.Route) map[string]struct{} { + distPeers := make(map[string]struct{}) + for _, id := range route.Groups { + group := a.Groups[id] + if group == nil { + continue + } + + for _, pID := range group.Peers { + distPeers[pID] = struct{}{} + } + } + return distPeers +} + +func getDefaultPermit(route *route.Route) []*RouteFirewallRule { + var rules []*RouteFirewallRule + + sources := []string{"0.0.0.0/0"} + if route.Network.Addr().Is6() { + sources = []string{"::/0"} + } + rule := RouteFirewallRule{ + SourceRanges: sources, + Action: string(PolicyTrafficActionAccept), + Destination: route.Network.String(), + Protocol: string(PolicyRuleProtocolALL), + IsDynamic: route.IsDynamic(), + } + + rules = append(rules, &rule) + + // dynamic routes always contain an IPv4 placeholder as destination, hence we must add IPv6 rules additionally + if route.IsDynamic() { + ruleV6 := rule + ruleV6.SourceRanges = []string{"::/0"} + rules = append(rules, &ruleV6) + } + + return rules +} + +// GetAllRoutePoliciesFromGroups retrieves route policies associated with the specified access control groups +// and returns a list of policies that have rules with destinations matching the specified groups. +func GetAllRoutePoliciesFromGroups(account *Account, accessControlGroups []string) []*Policy { + routePolicies := make([]*Policy, 0) + for _, groupID := range accessControlGroups { + group, ok := account.Groups[groupID] + if !ok { + continue + } + + for _, policy := range account.Policies { + for _, rule := range policy.Rules { + exist := slices.ContainsFunc(rule.Destinations, func(groupID string) bool { + return groupID == group.ID + }) + if exist { + routePolicies = append(routePolicies, policy) + continue + } + } + } + } + + return routePolicies +} diff --git a/management/server/types/dns_settings.go b/management/server/types/dns_settings.go new file mode 100644 index 000000000..1d33bb9fb --- /dev/null +++ b/management/server/types/dns_settings.go @@ -0,0 +1,16 @@ +package types + +// DNSSettings defines dns settings at the account level +type DNSSettings struct { + // DisabledManagementGroups groups whose DNS management is disabled + DisabledManagementGroups []string `gorm:"serializer:json"` +} + +// Copy returns a copy of the DNS settings +func (d DNSSettings) Copy() DNSSettings { + settings := DNSSettings{ + DisabledManagementGroups: make([]string, len(d.DisabledManagementGroups)), + } + copy(settings.DisabledManagementGroups, d.DisabledManagementGroups) + return settings +} diff --git a/management/server/types/firewall_rule.go b/management/server/types/firewall_rule.go new file mode 100644 index 000000000..1c9b6ca5b --- /dev/null +++ b/management/server/types/firewall_rule.go @@ -0,0 +1,129 @@ +package types + +import ( + "context" + "fmt" + "strconv" + "strings" + + log "github.com/sirupsen/logrus" + + nbpeer "github.com/netbirdio/netbird/management/server/peer" + nbroute "github.com/netbirdio/netbird/route" +) + +const ( + FirewallRuleDirectionIN = 0 + FirewallRuleDirectionOUT = 1 +) + +// FirewallRule is a rule of the firewall. +type FirewallRule struct { + // PeerIP of the peer + PeerIP string + + // Direction of the traffic + Direction int + + // Action of the traffic + Action string + + // Protocol of the traffic + Protocol string + + // Port of the traffic + Port string +} + +// generateRouteFirewallRules generates a list of firewall rules for a given route. +func generateRouteFirewallRules(ctx context.Context, route *nbroute.Route, rule *PolicyRule, groupPeers []*nbpeer.Peer, direction int) []*RouteFirewallRule { + rulesExists := make(map[string]struct{}) + rules := make([]*RouteFirewallRule, 0) + + sourceRanges := make([]string, 0, len(groupPeers)) + for _, peer := range groupPeers { + if peer == nil { + continue + } + sourceRanges = append(sourceRanges, fmt.Sprintf(AllowedIPsFormat, peer.IP)) + } + + baseRule := RouteFirewallRule{ + SourceRanges: sourceRanges, + Action: string(rule.Action), + Destination: route.Network.String(), + Protocol: string(rule.Protocol), + IsDynamic: route.IsDynamic(), + } + + // generate rule for port range + if len(rule.Ports) == 0 { + rules = append(rules, generateRulesWithPortRanges(baseRule, rule, rulesExists)...) + } else { + rules = append(rules, generateRulesWithPorts(ctx, baseRule, rule, rulesExists)...) + + } + + // TODO: generate IPv6 rules for dynamic routes + + return rules +} + +// generateRulesForPeer generates rules for a given peer based on ports and port ranges. +func generateRulesWithPortRanges(baseRule RouteFirewallRule, rule *PolicyRule, rulesExists map[string]struct{}) []*RouteFirewallRule { + rules := make([]*RouteFirewallRule, 0) + + ruleIDBase := generateRuleIDBase(rule, baseRule) + if len(rule.Ports) == 0 { + if len(rule.PortRanges) == 0 { + if _, ok := rulesExists[ruleIDBase]; !ok { + rulesExists[ruleIDBase] = struct{}{} + rules = append(rules, &baseRule) + } + } else { + for _, portRange := range rule.PortRanges { + ruleID := fmt.Sprintf("%s%d-%d", ruleIDBase, portRange.Start, portRange.End) + if _, ok := rulesExists[ruleID]; !ok { + rulesExists[ruleID] = struct{}{} + pr := baseRule + pr.PortRange = portRange + rules = append(rules, &pr) + } + } + } + return rules + } + + return rules +} + +// generateRulesWithPorts generates rules when specific ports are provided. +func generateRulesWithPorts(ctx context.Context, baseRule RouteFirewallRule, rule *PolicyRule, rulesExists map[string]struct{}) []*RouteFirewallRule { + rules := make([]*RouteFirewallRule, 0) + ruleIDBase := generateRuleIDBase(rule, baseRule) + + for _, port := range rule.Ports { + ruleID := ruleIDBase + port + if _, ok := rulesExists[ruleID]; ok { + continue + } + rulesExists[ruleID] = struct{}{} + + pr := baseRule + p, err := strconv.ParseUint(port, 10, 16) + if err != nil { + log.WithContext(ctx).Errorf("failed to parse port %s for rule: %s", port, rule.ID) + continue + } + + pr.Port = uint16(p) + rules = append(rules, &pr) + } + + return rules +} + +// generateRuleIDBase generates the base rule ID for checking duplicates. +func generateRuleIDBase(rule *PolicyRule, baseRule RouteFirewallRule) string { + return rule.ID + strings.Join(baseRule.SourceRanges, ",") + strconv.Itoa(FirewallRuleDirectionIN) + baseRule.Protocol + baseRule.Action +} diff --git a/management/server/group/group.go b/management/server/types/group.go similarity index 67% rename from management/server/group/group.go rename to management/server/types/group.go index 24c60d3ce..7ba4b8656 100644 --- a/management/server/group/group.go +++ b/management/server/types/group.go @@ -1,6 +1,8 @@ -package group +package types -import "github.com/netbirdio/netbird/management/server/integration_reference" +import ( + "github.com/netbirdio/netbird/management/server/integration_reference" +) const ( GroupIssuedAPI = "api" @@ -25,6 +27,9 @@ type Group struct { // Peers list of the group Peers []string `gorm:"serializer:json"` + // Resources contains a list of resources in that group + Resources []Resource `gorm:"serializer:json"` + IntegrationReference integration_reference.IntegrationReference `gorm:"embedded;embeddedPrefix:integration_ref_"` } @@ -39,9 +44,11 @@ func (g *Group) Copy() *Group { Name: g.Name, Issued: g.Issued, Peers: make([]string, len(g.Peers)), + Resources: make([]Resource, len(g.Resources)), IntegrationReference: g.IntegrationReference, } copy(group.Peers, g.Peers) + copy(group.Resources, g.Resources) return group } @@ -81,3 +88,26 @@ func (g *Group) RemovePeer(peerID string) bool { } return false } + +// AddResource adds resource to Resources if not present, returning true if added. +func (g *Group) AddResource(resource Resource) bool { + for _, item := range g.Resources { + if item == resource { + return false + } + } + + g.Resources = append(g.Resources, resource) + return true +} + +// RemoveResource removes resource from Resources if present, returning true if removed. +func (g *Group) RemoveResource(resource Resource) bool { + for i, item := range g.Resources { + if item == resource { + g.Resources = append(g.Resources[:i], g.Resources[i+1:]...) + return true + } + } + return false +} diff --git a/management/server/group/group_test.go b/management/server/types/group_test.go similarity index 99% rename from management/server/group/group_test.go rename to management/server/types/group_test.go index cb002f8d9..12107c603 100644 --- a/management/server/group/group_test.go +++ b/management/server/types/group_test.go @@ -1,4 +1,4 @@ -package group +package types import ( "testing" diff --git a/management/server/network.go b/management/server/types/network.go similarity index 96% rename from management/server/network.go rename to management/server/types/network.go index a5b188b46..d1fccd149 100644 --- a/management/server/network.go +++ b/management/server/types/network.go @@ -1,4 +1,4 @@ -package server +package types import ( "math/rand" @@ -43,7 +43,7 @@ type Network struct { // Used to synchronize state to the client apps. Serial uint64 - mu sync.Mutex `json:"-" gorm:"-"` + Mu sync.Mutex `json:"-" gorm:"-"` } // NewNetwork creates a new Network initializing it with a Serial=0 @@ -66,15 +66,15 @@ func NewNetwork() *Network { // IncSerial increments Serial by 1 reflecting that the network state has been changed func (n *Network) IncSerial() { - n.mu.Lock() - defer n.mu.Unlock() + n.Mu.Lock() + defer n.Mu.Unlock() n.Serial++ } // CurrentSerial returns the Network.Serial of the network (latest state id) func (n *Network) CurrentSerial() uint64 { - n.mu.Lock() - defer n.mu.Unlock() + n.Mu.Lock() + defer n.Mu.Unlock() return n.Serial } diff --git a/management/server/network_test.go b/management/server/types/network_test.go similarity index 98% rename from management/server/network_test.go rename to management/server/types/network_test.go index b067c4991..d0b0894d4 100644 --- a/management/server/network_test.go +++ b/management/server/types/network_test.go @@ -1,4 +1,4 @@ -package server +package types import ( "net" diff --git a/management/server/personal_access_token.go b/management/server/types/personal_access_token.go similarity index 99% rename from management/server/personal_access_token.go rename to management/server/types/personal_access_token.go index f46666112..1bf225856 100644 --- a/management/server/personal_access_token.go +++ b/management/server/types/personal_access_token.go @@ -1,4 +1,4 @@ -package server +package types import ( "crypto/sha256" diff --git a/management/server/personal_access_token_test.go b/management/server/types/personal_access_token_test.go similarity index 98% rename from management/server/personal_access_token_test.go rename to management/server/types/personal_access_token_test.go index 311ffd9cf..ac3377151 100644 --- a/management/server/personal_access_token_test.go +++ b/management/server/types/personal_access_token_test.go @@ -1,4 +1,4 @@ -package server +package types import ( "crypto/sha256" diff --git a/management/server/types/policy.go b/management/server/types/policy.go new file mode 100644 index 000000000..c0d84e6e0 --- /dev/null +++ b/management/server/types/policy.go @@ -0,0 +1,116 @@ +package types + +const ( + // PolicyTrafficActionAccept indicates that the traffic is accepted + PolicyTrafficActionAccept = PolicyTrafficActionType("accept") + // PolicyTrafficActionDrop indicates that the traffic is dropped + PolicyTrafficActionDrop = PolicyTrafficActionType("drop") +) + +const ( + // PolicyRuleProtocolALL type of traffic + PolicyRuleProtocolALL = PolicyRuleProtocolType("all") + // PolicyRuleProtocolTCP type of traffic + PolicyRuleProtocolTCP = PolicyRuleProtocolType("tcp") + // PolicyRuleProtocolUDP type of traffic + PolicyRuleProtocolUDP = PolicyRuleProtocolType("udp") + // PolicyRuleProtocolICMP type of traffic + PolicyRuleProtocolICMP = PolicyRuleProtocolType("icmp") +) + +const ( + // PolicyRuleFlowDirect allows traffic from source to destination + PolicyRuleFlowDirect = PolicyRuleDirection("direct") + // PolicyRuleFlowBidirect allows traffic to both directions + PolicyRuleFlowBidirect = PolicyRuleDirection("bidirect") +) + +const ( + // DefaultRuleName is a name for the Default rule that is created for every account + DefaultRuleName = "Default" + // DefaultRuleDescription is a description for the Default rule that is created for every account + DefaultRuleDescription = "This is a default rule that allows connections between all the resources" + // DefaultPolicyName is a name for the Default policy that is created for every account + DefaultPolicyName = "Default" + // DefaultPolicyDescription is a description for the Default policy that is created for every account + DefaultPolicyDescription = "This is a default policy that allows connections between all the resources" +) + +// PolicyUpdateOperation operation object with type and values to be applied +type PolicyUpdateOperation struct { + Type PolicyUpdateOperationType + Values []string +} + +// Policy of the Rego query +type Policy struct { + // ID of the policy' + ID string `gorm:"primaryKey"` + + // AccountID is a reference to Account that this object belongs + AccountID string `json:"-" gorm:"index"` + + // Name of the Policy + Name string + + // Description of the policy visible in the UI + Description string + + // Enabled status of the policy + Enabled bool + + // Rules of the policy + Rules []*PolicyRule `gorm:"foreignKey:PolicyID;references:id;constraint:OnDelete:CASCADE;"` + + // SourcePostureChecks are ID references to Posture checks for policy source groups + SourcePostureChecks []string `gorm:"serializer:json"` +} + +// Copy returns a copy of the policy. +func (p *Policy) Copy() *Policy { + c := &Policy{ + ID: p.ID, + AccountID: p.AccountID, + Name: p.Name, + Description: p.Description, + Enabled: p.Enabled, + Rules: make([]*PolicyRule, len(p.Rules)), + SourcePostureChecks: make([]string, len(p.SourcePostureChecks)), + } + for i, r := range p.Rules { + c.Rules[i] = r.Copy() + } + copy(c.SourcePostureChecks, p.SourcePostureChecks) + return c +} + +// EventMeta returns activity event meta related to this policy +func (p *Policy) EventMeta() map[string]any { + return map[string]any{"name": p.Name} +} + +// UpgradeAndFix different version of policies to latest version +func (p *Policy) UpgradeAndFix() { + for _, r := range p.Rules { + // start migrate from version v0.20.3 + if r.Protocol == "" { + r.Protocol = PolicyRuleProtocolALL + } + if r.Protocol == PolicyRuleProtocolALL && !r.Bidirectional { + r.Bidirectional = true + } + // -- v0.20.4 + } +} + +// RuleGroups returns a list of all groups referenced in the policy's rules, +// including sources and destinations. +func (p *Policy) RuleGroups() []string { + groups := make([]string, 0) + for _, rule := range p.Rules { + groups = append(groups, rule.Sources...) + groups = append(groups, rule.Destinations...) + } + + return groups +} diff --git a/management/server/types/policyrule.go b/management/server/types/policyrule.go new file mode 100644 index 000000000..bd9a99292 --- /dev/null +++ b/management/server/types/policyrule.go @@ -0,0 +1,87 @@ +package types + +// PolicyUpdateOperationType operation type +type PolicyUpdateOperationType int + +// PolicyTrafficActionType action type for the firewall +type PolicyTrafficActionType string + +// PolicyRuleProtocolType type of traffic +type PolicyRuleProtocolType string + +// PolicyRuleDirection direction of traffic +type PolicyRuleDirection string + +// RulePortRange represents a range of ports for a firewall rule. +type RulePortRange struct { + Start uint16 + End uint16 +} + +// PolicyRule is the metadata of the policy +type PolicyRule struct { + // ID of the policy rule + ID string `gorm:"primaryKey"` + + // PolicyID is a reference to Policy that this object belongs + PolicyID string `json:"-" gorm:"index"` + + // Name of the rule visible in the UI + Name string + + // Description of the rule visible in the UI + Description string + + // Enabled status of rule in the system + Enabled bool + + // Action policy accept or drops packets + Action PolicyTrafficActionType + + // Destinations policy destination groups + Destinations []string `gorm:"serializer:json"` + + // DestinationResource policy destination resource that the rule is applied to + DestinationResource Resource `gorm:"serializer:json"` + + // Sources policy source groups + Sources []string `gorm:"serializer:json"` + + // SourceResource policy source resource that the rule is applied to + SourceResource Resource `gorm:"serializer:json"` + + // Bidirectional define if the rule is applicable in both directions, sources, and destinations + Bidirectional bool + + // Protocol type of the traffic + Protocol PolicyRuleProtocolType + + // Ports or it ranges list + Ports []string `gorm:"serializer:json"` + + // PortRanges a list of port ranges. + PortRanges []RulePortRange `gorm:"serializer:json"` +} + +// Copy returns a copy of a policy rule +func (pm *PolicyRule) Copy() *PolicyRule { + rule := &PolicyRule{ + ID: pm.ID, + PolicyID: pm.PolicyID, + Name: pm.Name, + Description: pm.Description, + Enabled: pm.Enabled, + Action: pm.Action, + Destinations: make([]string, len(pm.Destinations)), + Sources: make([]string, len(pm.Sources)), + Bidirectional: pm.Bidirectional, + Protocol: pm.Protocol, + Ports: make([]string, len(pm.Ports)), + PortRanges: make([]RulePortRange, len(pm.PortRanges)), + } + copy(rule.Destinations, pm.Destinations) + copy(rule.Sources, pm.Sources) + copy(rule.Ports, pm.Ports) + copy(rule.PortRanges, pm.PortRanges) + return rule +} diff --git a/management/server/types/resource.go b/management/server/types/resource.go new file mode 100644 index 000000000..820872f20 --- /dev/null +++ b/management/server/types/resource.go @@ -0,0 +1,30 @@ +package types + +import ( + "github.com/netbirdio/netbird/management/server/http/api" +) + +type Resource struct { + ID string + Type string +} + +func (r *Resource) ToAPIResponse() *api.Resource { + if r.ID == "" && r.Type == "" { + return nil + } + + return &api.Resource{ + Id: r.ID, + Type: api.ResourceType(r.Type), + } +} + +func (r *Resource) FromAPIRequest(req *api.Resource) { + if req == nil { + return + } + + r.ID = req.Id + r.Type = string(req.Type) +} diff --git a/management/server/types/route_firewall_rule.go b/management/server/types/route_firewall_rule.go new file mode 100644 index 000000000..73d49d01d --- /dev/null +++ b/management/server/types/route_firewall_rule.go @@ -0,0 +1,25 @@ +package types + +// RouteFirewallRule a firewall rule applicable for a routed network. +type RouteFirewallRule struct { + // SourceRanges IP ranges of the routing peers. + SourceRanges []string + + // Action of the traffic when the rule is applicable + Action string + + // Destination a network prefix for the routed traffic + Destination string + + // Protocol of the traffic + Protocol string + + // Port of the traffic + Port uint16 + + // PortRange represents the range of ports for a firewall rule + PortRange RulePortRange + + // isDynamic indicates whether the rule is for DNS routing + IsDynamic bool +} diff --git a/management/server/types/settings.go b/management/server/types/settings.go new file mode 100644 index 000000000..0c1a3ecab --- /dev/null +++ b/management/server/types/settings.go @@ -0,0 +1,63 @@ +package types + +import ( + "time" + + "github.com/netbirdio/netbird/management/server/account" +) + +// Settings represents Account settings structure that can be modified via API and Dashboard +type Settings struct { + // PeerLoginExpirationEnabled globally enables or disables peer login expiration + PeerLoginExpirationEnabled bool + + // PeerLoginExpiration is a setting that indicates when peer login expires. + // Applies to all peers that have Peer.LoginExpirationEnabled set to true. + PeerLoginExpiration time.Duration + + // PeerInactivityExpirationEnabled globally enables or disables peer inactivity expiration + PeerInactivityExpirationEnabled bool + + // PeerInactivityExpiration is a setting that indicates when peer inactivity expires. + // Applies to all peers that have Peer.PeerInactivityExpirationEnabled set to true. + PeerInactivityExpiration time.Duration + + // RegularUsersViewBlocked allows to block regular users from viewing even their own peers and some UI elements + RegularUsersViewBlocked bool + + // GroupsPropagationEnabled allows to propagate auto groups from the user to the peer + GroupsPropagationEnabled bool + + // JWTGroupsEnabled allows extract groups from JWT claim, which name defined in the JWTGroupsClaimName + // and add it to account groups. + JWTGroupsEnabled bool + + // JWTGroupsClaimName from which we extract groups name to add it to account groups + JWTGroupsClaimName string + + // JWTAllowGroups list of groups to which users are allowed access + JWTAllowGroups []string `gorm:"serializer:json"` + + // Extra is a dictionary of Account settings + Extra *account.ExtraSettings `gorm:"embedded;embeddedPrefix:extra_"` +} + +// Copy copies the Settings struct +func (s *Settings) Copy() *Settings { + settings := &Settings{ + PeerLoginExpirationEnabled: s.PeerLoginExpirationEnabled, + PeerLoginExpiration: s.PeerLoginExpiration, + JWTGroupsEnabled: s.JWTGroupsEnabled, + JWTGroupsClaimName: s.JWTGroupsClaimName, + GroupsPropagationEnabled: s.GroupsPropagationEnabled, + JWTAllowGroups: s.JWTAllowGroups, + RegularUsersViewBlocked: s.RegularUsersViewBlocked, + + PeerInactivityExpirationEnabled: s.PeerInactivityExpirationEnabled, + PeerInactivityExpiration: s.PeerInactivityExpiration, + } + if s.Extra != nil { + settings.Extra = s.Extra.Copy() + } + return settings +} diff --git a/management/server/types/setupkey.go b/management/server/types/setupkey.go new file mode 100644 index 000000000..a5cf346a0 --- /dev/null +++ b/management/server/types/setupkey.go @@ -0,0 +1,181 @@ +package types + +import ( + "crypto/sha256" + b64 "encoding/base64" + "hash/fnv" + "strconv" + "strings" + "time" + "unicode/utf8" + + "github.com/google/uuid" +) + +const ( + // SetupKeyReusable is a multi-use key (can be used for multiple machines) + SetupKeyReusable SetupKeyType = "reusable" + // SetupKeyOneOff is a single use key (can be used only once) + SetupKeyOneOff SetupKeyType = "one-off" + // DefaultSetupKeyDuration = 1 month + DefaultSetupKeyDuration = 24 * 30 * time.Hour + // DefaultSetupKeyName is a default name of the default setup key + DefaultSetupKeyName = "Default key" + // SetupKeyUnlimitedUsage indicates an unlimited usage of a setup key + SetupKeyUnlimitedUsage = 0 +) + +// SetupKeyType is the type of setup key +type SetupKeyType string + +// SetupKey represents a pre-authorized key used to register machines (peers) +type SetupKey struct { + Id string + // AccountID is a reference to Account that this object belongs + AccountID string `json:"-" gorm:"index"` + Key string + KeySecret string + Name string + Type SetupKeyType + CreatedAt time.Time + ExpiresAt time.Time + UpdatedAt time.Time `gorm:"autoUpdateTime:false"` + // Revoked indicates whether the key was revoked or not (we don't remove them for tracking purposes) + Revoked bool + // UsedTimes indicates how many times the key was used + UsedTimes int + // LastUsed last time the key was used for peer registration + LastUsed time.Time + // AutoGroups is a list of Group IDs that are auto assigned to a Peer when it uses this key to register + AutoGroups []string `gorm:"serializer:json"` + // UsageLimit indicates the number of times this key can be used to enroll a machine. + // The value of 0 indicates the unlimited usage. + UsageLimit int + // Ephemeral indicate if the peers will be ephemeral or not + Ephemeral bool +} + +// Copy copies SetupKey to a new object +func (key *SetupKey) Copy() *SetupKey { + autoGroups := make([]string, len(key.AutoGroups)) + copy(autoGroups, key.AutoGroups) + if key.UpdatedAt.IsZero() { + key.UpdatedAt = key.CreatedAt + } + return &SetupKey{ + Id: key.Id, + AccountID: key.AccountID, + Key: key.Key, + KeySecret: key.KeySecret, + Name: key.Name, + Type: key.Type, + CreatedAt: key.CreatedAt, + ExpiresAt: key.ExpiresAt, + UpdatedAt: key.UpdatedAt, + Revoked: key.Revoked, + UsedTimes: key.UsedTimes, + LastUsed: key.LastUsed, + AutoGroups: autoGroups, + UsageLimit: key.UsageLimit, + Ephemeral: key.Ephemeral, + } +} + +// EventMeta returns activity event meta related to the setup key +func (key *SetupKey) EventMeta() map[string]any { + return map[string]any{"name": key.Name, "type": key.Type, "key": key.KeySecret} +} + +// HiddenKey returns the Key value hidden with "*" and a 5 character prefix. +// E.g., "831F6*******************************" +func HiddenKey(key string, length int) string { + prefix := key[0:5] + if length > utf8.RuneCountInString(key) { + length = utf8.RuneCountInString(key) - len(prefix) + } + return prefix + strings.Repeat("*", length) +} + +// IncrementUsage makes a copy of a key, increments the UsedTimes by 1 and sets LastUsed to now +func (key *SetupKey) IncrementUsage() *SetupKey { + c := key.Copy() + c.UsedTimes++ + c.LastUsed = time.Now().UTC() + return c +} + +// IsValid is true if the key was not revoked, is not expired and used not more than it was supposed to +func (key *SetupKey) IsValid() bool { + return !key.IsRevoked() && !key.IsExpired() && !key.IsOverUsed() +} + +// IsRevoked if key was revoked +func (key *SetupKey) IsRevoked() bool { + return key.Revoked +} + +// IsExpired if key was expired +func (key *SetupKey) IsExpired() bool { + if key.ExpiresAt.IsZero() { + return false + } + return time.Now().After(key.ExpiresAt) +} + +// IsOverUsed if the key was used too many times. SetupKey.UsageLimit == 0 indicates the unlimited usage. +func (key *SetupKey) IsOverUsed() bool { + limit := key.UsageLimit + if key.Type == SetupKeyOneOff { + limit = 1 + } + return limit > 0 && key.UsedTimes >= limit +} + +// GenerateSetupKey generates a new setup key +func GenerateSetupKey(name string, t SetupKeyType, validFor time.Duration, autoGroups []string, + usageLimit int, ephemeral bool) (*SetupKey, string) { + key := strings.ToUpper(uuid.New().String()) + limit := usageLimit + if t == SetupKeyOneOff { + limit = 1 + } + + expiresAt := time.Time{} + if validFor != 0 { + expiresAt = time.Now().UTC().Add(validFor) + } + + hashedKey := sha256.Sum256([]byte(key)) + encodedHashedKey := b64.StdEncoding.EncodeToString(hashedKey[:]) + + return &SetupKey{ + Id: strconv.Itoa(int(Hash(key))), + Key: encodedHashedKey, + KeySecret: HiddenKey(key, 4), + Name: name, + Type: t, + CreatedAt: time.Now().UTC(), + ExpiresAt: expiresAt, + UpdatedAt: time.Now().UTC(), + Revoked: false, + UsedTimes: 0, + AutoGroups: autoGroups, + UsageLimit: limit, + Ephemeral: ephemeral, + }, key +} + +// GenerateDefaultSetupKey generates a default reusable setup key with an unlimited usage and 30 days expiration +func GenerateDefaultSetupKey() (*SetupKey, string) { + return GenerateSetupKey(DefaultSetupKeyName, SetupKeyReusable, DefaultSetupKeyDuration, []string{}, + SetupKeyUnlimitedUsage, false) +} + +func Hash(s string) uint32 { + h := fnv.New32a() + _, err := h.Write([]byte(s)) + if err != nil { + panic(err) + } + return h.Sum32() +} diff --git a/management/server/types/user.go b/management/server/types/user.go new file mode 100644 index 000000000..5f1b71792 --- /dev/null +++ b/management/server/types/user.go @@ -0,0 +1,231 @@ +package types + +import ( + "fmt" + "strings" + "time" + + "github.com/netbirdio/netbird/management/server/idp" + "github.com/netbirdio/netbird/management/server/integration_reference" +) + +const ( + UserRoleOwner UserRole = "owner" + UserRoleAdmin UserRole = "admin" + UserRoleUser UserRole = "user" + UserRoleUnknown UserRole = "unknown" + UserRoleBillingAdmin UserRole = "billing_admin" + + UserStatusActive UserStatus = "active" + UserStatusDisabled UserStatus = "disabled" + UserStatusInvited UserStatus = "invited" + + UserIssuedAPI = "api" + UserIssuedIntegration = "integration" +) + +// StrRoleToUserRole returns UserRole for a given strRole or UserRoleUnknown if the specified role is unknown +func StrRoleToUserRole(strRole string) UserRole { + switch strings.ToLower(strRole) { + case "owner": + return UserRoleOwner + case "admin": + return UserRoleAdmin + case "user": + return UserRoleUser + case "billing_admin": + return UserRoleBillingAdmin + default: + return UserRoleUnknown + } +} + +// UserStatus is the status of a User +type UserStatus string + +// UserRole is the role of a User +type UserRole string + +type UserInfo struct { + ID string `json:"id"` + Email string `json:"email"` + Name string `json:"name"` + Role string `json:"role"` + AutoGroups []string `json:"auto_groups"` + Status string `json:"-"` + IsServiceUser bool `json:"is_service_user"` + IsBlocked bool `json:"is_blocked"` + NonDeletable bool `json:"non_deletable"` + LastLogin time.Time `json:"last_login"` + Issued string `json:"issued"` + IntegrationReference integration_reference.IntegrationReference `json:"-"` + Permissions UserPermissions `json:"permissions"` +} + +type UserPermissions struct { + DashboardView string `json:"dashboard_view"` +} + +// User represents a user of the system +type User struct { + Id string `gorm:"primaryKey"` + // AccountID is a reference to Account that this object belongs + AccountID string `json:"-" gorm:"index"` + Role UserRole + IsServiceUser bool + // NonDeletable indicates whether the service user can be deleted + NonDeletable bool + // ServiceUserName is only set if IsServiceUser is true + ServiceUserName string + // AutoGroups is a list of Group IDs to auto-assign to peers registered by this user + AutoGroups []string `gorm:"serializer:json"` + PATs map[string]*PersonalAccessToken `gorm:"-"` + PATsG []PersonalAccessToken `json:"-" gorm:"foreignKey:UserID;references:id"` + // Blocked indicates whether the user is blocked. Blocked users can't use the system. + Blocked bool + // LastLogin is the last time the user logged in to IdP + LastLogin time.Time + // CreatedAt records the time the user was created + CreatedAt time.Time + + // Issued of the user + Issued string `gorm:"default:api"` + + IntegrationReference integration_reference.IntegrationReference `gorm:"embedded;embeddedPrefix:integration_ref_"` +} + +// IsBlocked returns true if the user is blocked, false otherwise +func (u *User) IsBlocked() bool { + return u.Blocked +} + +func (u *User) LastDashboardLoginChanged(LastLogin time.Time) bool { + return LastLogin.After(u.LastLogin) && !u.LastLogin.IsZero() +} + +// HasAdminPower returns true if the user has admin or owner roles, false otherwise +func (u *User) HasAdminPower() bool { + return u.Role == UserRoleAdmin || u.Role == UserRoleOwner +} + +// IsAdminOrServiceUser checks if the user has admin power or is a service user. +func (u *User) IsAdminOrServiceUser() bool { + return u.HasAdminPower() || u.IsServiceUser +} + +// IsRegularUser checks if the user is a regular user. +func (u *User) IsRegularUser() bool { + return !u.HasAdminPower() && !u.IsServiceUser +} + +// ToUserInfo converts a User object to a UserInfo object. +func (u *User) ToUserInfo(userData *idp.UserData, settings *Settings) (*UserInfo, error) { + autoGroups := u.AutoGroups + if autoGroups == nil { + autoGroups = []string{} + } + + dashboardViewPermissions := "full" + if !u.HasAdminPower() { + dashboardViewPermissions = "limited" + if settings.RegularUsersViewBlocked { + dashboardViewPermissions = "blocked" + } + } + + if userData == nil { + return &UserInfo{ + ID: u.Id, + Email: "", + Name: u.ServiceUserName, + Role: string(u.Role), + AutoGroups: u.AutoGroups, + Status: string(UserStatusActive), + IsServiceUser: u.IsServiceUser, + IsBlocked: u.Blocked, + LastLogin: u.LastLogin, + Issued: u.Issued, + Permissions: UserPermissions{ + DashboardView: dashboardViewPermissions, + }, + }, nil + } + if userData.ID != u.Id { + return nil, fmt.Errorf("wrong UserData provided for user %s", u.Id) + } + + userStatus := UserStatusActive + if userData.AppMetadata.WTPendingInvite != nil && *userData.AppMetadata.WTPendingInvite { + userStatus = UserStatusInvited + } + + return &UserInfo{ + ID: u.Id, + Email: userData.Email, + Name: userData.Name, + Role: string(u.Role), + AutoGroups: autoGroups, + Status: string(userStatus), + IsServiceUser: u.IsServiceUser, + IsBlocked: u.Blocked, + LastLogin: u.LastLogin, + Issued: u.Issued, + Permissions: UserPermissions{ + DashboardView: dashboardViewPermissions, + }, + }, nil +} + +// Copy the user +func (u *User) Copy() *User { + autoGroups := make([]string, len(u.AutoGroups)) + copy(autoGroups, u.AutoGroups) + pats := make(map[string]*PersonalAccessToken, len(u.PATs)) + for k, v := range u.PATs { + pats[k] = v.Copy() + } + return &User{ + Id: u.Id, + AccountID: u.AccountID, + Role: u.Role, + AutoGroups: autoGroups, + IsServiceUser: u.IsServiceUser, + NonDeletable: u.NonDeletable, + ServiceUserName: u.ServiceUserName, + PATs: pats, + Blocked: u.Blocked, + LastLogin: u.LastLogin, + CreatedAt: u.CreatedAt, + Issued: u.Issued, + IntegrationReference: u.IntegrationReference, + } +} + +// NewUser creates a new user +func NewUser(id string, role UserRole, isServiceUser bool, nonDeletable bool, serviceUserName string, autoGroups []string, issued string) *User { + return &User{ + Id: id, + Role: role, + IsServiceUser: isServiceUser, + NonDeletable: nonDeletable, + ServiceUserName: serviceUserName, + AutoGroups: autoGroups, + Issued: issued, + CreatedAt: time.Now().UTC(), + } +} + +// NewRegularUser creates a new user with role UserRoleUser +func NewRegularUser(id string) *User { + return NewUser(id, UserRoleUser, false, false, "", []string{}, UserIssuedAPI) +} + +// NewAdminUser creates a new user with role UserRoleAdmin +func NewAdminUser(id string) *User { + return NewUser(id, UserRoleAdmin, false, false, "", []string{}, UserIssuedAPI) +} + +// NewOwnerUser creates a new user with role UserRoleOwner +func NewOwnerUser(id string) *User { + return NewUser(id, UserRoleOwner, false, false, "", []string{}, UserIssuedAPI) +} diff --git a/management/server/updatechannel.go b/management/server/updatechannel.go index d338b84b1..de7dd57df 100644 --- a/management/server/updatechannel.go +++ b/management/server/updatechannel.go @@ -9,13 +9,14 @@ import ( "github.com/netbirdio/netbird/management/proto" "github.com/netbirdio/netbird/management/server/telemetry" + "github.com/netbirdio/netbird/management/server/types" ) const channelBufferSize = 100 type UpdateMessage struct { Update *proto.SyncResponse - NetworkMap *NetworkMap + NetworkMap *types.NetworkMap } type PeersUpdateManager struct { diff --git a/management/server/user.go b/management/server/user.go index edb5e6fd3..9fc2464de 100644 --- a/management/server/user.go +++ b/management/server/user.go @@ -13,217 +13,17 @@ import ( "github.com/netbirdio/netbird/management/server/activity" nbContext "github.com/netbirdio/netbird/management/server/context" - nbgroup "github.com/netbirdio/netbird/management/server/group" "github.com/netbirdio/netbird/management/server/idp" - "github.com/netbirdio/netbird/management/server/integration_reference" "github.com/netbirdio/netbird/management/server/jwtclaims" nbpeer "github.com/netbirdio/netbird/management/server/peer" "github.com/netbirdio/netbird/management/server/status" + "github.com/netbirdio/netbird/management/server/store" + "github.com/netbirdio/netbird/management/server/types" + "github.com/netbirdio/netbird/management/server/util" ) -const ( - UserRoleOwner UserRole = "owner" - UserRoleAdmin UserRole = "admin" - UserRoleUser UserRole = "user" - UserRoleUnknown UserRole = "unknown" - UserRoleBillingAdmin UserRole = "billing_admin" - - UserStatusActive UserStatus = "active" - UserStatusDisabled UserStatus = "disabled" - UserStatusInvited UserStatus = "invited" - - UserIssuedAPI = "api" - UserIssuedIntegration = "integration" -) - -// StrRoleToUserRole returns UserRole for a given strRole or UserRoleUnknown if the specified role is unknown -func StrRoleToUserRole(strRole string) UserRole { - switch strings.ToLower(strRole) { - case "owner": - return UserRoleOwner - case "admin": - return UserRoleAdmin - case "user": - return UserRoleUser - case "billing_admin": - return UserRoleBillingAdmin - default: - return UserRoleUnknown - } -} - -// UserStatus is the status of a User -type UserStatus string - -// UserRole is the role of a User -type UserRole string - -// User represents a user of the system -type User struct { - Id string `gorm:"primaryKey"` - // AccountID is a reference to Account that this object belongs - AccountID string `json:"-" gorm:"index"` - Role UserRole - IsServiceUser bool - // NonDeletable indicates whether the service user can be deleted - NonDeletable bool - // ServiceUserName is only set if IsServiceUser is true - ServiceUserName string - // AutoGroups is a list of Group IDs to auto-assign to peers registered by this user - AutoGroups []string `gorm:"serializer:json"` - PATs map[string]*PersonalAccessToken `gorm:"-"` - PATsG []PersonalAccessToken `json:"-" gorm:"foreignKey:UserID;references:id"` - // Blocked indicates whether the user is blocked. Blocked users can't use the system. - Blocked bool - // LastLogin is the last time the user logged in to IdP - LastLogin time.Time - // CreatedAt records the time the user was created - CreatedAt time.Time - - // Issued of the user - Issued string `gorm:"default:api"` - - IntegrationReference integration_reference.IntegrationReference `gorm:"embedded;embeddedPrefix:integration_ref_"` -} - -// IsBlocked returns true if the user is blocked, false otherwise -func (u *User) IsBlocked() bool { - return u.Blocked -} - -func (u *User) LastDashboardLoginChanged(LastLogin time.Time) bool { - return LastLogin.After(u.LastLogin) && !u.LastLogin.IsZero() -} - -// HasAdminPower returns true if the user has admin or owner roles, false otherwise -func (u *User) HasAdminPower() bool { - return u.Role == UserRoleAdmin || u.Role == UserRoleOwner -} - -// IsAdminOrServiceUser checks if the user has admin power or is a service user. -func (u *User) IsAdminOrServiceUser() bool { - return u.HasAdminPower() || u.IsServiceUser -} - -// IsRegularUser checks if the user is a regular user. -func (u *User) IsRegularUser() bool { - return !u.HasAdminPower() && !u.IsServiceUser -} - -// ToUserInfo converts a User object to a UserInfo object. -func (u *User) ToUserInfo(userData *idp.UserData, settings *Settings) (*UserInfo, error) { - autoGroups := u.AutoGroups - if autoGroups == nil { - autoGroups = []string{} - } - - dashboardViewPermissions := "full" - if !u.HasAdminPower() { - dashboardViewPermissions = "limited" - if settings.RegularUsersViewBlocked { - dashboardViewPermissions = "blocked" - } - } - - if userData == nil { - return &UserInfo{ - ID: u.Id, - Email: "", - Name: u.ServiceUserName, - Role: string(u.Role), - AutoGroups: u.AutoGroups, - Status: string(UserStatusActive), - IsServiceUser: u.IsServiceUser, - IsBlocked: u.Blocked, - LastLogin: u.LastLogin, - Issued: u.Issued, - Permissions: UserPermissions{ - DashboardView: dashboardViewPermissions, - }, - }, nil - } - if userData.ID != u.Id { - return nil, fmt.Errorf("wrong UserData provided for user %s", u.Id) - } - - userStatus := UserStatusActive - if userData.AppMetadata.WTPendingInvite != nil && *userData.AppMetadata.WTPendingInvite { - userStatus = UserStatusInvited - } - - return &UserInfo{ - ID: u.Id, - Email: userData.Email, - Name: userData.Name, - Role: string(u.Role), - AutoGroups: autoGroups, - Status: string(userStatus), - IsServiceUser: u.IsServiceUser, - IsBlocked: u.Blocked, - LastLogin: u.LastLogin, - Issued: u.Issued, - Permissions: UserPermissions{ - DashboardView: dashboardViewPermissions, - }, - }, nil -} - -// Copy the user -func (u *User) Copy() *User { - autoGroups := make([]string, len(u.AutoGroups)) - copy(autoGroups, u.AutoGroups) - pats := make(map[string]*PersonalAccessToken, len(u.PATs)) - for k, v := range u.PATs { - pats[k] = v.Copy() - } - return &User{ - Id: u.Id, - AccountID: u.AccountID, - Role: u.Role, - AutoGroups: autoGroups, - IsServiceUser: u.IsServiceUser, - NonDeletable: u.NonDeletable, - ServiceUserName: u.ServiceUserName, - PATs: pats, - Blocked: u.Blocked, - LastLogin: u.LastLogin, - CreatedAt: u.CreatedAt, - Issued: u.Issued, - IntegrationReference: u.IntegrationReference, - } -} - -// NewUser creates a new user -func NewUser(id string, role UserRole, isServiceUser bool, nonDeletable bool, serviceUserName string, autoGroups []string, issued string) *User { - return &User{ - Id: id, - Role: role, - IsServiceUser: isServiceUser, - NonDeletable: nonDeletable, - ServiceUserName: serviceUserName, - AutoGroups: autoGroups, - Issued: issued, - CreatedAt: time.Now().UTC(), - } -} - -// NewRegularUser creates a new user with role UserRoleUser -func NewRegularUser(id string) *User { - return NewUser(id, UserRoleUser, false, false, "", []string{}, UserIssuedAPI) -} - -// NewAdminUser creates a new user with role UserRoleAdmin -func NewAdminUser(id string) *User { - return NewUser(id, UserRoleAdmin, false, false, "", []string{}, UserIssuedAPI) -} - -// NewOwnerUser creates a new user with role UserRoleOwner -func NewOwnerUser(id string) *User { - return NewUser(id, UserRoleOwner, false, false, "", []string{}, UserIssuedAPI) -} - // createServiceUser creates a new service user under the given account. -func (am *DefaultAccountManager) createServiceUser(ctx context.Context, accountID string, initiatorUserID string, role UserRole, serviceUserName string, nonDeletable bool, autoGroups []string) (*UserInfo, error) { +func (am *DefaultAccountManager) createServiceUser(ctx context.Context, accountID string, initiatorUserID string, role types.UserRole, serviceUserName string, nonDeletable bool, autoGroups []string) (*types.UserInfo, error) { unlock := am.Store.AcquireWriteLockByUID(ctx, accountID) defer unlock() @@ -240,12 +40,12 @@ func (am *DefaultAccountManager) createServiceUser(ctx context.Context, accountI return nil, status.Errorf(status.PermissionDenied, "only users with admin power can create service users") } - if role == UserRoleOwner { + if role == types.UserRoleOwner { return nil, status.Errorf(status.InvalidArgument, "can't create a service user with owner role") } newUserID := uuid.New().String() - newUser := NewUser(newUserID, role, true, nonDeletable, serviceUserName, autoGroups, UserIssuedAPI) + newUser := types.NewUser(newUserID, role, true, nonDeletable, serviceUserName, autoGroups, types.UserIssuedAPI) log.WithContext(ctx).Debugf("New User: %v", newUser) account.Users[newUserID] = newUser @@ -257,29 +57,29 @@ func (am *DefaultAccountManager) createServiceUser(ctx context.Context, accountI meta := map[string]any{"name": newUser.ServiceUserName} am.StoreEvent(ctx, initiatorUserID, newUser.Id, accountID, activity.ServiceUserCreated, meta) - return &UserInfo{ + return &types.UserInfo{ ID: newUser.Id, Email: "", Name: newUser.ServiceUserName, Role: string(newUser.Role), AutoGroups: newUser.AutoGroups, - Status: string(UserStatusActive), + Status: string(types.UserStatusActive), IsServiceUser: true, LastLogin: time.Time{}, - Issued: UserIssuedAPI, + Issued: types.UserIssuedAPI, }, nil } // CreateUser creates a new user under the given account. Effectively this is a user invite. -func (am *DefaultAccountManager) CreateUser(ctx context.Context, accountID, userID string, user *UserInfo) (*UserInfo, error) { +func (am *DefaultAccountManager) CreateUser(ctx context.Context, accountID, userID string, user *types.UserInfo) (*types.UserInfo, error) { if user.IsServiceUser { - return am.createServiceUser(ctx, accountID, userID, StrRoleToUserRole(user.Role), user.Name, user.NonDeletable, user.AutoGroups) + return am.createServiceUser(ctx, accountID, userID, types.StrRoleToUserRole(user.Role), user.Name, user.NonDeletable, user.AutoGroups) } return am.inviteNewUser(ctx, accountID, userID, user) } // inviteNewUser Invites a USer to a given account and creates reference in datastore -func (am *DefaultAccountManager) inviteNewUser(ctx context.Context, accountID, userID string, invite *UserInfo) (*UserInfo, error) { +func (am *DefaultAccountManager) inviteNewUser(ctx context.Context, accountID, userID string, invite *types.UserInfo) (*types.UserInfo, error) { unlock := am.Store.AcquireWriteLockByUID(ctx, accountID) defer unlock() @@ -291,14 +91,14 @@ func (am *DefaultAccountManager) inviteNewUser(ctx context.Context, accountID, u return nil, fmt.Errorf("provided user update is nil") } - invitedRole := StrRoleToUserRole(invite.Role) + invitedRole := types.StrRoleToUserRole(invite.Role) switch { case invite.Name == "": return nil, status.Errorf(status.InvalidArgument, "name can't be empty") case invite.Email == "": return nil, status.Errorf(status.InvalidArgument, "email can't be empty") - case invitedRole == UserRoleOwner: + case invitedRole == types.UserRoleOwner: return nil, status.Errorf(status.InvalidArgument, "can't invite a user with owner role") default: } @@ -348,7 +148,7 @@ func (am *DefaultAccountManager) inviteNewUser(ctx context.Context, accountID, u return nil, err } - newUser := &User{ + newUser := &types.User{ Id: idpUser.ID, Role: invitedRole, AutoGroups: invite.AutoGroups, @@ -373,19 +173,19 @@ func (am *DefaultAccountManager) inviteNewUser(ctx context.Context, accountID, u return newUser.ToUserInfo(idpUser, account.Settings) } -func (am *DefaultAccountManager) GetUserByID(ctx context.Context, id string) (*User, error) { - return am.Store.GetUserByUserID(ctx, LockingStrengthShare, id) +func (am *DefaultAccountManager) GetUserByID(ctx context.Context, id string) (*types.User, error) { + return am.Store.GetUserByUserID(ctx, store.LockingStrengthShare, id) } // GetUser looks up a user by provided authorization claims. // It will also create an account if didn't exist for this user before. -func (am *DefaultAccountManager) GetUser(ctx context.Context, claims jwtclaims.AuthorizationClaims) (*User, error) { +func (am *DefaultAccountManager) GetUser(ctx context.Context, claims jwtclaims.AuthorizationClaims) (*types.User, error) { accountID, userID, err := am.GetAccountIDFromToken(ctx, claims) if err != nil { return nil, fmt.Errorf("failed to get account with token claims %v", err) } - user, err := am.Store.GetUserByUserID(ctx, LockingStrengthShare, userID) + user, err := am.Store.GetUserByUserID(ctx, store.LockingStrengthShare, userID) if err != nil { return nil, err } @@ -409,7 +209,7 @@ func (am *DefaultAccountManager) GetUser(ctx context.Context, claims jwtclaims.A // ListUsers returns lists of all users under the account. // It doesn't populate user information such as email or name. -func (am *DefaultAccountManager) ListUsers(ctx context.Context, accountID string) ([]*User, error) { +func (am *DefaultAccountManager) ListUsers(ctx context.Context, accountID string) ([]*types.User, error) { unlock := am.Store.AcquireWriteLockByUID(ctx, accountID) defer unlock() @@ -418,7 +218,7 @@ func (am *DefaultAccountManager) ListUsers(ctx context.Context, accountID string return nil, err } - users := make([]*User, 0, len(account.Users)) + users := make([]*types.User, 0, len(account.Users)) for _, item := range account.Users { users = append(users, item) } @@ -426,7 +226,7 @@ func (am *DefaultAccountManager) ListUsers(ctx context.Context, accountID string return users, nil } -func (am *DefaultAccountManager) deleteServiceUser(ctx context.Context, account *Account, initiatorUserID string, targetUser *User) { +func (am *DefaultAccountManager) deleteServiceUser(ctx context.Context, account *types.Account, initiatorUserID string, targetUser *types.User) { meta := map[string]any{"name": targetUser.ServiceUserName, "created_at": targetUser.CreatedAt} am.StoreEvent(ctx, initiatorUserID, targetUser.Id, account.Id, activity.ServiceUserDeleted, meta) delete(account.Users, targetUser.Id) @@ -458,12 +258,12 @@ func (am *DefaultAccountManager) DeleteUser(ctx context.Context, accountID, init return status.Errorf(status.NotFound, "target user not found") } - if targetUser.Role == UserRoleOwner { + if targetUser.Role == types.UserRoleOwner { return status.Errorf(status.PermissionDenied, "unable to delete a user with owner role") } // disable deleting integration user if the initiator is not admin service user - if targetUser.Issued == UserIssuedIntegration && !executingUser.IsServiceUser { + if targetUser.Issued == types.UserIssuedIntegration && !executingUser.IsServiceUser { return status.Errorf(status.PermissionDenied, "only integration service user can delete this user") } @@ -480,7 +280,7 @@ func (am *DefaultAccountManager) DeleteUser(ctx context.Context, accountID, init return am.deleteRegularUser(ctx, account, initiatorUserID, targetUserID) } -func (am *DefaultAccountManager) deleteRegularUser(ctx context.Context, account *Account, initiatorUserID, targetUserID string) error { +func (am *DefaultAccountManager) deleteRegularUser(ctx context.Context, account *types.Account, initiatorUserID, targetUserID string) error { meta, updateAccountPeers, err := am.prepareUserDeletion(ctx, account, initiatorUserID, targetUserID) if err != nil { return err @@ -500,7 +300,7 @@ func (am *DefaultAccountManager) deleteRegularUser(ctx context.Context, account return nil } -func (am *DefaultAccountManager) deleteUserPeers(ctx context.Context, initiatorUserID string, targetUserID string, account *Account) (bool, error) { +func (am *DefaultAccountManager) deleteUserPeers(ctx context.Context, initiatorUserID string, targetUserID string, account *types.Account) (bool, error) { peers, err := account.FindUserPeers(targetUserID) if err != nil { return false, status.Errorf(status.Internal, "failed to find user peers") @@ -560,7 +360,7 @@ func (am *DefaultAccountManager) InviteUser(ctx context.Context, accountID strin } // CreatePAT creates a new PAT for the given user -func (am *DefaultAccountManager) CreatePAT(ctx context.Context, accountID string, initiatorUserID string, targetUserID string, tokenName string, expiresIn int) (*PersonalAccessTokenGenerated, error) { +func (am *DefaultAccountManager) CreatePAT(ctx context.Context, accountID string, initiatorUserID string, targetUserID string, tokenName string, expiresIn int) (*types.PersonalAccessTokenGenerated, error) { unlock := am.Store.AcquireWriteLockByUID(ctx, accountID) defer unlock() @@ -591,7 +391,7 @@ func (am *DefaultAccountManager) CreatePAT(ctx context.Context, accountID string return nil, status.Errorf(status.PermissionDenied, "no permission to create PAT for this user") } - pat, err := CreateNewPAT(tokenName, expiresIn, executingUser.Id) + pat, err := types.CreateNewPAT(tokenName, expiresIn, executingUser.Id) if err != nil { return nil, status.Errorf(status.Internal, "failed to create PAT: %v", err) } @@ -660,13 +460,13 @@ func (am *DefaultAccountManager) DeletePAT(ctx context.Context, accountID string } // GetPAT returns a specific PAT from a user -func (am *DefaultAccountManager) GetPAT(ctx context.Context, accountID string, initiatorUserID string, targetUserID string, tokenID string) (*PersonalAccessToken, error) { - initiatorUser, err := am.Store.GetUserByUserID(ctx, LockingStrengthShare, initiatorUserID) +func (am *DefaultAccountManager) GetPAT(ctx context.Context, accountID string, initiatorUserID string, targetUserID string, tokenID string) (*types.PersonalAccessToken, error) { + initiatorUser, err := am.Store.GetUserByUserID(ctx, store.LockingStrengthShare, initiatorUserID) if err != nil { return nil, err } - targetUser, err := am.Store.GetUserByUserID(ctx, LockingStrengthShare, targetUserID) + targetUser, err := am.Store.GetUserByUserID(ctx, store.LockingStrengthShare, targetUserID) if err != nil { return nil, err } @@ -685,13 +485,13 @@ func (am *DefaultAccountManager) GetPAT(ctx context.Context, accountID string, i } // GetAllPATs returns all PATs for a user -func (am *DefaultAccountManager) GetAllPATs(ctx context.Context, accountID string, initiatorUserID string, targetUserID string) ([]*PersonalAccessToken, error) { - initiatorUser, err := am.Store.GetUserByUserID(ctx, LockingStrengthShare, initiatorUserID) +func (am *DefaultAccountManager) GetAllPATs(ctx context.Context, accountID string, initiatorUserID string, targetUserID string) ([]*types.PersonalAccessToken, error) { + initiatorUser, err := am.Store.GetUserByUserID(ctx, store.LockingStrengthShare, initiatorUserID) if err != nil { return nil, err } - targetUser, err := am.Store.GetUserByUserID(ctx, LockingStrengthShare, targetUserID) + targetUser, err := am.Store.GetUserByUserID(ctx, store.LockingStrengthShare, targetUserID) if err != nil { return nil, err } @@ -700,7 +500,7 @@ func (am *DefaultAccountManager) GetAllPATs(ctx context.Context, accountID strin return nil, status.Errorf(status.PermissionDenied, "no permission to get PAT for this user") } - pats := make([]*PersonalAccessToken, 0, len(targetUser.PATsG)) + pats := make([]*types.PersonalAccessToken, 0, len(targetUser.PATsG)) for _, pat := range targetUser.PATsG { pats = append(pats, pat.Copy()) } @@ -709,13 +509,13 @@ func (am *DefaultAccountManager) GetAllPATs(ctx context.Context, accountID strin } // SaveUser saves updates to the given user. If the user doesn't exist, it will throw status.NotFound error. -func (am *DefaultAccountManager) SaveUser(ctx context.Context, accountID, initiatorUserID string, update *User) (*UserInfo, error) { +func (am *DefaultAccountManager) SaveUser(ctx context.Context, accountID, initiatorUserID string, update *types.User) (*types.UserInfo, error) { return am.SaveOrAddUser(ctx, accountID, initiatorUserID, update, false) // false means do not create user and throw status.NotFound } // SaveOrAddUser updates the given user. If addIfNotExists is set to true it will add user when no exist // Only User.AutoGroups, User.Role, and User.Blocked fields are allowed to be updated for now. -func (am *DefaultAccountManager) SaveOrAddUser(ctx context.Context, accountID, initiatorUserID string, update *User, addIfNotExists bool) (*UserInfo, error) { +func (am *DefaultAccountManager) SaveOrAddUser(ctx context.Context, accountID, initiatorUserID string, update *types.User, addIfNotExists bool) (*types.UserInfo, error) { if update == nil { return nil, status.Errorf(status.InvalidArgument, "provided user update is nil") } @@ -723,7 +523,7 @@ func (am *DefaultAccountManager) SaveOrAddUser(ctx context.Context, accountID, i unlock := am.Store.AcquireWriteLockByUID(ctx, accountID) defer unlock() - updatedUsers, err := am.SaveOrAddUsers(ctx, accountID, initiatorUserID, []*User{update}, addIfNotExists) + updatedUsers, err := am.SaveOrAddUsers(ctx, accountID, initiatorUserID, []*types.User{update}, addIfNotExists) if err != nil { return nil, err } @@ -738,7 +538,7 @@ func (am *DefaultAccountManager) SaveOrAddUser(ctx context.Context, accountID, i // SaveOrAddUsers updates existing users or adds new users to the account. // Note: This function does not acquire the global lock. // It is the caller's responsibility to ensure proper locking is in place before invoking this method. -func (am *DefaultAccountManager) SaveOrAddUsers(ctx context.Context, accountID, initiatorUserID string, updates []*User, addIfNotExists bool) ([]*UserInfo, error) { +func (am *DefaultAccountManager) SaveOrAddUsers(ctx context.Context, accountID, initiatorUserID string, updates []*types.User, addIfNotExists bool) ([]*types.UserInfo, error) { if len(updates) == 0 { return nil, nil //nolint:nilnil } @@ -757,7 +557,7 @@ func (am *DefaultAccountManager) SaveOrAddUsers(ctx context.Context, accountID, return nil, status.Errorf(status.PermissionDenied, "only users with admin power are authorized to perform user update operations") } - updatedUsers := make([]*UserInfo, 0, len(updates)) + updatedUsers := make([]*types.UserInfo, 0, len(updates)) var ( expiredPeers []*nbpeer.Peer userIDs []string @@ -808,7 +608,7 @@ func (am *DefaultAccountManager) SaveOrAddUsers(ctx context.Context, accountID, peerGroupsAdded := make(map[string][]string) peerGroupsRemoved := make(map[string][]string) if update.AutoGroups != nil && account.Settings.GroupsPropagationEnabled { - removedGroups := difference(oldUser.AutoGroups, update.AutoGroups) + removedGroups := util.Difference(oldUser.AutoGroups, update.AutoGroups) // need force update all auto groups in any case they will not be duplicated peerGroupsAdded = account.UserGroupsAddToPeers(oldUser.Id, update.AutoGroups...) peerGroupsRemoved = account.UserGroupsRemoveFromPeers(oldUser.Id, removedGroups...) @@ -851,7 +651,7 @@ func (am *DefaultAccountManager) SaveOrAddUsers(ctx context.Context, accountID, } // prepareUserUpdateEvents prepares a list user update events based on the changes between the old and new user data. -func (am *DefaultAccountManager) prepareUserUpdateEvents(ctx context.Context, initiatorUserID string, oldUser, newUser *User, account *Account, transferredOwnerRole bool) []func() { +func (am *DefaultAccountManager) prepareUserUpdateEvents(ctx context.Context, initiatorUserID string, oldUser, newUser *types.User, account *types.Account, transferredOwnerRole bool) []func() { var eventsToStore []func() if oldUser.IsBlocked() != newUser.IsBlocked() { @@ -880,11 +680,11 @@ func (am *DefaultAccountManager) prepareUserUpdateEvents(ctx context.Context, in return eventsToStore } -func (am *DefaultAccountManager) prepareUserGroupsEvents(ctx context.Context, initiatorUserID string, oldUser, newUser *User, account *Account, peerGroupsAdded, peerGroupsRemoved map[string][]string) []func() { +func (am *DefaultAccountManager) prepareUserGroupsEvents(ctx context.Context, initiatorUserID string, oldUser, newUser *types.User, account *types.Account, peerGroupsAdded, peerGroupsRemoved map[string][]string) []func() { var eventsToStore []func() if newUser.AutoGroups != nil { - removedGroups := difference(oldUser.AutoGroups, newUser.AutoGroups) - addedGroups := difference(newUser.AutoGroups, oldUser.AutoGroups) + removedGroups := util.Difference(oldUser.AutoGroups, newUser.AutoGroups) + addedGroups := util.Difference(newUser.AutoGroups, oldUser.AutoGroups) removedEvents := am.handleGroupRemovedFromUser(ctx, initiatorUserID, oldUser, newUser, account, removedGroups, peerGroupsRemoved) eventsToStore = append(eventsToStore, removedEvents...) @@ -895,7 +695,7 @@ func (am *DefaultAccountManager) prepareUserGroupsEvents(ctx context.Context, in return eventsToStore } -func (am *DefaultAccountManager) handleGroupAddedToUser(ctx context.Context, initiatorUserID string, oldUser, newUser *User, account *Account, addedGroups []string, peerGroupsAdded map[string][]string) []func() { +func (am *DefaultAccountManager) handleGroupAddedToUser(ctx context.Context, initiatorUserID string, oldUser, newUser *types.User, account *types.Account, addedGroups []string, peerGroupsAdded map[string][]string) []func() { var eventsToStore []func() for _, g := range addedGroups { group := account.GetGroup(g) @@ -922,7 +722,7 @@ func (am *DefaultAccountManager) handleGroupAddedToUser(ctx context.Context, ini return eventsToStore } -func (am *DefaultAccountManager) handleGroupRemovedFromUser(ctx context.Context, initiatorUserID string, oldUser, newUser *User, account *Account, removedGroups []string, peerGroupsRemoved map[string][]string) []func() { +func (am *DefaultAccountManager) handleGroupRemovedFromUser(ctx context.Context, initiatorUserID string, oldUser, newUser *types.User, account *types.Account, removedGroups []string, peerGroupsRemoved map[string][]string) []func() { var eventsToStore []func() for _, g := range removedGroups { group := account.GetGroup(g) @@ -952,10 +752,10 @@ func (am *DefaultAccountManager) handleGroupRemovedFromUser(ctx context.Context, return eventsToStore } -func handleOwnerRoleTransfer(account *Account, initiatorUser, update *User) bool { - if initiatorUser.Role == UserRoleOwner && initiatorUser.Id != update.Id && update.Role == UserRoleOwner { +func handleOwnerRoleTransfer(account *types.Account, initiatorUser, update *types.User) bool { + if initiatorUser.Role == types.UserRoleOwner && initiatorUser.Id != update.Id && update.Role == types.UserRoleOwner { newInitiatorUser := initiatorUser.Copy() - newInitiatorUser.Role = UserRoleAdmin + newInitiatorUser.Role = types.UserRoleAdmin account.Users[initiatorUser.Id] = newInitiatorUser return true } @@ -965,7 +765,7 @@ func handleOwnerRoleTransfer(account *Account, initiatorUser, update *User) bool // getUserInfo retrieves the UserInfo for a given User and Account. // If the AccountManager has a non-nil idpManager and the User is not a service user, // it will attempt to look up the UserData from the cache. -func getUserInfo(ctx context.Context, am *DefaultAccountManager, user *User, account *Account) (*UserInfo, error) { +func getUserInfo(ctx context.Context, am *DefaultAccountManager, user *types.User, account *types.Account) (*types.UserInfo, error) { if !isNil(am.idpManager) && !user.IsServiceUser { userData, err := am.lookupUserInCache(ctx, user.Id, account) if err != nil { @@ -977,23 +777,23 @@ func getUserInfo(ctx context.Context, am *DefaultAccountManager, user *User, acc } // validateUserUpdate validates the update operation for a user. -func validateUserUpdate(account *Account, initiatorUser, oldUser, update *User) error { +func validateUserUpdate(account *types.Account, initiatorUser, oldUser, update *types.User) error { if initiatorUser.HasAdminPower() && initiatorUser.Id == update.Id && oldUser.Blocked != update.Blocked { return status.Errorf(status.PermissionDenied, "admins can't block or unblock themselves") } if initiatorUser.HasAdminPower() && initiatorUser.Id == update.Id && update.Role != initiatorUser.Role { return status.Errorf(status.PermissionDenied, "admins can't change their role") } - if initiatorUser.Role == UserRoleAdmin && oldUser.Role == UserRoleOwner && update.Role != oldUser.Role { + if initiatorUser.Role == types.UserRoleAdmin && oldUser.Role == types.UserRoleOwner && update.Role != oldUser.Role { return status.Errorf(status.PermissionDenied, "only owners can remove owner role from their user") } - if initiatorUser.Role == UserRoleAdmin && oldUser.Role == UserRoleOwner && update.IsBlocked() && !oldUser.IsBlocked() { + if initiatorUser.Role == types.UserRoleAdmin && oldUser.Role == types.UserRoleOwner && update.IsBlocked() && !oldUser.IsBlocked() { return status.Errorf(status.PermissionDenied, "unable to block owner user") } - if initiatorUser.Role == UserRoleAdmin && update.Role == UserRoleOwner && update.Role != oldUser.Role { + if initiatorUser.Role == types.UserRoleAdmin && update.Role == types.UserRoleOwner && update.Role != oldUser.Role { return status.Errorf(status.PermissionDenied, "only owners can add owner role to other users") } - if oldUser.IsServiceUser && update.Role == UserRoleOwner { + if oldUser.IsServiceUser && update.Role == types.UserRoleOwner { return status.Errorf(status.PermissionDenied, "can't update a service user with owner role") } @@ -1012,7 +812,7 @@ func validateUserUpdate(account *Account, initiatorUser, oldUser, update *User) } // GetOrCreateAccountByUser returns an existing account for a given user id or creates a new one if doesn't exist -func (am *DefaultAccountManager) GetOrCreateAccountByUser(ctx context.Context, userID, domain string) (*Account, error) { +func (am *DefaultAccountManager) GetOrCreateAccountByUser(ctx context.Context, userID, domain string) (*types.Account, error) { start := time.Now() unlock := am.Store.AcquireGlobalLock(ctx) defer unlock() @@ -1039,7 +839,7 @@ func (am *DefaultAccountManager) GetOrCreateAccountByUser(ctx context.Context, u userObj := account.Users[userID] - if lowerDomain != "" && account.Domain != lowerDomain && userObj.Role == UserRoleOwner { + if lowerDomain != "" && account.Domain != lowerDomain && userObj.Role == types.UserRoleOwner { account.Domain = lowerDomain err = am.Store.SaveAccount(ctx, account) if err != nil { @@ -1052,7 +852,7 @@ func (am *DefaultAccountManager) GetOrCreateAccountByUser(ctx context.Context, u // GetUsersFromAccount performs a batched request for users from IDP by account ID apply filter on what data to return // based on provided user role. -func (am *DefaultAccountManager) GetUsersFromAccount(ctx context.Context, accountID, userID string) ([]*UserInfo, error) { +func (am *DefaultAccountManager) GetUsersFromAccount(ctx context.Context, accountID, userID string) ([]*types.UserInfo, error) { account, err := am.Store.GetAccount(ctx, accountID) if err != nil { return nil, err @@ -1068,7 +868,7 @@ func (am *DefaultAccountManager) GetUsersFromAccount(ctx context.Context, accoun users := make(map[string]userLoggedInOnce, len(account.Users)) usersFromIntegration := make([]*idp.UserData, 0) for _, user := range account.Users { - if user.Issued == UserIssuedIntegration { + if user.Issued == types.UserIssuedIntegration { key := user.IntegrationReference.CacheKey(accountID, user.Id) info, err := am.externalCacheManager.Get(am.ctx, key) if err != nil { @@ -1092,7 +892,7 @@ func (am *DefaultAccountManager) GetUsersFromAccount(ctx context.Context, accoun queriedUsers = append(queriedUsers, usersFromIntegration...) } - userInfos := make([]*UserInfo, 0) + userInfos := make([]*types.UserInfo, 0) // in case of self-hosted, or IDP doesn't return anything, we will return the locally stored userInfo if len(queriedUsers) == 0 { @@ -1116,7 +916,7 @@ func (am *DefaultAccountManager) GetUsersFromAccount(ctx context.Context, accoun continue } - var info *UserInfo + var info *types.UserInfo if queriedUser, contains := findUserInIDPUserdata(localUser.Id, queriedUsers); contains { info, err = localUser.ToUserInfo(queriedUser, account.Settings) if err != nil { @@ -1136,16 +936,16 @@ func (am *DefaultAccountManager) GetUsersFromAccount(ctx context.Context, accoun } } - info = &UserInfo{ + info = &types.UserInfo{ ID: localUser.Id, Email: "", Name: name, Role: string(localUser.Role), AutoGroups: localUser.AutoGroups, - Status: string(UserStatusActive), + Status: string(types.UserStatusActive), IsServiceUser: localUser.IsServiceUser, NonDeletable: localUser.NonDeletable, - Permissions: UserPermissions{DashboardView: dashboardViewPermissions}, + Permissions: types.UserPermissions{DashboardView: dashboardViewPermissions}, } } userInfos = append(userInfos, info) @@ -1155,7 +955,7 @@ func (am *DefaultAccountManager) GetUsersFromAccount(ctx context.Context, accoun } // expireAndUpdatePeers expires all peers of the given user and updates them in the account -func (am *DefaultAccountManager) expireAndUpdatePeers(ctx context.Context, account *Account, peers []*nbpeer.Peer) error { +func (am *DefaultAccountManager) expireAndUpdatePeers(ctx context.Context, account *types.Account, peers []*nbpeer.Peer) error { var peerIDs []string for _, peer := range peers { // nolint:staticcheck @@ -1260,13 +1060,13 @@ func (am *DefaultAccountManager) DeleteRegularUsers(ctx context.Context, account continue } - if targetUser.Role == UserRoleOwner { + if targetUser.Role == types.UserRoleOwner { allErrors = errors.Join(allErrors, fmt.Errorf("unable to delete a user: %s with owner role", targetUserID)) continue } // disable deleting integration user if the initiator is not admin service user - if targetUser.Issued == UserIssuedIntegration && !executingUser.IsServiceUser { + if targetUser.Issued == types.UserIssuedIntegration && !executingUser.IsServiceUser { allErrors = errors.Join(allErrors, errors.New("only integration service user can delete this user")) continue } @@ -1301,7 +1101,7 @@ func (am *DefaultAccountManager) DeleteRegularUsers(ctx context.Context, account return allErrors } -func (am *DefaultAccountManager) prepareUserDeletion(ctx context.Context, account *Account, initiatorUserID, targetUserID string) (map[string]any, bool, error) { +func (am *DefaultAccountManager) prepareUserDeletion(ctx context.Context, account *types.Account, initiatorUserID, targetUserID string) (map[string]any, bool, error) { tuEmail, tuName, err := am.getEmailAndNameOfTargetUser(ctx, account.Id, initiatorUserID, targetUserID) if err != nil { log.WithContext(ctx).Errorf("failed to resolve email address: %s", err) @@ -1342,8 +1142,8 @@ func (am *DefaultAccountManager) prepareUserDeletion(ctx context.Context, accoun } // updateUserPeersInGroups updates the user's peers in the specified groups by adding or removing them. -func (am *DefaultAccountManager) updateUserPeersInGroups(accountGroups map[string]*nbgroup.Group, peers []*nbpeer.Peer, groupsToAdd, - groupsToRemove []string) (groupsToUpdate []*nbgroup.Group, err error) { +func (am *DefaultAccountManager) updateUserPeersInGroups(accountGroups map[string]*types.Group, peers []*nbpeer.Peer, groupsToAdd, + groupsToRemove []string) (groupsToUpdate []*types.Group, err error) { if len(groupsToAdd) == 0 && len(groupsToRemove) == 0 { return @@ -1376,7 +1176,7 @@ func (am *DefaultAccountManager) updateUserPeersInGroups(accountGroups map[strin } // addUserPeersToGroup adds the user's peers to the group. -func addUserPeersToGroup(userPeerIDs map[string]struct{}, group *nbgroup.Group) { +func addUserPeersToGroup(userPeerIDs map[string]struct{}, group *types.Group) { groupPeers := make(map[string]struct{}, len(group.Peers)) for _, pid := range group.Peers { groupPeers[pid] = struct{}{} @@ -1393,7 +1193,7 @@ func addUserPeersToGroup(userPeerIDs map[string]struct{}, group *nbgroup.Group) } // removeUserPeersFromGroup removes user's peers from the group. -func removeUserPeersFromGroup(userPeerIDs map[string]struct{}, group *nbgroup.Group) { +func removeUserPeersFromGroup(userPeerIDs map[string]struct{}, group *types.Group) { // skip removing peers from group All if group.Name == "All" { return @@ -1419,7 +1219,7 @@ func findUserInIDPUserdata(userID string, userData []*idp.UserData) (*idp.UserDa } // areUsersLinkedToPeers checks if any of the given userIDs are linked to any of the peers in the account. -func areUsersLinkedToPeers(account *Account, userIDs []string) bool { +func areUsersLinkedToPeers(account *types.Account, userIDs []string) bool { for _, peer := range account.Peers { if slices.Contains(userIDs, peer.UserID) { return true diff --git a/management/server/user_test.go b/management/server/user_test.go index 498017afa..75d88f9c8 100644 --- a/management/server/user_test.go +++ b/management/server/user_test.go @@ -10,8 +10,11 @@ import ( "github.com/eko/gocache/v3/cache" cacheStore "github.com/eko/gocache/v3/store" "github.com/google/go-cmp/cmp" - nbgroup "github.com/netbirdio/netbird/management/server/group" + nbpeer "github.com/netbirdio/netbird/management/server/peer" + "github.com/netbirdio/netbird/management/server/store" + "github.com/netbirdio/netbird/management/server/types" + gocache "github.com/patrickmn/go-cache" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" @@ -41,11 +44,15 @@ const ( ) func TestUser_CreatePAT_ForSameUser(t *testing.T) { - store := newStore(t) - defer store.Close(context.Background()) + store, cleanup, err := store.NewTestStoreFromSQL(context.Background(), "", t.TempDir()) + if err != nil { + t.Fatalf("Error when creating store: %s", err) + } + t.Cleanup(cleanup) + account := newAccountWithId(context.Background(), mockAccountID, mockUserID, "") - err := store.SaveAccount(context.Background(), account) + err = store.SaveAccount(context.Background(), account) if err != nil { t.Fatalf("Error when saving account: %s", err) } @@ -82,14 +89,18 @@ func TestUser_CreatePAT_ForSameUser(t *testing.T) { } func TestUser_CreatePAT_ForDifferentUser(t *testing.T) { - store := newStore(t) - defer store.Close(context.Background()) + store, cleanup, err := store.NewTestStoreFromSQL(context.Background(), "", t.TempDir()) + if err != nil { + t.Fatalf("Error when creating store: %s", err) + } + t.Cleanup(cleanup) + account := newAccountWithId(context.Background(), mockAccountID, mockUserID, "") - account.Users[mockTargetUserId] = &User{ + account.Users[mockTargetUserId] = &types.User{ Id: mockTargetUserId, IsServiceUser: false, } - err := store.SaveAccount(context.Background(), account) + err = store.SaveAccount(context.Background(), account) if err != nil { t.Fatalf("Error when saving account: %s", err) } @@ -104,14 +115,18 @@ func TestUser_CreatePAT_ForDifferentUser(t *testing.T) { } func TestUser_CreatePAT_ForServiceUser(t *testing.T) { - store := newStore(t) - defer store.Close(context.Background()) + store, cleanup, err := store.NewTestStoreFromSQL(context.Background(), "", t.TempDir()) + if err != nil { + t.Fatalf("Error when creating store: %s", err) + } + t.Cleanup(cleanup) + account := newAccountWithId(context.Background(), mockAccountID, mockUserID, "") - account.Users[mockTargetUserId] = &User{ + account.Users[mockTargetUserId] = &types.User{ Id: mockTargetUserId, IsServiceUser: true, } - err := store.SaveAccount(context.Background(), account) + err = store.SaveAccount(context.Background(), account) if err != nil { t.Fatalf("Error when saving account: %s", err) } @@ -130,11 +145,15 @@ func TestUser_CreatePAT_ForServiceUser(t *testing.T) { } func TestUser_CreatePAT_WithWrongExpiration(t *testing.T) { - store := newStore(t) - defer store.Close(context.Background()) + store, cleanup, err := store.NewTestStoreFromSQL(context.Background(), "", t.TempDir()) + if err != nil { + t.Fatalf("Error when creating store: %s", err) + } + t.Cleanup(cleanup) + account := newAccountWithId(context.Background(), mockAccountID, mockUserID, "") - err := store.SaveAccount(context.Background(), account) + err = store.SaveAccount(context.Background(), account) if err != nil { t.Fatalf("Error when saving account: %s", err) } @@ -149,11 +168,15 @@ func TestUser_CreatePAT_WithWrongExpiration(t *testing.T) { } func TestUser_CreatePAT_WithEmptyName(t *testing.T) { - store := newStore(t) - defer store.Close(context.Background()) + store, cleanup, err := store.NewTestStoreFromSQL(context.Background(), "", t.TempDir()) + if err != nil { + t.Fatalf("Error when creating store: %s", err) + } + t.Cleanup(cleanup) + account := newAccountWithId(context.Background(), mockAccountID, mockUserID, "") - err := store.SaveAccount(context.Background(), account) + err = store.SaveAccount(context.Background(), account) if err != nil { t.Fatalf("Error when saving account: %s", err) } @@ -168,19 +191,23 @@ func TestUser_CreatePAT_WithEmptyName(t *testing.T) { } func TestUser_DeletePAT(t *testing.T) { - store := newStore(t) - defer store.Close(context.Background()) + store, cleanup, err := store.NewTestStoreFromSQL(context.Background(), "", t.TempDir()) + if err != nil { + t.Fatalf("Error when creating store: %s", err) + } + t.Cleanup(cleanup) + account := newAccountWithId(context.Background(), mockAccountID, mockUserID, "") - account.Users[mockUserID] = &User{ + account.Users[mockUserID] = &types.User{ Id: mockUserID, - PATs: map[string]*PersonalAccessToken{ + PATs: map[string]*types.PersonalAccessToken{ mockTokenID1: { ID: mockTokenID1, HashedToken: mockToken1, }, }, } - err := store.SaveAccount(context.Background(), account) + err = store.SaveAccount(context.Background(), account) if err != nil { t.Fatalf("Error when saving account: %s", err) } @@ -204,20 +231,24 @@ func TestUser_DeletePAT(t *testing.T) { } func TestUser_GetPAT(t *testing.T) { - store := newStore(t) - defer store.Close(context.Background()) + store, cleanup, err := store.NewTestStoreFromSQL(context.Background(), "", t.TempDir()) + if err != nil { + t.Fatalf("Error when creating store: %s", err) + } + t.Cleanup(cleanup) + account := newAccountWithId(context.Background(), mockAccountID, mockUserID, "") - account.Users[mockUserID] = &User{ + account.Users[mockUserID] = &types.User{ Id: mockUserID, AccountID: mockAccountID, - PATs: map[string]*PersonalAccessToken{ + PATs: map[string]*types.PersonalAccessToken{ mockTokenID1: { ID: mockTokenID1, HashedToken: mockToken1, }, }, } - err := store.SaveAccount(context.Background(), account) + err = store.SaveAccount(context.Background(), account) if err != nil { t.Fatalf("Error when saving account: %s", err) } @@ -237,13 +268,17 @@ func TestUser_GetPAT(t *testing.T) { } func TestUser_GetAllPATs(t *testing.T) { - store := newStore(t) - defer store.Close(context.Background()) + store, cleanup, err := store.NewTestStoreFromSQL(context.Background(), "", t.TempDir()) + if err != nil { + t.Fatalf("Error when creating store: %s", err) + } + t.Cleanup(cleanup) + account := newAccountWithId(context.Background(), mockAccountID, mockUserID, "") - account.Users[mockUserID] = &User{ + account.Users[mockUserID] = &types.User{ Id: mockUserID, AccountID: mockAccountID, - PATs: map[string]*PersonalAccessToken{ + PATs: map[string]*types.PersonalAccessToken{ mockTokenID1: { ID: mockTokenID1, HashedToken: mockToken1, @@ -254,7 +289,7 @@ func TestUser_GetAllPATs(t *testing.T) { }, }, } - err := store.SaveAccount(context.Background(), account) + err = store.SaveAccount(context.Background(), account) if err != nil { t.Fatalf("Error when saving account: %s", err) } @@ -274,14 +309,14 @@ func TestUser_GetAllPATs(t *testing.T) { func TestUser_Copy(t *testing.T) { // this is an imaginary case which will never be in DB this way - user := User{ + user := types.User{ Id: "userId", AccountID: "accountId", Role: "role", IsServiceUser: true, ServiceUserName: "servicename", AutoGroups: []string{"group1", "group2"}, - PATs: map[string]*PersonalAccessToken{ + PATs: map[string]*types.PersonalAccessToken{ "pat1": { ID: "pat1", Name: "First PAT", @@ -340,11 +375,15 @@ func validateStruct(s interface{}) (err error) { } func TestUser_CreateServiceUser(t *testing.T) { - store := newStore(t) - defer store.Close(context.Background()) + store, cleanup, err := store.NewTestStoreFromSQL(context.Background(), "", t.TempDir()) + if err != nil { + t.Fatalf("Error when creating store: %s", err) + } + t.Cleanup(cleanup) + account := newAccountWithId(context.Background(), mockAccountID, mockUserID, "") - err := store.SaveAccount(context.Background(), account) + err = store.SaveAccount(context.Background(), account) if err != nil { t.Fatalf("Error when saving account: %s", err) } @@ -366,26 +405,30 @@ func TestUser_CreateServiceUser(t *testing.T) { assert.NotNil(t, account.Users[user.ID]) assert.True(t, account.Users[user.ID].IsServiceUser) assert.Equal(t, mockServiceUserName, account.Users[user.ID].ServiceUserName) - assert.Equal(t, UserRole(mockRole), account.Users[user.ID].Role) + assert.Equal(t, types.UserRole(mockRole), account.Users[user.ID].Role) assert.Equal(t, []string{"group1", "group2"}, account.Users[user.ID].AutoGroups) - assert.Equal(t, map[string]*PersonalAccessToken{}, account.Users[user.ID].PATs) + assert.Equal(t, map[string]*types.PersonalAccessToken{}, account.Users[user.ID].PATs) assert.Zero(t, user.Email) assert.True(t, user.IsServiceUser) assert.Equal(t, "active", user.Status) - _, err = am.createServiceUser(context.Background(), mockAccountID, mockUserID, UserRoleOwner, mockServiceUserName, false, nil) + _, err = am.createServiceUser(context.Background(), mockAccountID, mockUserID, types.UserRoleOwner, mockServiceUserName, false, nil) if err == nil { t.Fatal("should return error when creating service user with owner role") } } func TestUser_CreateUser_ServiceUser(t *testing.T) { - store := newStore(t) - defer store.Close(context.Background()) + store, cleanup, err := store.NewTestStoreFromSQL(context.Background(), "", t.TempDir()) + if err != nil { + t.Fatalf("Error when creating store: %s", err) + } + t.Cleanup(cleanup) + account := newAccountWithId(context.Background(), mockAccountID, mockUserID, "") - err := store.SaveAccount(context.Background(), account) + err = store.SaveAccount(context.Background(), account) if err != nil { t.Fatalf("Error when saving account: %s", err) } @@ -395,7 +438,7 @@ func TestUser_CreateUser_ServiceUser(t *testing.T) { eventStore: &activity.InMemoryEventStore{}, } - user, err := am.CreateUser(context.Background(), mockAccountID, mockUserID, &UserInfo{ + user, err := am.CreateUser(context.Background(), mockAccountID, mockUserID, &types.UserInfo{ Name: mockServiceUserName, Role: mockRole, IsServiceUser: true, @@ -413,7 +456,7 @@ func TestUser_CreateUser_ServiceUser(t *testing.T) { assert.Equal(t, 2, len(account.Users)) assert.True(t, account.Users[user.ID].IsServiceUser) assert.Equal(t, mockServiceUserName, account.Users[user.ID].ServiceUserName) - assert.Equal(t, UserRole(mockRole), account.Users[user.ID].Role) + assert.Equal(t, types.UserRole(mockRole), account.Users[user.ID].Role) assert.Equal(t, []string{"group1", "group2"}, account.Users[user.ID].AutoGroups) assert.Equal(t, mockServiceUserName, user.Name) @@ -423,11 +466,15 @@ func TestUser_CreateUser_ServiceUser(t *testing.T) { } func TestUser_CreateUser_RegularUser(t *testing.T) { - store := newStore(t) - defer store.Close(context.Background()) + store, cleanup, err := store.NewTestStoreFromSQL(context.Background(), "", t.TempDir()) + if err != nil { + t.Fatalf("Error when creating store: %s", err) + } + t.Cleanup(cleanup) + account := newAccountWithId(context.Background(), mockAccountID, mockUserID, "") - err := store.SaveAccount(context.Background(), account) + err = store.SaveAccount(context.Background(), account) if err != nil { t.Fatalf("Error when saving account: %s", err) } @@ -437,7 +484,7 @@ func TestUser_CreateUser_RegularUser(t *testing.T) { eventStore: &activity.InMemoryEventStore{}, } - _, err = am.CreateUser(context.Background(), mockAccountID, mockUserID, &UserInfo{ + _, err = am.CreateUser(context.Background(), mockAccountID, mockUserID, &types.UserInfo{ Name: mockServiceUserName, Role: mockRole, IsServiceUser: false, @@ -448,11 +495,15 @@ func TestUser_CreateUser_RegularUser(t *testing.T) { } func TestUser_InviteNewUser(t *testing.T) { - store := newStore(t) - defer store.Close(context.Background()) + store, cleanup, err := store.NewTestStoreFromSQL(context.Background(), "", t.TempDir()) + if err != nil { + t.Fatalf("Error when creating store: %s", err) + } + t.Cleanup(cleanup) + account := newAccountWithId(context.Background(), mockAccountID, mockUserID, "") - err := store.SaveAccount(context.Background(), account) + err = store.SaveAccount(context.Background(), account) if err != nil { t.Fatalf("Error when saving account: %s", err) } @@ -495,7 +546,7 @@ func TestUser_InviteNewUser(t *testing.T) { am.idpManager = &idpMock // test if new invite with regular role works - _, err = am.inviteNewUser(context.Background(), mockAccountID, mockUserID, &UserInfo{ + _, err = am.inviteNewUser(context.Background(), mockAccountID, mockUserID, &types.UserInfo{ Name: mockServiceUserName, Role: mockRole, Email: "test@teste.com", @@ -506,9 +557,9 @@ func TestUser_InviteNewUser(t *testing.T) { assert.NoErrorf(t, err, "Invite user should not throw error") // test if new invite with owner role fails - _, err = am.inviteNewUser(context.Background(), mockAccountID, mockUserID, &UserInfo{ + _, err = am.inviteNewUser(context.Background(), mockAccountID, mockUserID, &types.UserInfo{ Name: mockServiceUserName, - Role: string(UserRoleOwner), + Role: string(types.UserRoleOwner), Email: "test2@teste.com", IsServiceUser: false, AutoGroups: []string{"group1", "group2"}, @@ -520,13 +571,13 @@ func TestUser_InviteNewUser(t *testing.T) { func TestUser_DeleteUser_ServiceUser(t *testing.T) { tests := []struct { name string - serviceUser *User + serviceUser *types.User assertErrFunc assert.ErrorAssertionFunc assertErrMessage string }{ { name: "Can delete service user", - serviceUser: &User{ + serviceUser: &types.User{ Id: mockServiceUserID, IsServiceUser: true, ServiceUserName: mockServiceUserName, @@ -535,7 +586,7 @@ func TestUser_DeleteUser_ServiceUser(t *testing.T) { }, { name: "Cannot delete non-deletable service user", - serviceUser: &User{ + serviceUser: &types.User{ Id: mockServiceUserID, IsServiceUser: true, ServiceUserName: mockServiceUserName, @@ -548,11 +599,16 @@ func TestUser_DeleteUser_ServiceUser(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - store := newStore(t) + store, cleanup, err := store.NewTestStoreFromSQL(context.Background(), "", t.TempDir()) + if err != nil { + t.Fatalf("Error when creating store: %s", err) + } + t.Cleanup(cleanup) + account := newAccountWithId(context.Background(), mockAccountID, mockUserID, "") account.Users[mockServiceUserID] = tt.serviceUser - err := store.SaveAccount(context.Background(), account) + err = store.SaveAccount(context.Background(), account) if err != nil { t.Fatalf("Error when saving account: %s", err) } @@ -580,11 +636,15 @@ func TestUser_DeleteUser_ServiceUser(t *testing.T) { } func TestUser_DeleteUser_SelfDelete(t *testing.T) { - store := newStore(t) - defer store.Close(context.Background()) + store, cleanup, err := store.NewTestStoreFromSQL(context.Background(), "", t.TempDir()) + if err != nil { + t.Fatalf("Error when creating store: %s", err) + } + t.Cleanup(cleanup) + account := newAccountWithId(context.Background(), mockAccountID, mockUserID, "") - err := store.SaveAccount(context.Background(), account) + err = store.SaveAccount(context.Background(), account) if err != nil { t.Fatalf("Error when saving account: %s", err) } @@ -601,38 +661,42 @@ func TestUser_DeleteUser_SelfDelete(t *testing.T) { } func TestUser_DeleteUser_regularUser(t *testing.T) { - store := newStore(t) - defer store.Close(context.Background()) + store, cleanup, err := store.NewTestStoreFromSQL(context.Background(), "", t.TempDir()) + if err != nil { + t.Fatalf("Error when creating store: %s", err) + } + t.Cleanup(cleanup) + account := newAccountWithId(context.Background(), mockAccountID, mockUserID, "") targetId := "user2" - account.Users[targetId] = &User{ + account.Users[targetId] = &types.User{ Id: targetId, IsServiceUser: true, ServiceUserName: "user2username", } targetId = "user3" - account.Users[targetId] = &User{ + account.Users[targetId] = &types.User{ Id: targetId, IsServiceUser: false, - Issued: UserIssuedAPI, + Issued: types.UserIssuedAPI, } targetId = "user4" - account.Users[targetId] = &User{ + account.Users[targetId] = &types.User{ Id: targetId, IsServiceUser: false, - Issued: UserIssuedIntegration, + Issued: types.UserIssuedIntegration, } targetId = "user5" - account.Users[targetId] = &User{ + account.Users[targetId] = &types.User{ Id: targetId, IsServiceUser: false, - Issued: UserIssuedAPI, - Role: UserRoleOwner, + Issued: types.UserIssuedAPI, + Role: types.UserRoleOwner, } - err := store.SaveAccount(context.Background(), account) + err = store.SaveAccount(context.Background(), account) if err != nil { t.Fatalf("Error when saving account: %s", err) } @@ -683,60 +747,64 @@ func TestUser_DeleteUser_regularUser(t *testing.T) { } func TestUser_DeleteUser_RegularUsers(t *testing.T) { - store := newStore(t) - defer store.Close(context.Background()) + store, cleanup, err := store.NewTestStoreFromSQL(context.Background(), "", t.TempDir()) + if err != nil { + t.Fatalf("Error when creating store: %s", err) + } + t.Cleanup(cleanup) + account := newAccountWithId(context.Background(), mockAccountID, mockUserID, "") targetId := "user2" - account.Users[targetId] = &User{ + account.Users[targetId] = &types.User{ Id: targetId, IsServiceUser: true, ServiceUserName: "user2username", } targetId = "user3" - account.Users[targetId] = &User{ + account.Users[targetId] = &types.User{ Id: targetId, IsServiceUser: false, - Issued: UserIssuedAPI, + Issued: types.UserIssuedAPI, } targetId = "user4" - account.Users[targetId] = &User{ + account.Users[targetId] = &types.User{ Id: targetId, IsServiceUser: false, - Issued: UserIssuedIntegration, + Issued: types.UserIssuedIntegration, } targetId = "user5" - account.Users[targetId] = &User{ + account.Users[targetId] = &types.User{ Id: targetId, IsServiceUser: false, - Issued: UserIssuedAPI, - Role: UserRoleOwner, + Issued: types.UserIssuedAPI, + Role: types.UserRoleOwner, } - account.Users["user6"] = &User{ + account.Users["user6"] = &types.User{ Id: "user6", IsServiceUser: false, - Issued: UserIssuedAPI, + Issued: types.UserIssuedAPI, } - account.Users["user7"] = &User{ + account.Users["user7"] = &types.User{ Id: "user7", IsServiceUser: false, - Issued: UserIssuedAPI, + Issued: types.UserIssuedAPI, } - account.Users["user8"] = &User{ + account.Users["user8"] = &types.User{ Id: "user8", IsServiceUser: false, - Issued: UserIssuedAPI, - Role: UserRoleAdmin, + Issued: types.UserIssuedAPI, + Role: types.UserRoleAdmin, } - account.Users["user9"] = &User{ + account.Users["user9"] = &types.User{ Id: "user9", IsServiceUser: false, - Issued: UserIssuedAPI, - Role: UserRoleAdmin, + Issued: types.UserIssuedAPI, + Role: types.UserRoleAdmin, } - err := store.SaveAccount(context.Background(), account) + err = store.SaveAccount(context.Background(), account) if err != nil { t.Fatalf("Error when saving account: %s", err) } @@ -834,11 +902,15 @@ func TestUser_DeleteUser_RegularUsers(t *testing.T) { } func TestDefaultAccountManager_GetUser(t *testing.T) { - store := newStore(t) - defer store.Close(context.Background()) + store, cleanup, err := store.NewTestStoreFromSQL(context.Background(), "", t.TempDir()) + if err != nil { + t.Fatalf("Error when creating store: %s", err) + } + t.Cleanup(cleanup) + account := newAccountWithId(context.Background(), mockAccountID, mockUserID, "") - err := store.SaveAccount(context.Background(), account) + err = store.SaveAccount(context.Background(), account) if err != nil { t.Fatalf("Error when saving account: %s", err) } @@ -863,13 +935,17 @@ func TestDefaultAccountManager_GetUser(t *testing.T) { } func TestDefaultAccountManager_ListUsers(t *testing.T) { - store := newStore(t) - defer store.Close(context.Background()) - account := newAccountWithId(context.Background(), mockAccountID, mockUserID, "") - account.Users["normal_user1"] = NewRegularUser("normal_user1") - account.Users["normal_user2"] = NewRegularUser("normal_user2") + store, cleanup, err := store.NewTestStoreFromSQL(context.Background(), "", t.TempDir()) + if err != nil { + t.Fatalf("Error when creating store: %s", err) + } + t.Cleanup(cleanup) - err := store.SaveAccount(context.Background(), account) + account := newAccountWithId(context.Background(), mockAccountID, mockUserID, "") + account.Users["normal_user1"] = types.NewRegularUser("normal_user1") + account.Users["normal_user2"] = types.NewRegularUser("normal_user2") + + err = store.SaveAccount(context.Background(), account) if err != nil { t.Fatalf("Error when saving account: %s", err) } @@ -901,43 +977,43 @@ func TestDefaultAccountManager_ListUsers(t *testing.T) { func TestDefaultAccountManager_ListUsers_DashboardPermissions(t *testing.T) { testCases := []struct { name string - role UserRole + role types.UserRole limitedViewSettings bool expectedDashboardPermissions string }{ { name: "Regular user, no limited view settings", - role: UserRoleUser, + role: types.UserRoleUser, limitedViewSettings: false, expectedDashboardPermissions: "limited", }, { name: "Admin user, no limited view settings", - role: UserRoleAdmin, + role: types.UserRoleAdmin, limitedViewSettings: false, expectedDashboardPermissions: "full", }, { name: "Owner, no limited view settings", - role: UserRoleOwner, + role: types.UserRoleOwner, limitedViewSettings: false, expectedDashboardPermissions: "full", }, { name: "Regular user, limited view settings", - role: UserRoleUser, + role: types.UserRoleUser, limitedViewSettings: true, expectedDashboardPermissions: "blocked", }, { name: "Admin user, limited view settings", - role: UserRoleAdmin, + role: types.UserRoleAdmin, limitedViewSettings: true, expectedDashboardPermissions: "full", }, { name: "Owner, limited view settings", - role: UserRoleOwner, + role: types.UserRoleOwner, limitedViewSettings: true, expectedDashboardPermissions: "full", }, @@ -945,13 +1021,18 @@ func TestDefaultAccountManager_ListUsers_DashboardPermissions(t *testing.T) { for _, testCase := range testCases { t.Run(testCase.name, func(t *testing.T) { - store := newStore(t) + store, cleanup, err := store.NewTestStoreFromSQL(context.Background(), "", t.TempDir()) + if err != nil { + t.Fatalf("Error when creating store: %s", err) + } + t.Cleanup(cleanup) + account := newAccountWithId(context.Background(), mockAccountID, mockUserID, "") - account.Users["normal_user1"] = NewUser("normal_user1", testCase.role, false, false, "", []string{}, UserIssuedAPI) + account.Users["normal_user1"] = types.NewUser("normal_user1", testCase.role, false, false, "", []string{}, types.UserIssuedAPI) account.Settings.RegularUsersViewBlocked = testCase.limitedViewSettings delete(account.Users, mockUserID) - err := store.SaveAccount(context.Background(), account) + err = store.SaveAccount(context.Background(), account) if err != nil { t.Fatalf("Error when saving account: %s", err) } @@ -976,13 +1057,17 @@ func TestDefaultAccountManager_ListUsers_DashboardPermissions(t *testing.T) { } func TestDefaultAccountManager_ExternalCache(t *testing.T) { - store := newStore(t) - defer store.Close(context.Background()) + store, cleanup, err := store.NewTestStoreFromSQL(context.Background(), "", t.TempDir()) + if err != nil { + t.Fatalf("Error when creating store: %s", err) + } + t.Cleanup(cleanup) + account := newAccountWithId(context.Background(), mockAccountID, mockUserID, "") - externalUser := &User{ + externalUser := &types.User{ Id: "externalUser", - Role: UserRoleUser, - Issued: UserIssuedIntegration, + Role: types.UserRoleUser, + Issued: types.UserIssuedIntegration, IntegrationReference: integration_reference.IntegrationReference{ ID: 1, IntegrationType: "external", @@ -990,7 +1075,7 @@ func TestDefaultAccountManager_ExternalCache(t *testing.T) { } account.Users[externalUser.Id] = externalUser - err := store.SaveAccount(context.Background(), account) + err = store.SaveAccount(context.Background(), account) if err != nil { t.Fatalf("Error when saving account: %s", err) } @@ -1020,7 +1105,7 @@ func TestDefaultAccountManager_ExternalCache(t *testing.T) { infos, err := am.GetUsersFromAccount(context.Background(), mockAccountID, mockUserID) assert.NoError(t, err) assert.Equal(t, 2, len(infos)) - var user *UserInfo + var user *types.UserInfo for _, info := range infos { if info.ID == externalUser.Id { user = info @@ -1032,24 +1117,28 @@ func TestDefaultAccountManager_ExternalCache(t *testing.T) { func TestUser_IsAdmin(t *testing.T) { - user := NewAdminUser(mockUserID) + user := types.NewAdminUser(mockUserID) assert.True(t, user.HasAdminPower()) - user = NewRegularUser(mockUserID) + user = types.NewRegularUser(mockUserID) assert.False(t, user.HasAdminPower()) } func TestUser_GetUsersFromAccount_ForAdmin(t *testing.T) { - store := newStore(t) - defer store.Close(context.Background()) + store, cleanup, err := store.NewTestStoreFromSQL(context.Background(), "", t.TempDir()) + if err != nil { + t.Fatalf("Error when creating store: %s", err) + } + t.Cleanup(cleanup) + account := newAccountWithId(context.Background(), mockAccountID, mockUserID, "") - account.Users[mockServiceUserID] = &User{ + account.Users[mockServiceUserID] = &types.User{ Id: mockServiceUserID, Role: "user", IsServiceUser: true, } - err := store.SaveAccount(context.Background(), account) + err = store.SaveAccount(context.Background(), account) if err != nil { t.Fatalf("Error when saving account: %s", err) } @@ -1068,17 +1157,20 @@ func TestUser_GetUsersFromAccount_ForAdmin(t *testing.T) { } func TestUser_GetUsersFromAccount_ForUser(t *testing.T) { - store := newStore(t) - defer store.Close(context.Background()) + store, cleanup, err := store.NewTestStoreFromSQL(context.Background(), "", t.TempDir()) + if err != nil { + t.Fatalf("Error when creating store: %s", err) + } + t.Cleanup(cleanup) account := newAccountWithId(context.Background(), mockAccountID, mockUserID, "") - account.Users[mockServiceUserID] = &User{ + account.Users[mockServiceUserID] = &types.User{ Id: mockServiceUserID, Role: "user", IsServiceUser: true, } - err := store.SaveAccount(context.Background(), account) + err = store.SaveAccount(context.Background(), account) if err != nil { t.Fatalf("Error when saving account: %s", err) } @@ -1112,25 +1204,25 @@ func TestDefaultAccountManager_SaveUser(t *testing.T) { tt := []struct { name string initiatorID string - update *User + update *types.User expectedErr bool }{ { name: "Should_Fail_To_Update_Admin_Role", expectedErr: true, initiatorID: adminUserID, - update: &User{ + update: &types.User{ Id: adminUserID, - Role: UserRoleUser, + Role: types.UserRoleUser, Blocked: false, }, }, { name: "Should_Fail_When_Admin_Blocks_Themselves", expectedErr: true, initiatorID: adminUserID, - update: &User{ + update: &types.User{ Id: adminUserID, - Role: UserRoleAdmin, + Role: types.UserRoleAdmin, Blocked: true, }, }, @@ -1138,9 +1230,9 @@ func TestDefaultAccountManager_SaveUser(t *testing.T) { name: "Should_Fail_To_Update_Non_Existing_User", expectedErr: true, initiatorID: adminUserID, - update: &User{ + update: &types.User{ Id: userID, - Role: UserRoleAdmin, + Role: types.UserRoleAdmin, Blocked: true, }, }, @@ -1148,9 +1240,9 @@ func TestDefaultAccountManager_SaveUser(t *testing.T) { name: "Should_Fail_To_Update_When_Initiator_Is_Not_An_Admin", expectedErr: true, initiatorID: regularUserID, - update: &User{ + update: &types.User{ Id: adminUserID, - Role: UserRoleAdmin, + Role: types.UserRoleAdmin, Blocked: true, }, }, @@ -1158,9 +1250,9 @@ func TestDefaultAccountManager_SaveUser(t *testing.T) { name: "Should_Update_User", expectedErr: false, initiatorID: adminUserID, - update: &User{ + update: &types.User{ Id: regularUserID, - Role: UserRoleAdmin, + Role: types.UserRoleAdmin, Blocked: true, }, }, @@ -1168,9 +1260,9 @@ func TestDefaultAccountManager_SaveUser(t *testing.T) { name: "Should_Transfer_Owner_Role_To_User", expectedErr: false, initiatorID: ownerUserID, - update: &User{ + update: &types.User{ Id: adminUserID, - Role: UserRoleAdmin, + Role: types.UserRoleAdmin, Blocked: false, }, }, @@ -1178,9 +1270,9 @@ func TestDefaultAccountManager_SaveUser(t *testing.T) { name: "Should_Fail_To_Transfer_Owner_Role_To_Service_User", expectedErr: true, initiatorID: ownerUserID, - update: &User{ + update: &types.User{ Id: serviceUserID, - Role: UserRoleOwner, + Role: types.UserRoleOwner, Blocked: false, }, }, @@ -1188,9 +1280,9 @@ func TestDefaultAccountManager_SaveUser(t *testing.T) { name: "Should_Fail_To_Update_Owner_User_Role_By_Admin", expectedErr: true, initiatorID: adminUserID, - update: &User{ + update: &types.User{ Id: ownerUserID, - Role: UserRoleAdmin, + Role: types.UserRoleAdmin, Blocked: false, }, }, @@ -1198,9 +1290,9 @@ func TestDefaultAccountManager_SaveUser(t *testing.T) { name: "Should_Fail_To_Update_Owner_User_Role_By_User", expectedErr: true, initiatorID: regularUserID, - update: &User{ + update: &types.User{ Id: ownerUserID, - Role: UserRoleAdmin, + Role: types.UserRoleAdmin, Blocked: false, }, }, @@ -1208,9 +1300,9 @@ func TestDefaultAccountManager_SaveUser(t *testing.T) { name: "Should_Fail_To_Update_Owner_User_Role_By_Service_User", expectedErr: true, initiatorID: serviceUserID, - update: &User{ + update: &types.User{ Id: ownerUserID, - Role: UserRoleAdmin, + Role: types.UserRoleAdmin, Blocked: false, }, }, @@ -1218,9 +1310,9 @@ func TestDefaultAccountManager_SaveUser(t *testing.T) { name: "Should_Fail_To_Update_Owner_Role_By_Admin", expectedErr: true, initiatorID: adminUserID, - update: &User{ + update: &types.User{ Id: regularUserID, - Role: UserRoleOwner, + Role: types.UserRoleOwner, Blocked: false, }, }, @@ -1228,9 +1320,9 @@ func TestDefaultAccountManager_SaveUser(t *testing.T) { name: "Should_Fail_To_Block_Owner_Role_By_Admin", expectedErr: true, initiatorID: adminUserID, - update: &User{ + update: &types.User{ Id: ownerUserID, - Role: UserRoleOwner, + Role: types.UserRoleOwner, Blocked: true, }, }, @@ -1246,9 +1338,9 @@ func TestDefaultAccountManager_SaveUser(t *testing.T) { } // create other users - account.Users[regularUserID] = NewRegularUser(regularUserID) - account.Users[adminUserID] = NewAdminUser(adminUserID) - account.Users[serviceUserID] = &User{IsServiceUser: true, Id: serviceUserID, Role: UserRoleAdmin, ServiceUserName: "service"} + account.Users[regularUserID] = types.NewRegularUser(regularUserID) + account.Users[adminUserID] = types.NewAdminUser(adminUserID) + account.Users[serviceUserID] = &types.User{IsServiceUser: true, Id: serviceUserID, Role: types.UserRoleAdmin, ServiceUserName: "service"} err = manager.Store.SaveAccount(context.Background(), account) if err != nil { t.Fatal(err) @@ -1272,22 +1364,22 @@ func TestUserAccountPeersUpdate(t *testing.T) { // account groups propagation is enabled manager, account, peer1, peer2, peer3 := setupNetworkMapTest(t) - err := manager.SaveGroup(context.Background(), account.Id, userID, &nbgroup.Group{ + err := manager.SaveGroup(context.Background(), account.Id, userID, &types.Group{ ID: "groupA", Name: "GroupA", Peers: []string{peer1.ID, peer2.ID, peer3.ID}, }) require.NoError(t, err) - policy := &Policy{ + policy := &types.Policy{ Enabled: true, - Rules: []*PolicyRule{ + Rules: []*types.PolicyRule{ { Enabled: true, Sources: []string{"groupA"}, Destinations: []string{"groupA"}, Bidirectional: true, - Action: PolicyTrafficActionAccept, + Action: types.PolicyTrafficActionAccept, }, }, } @@ -1307,11 +1399,11 @@ func TestUserAccountPeersUpdate(t *testing.T) { close(done) }() - _, err = manager.SaveOrAddUser(context.Background(), account.Id, userID, &User{ + _, err = manager.SaveOrAddUser(context.Background(), account.Id, userID, &types.User{ Id: "regularUser1", AccountID: account.Id, - Role: UserRoleUser, - Issued: UserIssuedAPI, + Role: types.UserRoleUser, + Issued: types.UserIssuedAPI, }, true) require.NoError(t, err) @@ -1330,11 +1422,11 @@ func TestUserAccountPeersUpdate(t *testing.T) { close(done) }() - _, err = manager.SaveOrAddUser(context.Background(), account.Id, userID, &User{ + _, err = manager.SaveOrAddUser(context.Background(), account.Id, userID, &types.User{ Id: "regularUser1", AccountID: account.Id, - Role: UserRoleUser, - Issued: UserIssuedAPI, + Role: types.UserRoleUser, + Issued: types.UserIssuedAPI, }, false) require.NoError(t, err) @@ -1364,11 +1456,11 @@ func TestUserAccountPeersUpdate(t *testing.T) { }) // create a user and add new peer with the user - _, err = manager.SaveOrAddUser(context.Background(), account.Id, userID, &User{ + _, err = manager.SaveOrAddUser(context.Background(), account.Id, userID, &types.User{ Id: "regularUser2", AccountID: account.Id, - Role: UserRoleAdmin, - Issued: UserIssuedAPI, + Role: types.UserRoleAdmin, + Issued: types.UserIssuedAPI, }, true) require.NoError(t, err) @@ -1390,11 +1482,11 @@ func TestUserAccountPeersUpdate(t *testing.T) { close(done) }() - _, err = manager.SaveOrAddUser(context.Background(), account.Id, userID, &User{ + _, err = manager.SaveOrAddUser(context.Background(), account.Id, userID, &types.User{ Id: "regularUser2", AccountID: account.Id, - Role: UserRoleAdmin, - Issued: UserIssuedAPI, + Role: types.UserRoleAdmin, + Issued: types.UserIssuedAPI, }, false) require.NoError(t, err) diff --git a/management/server/users/manager.go b/management/server/users/manager.go new file mode 100644 index 000000000..76291a678 --- /dev/null +++ b/management/server/users/manager.go @@ -0,0 +1,26 @@ +package users + +import ( + "context" + + "github.com/netbirdio/netbird/management/server/store" + "github.com/netbirdio/netbird/management/server/types" +) + +type Manager interface { + GetUser(ctx context.Context, userID string) (*types.User, error) +} + +type managerImpl struct { + store store.Store +} + +func NewManager(store store.Store) Manager { + return &managerImpl{ + store: store, + } +} + +func (m *managerImpl) GetUser(ctx context.Context, userID string) (*types.User, error) { + return m.store.GetUserByUserID(ctx, store.LockingStrengthShare, userID) +} diff --git a/management/server/util/util.go b/management/server/util/util.go new file mode 100644 index 000000000..ff738781f --- /dev/null +++ b/management/server/util/util.go @@ -0,0 +1,16 @@ +package util + +// Difference returns the elements in `a` that aren't in `b`. +func Difference(a, b []string) []string { + mb := make(map[string]struct{}, len(b)) + for _, x := range b { + mb[x] = struct{}{} + } + var diff []string + for _, x := range a { + if _, found := mb[x]; !found { + diff = append(diff, x) + } + } + return diff +}