diff --git a/.github/workflows/golang-test-darwin.yml b/.github/workflows/golang-test-darwin.yml index 664e8be18..4571ce753 100644 --- a/.github/workflows/golang-test-darwin.yml +++ b/.github/workflows/golang-test-darwin.yml @@ -1,4 +1,4 @@ -name: Test Code Darwin +name: "Darwin" on: push: @@ -12,9 +12,7 @@ concurrency: jobs: test: - strategy: - matrix: - store: ['sqlite'] + name: "Client / Unit" runs-on: macos-latest steps: - name: Install Go diff --git a/.github/workflows/golang-test-freebsd.yml b/.github/workflows/golang-test-freebsd.yml index 0f510cb3a..e1c688b1b 100644 --- a/.github/workflows/golang-test-freebsd.yml +++ b/.github/workflows/golang-test-freebsd.yml @@ -1,5 +1,4 @@ - -name: Test Code FreeBSD +name: "FreeBSD" on: push: @@ -13,6 +12,7 @@ concurrency: jobs: test: + name: "Client / Unit" runs-on: ubuntu-22.04 steps: - uses: actions/checkout@v4 diff --git a/.github/workflows/golang-test-linux.yml b/.github/workflows/golang-test-linux.yml index efe1a2654..3be8bcff3 100644 --- a/.github/workflows/golang-test-linux.yml +++ b/.github/workflows/golang-test-linux.yml @@ -316,7 +316,7 @@ jobs: NETBIRD_STORE_ENGINE=${{ matrix.store }} \ go test -tags=devcert \ -exec "sudo --preserve-env=CI,NETBIRD_STORE_ENGINE" \ - -timeout 10m ./management/... + -timeout 20m ./management/... benchmark: name: "Management / Benchmark" @@ -508,7 +508,7 @@ jobs: NETBIRD_STORE_ENGINE=${{ matrix.store }} CI=true \ go test -tags=integration \ -exec 'sudo --preserve-env=CI,NETBIRD_STORE_ENGINE' \ - -timeout 10m ./management/... + -timeout 20m ./management/... test_client_on_docker: name: "Client (Docker) / Unit" diff --git a/.github/workflows/golang-test-windows.yml b/.github/workflows/golang-test-windows.yml index 782e4c30a..d9ff0a84b 100644 --- a/.github/workflows/golang-test-windows.yml +++ b/.github/workflows/golang-test-windows.yml @@ -1,4 +1,4 @@ -name: Test Code Windows +name: "Windows" on: push: @@ -14,6 +14,7 @@ concurrency: jobs: test: + name: "Client / Unit" runs-on: windows-latest steps: - name: Checkout code diff --git a/.github/workflows/golangci-lint.yml b/.github/workflows/golangci-lint.yml index 6705a34ec..ca075d30f 100644 --- a/.github/workflows/golangci-lint.yml +++ b/.github/workflows/golangci-lint.yml @@ -1,4 +1,4 @@ -name: golangci-lint +name: Lint on: [pull_request] permissions: @@ -27,7 +27,14 @@ jobs: fail-fast: false matrix: os: [macos-latest, windows-latest, ubuntu-latest] - name: lint + include: + - os: macos-latest + display_name: Darwin + - os: windows-latest + display_name: Windows + - os: ubuntu-latest + display_name: Linux + name: ${{ matrix.display_name }} runs-on: ${{ matrix.os }} timeout-minutes: 15 steps: diff --git a/.github/workflows/mobile-build-validation.yml b/.github/workflows/mobile-build-validation.yml index dcf461a34..569956a54 100644 --- a/.github/workflows/mobile-build-validation.yml +++ b/.github/workflows/mobile-build-validation.yml @@ -1,4 +1,4 @@ -name: Mobile build validation +name: Mobile on: push: @@ -12,6 +12,7 @@ concurrency: jobs: android_build: + name: "Android / Build" runs-on: ubuntu-latest steps: - name: Checkout repository @@ -47,6 +48,7 @@ jobs: CGO_ENABLED: 0 ANDROID_NDK_HOME: /usr/local/lib/android/sdk/ndk/23.1.7779620 ios_build: + name: "iOS / Build" runs-on: macos-latest steps: - name: Checkout repository diff --git a/.gitignore b/.gitignore index d0b4f82dd..abb728b19 100644 --- a/.gitignore +++ b/.gitignore @@ -29,3 +29,4 @@ infrastructure_files/setup.env infrastructure_files/setup-*.env .vscode .DS_Store +vendor/ diff --git a/.goreleaser_ui.yaml b/.goreleaser_ui.yaml index 983aa0e78..1dd649d1b 100644 --- a/.goreleaser_ui.yaml +++ b/.goreleaser_ui.yaml @@ -50,6 +50,8 @@ nfpms: - netbird-ui formats: - deb + scripts: + postinstall: "release_files/ui-post-install.sh" contents: - src: client/ui/netbird.desktop dst: /usr/share/applications/netbird.desktop @@ -67,6 +69,8 @@ nfpms: - netbird-ui formats: - rpm + scripts: + postinstall: "release_files/ui-post-install.sh" contents: - src: client/ui/netbird.desktop dst: /usr/share/applications/netbird.desktop diff --git a/README.md b/README.md index 0537710e9..5b136eff6 100644 --- a/README.md +++ b/README.md @@ -1,4 +1,6 @@
+
+

@@ -31,6 +33,10 @@
+
+ + Webinar: Securely Access Kubernetes without Port Forwarding and Jump Hosts +


diff --git a/client/Dockerfile b/client/Dockerfile index 2f5ff14ae..35c1d04c2 100644 --- a/client/Dockerfile +++ b/client/Dockerfile @@ -1,4 +1,4 @@ -FROM alpine:3.21.0 +FROM alpine:3.21.3 RUN apk add --no-cache ca-certificates iptables ip6tables ENV NB_FOREGROUND_MODE=true ENTRYPOINT [ "/usr/local/bin/netbird","up"] diff --git a/client/cmd/debug.go b/client/cmd/debug.go index c7ab87b47..c02f60aed 100644 --- a/client/cmd/debug.go +++ b/client/cmd/debug.go @@ -13,6 +13,7 @@ import ( "github.com/netbirdio/netbird/client/internal" "github.com/netbirdio/netbird/client/proto" "github.com/netbirdio/netbird/client/server" + nbstatus "github.com/netbirdio/netbird/client/status" ) const errCloseConnection = "Failed to close connection: %v" @@ -85,7 +86,7 @@ func debugBundle(cmd *cobra.Command, _ []string) error { client := proto.NewDaemonServiceClient(conn) resp, err := client.DebugBundle(cmd.Context(), &proto.DebugBundleRequest{ Anonymize: anonymizeFlag, - Status: getStatusOutput(cmd), + Status: getStatusOutput(cmd, anonymizeFlag), SystemInfo: debugSystemInfoFlag, }) if err != nil { @@ -196,7 +197,7 @@ func runForDuration(cmd *cobra.Command, args []string) error { time.Sleep(3 * time.Second) headerPostUp := fmt.Sprintf("----- Netbird post-up - Timestamp: %s", time.Now().Format(time.RFC3339)) - statusOutput := fmt.Sprintf("%s\n%s", headerPostUp, getStatusOutput(cmd)) + statusOutput := fmt.Sprintf("%s\n%s", headerPostUp, getStatusOutput(cmd, anonymizeFlag)) if waitErr := waitForDurationOrCancel(cmd.Context(), duration, cmd); waitErr != nil { return waitErr @@ -206,7 +207,7 @@ func runForDuration(cmd *cobra.Command, args []string) error { cmd.Println("Creating debug bundle...") headerPreDown := fmt.Sprintf("----- Netbird pre-down - Timestamp: %s - Duration: %s", time.Now().Format(time.RFC3339), duration) - statusOutput = fmt.Sprintf("%s\n%s\n%s", statusOutput, headerPreDown, getStatusOutput(cmd)) + statusOutput = fmt.Sprintf("%s\n%s\n%s", statusOutput, headerPreDown, getStatusOutput(cmd, anonymizeFlag)) resp, err := client.DebugBundle(cmd.Context(), &proto.DebugBundleRequest{ Anonymize: anonymizeFlag, @@ -271,13 +272,15 @@ func setNetworkMapPersistence(cmd *cobra.Command, args []string) error { return nil } -func getStatusOutput(cmd *cobra.Command) string { +func getStatusOutput(cmd *cobra.Command, anon bool) string { var statusOutputString string statusResp, err := getStatus(cmd.Context()) if err != nil { cmd.PrintErrf("Failed to get status: %v\n", err) } else { - statusOutputString = parseToFullDetailSummary(convertToStatusOutputOverview(statusResp)) + statusOutputString = nbstatus.ParseToFullDetailSummary( + nbstatus.ConvertToStatusOutputOverview(statusResp, anon, "", nil, nil, nil), + ) } return statusOutputString } diff --git a/client/cmd/status.go b/client/cmd/status.go index 1deef487b..0ddba8b2f 100644 --- a/client/cmd/status.go +++ b/client/cmd/status.go @@ -2,108 +2,20 @@ package cmd import ( "context" - "encoding/json" "fmt" "net" "net/netip" - "os" - "runtime" - "sort" "strings" - "time" "github.com/spf13/cobra" "google.golang.org/grpc/status" - "gopkg.in/yaml.v3" - "github.com/netbirdio/netbird/client/anonymize" "github.com/netbirdio/netbird/client/internal" - "github.com/netbirdio/netbird/client/internal/peer" "github.com/netbirdio/netbird/client/proto" + nbstatus "github.com/netbirdio/netbird/client/status" "github.com/netbirdio/netbird/util" - "github.com/netbirdio/netbird/version" ) -type peerStateDetailOutput struct { - FQDN string `json:"fqdn" yaml:"fqdn"` - IP string `json:"netbirdIp" yaml:"netbirdIp"` - PubKey string `json:"publicKey" yaml:"publicKey"` - Status string `json:"status" yaml:"status"` - LastStatusUpdate time.Time `json:"lastStatusUpdate" yaml:"lastStatusUpdate"` - ConnType string `json:"connectionType" yaml:"connectionType"` - IceCandidateType iceCandidateType `json:"iceCandidateType" yaml:"iceCandidateType"` - IceCandidateEndpoint iceCandidateType `json:"iceCandidateEndpoint" yaml:"iceCandidateEndpoint"` - RelayAddress string `json:"relayAddress" yaml:"relayAddress"` - LastWireguardHandshake time.Time `json:"lastWireguardHandshake" yaml:"lastWireguardHandshake"` - TransferReceived int64 `json:"transferReceived" yaml:"transferReceived"` - TransferSent int64 `json:"transferSent" yaml:"transferSent"` - Latency time.Duration `json:"latency" yaml:"latency"` - RosenpassEnabled bool `json:"quantumResistance" yaml:"quantumResistance"` - Networks []string `json:"networks" yaml:"networks"` -} - -type peersStateOutput struct { - Total int `json:"total" yaml:"total"` - Connected int `json:"connected" yaml:"connected"` - Details []peerStateDetailOutput `json:"details" yaml:"details"` -} - -type signalStateOutput struct { - URL string `json:"url" yaml:"url"` - Connected bool `json:"connected" yaml:"connected"` - Error string `json:"error" yaml:"error"` -} - -type managementStateOutput struct { - URL string `json:"url" yaml:"url"` - Connected bool `json:"connected" yaml:"connected"` - Error string `json:"error" yaml:"error"` -} - -type relayStateOutputDetail struct { - URI string `json:"uri" yaml:"uri"` - Available bool `json:"available" yaml:"available"` - Error string `json:"error" yaml:"error"` -} - -type relayStateOutput struct { - Total int `json:"total" yaml:"total"` - Available int `json:"available" yaml:"available"` - Details []relayStateOutputDetail `json:"details" yaml:"details"` -} - -type iceCandidateType struct { - Local string `json:"local" yaml:"local"` - Remote string `json:"remote" yaml:"remote"` -} - -type nsServerGroupStateOutput struct { - Servers []string `json:"servers" yaml:"servers"` - Domains []string `json:"domains" yaml:"domains"` - Enabled bool `json:"enabled" yaml:"enabled"` - Error string `json:"error" yaml:"error"` -} - -type statusOutputOverview struct { - Peers peersStateOutput `json:"peers" yaml:"peers"` - CliVersion string `json:"cliVersion" yaml:"cliVersion"` - DaemonVersion string `json:"daemonVersion" yaml:"daemonVersion"` - ManagementState managementStateOutput `json:"management" yaml:"management"` - SignalState signalStateOutput `json:"signal" yaml:"signal"` - Relays relayStateOutput `json:"relays" yaml:"relays"` - IP string `json:"netbirdIp" yaml:"netbirdIp"` - PubKey string `json:"publicKey" yaml:"publicKey"` - KernelInterface bool `json:"usesKernelInterface" yaml:"usesKernelInterface"` - FQDN string `json:"fqdn" yaml:"fqdn"` - RosenpassEnabled bool `json:"quantumResistance" yaml:"quantumResistance"` - RosenpassPermissive bool `json:"quantumResistancePermissive" yaml:"quantumResistancePermissive"` - - Networks []string `json:"networks" yaml:"networks"` - NumberOfForwardingRules int `json:"forwardingRules" yaml:"forwardingRules"` - NSServerGroups []nsServerGroupStateOutput `json:"dnsServers" yaml:"dnsServers"` - Events []systemEventOutput `json:"events" yaml:"events"` -} - var ( detailFlag bool ipv4Flag bool @@ -174,18 +86,17 @@ func statusFunc(cmd *cobra.Command, args []string) error { return nil } - outputInformationHolder := convertToStatusOutputOverview(resp) - + var outputInformationHolder = nbstatus.ConvertToStatusOutputOverview(resp, anonymizeFlag, statusFilter, prefixNamesFilter, prefixNamesFilterMap, ipsFilterMap) var statusOutputString string switch { case detailFlag: - statusOutputString = parseToFullDetailSummary(outputInformationHolder) + statusOutputString = nbstatus.ParseToFullDetailSummary(outputInformationHolder) case jsonFlag: - statusOutputString, err = parseToJSON(outputInformationHolder) + statusOutputString, err = nbstatus.ParseToJSON(outputInformationHolder) case yamlFlag: - statusOutputString, err = parseToYAML(outputInformationHolder) + statusOutputString, err = nbstatus.ParseToYAML(outputInformationHolder) default: - statusOutputString = parseGeneralSummary(outputInformationHolder, false, false, false) + statusOutputString = nbstatus.ParseGeneralSummary(outputInformationHolder, false, false, false) } if err != nil { @@ -215,7 +126,6 @@ func getStatus(ctx context.Context) (*proto.StatusResponse, error) { } func parseFilters() error { - switch strings.ToLower(statusFilter) { case "", "disconnected", "connected": if strings.ToLower(statusFilter) != "" { @@ -252,176 +162,6 @@ func enableDetailFlagWhenFilterFlag() { } } -func convertToStatusOutputOverview(resp *proto.StatusResponse) statusOutputOverview { - pbFullStatus := resp.GetFullStatus() - - managementState := pbFullStatus.GetManagementState() - managementOverview := managementStateOutput{ - URL: managementState.GetURL(), - Connected: managementState.GetConnected(), - Error: managementState.Error, - } - - signalState := pbFullStatus.GetSignalState() - signalOverview := signalStateOutput{ - URL: signalState.GetURL(), - Connected: signalState.GetConnected(), - Error: signalState.Error, - } - - relayOverview := mapRelays(pbFullStatus.GetRelays()) - peersOverview := mapPeers(resp.GetFullStatus().GetPeers()) - - overview := statusOutputOverview{ - Peers: peersOverview, - CliVersion: version.NetbirdVersion(), - DaemonVersion: resp.GetDaemonVersion(), - ManagementState: managementOverview, - SignalState: signalOverview, - Relays: relayOverview, - IP: pbFullStatus.GetLocalPeerState().GetIP(), - PubKey: pbFullStatus.GetLocalPeerState().GetPubKey(), - KernelInterface: pbFullStatus.GetLocalPeerState().GetKernelInterface(), - FQDN: pbFullStatus.GetLocalPeerState().GetFqdn(), - RosenpassEnabled: pbFullStatus.GetLocalPeerState().GetRosenpassEnabled(), - RosenpassPermissive: pbFullStatus.GetLocalPeerState().GetRosenpassPermissive(), - - Networks: pbFullStatus.GetLocalPeerState().GetNetworks(), - NumberOfForwardingRules: int(pbFullStatus.GetNumberOfForwardingRules()), - NSServerGroups: mapNSGroups(pbFullStatus.GetDnsServers()), - Events: mapEvents(pbFullStatus.GetEvents()), - } - - if anonymizeFlag { - anonymizer := anonymize.NewAnonymizer(anonymize.DefaultAddresses()) - anonymizeOverview(anonymizer, &overview) - } - - return overview -} - -func mapRelays(relays []*proto.RelayState) relayStateOutput { - var relayStateDetail []relayStateOutputDetail - - var relaysAvailable int - for _, relay := range relays { - available := relay.GetAvailable() - relayStateDetail = append(relayStateDetail, - relayStateOutputDetail{ - URI: relay.URI, - Available: available, - Error: relay.GetError(), - }, - ) - - if available { - relaysAvailable++ - } - } - - return relayStateOutput{ - Total: len(relays), - Available: relaysAvailable, - Details: relayStateDetail, - } -} - -func mapNSGroups(servers []*proto.NSGroupState) []nsServerGroupStateOutput { - mappedNSGroups := make([]nsServerGroupStateOutput, 0, len(servers)) - for _, pbNsGroupServer := range servers { - mappedNSGroups = append(mappedNSGroups, nsServerGroupStateOutput{ - Servers: pbNsGroupServer.GetServers(), - Domains: pbNsGroupServer.GetDomains(), - Enabled: pbNsGroupServer.GetEnabled(), - Error: pbNsGroupServer.GetError(), - }) - } - return mappedNSGroups -} - -func mapPeers(peers []*proto.PeerState) peersStateOutput { - var peersStateDetail []peerStateDetailOutput - peersConnected := 0 - for _, pbPeerState := range peers { - localICE := "" - remoteICE := "" - localICEEndpoint := "" - remoteICEEndpoint := "" - relayServerAddress := "" - connType := "" - lastHandshake := time.Time{} - transferReceived := int64(0) - transferSent := int64(0) - - isPeerConnected := pbPeerState.ConnStatus == peer.StatusConnected.String() - if skipDetailByFilters(pbPeerState, isPeerConnected) { - continue - } - if isPeerConnected { - peersConnected++ - - localICE = pbPeerState.GetLocalIceCandidateType() - remoteICE = pbPeerState.GetRemoteIceCandidateType() - localICEEndpoint = pbPeerState.GetLocalIceCandidateEndpoint() - remoteICEEndpoint = pbPeerState.GetRemoteIceCandidateEndpoint() - connType = "P2P" - if pbPeerState.Relayed { - connType = "Relayed" - } - relayServerAddress = pbPeerState.GetRelayAddress() - lastHandshake = pbPeerState.GetLastWireguardHandshake().AsTime().Local() - transferReceived = pbPeerState.GetBytesRx() - transferSent = pbPeerState.GetBytesTx() - } - - timeLocal := pbPeerState.GetConnStatusUpdate().AsTime().Local() - peerState := peerStateDetailOutput{ - IP: pbPeerState.GetIP(), - PubKey: pbPeerState.GetPubKey(), - Status: pbPeerState.GetConnStatus(), - LastStatusUpdate: timeLocal, - ConnType: connType, - IceCandidateType: iceCandidateType{ - Local: localICE, - Remote: remoteICE, - }, - IceCandidateEndpoint: iceCandidateType{ - Local: localICEEndpoint, - Remote: remoteICEEndpoint, - }, - RelayAddress: relayServerAddress, - FQDN: pbPeerState.GetFqdn(), - LastWireguardHandshake: lastHandshake, - TransferReceived: transferReceived, - TransferSent: transferSent, - Latency: pbPeerState.GetLatency().AsDuration(), - RosenpassEnabled: pbPeerState.GetRosenpassEnabled(), - Networks: pbPeerState.GetNetworks(), - } - - peersStateDetail = append(peersStateDetail, peerState) - } - - sortPeersByIP(peersStateDetail) - - peersOverview := peersStateOutput{ - Total: len(peersStateDetail), - Connected: peersConnected, - Details: peersStateDetail, - } - return peersOverview -} - -func sortPeersByIP(peersStateDetail []peerStateDetailOutput) { - if len(peersStateDetail) > 0 { - sort.SliceStable(peersStateDetail, func(i, j int) bool { - iAddr, _ := netip.ParseAddr(peersStateDetail[i].IP) - jAddr, _ := netip.ParseAddr(peersStateDetail[j].IP) - return iAddr.Compare(jAddr) == -1 - }) - } -} - func parseInterfaceIP(interfaceIP string) string { ip, _, err := net.ParseCIDR(interfaceIP) if err != nil { @@ -429,451 +169,3 @@ func parseInterfaceIP(interfaceIP string) string { } return fmt.Sprintf("%s\n", ip) } - -func parseToJSON(overview statusOutputOverview) (string, error) { - jsonBytes, err := json.Marshal(overview) - if err != nil { - return "", fmt.Errorf("json marshal failed") - } - return string(jsonBytes), err -} - -func parseToYAML(overview statusOutputOverview) (string, error) { - yamlBytes, err := yaml.Marshal(overview) - if err != nil { - return "", fmt.Errorf("yaml marshal failed") - } - return string(yamlBytes), nil -} - -func parseGeneralSummary(overview statusOutputOverview, showURL bool, showRelays bool, showNameServers bool) string { - var managementConnString string - if overview.ManagementState.Connected { - managementConnString = "Connected" - if showURL { - managementConnString = fmt.Sprintf("%s to %s", managementConnString, overview.ManagementState.URL) - } - } else { - managementConnString = "Disconnected" - if overview.ManagementState.Error != "" { - managementConnString = fmt.Sprintf("%s, reason: %s", managementConnString, overview.ManagementState.Error) - } - } - - var signalConnString string - if overview.SignalState.Connected { - signalConnString = "Connected" - if showURL { - signalConnString = fmt.Sprintf("%s to %s", signalConnString, overview.SignalState.URL) - } - } else { - signalConnString = "Disconnected" - if overview.SignalState.Error != "" { - signalConnString = fmt.Sprintf("%s, reason: %s", signalConnString, overview.SignalState.Error) - } - } - - interfaceTypeString := "Userspace" - interfaceIP := overview.IP - if overview.KernelInterface { - interfaceTypeString = "Kernel" - } else if overview.IP == "" { - interfaceTypeString = "N/A" - interfaceIP = "N/A" - } - - var relaysString string - if showRelays { - for _, relay := range overview.Relays.Details { - available := "Available" - reason := "" - if !relay.Available { - available = "Unavailable" - reason = fmt.Sprintf(", reason: %s", relay.Error) - } - relaysString += fmt.Sprintf("\n [%s] is %s%s", relay.URI, available, reason) - } - } else { - relaysString = fmt.Sprintf("%d/%d Available", overview.Relays.Available, overview.Relays.Total) - } - - networks := "-" - if len(overview.Networks) > 0 { - sort.Strings(overview.Networks) - networks = strings.Join(overview.Networks, ", ") - } - - var dnsServersString string - if showNameServers { - for _, nsServerGroup := range overview.NSServerGroups { - enabled := "Available" - if !nsServerGroup.Enabled { - enabled = "Unavailable" - } - errorString := "" - if nsServerGroup.Error != "" { - errorString = fmt.Sprintf(", reason: %s", nsServerGroup.Error) - errorString = strings.TrimSpace(errorString) - } - - domainsString := strings.Join(nsServerGroup.Domains, ", ") - if domainsString == "" { - domainsString = "." // Show "." for the default zone - } - dnsServersString += fmt.Sprintf( - "\n [%s] for [%s] is %s%s", - strings.Join(nsServerGroup.Servers, ", "), - domainsString, - enabled, - errorString, - ) - } - } else { - dnsServersString = fmt.Sprintf("%d/%d Available", countEnabled(overview.NSServerGroups), len(overview.NSServerGroups)) - } - - rosenpassEnabledStatus := "false" - if overview.RosenpassEnabled { - rosenpassEnabledStatus = "true" - if overview.RosenpassPermissive { - rosenpassEnabledStatus = "true (permissive)" //nolint:gosec - } - } - - peersCountString := fmt.Sprintf("%d/%d Connected", overview.Peers.Connected, overview.Peers.Total) - - goos := runtime.GOOS - goarch := runtime.GOARCH - goarm := "" - if goarch == "arm" { - goarm = fmt.Sprintf(" (ARMv%s)", os.Getenv("GOARM")) - } - - summary := fmt.Sprintf( - "OS: %s\n"+ - "Daemon version: %s\n"+ - "CLI version: %s\n"+ - "Management: %s\n"+ - "Signal: %s\n"+ - "Relays: %s\n"+ - "Nameservers: %s\n"+ - "FQDN: %s\n"+ - "NetBird IP: %s\n"+ - "Interface type: %s\n"+ - "Quantum resistance: %s\n"+ - "Networks: %s\n"+ - "Forwarding rules: %d\n"+ - "Peers count: %s\n", - fmt.Sprintf("%s/%s%s", goos, goarch, goarm), - overview.DaemonVersion, - version.NetbirdVersion(), - managementConnString, - signalConnString, - relaysString, - dnsServersString, - overview.FQDN, - interfaceIP, - interfaceTypeString, - rosenpassEnabledStatus, - networks, - overview.NumberOfForwardingRules, - peersCountString, - ) - return summary -} - -func parseToFullDetailSummary(overview statusOutputOverview) string { - parsedPeersString := parsePeers(overview.Peers, overview.RosenpassEnabled, overview.RosenpassPermissive) - parsedEventsString := parseEvents(overview.Events) - summary := parseGeneralSummary(overview, true, true, true) - - return fmt.Sprintf( - "Peers detail:"+ - "%s\n"+ - "Events:"+ - "%s\n"+ - "%s", - parsedPeersString, - parsedEventsString, - summary, - ) -} - -func parsePeers(peers peersStateOutput, rosenpassEnabled, rosenpassPermissive bool) string { - var ( - peersString = "" - ) - - for _, peerState := range peers.Details { - - localICE := "-" - if peerState.IceCandidateType.Local != "" { - localICE = peerState.IceCandidateType.Local - } - - remoteICE := "-" - if peerState.IceCandidateType.Remote != "" { - remoteICE = peerState.IceCandidateType.Remote - } - - localICEEndpoint := "-" - if peerState.IceCandidateEndpoint.Local != "" { - localICEEndpoint = peerState.IceCandidateEndpoint.Local - } - - remoteICEEndpoint := "-" - if peerState.IceCandidateEndpoint.Remote != "" { - remoteICEEndpoint = peerState.IceCandidateEndpoint.Remote - } - - rosenpassEnabledStatus := "false" - if rosenpassEnabled { - if peerState.RosenpassEnabled { - rosenpassEnabledStatus = "true" - } else { - if rosenpassPermissive { - rosenpassEnabledStatus = "false (remote didn't enable quantum resistance)" - } else { - rosenpassEnabledStatus = "false (connection won't work without a permissive mode)" - } - } - } else { - if peerState.RosenpassEnabled { - rosenpassEnabledStatus = "false (connection might not work without a remote permissive mode)" - } - } - - networks := "-" - if len(peerState.Networks) > 0 { - sort.Strings(peerState.Networks) - networks = strings.Join(peerState.Networks, ", ") - } - - peerString := fmt.Sprintf( - "\n %s:\n"+ - " NetBird IP: %s\n"+ - " Public key: %s\n"+ - " Status: %s\n"+ - " -- detail --\n"+ - " Connection type: %s\n"+ - " ICE candidate (Local/Remote): %s/%s\n"+ - " ICE candidate endpoints (Local/Remote): %s/%s\n"+ - " Relay server address: %s\n"+ - " Last connection update: %s\n"+ - " Last WireGuard handshake: %s\n"+ - " Transfer status (received/sent) %s/%s\n"+ - " Quantum resistance: %s\n"+ - " Networks: %s\n"+ - " Latency: %s\n", - peerState.FQDN, - peerState.IP, - peerState.PubKey, - peerState.Status, - peerState.ConnType, - localICE, - remoteICE, - localICEEndpoint, - remoteICEEndpoint, - peerState.RelayAddress, - timeAgo(peerState.LastStatusUpdate), - timeAgo(peerState.LastWireguardHandshake), - toIEC(peerState.TransferReceived), - toIEC(peerState.TransferSent), - rosenpassEnabledStatus, - networks, - peerState.Latency.String(), - ) - - peersString += peerString - } - return peersString -} - -func skipDetailByFilters(peerState *proto.PeerState, isConnected bool) bool { - statusEval := false - ipEval := false - nameEval := true - - if statusFilter != "" { - lowerStatusFilter := strings.ToLower(statusFilter) - if lowerStatusFilter == "disconnected" && isConnected { - statusEval = true - } else if lowerStatusFilter == "connected" && !isConnected { - statusEval = true - } - } - - if len(ipsFilter) > 0 { - _, ok := ipsFilterMap[peerState.IP] - if !ok { - ipEval = true - } - } - - if len(prefixNamesFilter) > 0 { - for prefixNameFilter := range prefixNamesFilterMap { - if strings.HasPrefix(peerState.Fqdn, prefixNameFilter) { - nameEval = false - break - } - } - } else { - nameEval = false - } - - return statusEval || ipEval || nameEval -} - -func toIEC(b int64) string { - const unit = 1024 - if b < unit { - return fmt.Sprintf("%d B", b) - } - div, exp := int64(unit), 0 - for n := b / unit; n >= unit; n /= unit { - div *= unit - exp++ - } - return fmt.Sprintf("%.1f %ciB", - float64(b)/float64(div), "KMGTPE"[exp]) -} - -func countEnabled(dnsServers []nsServerGroupStateOutput) int { - count := 0 - for _, server := range dnsServers { - if server.Enabled { - count++ - } - } - return count -} - -// timeAgo returns a string representing the duration since the provided time in a human-readable format. -func timeAgo(t time.Time) string { - if t.IsZero() || t.Equal(time.Unix(0, 0)) { - return "-" - } - duration := time.Since(t) - switch { - case duration < time.Second: - return "Now" - case duration < time.Minute: - seconds := int(duration.Seconds()) - if seconds == 1 { - return "1 second ago" - } - return fmt.Sprintf("%d seconds ago", seconds) - case duration < time.Hour: - minutes := int(duration.Minutes()) - seconds := int(duration.Seconds()) % 60 - if minutes == 1 { - if seconds == 1 { - return "1 minute, 1 second ago" - } else if seconds > 0 { - return fmt.Sprintf("1 minute, %d seconds ago", seconds) - } - return "1 minute ago" - } - if seconds > 0 { - return fmt.Sprintf("%d minutes, %d seconds ago", minutes, seconds) - } - return fmt.Sprintf("%d minutes ago", minutes) - case duration < 24*time.Hour: - hours := int(duration.Hours()) - minutes := int(duration.Minutes()) % 60 - if hours == 1 { - if minutes == 1 { - return "1 hour, 1 minute ago" - } else if minutes > 0 { - return fmt.Sprintf("1 hour, %d minutes ago", minutes) - } - return "1 hour ago" - } - if minutes > 0 { - return fmt.Sprintf("%d hours, %d minutes ago", hours, minutes) - } - return fmt.Sprintf("%d hours ago", hours) - } - - days := int(duration.Hours()) / 24 - hours := int(duration.Hours()) % 24 - if days == 1 { - if hours == 1 { - return "1 day, 1 hour ago" - } else if hours > 0 { - return fmt.Sprintf("1 day, %d hours ago", hours) - } - return "1 day ago" - } - if hours > 0 { - return fmt.Sprintf("%d days, %d hours ago", days, hours) - } - return fmt.Sprintf("%d days ago", days) -} - -func anonymizePeerDetail(a *anonymize.Anonymizer, peer *peerStateDetailOutput) { - peer.FQDN = a.AnonymizeDomain(peer.FQDN) - if localIP, port, err := net.SplitHostPort(peer.IceCandidateEndpoint.Local); err == nil { - peer.IceCandidateEndpoint.Local = fmt.Sprintf("%s:%s", a.AnonymizeIPString(localIP), port) - } - if remoteIP, port, err := net.SplitHostPort(peer.IceCandidateEndpoint.Remote); err == nil { - peer.IceCandidateEndpoint.Remote = fmt.Sprintf("%s:%s", a.AnonymizeIPString(remoteIP), port) - } - - peer.RelayAddress = a.AnonymizeURI(peer.RelayAddress) - - for i, route := range peer.Networks { - peer.Networks[i] = a.AnonymizeIPString(route) - } - - for i, route := range peer.Networks { - peer.Networks[i] = a.AnonymizeRoute(route) - } -} - -func anonymizeOverview(a *anonymize.Anonymizer, overview *statusOutputOverview) { - for i, peer := range overview.Peers.Details { - peer := peer - anonymizePeerDetail(a, &peer) - overview.Peers.Details[i] = peer - } - - overview.ManagementState.URL = a.AnonymizeURI(overview.ManagementState.URL) - overview.ManagementState.Error = a.AnonymizeString(overview.ManagementState.Error) - overview.SignalState.URL = a.AnonymizeURI(overview.SignalState.URL) - overview.SignalState.Error = a.AnonymizeString(overview.SignalState.Error) - - overview.IP = a.AnonymizeIPString(overview.IP) - for i, detail := range overview.Relays.Details { - detail.URI = a.AnonymizeURI(detail.URI) - detail.Error = a.AnonymizeString(detail.Error) - overview.Relays.Details[i] = detail - } - - for i, nsGroup := range overview.NSServerGroups { - for j, domain := range nsGroup.Domains { - overview.NSServerGroups[i].Domains[j] = a.AnonymizeDomain(domain) - } - for j, ns := range nsGroup.Servers { - host, port, err := net.SplitHostPort(ns) - if err == nil { - overview.NSServerGroups[i].Servers[j] = fmt.Sprintf("%s:%s", a.AnonymizeIPString(host), port) - } - } - } - - for i, route := range overview.Networks { - overview.Networks[i] = a.AnonymizeRoute(route) - } - - overview.FQDN = a.AnonymizeDomain(overview.FQDN) - - for i, event := range overview.Events { - overview.Events[i].Message = a.AnonymizeString(event.Message) - overview.Events[i].UserMessage = a.AnonymizeString(event.UserMessage) - - for k, v := range event.Metadata { - event.Metadata[k] = a.AnonymizeString(v) - } - } -} diff --git a/client/cmd/status_test.go b/client/cmd/status_test.go index 0b0ae4c51..03608eab0 100644 --- a/client/cmd/status_test.go +++ b/client/cmd/status_test.go @@ -1,583 +1,11 @@ package cmd import ( - "bytes" - "encoding/json" - "fmt" - "runtime" "testing" - "time" "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" - "google.golang.org/protobuf/types/known/durationpb" - "google.golang.org/protobuf/types/known/timestamppb" - - "github.com/netbirdio/netbird/client/proto" - "github.com/netbirdio/netbird/version" ) -func init() { - loc, err := time.LoadLocation("UTC") - if err != nil { - panic(err) - } - - time.Local = loc -} - -var resp = &proto.StatusResponse{ - Status: "Connected", - FullStatus: &proto.FullStatus{ - Peers: []*proto.PeerState{ - { - IP: "192.168.178.101", - PubKey: "Pubkey1", - Fqdn: "peer-1.awesome-domain.com", - ConnStatus: "Connected", - ConnStatusUpdate: timestamppb.New(time.Date(2001, time.Month(1), 1, 1, 1, 1, 0, time.UTC)), - Relayed: false, - LocalIceCandidateType: "", - RemoteIceCandidateType: "", - LocalIceCandidateEndpoint: "", - RemoteIceCandidateEndpoint: "", - LastWireguardHandshake: timestamppb.New(time.Date(2001, time.Month(1), 1, 1, 1, 2, 0, time.UTC)), - BytesRx: 200, - BytesTx: 100, - Networks: []string{ - "10.1.0.0/24", - }, - Latency: durationpb.New(time.Duration(10000000)), - }, - { - IP: "192.168.178.102", - PubKey: "Pubkey2", - Fqdn: "peer-2.awesome-domain.com", - ConnStatus: "Connected", - ConnStatusUpdate: timestamppb.New(time.Date(2002, time.Month(2), 2, 2, 2, 2, 0, time.UTC)), - Relayed: true, - LocalIceCandidateType: "relay", - RemoteIceCandidateType: "prflx", - LocalIceCandidateEndpoint: "10.0.0.1:10001", - RemoteIceCandidateEndpoint: "10.0.10.1:10002", - LastWireguardHandshake: timestamppb.New(time.Date(2002, time.Month(2), 2, 2, 2, 3, 0, time.UTC)), - BytesRx: 2000, - BytesTx: 1000, - Latency: durationpb.New(time.Duration(10000000)), - }, - }, - ManagementState: &proto.ManagementState{ - URL: "my-awesome-management.com:443", - Connected: true, - Error: "", - }, - SignalState: &proto.SignalState{ - URL: "my-awesome-signal.com:443", - Connected: true, - Error: "", - }, - Relays: []*proto.RelayState{ - { - URI: "stun:my-awesome-stun.com:3478", - Available: true, - Error: "", - }, - { - URI: "turns:my-awesome-turn.com:443?transport=tcp", - Available: false, - Error: "context: deadline exceeded", - }, - }, - LocalPeerState: &proto.LocalPeerState{ - IP: "192.168.178.100/16", - PubKey: "Some-Pub-Key", - KernelInterface: true, - Fqdn: "some-localhost.awesome-domain.com", - Networks: []string{ - "10.10.0.0/24", - }, - }, - DnsServers: []*proto.NSGroupState{ - { - Servers: []string{ - "8.8.8.8:53", - }, - Domains: nil, - Enabled: true, - Error: "", - }, - { - Servers: []string{ - "1.1.1.1:53", - "2.2.2.2:53", - }, - Domains: []string{ - "example.com", - "example.net", - }, - Enabled: false, - Error: "timeout", - }, - }, - }, - DaemonVersion: "0.14.1", -} - -var overview = statusOutputOverview{ - Peers: peersStateOutput{ - Total: 2, - Connected: 2, - Details: []peerStateDetailOutput{ - { - IP: "192.168.178.101", - PubKey: "Pubkey1", - FQDN: "peer-1.awesome-domain.com", - Status: "Connected", - LastStatusUpdate: time.Date(2001, 1, 1, 1, 1, 1, 0, time.UTC), - ConnType: "P2P", - IceCandidateType: iceCandidateType{ - Local: "", - Remote: "", - }, - IceCandidateEndpoint: iceCandidateType{ - Local: "", - Remote: "", - }, - LastWireguardHandshake: time.Date(2001, 1, 1, 1, 1, 2, 0, time.UTC), - TransferReceived: 200, - TransferSent: 100, - Networks: []string{ - "10.1.0.0/24", - }, - Latency: time.Duration(10000000), - }, - { - IP: "192.168.178.102", - PubKey: "Pubkey2", - FQDN: "peer-2.awesome-domain.com", - Status: "Connected", - LastStatusUpdate: time.Date(2002, 2, 2, 2, 2, 2, 0, time.UTC), - ConnType: "Relayed", - IceCandidateType: iceCandidateType{ - Local: "relay", - Remote: "prflx", - }, - IceCandidateEndpoint: iceCandidateType{ - Local: "10.0.0.1:10001", - Remote: "10.0.10.1:10002", - }, - LastWireguardHandshake: time.Date(2002, 2, 2, 2, 2, 3, 0, time.UTC), - TransferReceived: 2000, - TransferSent: 1000, - Latency: time.Duration(10000000), - }, - }, - }, - Events: []systemEventOutput{}, - CliVersion: version.NetbirdVersion(), - DaemonVersion: "0.14.1", - ManagementState: managementStateOutput{ - URL: "my-awesome-management.com:443", - Connected: true, - Error: "", - }, - SignalState: signalStateOutput{ - URL: "my-awesome-signal.com:443", - Connected: true, - Error: "", - }, - Relays: relayStateOutput{ - Total: 2, - Available: 1, - Details: []relayStateOutputDetail{ - { - URI: "stun:my-awesome-stun.com:3478", - Available: true, - Error: "", - }, - { - URI: "turns:my-awesome-turn.com:443?transport=tcp", - Available: false, - Error: "context: deadline exceeded", - }, - }, - }, - IP: "192.168.178.100/16", - PubKey: "Some-Pub-Key", - KernelInterface: true, - FQDN: "some-localhost.awesome-domain.com", - NSServerGroups: []nsServerGroupStateOutput{ - { - Servers: []string{ - "8.8.8.8:53", - }, - Domains: nil, - Enabled: true, - Error: "", - }, - { - Servers: []string{ - "1.1.1.1:53", - "2.2.2.2:53", - }, - Domains: []string{ - "example.com", - "example.net", - }, - Enabled: false, - Error: "timeout", - }, - }, - Networks: []string{ - "10.10.0.0/24", - }, -} - -func TestConversionFromFullStatusToOutputOverview(t *testing.T) { - convertedResult := convertToStatusOutputOverview(resp) - - assert.Equal(t, overview, convertedResult) -} - -func TestSortingOfPeers(t *testing.T) { - peers := []peerStateDetailOutput{ - { - IP: "192.168.178.104", - }, - { - IP: "192.168.178.102", - }, - { - IP: "192.168.178.101", - }, - { - IP: "192.168.178.105", - }, - { - IP: "192.168.178.103", - }, - } - - sortPeersByIP(peers) - - assert.Equal(t, peers[3].IP, "192.168.178.104") -} - -func TestParsingToJSON(t *testing.T) { - jsonString, _ := parseToJSON(overview) - - //@formatter:off - expectedJSONString := ` - { - "peers": { - "total": 2, - "connected": 2, - "details": [ - { - "fqdn": "peer-1.awesome-domain.com", - "netbirdIp": "192.168.178.101", - "publicKey": "Pubkey1", - "status": "Connected", - "lastStatusUpdate": "2001-01-01T01:01:01Z", - "connectionType": "P2P", - "iceCandidateType": { - "local": "", - "remote": "" - }, - "iceCandidateEndpoint": { - "local": "", - "remote": "" - }, - "relayAddress": "", - "lastWireguardHandshake": "2001-01-01T01:01:02Z", - "transferReceived": 200, - "transferSent": 100, - "latency": 10000000, - "quantumResistance": false, - "networks": [ - "10.1.0.0/24" - ] - }, - { - "fqdn": "peer-2.awesome-domain.com", - "netbirdIp": "192.168.178.102", - "publicKey": "Pubkey2", - "status": "Connected", - "lastStatusUpdate": "2002-02-02T02:02:02Z", - "connectionType": "Relayed", - "iceCandidateType": { - "local": "relay", - "remote": "prflx" - }, - "iceCandidateEndpoint": { - "local": "10.0.0.1:10001", - "remote": "10.0.10.1:10002" - }, - "relayAddress": "", - "lastWireguardHandshake": "2002-02-02T02:02:03Z", - "transferReceived": 2000, - "transferSent": 1000, - "latency": 10000000, - "quantumResistance": false, - "networks": null - } - ] - }, - "cliVersion": "development", - "daemonVersion": "0.14.1", - "management": { - "url": "my-awesome-management.com:443", - "connected": true, - "error": "" - }, - "signal": { - "url": "my-awesome-signal.com:443", - "connected": true, - "error": "" - }, - "relays": { - "total": 2, - "available": 1, - "details": [ - { - "uri": "stun:my-awesome-stun.com:3478", - "available": true, - "error": "" - }, - { - "uri": "turns:my-awesome-turn.com:443?transport=tcp", - "available": false, - "error": "context: deadline exceeded" - } - ] - }, - "netbirdIp": "192.168.178.100/16", - "publicKey": "Some-Pub-Key", - "usesKernelInterface": true, - "fqdn": "some-localhost.awesome-domain.com", - "quantumResistance": false, - "quantumResistancePermissive": false, - "networks": [ - "10.10.0.0/24" - ], - "forwardingRules": 0, - "dnsServers": [ - { - "servers": [ - "8.8.8.8:53" - ], - "domains": null, - "enabled": true, - "error": "" - }, - { - "servers": [ - "1.1.1.1:53", - "2.2.2.2:53" - ], - "domains": [ - "example.com", - "example.net" - ], - "enabled": false, - "error": "timeout" - } - ], - "events": [] - }` - // @formatter:on - - var expectedJSON bytes.Buffer - require.NoError(t, json.Compact(&expectedJSON, []byte(expectedJSONString))) - - assert.Equal(t, expectedJSON.String(), jsonString) -} - -func TestParsingToYAML(t *testing.T) { - yaml, _ := parseToYAML(overview) - - expectedYAML := - `peers: - total: 2 - connected: 2 - details: - - fqdn: peer-1.awesome-domain.com - netbirdIp: 192.168.178.101 - publicKey: Pubkey1 - status: Connected - lastStatusUpdate: 2001-01-01T01:01:01Z - connectionType: P2P - iceCandidateType: - local: "" - remote: "" - iceCandidateEndpoint: - local: "" - remote: "" - relayAddress: "" - lastWireguardHandshake: 2001-01-01T01:01:02Z - transferReceived: 200 - transferSent: 100 - latency: 10ms - quantumResistance: false - networks: - - 10.1.0.0/24 - - fqdn: peer-2.awesome-domain.com - netbirdIp: 192.168.178.102 - publicKey: Pubkey2 - status: Connected - lastStatusUpdate: 2002-02-02T02:02:02Z - connectionType: Relayed - iceCandidateType: - local: relay - remote: prflx - iceCandidateEndpoint: - local: 10.0.0.1:10001 - remote: 10.0.10.1:10002 - relayAddress: "" - lastWireguardHandshake: 2002-02-02T02:02:03Z - transferReceived: 2000 - transferSent: 1000 - latency: 10ms - quantumResistance: false - networks: [] -cliVersion: development -daemonVersion: 0.14.1 -management: - url: my-awesome-management.com:443 - connected: true - error: "" -signal: - url: my-awesome-signal.com:443 - connected: true - error: "" -relays: - total: 2 - available: 1 - details: - - uri: stun:my-awesome-stun.com:3478 - available: true - error: "" - - uri: turns:my-awesome-turn.com:443?transport=tcp - available: false - error: 'context: deadline exceeded' -netbirdIp: 192.168.178.100/16 -publicKey: Some-Pub-Key -usesKernelInterface: true -fqdn: some-localhost.awesome-domain.com -quantumResistance: false -quantumResistancePermissive: false -networks: - - 10.10.0.0/24 -forwardingRules: 0 -dnsServers: - - servers: - - 8.8.8.8:53 - domains: [] - enabled: true - error: "" - - servers: - - 1.1.1.1:53 - - 2.2.2.2:53 - domains: - - example.com - - example.net - enabled: false - error: timeout -events: [] -` - - assert.Equal(t, expectedYAML, yaml) -} - -func TestParsingToDetail(t *testing.T) { - // Calculate time ago based on the fixture dates - lastConnectionUpdate1 := timeAgo(overview.Peers.Details[0].LastStatusUpdate) - lastHandshake1 := timeAgo(overview.Peers.Details[0].LastWireguardHandshake) - lastConnectionUpdate2 := timeAgo(overview.Peers.Details[1].LastStatusUpdate) - lastHandshake2 := timeAgo(overview.Peers.Details[1].LastWireguardHandshake) - - detail := parseToFullDetailSummary(overview) - - expectedDetail := fmt.Sprintf( - `Peers detail: - peer-1.awesome-domain.com: - NetBird IP: 192.168.178.101 - Public key: Pubkey1 - Status: Connected - -- detail -- - Connection type: P2P - ICE candidate (Local/Remote): -/- - ICE candidate endpoints (Local/Remote): -/- - Relay server address: - Last connection update: %s - Last WireGuard handshake: %s - Transfer status (received/sent) 200 B/100 B - Quantum resistance: false - Networks: 10.1.0.0/24 - Latency: 10ms - - peer-2.awesome-domain.com: - NetBird IP: 192.168.178.102 - Public key: Pubkey2 - Status: Connected - -- detail -- - Connection type: Relayed - ICE candidate (Local/Remote): relay/prflx - ICE candidate endpoints (Local/Remote): 10.0.0.1:10001/10.0.10.1:10002 - Relay server address: - Last connection update: %s - Last WireGuard handshake: %s - Transfer status (received/sent) 2.0 KiB/1000 B - Quantum resistance: false - Networks: - - Latency: 10ms - -Events: No events recorded -OS: %s/%s -Daemon version: 0.14.1 -CLI version: %s -Management: Connected to my-awesome-management.com:443 -Signal: Connected to my-awesome-signal.com:443 -Relays: - [stun:my-awesome-stun.com:3478] is Available - [turns:my-awesome-turn.com:443?transport=tcp] is Unavailable, reason: context: deadline exceeded -Nameservers: - [8.8.8.8:53] for [.] is Available - [1.1.1.1:53, 2.2.2.2:53] for [example.com, example.net] is Unavailable, reason: timeout -FQDN: some-localhost.awesome-domain.com -NetBird IP: 192.168.178.100/16 -Interface type: Kernel -Quantum resistance: false -Networks: 10.10.0.0/24 -Forwarding rules: 0 -Peers count: 2/2 Connected -`, lastConnectionUpdate1, lastHandshake1, lastConnectionUpdate2, lastHandshake2, runtime.GOOS, runtime.GOARCH, overview.CliVersion) - - assert.Equal(t, expectedDetail, detail) -} - -func TestParsingToShortVersion(t *testing.T) { - shortVersion := parseGeneralSummary(overview, false, false, false) - - expectedString := fmt.Sprintf("OS: %s/%s", runtime.GOOS, runtime.GOARCH) + ` -Daemon version: 0.14.1 -CLI version: development -Management: Connected -Signal: Connected -Relays: 1/2 Available -Nameservers: 1/2 Available -FQDN: some-localhost.awesome-domain.com -NetBird IP: 192.168.178.100/16 -Interface type: Kernel -Quantum resistance: false -Networks: 10.10.0.0/24 -Forwarding rules: 0 -Peers count: 2/2 Connected -` - - assert.Equal(t, expectedString, shortVersion) -} - func TestParsingOfIP(t *testing.T) { InterfaceIP := "192.168.178.123/16" @@ -585,31 +13,3 @@ func TestParsingOfIP(t *testing.T) { assert.Equal(t, "192.168.178.123\n", parsedIP) } - -func TestTimeAgo(t *testing.T) { - now := time.Now() - - cases := []struct { - name string - input time.Time - expected string - }{ - {"Now", now, "Now"}, - {"Seconds ago", now.Add(-10 * time.Second), "10 seconds ago"}, - {"One minute ago", now.Add(-1 * time.Minute), "1 minute ago"}, - {"Minutes and seconds ago", now.Add(-(1*time.Minute + 30*time.Second)), "1 minute, 30 seconds ago"}, - {"One hour ago", now.Add(-1 * time.Hour), "1 hour ago"}, - {"Hours and minutes ago", now.Add(-(2*time.Hour + 15*time.Minute)), "2 hours, 15 minutes ago"}, - {"One day ago", now.Add(-24 * time.Hour), "1 day ago"}, - {"Multiple days ago", now.Add(-(72*time.Hour + 20*time.Minute)), "3 days ago"}, - {"Zero time", time.Time{}, "-"}, - {"Unix zero time", time.Unix(0, 0), "-"}, - } - - for _, tc := range cases { - t.Run(tc.name, func(t *testing.T) { - result := timeAgo(tc.input) - assert.Equal(t, tc.expected, result, "Failed %s", tc.name) - }) - } -} diff --git a/client/cmd/testutil_test.go b/client/cmd/testutil_test.go index ee67d2501..4c06a7da0 100644 --- a/client/cmd/testutil_test.go +++ b/client/cmd/testutil_test.go @@ -96,7 +96,7 @@ func startManagement(t *testing.T, config *mgmt.Config, testFile string) (*grpc. } secretsManager := mgmt.NewTimeBasedAuthSecretsManager(peersUpdateManager, config.TURNConfig, config.Relay) - mgmtServer, err := mgmt.NewServer(context.Background(), config, accountManager, settings.NewManager(store), peersUpdateManager, secretsManager, nil, nil) + mgmtServer, err := mgmt.NewServer(context.Background(), config, accountManager, settings.NewManager(store), peersUpdateManager, secretsManager, nil, nil, nil) if err != nil { t.Fatal(err) } diff --git a/client/firewall/uspfilter/forwarder/udp.go b/client/firewall/uspfilter/forwarder/udp.go index 97e4662fd..c37740587 100644 --- a/client/firewall/uspfilter/forwarder/udp.go +++ b/client/firewall/uspfilter/forwarder/udp.go @@ -245,33 +245,29 @@ func (c *udpPacketConn) copy(ctx context.Context, dst net.Conn, src net.Conn, bu defer bufPool.Put(bufp) buffer := *bufp - if err := src.SetReadDeadline(time.Now().Add(udpTimeout)); err != nil { - return fmt.Errorf("set read deadline: %w", err) - } - if err := src.SetWriteDeadline(time.Now().Add(udpTimeout)); err != nil { - return fmt.Errorf("set write deadline: %w", err) - } - for { - select { - case <-ctx.Done(): + if ctx.Err() != nil { return ctx.Err() - default: - n, err := src.Read(buffer) - if err != nil { - if isTimeout(err) { - continue - } - return fmt.Errorf("read from %s: %w", direction, err) - } - - _, err = dst.Write(buffer[:n]) - if err != nil { - return fmt.Errorf("write to %s: %w", direction, err) - } - - c.updateLastSeen() } + + if err := src.SetDeadline(time.Now().Add(udpTimeout)); err != nil { + return fmt.Errorf("set read deadline: %w", err) + } + + n, err := src.Read(buffer) + if err != nil { + if isTimeout(err) { + continue + } + return fmt.Errorf("read from %s: %w", direction, err) + } + + _, err = dst.Write(buffer[:n]) + if err != nil { + return fmt.Errorf("write to %s: %w", direction, err) + } + + c.updateLastSeen() } } diff --git a/client/iface/bind/udp_mux.go b/client/iface/bind/udp_mux.go index 00a91f0ec..4c827de95 100644 --- a/client/iface/bind/udp_mux.go +++ b/client/iface/bind/udp_mux.go @@ -4,6 +4,7 @@ import ( "fmt" "io" "net" + "slices" "strings" "sync" @@ -152,46 +153,7 @@ func NewUDPMuxDefault(params UDPMuxParams) *UDPMuxDefault { params.Logger = logging.NewDefaultLoggerFactory().NewLogger("ice") } - var localAddrsForUnspecified []net.Addr - if addr, ok := params.UDPConn.LocalAddr().(*net.UDPAddr); !ok { - params.Logger.Errorf("LocalAddr is not a net.UDPAddr, got %T", params.UDPConn.LocalAddr()) - } else if ok && addr.IP.IsUnspecified() { - // For unspecified addresses, the correct behavior is to return errListenUnspecified, but - // it will break the applications that are already using unspecified UDP connection - // with UDPMuxDefault, so print a warn log and create a local address list for mux. - params.Logger.Warn("UDPMuxDefault should not listening on unspecified address, use NewMultiUDPMuxFromPort instead") - var networks []ice.NetworkType - switch { - - case addr.IP.To16() != nil: - networks = []ice.NetworkType{ice.NetworkTypeUDP4, ice.NetworkTypeUDP6} - - case addr.IP.To4() != nil: - networks = []ice.NetworkType{ice.NetworkTypeUDP4} - - default: - params.Logger.Errorf("LocalAddr expected IPV4 or IPV6, got %T", params.UDPConn.LocalAddr()) - } - if len(networks) > 0 { - if params.Net == nil { - var err error - if params.Net, err = stdnet.NewNet(); err != nil { - params.Logger.Errorf("failed to get create network: %v", err) - } - } - - ips, err := localInterfaces(params.Net, params.InterfaceFilter, nil, networks, true) - if err == nil { - for _, ip := range ips { - localAddrsForUnspecified = append(localAddrsForUnspecified, &net.UDPAddr{IP: ip, Port: addr.Port}) - } - } else { - params.Logger.Errorf("failed to get local interfaces for unspecified addr: %v", err) - } - } - } - - return &UDPMuxDefault{ + mux := &UDPMuxDefault{ addressMap: map[string][]*udpMuxedConn{}, params: params, connsIPv4: make(map[string]*udpMuxedConn), @@ -203,8 +165,55 @@ func NewUDPMuxDefault(params UDPMuxParams) *UDPMuxDefault { return newBufferHolder(receiveMTU + maxAddrSize) }, }, - localAddrsForUnspecified: localAddrsForUnspecified, } + + mux.updateLocalAddresses() + return mux +} + +func (m *UDPMuxDefault) updateLocalAddresses() { + var localAddrsForUnspecified []net.Addr + if addr, ok := m.params.UDPConn.LocalAddr().(*net.UDPAddr); !ok { + m.params.Logger.Errorf("LocalAddr is not a net.UDPAddr, got %T", m.params.UDPConn.LocalAddr()) + } else if ok && addr.IP.IsUnspecified() { + // For unspecified addresses, the correct behavior is to return errListenUnspecified, but + // it will break the applications that are already using unspecified UDP connection + // with UDPMuxDefault, so print a warn log and create a local address list for mux. + m.params.Logger.Warn("UDPMuxDefault should not listening on unspecified address, use NewMultiUDPMuxFromPort instead") + var networks []ice.NetworkType + switch { + + case addr.IP.To16() != nil: + networks = []ice.NetworkType{ice.NetworkTypeUDP4, ice.NetworkTypeUDP6} + + case addr.IP.To4() != nil: + networks = []ice.NetworkType{ice.NetworkTypeUDP4} + + default: + m.params.Logger.Errorf("LocalAddr expected IPV4 or IPV6, got %T", m.params.UDPConn.LocalAddr()) + } + if len(networks) > 0 { + if m.params.Net == nil { + var err error + if m.params.Net, err = stdnet.NewNet(); err != nil { + m.params.Logger.Errorf("failed to get create network: %v", err) + } + } + + ips, err := localInterfaces(m.params.Net, m.params.InterfaceFilter, nil, networks, true) + if err == nil { + for _, ip := range ips { + localAddrsForUnspecified = append(localAddrsForUnspecified, &net.UDPAddr{IP: ip, Port: addr.Port}) + } + } else { + m.params.Logger.Errorf("failed to get local interfaces for unspecified addr: %v", err) + } + } + } + + m.mu.Lock() + m.localAddrsForUnspecified = localAddrsForUnspecified + m.mu.Unlock() } // LocalAddr returns the listening address of this UDPMuxDefault @@ -214,8 +223,12 @@ func (m *UDPMuxDefault) LocalAddr() net.Addr { // GetListenAddresses returns the list of addresses that this mux is listening on func (m *UDPMuxDefault) GetListenAddresses() []net.Addr { + m.updateLocalAddresses() + + m.mu.Lock() + defer m.mu.Unlock() if len(m.localAddrsForUnspecified) > 0 { - return m.localAddrsForUnspecified + return slices.Clone(m.localAddrsForUnspecified) } return []net.Addr{m.LocalAddr()} @@ -225,7 +238,10 @@ func (m *UDPMuxDefault) GetListenAddresses() []net.Addr { // creates the connection if an existing one can't be found func (m *UDPMuxDefault) GetConn(ufrag string, addr net.Addr) (net.PacketConn, error) { // don't check addr for mux using unspecified address - if len(m.localAddrsForUnspecified) == 0 && m.params.UDPConn.LocalAddr().String() != addr.String() { + m.mu.Lock() + lenLocalAddrs := len(m.localAddrsForUnspecified) + m.mu.Unlock() + if lenLocalAddrs == 0 && m.params.UDPConn.LocalAddr().String() != addr.String() { return nil, fmt.Errorf("invalid address %s", addr.String()) } diff --git a/client/iface/configurer/kernel_unix.go b/client/iface/configurer/kernel_unix.go index 7c1c41669..6f09a63c9 100644 --- a/client/iface/configurer/kernel_unix.go +++ b/client/iface/configurer/kernel_unix.go @@ -43,13 +43,7 @@ func (c *KernelConfigurer) ConfigureInterface(privateKey string, port int) error return nil } -func (c *KernelConfigurer) UpdatePeer(peerKey string, allowedIps string, keepAlive time.Duration, endpoint *net.UDPAddr, preSharedKey *wgtypes.Key) error { - // parse allowed ips - _, ipNet, err := net.ParseCIDR(allowedIps) - if err != nil { - return err - } - +func (c *KernelConfigurer) UpdatePeer(peerKey string, allowedIps []net.IPNet, keepAlive time.Duration, endpoint *net.UDPAddr, preSharedKey *wgtypes.Key) error { peerKeyParsed, err := wgtypes.ParseKey(peerKey) if err != nil { return err @@ -58,7 +52,7 @@ func (c *KernelConfigurer) UpdatePeer(peerKey string, allowedIps string, keepAli PublicKey: peerKeyParsed, ReplaceAllowedIPs: false, // don't replace allowed ips, wg will handle duplicated peer IP - AllowedIPs: []net.IPNet{*ipNet}, + AllowedIPs: allowedIps, PersistentKeepaliveInterval: &keepAlive, Endpoint: endpoint, PresharedKey: preSharedKey, diff --git a/client/iface/configurer/usp.go b/client/iface/configurer/usp.go index 391269dd0..a3de58c24 100644 --- a/client/iface/configurer/usp.go +++ b/client/iface/configurer/usp.go @@ -52,13 +52,7 @@ func (c *WGUSPConfigurer) ConfigureInterface(privateKey string, port int) error return c.device.IpcSet(toWgUserspaceString(config)) } -func (c *WGUSPConfigurer) UpdatePeer(peerKey string, allowedIps string, keepAlive time.Duration, endpoint *net.UDPAddr, preSharedKey *wgtypes.Key) error { - // parse allowed ips - _, ipNet, err := net.ParseCIDR(allowedIps) - if err != nil { - return err - } - +func (c *WGUSPConfigurer) UpdatePeer(peerKey string, allowedIps []net.IPNet, keepAlive time.Duration, endpoint *net.UDPAddr, preSharedKey *wgtypes.Key) error { peerKeyParsed, err := wgtypes.ParseKey(peerKey) if err != nil { return err @@ -67,7 +61,7 @@ func (c *WGUSPConfigurer) UpdatePeer(peerKey string, allowedIps string, keepAliv PublicKey: peerKeyParsed, ReplaceAllowedIPs: false, // don't replace allowed ips, wg will handle duplicated peer IP - AllowedIPs: []net.IPNet{*ipNet}, + AllowedIPs: allowedIps, PersistentKeepaliveInterval: &keepAlive, PresharedKey: preSharedKey, Endpoint: endpoint, diff --git a/client/iface/device/interface.go b/client/iface/device/interface.go index 0196b0085..6971b6946 100644 --- a/client/iface/device/interface.go +++ b/client/iface/device/interface.go @@ -11,7 +11,7 @@ import ( type WGConfigurer interface { ConfigureInterface(privateKey string, port int) error - UpdatePeer(peerKey string, allowedIps string, keepAlive time.Duration, endpoint *net.UDPAddr, preSharedKey *wgtypes.Key) error + UpdatePeer(peerKey string, allowedIps []net.IPNet, keepAlive time.Duration, endpoint *net.UDPAddr, preSharedKey *wgtypes.Key) error RemovePeer(peerKey string) error AddAllowedIP(peerKey string, allowedIP string) error RemoveAllowedIP(peerKey string, allowedIP string) error diff --git a/client/iface/iface.go b/client/iface/iface.go index 8056dd9a6..40bd51fbb 100644 --- a/client/iface/iface.go +++ b/client/iface/iface.go @@ -3,6 +3,7 @@ package iface import ( "fmt" "net" + "net/netip" "sync" "time" @@ -112,12 +113,13 @@ func (w *WGIface) UpdateAddr(newAddr string) error { // UpdatePeer updates existing Wireguard Peer or creates a new one if doesn't exist // Endpoint is optional -func (w *WGIface) UpdatePeer(peerKey string, allowedIps string, keepAlive time.Duration, endpoint *net.UDPAddr, preSharedKey *wgtypes.Key) error { +func (w *WGIface) UpdatePeer(peerKey string, allowedIps []netip.Prefix, keepAlive time.Duration, endpoint *net.UDPAddr, preSharedKey *wgtypes.Key) error { w.mu.Lock() defer w.mu.Unlock() + netIPNets := prefixesToIPNets(allowedIps) log.Debugf("updating interface %s peer %s, endpoint %s", w.tun.DeviceName(), peerKey, endpoint) - return w.configurer.UpdatePeer(peerKey, allowedIps, keepAlive, endpoint, preSharedKey) + return w.configurer.UpdatePeer(peerKey, netIPNets, keepAlive, endpoint, preSharedKey) } // RemovePeer removes a Wireguard Peer from the interface iface @@ -250,3 +252,14 @@ func (w *WGIface) GetNet() *netstack.Net { return w.tun.GetNet() } + +func prefixesToIPNets(prefixes []netip.Prefix) []net.IPNet { + ipNets := make([]net.IPNet, len(prefixes)) + for i, prefix := range prefixes { + ipNets[i] = net.IPNet{ + IP: net.IP(prefix.Addr().AsSlice()), // Convert netip.Addr to net.IP + Mask: net.CIDRMask(prefix.Bits(), prefix.Addr().BitLen()), // Create subnet mask + } + } + return ipNets +} diff --git a/client/iface/iface_moc.go b/client/iface/iface_moc.go deleted file mode 100644 index f92a8cfc8..000000000 --- a/client/iface/iface_moc.go +++ /dev/null @@ -1,123 +0,0 @@ -package iface - -import ( - "net" - "time" - - wgdevice "golang.zx2c4.com/wireguard/device" - "golang.zx2c4.com/wireguard/tun/netstack" - "golang.zx2c4.com/wireguard/wgctrl/wgtypes" - - "github.com/netbirdio/netbird/client/iface/bind" - "github.com/netbirdio/netbird/client/iface/configurer" - "github.com/netbirdio/netbird/client/iface/device" - "github.com/netbirdio/netbird/client/iface/wgproxy" -) - -type MockWGIface struct { - CreateFunc func() error - CreateOnAndroidFunc func(routeRange []string, ip string, domains []string) error - IsUserspaceBindFunc func() bool - NameFunc func() string - AddressFunc func() device.WGAddress - ToInterfaceFunc func() *net.Interface - UpFunc func() (*bind.UniversalUDPMuxDefault, error) - UpdateAddrFunc func(newAddr string) error - UpdatePeerFunc func(peerKey string, allowedIps string, keepAlive time.Duration, endpoint *net.UDPAddr, preSharedKey *wgtypes.Key) error - RemovePeerFunc func(peerKey string) error - AddAllowedIPFunc func(peerKey string, allowedIP string) error - RemoveAllowedIPFunc func(peerKey string, allowedIP string) error - CloseFunc func() error - SetFilterFunc func(filter device.PacketFilter) error - GetFilterFunc func() device.PacketFilter - GetDeviceFunc func() *device.FilteredDevice - GetWGDeviceFunc func() *wgdevice.Device - GetStatsFunc func(peerKey string) (configurer.WGStats, error) - GetInterfaceGUIDStringFunc func() (string, error) - GetProxyFunc func() wgproxy.Proxy - GetNetFunc func() *netstack.Net -} - -func (m *MockWGIface) GetInterfaceGUIDString() (string, error) { - return m.GetInterfaceGUIDStringFunc() -} - -func (m *MockWGIface) Create() error { - return m.CreateFunc() -} - -func (m *MockWGIface) CreateOnAndroid(routeRange []string, ip string, domains []string) error { - return m.CreateOnAndroidFunc(routeRange, ip, domains) -} - -func (m *MockWGIface) IsUserspaceBind() bool { - return m.IsUserspaceBindFunc() -} - -func (m *MockWGIface) Name() string { - return m.NameFunc() -} - -func (m *MockWGIface) Address() device.WGAddress { - return m.AddressFunc() -} - -func (m *MockWGIface) ToInterface() *net.Interface { - return m.ToInterfaceFunc() -} - -func (m *MockWGIface) Up() (*bind.UniversalUDPMuxDefault, error) { - return m.UpFunc() -} - -func (m *MockWGIface) UpdateAddr(newAddr string) error { - return m.UpdateAddrFunc(newAddr) -} - -func (m *MockWGIface) UpdatePeer(peerKey string, allowedIps string, keepAlive time.Duration, endpoint *net.UDPAddr, preSharedKey *wgtypes.Key) error { - return m.UpdatePeerFunc(peerKey, allowedIps, keepAlive, endpoint, preSharedKey) -} - -func (m *MockWGIface) RemovePeer(peerKey string) error { - return m.RemovePeerFunc(peerKey) -} - -func (m *MockWGIface) AddAllowedIP(peerKey string, allowedIP string) error { - return m.AddAllowedIPFunc(peerKey, allowedIP) -} - -func (m *MockWGIface) RemoveAllowedIP(peerKey string, allowedIP string) error { - return m.RemoveAllowedIPFunc(peerKey, allowedIP) -} - -func (m *MockWGIface) Close() error { - return m.CloseFunc() -} - -func (m *MockWGIface) SetFilter(filter device.PacketFilter) error { - return m.SetFilterFunc(filter) -} - -func (m *MockWGIface) GetFilter() device.PacketFilter { - return m.GetFilterFunc() -} - -func (m *MockWGIface) GetDevice() *device.FilteredDevice { - return m.GetDeviceFunc() -} - -func (m *MockWGIface) GetWGDevice() *wgdevice.Device { - return m.GetWGDeviceFunc() -} - -func (m *MockWGIface) GetStats(peerKey string) (configurer.WGStats, error) { - return m.GetStatsFunc(peerKey) -} - -func (m *MockWGIface) GetProxy() wgproxy.Proxy { - return m.GetProxyFunc() -} - -func (m *MockWGIface) GetNet() *netstack.Net { - return m.GetNetFunc() -} diff --git a/client/iface/iface_test.go b/client/iface/iface_test.go index 85db9cacb..e890b30f3 100644 --- a/client/iface/iface_test.go +++ b/client/iface/iface_test.go @@ -373,12 +373,12 @@ func Test_UpdatePeer(t *testing.T) { t.Fatal(err) } keepAlive := 15 * time.Second - allowedIP := "10.99.99.10/32" + allowedIP := netip.MustParsePrefix("10.99.99.10/32") endpoint, err := net.ResolveUDPAddr("udp", "127.0.0.1:9900") if err != nil { t.Fatal(err) } - err = iface.UpdatePeer(peerPubKey, allowedIP, keepAlive, endpoint, nil) + err = iface.UpdatePeer(peerPubKey, []netip.Prefix{allowedIP}, keepAlive, endpoint, nil) if err != nil { t.Fatal(err) } @@ -396,7 +396,7 @@ func Test_UpdatePeer(t *testing.T) { var foundAllowedIP bool for _, aip := range peer.AllowedIPs { - if aip.String() == allowedIP { + if aip.String() == allowedIP.String() { foundAllowedIP = true break } @@ -443,9 +443,8 @@ func Test_RemovePeer(t *testing.T) { t.Fatal(err) } keepAlive := 15 * time.Second - allowedIP := "10.99.99.14/32" - - err = iface.UpdatePeer(peerPubKey, allowedIP, keepAlive, nil, nil) + allowedIP := netip.MustParsePrefix("10.99.99.14/32") + err = iface.UpdatePeer(peerPubKey, []netip.Prefix{allowedIP}, keepAlive, nil, nil) if err != nil { t.Fatal(err) } @@ -462,12 +461,12 @@ func Test_RemovePeer(t *testing.T) { func Test_ConnectPeers(t *testing.T) { peer1ifaceName := fmt.Sprintf("utun%d", WgIntNumber+400) - peer1wgIP := "10.99.99.17/30" + peer1wgIP := netip.MustParsePrefix("10.99.99.17/30") peer1Key, _ := wgtypes.GeneratePrivateKey() peer1wgPort := 33100 peer2ifaceName := "utun500" - peer2wgIP := "10.99.99.18/30" + peer2wgIP := netip.MustParsePrefix("10.99.99.18/30") peer2Key, _ := wgtypes.GeneratePrivateKey() peer2wgPort := 33200 @@ -482,7 +481,7 @@ func Test_ConnectPeers(t *testing.T) { optsPeer1 := WGIFaceOpts{ IFaceName: peer1ifaceName, - Address: peer1wgIP, + Address: peer1wgIP.String(), WGPort: peer1wgPort, WGPrivKey: peer1Key.String(), MTU: DefaultMTU, @@ -522,7 +521,7 @@ func Test_ConnectPeers(t *testing.T) { optsPeer2 := WGIFaceOpts{ IFaceName: peer2ifaceName, - Address: peer2wgIP, + Address: peer2wgIP.String(), WGPort: peer2wgPort, WGPrivKey: peer2Key.String(), MTU: DefaultMTU, @@ -558,11 +557,11 @@ func Test_ConnectPeers(t *testing.T) { } }() - err = iface1.UpdatePeer(peer2Key.PublicKey().String(), peer2wgIP, keepAlive, peer2endpoint, nil) + err = iface1.UpdatePeer(peer2Key.PublicKey().String(), []netip.Prefix{peer2wgIP}, keepAlive, peer2endpoint, nil) if err != nil { t.Fatal(err) } - err = iface2.UpdatePeer(peer1Key.PublicKey().String(), peer1wgIP, keepAlive, peer1endpoint, nil) + err = iface2.UpdatePeer(peer1Key.PublicKey().String(), []netip.Prefix{peer1wgIP}, keepAlive, peer1endpoint, nil) if err != nil { t.Fatal(err) } diff --git a/client/iface/iwginterface_windows.go b/client/iface/iwginterface_windows.go deleted file mode 100644 index cac096b54..000000000 --- a/client/iface/iwginterface_windows.go +++ /dev/null @@ -1,39 +0,0 @@ -package iface - -import ( - "net" - "time" - - wgdevice "golang.zx2c4.com/wireguard/device" - "golang.zx2c4.com/wireguard/tun/netstack" - "golang.zx2c4.com/wireguard/wgctrl/wgtypes" - - "github.com/netbirdio/netbird/client/iface/bind" - "github.com/netbirdio/netbird/client/iface/configurer" - "github.com/netbirdio/netbird/client/iface/device" - "github.com/netbirdio/netbird/client/iface/wgproxy" -) - -type IWGIface interface { - Create() error - CreateOnAndroid(routeRange []string, ip string, domains []string) error - IsUserspaceBind() bool - Name() string - Address() device.WGAddress - ToInterface() *net.Interface - Up() (*bind.UniversalUDPMuxDefault, error) - UpdateAddr(newAddr string) error - GetProxy() wgproxy.Proxy - UpdatePeer(peerKey string, allowedIps string, keepAlive time.Duration, endpoint *net.UDPAddr, preSharedKey *wgtypes.Key) error - RemovePeer(peerKey string) error - AddAllowedIP(peerKey string, allowedIP string) error - RemoveAllowedIP(peerKey string, allowedIP string) error - Close() error - SetFilter(filter device.PacketFilter) error - GetFilter() device.PacketFilter - GetDevice() *device.FilteredDevice - GetWGDevice() *wgdevice.Device - GetStats(peerKey string) (configurer.WGStats, error) - GetInterfaceGUIDString() (string, error) - GetNet() *netstack.Net -} diff --git a/client/internal/config.go b/client/internal/config.go index b269a3854..b2f96cbdc 100644 --- a/client/internal/config.go +++ b/client/internal/config.go @@ -99,7 +99,7 @@ type Config struct { BlockLANAccess bool - DisableNotifications bool + DisableNotifications *bool DNSLabels domain.List @@ -479,13 +479,20 @@ func (config *Config) apply(input ConfigInput) (updated bool, err error) { updated = true } - if input.DisableNotifications != nil && *input.DisableNotifications != config.DisableNotifications { + if input.DisableNotifications != nil && input.DisableNotifications != config.DisableNotifications { if *input.DisableNotifications { log.Infof("disabling notifications") } else { log.Infof("enabling notifications") } - config.DisableNotifications = *input.DisableNotifications + config.DisableNotifications = input.DisableNotifications + updated = true + } + + if config.DisableNotifications == nil { + disabled := true + config.DisableNotifications = &disabled + log.Infof("setting notifications to disabled by default") updated = true } diff --git a/client/internal/connect.go b/client/internal/connect.go index 26ae3b687..bf513ed39 100644 --- a/client/internal/connect.go +++ b/client/internal/connect.go @@ -23,6 +23,7 @@ import ( "github.com/netbirdio/netbird/client/internal/listener" "github.com/netbirdio/netbird/client/internal/peer" "github.com/netbirdio/netbird/client/internal/stdnet" + cProto "github.com/netbirdio/netbird/client/proto" "github.com/netbirdio/netbird/client/ssh" "github.com/netbirdio/netbird/client/system" mgm "github.com/netbirdio/netbird/management/client" @@ -104,6 +105,16 @@ func (c *ConnectClient) RunOniOS( func (c *ConnectClient) run(mobileDependency MobileDependency, runningChan chan error) error { defer func() { if r := recover(); r != nil { + rec := c.statusRecorder + if rec != nil { + rec.PublishEvent( + cProto.SystemEvent_CRITICAL, cProto.SystemEvent_SYSTEM, + "panic occurred", + "The Netbird service panicked. Please restart the service and submit a bug report with the client logs.", + nil, + ) + } + log.Panicf("Panic occurred: %v, stack trace: %s", r, string(debug.Stack())) } }() diff --git a/client/internal/dns.go b/client/internal/dns.go new file mode 100644 index 000000000..8a73f50f2 --- /dev/null +++ b/client/internal/dns.go @@ -0,0 +1,111 @@ +package internal + +import ( + "fmt" + "net" + "slices" + "strings" + + "github.com/miekg/dns" + log "github.com/sirupsen/logrus" + + nbdns "github.com/netbirdio/netbird/dns" +) + +func createPTRRecord(aRecord nbdns.SimpleRecord, ipNet *net.IPNet) (nbdns.SimpleRecord, bool) { + ip := net.ParseIP(aRecord.RData) + if ip == nil || ip.To4() == nil { + return nbdns.SimpleRecord{}, false + } + + if !ipNet.Contains(ip) { + return nbdns.SimpleRecord{}, false + } + + ipOctets := strings.Split(ip.String(), ".") + slices.Reverse(ipOctets) + rdnsName := dns.Fqdn(strings.Join(ipOctets, ".") + ".in-addr.arpa") + + return nbdns.SimpleRecord{ + Name: rdnsName, + Type: int(dns.TypePTR), + Class: aRecord.Class, + TTL: aRecord.TTL, + RData: dns.Fqdn(aRecord.Name), + }, true +} + +// generateReverseZoneName creates the reverse DNS zone name for a given network +func generateReverseZoneName(ipNet *net.IPNet) (string, error) { + networkIP := ipNet.IP.Mask(ipNet.Mask) + maskOnes, _ := ipNet.Mask.Size() + + // round up to nearest byte + octetsToUse := (maskOnes + 7) / 8 + + octets := strings.Split(networkIP.String(), ".") + if octetsToUse > len(octets) { + return "", fmt.Errorf("invalid network mask size for reverse DNS: %d", maskOnes) + } + + reverseOctets := make([]string, octetsToUse) + for i := 0; i < octetsToUse; i++ { + reverseOctets[octetsToUse-1-i] = octets[i] + } + + return dns.Fqdn(strings.Join(reverseOctets, ".") + ".in-addr.arpa"), nil +} + +// zoneExists checks if a zone with the given name already exists in the configuration +func zoneExists(config *nbdns.Config, zoneName string) bool { + for _, zone := range config.CustomZones { + if zone.Domain == zoneName { + log.Debugf("reverse DNS zone %s already exists", zoneName) + return true + } + } + return false +} + +// collectPTRRecords gathers all PTR records for the given network from A records +func collectPTRRecords(config *nbdns.Config, ipNet *net.IPNet) []nbdns.SimpleRecord { + var records []nbdns.SimpleRecord + + for _, zone := range config.CustomZones { + for _, record := range zone.Records { + if record.Type != int(dns.TypeA) { + continue + } + + if ptrRecord, ok := createPTRRecord(record, ipNet); ok { + records = append(records, ptrRecord) + } + } + } + + return records +} + +// addReverseZone adds a reverse DNS zone to the configuration for the given network +func addReverseZone(config *nbdns.Config, ipNet *net.IPNet) { + zoneName, err := generateReverseZoneName(ipNet) + if err != nil { + log.Warn(err) + return + } + + if zoneExists(config, zoneName) { + log.Debugf("reverse DNS zone %s already exists", zoneName) + return + } + + records := collectPTRRecords(config, ipNet) + + reverseZone := nbdns.CustomZone{ + Domain: zoneName, + Records: records, + } + + config.CustomZones = append(config.CustomZones, reverseZone) + log.Debugf("added reverse DNS zone: %s with %d records", zoneName, len(records)) +} diff --git a/client/internal/dns/file_unix.go b/client/internal/dns/file_unix.go index 02ae26e10..1f4ddb67c 100644 --- a/client/internal/dns/file_unix.go +++ b/client/internal/dns/file_unix.go @@ -58,7 +58,7 @@ func (f *fileConfigurator) applyDNSConfig(config HostDNSConfig, stateManager *st return fmt.Errorf("restoring the original resolv.conf file return err: %w", err) } } - return fmt.Errorf("unable to configure DNS for this peer using file manager without a nameserver group with all domains configured") + return ErrRouteAllWithoutNameserverGroup } if !backupFileExist { @@ -121,6 +121,10 @@ func (f *fileConfigurator) restoreHostDNS() error { return f.restore() } +func (f *fileConfigurator) string() string { + return "file" +} + func (f *fileConfigurator) backup() error { stats, err := os.Stat(defaultResolvConfPath) if err != nil { diff --git a/client/internal/dns/host.go b/client/internal/dns/host.go index fbe8c4dbb..25e9ff7e5 100644 --- a/client/internal/dns/host.go +++ b/client/internal/dns/host.go @@ -9,10 +9,18 @@ import ( nbdns "github.com/netbirdio/netbird/dns" ) +var ErrRouteAllWithoutNameserverGroup = fmt.Errorf("unable to configure DNS for this peer using file manager without a nameserver group with all domains configured") + +const ( + ipv4ReverseZone = ".in-addr.arpa" + ipv6ReverseZone = ".ip6.arpa" +) + type hostManager interface { applyDNSConfig(config HostDNSConfig, stateManager *statemanager.Manager) error restoreHostDNS() error supportCustomPort() bool + string() string } type SystemDNSSettings struct { @@ -39,6 +47,7 @@ type mockHostConfigurator struct { restoreHostDNSFunc func() error supportCustomPortFunc func() bool restoreUncleanShutdownDNSFunc func(*netip.Addr) error + stringFunc func() string } func (m *mockHostConfigurator) applyDNSConfig(config HostDNSConfig, stateManager *statemanager.Manager) error { @@ -62,6 +71,13 @@ func (m *mockHostConfigurator) supportCustomPort() bool { return false } +func (m *mockHostConfigurator) string() string { + if m.stringFunc != nil { + return m.stringFunc() + } + return "mock" +} + func newNoopHostMocker() hostManager { return &mockHostConfigurator{ applyDNSConfigFunc: func(config HostDNSConfig, stateManager *statemanager.Manager) error { return nil }, @@ -94,9 +110,10 @@ func dnsConfigToHostDNSConfig(dnsConfig nbdns.Config, ip string, port int) HostD } for _, customZone := range dnsConfig.CustomZones { + matchOnly := strings.HasSuffix(customZone.Domain, ipv4ReverseZone) || strings.HasSuffix(customZone.Domain, ipv6ReverseZone) config.Domains = append(config.Domains, DomainConfig{ Domain: strings.TrimSuffix(customZone.Domain, "."), - MatchOnly: false, + MatchOnly: matchOnly, }) } @@ -116,3 +133,7 @@ func (n noopHostConfigurator) restoreHostDNS() error { func (n noopHostConfigurator) supportCustomPort() bool { return true } + +func (n noopHostConfigurator) string() string { + return "noop" +} diff --git a/client/internal/dns/host_android.go b/client/internal/dns/host_android.go index 5653710d7..dfa3e5712 100644 --- a/client/internal/dns/host_android.go +++ b/client/internal/dns/host_android.go @@ -22,3 +22,7 @@ func (a androidHostManager) restoreHostDNS() error { func (a androidHostManager) supportCustomPort() bool { return false } + +func (a androidHostManager) string() string { + return "none" +} diff --git a/client/internal/dns/host_darwin.go b/client/internal/dns/host_darwin.go index 2f92dd367..f727f68b5 100644 --- a/client/internal/dns/host_darwin.go +++ b/client/internal/dns/host_darwin.go @@ -114,6 +114,10 @@ func (s *systemConfigurator) applyDNSConfig(config HostDNSConfig, stateManager * return nil } +func (s *systemConfigurator) string() string { + return "scutil" +} + func (s *systemConfigurator) restoreHostDNS() error { keys := s.getRemovableKeysWithDefaults() for _, key := range keys { diff --git a/client/internal/dns/host_ios.go b/client/internal/dns/host_ios.go index 4a0acf572..1c0ac63e9 100644 --- a/client/internal/dns/host_ios.go +++ b/client/internal/dns/host_ios.go @@ -38,3 +38,7 @@ func (a iosHostManager) restoreHostDNS() error { func (a iosHostManager) supportCustomPort() bool { return false } + +func (a iosHostManager) string() string { + return "none" +} diff --git a/client/internal/dns/host_windows.go b/client/internal/dns/host_windows.go index 58b0a14de..dceb24420 100644 --- a/client/internal/dns/host_windows.go +++ b/client/internal/dns/host_windows.go @@ -184,6 +184,10 @@ func (r *registryConfigurator) configureDNSPolicy(policyPath string, domains []s return nil } +func (r *registryConfigurator) string() string { + return "registry" +} + func (r *registryConfigurator) updateSearchDomains(domains []string) error { if err := r.setInterfaceRegistryKeyStringValue(interfaceConfigSearchListKey, strings.Join(domains, ",")); err != nil { return fmt.Errorf("update search domains: %w", err) diff --git a/client/internal/dns/network_manager_unix.go b/client/internal/dns/network_manager_unix.go index 63bbead77..10b4e6a6e 100644 --- a/client/internal/dns/network_manager_unix.go +++ b/client/internal/dns/network_manager_unix.go @@ -179,6 +179,10 @@ func (n *networkManagerDbusConfigurator) restoreHostDNS() error { return nil } +func (n *networkManagerDbusConfigurator) string() string { + return "network-manager" +} + func (n *networkManagerDbusConfigurator) getAppliedConnectionSettings() (networkManagerConnSettings, networkManagerConfigVersion, error) { obj, closeConn, err := getDbusObject(networkManagerDest, n.dbusLinkObject) if err != nil { diff --git a/client/internal/dns/resolvconf_unix.go b/client/internal/dns/resolvconf_unix.go index 6b5fdaf86..54c4c75bf 100644 --- a/client/internal/dns/resolvconf_unix.go +++ b/client/internal/dns/resolvconf_unix.go @@ -91,7 +91,7 @@ func (r *resolvconf) applyDNSConfig(config HostDNSConfig, stateManager *stateman if err != nil { log.Errorf("restore host dns: %s", err) } - return fmt.Errorf("unable to configure DNS for this peer using resolvconf manager without a nameserver group with all domains configured") + return ErrRouteAllWithoutNameserverGroup } searchDomainList := searchDomains(config) @@ -139,6 +139,10 @@ func (r *resolvconf) restoreHostDNS() error { return nil } +func (r *resolvconf) string() string { + return fmt.Sprintf("resolvconf (%s)", r.implType) +} + func (r *resolvconf) applyConfig(content bytes.Buffer) error { var cmd *exec.Cmd diff --git a/client/internal/dns/server.go b/client/internal/dns/server.go index d4d68370d..bc87012f2 100644 --- a/client/internal/dns/server.go +++ b/client/internal/dns/server.go @@ -2,6 +2,7 @@ package dns import ( "context" + "errors" "fmt" "net/netip" "runtime" @@ -15,6 +16,7 @@ import ( "github.com/netbirdio/netbird/client/internal/listener" "github.com/netbirdio/netbird/client/internal/peer" "github.com/netbirdio/netbird/client/internal/statemanager" + cProto "github.com/netbirdio/netbird/client/proto" nbdns "github.com/netbirdio/netbird/dns" ) @@ -395,12 +397,12 @@ func (s *DefaultServer) applyConfiguration(update nbdns.Config) error { localMuxUpdates, localRecordsByDomain, err := s.buildLocalHandlerUpdate(update.CustomZones) if err != nil { - return fmt.Errorf("not applying dns update, error: %v", err) + return fmt.Errorf("local handler updater: %w", err) } upstreamMuxUpdates, err := s.buildUpstreamHandlerUpdate(update.NameServerGroups) if err != nil { - return fmt.Errorf("not applying dns update, error: %v", err) + return fmt.Errorf("upstream handler updater: %w", err) } muxUpdates := append(localMuxUpdates, upstreamMuxUpdates...) //nolint:gocritic @@ -420,6 +422,7 @@ func (s *DefaultServer) applyConfiguration(update nbdns.Config) error { if err = s.hostManager.applyDNSConfig(hostUpdate, s.stateManager); err != nil { log.Error(err) + s.handleErrNoGroupaAll(err) } go func() { @@ -438,16 +441,33 @@ func (s *DefaultServer) applyConfiguration(update nbdns.Config) error { return nil } +func (s *DefaultServer) handleErrNoGroupaAll(err error) { + if !errors.Is(ErrRouteAllWithoutNameserverGroup, err) { + return + } + + if s.statusRecorder == nil { + return + } + + s.statusRecorder.PublishEvent( + cProto.SystemEvent_WARNING, cProto.SystemEvent_DNS, + "The host dns manager does not support match domains", + "The host dns manager does not support match domains without a catch-all nameserver group.", + map[string]string{"manager": s.hostManager.string()}, + ) +} + func (s *DefaultServer) buildLocalHandlerUpdate( customZones []nbdns.CustomZone, ) ([]handlerWrapper, map[string][]nbdns.SimpleRecord, error) { - var muxUpdates []handlerWrapper localRecords := make(map[string][]nbdns.SimpleRecord) for _, customZone := range customZones { if len(customZone.Records) == 0 { - return nil, nil, fmt.Errorf("received an empty list of records") + log.Warnf("received a custom zone with empty records, skipping domain: %s", customZone.Domain) + continue } muxUpdates = append(muxUpdates, handlerWrapper{ @@ -460,7 +480,8 @@ func (s *DefaultServer) buildLocalHandlerUpdate( for _, record := range customZone.Records { var class uint16 = dns.ClassINET if record.Class != nbdns.DefaultClass { - return nil, nil, fmt.Errorf("received an invalid class type: %s", record.Class) + log.Warnf("received an invalid class type: %s", record.Class) + continue } key := buildRecordKey(record.Name, class, uint16(record.Type)) @@ -670,6 +691,7 @@ func (s *DefaultServer) upstreamCallbacks( } if err := s.hostManager.applyDNSConfig(s.currentConfig, s.stateManager); err != nil { + s.handleErrNoGroupaAll(err) l.Errorf("Failed to apply nameserver deactivation on the host: %v", err) } @@ -708,6 +730,7 @@ func (s *DefaultServer) upstreamCallbacks( if s.hostManager != nil { if err := s.hostManager.applyDNSConfig(s.currentConfig, s.stateManager); err != nil { + s.handleErrNoGroupaAll(err) l.WithError(err).Error("reactivate temporary disabled nameserver group, DNS update apply") } } diff --git a/client/internal/dns/server_test.go b/client/internal/dns/server_test.go index e9ddd5f59..94b87124b 100644 --- a/client/internal/dns/server_test.go +++ b/client/internal/dns/server_test.go @@ -266,7 +266,7 @@ func TestUpdateDNSServer(t *testing.T) { shouldFail: true, }, { - name: "Invalid Custom Zone Records list Should Fail", + name: "Invalid Custom Zone Records list Should Skip", initLocalMap: make(registrationMap), initUpstreamMap: make(registeredHandlerMap), initSerial: 0, @@ -285,7 +285,11 @@ func TestUpdateDNSServer(t *testing.T) { }, }, }, - shouldFail: true, + expectedUpstreamMap: registeredHandlerMap{generateDummyHandler(".", nameServers).id(): handlerWrapper{ + domain: ".", + handler: dummyHandler, + priority: PriorityDefault, + }}, }, { name: "Empty Config Should Succeed and Clean Maps", @@ -352,7 +356,7 @@ func TestUpdateDNSServer(t *testing.T) { t.Log(err) } }() - dnsServer, err := NewDefaultServer(context.Background(), wgIface, "", &peer.Status{}, nil, false) + dnsServer, err := NewDefaultServer(context.Background(), wgIface, "", peer.NewRecorder("mgm"), nil, false) if err != nil { t.Fatal(err) } @@ -409,7 +413,7 @@ func TestDNSFakeResolverHandleUpdates(t *testing.T) { defer t.Setenv("NB_WG_KERNEL_DISABLED", ov) t.Setenv("NB_WG_KERNEL_DISABLED", "true") - newNet, err := stdnet.NewNet(nil) + newNet, err := stdnet.NewNet([]string{"utun2301"}) if err != nil { t.Errorf("create stdnet: %v", err) return @@ -461,7 +465,7 @@ func TestDNSFakeResolverHandleUpdates(t *testing.T) { return } - dnsServer, err := NewDefaultServer(context.Background(), wgIface, "", &peer.Status{}, nil, false) + dnsServer, err := NewDefaultServer(context.Background(), wgIface, "", peer.NewRecorder("mgm"), nil, false) if err != nil { t.Errorf("create DNS server: %v", err) return @@ -562,7 +566,7 @@ func TestDNSServerStartStop(t *testing.T) { for _, testCase := range testCases { t.Run(testCase.name, func(t *testing.T) { - dnsServer, err := NewDefaultServer(context.Background(), &mocWGIface{}, testCase.addrPort, &peer.Status{}, nil, false) + dnsServer, err := NewDefaultServer(context.Background(), &mocWGIface{}, testCase.addrPort, peer.NewRecorder("mgm"), nil, false) if err != nil { t.Fatalf("%v", err) } @@ -635,7 +639,7 @@ func TestDNSServerUpstreamDeactivateCallback(t *testing.T) { {false, "domain2", false}, }, }, - statusRecorder: &peer.Status{}, + statusRecorder: peer.NewRecorder("mgm"), } var domainsUpdate string @@ -696,7 +700,7 @@ func TestDNSPermanent_updateHostDNS_emptyUpstream(t *testing.T) { var dnsList []string dnsConfig := nbdns.Config{} - dnsServer := NewDefaultServerPermanentUpstream(context.Background(), wgIFace, dnsList, dnsConfig, nil, &peer.Status{}, false) + dnsServer := NewDefaultServerPermanentUpstream(context.Background(), wgIFace, dnsList, dnsConfig, nil, peer.NewRecorder("mgm"), false) err = dnsServer.Initialize() if err != nil { t.Errorf("failed to initialize DNS server: %v", err) @@ -720,7 +724,7 @@ func TestDNSPermanent_updateUpstream(t *testing.T) { } defer wgIFace.Close() dnsConfig := nbdns.Config{} - dnsServer := NewDefaultServerPermanentUpstream(context.Background(), wgIFace, []string{"8.8.8.8"}, dnsConfig, nil, &peer.Status{}, false) + dnsServer := NewDefaultServerPermanentUpstream(context.Background(), wgIFace, []string{"8.8.8.8"}, dnsConfig, nil, peer.NewRecorder("mgm"), false) err = dnsServer.Initialize() if err != nil { t.Errorf("failed to initialize DNS server: %v", err) @@ -812,7 +816,7 @@ func TestDNSPermanent_matchOnly(t *testing.T) { } defer wgIFace.Close() dnsConfig := nbdns.Config{} - dnsServer := NewDefaultServerPermanentUpstream(context.Background(), wgIFace, []string{"8.8.8.8"}, dnsConfig, nil, &peer.Status{}, false) + dnsServer := NewDefaultServerPermanentUpstream(context.Background(), wgIFace, []string{"8.8.8.8"}, dnsConfig, nil, peer.NewRecorder("mgm"), false) err = dnsServer.Initialize() if err != nil { t.Errorf("failed to initialize DNS server: %v", err) @@ -883,7 +887,7 @@ func createWgInterfaceWithBind(t *testing.T) (*iface.WGIface, error) { defer t.Setenv("NB_WG_KERNEL_DISABLED", ov) t.Setenv("NB_WG_KERNEL_DISABLED", "true") - newNet, err := stdnet.NewNet(nil) + newNet, err := stdnet.NewNet([]string{"utun2301"}) if err != nil { t.Fatalf("create stdnet: %v", err) return nil, err diff --git a/client/internal/dns/systemd_linux.go b/client/internal/dns/systemd_linux.go index a031be582..a87cc73e5 100644 --- a/client/internal/dns/systemd_linux.go +++ b/client/internal/dns/systemd_linux.go @@ -154,6 +154,10 @@ func (s *systemdDbusConfigurator) applyDNSConfig(config HostDNSConfig, stateMana return nil } +func (s *systemdDbusConfigurator) string() string { + return "dbus" +} + func (s *systemdDbusConfigurator) setDomainsForInterface(domainsInput []systemdDbusLinkDomainsInput) error { err := s.callLinkMethod(systemdDbusSetDomainsMethodSuffix, domainsInput) if err != nil { diff --git a/client/internal/dns/upstream.go b/client/internal/dns/upstream.go index d269107e3..a22689cf9 100644 --- a/client/internal/dns/upstream.go +++ b/client/internal/dns/upstream.go @@ -183,6 +183,19 @@ func (u *upstreamResolverBase) checkUpstreamFails(err error) { } u.disable(err) + + if u.statusRecorder == nil { + return + } + + u.statusRecorder.PublishEvent( + proto.SystemEvent_WARNING, + proto.SystemEvent_DNS, + "All upstream servers failed (fail count exceeded)", + "Unable to reach one or more DNS servers. This might affect your ability to connect to some services.", + map[string]string{"upstreams": strings.Join(u.upstreamServers, ", ")}, + // TODO add domain meta + ) } // probeAvailability tests all upstream servers simultaneously and @@ -232,10 +245,14 @@ func (u *upstreamResolverBase) probeAvailability() { if !success { u.disable(errors.ErrorOrNil()) + if u.statusRecorder == nil { + return + } + u.statusRecorder.PublishEvent( proto.SystemEvent_WARNING, proto.SystemEvent_DNS, - "All upstream servers failed", + "All upstream servers failed (probe failed)", "Unable to reach one or more DNS servers. This might affect your ability to connect to some services.", map[string]string{"upstreams": strings.Join(u.upstreamServers, ", ")}, ) diff --git a/client/internal/engine.go b/client/internal/engine.go index 865df5fbb..3d7802675 100644 --- a/client/internal/engine.go +++ b/client/internal/engine.go @@ -44,6 +44,7 @@ import ( "github.com/netbirdio/netbird/client/internal/routemanager" "github.com/netbirdio/netbird/client/internal/routemanager/systemops" "github.com/netbirdio/netbird/client/internal/statemanager" + cProto "github.com/netbirdio/netbird/client/proto" "github.com/netbirdio/netbird/management/domain" semaphoregroup "github.com/netbirdio/netbird/util/semaphore-group" @@ -155,7 +156,7 @@ type Engine struct { ctx context.Context cancel context.CancelFunc - wgInterface iface.IWGIface + wgInterface WGIface udpMux *bind.UniversalUDPMuxDefault @@ -535,15 +536,18 @@ func (e *Engine) modifyPeers(peersUpdate []*mgmProto.RemotePeerConfig) error { var modified []*mgmProto.RemotePeerConfig for _, p := range peersUpdate { peerPubKey := p.GetWgPubKey() - if allowedIPs, ok := e.peerStore.AllowedIPs(peerPubKey); ok { - if allowedIPs != strings.Join(p.AllowedIps, ",") { - modified = append(modified, p) - continue - } - err := e.statusRecorder.UpdatePeerFQDN(peerPubKey, p.GetFqdn()) - if err != nil { - log.Warnf("error updating peer's %s fqdn in the status recorder, got error: %v", peerPubKey, err) - } + allowedIPs, ok := e.peerStore.AllowedIPs(peerPubKey) + if !ok { + continue + } + if !compareNetIPLists(allowedIPs, p.GetAllowedIps()) { + modified = append(modified, p) + continue + } + + err := e.statusRecorder.UpdatePeerFQDN(peerPubKey, p.GetFqdn()) + if err != nil { + log.Warnf("error updating peer's %s fqdn in the status recorder, got error: %v", peerPubKey, err) } } @@ -682,6 +686,8 @@ func (e *Engine) handleSync(update *mgmProto.SyncResponse) error { return err } + e.statusRecorder.PublishEvent(cProto.SystemEvent_INFO, cProto.SystemEvent_SYSTEM, "Network map updated", "", nil) + return nil } @@ -967,7 +973,7 @@ func (e *Engine) updateNetworkMap(networkMap *mgmProto.NetworkMap) error { protoDNSConfig = &mgmProto.DNSConfig{} } - if err := e.dnsServer.UpdateDNSServer(serial, toDNSConfig(protoDNSConfig)); err != nil { + if err := e.dnsServer.UpdateDNSServer(serial, toDNSConfig(protoDNSConfig, e.wgInterface.Address().Network)); err != nil { log.Errorf("failed to update dns server, err: %v", err) } @@ -1036,7 +1042,7 @@ func toRouteDomains(myPubKey string, protoRoutes []*mgmProto.Route) []string { return dnsRoutes } -func toDNSConfig(protoDNSConfig *mgmProto.DNSConfig) nbdns.Config { +func toDNSConfig(protoDNSConfig *mgmProto.DNSConfig, network *net.IPNet) nbdns.Config { dnsUpdate := nbdns.Config{ ServiceEnable: protoDNSConfig.GetServiceEnable(), CustomZones: make([]nbdns.CustomZone, 0), @@ -1076,6 +1082,11 @@ func toDNSConfig(protoDNSConfig *mgmProto.DNSConfig) nbdns.Config { } dnsUpdate.NameServerGroups = append(dnsUpdate.NameServerGroups, dnsNSGroup) } + + if len(dnsUpdate.CustomZones) > 0 { + addReverseZone(&dnsUpdate, network) + } + return dnsUpdate } @@ -1109,34 +1120,45 @@ func (e *Engine) addNewPeers(peersUpdate []*mgmProto.RemotePeerConfig) error { // addNewPeer add peer if connection doesn't exist func (e *Engine) addNewPeer(peerConfig *mgmProto.RemotePeerConfig) error { peerKey := peerConfig.GetWgPubKey() - peerIPs := peerConfig.GetAllowedIps() - if _, ok := e.peerStore.PeerConn(peerKey); !ok { - conn, err := e.createPeerConn(peerKey, strings.Join(peerIPs, ",")) - if err != nil { - return fmt.Errorf("create peer connection: %w", err) - } - - if ok := e.peerStore.AddPeerConn(peerKey, conn); !ok { - conn.Close() - return fmt.Errorf("peer already exists: %s", peerKey) - } - - if e.beforePeerHook != nil && e.afterPeerHook != nil { - conn.AddBeforeAddPeerHook(e.beforePeerHook) - conn.AddAfterRemovePeerHook(e.afterPeerHook) - } - - err = e.statusRecorder.AddPeer(peerKey, peerConfig.Fqdn) - if err != nil { - log.Warnf("error adding peer %s to status recorder, got error: %v", peerKey, err) - } - - conn.Open() + peerIPs := make([]netip.Prefix, 0, len(peerConfig.GetAllowedIps())) + if _, ok := e.peerStore.PeerConn(peerKey); ok { + return nil } + + for _, ipString := range peerConfig.GetAllowedIps() { + allowedNetIP, err := netip.ParsePrefix(ipString) + if err != nil { + log.Errorf("failed to parse allowedIPS: %v", err) + return err + } + peerIPs = append(peerIPs, allowedNetIP) + } + + conn, err := e.createPeerConn(peerKey, peerIPs) + if err != nil { + return fmt.Errorf("create peer connection: %w", err) + } + + if ok := e.peerStore.AddPeerConn(peerKey, conn); !ok { + conn.Close() + return fmt.Errorf("peer already exists: %s", peerKey) + } + + if e.beforePeerHook != nil && e.afterPeerHook != nil { + conn.AddBeforeAddPeerHook(e.beforePeerHook) + conn.AddAfterRemovePeerHook(e.afterPeerHook) + } + + err = e.statusRecorder.AddPeer(peerKey, peerConfig.Fqdn) + if err != nil { + log.Warnf("error adding peer %s to status recorder, got error: %v", peerKey, err) + } + + conn.Open() return nil } -func (e *Engine) createPeerConn(pubKey string, allowedIPs string) (*peer.Conn, error) { +func (e *Engine) createPeerConn(pubKey string, allowedIPs []netip.Prefix) (*peer.Conn, error) { log.Debugf("creating peer connection %s", pubKey) wgConfig := peer.WgConfig{ @@ -1382,7 +1404,7 @@ func (e *Engine) readInitialSettings() ([]*route.Route, *nbdns.Config, error) { return nil, nil, err } routes := toRoutes(netMap.GetRoutes()) - dnsCfg := toDNSConfig(netMap.GetDNSConfig()) + dnsCfg := toDNSConfig(netMap.GetDNSConfig(), e.wgInterface.Address().Network) return routes, &dnsCfg, nil } @@ -1889,3 +1911,36 @@ func getInterfacePrefixes() ([]netip.Prefix, error) { return prefixes, nberrors.FormatErrorOrNil(merr) } + +// compareNetIPLists compares a list of netip.Prefix with a list of strings. +// return true if both lists are equal, false otherwise. +func compareNetIPLists(list1 []netip.Prefix, list2 []string) bool { + if len(list1) != len(list2) { + return false + } + + freq := make(map[string]int, len(list1)) + for _, p := range list1 { + freq[p.String()]++ + } + + for _, s := range list2 { + p, err := netip.ParsePrefix(s) + if err != nil { + return false // invalid prefix in list2. + } + key := p.String() + if freq[key] == 0 { + return false + } + freq[key]-- + } + + // all counts should be zero if lists are equal. + for _, count := range freq { + if count != 0 { + return false + } + } + return true +} diff --git a/client/internal/engine_test.go b/client/internal/engine_test.go index 1ac9a0430..9de1da28d 100644 --- a/client/internal/engine_test.go +++ b/client/internal/engine_test.go @@ -22,11 +22,16 @@ import ( "google.golang.org/grpc" "google.golang.org/grpc/keepalive" + wgdevice "golang.zx2c4.com/wireguard/device" + "golang.zx2c4.com/wireguard/tun/netstack" + "github.com/netbirdio/management-integrations/integrations" "github.com/netbirdio/netbird/client/iface" "github.com/netbirdio/netbird/client/iface/bind" + "github.com/netbirdio/netbird/client/iface/configurer" "github.com/netbirdio/netbird/client/iface/device" + "github.com/netbirdio/netbird/client/iface/wgproxy" "github.com/netbirdio/netbird/client/internal/dns" "github.com/netbirdio/netbird/client/internal/peer" "github.com/netbirdio/netbird/client/internal/peer/guard" @@ -65,6 +70,114 @@ var ( } ) +type MockWGIface struct { + CreateFunc func() error + CreateOnAndroidFunc func(routeRange []string, ip string, domains []string) error + IsUserspaceBindFunc func() bool + NameFunc func() string + AddressFunc func() device.WGAddress + ToInterfaceFunc func() *net.Interface + UpFunc func() (*bind.UniversalUDPMuxDefault, error) + UpdateAddrFunc func(newAddr string) error + UpdatePeerFunc func(peerKey string, allowedIps []netip.Prefix, keepAlive time.Duration, endpoint *net.UDPAddr, preSharedKey *wgtypes.Key) error + RemovePeerFunc func(peerKey string) error + AddAllowedIPFunc func(peerKey string, allowedIP string) error + RemoveAllowedIPFunc func(peerKey string, allowedIP string) error + CloseFunc func() error + SetFilterFunc func(filter device.PacketFilter) error + GetFilterFunc func() device.PacketFilter + GetDeviceFunc func() *device.FilteredDevice + GetWGDeviceFunc func() *wgdevice.Device + GetStatsFunc func(peerKey string) (configurer.WGStats, error) + GetInterfaceGUIDStringFunc func() (string, error) + GetProxyFunc func() wgproxy.Proxy + GetNetFunc func() *netstack.Net +} + +func (m *MockWGIface) GetInterfaceGUIDString() (string, error) { + return m.GetInterfaceGUIDStringFunc() +} + +func (m *MockWGIface) Create() error { + return m.CreateFunc() +} + +func (m *MockWGIface) CreateOnAndroid(routeRange []string, ip string, domains []string) error { + return m.CreateOnAndroidFunc(routeRange, ip, domains) +} + +func (m *MockWGIface) IsUserspaceBind() bool { + return m.IsUserspaceBindFunc() +} + +func (m *MockWGIface) Name() string { + return m.NameFunc() +} + +func (m *MockWGIface) Address() device.WGAddress { + return m.AddressFunc() +} + +func (m *MockWGIface) ToInterface() *net.Interface { + return m.ToInterfaceFunc() +} + +func (m *MockWGIface) Up() (*bind.UniversalUDPMuxDefault, error) { + return m.UpFunc() +} + +func (m *MockWGIface) UpdateAddr(newAddr string) error { + return m.UpdateAddrFunc(newAddr) +} + +func (m *MockWGIface) UpdatePeer(peerKey string, allowedIps []netip.Prefix, keepAlive time.Duration, endpoint *net.UDPAddr, preSharedKey *wgtypes.Key) error { + return m.UpdatePeerFunc(peerKey, allowedIps, keepAlive, endpoint, preSharedKey) +} + +func (m *MockWGIface) RemovePeer(peerKey string) error { + return m.RemovePeerFunc(peerKey) +} + +func (m *MockWGIface) AddAllowedIP(peerKey string, allowedIP string) error { + return m.AddAllowedIPFunc(peerKey, allowedIP) +} + +func (m *MockWGIface) RemoveAllowedIP(peerKey string, allowedIP string) error { + return m.RemoveAllowedIPFunc(peerKey, allowedIP) +} + +func (m *MockWGIface) Close() error { + return m.CloseFunc() +} + +func (m *MockWGIface) SetFilter(filter device.PacketFilter) error { + return m.SetFilterFunc(filter) +} + +func (m *MockWGIface) GetFilter() device.PacketFilter { + return m.GetFilterFunc() +} + +func (m *MockWGIface) GetDevice() *device.FilteredDevice { + return m.GetDeviceFunc() +} + +func (m *MockWGIface) GetWGDevice() *wgdevice.Device { + return m.GetWGDeviceFunc() +} + +func (m *MockWGIface) GetStats(peerKey string) (configurer.WGStats, error) { + return m.GetStatsFunc(peerKey) +} + +func (m *MockWGIface) GetProxy() wgproxy.Proxy { + return m.GetProxyFunc() +} + +func (m *MockWGIface) GetNet() *netstack.Net { + return m.GetNetFunc() +} + func TestMain(m *testing.M) { _ = util.InitLog("debug", "console") code := m.Run() @@ -246,11 +359,20 @@ func TestEngine_UpdateNetworkMap(t *testing.T) { peer.NewRecorder("https://mgm"), nil) - wgIface := &iface.MockWGIface{ + wgIface := &MockWGIface{ NameFunc: func() string { return "utun102" }, RemovePeerFunc: func(peerKey string) error { return nil }, + AddressFunc: func() iface.WGAddress { + return iface.WGAddress{ + IP: net.ParseIP("10.20.0.1"), + Network: &net.IPNet{ + IP: net.ParseIP("10.20.0.0"), + Mask: net.IPv4Mask(255, 255, 255, 0), + }, + } + }, } engine.wgInterface = wgIface engine.routeManager = routemanager.NewManager(routemanager.ManagerConfig{ @@ -414,7 +536,7 @@ func TestEngine_UpdateNetworkMap(t *testing.T) { t.Errorf("expecting Engine.peerConns to contain peer %s", p) } expectedAllowedIPs := strings.Join(p.AllowedIps, ",") - if conn.WgConfig().AllowedIps != expectedAllowedIPs { + if !compareNetIPLists(conn.WgConfig().AllowedIps, p.AllowedIps) { t.Errorf("expecting peer %s to have AllowedIPs= %s, got %s", p.GetWgPubKey(), expectedAllowedIPs, conn.WgConfig().AllowedIps) } @@ -693,6 +815,9 @@ func TestEngine_UpdateNetworkMapWithDNSUpdate(t *testing.T) { }, }, }, + { + Domain: "0.66.100.in-addr.arpa.", + }, }, NameServerGroups: []*mgmtProto.NameServerGroup{ { @@ -722,6 +847,9 @@ func TestEngine_UpdateNetworkMapWithDNSUpdate(t *testing.T) { }, }, }, + { + Domain: "0.66.100.in-addr.arpa.", + }, }, expectedNSGroupsLen: 1, expectedNSGroups: []*nbdns.NameServerGroup{ @@ -1111,6 +1239,91 @@ func Test_CheckFilesEqual(t *testing.T) { } } +func TestCompareNetIPLists(t *testing.T) { + tests := []struct { + name string + list1 []netip.Prefix + list2 []string + expected bool + }{ + { + name: "both empty", + list1: []netip.Prefix{}, + list2: []string{}, + expected: true, + }, + { + name: "single match ipv4", + list1: []netip.Prefix{netip.MustParsePrefix("192.168.0.0/24")}, + list2: []string{"192.168.0.0/24"}, + expected: true, + }, + { + name: "multiple match ipv4, different order", + list1: []netip.Prefix{netip.MustParsePrefix("192.168.1.0/24"), netip.MustParsePrefix("10.0.0.0/8")}, + list2: []string{"10.0.0.0/8", "192.168.1.0/24"}, + expected: true, + }, + { + name: "ipv4 mismatch due to extra element in list2", + list1: []netip.Prefix{netip.MustParsePrefix("192.168.1.0/24")}, + list2: []string{"192.168.1.0/24", "10.0.0.0/8"}, + expected: false, + }, + { + name: "ipv4 mismatch due to duplicate count", + list1: []netip.Prefix{netip.MustParsePrefix("192.168.1.0/24"), netip.MustParsePrefix("192.168.1.0/24")}, + list2: []string{"192.168.1.0/24"}, + expected: false, + }, + { + name: "invalid prefix in list2", + list1: []netip.Prefix{netip.MustParsePrefix("192.168.1.0/24")}, + list2: []string{"invalid-prefix"}, + expected: false, + }, + { + name: "ipv4 mismatch because different prefixes", + list1: []netip.Prefix{netip.MustParsePrefix("192.168.1.0/24")}, + list2: []string{"10.0.0.0/8"}, + expected: false, + }, + { + name: "single match ipv6", + list1: []netip.Prefix{netip.MustParsePrefix("2001:db8::/32")}, + list2: []string{"2001:db8::/32"}, + expected: true, + }, + { + name: "multiple match ipv6, different order", + list1: []netip.Prefix{netip.MustParsePrefix("2001:db8::/32"), netip.MustParsePrefix("fe80::/10")}, + list2: []string{"fe80::/10", "2001:db8::/32"}, + expected: true, + }, + { + name: "mixed ipv4 and ipv6 match", + list1: []netip.Prefix{netip.MustParsePrefix("192.168.1.0/24"), netip.MustParsePrefix("2001:db8::/32")}, + list2: []string{"2001:db8::/32", "192.168.1.0/24"}, + expected: true, + }, + { + name: "ipv6 mismatch with invalid prefix", + list1: []netip.Prefix{netip.MustParsePrefix("2001:db8::/32")}, + list2: []string{"invalid-ipv6"}, + expected: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := compareNetIPLists(tt.list1, tt.list2) + if result != tt.expected { + t.Errorf("compareNetIPLists(%v, %v) = %v; want %v", tt.list1, tt.list2, result, tt.expected) + } + }) + } +} + func createEngine(ctx context.Context, cancel context.CancelFunc, setupKey string, i int, mgmtAddr string, signalAddr string) (*Engine, error) { key, err := wgtypes.GeneratePrivateKey() if err != nil { @@ -1131,7 +1344,7 @@ func createEngine(ctx context.Context, cancel context.CancelFunc, setupKey strin } info := system.GetInfo(ctx) - resp, err := mgmtClient.Register(*publicKey, setupKey, "", info, nil) + resp, err := mgmtClient.Register(*publicKey, setupKey, "", info, nil, nil) if err != nil { return nil, err } @@ -1227,7 +1440,7 @@ func startManagement(t *testing.T, dataDir, testFile string) (*grpc.Server, stri } secretsManager := server.NewTimeBasedAuthSecretsManager(peersUpdateManager, config.TURNConfig, config.Relay) - mgmtServer, err := server.NewServer(context.Background(), config, accountManager, settings.NewManager(store), peersUpdateManager, secretsManager, nil, nil) + mgmtServer, err := server.NewServer(context.Background(), config, accountManager, settings.NewManager(store), peersUpdateManager, secretsManager, nil, nil, nil) if err != nil { return nil, "", err } diff --git a/client/internal/iface.go b/client/internal/iface.go new file mode 100644 index 000000000..bd0069c19 --- /dev/null +++ b/client/internal/iface.go @@ -0,0 +1,8 @@ +//go:build !windows +// +build !windows + +package internal + +type WGIface interface { + wgIfaceBase +} diff --git a/client/iface/iwginterface.go b/client/internal/iface_common.go similarity index 84% rename from client/iface/iwginterface.go rename to client/internal/iface_common.go index 2b919ac9e..65b425015 100644 --- a/client/iface/iwginterface.go +++ b/client/internal/iface_common.go @@ -1,9 +1,8 @@ -//go:build !windows - -package iface +package internal import ( "net" + "net/netip" "time" wgdevice "golang.zx2c4.com/wireguard/device" @@ -16,7 +15,7 @@ import ( "github.com/netbirdio/netbird/client/iface/wgproxy" ) -type IWGIface interface { +type wgIfaceBase interface { Create() error CreateOnAndroid(routeRange []string, ip string, domains []string) error IsUserspaceBind() bool @@ -26,7 +25,7 @@ type IWGIface interface { Up() (*bind.UniversalUDPMuxDefault, error) UpdateAddr(newAddr string) error GetProxy() wgproxy.Proxy - UpdatePeer(peerKey string, allowedIps string, keepAlive time.Duration, endpoint *net.UDPAddr, preSharedKey *wgtypes.Key) error + UpdatePeer(peerKey string, allowedIps []netip.Prefix, keepAlive time.Duration, endpoint *net.UDPAddr, preSharedKey *wgtypes.Key) error RemovePeer(peerKey string) error AddAllowedIP(peerKey string, allowedIP string) error RemoveAllowedIP(peerKey string, allowedIP string) error diff --git a/client/internal/iface_windows.go b/client/internal/iface_windows.go new file mode 100644 index 000000000..113217815 --- /dev/null +++ b/client/internal/iface_windows.go @@ -0,0 +1,6 @@ +package internal + +type WGIface interface { + wgIfaceBase + GetInterfaceGUIDString() (string, error) +} diff --git a/client/internal/login.go b/client/internal/login.go index 092f2309c..395a17199 100644 --- a/client/internal/login.go +++ b/client/internal/login.go @@ -140,7 +140,7 @@ func registerPeer(ctx context.Context, serverPublicKey wgtypes.Key, client *mgm. config.DisableDNS, config.DisableFirewall, ) - loginResp, err := client.Register(serverPublicKey, validSetupKey.String(), jwtToken, info, pubSSHKey) + loginResp, err := client.Register(serverPublicKey, validSetupKey.String(), jwtToken, info, pubSSHKey, config.DNSLabels) if err != nil { log.Errorf("failed registering peer %v,%s", err, validSetupKey.String()) return nil, err diff --git a/client/internal/peer/conn.go b/client/internal/peer/conn.go index 8bbea6a2b..9b4d1a554 100644 --- a/client/internal/peer/conn.go +++ b/client/internal/peer/conn.go @@ -5,9 +5,9 @@ import ( "fmt" "math/rand" "net" + "net/netip" "os" "runtime" - "strings" "sync" "time" @@ -15,7 +15,6 @@ import ( log "github.com/sirupsen/logrus" "golang.zx2c4.com/wireguard/wgctrl/wgtypes" - "github.com/netbirdio/netbird/client/iface" "github.com/netbirdio/netbird/client/iface/configurer" "github.com/netbirdio/netbird/client/iface/wgproxy" "github.com/netbirdio/netbird/client/internal/peer/guard" @@ -56,8 +55,8 @@ const ( type WgConfig struct { WgListenPort int RemoteKey string - WgInterface iface.IWGIface - AllowedIps string + WgInterface WGIface + AllowedIps []netip.Prefix PreSharedKey *wgtypes.Key } @@ -92,11 +91,10 @@ type Conn struct { statusRecorder *Status signaler *Signaler relayManager *relayClient.Manager - allowedIP net.IP handshaker *Handshaker onConnected func(remoteWireGuardKey string, remoteRosenpassPubKey []byte, wireGuardIP string, remoteRosenpassAddr string) - onDisconnected func(remotePeer string, wgIP string) + onDisconnected func(remotePeer string) statusRelay *AtomicConnStatus statusICE *AtomicConnStatus @@ -121,10 +119,8 @@ type Conn struct { // NewConn creates a new not opened Conn to the remote peer. // To establish a connection run Conn.Open func NewConn(engineCtx context.Context, config ConnConfig, statusRecorder *Status, signaler *Signaler, iFaceDiscover stdnet.ExternalIFaceDiscover, relayManager *relayClient.Manager, srWatcher *guard.SRWatcher, semaphore *semaphoregroup.SemaphoreGroup) (*Conn, error) { - allowedIP, _, err := net.ParseCIDR(config.WgConfig.AllowedIps) - if err != nil { - log.Errorf("failed to parse allowedIPS: %v", err) - return nil, err + if len(config.WgConfig.AllowedIps) == 0 { + return nil, fmt.Errorf("allowed IPs is empty") } ctx, ctxCancel := context.WithCancel(engineCtx) @@ -138,7 +134,6 @@ func NewConn(engineCtx context.Context, config ConnConfig, statusRecorder *Statu statusRecorder: statusRecorder, signaler: signaler, relayManager: relayManager, - allowedIP: allowedIP, statusRelay: NewAtomicConnStatus(), statusICE: NewAtomicConnStatus(), semaphore: semaphore, @@ -148,10 +143,11 @@ func NewConn(engineCtx context.Context, config ConnConfig, statusRecorder *Statu conn.workerRelay = NewWorkerRelay(connLog, ctrl, config, conn, relayManager) relayIsSupportedLocally := conn.workerRelay.RelayIsSupportedLocally() - conn.workerICE, err = NewWorkerICE(ctx, connLog, config, conn, signaler, iFaceDiscover, statusRecorder, relayIsSupportedLocally) + workerICE, err := NewWorkerICE(ctx, connLog, config, conn, signaler, iFaceDiscover, statusRecorder, relayIsSupportedLocally) if err != nil { return nil, err } + conn.workerICE = workerICE conn.handshaker = NewHandshaker(ctx, connLog, config, signaler, conn.workerICE, conn.workerRelay) @@ -180,7 +176,7 @@ func (conn *Conn) Open() { peerState := State{ PubKey: conn.config.Key, - IP: strings.Split(conn.config.WgConfig.AllowedIps, "/")[0], + IP: conn.config.WgConfig.AllowedIps[0].Addr().String(), ConnStatusUpdate: time.Now(), ConnStatus: StatusDisconnected, Mux: new(sync.RWMutex), @@ -246,7 +242,7 @@ func (conn *Conn) Close() { conn.freeUpConnID() if conn.evalStatus() == StatusConnected && conn.onDisconnected != nil { - conn.onDisconnected(conn.config.WgConfig.RemoteKey, conn.config.WgConfig.AllowedIps) + conn.onDisconnected(conn.config.WgConfig.RemoteKey) } conn.setStatusToDisconnected() @@ -277,7 +273,7 @@ func (conn *Conn) SetOnConnected(handler func(remoteWireGuardKey string, remoteR } // SetOnDisconnected sets a handler function to be triggered by Conn when a connection to a remote disconnected -func (conn *Conn) SetOnDisconnected(handler func(remotePeer string, wgIP string)) { +func (conn *Conn) SetOnDisconnected(handler func(remotePeer string)) { conn.onDisconnected = handler } @@ -602,7 +598,7 @@ func (conn *Conn) doOnConnected(remoteRosenpassPubKey []byte, remoteRosenpassAdd } if conn.onConnected != nil { - conn.onConnected(conn.config.Key, remoteRosenpassPubKey, conn.allowedIP.String(), remoteRosenpassAddr) + conn.onConnected(conn.config.Key, remoteRosenpassPubKey, conn.config.WgConfig.AllowedIps[0].Addr().String(), remoteRosenpassAddr) } } @@ -699,7 +695,7 @@ func (conn *Conn) freeUpConnID() { func (conn *Conn) newProxy(remoteConn net.Conn) (wgproxy.Proxy, error) { conn.log.Debugf("setup proxied WireGuard connection") udpAddr := &net.UDPAddr{ - IP: conn.allowedIP, + IP: conn.config.WgConfig.AllowedIps[0].Addr().AsSlice(), Port: conn.config.WgConfig.WgListenPort, } @@ -753,8 +749,8 @@ func (conn *Conn) setRelayedProxy(proxy wgproxy.Proxy) { } // AllowedIP returns the allowed IP of the remote peer -func (conn *Conn) AllowedIP() net.IP { - return conn.allowedIP +func (conn *Conn) AllowedIP() netip.Addr { + return conn.config.WgConfig.AllowedIps[0].Addr() } func isController(config ConnConfig) bool { diff --git a/client/internal/peer/iface.go b/client/internal/peer/iface.go new file mode 100644 index 000000000..c7b6de9ea --- /dev/null +++ b/client/internal/peer/iface.go @@ -0,0 +1,19 @@ +package peer + +import ( + "net" + "net/netip" + "time" + + "golang.zx2c4.com/wireguard/wgctrl/wgtypes" + + "github.com/netbirdio/netbird/client/iface/configurer" + "github.com/netbirdio/netbird/client/iface/wgproxy" +) + +type WGIface interface { + UpdatePeer(peerKey string, allowedIps []netip.Prefix, keepAlive time.Duration, endpoint *net.UDPAddr, preSharedKey *wgtypes.Key) error + RemovePeer(peerKey string) error + GetStats(peerKey string) (configurer.WGStats, error) + GetProxy() wgproxy.Proxy +} diff --git a/client/internal/peerstore/store.go b/client/internal/peerstore/store.go index 6b3385ff5..15d34d3d0 100644 --- a/client/internal/peerstore/store.go +++ b/client/internal/peerstore/store.go @@ -1,7 +1,7 @@ package peerstore import ( - "net" + "net/netip" "sync" "golang.org/x/exp/maps" @@ -46,18 +46,7 @@ func (s *Store) Remove(pubKey string) (*peer.Conn, bool) { return p, true } -func (s *Store) AllowedIPs(pubKey string) (string, bool) { - s.peerConnsMu.RLock() - defer s.peerConnsMu.RUnlock() - - p, ok := s.peerConns[pubKey] - if !ok { - return "", false - } - return p.WgConfig().AllowedIps, true -} - -func (s *Store) AllowedIP(pubKey string) (net.IP, bool) { +func (s *Store) AllowedIPs(pubKey string) ([]netip.Prefix, bool) { s.peerConnsMu.RLock() defer s.peerConnsMu.RUnlock() @@ -65,6 +54,17 @@ func (s *Store) AllowedIP(pubKey string) (net.IP, bool) { if !ok { return nil, false } + return p.WgConfig().AllowedIps, true +} + +func (s *Store) AllowedIP(pubKey string) (netip.Addr, bool) { + s.peerConnsMu.RLock() + defer s.peerConnsMu.RUnlock() + + p, ok := s.peerConns[pubKey] + if !ok { + return netip.Addr{}, false + } return p.AllowedIP(), true } diff --git a/client/internal/rosenpass/manager.go b/client/internal/rosenpass/manager.go index bf019453b..d2d7408fd 100644 --- a/client/internal/rosenpass/manager.go +++ b/client/internal/rosenpass/manager.go @@ -126,7 +126,7 @@ func (m *Manager) generateConfig() (rp.Config, error) { return cfg, nil } -func (m *Manager) OnDisconnected(peerKey string, wgIP string) { +func (m *Manager) OnDisconnected(peerKey string) { m.lock.Lock() defer m.lock.Unlock() diff --git a/client/internal/routemanager/client.go b/client/internal/routemanager/client.go index 5984e69cb..6680f727a 100644 --- a/client/internal/routemanager/client.go +++ b/client/internal/routemanager/client.go @@ -11,12 +11,12 @@ import ( log "github.com/sirupsen/logrus" nberrors "github.com/netbirdio/netbird/client/errors" - "github.com/netbirdio/netbird/client/iface" nbdns "github.com/netbirdio/netbird/client/internal/dns" "github.com/netbirdio/netbird/client/internal/peer" "github.com/netbirdio/netbird/client/internal/peerstore" "github.com/netbirdio/netbird/client/internal/routemanager/dnsinterceptor" "github.com/netbirdio/netbird/client/internal/routemanager/dynamic" + "github.com/netbirdio/netbird/client/internal/routemanager/iface" "github.com/netbirdio/netbird/client/internal/routemanager/refcounter" "github.com/netbirdio/netbird/client/internal/routemanager/static" "github.com/netbirdio/netbird/client/proto" @@ -62,7 +62,7 @@ type clientNetwork struct { ctx context.Context cancel context.CancelFunc statusRecorder *peer.Status - wgInterface iface.IWGIface + wgInterface iface.WGIface routes map[route.ID]*route.Route routeUpdate chan routesUpdate peerStateUpdate chan struct{} @@ -75,7 +75,7 @@ type clientNetwork struct { func newClientNetworkWatcher( ctx context.Context, dnsRouteInterval time.Duration, - wgInterface iface.IWGIface, + wgInterface iface.WGIface, statusRecorder *peer.Status, rt *route.Route, routeRefCounter *refcounter.RouteRefCounter, @@ -306,11 +306,13 @@ func (c *clientNetwork) recalculateRouteAndUpdatePeerAndSystem(rsn reason) error return nil } + var isNew bool if c.currentChosen == nil { // If they were not previously assigned to another peer, add routes to the system first if err := c.handler.AddRoute(c.ctx); err != nil { return fmt.Errorf("add route: %w", err) } + isNew = true } else { // Otherwise, remove the allowed IPs from the previous peer first if err := c.removeRouteFromWireGuardPeer(); err != nil { @@ -324,6 +326,10 @@ func (c *clientNetwork) recalculateRouteAndUpdatePeerAndSystem(rsn reason) error return fmt.Errorf("add allowed IPs for peer %s: %w", c.currentChosen.Peer, err) } + if isNew { + c.connectEvent() + } + err := c.statusRecorder.AddPeerStateRoute(c.currentChosen.Peer, c.handler.String()) if err != nil { return fmt.Errorf("add peer state route: %w", err) @@ -331,6 +337,35 @@ func (c *clientNetwork) recalculateRouteAndUpdatePeerAndSystem(rsn reason) error return nil } +func (c *clientNetwork) connectEvent() { + var defaultRoute bool + for _, r := range c.routes { + if r.Network.Bits() == 0 { + defaultRoute = true + break + } + } + + if !defaultRoute { + return + } + + meta := map[string]string{ + "network": c.handler.String(), + } + if c.currentChosen != nil { + meta["id"] = string(c.currentChosen.NetID) + meta["peer"] = c.currentChosen.Peer + } + c.statusRecorder.PublishEvent( + proto.SystemEvent_INFO, + proto.SystemEvent_NETWORK, + "Default route added", + "Exit node connected.", + meta, + ) +} + func (c *clientNetwork) disconnectEvent(rsn reason) { var defaultRoute bool for _, r := range c.routes { @@ -349,29 +384,27 @@ func (c *clientNetwork) disconnectEvent(rsn reason) { var userMessage string meta := make(map[string]string) + if c.currentChosen != nil { + meta["id"] = string(c.currentChosen.NetID) + meta["peer"] = c.currentChosen.Peer + } + meta["network"] = c.handler.String() switch rsn { case reasonShutdown: severity = proto.SystemEvent_INFO message = "Default route removed" userMessage = "Exit node disconnected." - meta["network"] = c.handler.String() case reasonRouteUpdate: severity = proto.SystemEvent_INFO message = "Default route updated due to configuration change" - meta["network"] = c.handler.String() case reasonPeerUpdate: severity = proto.SystemEvent_WARNING message = "Default route disconnected due to peer unreachability" userMessage = "Exit node connection lost. Your internet access might be affected." - if c.currentChosen != nil { - meta["peer"] = c.currentChosen.Peer - meta["network"] = c.handler.String() - } default: severity = proto.SystemEvent_ERROR - message = "Default route disconnected for unknown reason" + message = "Default route disconnected for unknown reasons" userMessage = "Exit node disconnected for unknown reasons." - meta["network"] = c.handler.String() } c.statusRecorder.PublishEvent( @@ -468,7 +501,7 @@ func handlerFromRoute( allowedIPsRefCounter *refcounter.AllowedIPsRefCounter, dnsRouterInteval time.Duration, statusRecorder *peer.Status, - wgInterface iface.IWGIface, + wgInterface iface.WGIface, dnsServer nbdns.Server, peerStore *peerstore.Store, useNewDNSRoute bool, diff --git a/client/internal/routemanager/dnsinterceptor/handler.go b/client/internal/routemanager/dnsinterceptor/handler.go index 10cb03f1d..f36285cc4 100644 --- a/client/internal/routemanager/dnsinterceptor/handler.go +++ b/client/internal/routemanager/dnsinterceptor/handler.go @@ -3,7 +3,6 @@ package dnsinterceptor import ( "context" "fmt" - "net" "net/netip" "strings" "sync" @@ -165,14 +164,14 @@ func (d *DnsInterceptor) ServeDNS(w dns.ResponseWriter, r *dns.Msg) { Timeout: 5 * time.Second, Net: "udp", } - upstream := fmt.Sprintf("%s:%d", upstreamIP, dnsfwd.ListenPort) + upstream := fmt.Sprintf("%s:%d", upstreamIP.String(), dnsfwd.ListenPort) reply, _, err := client.ExchangeContext(context.Background(), r, upstream) var answer []dns.RR if reply != nil { answer = reply.Answer } - log.Tracef("upstream %s (%s) DNS response for domain=%s answers=%v", upstreamIP, peerKey, r.Question[0].Name, answer) + log.Tracef("upstream %s (%s) DNS response for domain=%s answers=%v", upstreamIP.String(), peerKey, r.Question[0].Name, answer) if err != nil { log.Errorf("failed to exchange DNS request with %s: %v", upstream, err) @@ -201,10 +200,10 @@ func (d *DnsInterceptor) continueToNextHandler(w dns.ResponseWriter, r *dns.Msg, } } -func (d *DnsInterceptor) getUpstreamIP(peerKey string) (net.IP, error) { +func (d *DnsInterceptor) getUpstreamIP(peerKey string) (netip.Addr, error) { peerAllowedIP, exists := d.peerStore.AllowedIP(peerKey) if !exists { - return nil, fmt.Errorf("peer connection not found for key: %s", peerKey) + return netip.Addr{}, fmt.Errorf("peer connection not found for key: %s", peerKey) } return peerAllowedIP, nil } diff --git a/client/internal/routemanager/dynamic/route.go b/client/internal/routemanager/dynamic/route.go index a0fff7713..5ef18a47e 100644 --- a/client/internal/routemanager/dynamic/route.go +++ b/client/internal/routemanager/dynamic/route.go @@ -13,8 +13,8 @@ import ( log "github.com/sirupsen/logrus" nberrors "github.com/netbirdio/netbird/client/errors" - "github.com/netbirdio/netbird/client/iface" "github.com/netbirdio/netbird/client/internal/peer" + "github.com/netbirdio/netbird/client/internal/routemanager/iface" "github.com/netbirdio/netbird/client/internal/routemanager/refcounter" "github.com/netbirdio/netbird/client/internal/routemanager/util" "github.com/netbirdio/netbird/management/domain" @@ -48,7 +48,7 @@ type Route struct { currentPeerKey string cancel context.CancelFunc statusRecorder *peer.Status - wgInterface iface.IWGIface + wgInterface iface.WGIface resolverAddr string } @@ -58,7 +58,7 @@ func NewRoute( allowedIPsRefCounter *refcounter.AllowedIPsRefCounter, interval time.Duration, statusRecorder *peer.Status, - wgInterface iface.IWGIface, + wgInterface iface.WGIface, resolverAddr string, ) *Route { return &Route{ diff --git a/client/internal/routemanager/iface/iface.go b/client/internal/routemanager/iface/iface.go new file mode 100644 index 000000000..57dbec03d --- /dev/null +++ b/client/internal/routemanager/iface/iface.go @@ -0,0 +1,9 @@ +//go:build !windows +// +build !windows + +package iface + +// WGIface defines subset methods of interface required for router +type WGIface interface { + wgIfaceBase +} diff --git a/client/internal/routemanager/iface/iface_common.go b/client/internal/routemanager/iface/iface_common.go new file mode 100644 index 000000000..8b2dc9714 --- /dev/null +++ b/client/internal/routemanager/iface/iface_common.go @@ -0,0 +1,22 @@ +package iface + +import ( + "net" + + "github.com/netbirdio/netbird/client/iface" + "github.com/netbirdio/netbird/client/iface/configurer" + "github.com/netbirdio/netbird/client/iface/device" +) + +type wgIfaceBase interface { + AddAllowedIP(peerKey string, allowedIP string) error + RemoveAllowedIP(peerKey string, allowedIP string) error + + Name() string + Address() iface.WGAddress + ToInterface() *net.Interface + IsUserspaceBind() bool + GetFilter() device.PacketFilter + GetDevice() *device.FilteredDevice + GetStats(peerKey string) (configurer.WGStats, error) +} diff --git a/client/internal/routemanager/iface/iface_windows.go b/client/internal/routemanager/iface/iface_windows.go new file mode 100644 index 000000000..7ab7e239c --- /dev/null +++ b/client/internal/routemanager/iface/iface_windows.go @@ -0,0 +1,7 @@ +package iface + +// WGIface defines subset methods of interface required for router +type WGIface interface { + wgIfaceBase + GetInterfaceGUIDString() (string, error) +} diff --git a/client/internal/routemanager/manager.go b/client/internal/routemanager/manager.go index 52de0948b..ae0d1d220 100644 --- a/client/internal/routemanager/manager.go +++ b/client/internal/routemanager/manager.go @@ -15,13 +15,13 @@ import ( "golang.org/x/exp/maps" firewall "github.com/netbirdio/netbird/client/firewall/manager" - "github.com/netbirdio/netbird/client/iface" "github.com/netbirdio/netbird/client/iface/configurer" "github.com/netbirdio/netbird/client/iface/netstack" "github.com/netbirdio/netbird/client/internal/dns" "github.com/netbirdio/netbird/client/internal/listener" "github.com/netbirdio/netbird/client/internal/peer" "github.com/netbirdio/netbird/client/internal/peerstore" + "github.com/netbirdio/netbird/client/internal/routemanager/iface" "github.com/netbirdio/netbird/client/internal/routemanager/notifier" "github.com/netbirdio/netbird/client/internal/routemanager/refcounter" "github.com/netbirdio/netbird/client/internal/routemanager/systemops" @@ -52,7 +52,7 @@ type ManagerConfig struct { Context context.Context PublicKey string DNSRouteInterval time.Duration - WGInterface iface.IWGIface + WGInterface iface.WGIface StatusRecorder *peer.Status RelayManager *relayClient.Manager InitialRoutes []*route.Route @@ -74,7 +74,7 @@ type DefaultManager struct { sysOps *systemops.SysOps statusRecorder *peer.Status relayMgr *relayClient.Manager - wgInterface iface.IWGIface + wgInterface iface.WGIface pubKey string notifier *notifier.Notifier routeRefCounter *refcounter.RouteRefCounter diff --git a/client/internal/routemanager/server_android.go b/client/internal/routemanager/server_android.go index e9cfa0826..48bb0380d 100644 --- a/client/internal/routemanager/server_android.go +++ b/client/internal/routemanager/server_android.go @@ -7,8 +7,8 @@ import ( "fmt" firewall "github.com/netbirdio/netbird/client/firewall/manager" - "github.com/netbirdio/netbird/client/iface" "github.com/netbirdio/netbird/client/internal/peer" + "github.com/netbirdio/netbird/client/internal/routemanager/iface" "github.com/netbirdio/netbird/route" ) @@ -22,6 +22,6 @@ func (r serverRouter) updateRoutes(map[route.ID]*route.Route) error { return nil } -func newServerRouter(context.Context, iface.IWGIface, firewall.Manager, *peer.Status) (*serverRouter, error) { +func newServerRouter(context.Context, iface.WGIface, firewall.Manager, *peer.Status) (*serverRouter, error) { return nil, fmt.Errorf("server route not supported on this os") } diff --git a/client/internal/routemanager/server_nonandroid.go b/client/internal/routemanager/server_nonandroid.go index 7ddcccb0b..603818bba 100644 --- a/client/internal/routemanager/server_nonandroid.go +++ b/client/internal/routemanager/server_nonandroid.go @@ -11,8 +11,12 @@ import ( log "github.com/sirupsen/logrus" firewall "github.com/netbirdio/netbird/client/firewall/manager" - "github.com/netbirdio/netbird/client/iface" "github.com/netbirdio/netbird/client/internal/peer" +<<<<<<< HEAD +======= + "github.com/netbirdio/netbird/client/internal/routemanager/iface" + "github.com/netbirdio/netbird/client/internal/routemanager/systemops" +>>>>>>> main "github.com/netbirdio/netbird/route" ) @@ -21,11 +25,11 @@ type serverRouter struct { ctx context.Context routes map[route.ID]*route.Route firewall firewall.Manager - wgInterface iface.IWGIface + wgInterface iface.WGIface statusRecorder *peer.Status } -func newServerRouter(ctx context.Context, wgInterface iface.IWGIface, firewall firewall.Manager, statusRecorder *peer.Status) (*serverRouter, error) { +func newServerRouter(ctx context.Context, wgInterface iface.WGIface, firewall firewall.Manager, statusRecorder *peer.Status) (*serverRouter, error) { return &serverRouter{ ctx: ctx, routes: make(map[route.ID]*route.Route), diff --git a/client/internal/routemanager/sysctl/sysctl_linux.go b/client/internal/routemanager/sysctl/sysctl_linux.go index bb620ee68..ea63f02fc 100644 --- a/client/internal/routemanager/sysctl/sysctl_linux.go +++ b/client/internal/routemanager/sysctl/sysctl_linux.go @@ -13,7 +13,7 @@ import ( log "github.com/sirupsen/logrus" nberrors "github.com/netbirdio/netbird/client/errors" - "github.com/netbirdio/netbird/client/iface" + "github.com/netbirdio/netbird/client/internal/routemanager/iface" ) const ( @@ -23,7 +23,7 @@ const ( ) // Setup configures sysctl settings for RP filtering and source validation. -func Setup(wgIface iface.IWGIface) (map[string]int, error) { +func Setup(wgIface iface.WGIface) (map[string]int, error) { keys := map[string]int{} var result *multierror.Error diff --git a/client/internal/routemanager/systemops/systemops.go b/client/internal/routemanager/systemops/systemops.go index d1cb83bfb..5c117b94d 100644 --- a/client/internal/routemanager/systemops/systemops.go +++ b/client/internal/routemanager/systemops/systemops.go @@ -5,7 +5,7 @@ import ( "net/netip" "sync" - "github.com/netbirdio/netbird/client/iface" + "github.com/netbirdio/netbird/client/internal/routemanager/iface" "github.com/netbirdio/netbird/client/internal/routemanager/notifier" "github.com/netbirdio/netbird/client/internal/routemanager/refcounter" ) @@ -19,7 +19,7 @@ type ExclusionCounter = refcounter.Counter[netip.Prefix, struct{}, Nexthop] type SysOps struct { refCounter *ExclusionCounter - wgInterface iface.IWGIface + wgInterface iface.WGIface // prefixes is tracking all the current added prefixes im memory // (this is used in iOS as all route updates require a full table update) //nolint @@ -30,7 +30,7 @@ type SysOps struct { notifier *notifier.Notifier } -func NewSysOps(wgInterface iface.IWGIface, notifier *notifier.Notifier) *SysOps { +func NewSysOps(wgInterface iface.WGIface, notifier *notifier.Notifier) *SysOps { return &SysOps{ wgInterface: wgInterface, notifier: notifier, diff --git a/client/internal/routemanager/systemops/systemops_generic.go b/client/internal/routemanager/systemops/systemops_generic.go index 31b7f3ac2..eaef01815 100644 --- a/client/internal/routemanager/systemops/systemops_generic.go +++ b/client/internal/routemanager/systemops/systemops_generic.go @@ -16,8 +16,8 @@ import ( log "github.com/sirupsen/logrus" nberrors "github.com/netbirdio/netbird/client/errors" - "github.com/netbirdio/netbird/client/iface" "github.com/netbirdio/netbird/client/iface/netstack" + "github.com/netbirdio/netbird/client/internal/routemanager/iface" "github.com/netbirdio/netbird/client/internal/routemanager/refcounter" "github.com/netbirdio/netbird/client/internal/routemanager/util" "github.com/netbirdio/netbird/client/internal/routemanager/vars" @@ -149,7 +149,7 @@ func (r *SysOps) addRouteForCurrentDefaultGateway(prefix netip.Prefix) error { // addRouteToNonVPNIntf adds a new route to the routing table for the given prefix and returns the next hop and interface. // If the next hop or interface is pointing to the VPN interface, it will return the initial values. -func (r *SysOps) addRouteToNonVPNIntf(prefix netip.Prefix, vpnIntf iface.IWGIface, initialNextHop Nexthop) (Nexthop, error) { +func (r *SysOps) addRouteToNonVPNIntf(prefix netip.Prefix, vpnIntf iface.WGIface, initialNextHop Nexthop) (Nexthop, error) { addr := prefix.Addr() switch { case addr.IsLoopback(), diff --git a/client/internal/stdnet/filter.go b/client/internal/stdnet/filter.go index c04250b2d..e45714001 100644 --- a/client/internal/stdnet/filter.go +++ b/client/internal/stdnet/filter.go @@ -21,7 +21,6 @@ func InterfaceFilter(disallowList []string) func(string) bool { for _, s := range disallowList { if strings.HasPrefix(iFace, s) && runtime.GOOS != "ios" { - log.Tracef("ignoring interface %s - it is not allowed", iFace) return false } } diff --git a/client/internal/stdnet/stdnet.go b/client/internal/stdnet/stdnet.go index 2e87475a5..aa9fdd045 100644 --- a/client/internal/stdnet/stdnet.go +++ b/client/internal/stdnet/stdnet.go @@ -5,11 +5,16 @@ package stdnet import ( "fmt" + "slices" + "sync" + "time" "github.com/pion/transport/v3" "github.com/pion/transport/v3/stdnet" ) +const updateInterval = 30 * time.Second + // Net is an implementation of the net.Net interface // based on functions of the standard net package. type Net struct { @@ -18,6 +23,10 @@ type Net struct { iFaceDiscover iFaceDiscover // interfaceFilter should return true if the given interfaceName is allowed interfaceFilter func(interfaceName string) bool + lastUpdate time.Time + + // mu is shared between interfaces and lastUpdate + mu sync.Mutex } // NewNetWithDiscover creates a new StdNet instance. @@ -43,18 +52,40 @@ func NewNet(disallowList []string) (*Net, error) { // The interfaces are discovered by an external iFaceDiscover function or by a default discoverer if the external one // wasn't specified. func (n *Net) UpdateInterfaces() (err error) { + n.mu.Lock() + defer n.mu.Unlock() + + return n.updateInterfaces() +} + +func (n *Net) updateInterfaces() (err error) { allIfaces, err := n.iFaceDiscover.iFaces() if err != nil { return err } + n.interfaces = n.filterInterfaces(allIfaces) + + n.lastUpdate = time.Now() + return nil } // Interfaces returns a slice of interfaces which are available on the // system func (n *Net) Interfaces() ([]*transport.Interface, error) { - return n.interfaces, nil + n.mu.Lock() + defer n.mu.Unlock() + + if time.Since(n.lastUpdate) < updateInterval { + return slices.Clone(n.interfaces), nil + } + + if err := n.updateInterfaces(); err != nil { + return nil, fmt.Errorf("update interfaces: %w", err) + } + + return slices.Clone(n.interfaces), nil } // InterfaceByIndex returns the interface specified by index. @@ -63,6 +94,8 @@ func (n *Net) Interfaces() ([]*transport.Interface, error) { // sharing the logical data link; for more precision use // InterfaceByName. func (n *Net) InterfaceByIndex(index int) (*transport.Interface, error) { + n.mu.Lock() + defer n.mu.Unlock() for _, ifc := range n.interfaces { if ifc.Index == index { return ifc, nil @@ -74,6 +107,8 @@ func (n *Net) InterfaceByIndex(index int) (*transport.Interface, error) { // InterfaceByName returns the interface specified by name. func (n *Net) InterfaceByName(name string) (*transport.Interface, error) { + n.mu.Lock() + defer n.mu.Unlock() for _, ifc := range n.interfaces { if ifc.Name == name { return ifc, nil @@ -87,7 +122,7 @@ func (n *Net) filterInterfaces(interfaces []*transport.Interface) []*transport.I if n.interfaceFilter == nil { return interfaces } - result := []*transport.Interface{} + var result []*transport.Interface for _, iface := range interfaces { if n.interfaceFilter(iface.Name) { result = append(result, iface) diff --git a/client/proto/daemon.pb.go b/client/proto/daemon.pb.go index 0d7d7146a..d04d7a9c0 100644 --- a/client/proto/daemon.pb.go +++ b/client/proto/daemon.pb.go @@ -146,6 +146,7 @@ const ( SystemEvent_DNS SystemEvent_Category = 1 SystemEvent_AUTHENTICATION SystemEvent_Category = 2 SystemEvent_CONNECTIVITY SystemEvent_Category = 3 + SystemEvent_SYSTEM SystemEvent_Category = 4 ) // Enum value maps for SystemEvent_Category. @@ -155,12 +156,14 @@ var ( 1: "DNS", 2: "AUTHENTICATION", 3: "CONNECTIVITY", + 4: "SYSTEM", } SystemEvent_Category_value = map[string]int32{ "NETWORK": 0, "DNS": 1, "AUTHENTICATION": 2, "CONNECTIVITY": 3, + "SYSTEM": 4, } ) @@ -4020,7 +4023,7 @@ var file_daemon_proto_rawDesc = []byte{ 0x69, 0x6e, 0x61, 0x6c, 0x5f, 0x64, 0x69, 0x73, 0x70, 0x6f, 0x73, 0x69, 0x74, 0x69, 0x6f, 0x6e, 0x18, 0x02, 0x20, 0x01, 0x28, 0x08, 0x52, 0x10, 0x66, 0x69, 0x6e, 0x61, 0x6c, 0x44, 0x69, 0x73, 0x70, 0x6f, 0x73, 0x69, 0x74, 0x69, 0x6f, 0x6e, 0x22, 0x12, 0x0a, 0x10, 0x53, 0x75, 0x62, 0x73, - 0x63, 0x72, 0x69, 0x62, 0x65, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x22, 0x87, 0x04, 0x0a, + 0x63, 0x72, 0x69, 0x62, 0x65, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x22, 0x93, 0x04, 0x0a, 0x0b, 0x53, 0x79, 0x73, 0x74, 0x65, 0x6d, 0x45, 0x76, 0x65, 0x6e, 0x74, 0x12, 0x0e, 0x0a, 0x02, 0x69, 0x64, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x02, 0x69, 0x64, 0x12, 0x38, 0x0a, 0x08, 0x73, 0x65, 0x76, 0x65, 0x72, 0x69, 0x74, 0x79, 0x18, 0x02, 0x20, 0x01, 0x28, 0x0e, 0x32, 0x1c, @@ -4048,116 +4051,117 @@ var file_daemon_proto_rawDesc = []byte{ 0x38, 0x01, 0x22, 0x3a, 0x0a, 0x08, 0x53, 0x65, 0x76, 0x65, 0x72, 0x69, 0x74, 0x79, 0x12, 0x08, 0x0a, 0x04, 0x49, 0x4e, 0x46, 0x4f, 0x10, 0x00, 0x12, 0x0b, 0x0a, 0x07, 0x57, 0x41, 0x52, 0x4e, 0x49, 0x4e, 0x47, 0x10, 0x01, 0x12, 0x09, 0x0a, 0x05, 0x45, 0x52, 0x52, 0x4f, 0x52, 0x10, 0x02, - 0x12, 0x0c, 0x0a, 0x08, 0x43, 0x52, 0x49, 0x54, 0x49, 0x43, 0x41, 0x4c, 0x10, 0x03, 0x22, 0x46, + 0x12, 0x0c, 0x0a, 0x08, 0x43, 0x52, 0x49, 0x54, 0x49, 0x43, 0x41, 0x4c, 0x10, 0x03, 0x22, 0x52, 0x0a, 0x08, 0x43, 0x61, 0x74, 0x65, 0x67, 0x6f, 0x72, 0x79, 0x12, 0x0b, 0x0a, 0x07, 0x4e, 0x45, 0x54, 0x57, 0x4f, 0x52, 0x4b, 0x10, 0x00, 0x12, 0x07, 0x0a, 0x03, 0x44, 0x4e, 0x53, 0x10, 0x01, 0x12, 0x12, 0x0a, 0x0e, 0x41, 0x55, 0x54, 0x48, 0x45, 0x4e, 0x54, 0x49, 0x43, 0x41, 0x54, 0x49, 0x4f, 0x4e, 0x10, 0x02, 0x12, 0x10, 0x0a, 0x0c, 0x43, 0x4f, 0x4e, 0x4e, 0x45, 0x43, 0x54, 0x49, - 0x56, 0x49, 0x54, 0x59, 0x10, 0x03, 0x22, 0x12, 0x0a, 0x10, 0x47, 0x65, 0x74, 0x45, 0x76, 0x65, - 0x6e, 0x74, 0x73, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x22, 0x40, 0x0a, 0x11, 0x47, 0x65, - 0x74, 0x45, 0x76, 0x65, 0x6e, 0x74, 0x73, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x12, - 0x2b, 0x0a, 0x06, 0x65, 0x76, 0x65, 0x6e, 0x74, 0x73, 0x18, 0x01, 0x20, 0x03, 0x28, 0x0b, 0x32, - 0x13, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x53, 0x79, 0x73, 0x74, 0x65, 0x6d, 0x45, - 0x76, 0x65, 0x6e, 0x74, 0x52, 0x06, 0x65, 0x76, 0x65, 0x6e, 0x74, 0x73, 0x2a, 0x62, 0x0a, 0x08, - 0x4c, 0x6f, 0x67, 0x4c, 0x65, 0x76, 0x65, 0x6c, 0x12, 0x0b, 0x0a, 0x07, 0x55, 0x4e, 0x4b, 0x4e, - 0x4f, 0x57, 0x4e, 0x10, 0x00, 0x12, 0x09, 0x0a, 0x05, 0x50, 0x41, 0x4e, 0x49, 0x43, 0x10, 0x01, - 0x12, 0x09, 0x0a, 0x05, 0x46, 0x41, 0x54, 0x41, 0x4c, 0x10, 0x02, 0x12, 0x09, 0x0a, 0x05, 0x45, - 0x52, 0x52, 0x4f, 0x52, 0x10, 0x03, 0x12, 0x08, 0x0a, 0x04, 0x57, 0x41, 0x52, 0x4e, 0x10, 0x04, - 0x12, 0x08, 0x0a, 0x04, 0x49, 0x4e, 0x46, 0x4f, 0x10, 0x05, 0x12, 0x09, 0x0a, 0x05, 0x44, 0x45, - 0x42, 0x55, 0x47, 0x10, 0x06, 0x12, 0x09, 0x0a, 0x05, 0x54, 0x52, 0x41, 0x43, 0x45, 0x10, 0x07, - 0x32, 0xb3, 0x0b, 0x0a, 0x0d, 0x44, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x53, 0x65, 0x72, 0x76, 0x69, - 0x63, 0x65, 0x12, 0x36, 0x0a, 0x05, 0x4c, 0x6f, 0x67, 0x69, 0x6e, 0x12, 0x14, 0x2e, 0x64, 0x61, - 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x4c, 0x6f, 0x67, 0x69, 0x6e, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, - 0x74, 0x1a, 0x15, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x4c, 0x6f, 0x67, 0x69, 0x6e, - 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x22, 0x00, 0x12, 0x4b, 0x0a, 0x0c, 0x57, 0x61, - 0x69, 0x74, 0x53, 0x53, 0x4f, 0x4c, 0x6f, 0x67, 0x69, 0x6e, 0x12, 0x1b, 0x2e, 0x64, 0x61, 0x65, - 0x6d, 0x6f, 0x6e, 0x2e, 0x57, 0x61, 0x69, 0x74, 0x53, 0x53, 0x4f, 0x4c, 0x6f, 0x67, 0x69, 0x6e, - 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x1a, 0x1c, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, - 0x2e, 0x57, 0x61, 0x69, 0x74, 0x53, 0x53, 0x4f, 0x4c, 0x6f, 0x67, 0x69, 0x6e, 0x52, 0x65, 0x73, - 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x22, 0x00, 0x12, 0x2d, 0x0a, 0x02, 0x55, 0x70, 0x12, 0x11, 0x2e, - 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x55, 0x70, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, - 0x1a, 0x12, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x55, 0x70, 0x52, 0x65, 0x73, 0x70, - 0x6f, 0x6e, 0x73, 0x65, 0x22, 0x00, 0x12, 0x39, 0x0a, 0x06, 0x53, 0x74, 0x61, 0x74, 0x75, 0x73, - 0x12, 0x15, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x53, 0x74, 0x61, 0x74, 0x75, 0x73, - 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x1a, 0x16, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, - 0x2e, 0x53, 0x74, 0x61, 0x74, 0x75, 0x73, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x22, - 0x00, 0x12, 0x33, 0x0a, 0x04, 0x44, 0x6f, 0x77, 0x6e, 0x12, 0x13, 0x2e, 0x64, 0x61, 0x65, 0x6d, - 0x6f, 0x6e, 0x2e, 0x44, 0x6f, 0x77, 0x6e, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x1a, 0x14, - 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x44, 0x6f, 0x77, 0x6e, 0x52, 0x65, 0x73, 0x70, - 0x6f, 0x6e, 0x73, 0x65, 0x22, 0x00, 0x12, 0x42, 0x0a, 0x09, 0x47, 0x65, 0x74, 0x43, 0x6f, 0x6e, - 0x66, 0x69, 0x67, 0x12, 0x18, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x47, 0x65, 0x74, - 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x1a, 0x19, 0x2e, - 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x47, 0x65, 0x74, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, - 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x22, 0x00, 0x12, 0x4b, 0x0a, 0x0c, 0x4c, 0x69, - 0x73, 0x74, 0x4e, 0x65, 0x74, 0x77, 0x6f, 0x72, 0x6b, 0x73, 0x12, 0x1b, 0x2e, 0x64, 0x61, 0x65, - 0x6d, 0x6f, 0x6e, 0x2e, 0x4c, 0x69, 0x73, 0x74, 0x4e, 0x65, 0x74, 0x77, 0x6f, 0x72, 0x6b, 0x73, - 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x1a, 0x1c, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, - 0x2e, 0x4c, 0x69, 0x73, 0x74, 0x4e, 0x65, 0x74, 0x77, 0x6f, 0x72, 0x6b, 0x73, 0x52, 0x65, 0x73, - 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x22, 0x00, 0x12, 0x51, 0x0a, 0x0e, 0x53, 0x65, 0x6c, 0x65, 0x63, - 0x74, 0x4e, 0x65, 0x74, 0x77, 0x6f, 0x72, 0x6b, 0x73, 0x12, 0x1d, 0x2e, 0x64, 0x61, 0x65, 0x6d, + 0x56, 0x49, 0x54, 0x59, 0x10, 0x03, 0x12, 0x0a, 0x0a, 0x06, 0x53, 0x59, 0x53, 0x54, 0x45, 0x4d, + 0x10, 0x04, 0x22, 0x12, 0x0a, 0x10, 0x47, 0x65, 0x74, 0x45, 0x76, 0x65, 0x6e, 0x74, 0x73, 0x52, + 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x22, 0x40, 0x0a, 0x11, 0x47, 0x65, 0x74, 0x45, 0x76, 0x65, + 0x6e, 0x74, 0x73, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x12, 0x2b, 0x0a, 0x06, 0x65, + 0x76, 0x65, 0x6e, 0x74, 0x73, 0x18, 0x01, 0x20, 0x03, 0x28, 0x0b, 0x32, 0x13, 0x2e, 0x64, 0x61, + 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x53, 0x79, 0x73, 0x74, 0x65, 0x6d, 0x45, 0x76, 0x65, 0x6e, 0x74, + 0x52, 0x06, 0x65, 0x76, 0x65, 0x6e, 0x74, 0x73, 0x2a, 0x62, 0x0a, 0x08, 0x4c, 0x6f, 0x67, 0x4c, + 0x65, 0x76, 0x65, 0x6c, 0x12, 0x0b, 0x0a, 0x07, 0x55, 0x4e, 0x4b, 0x4e, 0x4f, 0x57, 0x4e, 0x10, + 0x00, 0x12, 0x09, 0x0a, 0x05, 0x50, 0x41, 0x4e, 0x49, 0x43, 0x10, 0x01, 0x12, 0x09, 0x0a, 0x05, + 0x46, 0x41, 0x54, 0x41, 0x4c, 0x10, 0x02, 0x12, 0x09, 0x0a, 0x05, 0x45, 0x52, 0x52, 0x4f, 0x52, + 0x10, 0x03, 0x12, 0x08, 0x0a, 0x04, 0x57, 0x41, 0x52, 0x4e, 0x10, 0x04, 0x12, 0x08, 0x0a, 0x04, + 0x49, 0x4e, 0x46, 0x4f, 0x10, 0x05, 0x12, 0x09, 0x0a, 0x05, 0x44, 0x45, 0x42, 0x55, 0x47, 0x10, + 0x06, 0x12, 0x09, 0x0a, 0x05, 0x54, 0x52, 0x41, 0x43, 0x45, 0x10, 0x07, 0x32, 0xb3, 0x0b, 0x0a, + 0x0d, 0x44, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x53, 0x65, 0x72, 0x76, 0x69, 0x63, 0x65, 0x12, 0x36, + 0x0a, 0x05, 0x4c, 0x6f, 0x67, 0x69, 0x6e, 0x12, 0x14, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, + 0x2e, 0x4c, 0x6f, 0x67, 0x69, 0x6e, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x1a, 0x15, 0x2e, + 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x4c, 0x6f, 0x67, 0x69, 0x6e, 0x52, 0x65, 0x73, 0x70, + 0x6f, 0x6e, 0x73, 0x65, 0x22, 0x00, 0x12, 0x4b, 0x0a, 0x0c, 0x57, 0x61, 0x69, 0x74, 0x53, 0x53, + 0x4f, 0x4c, 0x6f, 0x67, 0x69, 0x6e, 0x12, 0x1b, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, + 0x57, 0x61, 0x69, 0x74, 0x53, 0x53, 0x4f, 0x4c, 0x6f, 0x67, 0x69, 0x6e, 0x52, 0x65, 0x71, 0x75, + 0x65, 0x73, 0x74, 0x1a, 0x1c, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x57, 0x61, 0x69, + 0x74, 0x53, 0x53, 0x4f, 0x4c, 0x6f, 0x67, 0x69, 0x6e, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, + 0x65, 0x22, 0x00, 0x12, 0x2d, 0x0a, 0x02, 0x55, 0x70, 0x12, 0x11, 0x2e, 0x64, 0x61, 0x65, 0x6d, + 0x6f, 0x6e, 0x2e, 0x55, 0x70, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x1a, 0x12, 0x2e, 0x64, + 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x55, 0x70, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, + 0x22, 0x00, 0x12, 0x39, 0x0a, 0x06, 0x53, 0x74, 0x61, 0x74, 0x75, 0x73, 0x12, 0x15, 0x2e, 0x64, + 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x53, 0x74, 0x61, 0x74, 0x75, 0x73, 0x52, 0x65, 0x71, 0x75, + 0x65, 0x73, 0x74, 0x1a, 0x16, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x53, 0x74, 0x61, + 0x74, 0x75, 0x73, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x22, 0x00, 0x12, 0x33, 0x0a, + 0x04, 0x44, 0x6f, 0x77, 0x6e, 0x12, 0x13, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x44, + 0x6f, 0x77, 0x6e, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x1a, 0x14, 0x2e, 0x64, 0x61, 0x65, + 0x6d, 0x6f, 0x6e, 0x2e, 0x44, 0x6f, 0x77, 0x6e, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, + 0x22, 0x00, 0x12, 0x42, 0x0a, 0x09, 0x47, 0x65, 0x74, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x12, + 0x18, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x47, 0x65, 0x74, 0x43, 0x6f, 0x6e, 0x66, + 0x69, 0x67, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x1a, 0x19, 0x2e, 0x64, 0x61, 0x65, 0x6d, + 0x6f, 0x6e, 0x2e, 0x47, 0x65, 0x74, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x52, 0x65, 0x73, 0x70, + 0x6f, 0x6e, 0x73, 0x65, 0x22, 0x00, 0x12, 0x4b, 0x0a, 0x0c, 0x4c, 0x69, 0x73, 0x74, 0x4e, 0x65, + 0x74, 0x77, 0x6f, 0x72, 0x6b, 0x73, 0x12, 0x1b, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, + 0x4c, 0x69, 0x73, 0x74, 0x4e, 0x65, 0x74, 0x77, 0x6f, 0x72, 0x6b, 0x73, 0x52, 0x65, 0x71, 0x75, + 0x65, 0x73, 0x74, 0x1a, 0x1c, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x4c, 0x69, 0x73, + 0x74, 0x4e, 0x65, 0x74, 0x77, 0x6f, 0x72, 0x6b, 0x73, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, + 0x65, 0x22, 0x00, 0x12, 0x51, 0x0a, 0x0e, 0x53, 0x65, 0x6c, 0x65, 0x63, 0x74, 0x4e, 0x65, 0x74, + 0x77, 0x6f, 0x72, 0x6b, 0x73, 0x12, 0x1d, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x53, + 0x65, 0x6c, 0x65, 0x63, 0x74, 0x4e, 0x65, 0x74, 0x77, 0x6f, 0x72, 0x6b, 0x73, 0x52, 0x65, 0x71, + 0x75, 0x65, 0x73, 0x74, 0x1a, 0x1e, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x53, 0x65, + 0x6c, 0x65, 0x63, 0x74, 0x4e, 0x65, 0x74, 0x77, 0x6f, 0x72, 0x6b, 0x73, 0x52, 0x65, 0x73, 0x70, + 0x6f, 0x6e, 0x73, 0x65, 0x22, 0x00, 0x12, 0x53, 0x0a, 0x10, 0x44, 0x65, 0x73, 0x65, 0x6c, 0x65, + 0x63, 0x74, 0x4e, 0x65, 0x74, 0x77, 0x6f, 0x72, 0x6b, 0x73, 0x12, 0x1d, 0x2e, 0x64, 0x61, 0x65, + 0x6d, 0x6f, 0x6e, 0x2e, 0x53, 0x65, 0x6c, 0x65, 0x63, 0x74, 0x4e, 0x65, 0x74, 0x77, 0x6f, 0x72, + 0x6b, 0x73, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x1a, 0x1e, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x53, 0x65, 0x6c, 0x65, 0x63, 0x74, 0x4e, 0x65, 0x74, 0x77, 0x6f, 0x72, 0x6b, - 0x73, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x1a, 0x1e, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, - 0x6e, 0x2e, 0x53, 0x65, 0x6c, 0x65, 0x63, 0x74, 0x4e, 0x65, 0x74, 0x77, 0x6f, 0x72, 0x6b, 0x73, - 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x22, 0x00, 0x12, 0x53, 0x0a, 0x10, 0x44, 0x65, - 0x73, 0x65, 0x6c, 0x65, 0x63, 0x74, 0x4e, 0x65, 0x74, 0x77, 0x6f, 0x72, 0x6b, 0x73, 0x12, 0x1d, - 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x53, 0x65, 0x6c, 0x65, 0x63, 0x74, 0x4e, 0x65, - 0x74, 0x77, 0x6f, 0x72, 0x6b, 0x73, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x1a, 0x1e, 0x2e, - 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x53, 0x65, 0x6c, 0x65, 0x63, 0x74, 0x4e, 0x65, 0x74, - 0x77, 0x6f, 0x72, 0x6b, 0x73, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x22, 0x00, 0x12, - 0x4a, 0x0a, 0x0f, 0x46, 0x6f, 0x72, 0x77, 0x61, 0x72, 0x64, 0x69, 0x6e, 0x67, 0x52, 0x75, 0x6c, - 0x65, 0x73, 0x12, 0x14, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x45, 0x6d, 0x70, 0x74, - 0x79, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x1a, 0x1f, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, - 0x6e, 0x2e, 0x46, 0x6f, 0x72, 0x77, 0x61, 0x72, 0x64, 0x69, 0x6e, 0x67, 0x52, 0x75, 0x6c, 0x65, - 0x73, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x22, 0x00, 0x12, 0x48, 0x0a, 0x0b, 0x44, - 0x65, 0x62, 0x75, 0x67, 0x42, 0x75, 0x6e, 0x64, 0x6c, 0x65, 0x12, 0x1a, 0x2e, 0x64, 0x61, 0x65, - 0x6d, 0x6f, 0x6e, 0x2e, 0x44, 0x65, 0x62, 0x75, 0x67, 0x42, 0x75, 0x6e, 0x64, 0x6c, 0x65, 0x52, - 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x1a, 0x1b, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, - 0x44, 0x65, 0x62, 0x75, 0x67, 0x42, 0x75, 0x6e, 0x64, 0x6c, 0x65, 0x52, 0x65, 0x73, 0x70, 0x6f, - 0x6e, 0x73, 0x65, 0x22, 0x00, 0x12, 0x48, 0x0a, 0x0b, 0x47, 0x65, 0x74, 0x4c, 0x6f, 0x67, 0x4c, - 0x65, 0x76, 0x65, 0x6c, 0x12, 0x1a, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x47, 0x65, - 0x74, 0x4c, 0x6f, 0x67, 0x4c, 0x65, 0x76, 0x65, 0x6c, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, - 0x1a, 0x1b, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x47, 0x65, 0x74, 0x4c, 0x6f, 0x67, - 0x4c, 0x65, 0x76, 0x65, 0x6c, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x22, 0x00, 0x12, - 0x48, 0x0a, 0x0b, 0x53, 0x65, 0x74, 0x4c, 0x6f, 0x67, 0x4c, 0x65, 0x76, 0x65, 0x6c, 0x12, 0x1a, - 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x53, 0x65, 0x74, 0x4c, 0x6f, 0x67, 0x4c, 0x65, - 0x76, 0x65, 0x6c, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x1a, 0x1b, 0x2e, 0x64, 0x61, 0x65, + 0x73, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x22, 0x00, 0x12, 0x4a, 0x0a, 0x0f, 0x46, + 0x6f, 0x72, 0x77, 0x61, 0x72, 0x64, 0x69, 0x6e, 0x67, 0x52, 0x75, 0x6c, 0x65, 0x73, 0x12, 0x14, + 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x45, 0x6d, 0x70, 0x74, 0x79, 0x52, 0x65, 0x71, + 0x75, 0x65, 0x73, 0x74, 0x1a, 0x1f, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x46, 0x6f, + 0x72, 0x77, 0x61, 0x72, 0x64, 0x69, 0x6e, 0x67, 0x52, 0x75, 0x6c, 0x65, 0x73, 0x52, 0x65, 0x73, + 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x22, 0x00, 0x12, 0x48, 0x0a, 0x0b, 0x44, 0x65, 0x62, 0x75, 0x67, + 0x42, 0x75, 0x6e, 0x64, 0x6c, 0x65, 0x12, 0x1a, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, + 0x44, 0x65, 0x62, 0x75, 0x67, 0x42, 0x75, 0x6e, 0x64, 0x6c, 0x65, 0x52, 0x65, 0x71, 0x75, 0x65, + 0x73, 0x74, 0x1a, 0x1b, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x44, 0x65, 0x62, 0x75, + 0x67, 0x42, 0x75, 0x6e, 0x64, 0x6c, 0x65, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x22, + 0x00, 0x12, 0x48, 0x0a, 0x0b, 0x47, 0x65, 0x74, 0x4c, 0x6f, 0x67, 0x4c, 0x65, 0x76, 0x65, 0x6c, + 0x12, 0x1a, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x47, 0x65, 0x74, 0x4c, 0x6f, 0x67, + 0x4c, 0x65, 0x76, 0x65, 0x6c, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x1a, 0x1b, 0x2e, 0x64, + 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x47, 0x65, 0x74, 0x4c, 0x6f, 0x67, 0x4c, 0x65, 0x76, 0x65, + 0x6c, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x22, 0x00, 0x12, 0x48, 0x0a, 0x0b, 0x53, + 0x65, 0x74, 0x4c, 0x6f, 0x67, 0x4c, 0x65, 0x76, 0x65, 0x6c, 0x12, 0x1a, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x53, 0x65, 0x74, 0x4c, 0x6f, 0x67, 0x4c, 0x65, 0x76, 0x65, 0x6c, 0x52, - 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x22, 0x00, 0x12, 0x45, 0x0a, 0x0a, 0x4c, 0x69, 0x73, - 0x74, 0x53, 0x74, 0x61, 0x74, 0x65, 0x73, 0x12, 0x19, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, - 0x2e, 0x4c, 0x69, 0x73, 0x74, 0x53, 0x74, 0x61, 0x74, 0x65, 0x73, 0x52, 0x65, 0x71, 0x75, 0x65, - 0x73, 0x74, 0x1a, 0x1a, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x4c, 0x69, 0x73, 0x74, - 0x53, 0x74, 0x61, 0x74, 0x65, 0x73, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x22, 0x00, - 0x12, 0x45, 0x0a, 0x0a, 0x43, 0x6c, 0x65, 0x61, 0x6e, 0x53, 0x74, 0x61, 0x74, 0x65, 0x12, 0x19, - 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x43, 0x6c, 0x65, 0x61, 0x6e, 0x53, 0x74, 0x61, - 0x74, 0x65, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x1a, 0x1a, 0x2e, 0x64, 0x61, 0x65, 0x6d, - 0x6f, 0x6e, 0x2e, 0x43, 0x6c, 0x65, 0x61, 0x6e, 0x53, 0x74, 0x61, 0x74, 0x65, 0x52, 0x65, 0x73, - 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x22, 0x00, 0x12, 0x48, 0x0a, 0x0b, 0x44, 0x65, 0x6c, 0x65, 0x74, - 0x65, 0x53, 0x74, 0x61, 0x74, 0x65, 0x12, 0x1a, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, - 0x44, 0x65, 0x6c, 0x65, 0x74, 0x65, 0x53, 0x74, 0x61, 0x74, 0x65, 0x52, 0x65, 0x71, 0x75, 0x65, - 0x73, 0x74, 0x1a, 0x1b, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x44, 0x65, 0x6c, 0x65, - 0x74, 0x65, 0x53, 0x74, 0x61, 0x74, 0x65, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x22, - 0x00, 0x12, 0x6f, 0x0a, 0x18, 0x53, 0x65, 0x74, 0x4e, 0x65, 0x74, 0x77, 0x6f, 0x72, 0x6b, 0x4d, - 0x61, 0x70, 0x50, 0x65, 0x72, 0x73, 0x69, 0x73, 0x74, 0x65, 0x6e, 0x63, 0x65, 0x12, 0x27, 0x2e, - 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x53, 0x65, 0x74, 0x4e, 0x65, 0x74, 0x77, 0x6f, 0x72, - 0x6b, 0x4d, 0x61, 0x70, 0x50, 0x65, 0x72, 0x73, 0x69, 0x73, 0x74, 0x65, 0x6e, 0x63, 0x65, 0x52, - 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x1a, 0x28, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, - 0x53, 0x65, 0x74, 0x4e, 0x65, 0x74, 0x77, 0x6f, 0x72, 0x6b, 0x4d, 0x61, 0x70, 0x50, 0x65, 0x72, - 0x73, 0x69, 0x73, 0x74, 0x65, 0x6e, 0x63, 0x65, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, - 0x22, 0x00, 0x12, 0x48, 0x0a, 0x0b, 0x54, 0x72, 0x61, 0x63, 0x65, 0x50, 0x61, 0x63, 0x6b, 0x65, - 0x74, 0x12, 0x1a, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x54, 0x72, 0x61, 0x63, 0x65, - 0x50, 0x61, 0x63, 0x6b, 0x65, 0x74, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x1a, 0x1b, 0x2e, + 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x1a, 0x1b, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, + 0x53, 0x65, 0x74, 0x4c, 0x6f, 0x67, 0x4c, 0x65, 0x76, 0x65, 0x6c, 0x52, 0x65, 0x73, 0x70, 0x6f, + 0x6e, 0x73, 0x65, 0x22, 0x00, 0x12, 0x45, 0x0a, 0x0a, 0x4c, 0x69, 0x73, 0x74, 0x53, 0x74, 0x61, + 0x74, 0x65, 0x73, 0x12, 0x19, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x4c, 0x69, 0x73, + 0x74, 0x53, 0x74, 0x61, 0x74, 0x65, 0x73, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x1a, 0x1a, + 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x4c, 0x69, 0x73, 0x74, 0x53, 0x74, 0x61, 0x74, + 0x65, 0x73, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x22, 0x00, 0x12, 0x45, 0x0a, 0x0a, + 0x43, 0x6c, 0x65, 0x61, 0x6e, 0x53, 0x74, 0x61, 0x74, 0x65, 0x12, 0x19, 0x2e, 0x64, 0x61, 0x65, + 0x6d, 0x6f, 0x6e, 0x2e, 0x43, 0x6c, 0x65, 0x61, 0x6e, 0x53, 0x74, 0x61, 0x74, 0x65, 0x52, 0x65, + 0x71, 0x75, 0x65, 0x73, 0x74, 0x1a, 0x1a, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x43, + 0x6c, 0x65, 0x61, 0x6e, 0x53, 0x74, 0x61, 0x74, 0x65, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, + 0x65, 0x22, 0x00, 0x12, 0x48, 0x0a, 0x0b, 0x44, 0x65, 0x6c, 0x65, 0x74, 0x65, 0x53, 0x74, 0x61, + 0x74, 0x65, 0x12, 0x1a, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x44, 0x65, 0x6c, 0x65, + 0x74, 0x65, 0x53, 0x74, 0x61, 0x74, 0x65, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x1a, 0x1b, + 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x44, 0x65, 0x6c, 0x65, 0x74, 0x65, 0x53, 0x74, + 0x61, 0x74, 0x65, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x22, 0x00, 0x12, 0x6f, 0x0a, + 0x18, 0x53, 0x65, 0x74, 0x4e, 0x65, 0x74, 0x77, 0x6f, 0x72, 0x6b, 0x4d, 0x61, 0x70, 0x50, 0x65, + 0x72, 0x73, 0x69, 0x73, 0x74, 0x65, 0x6e, 0x63, 0x65, 0x12, 0x27, 0x2e, 0x64, 0x61, 0x65, 0x6d, + 0x6f, 0x6e, 0x2e, 0x53, 0x65, 0x74, 0x4e, 0x65, 0x74, 0x77, 0x6f, 0x72, 0x6b, 0x4d, 0x61, 0x70, + 0x50, 0x65, 0x72, 0x73, 0x69, 0x73, 0x74, 0x65, 0x6e, 0x63, 0x65, 0x52, 0x65, 0x71, 0x75, 0x65, + 0x73, 0x74, 0x1a, 0x28, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x53, 0x65, 0x74, 0x4e, + 0x65, 0x74, 0x77, 0x6f, 0x72, 0x6b, 0x4d, 0x61, 0x70, 0x50, 0x65, 0x72, 0x73, 0x69, 0x73, 0x74, + 0x65, 0x6e, 0x63, 0x65, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x22, 0x00, 0x12, 0x48, + 0x0a, 0x0b, 0x54, 0x72, 0x61, 0x63, 0x65, 0x50, 0x61, 0x63, 0x6b, 0x65, 0x74, 0x12, 0x1a, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x54, 0x72, 0x61, 0x63, 0x65, 0x50, 0x61, 0x63, 0x6b, - 0x65, 0x74, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x22, 0x00, 0x12, 0x44, 0x0a, 0x0f, - 0x53, 0x75, 0x62, 0x73, 0x63, 0x72, 0x69, 0x62, 0x65, 0x45, 0x76, 0x65, 0x6e, 0x74, 0x73, 0x12, - 0x18, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x53, 0x75, 0x62, 0x73, 0x63, 0x72, 0x69, - 0x62, 0x65, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x1a, 0x13, 0x2e, 0x64, 0x61, 0x65, 0x6d, - 0x6f, 0x6e, 0x2e, 0x53, 0x79, 0x73, 0x74, 0x65, 0x6d, 0x45, 0x76, 0x65, 0x6e, 0x74, 0x22, 0x00, - 0x30, 0x01, 0x12, 0x42, 0x0a, 0x09, 0x47, 0x65, 0x74, 0x45, 0x76, 0x65, 0x6e, 0x74, 0x73, 0x12, - 0x18, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x47, 0x65, 0x74, 0x45, 0x76, 0x65, 0x6e, - 0x74, 0x73, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x1a, 0x19, 0x2e, 0x64, 0x61, 0x65, 0x6d, - 0x6f, 0x6e, 0x2e, 0x47, 0x65, 0x74, 0x45, 0x76, 0x65, 0x6e, 0x74, 0x73, 0x52, 0x65, 0x73, 0x70, - 0x6f, 0x6e, 0x73, 0x65, 0x22, 0x00, 0x42, 0x08, 0x5a, 0x06, 0x2f, 0x70, 0x72, 0x6f, 0x74, 0x6f, - 0x62, 0x06, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x33, + 0x65, 0x74, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x1a, 0x1b, 0x2e, 0x64, 0x61, 0x65, 0x6d, + 0x6f, 0x6e, 0x2e, 0x54, 0x72, 0x61, 0x63, 0x65, 0x50, 0x61, 0x63, 0x6b, 0x65, 0x74, 0x52, 0x65, + 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x22, 0x00, 0x12, 0x44, 0x0a, 0x0f, 0x53, 0x75, 0x62, 0x73, + 0x63, 0x72, 0x69, 0x62, 0x65, 0x45, 0x76, 0x65, 0x6e, 0x74, 0x73, 0x12, 0x18, 0x2e, 0x64, 0x61, + 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x53, 0x75, 0x62, 0x73, 0x63, 0x72, 0x69, 0x62, 0x65, 0x52, 0x65, + 0x71, 0x75, 0x65, 0x73, 0x74, 0x1a, 0x13, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x53, + 0x79, 0x73, 0x74, 0x65, 0x6d, 0x45, 0x76, 0x65, 0x6e, 0x74, 0x22, 0x00, 0x30, 0x01, 0x12, 0x42, + 0x0a, 0x09, 0x47, 0x65, 0x74, 0x45, 0x76, 0x65, 0x6e, 0x74, 0x73, 0x12, 0x18, 0x2e, 0x64, 0x61, + 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x47, 0x65, 0x74, 0x45, 0x76, 0x65, 0x6e, 0x74, 0x73, 0x52, 0x65, + 0x71, 0x75, 0x65, 0x73, 0x74, 0x1a, 0x19, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x47, + 0x65, 0x74, 0x45, 0x76, 0x65, 0x6e, 0x74, 0x73, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, + 0x22, 0x00, 0x42, 0x08, 0x5a, 0x06, 0x2f, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x62, 0x06, 0x70, 0x72, + 0x6f, 0x74, 0x6f, 0x33, } var ( diff --git a/client/proto/daemon.proto b/client/proto/daemon.proto index 4f2032d00..49e577853 100644 --- a/client/proto/daemon.proto +++ b/client/proto/daemon.proto @@ -457,6 +457,7 @@ message SystemEvent { DNS = 1; AUTHENTICATION = 2; CONNECTIVITY = 3; + SYSTEM = 4; } string id = 1; diff --git a/client/server/network.go b/client/server/network.go index aaf361524..d310f4da1 100644 --- a/client/server/network.go +++ b/client/server/network.go @@ -6,6 +6,7 @@ import ( "net/netip" "slices" "sort" + "strings" "golang.org/x/exp/maps" @@ -134,6 +135,18 @@ func (s *Server) SelectNetworks(_ context.Context, req *proto.SelectNetworksRequ } routeManager.TriggerSelection(routeManager.GetClientRoutes()) + s.statusRecorder.PublishEvent( + proto.SystemEvent_INFO, + proto.SystemEvent_SYSTEM, + "Network selection changed", + "", + map[string]string{ + "networks": strings.Join(req.GetNetworkIDs(), ", "), + "append": fmt.Sprint(req.GetAppend()), + "all": fmt.Sprint(req.GetAll()), + }, + ) + return &proto.SelectNetworksResponse{}, nil } @@ -164,6 +177,18 @@ func (s *Server) DeselectNetworks(_ context.Context, req *proto.SelectNetworksRe } routeManager.TriggerSelection(routeManager.GetClientRoutes()) + s.statusRecorder.PublishEvent( + proto.SystemEvent_INFO, + proto.SystemEvent_SYSTEM, + "Network deselection changed", + "", + map[string]string{ + "networks": strings.Join(req.GetNetworkIDs(), ", "), + "append": fmt.Sprint(req.GetAppend()), + "all": fmt.Sprint(req.GetAll()), + }, + ) + return &proto.SelectNetworksResponse{}, nil } diff --git a/client/server/server.go b/client/server/server.go index dcc6e5651..8907f541f 100644 --- a/client/server/server.go +++ b/client/server/server.go @@ -751,6 +751,11 @@ func (s *Server) GetConfig(_ context.Context, _ *proto.GetConfigRequest) (*proto } + disableNotifications := true + if s.config.DisableNotifications != nil { + disableNotifications = *s.config.DisableNotifications + } + return &proto.GetConfigResponse{ ManagementUrl: managementURL, ConfigFile: s.latestConfigInput.ConfigPath, @@ -763,13 +768,14 @@ func (s *Server) GetConfig(_ context.Context, _ *proto.GetConfigRequest) (*proto ServerSSHAllowed: *s.config.ServerSSHAllowed, RosenpassEnabled: s.config.RosenpassEnabled, RosenpassPermissive: s.config.RosenpassPermissive, - DisableNotifications: s.config.DisableNotifications, + DisableNotifications: disableNotifications, }, nil } + func (s *Server) onSessionExpire() { if runtime.GOOS != "windows" { isUIActive := internal.CheckUIApp() - if !isUIActive { + if !isUIActive && s.config.DisableNotifications != nil && !*s.config.DisableNotifications { if err := sendTerminalNotification(); err != nil { log.Errorf("send session expire terminal notification: %v", err) } diff --git a/client/server/server_test.go b/client/server/server_test.go index 278ec246c..0c0f32fec 100644 --- a/client/server/server_test.go +++ b/client/server/server_test.go @@ -135,7 +135,7 @@ func startManagement(t *testing.T, signalAddr string, counter *int) (*grpc.Serve } secretsManager := server.NewTimeBasedAuthSecretsManager(peersUpdateManager, config.TURNConfig, config.Relay) - mgmtServer, err := server.NewServer(context.Background(), config, accountManager, settings.NewManager(store), peersUpdateManager, secretsManager, nil, nil) + mgmtServer, err := server.NewServer(context.Background(), config, accountManager, settings.NewManager(store), peersUpdateManager, secretsManager, nil, nil, nil) if err != nil { return nil, "", err } diff --git a/client/cmd/status_event.go b/client/status/event.go similarity index 86% rename from client/cmd/status_event.go rename to client/status/event.go index 9331570e6..2b65c9fa3 100644 --- a/client/cmd/status_event.go +++ b/client/status/event.go @@ -1,4 +1,4 @@ -package cmd +package status import ( "fmt" @@ -9,7 +9,7 @@ import ( "github.com/netbirdio/netbird/client/proto" ) -type systemEventOutput struct { +type SystemEventOutput struct { ID string `json:"id" yaml:"id"` Severity string `json:"severity" yaml:"severity"` Category string `json:"category" yaml:"category"` @@ -19,10 +19,10 @@ type systemEventOutput struct { Metadata map[string]string `json:"metadata" yaml:"metadata"` } -func mapEvents(protoEvents []*proto.SystemEvent) []systemEventOutput { - events := make([]systemEventOutput, len(protoEvents)) +func mapEvents(protoEvents []*proto.SystemEvent) []SystemEventOutput { + events := make([]SystemEventOutput, len(protoEvents)) for i, event := range protoEvents { - events[i] = systemEventOutput{ + events[i] = SystemEventOutput{ ID: event.GetId(), Severity: event.GetSeverity().String(), Category: event.GetCategory().String(), @@ -35,7 +35,7 @@ func mapEvents(protoEvents []*proto.SystemEvent) []systemEventOutput { return events } -func parseEvents(events []systemEventOutput) string { +func parseEvents(events []SystemEventOutput) string { if len(events) == 0 { return " No events recorded" } diff --git a/client/status/status.go b/client/status/status.go new file mode 100644 index 000000000..43acc9197 --- /dev/null +++ b/client/status/status.go @@ -0,0 +1,729 @@ +package status + +import ( + "encoding/json" + "fmt" + "net" + "net/netip" + "os" + "runtime" + "sort" + "strings" + "time" + + "gopkg.in/yaml.v3" + + "github.com/netbirdio/netbird/client/anonymize" + "github.com/netbirdio/netbird/client/internal/peer" + "github.com/netbirdio/netbird/client/proto" + "github.com/netbirdio/netbird/version" +) + +type PeerStateDetailOutput struct { + FQDN string `json:"fqdn" yaml:"fqdn"` + IP string `json:"netbirdIp" yaml:"netbirdIp"` + PubKey string `json:"publicKey" yaml:"publicKey"` + Status string `json:"status" yaml:"status"` + LastStatusUpdate time.Time `json:"lastStatusUpdate" yaml:"lastStatusUpdate"` + ConnType string `json:"connectionType" yaml:"connectionType"` + IceCandidateType IceCandidateType `json:"iceCandidateType" yaml:"iceCandidateType"` + IceCandidateEndpoint IceCandidateType `json:"iceCandidateEndpoint" yaml:"iceCandidateEndpoint"` + RelayAddress string `json:"relayAddress" yaml:"relayAddress"` + LastWireguardHandshake time.Time `json:"lastWireguardHandshake" yaml:"lastWireguardHandshake"` + TransferReceived int64 `json:"transferReceived" yaml:"transferReceived"` + TransferSent int64 `json:"transferSent" yaml:"transferSent"` + Latency time.Duration `json:"latency" yaml:"latency"` + RosenpassEnabled bool `json:"quantumResistance" yaml:"quantumResistance"` + Networks []string `json:"networks" yaml:"networks"` +} + +type PeersStateOutput struct { + Total int `json:"total" yaml:"total"` + Connected int `json:"connected" yaml:"connected"` + Details []PeerStateDetailOutput `json:"details" yaml:"details"` +} + +type SignalStateOutput struct { + URL string `json:"url" yaml:"url"` + Connected bool `json:"connected" yaml:"connected"` + Error string `json:"error" yaml:"error"` +} + +type ManagementStateOutput struct { + URL string `json:"url" yaml:"url"` + Connected bool `json:"connected" yaml:"connected"` + Error string `json:"error" yaml:"error"` +} + +type RelayStateOutputDetail struct { + URI string `json:"uri" yaml:"uri"` + Available bool `json:"available" yaml:"available"` + Error string `json:"error" yaml:"error"` +} + +type RelayStateOutput struct { + Total int `json:"total" yaml:"total"` + Available int `json:"available" yaml:"available"` + Details []RelayStateOutputDetail `json:"details" yaml:"details"` +} + +type IceCandidateType struct { + Local string `json:"local" yaml:"local"` + Remote string `json:"remote" yaml:"remote"` +} + +type NsServerGroupStateOutput struct { + Servers []string `json:"servers" yaml:"servers"` + Domains []string `json:"domains" yaml:"domains"` + Enabled bool `json:"enabled" yaml:"enabled"` + Error string `json:"error" yaml:"error"` +} + +type OutputOverview struct { + Peers PeersStateOutput `json:"peers" yaml:"peers"` + CliVersion string `json:"cliVersion" yaml:"cliVersion"` + DaemonVersion string `json:"daemonVersion" yaml:"daemonVersion"` + ManagementState ManagementStateOutput `json:"management" yaml:"management"` + SignalState SignalStateOutput `json:"signal" yaml:"signal"` + Relays RelayStateOutput `json:"relays" yaml:"relays"` + IP string `json:"netbirdIp" yaml:"netbirdIp"` + PubKey string `json:"publicKey" yaml:"publicKey"` + KernelInterface bool `json:"usesKernelInterface" yaml:"usesKernelInterface"` + FQDN string `json:"fqdn" yaml:"fqdn"` + RosenpassEnabled bool `json:"quantumResistance" yaml:"quantumResistance"` + RosenpassPermissive bool `json:"quantumResistancePermissive" yaml:"quantumResistancePermissive"` + Networks []string `json:"networks" yaml:"networks"` + NumberOfForwardingRules int `json:"forwardingRules" yaml:"forwardingRules"` + NSServerGroups []NsServerGroupStateOutput `json:"dnsServers" yaml:"dnsServers"` + Events []SystemEventOutput `json:"events" yaml:"events"` +} + +func ConvertToStatusOutputOverview(resp *proto.StatusResponse, anon bool, statusFilter string, prefixNamesFilter []string, prefixNamesFilterMap map[string]struct{}, ipsFilter map[string]struct{}) OutputOverview { + pbFullStatus := resp.GetFullStatus() + + managementState := pbFullStatus.GetManagementState() + managementOverview := ManagementStateOutput{ + URL: managementState.GetURL(), + Connected: managementState.GetConnected(), + Error: managementState.Error, + } + + signalState := pbFullStatus.GetSignalState() + signalOverview := SignalStateOutput{ + URL: signalState.GetURL(), + Connected: signalState.GetConnected(), + Error: signalState.Error, + } + + relayOverview := mapRelays(pbFullStatus.GetRelays()) + peersOverview := mapPeers(resp.GetFullStatus().GetPeers(), statusFilter, prefixNamesFilter, prefixNamesFilterMap, ipsFilter) + + overview := OutputOverview{ + Peers: peersOverview, + CliVersion: version.NetbirdVersion(), + DaemonVersion: resp.GetDaemonVersion(), + ManagementState: managementOverview, + SignalState: signalOverview, + Relays: relayOverview, + IP: pbFullStatus.GetLocalPeerState().GetIP(), + PubKey: pbFullStatus.GetLocalPeerState().GetPubKey(), + KernelInterface: pbFullStatus.GetLocalPeerState().GetKernelInterface(), + FQDN: pbFullStatus.GetLocalPeerState().GetFqdn(), + RosenpassEnabled: pbFullStatus.GetLocalPeerState().GetRosenpassEnabled(), + RosenpassPermissive: pbFullStatus.GetLocalPeerState().GetRosenpassPermissive(), + Networks: pbFullStatus.GetLocalPeerState().GetNetworks(), + NumberOfForwardingRules: int(pbFullStatus.GetNumberOfForwardingRules()), + NSServerGroups: mapNSGroups(pbFullStatus.GetDnsServers()), + Events: mapEvents(pbFullStatus.GetEvents()), + } + + if anon { + anonymizer := anonymize.NewAnonymizer(anonymize.DefaultAddresses()) + anonymizeOverview(anonymizer, &overview) + } + + return overview +} + +func mapRelays(relays []*proto.RelayState) RelayStateOutput { + var relayStateDetail []RelayStateOutputDetail + + var relaysAvailable int + for _, relay := range relays { + available := relay.GetAvailable() + relayStateDetail = append(relayStateDetail, + RelayStateOutputDetail{ + URI: relay.URI, + Available: available, + Error: relay.GetError(), + }, + ) + + if available { + relaysAvailable++ + } + } + + return RelayStateOutput{ + Total: len(relays), + Available: relaysAvailable, + Details: relayStateDetail, + } +} + +func mapNSGroups(servers []*proto.NSGroupState) []NsServerGroupStateOutput { + mappedNSGroups := make([]NsServerGroupStateOutput, 0, len(servers)) + for _, pbNsGroupServer := range servers { + mappedNSGroups = append(mappedNSGroups, NsServerGroupStateOutput{ + Servers: pbNsGroupServer.GetServers(), + Domains: pbNsGroupServer.GetDomains(), + Enabled: pbNsGroupServer.GetEnabled(), + Error: pbNsGroupServer.GetError(), + }) + } + return mappedNSGroups +} + +func mapPeers( + peers []*proto.PeerState, + statusFilter string, + prefixNamesFilter []string, + prefixNamesFilterMap map[string]struct{}, + ipsFilter map[string]struct{}, +) PeersStateOutput { + var peersStateDetail []PeerStateDetailOutput + peersConnected := 0 + for _, pbPeerState := range peers { + localICE := "" + remoteICE := "" + localICEEndpoint := "" + remoteICEEndpoint := "" + relayServerAddress := "" + connType := "" + lastHandshake := time.Time{} + transferReceived := int64(0) + transferSent := int64(0) + + isPeerConnected := pbPeerState.ConnStatus == peer.StatusConnected.String() + if skipDetailByFilters(pbPeerState, isPeerConnected, statusFilter, prefixNamesFilter, prefixNamesFilterMap, ipsFilter) { + continue + } + if isPeerConnected { + peersConnected++ + + localICE = pbPeerState.GetLocalIceCandidateType() + remoteICE = pbPeerState.GetRemoteIceCandidateType() + localICEEndpoint = pbPeerState.GetLocalIceCandidateEndpoint() + remoteICEEndpoint = pbPeerState.GetRemoteIceCandidateEndpoint() + connType = "P2P" + if pbPeerState.Relayed { + connType = "Relayed" + } + relayServerAddress = pbPeerState.GetRelayAddress() + lastHandshake = pbPeerState.GetLastWireguardHandshake().AsTime().Local() + transferReceived = pbPeerState.GetBytesRx() + transferSent = pbPeerState.GetBytesTx() + } + + timeLocal := pbPeerState.GetConnStatusUpdate().AsTime().Local() + peerState := PeerStateDetailOutput{ + IP: pbPeerState.GetIP(), + PubKey: pbPeerState.GetPubKey(), + Status: pbPeerState.GetConnStatus(), + LastStatusUpdate: timeLocal, + ConnType: connType, + IceCandidateType: IceCandidateType{ + Local: localICE, + Remote: remoteICE, + }, + IceCandidateEndpoint: IceCandidateType{ + Local: localICEEndpoint, + Remote: remoteICEEndpoint, + }, + RelayAddress: relayServerAddress, + FQDN: pbPeerState.GetFqdn(), + LastWireguardHandshake: lastHandshake, + TransferReceived: transferReceived, + TransferSent: transferSent, + Latency: pbPeerState.GetLatency().AsDuration(), + RosenpassEnabled: pbPeerState.GetRosenpassEnabled(), + Networks: pbPeerState.GetNetworks(), + } + + peersStateDetail = append(peersStateDetail, peerState) + } + + sortPeersByIP(peersStateDetail) + + peersOverview := PeersStateOutput{ + Total: len(peersStateDetail), + Connected: peersConnected, + Details: peersStateDetail, + } + return peersOverview +} + +func sortPeersByIP(peersStateDetail []PeerStateDetailOutput) { + if len(peersStateDetail) > 0 { + sort.SliceStable(peersStateDetail, func(i, j int) bool { + iAddr, _ := netip.ParseAddr(peersStateDetail[i].IP) + jAddr, _ := netip.ParseAddr(peersStateDetail[j].IP) + return iAddr.Compare(jAddr) == -1 + }) + } +} + +func ParseToJSON(overview OutputOverview) (string, error) { + jsonBytes, err := json.Marshal(overview) + if err != nil { + return "", fmt.Errorf("json marshal failed") + } + return string(jsonBytes), err +} + +func ParseToYAML(overview OutputOverview) (string, error) { + yamlBytes, err := yaml.Marshal(overview) + if err != nil { + return "", fmt.Errorf("yaml marshal failed") + } + return string(yamlBytes), nil +} + +func ParseGeneralSummary(overview OutputOverview, showURL bool, showRelays bool, showNameServers bool) string { + var managementConnString string + if overview.ManagementState.Connected { + managementConnString = "Connected" + if showURL { + managementConnString = fmt.Sprintf("%s to %s", managementConnString, overview.ManagementState.URL) + } + } else { + managementConnString = "Disconnected" + if overview.ManagementState.Error != "" { + managementConnString = fmt.Sprintf("%s, reason: %s", managementConnString, overview.ManagementState.Error) + } + } + + var signalConnString string + if overview.SignalState.Connected { + signalConnString = "Connected" + if showURL { + signalConnString = fmt.Sprintf("%s to %s", signalConnString, overview.SignalState.URL) + } + } else { + signalConnString = "Disconnected" + if overview.SignalState.Error != "" { + signalConnString = fmt.Sprintf("%s, reason: %s", signalConnString, overview.SignalState.Error) + } + } + + interfaceTypeString := "Userspace" + interfaceIP := overview.IP + if overview.KernelInterface { + interfaceTypeString = "Kernel" + } else if overview.IP == "" { + interfaceTypeString = "N/A" + interfaceIP = "N/A" + } + + var relaysString string + if showRelays { + for _, relay := range overview.Relays.Details { + available := "Available" + reason := "" + if !relay.Available { + available = "Unavailable" + reason = fmt.Sprintf(", reason: %s", relay.Error) + } + relaysString += fmt.Sprintf("\n [%s] is %s%s", relay.URI, available, reason) + } + } else { + relaysString = fmt.Sprintf("%d/%d Available", overview.Relays.Available, overview.Relays.Total) + } + + networks := "-" + if len(overview.Networks) > 0 { + sort.Strings(overview.Networks) + networks = strings.Join(overview.Networks, ", ") + } + + var dnsServersString string + if showNameServers { + for _, nsServerGroup := range overview.NSServerGroups { + enabled := "Available" + if !nsServerGroup.Enabled { + enabled = "Unavailable" + } + errorString := "" + if nsServerGroup.Error != "" { + errorString = fmt.Sprintf(", reason: %s", nsServerGroup.Error) + errorString = strings.TrimSpace(errorString) + } + + domainsString := strings.Join(nsServerGroup.Domains, ", ") + if domainsString == "" { + domainsString = "." // Show "." for the default zone + } + dnsServersString += fmt.Sprintf( + "\n [%s] for [%s] is %s%s", + strings.Join(nsServerGroup.Servers, ", "), + domainsString, + enabled, + errorString, + ) + } + } else { + dnsServersString = fmt.Sprintf("%d/%d Available", countEnabled(overview.NSServerGroups), len(overview.NSServerGroups)) + } + + rosenpassEnabledStatus := "false" + if overview.RosenpassEnabled { + rosenpassEnabledStatus = "true" + if overview.RosenpassPermissive { + rosenpassEnabledStatus = "true (permissive)" //nolint:gosec + } + } + + peersCountString := fmt.Sprintf("%d/%d Connected", overview.Peers.Connected, overview.Peers.Total) + + goos := runtime.GOOS + goarch := runtime.GOARCH + goarm := "" + if goarch == "arm" { + goarm = fmt.Sprintf(" (ARMv%s)", os.Getenv("GOARM")) + } + + summary := fmt.Sprintf( + "OS: %s\n"+ + "Daemon version: %s\n"+ + "CLI version: %s\n"+ + "Management: %s\n"+ + "Signal: %s\n"+ + "Relays: %s\n"+ + "Nameservers: %s\n"+ + "FQDN: %s\n"+ + "NetBird IP: %s\n"+ + "Interface type: %s\n"+ + "Quantum resistance: %s\n"+ + "Networks: %s\n"+ + "Forwarding rules: %d\n"+ + "Peers count: %s\n", + fmt.Sprintf("%s/%s%s", goos, goarch, goarm), + overview.DaemonVersion, + version.NetbirdVersion(), + managementConnString, + signalConnString, + relaysString, + dnsServersString, + overview.FQDN, + interfaceIP, + interfaceTypeString, + rosenpassEnabledStatus, + networks, + overview.NumberOfForwardingRules, + peersCountString, + ) + return summary +} + +func ParseToFullDetailSummary(overview OutputOverview) string { + parsedPeersString := parsePeers(overview.Peers, overview.RosenpassEnabled, overview.RosenpassPermissive) + parsedEventsString := parseEvents(overview.Events) + summary := ParseGeneralSummary(overview, true, true, true) + + return fmt.Sprintf( + "Peers detail:"+ + "%s\n"+ + "Events:"+ + "%s\n"+ + "%s", + parsedPeersString, + parsedEventsString, + summary, + ) +} + +func parsePeers(peers PeersStateOutput, rosenpassEnabled, rosenpassPermissive bool) string { + var ( + peersString = "" + ) + + for _, peerState := range peers.Details { + + localICE := "-" + if peerState.IceCandidateType.Local != "" { + localICE = peerState.IceCandidateType.Local + } + + remoteICE := "-" + if peerState.IceCandidateType.Remote != "" { + remoteICE = peerState.IceCandidateType.Remote + } + + localICEEndpoint := "-" + if peerState.IceCandidateEndpoint.Local != "" { + localICEEndpoint = peerState.IceCandidateEndpoint.Local + } + + remoteICEEndpoint := "-" + if peerState.IceCandidateEndpoint.Remote != "" { + remoteICEEndpoint = peerState.IceCandidateEndpoint.Remote + } + + rosenpassEnabledStatus := "false" + if rosenpassEnabled { + if peerState.RosenpassEnabled { + rosenpassEnabledStatus = "true" + } else { + if rosenpassPermissive { + rosenpassEnabledStatus = "false (remote didn't enable quantum resistance)" + } else { + rosenpassEnabledStatus = "false (connection won't work without a permissive mode)" + } + } + } else { + if peerState.RosenpassEnabled { + rosenpassEnabledStatus = "false (connection might not work without a remote permissive mode)" + } + } + + networks := "-" + if len(peerState.Networks) > 0 { + sort.Strings(peerState.Networks) + networks = strings.Join(peerState.Networks, ", ") + } + + peerString := fmt.Sprintf( + "\n %s:\n"+ + " NetBird IP: %s\n"+ + " Public key: %s\n"+ + " Status: %s\n"+ + " -- detail --\n"+ + " Connection type: %s\n"+ + " ICE candidate (Local/Remote): %s/%s\n"+ + " ICE candidate endpoints (Local/Remote): %s/%s\n"+ + " Relay server address: %s\n"+ + " Last connection update: %s\n"+ + " Last WireGuard handshake: %s\n"+ + " Transfer status (received/sent) %s/%s\n"+ + " Quantum resistance: %s\n"+ + " Networks: %s\n"+ + " Latency: %s\n", + peerState.FQDN, + peerState.IP, + peerState.PubKey, + peerState.Status, + peerState.ConnType, + localICE, + remoteICE, + localICEEndpoint, + remoteICEEndpoint, + peerState.RelayAddress, + timeAgo(peerState.LastStatusUpdate), + timeAgo(peerState.LastWireguardHandshake), + toIEC(peerState.TransferReceived), + toIEC(peerState.TransferSent), + rosenpassEnabledStatus, + networks, + peerState.Latency.String(), + ) + + peersString += peerString + } + return peersString +} + +func skipDetailByFilters( + peerState *proto.PeerState, + isConnected bool, + statusFilter string, + prefixNamesFilter []string, + prefixNamesFilterMap map[string]struct{}, + ipsFilter map[string]struct{}, +) bool { + statusEval := false + ipEval := false + nameEval := true + + if statusFilter != "" { + lowerStatusFilter := strings.ToLower(statusFilter) + if lowerStatusFilter == "disconnected" && isConnected { + statusEval = true + } else if lowerStatusFilter == "connected" && !isConnected { + statusEval = true + } + } + + if len(ipsFilter) > 0 { + _, ok := ipsFilter[peerState.IP] + if !ok { + ipEval = true + } + } + + if len(prefixNamesFilter) > 0 { + for prefixNameFilter := range prefixNamesFilterMap { + if strings.HasPrefix(peerState.Fqdn, prefixNameFilter) { + nameEval = false + break + } + } + } else { + nameEval = false + } + + return statusEval || ipEval || nameEval +} + +func toIEC(b int64) string { + const unit = 1024 + if b < unit { + return fmt.Sprintf("%d B", b) + } + div, exp := int64(unit), 0 + for n := b / unit; n >= unit; n /= unit { + div *= unit + exp++ + } + return fmt.Sprintf("%.1f %ciB", + float64(b)/float64(div), "KMGTPE"[exp]) +} + +func countEnabled(dnsServers []NsServerGroupStateOutput) int { + count := 0 + for _, server := range dnsServers { + if server.Enabled { + count++ + } + } + return count +} + +// timeAgo returns a string representing the duration since the provided time in a human-readable format. +func timeAgo(t time.Time) string { + if t.IsZero() || t.Equal(time.Unix(0, 0)) { + return "-" + } + duration := time.Since(t) + switch { + case duration < time.Second: + return "Now" + case duration < time.Minute: + seconds := int(duration.Seconds()) + if seconds == 1 { + return "1 second ago" + } + return fmt.Sprintf("%d seconds ago", seconds) + case duration < time.Hour: + minutes := int(duration.Minutes()) + seconds := int(duration.Seconds()) % 60 + if minutes == 1 { + if seconds == 1 { + return "1 minute, 1 second ago" + } else if seconds > 0 { + return fmt.Sprintf("1 minute, %d seconds ago", seconds) + } + return "1 minute ago" + } + if seconds > 0 { + return fmt.Sprintf("%d minutes, %d seconds ago", minutes, seconds) + } + return fmt.Sprintf("%d minutes ago", minutes) + case duration < 24*time.Hour: + hours := int(duration.Hours()) + minutes := int(duration.Minutes()) % 60 + if hours == 1 { + if minutes == 1 { + return "1 hour, 1 minute ago" + } else if minutes > 0 { + return fmt.Sprintf("1 hour, %d minutes ago", minutes) + } + return "1 hour ago" + } + if minutes > 0 { + return fmt.Sprintf("%d hours, %d minutes ago", hours, minutes) + } + return fmt.Sprintf("%d hours ago", hours) + } + + days := int(duration.Hours()) / 24 + hours := int(duration.Hours()) % 24 + if days == 1 { + if hours == 1 { + return "1 day, 1 hour ago" + } else if hours > 0 { + return fmt.Sprintf("1 day, %d hours ago", hours) + } + return "1 day ago" + } + if hours > 0 { + return fmt.Sprintf("%d days, %d hours ago", days, hours) + } + return fmt.Sprintf("%d days ago", days) +} + +func anonymizePeerDetail(a *anonymize.Anonymizer, peer *PeerStateDetailOutput) { + peer.FQDN = a.AnonymizeDomain(peer.FQDN) + if localIP, port, err := net.SplitHostPort(peer.IceCandidateEndpoint.Local); err == nil { + peer.IceCandidateEndpoint.Local = fmt.Sprintf("%s:%s", a.AnonymizeIPString(localIP), port) + } + if remoteIP, port, err := net.SplitHostPort(peer.IceCandidateEndpoint.Remote); err == nil { + peer.IceCandidateEndpoint.Remote = fmt.Sprintf("%s:%s", a.AnonymizeIPString(remoteIP), port) + } + + peer.RelayAddress = a.AnonymizeURI(peer.RelayAddress) + + for i, route := range peer.Networks { + peer.Networks[i] = a.AnonymizeIPString(route) + } + + for i, route := range peer.Networks { + peer.Networks[i] = a.AnonymizeRoute(route) + } +} + +func anonymizeOverview(a *anonymize.Anonymizer, overview *OutputOverview) { + for i, peer := range overview.Peers.Details { + peer := peer + anonymizePeerDetail(a, &peer) + overview.Peers.Details[i] = peer + } + + overview.ManagementState.URL = a.AnonymizeURI(overview.ManagementState.URL) + overview.ManagementState.Error = a.AnonymizeString(overview.ManagementState.Error) + overview.SignalState.URL = a.AnonymizeURI(overview.SignalState.URL) + overview.SignalState.Error = a.AnonymizeString(overview.SignalState.Error) + + overview.IP = a.AnonymizeIPString(overview.IP) + for i, detail := range overview.Relays.Details { + detail.URI = a.AnonymizeURI(detail.URI) + detail.Error = a.AnonymizeString(detail.Error) + overview.Relays.Details[i] = detail + } + + for i, nsGroup := range overview.NSServerGroups { + for j, domain := range nsGroup.Domains { + overview.NSServerGroups[i].Domains[j] = a.AnonymizeDomain(domain) + } + for j, ns := range nsGroup.Servers { + host, port, err := net.SplitHostPort(ns) + if err == nil { + overview.NSServerGroups[i].Servers[j] = fmt.Sprintf("%s:%s", a.AnonymizeIPString(host), port) + } + } + } + + for i, route := range overview.Networks { + overview.Networks[i] = a.AnonymizeRoute(route) + } + + overview.FQDN = a.AnonymizeDomain(overview.FQDN) + + for i, event := range overview.Events { + overview.Events[i].Message = a.AnonymizeString(event.Message) + overview.Events[i].UserMessage = a.AnonymizeString(event.UserMessage) + + for k, v := range event.Metadata { + event.Metadata[k] = a.AnonymizeString(v) + } + } +} diff --git a/client/status/status_test.go b/client/status/status_test.go new file mode 100644 index 000000000..e48b441f5 --- /dev/null +++ b/client/status/status_test.go @@ -0,0 +1,607 @@ +package status + +import ( + "bytes" + "encoding/json" + "fmt" + "runtime" + "testing" + "time" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "google.golang.org/protobuf/types/known/durationpb" + "google.golang.org/protobuf/types/known/timestamppb" + + "github.com/netbirdio/netbird/client/proto" + "github.com/netbirdio/netbird/version" +) + +func init() { + loc, err := time.LoadLocation("UTC") + if err != nil { + panic(err) + } + + time.Local = loc +} + +var resp = &proto.StatusResponse{ + Status: "Connected", + FullStatus: &proto.FullStatus{ + Peers: []*proto.PeerState{ + { + IP: "192.168.178.101", + PubKey: "Pubkey1", + Fqdn: "peer-1.awesome-domain.com", + ConnStatus: "Connected", + ConnStatusUpdate: timestamppb.New(time.Date(2001, time.Month(1), 1, 1, 1, 1, 0, time.UTC)), + Relayed: false, + LocalIceCandidateType: "", + RemoteIceCandidateType: "", + LocalIceCandidateEndpoint: "", + RemoteIceCandidateEndpoint: "", + LastWireguardHandshake: timestamppb.New(time.Date(2001, time.Month(1), 1, 1, 1, 2, 0, time.UTC)), + BytesRx: 200, + BytesTx: 100, + Networks: []string{ + "10.1.0.0/24", + }, + Latency: durationpb.New(time.Duration(10000000)), + }, + { + IP: "192.168.178.102", + PubKey: "Pubkey2", + Fqdn: "peer-2.awesome-domain.com", + ConnStatus: "Connected", + ConnStatusUpdate: timestamppb.New(time.Date(2002, time.Month(2), 2, 2, 2, 2, 0, time.UTC)), + Relayed: true, + LocalIceCandidateType: "relay", + RemoteIceCandidateType: "prflx", + LocalIceCandidateEndpoint: "10.0.0.1:10001", + RemoteIceCandidateEndpoint: "10.0.10.1:10002", + LastWireguardHandshake: timestamppb.New(time.Date(2002, time.Month(2), 2, 2, 2, 3, 0, time.UTC)), + BytesRx: 2000, + BytesTx: 1000, + Latency: durationpb.New(time.Duration(10000000)), + }, + }, + ManagementState: &proto.ManagementState{ + URL: "my-awesome-management.com:443", + Connected: true, + Error: "", + }, + SignalState: &proto.SignalState{ + URL: "my-awesome-signal.com:443", + Connected: true, + Error: "", + }, + Relays: []*proto.RelayState{ + { + URI: "stun:my-awesome-stun.com:3478", + Available: true, + Error: "", + }, + { + URI: "turns:my-awesome-turn.com:443?transport=tcp", + Available: false, + Error: "context: deadline exceeded", + }, + }, + LocalPeerState: &proto.LocalPeerState{ + IP: "192.168.178.100/16", + PubKey: "Some-Pub-Key", + KernelInterface: true, + Fqdn: "some-localhost.awesome-domain.com", + Networks: []string{ + "10.10.0.0/24", + }, + }, + DnsServers: []*proto.NSGroupState{ + { + Servers: []string{ + "8.8.8.8:53", + }, + Domains: nil, + Enabled: true, + Error: "", + }, + { + Servers: []string{ + "1.1.1.1:53", + "2.2.2.2:53", + }, + Domains: []string{ + "example.com", + "example.net", + }, + Enabled: false, + Error: "timeout", + }, + }, + }, + DaemonVersion: "0.14.1", +} + +var overview = OutputOverview{ + Peers: PeersStateOutput{ + Total: 2, + Connected: 2, + Details: []PeerStateDetailOutput{ + { + IP: "192.168.178.101", + PubKey: "Pubkey1", + FQDN: "peer-1.awesome-domain.com", + Status: "Connected", + LastStatusUpdate: time.Date(2001, 1, 1, 1, 1, 1, 0, time.UTC), + ConnType: "P2P", + IceCandidateType: IceCandidateType{ + Local: "", + Remote: "", + }, + IceCandidateEndpoint: IceCandidateType{ + Local: "", + Remote: "", + }, + LastWireguardHandshake: time.Date(2001, 1, 1, 1, 1, 2, 0, time.UTC), + TransferReceived: 200, + TransferSent: 100, + Networks: []string{ + "10.1.0.0/24", + }, + Latency: time.Duration(10000000), + }, + { + IP: "192.168.178.102", + PubKey: "Pubkey2", + FQDN: "peer-2.awesome-domain.com", + Status: "Connected", + LastStatusUpdate: time.Date(2002, 2, 2, 2, 2, 2, 0, time.UTC), + ConnType: "Relayed", + IceCandidateType: IceCandidateType{ + Local: "relay", + Remote: "prflx", + }, + IceCandidateEndpoint: IceCandidateType{ + Local: "10.0.0.1:10001", + Remote: "10.0.10.1:10002", + }, + LastWireguardHandshake: time.Date(2002, 2, 2, 2, 2, 3, 0, time.UTC), + TransferReceived: 2000, + TransferSent: 1000, + Latency: time.Duration(10000000), + }, + }, + }, + Events: []SystemEventOutput{}, + CliVersion: version.NetbirdVersion(), + DaemonVersion: "0.14.1", + ManagementState: ManagementStateOutput{ + URL: "my-awesome-management.com:443", + Connected: true, + Error: "", + }, + SignalState: SignalStateOutput{ + URL: "my-awesome-signal.com:443", + Connected: true, + Error: "", + }, + Relays: RelayStateOutput{ + Total: 2, + Available: 1, + Details: []RelayStateOutputDetail{ + { + URI: "stun:my-awesome-stun.com:3478", + Available: true, + Error: "", + }, + { + URI: "turns:my-awesome-turn.com:443?transport=tcp", + Available: false, + Error: "context: deadline exceeded", + }, + }, + }, + IP: "192.168.178.100/16", + PubKey: "Some-Pub-Key", + KernelInterface: true, + FQDN: "some-localhost.awesome-domain.com", + NSServerGroups: []NsServerGroupStateOutput{ + { + Servers: []string{ + "8.8.8.8:53", + }, + Domains: nil, + Enabled: true, + Error: "", + }, + { + Servers: []string{ + "1.1.1.1:53", + "2.2.2.2:53", + }, + Domains: []string{ + "example.com", + "example.net", + }, + Enabled: false, + Error: "timeout", + }, + }, + Networks: []string{ + "10.10.0.0/24", + }, +} + +func TestConversionFromFullStatusToOutputOverview(t *testing.T) { + convertedResult := ConvertToStatusOutputOverview(resp, false, "", nil, nil, nil) + + assert.Equal(t, overview, convertedResult) +} + +func TestSortingOfPeers(t *testing.T) { + peers := []PeerStateDetailOutput{ + { + IP: "192.168.178.104", + }, + { + IP: "192.168.178.102", + }, + { + IP: "192.168.178.101", + }, + { + IP: "192.168.178.105", + }, + { + IP: "192.168.178.103", + }, + } + + sortPeersByIP(peers) + + assert.Equal(t, peers[3].IP, "192.168.178.104") +} + +func TestParsingToJSON(t *testing.T) { + jsonString, _ := ParseToJSON(overview) + + //@formatter:off + expectedJSONString := ` + { + "peers": { + "total": 2, + "connected": 2, + "details": [ + { + "fqdn": "peer-1.awesome-domain.com", + "netbirdIp": "192.168.178.101", + "publicKey": "Pubkey1", + "status": "Connected", + "lastStatusUpdate": "2001-01-01T01:01:01Z", + "connectionType": "P2P", + "iceCandidateType": { + "local": "", + "remote": "" + }, + "iceCandidateEndpoint": { + "local": "", + "remote": "" + }, + "relayAddress": "", + "lastWireguardHandshake": "2001-01-01T01:01:02Z", + "transferReceived": 200, + "transferSent": 100, + "latency": 10000000, + "quantumResistance": false, + "networks": [ + "10.1.0.0/24" + ] + }, + { + "fqdn": "peer-2.awesome-domain.com", + "netbirdIp": "192.168.178.102", + "publicKey": "Pubkey2", + "status": "Connected", + "lastStatusUpdate": "2002-02-02T02:02:02Z", + "connectionType": "Relayed", + "iceCandidateType": { + "local": "relay", + "remote": "prflx" + }, + "iceCandidateEndpoint": { + "local": "10.0.0.1:10001", + "remote": "10.0.10.1:10002" + }, + "relayAddress": "", + "lastWireguardHandshake": "2002-02-02T02:02:03Z", + "transferReceived": 2000, + "transferSent": 1000, + "latency": 10000000, + "quantumResistance": false, + "networks": null + } + ] + }, + "cliVersion": "development", + "daemonVersion": "0.14.1", + "management": { + "url": "my-awesome-management.com:443", + "connected": true, + "error": "" + }, + "signal": { + "url": "my-awesome-signal.com:443", + "connected": true, + "error": "" + }, + "relays": { + "total": 2, + "available": 1, + "details": [ + { + "uri": "stun:my-awesome-stun.com:3478", + "available": true, + "error": "" + }, + { + "uri": "turns:my-awesome-turn.com:443?transport=tcp", + "available": false, + "error": "context: deadline exceeded" + } + ] + }, + "netbirdIp": "192.168.178.100/16", + "publicKey": "Some-Pub-Key", + "usesKernelInterface": true, + "fqdn": "some-localhost.awesome-domain.com", + "quantumResistance": false, + "quantumResistancePermissive": false, + "networks": [ + "10.10.0.0/24" + ], + "forwardingRules": 0, + "dnsServers": [ + { + "servers": [ + "8.8.8.8:53" + ], + "domains": null, + "enabled": true, + "error": "" + }, + { + "servers": [ + "1.1.1.1:53", + "2.2.2.2:53" + ], + "domains": [ + "example.com", + "example.net" + ], + "enabled": false, + "error": "timeout" + } + ], + "events": [] + }` + // @formatter:on + + var expectedJSON bytes.Buffer + require.NoError(t, json.Compact(&expectedJSON, []byte(expectedJSONString))) + + assert.Equal(t, expectedJSON.String(), jsonString) +} + +func TestParsingToYAML(t *testing.T) { + yaml, _ := ParseToYAML(overview) + + expectedYAML := + `peers: + total: 2 + connected: 2 + details: + - fqdn: peer-1.awesome-domain.com + netbirdIp: 192.168.178.101 + publicKey: Pubkey1 + status: Connected + lastStatusUpdate: 2001-01-01T01:01:01Z + connectionType: P2P + iceCandidateType: + local: "" + remote: "" + iceCandidateEndpoint: + local: "" + remote: "" + relayAddress: "" + lastWireguardHandshake: 2001-01-01T01:01:02Z + transferReceived: 200 + transferSent: 100 + latency: 10ms + quantumResistance: false + networks: + - 10.1.0.0/24 + - fqdn: peer-2.awesome-domain.com + netbirdIp: 192.168.178.102 + publicKey: Pubkey2 + status: Connected + lastStatusUpdate: 2002-02-02T02:02:02Z + connectionType: Relayed + iceCandidateType: + local: relay + remote: prflx + iceCandidateEndpoint: + local: 10.0.0.1:10001 + remote: 10.0.10.1:10002 + relayAddress: "" + lastWireguardHandshake: 2002-02-02T02:02:03Z + transferReceived: 2000 + transferSent: 1000 + latency: 10ms + quantumResistance: false + networks: [] +cliVersion: development +daemonVersion: 0.14.1 +management: + url: my-awesome-management.com:443 + connected: true + error: "" +signal: + url: my-awesome-signal.com:443 + connected: true + error: "" +relays: + total: 2 + available: 1 + details: + - uri: stun:my-awesome-stun.com:3478 + available: true + error: "" + - uri: turns:my-awesome-turn.com:443?transport=tcp + available: false + error: 'context: deadline exceeded' +netbirdIp: 192.168.178.100/16 +publicKey: Some-Pub-Key +usesKernelInterface: true +fqdn: some-localhost.awesome-domain.com +quantumResistance: false +quantumResistancePermissive: false +networks: + - 10.10.0.0/24 +forwardingRules: 0 +dnsServers: + - servers: + - 8.8.8.8:53 + domains: [] + enabled: true + error: "" + - servers: + - 1.1.1.1:53 + - 2.2.2.2:53 + domains: + - example.com + - example.net + enabled: false + error: timeout +events: [] +` + + assert.Equal(t, expectedYAML, yaml) +} + +func TestParsingToDetail(t *testing.T) { + // Calculate time ago based on the fixture dates + lastConnectionUpdate1 := timeAgo(overview.Peers.Details[0].LastStatusUpdate) + lastHandshake1 := timeAgo(overview.Peers.Details[0].LastWireguardHandshake) + lastConnectionUpdate2 := timeAgo(overview.Peers.Details[1].LastStatusUpdate) + lastHandshake2 := timeAgo(overview.Peers.Details[1].LastWireguardHandshake) + + detail := ParseToFullDetailSummary(overview) + + expectedDetail := fmt.Sprintf( + `Peers detail: + peer-1.awesome-domain.com: + NetBird IP: 192.168.178.101 + Public key: Pubkey1 + Status: Connected + -- detail -- + Connection type: P2P + ICE candidate (Local/Remote): -/- + ICE candidate endpoints (Local/Remote): -/- + Relay server address: + Last connection update: %s + Last WireGuard handshake: %s + Transfer status (received/sent) 200 B/100 B + Quantum resistance: false + Networks: 10.1.0.0/24 + Latency: 10ms + + peer-2.awesome-domain.com: + NetBird IP: 192.168.178.102 + Public key: Pubkey2 + Status: Connected + -- detail -- + Connection type: Relayed + ICE candidate (Local/Remote): relay/prflx + ICE candidate endpoints (Local/Remote): 10.0.0.1:10001/10.0.10.1:10002 + Relay server address: + Last connection update: %s + Last WireGuard handshake: %s + Transfer status (received/sent) 2.0 KiB/1000 B + Quantum resistance: false + Networks: - + Latency: 10ms + +Events: No events recorded +OS: %s/%s +Daemon version: 0.14.1 +CLI version: %s +Management: Connected to my-awesome-management.com:443 +Signal: Connected to my-awesome-signal.com:443 +Relays: + [stun:my-awesome-stun.com:3478] is Available + [turns:my-awesome-turn.com:443?transport=tcp] is Unavailable, reason: context: deadline exceeded +Nameservers: + [8.8.8.8:53] for [.] is Available + [1.1.1.1:53, 2.2.2.2:53] for [example.com, example.net] is Unavailable, reason: timeout +FQDN: some-localhost.awesome-domain.com +NetBird IP: 192.168.178.100/16 +Interface type: Kernel +Quantum resistance: false +Networks: 10.10.0.0/24 +Forwarding rules: 0 +Peers count: 2/2 Connected +`, lastConnectionUpdate1, lastHandshake1, lastConnectionUpdate2, lastHandshake2, runtime.GOOS, runtime.GOARCH, overview.CliVersion) + + assert.Equal(t, expectedDetail, detail) +} + +func TestParsingToShortVersion(t *testing.T) { + shortVersion := ParseGeneralSummary(overview, false, false, false) + + expectedString := fmt.Sprintf("OS: %s/%s", runtime.GOOS, runtime.GOARCH) + ` +Daemon version: 0.14.1 +CLI version: development +Management: Connected +Signal: Connected +Relays: 1/2 Available +Nameservers: 1/2 Available +FQDN: some-localhost.awesome-domain.com +NetBird IP: 192.168.178.100/16 +Interface type: Kernel +Quantum resistance: false +Networks: 10.10.0.0/24 +Forwarding rules: 0 +Peers count: 2/2 Connected +` + + assert.Equal(t, expectedString, shortVersion) +} + +func TestTimeAgo(t *testing.T) { + now := time.Now() + + cases := []struct { + name string + input time.Time + expected string + }{ + {"Now", now, "Now"}, + {"Seconds ago", now.Add(-10 * time.Second), "10 seconds ago"}, + {"One minute ago", now.Add(-1 * time.Minute), "1 minute ago"}, + {"Minutes and seconds ago", now.Add(-(1*time.Minute + 30*time.Second)), "1 minute, 30 seconds ago"}, + {"One hour ago", now.Add(-1 * time.Hour), "1 hour ago"}, + {"Hours and minutes ago", now.Add(-(2*time.Hour + 15*time.Minute)), "2 hours, 15 minutes ago"}, + {"One day ago", now.Add(-24 * time.Hour), "1 day ago"}, + {"Multiple days ago", now.Add(-(72*time.Hour + 20*time.Minute)), "3 days ago"}, + {"Zero time", time.Time{}, "-"}, + {"Unix zero time", time.Unix(0, 0), "-"}, + } + + for _, tc := range cases { + t.Run(tc.name, func(t *testing.T) { + result := timeAgo(tc.input) + assert.Equal(t, tc.expected, result, "Failed %s", tc.name) + }) + } +} diff --git a/client/system/info.go b/client/system/info.go index d83e9509a..2a0343ca6 100644 --- a/client/system/info.go +++ b/client/system/info.go @@ -9,7 +9,6 @@ import ( "google.golang.org/grpc/metadata" "github.com/netbirdio/netbird/management/proto" - "github.com/netbirdio/netbird/version" ) // DeviceNameCtxKey context key for device name @@ -119,11 +118,6 @@ func extractDeviceName(ctx context.Context, defaultName string) string { return v } -// GetDesktopUIUserAgent returns the Desktop ui user agent -func GetDesktopUIUserAgent() string { - return "netbird-desktop-ui/" + version.NetbirdVersion() -} - func networkAddresses() ([]NetworkAddress, error) { interfaces, err := net.Interfaces() if err != nil { diff --git a/client/ui/bundled.go b/client/ui/bundled.go deleted file mode 100644 index e2c138b14..000000000 --- a/client/ui/bundled.go +++ /dev/null @@ -1,12 +0,0 @@ -// auto-generated -// Code generated by '$ fyne bundle'. DO NOT EDIT. - -package main - -import "fyne.io/fyne/v2" - -var resourceNetbirdSystemtrayConnectedPng = &fyne.StaticResource{ - StaticName: "netbird-systemtray-connected.png", - StaticContent: []byte( - "\x89PNG\r\n\x1a\n\x00\x00\x00\rIHDR\x00\x00\x01\x00\x00\x00\x01\x00\b\x06\x00\x00\x00\\r\xa8f\x00\x00\x00\xc3zTXtRaw profile type exif\x00\x00x\xdamP\xdb\r\xc3 \f\xfc\xf7\x14\x1d\xc1\xaf\x80\x19\x874\xa9\xd4\r:~\r8Q\x88r\x92χ\x9d\x1cư\xff\xbe\x1fx50)\xe8\x92-\x95\x94СE\vW\x17\x86\x03\xb53\xa1v>@\xc1S\x1dNɞų\x8c\x86\xa5\xf8\xeb\xa8\xd3d\x83T]-\x17#{Gc\x9d\x1bEGf\xbb\x19\xc5E\xd2&b\x17[\x18\x950\x12\x1e\r\n\x83:\x9e\x85\xa9X\xbe>a\xddq\x86\x8d\x80F\x92\xbb\xf7ir?k\xf6\xedm\x8b\x17\x85y\x17\x12t\x16\xd11\x80\xb4P\x90\xdaE\xf5\xf0\xa1\xfc#u-\x92:[L\xe2\vy\xda\xd3\x01\xf8\x03\xda\xd4Y\x17ݮ\xb7\xee\x00\x00\x01\x84iCCPICC profile\x00\x00x\x9c}\x91=H\xc3@\x1c\xc5_S\xa5\"-\x0e\x16\x14\x11\xccP\x9d\xec\xa2\"\xe2T\xabP\x84\n\xa5Vh\xd5\xc1\xe4\xd2/hҐ\xa4\xb88\n\xae\x05\a?\x16\xab\x0e.κ:\xb8\n\x82\xe0\a\x88\xb3\x83\x93\xa2\x8b\x94\xf8\xbf\xa4\xd0\"ƃ\xe3~\xbc\xbb\xf7\xb8{\a\b\x8d\nSͮ\x18\xa0j\x96\x91N\xc4\xc5lnU\f\xbcB\xc0\x00B\x18\xc1\xac\xc4L}.\x95J\xc2s|\xdd\xc3\xc7\u05fb(\xcf\xf2>\xf7\xe7\b)y\x93\x01>\x918\xc6t\xc3\"\xde \x9e\u07b4t\xce\xfb\xc4aV\x92\x14\xe2s\xe2q\x83.H\xfc\xc8u\xd9\xe57\xceE\x87\x05\x9e\x1962\xe9y\xe20\xb1X\xec`\xb9\x83Y\xc9P\x89\xa7\x88#\x8a\xaaQ\xbe\x90uY\xe1\xbc\xc5Y\xad\xd4X\xeb\x9e\xfc\x85\xc1\xbc\xb6\xb2\xccu\x9a\xc3H`\x11KHA\x84\x8c\x1aʨ\xc0B\x94V\x8d\x14\x13iڏ{\xf8\x87\x1c\x7f\x8a\\2\xb9\xca`\xe4X@\x15*$\xc7\x0f\xfe\a\xbf\xbb5\v\x93\x13nR0\x0et\xbf\xd8\xf6\xc7(\x10\xd8\x05\x9au\xdb\xfe>\xb6\xed\xe6\t\xe0\x7f\x06\xae\xb4\xb6\xbf\xda\x00f>I\xaf\xb7\xb5\xc8\x11з\r\\\\\xb75y\x0f\xb8\xdc\x01\x06\x9ftɐ\x1c\xc9OS(\x14\x80\xf73\xfa\xa6\x1c\xd0\x7f\v\xf4\xae\xb9\xbd\xb5\xf6q\xfa\x00d\xa8\xab\xe4\rpp\b\x8c\x15){\xdd\xe3\xdd=\x9d\xbd\xfd{\xa6\xd5\xdf\x0fںr\xd0VwQ\xba\x00\x00\rxiTXtXML:com.adobe.xmp\x00\x00\x00\x00\x00\n\n \n \n \n \n \n \n \n \n \n\n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n\xf0C\xff\xd9\x00\x00\x00\x06bKGD\x00\xff\x00\xff\x00\xff\xa0\xbd\xa7\x93\x00\x00\x00\tpHYs\x00\x00\v\x13\x00\x00\v\x13\x01\x00\x9a\x9c\x18\x00\x00\x00\atIME\a\xe8\x02\x17\r$'\xdd\xf7ȗ\x00\x00\x13;IDATx\xda\xed\x9d]o\x14W\x9a\xc7\xff\xa7\xaamh\xbf\xc46,I`\x99\xa1\xc3\ni\xb5{1\x95O0\xe4\x1b\xc0'X\xf2\t`.W`hp\xa2\xb9\fH{O\xa3\xcc\xc5\xecJ3q\xa4\x1d\xed\xcdJx>Aj/\"EBJګL \xb1\x00g\xf1\v\xb6\xbb\xeb\xec\x85mb\f\xb6\xfb\xa5^Ω\xfa\xfd\xee\x928v\xf7\xa9z\xfe\xcfs\x9e\xa7ο$\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\u0603a\t\xc0g\xd6\x7f\x1f5\x92\x8e\"k\xd4\b\xa4s\xb2jH\x9afez\n\xfe\xdb\b\x00x\x81mF\xd3/CE]\xa3(\x94~c\xa5\x8b;\xc1\x0e\x83E\x7f{\xecF\xfcA\x8d\x95\x00g\xb3\xfb\\tQ\xd2o\xadtq]\xba(I\x81\x95,K345\xa3˒\x84\x00\x80SY~5Х0\xd0o\x13\xabK\x96R>\x9b\xe4o\xd4\x1a\xbd\x1e\xc7\b\x008\x93\xe9\xadtkM\x8a\x02i\xdaZ\x9aS\x99\x12\xea\xf6\xabJ\x80Հ\x02\xf7\xf4W\x13\xe9\xdan\xa6'\xe8sXw\xe9\xf6ؿ\xc6\xed_Z\x01\x00\x05d{\xed\xec\xe9!\xcf\xda\x7f\xbb\xf1\xf7Z/\x80U\x81<\x03\xdf\x12\xf8\xc5\xc5\x7f\xf2K\xe9O\x05\x00d\xfcje\xffx\xecF\xfc\xe1\xfe\x7fM\x05\x00\xd9\x04~3j$5ݲVWX\r\a\xe2?\xdc\x1e\xfb!\x00\x909ks\xd1\xd5Dj\x1a\xcb\x18ω\xe07j\xd5\xf74\xfe\x10\x00Ȅ\x95O\xa3(Ht_R\xc4\xdeҙҿ\xbdw췟\x80\x15\x82\xb4\xb2~\x90\xe8+I\x11\xab\xe1\x0e\xd6\xea\xc1A\xd9\x7f[\x1f\x00\x86\xdc\xeb\xdbP\xf7E\x93\xcf\xc9\xec\xbf\x7f\xec\xc7\x16\x00\xd2\v\xfe\xb9\xe8b\"}axd\xd7\xcd\xf8O\x0e.\xfd\xd9\x02\xc0\xb0\xc1\x7f\xcbJ\x0f\t~G\t4_\xbf\x19\xb7\x8e\xfa1*\x00\xe8\x9b\xd5O\xa2\xfb\x8c\xf7\x1c\xcf\xfe\x81~\xd7\xcb\xcf!\x00\xd03\xb6\x19M\xaf\x87\xfaB\x96\xfd\xbe\xd3\xc1\x7f\xc8؏-\x00\fV\xf27\xa3\xc6z\xa8\x87\xa2\xd9\xe7x\xf4\x1f>\xf6\xa3\x02\x80\x81\x82\xdf\xd6\xf4\x10\a\x1e\x0f\xe2?\xd1\xed\xfa\x8d\u07b2\xff\xb6^\x00\x10\xfc\xa5\xc9\xfeG\x8d\xfdJW\x01\xd8f4\xfd\xf2ؾNt\xe7\xe8\x9b5\f\xb4\xdc\r\xb4\xbc\xfb\xcf\xc77\xb4l\x9a\xf12wѾ=?\xc1\xef\xcf\xf5\xb2\xbd5\xfe\x9c\xac\x00\xd6\x7f\x1f5\xc2D\xd3[\x89\x1a\xd6jZ\x81\xa6\x8d\xd5t`tn\xe7\xcb5$M\xcb\xec\x04{\x867\xa5\x95\x96\x8dѲ\xacvK\xa9ec\xb4\x9cX-Z\xa3ec\xd5\x0e\xa4e\xd5\xd4\xee\xb5\xd9\xe2#ks\x11O\xf6\xf9\x92\xfc\x8dZ\xf5\x1b\xf1\xc7N\n\x80mF\xd3[#jlv\x15)\xd0\xf4\x1e\xfb憌\xa6\xbd\xcf0F\xed\x1d\xb1X\xb6\xd2\xffX\xabvh\xd4>\xdeU\xeckU\xb1v'\xfaLF\xd7\b-On\xc1\x9a>\x18$\x19\x99,\x82<\b\xf4\x1b\xb3\xed\xed\x16Y\xa9Q\xe5\x87E\xac\xb4l\xa4XFq\"-\x86V\xb1\xeb°\xf3\x90O\x93\xb0\xf2\xe6\x1e\xbb=>\x1b\x0ft\xbd\xcc\xc0\x81nuq'\x93Gv\xfb\xf4\x17O\x84\xf5G,\xa9\x9d\x18\xfd5\xb4\x8a\xeb\xb3\xf1\x82\v\x1fju.\xbad\xa4/\xb8<\xfeT\x9f\xf5\x8e>\x1c4\xa1\x98\xa3\x82}-\xd4Ek\xd4\xe0e\f\xb9\xb0 \xa3\xd8Z\xfdu\xac\xab\x85\xbc\xab\x04:\xfe\x1eƿ\xd5ǽ<\xf2ۓ\x00l~\x1aE\x9bV\x17\tv\x87\xaa\x04\xa3\x05c\xf5e\x1e\x15\xc2\xda'\xd1w\\s\xbf\xb2\x7f\xbfc\xbf7~\xc5\xda\\tU\xd2%\xcax/z\t\v\x89\u0557\xe1\x88\x16Ҟ>\xb0\xef\xf70\xfe\al\xfc\xbd\xf6;V>\x8d\"\x93p\xaa\xcb\xc7\xedBb\xf5 \r1\xd89\xd3\xff\x1dK\xeaQ\xf0\x0f8\xf6{\xeb\x16`ǹ\xf5!\xcbZM1X\x9d\x8b\xbe2\xcc\xfb\xbd\xaa\x06\x83\x9a>L\xa3\n|\xd5\x03X\xbf\x13]\xb1F\xf7Y^\uf677҃\xf1\xd9x\xbe\x97\x1f^\xb9\x13]\t\xb8\xee\xbe\t\xc0\xc0c\xbf\x03\x05\x00\x11([\x8d\xa8\xb6\x12͛\x11\xdd;,S\xd0\xf8\xf3\xef\xba\x0e\xdb\xf8\xdb\xcbkǁ\xeb7\xe3\x96U\xefG\t\xc1\xe94ѐ\xd15\xdb\xd1w\xab\x9fD\xf7w^\xb5\xfd\xfa\xde\x7f.\xbaE\xf0{\x16\xffI\xba\xf1i\x0e\xd8\x136\xcd\xf6\xdb\\\xa0d\xd9#It{\xe2fܢ\xf1\xe7\xe5\xf5[\x18\xbb\x11\x7f\x94\xb9\x00 \x02\x15\x10\x82\xe7\x13\xed`z\xe5\"\x8b\xe1\xd1eKa\xec׳\x00 \x02%\xde\x1dlִ\xf5\xf5\xafeF;\nO?Wp\xe2\x05\x8b\xe2z\xf0\xa74\xf6;\xb4\a\xb0\x9f\xf1ٸ\x99H\x0fX\xfer\xd1yt\xe6\x95\x10t\x16Oi\xeb\xeb_+y6\xc9\xc28\\\xb1\xf5c\xf3\x95\x9a\x00H\xd2\xc4l|\x05\x11(\x0fɳI\xd9\xcd\xda\x1b\x15\xc1\xae\x10ؕ:\x8b\xe4\x1e\xf7\xb2\xf2\x9d\xe8\xf94\xe0\xfa\\\xf4\x90w\xbb{^\xfaw\x03u\xbe9\xfb\x86\x00\xbc\x91\x15N\xac(<\xfd\\ft\x8bEs \xfb\xa79\xf6\xeb\xbb\x02\xd8\xe5eW\x97\xed\xf6\x11V\xf05\xfb\xff4ud\xf0oW\t\x13\xda\xfa\xfaW\xea>\x9eaъ\x8e\xff$۱|_~\x00ϛ\xd1\xf4h\xa8\x87<6\xeaa\xf6\xdfi\xfc\xf5}\x83\x8cvT\xbb\xf0\x98j\xa0\x88\xe0Ϩ\xf17P\x05 I3\xcdx9\xe8게\xda\\\x1e\xbf\x184\x9bo\vǯ\xd4\xfd\xfe\xa4l\x97\xd7H\xe4J\x98\xfdCy}_\xd1z3n\x9b\x8e>B\x04<\xca\xfe+\xf5\xa1\xbb\xfcݥ\xa9\x9d\xfe\xc1\b\v\x9a\xc75\x93n\xe7a8;\x90\xa4#\x02~\xd1Y<\x95\xe26\x82\xde@\xf6\xb5\xbf\xdaAM\xad<\xfe\xd4\xc05\x1d\"\xe0\ao\x1b\xfb\r\xbd\x9dx2\xa3-\xaa\x81\xec\xe2?\xc9'\xfbok͐`(\xe2p\x19\xb9YS\xe7љ\xd4\x05\xe0\xd5\xcd3\xdaQx\xf6\xa9\x82\xa9U\x16;\xc5\xec\x9f\xe5\xd8/\xb5\n`\x97\x89\xebql\x03}d%ު\xe3\x18\xdd\xc73\x99\x05\xff+\x81\xf9\xf6=\xb6\x04)R3\xba\x9c\xe7\xdfK\xa5\xad;q=\x8e%}\xcc\xe5s+\xfb\xe7\xf5xo\xf7Ɍ:\x8b\xef2%\x186\xf9\x1b\xb5F\xb7c\xc9/\x01\x90\xa4\xf1\xd9x\xdeXD\xc0\xa5\xec\x9fo\xafa\x82)\xc1\xb0\x84\xf9{q\xa4*\xd9\xf5\x9bqK\xa6\xff\x17\x14B\xda\xc18Y\xc8\xe1\x9e\xed\x9e\xc3iD`\x90\xb5S~\x8d\xbf\xd7[\x0e\x19\xc01\xe2b\xd9\xfa\xfaי\xee\xfd\x8f\xced\x89F.<\x96\xa9op1z\x8b\xc2\\\x1b\x7f\x99U\x00{\xb6\x03M\xac\xc5\n*\xfd{|\xde?\xdb\x0f\x11h\xeb\xd1i%?\x8fsAz\x89\xff\xa4\xb8X\xc9\xf4\xed\xc0T\x02E\x94\xe0g\x8a\x17\x80\xbd\xc5\xc0\xb9%\x85\x18\x8e\x1c\x16\x81\xf1؍\xf8â\xfe|\xa6m\xdb\x1dC\x91{\\\xe5\x9c\x12o\xc6c\xbf\x81>\xd3\xe2)u1\x1b98\xfe\xc3|\xc7~\xb9\n\x80$M\xcc\xc6\xd70\x14\xc9'\xfb\xbb\xea\xea\x83\b\x1c\x10\xfcF\xad\"\x1a\x7f\xb9\n\xc0\x8e\b\xe0*\x941\x9do\xdfw\xbb:A\x04\xf6\x97\xfe\xed\"\xc6~\x85\b\x80$muu\rC\x91lH\x9eMʮ\x8f\xba\xbfE\xf9\xfe\xa4\xec\xfa1.\x98$k\xf5\xa0\xe8쿭C9\x82\xa1HF\xe2Z\xf4د\x1f\xc2D#\xff\xf8\xb7j\x1b\x8c\x148\xf6+\xac\x02\x90\xb6\rE6\xbb\x9c L5\xab:\xd8\xf8;\xfc\x03\a\x95\x7fX\xa8ȱ_\xa1\x02\xb0+\x02\x1c#N\xa9\x8cܬ\xa9\xfbd\xc6\xcb\xcf\xdd\xf9\xf6\xbdJ\x9e\x1d0F\xad\xfa\u0378UY\x01\x90\xf0\x12H3\xfb{+^\xeb\xa3\xea~\xffwջh\xa1[\x0f\xc8\x15&\xc1\x88\xc0\xb0\x01t\xcc\xfb\x97y$\xcf&*u\x94\u0605\xb1\x9f3\x02\xb0W\x04\xf0\x12\xe8\x9fη\uf563\x8ay2S\x8d\x97\x9182\xf6sJ\x00vE\x00C\x91~3\xe7\xa4_\x8d\xbf#\xd8\xfa\xf6\xbd\xd27\x05\xf3\xb4\xf9\xeaO\x97\x1c\x01k\xb1\x1eK\x7f\a\x9f\xf7O\xe5F\x9cx\xa9\x91\v?\x946\xfb\xbb2\xf6s\xae\x02\xd8e\xe2z\x1c\a\x16/\x81#\xb3\xff\xd3\xc9\xd2\x05\xbf$ٕ\xe3\xea.M\x953\xfe\x1d6\xcaqj\x0eS\xbf\x19\xb7p\x15:<\xfb\xfb8\xf6\xeb\xb9\x1f\xf0\xfd\xc9\xd2\xf5\x03\x8cQ\xab>\x1b/ \x00}\x88\x00\xaeB\a\x04H\x05:\xe6\x9d\xc5S\xe5z> t\xdb\x17\xc3ɕ\x1e\xbb\x11\xdf\xc5Pd_\xf6\xffy\xdc\xfb\xb1_\xafUNR\x12\xa1+\xca\xe6\xcb{\x01\x90p\x15z#3~\x7f\xb2:\x95\xceҔ\xff[\x01\xa3\xf6XWw]\xff\x98N\xd7Z\x88\xc06e\x1b\xfbUA\xf0L\xa2ۦ\x19;?\xda6>,\xe6\xca\\t7\x90\xaeV\xb2\xf4/\xe9د\xa7\xed\xf3٧\nO\xfd\xecg\xf6wt\xec\xe7U\x05\xb0K\x95]\x85\xbc;\xed\x97\xf6w\xf7\xb0!hB}\xe4\xcbg\xf5fu\xab\xe8*\xe4\xb2\xcdW>\n\x10xw`\xc8\xc5\xe7\xfdK!\x00R\xf5\\\x85*yZn\x1fɳ\to\x1a\x82VZv}\xec\xe7\xb5\x00\xec\x1a\x8aTA\x04\x92g\x93J~\x1e\x13H\x1d\x7fƂ\xf7|\xca\xfe\x92'M\xc0\xfd\xac7\xa3\x86\xad顬\x1ae\xbd齲\xf9ʁ\x91\v\x8fe&\xd6]\x8e$o\x1a\x7f\xdeV\x00\xbb\x94\xddK\xa0ʍ?_\xab\x00\x97l\xbeJ/\x00e\x16\x01\xbbYSR\xd2C1C\xad\xcb\xcaqw{\x01\x81\xe6]\xb2\xf9\xaa\x84\x00\x94U\x04|\x1d}U\xb9\n0\x81\xbfgW\xbc\xbf\xd3\xca\xe4*T\xf9\xb1\x9f\x87U\x80oc\xbf\xd2\t\xc0\xae\b\x94\xc1U\xa8\xf3\xe8\fQ\xeeS\x15\xe0\xa8\xcdW\xe5\x04@\xda1\x14Q\xb1/Z\x1c\x86*>\xef\xef{\x15\xe0\xaa\xcdW%\x05@\x92\xea\xb3\U000423c6\"\xb6\x1bT\xca\x1d\xb7\x14U\x80Q\xdb\xd7\xc6_i\x05@\xf2\xd3U(\xf9i\x8a\xec\xdfo\x15P\xb0\x89\xa8-\x89}])\xdb\xcd\xf5\x9bq˗c\xc4e\xb7\xf9\xcan\xcb4Q\\\xf27j\x8d\xcf\xc6\xf3\b\x80\xc3\xf8\xe2%@\xe9?\xe0\xba\xfd4Uܸ4,\x8fGE\xa9\aή\x8b\x80]\xa93\xf6\x1bX\x01\x82B\xd6\xce\a\x9b/\x04\xc0\x13\x11\xe8,\x9e\"\x90\x87\xd9\x06,\x8f\xe7\\\xfb\xab\x1d\xd4\xd4*\xd3\x1aV⑳\xf1ٸ隗\x00c\xbf4*\xa8\xe3\xb2\xeb\xc7\xf2\x8b\xff\xa4\\ٿ2\x02 \xb9e(b7k\xec\xfd\xd3\x12Ҽ\x8eL\x97d\xecWY\x01\xd8\x15\x01#-\x14~\xd32\xf6K\xaf\x15\xf0S>\a\xa7j\xc6߇\xcc\x10\x80=\xbc\xec\xear\x91\x86\"v\xb3V\xdaW`\x15\xa3\x00A\xe6O\x06\x1a\xa3\xd6\xe8\xf58F\x00J@ѮB\x94\xfe\x19\xaci\xd6ۀ\xb0\xbc\xd6\xf4\x95@X\xbd\xec\x8f\x00d \x02\x8c\xfd\x1c\xe8\x03\xf4)\x02I\xc5\x1a\x7f\b\xc0\x80\"Ћ\xa1\bo\xf7q\xa1\x0f\xd0\xc76\xc0\xa8\x1d\xd6t\xb7\xaak\x85\x00\xf4\xc1Q\xaeB\xd8|\xb9\xb2\r\xe8\xdd&\xac\x8c6_\xfd`\xb8]\xfagu.\xfa\xcaH\xd1k7]7P盳\b\x80\v7\xf5hG#\xff\xfc\xbf=e\xff\xb1\x1b\xf1\aU^+*\x80\x01x\x9b\xa1\b6_\x0eU\x00\x9b\xb5\x9e\x9a\xb0>\xbeF\x0e\x01p\x80]W\xa1\xdd\x13\x84\xbc\xdd\xc7A\x8e\xd8\x06\x18\xa3V}6^@\x00``\x11\xd8=F\xcc\x13\x7f\xee\x91\x1c5\t\xa8\xe8\xd8\x0f\x01H\x91z3n'\x1b\xb5{\xae\xbc\xae\x1a\xf6l\x03\x0e\xa9\x00\xcan\xf3\x85\x00乀\xc7:Wk\x17~\x90\u0084\xc5pI\x00\xd6\x0e\xa8\x00\x8c\xdac\xdd\xea\x8e\xfd\x10\x80\x14Y\x9b\x8b\xaeʪaF;\x1a\xb9\xf0\x18\x11pI\x00\x0ehȚD\xb7M3^f\x85\x10\x80\xa1XoF\r\x19]{uc\xd574r\xfeG\x16\xc6\x15\xba\xc1\x9b\x0eA%}\xbb\x0f\x02P\x00IM\xb7d\xd5x\xed\xfe\x9aXWxn\x89\xc5q\xa6\x0f\xf0\xfa6\xc0\x84\xfa\x88UA\x00R\xc9\xfe\xc6\xea\xca\xdb\xfe[x\xe2\x05\"\xe0\xe06\xa0J6_\b@\xd6\xd9?\xd4\x17\x87\xfd\xf7\xf0\xc4\v\x85\xa7\x9f\xb3P\x8e\b\x80\x95\x96\x19\xfb!\x00\xa9\xb0r'\xba\xb2\xff1්\xc0\xfb\xcf\x11\x01w*\x80{d\x7f\x04 \x9d\x05\vt\xabןE\x04\nf\xedX%\xde\xee\x83\x00\xe4\xb5\xf7\x9f\x8b\xdeh\xfc!\x02\x0eW\x00ݠ\x926_\xfd\xc0i\xc0^\x83\xbf\x195l\xa8\xef\x06\xfd\xff;\x8b\xa70\n\xc9\xfb\xe6~g\xbd=\xd5\xfa\xaf\x0fX\t*\x80\xa1Ij\xbd\x97\xfeo\xa3vnI\xc1\x89\x17,d\x8e\x84'V\xc8\xfeT\x00\xc5g\xff\xbdl=:û\x02\xf2\xc8lc\x9b\xf3\xef\xfc\xe1?/\xb3\x12T\x00\xc3\xef%kz\x98\xd6\xef\x1a9\xffD\xa6\xbeɢf\x99\xd5F;\xeav&~\xc7J \x00C\xb3r'\xba\xd2o\xe3\xef\xf0\xba4\xd1ȅ\x1f\x10\x81,K\xff\xa9\xf5\xd6\xcc\x1f\xff\xd8f%\x10\x80\xa1K\xff~\xc6~\xfd\x88@\xed\xfc\x13\x99\xd1\x0e\x8b\x9c>\xed\xb0\xb1\xc4\xde\x1f\x01H#P\xf5/\xa9f\xff}ej\xed\xc2\x0f\x88@\xda7\xf4\xf4\x1ag\xfd\xfb\xb9\x0fY\x82\x83\xb3\x7fZ\x8d\xbfC\xfb\v\x9b5u\x1e\x9d\xc1O0\x8d\x9b\x99\xb1\x1f\x15@Z\f;\xf6\xa3\x12(\xe0f\x9e\xd8\xfa\x98U@\x00\x86fu.\xbat\xd0i\xbf\xccD\xe0\xfc\x8f\x18\x8a\fs#\x8fo\xb4&\xff\xed\xbf\x17X\t\x04`\xf8\x804\xfa,\xf7\xbfY\xdf\xc0Uh\b\x01\x1d9\xb9J\xe3\x0f\x01\x18\x9e\xd4\xc7~}\x8a@\r/\x81\xfe\x19\xedܮ\xdf]h\xb3\x10\b\xc0Pd6\xf6\xeb\xe7\x82L\xadb(\xd2\x1f\xedw\xce\xff\x80\xc9'\x020\xff\v\x8d?*\x80\xe1qe\xec\xd7\x0f\xb5\xb3O+\xed%\x10\x9e\xfa?J\x7f*\x80\x94\xb2\xbfcc\xbf\x9e\xe9\x06\xdb\xd6b\xeb\xa3\xd5\xcaV#\xc9\xfc;\xff>\x8f\xcd\x17\x15@\n\xd9\xff\x88\xb7\xfb\xb8\x9d\x06w\\\x85*t\x82Ќv\xd45DZ\xf9B\x00\x86\xa7\u05f7\xfb\xb8.\x02U:F\x8c\xcd\x17\x02\x90ޗv|\xec\xd7OV\xac\x88\b`\xf3\x85\x00\xa4\xb4\xf7\xf7d\xec\x87\b\xfc\x82\x95\xb0\xf9\xca\xea\xfe\xa9T\xf0\xfb\xdc\xf8;*H6k\xda\xfa\xe6\xac\xd4-\x97\xa6\xf3\xbc?\x15@j\xe4e\xf3UT%PFC\x11l\xbe\xa8\x00Ra\xe5\xd3(\n\x12}U\xf6\xefi\u05cfi\xeb\xd1\xe9RT\x02\xc1\xf8F\xeb\x9d\xcf\xff\x82\x00P\x01\xa4\xf0E\xad\xc7c\xbf~\x14\xbd\xbeQ\n/\x013\xdaQwk\x92\xc6\x1f\x02\x90B\xf6/\xd0\xe6\xab\b\xc2\x13/\xfcw\x15\x1a\xed\xdcf\xec\x87\x00\f\x8d\v6_\x85\x89\x80\xbf^\x02\xd8|!\x00)\xed\x89\x03]\xadR\xf6\x7fM\x04<5\x14\t\xde_\xc6\xe6+\xaf\xadVٳ\x7fY\xc7~\xfd\xe0\x93\xb5\x18c?*\x80\xd4(\xf3د\xac\x95\x006_\b@*\xac܉\xae\xe4\xf9v\x1f/D\xc0qW\xa1`|\x83\xe7\xfd\x11\x80\x94\xbeX@\xf6\x7fC\x04\xdcv\x15\xc2\xe6\v\x01H\x87\xb5\xb9\xa8\xb2\x8d\xbf\xa3\xa8\x9d[\x92\x99x\xe9\xde\xde\x1f\x9b\xafbֽl_\xc8G\x9b\xaf\xdcq\xccP\xc4Ԓ\xf6\xd4\x7f\xcc\xd3\xf8\xa3\x02\x18\x1e\x1fm\xbe\xf2\xdf\v\xec\x18\x8a8b-\x16\x9e}J\xe9O\x05\x90R\xf6g\xec\xd73v\xb3\xa6Σ3\xb2\x9bŽ\x1e\x02\x9b/*\x80\x14S\x89\xeesI\xfbP\xff\x82\xbd\x04\xb0\xf9B\x00Rc\xe5Nt\xc5J\x17\xb9\xa4\xfe\x88\x80\x19\xe92\xf6C\x00R\xfa\"\x8c\xfd|\x13\x01\xc6~\b@J{\xff\x92\xd9|\x15&\x02\xe7\x7f\xcc\xcdP\x04\x9b/G\xae\xbb\xf7\xc1ߌ\x1aI\xa8\xaf\x8c4\xcd\xe5L!0s0\x14\xe1\xed>T\x00\xa9\x91\xd4t\x8b\xe0O18\xeb\x1b\x1a9\xffc\xb67\xdd\xd4\x06.?T\x00\xe9d\x7f\xc6~\xd9\xd0}6\xa9\xee\xe2\xa9\xf4o8cZS\x7f\xfa\x13\x02@\x05\x90B\xb9Z\xd3C.a6d\xe1*dF;JFFh\xfc!\x00\xc3S5\x9b\xaf\xc2D \xcdc\xc4\xd8|!\x00\xa9d\xfef4\xcd\xd8/'\x11H\xcfK\x00\x9b/\x04 \x1d^\x86ⴟg\"\x10\x1c\xdf\xc2\xe6\xcbA\xbck\x02\xd2\xf8+\x8eA\xadŰ\xf9\xa2\x02H\rl\xbe\x8a\xad\x04\x061\x14\xc1\xe6\v\x01H\x85չ\xe8\x126_\xc5R;\xb7ԗ\b`\xf3\x85\x00\xa4\xb7_1\xfa\x8cK\xe6\x86\b\xf4\xe4%\x10&\xcb#'W\x19\xfb!\x00\xc3\xc3\xd8\xcf-z1\x141\xf5\xcd{\xf5\xbb\vd\x7f\x97\x93\xaa\x0f\x1f\x12\x9b/G\xe9\x06\xda\xfa\xe6\xecA\x86\"\xed\xe9?\xff\x99\xc6\x1f\x15\xc0\xf0`\xf3\xe5(ar\xe01\xe2Zc\x89ҟ\n \xa5\xec\xcf\xd8\xcfi\xf6[\x8ba\xf3E\x05\x90\xde\xcd\x15\xd2\xf8s>\x8b\xec1\x141a\x82͗G\xd4\\\xfep\xabs\xd1%\x19E\x92\xda\\*\xc7E\xe0XG#\xff\xf0X\x1b\x8b\xa7\xbe\x9c\xf9\xc3<\xd7\v\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00|\xe4\xff\x01\xf6P(\xf3)+S\x1f\x00\x00\x00\x00IEND\xaeB`\x82"), -} diff --git a/client/ui/client_ui.go b/client/ui/client_ui.go index 9ed40b0be..51eec59a5 100644 --- a/client/ui/client_ui.go +++ b/client/ui/client_ui.go @@ -33,7 +33,7 @@ import ( "github.com/netbirdio/netbird/client/internal" "github.com/netbirdio/netbird/client/proto" - "github.com/netbirdio/netbird/client/system" + "github.com/netbirdio/netbird/client/ui/desktop" "github.com/netbirdio/netbird/client/ui/event" "github.com/netbirdio/netbird/util" "github.com/netbirdio/netbird/version" @@ -148,22 +148,24 @@ type serviceClient struct { icError []byte // systray menu items - mStatus *systray.MenuItem - mUp *systray.MenuItem - mDown *systray.MenuItem - mAdminPanel *systray.MenuItem - mSettings *systray.MenuItem - mAbout *systray.MenuItem - mVersionUI *systray.MenuItem - mVersionDaemon *systray.MenuItem - mUpdate *systray.MenuItem - mQuit *systray.MenuItem - mRoutes *systray.MenuItem - mAllowSSH *systray.MenuItem - mAutoConnect *systray.MenuItem - mEnableRosenpass *systray.MenuItem - mNotifications *systray.MenuItem - mAdvancedSettings *systray.MenuItem + mStatus *systray.MenuItem + mUp *systray.MenuItem + mDown *systray.MenuItem + mAdminPanel *systray.MenuItem + mSettings *systray.MenuItem + mAbout *systray.MenuItem + mVersionUI *systray.MenuItem + mVersionDaemon *systray.MenuItem + mUpdate *systray.MenuItem + mQuit *systray.MenuItem + mNetworks *systray.MenuItem + mAllowSSH *systray.MenuItem + mAutoConnect *systray.MenuItem + mEnableRosenpass *systray.MenuItem + mNotifications *systray.MenuItem + mAdvancedSettings *systray.MenuItem + mCreateDebugBundle *systray.MenuItem + mExitNode *systray.MenuItem // application with main windows. app fyne.App @@ -200,6 +202,14 @@ type serviceClient struct { wRoutes fyne.Window eventManager *event.Manager + + exitNodeMu sync.Mutex + mExitNodeItems []menuHandler +} + +type menuHandler struct { + *systray.MenuItem + cancel context.CancelFunc } // newServiceClient instance constructor @@ -473,6 +483,9 @@ func (s *serviceClient) updateStatus() error { status, err := conn.Status(s.ctx, &proto.StatusRequest{}) if err != nil { log.Errorf("get service status: %v", err) + if s.connected { + s.app.SendNotification(fyne.NewNotification("Error", "Connection to service lost")) + } s.setDisconnectedStatus() return err } @@ -498,7 +511,8 @@ func (s *serviceClient) updateStatus() error { s.mStatus.SetTitle("Connected") s.mUp.Disable() s.mDown.Enable() - s.mRoutes.Enable() + s.mNetworks.Enable() + go s.updateExitNodes() systrayIconState = true } else if status.Status != string(internal.StatusConnected) && s.mUp.Disabled() { s.setDisconnectedStatus() @@ -554,7 +568,9 @@ func (s *serviceClient) setDisconnectedStatus() { s.mStatus.SetTitle("Disconnected") s.mDown.Disable() s.mUp.Enable() - s.mRoutes.Disable() + s.mNetworks.Disable() + s.mExitNode.Disable() + go s.updateExitNodes() } func (s *serviceClient) onTrayReady() { @@ -575,12 +591,18 @@ func (s *serviceClient) onTrayReady() { s.mAllowSSH = s.mSettings.AddSubMenuItemCheckbox("Allow SSH", "Allow SSH connections", false) s.mAutoConnect = s.mSettings.AddSubMenuItemCheckbox("Connect on Startup", "Connect automatically when the service starts", false) s.mEnableRosenpass = s.mSettings.AddSubMenuItemCheckbox("Enable Quantum-Resistance", "Enable post-quantum security via Rosenpass", false) - s.mNotifications = s.mSettings.AddSubMenuItemCheckbox("Notifications", "Enable notifications", true) + s.mNotifications = s.mSettings.AddSubMenuItemCheckbox("Notifications", "Enable notifications", false) s.mAdvancedSettings = s.mSettings.AddSubMenuItem("Advanced Settings", "Advanced settings of the application") + s.mCreateDebugBundle = s.mSettings.AddSubMenuItem("Create Debug Bundle", "Create and open debug information bundle") s.loadSettings() - s.mRoutes = systray.AddMenuItem("Networks", "Open the networks management window") - s.mRoutes.Disable() + s.exitNodeMu.Lock() + s.mExitNode = systray.AddMenuItem("Exit Node", "Select exit node for routing traffic") + s.mExitNode.Disable() + s.exitNodeMu.Unlock() + + s.mNetworks = systray.AddMenuItem("Networks", "Open the networks management window") + s.mNetworks.Disable() systray.AddSeparator() s.mAbout = systray.AddMenuItem("About", "About") @@ -599,6 +621,9 @@ func (s *serviceClient) onTrayReady() { systray.AddSeparator() s.mQuit = systray.AddMenuItem("Quit", "Quit the client app") + // update exit node menu in case service is already connected + go s.updateExitNodes() + s.update.SetOnUpdateListener(s.onUpdateAvailable) go func() { s.getSrvConfig() @@ -614,6 +639,12 @@ func (s *serviceClient) onTrayReady() { s.eventManager = event.NewManager(s.app, s.addr) s.eventManager.SetNotificationsEnabled(s.mNotifications.Checked()) + s.eventManager.AddHandler(func(event *proto.SystemEvent) { + if event.Category == proto.SystemEvent_SYSTEM { + s.updateExitNodes() + } + }) + go s.eventManager.Start(s.ctx) go func() { @@ -628,7 +659,7 @@ func (s *serviceClient) onTrayReady() { defer s.mUp.Enable() err := s.menuUpClick() if err != nil { - s.runSelfCommand("error-msg", err.Error()) + s.app.SendNotification(fyne.NewNotification("Error", "Failed to connect to NetBird service")) return } }() @@ -638,7 +669,7 @@ func (s *serviceClient) onTrayReady() { defer s.mDown.Enable() err := s.menuDownClick() if err != nil { - s.runSelfCommand("error-msg", err.Error()) + s.app.SendNotification(fyne.NewNotification("Error", "Failed to connect to NetBird service")) return } }() @@ -676,6 +707,13 @@ func (s *serviceClient) onTrayReady() { defer s.getSrvConfig() s.runSelfCommand("settings", "true") }() + case <-s.mCreateDebugBundle.ClickedCh: + go func() { + if err := s.createAndOpenDebugBundle(); err != nil { + log.Errorf("Failed to create debug bundle: %v", err) + s.app.SendNotification(fyne.NewNotification("Error", "Failed to create debug bundle")) + } + }() case <-s.mQuit.ClickedCh: systray.Quit() return @@ -684,10 +722,10 @@ func (s *serviceClient) onTrayReady() { if err != nil { log.Errorf("%s", err) } - case <-s.mRoutes.ClickedCh: - s.mRoutes.Disable() + case <-s.mNetworks.ClickedCh: + s.mNetworks.Disable() go func() { - defer s.mRoutes.Enable() + defer s.mNetworks.Enable() s.runSelfCommand("networks", "true") }() case <-s.mNotifications.ClickedCh: @@ -718,7 +756,11 @@ func (s *serviceClient) runSelfCommand(command, arg string) { return } - cmd := exec.Command(proc, fmt.Sprintf("--%s=%s", command, arg)) + cmd := exec.Command(proc, + fmt.Sprintf("--%s=%s", command, arg), + fmt.Sprintf("--daemon-addr=%s", s.addr), + ) + out, err := cmd.CombinedOutput() if exitErr, ok := err.(*exec.ExitError); ok && exitErr.ExitCode() == 1 { log.Errorf("start %s UI: %v, %s", command, err, string(out)) @@ -737,7 +779,12 @@ func normalizedVersion(version string) string { return versionString } -func (s *serviceClient) onTrayExit() {} +// onTrayExit is called when the tray icon is closed. +func (s *serviceClient) onTrayExit() { + for _, item := range s.mExitNodeItems { + item.cancel() + } +} // getSrvClient connection to the service. func (s *serviceClient) getSrvClient(timeout time.Duration) (proto.DaemonServiceClient, error) { @@ -753,7 +800,7 @@ func (s *serviceClient) getSrvClient(timeout time.Duration) (proto.DaemonService strings.TrimPrefix(s.addr, "tcp://"), grpc.WithTransportCredentials(insecure.NewCredentials()), grpc.WithBlock(), - grpc.WithUserAgent(system.GetDesktopUIUserAgent()), + grpc.WithUserAgent(desktop.GetUIUserAgent()), ) if err != nil { return nil, fmt.Errorf("dial service: %w", err) diff --git a/client/ui/config/config.go b/client/ui/config/config.go deleted file mode 100644 index fc3361b61..000000000 --- a/client/ui/config/config.go +++ /dev/null @@ -1,46 +0,0 @@ -package config - -import ( - "os" - "runtime" -) - -// ClientConfig basic settings for the UI application. -type ClientConfig struct { - configPath string - logFile string - daemonAddr string -} - -// Config object with default settings. -// -// We are creating this package to extract utility functions from the cmd package -// reading and parsing the configurations for the client should be done here -func Config() *ClientConfig { - defaultConfigPath := "/etc/wiretrustee/config.json" - defaultLogFile := "/var/log/wiretrustee/client.log" - if runtime.GOOS == "windows" { - defaultConfigPath = os.Getenv("PROGRAMDATA") + "\\Wiretrustee\\" + "config.json" - defaultLogFile = os.Getenv("PROGRAMDATA") + "\\Wiretrustee\\" + "client.log" - } - - defaultDaemonAddr := "unix:///var/run/wiretrustee.sock" - if runtime.GOOS == "windows" { - defaultDaemonAddr = "tcp://127.0.0.1:41731" - } - return &ClientConfig{ - configPath: defaultConfigPath, - logFile: defaultLogFile, - daemonAddr: defaultDaemonAddr, - } -} - -// DaemonAddr of the gRPC API. -func (c *ClientConfig) DaemonAddr() string { - return c.daemonAddr -} - -// LogFile path. -func (c *ClientConfig) LogFile() string { - return c.logFile -} diff --git a/client/ui/debug.go b/client/ui/debug.go new file mode 100644 index 000000000..845ea284c --- /dev/null +++ b/client/ui/debug.go @@ -0,0 +1,50 @@ +//go:build !(linux && 386) + +package main + +import ( + "fmt" + "path/filepath" + + "fyne.io/fyne/v2" + "github.com/skratchdot/open-golang/open" + + "github.com/netbirdio/netbird/client/proto" + nbstatus "github.com/netbirdio/netbird/client/status" +) + +func (s *serviceClient) createAndOpenDebugBundle() error { + conn, err := s.getSrvClient(failFastTimeout) + if err != nil { + return fmt.Errorf("get client: %v", err) + } + + statusResp, err := conn.Status(s.ctx, &proto.StatusRequest{GetFullPeerStatus: true}) + if err != nil { + return fmt.Errorf("failed to get status: %v", err) + } + + overview := nbstatus.ConvertToStatusOutputOverview(statusResp, true, "", nil, nil, nil) + statusOutput := nbstatus.ParseToFullDetailSummary(overview) + + resp, err := conn.DebugBundle(s.ctx, &proto.DebugBundleRequest{ + Anonymize: true, + Status: statusOutput, + SystemInfo: true, + }) + if err != nil { + return fmt.Errorf("failed to create debug bundle: %v", err) + } + + bundleDir := filepath.Dir(resp.GetPath()) + if err := open.Start(bundleDir); err != nil { + return fmt.Errorf("failed to open debug bundle directory: %v", err) + } + + s.app.SendNotification(fyne.NewNotification( + "Debug Bundle", + fmt.Sprintf("Debug bundle created at %s. Administrator privileges are required to access it.", resp.GetPath()), + )) + + return nil +} diff --git a/client/ui/desktop/desktop.go b/client/ui/desktop/desktop.go new file mode 100644 index 000000000..0c99e2f38 --- /dev/null +++ b/client/ui/desktop/desktop.go @@ -0,0 +1,8 @@ +package desktop + +import "github.com/netbirdio/netbird/version" + +// GetUIUserAgent returns the Desktop ui user agent +func GetUIUserAgent() string { + return "netbird-desktop-ui/" + version.NetbirdVersion() +} diff --git a/client/ui/event/event.go b/client/ui/event/event.go index 7925ee4d3..4d949416d 100644 --- a/client/ui/event/event.go +++ b/client/ui/event/event.go @@ -3,6 +3,7 @@ package event import ( "context" "fmt" + "slices" "strings" "sync" "time" @@ -14,17 +15,20 @@ import ( "google.golang.org/grpc/credentials/insecure" "github.com/netbirdio/netbird/client/proto" - "github.com/netbirdio/netbird/client/system" + "github.com/netbirdio/netbird/client/ui/desktop" ) +type Handler func(*proto.SystemEvent) + type Manager struct { app fyne.App addr string - mu sync.Mutex - ctx context.Context - cancel context.CancelFunc - enabled bool + mu sync.Mutex + ctx context.Context + cancel context.CancelFunc + enabled bool + handlers []Handler } func NewManager(app fyne.App, addr string) *Manager { @@ -100,20 +104,41 @@ func (e *Manager) SetNotificationsEnabled(enabled bool) { func (e *Manager) handleEvent(event *proto.SystemEvent) { e.mu.Lock() enabled := e.enabled + handlers := slices.Clone(e.handlers) e.mu.Unlock() - if !enabled { + // critical events are always shown + if !enabled && event.Severity != proto.SystemEvent_CRITICAL { return } - title := e.getEventTitle(event) - e.app.SendNotification(fyne.NewNotification(title, event.UserMessage)) + if event.UserMessage != "" { + title := e.getEventTitle(event) + body := event.UserMessage + id := event.Metadata["id"] + if id != "" { + body += fmt.Sprintf(" ID: %s", id) + } + e.app.SendNotification(fyne.NewNotification(title, body)) + } + + for _, handler := range handlers { + go handler(event) + } +} + +func (e *Manager) AddHandler(handler Handler) { + e.mu.Lock() + defer e.mu.Unlock() + e.handlers = append(e.handlers, handler) } func (e *Manager) getEventTitle(event *proto.SystemEvent) string { var prefix string switch event.Severity { - case proto.SystemEvent_ERROR, proto.SystemEvent_CRITICAL: + case proto.SystemEvent_CRITICAL: + prefix = "Critical" + case proto.SystemEvent_ERROR: prefix = "Error" case proto.SystemEvent_WARNING: prefix = "Warning" @@ -142,7 +167,7 @@ func getClient(addr string) (proto.DaemonServiceClient, error) { conn, err := grpc.NewClient( strings.TrimPrefix(addr, "tcp://"), grpc.WithTransportCredentials(insecure.NewCredentials()), - grpc.WithUserAgent(system.GetDesktopUIUserAgent()), + grpc.WithUserAgent(desktop.GetUIUserAgent()), ) if err != nil { return nil, err diff --git a/client/ui/network.go b/client/ui/network.go index 852c4765b..750788cf3 100644 --- a/client/ui/network.go +++ b/client/ui/network.go @@ -3,7 +3,9 @@ package main import ( + "context" "fmt" + "runtime" "sort" "strings" "time" @@ -13,6 +15,7 @@ import ( "fyne.io/fyne/v2/dialog" "fyne.io/fyne/v2/layout" "fyne.io/fyne/v2/widget" + "fyne.io/systray" log "github.com/sirupsen/logrus" "github.com/netbirdio/netbird/client/proto" @@ -237,14 +240,14 @@ func (s *serviceClient) selectNetwork(id string, checked bool) { s.showError(fmt.Errorf("failed to select network: %v", err)) return } - log.Infof("Route %s selected", id) + log.Infof("Network '%s' selected", id) } else { if _, err := conn.DeselectNetworks(s.ctx, req); err != nil { log.Errorf("failed to deselect network: %v", err) s.showError(fmt.Errorf("failed to deselect network: %v", err)) return } - log.Infof("Network %s deselected", id) + log.Infof("Network '%s' deselected", id) } } @@ -324,6 +327,201 @@ func (s *serviceClient) updateNetworksBasedOnDisplayTab(tabs *container.AppTabs, s.updateNetworks(grid, f) } +func (s *serviceClient) updateExitNodes() { + conn, err := s.getSrvClient(defaultFailTimeout) + if err != nil { + log.Errorf("get client: %v", err) + return + } + + exitNodes, err := s.getExitNodes(conn) + if err != nil { + log.Errorf("get exit nodes: %v", err) + return + } + + s.exitNodeMu.Lock() + defer s.exitNodeMu.Unlock() + + s.recreateExitNodeMenu(exitNodes) + + if len(s.mExitNodeItems) > 0 { + s.mExitNode.Enable() + } else { + s.mExitNode.Disable() + } + + log.Debugf("Exit nodes updated: %d", len(s.mExitNodeItems)) +} + +func (s *serviceClient) recreateExitNodeMenu(exitNodes []*proto.Network) { + for _, node := range s.mExitNodeItems { + node.cancel() + node.Remove() + } + s.mExitNodeItems = nil + + if runtime.GOOS == "linux" || runtime.GOOS == "freebsd" { + s.mExitNode.Remove() + s.mExitNode = systray.AddMenuItem("Exit Node", "Select exit node for routing traffic") + } + + for _, node := range exitNodes { + menuItem := s.mExitNode.AddSubMenuItemCheckbox( + node.ID, + fmt.Sprintf("Use exit node %s", node.ID), + node.Selected, + ) + + ctx, cancel := context.WithCancel(context.Background()) + s.mExitNodeItems = append(s.mExitNodeItems, menuHandler{ + MenuItem: menuItem, + cancel: cancel, + }) + go s.handleChecked(ctx, node.ID, menuItem) + } + +} + +func (s *serviceClient) getExitNodes(conn proto.DaemonServiceClient) ([]*proto.Network, error) { + ctx, cancel := context.WithTimeout(s.ctx, defaultFailTimeout) + defer cancel() + + resp, err := conn.ListNetworks(ctx, &proto.ListNetworksRequest{}) + if err != nil { + return nil, fmt.Errorf("list networks: %v", err) + } + + var exitNodes []*proto.Network + for _, network := range resp.Routes { + if network.Range == "0.0.0.0/0" { + exitNodes = append(exitNodes, network) + } + } + return exitNodes, nil +} + +func (s *serviceClient) handleChecked(ctx context.Context, id string, item *systray.MenuItem) { + for { + select { + case <-ctx.Done(): + return + case _, ok := <-item.ClickedCh: + if !ok { + return + } + if err := s.toggleExitNode(id, item); err != nil { + log.Errorf("failed to toggle exit node: %v", err) + continue + } + } + } +} + +// Add function to toggle exit node selection +func (s *serviceClient) toggleExitNode(nodeID string, item *systray.MenuItem) error { + conn, err := s.getSrvClient(defaultFailTimeout) + if err != nil { + return fmt.Errorf("get client: %v", err) + } + + log.Infof("Toggling exit node '%s'", nodeID) + + s.exitNodeMu.Lock() + defer s.exitNodeMu.Unlock() + + exitNodes, err := s.getExitNodes(conn) + if err != nil { + return fmt.Errorf("get exit nodes: %v", err) + } + + var exitNode *proto.Network + // find other selected nodes and ours + ids := make([]string, 0, len(exitNodes)) + for _, node := range exitNodes { + if node.ID == nodeID { + // preserve original state + cp := *node //nolint:govet + exitNode = &cp + + // set desired state for recreation + node.Selected = true + continue + } + if node.Selected { + ids = append(ids, node.ID) + + // set desired state for recreation + node.Selected = false + } + } + + if item.Checked() && len(ids) == 0 { + // exit node is the only selected node, deselect it + ids = append(ids, nodeID) + exitNode = nil + } + + // deselect all other selected exit nodes + if err := s.deselectOtherExitNodes(conn, ids, item); err != nil { + return err + } + + if err := s.selectNewExitNode(conn, exitNode, nodeID, item); err != nil { + return err + } + + // linux/bsd doesn't handle Check/Uncheck well, so we recreate the menu + if runtime.GOOS == "linux" || runtime.GOOS == "freebsd" { + s.recreateExitNodeMenu(exitNodes) + } + + return nil +} + +func (s *serviceClient) deselectOtherExitNodes(conn proto.DaemonServiceClient, ids []string, currentItem *systray.MenuItem) error { + // deselect all other selected exit nodes + if len(ids) > 0 { + deselectReq := &proto.SelectNetworksRequest{ + NetworkIDs: ids, + } + if _, err := conn.DeselectNetworks(s.ctx, deselectReq); err != nil { + return fmt.Errorf("deselect networks: %v", err) + } + + log.Infof("Deselected exit nodes: %v", ids) + } + + // uncheck all other exit node menu items + for _, i := range s.mExitNodeItems { + if i.MenuItem == currentItem { + continue + } + i.Uncheck() + log.Infof("Unchecked exit node %v", i) + } + + return nil +} + +func (s *serviceClient) selectNewExitNode(conn proto.DaemonServiceClient, exitNode *proto.Network, nodeID string, item *systray.MenuItem) error { + if exitNode != nil && !exitNode.Selected { + selectReq := &proto.SelectNetworksRequest{ + NetworkIDs: []string{exitNode.ID}, + Append: true, + } + if _, err := conn.SelectNetworks(s.ctx, selectReq); err != nil { + return fmt.Errorf("select network: %v", err) + } + + log.Infof("Selected exit node '%s'", nodeID) + } + + item.Check() + + return nil +} + func getGridAndFilterFromTab(tabs *container.AppTabs, allGrid, overlappingGrid, exitNodesGrid *fyne.Container) (*fyne.Container, filter) { switch tabs.Selected().Text { case overlappingNetworksText: diff --git a/go.sum b/go.sum index c0685caa9..0ccf91b5d 100644 --- a/go.sum +++ b/go.sum @@ -18,10 +18,10 @@ cloud.google.com/go v0.74.0/go.mod h1:VV1xSbzvo+9QJOxLDaJfTjx5e+MePCpCWwvftOeQmW cloud.google.com/go v0.78.0/go.mod h1:QjdrLG0uq+YwhjoVOLsS1t7TW8fs36kLs4XO5R5ECHg= cloud.google.com/go v0.79.0/go.mod h1:3bzgcEeQlzbuEAYu4mrWhKqWjmpprinYgKJLgKHnbb8= cloud.google.com/go v0.81.0/go.mod h1:mk/AM35KwGk/Nm2YSeZbxXdrNK3KZOYHmLkOqC2V6E0= -cloud.google.com/go/auth v0.3.0 h1:PRyzEpGfx/Z9e8+lHsbkoUVXD0gnu4MNmm7Gp8TQNIs= -cloud.google.com/go/auth v0.3.0/go.mod h1:lBv6NKTWp8E3LPzmO1TbiiRKc4drLOfHsgmlH9ogv5w= -cloud.google.com/go/auth/oauth2adapt v0.2.2 h1:+TTV8aXpjeChS9M+aTtN/TjdQnzJvmzKFt//oWu7HX4= -cloud.google.com/go/auth/oauth2adapt v0.2.2/go.mod h1:wcYjgpZI9+Yu7LyYBg4pqSiaRkfEK3GQcpb7C/uyF1Q= +cloud.google.com/go/auth v0.14.1 h1:AwoJbzUdxA/whv1qj3TLKwh3XX5sikny2fc40wUl+h0= +cloud.google.com/go/auth v0.14.1/go.mod h1:4JHUxlGXisL0AW8kXPtUF6ztuOksyfUQNFjfsOCXkPM= +cloud.google.com/go/auth/oauth2adapt v0.2.7 h1:/Lc7xODdqcEw8IrZ9SvwnlLX6j9FHQM74z6cBk9Rw6M= +cloud.google.com/go/auth/oauth2adapt v0.2.7/go.mod h1:NTbTTzfvPl1Y3V1nPpOgl2w6d/FjO7NNUQaWSox6ZMc= cloud.google.com/go/bigquery v1.0.1/go.mod h1:i/xbL2UlR5RvWAURpBYZTtm/cXjCha9lbfbpx4poX+o= cloud.google.com/go/bigquery v1.3.0/go.mod h1:PjpwJnslEMmckchkHFfq+HTD2DmtT67aNFKH1/VBDHE= cloud.google.com/go/bigquery v1.4.0/go.mod h1:S8dzgnTigyfTmLBfrtrhyYhwRxG72rYxvftPBK2Dvzc= @@ -29,8 +29,8 @@ cloud.google.com/go/bigquery v1.5.0/go.mod h1:snEHRnqQbz117VIFhE8bmtwIDY80NLUZUM cloud.google.com/go/bigquery v1.7.0/go.mod h1://okPTzCYNXSlb24MZs83e2Do+h+VXtc4gLoIoXIAPc= cloud.google.com/go/bigquery v1.8.0/go.mod h1:J5hqkt3O0uAFnINi6JXValWIb1v0goeZM77hZzJN/fQ= cloud.google.com/go/compute/metadata v0.2.0/go.mod h1:zFmK7XCadkQkj6TtorcaGlCW1hT1fIilQDwofLpJ20k= -cloud.google.com/go/compute/metadata v0.3.0 h1:Tz+eQXMEqDIKRsmY3cHTL6FVaynIjX2QxYC4trgAKZc= -cloud.google.com/go/compute/metadata v0.3.0/go.mod h1:zFmK7XCadkQkj6TtorcaGlCW1hT1fIilQDwofLpJ20k= +cloud.google.com/go/compute/metadata v0.6.0 h1:A6hENjEsCDtC1k8byVsgwvVcioamEHvZ4j01OwKxG9I= +cloud.google.com/go/compute/metadata v0.6.0/go.mod h1:FjyFAW1MW0C203CEOMDTu3Dk1FlqW3Rga40jzHL4hfg= cloud.google.com/go/datastore v1.0.0/go.mod h1:LXYbyblFSglQ5pkeyhO+Qmw7ukd3C+pD7TKLgZqpHYE= cloud.google.com/go/datastore v1.1.0/go.mod h1:umbIZjpQpHh4hmRpGhH4tLFup+FVzqBi1b3c64qFpCk= cloud.google.com/go/firestore v1.1.0/go.mod h1:ulACoGHTpvq5r8rxGJ4ddJZBZqakUQqClKRT5SZwBmk= @@ -225,8 +225,8 @@ github.com/go-gl/glfw/v3.3/glfw v0.0.0-20240506104042-037f3cc74f2a h1:vxnBhFDDT+ github.com/go-gl/glfw/v3.3/glfw v0.0.0-20240506104042-037f3cc74f2a/go.mod h1:tQ2UAYgL5IevRw8kRxooKSPJfGvJ9fJQFa0TUsXzTg8= github.com/go-logr/logr v0.1.0/go.mod h1:ixOQHD9gLJUVQQ2ZOR7zLEifBX6tGkNJF4QyIY7sIas= github.com/go-logr/logr v1.2.2/go.mod h1:jdQByPbusPIv2/zmleS9BjJVeZ6kBagPoEUsqbVz/1A= -github.com/go-logr/logr v1.4.1 h1:pKouT5E8xu9zeFC39JXRDukb6JFQPXM5p5I91188VAQ= -github.com/go-logr/logr v1.4.1/go.mod h1:9T104GzyrTigFIr8wt5mBrctHMim0Nb2HLGrmQ40KvY= +github.com/go-logr/logr v1.4.2 h1:6pFjapn8bFcIbiKo3XT4j/BhANplGihG6tvd+8rYgrY= +github.com/go-logr/logr v1.4.2/go.mod h1:9T104GzyrTigFIr8wt5mBrctHMim0Nb2HLGrmQ40KvY= github.com/go-logr/stdr v1.2.2 h1:hSWxHoqTgW2S2qGc0LTAI563KZ5YKYRhT3MFKZMbjag= github.com/go-logr/stdr v1.2.2/go.mod h1:mMo/vtBO5dYbehREoey6XUKy/eSumjCCveDpRre4VKE= github.com/go-ole/go-ole v1.2.6/go.mod h1:pprOEPIfldk/42T2oK7lQ4v4JSDwmV0As9GaiUsvbm0= @@ -263,14 +263,12 @@ github.com/gogo/protobuf v1.3.2/go.mod h1:P1XiOD3dCwIKUDQYPy72D8LYyHL2YPYrpS2s69 github.com/golang-jwt/jwt v3.2.2+incompatible h1:IfV12K8xAKAnZqdXVzCZ+TOjboZ2keLg81eXfW3O+oY= github.com/golang-jwt/jwt v3.2.2+incompatible/go.mod h1:8pz2t5EyA70fFQQSrl6XZXzqecmYZeUEB8OUGHkxJ+I= github.com/golang/glog v0.0.0-20160126235308-23def4e6c14b/go.mod h1:SBH7ygxi8pfUlaOkMMuAQtPIUF8ecWP5IEl/CR7VP2Q= -github.com/golang/glog v1.2.0 h1:uCdmnmatrKCgMBlM4rMuJZWOkPDqdbZPnrMXDY4gI68= -github.com/golang/glog v1.2.0/go.mod h1:6AhwSGph0fcJtXVM/PEHPqZlFeoLxhs7/t5UDAwmO+w= +github.com/golang/glog v1.2.3 h1:oDTdz9f5VGVVNGu/Q7UXKWYsD0873HXLHdJUNBsSEKM= +github.com/golang/glog v1.2.3/go.mod h1:6AhwSGph0fcJtXVM/PEHPqZlFeoLxhs7/t5UDAwmO+w= github.com/golang/groupcache v0.0.0-20160516000752-02826c3e7903/go.mod h1:cIg4eruTrX1D+g88fzRXU5OdNfaM+9IcxsU14FzY7Hc= github.com/golang/groupcache v0.0.0-20190702054246-869f871628b6/go.mod h1:cIg4eruTrX1D+g88fzRXU5OdNfaM+9IcxsU14FzY7Hc= github.com/golang/groupcache v0.0.0-20191227052852-215e87163ea7/go.mod h1:cIg4eruTrX1D+g88fzRXU5OdNfaM+9IcxsU14FzY7Hc= github.com/golang/groupcache v0.0.0-20200121045136-8c9f03a8e57e/go.mod h1:cIg4eruTrX1D+g88fzRXU5OdNfaM+9IcxsU14FzY7Hc= -github.com/golang/groupcache v0.0.0-20210331224755-41bb18bfe9da h1:oI5xCqsCo564l8iNU+DwB5epxmsaqB+rhGL0m5jtYqE= -github.com/golang/groupcache v0.0.0-20210331224755-41bb18bfe9da/go.mod h1:cIg4eruTrX1D+g88fzRXU5OdNfaM+9IcxsU14FzY7Hc= github.com/golang/mock v1.1.1/go.mod h1:oTYuIxOrZwtPieC+H1uAHpcLFnEyAGVDL/k47Jfbm0A= github.com/golang/mock v1.2.0/go.mod h1:oTYuIxOrZwtPieC+H1uAHpcLFnEyAGVDL/k47Jfbm0A= github.com/golang/mock v1.3.1/go.mod h1:sBzyDLLjw3U8JLTeZvSv8jJB+tU5PVekmnlKIyFUx0Y= @@ -345,18 +343,18 @@ github.com/google/pprof v0.0.0-20210226084205-cbba55b83ad5/go.mod h1:kpwsk12EmLe github.com/google/pprof v0.0.0-20211214055906-6f57359322fd h1:1FjCyPC+syAzJ5/2S8fqdZK1R22vvA0J7JZKcuOIQ7Y= github.com/google/pprof v0.0.0-20211214055906-6f57359322fd/go.mod h1:KgnwoLYCZ8IQu3XUZ8Nc/bM9CCZFOyjUNOSygVozoDg= github.com/google/renameio v0.1.0/go.mod h1:KWCgfxg9yswjAJkECMjeO8J8rahYeXnNhOm40UhjYkI= -github.com/google/s2a-go v0.1.7 h1:60BLSyTrOV4/haCDW4zb1guZItoSq8foHCXrAnjBo/o= -github.com/google/s2a-go v0.1.7/go.mod h1:50CgR4k1jNlWBu4UfS4AcfhVe1r6pdZPygJ3R8F0Qdw= +github.com/google/s2a-go v0.1.9 h1:LGD7gtMgezd8a/Xak7mEWL0PjoTQFvpRudN895yqKW0= +github.com/google/s2a-go v0.1.9/go.mod h1:YA0Ei2ZQL3acow2O62kdp9UlnvMmU7kA6Eutn0dXayM= github.com/google/uuid v1.1.1/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= github.com/google/uuid v1.1.2/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0= github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= -github.com/googleapis/enterprise-certificate-proxy v0.3.2 h1:Vie5ybvEvT75RniqhfFxPRy3Bf7vr3h0cechB90XaQs= -github.com/googleapis/enterprise-certificate-proxy v0.3.2/go.mod h1:VLSiSSBs/ksPL8kq3OBOQ6WRI2QnaFynd1DCjZ62+V0= +github.com/googleapis/enterprise-certificate-proxy v0.3.4 h1:XYIDZApgAnrN1c855gTgghdIA6Stxb52D5RnLI1SLyw= +github.com/googleapis/enterprise-certificate-proxy v0.3.4/go.mod h1:YKe7cfqYXjKGpGvmSg28/fFvhNzinZQm8DGnaburhGA= github.com/googleapis/gax-go/v2 v2.0.4/go.mod h1:0Wqv26UfaUD9n4G6kQubkQ+KchISgw+vpHVxEJEs9eg= github.com/googleapis/gax-go/v2 v2.0.5/go.mod h1:DWXyrwAJ9X0FpwwEdw+IPEYBICEFu5mhpdKc/us6bOk= -github.com/googleapis/gax-go/v2 v2.12.3 h1:5/zPPDvw8Q1SuXjrqrZslrqT7dL/uJT2CQii/cLCKqA= -github.com/googleapis/gax-go/v2 v2.12.3/go.mod h1:AKloxT6GtNbaLm8QTNSidHUVsHYcBHwWRvkNFJUQcS4= +github.com/googleapis/gax-go/v2 v2.14.1 h1:hb0FFeiPaQskmvakKu5EbCbpntQn48jyHuvrkurSS/Q= +github.com/googleapis/gax-go/v2 v2.14.1/go.mod h1:Hb/NubMaVM88SrNkvl8X/o8XWwDJEPqouaLeN2IUxoA= github.com/googleapis/gnostic v0.0.0-20170729233727-0c5108395e2d/go.mod h1:sJBsCZ4ayReDTBIg8b9dl28c5xFWyhBTVRp3pOg5EKY= github.com/gopacket/gopacket v1.1.1 h1:zbx9F9d6A7sWNkFKrvMBZTfGgxFoY4NgUudFVVHMfcw= github.com/gopacket/gopacket v1.1.1/go.mod h1:HavMeONEl7W9036of9LbSWoonqhH7HA1+ZRO+rMIvFs= @@ -617,8 +615,8 @@ github.com/quic-go/quic-go v0.48.2 h1:wsKXZPeGWpMpCGSWqOcqpW2wZYic/8T3aqiOID0/KW github.com/quic-go/quic-go v0.48.2/go.mod h1:yBgs3rWBOADpga7F+jJsb6Ybg1LSYiQvwWlLX+/6HMs= github.com/rogpeppe/fastuuid v1.2.0/go.mod h1:jVj6XXZzXRy/MSR5jhDC/2q6DgLz+nrA6LYCDYWNEvQ= github.com/rogpeppe/go-internal v1.3.0/go.mod h1:M8bDsm7K2OlrFYOpmOWEs/qY81heoFRclV5y23lUDJ4= -github.com/rogpeppe/go-internal v1.11.0 h1:cWPaGQEPrBb5/AsnsZesgZZ9yb1OQ+GOISoDNXVBh4M= -github.com/rogpeppe/go-internal v1.11.0/go.mod h1:ddIwULY96R17DhadqLgMfk9H9tvdUzkipdSkR5nkCZA= +github.com/rogpeppe/go-internal v1.13.1 h1:KvO1DLK/DRN07sQ1LQKScxyZJuNnedQ5/wKSR38lUII= +github.com/rogpeppe/go-internal v1.13.1/go.mod h1:uMEvuHeurkdAXX61udpOXGD/AzZDWNMNyH2VO9fmH0o= github.com/rs/cors v1.8.0 h1:P2KMzcFwrPoSjkF1WLRPsp3UMLyql8L4v9hQpVeK5so= github.com/rs/cors v1.8.0/go.mod h1:EBwu+T5AvHOcXwvZIkQFjUN6s8Czyqw12GL/Y0tUyRM= github.com/rs/xid v1.3.0 h1:6NjYksEUlhurdVehpc7S7dk6DAmcKv8V9gG0FsVN2U4= @@ -683,11 +681,11 @@ github.com/stretchr/testify v1.6.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/ github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= github.com/stretchr/testify v1.7.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= github.com/stretchr/testify v1.8.0/go.mod h1:yNjHg4UonilssWZ8iaSj1OCr/vHnekPRkoO+kdMU+MU= -github.com/stretchr/testify v1.8.1/go.mod h1:w2LPCIKwWwSfY2zedu0+kehJoqGctiVI29o6fzry7u4= github.com/stretchr/testify v1.8.3/go.mod h1:sz/lmYIOXD/1dqDmKjjqLyZ2RngseejIcXlSw2iwfAo= github.com/stretchr/testify v1.8.4/go.mod h1:sz/lmYIOXD/1dqDmKjjqLyZ2RngseejIcXlSw2iwfAo= -github.com/stretchr/testify v1.9.0 h1:HtqpIVDClZ4nwg75+f6Lvsy/wHu+3BoSGCbBAcpTsTg= github.com/stretchr/testify v1.9.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY= +github.com/stretchr/testify v1.10.0 h1:Xv5erBjTwe/5IxqUQTdXv5kgmIvbHo3QQyRwhJsOfJA= +github.com/stretchr/testify v1.10.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY= github.com/subosito/gotenv v1.2.0/go.mod h1:N0PQaV/YGNqwC0u51sEeR/aUtSLEXKX9iv69rRypqCw= github.com/testcontainers/testcontainers-go v0.31.0 h1:W0VwIhcEVhRflwL9as3dhY6jXjVCA27AkmbnZ+UTh3U= github.com/testcontainers/testcontainers-go v0.31.0/go.mod h1:D2lAoA0zUFiSY+eAflqK5mcUx/A5hrrORaEQrd0SefI= @@ -739,28 +737,28 @@ go.opencensus.io v0.22.3/go.mod h1:yxeiOL68Rb0Xd1ddK5vPZ/oVn4vY4Ynel7k9FzqtOIw= go.opencensus.io v0.22.4/go.mod h1:yxeiOL68Rb0Xd1ddK5vPZ/oVn4vY4Ynel7k9FzqtOIw= go.opencensus.io v0.22.5/go.mod h1:5pWMHQbX5EPX2/62yrJeAkowc+lfs/XD7Uxpq3pI6kk= go.opencensus.io v0.23.0/go.mod h1:XItmlyltB5F7CS4xOC1DcqMoFqwtC6OG2xF7mCv7P7E= -go.opencensus.io v0.24.0 h1:y73uSU6J157QMP2kn2r30vwW1A2W2WFwSCGnAVxeaD0= -go.opencensus.io v0.24.0/go.mod h1:vNK8G9p7aAivkbmorf4v+7Hgx+Zs0yY+0fOtgBfjQKo= -go.opentelemetry.io/contrib/instrumentation/google.golang.org/grpc/otelgrpc v0.49.0 h1:4Pp6oUg3+e/6M4C0A/3kJ2VYa++dsWVTtGgLVj5xtHg= -go.opentelemetry.io/contrib/instrumentation/google.golang.org/grpc/otelgrpc v0.49.0/go.mod h1:Mjt1i1INqiaoZOMGR1RIUJN+i3ChKoFRqzrRQhlkbs0= -go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.51.0 h1:Xs2Ncz0gNihqu9iosIZ5SkBbWo5T8JhhLJFMQL1qmLI= -go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.51.0/go.mod h1:vy+2G/6NvVMpwGX/NyLqcC41fxepnuKHk16E6IZUcJc= -go.opentelemetry.io/otel v1.26.0 h1:LQwgL5s/1W7YiiRwxf03QGnWLb2HW4pLiAhaA5cZXBs= -go.opentelemetry.io/otel v1.26.0/go.mod h1:UmLkJHUAidDval2EICqBMbnAd0/m2vmpf/dAM+fvFs4= +go.opentelemetry.io/auto/sdk v1.1.0 h1:cH53jehLUN6UFLY71z+NDOiNJqDdPRaXzTel0sJySYA= +go.opentelemetry.io/auto/sdk v1.1.0/go.mod h1:3wSPjt5PWp2RhlCcmmOial7AvC4DQqZb7a7wCow3W8A= +go.opentelemetry.io/contrib/instrumentation/google.golang.org/grpc/otelgrpc v0.58.0 h1:PS8wXpbyaDJQ2VDHHncMe9Vct0Zn1fEjpsjrLxGJoSc= +go.opentelemetry.io/contrib/instrumentation/google.golang.org/grpc/otelgrpc v0.58.0/go.mod h1:HDBUsEjOuRC0EzKZ1bSaRGZWUBAzo+MhAcUUORSr4D0= +go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.58.0 h1:yd02MEjBdJkG3uabWP9apV+OuWRIXGDuJEUJbOHmCFU= +go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.58.0/go.mod h1:umTcuxiv1n/s/S6/c2AT/g2CQ7u5C59sHDNmfSwgz7Q= +go.opentelemetry.io/otel v1.34.0 h1:zRLXxLCgL1WyKsPVrgbSdMN4c0FMkDAskSTQP+0hdUY= +go.opentelemetry.io/otel v1.34.0/go.mod h1:OWFPOQ+h4G8xpyjgqo4SxJYdDQ/qmRH+wivy7zzx9oI= go.opentelemetry.io/otel/exporters/otlp/otlptrace v1.19.0 h1:Mne5On7VWdx7omSrSSZvM4Kw7cS7NQkOOmLcgscI51U= go.opentelemetry.io/otel/exporters/otlp/otlptrace v1.19.0/go.mod h1:IPtUMKL4O3tH5y+iXVyAXqpAwMuzC1IrxVS81rummfE= go.opentelemetry.io/otel/exporters/otlp/otlptrace/otlptracehttp v1.19.0 h1:IeMeyr1aBvBiPVYihXIaeIZba6b8E1bYp7lbdxK8CQg= go.opentelemetry.io/otel/exporters/otlp/otlptrace/otlptracehttp v1.19.0/go.mod h1:oVdCUtjq9MK9BlS7TtucsQwUcXcymNiEDjgDD2jMtZU= go.opentelemetry.io/otel/exporters/prometheus v0.48.0 h1:sBQe3VNGUjY9IKWQC6z2lNqa5iGbDSxhs60ABwK4y0s= go.opentelemetry.io/otel/exporters/prometheus v0.48.0/go.mod h1:DtrbMzoZWwQHyrQmCfLam5DZbnmorsGbOtTbYHycU5o= -go.opentelemetry.io/otel/metric v1.26.0 h1:7S39CLuY5Jgg9CrnA9HHiEjGMF/X2VHvoXGgSllRz30= -go.opentelemetry.io/otel/metric v1.26.0/go.mod h1:SY+rHOI4cEawI9a7N1A4nIg/nTQXe1ccCNWYOJUrpX4= -go.opentelemetry.io/otel/sdk v1.26.0 h1:Y7bumHf5tAiDlRYFmGqetNcLaVUZmh4iYfmGxtmz7F8= -go.opentelemetry.io/otel/sdk v1.26.0/go.mod h1:0p8MXpqLeJ0pzcszQQN4F0S5FVjBLgypeGSngLsmirs= -go.opentelemetry.io/otel/sdk/metric v1.26.0 h1:cWSks5tfriHPdWFnl+qpX3P681aAYqlZHcAyHw5aU9Y= -go.opentelemetry.io/otel/sdk/metric v1.26.0/go.mod h1:ClMFFknnThJCksebJwz7KIyEDHO+nTB6gK8obLy8RyE= -go.opentelemetry.io/otel/trace v1.26.0 h1:1ieeAUb4y0TE26jUFrCIXKpTuVK7uJGN9/Z/2LP5sQA= -go.opentelemetry.io/otel/trace v1.26.0/go.mod h1:4iDxvGDQuUkHve82hJJ8UqrwswHYsZuWCBllGV2U2y0= +go.opentelemetry.io/otel/metric v1.34.0 h1:+eTR3U0MyfWjRDhmFMxe2SsW64QrZ84AOhvqS7Y+PoQ= +go.opentelemetry.io/otel/metric v1.34.0/go.mod h1:CEDrp0fy2D0MvkXE+dPV7cMi8tWZwX3dmaIhwPOaqHE= +go.opentelemetry.io/otel/sdk v1.34.0 h1:95zS4k/2GOy069d321O8jWgYsW3MzVV+KuSPKp7Wr1A= +go.opentelemetry.io/otel/sdk v1.34.0/go.mod h1:0e/pNiaMAqaykJGKbi+tSjWfNNHMTxoC9qANsCzbyxU= +go.opentelemetry.io/otel/sdk/metric v1.32.0 h1:rZvFnvmvawYb0alrYkjraqJq0Z4ZUJAiyYCU9snn1CU= +go.opentelemetry.io/otel/sdk/metric v1.32.0/go.mod h1:PWeZlq0zt9YkYAp3gjKZ0eicRYvOh1Gd+X99x6GHpCQ= +go.opentelemetry.io/otel/trace v1.34.0 h1:+ouXS2V8Rd4hp4580a8q23bg0azF2nI8cqLYnC8mh/k= +go.opentelemetry.io/otel/trace v1.34.0/go.mod h1:Svm7lSjQD7kG7KJ/MUHPVXSDGz2OX4h0M2jHBhmSfRE= go.opentelemetry.io/proto/otlp v1.0.0 h1:T0TX0tmXU8a3CbNXzEKGeU5mIVOdf0oykP+u2lIVU/I= go.opentelemetry.io/proto/otlp v1.0.0/go.mod h1:Sy6pihPLfYHkr3NkUbEhGHFhINUSI/v80hjKIs5JXpM= go.uber.org/atomic v1.7.0/go.mod h1:fEN4uk6kAWBTFdckzkM89CLk9XfWZrxpCo0nPH17wJc= @@ -885,8 +883,8 @@ golang.org/x/net v0.9.0/go.mod h1:d48xBJpPfHeWQsugry2m+kC02ZBRGRgulfHnEXEuWns= golang.org/x/net v0.10.0/go.mod h1:0qNGK6F8kojg2nk9dLZ2mShWaEBan6FAoqfSigmmuDg= golang.org/x/net v0.14.0/go.mod h1:PpSgVXXLK0OxS0F31C1/tv6XNguvCrnXIDrFMspZIUI= golang.org/x/net v0.20.0/go.mod h1:z8BVo6PvndSri0LbOE3hAn0apkU+1YvI6E70E9jsnvY= -golang.org/x/net v0.33.0 h1:74SYHlV8BIgHIFC/LrYkOGIwL19eTYXQ5wc6TBuO36I= -golang.org/x/net v0.33.0/go.mod h1:HXLR5J+9DxmrqMwG9qjGCxZ+zKXxBru04zlTvWlWuN4= +golang.org/x/net v0.34.0 h1:Mb7Mrk043xzHgnRM88suvJFwzVrRfHEHJEl5/71CKw0= +golang.org/x/net v0.34.0/go.mod h1:di0qlW3YNM5oh6GqDGQr92MyTozJPmybPK4Ev/Gm31k= golang.org/x/oauth2 v0.0.0-20180821212333-d2e6202438be/go.mod h1:N/0e6XlmueqKjAGxoOufVs8QHGRruUQn6yWY3a++T0U= golang.org/x/oauth2 v0.0.0-20190226205417-e64efc72b421/go.mod h1:gOpvHmFTYa4IltrdGE7lF6nIHvwfUNPOp7c8zoXwtLw= golang.org/x/oauth2 v0.0.0-20190604053449-0f29369cfe45/go.mod h1:gOpvHmFTYa4IltrdGE7lF6nIHvwfUNPOp7c8zoXwtLw= @@ -900,8 +898,8 @@ golang.org/x/oauth2 v0.0.0-20210220000619-9bb904979d93/go.mod h1:KelEdhl1UZF7XfJ golang.org/x/oauth2 v0.0.0-20210313182246-cd4f82c27b84/go.mod h1:KelEdhl1UZF7XfJ4dDtk6s++YSgaE7mD/BuKKDLBl4A= golang.org/x/oauth2 v0.0.0-20210402161424-2e8d93401602/go.mod h1:KelEdhl1UZF7XfJ4dDtk6s++YSgaE7mD/BuKKDLBl4A= golang.org/x/oauth2 v0.8.0/go.mod h1:yr7u4HXZRm1R1kBWqr/xKNqewf0plRYoB7sla+BCIXE= -golang.org/x/oauth2 v0.19.0 h1:9+E/EZBCbTLNrbN35fHv/a/d/mOBatymz1zbtQrXpIg= -golang.org/x/oauth2 v0.19.0/go.mod h1:vYi7skDa1x015PmRRYZ7+s1cWyPgrPiSYRe4rnsexc8= +golang.org/x/oauth2 v0.26.0 h1:afQXWNNaeC4nvZ0Ed9XvCCzXM6UHJG7iCg0W4fPqSBE= +golang.org/x/oauth2 v0.26.0/go.mod h1:XYTD2NtWslqkgxebSiOHnXEap4TF09sJSc7H1sXbhtI= golang.org/x/sync v0.0.0-20180314180146-1d60e4601c6f/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.0.0-20181108010431-42b317875d0f/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.0.0-20181221193216-37e7f081c4d4/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= @@ -1019,8 +1017,8 @@ golang.org/x/text v0.21.0/go.mod h1:4IBbMaMmOPCJ8SecivzSH54+73PCFmPWxNTLm+vZkEQ= golang.org/x/time v0.0.0-20181108054448-85acf8d2951c/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ= golang.org/x/time v0.0.0-20190308202827-9d24e82272b4/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ= golang.org/x/time v0.0.0-20191024005414-555d28b269f0/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ= -golang.org/x/time v0.5.0 h1:o7cqy6amK/52YcAKIPlM3a+Fpj35zvRj2TP+e1xFSfk= -golang.org/x/time v0.5.0/go.mod h1:3BpzKBy/shNhVucY/MWOyx10tF3SFh9QdLuxbVysPQM= +golang.org/x/time v0.10.0 h1:3usCWA8tQn0L8+hFJQNgzpWbd89begxN66o1Ojdn5L4= +golang.org/x/time v0.10.0/go.mod h1:3BpzKBy/shNhVucY/MWOyx10tF3SFh9QdLuxbVysPQM= golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= golang.org/x/tools v0.0.0-20181011042414-1f849cf54d09/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= golang.org/x/tools v0.0.0-20181030221726-6c7e314b6563/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= @@ -1114,8 +1112,8 @@ google.golang.org/api v0.40.0/go.mod h1:fYKFpnQN0DsDSKRVRcQSDQNtqWPfM9i+zNPxepjR google.golang.org/api v0.41.0/go.mod h1:RkxM5lITDfTzmyKFPt+wGrCJbVfniCr2ool8kTBzRTU= google.golang.org/api v0.43.0/go.mod h1:nQsDGjRXMo4lvh5hP0TKqF244gqhGcr/YSIykhUk/94= google.golang.org/api v0.44.0/go.mod h1:EBOGZqzyhtvMDoxwS97ctnh0zUmYY6CxqXsc1AvkYD8= -google.golang.org/api v0.177.0 h1:8a0p/BbPa65GlqGWtUKxot4p0TV8OGOfyTjtmkXNXmk= -google.golang.org/api v0.177.0/go.mod h1:srbhue4MLjkjbkux5p3dw/ocYOSZTaIEvf7bCOnFQDw= +google.golang.org/api v0.220.0 h1:3oMI4gdBgB72WFVwE1nerDD8W3HUOS4kypK6rRLbGns= +google.golang.org/api v0.220.0/go.mod h1:26ZAlY6aN/8WgpCzjPNy18QpYaz7Zgg1h0qe1GkZEmY= google.golang.org/appengine v1.1.0/go.mod h1:EbEs0AVv82hx2wNQdGPgUI5lhzA/G0D9YwlJXL52JkM= google.golang.org/appengine v1.4.0/go.mod h1:xpcJRLb0r/rnEns0DIKYYv+WjYCduHsrkT7/EB5XEv4= google.golang.org/appengine v1.5.0/go.mod h1:xpcJRLb0r/rnEns0DIKYYv+WjYCduHsrkT7/EB5XEv4= @@ -1164,10 +1162,11 @@ google.golang.org/genproto v0.0.0-20210310155132-4ce2db91004e/go.mod h1:FWY/as6D google.golang.org/genproto v0.0.0-20210319143718-93e7006c17a6/go.mod h1:FWY/as6DDZQgahTzZj3fqbO1CbirC29ZNUFHwi0/+no= google.golang.org/genproto v0.0.0-20210402141018-6c239bbf2bb1/go.mod h1:9lPAdzaEmUacj36I+k7YKbEc5CXzPIeORRgDAUOu28A= google.golang.org/genproto v0.0.0-20210602131652-f16073e35f0c/go.mod h1:UODoCrxHCcBojKKwX1terBiRUaqAsFqJiF615XL43r0= -google.golang.org/genproto/googleapis/api v0.0.0-20240509183442-62759503f434 h1:OpXbo8JnN8+jZGPrL4SSfaDjSCjupr8lXyBAbexEm/U= -google.golang.org/genproto/googleapis/api v0.0.0-20240509183442-62759503f434/go.mod h1:FfiGhwUm6CJviekPrc0oJ+7h29e+DmWU6UtjX0ZvI7Y= -google.golang.org/genproto/googleapis/rpc v0.0.0-20240903143218-8af14fe29dc1 h1:pPJltXNxVzT4pK9yD8vR9X75DaWYYmLGMsEvBfFQZzQ= -google.golang.org/genproto/googleapis/rpc v0.0.0-20240903143218-8af14fe29dc1/go.mod h1:UqMtugtsSgubUsoxbuAoiCXvqvErP7Gf0so0mK9tHxU= +google.golang.org/genproto v0.0.0-20240123012728-ef4313101c80 h1:KAeGQVN3M9nD0/bQXnr/ClcEMJ968gUXJQ9pwfSynuQ= +google.golang.org/genproto/googleapis/api v0.0.0-20241209162323-e6fa225c2576 h1:CkkIfIt50+lT6NHAVoRYEyAvQGFM7xEwXUUywFvEb3Q= +google.golang.org/genproto/googleapis/api v0.0.0-20241209162323-e6fa225c2576/go.mod h1:1R3kvZ1dtP3+4p4d3G8uJ8rFk/fWlScl38vanWACI08= +google.golang.org/genproto/googleapis/rpc v0.0.0-20250127172529-29210b9bc287 h1:J1H9f+LEdWAfHcez/4cvaVBox7cOYT+IU6rgqj5x++8= +google.golang.org/genproto/googleapis/rpc v0.0.0-20250127172529-29210b9bc287/go.mod h1:8BS3B93F/U1juMFq9+EDk+qOT5CO1R9IzXxG3PTqiRk= google.golang.org/grpc v1.19.0/go.mod h1:mqu4LbDTu4XGKhr4mRzUsmM4RtVoemTSY81AxZiDr8c= google.golang.org/grpc v1.20.1/go.mod h1:10oTOabMzJvdu6/UiuZezV6QK5dSlG84ov/aaiqXj38= google.golang.org/grpc v1.21.1/go.mod h1:oYelfM1adQP15Ek0mdvEgi9Df8B9CZIaU1084ijfRaM= @@ -1188,8 +1187,8 @@ google.golang.org/grpc v1.35.0/go.mod h1:qjiiYl8FncCW8feJPdyg3v6XW24KsRHe+dy9BAG google.golang.org/grpc v1.36.0/go.mod h1:qjiiYl8FncCW8feJPdyg3v6XW24KsRHe+dy9BAGRRjU= google.golang.org/grpc v1.36.1/go.mod h1:qjiiYl8FncCW8feJPdyg3v6XW24KsRHe+dy9BAGRRjU= google.golang.org/grpc v1.38.0/go.mod h1:NREThFqKR1f3iQ6oBuvc5LadQuXVGo9rkm5ZGrQdJfM= -google.golang.org/grpc v1.64.1 h1:LKtvyfbX3UGVPFcGqJ9ItpVWW6oN/2XqTxfAnwRRXiA= -google.golang.org/grpc v1.64.1/go.mod h1:hiQF4LFZelK2WKaP6W0L92zGHtiQdZxk8CrSdvyjeP0= +google.golang.org/grpc v1.70.0 h1:pWFv03aZoHzlRKHWicjsZytKAiYCtNS0dHbXnIdq7jQ= +google.golang.org/grpc v1.70.0/go.mod h1:ofIJqVKDXx/JiXrwr2IG4/zwdH9txy3IlF40RmcJSQw= google.golang.org/protobuf v0.0.0-20200109180630-ec00e32a8dfd/go.mod h1:DFci5gLYBciE7Vtevhsrf46CRTquxDuWsQurQQe4oz8= google.golang.org/protobuf v0.0.0-20200221191635-4d8936d0db64/go.mod h1:kwYJMbMJ01Woi6D6+Kah6886xMZcty6N08ah7+eCXa0= google.golang.org/protobuf v0.0.0-20200228230310-ab0ca4ff8a60/go.mod h1:cfTl7dwQJ+fmap5saPgwCLgHXTUD7jkjRqWcaiX5VyM= @@ -1204,8 +1203,8 @@ google.golang.org/protobuf v1.26.0-rc.1/go.mod h1:jlhhOSvTdKEhbULTjvd4ARK9grFBp0 google.golang.org/protobuf v1.26.0/go.mod h1:9q0QmTI4eRPtz6boOQmLYwt+qCgq0jsYwAQnmE0givc= google.golang.org/protobuf v1.28.0/go.mod h1:HV8QOd/L58Z+nl8r43ehVNZIU/HEI6OcFqwMG9pJV4I= google.golang.org/protobuf v1.30.0/go.mod h1:HV8QOd/L58Z+nl8r43ehVNZIU/HEI6OcFqwMG9pJV4I= -google.golang.org/protobuf v1.34.2 h1:6xV6lTsCfpGD21XK49h7MhtcApnLqkfYgPcdHftf6hg= -google.golang.org/protobuf v1.34.2/go.mod h1:qYOHts0dSfpeUzUFpOMr/WGzszTmLH+DiWniOlNbLDw= +google.golang.org/protobuf v1.36.4 h1:6A3ZDJHn/eNqc1i+IdefRzy/9PokBTPvcqMySR7NNIM= +google.golang.org/protobuf v1.36.4/go.mod h1:9fA7Ob0pmnwhb644+1+CVWFRbNajQ6iRojtC/QF5bRE= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/check.v1 v1.0.0-20180628173108-788fd7840127/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/check.v1 v1.0.0-20190902080502-41f04d3bba15/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= diff --git a/management/client/client.go b/management/client/client.go index e9eeaccc1..950f6137e 100644 --- a/management/client/client.go +++ b/management/client/client.go @@ -15,7 +15,7 @@ type Client interface { io.Closer Sync(ctx context.Context, sysInfo *system.Info, msgHandler func(msg *proto.SyncResponse) error) error GetServerPublicKey() (*wgtypes.Key, error) - Register(serverKey wgtypes.Key, setupKey string, jwtToken string, sysInfo *system.Info, sshKey []byte) (*proto.LoginResponse, error) + Register(serverKey wgtypes.Key, setupKey string, jwtToken string, sysInfo *system.Info, sshKey []byte, dnsLabels domain.List) (*proto.LoginResponse, error) Login(serverKey wgtypes.Key, sysInfo *system.Info, sshKey []byte, dnsLabels domain.List) (*proto.LoginResponse, error) GetDeviceAuthorizationFlow(serverKey wgtypes.Key) (*proto.DeviceAuthorizationFlow, error) GetPKCEAuthorizationFlow(serverKey wgtypes.Key) (*proto.PKCEAuthorizationFlow, error) diff --git a/management/client/client_test.go b/management/client/client_test.go index 2e9ce8b8b..73427b38a 100644 --- a/management/client/client_test.go +++ b/management/client/client_test.go @@ -79,7 +79,7 @@ func startManagement(t *testing.T) (*grpc.Server, net.Listener) { } secretsManager := mgmt.NewTimeBasedAuthSecretsManager(peersUpdateManager, config.TURNConfig, config.Relay) - mgmtServer, err := mgmt.NewServer(context.Background(), config, accountManager, settings.NewManager(store), peersUpdateManager, secretsManager, nil, nil) + mgmtServer, err := mgmt.NewServer(context.Background(), config, accountManager, settings.NewManager(store), peersUpdateManager, secretsManager, nil, nil, nil) if err != nil { t.Fatal(err) } @@ -206,7 +206,7 @@ func TestClient_LoginRegistered(t *testing.T) { t.Error(err) } info := system.GetInfo(context.TODO()) - resp, err := client.Register(*key, ValidKey, "", info, nil) + resp, err := client.Register(*key, ValidKey, "", info, nil, nil) if err != nil { t.Error(err) } @@ -236,7 +236,7 @@ func TestClient_Sync(t *testing.T) { } info := system.GetInfo(context.TODO()) - _, err = client.Register(*serverKey, ValidKey, "", info, nil) + _, err = client.Register(*serverKey, ValidKey, "", info, nil, nil) if err != nil { t.Error(err) } @@ -252,7 +252,7 @@ func TestClient_Sync(t *testing.T) { } info = system.GetInfo(context.TODO()) - _, err = remoteClient.Register(*serverKey, ValidKey, "", info, nil) + _, err = remoteClient.Register(*serverKey, ValidKey, "", info, nil, nil) if err != nil { t.Fatal(err) } @@ -353,7 +353,7 @@ func Test_SystemMetaDataFromClient(t *testing.T) { } info := system.GetInfo(context.TODO()) - _, err = testClient.Register(*key, ValidKey, "", info, nil) + _, err = testClient.Register(*key, ValidKey, "", info, nil, nil) if err != nil { t.Errorf("error while trying to register client: %v", err) } diff --git a/management/client/grpc.go b/management/client/grpc.go index d02509c27..d3aaffec0 100644 --- a/management/client/grpc.go +++ b/management/client/grpc.go @@ -365,12 +365,12 @@ func (c *GrpcClient) login(serverKey wgtypes.Key, req *proto.LoginRequest) (*pro // Register registers peer on Management Server. It actually calls a Login endpoint with a provided setup key // Takes care of encrypting and decrypting messages. // This method will also collect system info and send it with the request (e.g. hostname, os, etc) -func (c *GrpcClient) Register(serverKey wgtypes.Key, setupKey string, jwtToken string, sysInfo *system.Info, pubSSHKey []byte) (*proto.LoginResponse, error) { +func (c *GrpcClient) Register(serverKey wgtypes.Key, setupKey string, jwtToken string, sysInfo *system.Info, pubSSHKey []byte, dnsLabels domain.List) (*proto.LoginResponse, error) { keys := &proto.PeerKeys{ SshPubKey: pubSSHKey, WgPubKey: []byte(c.key.PublicKey().String()), } - return c.login(serverKey, &proto.LoginRequest{SetupKey: setupKey, Meta: infoToMetaData(sysInfo), JwtToken: jwtToken, PeerKeys: keys}) + return c.login(serverKey, &proto.LoginRequest{SetupKey: setupKey, Meta: infoToMetaData(sysInfo), JwtToken: jwtToken, PeerKeys: keys, DnsLabels: dnsLabels.ToPunycodeList()}) } // Login attempts login to Management Server. Takes care of encrypting and decrypting messages. diff --git a/management/client/mock.go b/management/client/mock.go index 11564093a..9e1786f82 100644 --- a/management/client/mock.go +++ b/management/client/mock.go @@ -14,7 +14,7 @@ type MockClient struct { CloseFunc func() error SyncFunc func(ctx context.Context, sysInfo *system.Info, msgHandler func(msg *proto.SyncResponse) error) error GetServerPublicKeyFunc func() (*wgtypes.Key, error) - RegisterFunc func(serverKey wgtypes.Key, setupKey string, jwtToken string, info *system.Info, sshKey []byte) (*proto.LoginResponse, error) + RegisterFunc func(serverKey wgtypes.Key, setupKey string, jwtToken string, info *system.Info, sshKey []byte, dnsLabels domain.List) (*proto.LoginResponse, error) LoginFunc func(serverKey wgtypes.Key, info *system.Info, sshKey []byte, dnsLabels domain.List) (*proto.LoginResponse, error) GetDeviceAuthorizationFlowFunc func(serverKey wgtypes.Key) (*proto.DeviceAuthorizationFlow, error) GetPKCEAuthorizationFlowFunc func(serverKey wgtypes.Key) (*proto.PKCEAuthorizationFlow, error) @@ -46,11 +46,11 @@ func (m *MockClient) GetServerPublicKey() (*wgtypes.Key, error) { return m.GetServerPublicKeyFunc() } -func (m *MockClient) Register(serverKey wgtypes.Key, setupKey string, jwtToken string, info *system.Info, sshKey []byte) (*proto.LoginResponse, error) { +func (m *MockClient) Register(serverKey wgtypes.Key, setupKey string, jwtToken string, info *system.Info, sshKey []byte, dnsLabels domain.List) (*proto.LoginResponse, error) { if m.RegisterFunc == nil { return nil, nil } - return m.RegisterFunc(serverKey, setupKey, jwtToken, info, sshKey) + return m.RegisterFunc(serverKey, setupKey, jwtToken, info, sshKey, dnsLabels) } func (m *MockClient) Login(serverKey wgtypes.Key, info *system.Info, sshKey []byte, dnsLabels domain.List) (*proto.LoginResponse, error) { diff --git a/management/cmd/management.go b/management/cmd/management.go index 469be4d67..64ca958b5 100644 --- a/management/cmd/management.go +++ b/management/cmd/management.go @@ -41,13 +41,12 @@ import ( "github.com/netbirdio/netbird/formatter" mgmtProto "github.com/netbirdio/netbird/management/proto" "github.com/netbirdio/netbird/management/server" + "github.com/netbirdio/netbird/management/server/auth" nbContext "github.com/netbirdio/netbird/management/server/context" "github.com/netbirdio/netbird/management/server/geolocation" "github.com/netbirdio/netbird/management/server/groups" nbhttp "github.com/netbirdio/netbird/management/server/http" - "github.com/netbirdio/netbird/management/server/http/configs" "github.com/netbirdio/netbird/management/server/idp" - "github.com/netbirdio/netbird/management/server/jwtclaims" "github.com/netbirdio/netbird/management/server/metrics" "github.com/netbirdio/netbird/management/server/networks" "github.com/netbirdio/netbird/management/server/networks/resources" @@ -264,11 +263,11 @@ var ( tlsEnabled = true } - jwtValidator, err := jwtclaims.NewJWTValidator( - ctx, + authManager := auth.NewManager(store, config.HttpConfig.AuthIssuer, - config.GetAuthAudiences(), + config.HttpConfig.AuthAudience, config.HttpConfig.AuthKeysLocation, +<<<<<<< HEAD config.HttpConfig.IdpSignKeyRefreshEnabled, ) if err != nil { @@ -282,12 +281,24 @@ var ( KeysLocation: config.HttpConfig.AuthKeysLocation, } +======= + config.HttpConfig.AuthUserIDClaim, + config.GetAuthAudiences(), + config.HttpConfig.IdpSignKeyRefreshEnabled) + userManager := users.NewManager(store) + settingsManager := settings.NewManager(store) + permissionsManager := permissions.NewManager(userManager, settingsManager) +>>>>>>> main groupsManager := groups.NewManager(store, permissionsManager, accountManager) resourcesManager := resources.NewManager(store, permissionsManager, groupsManager, accountManager) routersManager := routers.NewManager(store, permissionsManager, accountManager) networksManager := networks.NewManager(store, permissionsManager, resourcesManager, routersManager, accountManager) +<<<<<<< HEAD httpAPIHandler, err := nbhttp.NewAPIHandler(ctx, accountManager, networksManager, resourcesManager, routersManager, groupsManager, geo, jwtValidator, appMetrics, httpAPIAuthCfg, integratedPeerValidator, proxyController, permissionsManager, peersManager) +======= + httpAPIHandler, err := nbhttp.NewAPIHandler(ctx, accountManager, networksManager, resourcesManager, routersManager, groupsManager, geo, authManager, appMetrics, config, integratedPeerValidator) +>>>>>>> main if err != nil { return fmt.Errorf("failed creating HTTP API handler: %v", err) } @@ -296,7 +307,7 @@ var ( ephemeralManager.LoadInitialPeers(ctx) gRPCAPIHandler := grpc.NewServer(gRPCOpts...) - srv, err := server.NewServer(ctx, config, accountManager, settingsManager, peersUpdateManager, secretsManager, appMetrics, ephemeralManager) + srv, err := server.NewServer(ctx, config, accountManager, settingsManager, peersUpdateManager, secretsManager, appMetrics, ephemeralManager, authManager) if err != nil { return fmt.Errorf("failed creating gRPC API handler: %v", err) } diff --git a/management/server/account.go b/management/server/account.go index c62acf6df..8cb5bec07 100644 --- a/management/server/account.go +++ b/management/server/account.go @@ -2,11 +2,8 @@ package server import ( "context" - "crypto/sha256" - b64 "encoding/base64" "errors" "fmt" - "hash/crc32" "math/rand" "net" "net/netip" @@ -24,15 +21,19 @@ import ( log "github.com/sirupsen/logrus" "golang.org/x/exp/maps" - "github.com/netbirdio/netbird/base62" nbdns "github.com/netbirdio/netbird/dns" "github.com/netbirdio/netbird/management/domain" "github.com/netbirdio/netbird/management/server/activity" + nbcontext "github.com/netbirdio/netbird/management/server/context" "github.com/netbirdio/netbird/management/server/geolocation" "github.com/netbirdio/netbird/management/server/idp" +<<<<<<< HEAD "github.com/netbirdio/netbird/management/server/integrations/integrated_validator" "github.com/netbirdio/netbird/management/server/integrations/port_forwarding" "github.com/netbirdio/netbird/management/server/jwtclaims" +======= + "github.com/netbirdio/netbird/management/server/integrated_validator" +>>>>>>> main nbpeer "github.com/netbirdio/netbird/management/server/peer" "github.com/netbirdio/netbird/management/server/posture" "github.com/netbirdio/netbird/management/server/status" @@ -78,13 +79,10 @@ type AccountManager interface { GetAccountByID(ctx context.Context, accountID string, userID string) (*types.Account, error) AccountExists(ctx context.Context, accountID string) (bool, error) GetAccountIDByUserID(ctx context.Context, userID, domain string) (string, error) - GetAccountIDFromToken(ctx context.Context, claims jwtclaims.AuthorizationClaims) (string, string, error) - CheckUserAccessByJWTGroups(ctx context.Context, claims jwtclaims.AuthorizationClaims) error - GetPATInfo(ctx context.Context, token string) (*types.User, *types.PersonalAccessToken, string, string, error) + GetAccountIDFromUserAuth(ctx context.Context, userAuth nbcontext.UserAuth) (string, string, error) DeleteAccount(ctx context.Context, accountID, userID string) error - MarkPATUsed(ctx context.Context, tokenID string) error GetUserByID(ctx context.Context, id string) (*types.User, error) - GetUser(ctx context.Context, claims jwtclaims.AuthorizationClaims) (*types.User, error) + GetUserFromUserAuth(ctx context.Context, userAuth nbcontext.UserAuth) (*types.User, error) ListUsers(ctx context.Context, accountID string) ([]*types.User, error) GetPeers(ctx context.Context, accountID, userID, nameFilter, ipFilter string) ([]*nbpeer.Peer, error) MarkPeerConnected(ctx context.Context, peerKey string, connected bool, realIP net.IP, accountID string) error @@ -151,6 +149,7 @@ type AccountManager interface { DeleteSetupKey(ctx context.Context, accountID, userID, keyID string) error UpdateAccountPeers(ctx context.Context, accountID string) BuildUserInfosForAccount(ctx context.Context, accountID, initiatorUserID string, accountUsers []*types.User) (map[string]*types.UserInfo, error) + SyncUserJWTGroups(ctx context.Context, userAuth nbcontext.UserAuth) error } type DefaultAccountManager struct { @@ -256,6 +255,11 @@ func BuildManager( metrics telemetry.AppMetrics, proxyController port_forwarding.Controller, ) (*DefaultAccountManager, error) { + start := time.Now() + defer func() { + log.WithContext(ctx).Debugf("took %v to instantiate account manager", time.Since(start)) + }() + am := &DefaultAccountManager{ Store: store, geo: geo, @@ -274,39 +278,21 @@ func BuildManager( requestBuffer: NewAccountRequestBuffer(ctx, store), proxyController: proxyController, } - allAccounts := store.GetAllAccounts(ctx) + accountsCounter, err := store.GetAccountsCounter(ctx) + if err != nil { + log.WithContext(ctx).Error(err) + } + // enable single account mode only if configured by user and number of existing accounts is not grater than 1 - am.singleAccountMode = singleAccountModeDomain != "" && len(allAccounts) <= 1 + am.singleAccountMode = singleAccountModeDomain != "" && accountsCounter <= 1 if am.singleAccountMode { if !isDomainValid(singleAccountModeDomain) { return nil, status.Errorf(status.InvalidArgument, "invalid domain \"%s\" provided for a single account mode. Please review your input for --single-account-mode-domain", singleAccountModeDomain) } am.singleAccountModeDomain = singleAccountModeDomain - log.WithContext(ctx).Infof("single account mode enabled, accounts number %d", len(allAccounts)) + log.WithContext(ctx).Infof("single account mode enabled, accounts number %d", accountsCounter) } else { - log.WithContext(ctx).Infof("single account mode disabled, accounts number %d", len(allAccounts)) - } - - // if account doesn't have a default group - // we create 'all' group and add all peers into it - // also we create default rule with source as destination - for _, account := range allAccounts { - shouldSave := false - - _, err := account.GetGroupAll() - if err != nil { - if err := addAllGroup(account); err != nil { - return nil, err - } - shouldSave = true - } - - if shouldSave { - err = store.SaveAccount(ctx, account) - if err != nil { - return nil, err - } - } + log.WithContext(ctx).Infof("single account mode disabled, accounts number %d", accountsCounter) } goCacheClient := gocache.New(CacheExpirationMax, 30*time.Minute) @@ -959,11 +945,11 @@ func (am *DefaultAccountManager) removeUserFromCache(ctx context.Context, accoun } // updateAccountDomainAttributesIfNotUpToDate updates the account domain attributes if they are not up to date and then, saves the account changes -func (am *DefaultAccountManager) updateAccountDomainAttributesIfNotUpToDate(ctx context.Context, accountID string, claims jwtclaims.AuthorizationClaims, +func (am *DefaultAccountManager) updateAccountDomainAttributesIfNotUpToDate(ctx context.Context, accountID string, userAuth nbcontext.UserAuth, primaryDomain bool, ) error { - if claims.Domain == "" { - log.WithContext(ctx).Errorf("claims don't contain a valid domain, skipping domain attributes update. Received claims: %v", claims) + if userAuth.Domain == "" { + log.WithContext(ctx).Errorf("claims don't contain a valid domain, skipping domain attributes update. Received claims: %v", userAuth) return nil } @@ -976,11 +962,11 @@ func (am *DefaultAccountManager) updateAccountDomainAttributesIfNotUpToDate(ctx return err } - if domainIsUpToDate(accountDomain, domainCategory, claims) { + if domainIsUpToDate(accountDomain, domainCategory, userAuth) { return nil } - user, err := am.Store.GetUserByUserID(ctx, store.LockingStrengthShare, claims.UserId) + user, err := am.Store.GetUserByUserID(ctx, store.LockingStrengthShare, userAuth.UserId) if err != nil { log.WithContext(ctx).Errorf("error getting user: %v", err) return err @@ -989,13 +975,13 @@ func (am *DefaultAccountManager) updateAccountDomainAttributesIfNotUpToDate(ctx newDomain := accountDomain newCategoty := domainCategory - lowerDomain := strings.ToLower(claims.Domain) + lowerDomain := strings.ToLower(userAuth.Domain) if accountDomain != lowerDomain && user.HasAdminPower() { newDomain = lowerDomain } if accountDomain == lowerDomain { - newCategoty = claims.DomainCategory + newCategoty = userAuth.DomainCategory } return am.Store.UpdateAccountDomainAttributes(ctx, accountID, newDomain, newCategoty, primaryDomain) @@ -1011,16 +997,16 @@ func (am *DefaultAccountManager) handleExistingUserAccount( ctx context.Context, userAccountID string, domainAccountID string, - claims jwtclaims.AuthorizationClaims, + userAuth nbcontext.UserAuth, ) error { primaryDomain := domainAccountID == "" || userAccountID == domainAccountID - err := am.updateAccountDomainAttributesIfNotUpToDate(ctx, userAccountID, claims, primaryDomain) + err := am.updateAccountDomainAttributesIfNotUpToDate(ctx, userAccountID, userAuth, primaryDomain) if err != nil { return err } // we should register the account ID to this user's metadata in our IDP manager - err = am.addAccountIDToIDPAppMeta(ctx, claims.UserId, userAccountID) + err = am.addAccountIDToIDPAppMeta(ctx, userAuth.UserId, userAccountID) if err != nil { return err } @@ -1030,20 +1016,20 @@ func (am *DefaultAccountManager) handleExistingUserAccount( // addNewPrivateAccount validates if there is an existing primary account for the domain, if so it adds the new user to that account, // otherwise it will create a new account and make it primary account for the domain. -func (am *DefaultAccountManager) addNewPrivateAccount(ctx context.Context, domainAccountID string, claims jwtclaims.AuthorizationClaims) (string, error) { - if claims.UserId == "" { +func (am *DefaultAccountManager) addNewPrivateAccount(ctx context.Context, domainAccountID string, userAuth nbcontext.UserAuth) (string, error) { + if userAuth.UserId == "" { return "", fmt.Errorf("user ID is empty") } - lowerDomain := strings.ToLower(claims.Domain) + lowerDomain := strings.ToLower(userAuth.Domain) - newAccount, err := am.newAccount(ctx, claims.UserId, lowerDomain) + newAccount, err := am.newAccount(ctx, userAuth.UserId, lowerDomain) if err != nil { return "", err } newAccount.Domain = lowerDomain - newAccount.DomainCategory = claims.DomainCategory + newAccount.DomainCategory = userAuth.DomainCategory newAccount.IsDomainPrimaryAccount = true err = am.Store.SaveAccount(ctx, newAccount) @@ -1051,33 +1037,33 @@ func (am *DefaultAccountManager) addNewPrivateAccount(ctx context.Context, domai return "", err } - err = am.addAccountIDToIDPAppMeta(ctx, claims.UserId, newAccount.Id) + err = am.addAccountIDToIDPAppMeta(ctx, userAuth.UserId, newAccount.Id) if err != nil { return "", err } - am.StoreEvent(ctx, claims.UserId, claims.UserId, newAccount.Id, activity.UserJoined, nil) + am.StoreEvent(ctx, userAuth.UserId, userAuth.UserId, newAccount.Id, activity.UserJoined, nil) return newAccount.Id, nil } -func (am *DefaultAccountManager) addNewUserToDomainAccount(ctx context.Context, domainAccountID string, claims jwtclaims.AuthorizationClaims) (string, error) { +func (am *DefaultAccountManager) addNewUserToDomainAccount(ctx context.Context, domainAccountID string, userAuth nbcontext.UserAuth) (string, error) { unlockAccount := am.Store.AcquireWriteLockByUID(ctx, domainAccountID) defer unlockAccount() - newUser := types.NewRegularUser(claims.UserId) + newUser := types.NewRegularUser(userAuth.UserId) newUser.AccountID = domainAccountID err := am.Store.SaveUser(ctx, store.LockingStrengthUpdate, newUser) if err != nil { return "", err } - err = am.addAccountIDToIDPAppMeta(ctx, claims.UserId, domainAccountID) + err = am.addAccountIDToIDPAppMeta(ctx, userAuth.UserId, domainAccountID) if err != nil { return "", err } - am.StoreEvent(ctx, claims.UserId, claims.UserId, domainAccountID, activity.UserJoined, nil) + am.StoreEvent(ctx, userAuth.UserId, userAuth.UserId, domainAccountID, activity.UserJoined, nil) return domainAccountID, nil } @@ -1117,76 +1103,11 @@ func (am *DefaultAccountManager) redeemInvite(ctx context.Context, accountID str return nil } -// MarkPATUsed marks a personal access token as used -func (am *DefaultAccountManager) MarkPATUsed(ctx context.Context, tokenID string) error { - return am.Store.MarkPATUsed(ctx, store.LockingStrengthUpdate, tokenID) -} - // GetAccount returns an account associated with this account ID. func (am *DefaultAccountManager) GetAccount(ctx context.Context, accountID string) (*types.Account, error) { return am.Store.GetAccount(ctx, accountID) } -// GetPATInfo retrieves user, personal access token, domain, and category details from a personal access token. -func (am *DefaultAccountManager) GetPATInfo(ctx context.Context, token string) (user *types.User, pat *types.PersonalAccessToken, domain string, category string, err error) { - user, pat, err = am.extractPATFromToken(ctx, token) - if err != nil { - return nil, nil, "", "", err - } - - domain, category, err = am.Store.GetAccountDomainAndCategory(ctx, store.LockingStrengthShare, user.AccountID) - if err != nil { - return nil, nil, "", "", err - } - - return user, pat, domain, category, nil -} - -// extractPATFromToken validates the token structure and retrieves associated User and PAT. -func (am *DefaultAccountManager) extractPATFromToken(ctx context.Context, token string) (*types.User, *types.PersonalAccessToken, error) { - if len(token) != types.PATLength { - return nil, nil, fmt.Errorf("token has incorrect length") - } - - prefix := token[:len(types.PATPrefix)] - if prefix != types.PATPrefix { - return nil, nil, fmt.Errorf("token has wrong prefix") - } - secret := token[len(types.PATPrefix) : len(types.PATPrefix)+types.PATSecretLength] - encodedChecksum := token[len(types.PATPrefix)+types.PATSecretLength : len(types.PATPrefix)+types.PATSecretLength+types.PATChecksumLength] - - verificationChecksum, err := base62.Decode(encodedChecksum) - if err != nil { - return nil, nil, fmt.Errorf("token checksum decoding failed: %w", err) - } - - secretChecksum := crc32.ChecksumIEEE([]byte(secret)) - if secretChecksum != verificationChecksum { - return nil, nil, fmt.Errorf("token checksum does not match") - } - - hashedToken := sha256.Sum256([]byte(token)) - encodedHashedToken := b64.StdEncoding.EncodeToString(hashedToken[:]) - - var user *types.User - var pat *types.PersonalAccessToken - - err = am.Store.ExecuteInTransaction(ctx, func(transaction store.Store) error { - pat, err = transaction.GetPATByHashedToken(ctx, store.LockingStrengthShare, encodedHashedToken) - if err != nil { - return err - } - - user, err = transaction.GetUserByPATID(ctx, store.LockingStrengthShare, pat.ID) - return err - }) - if err != nil { - return nil, nil, err - } - - return user, pat, nil -} - // GetAccountByID returns an account associated with this account ID. func (am *DefaultAccountManager) GetAccountByID(ctx context.Context, accountID string, userID string) (*types.Account, error) { user, err := am.Store.GetUserByUserID(ctx, store.LockingStrengthShare, userID) @@ -1201,58 +1122,56 @@ func (am *DefaultAccountManager) GetAccountByID(ctx context.Context, accountID s return am.Store.GetAccount(ctx, accountID) } -// GetAccountIDFromToken returns an account ID associated with this token. -func (am *DefaultAccountManager) GetAccountIDFromToken(ctx context.Context, claims jwtclaims.AuthorizationClaims) (string, string, error) { - if claims.UserId == "" { +func (am *DefaultAccountManager) GetAccountIDFromUserAuth(ctx context.Context, userAuth nbcontext.UserAuth) (string, string, error) { + if userAuth.UserId == "" { return "", "", errors.New(emptyUserID) } if am.singleAccountMode && am.singleAccountModeDomain != "" { // This section is mostly related to self-hosted installations. // We override incoming domain claims to group users under a single account. - claims.Domain = am.singleAccountModeDomain - claims.DomainCategory = types.PrivateCategory + userAuth.Domain = am.singleAccountModeDomain + userAuth.DomainCategory = types.PrivateCategory log.WithContext(ctx).Debugf("overriding JWT Domain and DomainCategory claims since single account mode is enabled") } - accountID, err := am.getAccountIDWithAuthorizationClaims(ctx, claims) + accountID, err := am.getAccountIDWithAuthorizationClaims(ctx, userAuth) if err != nil { return "", "", err } - user, err := am.Store.GetUserByUserID(ctx, store.LockingStrengthShare, claims.UserId) + user, err := am.Store.GetUserByUserID(ctx, store.LockingStrengthShare, userAuth.UserId) if err != nil { // this is not really possible because we got an account by user ID - return "", "", status.Errorf(status.NotFound, "user %s not found", claims.UserId) + return "", "", status.Errorf(status.NotFound, "user %s not found", userAuth.UserId) + } + + if userAuth.IsChild { + return accountID, user.Id, nil } if user.AccountID != accountID { - return "", "", status.Errorf(status.PermissionDenied, "user %s is not part of the account %s", claims.UserId, accountID) + return "", "", status.Errorf(status.PermissionDenied, "user %s is not part of the account %s", userAuth.UserId, accountID) } - if !user.IsServiceUser && claims.Invited { + if !user.IsServiceUser && userAuth.Invited { err = am.redeemInvite(ctx, accountID, user.Id) if err != nil { return "", "", err } } - if err = am.syncJWTGroups(ctx, accountID, claims); err != nil { - return "", "", err - } - return accountID, user.Id, nil } // syncJWTGroups processes the JWT groups for a user, updates the account based on the groups, // and propagates changes to peers if group propagation is enabled. -func (am *DefaultAccountManager) syncJWTGroups(ctx context.Context, accountID string, claims jwtclaims.AuthorizationClaims) error { - if claim, exists := claims.Raw[jwtclaims.IsToken]; exists { - if isToken, ok := claim.(bool); ok && isToken { - return nil - } +// requires userAuth to have been ValidateAndParseToken and EnsureUserAccessByJWTGroups by the AuthManager +func (am *DefaultAccountManager) SyncUserJWTGroups(ctx context.Context, userAuth nbcontext.UserAuth) error { + if userAuth.IsChild || userAuth.IsPAT { + return nil } - settings, err := am.Store.GetAccountSettings(ctx, store.LockingStrengthShare, accountID) + settings, err := am.Store.GetAccountSettings(ctx, store.LockingStrengthShare, userAuth.AccountId) if err != nil { return err } @@ -1266,9 +1185,7 @@ func (am *DefaultAccountManager) syncJWTGroups(ctx context.Context, accountID st return nil } - jwtGroupsNames := extractJWTGroups(ctx, settings.JWTGroupsClaimName, claims) - - unlockAccount := am.Store.AcquireWriteLockByUID(ctx, accountID) + unlockAccount := am.Store.AcquireWriteLockByUID(ctx, userAuth.AccountId) defer func() { if unlockAccount != nil { unlockAccount() @@ -1280,17 +1197,17 @@ func (am *DefaultAccountManager) syncJWTGroups(ctx context.Context, accountID st var hasChanges bool var user *types.User err = am.Store.ExecuteInTransaction(ctx, func(transaction store.Store) error { - user, err = transaction.GetUserByUserID(ctx, store.LockingStrengthShare, claims.UserId) + user, err = transaction.GetUserByUserID(ctx, store.LockingStrengthShare, userAuth.UserId) if err != nil { return fmt.Errorf("error getting user: %w", err) } - groups, err := transaction.GetAccountGroups(ctx, store.LockingStrengthShare, accountID) + groups, err := transaction.GetAccountGroups(ctx, store.LockingStrengthShare, userAuth.AccountId) if err != nil { return fmt.Errorf("error getting account groups: %w", err) } - changed, updatedAutoGroups, newGroupsToCreate, err := am.getJWTGroupsChanges(user, groups, jwtGroupsNames) + changed, updatedAutoGroups, newGroupsToCreate, err := am.getJWTGroupsChanges(user, groups, userAuth.Groups) if err != nil { return fmt.Errorf("error getting JWT groups changes: %w", err) } @@ -1315,7 +1232,7 @@ func (am *DefaultAccountManager) syncJWTGroups(ctx context.Context, accountID st // Propagate changes to peers if group propagation is enabled if settings.GroupsPropagationEnabled { - groups, err = transaction.GetAccountGroups(ctx, store.LockingStrengthShare, accountID) + groups, err = transaction.GetAccountGroups(ctx, store.LockingStrengthShare, userAuth.AccountId) if err != nil { return fmt.Errorf("error getting account groups: %w", err) } @@ -1325,7 +1242,7 @@ func (am *DefaultAccountManager) syncJWTGroups(ctx context.Context, accountID st groupsMap[group.ID] = group } - peers, err := transaction.GetUserPeers(ctx, store.LockingStrengthShare, accountID, claims.UserId) + peers, err := transaction.GetUserPeers(ctx, store.LockingStrengthShare, userAuth.AccountId, userAuth.UserId) if err != nil { return fmt.Errorf("error getting user peers: %w", err) } @@ -1339,7 +1256,7 @@ func (am *DefaultAccountManager) syncJWTGroups(ctx context.Context, accountID st return fmt.Errorf("error saving groups: %w", err) } - if err = transaction.IncrementNetworkSerial(ctx, store.LockingStrengthUpdate, accountID); err != nil { + if err = transaction.IncrementNetworkSerial(ctx, store.LockingStrengthUpdate, userAuth.AccountId); err != nil { return fmt.Errorf("error incrementing network serial: %w", err) } } @@ -1357,45 +1274,45 @@ func (am *DefaultAccountManager) syncJWTGroups(ctx context.Context, accountID st } for _, g := range addNewGroups { - group, err := am.Store.GetGroupByID(ctx, store.LockingStrengthShare, accountID, g) + group, err := am.Store.GetGroupByID(ctx, store.LockingStrengthShare, userAuth.AccountId, g) if err != nil { - log.WithContext(ctx).Debugf("group %s not found while saving user activity event of account %s", g, accountID) + log.WithContext(ctx).Debugf("group %s not found while saving user activity event of account %s", g, userAuth.AccountId) } else { meta := map[string]any{ "group": group.Name, "group_id": group.ID, "is_service_user": user.IsServiceUser, "user_name": user.ServiceUserName, } - am.StoreEvent(ctx, user.Id, user.Id, accountID, activity.GroupAddedToUser, meta) + am.StoreEvent(ctx, user.Id, user.Id, userAuth.AccountId, activity.GroupAddedToUser, meta) } } for _, g := range removeOldGroups { - group, err := am.Store.GetGroupByID(ctx, store.LockingStrengthShare, accountID, g) + group, err := am.Store.GetGroupByID(ctx, store.LockingStrengthShare, userAuth.AccountId, g) if err != nil { - log.WithContext(ctx).Debugf("group %s not found while saving user activity event of account %s", g, accountID) + log.WithContext(ctx).Debugf("group %s not found while saving user activity event of account %s", g, userAuth.AccountId) } else { meta := map[string]any{ "group": group.Name, "group_id": group.ID, "is_service_user": user.IsServiceUser, "user_name": user.ServiceUserName, } - am.StoreEvent(ctx, user.Id, user.Id, accountID, activity.GroupRemovedFromUser, meta) + am.StoreEvent(ctx, user.Id, user.Id, userAuth.AccountId, activity.GroupRemovedFromUser, meta) } } if settings.GroupsPropagationEnabled { - removedGroupAffectsPeers, err := areGroupChangesAffectPeers(ctx, am.Store, accountID, removeOldGroups) + removedGroupAffectsPeers, err := areGroupChangesAffectPeers(ctx, am.Store, userAuth.AccountId, removeOldGroups) if err != nil { return err } - newGroupsAffectsPeers, err := areGroupChangesAffectPeers(ctx, am.Store, accountID, addNewGroups) + newGroupsAffectsPeers, err := areGroupChangesAffectPeers(ctx, am.Store, userAuth.AccountId, addNewGroups) if err != nil { return err } if removedGroupAffectsPeers || newGroupsAffectsPeers { - log.WithContext(ctx).Tracef("user %s: JWT group membership changed, updating account peers", claims.UserId) - am.UpdateAccountPeers(ctx, accountID) + log.WithContext(ctx).Tracef("user %s: JWT group membership changed, updating account peers", userAuth.UserId) + am.UpdateAccountPeers(ctx, userAuth.AccountId) } } @@ -1420,24 +1337,34 @@ func (am *DefaultAccountManager) syncJWTGroups(ctx context.Context, accountID st // Existing user + Existing account + Existing Indexed Domain -> Nothing changes // // Existing user + Existing account + Existing domain reclassified Domain as private -> Nothing changes (index domain) -func (am *DefaultAccountManager) getAccountIDWithAuthorizationClaims(ctx context.Context, claims jwtclaims.AuthorizationClaims) (string, error) { +// +// UserAuth IsChild -> checks that account exists +func (am *DefaultAccountManager) getAccountIDWithAuthorizationClaims(ctx context.Context, userAuth nbcontext.UserAuth) (string, error) { log.WithContext(ctx).Tracef("getting account with authorization claims. User ID: \"%s\", Account ID: \"%s\", Domain: \"%s\", Domain Category: \"%s\"", - claims.UserId, claims.AccountId, claims.Domain, claims.DomainCategory) + userAuth.UserId, userAuth.AccountId, userAuth.Domain, userAuth.DomainCategory) - if claims.UserId == "" { + if userAuth.UserId == "" { return "", errors.New(emptyUserID) } - if claims.DomainCategory != types.PrivateCategory || !isDomainValid(claims.Domain) { - return am.GetAccountIDByUserID(ctx, claims.UserId, claims.Domain) + if userAuth.IsChild { + exists, err := am.Store.AccountExists(ctx, store.LockingStrengthShare, userAuth.AccountId) + if err != nil || !exists { + return "", err + } + return userAuth.AccountId, nil } - if claims.AccountId != "" { - return am.handlePrivateAccountWithIDFromClaim(ctx, claims) + if userAuth.DomainCategory != types.PrivateCategory || !isDomainValid(userAuth.Domain) { + return am.GetAccountIDByUserID(ctx, userAuth.UserId, userAuth.Domain) + } + + if userAuth.AccountId != "" { + return am.handlePrivateAccountWithIDFromClaim(ctx, userAuth) } // We checked if the domain has a primary account already - domainAccountID, cancel, err := am.getPrivateDomainWithGlobalLock(ctx, claims.Domain) + domainAccountID, cancel, err := am.getPrivateDomainWithGlobalLock(ctx, userAuth.Domain) if cancel != nil { defer cancel() } @@ -1445,14 +1372,14 @@ func (am *DefaultAccountManager) getAccountIDWithAuthorizationClaims(ctx context return "", err } - userAccountID, err := am.Store.GetAccountIDByUserID(ctx, store.LockingStrengthShare, claims.UserId) + userAccountID, err := am.Store.GetAccountIDByUserID(ctx, store.LockingStrengthShare, userAuth.UserId) if handleNotFound(err) != nil { log.WithContext(ctx).Errorf("error getting account ID by user ID: %v", err) return "", err } if userAccountID != "" { - if err = am.handleExistingUserAccount(ctx, userAccountID, domainAccountID, claims); err != nil { + if err = am.handleExistingUserAccount(ctx, userAccountID, domainAccountID, userAuth); err != nil { return "", err } @@ -1460,10 +1387,10 @@ func (am *DefaultAccountManager) getAccountIDWithAuthorizationClaims(ctx context } if domainAccountID != "" { - return am.addNewUserToDomainAccount(ctx, domainAccountID, claims) + return am.addNewUserToDomainAccount(ctx, domainAccountID, userAuth) } - return am.addNewPrivateAccount(ctx, domainAccountID, claims) + return am.addNewPrivateAccount(ctx, domainAccountID, userAuth) } func (am *DefaultAccountManager) getPrivateDomainWithGlobalLock(ctx context.Context, domain string) (string, context.CancelFunc, error) { domainAccountID, err := am.Store.GetAccountIDByPrivateDomain(ctx, store.LockingStrengthShare, domain) @@ -1491,40 +1418,40 @@ func (am *DefaultAccountManager) getPrivateDomainWithGlobalLock(ctx context.Cont return domainAccountID, cancel, nil } -func (am *DefaultAccountManager) handlePrivateAccountWithIDFromClaim(ctx context.Context, claims jwtclaims.AuthorizationClaims) (string, error) { - userAccountID, err := am.Store.GetAccountIDByUserID(ctx, store.LockingStrengthShare, claims.UserId) +func (am *DefaultAccountManager) handlePrivateAccountWithIDFromClaim(ctx context.Context, userAuth nbcontext.UserAuth) (string, error) { + userAccountID, err := am.Store.GetAccountIDByUserID(ctx, store.LockingStrengthShare, userAuth.UserId) if err != nil { log.WithContext(ctx).Errorf("error getting account ID by user ID: %v", err) return "", err } - if userAccountID != claims.AccountId { - return "", fmt.Errorf("user %s is not part of the account id %s", claims.UserId, claims.AccountId) + if userAccountID != userAuth.AccountId { + return "", fmt.Errorf("user %s is not part of the account id %s", userAuth.UserId, userAuth.AccountId) } - accountDomain, domainCategory, err := am.Store.GetAccountDomainAndCategory(ctx, store.LockingStrengthShare, claims.AccountId) + accountDomain, domainCategory, err := am.Store.GetAccountDomainAndCategory(ctx, store.LockingStrengthShare, userAuth.AccountId) if handleNotFound(err) != nil { log.WithContext(ctx).Errorf("error getting account domain and category: %v", err) return "", err } - if domainIsUpToDate(accountDomain, domainCategory, claims) { - return claims.AccountId, nil + if domainIsUpToDate(accountDomain, domainCategory, userAuth) { + return userAuth.AccountId, nil } // We checked if the domain has a primary account already - domainAccountID, err := am.Store.GetAccountIDByPrivateDomain(ctx, store.LockingStrengthShare, claims.Domain) + domainAccountID, err := am.Store.GetAccountIDByPrivateDomain(ctx, store.LockingStrengthShare, userAuth.Domain) if handleNotFound(err) != nil { log.WithContext(ctx).Errorf(errorGettingDomainAccIDFmt, err) return "", err } - err = am.handleExistingUserAccount(ctx, claims.AccountId, domainAccountID, claims) + err = am.handleExistingUserAccount(ctx, userAuth.AccountId, domainAccountID, userAuth) if err != nil { return "", err } - return claims.AccountId, nil + return userAuth.AccountId, nil } func handleNotFound(err error) error { @@ -1539,8 +1466,8 @@ func handleNotFound(err error) error { return nil } -func domainIsUpToDate(domain string, domainCategory string, claims jwtclaims.AuthorizationClaims) bool { - return domainCategory == types.PrivateCategory || claims.DomainCategory != types.PrivateCategory || domain != claims.Domain +func domainIsUpToDate(domain string, domainCategory string, userAuth nbcontext.UserAuth) bool { + return domainCategory == types.PrivateCategory || userAuth.DomainCategory != types.PrivateCategory || domain != userAuth.Domain } func (am *DefaultAccountManager) SyncAndMarkPeer(ctx context.Context, accountID string, peerPubKey string, meta nbpeer.PeerSystemMeta, realIP net.IP) (*nbpeer.Peer, *types.NetworkMap, []*posture.Checks, error) { @@ -1622,34 +1549,6 @@ func (am *DefaultAccountManager) GetDNSDomain() string { return am.dnsDomain } -// CheckUserAccessByJWTGroups checks if the user has access, particularly in cases where the admin enabled JWT -// group propagation and set the list of groups with access permissions. -func (am *DefaultAccountManager) CheckUserAccessByJWTGroups(ctx context.Context, claims jwtclaims.AuthorizationClaims) error { - accountID, _, err := am.GetAccountIDFromToken(ctx, claims) - if err != nil { - return err - } - - settings, err := am.Store.GetAccountSettings(ctx, store.LockingStrengthShare, accountID) - if err != nil { - return err - } - - // Ensures JWT group synchronization to the management is enabled before, - // filtering access based on the allowed groups. - if settings != nil && settings.JWTGroupsEnabled { - if allowedGroups := settings.JWTAllowGroups; len(allowedGroups) > 0 { - userJWTGroups := extractJWTGroups(ctx, settings.JWTGroupsClaimName, claims) - - if !userHasAllowedGroup(allowedGroups, userJWTGroups) { - return fmt.Errorf("user does not belong to any of the allowed JWT groups") - } - } - } - - return nil -} - func (am *DefaultAccountManager) onPeersInvalidated(ctx context.Context, accountID string) { log.WithContext(ctx).Debugf("validated peers has been invalidated for account %s", accountID) am.UpdateAccountPeers(ctx, accountID) @@ -1717,46 +1616,6 @@ func (am *DefaultAccountManager) GetAccountSettings(ctx context.Context, account return am.Store.GetAccountSettings(ctx, store.LockingStrengthShare, accountID) } -// addAllGroup to account object if it doesn't exist -func addAllGroup(account *types.Account) error { - if len(account.Groups) == 0 { - allGroup := &types.Group{ - ID: xid.New().String(), - Name: "All", - Issued: types.GroupIssuedAPI, - } - for _, peer := range account.Peers { - allGroup.Peers = append(allGroup.Peers, peer.ID) - } - account.Groups = map[string]*types.Group{allGroup.ID: allGroup} - - id := xid.New().String() - - defaultPolicy := &types.Policy{ - ID: id, - Name: types.DefaultRuleName, - Description: types.DefaultRuleDescription, - Enabled: true, - Rules: []*types.PolicyRule{ - { - ID: id, - Name: types.DefaultRuleName, - Description: types.DefaultRuleDescription, - Enabled: true, - Sources: []string{allGroup.ID}, - Destinations: []string{allGroup.ID}, - Bidirectional: true, - Protocol: types.PolicyRuleProtocolALL, - Action: types.PolicyTrafficActionAccept, - }, - }, - } - - account.Policies = []*types.Policy{defaultPolicy} - } - return nil -} - // newAccountWithId creates a new Account with a default SetupKey (doesn't store in a Store) and provided id func newAccountWithId(ctx context.Context, accountID, userID, domain string) *types.Account { log.WithContext(ctx).Debugf("creating new account") @@ -1801,45 +1660,12 @@ func newAccountWithId(ctx context.Context, accountID, userID, domain string) *ty }, } - if err := addAllGroup(acc); err != nil { + if err := acc.AddAllGroup(); err != nil { log.WithContext(ctx).Errorf("error adding all group to account %s: %v", acc.Id, err) } return acc } -// extractJWTGroups extracts the group names from a JWT token's claims. -func extractJWTGroups(ctx context.Context, claimName string, claims jwtclaims.AuthorizationClaims) []string { - userJWTGroups := make([]string, 0) - - if claim, ok := claims.Raw[claimName]; ok { - if claimGroups, ok := claim.([]interface{}); ok { - for _, g := range claimGroups { - if group, ok := g.(string); ok { - userJWTGroups = append(userJWTGroups, group) - } else { - log.WithContext(ctx).Debugf("JWT claim %q contains a non-string group (type: %T): %v", claimName, g, g) - } - } - } - } else { - log.WithContext(ctx).Debugf("JWT claim %q is not a string array", claimName) - } - - return userJWTGroups -} - -// userHasAllowedGroup checks if a user belongs to any of the allowed groups. -func userHasAllowedGroup(allowedGroups []string, userGroups []string) bool { - for _, userGroup := range userGroups { - for _, allowedGroup := range allowedGroups { - if userGroup == allowedGroup { - return true - } - } - } - return false -} - // separateGroups separates user's auto groups into non-JWT and JWT groups. // Returns the list of standard auto groups and a map of JWT auto groups, // where the keys are the group names and the values are the group IDs. diff --git a/management/server/account_test.go b/management/server/account_test.go index 802cc8bd2..8a042f4c3 100644 --- a/management/server/account_test.go +++ b/management/server/account_test.go @@ -2,8 +2,6 @@ package server import ( "context" - "crypto/sha256" - b64 "encoding/base64" "encoding/json" "fmt" "io" @@ -15,9 +13,12 @@ import ( "testing" "time" +<<<<<<< HEAD "github.com/golang-jwt/jwt" "github.com/netbirdio/netbird/management/server/integrations/port_forwarding" +======= +>>>>>>> main "github.com/netbirdio/netbird/management/server/util" resourceTypes "github.com/netbirdio/netbird/management/server/networks/resources/types" @@ -31,7 +32,7 @@ import ( nbdns "github.com/netbirdio/netbird/dns" "github.com/netbirdio/netbird/management/server/activity" - "github.com/netbirdio/netbird/management/server/jwtclaims" + nbcontext "github.com/netbirdio/netbird/management/server/context" nbpeer "github.com/netbirdio/netbird/management/server/peer" "github.com/netbirdio/netbird/management/server/posture" "github.com/netbirdio/netbird/management/server/store" @@ -438,7 +439,7 @@ func TestAccountManager_GetOrCreateAccountByUser(t *testing.T) { } func TestDefaultAccountManager_GetAccountIDFromToken(t *testing.T) { - type initUserParams jwtclaims.AuthorizationClaims + type initUserParams nbcontext.UserAuth var ( publicDomain = "public.com" @@ -461,7 +462,7 @@ func TestDefaultAccountManager_GetAccountIDFromToken(t *testing.T) { testCases := []struct { name string - inputClaims jwtclaims.AuthorizationClaims + inputClaims nbcontext.UserAuth inputInitUserParams initUserParams inputUpdateAttrs bool inputUpdateClaimAccount bool @@ -476,7 +477,7 @@ func TestDefaultAccountManager_GetAccountIDFromToken(t *testing.T) { }{ { name: "New User With Public Domain", - inputClaims: jwtclaims.AuthorizationClaims{ + inputClaims: nbcontext.UserAuth{ Domain: publicDomain, UserId: "pub-domain-user", DomainCategory: types.PublicCategory, @@ -493,7 +494,7 @@ func TestDefaultAccountManager_GetAccountIDFromToken(t *testing.T) { }, { name: "New User With Unknown Domain", - inputClaims: jwtclaims.AuthorizationClaims{ + inputClaims: nbcontext.UserAuth{ Domain: unknownDomain, UserId: "unknown-domain-user", DomainCategory: types.UnknownCategory, @@ -510,7 +511,7 @@ func TestDefaultAccountManager_GetAccountIDFromToken(t *testing.T) { }, { name: "New User With Private Domain", - inputClaims: jwtclaims.AuthorizationClaims{ + inputClaims: nbcontext.UserAuth{ Domain: privateDomain, UserId: "pvt-domain-user", DomainCategory: types.PrivateCategory, @@ -527,7 +528,7 @@ func TestDefaultAccountManager_GetAccountIDFromToken(t *testing.T) { }, { name: "New Regular User With Existing Private Domain", - inputClaims: jwtclaims.AuthorizationClaims{ + inputClaims: nbcontext.UserAuth{ Domain: privateDomain, UserId: "new-pvt-domain-user", DomainCategory: types.PrivateCategory, @@ -545,7 +546,7 @@ func TestDefaultAccountManager_GetAccountIDFromToken(t *testing.T) { }, { name: "Existing User With Existing Reclassified Private Domain", - inputClaims: jwtclaims.AuthorizationClaims{ + inputClaims: nbcontext.UserAuth{ Domain: defaultInitAccount.Domain, UserId: defaultInitAccount.UserId, DomainCategory: types.PrivateCategory, @@ -562,7 +563,7 @@ func TestDefaultAccountManager_GetAccountIDFromToken(t *testing.T) { }, { name: "Existing Account Id With Existing Reclassified Private Domain", - inputClaims: jwtclaims.AuthorizationClaims{ + inputClaims: nbcontext.UserAuth{ Domain: defaultInitAccount.Domain, UserId: defaultInitAccount.UserId, DomainCategory: types.PrivateCategory, @@ -580,7 +581,7 @@ func TestDefaultAccountManager_GetAccountIDFromToken(t *testing.T) { }, { name: "User With Private Category And Empty Domain", - inputClaims: jwtclaims.AuthorizationClaims{ + inputClaims: nbcontext.UserAuth{ Domain: "", UserId: "pvt-domain-user", DomainCategory: types.PrivateCategory, @@ -609,7 +610,7 @@ func TestDefaultAccountManager_GetAccountIDFromToken(t *testing.T) { require.NoError(t, err, "get init account failed") if testCase.inputUpdateAttrs { - err = manager.updateAccountDomainAttributesIfNotUpToDate(context.Background(), initAccount.Id, jwtclaims.AuthorizationClaims{UserId: testCase.inputInitUserParams.UserId, Domain: testCase.inputInitUserParams.Domain, DomainCategory: testCase.inputInitUserParams.DomainCategory}, true) + err = manager.updateAccountDomainAttributesIfNotUpToDate(context.Background(), initAccount.Id, nbcontext.UserAuth{UserId: testCase.inputInitUserParams.UserId, Domain: testCase.inputInitUserParams.Domain, DomainCategory: testCase.inputInitUserParams.DomainCategory}, true) require.NoError(t, err, "update init user failed") } @@ -617,7 +618,7 @@ func TestDefaultAccountManager_GetAccountIDFromToken(t *testing.T) { testCase.inputClaims.AccountId = initAccount.Id } - accountID, _, err = manager.GetAccountIDFromToken(context.Background(), testCase.inputClaims) + accountID, _, err = manager.GetAccountIDFromUserAuth(context.Background(), testCase.inputClaims) require.NoError(t, err, "support function failed") account, err := manager.Store.GetAccount(context.Background(), accountID) @@ -636,14 +637,12 @@ func TestDefaultAccountManager_GetAccountIDFromToken(t *testing.T) { } } -func TestDefaultAccountManager_GetGroupsFromTheToken(t *testing.T) { +func TestDefaultAccountManager_SyncUserJWTGroups(t *testing.T) { userId := "user-id" domain := "test.domain" - _ = newAccountWithId(context.Background(), "", userId, domain) manager, err := createManager(t) require.NoError(t, err, "unable to create account manager") - accountID, err := manager.GetAccountIDByUserID(context.Background(), userId, domain) require.NoError(t, err, "create init user failed") // as initAccount was created without account id we have to take the id after account initialization @@ -651,65 +650,50 @@ func TestDefaultAccountManager_GetGroupsFromTheToken(t *testing.T) { // it is important to set the id as it help to avoid creating additional account with empty Id and re-pointing indices to it initAccount, err := manager.Store.GetAccount(context.Background(), accountID) require.NoError(t, err, "get init account failed") - - claims := jwtclaims.AuthorizationClaims{ + claims := nbcontext.UserAuth{ AccountId: accountID, // is empty as it is based on accountID right after initialization of initAccount Domain: domain, UserId: userId, DomainCategory: "test-category", - Raw: jwt.MapClaims{"idp-groups": []interface{}{"group1", "group2"}}, + Groups: []string{"group1", "group2"}, } - t.Run("JWT groups disabled", func(t *testing.T) { - accountID, _, err := manager.GetAccountIDFromToken(context.Background(), claims) - require.NoError(t, err, "get account by token failed") - + err := manager.SyncUserJWTGroups(context.Background(), claims) + require.NoError(t, err, "synt user jwt groups failed") account, err := manager.Store.GetAccount(context.Background(), accountID) require.NoError(t, err, "get account failed") - require.Len(t, account.Groups, 1, "only ALL group should exists") }) - t.Run("JWT groups enabled without claim name", func(t *testing.T) { initAccount.Settings.JWTGroupsEnabled = true err := manager.Store.SaveAccount(context.Background(), initAccount) require.NoError(t, err, "save account failed") require.Len(t, manager.Store.GetAllAccounts(context.Background()), 1, "only one account should exist") - - accountID, _, err := manager.GetAccountIDFromToken(context.Background(), claims) - require.NoError(t, err, "get account by token failed") - + err = manager.SyncUserJWTGroups(context.Background(), claims) + require.NoError(t, err, "synt user jwt groups failed") account, err := manager.Store.GetAccount(context.Background(), accountID) require.NoError(t, err, "get account failed") - require.Len(t, account.Groups, 1, "if group claim is not set no group added from JWT") }) - t.Run("JWT groups enabled", func(t *testing.T) { initAccount.Settings.JWTGroupsEnabled = true initAccount.Settings.JWTGroupsClaimName = "idp-groups" err := manager.Store.SaveAccount(context.Background(), initAccount) require.NoError(t, err, "save account failed") require.Len(t, manager.Store.GetAllAccounts(context.Background()), 1, "only one account should exist") - - accountID, _, err := manager.GetAccountIDFromToken(context.Background(), claims) - require.NoError(t, err, "get account by token failed") - + err = manager.SyncUserJWTGroups(context.Background(), claims) + require.NoError(t, err, "synt user jwt groups failed") account, err := manager.Store.GetAccount(context.Background(), accountID) require.NoError(t, err, "get account failed") - require.Len(t, account.Groups, 3, "groups should be added to the account") - groupsByNames := map[string]*types.Group{} for _, g := range account.Groups { groupsByNames[g.Name] = g } - g1, ok := groupsByNames["group1"] require.True(t, ok, "group1 should be added to the account") require.Equal(t, g1.Name, "group1", "group1 name should match") require.Equal(t, g1.Issued, types.GroupIssuedJWT, "group1 issued should match") - g2, ok := groupsByNames["group2"] require.True(t, ok, "group2 should be added to the account") require.Equal(t, g2.Name, "group2", "group2 name should match") @@ -717,88 +701,6 @@ func TestDefaultAccountManager_GetGroupsFromTheToken(t *testing.T) { }) } -func TestAccountManager_GetAccountFromPAT(t *testing.T) { - store, cleanup, err := store.NewTestStoreFromSQL(context.Background(), "", t.TempDir()) - if err != nil { - t.Fatalf("Error when creating store: %s", err) - } - t.Cleanup(cleanup) - account := newAccountWithId(context.Background(), "account_id", "testuser", "") - - token := "nbp_9999EUDNdkeusjentDLSJEn1902u84390W6W" - hashedToken := sha256.Sum256([]byte(token)) - encodedHashedToken := b64.StdEncoding.EncodeToString(hashedToken[:]) - account.Users["someUser"] = &types.User{ - Id: "someUser", - PATs: map[string]*types.PersonalAccessToken{ - "tokenId": { - ID: "tokenId", - UserID: "someUser", - HashedToken: encodedHashedToken, - }, - }, - } - err = store.SaveAccount(context.Background(), account) - if err != nil { - t.Fatalf("Error when saving account: %s", err) - } - - am := DefaultAccountManager{ - Store: store, - } - - user, pat, _, _, err := am.GetPATInfo(context.Background(), token) - if err != nil { - t.Fatalf("Error when getting Account from PAT: %s", err) - } - - assert.Equal(t, "account_id", user.AccountID) - assert.Equal(t, "someUser", user.Id) - assert.Equal(t, account.Users["someUser"].PATs["tokenId"].ID, pat.ID) -} - -func TestDefaultAccountManager_MarkPATUsed(t *testing.T) { - store, cleanup, err := store.NewTestStoreFromSQL(context.Background(), "", t.TempDir()) - if err != nil { - t.Fatalf("Error when creating store: %s", err) - } - t.Cleanup(cleanup) - - account := newAccountWithId(context.Background(), "account_id", "testuser", "") - - token := "nbp_9999EUDNdkeusjentDLSJEn1902u84390W6W" - hashedToken := sha256.Sum256([]byte(token)) - encodedHashedToken := b64.StdEncoding.EncodeToString(hashedToken[:]) - account.Users["someUser"] = &types.User{ - Id: "someUser", - PATs: map[string]*types.PersonalAccessToken{ - "tokenId": { - ID: "tokenId", - HashedToken: encodedHashedToken, - }, - }, - } - err = store.SaveAccount(context.Background(), account) - if err != nil { - t.Fatalf("Error when saving account: %s", err) - } - - am := DefaultAccountManager{ - Store: store, - } - - err = am.MarkPATUsed(context.Background(), "tokenId") - if err != nil { - t.Fatalf("Error when marking PAT used: %s", err) - } - - account, err = am.Store.GetAccount(context.Background(), "account_id") - if err != nil { - t.Fatalf("Error when getting account: %s", err) - } - assert.True(t, !account.Users["someUser"].PATs["tokenId"].GetLastUsed().IsZero()) -} - func TestAccountManager_PrivateAccount(t *testing.T) { manager, err := createManager(t) if err != nil { @@ -963,13 +865,13 @@ func TestAccountManager_DeleteAccount(t *testing.T) { } func BenchmarkTest_GetAccountWithclaims(b *testing.B) { - claims := jwtclaims.AuthorizationClaims{ + claims := nbcontext.UserAuth{ Domain: "example.com", UserId: "pvt-domain-user", DomainCategory: types.PrivateCategory, } - publicClaims := jwtclaims.AuthorizationClaims{ + publicClaims := nbcontext.UserAuth{ Domain: "test.com", UserId: "public-domain-user", DomainCategory: types.PublicCategory, @@ -2684,11 +2586,13 @@ func TestAccount_SetJWTGroups(t *testing.T) { assert.NoError(t, manager.Store.SaveAccount(context.Background(), account), "unable to save account") t.Run("skip sync for token auth type", func(t *testing.T) { - claims := jwtclaims.AuthorizationClaims{ - UserId: "user1", - Raw: jwt.MapClaims{"groups": []interface{}{"group3"}, "is_token": true}, + claims := nbcontext.UserAuth{ + UserId: "user1", + AccountId: "accountID", + Groups: []string{"group3"}, + IsPAT: true, } - err = manager.syncJWTGroups(context.Background(), "accountID", claims) + err = manager.SyncUserJWTGroups(context.Background(), claims) assert.NoError(t, err, "unable to sync jwt groups") user, err := manager.Store.GetUserByUserID(context.Background(), store.LockingStrengthShare, "user1") @@ -2697,11 +2601,12 @@ func TestAccount_SetJWTGroups(t *testing.T) { }) t.Run("empty jwt groups", func(t *testing.T) { - claims := jwtclaims.AuthorizationClaims{ - UserId: "user1", - Raw: jwt.MapClaims{"groups": []interface{}{}}, + claims := nbcontext.UserAuth{ + UserId: "user1", + AccountId: "accountID", + Groups: []string{}, } - err := manager.syncJWTGroups(context.Background(), "accountID", claims) + err := manager.SyncUserJWTGroups(context.Background(), claims) assert.NoError(t, err, "unable to sync jwt groups") user, err := manager.Store.GetUserByUserID(context.Background(), store.LockingStrengthShare, "user1") @@ -2710,11 +2615,12 @@ func TestAccount_SetJWTGroups(t *testing.T) { }) t.Run("jwt match existing api group", func(t *testing.T) { - claims := jwtclaims.AuthorizationClaims{ - UserId: "user1", - Raw: jwt.MapClaims{"groups": []interface{}{"group1"}}, + claims := nbcontext.UserAuth{ + UserId: "user1", + AccountId: "accountID", + Groups: []string{"group1"}, } - err := manager.syncJWTGroups(context.Background(), "accountID", claims) + err := manager.SyncUserJWTGroups(context.Background(), claims) assert.NoError(t, err, "unable to sync jwt groups") user, err := manager.Store.GetUserByUserID(context.Background(), store.LockingStrengthShare, "user1") @@ -2730,11 +2636,12 @@ func TestAccount_SetJWTGroups(t *testing.T) { account.Users["user1"].AutoGroups = []string{"group1"} assert.NoError(t, manager.Store.SaveUser(context.Background(), store.LockingStrengthUpdate, account.Users["user1"])) - claims := jwtclaims.AuthorizationClaims{ - UserId: "user1", - Raw: jwt.MapClaims{"groups": []interface{}{"group1"}}, + claims := nbcontext.UserAuth{ + UserId: "user1", + AccountId: "accountID", + Groups: []string{"group1"}, } - err = manager.syncJWTGroups(context.Background(), "accountID", claims) + err = manager.SyncUserJWTGroups(context.Background(), claims) assert.NoError(t, err, "unable to sync jwt groups") user, err := manager.Store.GetUserByUserID(context.Background(), store.LockingStrengthShare, "user1") @@ -2747,11 +2654,12 @@ func TestAccount_SetJWTGroups(t *testing.T) { }) t.Run("add jwt group", func(t *testing.T) { - claims := jwtclaims.AuthorizationClaims{ - UserId: "user1", - Raw: jwt.MapClaims{"groups": []interface{}{"group1", "group2"}}, + claims := nbcontext.UserAuth{ + UserId: "user1", + AccountId: "accountID", + Groups: []string{"group1", "group2"}, } - err = manager.syncJWTGroups(context.Background(), "accountID", claims) + err = manager.SyncUserJWTGroups(context.Background(), claims) assert.NoError(t, err, "unable to sync jwt groups") user, err := manager.Store.GetUserByUserID(context.Background(), store.LockingStrengthShare, "user1") @@ -2760,11 +2668,12 @@ func TestAccount_SetJWTGroups(t *testing.T) { }) t.Run("existed group not update", func(t *testing.T) { - claims := jwtclaims.AuthorizationClaims{ - UserId: "user1", - Raw: jwt.MapClaims{"groups": []interface{}{"group2"}}, + claims := nbcontext.UserAuth{ + UserId: "user1", + AccountId: "accountID", + Groups: []string{"group2"}, } - err = manager.syncJWTGroups(context.Background(), "accountID", claims) + err = manager.SyncUserJWTGroups(context.Background(), claims) assert.NoError(t, err, "unable to sync jwt groups") user, err := manager.Store.GetUserByUserID(context.Background(), store.LockingStrengthShare, "user1") @@ -2773,11 +2682,12 @@ func TestAccount_SetJWTGroups(t *testing.T) { }) t.Run("add new group", func(t *testing.T) { - claims := jwtclaims.AuthorizationClaims{ - UserId: "user2", - Raw: jwt.MapClaims{"groups": []interface{}{"group1", "group3"}}, + claims := nbcontext.UserAuth{ + UserId: "user2", + AccountId: "accountID", + Groups: []string{"group1", "group3"}, } - err = manager.syncJWTGroups(context.Background(), "accountID", claims) + err = manager.SyncUserJWTGroups(context.Background(), claims) assert.NoError(t, err, "unable to sync jwt groups") groups, err := manager.Store.GetAccountGroups(context.Background(), store.LockingStrengthShare, "accountID") @@ -2790,11 +2700,12 @@ func TestAccount_SetJWTGroups(t *testing.T) { }) t.Run("remove all JWT groups when list is empty", func(t *testing.T) { - claims := jwtclaims.AuthorizationClaims{ - UserId: "user1", - Raw: jwt.MapClaims{"groups": []interface{}{}}, + claims := nbcontext.UserAuth{ + UserId: "user1", + AccountId: "accountID", + Groups: []string{}, } - err = manager.syncJWTGroups(context.Background(), "accountID", claims) + err = manager.SyncUserJWTGroups(context.Background(), claims) assert.NoError(t, err, "unable to sync jwt groups") user, err := manager.Store.GetUserByUserID(context.Background(), store.LockingStrengthShare, "user1") @@ -2804,11 +2715,12 @@ func TestAccount_SetJWTGroups(t *testing.T) { }) t.Run("remove all JWT groups when claim does not exist", func(t *testing.T) { - claims := jwtclaims.AuthorizationClaims{ - UserId: "user2", - Raw: jwt.MapClaims{}, + claims := nbcontext.UserAuth{ + UserId: "user2", + AccountId: "accountID", + Groups: []string{}, } - err = manager.syncJWTGroups(context.Background(), "accountID", claims) + err = manager.SyncUserJWTGroups(context.Background(), claims) assert.NoError(t, err, "unable to sync jwt groups") user, err := manager.Store.GetUserByUserID(context.Background(), store.LockingStrengthShare, "user2") diff --git a/management/server/jwtclaims/extractor.go b/management/server/auth/jwt/extractor.go similarity index 51% rename from management/server/jwtclaims/extractor.go rename to management/server/auth/jwt/extractor.go index 18214b434..fab429125 100644 --- a/management/server/jwtclaims/extractor.go +++ b/management/server/auth/jwt/extractor.go @@ -1,15 +1,17 @@ -package jwtclaims +package jwt import ( - "net/http" + "errors" + "net/url" "time" "github.com/golang-jwt/jwt" + log "github.com/sirupsen/logrus" + + nbcontext "github.com/netbirdio/netbird/management/server/context" ) const ( - // TokenUserProperty key for the user property in the request context - TokenUserProperty = "user" // AccountIDSuffix suffix for the account id claim AccountIDSuffix = "wt_account_id" // DomainIDSuffix suffix for the domain id claim @@ -22,19 +24,16 @@ const ( LastLoginSuffix = "nb_last_login" // Invited claim indicates that an incoming JWT is from a user that just accepted an invitation Invited = "nb_invited" - // IsToken claim indicates that auth type from the user is a token - IsToken = "is_token" ) -// ExtractClaims Extract function type -type ExtractClaims func(r *http.Request) AuthorizationClaims +var ( + errUserIDClaimEmpty = errors.New("user ID claim token value is empty") +) // ClaimsExtractor struct that holds the extract function type ClaimsExtractor struct { authAudience string userIDClaim string - - FromRequestContext ExtractClaims } // ClaimsExtractorOption is a function that configures the ClaimsExtractor @@ -54,13 +53,6 @@ func WithUserIDClaim(userIDClaim string) ClaimsExtractorOption { } } -// WithFromRequestContext sets the function that extracts claims from the request context -func WithFromRequestContext(ec ExtractClaims) ClaimsExtractorOption { - return func(c *ClaimsExtractor) { - c.FromRequestContext = ec - } -} - // NewClaimsExtractor returns an extractor, and if provided with a function with ExtractClaims signature, // then it will use that logic. Uses ExtractClaimsFromRequestContext by default func NewClaimsExtractor(options ...ClaimsExtractorOption) *ClaimsExtractor { @@ -68,49 +60,13 @@ func NewClaimsExtractor(options ...ClaimsExtractorOption) *ClaimsExtractor { for _, option := range options { option(ce) } - if ce.FromRequestContext == nil { - ce.FromRequestContext = ce.fromRequestContext - } + if ce.userIDClaim == "" { ce.userIDClaim = UserIDClaim } return ce } -// FromToken extracts claims from the token (after auth) -func (c *ClaimsExtractor) FromToken(token *jwt.Token) AuthorizationClaims { - claims := token.Claims.(jwt.MapClaims) - jwtClaims := AuthorizationClaims{ - Raw: claims, - } - userID, ok := claims[c.userIDClaim].(string) - if !ok { - return jwtClaims - } - jwtClaims.UserId = userID - accountIDClaim, ok := claims[c.authAudience+AccountIDSuffix] - if ok { - jwtClaims.AccountId = accountIDClaim.(string) - } - domainClaim, ok := claims[c.authAudience+DomainIDSuffix] - if ok { - jwtClaims.Domain = domainClaim.(string) - } - domainCategoryClaim, ok := claims[c.authAudience+DomainCategorySuffix] - if ok { - jwtClaims.DomainCategory = domainCategoryClaim.(string) - } - LastLoginClaimString, ok := claims[c.authAudience+LastLoginSuffix] - if ok { - jwtClaims.LastLogin = parseTime(LastLoginClaimString.(string)) - } - invitedBool, ok := claims[c.authAudience+Invited] - if ok { - jwtClaims.Invited = invitedBool.(bool) - } - return jwtClaims -} - func parseTime(timeString string) time.Time { if timeString == "" { return time.Time{} @@ -122,11 +78,67 @@ func parseTime(timeString string) time.Time { return parsedTime } -// fromRequestContext extracts claims from the request context previously filled by the JWT token (after auth) -func (c *ClaimsExtractor) fromRequestContext(r *http.Request) AuthorizationClaims { - if r.Context().Value(TokenUserProperty) == nil { - return AuthorizationClaims{} +func (c ClaimsExtractor) audienceClaim(claimName string) string { + url, err := url.JoinPath(c.authAudience, claimName) + if err != nil { + return c.authAudience + claimName // as it was previously } - token := r.Context().Value(TokenUserProperty).(*jwt.Token) - return c.FromToken(token) + + return url +} + +func (c *ClaimsExtractor) ToUserAuth(token *jwt.Token) (nbcontext.UserAuth, error) { + claims := token.Claims.(jwt.MapClaims) + userAuth := nbcontext.UserAuth{} + + userID, ok := claims[c.userIDClaim].(string) + if !ok { + return userAuth, errUserIDClaimEmpty + } + userAuth.UserId = userID + + if accountIDClaim, ok := claims[c.audienceClaim(AccountIDSuffix)]; ok { + userAuth.AccountId = accountIDClaim.(string) + } + + if domainClaim, ok := claims[c.audienceClaim(DomainIDSuffix)]; ok { + userAuth.Domain = domainClaim.(string) + } + + if domainCategoryClaim, ok := claims[c.audienceClaim(DomainCategorySuffix)]; ok { + userAuth.DomainCategory = domainCategoryClaim.(string) + } + + if lastLoginClaimString, ok := claims[c.audienceClaim(LastLoginSuffix)]; ok { + userAuth.LastLogin = parseTime(lastLoginClaimString.(string)) + } + + if invitedBool, ok := claims[c.audienceClaim(Invited)]; ok { + if value, ok := invitedBool.(bool); ok { + userAuth.Invited = value + } + } + + return userAuth, nil +} + +func (c *ClaimsExtractor) ToGroups(token *jwt.Token, claimName string) []string { + claims := token.Claims.(jwt.MapClaims) + userJWTGroups := make([]string, 0) + + if claim, ok := claims[claimName]; ok { + if claimGroups, ok := claim.([]interface{}); ok { + for _, g := range claimGroups { + if group, ok := g.(string); ok { + userJWTGroups = append(userJWTGroups, group) + } else { + log.Debugf("JWT claim %q contains a non-string group (type: %T): %v", claimName, g, g) + } + } + } + } else { + log.Debugf("JWT claim %q is not a string array", claimName) + } + + return userJWTGroups } diff --git a/management/server/auth/jwt/validator.go b/management/server/auth/jwt/validator.go new file mode 100644 index 000000000..5b38ca786 --- /dev/null +++ b/management/server/auth/jwt/validator.go @@ -0,0 +1,302 @@ +package jwt + +import ( + "context" + "crypto/ecdsa" + "crypto/elliptic" + "crypto/rsa" + "encoding/base64" + "encoding/json" + "errors" + "fmt" + "math/big" + "net/http" + "net/url" + "strconv" + "strings" + "sync" + "time" + + "github.com/golang-jwt/jwt" + + log "github.com/sirupsen/logrus" +) + +// Jwks is a collection of JSONWebKey obtained from Config.HttpServerConfig.AuthKeysLocation +type Jwks struct { + Keys []JSONWebKey `json:"keys"` + expiresInTime time.Time +} + +// The supported elliptic curves types +const ( + // p256 represents a cryptographic elliptical curve type. + p256 = "P-256" + + // p384 represents a cryptographic elliptical curve type. + p384 = "P-384" + + // p521 represents a cryptographic elliptical curve type. + p521 = "P-521" +) + +// JSONWebKey is a representation of a Jason Web Key +type JSONWebKey struct { + Kty string `json:"kty"` + Kid string `json:"kid"` + Use string `json:"use"` + N string `json:"n"` + E string `json:"e"` + Crv string `json:"crv"` + X string `json:"x"` + Y string `json:"y"` + X5c []string `json:"x5c"` +} + +type Validator struct { + lock sync.Mutex + issuer string + audienceList []string + keysLocation string + idpSignkeyRefreshEnabled bool + keys *Jwks +} + +var ( + errKeyNotFound = errors.New("unable to find appropriate key") + errInvalidAudience = errors.New("invalid audience") + errInvalidIssuer = errors.New("invalid issuer") + errTokenEmpty = errors.New("required authorization token not found") + errTokenInvalid = errors.New("token is invalid") + errTokenParsing = errors.New("token could not be parsed") +) + +func NewValidator(issuer string, audienceList []string, keysLocation string, idpSignkeyRefreshEnabled bool) *Validator { + keys, err := getPemKeys(keysLocation) + if err != nil { + log.WithField("keysLocation", keysLocation).Errorf("could not get keys from location: %s", err) + } + + return &Validator{ + keys: keys, + issuer: issuer, + audienceList: audienceList, + keysLocation: keysLocation, + idpSignkeyRefreshEnabled: idpSignkeyRefreshEnabled, + } +} + +func (v *Validator) getKeyFunc(ctx context.Context) jwt.Keyfunc { + return func(token *jwt.Token) (interface{}, error) { + // Verify 'aud' claim + var checkAud bool + for _, audience := range v.audienceList { + checkAud = token.Claims.(jwt.MapClaims).VerifyAudience(audience, false) + if checkAud { + break + } + } + if !checkAud { + return token, errInvalidAudience + } + + // Verify 'issuer' claim + checkIss := token.Claims.(jwt.MapClaims).VerifyIssuer(v.issuer, false) + if !checkIss { + return token, errInvalidIssuer + } + + // If keys are rotated, verify the keys prior to token validation + if v.idpSignkeyRefreshEnabled { + // If the keys are invalid, retrieve new ones + // @todo propose a separate go routine to regularly check these to prevent blocking when actually + // validating the token + if !v.keys.stillValid() { + v.lock.Lock() + defer v.lock.Unlock() + + refreshedKeys, err := getPemKeys(v.keysLocation) + if err != nil { + log.WithContext(ctx).Debugf("cannot get JSONWebKey: %v, falling back to old keys", err) + refreshedKeys = v.keys + } + + log.WithContext(ctx).Debugf("keys refreshed, new UTC expiration time: %s", refreshedKeys.expiresInTime.UTC()) + + v.keys = refreshedKeys + } + } + + publicKey, err := getPublicKey(token, v.keys) + if err == nil { + return publicKey, nil + } + + msg := fmt.Sprintf("getPublicKey error: %s", err) + if errors.Is(err, errKeyNotFound) && !v.idpSignkeyRefreshEnabled { + msg = fmt.Sprintf("getPublicKey error: %s. You can enable key refresh by setting HttpServerConfig.IdpSignKeyRefreshEnabled to true in your management.json file and restart the service", err) + } + + log.WithContext(ctx).Error(msg) + + return nil, err + } +} + +// ValidateAndParse validates the token and returns the parsed token +func (m *Validator) ValidateAndParse(ctx context.Context, token string) (*jwt.Token, error) { + // If the token is empty... + if token == "" { + // If we get here, the required token is missing + log.WithContext(ctx).Debugf(" Error: No credentials found (CredentialsOptional=false)") + return nil, errTokenEmpty + } + + // Now parse the token + parsedToken, err := jwt.Parse(token, m.getKeyFunc(ctx)) + + // Check if there was an error in parsing... + if err != nil { + err = fmt.Errorf("%w: %s", errTokenParsing, err) + log.WithContext(ctx).Error(err.Error()) + return nil, err + } + + // Check if the parsed token is valid... + if !parsedToken.Valid { + log.WithContext(ctx).Debug(errTokenInvalid.Error()) + return nil, errTokenInvalid + } + + return parsedToken, nil +} + +// stillValid returns true if the JSONWebKey still valid and have enough time to be used +func (jwks *Jwks) stillValid() bool { + return !jwks.expiresInTime.IsZero() && time.Now().Add(5*time.Second).Before(jwks.expiresInTime) +} + +func getPemKeys(keysLocation string) (*Jwks, error) { + jwks := &Jwks{} + + url, err := url.ParseRequestURI(keysLocation) + if err != nil { + return jwks, err + } + + resp, err := http.Get(url.String()) + if err != nil { + return jwks, err + } + defer resp.Body.Close() + + err = json.NewDecoder(resp.Body).Decode(jwks) + if err != nil { + return jwks, err + } + + cacheControlHeader := resp.Header.Get("Cache-Control") + expiresIn := getMaxAgeFromCacheHeader(cacheControlHeader) + jwks.expiresInTime = time.Now().Add(time.Duration(expiresIn) * time.Second) + + return jwks, nil +} + +func getPublicKey(token *jwt.Token, jwks *Jwks) (interface{}, error) { + // todo as we load the jkws when the server is starting, we should build a JKS map with the pem cert at the boot time + for k := range jwks.Keys { + if token.Header["kid"] != jwks.Keys[k].Kid { + continue + } + + if len(jwks.Keys[k].X5c) != 0 { + cert := "-----BEGIN CERTIFICATE-----\n" + jwks.Keys[k].X5c[0] + "\n-----END CERTIFICATE-----" + return jwt.ParseRSAPublicKeyFromPEM([]byte(cert)) + } + + if jwks.Keys[k].Kty == "RSA" { + return getPublicKeyFromRSA(jwks.Keys[k]) + } + if jwks.Keys[k].Kty == "EC" { + return getPublicKeyFromECDSA(jwks.Keys[k]) + } + } + + return nil, errKeyNotFound +} + +func getPublicKeyFromECDSA(jwk JSONWebKey) (publicKey *ecdsa.PublicKey, err error) { + if jwk.X == "" || jwk.Y == "" || jwk.Crv == "" { + return nil, fmt.Errorf("ecdsa key incomplete") + } + + var xCoordinate []byte + if xCoordinate, err = base64.RawURLEncoding.DecodeString(jwk.X); err != nil { + return nil, err + } + + var yCoordinate []byte + if yCoordinate, err = base64.RawURLEncoding.DecodeString(jwk.Y); err != nil { + return nil, err + } + + publicKey = &ecdsa.PublicKey{} + + var curve elliptic.Curve + switch jwk.Crv { + case p256: + curve = elliptic.P256() + case p384: + curve = elliptic.P384() + case p521: + curve = elliptic.P521() + } + + publicKey.Curve = curve + publicKey.X = big.NewInt(0).SetBytes(xCoordinate) + publicKey.Y = big.NewInt(0).SetBytes(yCoordinate) + + return publicKey, nil +} + +func getPublicKeyFromRSA(jwk JSONWebKey) (*rsa.PublicKey, error) { + decodedE, err := base64.RawURLEncoding.DecodeString(jwk.E) + if err != nil { + return nil, err + } + decodedN, err := base64.RawURLEncoding.DecodeString(jwk.N) + if err != nil { + return nil, err + } + + var n, e big.Int + e.SetBytes(decodedE) + n.SetBytes(decodedN) + + return &rsa.PublicKey{ + E: int(e.Int64()), + N: &n, + }, nil +} + +// getMaxAgeFromCacheHeader extracts max-age directive from the Cache-Control header +func getMaxAgeFromCacheHeader(cacheControl string) int { + // Split into individual directives + directives := strings.Split(cacheControl, ",") + + for _, directive := range directives { + directive = strings.TrimSpace(directive) + if strings.HasPrefix(directive, "max-age=") { + // Extract the max-age value + maxAgeStr := strings.TrimPrefix(directive, "max-age=") + maxAge, err := strconv.Atoi(maxAgeStr) + if err != nil { + return 0 + } + + return maxAge + } + } + + return 0 +} diff --git a/management/server/auth/manager.go b/management/server/auth/manager.go new file mode 100644 index 000000000..6835a3ced --- /dev/null +++ b/management/server/auth/manager.go @@ -0,0 +1,170 @@ +package auth + +import ( + "context" + "crypto/sha256" + "encoding/base64" + "fmt" + "hash/crc32" + + "github.com/golang-jwt/jwt" + + "github.com/netbirdio/netbird/base62" + nbjwt "github.com/netbirdio/netbird/management/server/auth/jwt" + nbcontext "github.com/netbirdio/netbird/management/server/context" + "github.com/netbirdio/netbird/management/server/store" + "github.com/netbirdio/netbird/management/server/types" +) + +var _ Manager = (*manager)(nil) + +type Manager interface { + ValidateAndParseToken(ctx context.Context, value string) (nbcontext.UserAuth, *jwt.Token, error) + EnsureUserAccessByJWTGroups(ctx context.Context, userAuth nbcontext.UserAuth, token *jwt.Token) (nbcontext.UserAuth, error) + MarkPATUsed(ctx context.Context, tokenID string) error + GetPATInfo(ctx context.Context, token string) (user *types.User, pat *types.PersonalAccessToken, domain string, category string, err error) +} + +type manager struct { + store store.Store + + validator *nbjwt.Validator + extractor *nbjwt.ClaimsExtractor +} + +func NewManager(store store.Store, issuer, audience, keysLocation, userIdClaim string, allAudiences []string, idpRefreshKeys bool) Manager { + // @note if invalid/missing parameters are sent the validator will instantiate + // but it will fail when validating and parsing the token + jwtValidator := nbjwt.NewValidator( + issuer, + allAudiences, + keysLocation, + idpRefreshKeys, + ) + + claimsExtractor := nbjwt.NewClaimsExtractor( + nbjwt.WithAudience(audience), + nbjwt.WithUserIDClaim(userIdClaim), + ) + + return &manager{ + store: store, + + validator: jwtValidator, + extractor: claimsExtractor, + } +} + +func (m *manager) ValidateAndParseToken(ctx context.Context, value string) (nbcontext.UserAuth, *jwt.Token, error) { + token, err := m.validator.ValidateAndParse(ctx, value) + if err != nil { + return nbcontext.UserAuth{}, nil, err + } + + userAuth, err := m.extractor.ToUserAuth(token) + if err != nil { + return nbcontext.UserAuth{}, nil, err + } + return userAuth, token, err +} + +func (m *manager) EnsureUserAccessByJWTGroups(ctx context.Context, userAuth nbcontext.UserAuth, token *jwt.Token) (nbcontext.UserAuth, error) { + if userAuth.IsChild || userAuth.IsPAT { + return userAuth, nil + } + + settings, err := m.store.GetAccountSettings(ctx, store.LockingStrengthShare, userAuth.AccountId) + if err != nil { + return userAuth, err + } + + // Ensures JWT group synchronization to the management is enabled before, + // filtering access based on the allowed groups. + if settings != nil && settings.JWTGroupsEnabled { + userAuth.Groups = m.extractor.ToGroups(token, settings.JWTGroupsClaimName) + if allowedGroups := settings.JWTAllowGroups; len(allowedGroups) > 0 { + if !userHasAllowedGroup(allowedGroups, userAuth.Groups) { + return userAuth, fmt.Errorf("user does not belong to any of the allowed JWT groups") + } + } + } + + return userAuth, nil +} + +// MarkPATUsed marks a personal access token as used +func (am *manager) MarkPATUsed(ctx context.Context, tokenID string) error { + return am.store.MarkPATUsed(ctx, store.LockingStrengthUpdate, tokenID) +} + +// GetPATInfo retrieves user, personal access token, domain, and category details from a personal access token. +func (am *manager) GetPATInfo(ctx context.Context, token string) (user *types.User, pat *types.PersonalAccessToken, domain string, category string, err error) { + user, pat, err = am.extractPATFromToken(ctx, token) + if err != nil { + return nil, nil, "", "", err + } + + domain, category, err = am.store.GetAccountDomainAndCategory(ctx, store.LockingStrengthShare, user.AccountID) + if err != nil { + return nil, nil, "", "", err + } + + return user, pat, domain, category, nil +} + +// extractPATFromToken validates the token structure and retrieves associated User and PAT. +func (am *manager) extractPATFromToken(ctx context.Context, token string) (*types.User, *types.PersonalAccessToken, error) { + if len(token) != types.PATLength { + return nil, nil, fmt.Errorf("PAT has incorrect length") + } + + prefix := token[:len(types.PATPrefix)] + if prefix != types.PATPrefix { + return nil, nil, fmt.Errorf("PAT has wrong prefix") + } + secret := token[len(types.PATPrefix) : len(types.PATPrefix)+types.PATSecretLength] + encodedChecksum := token[len(types.PATPrefix)+types.PATSecretLength : len(types.PATPrefix)+types.PATSecretLength+types.PATChecksumLength] + + verificationChecksum, err := base62.Decode(encodedChecksum) + if err != nil { + return nil, nil, fmt.Errorf("PAT checksum decoding failed: %w", err) + } + + secretChecksum := crc32.ChecksumIEEE([]byte(secret)) + if secretChecksum != verificationChecksum { + return nil, nil, fmt.Errorf("PAT checksum does not match") + } + + hashedToken := sha256.Sum256([]byte(token)) + encodedHashedToken := base64.StdEncoding.EncodeToString(hashedToken[:]) + + var user *types.User + var pat *types.PersonalAccessToken + + err = am.store.ExecuteInTransaction(ctx, func(transaction store.Store) error { + pat, err = transaction.GetPATByHashedToken(ctx, store.LockingStrengthShare, encodedHashedToken) + if err != nil { + return err + } + + user, err = transaction.GetUserByPATID(ctx, store.LockingStrengthShare, pat.ID) + return err + }) + if err != nil { + return nil, nil, err + } + + return user, pat, nil +} + +// userHasAllowedGroup checks if a user belongs to any of the allowed groups. +func userHasAllowedGroup(allowedGroups []string, userGroups []string) bool { + for _, userGroup := range userGroups { + for _, allowedGroup := range allowedGroups { + if userGroup == allowedGroup { + return true + } + } + } + return false +} diff --git a/management/server/auth/manager_mock.go b/management/server/auth/manager_mock.go new file mode 100644 index 000000000..bc7066548 --- /dev/null +++ b/management/server/auth/manager_mock.go @@ -0,0 +1,54 @@ +package auth + +import ( + "context" + + "github.com/golang-jwt/jwt" + + nbcontext "github.com/netbirdio/netbird/management/server/context" + "github.com/netbirdio/netbird/management/server/types" +) + +var ( + _ Manager = (*MockManager)(nil) +) + +// @note really dislike this mocking approach but rather than have to do additional test refactoring. +type MockManager struct { + ValidateAndParseTokenFunc func(ctx context.Context, value string) (nbcontext.UserAuth, *jwt.Token, error) + EnsureUserAccessByJWTGroupsFunc func(ctx context.Context, userAuth nbcontext.UserAuth, token *jwt.Token) (nbcontext.UserAuth, error) + MarkPATUsedFunc func(ctx context.Context, tokenID string) error + GetPATInfoFunc func(ctx context.Context, token string) (user *types.User, pat *types.PersonalAccessToken, domain string, category string, err error) +} + +// EnsureUserAccessByJWTGroups implements Manager. +func (m *MockManager) EnsureUserAccessByJWTGroups(ctx context.Context, userAuth nbcontext.UserAuth, token *jwt.Token) (nbcontext.UserAuth, error) { + if m.EnsureUserAccessByJWTGroupsFunc != nil { + return m.EnsureUserAccessByJWTGroupsFunc(ctx, userAuth, token) + } + return nbcontext.UserAuth{}, nil +} + +// GetPATInfo implements Manager. +func (m *MockManager) GetPATInfo(ctx context.Context, token string) (user *types.User, pat *types.PersonalAccessToken, domain string, category string, err error) { + if m.GetPATInfoFunc != nil { + return m.GetPATInfoFunc(ctx, token) + } + return &types.User{}, &types.PersonalAccessToken{}, "", "", nil +} + +// MarkPATUsed implements Manager. +func (m *MockManager) MarkPATUsed(ctx context.Context, tokenID string) error { + if m.MarkPATUsedFunc != nil { + return m.MarkPATUsedFunc(ctx, tokenID) + } + return nil +} + +// ValidateAndParseToken implements Manager. +func (m *MockManager) ValidateAndParseToken(ctx context.Context, value string) (nbcontext.UserAuth, *jwt.Token, error) { + if m.ValidateAndParseTokenFunc != nil { + return m.ValidateAndParseTokenFunc(ctx, value) + } + return nbcontext.UserAuth{}, &jwt.Token{}, nil +} diff --git a/management/server/auth/manager_test.go b/management/server/auth/manager_test.go new file mode 100644 index 000000000..55fb1e31a --- /dev/null +++ b/management/server/auth/manager_test.go @@ -0,0 +1,407 @@ +package auth_test + +import ( + "context" + "crypto/sha256" + "encoding/base64" + "fmt" + "net/http" + "net/http/httptest" + "os" + "strings" + "testing" + "time" + + "github.com/golang-jwt/jwt" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "github.com/netbirdio/netbird/management/server/auth" + nbjwt "github.com/netbirdio/netbird/management/server/auth/jwt" + nbcontext "github.com/netbirdio/netbird/management/server/context" + "github.com/netbirdio/netbird/management/server/store" + "github.com/netbirdio/netbird/management/server/types" +) + +func TestAuthManager_GetAccountInfoFromPAT(t *testing.T) { + store, cleanup, err := store.NewTestStoreFromSQL(context.Background(), "", t.TempDir()) + if err != nil { + t.Fatalf("Error when creating store: %s", err) + } + t.Cleanup(cleanup) + + token := "nbp_9999EUDNdkeusjentDLSJEn1902u84390W6W" + hashedToken := sha256.Sum256([]byte(token)) + encodedHashedToken := base64.StdEncoding.EncodeToString(hashedToken[:]) + account := &types.Account{ + Id: "account_id", + Users: map[string]*types.User{"someUser": { + Id: "someUser", + PATs: map[string]*types.PersonalAccessToken{ + "tokenId": { + ID: "tokenId", + UserID: "someUser", + HashedToken: encodedHashedToken, + }, + }, + }}, + } + + err = store.SaveAccount(context.Background(), account) + if err != nil { + t.Fatalf("Error when saving account: %s", err) + } + + manager := auth.NewManager(store, "", "", "", "", []string{}, false) + + user, pat, _, _, err := manager.GetPATInfo(context.Background(), token) + if err != nil { + t.Fatalf("Error when getting Account from PAT: %s", err) + } + + assert.Equal(t, "account_id", user.AccountID) + assert.Equal(t, "someUser", user.Id) + assert.Equal(t, account.Users["someUser"].PATs["tokenId"].ID, pat.ID) +} + +func TestAuthManager_MarkPATUsed(t *testing.T) { + store, cleanup, err := store.NewTestStoreFromSQL(context.Background(), "", t.TempDir()) + if err != nil { + t.Fatalf("Error when creating store: %s", err) + } + t.Cleanup(cleanup) + + token := "nbp_9999EUDNdkeusjentDLSJEn1902u84390W6W" + hashedToken := sha256.Sum256([]byte(token)) + encodedHashedToken := base64.StdEncoding.EncodeToString(hashedToken[:]) + account := &types.Account{ + Id: "account_id", + Users: map[string]*types.User{"someUser": { + Id: "someUser", + PATs: map[string]*types.PersonalAccessToken{ + "tokenId": { + ID: "tokenId", + HashedToken: encodedHashedToken, + }, + }, + }}, + } + + err = store.SaveAccount(context.Background(), account) + if err != nil { + t.Fatalf("Error when saving account: %s", err) + } + + manager := auth.NewManager(store, "", "", "", "", []string{}, false) + + err = manager.MarkPATUsed(context.Background(), "tokenId") + if err != nil { + t.Fatalf("Error when marking PAT used: %s", err) + } + + account, err = store.GetAccount(context.Background(), "account_id") + if err != nil { + t.Fatalf("Error when getting account: %s", err) + } + assert.True(t, !account.Users["someUser"].PATs["tokenId"].GetLastUsed().IsZero()) +} + +func TestAuthManager_EnsureUserAccessByJWTGroups(t *testing.T) { + store, cleanup, err := store.NewTestStoreFromSQL(context.Background(), "", t.TempDir()) + if err != nil { + t.Fatalf("Error when creating store: %s", err) + } + t.Cleanup(cleanup) + + userId := "user-id" + domain := "test.domain" + + account := &types.Account{ + Id: "account_id", + Domain: domain, + Users: map[string]*types.User{"someUser": { + Id: "someUser", + }}, + Settings: &types.Settings{}, + } + + err = store.SaveAccount(context.Background(), account) + if err != nil { + t.Fatalf("Error when saving account: %s", err) + } + + // this has been validated and parsed by ValidateAndParseToken + userAuth := nbcontext.UserAuth{ + AccountId: account.Id, + Domain: domain, + UserId: userId, + DomainCategory: "test-category", + // Groups: []string{"group1", "group2"}, + } + + // these tests only assert groups are parsed from token as per account settings + token := jwt.NewWithClaims(jwt.SigningMethodHS256, jwt.MapClaims{"idp-groups": []interface{}{"group1", "group2"}}) + + manager := auth.NewManager(store, "", "", "", "", []string{}, false) + + t.Run("JWT groups disabled", func(t *testing.T) { + userAuth, err := manager.EnsureUserAccessByJWTGroups(context.Background(), userAuth, token) + require.NoError(t, err, "ensure user access by JWT groups failed") + require.Len(t, userAuth.Groups, 0, "account not enabled to ensure access by groups") + }) + + t.Run("User impersonated", func(t *testing.T) { + userAuth, err := manager.EnsureUserAccessByJWTGroups(context.Background(), userAuth, token) + require.NoError(t, err, "ensure user access by JWT groups failed") + require.Len(t, userAuth.Groups, 0, "account not enabled to ensure access by groups") + }) + + t.Run("User PAT", func(t *testing.T) { + userAuth, err := manager.EnsureUserAccessByJWTGroups(context.Background(), userAuth, token) + require.NoError(t, err, "ensure user access by JWT groups failed") + require.Len(t, userAuth.Groups, 0, "account not enabled to ensure access by groups") + }) + + t.Run("JWT groups enabled without claim name", func(t *testing.T) { + account.Settings.JWTGroupsEnabled = true + err := store.SaveAccount(context.Background(), account) + require.NoError(t, err, "save account failed") + + userAuth, err := manager.EnsureUserAccessByJWTGroups(context.Background(), userAuth, token) + require.NoError(t, err, "ensure user access by JWT groups failed") + require.Len(t, userAuth.Groups, 0, "account missing groups claim name") + }) + + t.Run("JWT groups enabled without allowed groups", func(t *testing.T) { + account.Settings.JWTGroupsEnabled = true + account.Settings.JWTGroupsClaimName = "idp-groups" + err := store.SaveAccount(context.Background(), account) + require.NoError(t, err, "save account failed") + + userAuth, err := manager.EnsureUserAccessByJWTGroups(context.Background(), userAuth, token) + require.NoError(t, err, "ensure user access by JWT groups failed") + require.Equal(t, []string{"group1", "group2"}, userAuth.Groups, "group parsed do not match") + }) + + t.Run("User in allowed JWT groups", func(t *testing.T) { + account.Settings.JWTGroupsEnabled = true + account.Settings.JWTGroupsClaimName = "idp-groups" + account.Settings.JWTAllowGroups = []string{"group1"} + err := store.SaveAccount(context.Background(), account) + require.NoError(t, err, "save account failed") + + userAuth, err := manager.EnsureUserAccessByJWTGroups(context.Background(), userAuth, token) + require.NoError(t, err, "ensure user access by JWT groups failed") + + require.Equal(t, []string{"group1", "group2"}, userAuth.Groups, "group parsed do not match") + }) + + t.Run("User not in allowed JWT groups", func(t *testing.T) { + account.Settings.JWTGroupsEnabled = true + account.Settings.JWTGroupsClaimName = "idp-groups" + account.Settings.JWTAllowGroups = []string{"not-a-group"} + err := store.SaveAccount(context.Background(), account) + require.NoError(t, err, "save account failed") + + _, err = manager.EnsureUserAccessByJWTGroups(context.Background(), userAuth, token) + require.Error(t, err, "ensure user access is not in allowed groups") + }) +} + +func TestAuthManager_ValidateAndParseToken(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Add("Cache-Control", "max-age=30") // set a 30s expiry to these keys + http.ServeFile(w, r, "test_data/jwks.json") + })) + defer server.Close() + + issuer := "http://issuer.local" + audience := "http://audience.local" + userIdClaim := "" // defaults to "sub" + + // we're only testing with RSA256 + keyData, _ := os.ReadFile("test_data/sample_key") + key, _ := jwt.ParseRSAPrivateKeyFromPEM(keyData) + keyId := "test-key" + + // note, we can use a nil store because ValidateAndParseToken does not use it in it's flow + manager := auth.NewManager(nil, issuer, audience, server.URL, userIdClaim, []string{audience}, false) + + customClaim := func(name string) string { + return fmt.Sprintf("%s/%s", audience, name) + } + + lastLogin := time.Date(2025, 2, 12, 14, 25, 26, 0, time.UTC) //"2025-02-12T14:25:26.186Z" + + tests := []struct { + name string + tokenFunc func() string + expected *nbcontext.UserAuth // nil indicates expected error + }{ + { + name: "Valid with custom claims", + tokenFunc: func() string { + token := jwt.New(jwt.SigningMethodRS256) + token.Header["kid"] = keyId + token.Claims = jwt.MapClaims{ + "iss": issuer, + "aud": []string{audience}, + "iat": time.Now().Unix(), + "exp": time.Now().Add(time.Hour * 1).Unix(), + "sub": "user-id|123", + customClaim(nbjwt.AccountIDSuffix): "account-id|567", + customClaim(nbjwt.DomainIDSuffix): "http://localhost", + customClaim(nbjwt.DomainCategorySuffix): "private", + customClaim(nbjwt.LastLoginSuffix): lastLogin.Format(time.RFC3339), + customClaim(nbjwt.Invited): false, + } + tokenString, _ := token.SignedString(key) + return tokenString + }, + expected: &nbcontext.UserAuth{ + UserId: "user-id|123", + AccountId: "account-id|567", + Domain: "http://localhost", + DomainCategory: "private", + LastLogin: lastLogin, + Invited: false, + }, + }, + { + name: "Valid without custom claims", + tokenFunc: func() string { + token := jwt.New(jwt.SigningMethodRS256) + token.Header["kid"] = keyId + token.Claims = jwt.MapClaims{ + "iss": issuer, + "aud": []string{audience}, + "iat": time.Now().Unix(), + "exp": time.Now().Add(time.Hour).Unix(), + "sub": "user-id|123", + } + tokenString, _ := token.SignedString(key) + return tokenString + }, + expected: &nbcontext.UserAuth{ + UserId: "user-id|123", + }, + }, + { + name: "Expired token", + tokenFunc: func() string { + token := jwt.New(jwt.SigningMethodRS256) + token.Header["kid"] = keyId + token.Claims = jwt.MapClaims{ + "iss": issuer, + "aud": []string{audience}, + "iat": time.Now().Add(time.Hour * -2).Unix(), + "exp": time.Now().Add(time.Hour * -1).Unix(), + "sub": "user-id|123", + } + tokenString, _ := token.SignedString(key) + return tokenString + }, + }, + { + name: "Not yet valid", + tokenFunc: func() string { + token := jwt.New(jwt.SigningMethodRS256) + token.Header["kid"] = keyId + token.Claims = jwt.MapClaims{ + "iss": issuer, + "aud": []string{audience}, + "iat": time.Now().Add(time.Hour).Unix(), + "exp": time.Now().Add(time.Hour * 2).Unix(), + "sub": "user-id|123", + } + tokenString, _ := token.SignedString(key) + return tokenString + }, + }, + { + name: "Invalid signature", + tokenFunc: func() string { + token := jwt.New(jwt.SigningMethodRS256) + token.Header["kid"] = keyId + token.Claims = jwt.MapClaims{ + "iss": issuer, + "aud": []string{audience}, + "iat": time.Now().Unix(), + "exp": time.Now().Add(time.Hour).Unix(), + "sub": "user-id|123", + } + tokenString, _ := token.SignedString(key) + parts := strings.Split(tokenString, ".") + parts[2] = "invalid-signature" + return strings.Join(parts, ".") + }, + }, + { + name: "Invalid issuer", + tokenFunc: func() string { + token := jwt.New(jwt.SigningMethodRS256) + token.Header["kid"] = keyId + token.Claims = jwt.MapClaims{ + "iss": "not-the-issuer", + "aud": []string{audience}, + "iat": time.Now().Unix(), + "exp": time.Now().Add(time.Hour).Unix(), + "sub": "user-id|123", + } + tokenString, _ := token.SignedString(key) + return tokenString + }, + }, + { + name: "Invalid audience", + tokenFunc: func() string { + token := jwt.New(jwt.SigningMethodRS256) + token.Header["kid"] = keyId + token.Claims = jwt.MapClaims{ + "iss": issuer, + "aud": []string{"not-the-audience"}, + "iat": time.Now().Unix(), + "exp": time.Now().Add(time.Hour).Unix(), + "sub": "user-id|123", + } + tokenString, _ := token.SignedString(key) + return tokenString + }, + }, + { + name: "Invalid user claim", + tokenFunc: func() string { + token := jwt.New(jwt.SigningMethodRS256) + token.Header["kid"] = keyId + token.Claims = jwt.MapClaims{ + "iss": issuer, + "aud": []string{audience}, + "iat": time.Now().Unix(), + "exp": time.Now().Add(time.Hour).Unix(), + "not-sub": "user-id|123", + } + tokenString, _ := token.SignedString(key) + return tokenString + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + tokenString := tt.tokenFunc() + + userAuth, token, err := manager.ValidateAndParseToken(context.Background(), tokenString) + + if tt.expected != nil { + assert.NoError(t, err) + assert.True(t, token.Valid) + assert.Equal(t, *tt.expected, userAuth) + } else { + assert.Error(t, err) + assert.Nil(t, token) + assert.Empty(t, userAuth) + } + }) + } + +} diff --git a/management/server/auth/test_data/jwks.json b/management/server/auth/test_data/jwks.json new file mode 100644 index 000000000..8080f5599 --- /dev/null +++ b/management/server/auth/test_data/jwks.json @@ -0,0 +1,11 @@ +{ + "keys": [ + { + "kty": "RSA", + "kid": "test-key", + "use": "sig", + "n": "4f5wg5l2hKsTeNem_V41fGnJm6gOdrj8ym3rFkEU_wT8RDtnSgFEZOQpHEgQ7JL38xUfU0Y3g6aYw9QT0hJ7mCpz9Er5qLaMXJwZxzHzAahlfA0icqabvJOMvQtzD6uQv6wPEyZtDTWiQi9AXwBpHssPnpYGIn20ZZuNlX2BrClciHhCPUIIZOQn_MmqTD31jSyjoQoV7MhhMTATKJx2XrHhR-1DcKJzQBSTAGnpYVaqpsARap-nwRipr3nUTuxyGohBTSmjJ2usSeQXHI3bODIRe1AuTyHceAbewn8b462yEWKARdpd9AjQW5SIVPfdsz5B6GlYQ5LdYKtznTuy7w", + "e": "AQAB" + } + ] +} \ No newline at end of file diff --git a/management/server/auth/test_data/sample_key b/management/server/auth/test_data/sample_key new file mode 100644 index 000000000..e69284a3f --- /dev/null +++ b/management/server/auth/test_data/sample_key @@ -0,0 +1,27 @@ +-----BEGIN RSA PRIVATE KEY----- +MIIEowIBAAKCAQEA4f5wg5l2hKsTeNem/V41fGnJm6gOdrj8ym3rFkEU/wT8RDtn +SgFEZOQpHEgQ7JL38xUfU0Y3g6aYw9QT0hJ7mCpz9Er5qLaMXJwZxzHzAahlfA0i +cqabvJOMvQtzD6uQv6wPEyZtDTWiQi9AXwBpHssPnpYGIn20ZZuNlX2BrClciHhC +PUIIZOQn/MmqTD31jSyjoQoV7MhhMTATKJx2XrHhR+1DcKJzQBSTAGnpYVaqpsAR +ap+nwRipr3nUTuxyGohBTSmjJ2usSeQXHI3bODIRe1AuTyHceAbewn8b462yEWKA +Rdpd9AjQW5SIVPfdsz5B6GlYQ5LdYKtznTuy7wIDAQABAoIBAQCwia1k7+2oZ2d3 +n6agCAbqIE1QXfCmh41ZqJHbOY3oRQG3X1wpcGH4Gk+O+zDVTV2JszdcOt7E5dAy +MaomETAhRxB7hlIOnEN7WKm+dGNrKRvV0wDU5ReFMRHg31/Lnu8c+5BvGjZX+ky9 +POIhFFYJqwCRlopGSUIxmVj5rSgtzk3iWOQXr+ah1bjEXvlxDOWkHN6YfpV5ThdE +KdBIPGEVqa63r9n2h+qazKrtiRqJqGnOrHzOECYbRFYhexsNFz7YT02xdfSHn7gM +IvabDDP/Qp0PjE1jdouiMaFHYnLBbgvlnZW9yuVf/rpXTUq/njxIXMmvmEyyvSDn +FcFikB8pAoGBAPF77hK4m3/rdGT7X8a/gwvZ2R121aBcdPwEaUhvj/36dx596zvY +mEOjrWfZhF083/nYWE2kVquj2wjs+otCLfifEEgXcVPTnEOPO9Zg3uNSL0nNQghj +FuD3iGLTUBCtM66oTe0jLSslHe8gLGEQqyMzHOzYxNqibxcOZIe8Qt0NAoGBAO+U +I5+XWjWEgDmvyC3TrOSf/KCGjtu0TSv30ipv27bDLMrpvPmD/5lpptTFwcxvVhCs +2b+chCjlghFSWFbBULBrfci2FtliClOVMYrlNBdUSJhf3aYSG2Doe6Bgt1n2CpNn +/iu37Y3NfemZBJA7hNl4dYe+f+uzM87cdQ214+jrAoGAXA0XxX8ll2+ToOLJsaNT +OvNB9h9Uc5qK5X5w+7G7O998BN2PC/MWp8H+2fVqpXgNENpNXttkRm1hk1dych86 +EunfdPuqsX+as44oCyJGFHVBnWpm33eWQw9YqANRI+pCJzP08I5WK3osnPiwshd+ +hR54yjgfYhBFNI7B95PmEQkCgYBzFSz7h1+s34Ycr8SvxsOBWxymG5zaCsUbPsL0 +4aCgLScCHb9J+E86aVbbVFdglYa5Id7DPTL61ixhl7WZjujspeXZGSbmq0Kcnckb +mDgqkLECiOJW2NHP/j0McAkDLL4tysF8TLDO8gvuvzNC+WQ6drO2ThrypLVZQ+ry +eBIPmwKBgEZxhqa0gVvHQG/7Od69KWj4eJP28kq13RhKay8JOoN0vPmspXJo1HY3 +CKuHRG+AP579dncdUnOMvfXOtkdM4vk0+hWASBQzM9xzVcztCa+koAugjVaLS9A+ +9uQoqEeVNTckxx0S2bYevRy7hGQmUJTyQm3j1zEUR5jpdbL83Fbq +-----END RSA PRIVATE KEY----- \ No newline at end of file diff --git a/management/server/auth/test_data/sample_key.pub b/management/server/auth/test_data/sample_key.pub new file mode 100644 index 000000000..d5b7f7102 --- /dev/null +++ b/management/server/auth/test_data/sample_key.pub @@ -0,0 +1,9 @@ +-----BEGIN PUBLIC KEY----- +MIIBIjANBgkqhkiG9w0BAQEFAAOCAQ8AMIIBCgKCAQEA4f5wg5l2hKsTeNem/V41 +fGnJm6gOdrj8ym3rFkEU/wT8RDtnSgFEZOQpHEgQ7JL38xUfU0Y3g6aYw9QT0hJ7 +mCpz9Er5qLaMXJwZxzHzAahlfA0icqabvJOMvQtzD6uQv6wPEyZtDTWiQi9AXwBp +HssPnpYGIn20ZZuNlX2BrClciHhCPUIIZOQn/MmqTD31jSyjoQoV7MhhMTATKJx2 +XrHhR+1DcKJzQBSTAGnpYVaqpsARap+nwRipr3nUTuxyGohBTSmjJ2usSeQXHI3b +ODIRe1AuTyHceAbewn8b462yEWKARdpd9AjQW5SIVPfdsz5B6GlYQ5LdYKtznTuy +7wIDAQAB +-----END PUBLIC KEY----- \ No newline at end of file diff --git a/management/server/config.go b/management/server/config.go index 397b5f0e6..ce2ff4d16 100644 --- a/management/server/config.go +++ b/management/server/config.go @@ -2,7 +2,6 @@ package server import ( "net/netip" - "net/url" "github.com/netbirdio/netbird/management/server/idp" "github.com/netbirdio/netbird/management/server/store" @@ -180,9 +179,3 @@ type ReverseProxy struct { // trusted IP prefixes. TrustedPeers []netip.Prefix } - -// validateURL validates input http url -func validateURL(httpURL string) bool { - _, err := url.ParseRequestURI(httpURL) - return err == nil -} diff --git a/management/server/context/auth.go b/management/server/context/auth.go new file mode 100644 index 000000000..5cb28ddb7 --- /dev/null +++ b/management/server/context/auth.go @@ -0,0 +1,60 @@ +package context + +import ( + "context" + "fmt" + "net/http" + "time" +) + +type key int + +const ( + UserAuthContextKey key = iota +) + +type UserAuth struct { + // The account id the user is accessing + AccountId string + // The account domain + Domain string + // The account domain category, TBC values + DomainCategory string + // Indicates whether this user was invited, TBC logic + Invited bool + // Indicates whether this is a child account + IsChild bool + + // The user id + UserId string + // Last login time for this user + LastLogin time.Time + // The Groups the user belongs to on this account + Groups []string + + // Indicates whether this user has authenticated with a Personal Access Token + IsPAT bool +} + +func GetUserAuthFromRequest(r *http.Request) (UserAuth, error) { + return GetUserAuthFromContext(r.Context()) +} + +func SetUserAuthInRequest(r *http.Request, userAuth UserAuth) *http.Request { + return r.WithContext(SetUserAuthInContext(r.Context(), userAuth)) +} + +func GetUserAuthFromContext(ctx context.Context) (UserAuth, error) { + if userAuth, ok := ctx.Value(UserAuthContextKey).(UserAuth); ok { + return userAuth, nil + } + return UserAuth{}, fmt.Errorf("user auth not in context") +} + +func SetUserAuthInContext(ctx context.Context, userAuth UserAuth) context.Context { + //nolint + ctx = context.WithValue(ctx, UserIDKey, userAuth.UserId) + //nolint + ctx = context.WithValue(ctx, AccountIDKey, userAuth.AccountId) + return context.WithValue(ctx, UserAuthContextKey, userAuth) +} diff --git a/management/server/grpcserver.go b/management/server/grpcserver.go index c3a36153a..9f77fd242 100644 --- a/management/server/grpcserver.go +++ b/management/server/grpcserver.go @@ -20,8 +20,8 @@ import ( "github.com/netbirdio/netbird/encryption" "github.com/netbirdio/netbird/management/proto" + "github.com/netbirdio/netbird/management/server/auth" nbContext "github.com/netbirdio/netbird/management/server/context" - "github.com/netbirdio/netbird/management/server/jwtclaims" nbpeer "github.com/netbirdio/netbird/management/server/peer" "github.com/netbirdio/netbird/management/server/posture" "github.com/netbirdio/netbird/management/server/settings" @@ -39,11 +39,10 @@ type GRPCServer struct { peersUpdateManager *PeersUpdateManager config *Config secretsManager SecretsManager - jwtValidator jwtclaims.JWTValidator - jwtClaimsExtractor *jwtclaims.ClaimsExtractor appMetrics telemetry.AppMetrics ephemeralManager *EphemeralManager peerLocks sync.Map + authManager auth.Manager } // NewServer creates a new Management server @@ -56,29 +55,13 @@ func NewServer( secretsManager SecretsManager, appMetrics telemetry.AppMetrics, ephemeralManager *EphemeralManager, + authManager auth.Manager, ) (*GRPCServer, error) { key, err := wgtypes.GeneratePrivateKey() if err != nil { return nil, err } - var jwtValidator jwtclaims.JWTValidator - - if config.HttpConfig != nil && config.HttpConfig.AuthIssuer != "" && config.HttpConfig.AuthAudience != "" && validateURL(config.HttpConfig.AuthKeysLocation) { - jwtValidator, err = jwtclaims.NewJWTValidator( - ctx, - config.HttpConfig.AuthIssuer, - config.GetAuthAudiences(), - config.HttpConfig.AuthKeysLocation, - config.HttpConfig.IdpSignKeyRefreshEnabled, - ) - if err != nil { - return nil, status.Errorf(codes.Internal, "unable to create new jwt middleware, err: %v", err) - } - } else { - log.WithContext(ctx).Debug("unable to use http config to create new jwt middleware") - } - if appMetrics != nil { // update gauge based on number of connected peers which is equal to open gRPC streams err = appMetrics.GRPCMetrics().RegisterConnectedStreams(func() int64 { @@ -89,16 +72,6 @@ func NewServer( } } - var audience, userIDClaim string - if config.HttpConfig != nil { - audience = config.HttpConfig.AuthAudience - userIDClaim = config.HttpConfig.AuthUserIDClaim - } - jwtClaimsExtractor := jwtclaims.NewClaimsExtractor( - jwtclaims.WithAudience(audience), - jwtclaims.WithUserIDClaim(userIDClaim), - ) - return &GRPCServer{ wgKey: key, // peerKey -> event channel @@ -107,8 +80,7 @@ func NewServer( settingsManager: settingsManager, config: config, secretsManager: secretsManager, - jwtValidator: jwtValidator, - jwtClaimsExtractor: jwtClaimsExtractor, + authManager: authManager, appMetrics: appMetrics, ephemeralManager: ephemeralManager, }, nil @@ -294,26 +266,37 @@ func (s *GRPCServer) cancelPeerRoutines(ctx context.Context, accountID string, p } func (s *GRPCServer) validateToken(ctx context.Context, jwtToken string) (string, error) { - if s.jwtValidator == nil { - return "", status.Error(codes.Internal, "no jwt validator set") + if s.authManager == nil { + return "", status.Errorf(codes.Internal, "missing auth manager") } - token, err := s.jwtValidator.ValidateAndParse(ctx, jwtToken) + userAuth, token, err := s.authManager.ValidateAndParseToken(ctx, jwtToken) if err != nil { return "", status.Errorf(codes.InvalidArgument, "invalid jwt token, err: %v", err) } - claims := s.jwtClaimsExtractor.FromToken(token) + // we need to call this method because if user is new, we will automatically add it to existing or create a new account - _, _, err = s.accountManager.GetAccountIDFromToken(ctx, claims) + accountId, _, err := s.accountManager.GetAccountIDFromUserAuth(ctx, userAuth) if err != nil { return "", status.Errorf(codes.Internal, "unable to fetch account with claims, err: %v", err) } - if err := s.accountManager.CheckUserAccessByJWTGroups(ctx, claims); err != nil { + if userAuth.AccountId != accountId { + log.WithContext(ctx).Debugf("gRPC server sets accountId from ensure, before %s, now %s", userAuth.AccountId, accountId) + userAuth.AccountId = accountId + } + + userAuth, err = s.authManager.EnsureUserAccessByJWTGroups(ctx, userAuth, token) + if err != nil { return "", status.Error(codes.PermissionDenied, err.Error()) } - return claims.UserId, nil + err = s.accountManager.SyncUserJWTGroups(ctx, userAuth) + if err != nil { + log.WithContext(ctx).Errorf("gRPC server failed to sync user JWT groups: %s", err) + } + + return userAuth.UserId, nil } func (s *GRPCServer) acquirePeerLockByUID(ctx context.Context, uniqueID string) (unlock func()) { diff --git a/management/server/http/handler.go b/management/server/http/handler.go index eb1cfb5dd..7dd277daa 100644 --- a/management/server/http/handler.go +++ b/management/server/http/handler.go @@ -13,9 +13,9 @@ import ( "github.com/netbirdio/netbird/management/server/permissions" s "github.com/netbirdio/netbird/management/server" + "github.com/netbirdio/netbird/management/server/auth" "github.com/netbirdio/netbird/management/server/geolocation" nbgroups "github.com/netbirdio/netbird/management/server/groups" - "github.com/netbirdio/netbird/management/server/http/configs" "github.com/netbirdio/netbird/management/server/http/handlers/accounts" "github.com/netbirdio/netbird/management/server/http/handlers/dns" "github.com/netbirdio/netbird/management/server/http/handlers/events" @@ -27,8 +27,12 @@ import ( "github.com/netbirdio/netbird/management/server/http/handlers/setup_keys" "github.com/netbirdio/netbird/management/server/http/handlers/users" "github.com/netbirdio/netbird/management/server/http/middleware" +<<<<<<< HEAD "github.com/netbirdio/netbird/management/server/integrations/integrated_validator" "github.com/netbirdio/netbird/management/server/jwtclaims" +======= + "github.com/netbirdio/netbird/management/server/integrated_validator" +>>>>>>> main nbnetworks "github.com/netbirdio/netbird/management/server/networks" "github.com/netbirdio/netbird/management/server/networks/resources" "github.com/netbirdio/netbird/management/server/networks/routers" @@ -39,55 +43,63 @@ import ( const apiPrefix = "/api" // NewAPIHandler creates the Management service HTTP API handler registering all the available endpoints. +<<<<<<< HEAD func NewAPIHandler(ctx context.Context, accountManager s.AccountManager, networksManager nbnetworks.Manager, resourceManager resources.Manager, routerManager routers.Manager, groupsManager nbgroups.Manager, LocationManager geolocation.Geolocation, jwtValidator jwtclaims.JWTValidator, appMetrics telemetry.AppMetrics, authCfg configs.AuthCfg, integratedValidator integrated_validator.IntegratedValidator, proxyController port_forwarding.Controller, permissionsManager permissions.Manager, peersManager nbpeers.Manager) (http.Handler, error) { claimsExtractor := jwtclaims.NewClaimsExtractor( jwtclaims.WithAudience(authCfg.Audience), jwtclaims.WithUserIDClaim(authCfg.UserIDClaim), ) +======= +func NewAPIHandler( + ctx context.Context, + accountManager s.AccountManager, + networksManager nbnetworks.Manager, + resourceManager resources.Manager, + routerManager routers.Manager, + groupsManager nbgroups.Manager, + LocationManager geolocation.Geolocation, + authManager auth.Manager, + appMetrics telemetry.AppMetrics, + config *s.Config, + integratedValidator integrated_validator.IntegratedValidator) (http.Handler, error) { +>>>>>>> main authMiddleware := middleware.NewAuthMiddleware( - accountManager.GetPATInfo, - jwtValidator.ValidateAndParse, - accountManager.MarkPATUsed, - accountManager.CheckUserAccessByJWTGroups, - claimsExtractor, - authCfg.Audience, - authCfg.UserIDClaim, + authManager, + accountManager.GetAccountIDFromUserAuth, + accountManager.SyncUserJWTGroups, ) corsMiddleware := cors.AllowAll() - claimsExtractor = jwtclaims.NewClaimsExtractor( - jwtclaims.WithAudience(authCfg.Audience), - jwtclaims.WithUserIDClaim(authCfg.UserIDClaim), - ) - - acMiddleware := middleware.NewAccessControl( - authCfg.Audience, - authCfg.UserIDClaim, - accountManager.GetUser) + acMiddleware := middleware.NewAccessControl(accountManager.GetUserFromUserAuth) rootRouter := mux.NewRouter() metricsMiddleware := appMetrics.HTTPMiddleware() prefix := apiPrefix router := rootRouter.PathPrefix(prefix).Subrouter() + router.Use(metricsMiddleware.Handler, corsMiddleware.Handler, authMiddleware.Handler, acMiddleware.Handler) +<<<<<<< HEAD if _, err := integrations.RegisterHandlers(ctx, prefix, router, accountManager, claimsExtractor, integratedValidator, appMetrics.GetMeter(), permissionsManager, peersManager, proxyController); err != nil { +======= + if _, err := integrations.RegisterHandlers(ctx, prefix, router, accountManager, integratedValidator, appMetrics.GetMeter()); err != nil { +>>>>>>> main return nil, fmt.Errorf("register integrations endpoints: %w", err) } - accounts.AddEndpoints(accountManager, authCfg, router) - peers.AddEndpoints(accountManager, authCfg, router) - users.AddEndpoints(accountManager, authCfg, router) - setup_keys.AddEndpoints(accountManager, authCfg, router) - policies.AddEndpoints(accountManager, LocationManager, authCfg, router) - groups.AddEndpoints(accountManager, authCfg, router) - routes.AddEndpoints(accountManager, authCfg, router) - dns.AddEndpoints(accountManager, authCfg, router) - events.AddEndpoints(accountManager, authCfg, router) - networks.AddEndpoints(networksManager, resourceManager, routerManager, groupsManager, accountManager, accountManager.GetAccountIDFromToken, authCfg, router) + accounts.AddEndpoints(accountManager, router) + peers.AddEndpoints(accountManager, router) + users.AddEndpoints(accountManager, router) + setup_keys.AddEndpoints(accountManager, router) + policies.AddEndpoints(accountManager, LocationManager, router) + groups.AddEndpoints(accountManager, router) + routes.AddEndpoints(accountManager, router) + dns.AddEndpoints(accountManager, router) + events.AddEndpoints(accountManager, router) + networks.AddEndpoints(networksManager, resourceManager, routerManager, groupsManager, accountManager, router) return rootRouter, nil } diff --git a/management/server/http/handlers/accounts/accounts_handler.go b/management/server/http/handlers/accounts/accounts_handler.go index a23628cdc..bc0054a7f 100644 --- a/management/server/http/handlers/accounts/accounts_handler.go +++ b/management/server/http/handlers/accounts/accounts_handler.go @@ -9,47 +9,42 @@ import ( "github.com/netbirdio/netbird/management/server" "github.com/netbirdio/netbird/management/server/account" + nbcontext "github.com/netbirdio/netbird/management/server/context" "github.com/netbirdio/netbird/management/server/http/api" - "github.com/netbirdio/netbird/management/server/http/configs" "github.com/netbirdio/netbird/management/server/http/util" - "github.com/netbirdio/netbird/management/server/jwtclaims" "github.com/netbirdio/netbird/management/server/status" "github.com/netbirdio/netbird/management/server/types" ) // handler is a handler that handles the server.Account HTTP endpoints type handler struct { - accountManager server.AccountManager - claimsExtractor *jwtclaims.ClaimsExtractor + accountManager server.AccountManager } -func AddEndpoints(accountManager server.AccountManager, authCfg configs.AuthCfg, router *mux.Router) { - accountsHandler := newHandler(accountManager, authCfg) +func AddEndpoints(accountManager server.AccountManager, router *mux.Router) { + accountsHandler := newHandler(accountManager) router.HandleFunc("/accounts/{accountId}", accountsHandler.updateAccount).Methods("PUT", "OPTIONS") router.HandleFunc("/accounts/{accountId}", accountsHandler.deleteAccount).Methods("DELETE", "OPTIONS") router.HandleFunc("/accounts", accountsHandler.getAllAccounts).Methods("GET", "OPTIONS") } // newHandler creates a new handler HTTP handler -func newHandler(accountManager server.AccountManager, authCfg configs.AuthCfg) *handler { +func newHandler(accountManager server.AccountManager) *handler { return &handler{ accountManager: accountManager, - claimsExtractor: jwtclaims.NewClaimsExtractor( - jwtclaims.WithAudience(authCfg.Audience), - jwtclaims.WithUserIDClaim(authCfg.UserIDClaim), - ), } } // getAllAccounts is HTTP GET handler that returns a list of accounts. Effectively returns just a single account. func (h *handler) getAllAccounts(w http.ResponseWriter, r *http.Request) { - claims := h.claimsExtractor.FromRequestContext(r) - accountID, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims) + userAuth, err := nbcontext.GetUserAuthFromContext(r.Context()) if err != nil { util.WriteError(r.Context(), err, w) return } + accountID, userID := userAuth.AccountId, userAuth.UserId + settings, err := h.accountManager.GetAccountSettings(r.Context(), accountID, userID) if err != nil { util.WriteError(r.Context(), err, w) @@ -62,13 +57,14 @@ func (h *handler) getAllAccounts(w http.ResponseWriter, r *http.Request) { // updateAccount is HTTP PUT handler that updates the provided account. Updates only account settings (server.Settings) func (h *handler) updateAccount(w http.ResponseWriter, r *http.Request) { - claims := h.claimsExtractor.FromRequestContext(r) - _, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims) + userAuth, err := nbcontext.GetUserAuthFromContext(r.Context()) if err != nil { util.WriteError(r.Context(), err, w) return } + _, userID := userAuth.AccountId, userAuth.UserId + vars := mux.Vars(r) accountID := vars["accountId"] if len(accountID) == 0 { @@ -125,7 +121,12 @@ func (h *handler) updateAccount(w http.ResponseWriter, r *http.Request) { // deleteAccount is a HTTP DELETE handler to delete an account func (h *handler) deleteAccount(w http.ResponseWriter, r *http.Request) { - claims := h.claimsExtractor.FromRequestContext(r) + userAuth, err := nbcontext.GetUserAuthFromContext(r.Context()) + if err != nil { + util.WriteError(r.Context(), err, w) + return + } + vars := mux.Vars(r) targetAccountID := vars["accountId"] if len(targetAccountID) == 0 { @@ -133,7 +134,7 @@ func (h *handler) deleteAccount(w http.ResponseWriter, r *http.Request) { return } - err := h.accountManager.DeleteAccount(r.Context(), targetAccountID, claims.UserId) + err = h.accountManager.DeleteAccount(r.Context(), targetAccountID, userAuth.UserId) if err != nil { util.WriteError(r.Context(), err, w) return diff --git a/management/server/http/handlers/accounts/accounts_handler_test.go b/management/server/http/handlers/accounts/accounts_handler_test.go index e8a599863..a8d57a13f 100644 --- a/management/server/http/handlers/accounts/accounts_handler_test.go +++ b/management/server/http/handlers/accounts/accounts_handler_test.go @@ -13,19 +13,16 @@ import ( "github.com/gorilla/mux" "github.com/stretchr/testify/assert" + nbcontext "github.com/netbirdio/netbird/management/server/context" "github.com/netbirdio/netbird/management/server/http/api" - "github.com/netbirdio/netbird/management/server/jwtclaims" "github.com/netbirdio/netbird/management/server/mock_server" "github.com/netbirdio/netbird/management/server/status" "github.com/netbirdio/netbird/management/server/types" ) -func initAccountsTestData(account *types.Account, admin *types.User) *handler { +func initAccountsTestData(account *types.Account) *handler { return &handler{ accountManager: &mock_server.MockAccountManager{ - GetAccountIDFromTokenFunc: func(ctx context.Context, claims jwtclaims.AuthorizationClaims) (string, string, error) { - return account.Id, admin.Id, nil - }, GetAccountSettingsFunc: func(ctx context.Context, accountID string, userID string) (*types.Settings, error) { return account.Settings, nil }, @@ -44,15 +41,6 @@ func initAccountsTestData(account *types.Account, admin *types.User) *handler { return accCopy, nil }, }, - claimsExtractor: jwtclaims.NewClaimsExtractor( - jwtclaims.WithFromRequestContext(func(r *http.Request) jwtclaims.AuthorizationClaims { - return jwtclaims.AuthorizationClaims{ - UserId: "test_user", - Domain: "hotmail.com", - AccountId: "test_account", - } - }), - ), } } @@ -75,7 +63,7 @@ func TestAccounts_AccountsHandler(t *testing.T) { PeerLoginExpiration: time.Hour, RegularUsersViewBlocked: true, }, - }, adminUser) + }) tt := []struct { name string @@ -191,6 +179,11 @@ func TestAccounts_AccountsHandler(t *testing.T) { t.Run(tc.name, func(t *testing.T) { recorder := httptest.NewRecorder() req := httptest.NewRequest(tc.requestType, tc.requestPath, tc.requestBody) + req = nbcontext.SetUserAuthInRequest(req, nbcontext.UserAuth{ + UserId: adminUser.Id, + AccountId: accountID, + Domain: "hotmail.com", + }) router := mux.NewRouter() router.HandleFunc("/api/accounts", handler.getAllAccounts).Methods("GET") diff --git a/management/server/http/handlers/dns/dns_settings_handler.go b/management/server/http/handlers/dns/dns_settings_handler.go index 112eee179..6ff938369 100644 --- a/management/server/http/handlers/dns/dns_settings_handler.go +++ b/management/server/http/handlers/dns/dns_settings_handler.go @@ -8,51 +8,44 @@ import ( log "github.com/sirupsen/logrus" "github.com/netbirdio/netbird/management/server" + nbcontext "github.com/netbirdio/netbird/management/server/context" "github.com/netbirdio/netbird/management/server/http/api" - "github.com/netbirdio/netbird/management/server/http/configs" "github.com/netbirdio/netbird/management/server/http/util" - "github.com/netbirdio/netbird/management/server/jwtclaims" "github.com/netbirdio/netbird/management/server/types" ) // dnsSettingsHandler is a handler that returns the DNS settings of the account type dnsSettingsHandler struct { - accountManager server.AccountManager - claimsExtractor *jwtclaims.ClaimsExtractor + accountManager server.AccountManager } -func AddEndpoints(accountManager server.AccountManager, authCfg configs.AuthCfg, router *mux.Router) { - addDNSSettingEndpoint(accountManager, authCfg, router) - addDNSNameserversEndpoint(accountManager, authCfg, router) +func AddEndpoints(accountManager server.AccountManager, router *mux.Router) { + addDNSSettingEndpoint(accountManager, router) + addDNSNameserversEndpoint(accountManager, router) } -func addDNSSettingEndpoint(accountManager server.AccountManager, authCfg configs.AuthCfg, router *mux.Router) { - dnsSettingsHandler := newDNSSettingsHandler(accountManager, authCfg) +func addDNSSettingEndpoint(accountManager server.AccountManager, router *mux.Router) { + dnsSettingsHandler := newDNSSettingsHandler(accountManager) router.HandleFunc("/dns/settings", dnsSettingsHandler.getDNSSettings).Methods("GET", "OPTIONS") router.HandleFunc("/dns/settings", dnsSettingsHandler.updateDNSSettings).Methods("PUT", "OPTIONS") } // newDNSSettingsHandler returns a new instance of dnsSettingsHandler handler -func newDNSSettingsHandler(accountManager server.AccountManager, authCfg configs.AuthCfg) *dnsSettingsHandler { - return &dnsSettingsHandler{ - accountManager: accountManager, - claimsExtractor: jwtclaims.NewClaimsExtractor( - jwtclaims.WithAudience(authCfg.Audience), - jwtclaims.WithUserIDClaim(authCfg.UserIDClaim), - ), - } +func newDNSSettingsHandler(accountManager server.AccountManager) *dnsSettingsHandler { + return &dnsSettingsHandler{accountManager: accountManager} } // getDNSSettings returns the DNS settings for the account func (h *dnsSettingsHandler) getDNSSettings(w http.ResponseWriter, r *http.Request) { - claims := h.claimsExtractor.FromRequestContext(r) - accountID, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims) + userAuth, err := nbcontext.GetUserAuthFromContext(r.Context()) if err != nil { log.WithContext(r.Context()).Error(err) http.Redirect(w, r, "/", http.StatusInternalServerError) return } + accountID, userID := userAuth.AccountId, userAuth.UserId + dnsSettings, err := h.accountManager.GetDNSSettings(r.Context(), accountID, userID) if err != nil { util.WriteError(r.Context(), err, w) @@ -68,13 +61,14 @@ func (h *dnsSettingsHandler) getDNSSettings(w http.ResponseWriter, r *http.Reque // updateDNSSettings handles update to DNS settings of an account func (h *dnsSettingsHandler) updateDNSSettings(w http.ResponseWriter, r *http.Request) { - claims := h.claimsExtractor.FromRequestContext(r) - accountID, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims) + userAuth, err := nbcontext.GetUserAuthFromContext(r.Context()) if err != nil { util.WriteError(r.Context(), err, w) return } + accountID, userID := userAuth.AccountId, userAuth.UserId + var req api.PutApiDnsSettingsJSONRequestBody err = json.NewDecoder(r.Body).Decode(&req) if err != nil { diff --git a/management/server/http/handlers/dns/dns_settings_handler_test.go b/management/server/http/handlers/dns/dns_settings_handler_test.go index 9ca1dc032..ca81adf43 100644 --- a/management/server/http/handlers/dns/dns_settings_handler_test.go +++ b/management/server/http/handlers/dns/dns_settings_handler_test.go @@ -17,7 +17,8 @@ import ( "github.com/gorilla/mux" - "github.com/netbirdio/netbird/management/server/jwtclaims" + nbcontext "github.com/netbirdio/netbird/management/server/context" + "github.com/netbirdio/netbird/management/server/mock_server" ) @@ -52,19 +53,7 @@ func initDNSSettingsTestData() *dnsSettingsHandler { } return status.Errorf(status.InvalidArgument, "the dns settings provided are nil") }, - GetAccountIDFromTokenFunc: func(ctx context.Context, _ jwtclaims.AuthorizationClaims) (string, string, error) { - return testingDNSSettingsAccount.Id, testingDNSSettingsAccount.Users[testDNSSettingsUserID].Id, nil - }, }, - claimsExtractor: jwtclaims.NewClaimsExtractor( - jwtclaims.WithFromRequestContext(func(r *http.Request) jwtclaims.AuthorizationClaims { - return jwtclaims.AuthorizationClaims{ - UserId: "test_user", - Domain: "hotmail.com", - AccountId: testDNSSettingsAccountID, - } - }), - ), } } @@ -118,6 +107,11 @@ func TestDNSSettingsHandlers(t *testing.T) { t.Run(tc.name, func(t *testing.T) { recorder := httptest.NewRecorder() req := httptest.NewRequest(tc.requestType, tc.requestPath, tc.requestBody) + req = nbcontext.SetUserAuthInRequest(req, nbcontext.UserAuth{ + UserId: testingDNSSettingsAccount.Users[testDNSSettingsUserID].Id, + AccountId: testingDNSSettingsAccount.Id, + Domain: testingDNSSettingsAccount.Domain, + }) router := mux.NewRouter() router.HandleFunc("/api/dns/settings", p.getDNSSettings).Methods("GET") diff --git a/management/server/http/handlers/dns/nameservers_handler.go b/management/server/http/handlers/dns/nameservers_handler.go index 09047e231..33d070477 100644 --- a/management/server/http/handlers/dns/nameservers_handler.go +++ b/management/server/http/handlers/dns/nameservers_handler.go @@ -10,21 +10,19 @@ import ( nbdns "github.com/netbirdio/netbird/dns" "github.com/netbirdio/netbird/management/server" + nbcontext "github.com/netbirdio/netbird/management/server/context" "github.com/netbirdio/netbird/management/server/http/api" - "github.com/netbirdio/netbird/management/server/http/configs" "github.com/netbirdio/netbird/management/server/http/util" - "github.com/netbirdio/netbird/management/server/jwtclaims" "github.com/netbirdio/netbird/management/server/status" ) // nameserversHandler is the nameserver group handler of the account type nameserversHandler struct { - accountManager server.AccountManager - claimsExtractor *jwtclaims.ClaimsExtractor + accountManager server.AccountManager } -func addDNSNameserversEndpoint(accountManager server.AccountManager, authCfg configs.AuthCfg, router *mux.Router) { - nameserversHandler := newNameserversHandler(accountManager, authCfg) +func addDNSNameserversEndpoint(accountManager server.AccountManager, router *mux.Router) { + nameserversHandler := newNameserversHandler(accountManager) router.HandleFunc("/dns/nameservers", nameserversHandler.getAllNameservers).Methods("GET", "OPTIONS") router.HandleFunc("/dns/nameservers", nameserversHandler.createNameserverGroup).Methods("POST", "OPTIONS") router.HandleFunc("/dns/nameservers/{nsgroupId}", nameserversHandler.updateNameserverGroup).Methods("PUT", "OPTIONS") @@ -33,26 +31,21 @@ func addDNSNameserversEndpoint(accountManager server.AccountManager, authCfg con } // newNameserversHandler returns a new instance of nameserversHandler handler -func newNameserversHandler(accountManager server.AccountManager, authCfg configs.AuthCfg) *nameserversHandler { - return &nameserversHandler{ - accountManager: accountManager, - claimsExtractor: jwtclaims.NewClaimsExtractor( - jwtclaims.WithAudience(authCfg.Audience), - jwtclaims.WithUserIDClaim(authCfg.UserIDClaim), - ), - } +func newNameserversHandler(accountManager server.AccountManager) *nameserversHandler { + return &nameserversHandler{accountManager: accountManager} } // getAllNameservers returns the list of nameserver groups for the account func (h *nameserversHandler) getAllNameservers(w http.ResponseWriter, r *http.Request) { - claims := h.claimsExtractor.FromRequestContext(r) - accountID, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims) + userAuth, err := nbcontext.GetUserAuthFromContext(r.Context()) if err != nil { log.WithContext(r.Context()).Error(err) http.Redirect(w, r, "/", http.StatusInternalServerError) return } + accountID, userID := userAuth.AccountId, userAuth.UserId + nsGroups, err := h.accountManager.ListNameServerGroups(r.Context(), accountID, userID) if err != nil { util.WriteError(r.Context(), err, w) @@ -69,13 +62,14 @@ func (h *nameserversHandler) getAllNameservers(w http.ResponseWriter, r *http.Re // createNameserverGroup handles nameserver group creation request func (h *nameserversHandler) createNameserverGroup(w http.ResponseWriter, r *http.Request) { - claims := h.claimsExtractor.FromRequestContext(r) - accountID, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims) + userAuth, err := nbcontext.GetUserAuthFromContext(r.Context()) if err != nil { util.WriteError(r.Context(), err, w) return } + accountID, userID := userAuth.AccountId, userAuth.UserId + var req api.PostApiDnsNameserversJSONRequestBody err = json.NewDecoder(r.Body).Decode(&req) if err != nil { @@ -102,13 +96,14 @@ func (h *nameserversHandler) createNameserverGroup(w http.ResponseWriter, r *htt // updateNameserverGroup handles update to a nameserver group identified by a given ID func (h *nameserversHandler) updateNameserverGroup(w http.ResponseWriter, r *http.Request) { - claims := h.claimsExtractor.FromRequestContext(r) - accountID, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims) + userAuth, err := nbcontext.GetUserAuthFromContext(r.Context()) if err != nil { util.WriteError(r.Context(), err, w) return } + accountID, userID := userAuth.AccountId, userAuth.UserId + nsGroupID := mux.Vars(r)["nsgroupId"] if len(nsGroupID) == 0 { util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "invalid nameserver group ID"), w) @@ -153,13 +148,14 @@ func (h *nameserversHandler) updateNameserverGroup(w http.ResponseWriter, r *htt // deleteNameserverGroup handles nameserver group deletion request func (h *nameserversHandler) deleteNameserverGroup(w http.ResponseWriter, r *http.Request) { - claims := h.claimsExtractor.FromRequestContext(r) - accountID, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims) + userAuth, err := nbcontext.GetUserAuthFromContext(r.Context()) if err != nil { util.WriteError(r.Context(), err, w) return } + accountID, userID := userAuth.AccountId, userAuth.UserId + nsGroupID := mux.Vars(r)["nsgroupId"] if len(nsGroupID) == 0 { util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "invalid nameserver group ID"), w) @@ -177,14 +173,14 @@ func (h *nameserversHandler) deleteNameserverGroup(w http.ResponseWriter, r *htt // getNameserverGroup handles a nameserver group Get request identified by ID func (h *nameserversHandler) getNameserverGroup(w http.ResponseWriter, r *http.Request) { - claims := h.claimsExtractor.FromRequestContext(r) - accountID, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims) + userAuth, err := nbcontext.GetUserAuthFromContext(r.Context()) if err != nil { - log.WithContext(r.Context()).Error(err) - http.Redirect(w, r, "/", http.StatusInternalServerError) + util.WriteError(r.Context(), err, w) return } + accountID, userID := userAuth.AccountId, userAuth.UserId + nsGroupID := mux.Vars(r)["nsgroupId"] if len(nsGroupID) == 0 { util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "invalid nameserver group ID"), w) diff --git a/management/server/http/handlers/dns/nameservers_handler_test.go b/management/server/http/handlers/dns/nameservers_handler_test.go index c6561e4d8..45283bc37 100644 --- a/management/server/http/handlers/dns/nameservers_handler_test.go +++ b/management/server/http/handlers/dns/nameservers_handler_test.go @@ -18,7 +18,8 @@ import ( "github.com/gorilla/mux" - "github.com/netbirdio/netbird/management/server/jwtclaims" + nbcontext "github.com/netbirdio/netbird/management/server/context" + "github.com/netbirdio/netbird/management/server/mock_server" ) @@ -81,19 +82,7 @@ func initNameserversTestData() *nameserversHandler { } return status.Errorf(status.NotFound, "nameserver group with ID %s was not found", nsGroupToSave.ID) }, - GetAccountIDFromTokenFunc: func(_ context.Context, claims jwtclaims.AuthorizationClaims) (string, string, error) { - return claims.AccountId, claims.UserId, nil - }, }, - claimsExtractor: jwtclaims.NewClaimsExtractor( - jwtclaims.WithFromRequestContext(func(r *http.Request) jwtclaims.AuthorizationClaims { - return jwtclaims.AuthorizationClaims{ - UserId: "test_user", - Domain: "hotmail.com", - AccountId: testNSGroupAccountID, - } - }), - ), } } @@ -204,6 +193,11 @@ func TestNameserversHandlers(t *testing.T) { t.Run(tc.name, func(t *testing.T) { recorder := httptest.NewRecorder() req := httptest.NewRequest(tc.requestType, tc.requestPath, tc.requestBody) + req = nbcontext.SetUserAuthInRequest(req, nbcontext.UserAuth{ + UserId: "test_user", + AccountId: testNSGroupAccountID, + Domain: "hotmail.com", + }) router := mux.NewRouter() router.HandleFunc("/api/dns/nameservers/{nsgroupId}", p.getNameserverGroup).Methods("GET") diff --git a/management/server/http/handlers/events/events_handler.go b/management/server/http/handlers/events/events_handler.go index 62da59535..0fb2295a8 100644 --- a/management/server/http/handlers/events/events_handler.go +++ b/management/server/http/handlers/events/events_handler.go @@ -10,44 +10,37 @@ import ( "github.com/netbirdio/netbird/management/server" "github.com/netbirdio/netbird/management/server/activity" + nbcontext "github.com/netbirdio/netbird/management/server/context" "github.com/netbirdio/netbird/management/server/http/api" - "github.com/netbirdio/netbird/management/server/http/configs" "github.com/netbirdio/netbird/management/server/http/util" - "github.com/netbirdio/netbird/management/server/jwtclaims" ) // handler HTTP handler type handler struct { - accountManager server.AccountManager - claimsExtractor *jwtclaims.ClaimsExtractor + accountManager server.AccountManager } -func AddEndpoints(accountManager server.AccountManager, authCfg configs.AuthCfg, router *mux.Router) { - eventsHandler := newHandler(accountManager, authCfg) +func AddEndpoints(accountManager server.AccountManager, router *mux.Router) { + eventsHandler := newHandler(accountManager) router.HandleFunc("/events", eventsHandler.getAllEvents).Methods("GET", "OPTIONS") } // newHandler creates a new events handler -func newHandler(accountManager server.AccountManager, authCfg configs.AuthCfg) *handler { - return &handler{ - accountManager: accountManager, - claimsExtractor: jwtclaims.NewClaimsExtractor( - jwtclaims.WithAudience(authCfg.Audience), - jwtclaims.WithUserIDClaim(authCfg.UserIDClaim), - ), - } +func newHandler(accountManager server.AccountManager) *handler { + return &handler{accountManager: accountManager} } // getAllEvents list of the given account func (h *handler) getAllEvents(w http.ResponseWriter, r *http.Request) { - claims := h.claimsExtractor.FromRequestContext(r) - accountID, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims) + userAuth, err := nbcontext.GetUserAuthFromContext(r.Context()) if err != nil { log.WithContext(r.Context()).Error(err) http.Redirect(w, r, "/", http.StatusInternalServerError) return } + accountID, userID := userAuth.AccountId, userAuth.UserId + accountEvents, err := h.accountManager.GetEvents(r.Context(), accountID, userID) if err != nil { util.WriteError(r.Context(), err, w) diff --git a/management/server/http/handlers/events/events_handler_test.go b/management/server/http/handlers/events/events_handler_test.go index fd603f289..3a643fe90 100644 --- a/management/server/http/handlers/events/events_handler_test.go +++ b/management/server/http/handlers/events/events_handler_test.go @@ -13,9 +13,10 @@ import ( "github.com/gorilla/mux" "github.com/stretchr/testify/assert" + nbcontext "github.com/netbirdio/netbird/management/server/context" + "github.com/netbirdio/netbird/management/server/activity" "github.com/netbirdio/netbird/management/server/http/api" - "github.com/netbirdio/netbird/management/server/jwtclaims" "github.com/netbirdio/netbird/management/server/mock_server" "github.com/netbirdio/netbird/management/server/types" ) @@ -29,22 +30,10 @@ func initEventsTestData(account string, events ...*activity.Event) *handler { } return []*activity.Event{}, nil }, - GetAccountIDFromTokenFunc: func(_ context.Context, claims jwtclaims.AuthorizationClaims) (string, string, error) { - return claims.AccountId, claims.UserId, nil - }, GetUsersFromAccountFunc: func(_ context.Context, accountID, userID string) (map[string]*types.UserInfo, error) { return make(map[string]*types.UserInfo), nil }, }, - claimsExtractor: jwtclaims.NewClaimsExtractor( - jwtclaims.WithFromRequestContext(func(r *http.Request) jwtclaims.AuthorizationClaims { - return jwtclaims.AuthorizationClaims{ - UserId: "test_user", - Domain: "hotmail.com", - AccountId: "test_account", - } - }), - ), } } @@ -199,6 +188,11 @@ func TestEvents_GetEvents(t *testing.T) { t.Run(tc.name, func(t *testing.T) { recorder := httptest.NewRecorder() req := httptest.NewRequest(tc.requestType, tc.requestPath, tc.requestBody) + req = nbcontext.SetUserAuthInRequest(req, nbcontext.UserAuth{ + UserId: "test_user", + Domain: "hotmail.com", + AccountId: "test_account", + }) router := mux.NewRouter() router.HandleFunc("/api/events/", handler.getAllEvents).Methods("GET") diff --git a/management/server/http/handlers/groups/groups_handler.go b/management/server/http/handlers/groups/groups_handler.go index 383a33dea..2d0b8bdbd 100644 --- a/management/server/http/handlers/groups/groups_handler.go +++ b/management/server/http/handlers/groups/groups_handler.go @@ -7,24 +7,23 @@ import ( "github.com/gorilla/mux" log "github.com/sirupsen/logrus" + nbcontext "github.com/netbirdio/netbird/management/server/context" + nbpeer "github.com/netbirdio/netbird/management/server/peer" + "github.com/netbirdio/netbird/management/server" "github.com/netbirdio/netbird/management/server/http/api" - "github.com/netbirdio/netbird/management/server/http/configs" "github.com/netbirdio/netbird/management/server/http/util" - "github.com/netbirdio/netbird/management/server/jwtclaims" - nbpeer "github.com/netbirdio/netbird/management/server/peer" "github.com/netbirdio/netbird/management/server/status" "github.com/netbirdio/netbird/management/server/types" ) // handler is a handler that returns groups of the account type handler struct { - accountManager server.AccountManager - claimsExtractor *jwtclaims.ClaimsExtractor + accountManager server.AccountManager } -func AddEndpoints(accountManager server.AccountManager, authCfg configs.AuthCfg, router *mux.Router) { - groupsHandler := newHandler(accountManager, authCfg) +func AddEndpoints(accountManager server.AccountManager, router *mux.Router) { + groupsHandler := newHandler(accountManager) router.HandleFunc("/groups", groupsHandler.getAllGroups).Methods("GET", "OPTIONS") router.HandleFunc("/groups", groupsHandler.createGroup).Methods("POST", "OPTIONS") router.HandleFunc("/groups/{groupId}", groupsHandler.updateGroup).Methods("PUT", "OPTIONS") @@ -33,25 +32,21 @@ func AddEndpoints(accountManager server.AccountManager, authCfg configs.AuthCfg, } // newHandler creates a new groups handler -func newHandler(accountManager server.AccountManager, authCfg configs.AuthCfg) *handler { +func newHandler(accountManager server.AccountManager) *handler { return &handler{ accountManager: accountManager, - claimsExtractor: jwtclaims.NewClaimsExtractor( - jwtclaims.WithAudience(authCfg.Audience), - jwtclaims.WithUserIDClaim(authCfg.UserIDClaim), - ), } } // getAllGroups list for the account func (h *handler) getAllGroups(w http.ResponseWriter, r *http.Request) { - claims := h.claimsExtractor.FromRequestContext(r) - accountID, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims) + userAuth, err := nbcontext.GetUserAuthFromContext(r.Context()) if err != nil { log.WithContext(r.Context()).Error(err) http.Redirect(w, r, "/", http.StatusInternalServerError) return } + accountID, userID := userAuth.AccountId, userAuth.UserId groups, err := h.accountManager.GetAllGroups(r.Context(), accountID, userID) if err != nil { @@ -75,13 +70,14 @@ func (h *handler) getAllGroups(w http.ResponseWriter, r *http.Request) { // updateGroup handles update to a group identified by a given ID func (h *handler) updateGroup(w http.ResponseWriter, r *http.Request) { - claims := h.claimsExtractor.FromRequestContext(r) - accountID, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims) + userAuth, err := nbcontext.GetUserAuthFromContext(r.Context()) if err != nil { util.WriteError(r.Context(), err, w) return } + accountID, userID := userAuth.AccountId, userAuth.UserId + vars := mux.Vars(r) groupID, ok := vars["groupId"] if !ok { @@ -164,13 +160,14 @@ func (h *handler) updateGroup(w http.ResponseWriter, r *http.Request) { // createGroup handles group creation request func (h *handler) createGroup(w http.ResponseWriter, r *http.Request) { - claims := h.claimsExtractor.FromRequestContext(r) - accountID, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims) + userAuth, err := nbcontext.GetUserAuthFromContext(r.Context()) if err != nil { util.WriteError(r.Context(), err, w) return } + accountID, userID := userAuth.AccountId, userAuth.UserId + var req api.PostApiGroupsJSONRequestBody err = json.NewDecoder(r.Body).Decode(&req) if err != nil { @@ -223,13 +220,14 @@ func (h *handler) createGroup(w http.ResponseWriter, r *http.Request) { // deleteGroup handles group deletion request func (h *handler) deleteGroup(w http.ResponseWriter, r *http.Request) { - claims := h.claimsExtractor.FromRequestContext(r) - accountID, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims) + userAuth, err := nbcontext.GetUserAuthFromContext(r.Context()) if err != nil { util.WriteError(r.Context(), err, w) return } + accountID, userID := userAuth.AccountId, userAuth.UserId + groupID := mux.Vars(r)["groupId"] if len(groupID) == 0 { util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "invalid group ID"), w) @@ -253,12 +251,13 @@ func (h *handler) deleteGroup(w http.ResponseWriter, r *http.Request) { // getGroup returns a group func (h *handler) getGroup(w http.ResponseWriter, r *http.Request) { - claims := h.claimsExtractor.FromRequestContext(r) - accountID, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims) + userAuth, err := nbcontext.GetUserAuthFromContext(r.Context()) if err != nil { util.WriteError(r.Context(), err, w) return } + + accountID, userID := userAuth.AccountId, userAuth.UserId groupID := mux.Vars(r)["groupId"] if len(groupID) == 0 { util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "invalid group ID"), w) diff --git a/management/server/http/handlers/groups/groups_handler_test.go b/management/server/http/handlers/groups/groups_handler_test.go index c82d33240..f4ac34e53 100644 --- a/management/server/http/handlers/groups/groups_handler_test.go +++ b/management/server/http/handlers/groups/groups_handler_test.go @@ -18,9 +18,9 @@ import ( "golang.org/x/exp/maps" "github.com/netbirdio/netbird/management/server" + nbcontext "github.com/netbirdio/netbird/management/server/context" "github.com/netbirdio/netbird/management/server/http/api" "github.com/netbirdio/netbird/management/server/http/util" - "github.com/netbirdio/netbird/management/server/jwtclaims" "github.com/netbirdio/netbird/management/server/mock_server" nbpeer "github.com/netbirdio/netbird/management/server/peer" "github.com/netbirdio/netbird/management/server/status" @@ -59,9 +59,6 @@ func initGroupTestData(initGroups ...*types.Group) *handler { return group, nil }, - GetAccountIDFromTokenFunc: func(_ context.Context, claims jwtclaims.AuthorizationClaims) (string, string, error) { - return claims.AccountId, claims.UserId, nil - }, GetGroupByNameFunc: func(ctx context.Context, groupName, _ string) (*types.Group, error) { if groupName == "All" { return &types.Group{ID: "id-all", Name: "All", Issued: types.GroupIssuedAPI}, nil @@ -87,15 +84,6 @@ func initGroupTestData(initGroups ...*types.Group) *handler { return nil }, }, - claimsExtractor: jwtclaims.NewClaimsExtractor( - jwtclaims.WithFromRequestContext(func(r *http.Request) jwtclaims.AuthorizationClaims { - return jwtclaims.AuthorizationClaims{ - UserId: "test_user", - Domain: "hotmail.com", - AccountId: "test_id", - } - }), - ), } } @@ -134,6 +122,11 @@ func TestGetGroup(t *testing.T) { t.Run(tc.name, func(t *testing.T) { recorder := httptest.NewRecorder() req := httptest.NewRequest(tc.requestType, tc.requestPath, tc.requestBody) + req = nbcontext.SetUserAuthInRequest(req, nbcontext.UserAuth{ + UserId: "test_user", + Domain: "hotmail.com", + AccountId: "test_id", + }) router := mux.NewRouter() router.HandleFunc("/api/groups/{groupId}", p.getGroup).Methods("GET") @@ -255,6 +248,11 @@ func TestWriteGroup(t *testing.T) { t.Run(tc.name, func(t *testing.T) { recorder := httptest.NewRecorder() req := httptest.NewRequest(tc.requestType, tc.requestPath, tc.requestBody) + req = nbcontext.SetUserAuthInRequest(req, nbcontext.UserAuth{ + UserId: "test_user", + Domain: "hotmail.com", + AccountId: "test_id", + }) router := mux.NewRouter() router.HandleFunc("/api/groups", p.createGroup).Methods("POST") @@ -332,7 +330,11 @@ func TestDeleteGroup(t *testing.T) { t.Run(tc.name, func(t *testing.T) { recorder := httptest.NewRecorder() req := httptest.NewRequest(tc.requestType, tc.requestPath, nil) - + req = nbcontext.SetUserAuthInRequest(req, nbcontext.UserAuth{ + UserId: "test_user", + Domain: "hotmail.com", + AccountId: "test_id", + }) router := mux.NewRouter() router.HandleFunc("/api/groups/{groupId}", p.deleteGroup).Methods("DELETE") router.ServeHTTP(recorder, req) diff --git a/management/server/http/handlers/networks/handler.go b/management/server/http/handlers/networks/handler.go index f716348d6..bb6b97267 100644 --- a/management/server/http/handlers/networks/handler.go +++ b/management/server/http/handlers/networks/handler.go @@ -10,11 +10,10 @@ import ( log "github.com/sirupsen/logrus" s "github.com/netbirdio/netbird/management/server" + nbcontext "github.com/netbirdio/netbird/management/server/context" "github.com/netbirdio/netbird/management/server/groups" "github.com/netbirdio/netbird/management/server/http/api" - "github.com/netbirdio/netbird/management/server/http/configs" "github.com/netbirdio/netbird/management/server/http/util" - "github.com/netbirdio/netbird/management/server/jwtclaims" "github.com/netbirdio/netbird/management/server/networks" "github.com/netbirdio/netbird/management/server/networks/resources" "github.com/netbirdio/netbird/management/server/networks/routers" @@ -31,16 +30,14 @@ type handler struct { routerManager routers.Manager accountManager s.AccountManager - groupsManager groups.Manager - extractFromToken func(ctx context.Context, claims jwtclaims.AuthorizationClaims) (string, string, error) - claimsExtractor *jwtclaims.ClaimsExtractor + groupsManager groups.Manager } -func AddEndpoints(networksManager networks.Manager, resourceManager resources.Manager, routerManager routers.Manager, groupsManager groups.Manager, accountManager s.AccountManager, extractFromToken func(ctx context.Context, claims jwtclaims.AuthorizationClaims) (string, string, error), authCfg configs.AuthCfg, router *mux.Router) { - addRouterEndpoints(routerManager, extractFromToken, authCfg, router) - addResourceEndpoints(resourceManager, groupsManager, extractFromToken, authCfg, router) +func AddEndpoints(networksManager networks.Manager, resourceManager resources.Manager, routerManager routers.Manager, groupsManager groups.Manager, accountManager s.AccountManager, router *mux.Router) { + addRouterEndpoints(routerManager, router) + addResourceEndpoints(resourceManager, groupsManager, router) - networksHandler := newHandler(networksManager, resourceManager, routerManager, groupsManager, accountManager, extractFromToken, authCfg) + networksHandler := newHandler(networksManager, resourceManager, routerManager, groupsManager, accountManager) router.HandleFunc("/networks", networksHandler.getAllNetworks).Methods("GET", "OPTIONS") router.HandleFunc("/networks", networksHandler.createNetwork).Methods("POST", "OPTIONS") router.HandleFunc("/networks/{networkId}", networksHandler.getNetwork).Methods("GET", "OPTIONS") @@ -48,29 +45,25 @@ func AddEndpoints(networksManager networks.Manager, resourceManager resources.Ma router.HandleFunc("/networks/{networkId}", networksHandler.deleteNetwork).Methods("DELETE", "OPTIONS") } -func newHandler(networksManager networks.Manager, resourceManager resources.Manager, routerManager routers.Manager, groupsManager groups.Manager, accountManager s.AccountManager, extractFromToken func(ctx context.Context, claims jwtclaims.AuthorizationClaims) (string, string, error), authCfg configs.AuthCfg) *handler { +func newHandler(networksManager networks.Manager, resourceManager resources.Manager, routerManager routers.Manager, groupsManager groups.Manager, accountManager s.AccountManager) *handler { return &handler{ - networksManager: networksManager, - resourceManager: resourceManager, - routerManager: routerManager, - groupsManager: groupsManager, - accountManager: accountManager, - extractFromToken: extractFromToken, - claimsExtractor: jwtclaims.NewClaimsExtractor( - jwtclaims.WithAudience(authCfg.Audience), - jwtclaims.WithUserIDClaim(authCfg.UserIDClaim), - ), + networksManager: networksManager, + resourceManager: resourceManager, + routerManager: routerManager, + groupsManager: groupsManager, + accountManager: accountManager, } } func (h *handler) getAllNetworks(w http.ResponseWriter, r *http.Request) { - claims := h.claimsExtractor.FromRequestContext(r) - accountID, userID, err := h.extractFromToken(r.Context(), claims) + userAuth, err := nbcontext.GetUserAuthFromContext(r.Context()) if err != nil { util.WriteError(r.Context(), err, w) return } + accountID, userID := userAuth.AccountId, userAuth.UserId + networks, err := h.networksManager.GetAllNetworks(r.Context(), accountID, userID) if err != nil { util.WriteError(r.Context(), err, w) @@ -105,12 +98,12 @@ func (h *handler) getAllNetworks(w http.ResponseWriter, r *http.Request) { } func (h *handler) createNetwork(w http.ResponseWriter, r *http.Request) { - claims := h.claimsExtractor.FromRequestContext(r) - accountID, userID, err := h.extractFromToken(r.Context(), claims) + userAuth, err := nbcontext.GetUserAuthFromContext(r.Context()) if err != nil { util.WriteError(r.Context(), err, w) return } + accountID, userID := userAuth.AccountId, userAuth.UserId var req api.NetworkRequest err = json.NewDecoder(r.Body).Decode(&req) @@ -141,12 +134,12 @@ func (h *handler) createNetwork(w http.ResponseWriter, r *http.Request) { } func (h *handler) getNetwork(w http.ResponseWriter, r *http.Request) { - claims := h.claimsExtractor.FromRequestContext(r) - accountID, userID, err := h.extractFromToken(r.Context(), claims) + userAuth, err := nbcontext.GetUserAuthFromContext(r.Context()) if err != nil { util.WriteError(r.Context(), err, w) return } + accountID, userID := userAuth.AccountId, userAuth.UserId vars := mux.Vars(r) networkID := vars["networkId"] @@ -179,13 +172,13 @@ func (h *handler) getNetwork(w http.ResponseWriter, r *http.Request) { } func (h *handler) updateNetwork(w http.ResponseWriter, r *http.Request) { - claims := h.claimsExtractor.FromRequestContext(r) - accountID, userID, err := h.extractFromToken(r.Context(), claims) + userAuth, err := nbcontext.GetUserAuthFromContext(r.Context()) if err != nil { util.WriteError(r.Context(), err, w) return } + accountID, userID := userAuth.AccountId, userAuth.UserId vars := mux.Vars(r) networkID := vars["networkId"] if len(networkID) == 0 { @@ -229,13 +222,13 @@ func (h *handler) updateNetwork(w http.ResponseWriter, r *http.Request) { } func (h *handler) deleteNetwork(w http.ResponseWriter, r *http.Request) { - claims := h.claimsExtractor.FromRequestContext(r) - accountID, userID, err := h.extractFromToken(r.Context(), claims) + userAuth, err := nbcontext.GetUserAuthFromContext(r.Context()) if err != nil { util.WriteError(r.Context(), err, w) return } + accountID, userID := userAuth.AccountId, userAuth.UserId vars := mux.Vars(r) networkID := vars["networkId"] if len(networkID) == 0 { diff --git a/management/server/http/handlers/networks/resources_handler.go b/management/server/http/handlers/networks/resources_handler.go index f2dc8e3b8..fba7026e8 100644 --- a/management/server/http/handlers/networks/resources_handler.go +++ b/management/server/http/handlers/networks/resources_handler.go @@ -1,30 +1,26 @@ package networks import ( - "context" "encoding/json" "net/http" "github.com/gorilla/mux" + nbcontext "github.com/netbirdio/netbird/management/server/context" "github.com/netbirdio/netbird/management/server/groups" "github.com/netbirdio/netbird/management/server/http/api" - "github.com/netbirdio/netbird/management/server/http/configs" "github.com/netbirdio/netbird/management/server/http/util" - "github.com/netbirdio/netbird/management/server/jwtclaims" "github.com/netbirdio/netbird/management/server/networks/resources" "github.com/netbirdio/netbird/management/server/networks/resources/types" ) type resourceHandler struct { - resourceManager resources.Manager - groupsManager groups.Manager - extractFromToken func(ctx context.Context, claims jwtclaims.AuthorizationClaims) (string, string, error) - claimsExtractor *jwtclaims.ClaimsExtractor + resourceManager resources.Manager + groupsManager groups.Manager } -func addResourceEndpoints(resourcesManager resources.Manager, groupsManager groups.Manager, extractFromToken func(ctx context.Context, claims jwtclaims.AuthorizationClaims) (string, string, error), authCfg configs.AuthCfg, router *mux.Router) { - resourceHandler := newResourceHandler(resourcesManager, groupsManager, extractFromToken, authCfg) +func addResourceEndpoints(resourcesManager resources.Manager, groupsManager groups.Manager, router *mux.Router) { + resourceHandler := newResourceHandler(resourcesManager, groupsManager) router.HandleFunc("/networks/resources", resourceHandler.getAllResourcesInAccount).Methods("GET", "OPTIONS") router.HandleFunc("/networks/{networkId}/resources", resourceHandler.getAllResourcesInNetwork).Methods("GET", "OPTIONS") router.HandleFunc("/networks/{networkId}/resources", resourceHandler.createResource).Methods("POST", "OPTIONS") @@ -33,26 +29,21 @@ func addResourceEndpoints(resourcesManager resources.Manager, groupsManager grou router.HandleFunc("/networks/{networkId}/resources/{resourceId}", resourceHandler.deleteResource).Methods("DELETE", "OPTIONS") } -func newResourceHandler(resourceManager resources.Manager, groupsManager groups.Manager, extractFromToken func(ctx context.Context, claims jwtclaims.AuthorizationClaims) (string, string, error), authCfg configs.AuthCfg) *resourceHandler { +func newResourceHandler(resourceManager resources.Manager, groupsManager groups.Manager) *resourceHandler { return &resourceHandler{ - resourceManager: resourceManager, - groupsManager: groupsManager, - extractFromToken: extractFromToken, - claimsExtractor: jwtclaims.NewClaimsExtractor( - jwtclaims.WithAudience(authCfg.Audience), - jwtclaims.WithUserIDClaim(authCfg.UserIDClaim), - ), + resourceManager: resourceManager, + groupsManager: groupsManager, } } func (h *resourceHandler) getAllResourcesInNetwork(w http.ResponseWriter, r *http.Request) { - claims := h.claimsExtractor.FromRequestContext(r) - accountID, userID, err := h.extractFromToken(r.Context(), claims) + userAuth, err := nbcontext.GetUserAuthFromContext(r.Context()) if err != nil { util.WriteError(r.Context(), err, w) return } + accountID, userID := userAuth.AccountId, userAuth.UserId networkID := mux.Vars(r)["networkId"] resources, err := h.resourceManager.GetAllResourcesInNetwork(r.Context(), accountID, userID, networkID) if err != nil { @@ -76,13 +67,14 @@ func (h *resourceHandler) getAllResourcesInNetwork(w http.ResponseWriter, r *htt util.WriteJSONObject(r.Context(), w, resourcesResponse) } func (h *resourceHandler) getAllResourcesInAccount(w http.ResponseWriter, r *http.Request) { - claims := h.claimsExtractor.FromRequestContext(r) - accountID, userID, err := h.extractFromToken(r.Context(), claims) + userAuth, err := nbcontext.GetUserAuthFromContext(r.Context()) if err != nil { util.WriteError(r.Context(), err, w) return } + accountID, userID := userAuth.AccountId, userAuth.UserId + resources, err := h.resourceManager.GetAllResourcesInAccount(r.Context(), accountID, userID) if err != nil { util.WriteError(r.Context(), err, w) @@ -106,13 +98,14 @@ func (h *resourceHandler) getAllResourcesInAccount(w http.ResponseWriter, r *htt } func (h *resourceHandler) createResource(w http.ResponseWriter, r *http.Request) { - claims := h.claimsExtractor.FromRequestContext(r) - accountID, userID, err := h.extractFromToken(r.Context(), claims) + userAuth, err := nbcontext.GetUserAuthFromContext(r.Context()) if err != nil { util.WriteError(r.Context(), err, w) return } + accountID, userID := userAuth.AccountId, userAuth.UserId + var req api.NetworkResourceRequest err = json.NewDecoder(r.Body).Decode(&req) if err != nil { @@ -144,13 +137,13 @@ func (h *resourceHandler) createResource(w http.ResponseWriter, r *http.Request) } func (h *resourceHandler) getResource(w http.ResponseWriter, r *http.Request) { - claims := h.claimsExtractor.FromRequestContext(r) - accountID, userID, err := h.extractFromToken(r.Context(), claims) + userAuth, err := nbcontext.GetUserAuthFromContext(r.Context()) if err != nil { util.WriteError(r.Context(), err, w) return } + accountID, userID := userAuth.AccountId, userAuth.UserId networkID := mux.Vars(r)["networkId"] resourceID := mux.Vars(r)["resourceId"] resource, err := h.resourceManager.GetResource(r.Context(), accountID, userID, networkID, resourceID) @@ -171,13 +164,13 @@ func (h *resourceHandler) getResource(w http.ResponseWriter, r *http.Request) { } func (h *resourceHandler) updateResource(w http.ResponseWriter, r *http.Request) { - claims := h.claimsExtractor.FromRequestContext(r) - accountID, userID, err := h.extractFromToken(r.Context(), claims) + userAuth, err := nbcontext.GetUserAuthFromContext(r.Context()) if err != nil { util.WriteError(r.Context(), err, w) return } + accountID, userID := userAuth.AccountId, userAuth.UserId var req api.NetworkResourceRequest err = json.NewDecoder(r.Body).Decode(&req) if err != nil { @@ -209,12 +202,12 @@ func (h *resourceHandler) updateResource(w http.ResponseWriter, r *http.Request) } func (h *resourceHandler) deleteResource(w http.ResponseWriter, r *http.Request) { - claims := h.claimsExtractor.FromRequestContext(r) - accountID, userID, err := h.extractFromToken(r.Context(), claims) + userAuth, err := nbcontext.GetUserAuthFromContext(r.Context()) if err != nil { util.WriteError(r.Context(), err, w) return } + accountID, userID := userAuth.AccountId, userAuth.UserId networkID := mux.Vars(r)["networkId"] resourceID := mux.Vars(r)["resourceId"] diff --git a/management/server/http/handlers/networks/routers_handler.go b/management/server/http/handlers/networks/routers_handler.go index 7ca95d902..f98da4966 100644 --- a/management/server/http/handlers/networks/routers_handler.go +++ b/management/server/http/handlers/networks/routers_handler.go @@ -1,28 +1,24 @@ package networks import ( - "context" "encoding/json" "net/http" "github.com/gorilla/mux" + nbcontext "github.com/netbirdio/netbird/management/server/context" "github.com/netbirdio/netbird/management/server/http/api" - "github.com/netbirdio/netbird/management/server/http/configs" "github.com/netbirdio/netbird/management/server/http/util" - "github.com/netbirdio/netbird/management/server/jwtclaims" "github.com/netbirdio/netbird/management/server/networks/routers" "github.com/netbirdio/netbird/management/server/networks/routers/types" ) type routersHandler struct { - routersManager routers.Manager - extractFromToken func(ctx context.Context, claims jwtclaims.AuthorizationClaims) (string, string, error) - claimsExtractor *jwtclaims.ClaimsExtractor + routersManager routers.Manager } -func addRouterEndpoints(routersManager routers.Manager, extractFromToken func(ctx context.Context, claims jwtclaims.AuthorizationClaims) (string, string, error), authCfg configs.AuthCfg, router *mux.Router) { - routersHandler := newRoutersHandler(routersManager, extractFromToken, authCfg) +func addRouterEndpoints(routersManager routers.Manager, router *mux.Router) { + routersHandler := newRoutersHandler(routersManager) router.HandleFunc("/networks/{networkId}/routers", routersHandler.getAllRouters).Methods("GET", "OPTIONS") router.HandleFunc("/networks/{networkId}/routers", routersHandler.createRouter).Methods("POST", "OPTIONS") router.HandleFunc("/networks/{networkId}/routers/{routerId}", routersHandler.getRouter).Methods("GET", "OPTIONS") @@ -30,25 +26,21 @@ func addRouterEndpoints(routersManager routers.Manager, extractFromToken func(ct router.HandleFunc("/networks/{networkId}/routers/{routerId}", routersHandler.deleteRouter).Methods("DELETE", "OPTIONS") } -func newRoutersHandler(routersManager routers.Manager, extractFromToken func(ctx context.Context, claims jwtclaims.AuthorizationClaims) (string, string, error), authCfg configs.AuthCfg) *routersHandler { +func newRoutersHandler(routersManager routers.Manager) *routersHandler { return &routersHandler{ - routersManager: routersManager, - extractFromToken: extractFromToken, - claimsExtractor: jwtclaims.NewClaimsExtractor( - jwtclaims.WithAudience(authCfg.Audience), - jwtclaims.WithUserIDClaim(authCfg.UserIDClaim), - ), + routersManager: routersManager, } } func (h *routersHandler) getAllRouters(w http.ResponseWriter, r *http.Request) { - claims := h.claimsExtractor.FromRequestContext(r) - accountID, userID, err := h.extractFromToken(r.Context(), claims) + userAuth, err := nbcontext.GetUserAuthFromContext(r.Context()) if err != nil { util.WriteError(r.Context(), err, w) return } + accountID, userID := userAuth.AccountId, userAuth.UserId + networkID := mux.Vars(r)["networkId"] routers, err := h.routersManager.GetAllRoutersInNetwork(r.Context(), accountID, userID, networkID) if err != nil { @@ -65,13 +57,14 @@ func (h *routersHandler) getAllRouters(w http.ResponseWriter, r *http.Request) { } func (h *routersHandler) createRouter(w http.ResponseWriter, r *http.Request) { - claims := h.claimsExtractor.FromRequestContext(r) - accountID, userID, err := h.extractFromToken(r.Context(), claims) + userAuth, err := nbcontext.GetUserAuthFromContext(r.Context()) if err != nil { util.WriteError(r.Context(), err, w) return } + accountID, userID := userAuth.AccountId, userAuth.UserId + networkID := mux.Vars(r)["networkId"] var req api.NetworkRouterRequest err = json.NewDecoder(r.Body).Decode(&req) @@ -96,13 +89,14 @@ func (h *routersHandler) createRouter(w http.ResponseWriter, r *http.Request) { } func (h *routersHandler) getRouter(w http.ResponseWriter, r *http.Request) { - claims := h.claimsExtractor.FromRequestContext(r) - accountID, userID, err := h.extractFromToken(r.Context(), claims) + userAuth, err := nbcontext.GetUserAuthFromContext(r.Context()) if err != nil { util.WriteError(r.Context(), err, w) return } + accountID, userID := userAuth.AccountId, userAuth.UserId + routerID := mux.Vars(r)["routerId"] networkID := mux.Vars(r)["networkId"] router, err := h.routersManager.GetRouter(r.Context(), accountID, userID, networkID, routerID) @@ -115,13 +109,14 @@ func (h *routersHandler) getRouter(w http.ResponseWriter, r *http.Request) { } func (h *routersHandler) updateRouter(w http.ResponseWriter, r *http.Request) { - claims := h.claimsExtractor.FromRequestContext(r) - accountID, userID, err := h.extractFromToken(r.Context(), claims) + userAuth, err := nbcontext.GetUserAuthFromContext(r.Context()) if err != nil { util.WriteError(r.Context(), err, w) return } + accountID, userID := userAuth.AccountId, userAuth.UserId + var req api.NetworkRouterRequest err = json.NewDecoder(r.Body).Decode(&req) if err != nil { @@ -146,13 +141,13 @@ func (h *routersHandler) updateRouter(w http.ResponseWriter, r *http.Request) { } func (h *routersHandler) deleteRouter(w http.ResponseWriter, r *http.Request) { - claims := h.claimsExtractor.FromRequestContext(r) - accountID, userID, err := h.extractFromToken(r.Context(), claims) + userAuth, err := nbcontext.GetUserAuthFromContext(r.Context()) if err != nil { util.WriteError(r.Context(), err, w) return } + accountID, userID := userAuth.AccountId, userAuth.UserId routerID := mux.Vars(r)["routerId"] networkID := mux.Vars(r)["networkId"] err = h.routersManager.DeleteRouter(r.Context(), accountID, userID, networkID, routerID) diff --git a/management/server/http/handlers/peers/peers_handler.go b/management/server/http/handlers/peers/peers_handler.go index 723ab4f67..a907a0f24 100644 --- a/management/server/http/handlers/peers/peers_handler.go +++ b/management/server/http/handlers/peers/peers_handler.go @@ -10,11 +10,10 @@ import ( log "github.com/sirupsen/logrus" "github.com/netbirdio/netbird/management/server" + nbcontext "github.com/netbirdio/netbird/management/server/context" "github.com/netbirdio/netbird/management/server/groups" "github.com/netbirdio/netbird/management/server/http/api" - "github.com/netbirdio/netbird/management/server/http/configs" "github.com/netbirdio/netbird/management/server/http/util" - "github.com/netbirdio/netbird/management/server/jwtclaims" nbpeer "github.com/netbirdio/netbird/management/server/peer" "github.com/netbirdio/netbird/management/server/status" "github.com/netbirdio/netbird/management/server/types" @@ -22,12 +21,11 @@ import ( // Handler is a handler that returns peers of the account type Handler struct { - accountManager server.AccountManager - claimsExtractor *jwtclaims.ClaimsExtractor + accountManager server.AccountManager } -func AddEndpoints(accountManager server.AccountManager, authCfg configs.AuthCfg, router *mux.Router) { - peersHandler := NewHandler(accountManager, authCfg) +func AddEndpoints(accountManager server.AccountManager, router *mux.Router) { + peersHandler := NewHandler(accountManager) router.HandleFunc("/peers", peersHandler.GetAllPeers).Methods("GET", "OPTIONS") router.HandleFunc("/peers/{peerId}", peersHandler.HandlePeer). Methods("GET", "PUT", "DELETE", "OPTIONS") @@ -35,13 +33,9 @@ func AddEndpoints(accountManager server.AccountManager, authCfg configs.AuthCfg, } // NewHandler creates a new peers Handler -func NewHandler(accountManager server.AccountManager, authCfg configs.AuthCfg) *Handler { +func NewHandler(accountManager server.AccountManager) *Handler { return &Handler{ accountManager: accountManager, - claimsExtractor: jwtclaims.NewClaimsExtractor( - jwtclaims.WithAudience(authCfg.Audience), - jwtclaims.WithUserIDClaim(authCfg.UserIDClaim), - ), } } @@ -149,12 +143,13 @@ func (h *Handler) deletePeer(ctx context.Context, accountID, userID string, peer // HandlePeer handles all peer requests for GET, PUT and DELETE operations func (h *Handler) HandlePeer(w http.ResponseWriter, r *http.Request) { - claims := h.claimsExtractor.FromRequestContext(r) - accountID, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims) + userAuth, err := nbcontext.GetUserAuthFromContext(r.Context()) if err != nil { util.WriteError(r.Context(), err, w) return } + + accountID, userID := userAuth.AccountId, userAuth.UserId vars := mux.Vars(r) peerID := vars["peerId"] if len(peerID) == 0 { @@ -179,17 +174,22 @@ func (h *Handler) HandlePeer(w http.ResponseWriter, r *http.Request) { // GetAllPeers returns a list of all peers associated with a provided account func (h *Handler) GetAllPeers(w http.ResponseWriter, r *http.Request) { - claims := h.claimsExtractor.FromRequestContext(r) - accountID, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims) + userAuth, err := nbcontext.GetUserAuthFromContext(r.Context()) if err != nil { util.WriteError(r.Context(), err, w) return } +<<<<<<< HEAD nameFilter := r.URL.Query().Get("name") ipFilter := r.URL.Query().Get("ip") peers, err := h.accountManager.GetPeers(r.Context(), accountID, userID, nameFilter, ipFilter) +======= + accountID, userID := userAuth.AccountId, userAuth.UserId + + peers, err := h.accountManager.GetPeers(r.Context(), accountID, userID) +>>>>>>> main if err != nil { util.WriteError(r.Context(), err, w) return @@ -233,13 +233,14 @@ func (h *Handler) setApprovalRequiredFlag(respBody []*api.PeerBatch, approvedPee // GetAccessiblePeers returns a list of all peers that the specified peer can connect to within the network. func (h *Handler) GetAccessiblePeers(w http.ResponseWriter, r *http.Request) { - claims := h.claimsExtractor.FromRequestContext(r) - accountID, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims) + userAuth, err := nbcontext.GetUserAuthFromContext(r.Context()) if err != nil { util.WriteError(r.Context(), err, w) return } + accountID, userID := userAuth.AccountId, userAuth.UserId + vars := mux.Vars(r) peerID := vars["peerId"] if len(peerID) == 0 { diff --git a/management/server/http/handlers/peers/peers_handler_test.go b/management/server/http/handlers/peers/peers_handler_test.go index 248611bd2..cb60ae4f1 100644 --- a/management/server/http/handlers/peers/peers_handler_test.go +++ b/management/server/http/handlers/peers/peers_handler_test.go @@ -15,8 +15,8 @@ import ( "github.com/gorilla/mux" "golang.org/x/exp/maps" + nbcontext "github.com/netbirdio/netbird/management/server/context" "github.com/netbirdio/netbird/management/server/http/api" - "github.com/netbirdio/netbird/management/server/jwtclaims" nbpeer "github.com/netbirdio/netbird/management/server/peer" "github.com/netbirdio/netbird/management/server/types" @@ -25,16 +25,13 @@ import ( "github.com/netbirdio/netbird/management/server/mock_server" ) -type ctxKey string - const ( testPeerID = "test_peer" noUpdateChannelTestPeerID = "no-update-channel" - adminUser = "admin_user" - regularUser = "regular_user" - serviceUser = "service_user" - userIDKey ctxKey = "user_id" + adminUser = "admin_user" + regularUser = "regular_user" + serviceUser = "service_user" ) func initTestMetaData(peers ...*nbpeer.Peer) *Handler { @@ -146,9 +143,6 @@ func initTestMetaData(peers ...*nbpeer.Peer) *Handler { GetDNSDomainFunc: func() string { return "netbird.selfhosted" }, - GetAccountIDFromTokenFunc: func(_ context.Context, claims jwtclaims.AuthorizationClaims) (string, string, error) { - return claims.AccountId, claims.UserId, nil - }, GetAccountFunc: func(ctx context.Context, accountID string) (*types.Account, error) { return account, nil }, @@ -167,16 +161,6 @@ func initTestMetaData(peers ...*nbpeer.Peer) *Handler { return ok }, }, - claimsExtractor: jwtclaims.NewClaimsExtractor( - jwtclaims.WithFromRequestContext(func(r *http.Request) jwtclaims.AuthorizationClaims { - userID := r.Context().Value(userIDKey).(string) - return jwtclaims.AuthorizationClaims{ - UserId: userID, - Domain: "hotmail.com", - AccountId: "test_id", - } - }), - ), } } @@ -267,8 +251,11 @@ func TestGetPeers(t *testing.T) { recorder := httptest.NewRecorder() req := httptest.NewRequest(tc.requestType, tc.requestPath, tc.requestBody) - ctx := context.WithValue(context.Background(), userIDKey, "admin_user") - req = req.WithContext(ctx) + req = nbcontext.SetUserAuthInRequest(req, nbcontext.UserAuth{ + UserId: "admin_user", + Domain: "hotmail.com", + AccountId: "test_id", + }) router := mux.NewRouter() router.HandleFunc("/api/peers/", p.GetAllPeers).Methods("GET") @@ -412,8 +399,11 @@ func TestGetAccessiblePeers(t *testing.T) { recorder := httptest.NewRecorder() req := httptest.NewRequest(http.MethodGet, fmt.Sprintf("/api/peers/%s/accessible-peers", tc.peerID), nil) - ctx := context.WithValue(context.Background(), userIDKey, tc.callerUserID) - req = req.WithContext(ctx) + req = nbcontext.SetUserAuthInRequest(req, nbcontext.UserAuth{ + UserId: tc.callerUserID, + Domain: "hotmail.com", + AccountId: "test_id", + }) router := mux.NewRouter() router.HandleFunc("/api/peers/{peerId}/accessible-peers", p.GetAccessiblePeers).Methods("GET") diff --git a/management/server/http/handlers/policies/geolocation_handler_test.go b/management/server/http/handlers/policies/geolocation_handler_test.go index fc5839baa..fbdc324d6 100644 --- a/management/server/http/handlers/policies/geolocation_handler_test.go +++ b/management/server/http/handlers/policies/geolocation_handler_test.go @@ -13,9 +13,9 @@ import ( "github.com/gorilla/mux" "github.com/stretchr/testify/assert" + nbcontext "github.com/netbirdio/netbird/management/server/context" "github.com/netbirdio/netbird/management/server/geolocation" "github.com/netbirdio/netbird/management/server/http/api" - "github.com/netbirdio/netbird/management/server/jwtclaims" "github.com/netbirdio/netbird/management/server/mock_server" "github.com/netbirdio/netbird/management/server/types" "github.com/netbirdio/netbird/util" @@ -43,23 +43,11 @@ func initGeolocationTestData(t *testing.T) *geolocationsHandler { return &geolocationsHandler{ accountManager: &mock_server.MockAccountManager{ - GetAccountIDFromTokenFunc: func(_ context.Context, claims jwtclaims.AuthorizationClaims) (string, string, error) { - return claims.AccountId, claims.UserId, nil - }, GetUserByIDFunc: func(ctx context.Context, id string) (*types.User, error) { return types.NewAdminUser(id), nil }, }, geolocationManager: geo, - claimsExtractor: jwtclaims.NewClaimsExtractor( - jwtclaims.WithFromRequestContext(func(r *http.Request) jwtclaims.AuthorizationClaims { - return jwtclaims.AuthorizationClaims{ - UserId: "test_user", - Domain: "hotmail.com", - AccountId: "test_id", - } - }), - ), } } @@ -112,6 +100,11 @@ func TestGetCitiesByCountry(t *testing.T) { t.Run(tc.name, func(t *testing.T) { recorder := httptest.NewRecorder() req := httptest.NewRequest(tc.requestType, tc.requestPath, nil) + req = nbcontext.SetUserAuthInRequest(req, nbcontext.UserAuth{ + UserId: "test_user", + Domain: "hotmail.com", + AccountId: "test_id", + }) router := mux.NewRouter() router.HandleFunc("/api/locations/countries/{country}/cities", geolocationHandler.getCitiesByCountry).Methods("GET") @@ -200,6 +193,11 @@ func TestGetAllCountries(t *testing.T) { t.Run(tc.name, func(t *testing.T) { recorder := httptest.NewRecorder() req := httptest.NewRequest(tc.requestType, tc.requestPath, nil) + req = nbcontext.SetUserAuthInRequest(req, nbcontext.UserAuth{ + UserId: "test_user", + Domain: "hotmail.com", + AccountId: "test_id", + }) router := mux.NewRouter() router.HandleFunc("/api/locations/countries", geolocationHandler.getAllCountries).Methods("GET") diff --git a/management/server/http/handlers/policies/geolocations_handler.go b/management/server/http/handlers/policies/geolocations_handler.go index 161d97402..c4868f879 100644 --- a/management/server/http/handlers/policies/geolocations_handler.go +++ b/management/server/http/handlers/policies/geolocations_handler.go @@ -7,11 +7,10 @@ import ( "github.com/gorilla/mux" "github.com/netbirdio/netbird/management/server" + nbcontext "github.com/netbirdio/netbird/management/server/context" "github.com/netbirdio/netbird/management/server/geolocation" "github.com/netbirdio/netbird/management/server/http/api" - "github.com/netbirdio/netbird/management/server/http/configs" "github.com/netbirdio/netbird/management/server/http/util" - "github.com/netbirdio/netbird/management/server/jwtclaims" "github.com/netbirdio/netbird/management/server/status" ) @@ -23,24 +22,19 @@ var ( type geolocationsHandler struct { accountManager server.AccountManager geolocationManager geolocation.Geolocation - claimsExtractor *jwtclaims.ClaimsExtractor } -func addLocationsEndpoint(accountManager server.AccountManager, locationManager geolocation.Geolocation, authCfg configs.AuthCfg, router *mux.Router) { - locationHandler := newGeolocationsHandlerHandler(accountManager, locationManager, authCfg) +func addLocationsEndpoint(accountManager server.AccountManager, locationManager geolocation.Geolocation, router *mux.Router) { + locationHandler := newGeolocationsHandlerHandler(accountManager, locationManager) router.HandleFunc("/locations/countries", locationHandler.getAllCountries).Methods("GET", "OPTIONS") router.HandleFunc("/locations/countries/{country}/cities", locationHandler.getCitiesByCountry).Methods("GET", "OPTIONS") } // newGeolocationsHandlerHandler creates a new Geolocations handler -func newGeolocationsHandlerHandler(accountManager server.AccountManager, geolocationManager geolocation.Geolocation, authCfg configs.AuthCfg) *geolocationsHandler { +func newGeolocationsHandlerHandler(accountManager server.AccountManager, geolocationManager geolocation.Geolocation) *geolocationsHandler { return &geolocationsHandler{ accountManager: accountManager, geolocationManager: geolocationManager, - claimsExtractor: jwtclaims.NewClaimsExtractor( - jwtclaims.WithAudience(authCfg.Audience), - jwtclaims.WithUserIDClaim(authCfg.UserIDClaim), - ), } } @@ -104,12 +98,13 @@ func (l *geolocationsHandler) getCitiesByCountry(w http.ResponseWriter, r *http. } func (l *geolocationsHandler) authenticateUser(r *http.Request) error { - claims := l.claimsExtractor.FromRequestContext(r) - _, userID, err := l.accountManager.GetAccountIDFromToken(r.Context(), claims) + userAuth, err := nbcontext.GetUserAuthFromContext(r.Context()) if err != nil { return err } + _, userID := userAuth.AccountId, userAuth.UserId + user, err := l.accountManager.GetUserByID(r.Context(), userID) if err != nil { return err diff --git a/management/server/http/handlers/policies/policies_handler.go b/management/server/http/handlers/policies/policies_handler.go index a748e73b8..63fc8a03b 100644 --- a/management/server/http/handlers/policies/policies_handler.go +++ b/management/server/http/handlers/policies/policies_handler.go @@ -8,51 +8,46 @@ import ( "github.com/gorilla/mux" "github.com/netbirdio/netbird/management/server" + nbcontext "github.com/netbirdio/netbird/management/server/context" "github.com/netbirdio/netbird/management/server/geolocation" "github.com/netbirdio/netbird/management/server/http/api" - "github.com/netbirdio/netbird/management/server/http/configs" "github.com/netbirdio/netbird/management/server/http/util" - "github.com/netbirdio/netbird/management/server/jwtclaims" "github.com/netbirdio/netbird/management/server/status" "github.com/netbirdio/netbird/management/server/types" ) // handler is a handler that returns policy of the account type handler struct { - accountManager server.AccountManager - claimsExtractor *jwtclaims.ClaimsExtractor + accountManager server.AccountManager } -func AddEndpoints(accountManager server.AccountManager, locationManager geolocation.Geolocation, authCfg configs.AuthCfg, router *mux.Router) { - policiesHandler := newHandler(accountManager, authCfg) +func AddEndpoints(accountManager server.AccountManager, locationManager geolocation.Geolocation, router *mux.Router) { + policiesHandler := newHandler(accountManager) router.HandleFunc("/policies", policiesHandler.getAllPolicies).Methods("GET", "OPTIONS") router.HandleFunc("/policies", policiesHandler.createPolicy).Methods("POST", "OPTIONS") router.HandleFunc("/policies/{policyId}", policiesHandler.updatePolicy).Methods("PUT", "OPTIONS") router.HandleFunc("/policies/{policyId}", policiesHandler.getPolicy).Methods("GET", "OPTIONS") router.HandleFunc("/policies/{policyId}", policiesHandler.deletePolicy).Methods("DELETE", "OPTIONS") - addPostureCheckEndpoint(accountManager, locationManager, authCfg, router) + addPostureCheckEndpoint(accountManager, locationManager, router) } // newHandler creates a new policies handler -func newHandler(accountManager server.AccountManager, authCfg configs.AuthCfg) *handler { +func newHandler(accountManager server.AccountManager) *handler { return &handler{ accountManager: accountManager, - claimsExtractor: jwtclaims.NewClaimsExtractor( - jwtclaims.WithAudience(authCfg.Audience), - jwtclaims.WithUserIDClaim(authCfg.UserIDClaim), - ), } } // getAllPolicies list for the account func (h *handler) getAllPolicies(w http.ResponseWriter, r *http.Request) { - claims := h.claimsExtractor.FromRequestContext(r) - accountID, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims) + userAuth, err := nbcontext.GetUserAuthFromContext(r.Context()) if err != nil { util.WriteError(r.Context(), err, w) return } + accountID, userID := userAuth.AccountId, userAuth.UserId + listPolicies, err := h.accountManager.ListPolicies(r.Context(), accountID, userID) if err != nil { util.WriteError(r.Context(), err, w) @@ -80,13 +75,14 @@ func (h *handler) getAllPolicies(w http.ResponseWriter, r *http.Request) { // updatePolicy handles update to a policy identified by a given ID func (h *handler) updatePolicy(w http.ResponseWriter, r *http.Request) { - claims := h.claimsExtractor.FromRequestContext(r) - accountID, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims) + userAuth, err := nbcontext.GetUserAuthFromContext(r.Context()) if err != nil { util.WriteError(r.Context(), err, w) return } + accountID, userID := userAuth.AccountId, userAuth.UserId + vars := mux.Vars(r) policyID := vars["policyId"] if len(policyID) == 0 { @@ -105,13 +101,14 @@ func (h *handler) updatePolicy(w http.ResponseWriter, r *http.Request) { // createPolicy handles policy creation request func (h *handler) createPolicy(w http.ResponseWriter, r *http.Request) { - claims := h.claimsExtractor.FromRequestContext(r) - accountID, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims) + userAuth, err := nbcontext.GetUserAuthFromContext(r.Context()) if err != nil { util.WriteError(r.Context(), err, w) return } + accountID, userID := userAuth.AccountId, userAuth.UserId + h.savePolicy(w, r, accountID, userID, "") } @@ -306,13 +303,13 @@ func (h *handler) savePolicy(w http.ResponseWriter, r *http.Request, accountID s // deletePolicy handles policy deletion request func (h *handler) deletePolicy(w http.ResponseWriter, r *http.Request) { - claims := h.claimsExtractor.FromRequestContext(r) - accountID, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims) + userAuth, err := nbcontext.GetUserAuthFromContext(r.Context()) if err != nil { util.WriteError(r.Context(), err, w) return } + accountID, userID := userAuth.AccountId, userAuth.UserId vars := mux.Vars(r) policyID := vars["policyId"] if len(policyID) == 0 { @@ -330,13 +327,14 @@ func (h *handler) deletePolicy(w http.ResponseWriter, r *http.Request) { // getPolicy handles a group Get request identified by ID func (h *handler) getPolicy(w http.ResponseWriter, r *http.Request) { - claims := h.claimsExtractor.FromRequestContext(r) - accountID, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims) + userAuth, err := nbcontext.GetUserAuthFromContext(r.Context()) if err != nil { util.WriteError(r.Context(), err, w) return } + accountID, userID := userAuth.AccountId, userAuth.UserId + vars := mux.Vars(r) policyID := vars["policyId"] if len(policyID) == 0 { diff --git a/management/server/http/handlers/policies/policies_handler_test.go b/management/server/http/handlers/policies/policies_handler_test.go index 8fbf84d4b..6450295eb 100644 --- a/management/server/http/handlers/policies/policies_handler_test.go +++ b/management/server/http/handlers/policies/policies_handler_test.go @@ -13,8 +13,8 @@ import ( "github.com/gorilla/mux" "github.com/stretchr/testify/assert" + nbcontext "github.com/netbirdio/netbird/management/server/context" "github.com/netbirdio/netbird/management/server/http/api" - "github.com/netbirdio/netbird/management/server/jwtclaims" "github.com/netbirdio/netbird/management/server/mock_server" "github.com/netbirdio/netbird/management/server/status" "github.com/netbirdio/netbird/management/server/types" @@ -44,9 +44,6 @@ func initPoliciesTestData(policies ...*types.Policy) *handler { GetAllGroupsFunc: func(ctx context.Context, accountID, userID string) ([]*types.Group, error) { return []*types.Group{{ID: "F"}, {ID: "G"}}, nil }, - GetAccountIDFromTokenFunc: func(_ context.Context, claims jwtclaims.AuthorizationClaims) (string, string, error) { - return claims.AccountId, claims.UserId, nil - }, GetAccountByIDFunc: func(ctx context.Context, accountID string, userID string) (*types.Account, error) { user := types.NewAdminUser(userID) return &types.Account{ @@ -65,15 +62,6 @@ func initPoliciesTestData(policies ...*types.Policy) *handler { }, nil }, }, - claimsExtractor: jwtclaims.NewClaimsExtractor( - jwtclaims.WithFromRequestContext(func(r *http.Request) jwtclaims.AuthorizationClaims { - return jwtclaims.AuthorizationClaims{ - UserId: "test_user", - Domain: "hotmail.com", - AccountId: "test_id", - } - }), - ), } } @@ -115,6 +103,11 @@ func TestPoliciesGetPolicy(t *testing.T) { t.Run(tc.name, func(t *testing.T) { recorder := httptest.NewRecorder() req := httptest.NewRequest(tc.requestType, tc.requestPath, tc.requestBody) + req = nbcontext.SetUserAuthInRequest(req, nbcontext.UserAuth{ + UserId: "test_user", + Domain: "hotmail.com", + AccountId: "test_id", + }) router := mux.NewRouter() router.HandleFunc("/api/policies/{policyId}", p.getPolicy).Methods("GET") @@ -274,6 +267,11 @@ func TestPoliciesWritePolicy(t *testing.T) { t.Run(tc.name, func(t *testing.T) { recorder := httptest.NewRecorder() req := httptest.NewRequest(tc.requestType, tc.requestPath, tc.requestBody) + req = nbcontext.SetUserAuthInRequest(req, nbcontext.UserAuth{ + UserId: "test_user", + Domain: "hotmail.com", + AccountId: "test_id", + }) router := mux.NewRouter() router.HandleFunc("/api/policies", p.createPolicy).Methods("POST") diff --git a/management/server/http/handlers/policies/posture_checks_handler.go b/management/server/http/handlers/policies/posture_checks_handler.go index ce0d4878c..e6e58da58 100644 --- a/management/server/http/handlers/policies/posture_checks_handler.go +++ b/management/server/http/handlers/policies/posture_checks_handler.go @@ -7,11 +7,10 @@ import ( "github.com/gorilla/mux" "github.com/netbirdio/netbird/management/server" + nbcontext "github.com/netbirdio/netbird/management/server/context" "github.com/netbirdio/netbird/management/server/geolocation" "github.com/netbirdio/netbird/management/server/http/api" - "github.com/netbirdio/netbird/management/server/http/configs" "github.com/netbirdio/netbird/management/server/http/util" - "github.com/netbirdio/netbird/management/server/jwtclaims" "github.com/netbirdio/netbird/management/server/posture" "github.com/netbirdio/netbird/management/server/status" ) @@ -20,40 +19,35 @@ import ( type postureChecksHandler struct { accountManager server.AccountManager geolocationManager geolocation.Geolocation - claimsExtractor *jwtclaims.ClaimsExtractor } -func addPostureCheckEndpoint(accountManager server.AccountManager, locationManager geolocation.Geolocation, authCfg configs.AuthCfg, router *mux.Router) { - postureCheckHandler := newPostureChecksHandler(accountManager, locationManager, authCfg) +func addPostureCheckEndpoint(accountManager server.AccountManager, locationManager geolocation.Geolocation, router *mux.Router) { + postureCheckHandler := newPostureChecksHandler(accountManager, locationManager) router.HandleFunc("/posture-checks", postureCheckHandler.getAllPostureChecks).Methods("GET", "OPTIONS") router.HandleFunc("/posture-checks", postureCheckHandler.createPostureCheck).Methods("POST", "OPTIONS") router.HandleFunc("/posture-checks/{postureCheckId}", postureCheckHandler.updatePostureCheck).Methods("PUT", "OPTIONS") router.HandleFunc("/posture-checks/{postureCheckId}", postureCheckHandler.getPostureCheck).Methods("GET", "OPTIONS") router.HandleFunc("/posture-checks/{postureCheckId}", postureCheckHandler.deletePostureCheck).Methods("DELETE", "OPTIONS") - addLocationsEndpoint(accountManager, locationManager, authCfg, router) + addLocationsEndpoint(accountManager, locationManager, router) } // newPostureChecksHandler creates a new PostureChecks handler -func newPostureChecksHandler(accountManager server.AccountManager, geolocationManager geolocation.Geolocation, authCfg configs.AuthCfg) *postureChecksHandler { +func newPostureChecksHandler(accountManager server.AccountManager, geolocationManager geolocation.Geolocation) *postureChecksHandler { return &postureChecksHandler{ accountManager: accountManager, geolocationManager: geolocationManager, - claimsExtractor: jwtclaims.NewClaimsExtractor( - jwtclaims.WithAudience(authCfg.Audience), - jwtclaims.WithUserIDClaim(authCfg.UserIDClaim), - ), } } // getAllPostureChecks list for the account func (p *postureChecksHandler) getAllPostureChecks(w http.ResponseWriter, r *http.Request) { - claims := p.claimsExtractor.FromRequestContext(r) - accountID, userID, err := p.accountManager.GetAccountIDFromToken(r.Context(), claims) + userAuth, err := nbcontext.GetUserAuthFromContext(r.Context()) if err != nil { util.WriteError(r.Context(), err, w) return } + accountID, userID := userAuth.AccountId, userAuth.UserId listPostureChecks, err := p.accountManager.ListPostureChecks(r.Context(), accountID, userID) if err != nil { util.WriteError(r.Context(), err, w) @@ -70,13 +64,14 @@ func (p *postureChecksHandler) getAllPostureChecks(w http.ResponseWriter, r *htt // updatePostureCheck handles update to a posture check identified by a given ID func (p *postureChecksHandler) updatePostureCheck(w http.ResponseWriter, r *http.Request) { - claims := p.claimsExtractor.FromRequestContext(r) - accountID, userID, err := p.accountManager.GetAccountIDFromToken(r.Context(), claims) + userAuth, err := nbcontext.GetUserAuthFromContext(r.Context()) if err != nil { util.WriteError(r.Context(), err, w) return } + accountID, userID := userAuth.AccountId, userAuth.UserId + vars := mux.Vars(r) postureChecksID := vars["postureCheckId"] if len(postureChecksID) == 0 { @@ -95,25 +90,26 @@ func (p *postureChecksHandler) updatePostureCheck(w http.ResponseWriter, r *http // createPostureCheck handles posture check creation request func (p *postureChecksHandler) createPostureCheck(w http.ResponseWriter, r *http.Request) { - claims := p.claimsExtractor.FromRequestContext(r) - accountID, userID, err := p.accountManager.GetAccountIDFromToken(r.Context(), claims) + userAuth, err := nbcontext.GetUserAuthFromContext(r.Context()) if err != nil { util.WriteError(r.Context(), err, w) return } + accountID, userID := userAuth.AccountId, userAuth.UserId + p.savePostureChecks(w, r, accountID, userID, "") } // getPostureCheck handles a posture check Get request identified by ID func (p *postureChecksHandler) getPostureCheck(w http.ResponseWriter, r *http.Request) { - claims := p.claimsExtractor.FromRequestContext(r) - accountID, userID, err := p.accountManager.GetAccountIDFromToken(r.Context(), claims) + userAuth, err := nbcontext.GetUserAuthFromContext(r.Context()) if err != nil { util.WriteError(r.Context(), err, w) return } + accountID, userID := userAuth.AccountId, userAuth.UserId vars := mux.Vars(r) postureChecksID := vars["postureCheckId"] if len(postureChecksID) == 0 { @@ -132,13 +128,13 @@ func (p *postureChecksHandler) getPostureCheck(w http.ResponseWriter, r *http.Re // deletePostureCheck handles posture check deletion request func (p *postureChecksHandler) deletePostureCheck(w http.ResponseWriter, r *http.Request) { - claims := p.claimsExtractor.FromRequestContext(r) - accountID, userID, err := p.accountManager.GetAccountIDFromToken(r.Context(), claims) + userAuth, err := nbcontext.GetUserAuthFromContext(r.Context()) if err != nil { util.WriteError(r.Context(), err, w) return } + accountID, userID := userAuth.AccountId, userAuth.UserId vars := mux.Vars(r) postureChecksID := vars["postureCheckId"] if len(postureChecksID) == 0 { diff --git a/management/server/http/handlers/policies/posture_checks_handler_test.go b/management/server/http/handlers/policies/posture_checks_handler_test.go index 237687fd4..e3844caa2 100644 --- a/management/server/http/handlers/policies/posture_checks_handler_test.go +++ b/management/server/http/handlers/policies/posture_checks_handler_test.go @@ -14,9 +14,9 @@ import ( "github.com/gorilla/mux" "github.com/stretchr/testify/assert" + nbcontext "github.com/netbirdio/netbird/management/server/context" "github.com/netbirdio/netbird/management/server/geolocation" "github.com/netbirdio/netbird/management/server/http/api" - "github.com/netbirdio/netbird/management/server/jwtclaims" "github.com/netbirdio/netbird/management/server/mock_server" "github.com/netbirdio/netbird/management/server/posture" "github.com/netbirdio/netbird/management/server/status" @@ -66,20 +66,8 @@ func initPostureChecksTestData(postureChecks ...*posture.Checks) *postureChecksH } return accountPostureChecks, nil }, - GetAccountIDFromTokenFunc: func(_ context.Context, claims jwtclaims.AuthorizationClaims) (string, string, error) { - return claims.AccountId, claims.UserId, nil - }, }, geolocationManager: &geolocation.Mock{}, - claimsExtractor: jwtclaims.NewClaimsExtractor( - jwtclaims.WithFromRequestContext(func(r *http.Request) jwtclaims.AuthorizationClaims { - return jwtclaims.AuthorizationClaims{ - UserId: "test_user", - Domain: "hotmail.com", - AccountId: "test_id", - } - }), - ), } } @@ -187,6 +175,11 @@ func TestGetPostureCheck(t *testing.T) { t.Run(tc.name, func(t *testing.T) { recorder := httptest.NewRecorder() req := httptest.NewRequest(http.MethodGet, "/api/posture-checks/"+tc.id, tc.requestBody) + req = nbcontext.SetUserAuthInRequest(req, nbcontext.UserAuth{ + UserId: "test_user", + Domain: "hotmail.com", + AccountId: "test_id", + }) router := mux.NewRouter() router.HandleFunc("/api/posture-checks/{postureCheckId}", p.getPostureCheck).Methods("GET") @@ -835,6 +828,11 @@ func TestPostureCheckUpdate(t *testing.T) { t.Run(tc.name, func(t *testing.T) { recorder := httptest.NewRecorder() req := httptest.NewRequest(tc.requestType, tc.requestPath, tc.requestBody) + req = nbcontext.SetUserAuthInRequest(req, nbcontext.UserAuth{ + UserId: "test_user", + Domain: "hotmail.com", + AccountId: "test_id", + }) defaultHandler := *p if tc.setupHandlerFunc != nil { diff --git a/management/server/http/handlers/routes/routes_handler.go b/management/server/http/handlers/routes/routes_handler.go index 6b6c37910..0f0d24780 100644 --- a/management/server/http/handlers/routes/routes_handler.go +++ b/management/server/http/handlers/routes/routes_handler.go @@ -10,10 +10,9 @@ import ( "github.com/netbirdio/netbird/management/domain" "github.com/netbirdio/netbird/management/server" + nbcontext "github.com/netbirdio/netbird/management/server/context" "github.com/netbirdio/netbird/management/server/http/api" - "github.com/netbirdio/netbird/management/server/http/configs" "github.com/netbirdio/netbird/management/server/http/util" - "github.com/netbirdio/netbird/management/server/jwtclaims" "github.com/netbirdio/netbird/management/server/status" "github.com/netbirdio/netbird/route" ) @@ -22,12 +21,11 @@ const failedToConvertRoute = "failed to convert route to response: %v" // handler is the routes handler of the account type handler struct { - accountManager server.AccountManager - claimsExtractor *jwtclaims.ClaimsExtractor + accountManager server.AccountManager } -func AddEndpoints(accountManager server.AccountManager, authCfg configs.AuthCfg, router *mux.Router) { - routesHandler := newHandler(accountManager, authCfg) +func AddEndpoints(accountManager server.AccountManager, router *mux.Router) { + routesHandler := newHandler(accountManager) router.HandleFunc("/routes", routesHandler.getAllRoutes).Methods("GET", "OPTIONS") router.HandleFunc("/routes", routesHandler.createRoute).Methods("POST", "OPTIONS") router.HandleFunc("/routes/{routeId}", routesHandler.updateRoute).Methods("PUT", "OPTIONS") @@ -36,25 +34,22 @@ func AddEndpoints(accountManager server.AccountManager, authCfg configs.AuthCfg, } // newHandler returns a new instance of routes handler -func newHandler(accountManager server.AccountManager, authCfg configs.AuthCfg) *handler { +func newHandler(accountManager server.AccountManager) *handler { return &handler{ accountManager: accountManager, - claimsExtractor: jwtclaims.NewClaimsExtractor( - jwtclaims.WithAudience(authCfg.Audience), - jwtclaims.WithUserIDClaim(authCfg.UserIDClaim), - ), } } // getAllRoutes returns the list of routes for the account func (h *handler) getAllRoutes(w http.ResponseWriter, r *http.Request) { - claims := h.claimsExtractor.FromRequestContext(r) - accountID, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims) + userAuth, err := nbcontext.GetUserAuthFromContext(r.Context()) if err != nil { util.WriteError(r.Context(), err, w) return } + accountID, userID := userAuth.AccountId, userAuth.UserId + routes, err := h.accountManager.ListRoutes(r.Context(), accountID, userID) if err != nil { util.WriteError(r.Context(), err, w) @@ -75,13 +70,14 @@ func (h *handler) getAllRoutes(w http.ResponseWriter, r *http.Request) { // createRoute handles route creation request func (h *handler) createRoute(w http.ResponseWriter, r *http.Request) { - claims := h.claimsExtractor.FromRequestContext(r) - accountID, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims) + userAuth, err := nbcontext.GetUserAuthFromContext(r.Context()) if err != nil { util.WriteError(r.Context(), err, w) return } + accountID, userID := userAuth.AccountId, userAuth.UserId + var req api.PostApiRoutesJSONRequestBody err = json.NewDecoder(r.Body).Decode(&req) if err != nil { @@ -172,13 +168,13 @@ func (h *handler) validateRoute(req api.PostApiRoutesJSONRequestBody) error { // updateRoute handles update to a route identified by a given ID func (h *handler) updateRoute(w http.ResponseWriter, r *http.Request) { - claims := h.claimsExtractor.FromRequestContext(r) - accountID, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims) + userAuth, err := nbcontext.GetUserAuthFromContext(r.Context()) if err != nil { util.WriteError(r.Context(), err, w) return } + accountID, userID := userAuth.AccountId, userAuth.UserId vars := mux.Vars(r) routeID := vars["routeId"] if len(routeID) == 0 { @@ -265,13 +261,13 @@ func (h *handler) updateRoute(w http.ResponseWriter, r *http.Request) { // deleteRoute handles route deletion request func (h *handler) deleteRoute(w http.ResponseWriter, r *http.Request) { - claims := h.claimsExtractor.FromRequestContext(r) - accountID, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims) + userAuth, err := nbcontext.GetUserAuthFromContext(r.Context()) if err != nil { util.WriteError(r.Context(), err, w) return } + accountID, userID := userAuth.AccountId, userAuth.UserId routeID := mux.Vars(r)["routeId"] if len(routeID) == 0 { util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "invalid route ID"), w) @@ -289,13 +285,14 @@ func (h *handler) deleteRoute(w http.ResponseWriter, r *http.Request) { // getRoute handles a route Get request identified by ID func (h *handler) getRoute(w http.ResponseWriter, r *http.Request) { - claims := h.claimsExtractor.FromRequestContext(r) - accountID, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims) + userAuth, err := nbcontext.GetUserAuthFromContext(r.Context()) if err != nil { util.WriteError(r.Context(), err, w) return } + accountID, userID := userAuth.AccountId, userAuth.UserId + routeID := mux.Vars(r)["routeId"] if len(routeID) == 0 { util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "invalid route ID"), w) diff --git a/management/server/http/handlers/routes/routes_handler_test.go b/management/server/http/handlers/routes/routes_handler_test.go index f3bd79ee4..ad1f8912d 100644 --- a/management/server/http/handlers/routes/routes_handler_test.go +++ b/management/server/http/handlers/routes/routes_handler_test.go @@ -16,12 +16,10 @@ import ( "github.com/stretchr/testify/require" "github.com/netbirdio/netbird/management/domain" + nbcontext "github.com/netbirdio/netbird/management/server/context" "github.com/netbirdio/netbird/management/server/http/api" - "github.com/netbirdio/netbird/management/server/jwtclaims" "github.com/netbirdio/netbird/management/server/mock_server" - nbpeer "github.com/netbirdio/netbird/management/server/peer" "github.com/netbirdio/netbird/management/server/status" - "github.com/netbirdio/netbird/management/server/types" "github.com/netbirdio/netbird/management/server/util" "github.com/netbirdio/netbird/route" ) @@ -60,32 +58,6 @@ var baseExistingRoute = &route.Route{ Groups: []string{existingGroupID}, } -var testingAccount = &types.Account{ - Id: testAccountID, - Domain: "hotmail.com", - Peers: map[string]*nbpeer.Peer{ - existingPeerID: { - Key: existingPeerKey, - IP: netip.MustParseAddr(existingPeerIP1).AsSlice(), - ID: existingPeerID, - Meta: nbpeer.PeerSystemMeta{ - GoOS: "linux", - }, - }, - nonLinuxExistingPeerID: { - Key: nonLinuxExistingPeerID, - IP: netip.MustParseAddr(existingPeerIP2).AsSlice(), - ID: nonLinuxExistingPeerID, - Meta: nbpeer.PeerSystemMeta{ - GoOS: "darwin", - }, - }, - }, - Users: map[string]*types.User{ - "test_user": types.NewAdminUser("test_user"), - }, -} - func initRoutesTestData() *handler { return &handler{ accountManager: &mock_server.MockAccountManager{ @@ -150,20 +122,7 @@ func initRoutesTestData() *handler { } return nil }, - GetAccountIDFromTokenFunc: func(_ context.Context, _ jwtclaims.AuthorizationClaims) (string, string, error) { - // return testingAccount, testingAccount.Users["test_user"], nil - return testingAccount.Id, testingAccount.Users["test_user"].Id, nil - }, }, - claimsExtractor: jwtclaims.NewClaimsExtractor( - jwtclaims.WithFromRequestContext(func(r *http.Request) jwtclaims.AuthorizationClaims { - return jwtclaims.AuthorizationClaims{ - UserId: "test_user", - Domain: "hotmail.com", - AccountId: testAccountID, - } - }), - ), } } @@ -526,6 +485,11 @@ func TestRoutesHandlers(t *testing.T) { t.Run(tc.name, func(t *testing.T) { recorder := httptest.NewRecorder() req := httptest.NewRequest(tc.requestType, tc.requestPath, tc.requestBody) + req = nbcontext.SetUserAuthInRequest(req, nbcontext.UserAuth{ + UserId: "test_user", + Domain: "hotmail.com", + AccountId: testAccountID, + }) router := mux.NewRouter() router.HandleFunc("/api/routes/{routeId}", p.getRoute).Methods("GET") diff --git a/management/server/http/handlers/setup_keys/setupkeys_handler.go b/management/server/http/handlers/setup_keys/setupkeys_handler.go index 3bd3ef589..8095f43b0 100644 --- a/management/server/http/handlers/setup_keys/setupkeys_handler.go +++ b/management/server/http/handlers/setup_keys/setupkeys_handler.go @@ -10,22 +10,20 @@ import ( "github.com/gorilla/mux" "github.com/netbirdio/netbird/management/server" + nbcontext "github.com/netbirdio/netbird/management/server/context" "github.com/netbirdio/netbird/management/server/http/api" - "github.com/netbirdio/netbird/management/server/http/configs" "github.com/netbirdio/netbird/management/server/http/util" - "github.com/netbirdio/netbird/management/server/jwtclaims" "github.com/netbirdio/netbird/management/server/status" "github.com/netbirdio/netbird/management/server/types" ) // handler is a handler that returns a list of setup keys of the account type handler struct { - accountManager server.AccountManager - claimsExtractor *jwtclaims.ClaimsExtractor + accountManager server.AccountManager } -func AddEndpoints(accountManager server.AccountManager, authCfg configs.AuthCfg, router *mux.Router) { - keysHandler := newHandler(accountManager, authCfg) +func AddEndpoints(accountManager server.AccountManager, router *mux.Router) { + keysHandler := newHandler(accountManager) router.HandleFunc("/setup-keys", keysHandler.getAllSetupKeys).Methods("GET", "OPTIONS") router.HandleFunc("/setup-keys", keysHandler.createSetupKey).Methods("POST", "OPTIONS") router.HandleFunc("/setup-keys/{keyId}", keysHandler.getSetupKey).Methods("GET", "OPTIONS") @@ -34,25 +32,21 @@ func AddEndpoints(accountManager server.AccountManager, authCfg configs.AuthCfg, } // newHandler creates a new setup key handler -func newHandler(accountManager server.AccountManager, authCfg configs.AuthCfg) *handler { +func newHandler(accountManager server.AccountManager) *handler { return &handler{ accountManager: accountManager, - claimsExtractor: jwtclaims.NewClaimsExtractor( - jwtclaims.WithAudience(authCfg.Audience), - jwtclaims.WithUserIDClaim(authCfg.UserIDClaim), - ), } } // createSetupKey is a POST requests that creates a new SetupKey func (h *handler) createSetupKey(w http.ResponseWriter, r *http.Request) { - claims := h.claimsExtractor.FromRequestContext(r) - accountID, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims) + userAuth, err := nbcontext.GetUserAuthFromContext(r.Context()) if err != nil { util.WriteError(r.Context(), err, w) return } + accountID, userID := userAuth.AccountId, userAuth.UserId req := &api.PostApiSetupKeysJSONRequestBody{} err = json.NewDecoder(r.Body).Decode(&req) if err != nil { @@ -108,12 +102,12 @@ func (h *handler) createSetupKey(w http.ResponseWriter, r *http.Request) { // getSetupKey is a GET request to get a SetupKey by ID func (h *handler) getSetupKey(w http.ResponseWriter, r *http.Request) { - claims := h.claimsExtractor.FromRequestContext(r) - accountID, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims) + userAuth, err := nbcontext.GetUserAuthFromContext(r.Context()) if err != nil { util.WriteError(r.Context(), err, w) return } + accountID, userID := userAuth.AccountId, userAuth.UserId vars := mux.Vars(r) keyID := vars["keyId"] @@ -133,13 +127,13 @@ func (h *handler) getSetupKey(w http.ResponseWriter, r *http.Request) { // updateSetupKey is a PUT request to update server.SetupKey func (h *handler) updateSetupKey(w http.ResponseWriter, r *http.Request) { - claims := h.claimsExtractor.FromRequestContext(r) - accountID, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims) + userAuth, err := nbcontext.GetUserAuthFromContext(r.Context()) if err != nil { util.WriteError(r.Context(), err, w) return } + accountID, userID := userAuth.AccountId, userAuth.UserId vars := mux.Vars(r) keyID := vars["keyId"] if len(keyID) == 0 { @@ -174,13 +168,13 @@ func (h *handler) updateSetupKey(w http.ResponseWriter, r *http.Request) { // getAllSetupKeys is a GET request that returns a list of SetupKey func (h *handler) getAllSetupKeys(w http.ResponseWriter, r *http.Request) { - claims := h.claimsExtractor.FromRequestContext(r) - accountID, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims) + userAuth, err := nbcontext.GetUserAuthFromContext(r.Context()) if err != nil { util.WriteError(r.Context(), err, w) return } + accountID, userID := userAuth.AccountId, userAuth.UserId setupKeys, err := h.accountManager.ListSetupKeys(r.Context(), accountID, userID) if err != nil { util.WriteError(r.Context(), err, w) @@ -196,13 +190,13 @@ func (h *handler) getAllSetupKeys(w http.ResponseWriter, r *http.Request) { } func (h *handler) deleteSetupKey(w http.ResponseWriter, r *http.Request) { - claims := h.claimsExtractor.FromRequestContext(r) - accountID, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims) + userAuth, err := nbcontext.GetUserAuthFromContext(r.Context()) if err != nil { util.WriteError(r.Context(), err, w) return } + accountID, userID := userAuth.AccountId, userAuth.UserId vars := mux.Vars(r) keyID := vars["keyId"] if len(keyID) == 0 { diff --git a/management/server/http/handlers/setup_keys/setupkeys_handler_test.go b/management/server/http/handlers/setup_keys/setupkeys_handler_test.go index 4912f9639..e9135469f 100644 --- a/management/server/http/handlers/setup_keys/setupkeys_handler_test.go +++ b/management/server/http/handlers/setup_keys/setupkeys_handler_test.go @@ -14,8 +14,8 @@ import ( "github.com/gorilla/mux" "github.com/stretchr/testify/assert" + nbcontext "github.com/netbirdio/netbird/management/server/context" "github.com/netbirdio/netbird/management/server/http/api" - "github.com/netbirdio/netbird/management/server/jwtclaims" "github.com/netbirdio/netbird/management/server/mock_server" "github.com/netbirdio/netbird/management/server/status" "github.com/netbirdio/netbird/management/server/types" @@ -28,14 +28,9 @@ const ( notFoundSetupKeyID = "notFoundSetupKeyID" ) -func initSetupKeysTestMetaData(defaultKey *types.SetupKey, newKey *types.SetupKey, updatedSetupKey *types.SetupKey, - user *types.User, -) *handler { +func initSetupKeysTestMetaData(defaultKey *types.SetupKey, newKey *types.SetupKey, updatedSetupKey *types.SetupKey) *handler { return &handler{ accountManager: &mock_server.MockAccountManager{ - GetAccountIDFromTokenFunc: func(_ context.Context, claims jwtclaims.AuthorizationClaims) (string, string, error) { - return claims.AccountId, claims.UserId, nil - }, CreateSetupKeyFunc: func(_ context.Context, _ string, keyName string, typ types.SetupKeyType, _ time.Duration, _ []string, _ int, _ string, ephemeral bool, allowExtraDNSLabels bool, ) (*types.SetupKey, error) { @@ -76,15 +71,6 @@ func initSetupKeysTestMetaData(defaultKey *types.SetupKey, newKey *types.SetupKe return status.Errorf(status.NotFound, "key %s not found", keyID) }, }, - claimsExtractor: jwtclaims.NewClaimsExtractor( - jwtclaims.WithFromRequestContext(func(r *http.Request) jwtclaims.AuthorizationClaims { - return jwtclaims.AuthorizationClaims{ - UserId: user.Id, - Domain: "hotmail.com", - AccountId: "testAccountId", - } - }), - ), } } @@ -171,12 +157,17 @@ func TestSetupKeysHandlers(t *testing.T) { }, } - handler := initSetupKeysTestMetaData(defaultSetupKey, newSetupKey, updatedDefaultSetupKey, adminUser) + handler := initSetupKeysTestMetaData(defaultSetupKey, newSetupKey, updatedDefaultSetupKey) for _, tc := range tt { t.Run(tc.name, func(t *testing.T) { recorder := httptest.NewRecorder() req := httptest.NewRequest(tc.requestType, tc.requestPath, tc.requestBody) + req = nbcontext.SetUserAuthInRequest(req, nbcontext.UserAuth{ + UserId: adminUser.Id, + Domain: "hotmail.com", + AccountId: "testAccountId", + }) router := mux.NewRouter() router.HandleFunc("/api/setup-keys", handler.getAllSetupKeys).Methods("GET", "OPTIONS") diff --git a/management/server/http/handlers/users/pat_handler.go b/management/server/http/handlers/users/pat_handler.go index 7b93d2ae1..84fbef93e 100644 --- a/management/server/http/handlers/users/pat_handler.go +++ b/management/server/http/handlers/users/pat_handler.go @@ -7,22 +7,20 @@ import ( "github.com/gorilla/mux" "github.com/netbirdio/netbird/management/server" + nbcontext "github.com/netbirdio/netbird/management/server/context" "github.com/netbirdio/netbird/management/server/http/api" - "github.com/netbirdio/netbird/management/server/http/configs" "github.com/netbirdio/netbird/management/server/http/util" - "github.com/netbirdio/netbird/management/server/jwtclaims" "github.com/netbirdio/netbird/management/server/status" "github.com/netbirdio/netbird/management/server/types" ) // patHandler is the nameserver group handler of the account type patHandler struct { - accountManager server.AccountManager - claimsExtractor *jwtclaims.ClaimsExtractor + accountManager server.AccountManager } -func addUsersTokensEndpoint(accountManager server.AccountManager, authCfg configs.AuthCfg, router *mux.Router) { - tokenHandler := newPATsHandler(accountManager, authCfg) +func addUsersTokensEndpoint(accountManager server.AccountManager, router *mux.Router) { + tokenHandler := newPATsHandler(accountManager) router.HandleFunc("/users/{userId}/tokens", tokenHandler.getAllTokens).Methods("GET", "OPTIONS") router.HandleFunc("/users/{userId}/tokens", tokenHandler.createToken).Methods("POST", "OPTIONS") router.HandleFunc("/users/{userId}/tokens/{tokenId}", tokenHandler.getToken).Methods("GET", "OPTIONS") @@ -30,25 +28,21 @@ func addUsersTokensEndpoint(accountManager server.AccountManager, authCfg config } // newPATsHandler creates a new patHandler HTTP handler -func newPATsHandler(accountManager server.AccountManager, authCfg configs.AuthCfg) *patHandler { +func newPATsHandler(accountManager server.AccountManager) *patHandler { return &patHandler{ accountManager: accountManager, - claimsExtractor: jwtclaims.NewClaimsExtractor( - jwtclaims.WithAudience(authCfg.Audience), - jwtclaims.WithUserIDClaim(authCfg.UserIDClaim), - ), } } // getAllTokens is HTTP GET handler that returns a list of all personal access tokens for the given user func (h *patHandler) getAllTokens(w http.ResponseWriter, r *http.Request) { - claims := h.claimsExtractor.FromRequestContext(r) - accountID, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims) + userAuth, err := nbcontext.GetUserAuthFromContext(r.Context()) if err != nil { util.WriteError(r.Context(), err, w) return } + accountID, userID := userAuth.AccountId, userAuth.UserId vars := mux.Vars(r) targetUserID := vars["userId"] if len(userID) == 0 { @@ -72,13 +66,13 @@ func (h *patHandler) getAllTokens(w http.ResponseWriter, r *http.Request) { // getToken is HTTP GET handler that returns a personal access token for the given user func (h *patHandler) getToken(w http.ResponseWriter, r *http.Request) { - claims := h.claimsExtractor.FromRequestContext(r) - accountID, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims) + userAuth, err := nbcontext.GetUserAuthFromContext(r.Context()) if err != nil { util.WriteError(r.Context(), err, w) return } + accountID, userID := userAuth.AccountId, userAuth.UserId vars := mux.Vars(r) targetUserID := vars["userId"] if len(targetUserID) == 0 { @@ -103,13 +97,13 @@ func (h *patHandler) getToken(w http.ResponseWriter, r *http.Request) { // createToken is HTTP POST handler that creates a personal access token for the given user func (h *patHandler) createToken(w http.ResponseWriter, r *http.Request) { - claims := h.claimsExtractor.FromRequestContext(r) - accountID, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims) + userAuth, err := nbcontext.GetUserAuthFromContext(r.Context()) if err != nil { util.WriteError(r.Context(), err, w) return } + accountID, userID := userAuth.AccountId, userAuth.UserId vars := mux.Vars(r) targetUserID := vars["userId"] if len(targetUserID) == 0 { @@ -135,13 +129,13 @@ func (h *patHandler) createToken(w http.ResponseWriter, r *http.Request) { // deleteToken is HTTP DELETE handler that deletes a personal access token for the given user func (h *patHandler) deleteToken(w http.ResponseWriter, r *http.Request) { - claims := h.claimsExtractor.FromRequestContext(r) - accountID, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims) + userAuth, err := nbcontext.GetUserAuthFromContext(r.Context()) if err != nil { util.WriteError(r.Context(), err, w) return } + accountID, userID := userAuth.AccountId, userAuth.UserId vars := mux.Vars(r) targetUserID := vars["userId"] if len(targetUserID) == 0 { diff --git a/management/server/http/handlers/users/pat_handler_test.go b/management/server/http/handlers/users/pat_handler_test.go index 9388067a4..6593de64a 100644 --- a/management/server/http/handlers/users/pat_handler_test.go +++ b/management/server/http/handlers/users/pat_handler_test.go @@ -12,11 +12,12 @@ import ( "github.com/google/go-cmp/cmp" "github.com/gorilla/mux" - "github.com/netbirdio/netbird/management/server/util" "github.com/stretchr/testify/assert" + "github.com/netbirdio/netbird/management/server/util" + + nbcontext "github.com/netbirdio/netbird/management/server/context" "github.com/netbirdio/netbird/management/server/http/api" - "github.com/netbirdio/netbird/management/server/jwtclaims" "github.com/netbirdio/netbird/management/server/mock_server" "github.com/netbirdio/netbird/management/server/status" "github.com/netbirdio/netbird/management/server/types" @@ -77,10 +78,6 @@ func initPATTestData() *patHandler { PersonalAccessToken: types.PersonalAccessToken{}, }, nil }, - - GetAccountIDFromTokenFunc: func(_ context.Context, claims jwtclaims.AuthorizationClaims) (string, string, error) { - return claims.AccountId, claims.UserId, nil - }, DeletePATFunc: func(_ context.Context, accountID string, initiatorUserID string, targetUserID string, tokenID string) error { if accountID != existingAccountID { return status.Errorf(status.NotFound, "account with ID %s not found", accountID) @@ -115,15 +112,6 @@ func initPATTestData() *patHandler { return []*types.PersonalAccessToken{testAccount.Users[existingUserID].PATs[existingTokenID], testAccount.Users[existingUserID].PATs["token2"]}, nil }, }, - claimsExtractor: jwtclaims.NewClaimsExtractor( - jwtclaims.WithFromRequestContext(func(r *http.Request) jwtclaims.AuthorizationClaims { - return jwtclaims.AuthorizationClaims{ - UserId: existingUserID, - Domain: testDomain, - AccountId: existingAccountID, - } - }), - ), } } @@ -185,6 +173,11 @@ func TestTokenHandlers(t *testing.T) { t.Run(tc.name, func(t *testing.T) { recorder := httptest.NewRecorder() req := httptest.NewRequest(tc.requestType, tc.requestPath, tc.requestBody) + req = nbcontext.SetUserAuthInRequest(req, nbcontext.UserAuth{ + UserId: existingUserID, + Domain: testDomain, + AccountId: existingAccountID, + }) router := mux.NewRouter() router.HandleFunc("/api/users/{userId}/tokens", p.getAllTokens).Methods("GET") diff --git a/management/server/http/handlers/users/users_handler.go b/management/server/http/handlers/users/users_handler.go index 7380dd97e..3869f21f0 100644 --- a/management/server/http/handlers/users/users_handler.go +++ b/management/server/http/handlers/users/users_handler.go @@ -9,39 +9,33 @@ import ( log "github.com/sirupsen/logrus" "github.com/netbirdio/netbird/management/server/http/api" - "github.com/netbirdio/netbird/management/server/http/configs" "github.com/netbirdio/netbird/management/server/http/util" "github.com/netbirdio/netbird/management/server/status" "github.com/netbirdio/netbird/management/server/types" "github.com/netbirdio/netbird/management/server" - "github.com/netbirdio/netbird/management/server/jwtclaims" + nbcontext "github.com/netbirdio/netbird/management/server/context" ) // handler is a handler that returns users of the account type handler struct { - accountManager server.AccountManager - claimsExtractor *jwtclaims.ClaimsExtractor + accountManager server.AccountManager } -func AddEndpoints(accountManager server.AccountManager, authCfg configs.AuthCfg, router *mux.Router) { - userHandler := newHandler(accountManager, authCfg) +func AddEndpoints(accountManager server.AccountManager, router *mux.Router) { + userHandler := newHandler(accountManager) router.HandleFunc("/users", userHandler.getAllUsers).Methods("GET", "OPTIONS") router.HandleFunc("/users/{userId}", userHandler.updateUser).Methods("PUT", "OPTIONS") router.HandleFunc("/users/{userId}", userHandler.deleteUser).Methods("DELETE", "OPTIONS") router.HandleFunc("/users", userHandler.createUser).Methods("POST", "OPTIONS") router.HandleFunc("/users/{userId}/invite", userHandler.inviteUser).Methods("POST", "OPTIONS") - addUsersTokensEndpoint(accountManager, authCfg, router) + addUsersTokensEndpoint(accountManager, router) } // newHandler creates a new UsersHandler HTTP handler -func newHandler(accountManager server.AccountManager, authCfg configs.AuthCfg) *handler { +func newHandler(accountManager server.AccountManager) *handler { return &handler{ accountManager: accountManager, - claimsExtractor: jwtclaims.NewClaimsExtractor( - jwtclaims.WithAudience(authCfg.Audience), - jwtclaims.WithUserIDClaim(authCfg.UserIDClaim), - ), } } @@ -52,13 +46,13 @@ func (h *handler) updateUser(w http.ResponseWriter, r *http.Request) { return } - claims := h.claimsExtractor.FromRequestContext(r) - accountID, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims) + userAuth, err := nbcontext.GetUserAuthFromContext(r.Context()) if err != nil { util.WriteError(r.Context(), err, w) return } + accountID, userID := userAuth.AccountId, userAuth.UserId vars := mux.Vars(r) targetUserID := vars["userId"] if len(targetUserID) == 0 { @@ -103,7 +97,7 @@ func (h *handler) updateUser(w http.ResponseWriter, r *http.Request) { util.WriteError(r.Context(), err, w) return } - util.WriteJSONObject(r.Context(), w, toUserResponse(newUser, claims.UserId)) + util.WriteJSONObject(r.Context(), w, toUserResponse(newUser, userID)) } // deleteUser is a DELETE request to delete a user @@ -113,13 +107,13 @@ func (h *handler) deleteUser(w http.ResponseWriter, r *http.Request) { return } - claims := h.claimsExtractor.FromRequestContext(r) - accountID, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims) + userAuth, err := nbcontext.GetUserAuthFromContext(r.Context()) if err != nil { util.WriteError(r.Context(), err, w) return } + accountID, userID := userAuth.AccountId, userAuth.UserId vars := mux.Vars(r) targetUserID := vars["userId"] if len(targetUserID) == 0 { @@ -143,12 +137,12 @@ func (h *handler) createUser(w http.ResponseWriter, r *http.Request) { return } - claims := h.claimsExtractor.FromRequestContext(r) - accountID, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims) + userAuth, err := nbcontext.GetUserAuthFromContext(r.Context()) if err != nil { util.WriteError(r.Context(), err, w) return } + accountID, userID := userAuth.AccountId, userAuth.UserId req := &api.PostApiUsersJSONRequestBody{} err = json.NewDecoder(r.Body).Decode(&req) @@ -184,7 +178,7 @@ func (h *handler) createUser(w http.ResponseWriter, r *http.Request) { util.WriteError(r.Context(), err, w) return } - util.WriteJSONObject(r.Context(), w, toUserResponse(newUser, claims.UserId)) + util.WriteJSONObject(r.Context(), w, toUserResponse(newUser, userID)) } // getAllUsers returns a list of users of the account this user belongs to. @@ -195,13 +189,13 @@ func (h *handler) getAllUsers(w http.ResponseWriter, r *http.Request) { return } - claims := h.claimsExtractor.FromRequestContext(r) - accountID, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims) + userAuth, err := nbcontext.GetUserAuthFromContext(r.Context()) if err != nil { util.WriteError(r.Context(), err, w) return } + accountID, userID := userAuth.AccountId, userAuth.UserId data, err := h.accountManager.GetUsersFromAccount(r.Context(), accountID, userID) if err != nil { util.WriteError(r.Context(), err, w) @@ -216,7 +210,7 @@ func (h *handler) getAllUsers(w http.ResponseWriter, r *http.Request) { continue } if serviceUser == "" { - users = append(users, toUserResponse(d, claims.UserId)) + users = append(users, toUserResponse(d, userID)) continue } @@ -227,7 +221,7 @@ func (h *handler) getAllUsers(w http.ResponseWriter, r *http.Request) { return } if includeServiceUser == d.IsServiceUser { - users = append(users, toUserResponse(d, claims.UserId)) + users = append(users, toUserResponse(d, userID)) } } @@ -242,12 +236,12 @@ func (h *handler) inviteUser(w http.ResponseWriter, r *http.Request) { return } - claims := h.claimsExtractor.FromRequestContext(r) - accountID, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims) + userAuth, err := nbcontext.GetUserAuthFromContext(r.Context()) if err != nil { util.WriteError(r.Context(), err, w) return } + accountID, userID := userAuth.AccountId, userAuth.UserId vars := mux.Vars(r) targetUserID := vars["userId"] diff --git a/management/server/http/handlers/users/users_handler_test.go b/management/server/http/handlers/users/users_handler_test.go index ff77cedff..a6a904a4c 100644 --- a/management/server/http/handlers/users/users_handler_test.go +++ b/management/server/http/handlers/users/users_handler_test.go @@ -13,8 +13,8 @@ import ( "github.com/gorilla/mux" "github.com/stretchr/testify/assert" + nbcontext "github.com/netbirdio/netbird/management/server/context" "github.com/netbirdio/netbird/management/server/http/api" - "github.com/netbirdio/netbird/management/server/jwtclaims" "github.com/netbirdio/netbird/management/server/mock_server" "github.com/netbirdio/netbird/management/server/status" "github.com/netbirdio/netbird/management/server/types" @@ -64,9 +64,6 @@ var usersTestAccount = &types.Account{ func initUsersTestData() *handler { return &handler{ accountManager: &mock_server.MockAccountManager{ - GetAccountIDFromTokenFunc: func(_ context.Context, claims jwtclaims.AuthorizationClaims) (string, string, error) { - return usersTestAccount.Id, claims.UserId, nil - }, GetUserByIDFunc: func(ctx context.Context, id string) (*types.User, error) { return usersTestAccount.Users[id], nil }, @@ -127,15 +124,6 @@ func initUsersTestData() *handler { return nil }, }, - claimsExtractor: jwtclaims.NewClaimsExtractor( - jwtclaims.WithFromRequestContext(func(r *http.Request) jwtclaims.AuthorizationClaims { - return jwtclaims.AuthorizationClaims{ - UserId: existingUserID, - Domain: testDomain, - AccountId: existingAccountID, - } - }), - ), } } @@ -158,6 +146,11 @@ func TestGetUsers(t *testing.T) { t.Run(tc.name, func(t *testing.T) { recorder := httptest.NewRecorder() req := httptest.NewRequest(tc.requestType, tc.requestPath, nil) + req = nbcontext.SetUserAuthInRequest(req, nbcontext.UserAuth{ + UserId: existingUserID, + Domain: testDomain, + AccountId: existingAccountID, + }) userHandler.getAllUsers(recorder, req) @@ -263,6 +256,11 @@ func TestUpdateUser(t *testing.T) { t.Run(tc.name, func(t *testing.T) { recorder := httptest.NewRecorder() req := httptest.NewRequest(tc.requestType, tc.requestPath, tc.requestBody) + req = nbcontext.SetUserAuthInRequest(req, nbcontext.UserAuth{ + UserId: existingUserID, + Domain: testDomain, + AccountId: existingAccountID, + }) router := mux.NewRouter() router.HandleFunc("/api/users/{userId}", userHandler.updateUser).Methods("PUT") @@ -355,6 +353,11 @@ func TestCreateUser(t *testing.T) { t.Run(tc.name, func(t *testing.T) { req := httptest.NewRequest(tc.requestType, tc.requestPath, tc.requestBody) rr := httptest.NewRecorder() + req = nbcontext.SetUserAuthInRequest(req, nbcontext.UserAuth{ + UserId: existingUserID, + Domain: testDomain, + AccountId: existingAccountID, + }) userHandler.createUser(rr, req) @@ -399,6 +402,12 @@ func TestInviteUser(t *testing.T) { t.Run(tc.name, func(t *testing.T) { req := httptest.NewRequest(tc.requestType, tc.requestPath, nil) req = mux.SetURLVars(req, tc.requestVars) + req = nbcontext.SetUserAuthInRequest(req, nbcontext.UserAuth{ + UserId: existingUserID, + Domain: testDomain, + AccountId: existingAccountID, + }) + rr := httptest.NewRecorder() userHandler.inviteUser(rr, req) @@ -452,6 +461,12 @@ func TestDeleteUser(t *testing.T) { t.Run(tc.name, func(t *testing.T) { req := httptest.NewRequest(tc.requestType, tc.requestPath, nil) req = mux.SetURLVars(req, tc.requestVars) + req = nbcontext.SetUserAuthInRequest(req, nbcontext.UserAuth{ + UserId: existingUserID, + Domain: testDomain, + AccountId: existingAccountID, + }) + rr := httptest.NewRecorder() userHandler.deleteUser(rr, req) diff --git a/management/server/http/middleware/access_control.go b/management/server/http/middleware/access_control.go index c5bdf5fe7..4ed90f47b 100644 --- a/management/server/http/middleware/access_control.go +++ b/management/server/http/middleware/access_control.go @@ -7,30 +7,24 @@ import ( log "github.com/sirupsen/logrus" + nbcontext "github.com/netbirdio/netbird/management/server/context" "github.com/netbirdio/netbird/management/server/http/middleware/bypass" "github.com/netbirdio/netbird/management/server/http/util" "github.com/netbirdio/netbird/management/server/status" "github.com/netbirdio/netbird/management/server/types" - - "github.com/netbirdio/netbird/management/server/jwtclaims" ) // GetUser function defines a function to fetch user from Account by jwtclaims.AuthorizationClaims -type GetUser func(ctx context.Context, claims jwtclaims.AuthorizationClaims) (*types.User, error) +type GetUser func(ctx context.Context, userAuth nbcontext.UserAuth) (*types.User, error) // AccessControl middleware to restrict to make POST/PUT/DELETE requests by admin only type AccessControl struct { - claimsExtract jwtclaims.ClaimsExtractor - getUser GetUser + getUser GetUser } // NewAccessControl instance constructor -func NewAccessControl(audience, userIDClaim string, getUser GetUser) *AccessControl { +func NewAccessControl(getUser GetUser) *AccessControl { return &AccessControl{ - claimsExtract: *jwtclaims.NewClaimsExtractor( - jwtclaims.WithAudience(audience), - jwtclaims.WithUserIDClaim(userIDClaim), - ), getUser: getUser, } } @@ -45,12 +39,16 @@ func (a *AccessControl) Handler(h http.Handler) http.Handler { return } - claims := a.claimsExtract.FromRequestContext(r) - - user, err := a.getUser(r.Context(), claims) + userAuth, err := nbcontext.GetUserAuthFromRequest(r) if err != nil { - log.WithContext(r.Context()).Errorf("failed to get user from claims: %s", err) - util.WriteError(r.Context(), status.Errorf(status.Unauthorized, "invalid JWT"), w) + log.WithContext(r.Context()).Errorf("failed to get user auth from request: %s", err) + util.WriteError(r.Context(), status.Errorf(status.Unauthorized, "invalid user auth"), w) + } + + user, err := a.getUser(r.Context(), userAuth) + if err != nil { + log.WithContext(r.Context()).Errorf("failed to get user: %s", err) + util.WriteError(r.Context(), status.Errorf(status.Unauthorized, "invalid user auth"), w) return } diff --git a/management/server/http/middleware/auth_middleware.go b/management/server/http/middleware/auth_middleware.go index dcf73259a..a8e6790a9 100644 --- a/management/server/http/middleware/auth_middleware.go +++ b/management/server/http/middleware/auth_middleware.go @@ -8,67 +8,41 @@ import ( "strings" "time" - "github.com/golang-jwt/jwt" log "github.com/sirupsen/logrus" - nbContext "github.com/netbirdio/netbird/management/server/context" + "github.com/netbirdio/netbird/management/server/auth" + nbcontext "github.com/netbirdio/netbird/management/server/context" "github.com/netbirdio/netbird/management/server/http/middleware/bypass" "github.com/netbirdio/netbird/management/server/http/util" - "github.com/netbirdio/netbird/management/server/jwtclaims" "github.com/netbirdio/netbird/management/server/status" - "github.com/netbirdio/netbird/management/server/types" ) -// GetAccountInfoFromPATFunc function -type GetAccountInfoFromPATFunc func(ctx context.Context, token string) (user *types.User, pat *types.PersonalAccessToken, domain string, category string, err error) - -// ValidateAndParseTokenFunc function -type ValidateAndParseTokenFunc func(ctx context.Context, token string) (*jwt.Token, error) - -// MarkPATUsedFunc function -type MarkPATUsedFunc func(ctx context.Context, token string) error - -// CheckUserAccessByJWTGroupsFunc function -type CheckUserAccessByJWTGroupsFunc func(ctx context.Context, claims jwtclaims.AuthorizationClaims) error +type EnsureAccountFunc func(ctx context.Context, userAuth nbcontext.UserAuth) (string, string, error) +type SyncUserJWTGroupsFunc func(ctx context.Context, userAuth nbcontext.UserAuth) error // AuthMiddleware middleware to verify personal access tokens (PAT) and JWT tokens type AuthMiddleware struct { - getAccountInfoFromPAT GetAccountInfoFromPATFunc - validateAndParseToken ValidateAndParseTokenFunc - markPATUsed MarkPATUsedFunc - checkUserAccessByJWTGroups CheckUserAccessByJWTGroupsFunc - claimsExtractor *jwtclaims.ClaimsExtractor - audience string - userIDClaim string + authManager auth.Manager + ensureAccount EnsureAccountFunc + syncUserJWTGroups SyncUserJWTGroupsFunc } -const ( - userProperty = "user" -) - // NewAuthMiddleware instance constructor -func NewAuthMiddleware(getAccountInfoFromPAT GetAccountInfoFromPATFunc, validateAndParseToken ValidateAndParseTokenFunc, - markPATUsed MarkPATUsedFunc, checkUserAccessByJWTGroups CheckUserAccessByJWTGroupsFunc, claimsExtractor *jwtclaims.ClaimsExtractor, - audience string, userIdClaim string) *AuthMiddleware { - if userIdClaim == "" { - userIdClaim = jwtclaims.UserIDClaim - } - +func NewAuthMiddleware( + authManager auth.Manager, + ensureAccount EnsureAccountFunc, + syncUserJWTGroups SyncUserJWTGroupsFunc, +) *AuthMiddleware { return &AuthMiddleware{ - getAccountInfoFromPAT: getAccountInfoFromPAT, - validateAndParseToken: validateAndParseToken, - markPATUsed: markPATUsed, - checkUserAccessByJWTGroups: checkUserAccessByJWTGroups, - claimsExtractor: claimsExtractor, - audience: audience, - userIDClaim: userIdClaim, + authManager: authManager, + ensureAccount: ensureAccount, + syncUserJWTGroups: syncUserJWTGroups, } } // Handler method of the middleware which authenticates a user either by JWT claims or by PAT func (m *AuthMiddleware) Handler(h http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - if bypass.ShouldBypass(r.URL.Path, h, w, r) { return } @@ -84,108 +58,111 @@ func (m *AuthMiddleware) Handler(h http.Handler) http.Handler { switch authType { case "bearer": - err := m.checkJWTFromRequest(w, r, auth) + request, err := m.checkJWTFromRequest(r, auth) if err != nil { - log.WithContext(r.Context()).Errorf("Error when validating JWT claims: %s", err.Error()) + log.WithContext(r.Context()).Errorf("Error when validating JWT: %s", err.Error()) util.WriteError(r.Context(), status.Errorf(status.Unauthorized, "token invalid"), w) return } + + h.ServeHTTP(w, request) case "token": - err := m.checkPATFromRequest(w, r, auth) + request, err := m.checkPATFromRequest(r, auth) if err != nil { - log.WithContext(r.Context()).Debugf("Error when validating PAT claims: %s", err.Error()) + log.WithContext(r.Context()).Debugf("Error when validating PAT: %s", err.Error()) util.WriteError(r.Context(), status.Errorf(status.Unauthorized, "token invalid"), w) return } + h.ServeHTTP(w, request) default: util.WriteError(r.Context(), status.Errorf(status.Unauthorized, "no valid authentication provided"), w) return } - claims := m.claimsExtractor.FromRequestContext(r) - //nolint - ctx := context.WithValue(r.Context(), nbContext.UserIDKey, claims.UserId) - //nolint - ctx = context.WithValue(ctx, nbContext.AccountIDKey, claims.AccountId) - h.ServeHTTP(w, r.WithContext(ctx)) }) } // CheckJWTFromRequest checks if the JWT is valid -func (m *AuthMiddleware) checkJWTFromRequest(w http.ResponseWriter, r *http.Request, auth []string) error { +func (m *AuthMiddleware) checkJWTFromRequest(r *http.Request, auth []string) (*http.Request, error) { token, err := getTokenFromJWTRequest(auth) // If an error occurs, call the error handler and return an error if err != nil { - return fmt.Errorf("Error extracting token: %w", err) + return r, fmt.Errorf("error extracting token: %w", err) } - validatedToken, err := m.validateAndParseToken(r.Context(), token) + ctx := r.Context() + + userAuth, validatedToken, err := m.authManager.ValidateAndParseToken(ctx, token) if err != nil { - return err + return r, err } - if validatedToken == nil { - return nil + if impersonate, ok := r.URL.Query()["account"]; ok && len(impersonate) == 1 { + userAuth.AccountId = impersonate[0] + userAuth.IsChild = ok } - if err := m.verifyUserAccess(r.Context(), validatedToken); err != nil { - return err + // we need to call this method because if user is new, we will automatically add it to existing or create a new account + accountId, _, err := m.ensureAccount(ctx, userAuth) + if err != nil { + return r, err } - // If we get here, everything worked and we can set the - // user property in context. - newRequest := r.WithContext(context.WithValue(r.Context(), userProperty, validatedToken)) //nolint - // Update the current request with the new context information. - *r = *newRequest - return nil -} + if userAuth.AccountId != accountId { + log.WithContext(ctx).Debugf("Auth middleware sets accountId from ensure, before %s, now %s", userAuth.AccountId, accountId) + userAuth.AccountId = accountId + } -// verifyUserAccess checks if a user, based on a validated JWT token, -// is allowed access, particularly in cases where the admin enabled JWT -// group propagation and designated certain groups with access permissions. -func (m *AuthMiddleware) verifyUserAccess(ctx context.Context, validatedToken *jwt.Token) error { - authClaims := m.claimsExtractor.FromToken(validatedToken) - return m.checkUserAccessByJWTGroups(ctx, authClaims) + userAuth, err = m.authManager.EnsureUserAccessByJWTGroups(ctx, userAuth, validatedToken) + if err != nil { + return r, err + } + + err = m.syncUserJWTGroups(ctx, userAuth) + if err != nil { + log.WithContext(ctx).Errorf("HTTP server failed to sync user JWT groups: %s", err) + } + + return nbcontext.SetUserAuthInRequest(r, userAuth), nil } // CheckPATFromRequest checks if the PAT is valid -func (m *AuthMiddleware) checkPATFromRequest(w http.ResponseWriter, r *http.Request, auth []string) error { +func (m *AuthMiddleware) checkPATFromRequest(r *http.Request, auth []string) (*http.Request, error) { token, err := getTokenFromPATRequest(auth) if err != nil { - return fmt.Errorf("error extracting token: %w", err) + return r, fmt.Errorf("error extracting token: %w", err) } - user, pat, accDomain, accCategory, err := m.getAccountInfoFromPAT(r.Context(), token) + ctx := r.Context() + user, pat, accDomain, accCategory, err := m.authManager.GetPATInfo(ctx, token) if err != nil { - return fmt.Errorf("invalid Token: %w", err) + return r, fmt.Errorf("invalid Token: %w", err) } if time.Now().After(pat.GetExpirationDate()) { - return fmt.Errorf("token expired") + return r, fmt.Errorf("token expired") } - err = m.markPATUsed(r.Context(), pat.ID) + err = m.authManager.MarkPATUsed(ctx, pat.ID) if err != nil { - return err + return r, err } - claimMaps := jwt.MapClaims{} - claimMaps[m.userIDClaim] = user.Id - claimMaps[m.audience+jwtclaims.AccountIDSuffix] = user.AccountID - claimMaps[m.audience+jwtclaims.DomainIDSuffix] = accDomain - claimMaps[m.audience+jwtclaims.DomainCategorySuffix] = accCategory - claimMaps[jwtclaims.IsToken] = true - jwtToken := jwt.NewWithClaims(jwt.SigningMethodHS256, claimMaps) - newRequest := r.WithContext(context.WithValue(r.Context(), jwtclaims.TokenUserProperty, jwtToken)) //nolint - // Update the current request with the new context information. - *r = *newRequest - return nil + userAuth := nbcontext.UserAuth{ + UserId: user.Id, + AccountId: user.AccountID, + Domain: accDomain, + DomainCategory: accCategory, + IsPAT: true, + } + + return nbcontext.SetUserAuthInRequest(r, userAuth), nil } // getTokenFromJWTRequest is a "TokenExtractor" that takes auth header parts and extracts // the JWT token from the Authorization header. func getTokenFromJWTRequest(authHeaderParts []string) (string, error) { if len(authHeaderParts) != 2 || strings.ToLower(authHeaderParts[0]) != "bearer" { - return "", errors.New("Authorization header format must be Bearer {token}") + return "", errors.New("authorization header format must be Bearer {token}") } return authHeaderParts[1], nil @@ -195,7 +172,7 @@ func getTokenFromJWTRequest(authHeaderParts []string) (string, error) { // the PAT token from the Authorization header. func getTokenFromPATRequest(authHeaderParts []string) (string, error) { if len(authHeaderParts) != 2 || strings.ToLower(authHeaderParts[0]) != "token" { - return "", errors.New("Authorization header format must be Token {token}") + return "", errors.New("authorization header format must be Token {token}") } return authHeaderParts[1], nil diff --git a/management/server/http/middleware/auth_middleware_test.go b/management/server/http/middleware/auth_middleware_test.go index c1686ed44..3dc7d51cb 100644 --- a/management/server/http/middleware/auth_middleware_test.go +++ b/management/server/http/middleware/auth_middleware_test.go @@ -9,10 +9,14 @@ import ( "time" "github.com/golang-jwt/jwt" + "github.com/stretchr/testify/assert" + + "github.com/netbirdio/netbird/management/server/auth" + nbjwt "github.com/netbirdio/netbird/management/server/auth/jwt" + nbcontext "github.com/netbirdio/netbird/management/server/context" "github.com/netbirdio/netbird/management/server/util" "github.com/netbirdio/netbird/management/server/http/middleware/bypass" - "github.com/netbirdio/netbird/management/server/jwtclaims" "github.com/netbirdio/netbird/management/server/types" ) @@ -58,17 +62,23 @@ func mockGetAccountInfoFromPAT(_ context.Context, token string) (user *types.Use return nil, nil, "", "", fmt.Errorf("PAT invalid") } -func mockValidateAndParseToken(_ context.Context, token string) (*jwt.Token, error) { +func mockValidateAndParseToken(_ context.Context, token string) (nbcontext.UserAuth, *jwt.Token, error) { if token == JWT { - return &jwt.Token{ - Claims: jwt.MapClaims{ - userIDClaim: userID, - audience + jwtclaims.AccountIDSuffix: accountID, + return nbcontext.UserAuth{ + UserId: userID, + AccountId: accountID, + Domain: testAccount.Domain, + DomainCategory: testAccount.DomainCategory, }, - Valid: true, - }, nil + &jwt.Token{ + Claims: jwt.MapClaims{ + userIDClaim: userID, + audience + nbjwt.AccountIDSuffix: accountID, + }, + Valid: true, + }, nil } - return nil, fmt.Errorf("JWT invalid") + return nbcontext.UserAuth{}, nil, fmt.Errorf("JWT invalid") } func mockMarkPATUsed(_ context.Context, token string) error { @@ -78,16 +88,20 @@ func mockMarkPATUsed(_ context.Context, token string) error { return fmt.Errorf("Should never get reached") } -func mockCheckUserAccessByJWTGroups(_ context.Context, claims jwtclaims.AuthorizationClaims) error { - if testAccount.Id != claims.AccountId { - return fmt.Errorf("account with id %s does not exist", claims.AccountId) +func mockEnsureUserAccessByJWTGroups(_ context.Context, userAuth nbcontext.UserAuth, token *jwt.Token) (nbcontext.UserAuth, error) { + if userAuth.IsChild || userAuth.IsPAT { + return userAuth, nil } - if _, ok := testAccount.Users[claims.UserId]; !ok { - return fmt.Errorf("user with id %s does not exist", claims.UserId) + if testAccount.Id != userAuth.AccountId { + return userAuth, fmt.Errorf("account with id %s does not exist", userAuth.AccountId) } - return nil + if _, ok := testAccount.Users[userAuth.UserId]; !ok { + return userAuth, fmt.Errorf("user with id %s does not exist", userAuth.UserId) + } + + return userAuth, nil } func TestAuthMiddleware_Handler(t *testing.T) { @@ -158,22 +172,24 @@ func TestAuthMiddleware_Handler(t *testing.T) { } nextHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - // do nothing + }) - claimsExtractor := jwtclaims.NewClaimsExtractor( - jwtclaims.WithAudience(audience), - jwtclaims.WithUserIDClaim(userIDClaim), - ) + mockAuth := &auth.MockManager{ + ValidateAndParseTokenFunc: mockValidateAndParseToken, + EnsureUserAccessByJWTGroupsFunc: mockEnsureUserAccessByJWTGroups, + MarkPATUsedFunc: mockMarkPATUsed, + GetPATInfoFunc: mockGetAccountInfoFromPAT, + } authMiddleware := NewAuthMiddleware( - mockGetAccountInfoFromPAT, - mockValidateAndParseToken, - mockMarkPATUsed, - mockCheckUserAccessByJWTGroups, - claimsExtractor, - audience, - userIDClaim, + mockAuth, + func(ctx context.Context, userAuth nbcontext.UserAuth) (string, string, error) { + return userAuth.AccountId, userAuth.UserId, nil + }, + func(ctx context.Context, userAuth nbcontext.UserAuth) error { + return nil + }, ) handlerToTest := authMiddleware.Handler(nextHandler) @@ -195,9 +211,115 @@ func TestAuthMiddleware_Handler(t *testing.T) { result := rec.Result() defer result.Body.Close() + if result.StatusCode != tc.expectedStatusCode { t.Errorf("expected status code %d, got %d", tc.expectedStatusCode, result.StatusCode) } }) } } + +func TestAuthMiddleware_Handler_Child(t *testing.T) { + tt := []struct { + name string + path string + authHeader string + expectedUserAuth *nbcontext.UserAuth // nil expects 401 response status + }{ + { + name: "Valid PAT Token", + path: "/test", + authHeader: "Token " + PAT, + expectedUserAuth: &nbcontext.UserAuth{ + AccountId: accountID, + UserId: userID, + Domain: testAccount.Domain, + DomainCategory: testAccount.DomainCategory, + IsPAT: true, + }, + }, + { + name: "Valid PAT Token ignores child", + path: "/test?account=xyz", + authHeader: "Token " + PAT, + expectedUserAuth: &nbcontext.UserAuth{ + AccountId: accountID, + UserId: userID, + Domain: testAccount.Domain, + DomainCategory: testAccount.DomainCategory, + IsPAT: true, + }, + }, + { + name: "Valid JWT Token", + path: "/test", + authHeader: "Bearer " + JWT, + expectedUserAuth: &nbcontext.UserAuth{ + AccountId: accountID, + UserId: userID, + Domain: testAccount.Domain, + DomainCategory: testAccount.DomainCategory, + }, + }, + + { + name: "Valid JWT Token with child", + path: "/test?account=xyz", + authHeader: "Bearer " + JWT, + expectedUserAuth: &nbcontext.UserAuth{ + AccountId: "xyz", + UserId: userID, + Domain: testAccount.Domain, + DomainCategory: testAccount.DomainCategory, + IsChild: true, + }, + }, + } + + mockAuth := &auth.MockManager{ + ValidateAndParseTokenFunc: mockValidateAndParseToken, + EnsureUserAccessByJWTGroupsFunc: mockEnsureUserAccessByJWTGroups, + MarkPATUsedFunc: mockMarkPATUsed, + GetPATInfoFunc: mockGetAccountInfoFromPAT, + } + + authMiddleware := NewAuthMiddleware( + mockAuth, + func(ctx context.Context, userAuth nbcontext.UserAuth) (string, string, error) { + return userAuth.AccountId, userAuth.UserId, nil + }, + func(ctx context.Context, userAuth nbcontext.UserAuth) error { + return nil + }, + ) + + for _, tc := range tt { + t.Run(tc.name, func(t *testing.T) { + handlerToTest := authMiddleware.Handler(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + userAuth, err := nbcontext.GetUserAuthFromRequest(r) + if tc.expectedUserAuth != nil { + assert.NoError(t, err) + assert.Equal(t, *tc.expectedUserAuth, userAuth) + } else { + assert.Error(t, err) + assert.Empty(t, userAuth) + } + })) + + req := httptest.NewRequest("GET", "http://testing"+tc.path, nil) + req.Header.Set("Authorization", tc.authHeader) + rec := httptest.NewRecorder() + + handlerToTest.ServeHTTP(rec, req) + + result := rec.Result() + defer result.Body.Close() + + if tc.expectedUserAuth != nil { + assert.Equal(t, 200, result.StatusCode) + } else { + assert.Equal(t, 401, result.StatusCode) + } + }) + } +} diff --git a/management/server/http/testing/benchmarks/peers_handler_benchmark_test.go b/management/server/http/testing/benchmarks/peers_handler_benchmark_test.go index 7f8eee6e7..e2c2c1d85 100644 --- a/management/server/http/testing/benchmarks/peers_handler_benchmark_test.go +++ b/management/server/http/testing/benchmarks/peers_handler_benchmark_test.go @@ -77,13 +77,13 @@ func BenchmarkUpdatePeer(b *testing.B) { func BenchmarkGetOnePeer(b *testing.B) { var expectedMetrics = map[string]testing_tools.PerformanceMetrics{ - "Peers - XS": {MinMsPerOpLocal: 15, MaxMsPerOpLocal: 40, MinMsPerOpCICD: 30, MaxMsPerOpCICD: 70}, - "Peers - S": {MinMsPerOpLocal: 1, MaxMsPerOpLocal: 5, MinMsPerOpCICD: 2, MaxMsPerOpCICD: 30}, - "Peers - M": {MinMsPerOpLocal: 9, MaxMsPerOpLocal: 18, MinMsPerOpCICD: 15, MaxMsPerOpCICD: 50}, - "Peers - L": {MinMsPerOpLocal: 40, MaxMsPerOpLocal: 90, MinMsPerOpCICD: 50, MaxMsPerOpCICD: 130}, - "Groups - L": {MinMsPerOpLocal: 80, MaxMsPerOpLocal: 130, MinMsPerOpCICD: 30, MaxMsPerOpCICD: 200}, - "Users - L": {MinMsPerOpLocal: 40, MaxMsPerOpLocal: 90, MinMsPerOpCICD: 50, MaxMsPerOpCICD: 130}, - "Setup Keys - L": {MinMsPerOpLocal: 40, MaxMsPerOpLocal: 90, MinMsPerOpCICD: 50, MaxMsPerOpCICD: 130}, + "Peers - XS": {MinMsPerOpLocal: 15, MaxMsPerOpLocal: 40, MinMsPerOpCICD: 2, MaxMsPerOpCICD: 70}, + "Peers - S": {MinMsPerOpLocal: 1, MaxMsPerOpLocal: 5, MinMsPerOpCICD: 2, MaxMsPerOpCICD: 70}, + "Peers - M": {MinMsPerOpLocal: 9, MaxMsPerOpLocal: 18, MinMsPerOpCICD: 2, MaxMsPerOpCICD: 70}, + "Peers - L": {MinMsPerOpLocal: 40, MaxMsPerOpLocal: 90, MinMsPerOpCICD: 5, MaxMsPerOpCICD: 200}, + "Groups - L": {MinMsPerOpLocal: 80, MaxMsPerOpLocal: 130, MinMsPerOpCICD: 5, MaxMsPerOpCICD: 200}, + "Users - L": {MinMsPerOpLocal: 40, MaxMsPerOpLocal: 90, MinMsPerOpCICD: 5, MaxMsPerOpCICD: 200}, + "Setup Keys - L": {MinMsPerOpLocal: 40, MaxMsPerOpLocal: 90, MinMsPerOpCICD: 5, MaxMsPerOpCICD: 200}, "Peers - XL": {MinMsPerOpLocal: 200, MaxMsPerOpLocal: 400, MinMsPerOpCICD: 200, MaxMsPerOpCICD: 750}, } @@ -111,9 +111,9 @@ func BenchmarkGetOnePeer(b *testing.B) { func BenchmarkGetAllPeers(b *testing.B) { var expectedMetrics = map[string]testing_tools.PerformanceMetrics{ - "Peers - XS": {MinMsPerOpLocal: 40, MaxMsPerOpLocal: 70, MinMsPerOpCICD: 50, MaxMsPerOpCICD: 150}, - "Peers - S": {MinMsPerOpLocal: 2, MaxMsPerOpLocal: 10, MinMsPerOpCICD: 5, MaxMsPerOpCICD: 30}, - "Peers - M": {MinMsPerOpLocal: 20, MaxMsPerOpLocal: 50, MinMsPerOpCICD: 20, MaxMsPerOpCICD: 70}, + "Peers - XS": {MinMsPerOpLocal: 40, MaxMsPerOpLocal: 70, MinMsPerOpCICD: 2, MaxMsPerOpCICD: 100}, + "Peers - S": {MinMsPerOpLocal: 2, MaxMsPerOpLocal: 10, MinMsPerOpCICD: 2, MaxMsPerOpCICD: 100}, + "Peers - M": {MinMsPerOpLocal: 20, MaxMsPerOpLocal: 50, MinMsPerOpCICD: 2, MaxMsPerOpCICD: 100}, "Peers - L": {MinMsPerOpLocal: 110, MaxMsPerOpLocal: 150, MinMsPerOpCICD: 100, MaxMsPerOpCICD: 300}, "Groups - L": {MinMsPerOpLocal: 150, MaxMsPerOpLocal: 200, MinMsPerOpCICD: 130, MaxMsPerOpCICD: 500}, "Users - L": {MinMsPerOpLocal: 100, MaxMsPerOpLocal: 170, MinMsPerOpCICD: 100, MaxMsPerOpCICD: 400}, diff --git a/management/server/http/testing/benchmarks/users_handler_benchmark_test.go b/management/server/http/testing/benchmarks/users_handler_benchmark_test.go index 0baf76328..b7deab334 100644 --- a/management/server/http/testing/benchmarks/users_handler_benchmark_test.go +++ b/management/server/http/testing/benchmarks/users_handler_benchmark_test.go @@ -48,13 +48,12 @@ func BenchmarkUpdateUser(b *testing.B) { log.SetOutput(io.Discard) defer log.SetOutput(os.Stderr) - recorder := httptest.NewRecorder() - for name, bc := range benchCasesUsers { b.Run(name, func(b *testing.B) { apiHandler, am, _ := testing_tools.BuildApiBlackBoxWithDBState(b, "../testdata/users.sql", nil, false) testing_tools.PopulateTestData(b, am.(*server.DefaultAccountManager), bc.Peers, bc.Groups, bc.Users, bc.SetupKeys) + recorder := httptest.NewRecorder() b.ResetTimer() start := time.Now() for i := 0; i < b.N; i++ { @@ -97,13 +96,12 @@ func BenchmarkGetOneUser(b *testing.B) { log.SetOutput(io.Discard) defer log.SetOutput(os.Stderr) - recorder := httptest.NewRecorder() - for name, bc := range benchCasesUsers { b.Run(name, func(b *testing.B) { apiHandler, am, _ := testing_tools.BuildApiBlackBoxWithDBState(b, "../testdata/users.sql", nil, false) testing_tools.PopulateTestData(b, am.(*server.DefaultAccountManager), bc.Peers, bc.Groups, bc.Users, bc.SetupKeys) + recorder := httptest.NewRecorder() b.ResetTimer() start := time.Now() for i := 0; i < b.N; i++ { @@ -118,26 +116,25 @@ func BenchmarkGetOneUser(b *testing.B) { func BenchmarkGetAllUsers(b *testing.B) { var expectedMetrics = map[string]testing_tools.PerformanceMetrics{ - "Users - XS": {MinMsPerOpLocal: 0, MaxMsPerOpLocal: 2, MinMsPerOpCICD: 0, MaxMsPerOpCICD: 10}, - "Users - S": {MinMsPerOpLocal: 0, MaxMsPerOpLocal: 2, MinMsPerOpCICD: 0, MaxMsPerOpCICD: 10}, - "Users - M": {MinMsPerOpLocal: 3, MaxMsPerOpLocal: 10, MinMsPerOpCICD: 5, MaxMsPerOpCICD: 15}, - "Users - L": {MinMsPerOpLocal: 10, MaxMsPerOpLocal: 20, MinMsPerOpCICD: 20, MaxMsPerOpCICD: 50}, - "Peers - L": {MinMsPerOpLocal: 15, MaxMsPerOpLocal: 25, MinMsPerOpCICD: 20, MaxMsPerOpCICD: 55}, - "Groups - L": {MinMsPerOpLocal: 15, MaxMsPerOpLocal: 25, MinMsPerOpCICD: 25, MaxMsPerOpCICD: 55}, - "Setup Keys - L": {MinMsPerOpLocal: 15, MaxMsPerOpLocal: 25, MinMsPerOpCICD: 25, MaxMsPerOpCICD: 55}, - "Users - XL": {MinMsPerOpLocal: 80, MaxMsPerOpLocal: 120, MinMsPerOpCICD: 100, MaxMsPerOpCICD: 300}, + "Users - XS": {MinMsPerOpLocal: 0, MaxMsPerOpLocal: 2, MinMsPerOpCICD: 0, MaxMsPerOpCICD: 75}, + "Users - S": {MinMsPerOpLocal: 0, MaxMsPerOpLocal: 2, MinMsPerOpCICD: 0, MaxMsPerOpCICD: 75}, + "Users - M": {MinMsPerOpLocal: 3, MaxMsPerOpLocal: 10, MinMsPerOpCICD: 0, MaxMsPerOpCICD: 75}, + "Users - L": {MinMsPerOpLocal: 10, MaxMsPerOpLocal: 20, MinMsPerOpCICD: 10, MaxMsPerOpCICD: 100}, + "Peers - L": {MinMsPerOpLocal: 15, MaxMsPerOpLocal: 25, MinMsPerOpCICD: 10, MaxMsPerOpCICD: 100}, + "Groups - L": {MinMsPerOpLocal: 15, MaxMsPerOpLocal: 25, MinMsPerOpCICD: 10, MaxMsPerOpCICD: 100}, + "Setup Keys - L": {MinMsPerOpLocal: 15, MaxMsPerOpLocal: 25, MinMsPerOpCICD: 10, MaxMsPerOpCICD: 100}, + "Users - XL": {MinMsPerOpLocal: 80, MaxMsPerOpLocal: 120, MinMsPerOpCICD: 50, MaxMsPerOpCICD: 300}, } log.SetOutput(io.Discard) defer log.SetOutput(os.Stderr) - recorder := httptest.NewRecorder() - for name, bc := range benchCasesUsers { b.Run(name, func(b *testing.B) { apiHandler, am, _ := testing_tools.BuildApiBlackBoxWithDBState(b, "../testdata/users.sql", nil, false) testing_tools.PopulateTestData(b, am.(*server.DefaultAccountManager), bc.Peers, bc.Groups, bc.Users, bc.SetupKeys) + recorder := httptest.NewRecorder() b.ResetTimer() start := time.Now() for i := 0; i < b.N; i++ { @@ -152,26 +149,25 @@ func BenchmarkGetAllUsers(b *testing.B) { func BenchmarkDeleteUsers(b *testing.B) { var expectedMetrics = map[string]testing_tools.PerformanceMetrics{ - "Users - XS": {MinMsPerOpLocal: 0, MaxMsPerOpLocal: 5, MinMsPerOpCICD: 2, MaxMsPerOpCICD: 15}, - "Users - S": {MinMsPerOpLocal: 0, MaxMsPerOpLocal: 5, MinMsPerOpCICD: 2, MaxMsPerOpCICD: 15}, - "Users - M": {MinMsPerOpLocal: 0, MaxMsPerOpLocal: 5, MinMsPerOpCICD: 2, MaxMsPerOpCICD: 15}, - "Users - L": {MinMsPerOpLocal: 0, MaxMsPerOpLocal: 5, MinMsPerOpCICD: 2, MaxMsPerOpCICD: 15}, - "Peers - L": {MinMsPerOpLocal: 0, MaxMsPerOpLocal: 5, MinMsPerOpCICD: 2, MaxMsPerOpCICD: 15}, - "Groups - L": {MinMsPerOpLocal: 0, MaxMsPerOpLocal: 5, MinMsPerOpCICD: 2, MaxMsPerOpCICD: 15}, - "Setup Keys - L": {MinMsPerOpLocal: 0, MaxMsPerOpLocal: 5, MinMsPerOpCICD: 2, MaxMsPerOpCICD: 15}, - "Users - XL": {MinMsPerOpLocal: 0, MaxMsPerOpLocal: 5, MinMsPerOpCICD: 2, MaxMsPerOpCICD: 15}, + "Users - XS": {MinMsPerOpLocal: 0, MaxMsPerOpLocal: 5, MinMsPerOpCICD: 1, MaxMsPerOpCICD: 50}, + "Users - S": {MinMsPerOpLocal: 0, MaxMsPerOpLocal: 5, MinMsPerOpCICD: 1, MaxMsPerOpCICD: 50}, + "Users - M": {MinMsPerOpLocal: 0, MaxMsPerOpLocal: 5, MinMsPerOpCICD: 1, MaxMsPerOpCICD: 50}, + "Users - L": {MinMsPerOpLocal: 0, MaxMsPerOpLocal: 5, MinMsPerOpCICD: 1, MaxMsPerOpCICD: 50}, + "Peers - L": {MinMsPerOpLocal: 0, MaxMsPerOpLocal: 5, MinMsPerOpCICD: 1, MaxMsPerOpCICD: 50}, + "Groups - L": {MinMsPerOpLocal: 0, MaxMsPerOpLocal: 5, MinMsPerOpCICD: 1, MaxMsPerOpCICD: 50}, + "Setup Keys - L": {MinMsPerOpLocal: 0, MaxMsPerOpLocal: 5, MinMsPerOpCICD: 1, MaxMsPerOpCICD: 50}, + "Users - XL": {MinMsPerOpLocal: 0, MaxMsPerOpLocal: 5, MinMsPerOpCICD: 1, MaxMsPerOpCICD: 50}, } log.SetOutput(io.Discard) defer log.SetOutput(os.Stderr) - recorder := httptest.NewRecorder() - for name, bc := range benchCasesUsers { b.Run(name, func(b *testing.B) { apiHandler, am, _ := testing_tools.BuildApiBlackBoxWithDBState(b, "../testdata/users.sql", nil, false) testing_tools.PopulateTestData(b, am.(*server.DefaultAccountManager), bc.Peers, bc.Groups, 1000, bc.SetupKeys) + recorder := httptest.NewRecorder() b.ResetTimer() start := time.Now() for i := 0; i < b.N; i++ { diff --git a/management/server/http/testing/testing_tools/tools.go b/management/server/http/testing/testing_tools/tools.go index a1e121fdf..1ed3355ef 100644 --- a/management/server/http/testing/testing_tools/tools.go +++ b/management/server/http/testing/testing_tools/tools.go @@ -3,6 +3,7 @@ package testing_tools import ( "bytes" "context" + "errors" "fmt" "io" "net" @@ -13,7 +14,11 @@ import ( "testing" "time" +<<<<<<< HEAD "github.com/netbirdio/management-integrations/integrations" +======= + "github.com/golang-jwt/jwt" +>>>>>>> main "github.com/stretchr/testify/assert" "golang.zx2c4.com/wireguard/wgctrl/wgtypes" @@ -23,11 +28,11 @@ import ( "github.com/netbirdio/netbird/management/server" "github.com/netbirdio/netbird/management/server/activity" + "github.com/netbirdio/netbird/management/server/auth" + nbcontext "github.com/netbirdio/netbird/management/server/context" "github.com/netbirdio/netbird/management/server/geolocation" "github.com/netbirdio/netbird/management/server/groups" nbhttp "github.com/netbirdio/netbird/management/server/http" - "github.com/netbirdio/netbird/management/server/http/configs" - "github.com/netbirdio/netbird/management/server/jwtclaims" "github.com/netbirdio/netbird/management/server/networks" "github.com/netbirdio/netbird/management/server/networks/resources" "github.com/netbirdio/netbird/management/server/networks/routers" @@ -36,6 +41,7 @@ import ( "github.com/netbirdio/netbird/management/server/store" "github.com/netbirdio/netbird/management/server/telemetry" "github.com/netbirdio/netbird/management/server/types" + "github.com/netbirdio/netbird/management/server/util" ) const ( @@ -120,13 +126,26 @@ func BuildApiBlackBoxWithDBState(t TB, sqlFile string, expectedPeerUpdate *serve t.Fatalf("Failed to create manager: %v", err) } + // @note this is required so that PAT's validate from store, but JWT's are mocked + authManager := auth.NewManager(store, "", "", "", "", []string{}, false) + authManagerMock := &auth.MockManager{ + ValidateAndParseTokenFunc: mockValidateAndParseToken, + EnsureUserAccessByJWTGroupsFunc: authManager.EnsureUserAccessByJWTGroups, + MarkPATUsedFunc: authManager.MarkPATUsed, + GetPATInfoFunc: authManager.GetPATInfo, + } + networksManagerMock := networks.NewManagerMock() resourcesManagerMock := resources.NewManagerMock() routersManagerMock := routers.NewManagerMock() groupsManagerMock := groups.NewManagerMock() +<<<<<<< HEAD permissionsManagerMock := permissions.NewManagerMock() peersManager := peers.NewManager(store, permissionsManagerMock) apiHandler, err := nbhttp.NewAPIHandler(context.Background(), am, networksManagerMock, resourcesManagerMock, routersManagerMock, groupsManagerMock, geoMock, &jwtclaims.JwtValidatorMock{}, metrics, configs.AuthCfg{}, validatorMock, proxyController, permissionsManagerMock, peersManager) +======= + apiHandler, err := nbhttp.NewAPIHandler(context.Background(), am, networksManagerMock, resourcesManagerMock, routersManagerMock, groupsManagerMock, geoMock, authManagerMock, metrics, &server.Config{}, validatorMock) +>>>>>>> main if err != nil { t.Fatalf("Failed to create API handler: %v", err) } @@ -316,3 +335,25 @@ func EvaluateBenchmarkResults(b *testing.B, name string, duration time.Duration, b.Fatalf("Benchmark %s failed: too slow (%.2f ms/op, maximum %.2f ms/op)", name, msPerOp, maxExpected) } } + +func mockValidateAndParseToken(_ context.Context, token string) (nbcontext.UserAuth, *jwt.Token, error) { + userAuth := nbcontext.UserAuth{} + + switch token { + case "testUserId", "testAdminId", "testOwnerId", "testServiceUserId", "testServiceAdminId", "blockedUserId": + userAuth.UserId = token + userAuth.AccountId = "testAccountId" + userAuth.Domain = "test.com" + userAuth.DomainCategory = "private" + case "otherUserId": + userAuth.UserId = "otherUserId" + userAuth.AccountId = "otherAccountId" + userAuth.Domain = "other.com" + userAuth.DomainCategory = "private" + case "invalidToken": + return userAuth, nil, errors.New("invalid token") + } + + jwtToken := jwt.New(jwt.SigningMethodHS256) + return userAuth, jwtToken, nil +} diff --git a/management/server/jwtclaims/claims.go b/management/server/jwtclaims/claims.go deleted file mode 100644 index 2527acbe3..000000000 --- a/management/server/jwtclaims/claims.go +++ /dev/null @@ -1,19 +0,0 @@ -package jwtclaims - -import ( - "time" - - "github.com/golang-jwt/jwt" -) - -// AuthorizationClaims stores authorization information from JWTs -type AuthorizationClaims struct { - UserId string - AccountId string - Domain string - DomainCategory string - LastLogin time.Time - Invited bool - - Raw jwt.MapClaims -} diff --git a/management/server/jwtclaims/extractor_test.go b/management/server/jwtclaims/extractor_test.go deleted file mode 100644 index eccd7c9e7..000000000 --- a/management/server/jwtclaims/extractor_test.go +++ /dev/null @@ -1,227 +0,0 @@ -package jwtclaims - -import ( - "context" - "net/http" - "testing" - "time" - - "github.com/golang-jwt/jwt" - "github.com/stretchr/testify/require" -) - -func newTestRequestWithJWT(t *testing.T, claims AuthorizationClaims, audience string) *http.Request { - t.Helper() - const layout = "2006-01-02T15:04:05.999Z" - - claimMaps := jwt.MapClaims{} - if claims.UserId != "" { - claimMaps[UserIDClaim] = claims.UserId - } - if claims.AccountId != "" { - claimMaps[audience+AccountIDSuffix] = claims.AccountId - } - if claims.Domain != "" { - claimMaps[audience+DomainIDSuffix] = claims.Domain - } - if claims.DomainCategory != "" { - claimMaps[audience+DomainCategorySuffix] = claims.DomainCategory - } - if claims.LastLogin != (time.Time{}) { - claimMaps[audience+LastLoginSuffix] = claims.LastLogin.Format(layout) - } - - if claims.Invited { - claimMaps[audience+Invited] = true - } - token := jwt.NewWithClaims(jwt.SigningMethodHS256, claimMaps) - r, err := http.NewRequest(http.MethodGet, "http://localhost", nil) - require.NoError(t, err, "creating testing request failed") - testRequest := r.WithContext(context.WithValue(r.Context(), TokenUserProperty, token)) // nolint - - return testRequest -} - -func TestExtractClaimsFromRequestContext(t *testing.T) { - type test struct { - name string - inputAuthorizationClaims AuthorizationClaims - inputAudiance string - testingFunc require.ComparisonAssertionFunc - expectedMSG string - } - - const layout = "2006-01-02T15:04:05.999Z" - lastLogin, _ := time.Parse(layout, "2023-08-17T09:30:40.465Z") - - testCase1 := test{ - name: "All Claim Fields", - inputAudiance: "https://login/", - inputAuthorizationClaims: AuthorizationClaims{ - UserId: "test", - Domain: "test.com", - AccountId: "testAcc", - LastLogin: lastLogin, - DomainCategory: "public", - Invited: true, - Raw: jwt.MapClaims{ - "https://login/wt_account_domain": "test.com", - "https://login/wt_account_domain_category": "public", - "https://login/wt_account_id": "testAcc", - "https://login/nb_last_login": lastLogin.Format(layout), - "sub": "test", - "https://login/" + Invited: true, - }, - }, - testingFunc: require.EqualValues, - expectedMSG: "extracted claims should match input claims", - } - - testCase2 := test{ - name: "Domain Is Empty", - inputAudiance: "https://login/", - inputAuthorizationClaims: AuthorizationClaims{ - UserId: "test", - AccountId: "testAcc", - Raw: jwt.MapClaims{ - "https://login/wt_account_id": "testAcc", - "sub": "test", - }, - }, - testingFunc: require.EqualValues, - expectedMSG: "extracted claims should match input claims", - } - - testCase3 := test{ - name: "Account ID Is Empty", - inputAudiance: "https://login/", - inputAuthorizationClaims: AuthorizationClaims{ - UserId: "test", - Domain: "test.com", - Raw: jwt.MapClaims{ - "https://login/wt_account_domain": "test.com", - "sub": "test", - }, - }, - testingFunc: require.EqualValues, - expectedMSG: "extracted claims should match input claims", - } - - testCase4 := test{ - name: "Category Is Empty", - inputAudiance: "https://login/", - inputAuthorizationClaims: AuthorizationClaims{ - UserId: "test", - Domain: "test.com", - AccountId: "testAcc", - Raw: jwt.MapClaims{ - "https://login/wt_account_domain": "test.com", - "https://login/wt_account_id": "testAcc", - "sub": "test", - }, - }, - testingFunc: require.EqualValues, - expectedMSG: "extracted claims should match input claims", - } - - testCase5 := test{ - name: "Only User ID Is set", - inputAudiance: "https://login/", - inputAuthorizationClaims: AuthorizationClaims{ - UserId: "test", - Raw: jwt.MapClaims{ - "sub": "test", - }, - }, - testingFunc: require.EqualValues, - expectedMSG: "extracted claims should match input claims", - } - - for _, testCase := range []test{testCase1, testCase2, testCase3, testCase4, testCase5} { - t.Run(testCase.name, func(t *testing.T) { - request := newTestRequestWithJWT(t, testCase.inputAuthorizationClaims, testCase.inputAudiance) - - extractor := NewClaimsExtractor(WithAudience(testCase.inputAudiance)) - extractedClaims := extractor.FromRequestContext(request) - - testCase.testingFunc(t, testCase.inputAuthorizationClaims, extractedClaims, testCase.expectedMSG) - }) - } -} - -func TestExtractClaimsSetOptions(t *testing.T) { - t.Helper() - type test struct { - name string - extractor *ClaimsExtractor - check func(t *testing.T, c test) - } - - testCase1 := test{ - name: "No custom options", - extractor: NewClaimsExtractor(), - check: func(t *testing.T, c test) { - t.Helper() - if c.extractor.authAudience != "" { - t.Error("audience should be empty") - return - } - if c.extractor.userIDClaim != UserIDClaim { - t.Errorf("user id claim should be default, expected %s, got %s", UserIDClaim, c.extractor.userIDClaim) - return - } - if c.extractor.FromRequestContext == nil { - t.Error("from request context should not be nil") - return - } - }, - } - - testCase2 := test{ - name: "Custom audience", - extractor: NewClaimsExtractor(WithAudience("https://login/")), - check: func(t *testing.T, c test) { - t.Helper() - if c.extractor.authAudience != "https://login/" { - t.Errorf("audience expected %s, got %s", "https://login/", c.extractor.authAudience) - return - } - }, - } - - testCase3 := test{ - name: "Custom user id claim", - extractor: NewClaimsExtractor(WithUserIDClaim("customUserId")), - check: func(t *testing.T, c test) { - t.Helper() - if c.extractor.userIDClaim != "customUserId" { - t.Errorf("user id claim expected %s, got %s", "customUserId", c.extractor.userIDClaim) - return - } - }, - } - - testCase4 := test{ - name: "Custom extractor from request context", - extractor: NewClaimsExtractor( - WithFromRequestContext(func(r *http.Request) AuthorizationClaims { - return AuthorizationClaims{ - UserId: "testCustomRequest", - } - })), - check: func(t *testing.T, c test) { - t.Helper() - claims := c.extractor.FromRequestContext(&http.Request{}) - if claims.UserId != "testCustomRequest" { - t.Errorf("user id claim expected %s, got %s", "testCustomRequest", claims.UserId) - return - } - }, - } - - for _, testCase := range []test{testCase1, testCase2, testCase3, testCase4} { - t.Run(testCase.name, func(t *testing.T) { - testCase.check(t, testCase) - }) - } -} diff --git a/management/server/jwtclaims/jwtValidator.go b/management/server/jwtclaims/jwtValidator.go deleted file mode 100644 index 79e59e76f..000000000 --- a/management/server/jwtclaims/jwtValidator.go +++ /dev/null @@ -1,349 +0,0 @@ -package jwtclaims - -import ( - "context" - "crypto/ecdsa" - "crypto/elliptic" - "crypto/rsa" - "encoding/base64" - "encoding/json" - "errors" - "fmt" - "math/big" - "net/http" - "strconv" - "strings" - "sync" - "time" - - "github.com/golang-jwt/jwt" - log "github.com/sirupsen/logrus" -) - -// Options is a struct for specifying configuration options for the middleware. -type Options struct { - // The function that will return the Key to validate the JWT. - // It can be either a shared secret or a public key. - // Default value: nil - ValidationKeyGetter jwt.Keyfunc - // The name of the property in the request where the user information - // from the JWT will be stored. - // Default value: "user" - UserProperty string - // The function that will be called when there's an error validating the token - // Default value: - CredentialsOptional bool - // A function that extracts the token from the request - // Default: FromAuthHeader (i.e., from Authorization header as bearer token) - Debug bool - // When set, all requests with the OPTIONS method will use authentication - // Default: false - EnableAuthOnOptions bool -} - -// Jwks is a collection of JSONWebKey obtained from Config.HttpServerConfig.AuthKeysLocation -type Jwks struct { - Keys []JSONWebKey `json:"keys"` - expiresInTime time.Time -} - -// The supported elliptic curves types -const ( - // p256 represents a cryptographic elliptical curve type. - p256 = "P-256" - - // p384 represents a cryptographic elliptical curve type. - p384 = "P-384" - - // p521 represents a cryptographic elliptical curve type. - p521 = "P-521" -) - -// JSONWebKey is a representation of a Jason Web Key -type JSONWebKey struct { - Kty string `json:"kty"` - Kid string `json:"kid"` - Use string `json:"use"` - N string `json:"n"` - E string `json:"e"` - Crv string `json:"crv"` - X string `json:"x"` - Y string `json:"y"` - X5c []string `json:"x5c"` -} - -type JWTValidator interface { - ValidateAndParse(ctx context.Context, token string) (*jwt.Token, error) -} - -// jwtValidatorImpl struct to handle token validation and parsing -type jwtValidatorImpl struct { - options Options -} - -var keyNotFound = errors.New("unable to find appropriate key") - -// NewJWTValidator constructor -func NewJWTValidator(ctx context.Context, issuer string, audienceList []string, keysLocation string, idpSignkeyRefreshEnabled bool) (JWTValidator, error) { - keys, err := getPemKeys(ctx, keysLocation) - if err != nil { - return nil, err - } - - var lock sync.Mutex - options := Options{ - ValidationKeyGetter: func(token *jwt.Token) (interface{}, error) { - // Verify 'aud' claim - var checkAud bool - for _, audience := range audienceList { - checkAud = token.Claims.(jwt.MapClaims).VerifyAudience(audience, false) - if checkAud { - break - } - } - if !checkAud { - return token, errors.New("invalid audience") - } - // Verify 'issuer' claim - checkIss := token.Claims.(jwt.MapClaims).VerifyIssuer(issuer, false) - if !checkIss { - return token, errors.New("invalid issuer") - } - - // If keys are rotated, verify the keys prior to token validation - if idpSignkeyRefreshEnabled { - // If the keys are invalid, retrieve new ones - if !keys.stillValid() { - lock.Lock() - defer lock.Unlock() - - refreshedKeys, err := getPemKeys(ctx, keysLocation) - if err != nil { - log.WithContext(ctx).Debugf("cannot get JSONWebKey: %v, falling back to old keys", err) - refreshedKeys = keys - } - - log.WithContext(ctx).Debugf("keys refreshed, new UTC expiration time: %s", refreshedKeys.expiresInTime.UTC()) - - keys = refreshedKeys - } - } - - publicKey, err := getPublicKey(ctx, token, keys) - if err == nil { - return publicKey, nil - } - - msg := fmt.Sprintf("getPublicKey error: %s", err) - if errors.Is(err, keyNotFound) && !idpSignkeyRefreshEnabled { - msg = fmt.Sprintf("getPublicKey error: %s. You can enable key refresh by setting HttpServerConfig.IdpSignKeyRefreshEnabled to true in your management.json file and restart the service", err) - } - - log.WithContext(ctx).Error(msg) - - return nil, err - }, - EnableAuthOnOptions: false, - } - - if options.UserProperty == "" { - options.UserProperty = "user" - } - - return &jwtValidatorImpl{ - options: options, - }, nil -} - -// ValidateAndParse validates the token and returns the parsed token -func (m *jwtValidatorImpl) ValidateAndParse(ctx context.Context, token string) (*jwt.Token, error) { - // If the token is empty... - if token == "" { - // Check if it was required - if m.options.CredentialsOptional { - log.WithContext(ctx).Debugf("no credentials found (CredentialsOptional=true)") - // No error, just no token (and that is ok given that CredentialsOptional is true) - return nil, nil //nolint:nilnil - } - - // If we get here, the required token is missing - errorMsg := "required authorization token not found" - log.WithContext(ctx).Debugf(" Error: No credentials found (CredentialsOptional=false)") - return nil, errors.New(errorMsg) - } - - // Now parse the token - parsedToken, err := jwt.Parse(token, m.options.ValidationKeyGetter) - - // Check if there was an error in parsing... - if err != nil { - log.WithContext(ctx).Errorf("error parsing token: %v", err) - return nil, fmt.Errorf("error parsing token: %w", err) - } - - // Check if the parsed token is valid... - if !parsedToken.Valid { - errorMsg := "token is invalid" - log.WithContext(ctx).Debug(errorMsg) - return nil, errors.New(errorMsg) - } - - return parsedToken, nil -} - -// stillValid returns true if the JSONWebKey still valid and have enough time to be used -func (jwks *Jwks) stillValid() bool { - return !jwks.expiresInTime.IsZero() && time.Now().Add(5*time.Second).Before(jwks.expiresInTime) -} - -func getPemKeys(ctx context.Context, keysLocation string) (*Jwks, error) { - resp, err := http.Get(keysLocation) - if err != nil { - return nil, err - } - defer resp.Body.Close() - - jwks := &Jwks{} - err = json.NewDecoder(resp.Body).Decode(jwks) - if err != nil { - return jwks, err - } - - cacheControlHeader := resp.Header.Get("Cache-Control") - expiresIn := getMaxAgeFromCacheHeader(ctx, cacheControlHeader) - jwks.expiresInTime = time.Now().Add(time.Duration(expiresIn) * time.Second) - - return jwks, err -} - -func getPublicKey(ctx context.Context, token *jwt.Token, jwks *Jwks) (interface{}, error) { - // todo as we load the jkws when the server is starting, we should build a JKS map with the pem cert at the boot time - - for k := range jwks.Keys { - if token.Header["kid"] != jwks.Keys[k].Kid { - continue - } - - if len(jwks.Keys[k].X5c) != 0 { - cert := "-----BEGIN CERTIFICATE-----\n" + jwks.Keys[k].X5c[0] + "\n-----END CERTIFICATE-----" - return jwt.ParseRSAPublicKeyFromPEM([]byte(cert)) - } - - if jwks.Keys[k].Kty == "RSA" { - log.WithContext(ctx).Debugf("generating PublicKey from RSA JWK") - return getPublicKeyFromRSA(jwks.Keys[k]) - } - if jwks.Keys[k].Kty == "EC" { - log.WithContext(ctx).Debugf("generating PublicKey from ECDSA JWK") - return getPublicKeyFromECDSA(jwks.Keys[k]) - } - - log.WithContext(ctx).Debugf("Key Type: %s not yet supported, please raise ticket!", jwks.Keys[k].Kty) - } - - return nil, keyNotFound -} - -func getPublicKeyFromECDSA(jwk JSONWebKey) (publicKey *ecdsa.PublicKey, err error) { - - if jwk.X == "" || jwk.Y == "" || jwk.Crv == "" { - return nil, fmt.Errorf("ecdsa key incomplete") - } - - var xCoordinate []byte - if xCoordinate, err = base64.RawURLEncoding.DecodeString(jwk.X); err != nil { - return nil, err - } - - var yCoordinate []byte - if yCoordinate, err = base64.RawURLEncoding.DecodeString(jwk.Y); err != nil { - return nil, err - } - - publicKey = &ecdsa.PublicKey{} - - var curve elliptic.Curve - switch jwk.Crv { - case p256: - curve = elliptic.P256() - case p384: - curve = elliptic.P384() - case p521: - curve = elliptic.P521() - } - - publicKey.Curve = curve - publicKey.X = big.NewInt(0).SetBytes(xCoordinate) - publicKey.Y = big.NewInt(0).SetBytes(yCoordinate) - - return publicKey, nil -} - -func getPublicKeyFromRSA(jwk JSONWebKey) (*rsa.PublicKey, error) { - - decodedE, err := base64.RawURLEncoding.DecodeString(jwk.E) - if err != nil { - return nil, err - } - decodedN, err := base64.RawURLEncoding.DecodeString(jwk.N) - if err != nil { - return nil, err - } - - var n, e big.Int - e.SetBytes(decodedE) - n.SetBytes(decodedN) - - return &rsa.PublicKey{ - E: int(e.Int64()), - N: &n, - }, nil -} - -// getMaxAgeFromCacheHeader extracts max-age directive from the Cache-Control header -func getMaxAgeFromCacheHeader(ctx context.Context, cacheControl string) int { - // Split into individual directives - directives := strings.Split(cacheControl, ",") - - for _, directive := range directives { - directive = strings.TrimSpace(directive) - if strings.HasPrefix(directive, "max-age=") { - // Extract the max-age value - maxAgeStr := strings.TrimPrefix(directive, "max-age=") - maxAge, err := strconv.Atoi(maxAgeStr) - if err != nil { - log.WithContext(ctx).Debugf("error parsing max-age: %v", err) - return 0 - } - - return maxAge - } - } - - return 0 -} - -type JwtValidatorMock struct{} - -func (j *JwtValidatorMock) ValidateAndParse(ctx context.Context, token string) (*jwt.Token, error) { - claimMaps := jwt.MapClaims{} - - switch token { - case "testUserId", "testAdminId", "testOwnerId", "testServiceUserId", "testServiceAdminId", "blockedUserId": - claimMaps[UserIDClaim] = token - claimMaps[AccountIDSuffix] = "testAccountId" - claimMaps[DomainIDSuffix] = "test.com" - claimMaps[DomainCategorySuffix] = "private" - case "otherUserId": - claimMaps[UserIDClaim] = "otherUserId" - claimMaps[AccountIDSuffix] = "otherAccountId" - claimMaps[DomainIDSuffix] = "other.com" - claimMaps[DomainCategorySuffix] = "private" - case "invalidToken": - return nil, errors.New("invalid token") - } - - jwtToken := jwt.NewWithClaims(jwt.SigningMethodHS256, claimMaps) - return jwtToken, nil -} - diff --git a/management/server/management_proto_test.go b/management/server/management_proto_test.go index 1c9703c58..c838c4a27 100644 --- a/management/server/management_proto_test.go +++ b/management/server/management_proto_test.go @@ -441,7 +441,7 @@ func startManagementForTest(t *testing.T, testFile string, config *Config) (*grp secretsManager := NewTimeBasedAuthSecretsManager(peersUpdateManager, config.TURNConfig, config.Relay) ephemeralMgr := NewEphemeralManager(store, accountManager) - mgmtServer, err := NewServer(context.Background(), config, accountManager, settings.NewManager(store), peersUpdateManager, secretsManager, nil, ephemeralMgr) + mgmtServer, err := NewServer(context.Background(), config, accountManager, settings.NewManager(store), peersUpdateManager, secretsManager, nil, ephemeralMgr, nil) if err != nil { return nil, nil, "", cleanup, err } diff --git a/management/server/management_test.go b/management/server/management_test.go index 56c295815..838065e49 100644 --- a/management/server/management_test.go +++ b/management/server/management_test.go @@ -206,6 +206,7 @@ func startServer( secretsManager, nil, nil, + nil, ) if err != nil { t.Fatalf("failed creating management server: %v", err) diff --git a/management/server/mock_server/account_mock.go b/management/server/mock_server/account_mock.go index d5bee97f3..5564aab01 100644 --- a/management/server/mock_server/account_mock.go +++ b/management/server/mock_server/account_mock.go @@ -13,14 +13,16 @@ import ( "github.com/netbirdio/netbird/management/domain" "github.com/netbirdio/netbird/management/server" "github.com/netbirdio/netbird/management/server/activity" + nbcontext "github.com/netbirdio/netbird/management/server/context" "github.com/netbirdio/netbird/management/server/idp" - "github.com/netbirdio/netbird/management/server/jwtclaims" nbpeer "github.com/netbirdio/netbird/management/server/peer" "github.com/netbirdio/netbird/management/server/posture" "github.com/netbirdio/netbird/management/server/types" "github.com/netbirdio/netbird/route" ) +var _ server.AccountManager = (*MockAccountManager)(nil) + type MockAccountManager struct { GetOrCreateAccountByUserFunc func(ctx context.Context, userId, domain string) (*types.Account, error) GetAccountFunc func(ctx context.Context, accountID string) (*types.Account, error) @@ -29,7 +31,7 @@ type MockAccountManager struct { GetSetupKeyFunc func(ctx context.Context, accountID, userID, keyID string) (*types.SetupKey, error) AccountExistsFunc func(ctx context.Context, accountID string) (bool, error) GetAccountIDByUserIdFunc func(ctx context.Context, userId, domain string) (string, error) - GetUserFunc func(ctx context.Context, claims jwtclaims.AuthorizationClaims) (*types.User, error) + GetUserFromUserAuthFunc func(ctx context.Context, userAuth nbcontext.UserAuth) (*types.User, error) ListUsersFunc func(ctx context.Context, accountID string) ([]*types.User, error) GetPeersFunc func(ctx context.Context, accountID, userID, nameFilter, ipFilter string) ([]*nbpeer.Peer, error) MarkPeerConnectedFunc func(ctx context.Context, peerKey string, connected bool, realIP net.IP) error @@ -54,8 +56,6 @@ type MockAccountManager struct { DeletePolicyFunc func(ctx context.Context, accountID, policyID, userID string) error ListPoliciesFunc func(ctx context.Context, accountID, userID string) ([]*types.Policy, error) GetUsersFromAccountFunc func(ctx context.Context, accountID, userID string) (map[string]*types.UserInfo, error) - GetPATInfoFunc func(ctx context.Context, token string) (*types.User, *types.PersonalAccessToken, string, string, error) - MarkPATUsedFunc func(ctx context.Context, pat string) error UpdatePeerMetaFunc func(ctx context.Context, peerID string, meta nbpeer.PeerSystemMeta) error UpdatePeerFunc func(ctx context.Context, accountID, userID string, peer *nbpeer.Peer) (*nbpeer.Peer, error) CreateRouteFunc func(ctx context.Context, accountID string, prefix netip.Prefix, networkType route.NetworkType, domains domain.List, peer string, peerGroups []string, description string, netID route.NetID, masquerade bool, metric int, groups, accessControlGroupIDs []string, enabled bool, userID string, keepRoute bool) (*route.Route, error) @@ -80,8 +80,7 @@ type MockAccountManager struct { DeleteNameServerGroupFunc func(ctx context.Context, accountID, nsGroupID, userID string) error ListNameServerGroupsFunc func(ctx context.Context, accountID string, userID string) ([]*nbdns.NameServerGroup, error) CreateUserFunc func(ctx context.Context, accountID, userID string, key *types.UserInfo) (*types.UserInfo, error) - GetAccountIDFromTokenFunc func(ctx context.Context, claims jwtclaims.AuthorizationClaims) (string, string, error) - CheckUserAccessByJWTGroupsFunc func(ctx context.Context, claims jwtclaims.AuthorizationClaims) error + GetAccountIDFromUserAuthFunc func(ctx context.Context, userAuth nbcontext.UserAuth) (string, string, error) DeleteAccountFunc func(ctx context.Context, accountID, userID string) error GetDNSDomainFunc func() string StoreEventFunc func(ctx context.Context, initiatorID, targetID, accountID string, activityID activity.ActivityDescriber, meta map[string]any) @@ -240,14 +239,6 @@ func (am *MockAccountManager) MarkPeerConnected(ctx context.Context, peerKey str return status.Errorf(codes.Unimplemented, "method MarkPeerConnected is not implemented") } -// GetPATInfo mock implementation of GetPATInfo from server.AccountManager interface -func (am *MockAccountManager) GetPATInfo(ctx context.Context, pat string) (*types.User, *types.PersonalAccessToken, string, string, error) { - if am.GetPATInfoFunc != nil { - return am.GetPATInfoFunc(ctx, pat) - } - return nil, nil, "", "", status.Errorf(codes.Unimplemented, "method GetPATInfo is not implemented") -} - // DeleteAccount mock implementation of DeleteAccount from server.AccountManager interface func (am *MockAccountManager) DeleteAccount(ctx context.Context, accountID, userID string) error { if am.DeleteAccountFunc != nil { @@ -256,14 +247,6 @@ func (am *MockAccountManager) DeleteAccount(ctx context.Context, accountID, user return status.Errorf(codes.Unimplemented, "method DeleteAccount is not implemented") } -// MarkPATUsed mock implementation of MarkPATUsed from server.AccountManager interface -func (am *MockAccountManager) MarkPATUsed(ctx context.Context, pat string) error { - if am.MarkPATUsedFunc != nil { - return am.MarkPATUsedFunc(ctx, pat) - } - return status.Errorf(codes.Unimplemented, "method MarkPATUsed is not implemented") -} - // CreatePAT mock implementation of GetPAT from server.AccountManager interface func (am *MockAccountManager) CreatePAT(ctx context.Context, accountID string, initiatorUserID string, targetUserID string, name string, expiresIn int) (*types.PersonalAccessTokenGenerated, error) { if am.CreatePATFunc != nil { @@ -430,11 +413,11 @@ func (am *MockAccountManager) UpdatePeerMeta(ctx context.Context, peerID string, } // GetUser mock implementation of GetUser from server.AccountManager interface -func (am *MockAccountManager) GetUser(ctx context.Context, claims jwtclaims.AuthorizationClaims) (*types.User, error) { - if am.GetUserFunc != nil { - return am.GetUserFunc(ctx, claims) +func (am *MockAccountManager) GetUserFromUserAuth(ctx context.Context, userAuth nbcontext.UserAuth) (*types.User, error) { + if am.GetUserFromUserAuthFunc != nil { + return am.GetUserFromUserAuthFunc(ctx, userAuth) } - return nil, status.Errorf(codes.Unimplemented, "method GetUser is not implemented") + return nil, status.Errorf(codes.Unimplemented, "method GetUserFromUserAuth is not implemented") } func (am *MockAccountManager) ListUsers(ctx context.Context, accountID string) ([]*types.User, error) { @@ -614,19 +597,11 @@ func (am *MockAccountManager) CreateUser(ctx context.Context, accountID, userID return nil, status.Errorf(codes.Unimplemented, "method CreateUser is not implemented") } -// GetAccountIDFromToken mocks GetAccountIDFromToken of the AccountManager interface -func (am *MockAccountManager) GetAccountIDFromToken(ctx context.Context, claims jwtclaims.AuthorizationClaims) (string, string, error) { - if am.GetAccountIDFromTokenFunc != nil { - return am.GetAccountIDFromTokenFunc(ctx, claims) +func (am *MockAccountManager) GetAccountIDFromUserAuth(ctx context.Context, userAuth nbcontext.UserAuth) (string, string, error) { + if am.GetAccountIDFromUserAuthFunc != nil { + return am.GetAccountIDFromUserAuthFunc(ctx, userAuth) } - return "", "", status.Errorf(codes.Unimplemented, "method GetAccountIDFromToken is not implemented") -} - -func (am *MockAccountManager) CheckUserAccessByJWTGroups(ctx context.Context, claims jwtclaims.AuthorizationClaims) error { - if am.CheckUserAccessByJWTGroupsFunc != nil { - return am.CheckUserAccessByJWTGroupsFunc(ctx, claims) - } - return status.Errorf(codes.Unimplemented, "method CheckUserAccessByJWTGroups is not implemented") + return "", "", status.Errorf(codes.Unimplemented, "method GetAccountIDFromUserAuth is not implemented") } // GetPeers mocks GetPeers of the AccountManager interface @@ -859,3 +834,7 @@ func (am *MockAccountManager) BuildUserInfosForAccount(ctx context.Context, acco } return nil, status.Errorf(codes.Unimplemented, "method BuildUserInfosForAccount is not implemented") } + +func (am *MockAccountManager) SyncUserJWTGroups(ctx context.Context, userAuth nbcontext.UserAuth) error { + return status.Errorf(codes.Unimplemented, "method SyncUserJWTGroups is not implemented") +} diff --git a/management/server/store/sql_store.go b/management/server/store/sql_store.go index 6d86557df..1dae3999b 100644 --- a/management/server/store/sql_store.go +++ b/management/server/store/sql_store.go @@ -15,7 +15,6 @@ import ( "sync" "time" - "github.com/netbirdio/netbird/management/server/util" log "github.com/sirupsen/logrus" "gorm.io/driver/mysql" "gorm.io/driver/postgres" @@ -24,6 +23,8 @@ import ( "gorm.io/gorm/clause" "gorm.io/gorm/logger" + "github.com/netbirdio/netbird/management/server/util" + nbdns "github.com/netbirdio/netbird/dns" "github.com/netbirdio/netbird/management/server/account" resourceTypes "github.com/netbirdio/netbird/management/server/networks/resources/types" @@ -615,6 +616,16 @@ func (s *SqlStore) GetResourceGroups(ctx context.Context, lockStrength LockingSt return groups, nil } +func (s *SqlStore) GetAccountsCounter(ctx context.Context) (int64, error) { + var count int64 + result := s.db.Model(&types.Account{}).Count(&count) + if result.Error != nil { + return 0, fmt.Errorf("failed to get all accounts counter: %w", result.Error) + } + + return count, nil +} + func (s *SqlStore) GetAllAccounts(ctx context.Context) (all []*types.Account) { var accounts []types.Account result := s.db.Find(&accounts) @@ -1035,6 +1046,13 @@ func NewSqliteStoreFromFileStore(ctx context.Context, fileStore *FileStore, data } for _, account := range fileStore.GetAllAccounts(ctx) { + _, err = account.GetGroupAll() + if err != nil { + if err := account.AddAllGroup(); err != nil { + return nil, err + } + } + err := store.SaveAccount(ctx, account) if err != nil { return nil, err diff --git a/management/server/store/sql_store_test.go b/management/server/store/sql_store_test.go index bdb5905bd..54649c5c1 100644 --- a/management/server/store/sql_store_test.go +++ b/management/server/store/sql_store_test.go @@ -15,7 +15,6 @@ import ( "time" "github.com/google/uuid" - "github.com/rs/xid" log "github.com/sirupsen/logrus" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" @@ -2045,52 +2044,12 @@ func newAccountWithId(ctx context.Context, accountID, userID, domain string) *ty }, } - if err := addAllGroup(acc); err != nil { + if err := acc.AddAllGroup(); err != nil { log.WithContext(ctx).Errorf("error adding all group to account %s: %v", acc.Id, err) } return acc } -// addAllGroup to account object if it doesn't exist -func addAllGroup(account *types.Account) error { - if len(account.Groups) == 0 { - allGroup := &types.Group{ - ID: xid.New().String(), - Name: "All", - Issued: types.GroupIssuedAPI, - } - for _, peer := range account.Peers { - allGroup.Peers = append(allGroup.Peers, peer.ID) - } - account.Groups = map[string]*types.Group{allGroup.ID: allGroup} - - id := xid.New().String() - - defaultPolicy := &types.Policy{ - ID: id, - Name: types.DefaultRuleName, - Description: types.DefaultRuleDescription, - Enabled: true, - Rules: []*types.PolicyRule{ - { - ID: id, - Name: types.DefaultRuleName, - Description: types.DefaultRuleDescription, - Enabled: true, - Sources: []string{allGroup.ID}, - Destinations: []string{allGroup.ID}, - Bidirectional: true, - Protocol: types.PolicyRuleProtocolALL, - Action: types.PolicyTrafficActionAccept, - }, - }, - } - - account.Policies = []*types.Policy{defaultPolicy} - } - return nil -} - func TestSqlStore_GetAccountNetworks(t *testing.T) { store, cleanup, err := NewTestStoreFromSQL(context.Background(), "../testdata/store.sql", t.TempDir()) t.Cleanup(cleanup) diff --git a/management/server/store/store.go b/management/server/store/store.go index e94ba2f35..d84d699bb 100644 --- a/management/server/store/store.go +++ b/management/server/store/store.go @@ -48,6 +48,7 @@ const ( ) type Store interface { + GetAccountsCounter(ctx context.Context) (int64, error) GetAllAccounts(ctx context.Context) []*types.Account GetAccount(ctx context.Context, accountID string) (*types.Account, error) AccountExists(ctx context.Context, lockStrength LockingStrength, id string) (bool, error) @@ -352,7 +353,46 @@ func NewTestStoreFromSQL(ctx context.Context, filename string, dataDir string) ( return nil, nil, fmt.Errorf("failed to create test store: %v", err) } - return getSqlStoreEngine(ctx, store, kind) + err = addAllGroupToAccount(ctx, store) + if err != nil { + return nil, nil, fmt.Errorf("failed to add all group to account: %v", err) + } + + + maxRetries := 2 + for i := 0; i < maxRetries; i++ { + sqlStore, cleanUp, err := getSqlStoreEngine(ctx, store, kind) + if err == nil { + return sqlStore, cleanUp, nil + } + if i < maxRetries-1 { + time.Sleep(100 * time.Millisecond) + } + } + return nil, nil, fmt.Errorf("failed to create test store after %d attempts: %v", maxRetries, err) +} + +func addAllGroupToAccount(ctx context.Context, store Store) error { + allAccounts := store.GetAllAccounts(ctx) + for _, account := range allAccounts { + shouldSave := false + + _, err := account.GetGroupAll() + if err != nil { + if err := account.AddAllGroup(); err != nil { + return err + } + shouldSave = true + } + + if shouldSave { + err = store.SaveAccount(ctx, account) + if err != nil { + return err + } + } + } + return nil } func getSqlStoreEngine(ctx context.Context, store *SqlStore, kind Engine) (Store, func(), error) { diff --git a/management/server/testdata/storev1.sql b/management/server/testdata/storev1.sql index cda333d4f..8b09ec2be 100644 --- a/management/server/testdata/storev1.sql +++ b/management/server/testdata/storev1.sql @@ -36,4 +36,3 @@ INSERT INTO peers VALUES('xlx9/9D8+ibnRiIIB8nHGMxGOzxV17r8ShPHgi4aYSM=','auth0|6 INSERT INTO peers VALUES('6kjbmVq1hmucVzvBXo5OucY5OYv+jSsB1jUTLq291Dw=','google-oauth2|103201118415301331038','6kjbmVq1hmucVzvBXo5OucY5OYv+jSsB1jUTLq291Dw=','5AFB60DB-61F2-4251-8E11-494847EE88E9','"100.64.0.2"','braginini','linux','Linux','21.04','x86_64','Ubuntu','','','','',NULL,'','','','{"Cloud":"","Platform":""}',NULL,'braginini','braginini','2021-12-24 16:12:05.994305438+01:00',0,0,0,'','',0,0,NULL,'2024-10-02 17:00:54.228182+02:00',0,'""','','',0); INSERT INTO peers VALUES('Ok+5QMdt/UjoktNOvicGYj+IX2g98p+0N2PJ3vJ45RI=','google-oauth2|103201118415301331038','Ok+5QMdt/UjoktNOvicGYj+IX2g98p+0N2PJ3vJ45RI=','A72E4DC2-00DE-4542-8A24-62945438104E','"100.64.0.1"','braginini','linux','Linux','21.04','x86_64','Ubuntu','','','','',NULL,'','','','{"Cloud":"","Platform":""}',NULL,'braginini','braginini-1','2021-12-24 16:11:27.015739803+01:00',0,0,0,'','',0,0,NULL,'2024-10-02 17:00:54.228182+02:00',1,'""','','',0); INSERT INTO installations VALUES(1,''); - diff --git a/management/server/types/account.go b/management/server/types/account.go index 4c68b9523..c890a7730 100644 --- a/management/server/types/account.go +++ b/management/server/types/account.go @@ -12,6 +12,7 @@ import ( "github.com/hashicorp/go-multierror" "github.com/miekg/dns" + "github.com/rs/xid" log "github.com/sirupsen/logrus" nbdns "github.com/netbirdio/netbird/dns" @@ -1525,3 +1526,43 @@ func getPoliciesSourcePeers(policies []*Policy, groups map[string]*Group) map[st return sourcePeers } + +// AddAllGroup to account object if it doesn't exist +func (a *Account) AddAllGroup() error { + if len(a.Groups) == 0 { + allGroup := &Group{ + ID: xid.New().String(), + Name: "All", + Issued: GroupIssuedAPI, + } + for _, peer := range a.Peers { + allGroup.Peers = append(allGroup.Peers, peer.ID) + } + a.Groups = map[string]*Group{allGroup.ID: allGroup} + + id := xid.New().String() + + defaultPolicy := &Policy{ + ID: id, + Name: DefaultRuleName, + Description: DefaultRuleDescription, + Enabled: true, + Rules: []*PolicyRule{ + { + ID: id, + Name: DefaultRuleName, + Description: DefaultRuleDescription, + Enabled: true, + Sources: []string{allGroup.ID}, + Destinations: []string{allGroup.ID}, + Bidirectional: true, + Protocol: PolicyRuleProtocolALL, + Action: PolicyTrafficActionAccept, + }, + }, + } + + a.Policies = []*Policy{defaultPolicy} + } + return nil +} diff --git a/management/server/user.go b/management/server/user.go index 6ba9b68d3..381879ae6 100644 --- a/management/server/user.go +++ b/management/server/user.go @@ -8,16 +8,16 @@ import ( "time" "github.com/google/uuid" + log "github.com/sirupsen/logrus" + "github.com/netbirdio/netbird/management/server/activity" nbContext "github.com/netbirdio/netbird/management/server/context" "github.com/netbirdio/netbird/management/server/idp" - "github.com/netbirdio/netbird/management/server/jwtclaims" nbpeer "github.com/netbirdio/netbird/management/server/peer" "github.com/netbirdio/netbird/management/server/status" "github.com/netbirdio/netbird/management/server/store" "github.com/netbirdio/netbird/management/server/types" "github.com/netbirdio/netbird/management/server/util" - log "github.com/sirupsen/logrus" ) // createServiceUser creates a new service user under the given account. @@ -174,31 +174,26 @@ func (am *DefaultAccountManager) GetUserByID(ctx context.Context, id string) (*t return am.Store.GetUserByUserID(ctx, store.LockingStrengthShare, id) } -// GetUser looks up a user by provided authorization claims. -// It will also create an account if didn't exist for this user before. -func (am *DefaultAccountManager) GetUser(ctx context.Context, claims jwtclaims.AuthorizationClaims) (*types.User, error) { - accountID, userID, err := am.GetAccountIDFromToken(ctx, claims) - if err != nil { - return nil, fmt.Errorf("failed to get account with token claims %v", err) - } - - user, err := am.Store.GetUserByUserID(ctx, store.LockingStrengthShare, userID) +// GetUser looks up a user by provided nbContext.UserAuths. +// Expects account to have been created already. +func (am *DefaultAccountManager) GetUserFromUserAuth(ctx context.Context, userAuth nbContext.UserAuth) (*types.User, error) { + user, err := am.Store.GetUserByUserID(ctx, store.LockingStrengthShare, userAuth.UserId) if err != nil { return nil, err } // this code should be outside of the am.GetAccountIDFromToken(claims) because this method is called also by the gRPC // server when user authenticates a device. And we need to separate the Dashboard login event from the Device login event. - newLogin := user.LastDashboardLoginChanged(claims.LastLogin) + newLogin := user.LastDashboardLoginChanged(userAuth.LastLogin) - err = am.Store.SaveUserLastLogin(ctx, accountID, userID, claims.LastLogin) + err = am.Store.SaveUserLastLogin(ctx, userAuth.AccountId, userAuth.UserId, userAuth.LastLogin) if err != nil { log.WithContext(ctx).Errorf("failed saving user last login: %v", err) } if newLogin { - meta := map[string]any{"timestamp": claims.LastLogin} - am.StoreEvent(ctx, claims.UserId, claims.UserId, accountID, activity.DashboardLogin, meta) + meta := map[string]any{"timestamp": userAuth.LastLogin} + am.StoreEvent(ctx, userAuth.UserId, userAuth.UserId, userAuth.AccountId, activity.DashboardLogin, meta) } return user, nil diff --git a/management/server/user_test.go b/management/server/user_test.go index 4a532c8a6..a180a761a 100644 --- a/management/server/user_test.go +++ b/management/server/user_test.go @@ -10,6 +10,8 @@ import ( "github.com/eko/gocache/v3/cache" cacheStore "github.com/eko/gocache/v3/store" "github.com/google/go-cmp/cmp" + + nbcontext "github.com/netbirdio/netbird/management/server/context" "github.com/netbirdio/netbird/management/server/util" "golang.org/x/exp/maps" @@ -25,7 +27,6 @@ import ( "github.com/netbirdio/netbird/management/server/activity" "github.com/netbirdio/netbird/management/server/idp" "github.com/netbirdio/netbird/management/server/integration_reference" - "github.com/netbirdio/netbird/management/server/jwtclaims" ) const ( @@ -925,11 +926,12 @@ func TestDefaultAccountManager_GetUser(t *testing.T) { eventStore: &activity.InMemoryEventStore{}, } - claims := jwtclaims.AuthorizationClaims{ - UserId: mockUserID, + claims := nbcontext.UserAuth{ + UserId: mockUserID, + AccountId: mockAccountID, } - user, err := am.GetUser(context.Background(), claims) + user, err := am.GetUserFromUserAuth(context.Background(), claims) if err != nil { t.Fatalf("Error when checking user role: %s", err) } diff --git a/release_files/ui-post-install.sh b/release_files/ui-post-install.sh new file mode 100644 index 000000000..f6e8ddf92 --- /dev/null +++ b/release_files/ui-post-install.sh @@ -0,0 +1,10 @@ +#!/bin/sh + +# Check if netbird-ui is running +if pgrep -x -f /usr/bin/netbird-ui >/dev/null 2>&1; +then + runner=$(ps --no-headers -o '%U' -p $(pgrep -x -f /usr/bin/netbird-ui) | sed 's/^[ \t]*//;s/[ \t]*$//') + # Only re-run if it was already running + pkill -x -f /usr/bin/netbird-ui >/dev/null 2>&1 + su -l - "$runner" -c 'nohup /usr/bin/netbird-ui > /dev/null 2>&1 &' +fi