diff --git a/.github/workflows/golang-test-darwin.yml b/.github/workflows/golang-test-darwin.yml
index 664e8be18..4571ce753 100644
--- a/.github/workflows/golang-test-darwin.yml
+++ b/.github/workflows/golang-test-darwin.yml
@@ -1,4 +1,4 @@
-name: Test Code Darwin
+name: "Darwin"
on:
push:
@@ -12,9 +12,7 @@ concurrency:
jobs:
test:
- strategy:
- matrix:
- store: ['sqlite']
+ name: "Client / Unit"
runs-on: macos-latest
steps:
- name: Install Go
diff --git a/.github/workflows/golang-test-freebsd.yml b/.github/workflows/golang-test-freebsd.yml
index 0f510cb3a..e1c688b1b 100644
--- a/.github/workflows/golang-test-freebsd.yml
+++ b/.github/workflows/golang-test-freebsd.yml
@@ -1,5 +1,4 @@
-
-name: Test Code FreeBSD
+name: "FreeBSD"
on:
push:
@@ -13,6 +12,7 @@ concurrency:
jobs:
test:
+ name: "Client / Unit"
runs-on: ubuntu-22.04
steps:
- uses: actions/checkout@v4
diff --git a/.github/workflows/golang-test-linux.yml b/.github/workflows/golang-test-linux.yml
index efe1a2654..3be8bcff3 100644
--- a/.github/workflows/golang-test-linux.yml
+++ b/.github/workflows/golang-test-linux.yml
@@ -316,7 +316,7 @@ jobs:
NETBIRD_STORE_ENGINE=${{ matrix.store }} \
go test -tags=devcert \
-exec "sudo --preserve-env=CI,NETBIRD_STORE_ENGINE" \
- -timeout 10m ./management/...
+ -timeout 20m ./management/...
benchmark:
name: "Management / Benchmark"
@@ -508,7 +508,7 @@ jobs:
NETBIRD_STORE_ENGINE=${{ matrix.store }} CI=true \
go test -tags=integration \
-exec 'sudo --preserve-env=CI,NETBIRD_STORE_ENGINE' \
- -timeout 10m ./management/...
+ -timeout 20m ./management/...
test_client_on_docker:
name: "Client (Docker) / Unit"
diff --git a/.github/workflows/golang-test-windows.yml b/.github/workflows/golang-test-windows.yml
index 782e4c30a..d9ff0a84b 100644
--- a/.github/workflows/golang-test-windows.yml
+++ b/.github/workflows/golang-test-windows.yml
@@ -1,4 +1,4 @@
-name: Test Code Windows
+name: "Windows"
on:
push:
@@ -14,6 +14,7 @@ concurrency:
jobs:
test:
+ name: "Client / Unit"
runs-on: windows-latest
steps:
- name: Checkout code
diff --git a/.github/workflows/golangci-lint.yml b/.github/workflows/golangci-lint.yml
index 6705a34ec..ca075d30f 100644
--- a/.github/workflows/golangci-lint.yml
+++ b/.github/workflows/golangci-lint.yml
@@ -1,4 +1,4 @@
-name: golangci-lint
+name: Lint
on: [pull_request]
permissions:
@@ -27,7 +27,14 @@ jobs:
fail-fast: false
matrix:
os: [macos-latest, windows-latest, ubuntu-latest]
- name: lint
+ include:
+ - os: macos-latest
+ display_name: Darwin
+ - os: windows-latest
+ display_name: Windows
+ - os: ubuntu-latest
+ display_name: Linux
+ name: ${{ matrix.display_name }}
runs-on: ${{ matrix.os }}
timeout-minutes: 15
steps:
diff --git a/.github/workflows/mobile-build-validation.yml b/.github/workflows/mobile-build-validation.yml
index dcf461a34..569956a54 100644
--- a/.github/workflows/mobile-build-validation.yml
+++ b/.github/workflows/mobile-build-validation.yml
@@ -1,4 +1,4 @@
-name: Mobile build validation
+name: Mobile
on:
push:
@@ -12,6 +12,7 @@ concurrency:
jobs:
android_build:
+ name: "Android / Build"
runs-on: ubuntu-latest
steps:
- name: Checkout repository
@@ -47,6 +48,7 @@ jobs:
CGO_ENABLED: 0
ANDROID_NDK_HOME: /usr/local/lib/android/sdk/ndk/23.1.7779620
ios_build:
+ name: "iOS / Build"
runs-on: macos-latest
steps:
- name: Checkout repository
diff --git a/.gitignore b/.gitignore
index d0b4f82dd..abb728b19 100644
--- a/.gitignore
+++ b/.gitignore
@@ -29,3 +29,4 @@ infrastructure_files/setup.env
infrastructure_files/setup-*.env
.vscode
.DS_Store
+vendor/
diff --git a/.goreleaser_ui.yaml b/.goreleaser_ui.yaml
index 983aa0e78..1dd649d1b 100644
--- a/.goreleaser_ui.yaml
+++ b/.goreleaser_ui.yaml
@@ -50,6 +50,8 @@ nfpms:
- netbird-ui
formats:
- deb
+ scripts:
+ postinstall: "release_files/ui-post-install.sh"
contents:
- src: client/ui/netbird.desktop
dst: /usr/share/applications/netbird.desktop
@@ -67,6 +69,8 @@ nfpms:
- netbird-ui
formats:
- rpm
+ scripts:
+ postinstall: "release_files/ui-post-install.sh"
contents:
- src: client/ui/netbird.desktop
dst: /usr/share/applications/netbird.desktop
diff --git a/README.md b/README.md
index 0537710e9..5b136eff6 100644
--- a/README.md
+++ b/README.md
@@ -1,4 +1,6 @@
+
+
@@ -31,6 +33,10 @@
+
+
+ Webinar: Securely Access Kubernetes without Port Forwarding and Jump Hosts
+
diff --git a/client/Dockerfile b/client/Dockerfile
index 2f5ff14ae..35c1d04c2 100644
--- a/client/Dockerfile
+++ b/client/Dockerfile
@@ -1,4 +1,4 @@
-FROM alpine:3.21.0
+FROM alpine:3.21.3
RUN apk add --no-cache ca-certificates iptables ip6tables
ENV NB_FOREGROUND_MODE=true
ENTRYPOINT [ "/usr/local/bin/netbird","up"]
diff --git a/client/cmd/debug.go b/client/cmd/debug.go
index c7ab87b47..c02f60aed 100644
--- a/client/cmd/debug.go
+++ b/client/cmd/debug.go
@@ -13,6 +13,7 @@ import (
"github.com/netbirdio/netbird/client/internal"
"github.com/netbirdio/netbird/client/proto"
"github.com/netbirdio/netbird/client/server"
+ nbstatus "github.com/netbirdio/netbird/client/status"
)
const errCloseConnection = "Failed to close connection: %v"
@@ -85,7 +86,7 @@ func debugBundle(cmd *cobra.Command, _ []string) error {
client := proto.NewDaemonServiceClient(conn)
resp, err := client.DebugBundle(cmd.Context(), &proto.DebugBundleRequest{
Anonymize: anonymizeFlag,
- Status: getStatusOutput(cmd),
+ Status: getStatusOutput(cmd, anonymizeFlag),
SystemInfo: debugSystemInfoFlag,
})
if err != nil {
@@ -196,7 +197,7 @@ func runForDuration(cmd *cobra.Command, args []string) error {
time.Sleep(3 * time.Second)
headerPostUp := fmt.Sprintf("----- Netbird post-up - Timestamp: %s", time.Now().Format(time.RFC3339))
- statusOutput := fmt.Sprintf("%s\n%s", headerPostUp, getStatusOutput(cmd))
+ statusOutput := fmt.Sprintf("%s\n%s", headerPostUp, getStatusOutput(cmd, anonymizeFlag))
if waitErr := waitForDurationOrCancel(cmd.Context(), duration, cmd); waitErr != nil {
return waitErr
@@ -206,7 +207,7 @@ func runForDuration(cmd *cobra.Command, args []string) error {
cmd.Println("Creating debug bundle...")
headerPreDown := fmt.Sprintf("----- Netbird pre-down - Timestamp: %s - Duration: %s", time.Now().Format(time.RFC3339), duration)
- statusOutput = fmt.Sprintf("%s\n%s\n%s", statusOutput, headerPreDown, getStatusOutput(cmd))
+ statusOutput = fmt.Sprintf("%s\n%s\n%s", statusOutput, headerPreDown, getStatusOutput(cmd, anonymizeFlag))
resp, err := client.DebugBundle(cmd.Context(), &proto.DebugBundleRequest{
Anonymize: anonymizeFlag,
@@ -271,13 +272,15 @@ func setNetworkMapPersistence(cmd *cobra.Command, args []string) error {
return nil
}
-func getStatusOutput(cmd *cobra.Command) string {
+func getStatusOutput(cmd *cobra.Command, anon bool) string {
var statusOutputString string
statusResp, err := getStatus(cmd.Context())
if err != nil {
cmd.PrintErrf("Failed to get status: %v\n", err)
} else {
- statusOutputString = parseToFullDetailSummary(convertToStatusOutputOverview(statusResp))
+ statusOutputString = nbstatus.ParseToFullDetailSummary(
+ nbstatus.ConvertToStatusOutputOverview(statusResp, anon, "", nil, nil, nil),
+ )
}
return statusOutputString
}
diff --git a/client/cmd/status.go b/client/cmd/status.go
index 1deef487b..0ddba8b2f 100644
--- a/client/cmd/status.go
+++ b/client/cmd/status.go
@@ -2,108 +2,20 @@ package cmd
import (
"context"
- "encoding/json"
"fmt"
"net"
"net/netip"
- "os"
- "runtime"
- "sort"
"strings"
- "time"
"github.com/spf13/cobra"
"google.golang.org/grpc/status"
- "gopkg.in/yaml.v3"
- "github.com/netbirdio/netbird/client/anonymize"
"github.com/netbirdio/netbird/client/internal"
- "github.com/netbirdio/netbird/client/internal/peer"
"github.com/netbirdio/netbird/client/proto"
+ nbstatus "github.com/netbirdio/netbird/client/status"
"github.com/netbirdio/netbird/util"
- "github.com/netbirdio/netbird/version"
)
-type peerStateDetailOutput struct {
- FQDN string `json:"fqdn" yaml:"fqdn"`
- IP string `json:"netbirdIp" yaml:"netbirdIp"`
- PubKey string `json:"publicKey" yaml:"publicKey"`
- Status string `json:"status" yaml:"status"`
- LastStatusUpdate time.Time `json:"lastStatusUpdate" yaml:"lastStatusUpdate"`
- ConnType string `json:"connectionType" yaml:"connectionType"`
- IceCandidateType iceCandidateType `json:"iceCandidateType" yaml:"iceCandidateType"`
- IceCandidateEndpoint iceCandidateType `json:"iceCandidateEndpoint" yaml:"iceCandidateEndpoint"`
- RelayAddress string `json:"relayAddress" yaml:"relayAddress"`
- LastWireguardHandshake time.Time `json:"lastWireguardHandshake" yaml:"lastWireguardHandshake"`
- TransferReceived int64 `json:"transferReceived" yaml:"transferReceived"`
- TransferSent int64 `json:"transferSent" yaml:"transferSent"`
- Latency time.Duration `json:"latency" yaml:"latency"`
- RosenpassEnabled bool `json:"quantumResistance" yaml:"quantumResistance"`
- Networks []string `json:"networks" yaml:"networks"`
-}
-
-type peersStateOutput struct {
- Total int `json:"total" yaml:"total"`
- Connected int `json:"connected" yaml:"connected"`
- Details []peerStateDetailOutput `json:"details" yaml:"details"`
-}
-
-type signalStateOutput struct {
- URL string `json:"url" yaml:"url"`
- Connected bool `json:"connected" yaml:"connected"`
- Error string `json:"error" yaml:"error"`
-}
-
-type managementStateOutput struct {
- URL string `json:"url" yaml:"url"`
- Connected bool `json:"connected" yaml:"connected"`
- Error string `json:"error" yaml:"error"`
-}
-
-type relayStateOutputDetail struct {
- URI string `json:"uri" yaml:"uri"`
- Available bool `json:"available" yaml:"available"`
- Error string `json:"error" yaml:"error"`
-}
-
-type relayStateOutput struct {
- Total int `json:"total" yaml:"total"`
- Available int `json:"available" yaml:"available"`
- Details []relayStateOutputDetail `json:"details" yaml:"details"`
-}
-
-type iceCandidateType struct {
- Local string `json:"local" yaml:"local"`
- Remote string `json:"remote" yaml:"remote"`
-}
-
-type nsServerGroupStateOutput struct {
- Servers []string `json:"servers" yaml:"servers"`
- Domains []string `json:"domains" yaml:"domains"`
- Enabled bool `json:"enabled" yaml:"enabled"`
- Error string `json:"error" yaml:"error"`
-}
-
-type statusOutputOverview struct {
- Peers peersStateOutput `json:"peers" yaml:"peers"`
- CliVersion string `json:"cliVersion" yaml:"cliVersion"`
- DaemonVersion string `json:"daemonVersion" yaml:"daemonVersion"`
- ManagementState managementStateOutput `json:"management" yaml:"management"`
- SignalState signalStateOutput `json:"signal" yaml:"signal"`
- Relays relayStateOutput `json:"relays" yaml:"relays"`
- IP string `json:"netbirdIp" yaml:"netbirdIp"`
- PubKey string `json:"publicKey" yaml:"publicKey"`
- KernelInterface bool `json:"usesKernelInterface" yaml:"usesKernelInterface"`
- FQDN string `json:"fqdn" yaml:"fqdn"`
- RosenpassEnabled bool `json:"quantumResistance" yaml:"quantumResistance"`
- RosenpassPermissive bool `json:"quantumResistancePermissive" yaml:"quantumResistancePermissive"`
-
- Networks []string `json:"networks" yaml:"networks"`
- NumberOfForwardingRules int `json:"forwardingRules" yaml:"forwardingRules"`
- NSServerGroups []nsServerGroupStateOutput `json:"dnsServers" yaml:"dnsServers"`
- Events []systemEventOutput `json:"events" yaml:"events"`
-}
-
var (
detailFlag bool
ipv4Flag bool
@@ -174,18 +86,17 @@ func statusFunc(cmd *cobra.Command, args []string) error {
return nil
}
- outputInformationHolder := convertToStatusOutputOverview(resp)
-
+ var outputInformationHolder = nbstatus.ConvertToStatusOutputOverview(resp, anonymizeFlag, statusFilter, prefixNamesFilter, prefixNamesFilterMap, ipsFilterMap)
var statusOutputString string
switch {
case detailFlag:
- statusOutputString = parseToFullDetailSummary(outputInformationHolder)
+ statusOutputString = nbstatus.ParseToFullDetailSummary(outputInformationHolder)
case jsonFlag:
- statusOutputString, err = parseToJSON(outputInformationHolder)
+ statusOutputString, err = nbstatus.ParseToJSON(outputInformationHolder)
case yamlFlag:
- statusOutputString, err = parseToYAML(outputInformationHolder)
+ statusOutputString, err = nbstatus.ParseToYAML(outputInformationHolder)
default:
- statusOutputString = parseGeneralSummary(outputInformationHolder, false, false, false)
+ statusOutputString = nbstatus.ParseGeneralSummary(outputInformationHolder, false, false, false)
}
if err != nil {
@@ -215,7 +126,6 @@ func getStatus(ctx context.Context) (*proto.StatusResponse, error) {
}
func parseFilters() error {
-
switch strings.ToLower(statusFilter) {
case "", "disconnected", "connected":
if strings.ToLower(statusFilter) != "" {
@@ -252,176 +162,6 @@ func enableDetailFlagWhenFilterFlag() {
}
}
-func convertToStatusOutputOverview(resp *proto.StatusResponse) statusOutputOverview {
- pbFullStatus := resp.GetFullStatus()
-
- managementState := pbFullStatus.GetManagementState()
- managementOverview := managementStateOutput{
- URL: managementState.GetURL(),
- Connected: managementState.GetConnected(),
- Error: managementState.Error,
- }
-
- signalState := pbFullStatus.GetSignalState()
- signalOverview := signalStateOutput{
- URL: signalState.GetURL(),
- Connected: signalState.GetConnected(),
- Error: signalState.Error,
- }
-
- relayOverview := mapRelays(pbFullStatus.GetRelays())
- peersOverview := mapPeers(resp.GetFullStatus().GetPeers())
-
- overview := statusOutputOverview{
- Peers: peersOverview,
- CliVersion: version.NetbirdVersion(),
- DaemonVersion: resp.GetDaemonVersion(),
- ManagementState: managementOverview,
- SignalState: signalOverview,
- Relays: relayOverview,
- IP: pbFullStatus.GetLocalPeerState().GetIP(),
- PubKey: pbFullStatus.GetLocalPeerState().GetPubKey(),
- KernelInterface: pbFullStatus.GetLocalPeerState().GetKernelInterface(),
- FQDN: pbFullStatus.GetLocalPeerState().GetFqdn(),
- RosenpassEnabled: pbFullStatus.GetLocalPeerState().GetRosenpassEnabled(),
- RosenpassPermissive: pbFullStatus.GetLocalPeerState().GetRosenpassPermissive(),
-
- Networks: pbFullStatus.GetLocalPeerState().GetNetworks(),
- NumberOfForwardingRules: int(pbFullStatus.GetNumberOfForwardingRules()),
- NSServerGroups: mapNSGroups(pbFullStatus.GetDnsServers()),
- Events: mapEvents(pbFullStatus.GetEvents()),
- }
-
- if anonymizeFlag {
- anonymizer := anonymize.NewAnonymizer(anonymize.DefaultAddresses())
- anonymizeOverview(anonymizer, &overview)
- }
-
- return overview
-}
-
-func mapRelays(relays []*proto.RelayState) relayStateOutput {
- var relayStateDetail []relayStateOutputDetail
-
- var relaysAvailable int
- for _, relay := range relays {
- available := relay.GetAvailable()
- relayStateDetail = append(relayStateDetail,
- relayStateOutputDetail{
- URI: relay.URI,
- Available: available,
- Error: relay.GetError(),
- },
- )
-
- if available {
- relaysAvailable++
- }
- }
-
- return relayStateOutput{
- Total: len(relays),
- Available: relaysAvailable,
- Details: relayStateDetail,
- }
-}
-
-func mapNSGroups(servers []*proto.NSGroupState) []nsServerGroupStateOutput {
- mappedNSGroups := make([]nsServerGroupStateOutput, 0, len(servers))
- for _, pbNsGroupServer := range servers {
- mappedNSGroups = append(mappedNSGroups, nsServerGroupStateOutput{
- Servers: pbNsGroupServer.GetServers(),
- Domains: pbNsGroupServer.GetDomains(),
- Enabled: pbNsGroupServer.GetEnabled(),
- Error: pbNsGroupServer.GetError(),
- })
- }
- return mappedNSGroups
-}
-
-func mapPeers(peers []*proto.PeerState) peersStateOutput {
- var peersStateDetail []peerStateDetailOutput
- peersConnected := 0
- for _, pbPeerState := range peers {
- localICE := ""
- remoteICE := ""
- localICEEndpoint := ""
- remoteICEEndpoint := ""
- relayServerAddress := ""
- connType := ""
- lastHandshake := time.Time{}
- transferReceived := int64(0)
- transferSent := int64(0)
-
- isPeerConnected := pbPeerState.ConnStatus == peer.StatusConnected.String()
- if skipDetailByFilters(pbPeerState, isPeerConnected) {
- continue
- }
- if isPeerConnected {
- peersConnected++
-
- localICE = pbPeerState.GetLocalIceCandidateType()
- remoteICE = pbPeerState.GetRemoteIceCandidateType()
- localICEEndpoint = pbPeerState.GetLocalIceCandidateEndpoint()
- remoteICEEndpoint = pbPeerState.GetRemoteIceCandidateEndpoint()
- connType = "P2P"
- if pbPeerState.Relayed {
- connType = "Relayed"
- }
- relayServerAddress = pbPeerState.GetRelayAddress()
- lastHandshake = pbPeerState.GetLastWireguardHandshake().AsTime().Local()
- transferReceived = pbPeerState.GetBytesRx()
- transferSent = pbPeerState.GetBytesTx()
- }
-
- timeLocal := pbPeerState.GetConnStatusUpdate().AsTime().Local()
- peerState := peerStateDetailOutput{
- IP: pbPeerState.GetIP(),
- PubKey: pbPeerState.GetPubKey(),
- Status: pbPeerState.GetConnStatus(),
- LastStatusUpdate: timeLocal,
- ConnType: connType,
- IceCandidateType: iceCandidateType{
- Local: localICE,
- Remote: remoteICE,
- },
- IceCandidateEndpoint: iceCandidateType{
- Local: localICEEndpoint,
- Remote: remoteICEEndpoint,
- },
- RelayAddress: relayServerAddress,
- FQDN: pbPeerState.GetFqdn(),
- LastWireguardHandshake: lastHandshake,
- TransferReceived: transferReceived,
- TransferSent: transferSent,
- Latency: pbPeerState.GetLatency().AsDuration(),
- RosenpassEnabled: pbPeerState.GetRosenpassEnabled(),
- Networks: pbPeerState.GetNetworks(),
- }
-
- peersStateDetail = append(peersStateDetail, peerState)
- }
-
- sortPeersByIP(peersStateDetail)
-
- peersOverview := peersStateOutput{
- Total: len(peersStateDetail),
- Connected: peersConnected,
- Details: peersStateDetail,
- }
- return peersOverview
-}
-
-func sortPeersByIP(peersStateDetail []peerStateDetailOutput) {
- if len(peersStateDetail) > 0 {
- sort.SliceStable(peersStateDetail, func(i, j int) bool {
- iAddr, _ := netip.ParseAddr(peersStateDetail[i].IP)
- jAddr, _ := netip.ParseAddr(peersStateDetail[j].IP)
- return iAddr.Compare(jAddr) == -1
- })
- }
-}
-
func parseInterfaceIP(interfaceIP string) string {
ip, _, err := net.ParseCIDR(interfaceIP)
if err != nil {
@@ -429,451 +169,3 @@ func parseInterfaceIP(interfaceIP string) string {
}
return fmt.Sprintf("%s\n", ip)
}
-
-func parseToJSON(overview statusOutputOverview) (string, error) {
- jsonBytes, err := json.Marshal(overview)
- if err != nil {
- return "", fmt.Errorf("json marshal failed")
- }
- return string(jsonBytes), err
-}
-
-func parseToYAML(overview statusOutputOverview) (string, error) {
- yamlBytes, err := yaml.Marshal(overview)
- if err != nil {
- return "", fmt.Errorf("yaml marshal failed")
- }
- return string(yamlBytes), nil
-}
-
-func parseGeneralSummary(overview statusOutputOverview, showURL bool, showRelays bool, showNameServers bool) string {
- var managementConnString string
- if overview.ManagementState.Connected {
- managementConnString = "Connected"
- if showURL {
- managementConnString = fmt.Sprintf("%s to %s", managementConnString, overview.ManagementState.URL)
- }
- } else {
- managementConnString = "Disconnected"
- if overview.ManagementState.Error != "" {
- managementConnString = fmt.Sprintf("%s, reason: %s", managementConnString, overview.ManagementState.Error)
- }
- }
-
- var signalConnString string
- if overview.SignalState.Connected {
- signalConnString = "Connected"
- if showURL {
- signalConnString = fmt.Sprintf("%s to %s", signalConnString, overview.SignalState.URL)
- }
- } else {
- signalConnString = "Disconnected"
- if overview.SignalState.Error != "" {
- signalConnString = fmt.Sprintf("%s, reason: %s", signalConnString, overview.SignalState.Error)
- }
- }
-
- interfaceTypeString := "Userspace"
- interfaceIP := overview.IP
- if overview.KernelInterface {
- interfaceTypeString = "Kernel"
- } else if overview.IP == "" {
- interfaceTypeString = "N/A"
- interfaceIP = "N/A"
- }
-
- var relaysString string
- if showRelays {
- for _, relay := range overview.Relays.Details {
- available := "Available"
- reason := ""
- if !relay.Available {
- available = "Unavailable"
- reason = fmt.Sprintf(", reason: %s", relay.Error)
- }
- relaysString += fmt.Sprintf("\n [%s] is %s%s", relay.URI, available, reason)
- }
- } else {
- relaysString = fmt.Sprintf("%d/%d Available", overview.Relays.Available, overview.Relays.Total)
- }
-
- networks := "-"
- if len(overview.Networks) > 0 {
- sort.Strings(overview.Networks)
- networks = strings.Join(overview.Networks, ", ")
- }
-
- var dnsServersString string
- if showNameServers {
- for _, nsServerGroup := range overview.NSServerGroups {
- enabled := "Available"
- if !nsServerGroup.Enabled {
- enabled = "Unavailable"
- }
- errorString := ""
- if nsServerGroup.Error != "" {
- errorString = fmt.Sprintf(", reason: %s", nsServerGroup.Error)
- errorString = strings.TrimSpace(errorString)
- }
-
- domainsString := strings.Join(nsServerGroup.Domains, ", ")
- if domainsString == "" {
- domainsString = "." // Show "." for the default zone
- }
- dnsServersString += fmt.Sprintf(
- "\n [%s] for [%s] is %s%s",
- strings.Join(nsServerGroup.Servers, ", "),
- domainsString,
- enabled,
- errorString,
- )
- }
- } else {
- dnsServersString = fmt.Sprintf("%d/%d Available", countEnabled(overview.NSServerGroups), len(overview.NSServerGroups))
- }
-
- rosenpassEnabledStatus := "false"
- if overview.RosenpassEnabled {
- rosenpassEnabledStatus = "true"
- if overview.RosenpassPermissive {
- rosenpassEnabledStatus = "true (permissive)" //nolint:gosec
- }
- }
-
- peersCountString := fmt.Sprintf("%d/%d Connected", overview.Peers.Connected, overview.Peers.Total)
-
- goos := runtime.GOOS
- goarch := runtime.GOARCH
- goarm := ""
- if goarch == "arm" {
- goarm = fmt.Sprintf(" (ARMv%s)", os.Getenv("GOARM"))
- }
-
- summary := fmt.Sprintf(
- "OS: %s\n"+
- "Daemon version: %s\n"+
- "CLI version: %s\n"+
- "Management: %s\n"+
- "Signal: %s\n"+
- "Relays: %s\n"+
- "Nameservers: %s\n"+
- "FQDN: %s\n"+
- "NetBird IP: %s\n"+
- "Interface type: %s\n"+
- "Quantum resistance: %s\n"+
- "Networks: %s\n"+
- "Forwarding rules: %d\n"+
- "Peers count: %s\n",
- fmt.Sprintf("%s/%s%s", goos, goarch, goarm),
- overview.DaemonVersion,
- version.NetbirdVersion(),
- managementConnString,
- signalConnString,
- relaysString,
- dnsServersString,
- overview.FQDN,
- interfaceIP,
- interfaceTypeString,
- rosenpassEnabledStatus,
- networks,
- overview.NumberOfForwardingRules,
- peersCountString,
- )
- return summary
-}
-
-func parseToFullDetailSummary(overview statusOutputOverview) string {
- parsedPeersString := parsePeers(overview.Peers, overview.RosenpassEnabled, overview.RosenpassPermissive)
- parsedEventsString := parseEvents(overview.Events)
- summary := parseGeneralSummary(overview, true, true, true)
-
- return fmt.Sprintf(
- "Peers detail:"+
- "%s\n"+
- "Events:"+
- "%s\n"+
- "%s",
- parsedPeersString,
- parsedEventsString,
- summary,
- )
-}
-
-func parsePeers(peers peersStateOutput, rosenpassEnabled, rosenpassPermissive bool) string {
- var (
- peersString = ""
- )
-
- for _, peerState := range peers.Details {
-
- localICE := "-"
- if peerState.IceCandidateType.Local != "" {
- localICE = peerState.IceCandidateType.Local
- }
-
- remoteICE := "-"
- if peerState.IceCandidateType.Remote != "" {
- remoteICE = peerState.IceCandidateType.Remote
- }
-
- localICEEndpoint := "-"
- if peerState.IceCandidateEndpoint.Local != "" {
- localICEEndpoint = peerState.IceCandidateEndpoint.Local
- }
-
- remoteICEEndpoint := "-"
- if peerState.IceCandidateEndpoint.Remote != "" {
- remoteICEEndpoint = peerState.IceCandidateEndpoint.Remote
- }
-
- rosenpassEnabledStatus := "false"
- if rosenpassEnabled {
- if peerState.RosenpassEnabled {
- rosenpassEnabledStatus = "true"
- } else {
- if rosenpassPermissive {
- rosenpassEnabledStatus = "false (remote didn't enable quantum resistance)"
- } else {
- rosenpassEnabledStatus = "false (connection won't work without a permissive mode)"
- }
- }
- } else {
- if peerState.RosenpassEnabled {
- rosenpassEnabledStatus = "false (connection might not work without a remote permissive mode)"
- }
- }
-
- networks := "-"
- if len(peerState.Networks) > 0 {
- sort.Strings(peerState.Networks)
- networks = strings.Join(peerState.Networks, ", ")
- }
-
- peerString := fmt.Sprintf(
- "\n %s:\n"+
- " NetBird IP: %s\n"+
- " Public key: %s\n"+
- " Status: %s\n"+
- " -- detail --\n"+
- " Connection type: %s\n"+
- " ICE candidate (Local/Remote): %s/%s\n"+
- " ICE candidate endpoints (Local/Remote): %s/%s\n"+
- " Relay server address: %s\n"+
- " Last connection update: %s\n"+
- " Last WireGuard handshake: %s\n"+
- " Transfer status (received/sent) %s/%s\n"+
- " Quantum resistance: %s\n"+
- " Networks: %s\n"+
- " Latency: %s\n",
- peerState.FQDN,
- peerState.IP,
- peerState.PubKey,
- peerState.Status,
- peerState.ConnType,
- localICE,
- remoteICE,
- localICEEndpoint,
- remoteICEEndpoint,
- peerState.RelayAddress,
- timeAgo(peerState.LastStatusUpdate),
- timeAgo(peerState.LastWireguardHandshake),
- toIEC(peerState.TransferReceived),
- toIEC(peerState.TransferSent),
- rosenpassEnabledStatus,
- networks,
- peerState.Latency.String(),
- )
-
- peersString += peerString
- }
- return peersString
-}
-
-func skipDetailByFilters(peerState *proto.PeerState, isConnected bool) bool {
- statusEval := false
- ipEval := false
- nameEval := true
-
- if statusFilter != "" {
- lowerStatusFilter := strings.ToLower(statusFilter)
- if lowerStatusFilter == "disconnected" && isConnected {
- statusEval = true
- } else if lowerStatusFilter == "connected" && !isConnected {
- statusEval = true
- }
- }
-
- if len(ipsFilter) > 0 {
- _, ok := ipsFilterMap[peerState.IP]
- if !ok {
- ipEval = true
- }
- }
-
- if len(prefixNamesFilter) > 0 {
- for prefixNameFilter := range prefixNamesFilterMap {
- if strings.HasPrefix(peerState.Fqdn, prefixNameFilter) {
- nameEval = false
- break
- }
- }
- } else {
- nameEval = false
- }
-
- return statusEval || ipEval || nameEval
-}
-
-func toIEC(b int64) string {
- const unit = 1024
- if b < unit {
- return fmt.Sprintf("%d B", b)
- }
- div, exp := int64(unit), 0
- for n := b / unit; n >= unit; n /= unit {
- div *= unit
- exp++
- }
- return fmt.Sprintf("%.1f %ciB",
- float64(b)/float64(div), "KMGTPE"[exp])
-}
-
-func countEnabled(dnsServers []nsServerGroupStateOutput) int {
- count := 0
- for _, server := range dnsServers {
- if server.Enabled {
- count++
- }
- }
- return count
-}
-
-// timeAgo returns a string representing the duration since the provided time in a human-readable format.
-func timeAgo(t time.Time) string {
- if t.IsZero() || t.Equal(time.Unix(0, 0)) {
- return "-"
- }
- duration := time.Since(t)
- switch {
- case duration < time.Second:
- return "Now"
- case duration < time.Minute:
- seconds := int(duration.Seconds())
- if seconds == 1 {
- return "1 second ago"
- }
- return fmt.Sprintf("%d seconds ago", seconds)
- case duration < time.Hour:
- minutes := int(duration.Minutes())
- seconds := int(duration.Seconds()) % 60
- if minutes == 1 {
- if seconds == 1 {
- return "1 minute, 1 second ago"
- } else if seconds > 0 {
- return fmt.Sprintf("1 minute, %d seconds ago", seconds)
- }
- return "1 minute ago"
- }
- if seconds > 0 {
- return fmt.Sprintf("%d minutes, %d seconds ago", minutes, seconds)
- }
- return fmt.Sprintf("%d minutes ago", minutes)
- case duration < 24*time.Hour:
- hours := int(duration.Hours())
- minutes := int(duration.Minutes()) % 60
- if hours == 1 {
- if minutes == 1 {
- return "1 hour, 1 minute ago"
- } else if minutes > 0 {
- return fmt.Sprintf("1 hour, %d minutes ago", minutes)
- }
- return "1 hour ago"
- }
- if minutes > 0 {
- return fmt.Sprintf("%d hours, %d minutes ago", hours, minutes)
- }
- return fmt.Sprintf("%d hours ago", hours)
- }
-
- days := int(duration.Hours()) / 24
- hours := int(duration.Hours()) % 24
- if days == 1 {
- if hours == 1 {
- return "1 day, 1 hour ago"
- } else if hours > 0 {
- return fmt.Sprintf("1 day, %d hours ago", hours)
- }
- return "1 day ago"
- }
- if hours > 0 {
- return fmt.Sprintf("%d days, %d hours ago", days, hours)
- }
- return fmt.Sprintf("%d days ago", days)
-}
-
-func anonymizePeerDetail(a *anonymize.Anonymizer, peer *peerStateDetailOutput) {
- peer.FQDN = a.AnonymizeDomain(peer.FQDN)
- if localIP, port, err := net.SplitHostPort(peer.IceCandidateEndpoint.Local); err == nil {
- peer.IceCandidateEndpoint.Local = fmt.Sprintf("%s:%s", a.AnonymizeIPString(localIP), port)
- }
- if remoteIP, port, err := net.SplitHostPort(peer.IceCandidateEndpoint.Remote); err == nil {
- peer.IceCandidateEndpoint.Remote = fmt.Sprintf("%s:%s", a.AnonymizeIPString(remoteIP), port)
- }
-
- peer.RelayAddress = a.AnonymizeURI(peer.RelayAddress)
-
- for i, route := range peer.Networks {
- peer.Networks[i] = a.AnonymizeIPString(route)
- }
-
- for i, route := range peer.Networks {
- peer.Networks[i] = a.AnonymizeRoute(route)
- }
-}
-
-func anonymizeOverview(a *anonymize.Anonymizer, overview *statusOutputOverview) {
- for i, peer := range overview.Peers.Details {
- peer := peer
- anonymizePeerDetail(a, &peer)
- overview.Peers.Details[i] = peer
- }
-
- overview.ManagementState.URL = a.AnonymizeURI(overview.ManagementState.URL)
- overview.ManagementState.Error = a.AnonymizeString(overview.ManagementState.Error)
- overview.SignalState.URL = a.AnonymizeURI(overview.SignalState.URL)
- overview.SignalState.Error = a.AnonymizeString(overview.SignalState.Error)
-
- overview.IP = a.AnonymizeIPString(overview.IP)
- for i, detail := range overview.Relays.Details {
- detail.URI = a.AnonymizeURI(detail.URI)
- detail.Error = a.AnonymizeString(detail.Error)
- overview.Relays.Details[i] = detail
- }
-
- for i, nsGroup := range overview.NSServerGroups {
- for j, domain := range nsGroup.Domains {
- overview.NSServerGroups[i].Domains[j] = a.AnonymizeDomain(domain)
- }
- for j, ns := range nsGroup.Servers {
- host, port, err := net.SplitHostPort(ns)
- if err == nil {
- overview.NSServerGroups[i].Servers[j] = fmt.Sprintf("%s:%s", a.AnonymizeIPString(host), port)
- }
- }
- }
-
- for i, route := range overview.Networks {
- overview.Networks[i] = a.AnonymizeRoute(route)
- }
-
- overview.FQDN = a.AnonymizeDomain(overview.FQDN)
-
- for i, event := range overview.Events {
- overview.Events[i].Message = a.AnonymizeString(event.Message)
- overview.Events[i].UserMessage = a.AnonymizeString(event.UserMessage)
-
- for k, v := range event.Metadata {
- event.Metadata[k] = a.AnonymizeString(v)
- }
- }
-}
diff --git a/client/cmd/status_test.go b/client/cmd/status_test.go
index 0b0ae4c51..03608eab0 100644
--- a/client/cmd/status_test.go
+++ b/client/cmd/status_test.go
@@ -1,583 +1,11 @@
package cmd
import (
- "bytes"
- "encoding/json"
- "fmt"
- "runtime"
"testing"
- "time"
"github.com/stretchr/testify/assert"
- "github.com/stretchr/testify/require"
- "google.golang.org/protobuf/types/known/durationpb"
- "google.golang.org/protobuf/types/known/timestamppb"
-
- "github.com/netbirdio/netbird/client/proto"
- "github.com/netbirdio/netbird/version"
)
-func init() {
- loc, err := time.LoadLocation("UTC")
- if err != nil {
- panic(err)
- }
-
- time.Local = loc
-}
-
-var resp = &proto.StatusResponse{
- Status: "Connected",
- FullStatus: &proto.FullStatus{
- Peers: []*proto.PeerState{
- {
- IP: "192.168.178.101",
- PubKey: "Pubkey1",
- Fqdn: "peer-1.awesome-domain.com",
- ConnStatus: "Connected",
- ConnStatusUpdate: timestamppb.New(time.Date(2001, time.Month(1), 1, 1, 1, 1, 0, time.UTC)),
- Relayed: false,
- LocalIceCandidateType: "",
- RemoteIceCandidateType: "",
- LocalIceCandidateEndpoint: "",
- RemoteIceCandidateEndpoint: "",
- LastWireguardHandshake: timestamppb.New(time.Date(2001, time.Month(1), 1, 1, 1, 2, 0, time.UTC)),
- BytesRx: 200,
- BytesTx: 100,
- Networks: []string{
- "10.1.0.0/24",
- },
- Latency: durationpb.New(time.Duration(10000000)),
- },
- {
- IP: "192.168.178.102",
- PubKey: "Pubkey2",
- Fqdn: "peer-2.awesome-domain.com",
- ConnStatus: "Connected",
- ConnStatusUpdate: timestamppb.New(time.Date(2002, time.Month(2), 2, 2, 2, 2, 0, time.UTC)),
- Relayed: true,
- LocalIceCandidateType: "relay",
- RemoteIceCandidateType: "prflx",
- LocalIceCandidateEndpoint: "10.0.0.1:10001",
- RemoteIceCandidateEndpoint: "10.0.10.1:10002",
- LastWireguardHandshake: timestamppb.New(time.Date(2002, time.Month(2), 2, 2, 2, 3, 0, time.UTC)),
- BytesRx: 2000,
- BytesTx: 1000,
- Latency: durationpb.New(time.Duration(10000000)),
- },
- },
- ManagementState: &proto.ManagementState{
- URL: "my-awesome-management.com:443",
- Connected: true,
- Error: "",
- },
- SignalState: &proto.SignalState{
- URL: "my-awesome-signal.com:443",
- Connected: true,
- Error: "",
- },
- Relays: []*proto.RelayState{
- {
- URI: "stun:my-awesome-stun.com:3478",
- Available: true,
- Error: "",
- },
- {
- URI: "turns:my-awesome-turn.com:443?transport=tcp",
- Available: false,
- Error: "context: deadline exceeded",
- },
- },
- LocalPeerState: &proto.LocalPeerState{
- IP: "192.168.178.100/16",
- PubKey: "Some-Pub-Key",
- KernelInterface: true,
- Fqdn: "some-localhost.awesome-domain.com",
- Networks: []string{
- "10.10.0.0/24",
- },
- },
- DnsServers: []*proto.NSGroupState{
- {
- Servers: []string{
- "8.8.8.8:53",
- },
- Domains: nil,
- Enabled: true,
- Error: "",
- },
- {
- Servers: []string{
- "1.1.1.1:53",
- "2.2.2.2:53",
- },
- Domains: []string{
- "example.com",
- "example.net",
- },
- Enabled: false,
- Error: "timeout",
- },
- },
- },
- DaemonVersion: "0.14.1",
-}
-
-var overview = statusOutputOverview{
- Peers: peersStateOutput{
- Total: 2,
- Connected: 2,
- Details: []peerStateDetailOutput{
- {
- IP: "192.168.178.101",
- PubKey: "Pubkey1",
- FQDN: "peer-1.awesome-domain.com",
- Status: "Connected",
- LastStatusUpdate: time.Date(2001, 1, 1, 1, 1, 1, 0, time.UTC),
- ConnType: "P2P",
- IceCandidateType: iceCandidateType{
- Local: "",
- Remote: "",
- },
- IceCandidateEndpoint: iceCandidateType{
- Local: "",
- Remote: "",
- },
- LastWireguardHandshake: time.Date(2001, 1, 1, 1, 1, 2, 0, time.UTC),
- TransferReceived: 200,
- TransferSent: 100,
- Networks: []string{
- "10.1.0.0/24",
- },
- Latency: time.Duration(10000000),
- },
- {
- IP: "192.168.178.102",
- PubKey: "Pubkey2",
- FQDN: "peer-2.awesome-domain.com",
- Status: "Connected",
- LastStatusUpdate: time.Date(2002, 2, 2, 2, 2, 2, 0, time.UTC),
- ConnType: "Relayed",
- IceCandidateType: iceCandidateType{
- Local: "relay",
- Remote: "prflx",
- },
- IceCandidateEndpoint: iceCandidateType{
- Local: "10.0.0.1:10001",
- Remote: "10.0.10.1:10002",
- },
- LastWireguardHandshake: time.Date(2002, 2, 2, 2, 2, 3, 0, time.UTC),
- TransferReceived: 2000,
- TransferSent: 1000,
- Latency: time.Duration(10000000),
- },
- },
- },
- Events: []systemEventOutput{},
- CliVersion: version.NetbirdVersion(),
- DaemonVersion: "0.14.1",
- ManagementState: managementStateOutput{
- URL: "my-awesome-management.com:443",
- Connected: true,
- Error: "",
- },
- SignalState: signalStateOutput{
- URL: "my-awesome-signal.com:443",
- Connected: true,
- Error: "",
- },
- Relays: relayStateOutput{
- Total: 2,
- Available: 1,
- Details: []relayStateOutputDetail{
- {
- URI: "stun:my-awesome-stun.com:3478",
- Available: true,
- Error: "",
- },
- {
- URI: "turns:my-awesome-turn.com:443?transport=tcp",
- Available: false,
- Error: "context: deadline exceeded",
- },
- },
- },
- IP: "192.168.178.100/16",
- PubKey: "Some-Pub-Key",
- KernelInterface: true,
- FQDN: "some-localhost.awesome-domain.com",
- NSServerGroups: []nsServerGroupStateOutput{
- {
- Servers: []string{
- "8.8.8.8:53",
- },
- Domains: nil,
- Enabled: true,
- Error: "",
- },
- {
- Servers: []string{
- "1.1.1.1:53",
- "2.2.2.2:53",
- },
- Domains: []string{
- "example.com",
- "example.net",
- },
- Enabled: false,
- Error: "timeout",
- },
- },
- Networks: []string{
- "10.10.0.0/24",
- },
-}
-
-func TestConversionFromFullStatusToOutputOverview(t *testing.T) {
- convertedResult := convertToStatusOutputOverview(resp)
-
- assert.Equal(t, overview, convertedResult)
-}
-
-func TestSortingOfPeers(t *testing.T) {
- peers := []peerStateDetailOutput{
- {
- IP: "192.168.178.104",
- },
- {
- IP: "192.168.178.102",
- },
- {
- IP: "192.168.178.101",
- },
- {
- IP: "192.168.178.105",
- },
- {
- IP: "192.168.178.103",
- },
- }
-
- sortPeersByIP(peers)
-
- assert.Equal(t, peers[3].IP, "192.168.178.104")
-}
-
-func TestParsingToJSON(t *testing.T) {
- jsonString, _ := parseToJSON(overview)
-
- //@formatter:off
- expectedJSONString := `
- {
- "peers": {
- "total": 2,
- "connected": 2,
- "details": [
- {
- "fqdn": "peer-1.awesome-domain.com",
- "netbirdIp": "192.168.178.101",
- "publicKey": "Pubkey1",
- "status": "Connected",
- "lastStatusUpdate": "2001-01-01T01:01:01Z",
- "connectionType": "P2P",
- "iceCandidateType": {
- "local": "",
- "remote": ""
- },
- "iceCandidateEndpoint": {
- "local": "",
- "remote": ""
- },
- "relayAddress": "",
- "lastWireguardHandshake": "2001-01-01T01:01:02Z",
- "transferReceived": 200,
- "transferSent": 100,
- "latency": 10000000,
- "quantumResistance": false,
- "networks": [
- "10.1.0.0/24"
- ]
- },
- {
- "fqdn": "peer-2.awesome-domain.com",
- "netbirdIp": "192.168.178.102",
- "publicKey": "Pubkey2",
- "status": "Connected",
- "lastStatusUpdate": "2002-02-02T02:02:02Z",
- "connectionType": "Relayed",
- "iceCandidateType": {
- "local": "relay",
- "remote": "prflx"
- },
- "iceCandidateEndpoint": {
- "local": "10.0.0.1:10001",
- "remote": "10.0.10.1:10002"
- },
- "relayAddress": "",
- "lastWireguardHandshake": "2002-02-02T02:02:03Z",
- "transferReceived": 2000,
- "transferSent": 1000,
- "latency": 10000000,
- "quantumResistance": false,
- "networks": null
- }
- ]
- },
- "cliVersion": "development",
- "daemonVersion": "0.14.1",
- "management": {
- "url": "my-awesome-management.com:443",
- "connected": true,
- "error": ""
- },
- "signal": {
- "url": "my-awesome-signal.com:443",
- "connected": true,
- "error": ""
- },
- "relays": {
- "total": 2,
- "available": 1,
- "details": [
- {
- "uri": "stun:my-awesome-stun.com:3478",
- "available": true,
- "error": ""
- },
- {
- "uri": "turns:my-awesome-turn.com:443?transport=tcp",
- "available": false,
- "error": "context: deadline exceeded"
- }
- ]
- },
- "netbirdIp": "192.168.178.100/16",
- "publicKey": "Some-Pub-Key",
- "usesKernelInterface": true,
- "fqdn": "some-localhost.awesome-domain.com",
- "quantumResistance": false,
- "quantumResistancePermissive": false,
- "networks": [
- "10.10.0.0/24"
- ],
- "forwardingRules": 0,
- "dnsServers": [
- {
- "servers": [
- "8.8.8.8:53"
- ],
- "domains": null,
- "enabled": true,
- "error": ""
- },
- {
- "servers": [
- "1.1.1.1:53",
- "2.2.2.2:53"
- ],
- "domains": [
- "example.com",
- "example.net"
- ],
- "enabled": false,
- "error": "timeout"
- }
- ],
- "events": []
- }`
- // @formatter:on
-
- var expectedJSON bytes.Buffer
- require.NoError(t, json.Compact(&expectedJSON, []byte(expectedJSONString)))
-
- assert.Equal(t, expectedJSON.String(), jsonString)
-}
-
-func TestParsingToYAML(t *testing.T) {
- yaml, _ := parseToYAML(overview)
-
- expectedYAML :=
- `peers:
- total: 2
- connected: 2
- details:
- - fqdn: peer-1.awesome-domain.com
- netbirdIp: 192.168.178.101
- publicKey: Pubkey1
- status: Connected
- lastStatusUpdate: 2001-01-01T01:01:01Z
- connectionType: P2P
- iceCandidateType:
- local: ""
- remote: ""
- iceCandidateEndpoint:
- local: ""
- remote: ""
- relayAddress: ""
- lastWireguardHandshake: 2001-01-01T01:01:02Z
- transferReceived: 200
- transferSent: 100
- latency: 10ms
- quantumResistance: false
- networks:
- - 10.1.0.0/24
- - fqdn: peer-2.awesome-domain.com
- netbirdIp: 192.168.178.102
- publicKey: Pubkey2
- status: Connected
- lastStatusUpdate: 2002-02-02T02:02:02Z
- connectionType: Relayed
- iceCandidateType:
- local: relay
- remote: prflx
- iceCandidateEndpoint:
- local: 10.0.0.1:10001
- remote: 10.0.10.1:10002
- relayAddress: ""
- lastWireguardHandshake: 2002-02-02T02:02:03Z
- transferReceived: 2000
- transferSent: 1000
- latency: 10ms
- quantumResistance: false
- networks: []
-cliVersion: development
-daemonVersion: 0.14.1
-management:
- url: my-awesome-management.com:443
- connected: true
- error: ""
-signal:
- url: my-awesome-signal.com:443
- connected: true
- error: ""
-relays:
- total: 2
- available: 1
- details:
- - uri: stun:my-awesome-stun.com:3478
- available: true
- error: ""
- - uri: turns:my-awesome-turn.com:443?transport=tcp
- available: false
- error: 'context: deadline exceeded'
-netbirdIp: 192.168.178.100/16
-publicKey: Some-Pub-Key
-usesKernelInterface: true
-fqdn: some-localhost.awesome-domain.com
-quantumResistance: false
-quantumResistancePermissive: false
-networks:
- - 10.10.0.0/24
-forwardingRules: 0
-dnsServers:
- - servers:
- - 8.8.8.8:53
- domains: []
- enabled: true
- error: ""
- - servers:
- - 1.1.1.1:53
- - 2.2.2.2:53
- domains:
- - example.com
- - example.net
- enabled: false
- error: timeout
-events: []
-`
-
- assert.Equal(t, expectedYAML, yaml)
-}
-
-func TestParsingToDetail(t *testing.T) {
- // Calculate time ago based on the fixture dates
- lastConnectionUpdate1 := timeAgo(overview.Peers.Details[0].LastStatusUpdate)
- lastHandshake1 := timeAgo(overview.Peers.Details[0].LastWireguardHandshake)
- lastConnectionUpdate2 := timeAgo(overview.Peers.Details[1].LastStatusUpdate)
- lastHandshake2 := timeAgo(overview.Peers.Details[1].LastWireguardHandshake)
-
- detail := parseToFullDetailSummary(overview)
-
- expectedDetail := fmt.Sprintf(
- `Peers detail:
- peer-1.awesome-domain.com:
- NetBird IP: 192.168.178.101
- Public key: Pubkey1
- Status: Connected
- -- detail --
- Connection type: P2P
- ICE candidate (Local/Remote): -/-
- ICE candidate endpoints (Local/Remote): -/-
- Relay server address:
- Last connection update: %s
- Last WireGuard handshake: %s
- Transfer status (received/sent) 200 B/100 B
- Quantum resistance: false
- Networks: 10.1.0.0/24
- Latency: 10ms
-
- peer-2.awesome-domain.com:
- NetBird IP: 192.168.178.102
- Public key: Pubkey2
- Status: Connected
- -- detail --
- Connection type: Relayed
- ICE candidate (Local/Remote): relay/prflx
- ICE candidate endpoints (Local/Remote): 10.0.0.1:10001/10.0.10.1:10002
- Relay server address:
- Last connection update: %s
- Last WireGuard handshake: %s
- Transfer status (received/sent) 2.0 KiB/1000 B
- Quantum resistance: false
- Networks: -
- Latency: 10ms
-
-Events: No events recorded
-OS: %s/%s
-Daemon version: 0.14.1
-CLI version: %s
-Management: Connected to my-awesome-management.com:443
-Signal: Connected to my-awesome-signal.com:443
-Relays:
- [stun:my-awesome-stun.com:3478] is Available
- [turns:my-awesome-turn.com:443?transport=tcp] is Unavailable, reason: context: deadline exceeded
-Nameservers:
- [8.8.8.8:53] for [.] is Available
- [1.1.1.1:53, 2.2.2.2:53] for [example.com, example.net] is Unavailable, reason: timeout
-FQDN: some-localhost.awesome-domain.com
-NetBird IP: 192.168.178.100/16
-Interface type: Kernel
-Quantum resistance: false
-Networks: 10.10.0.0/24
-Forwarding rules: 0
-Peers count: 2/2 Connected
-`, lastConnectionUpdate1, lastHandshake1, lastConnectionUpdate2, lastHandshake2, runtime.GOOS, runtime.GOARCH, overview.CliVersion)
-
- assert.Equal(t, expectedDetail, detail)
-}
-
-func TestParsingToShortVersion(t *testing.T) {
- shortVersion := parseGeneralSummary(overview, false, false, false)
-
- expectedString := fmt.Sprintf("OS: %s/%s", runtime.GOOS, runtime.GOARCH) + `
-Daemon version: 0.14.1
-CLI version: development
-Management: Connected
-Signal: Connected
-Relays: 1/2 Available
-Nameservers: 1/2 Available
-FQDN: some-localhost.awesome-domain.com
-NetBird IP: 192.168.178.100/16
-Interface type: Kernel
-Quantum resistance: false
-Networks: 10.10.0.0/24
-Forwarding rules: 0
-Peers count: 2/2 Connected
-`
-
- assert.Equal(t, expectedString, shortVersion)
-}
-
func TestParsingOfIP(t *testing.T) {
InterfaceIP := "192.168.178.123/16"
@@ -585,31 +13,3 @@ func TestParsingOfIP(t *testing.T) {
assert.Equal(t, "192.168.178.123\n", parsedIP)
}
-
-func TestTimeAgo(t *testing.T) {
- now := time.Now()
-
- cases := []struct {
- name string
- input time.Time
- expected string
- }{
- {"Now", now, "Now"},
- {"Seconds ago", now.Add(-10 * time.Second), "10 seconds ago"},
- {"One minute ago", now.Add(-1 * time.Minute), "1 minute ago"},
- {"Minutes and seconds ago", now.Add(-(1*time.Minute + 30*time.Second)), "1 minute, 30 seconds ago"},
- {"One hour ago", now.Add(-1 * time.Hour), "1 hour ago"},
- {"Hours and minutes ago", now.Add(-(2*time.Hour + 15*time.Minute)), "2 hours, 15 minutes ago"},
- {"One day ago", now.Add(-24 * time.Hour), "1 day ago"},
- {"Multiple days ago", now.Add(-(72*time.Hour + 20*time.Minute)), "3 days ago"},
- {"Zero time", time.Time{}, "-"},
- {"Unix zero time", time.Unix(0, 0), "-"},
- }
-
- for _, tc := range cases {
- t.Run(tc.name, func(t *testing.T) {
- result := timeAgo(tc.input)
- assert.Equal(t, tc.expected, result, "Failed %s", tc.name)
- })
- }
-}
diff --git a/client/cmd/testutil_test.go b/client/cmd/testutil_test.go
index ee67d2501..4c06a7da0 100644
--- a/client/cmd/testutil_test.go
+++ b/client/cmd/testutil_test.go
@@ -96,7 +96,7 @@ func startManagement(t *testing.T, config *mgmt.Config, testFile string) (*grpc.
}
secretsManager := mgmt.NewTimeBasedAuthSecretsManager(peersUpdateManager, config.TURNConfig, config.Relay)
- mgmtServer, err := mgmt.NewServer(context.Background(), config, accountManager, settings.NewManager(store), peersUpdateManager, secretsManager, nil, nil)
+ mgmtServer, err := mgmt.NewServer(context.Background(), config, accountManager, settings.NewManager(store), peersUpdateManager, secretsManager, nil, nil, nil)
if err != nil {
t.Fatal(err)
}
diff --git a/client/firewall/uspfilter/forwarder/udp.go b/client/firewall/uspfilter/forwarder/udp.go
index 97e4662fd..c37740587 100644
--- a/client/firewall/uspfilter/forwarder/udp.go
+++ b/client/firewall/uspfilter/forwarder/udp.go
@@ -245,33 +245,29 @@ func (c *udpPacketConn) copy(ctx context.Context, dst net.Conn, src net.Conn, bu
defer bufPool.Put(bufp)
buffer := *bufp
- if err := src.SetReadDeadline(time.Now().Add(udpTimeout)); err != nil {
- return fmt.Errorf("set read deadline: %w", err)
- }
- if err := src.SetWriteDeadline(time.Now().Add(udpTimeout)); err != nil {
- return fmt.Errorf("set write deadline: %w", err)
- }
-
for {
- select {
- case <-ctx.Done():
+ if ctx.Err() != nil {
return ctx.Err()
- default:
- n, err := src.Read(buffer)
- if err != nil {
- if isTimeout(err) {
- continue
- }
- return fmt.Errorf("read from %s: %w", direction, err)
- }
-
- _, err = dst.Write(buffer[:n])
- if err != nil {
- return fmt.Errorf("write to %s: %w", direction, err)
- }
-
- c.updateLastSeen()
}
+
+ if err := src.SetDeadline(time.Now().Add(udpTimeout)); err != nil {
+ return fmt.Errorf("set read deadline: %w", err)
+ }
+
+ n, err := src.Read(buffer)
+ if err != nil {
+ if isTimeout(err) {
+ continue
+ }
+ return fmt.Errorf("read from %s: %w", direction, err)
+ }
+
+ _, err = dst.Write(buffer[:n])
+ if err != nil {
+ return fmt.Errorf("write to %s: %w", direction, err)
+ }
+
+ c.updateLastSeen()
}
}
diff --git a/client/iface/bind/udp_mux.go b/client/iface/bind/udp_mux.go
index 00a91f0ec..4c827de95 100644
--- a/client/iface/bind/udp_mux.go
+++ b/client/iface/bind/udp_mux.go
@@ -4,6 +4,7 @@ import (
"fmt"
"io"
"net"
+ "slices"
"strings"
"sync"
@@ -152,46 +153,7 @@ func NewUDPMuxDefault(params UDPMuxParams) *UDPMuxDefault {
params.Logger = logging.NewDefaultLoggerFactory().NewLogger("ice")
}
- var localAddrsForUnspecified []net.Addr
- if addr, ok := params.UDPConn.LocalAddr().(*net.UDPAddr); !ok {
- params.Logger.Errorf("LocalAddr is not a net.UDPAddr, got %T", params.UDPConn.LocalAddr())
- } else if ok && addr.IP.IsUnspecified() {
- // For unspecified addresses, the correct behavior is to return errListenUnspecified, but
- // it will break the applications that are already using unspecified UDP connection
- // with UDPMuxDefault, so print a warn log and create a local address list for mux.
- params.Logger.Warn("UDPMuxDefault should not listening on unspecified address, use NewMultiUDPMuxFromPort instead")
- var networks []ice.NetworkType
- switch {
-
- case addr.IP.To16() != nil:
- networks = []ice.NetworkType{ice.NetworkTypeUDP4, ice.NetworkTypeUDP6}
-
- case addr.IP.To4() != nil:
- networks = []ice.NetworkType{ice.NetworkTypeUDP4}
-
- default:
- params.Logger.Errorf("LocalAddr expected IPV4 or IPV6, got %T", params.UDPConn.LocalAddr())
- }
- if len(networks) > 0 {
- if params.Net == nil {
- var err error
- if params.Net, err = stdnet.NewNet(); err != nil {
- params.Logger.Errorf("failed to get create network: %v", err)
- }
- }
-
- ips, err := localInterfaces(params.Net, params.InterfaceFilter, nil, networks, true)
- if err == nil {
- for _, ip := range ips {
- localAddrsForUnspecified = append(localAddrsForUnspecified, &net.UDPAddr{IP: ip, Port: addr.Port})
- }
- } else {
- params.Logger.Errorf("failed to get local interfaces for unspecified addr: %v", err)
- }
- }
- }
-
- return &UDPMuxDefault{
+ mux := &UDPMuxDefault{
addressMap: map[string][]*udpMuxedConn{},
params: params,
connsIPv4: make(map[string]*udpMuxedConn),
@@ -203,8 +165,55 @@ func NewUDPMuxDefault(params UDPMuxParams) *UDPMuxDefault {
return newBufferHolder(receiveMTU + maxAddrSize)
},
},
- localAddrsForUnspecified: localAddrsForUnspecified,
}
+
+ mux.updateLocalAddresses()
+ return mux
+}
+
+func (m *UDPMuxDefault) updateLocalAddresses() {
+ var localAddrsForUnspecified []net.Addr
+ if addr, ok := m.params.UDPConn.LocalAddr().(*net.UDPAddr); !ok {
+ m.params.Logger.Errorf("LocalAddr is not a net.UDPAddr, got %T", m.params.UDPConn.LocalAddr())
+ } else if ok && addr.IP.IsUnspecified() {
+ // For unspecified addresses, the correct behavior is to return errListenUnspecified, but
+ // it will break the applications that are already using unspecified UDP connection
+ // with UDPMuxDefault, so print a warn log and create a local address list for mux.
+ m.params.Logger.Warn("UDPMuxDefault should not listening on unspecified address, use NewMultiUDPMuxFromPort instead")
+ var networks []ice.NetworkType
+ switch {
+
+ case addr.IP.To16() != nil:
+ networks = []ice.NetworkType{ice.NetworkTypeUDP4, ice.NetworkTypeUDP6}
+
+ case addr.IP.To4() != nil:
+ networks = []ice.NetworkType{ice.NetworkTypeUDP4}
+
+ default:
+ m.params.Logger.Errorf("LocalAddr expected IPV4 or IPV6, got %T", m.params.UDPConn.LocalAddr())
+ }
+ if len(networks) > 0 {
+ if m.params.Net == nil {
+ var err error
+ if m.params.Net, err = stdnet.NewNet(); err != nil {
+ m.params.Logger.Errorf("failed to get create network: %v", err)
+ }
+ }
+
+ ips, err := localInterfaces(m.params.Net, m.params.InterfaceFilter, nil, networks, true)
+ if err == nil {
+ for _, ip := range ips {
+ localAddrsForUnspecified = append(localAddrsForUnspecified, &net.UDPAddr{IP: ip, Port: addr.Port})
+ }
+ } else {
+ m.params.Logger.Errorf("failed to get local interfaces for unspecified addr: %v", err)
+ }
+ }
+ }
+
+ m.mu.Lock()
+ m.localAddrsForUnspecified = localAddrsForUnspecified
+ m.mu.Unlock()
}
// LocalAddr returns the listening address of this UDPMuxDefault
@@ -214,8 +223,12 @@ func (m *UDPMuxDefault) LocalAddr() net.Addr {
// GetListenAddresses returns the list of addresses that this mux is listening on
func (m *UDPMuxDefault) GetListenAddresses() []net.Addr {
+ m.updateLocalAddresses()
+
+ m.mu.Lock()
+ defer m.mu.Unlock()
if len(m.localAddrsForUnspecified) > 0 {
- return m.localAddrsForUnspecified
+ return slices.Clone(m.localAddrsForUnspecified)
}
return []net.Addr{m.LocalAddr()}
@@ -225,7 +238,10 @@ func (m *UDPMuxDefault) GetListenAddresses() []net.Addr {
// creates the connection if an existing one can't be found
func (m *UDPMuxDefault) GetConn(ufrag string, addr net.Addr) (net.PacketConn, error) {
// don't check addr for mux using unspecified address
- if len(m.localAddrsForUnspecified) == 0 && m.params.UDPConn.LocalAddr().String() != addr.String() {
+ m.mu.Lock()
+ lenLocalAddrs := len(m.localAddrsForUnspecified)
+ m.mu.Unlock()
+ if lenLocalAddrs == 0 && m.params.UDPConn.LocalAddr().String() != addr.String() {
return nil, fmt.Errorf("invalid address %s", addr.String())
}
diff --git a/client/iface/configurer/kernel_unix.go b/client/iface/configurer/kernel_unix.go
index 7c1c41669..6f09a63c9 100644
--- a/client/iface/configurer/kernel_unix.go
+++ b/client/iface/configurer/kernel_unix.go
@@ -43,13 +43,7 @@ func (c *KernelConfigurer) ConfigureInterface(privateKey string, port int) error
return nil
}
-func (c *KernelConfigurer) UpdatePeer(peerKey string, allowedIps string, keepAlive time.Duration, endpoint *net.UDPAddr, preSharedKey *wgtypes.Key) error {
- // parse allowed ips
- _, ipNet, err := net.ParseCIDR(allowedIps)
- if err != nil {
- return err
- }
-
+func (c *KernelConfigurer) UpdatePeer(peerKey string, allowedIps []net.IPNet, keepAlive time.Duration, endpoint *net.UDPAddr, preSharedKey *wgtypes.Key) error {
peerKeyParsed, err := wgtypes.ParseKey(peerKey)
if err != nil {
return err
@@ -58,7 +52,7 @@ func (c *KernelConfigurer) UpdatePeer(peerKey string, allowedIps string, keepAli
PublicKey: peerKeyParsed,
ReplaceAllowedIPs: false,
// don't replace allowed ips, wg will handle duplicated peer IP
- AllowedIPs: []net.IPNet{*ipNet},
+ AllowedIPs: allowedIps,
PersistentKeepaliveInterval: &keepAlive,
Endpoint: endpoint,
PresharedKey: preSharedKey,
diff --git a/client/iface/configurer/usp.go b/client/iface/configurer/usp.go
index 391269dd0..a3de58c24 100644
--- a/client/iface/configurer/usp.go
+++ b/client/iface/configurer/usp.go
@@ -52,13 +52,7 @@ func (c *WGUSPConfigurer) ConfigureInterface(privateKey string, port int) error
return c.device.IpcSet(toWgUserspaceString(config))
}
-func (c *WGUSPConfigurer) UpdatePeer(peerKey string, allowedIps string, keepAlive time.Duration, endpoint *net.UDPAddr, preSharedKey *wgtypes.Key) error {
- // parse allowed ips
- _, ipNet, err := net.ParseCIDR(allowedIps)
- if err != nil {
- return err
- }
-
+func (c *WGUSPConfigurer) UpdatePeer(peerKey string, allowedIps []net.IPNet, keepAlive time.Duration, endpoint *net.UDPAddr, preSharedKey *wgtypes.Key) error {
peerKeyParsed, err := wgtypes.ParseKey(peerKey)
if err != nil {
return err
@@ -67,7 +61,7 @@ func (c *WGUSPConfigurer) UpdatePeer(peerKey string, allowedIps string, keepAliv
PublicKey: peerKeyParsed,
ReplaceAllowedIPs: false,
// don't replace allowed ips, wg will handle duplicated peer IP
- AllowedIPs: []net.IPNet{*ipNet},
+ AllowedIPs: allowedIps,
PersistentKeepaliveInterval: &keepAlive,
PresharedKey: preSharedKey,
Endpoint: endpoint,
diff --git a/client/iface/device/interface.go b/client/iface/device/interface.go
index 0196b0085..6971b6946 100644
--- a/client/iface/device/interface.go
+++ b/client/iface/device/interface.go
@@ -11,7 +11,7 @@ import (
type WGConfigurer interface {
ConfigureInterface(privateKey string, port int) error
- UpdatePeer(peerKey string, allowedIps string, keepAlive time.Duration, endpoint *net.UDPAddr, preSharedKey *wgtypes.Key) error
+ UpdatePeer(peerKey string, allowedIps []net.IPNet, keepAlive time.Duration, endpoint *net.UDPAddr, preSharedKey *wgtypes.Key) error
RemovePeer(peerKey string) error
AddAllowedIP(peerKey string, allowedIP string) error
RemoveAllowedIP(peerKey string, allowedIP string) error
diff --git a/client/iface/iface.go b/client/iface/iface.go
index 8056dd9a6..40bd51fbb 100644
--- a/client/iface/iface.go
+++ b/client/iface/iface.go
@@ -3,6 +3,7 @@ package iface
import (
"fmt"
"net"
+ "net/netip"
"sync"
"time"
@@ -112,12 +113,13 @@ func (w *WGIface) UpdateAddr(newAddr string) error {
// UpdatePeer updates existing Wireguard Peer or creates a new one if doesn't exist
// Endpoint is optional
-func (w *WGIface) UpdatePeer(peerKey string, allowedIps string, keepAlive time.Duration, endpoint *net.UDPAddr, preSharedKey *wgtypes.Key) error {
+func (w *WGIface) UpdatePeer(peerKey string, allowedIps []netip.Prefix, keepAlive time.Duration, endpoint *net.UDPAddr, preSharedKey *wgtypes.Key) error {
w.mu.Lock()
defer w.mu.Unlock()
+ netIPNets := prefixesToIPNets(allowedIps)
log.Debugf("updating interface %s peer %s, endpoint %s", w.tun.DeviceName(), peerKey, endpoint)
- return w.configurer.UpdatePeer(peerKey, allowedIps, keepAlive, endpoint, preSharedKey)
+ return w.configurer.UpdatePeer(peerKey, netIPNets, keepAlive, endpoint, preSharedKey)
}
// RemovePeer removes a Wireguard Peer from the interface iface
@@ -250,3 +252,14 @@ func (w *WGIface) GetNet() *netstack.Net {
return w.tun.GetNet()
}
+
+func prefixesToIPNets(prefixes []netip.Prefix) []net.IPNet {
+ ipNets := make([]net.IPNet, len(prefixes))
+ for i, prefix := range prefixes {
+ ipNets[i] = net.IPNet{
+ IP: net.IP(prefix.Addr().AsSlice()), // Convert netip.Addr to net.IP
+ Mask: net.CIDRMask(prefix.Bits(), prefix.Addr().BitLen()), // Create subnet mask
+ }
+ }
+ return ipNets
+}
diff --git a/client/iface/iface_moc.go b/client/iface/iface_moc.go
deleted file mode 100644
index f92a8cfc8..000000000
--- a/client/iface/iface_moc.go
+++ /dev/null
@@ -1,123 +0,0 @@
-package iface
-
-import (
- "net"
- "time"
-
- wgdevice "golang.zx2c4.com/wireguard/device"
- "golang.zx2c4.com/wireguard/tun/netstack"
- "golang.zx2c4.com/wireguard/wgctrl/wgtypes"
-
- "github.com/netbirdio/netbird/client/iface/bind"
- "github.com/netbirdio/netbird/client/iface/configurer"
- "github.com/netbirdio/netbird/client/iface/device"
- "github.com/netbirdio/netbird/client/iface/wgproxy"
-)
-
-type MockWGIface struct {
- CreateFunc func() error
- CreateOnAndroidFunc func(routeRange []string, ip string, domains []string) error
- IsUserspaceBindFunc func() bool
- NameFunc func() string
- AddressFunc func() device.WGAddress
- ToInterfaceFunc func() *net.Interface
- UpFunc func() (*bind.UniversalUDPMuxDefault, error)
- UpdateAddrFunc func(newAddr string) error
- UpdatePeerFunc func(peerKey string, allowedIps string, keepAlive time.Duration, endpoint *net.UDPAddr, preSharedKey *wgtypes.Key) error
- RemovePeerFunc func(peerKey string) error
- AddAllowedIPFunc func(peerKey string, allowedIP string) error
- RemoveAllowedIPFunc func(peerKey string, allowedIP string) error
- CloseFunc func() error
- SetFilterFunc func(filter device.PacketFilter) error
- GetFilterFunc func() device.PacketFilter
- GetDeviceFunc func() *device.FilteredDevice
- GetWGDeviceFunc func() *wgdevice.Device
- GetStatsFunc func(peerKey string) (configurer.WGStats, error)
- GetInterfaceGUIDStringFunc func() (string, error)
- GetProxyFunc func() wgproxy.Proxy
- GetNetFunc func() *netstack.Net
-}
-
-func (m *MockWGIface) GetInterfaceGUIDString() (string, error) {
- return m.GetInterfaceGUIDStringFunc()
-}
-
-func (m *MockWGIface) Create() error {
- return m.CreateFunc()
-}
-
-func (m *MockWGIface) CreateOnAndroid(routeRange []string, ip string, domains []string) error {
- return m.CreateOnAndroidFunc(routeRange, ip, domains)
-}
-
-func (m *MockWGIface) IsUserspaceBind() bool {
- return m.IsUserspaceBindFunc()
-}
-
-func (m *MockWGIface) Name() string {
- return m.NameFunc()
-}
-
-func (m *MockWGIface) Address() device.WGAddress {
- return m.AddressFunc()
-}
-
-func (m *MockWGIface) ToInterface() *net.Interface {
- return m.ToInterfaceFunc()
-}
-
-func (m *MockWGIface) Up() (*bind.UniversalUDPMuxDefault, error) {
- return m.UpFunc()
-}
-
-func (m *MockWGIface) UpdateAddr(newAddr string) error {
- return m.UpdateAddrFunc(newAddr)
-}
-
-func (m *MockWGIface) UpdatePeer(peerKey string, allowedIps string, keepAlive time.Duration, endpoint *net.UDPAddr, preSharedKey *wgtypes.Key) error {
- return m.UpdatePeerFunc(peerKey, allowedIps, keepAlive, endpoint, preSharedKey)
-}
-
-func (m *MockWGIface) RemovePeer(peerKey string) error {
- return m.RemovePeerFunc(peerKey)
-}
-
-func (m *MockWGIface) AddAllowedIP(peerKey string, allowedIP string) error {
- return m.AddAllowedIPFunc(peerKey, allowedIP)
-}
-
-func (m *MockWGIface) RemoveAllowedIP(peerKey string, allowedIP string) error {
- return m.RemoveAllowedIPFunc(peerKey, allowedIP)
-}
-
-func (m *MockWGIface) Close() error {
- return m.CloseFunc()
-}
-
-func (m *MockWGIface) SetFilter(filter device.PacketFilter) error {
- return m.SetFilterFunc(filter)
-}
-
-func (m *MockWGIface) GetFilter() device.PacketFilter {
- return m.GetFilterFunc()
-}
-
-func (m *MockWGIface) GetDevice() *device.FilteredDevice {
- return m.GetDeviceFunc()
-}
-
-func (m *MockWGIface) GetWGDevice() *wgdevice.Device {
- return m.GetWGDeviceFunc()
-}
-
-func (m *MockWGIface) GetStats(peerKey string) (configurer.WGStats, error) {
- return m.GetStatsFunc(peerKey)
-}
-
-func (m *MockWGIface) GetProxy() wgproxy.Proxy {
- return m.GetProxyFunc()
-}
-
-func (m *MockWGIface) GetNet() *netstack.Net {
- return m.GetNetFunc()
-}
diff --git a/client/iface/iface_test.go b/client/iface/iface_test.go
index 85db9cacb..e890b30f3 100644
--- a/client/iface/iface_test.go
+++ b/client/iface/iface_test.go
@@ -373,12 +373,12 @@ func Test_UpdatePeer(t *testing.T) {
t.Fatal(err)
}
keepAlive := 15 * time.Second
- allowedIP := "10.99.99.10/32"
+ allowedIP := netip.MustParsePrefix("10.99.99.10/32")
endpoint, err := net.ResolveUDPAddr("udp", "127.0.0.1:9900")
if err != nil {
t.Fatal(err)
}
- err = iface.UpdatePeer(peerPubKey, allowedIP, keepAlive, endpoint, nil)
+ err = iface.UpdatePeer(peerPubKey, []netip.Prefix{allowedIP}, keepAlive, endpoint, nil)
if err != nil {
t.Fatal(err)
}
@@ -396,7 +396,7 @@ func Test_UpdatePeer(t *testing.T) {
var foundAllowedIP bool
for _, aip := range peer.AllowedIPs {
- if aip.String() == allowedIP {
+ if aip.String() == allowedIP.String() {
foundAllowedIP = true
break
}
@@ -443,9 +443,8 @@ func Test_RemovePeer(t *testing.T) {
t.Fatal(err)
}
keepAlive := 15 * time.Second
- allowedIP := "10.99.99.14/32"
-
- err = iface.UpdatePeer(peerPubKey, allowedIP, keepAlive, nil, nil)
+ allowedIP := netip.MustParsePrefix("10.99.99.14/32")
+ err = iface.UpdatePeer(peerPubKey, []netip.Prefix{allowedIP}, keepAlive, nil, nil)
if err != nil {
t.Fatal(err)
}
@@ -462,12 +461,12 @@ func Test_RemovePeer(t *testing.T) {
func Test_ConnectPeers(t *testing.T) {
peer1ifaceName := fmt.Sprintf("utun%d", WgIntNumber+400)
- peer1wgIP := "10.99.99.17/30"
+ peer1wgIP := netip.MustParsePrefix("10.99.99.17/30")
peer1Key, _ := wgtypes.GeneratePrivateKey()
peer1wgPort := 33100
peer2ifaceName := "utun500"
- peer2wgIP := "10.99.99.18/30"
+ peer2wgIP := netip.MustParsePrefix("10.99.99.18/30")
peer2Key, _ := wgtypes.GeneratePrivateKey()
peer2wgPort := 33200
@@ -482,7 +481,7 @@ func Test_ConnectPeers(t *testing.T) {
optsPeer1 := WGIFaceOpts{
IFaceName: peer1ifaceName,
- Address: peer1wgIP,
+ Address: peer1wgIP.String(),
WGPort: peer1wgPort,
WGPrivKey: peer1Key.String(),
MTU: DefaultMTU,
@@ -522,7 +521,7 @@ func Test_ConnectPeers(t *testing.T) {
optsPeer2 := WGIFaceOpts{
IFaceName: peer2ifaceName,
- Address: peer2wgIP,
+ Address: peer2wgIP.String(),
WGPort: peer2wgPort,
WGPrivKey: peer2Key.String(),
MTU: DefaultMTU,
@@ -558,11 +557,11 @@ func Test_ConnectPeers(t *testing.T) {
}
}()
- err = iface1.UpdatePeer(peer2Key.PublicKey().String(), peer2wgIP, keepAlive, peer2endpoint, nil)
+ err = iface1.UpdatePeer(peer2Key.PublicKey().String(), []netip.Prefix{peer2wgIP}, keepAlive, peer2endpoint, nil)
if err != nil {
t.Fatal(err)
}
- err = iface2.UpdatePeer(peer1Key.PublicKey().String(), peer1wgIP, keepAlive, peer1endpoint, nil)
+ err = iface2.UpdatePeer(peer1Key.PublicKey().String(), []netip.Prefix{peer1wgIP}, keepAlive, peer1endpoint, nil)
if err != nil {
t.Fatal(err)
}
diff --git a/client/iface/iwginterface_windows.go b/client/iface/iwginterface_windows.go
deleted file mode 100644
index cac096b54..000000000
--- a/client/iface/iwginterface_windows.go
+++ /dev/null
@@ -1,39 +0,0 @@
-package iface
-
-import (
- "net"
- "time"
-
- wgdevice "golang.zx2c4.com/wireguard/device"
- "golang.zx2c4.com/wireguard/tun/netstack"
- "golang.zx2c4.com/wireguard/wgctrl/wgtypes"
-
- "github.com/netbirdio/netbird/client/iface/bind"
- "github.com/netbirdio/netbird/client/iface/configurer"
- "github.com/netbirdio/netbird/client/iface/device"
- "github.com/netbirdio/netbird/client/iface/wgproxy"
-)
-
-type IWGIface interface {
- Create() error
- CreateOnAndroid(routeRange []string, ip string, domains []string) error
- IsUserspaceBind() bool
- Name() string
- Address() device.WGAddress
- ToInterface() *net.Interface
- Up() (*bind.UniversalUDPMuxDefault, error)
- UpdateAddr(newAddr string) error
- GetProxy() wgproxy.Proxy
- UpdatePeer(peerKey string, allowedIps string, keepAlive time.Duration, endpoint *net.UDPAddr, preSharedKey *wgtypes.Key) error
- RemovePeer(peerKey string) error
- AddAllowedIP(peerKey string, allowedIP string) error
- RemoveAllowedIP(peerKey string, allowedIP string) error
- Close() error
- SetFilter(filter device.PacketFilter) error
- GetFilter() device.PacketFilter
- GetDevice() *device.FilteredDevice
- GetWGDevice() *wgdevice.Device
- GetStats(peerKey string) (configurer.WGStats, error)
- GetInterfaceGUIDString() (string, error)
- GetNet() *netstack.Net
-}
diff --git a/client/internal/config.go b/client/internal/config.go
index b269a3854..b2f96cbdc 100644
--- a/client/internal/config.go
+++ b/client/internal/config.go
@@ -99,7 +99,7 @@ type Config struct {
BlockLANAccess bool
- DisableNotifications bool
+ DisableNotifications *bool
DNSLabels domain.List
@@ -479,13 +479,20 @@ func (config *Config) apply(input ConfigInput) (updated bool, err error) {
updated = true
}
- if input.DisableNotifications != nil && *input.DisableNotifications != config.DisableNotifications {
+ if input.DisableNotifications != nil && input.DisableNotifications != config.DisableNotifications {
if *input.DisableNotifications {
log.Infof("disabling notifications")
} else {
log.Infof("enabling notifications")
}
- config.DisableNotifications = *input.DisableNotifications
+ config.DisableNotifications = input.DisableNotifications
+ updated = true
+ }
+
+ if config.DisableNotifications == nil {
+ disabled := true
+ config.DisableNotifications = &disabled
+ log.Infof("setting notifications to disabled by default")
updated = true
}
diff --git a/client/internal/connect.go b/client/internal/connect.go
index 26ae3b687..bf513ed39 100644
--- a/client/internal/connect.go
+++ b/client/internal/connect.go
@@ -23,6 +23,7 @@ import (
"github.com/netbirdio/netbird/client/internal/listener"
"github.com/netbirdio/netbird/client/internal/peer"
"github.com/netbirdio/netbird/client/internal/stdnet"
+ cProto "github.com/netbirdio/netbird/client/proto"
"github.com/netbirdio/netbird/client/ssh"
"github.com/netbirdio/netbird/client/system"
mgm "github.com/netbirdio/netbird/management/client"
@@ -104,6 +105,16 @@ func (c *ConnectClient) RunOniOS(
func (c *ConnectClient) run(mobileDependency MobileDependency, runningChan chan error) error {
defer func() {
if r := recover(); r != nil {
+ rec := c.statusRecorder
+ if rec != nil {
+ rec.PublishEvent(
+ cProto.SystemEvent_CRITICAL, cProto.SystemEvent_SYSTEM,
+ "panic occurred",
+ "The Netbird service panicked. Please restart the service and submit a bug report with the client logs.",
+ nil,
+ )
+ }
+
log.Panicf("Panic occurred: %v, stack trace: %s", r, string(debug.Stack()))
}
}()
diff --git a/client/internal/dns.go b/client/internal/dns.go
new file mode 100644
index 000000000..8a73f50f2
--- /dev/null
+++ b/client/internal/dns.go
@@ -0,0 +1,111 @@
+package internal
+
+import (
+ "fmt"
+ "net"
+ "slices"
+ "strings"
+
+ "github.com/miekg/dns"
+ log "github.com/sirupsen/logrus"
+
+ nbdns "github.com/netbirdio/netbird/dns"
+)
+
+func createPTRRecord(aRecord nbdns.SimpleRecord, ipNet *net.IPNet) (nbdns.SimpleRecord, bool) {
+ ip := net.ParseIP(aRecord.RData)
+ if ip == nil || ip.To4() == nil {
+ return nbdns.SimpleRecord{}, false
+ }
+
+ if !ipNet.Contains(ip) {
+ return nbdns.SimpleRecord{}, false
+ }
+
+ ipOctets := strings.Split(ip.String(), ".")
+ slices.Reverse(ipOctets)
+ rdnsName := dns.Fqdn(strings.Join(ipOctets, ".") + ".in-addr.arpa")
+
+ return nbdns.SimpleRecord{
+ Name: rdnsName,
+ Type: int(dns.TypePTR),
+ Class: aRecord.Class,
+ TTL: aRecord.TTL,
+ RData: dns.Fqdn(aRecord.Name),
+ }, true
+}
+
+// generateReverseZoneName creates the reverse DNS zone name for a given network
+func generateReverseZoneName(ipNet *net.IPNet) (string, error) {
+ networkIP := ipNet.IP.Mask(ipNet.Mask)
+ maskOnes, _ := ipNet.Mask.Size()
+
+ // round up to nearest byte
+ octetsToUse := (maskOnes + 7) / 8
+
+ octets := strings.Split(networkIP.String(), ".")
+ if octetsToUse > len(octets) {
+ return "", fmt.Errorf("invalid network mask size for reverse DNS: %d", maskOnes)
+ }
+
+ reverseOctets := make([]string, octetsToUse)
+ for i := 0; i < octetsToUse; i++ {
+ reverseOctets[octetsToUse-1-i] = octets[i]
+ }
+
+ return dns.Fqdn(strings.Join(reverseOctets, ".") + ".in-addr.arpa"), nil
+}
+
+// zoneExists checks if a zone with the given name already exists in the configuration
+func zoneExists(config *nbdns.Config, zoneName string) bool {
+ for _, zone := range config.CustomZones {
+ if zone.Domain == zoneName {
+ log.Debugf("reverse DNS zone %s already exists", zoneName)
+ return true
+ }
+ }
+ return false
+}
+
+// collectPTRRecords gathers all PTR records for the given network from A records
+func collectPTRRecords(config *nbdns.Config, ipNet *net.IPNet) []nbdns.SimpleRecord {
+ var records []nbdns.SimpleRecord
+
+ for _, zone := range config.CustomZones {
+ for _, record := range zone.Records {
+ if record.Type != int(dns.TypeA) {
+ continue
+ }
+
+ if ptrRecord, ok := createPTRRecord(record, ipNet); ok {
+ records = append(records, ptrRecord)
+ }
+ }
+ }
+
+ return records
+}
+
+// addReverseZone adds a reverse DNS zone to the configuration for the given network
+func addReverseZone(config *nbdns.Config, ipNet *net.IPNet) {
+ zoneName, err := generateReverseZoneName(ipNet)
+ if err != nil {
+ log.Warn(err)
+ return
+ }
+
+ if zoneExists(config, zoneName) {
+ log.Debugf("reverse DNS zone %s already exists", zoneName)
+ return
+ }
+
+ records := collectPTRRecords(config, ipNet)
+
+ reverseZone := nbdns.CustomZone{
+ Domain: zoneName,
+ Records: records,
+ }
+
+ config.CustomZones = append(config.CustomZones, reverseZone)
+ log.Debugf("added reverse DNS zone: %s with %d records", zoneName, len(records))
+}
diff --git a/client/internal/dns/file_unix.go b/client/internal/dns/file_unix.go
index 02ae26e10..1f4ddb67c 100644
--- a/client/internal/dns/file_unix.go
+++ b/client/internal/dns/file_unix.go
@@ -58,7 +58,7 @@ func (f *fileConfigurator) applyDNSConfig(config HostDNSConfig, stateManager *st
return fmt.Errorf("restoring the original resolv.conf file return err: %w", err)
}
}
- return fmt.Errorf("unable to configure DNS for this peer using file manager without a nameserver group with all domains configured")
+ return ErrRouteAllWithoutNameserverGroup
}
if !backupFileExist {
@@ -121,6 +121,10 @@ func (f *fileConfigurator) restoreHostDNS() error {
return f.restore()
}
+func (f *fileConfigurator) string() string {
+ return "file"
+}
+
func (f *fileConfigurator) backup() error {
stats, err := os.Stat(defaultResolvConfPath)
if err != nil {
diff --git a/client/internal/dns/host.go b/client/internal/dns/host.go
index fbe8c4dbb..25e9ff7e5 100644
--- a/client/internal/dns/host.go
+++ b/client/internal/dns/host.go
@@ -9,10 +9,18 @@ import (
nbdns "github.com/netbirdio/netbird/dns"
)
+var ErrRouteAllWithoutNameserverGroup = fmt.Errorf("unable to configure DNS for this peer using file manager without a nameserver group with all domains configured")
+
+const (
+ ipv4ReverseZone = ".in-addr.arpa"
+ ipv6ReverseZone = ".ip6.arpa"
+)
+
type hostManager interface {
applyDNSConfig(config HostDNSConfig, stateManager *statemanager.Manager) error
restoreHostDNS() error
supportCustomPort() bool
+ string() string
}
type SystemDNSSettings struct {
@@ -39,6 +47,7 @@ type mockHostConfigurator struct {
restoreHostDNSFunc func() error
supportCustomPortFunc func() bool
restoreUncleanShutdownDNSFunc func(*netip.Addr) error
+ stringFunc func() string
}
func (m *mockHostConfigurator) applyDNSConfig(config HostDNSConfig, stateManager *statemanager.Manager) error {
@@ -62,6 +71,13 @@ func (m *mockHostConfigurator) supportCustomPort() bool {
return false
}
+func (m *mockHostConfigurator) string() string {
+ if m.stringFunc != nil {
+ return m.stringFunc()
+ }
+ return "mock"
+}
+
func newNoopHostMocker() hostManager {
return &mockHostConfigurator{
applyDNSConfigFunc: func(config HostDNSConfig, stateManager *statemanager.Manager) error { return nil },
@@ -94,9 +110,10 @@ func dnsConfigToHostDNSConfig(dnsConfig nbdns.Config, ip string, port int) HostD
}
for _, customZone := range dnsConfig.CustomZones {
+ matchOnly := strings.HasSuffix(customZone.Domain, ipv4ReverseZone) || strings.HasSuffix(customZone.Domain, ipv6ReverseZone)
config.Domains = append(config.Domains, DomainConfig{
Domain: strings.TrimSuffix(customZone.Domain, "."),
- MatchOnly: false,
+ MatchOnly: matchOnly,
})
}
@@ -116,3 +133,7 @@ func (n noopHostConfigurator) restoreHostDNS() error {
func (n noopHostConfigurator) supportCustomPort() bool {
return true
}
+
+func (n noopHostConfigurator) string() string {
+ return "noop"
+}
diff --git a/client/internal/dns/host_android.go b/client/internal/dns/host_android.go
index 5653710d7..dfa3e5712 100644
--- a/client/internal/dns/host_android.go
+++ b/client/internal/dns/host_android.go
@@ -22,3 +22,7 @@ func (a androidHostManager) restoreHostDNS() error {
func (a androidHostManager) supportCustomPort() bool {
return false
}
+
+func (a androidHostManager) string() string {
+ return "none"
+}
diff --git a/client/internal/dns/host_darwin.go b/client/internal/dns/host_darwin.go
index 2f92dd367..f727f68b5 100644
--- a/client/internal/dns/host_darwin.go
+++ b/client/internal/dns/host_darwin.go
@@ -114,6 +114,10 @@ func (s *systemConfigurator) applyDNSConfig(config HostDNSConfig, stateManager *
return nil
}
+func (s *systemConfigurator) string() string {
+ return "scutil"
+}
+
func (s *systemConfigurator) restoreHostDNS() error {
keys := s.getRemovableKeysWithDefaults()
for _, key := range keys {
diff --git a/client/internal/dns/host_ios.go b/client/internal/dns/host_ios.go
index 4a0acf572..1c0ac63e9 100644
--- a/client/internal/dns/host_ios.go
+++ b/client/internal/dns/host_ios.go
@@ -38,3 +38,7 @@ func (a iosHostManager) restoreHostDNS() error {
func (a iosHostManager) supportCustomPort() bool {
return false
}
+
+func (a iosHostManager) string() string {
+ return "none"
+}
diff --git a/client/internal/dns/host_windows.go b/client/internal/dns/host_windows.go
index 58b0a14de..dceb24420 100644
--- a/client/internal/dns/host_windows.go
+++ b/client/internal/dns/host_windows.go
@@ -184,6 +184,10 @@ func (r *registryConfigurator) configureDNSPolicy(policyPath string, domains []s
return nil
}
+func (r *registryConfigurator) string() string {
+ return "registry"
+}
+
func (r *registryConfigurator) updateSearchDomains(domains []string) error {
if err := r.setInterfaceRegistryKeyStringValue(interfaceConfigSearchListKey, strings.Join(domains, ",")); err != nil {
return fmt.Errorf("update search domains: %w", err)
diff --git a/client/internal/dns/network_manager_unix.go b/client/internal/dns/network_manager_unix.go
index 63bbead77..10b4e6a6e 100644
--- a/client/internal/dns/network_manager_unix.go
+++ b/client/internal/dns/network_manager_unix.go
@@ -179,6 +179,10 @@ func (n *networkManagerDbusConfigurator) restoreHostDNS() error {
return nil
}
+func (n *networkManagerDbusConfigurator) string() string {
+ return "network-manager"
+}
+
func (n *networkManagerDbusConfigurator) getAppliedConnectionSettings() (networkManagerConnSettings, networkManagerConfigVersion, error) {
obj, closeConn, err := getDbusObject(networkManagerDest, n.dbusLinkObject)
if err != nil {
diff --git a/client/internal/dns/resolvconf_unix.go b/client/internal/dns/resolvconf_unix.go
index 6b5fdaf86..54c4c75bf 100644
--- a/client/internal/dns/resolvconf_unix.go
+++ b/client/internal/dns/resolvconf_unix.go
@@ -91,7 +91,7 @@ func (r *resolvconf) applyDNSConfig(config HostDNSConfig, stateManager *stateman
if err != nil {
log.Errorf("restore host dns: %s", err)
}
- return fmt.Errorf("unable to configure DNS for this peer using resolvconf manager without a nameserver group with all domains configured")
+ return ErrRouteAllWithoutNameserverGroup
}
searchDomainList := searchDomains(config)
@@ -139,6 +139,10 @@ func (r *resolvconf) restoreHostDNS() error {
return nil
}
+func (r *resolvconf) string() string {
+ return fmt.Sprintf("resolvconf (%s)", r.implType)
+}
+
func (r *resolvconf) applyConfig(content bytes.Buffer) error {
var cmd *exec.Cmd
diff --git a/client/internal/dns/server.go b/client/internal/dns/server.go
index d4d68370d..bc87012f2 100644
--- a/client/internal/dns/server.go
+++ b/client/internal/dns/server.go
@@ -2,6 +2,7 @@ package dns
import (
"context"
+ "errors"
"fmt"
"net/netip"
"runtime"
@@ -15,6 +16,7 @@ import (
"github.com/netbirdio/netbird/client/internal/listener"
"github.com/netbirdio/netbird/client/internal/peer"
"github.com/netbirdio/netbird/client/internal/statemanager"
+ cProto "github.com/netbirdio/netbird/client/proto"
nbdns "github.com/netbirdio/netbird/dns"
)
@@ -395,12 +397,12 @@ func (s *DefaultServer) applyConfiguration(update nbdns.Config) error {
localMuxUpdates, localRecordsByDomain, err := s.buildLocalHandlerUpdate(update.CustomZones)
if err != nil {
- return fmt.Errorf("not applying dns update, error: %v", err)
+ return fmt.Errorf("local handler updater: %w", err)
}
upstreamMuxUpdates, err := s.buildUpstreamHandlerUpdate(update.NameServerGroups)
if err != nil {
- return fmt.Errorf("not applying dns update, error: %v", err)
+ return fmt.Errorf("upstream handler updater: %w", err)
}
muxUpdates := append(localMuxUpdates, upstreamMuxUpdates...) //nolint:gocritic
@@ -420,6 +422,7 @@ func (s *DefaultServer) applyConfiguration(update nbdns.Config) error {
if err = s.hostManager.applyDNSConfig(hostUpdate, s.stateManager); err != nil {
log.Error(err)
+ s.handleErrNoGroupaAll(err)
}
go func() {
@@ -438,16 +441,33 @@ func (s *DefaultServer) applyConfiguration(update nbdns.Config) error {
return nil
}
+func (s *DefaultServer) handleErrNoGroupaAll(err error) {
+ if !errors.Is(ErrRouteAllWithoutNameserverGroup, err) {
+ return
+ }
+
+ if s.statusRecorder == nil {
+ return
+ }
+
+ s.statusRecorder.PublishEvent(
+ cProto.SystemEvent_WARNING, cProto.SystemEvent_DNS,
+ "The host dns manager does not support match domains",
+ "The host dns manager does not support match domains without a catch-all nameserver group.",
+ map[string]string{"manager": s.hostManager.string()},
+ )
+}
+
func (s *DefaultServer) buildLocalHandlerUpdate(
customZones []nbdns.CustomZone,
) ([]handlerWrapper, map[string][]nbdns.SimpleRecord, error) {
-
var muxUpdates []handlerWrapper
localRecords := make(map[string][]nbdns.SimpleRecord)
for _, customZone := range customZones {
if len(customZone.Records) == 0 {
- return nil, nil, fmt.Errorf("received an empty list of records")
+ log.Warnf("received a custom zone with empty records, skipping domain: %s", customZone.Domain)
+ continue
}
muxUpdates = append(muxUpdates, handlerWrapper{
@@ -460,7 +480,8 @@ func (s *DefaultServer) buildLocalHandlerUpdate(
for _, record := range customZone.Records {
var class uint16 = dns.ClassINET
if record.Class != nbdns.DefaultClass {
- return nil, nil, fmt.Errorf("received an invalid class type: %s", record.Class)
+ log.Warnf("received an invalid class type: %s", record.Class)
+ continue
}
key := buildRecordKey(record.Name, class, uint16(record.Type))
@@ -670,6 +691,7 @@ func (s *DefaultServer) upstreamCallbacks(
}
if err := s.hostManager.applyDNSConfig(s.currentConfig, s.stateManager); err != nil {
+ s.handleErrNoGroupaAll(err)
l.Errorf("Failed to apply nameserver deactivation on the host: %v", err)
}
@@ -708,6 +730,7 @@ func (s *DefaultServer) upstreamCallbacks(
if s.hostManager != nil {
if err := s.hostManager.applyDNSConfig(s.currentConfig, s.stateManager); err != nil {
+ s.handleErrNoGroupaAll(err)
l.WithError(err).Error("reactivate temporary disabled nameserver group, DNS update apply")
}
}
diff --git a/client/internal/dns/server_test.go b/client/internal/dns/server_test.go
index e9ddd5f59..94b87124b 100644
--- a/client/internal/dns/server_test.go
+++ b/client/internal/dns/server_test.go
@@ -266,7 +266,7 @@ func TestUpdateDNSServer(t *testing.T) {
shouldFail: true,
},
{
- name: "Invalid Custom Zone Records list Should Fail",
+ name: "Invalid Custom Zone Records list Should Skip",
initLocalMap: make(registrationMap),
initUpstreamMap: make(registeredHandlerMap),
initSerial: 0,
@@ -285,7 +285,11 @@ func TestUpdateDNSServer(t *testing.T) {
},
},
},
- shouldFail: true,
+ expectedUpstreamMap: registeredHandlerMap{generateDummyHandler(".", nameServers).id(): handlerWrapper{
+ domain: ".",
+ handler: dummyHandler,
+ priority: PriorityDefault,
+ }},
},
{
name: "Empty Config Should Succeed and Clean Maps",
@@ -352,7 +356,7 @@ func TestUpdateDNSServer(t *testing.T) {
t.Log(err)
}
}()
- dnsServer, err := NewDefaultServer(context.Background(), wgIface, "", &peer.Status{}, nil, false)
+ dnsServer, err := NewDefaultServer(context.Background(), wgIface, "", peer.NewRecorder("mgm"), nil, false)
if err != nil {
t.Fatal(err)
}
@@ -409,7 +413,7 @@ func TestDNSFakeResolverHandleUpdates(t *testing.T) {
defer t.Setenv("NB_WG_KERNEL_DISABLED", ov)
t.Setenv("NB_WG_KERNEL_DISABLED", "true")
- newNet, err := stdnet.NewNet(nil)
+ newNet, err := stdnet.NewNet([]string{"utun2301"})
if err != nil {
t.Errorf("create stdnet: %v", err)
return
@@ -461,7 +465,7 @@ func TestDNSFakeResolverHandleUpdates(t *testing.T) {
return
}
- dnsServer, err := NewDefaultServer(context.Background(), wgIface, "", &peer.Status{}, nil, false)
+ dnsServer, err := NewDefaultServer(context.Background(), wgIface, "", peer.NewRecorder("mgm"), nil, false)
if err != nil {
t.Errorf("create DNS server: %v", err)
return
@@ -562,7 +566,7 @@ func TestDNSServerStartStop(t *testing.T) {
for _, testCase := range testCases {
t.Run(testCase.name, func(t *testing.T) {
- dnsServer, err := NewDefaultServer(context.Background(), &mocWGIface{}, testCase.addrPort, &peer.Status{}, nil, false)
+ dnsServer, err := NewDefaultServer(context.Background(), &mocWGIface{}, testCase.addrPort, peer.NewRecorder("mgm"), nil, false)
if err != nil {
t.Fatalf("%v", err)
}
@@ -635,7 +639,7 @@ func TestDNSServerUpstreamDeactivateCallback(t *testing.T) {
{false, "domain2", false},
},
},
- statusRecorder: &peer.Status{},
+ statusRecorder: peer.NewRecorder("mgm"),
}
var domainsUpdate string
@@ -696,7 +700,7 @@ func TestDNSPermanent_updateHostDNS_emptyUpstream(t *testing.T) {
var dnsList []string
dnsConfig := nbdns.Config{}
- dnsServer := NewDefaultServerPermanentUpstream(context.Background(), wgIFace, dnsList, dnsConfig, nil, &peer.Status{}, false)
+ dnsServer := NewDefaultServerPermanentUpstream(context.Background(), wgIFace, dnsList, dnsConfig, nil, peer.NewRecorder("mgm"), false)
err = dnsServer.Initialize()
if err != nil {
t.Errorf("failed to initialize DNS server: %v", err)
@@ -720,7 +724,7 @@ func TestDNSPermanent_updateUpstream(t *testing.T) {
}
defer wgIFace.Close()
dnsConfig := nbdns.Config{}
- dnsServer := NewDefaultServerPermanentUpstream(context.Background(), wgIFace, []string{"8.8.8.8"}, dnsConfig, nil, &peer.Status{}, false)
+ dnsServer := NewDefaultServerPermanentUpstream(context.Background(), wgIFace, []string{"8.8.8.8"}, dnsConfig, nil, peer.NewRecorder("mgm"), false)
err = dnsServer.Initialize()
if err != nil {
t.Errorf("failed to initialize DNS server: %v", err)
@@ -812,7 +816,7 @@ func TestDNSPermanent_matchOnly(t *testing.T) {
}
defer wgIFace.Close()
dnsConfig := nbdns.Config{}
- dnsServer := NewDefaultServerPermanentUpstream(context.Background(), wgIFace, []string{"8.8.8.8"}, dnsConfig, nil, &peer.Status{}, false)
+ dnsServer := NewDefaultServerPermanentUpstream(context.Background(), wgIFace, []string{"8.8.8.8"}, dnsConfig, nil, peer.NewRecorder("mgm"), false)
err = dnsServer.Initialize()
if err != nil {
t.Errorf("failed to initialize DNS server: %v", err)
@@ -883,7 +887,7 @@ func createWgInterfaceWithBind(t *testing.T) (*iface.WGIface, error) {
defer t.Setenv("NB_WG_KERNEL_DISABLED", ov)
t.Setenv("NB_WG_KERNEL_DISABLED", "true")
- newNet, err := stdnet.NewNet(nil)
+ newNet, err := stdnet.NewNet([]string{"utun2301"})
if err != nil {
t.Fatalf("create stdnet: %v", err)
return nil, err
diff --git a/client/internal/dns/systemd_linux.go b/client/internal/dns/systemd_linux.go
index a031be582..a87cc73e5 100644
--- a/client/internal/dns/systemd_linux.go
+++ b/client/internal/dns/systemd_linux.go
@@ -154,6 +154,10 @@ func (s *systemdDbusConfigurator) applyDNSConfig(config HostDNSConfig, stateMana
return nil
}
+func (s *systemdDbusConfigurator) string() string {
+ return "dbus"
+}
+
func (s *systemdDbusConfigurator) setDomainsForInterface(domainsInput []systemdDbusLinkDomainsInput) error {
err := s.callLinkMethod(systemdDbusSetDomainsMethodSuffix, domainsInput)
if err != nil {
diff --git a/client/internal/dns/upstream.go b/client/internal/dns/upstream.go
index d269107e3..a22689cf9 100644
--- a/client/internal/dns/upstream.go
+++ b/client/internal/dns/upstream.go
@@ -183,6 +183,19 @@ func (u *upstreamResolverBase) checkUpstreamFails(err error) {
}
u.disable(err)
+
+ if u.statusRecorder == nil {
+ return
+ }
+
+ u.statusRecorder.PublishEvent(
+ proto.SystemEvent_WARNING,
+ proto.SystemEvent_DNS,
+ "All upstream servers failed (fail count exceeded)",
+ "Unable to reach one or more DNS servers. This might affect your ability to connect to some services.",
+ map[string]string{"upstreams": strings.Join(u.upstreamServers, ", ")},
+ // TODO add domain meta
+ )
}
// probeAvailability tests all upstream servers simultaneously and
@@ -232,10 +245,14 @@ func (u *upstreamResolverBase) probeAvailability() {
if !success {
u.disable(errors.ErrorOrNil())
+ if u.statusRecorder == nil {
+ return
+ }
+
u.statusRecorder.PublishEvent(
proto.SystemEvent_WARNING,
proto.SystemEvent_DNS,
- "All upstream servers failed",
+ "All upstream servers failed (probe failed)",
"Unable to reach one or more DNS servers. This might affect your ability to connect to some services.",
map[string]string{"upstreams": strings.Join(u.upstreamServers, ", ")},
)
diff --git a/client/internal/engine.go b/client/internal/engine.go
index 865df5fbb..3d7802675 100644
--- a/client/internal/engine.go
+++ b/client/internal/engine.go
@@ -44,6 +44,7 @@ import (
"github.com/netbirdio/netbird/client/internal/routemanager"
"github.com/netbirdio/netbird/client/internal/routemanager/systemops"
"github.com/netbirdio/netbird/client/internal/statemanager"
+ cProto "github.com/netbirdio/netbird/client/proto"
"github.com/netbirdio/netbird/management/domain"
semaphoregroup "github.com/netbirdio/netbird/util/semaphore-group"
@@ -155,7 +156,7 @@ type Engine struct {
ctx context.Context
cancel context.CancelFunc
- wgInterface iface.IWGIface
+ wgInterface WGIface
udpMux *bind.UniversalUDPMuxDefault
@@ -535,15 +536,18 @@ func (e *Engine) modifyPeers(peersUpdate []*mgmProto.RemotePeerConfig) error {
var modified []*mgmProto.RemotePeerConfig
for _, p := range peersUpdate {
peerPubKey := p.GetWgPubKey()
- if allowedIPs, ok := e.peerStore.AllowedIPs(peerPubKey); ok {
- if allowedIPs != strings.Join(p.AllowedIps, ",") {
- modified = append(modified, p)
- continue
- }
- err := e.statusRecorder.UpdatePeerFQDN(peerPubKey, p.GetFqdn())
- if err != nil {
- log.Warnf("error updating peer's %s fqdn in the status recorder, got error: %v", peerPubKey, err)
- }
+ allowedIPs, ok := e.peerStore.AllowedIPs(peerPubKey)
+ if !ok {
+ continue
+ }
+ if !compareNetIPLists(allowedIPs, p.GetAllowedIps()) {
+ modified = append(modified, p)
+ continue
+ }
+
+ err := e.statusRecorder.UpdatePeerFQDN(peerPubKey, p.GetFqdn())
+ if err != nil {
+ log.Warnf("error updating peer's %s fqdn in the status recorder, got error: %v", peerPubKey, err)
}
}
@@ -682,6 +686,8 @@ func (e *Engine) handleSync(update *mgmProto.SyncResponse) error {
return err
}
+ e.statusRecorder.PublishEvent(cProto.SystemEvent_INFO, cProto.SystemEvent_SYSTEM, "Network map updated", "", nil)
+
return nil
}
@@ -967,7 +973,7 @@ func (e *Engine) updateNetworkMap(networkMap *mgmProto.NetworkMap) error {
protoDNSConfig = &mgmProto.DNSConfig{}
}
- if err := e.dnsServer.UpdateDNSServer(serial, toDNSConfig(protoDNSConfig)); err != nil {
+ if err := e.dnsServer.UpdateDNSServer(serial, toDNSConfig(protoDNSConfig, e.wgInterface.Address().Network)); err != nil {
log.Errorf("failed to update dns server, err: %v", err)
}
@@ -1036,7 +1042,7 @@ func toRouteDomains(myPubKey string, protoRoutes []*mgmProto.Route) []string {
return dnsRoutes
}
-func toDNSConfig(protoDNSConfig *mgmProto.DNSConfig) nbdns.Config {
+func toDNSConfig(protoDNSConfig *mgmProto.DNSConfig, network *net.IPNet) nbdns.Config {
dnsUpdate := nbdns.Config{
ServiceEnable: protoDNSConfig.GetServiceEnable(),
CustomZones: make([]nbdns.CustomZone, 0),
@@ -1076,6 +1082,11 @@ func toDNSConfig(protoDNSConfig *mgmProto.DNSConfig) nbdns.Config {
}
dnsUpdate.NameServerGroups = append(dnsUpdate.NameServerGroups, dnsNSGroup)
}
+
+ if len(dnsUpdate.CustomZones) > 0 {
+ addReverseZone(&dnsUpdate, network)
+ }
+
return dnsUpdate
}
@@ -1109,34 +1120,45 @@ func (e *Engine) addNewPeers(peersUpdate []*mgmProto.RemotePeerConfig) error {
// addNewPeer add peer if connection doesn't exist
func (e *Engine) addNewPeer(peerConfig *mgmProto.RemotePeerConfig) error {
peerKey := peerConfig.GetWgPubKey()
- peerIPs := peerConfig.GetAllowedIps()
- if _, ok := e.peerStore.PeerConn(peerKey); !ok {
- conn, err := e.createPeerConn(peerKey, strings.Join(peerIPs, ","))
- if err != nil {
- return fmt.Errorf("create peer connection: %w", err)
- }
-
- if ok := e.peerStore.AddPeerConn(peerKey, conn); !ok {
- conn.Close()
- return fmt.Errorf("peer already exists: %s", peerKey)
- }
-
- if e.beforePeerHook != nil && e.afterPeerHook != nil {
- conn.AddBeforeAddPeerHook(e.beforePeerHook)
- conn.AddAfterRemovePeerHook(e.afterPeerHook)
- }
-
- err = e.statusRecorder.AddPeer(peerKey, peerConfig.Fqdn)
- if err != nil {
- log.Warnf("error adding peer %s to status recorder, got error: %v", peerKey, err)
- }
-
- conn.Open()
+ peerIPs := make([]netip.Prefix, 0, len(peerConfig.GetAllowedIps()))
+ if _, ok := e.peerStore.PeerConn(peerKey); ok {
+ return nil
}
+
+ for _, ipString := range peerConfig.GetAllowedIps() {
+ allowedNetIP, err := netip.ParsePrefix(ipString)
+ if err != nil {
+ log.Errorf("failed to parse allowedIPS: %v", err)
+ return err
+ }
+ peerIPs = append(peerIPs, allowedNetIP)
+ }
+
+ conn, err := e.createPeerConn(peerKey, peerIPs)
+ if err != nil {
+ return fmt.Errorf("create peer connection: %w", err)
+ }
+
+ if ok := e.peerStore.AddPeerConn(peerKey, conn); !ok {
+ conn.Close()
+ return fmt.Errorf("peer already exists: %s", peerKey)
+ }
+
+ if e.beforePeerHook != nil && e.afterPeerHook != nil {
+ conn.AddBeforeAddPeerHook(e.beforePeerHook)
+ conn.AddAfterRemovePeerHook(e.afterPeerHook)
+ }
+
+ err = e.statusRecorder.AddPeer(peerKey, peerConfig.Fqdn)
+ if err != nil {
+ log.Warnf("error adding peer %s to status recorder, got error: %v", peerKey, err)
+ }
+
+ conn.Open()
return nil
}
-func (e *Engine) createPeerConn(pubKey string, allowedIPs string) (*peer.Conn, error) {
+func (e *Engine) createPeerConn(pubKey string, allowedIPs []netip.Prefix) (*peer.Conn, error) {
log.Debugf("creating peer connection %s", pubKey)
wgConfig := peer.WgConfig{
@@ -1382,7 +1404,7 @@ func (e *Engine) readInitialSettings() ([]*route.Route, *nbdns.Config, error) {
return nil, nil, err
}
routes := toRoutes(netMap.GetRoutes())
- dnsCfg := toDNSConfig(netMap.GetDNSConfig())
+ dnsCfg := toDNSConfig(netMap.GetDNSConfig(), e.wgInterface.Address().Network)
return routes, &dnsCfg, nil
}
@@ -1889,3 +1911,36 @@ func getInterfacePrefixes() ([]netip.Prefix, error) {
return prefixes, nberrors.FormatErrorOrNil(merr)
}
+
+// compareNetIPLists compares a list of netip.Prefix with a list of strings.
+// return true if both lists are equal, false otherwise.
+func compareNetIPLists(list1 []netip.Prefix, list2 []string) bool {
+ if len(list1) != len(list2) {
+ return false
+ }
+
+ freq := make(map[string]int, len(list1))
+ for _, p := range list1 {
+ freq[p.String()]++
+ }
+
+ for _, s := range list2 {
+ p, err := netip.ParsePrefix(s)
+ if err != nil {
+ return false // invalid prefix in list2.
+ }
+ key := p.String()
+ if freq[key] == 0 {
+ return false
+ }
+ freq[key]--
+ }
+
+ // all counts should be zero if lists are equal.
+ for _, count := range freq {
+ if count != 0 {
+ return false
+ }
+ }
+ return true
+}
diff --git a/client/internal/engine_test.go b/client/internal/engine_test.go
index 1ac9a0430..9de1da28d 100644
--- a/client/internal/engine_test.go
+++ b/client/internal/engine_test.go
@@ -22,11 +22,16 @@ import (
"google.golang.org/grpc"
"google.golang.org/grpc/keepalive"
+ wgdevice "golang.zx2c4.com/wireguard/device"
+ "golang.zx2c4.com/wireguard/tun/netstack"
+
"github.com/netbirdio/management-integrations/integrations"
"github.com/netbirdio/netbird/client/iface"
"github.com/netbirdio/netbird/client/iface/bind"
+ "github.com/netbirdio/netbird/client/iface/configurer"
"github.com/netbirdio/netbird/client/iface/device"
+ "github.com/netbirdio/netbird/client/iface/wgproxy"
"github.com/netbirdio/netbird/client/internal/dns"
"github.com/netbirdio/netbird/client/internal/peer"
"github.com/netbirdio/netbird/client/internal/peer/guard"
@@ -65,6 +70,114 @@ var (
}
)
+type MockWGIface struct {
+ CreateFunc func() error
+ CreateOnAndroidFunc func(routeRange []string, ip string, domains []string) error
+ IsUserspaceBindFunc func() bool
+ NameFunc func() string
+ AddressFunc func() device.WGAddress
+ ToInterfaceFunc func() *net.Interface
+ UpFunc func() (*bind.UniversalUDPMuxDefault, error)
+ UpdateAddrFunc func(newAddr string) error
+ UpdatePeerFunc func(peerKey string, allowedIps []netip.Prefix, keepAlive time.Duration, endpoint *net.UDPAddr, preSharedKey *wgtypes.Key) error
+ RemovePeerFunc func(peerKey string) error
+ AddAllowedIPFunc func(peerKey string, allowedIP string) error
+ RemoveAllowedIPFunc func(peerKey string, allowedIP string) error
+ CloseFunc func() error
+ SetFilterFunc func(filter device.PacketFilter) error
+ GetFilterFunc func() device.PacketFilter
+ GetDeviceFunc func() *device.FilteredDevice
+ GetWGDeviceFunc func() *wgdevice.Device
+ GetStatsFunc func(peerKey string) (configurer.WGStats, error)
+ GetInterfaceGUIDStringFunc func() (string, error)
+ GetProxyFunc func() wgproxy.Proxy
+ GetNetFunc func() *netstack.Net
+}
+
+func (m *MockWGIface) GetInterfaceGUIDString() (string, error) {
+ return m.GetInterfaceGUIDStringFunc()
+}
+
+func (m *MockWGIface) Create() error {
+ return m.CreateFunc()
+}
+
+func (m *MockWGIface) CreateOnAndroid(routeRange []string, ip string, domains []string) error {
+ return m.CreateOnAndroidFunc(routeRange, ip, domains)
+}
+
+func (m *MockWGIface) IsUserspaceBind() bool {
+ return m.IsUserspaceBindFunc()
+}
+
+func (m *MockWGIface) Name() string {
+ return m.NameFunc()
+}
+
+func (m *MockWGIface) Address() device.WGAddress {
+ return m.AddressFunc()
+}
+
+func (m *MockWGIface) ToInterface() *net.Interface {
+ return m.ToInterfaceFunc()
+}
+
+func (m *MockWGIface) Up() (*bind.UniversalUDPMuxDefault, error) {
+ return m.UpFunc()
+}
+
+func (m *MockWGIface) UpdateAddr(newAddr string) error {
+ return m.UpdateAddrFunc(newAddr)
+}
+
+func (m *MockWGIface) UpdatePeer(peerKey string, allowedIps []netip.Prefix, keepAlive time.Duration, endpoint *net.UDPAddr, preSharedKey *wgtypes.Key) error {
+ return m.UpdatePeerFunc(peerKey, allowedIps, keepAlive, endpoint, preSharedKey)
+}
+
+func (m *MockWGIface) RemovePeer(peerKey string) error {
+ return m.RemovePeerFunc(peerKey)
+}
+
+func (m *MockWGIface) AddAllowedIP(peerKey string, allowedIP string) error {
+ return m.AddAllowedIPFunc(peerKey, allowedIP)
+}
+
+func (m *MockWGIface) RemoveAllowedIP(peerKey string, allowedIP string) error {
+ return m.RemoveAllowedIPFunc(peerKey, allowedIP)
+}
+
+func (m *MockWGIface) Close() error {
+ return m.CloseFunc()
+}
+
+func (m *MockWGIface) SetFilter(filter device.PacketFilter) error {
+ return m.SetFilterFunc(filter)
+}
+
+func (m *MockWGIface) GetFilter() device.PacketFilter {
+ return m.GetFilterFunc()
+}
+
+func (m *MockWGIface) GetDevice() *device.FilteredDevice {
+ return m.GetDeviceFunc()
+}
+
+func (m *MockWGIface) GetWGDevice() *wgdevice.Device {
+ return m.GetWGDeviceFunc()
+}
+
+func (m *MockWGIface) GetStats(peerKey string) (configurer.WGStats, error) {
+ return m.GetStatsFunc(peerKey)
+}
+
+func (m *MockWGIface) GetProxy() wgproxy.Proxy {
+ return m.GetProxyFunc()
+}
+
+func (m *MockWGIface) GetNet() *netstack.Net {
+ return m.GetNetFunc()
+}
+
func TestMain(m *testing.M) {
_ = util.InitLog("debug", "console")
code := m.Run()
@@ -246,11 +359,20 @@ func TestEngine_UpdateNetworkMap(t *testing.T) {
peer.NewRecorder("https://mgm"),
nil)
- wgIface := &iface.MockWGIface{
+ wgIface := &MockWGIface{
NameFunc: func() string { return "utun102" },
RemovePeerFunc: func(peerKey string) error {
return nil
},
+ AddressFunc: func() iface.WGAddress {
+ return iface.WGAddress{
+ IP: net.ParseIP("10.20.0.1"),
+ Network: &net.IPNet{
+ IP: net.ParseIP("10.20.0.0"),
+ Mask: net.IPv4Mask(255, 255, 255, 0),
+ },
+ }
+ },
}
engine.wgInterface = wgIface
engine.routeManager = routemanager.NewManager(routemanager.ManagerConfig{
@@ -414,7 +536,7 @@ func TestEngine_UpdateNetworkMap(t *testing.T) {
t.Errorf("expecting Engine.peerConns to contain peer %s", p)
}
expectedAllowedIPs := strings.Join(p.AllowedIps, ",")
- if conn.WgConfig().AllowedIps != expectedAllowedIPs {
+ if !compareNetIPLists(conn.WgConfig().AllowedIps, p.AllowedIps) {
t.Errorf("expecting peer %s to have AllowedIPs= %s, got %s", p.GetWgPubKey(),
expectedAllowedIPs, conn.WgConfig().AllowedIps)
}
@@ -693,6 +815,9 @@ func TestEngine_UpdateNetworkMapWithDNSUpdate(t *testing.T) {
},
},
},
+ {
+ Domain: "0.66.100.in-addr.arpa.",
+ },
},
NameServerGroups: []*mgmtProto.NameServerGroup{
{
@@ -722,6 +847,9 @@ func TestEngine_UpdateNetworkMapWithDNSUpdate(t *testing.T) {
},
},
},
+ {
+ Domain: "0.66.100.in-addr.arpa.",
+ },
},
expectedNSGroupsLen: 1,
expectedNSGroups: []*nbdns.NameServerGroup{
@@ -1111,6 +1239,91 @@ func Test_CheckFilesEqual(t *testing.T) {
}
}
+func TestCompareNetIPLists(t *testing.T) {
+ tests := []struct {
+ name string
+ list1 []netip.Prefix
+ list2 []string
+ expected bool
+ }{
+ {
+ name: "both empty",
+ list1: []netip.Prefix{},
+ list2: []string{},
+ expected: true,
+ },
+ {
+ name: "single match ipv4",
+ list1: []netip.Prefix{netip.MustParsePrefix("192.168.0.0/24")},
+ list2: []string{"192.168.0.0/24"},
+ expected: true,
+ },
+ {
+ name: "multiple match ipv4, different order",
+ list1: []netip.Prefix{netip.MustParsePrefix("192.168.1.0/24"), netip.MustParsePrefix("10.0.0.0/8")},
+ list2: []string{"10.0.0.0/8", "192.168.1.0/24"},
+ expected: true,
+ },
+ {
+ name: "ipv4 mismatch due to extra element in list2",
+ list1: []netip.Prefix{netip.MustParsePrefix("192.168.1.0/24")},
+ list2: []string{"192.168.1.0/24", "10.0.0.0/8"},
+ expected: false,
+ },
+ {
+ name: "ipv4 mismatch due to duplicate count",
+ list1: []netip.Prefix{netip.MustParsePrefix("192.168.1.0/24"), netip.MustParsePrefix("192.168.1.0/24")},
+ list2: []string{"192.168.1.0/24"},
+ expected: false,
+ },
+ {
+ name: "invalid prefix in list2",
+ list1: []netip.Prefix{netip.MustParsePrefix("192.168.1.0/24")},
+ list2: []string{"invalid-prefix"},
+ expected: false,
+ },
+ {
+ name: "ipv4 mismatch because different prefixes",
+ list1: []netip.Prefix{netip.MustParsePrefix("192.168.1.0/24")},
+ list2: []string{"10.0.0.0/8"},
+ expected: false,
+ },
+ {
+ name: "single match ipv6",
+ list1: []netip.Prefix{netip.MustParsePrefix("2001:db8::/32")},
+ list2: []string{"2001:db8::/32"},
+ expected: true,
+ },
+ {
+ name: "multiple match ipv6, different order",
+ list1: []netip.Prefix{netip.MustParsePrefix("2001:db8::/32"), netip.MustParsePrefix("fe80::/10")},
+ list2: []string{"fe80::/10", "2001:db8::/32"},
+ expected: true,
+ },
+ {
+ name: "mixed ipv4 and ipv6 match",
+ list1: []netip.Prefix{netip.MustParsePrefix("192.168.1.0/24"), netip.MustParsePrefix("2001:db8::/32")},
+ list2: []string{"2001:db8::/32", "192.168.1.0/24"},
+ expected: true,
+ },
+ {
+ name: "ipv6 mismatch with invalid prefix",
+ list1: []netip.Prefix{netip.MustParsePrefix("2001:db8::/32")},
+ list2: []string{"invalid-ipv6"},
+ expected: false,
+ },
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ result := compareNetIPLists(tt.list1, tt.list2)
+ if result != tt.expected {
+ t.Errorf("compareNetIPLists(%v, %v) = %v; want %v", tt.list1, tt.list2, result, tt.expected)
+ }
+ })
+ }
+}
+
func createEngine(ctx context.Context, cancel context.CancelFunc, setupKey string, i int, mgmtAddr string, signalAddr string) (*Engine, error) {
key, err := wgtypes.GeneratePrivateKey()
if err != nil {
@@ -1131,7 +1344,7 @@ func createEngine(ctx context.Context, cancel context.CancelFunc, setupKey strin
}
info := system.GetInfo(ctx)
- resp, err := mgmtClient.Register(*publicKey, setupKey, "", info, nil)
+ resp, err := mgmtClient.Register(*publicKey, setupKey, "", info, nil, nil)
if err != nil {
return nil, err
}
@@ -1227,7 +1440,7 @@ func startManagement(t *testing.T, dataDir, testFile string) (*grpc.Server, stri
}
secretsManager := server.NewTimeBasedAuthSecretsManager(peersUpdateManager, config.TURNConfig, config.Relay)
- mgmtServer, err := server.NewServer(context.Background(), config, accountManager, settings.NewManager(store), peersUpdateManager, secretsManager, nil, nil)
+ mgmtServer, err := server.NewServer(context.Background(), config, accountManager, settings.NewManager(store), peersUpdateManager, secretsManager, nil, nil, nil)
if err != nil {
return nil, "", err
}
diff --git a/client/internal/iface.go b/client/internal/iface.go
new file mode 100644
index 000000000..bd0069c19
--- /dev/null
+++ b/client/internal/iface.go
@@ -0,0 +1,8 @@
+//go:build !windows
+// +build !windows
+
+package internal
+
+type WGIface interface {
+ wgIfaceBase
+}
diff --git a/client/iface/iwginterface.go b/client/internal/iface_common.go
similarity index 84%
rename from client/iface/iwginterface.go
rename to client/internal/iface_common.go
index 2b919ac9e..65b425015 100644
--- a/client/iface/iwginterface.go
+++ b/client/internal/iface_common.go
@@ -1,9 +1,8 @@
-//go:build !windows
-
-package iface
+package internal
import (
"net"
+ "net/netip"
"time"
wgdevice "golang.zx2c4.com/wireguard/device"
@@ -16,7 +15,7 @@ import (
"github.com/netbirdio/netbird/client/iface/wgproxy"
)
-type IWGIface interface {
+type wgIfaceBase interface {
Create() error
CreateOnAndroid(routeRange []string, ip string, domains []string) error
IsUserspaceBind() bool
@@ -26,7 +25,7 @@ type IWGIface interface {
Up() (*bind.UniversalUDPMuxDefault, error)
UpdateAddr(newAddr string) error
GetProxy() wgproxy.Proxy
- UpdatePeer(peerKey string, allowedIps string, keepAlive time.Duration, endpoint *net.UDPAddr, preSharedKey *wgtypes.Key) error
+ UpdatePeer(peerKey string, allowedIps []netip.Prefix, keepAlive time.Duration, endpoint *net.UDPAddr, preSharedKey *wgtypes.Key) error
RemovePeer(peerKey string) error
AddAllowedIP(peerKey string, allowedIP string) error
RemoveAllowedIP(peerKey string, allowedIP string) error
diff --git a/client/internal/iface_windows.go b/client/internal/iface_windows.go
new file mode 100644
index 000000000..113217815
--- /dev/null
+++ b/client/internal/iface_windows.go
@@ -0,0 +1,6 @@
+package internal
+
+type WGIface interface {
+ wgIfaceBase
+ GetInterfaceGUIDString() (string, error)
+}
diff --git a/client/internal/login.go b/client/internal/login.go
index 092f2309c..395a17199 100644
--- a/client/internal/login.go
+++ b/client/internal/login.go
@@ -140,7 +140,7 @@ func registerPeer(ctx context.Context, serverPublicKey wgtypes.Key, client *mgm.
config.DisableDNS,
config.DisableFirewall,
)
- loginResp, err := client.Register(serverPublicKey, validSetupKey.String(), jwtToken, info, pubSSHKey)
+ loginResp, err := client.Register(serverPublicKey, validSetupKey.String(), jwtToken, info, pubSSHKey, config.DNSLabels)
if err != nil {
log.Errorf("failed registering peer %v,%s", err, validSetupKey.String())
return nil, err
diff --git a/client/internal/peer/conn.go b/client/internal/peer/conn.go
index 8bbea6a2b..9b4d1a554 100644
--- a/client/internal/peer/conn.go
+++ b/client/internal/peer/conn.go
@@ -5,9 +5,9 @@ import (
"fmt"
"math/rand"
"net"
+ "net/netip"
"os"
"runtime"
- "strings"
"sync"
"time"
@@ -15,7 +15,6 @@ import (
log "github.com/sirupsen/logrus"
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
- "github.com/netbirdio/netbird/client/iface"
"github.com/netbirdio/netbird/client/iface/configurer"
"github.com/netbirdio/netbird/client/iface/wgproxy"
"github.com/netbirdio/netbird/client/internal/peer/guard"
@@ -56,8 +55,8 @@ const (
type WgConfig struct {
WgListenPort int
RemoteKey string
- WgInterface iface.IWGIface
- AllowedIps string
+ WgInterface WGIface
+ AllowedIps []netip.Prefix
PreSharedKey *wgtypes.Key
}
@@ -92,11 +91,10 @@ type Conn struct {
statusRecorder *Status
signaler *Signaler
relayManager *relayClient.Manager
- allowedIP net.IP
handshaker *Handshaker
onConnected func(remoteWireGuardKey string, remoteRosenpassPubKey []byte, wireGuardIP string, remoteRosenpassAddr string)
- onDisconnected func(remotePeer string, wgIP string)
+ onDisconnected func(remotePeer string)
statusRelay *AtomicConnStatus
statusICE *AtomicConnStatus
@@ -121,10 +119,8 @@ type Conn struct {
// NewConn creates a new not opened Conn to the remote peer.
// To establish a connection run Conn.Open
func NewConn(engineCtx context.Context, config ConnConfig, statusRecorder *Status, signaler *Signaler, iFaceDiscover stdnet.ExternalIFaceDiscover, relayManager *relayClient.Manager, srWatcher *guard.SRWatcher, semaphore *semaphoregroup.SemaphoreGroup) (*Conn, error) {
- allowedIP, _, err := net.ParseCIDR(config.WgConfig.AllowedIps)
- if err != nil {
- log.Errorf("failed to parse allowedIPS: %v", err)
- return nil, err
+ if len(config.WgConfig.AllowedIps) == 0 {
+ return nil, fmt.Errorf("allowed IPs is empty")
}
ctx, ctxCancel := context.WithCancel(engineCtx)
@@ -138,7 +134,6 @@ func NewConn(engineCtx context.Context, config ConnConfig, statusRecorder *Statu
statusRecorder: statusRecorder,
signaler: signaler,
relayManager: relayManager,
- allowedIP: allowedIP,
statusRelay: NewAtomicConnStatus(),
statusICE: NewAtomicConnStatus(),
semaphore: semaphore,
@@ -148,10 +143,11 @@ func NewConn(engineCtx context.Context, config ConnConfig, statusRecorder *Statu
conn.workerRelay = NewWorkerRelay(connLog, ctrl, config, conn, relayManager)
relayIsSupportedLocally := conn.workerRelay.RelayIsSupportedLocally()
- conn.workerICE, err = NewWorkerICE(ctx, connLog, config, conn, signaler, iFaceDiscover, statusRecorder, relayIsSupportedLocally)
+ workerICE, err := NewWorkerICE(ctx, connLog, config, conn, signaler, iFaceDiscover, statusRecorder, relayIsSupportedLocally)
if err != nil {
return nil, err
}
+ conn.workerICE = workerICE
conn.handshaker = NewHandshaker(ctx, connLog, config, signaler, conn.workerICE, conn.workerRelay)
@@ -180,7 +176,7 @@ func (conn *Conn) Open() {
peerState := State{
PubKey: conn.config.Key,
- IP: strings.Split(conn.config.WgConfig.AllowedIps, "/")[0],
+ IP: conn.config.WgConfig.AllowedIps[0].Addr().String(),
ConnStatusUpdate: time.Now(),
ConnStatus: StatusDisconnected,
Mux: new(sync.RWMutex),
@@ -246,7 +242,7 @@ func (conn *Conn) Close() {
conn.freeUpConnID()
if conn.evalStatus() == StatusConnected && conn.onDisconnected != nil {
- conn.onDisconnected(conn.config.WgConfig.RemoteKey, conn.config.WgConfig.AllowedIps)
+ conn.onDisconnected(conn.config.WgConfig.RemoteKey)
}
conn.setStatusToDisconnected()
@@ -277,7 +273,7 @@ func (conn *Conn) SetOnConnected(handler func(remoteWireGuardKey string, remoteR
}
// SetOnDisconnected sets a handler function to be triggered by Conn when a connection to a remote disconnected
-func (conn *Conn) SetOnDisconnected(handler func(remotePeer string, wgIP string)) {
+func (conn *Conn) SetOnDisconnected(handler func(remotePeer string)) {
conn.onDisconnected = handler
}
@@ -602,7 +598,7 @@ func (conn *Conn) doOnConnected(remoteRosenpassPubKey []byte, remoteRosenpassAdd
}
if conn.onConnected != nil {
- conn.onConnected(conn.config.Key, remoteRosenpassPubKey, conn.allowedIP.String(), remoteRosenpassAddr)
+ conn.onConnected(conn.config.Key, remoteRosenpassPubKey, conn.config.WgConfig.AllowedIps[0].Addr().String(), remoteRosenpassAddr)
}
}
@@ -699,7 +695,7 @@ func (conn *Conn) freeUpConnID() {
func (conn *Conn) newProxy(remoteConn net.Conn) (wgproxy.Proxy, error) {
conn.log.Debugf("setup proxied WireGuard connection")
udpAddr := &net.UDPAddr{
- IP: conn.allowedIP,
+ IP: conn.config.WgConfig.AllowedIps[0].Addr().AsSlice(),
Port: conn.config.WgConfig.WgListenPort,
}
@@ -753,8 +749,8 @@ func (conn *Conn) setRelayedProxy(proxy wgproxy.Proxy) {
}
// AllowedIP returns the allowed IP of the remote peer
-func (conn *Conn) AllowedIP() net.IP {
- return conn.allowedIP
+func (conn *Conn) AllowedIP() netip.Addr {
+ return conn.config.WgConfig.AllowedIps[0].Addr()
}
func isController(config ConnConfig) bool {
diff --git a/client/internal/peer/iface.go b/client/internal/peer/iface.go
new file mode 100644
index 000000000..c7b6de9ea
--- /dev/null
+++ b/client/internal/peer/iface.go
@@ -0,0 +1,19 @@
+package peer
+
+import (
+ "net"
+ "net/netip"
+ "time"
+
+ "golang.zx2c4.com/wireguard/wgctrl/wgtypes"
+
+ "github.com/netbirdio/netbird/client/iface/configurer"
+ "github.com/netbirdio/netbird/client/iface/wgproxy"
+)
+
+type WGIface interface {
+ UpdatePeer(peerKey string, allowedIps []netip.Prefix, keepAlive time.Duration, endpoint *net.UDPAddr, preSharedKey *wgtypes.Key) error
+ RemovePeer(peerKey string) error
+ GetStats(peerKey string) (configurer.WGStats, error)
+ GetProxy() wgproxy.Proxy
+}
diff --git a/client/internal/peerstore/store.go b/client/internal/peerstore/store.go
index 6b3385ff5..15d34d3d0 100644
--- a/client/internal/peerstore/store.go
+++ b/client/internal/peerstore/store.go
@@ -1,7 +1,7 @@
package peerstore
import (
- "net"
+ "net/netip"
"sync"
"golang.org/x/exp/maps"
@@ -46,18 +46,7 @@ func (s *Store) Remove(pubKey string) (*peer.Conn, bool) {
return p, true
}
-func (s *Store) AllowedIPs(pubKey string) (string, bool) {
- s.peerConnsMu.RLock()
- defer s.peerConnsMu.RUnlock()
-
- p, ok := s.peerConns[pubKey]
- if !ok {
- return "", false
- }
- return p.WgConfig().AllowedIps, true
-}
-
-func (s *Store) AllowedIP(pubKey string) (net.IP, bool) {
+func (s *Store) AllowedIPs(pubKey string) ([]netip.Prefix, bool) {
s.peerConnsMu.RLock()
defer s.peerConnsMu.RUnlock()
@@ -65,6 +54,17 @@ func (s *Store) AllowedIP(pubKey string) (net.IP, bool) {
if !ok {
return nil, false
}
+ return p.WgConfig().AllowedIps, true
+}
+
+func (s *Store) AllowedIP(pubKey string) (netip.Addr, bool) {
+ s.peerConnsMu.RLock()
+ defer s.peerConnsMu.RUnlock()
+
+ p, ok := s.peerConns[pubKey]
+ if !ok {
+ return netip.Addr{}, false
+ }
return p.AllowedIP(), true
}
diff --git a/client/internal/rosenpass/manager.go b/client/internal/rosenpass/manager.go
index bf019453b..d2d7408fd 100644
--- a/client/internal/rosenpass/manager.go
+++ b/client/internal/rosenpass/manager.go
@@ -126,7 +126,7 @@ func (m *Manager) generateConfig() (rp.Config, error) {
return cfg, nil
}
-func (m *Manager) OnDisconnected(peerKey string, wgIP string) {
+func (m *Manager) OnDisconnected(peerKey string) {
m.lock.Lock()
defer m.lock.Unlock()
diff --git a/client/internal/routemanager/client.go b/client/internal/routemanager/client.go
index 5984e69cb..6680f727a 100644
--- a/client/internal/routemanager/client.go
+++ b/client/internal/routemanager/client.go
@@ -11,12 +11,12 @@ import (
log "github.com/sirupsen/logrus"
nberrors "github.com/netbirdio/netbird/client/errors"
- "github.com/netbirdio/netbird/client/iface"
nbdns "github.com/netbirdio/netbird/client/internal/dns"
"github.com/netbirdio/netbird/client/internal/peer"
"github.com/netbirdio/netbird/client/internal/peerstore"
"github.com/netbirdio/netbird/client/internal/routemanager/dnsinterceptor"
"github.com/netbirdio/netbird/client/internal/routemanager/dynamic"
+ "github.com/netbirdio/netbird/client/internal/routemanager/iface"
"github.com/netbirdio/netbird/client/internal/routemanager/refcounter"
"github.com/netbirdio/netbird/client/internal/routemanager/static"
"github.com/netbirdio/netbird/client/proto"
@@ -62,7 +62,7 @@ type clientNetwork struct {
ctx context.Context
cancel context.CancelFunc
statusRecorder *peer.Status
- wgInterface iface.IWGIface
+ wgInterface iface.WGIface
routes map[route.ID]*route.Route
routeUpdate chan routesUpdate
peerStateUpdate chan struct{}
@@ -75,7 +75,7 @@ type clientNetwork struct {
func newClientNetworkWatcher(
ctx context.Context,
dnsRouteInterval time.Duration,
- wgInterface iface.IWGIface,
+ wgInterface iface.WGIface,
statusRecorder *peer.Status,
rt *route.Route,
routeRefCounter *refcounter.RouteRefCounter,
@@ -306,11 +306,13 @@ func (c *clientNetwork) recalculateRouteAndUpdatePeerAndSystem(rsn reason) error
return nil
}
+ var isNew bool
if c.currentChosen == nil {
// If they were not previously assigned to another peer, add routes to the system first
if err := c.handler.AddRoute(c.ctx); err != nil {
return fmt.Errorf("add route: %w", err)
}
+ isNew = true
} else {
// Otherwise, remove the allowed IPs from the previous peer first
if err := c.removeRouteFromWireGuardPeer(); err != nil {
@@ -324,6 +326,10 @@ func (c *clientNetwork) recalculateRouteAndUpdatePeerAndSystem(rsn reason) error
return fmt.Errorf("add allowed IPs for peer %s: %w", c.currentChosen.Peer, err)
}
+ if isNew {
+ c.connectEvent()
+ }
+
err := c.statusRecorder.AddPeerStateRoute(c.currentChosen.Peer, c.handler.String())
if err != nil {
return fmt.Errorf("add peer state route: %w", err)
@@ -331,6 +337,35 @@ func (c *clientNetwork) recalculateRouteAndUpdatePeerAndSystem(rsn reason) error
return nil
}
+func (c *clientNetwork) connectEvent() {
+ var defaultRoute bool
+ for _, r := range c.routes {
+ if r.Network.Bits() == 0 {
+ defaultRoute = true
+ break
+ }
+ }
+
+ if !defaultRoute {
+ return
+ }
+
+ meta := map[string]string{
+ "network": c.handler.String(),
+ }
+ if c.currentChosen != nil {
+ meta["id"] = string(c.currentChosen.NetID)
+ meta["peer"] = c.currentChosen.Peer
+ }
+ c.statusRecorder.PublishEvent(
+ proto.SystemEvent_INFO,
+ proto.SystemEvent_NETWORK,
+ "Default route added",
+ "Exit node connected.",
+ meta,
+ )
+}
+
func (c *clientNetwork) disconnectEvent(rsn reason) {
var defaultRoute bool
for _, r := range c.routes {
@@ -349,29 +384,27 @@ func (c *clientNetwork) disconnectEvent(rsn reason) {
var userMessage string
meta := make(map[string]string)
+ if c.currentChosen != nil {
+ meta["id"] = string(c.currentChosen.NetID)
+ meta["peer"] = c.currentChosen.Peer
+ }
+ meta["network"] = c.handler.String()
switch rsn {
case reasonShutdown:
severity = proto.SystemEvent_INFO
message = "Default route removed"
userMessage = "Exit node disconnected."
- meta["network"] = c.handler.String()
case reasonRouteUpdate:
severity = proto.SystemEvent_INFO
message = "Default route updated due to configuration change"
- meta["network"] = c.handler.String()
case reasonPeerUpdate:
severity = proto.SystemEvent_WARNING
message = "Default route disconnected due to peer unreachability"
userMessage = "Exit node connection lost. Your internet access might be affected."
- if c.currentChosen != nil {
- meta["peer"] = c.currentChosen.Peer
- meta["network"] = c.handler.String()
- }
default:
severity = proto.SystemEvent_ERROR
- message = "Default route disconnected for unknown reason"
+ message = "Default route disconnected for unknown reasons"
userMessage = "Exit node disconnected for unknown reasons."
- meta["network"] = c.handler.String()
}
c.statusRecorder.PublishEvent(
@@ -468,7 +501,7 @@ func handlerFromRoute(
allowedIPsRefCounter *refcounter.AllowedIPsRefCounter,
dnsRouterInteval time.Duration,
statusRecorder *peer.Status,
- wgInterface iface.IWGIface,
+ wgInterface iface.WGIface,
dnsServer nbdns.Server,
peerStore *peerstore.Store,
useNewDNSRoute bool,
diff --git a/client/internal/routemanager/dnsinterceptor/handler.go b/client/internal/routemanager/dnsinterceptor/handler.go
index 10cb03f1d..f36285cc4 100644
--- a/client/internal/routemanager/dnsinterceptor/handler.go
+++ b/client/internal/routemanager/dnsinterceptor/handler.go
@@ -3,7 +3,6 @@ package dnsinterceptor
import (
"context"
"fmt"
- "net"
"net/netip"
"strings"
"sync"
@@ -165,14 +164,14 @@ func (d *DnsInterceptor) ServeDNS(w dns.ResponseWriter, r *dns.Msg) {
Timeout: 5 * time.Second,
Net: "udp",
}
- upstream := fmt.Sprintf("%s:%d", upstreamIP, dnsfwd.ListenPort)
+ upstream := fmt.Sprintf("%s:%d", upstreamIP.String(), dnsfwd.ListenPort)
reply, _, err := client.ExchangeContext(context.Background(), r, upstream)
var answer []dns.RR
if reply != nil {
answer = reply.Answer
}
- log.Tracef("upstream %s (%s) DNS response for domain=%s answers=%v", upstreamIP, peerKey, r.Question[0].Name, answer)
+ log.Tracef("upstream %s (%s) DNS response for domain=%s answers=%v", upstreamIP.String(), peerKey, r.Question[0].Name, answer)
if err != nil {
log.Errorf("failed to exchange DNS request with %s: %v", upstream, err)
@@ -201,10 +200,10 @@ func (d *DnsInterceptor) continueToNextHandler(w dns.ResponseWriter, r *dns.Msg,
}
}
-func (d *DnsInterceptor) getUpstreamIP(peerKey string) (net.IP, error) {
+func (d *DnsInterceptor) getUpstreamIP(peerKey string) (netip.Addr, error) {
peerAllowedIP, exists := d.peerStore.AllowedIP(peerKey)
if !exists {
- return nil, fmt.Errorf("peer connection not found for key: %s", peerKey)
+ return netip.Addr{}, fmt.Errorf("peer connection not found for key: %s", peerKey)
}
return peerAllowedIP, nil
}
diff --git a/client/internal/routemanager/dynamic/route.go b/client/internal/routemanager/dynamic/route.go
index a0fff7713..5ef18a47e 100644
--- a/client/internal/routemanager/dynamic/route.go
+++ b/client/internal/routemanager/dynamic/route.go
@@ -13,8 +13,8 @@ import (
log "github.com/sirupsen/logrus"
nberrors "github.com/netbirdio/netbird/client/errors"
- "github.com/netbirdio/netbird/client/iface"
"github.com/netbirdio/netbird/client/internal/peer"
+ "github.com/netbirdio/netbird/client/internal/routemanager/iface"
"github.com/netbirdio/netbird/client/internal/routemanager/refcounter"
"github.com/netbirdio/netbird/client/internal/routemanager/util"
"github.com/netbirdio/netbird/management/domain"
@@ -48,7 +48,7 @@ type Route struct {
currentPeerKey string
cancel context.CancelFunc
statusRecorder *peer.Status
- wgInterface iface.IWGIface
+ wgInterface iface.WGIface
resolverAddr string
}
@@ -58,7 +58,7 @@ func NewRoute(
allowedIPsRefCounter *refcounter.AllowedIPsRefCounter,
interval time.Duration,
statusRecorder *peer.Status,
- wgInterface iface.IWGIface,
+ wgInterface iface.WGIface,
resolverAddr string,
) *Route {
return &Route{
diff --git a/client/internal/routemanager/iface/iface.go b/client/internal/routemanager/iface/iface.go
new file mode 100644
index 000000000..57dbec03d
--- /dev/null
+++ b/client/internal/routemanager/iface/iface.go
@@ -0,0 +1,9 @@
+//go:build !windows
+// +build !windows
+
+package iface
+
+// WGIface defines subset methods of interface required for router
+type WGIface interface {
+ wgIfaceBase
+}
diff --git a/client/internal/routemanager/iface/iface_common.go b/client/internal/routemanager/iface/iface_common.go
new file mode 100644
index 000000000..8b2dc9714
--- /dev/null
+++ b/client/internal/routemanager/iface/iface_common.go
@@ -0,0 +1,22 @@
+package iface
+
+import (
+ "net"
+
+ "github.com/netbirdio/netbird/client/iface"
+ "github.com/netbirdio/netbird/client/iface/configurer"
+ "github.com/netbirdio/netbird/client/iface/device"
+)
+
+type wgIfaceBase interface {
+ AddAllowedIP(peerKey string, allowedIP string) error
+ RemoveAllowedIP(peerKey string, allowedIP string) error
+
+ Name() string
+ Address() iface.WGAddress
+ ToInterface() *net.Interface
+ IsUserspaceBind() bool
+ GetFilter() device.PacketFilter
+ GetDevice() *device.FilteredDevice
+ GetStats(peerKey string) (configurer.WGStats, error)
+}
diff --git a/client/internal/routemanager/iface/iface_windows.go b/client/internal/routemanager/iface/iface_windows.go
new file mode 100644
index 000000000..7ab7e239c
--- /dev/null
+++ b/client/internal/routemanager/iface/iface_windows.go
@@ -0,0 +1,7 @@
+package iface
+
+// WGIface defines subset methods of interface required for router
+type WGIface interface {
+ wgIfaceBase
+ GetInterfaceGUIDString() (string, error)
+}
diff --git a/client/internal/routemanager/manager.go b/client/internal/routemanager/manager.go
index 52de0948b..ae0d1d220 100644
--- a/client/internal/routemanager/manager.go
+++ b/client/internal/routemanager/manager.go
@@ -15,13 +15,13 @@ import (
"golang.org/x/exp/maps"
firewall "github.com/netbirdio/netbird/client/firewall/manager"
- "github.com/netbirdio/netbird/client/iface"
"github.com/netbirdio/netbird/client/iface/configurer"
"github.com/netbirdio/netbird/client/iface/netstack"
"github.com/netbirdio/netbird/client/internal/dns"
"github.com/netbirdio/netbird/client/internal/listener"
"github.com/netbirdio/netbird/client/internal/peer"
"github.com/netbirdio/netbird/client/internal/peerstore"
+ "github.com/netbirdio/netbird/client/internal/routemanager/iface"
"github.com/netbirdio/netbird/client/internal/routemanager/notifier"
"github.com/netbirdio/netbird/client/internal/routemanager/refcounter"
"github.com/netbirdio/netbird/client/internal/routemanager/systemops"
@@ -52,7 +52,7 @@ type ManagerConfig struct {
Context context.Context
PublicKey string
DNSRouteInterval time.Duration
- WGInterface iface.IWGIface
+ WGInterface iface.WGIface
StatusRecorder *peer.Status
RelayManager *relayClient.Manager
InitialRoutes []*route.Route
@@ -74,7 +74,7 @@ type DefaultManager struct {
sysOps *systemops.SysOps
statusRecorder *peer.Status
relayMgr *relayClient.Manager
- wgInterface iface.IWGIface
+ wgInterface iface.WGIface
pubKey string
notifier *notifier.Notifier
routeRefCounter *refcounter.RouteRefCounter
diff --git a/client/internal/routemanager/server_android.go b/client/internal/routemanager/server_android.go
index e9cfa0826..48bb0380d 100644
--- a/client/internal/routemanager/server_android.go
+++ b/client/internal/routemanager/server_android.go
@@ -7,8 +7,8 @@ import (
"fmt"
firewall "github.com/netbirdio/netbird/client/firewall/manager"
- "github.com/netbirdio/netbird/client/iface"
"github.com/netbirdio/netbird/client/internal/peer"
+ "github.com/netbirdio/netbird/client/internal/routemanager/iface"
"github.com/netbirdio/netbird/route"
)
@@ -22,6 +22,6 @@ func (r serverRouter) updateRoutes(map[route.ID]*route.Route) error {
return nil
}
-func newServerRouter(context.Context, iface.IWGIface, firewall.Manager, *peer.Status) (*serverRouter, error) {
+func newServerRouter(context.Context, iface.WGIface, firewall.Manager, *peer.Status) (*serverRouter, error) {
return nil, fmt.Errorf("server route not supported on this os")
}
diff --git a/client/internal/routemanager/server_nonandroid.go b/client/internal/routemanager/server_nonandroid.go
index 7ddcccb0b..603818bba 100644
--- a/client/internal/routemanager/server_nonandroid.go
+++ b/client/internal/routemanager/server_nonandroid.go
@@ -11,8 +11,12 @@ import (
log "github.com/sirupsen/logrus"
firewall "github.com/netbirdio/netbird/client/firewall/manager"
- "github.com/netbirdio/netbird/client/iface"
"github.com/netbirdio/netbird/client/internal/peer"
+<<<<<<< HEAD
+=======
+ "github.com/netbirdio/netbird/client/internal/routemanager/iface"
+ "github.com/netbirdio/netbird/client/internal/routemanager/systemops"
+>>>>>>> main
"github.com/netbirdio/netbird/route"
)
@@ -21,11 +25,11 @@ type serverRouter struct {
ctx context.Context
routes map[route.ID]*route.Route
firewall firewall.Manager
- wgInterface iface.IWGIface
+ wgInterface iface.WGIface
statusRecorder *peer.Status
}
-func newServerRouter(ctx context.Context, wgInterface iface.IWGIface, firewall firewall.Manager, statusRecorder *peer.Status) (*serverRouter, error) {
+func newServerRouter(ctx context.Context, wgInterface iface.WGIface, firewall firewall.Manager, statusRecorder *peer.Status) (*serverRouter, error) {
return &serverRouter{
ctx: ctx,
routes: make(map[route.ID]*route.Route),
diff --git a/client/internal/routemanager/sysctl/sysctl_linux.go b/client/internal/routemanager/sysctl/sysctl_linux.go
index bb620ee68..ea63f02fc 100644
--- a/client/internal/routemanager/sysctl/sysctl_linux.go
+++ b/client/internal/routemanager/sysctl/sysctl_linux.go
@@ -13,7 +13,7 @@ import (
log "github.com/sirupsen/logrus"
nberrors "github.com/netbirdio/netbird/client/errors"
- "github.com/netbirdio/netbird/client/iface"
+ "github.com/netbirdio/netbird/client/internal/routemanager/iface"
)
const (
@@ -23,7 +23,7 @@ const (
)
// Setup configures sysctl settings for RP filtering and source validation.
-func Setup(wgIface iface.IWGIface) (map[string]int, error) {
+func Setup(wgIface iface.WGIface) (map[string]int, error) {
keys := map[string]int{}
var result *multierror.Error
diff --git a/client/internal/routemanager/systemops/systemops.go b/client/internal/routemanager/systemops/systemops.go
index d1cb83bfb..5c117b94d 100644
--- a/client/internal/routemanager/systemops/systemops.go
+++ b/client/internal/routemanager/systemops/systemops.go
@@ -5,7 +5,7 @@ import (
"net/netip"
"sync"
- "github.com/netbirdio/netbird/client/iface"
+ "github.com/netbirdio/netbird/client/internal/routemanager/iface"
"github.com/netbirdio/netbird/client/internal/routemanager/notifier"
"github.com/netbirdio/netbird/client/internal/routemanager/refcounter"
)
@@ -19,7 +19,7 @@ type ExclusionCounter = refcounter.Counter[netip.Prefix, struct{}, Nexthop]
type SysOps struct {
refCounter *ExclusionCounter
- wgInterface iface.IWGIface
+ wgInterface iface.WGIface
// prefixes is tracking all the current added prefixes im memory
// (this is used in iOS as all route updates require a full table update)
//nolint
@@ -30,7 +30,7 @@ type SysOps struct {
notifier *notifier.Notifier
}
-func NewSysOps(wgInterface iface.IWGIface, notifier *notifier.Notifier) *SysOps {
+func NewSysOps(wgInterface iface.WGIface, notifier *notifier.Notifier) *SysOps {
return &SysOps{
wgInterface: wgInterface,
notifier: notifier,
diff --git a/client/internal/routemanager/systemops/systemops_generic.go b/client/internal/routemanager/systemops/systemops_generic.go
index 31b7f3ac2..eaef01815 100644
--- a/client/internal/routemanager/systemops/systemops_generic.go
+++ b/client/internal/routemanager/systemops/systemops_generic.go
@@ -16,8 +16,8 @@ import (
log "github.com/sirupsen/logrus"
nberrors "github.com/netbirdio/netbird/client/errors"
- "github.com/netbirdio/netbird/client/iface"
"github.com/netbirdio/netbird/client/iface/netstack"
+ "github.com/netbirdio/netbird/client/internal/routemanager/iface"
"github.com/netbirdio/netbird/client/internal/routemanager/refcounter"
"github.com/netbirdio/netbird/client/internal/routemanager/util"
"github.com/netbirdio/netbird/client/internal/routemanager/vars"
@@ -149,7 +149,7 @@ func (r *SysOps) addRouteForCurrentDefaultGateway(prefix netip.Prefix) error {
// addRouteToNonVPNIntf adds a new route to the routing table for the given prefix and returns the next hop and interface.
// If the next hop or interface is pointing to the VPN interface, it will return the initial values.
-func (r *SysOps) addRouteToNonVPNIntf(prefix netip.Prefix, vpnIntf iface.IWGIface, initialNextHop Nexthop) (Nexthop, error) {
+func (r *SysOps) addRouteToNonVPNIntf(prefix netip.Prefix, vpnIntf iface.WGIface, initialNextHop Nexthop) (Nexthop, error) {
addr := prefix.Addr()
switch {
case addr.IsLoopback(),
diff --git a/client/internal/stdnet/filter.go b/client/internal/stdnet/filter.go
index c04250b2d..e45714001 100644
--- a/client/internal/stdnet/filter.go
+++ b/client/internal/stdnet/filter.go
@@ -21,7 +21,6 @@ func InterfaceFilter(disallowList []string) func(string) bool {
for _, s := range disallowList {
if strings.HasPrefix(iFace, s) && runtime.GOOS != "ios" {
- log.Tracef("ignoring interface %s - it is not allowed", iFace)
return false
}
}
diff --git a/client/internal/stdnet/stdnet.go b/client/internal/stdnet/stdnet.go
index 2e87475a5..aa9fdd045 100644
--- a/client/internal/stdnet/stdnet.go
+++ b/client/internal/stdnet/stdnet.go
@@ -5,11 +5,16 @@ package stdnet
import (
"fmt"
+ "slices"
+ "sync"
+ "time"
"github.com/pion/transport/v3"
"github.com/pion/transport/v3/stdnet"
)
+const updateInterval = 30 * time.Second
+
// Net is an implementation of the net.Net interface
// based on functions of the standard net package.
type Net struct {
@@ -18,6 +23,10 @@ type Net struct {
iFaceDiscover iFaceDiscover
// interfaceFilter should return true if the given interfaceName is allowed
interfaceFilter func(interfaceName string) bool
+ lastUpdate time.Time
+
+ // mu is shared between interfaces and lastUpdate
+ mu sync.Mutex
}
// NewNetWithDiscover creates a new StdNet instance.
@@ -43,18 +52,40 @@ func NewNet(disallowList []string) (*Net, error) {
// The interfaces are discovered by an external iFaceDiscover function or by a default discoverer if the external one
// wasn't specified.
func (n *Net) UpdateInterfaces() (err error) {
+ n.mu.Lock()
+ defer n.mu.Unlock()
+
+ return n.updateInterfaces()
+}
+
+func (n *Net) updateInterfaces() (err error) {
allIfaces, err := n.iFaceDiscover.iFaces()
if err != nil {
return err
}
+
n.interfaces = n.filterInterfaces(allIfaces)
+
+ n.lastUpdate = time.Now()
+
return nil
}
// Interfaces returns a slice of interfaces which are available on the
// system
func (n *Net) Interfaces() ([]*transport.Interface, error) {
- return n.interfaces, nil
+ n.mu.Lock()
+ defer n.mu.Unlock()
+
+ if time.Since(n.lastUpdate) < updateInterval {
+ return slices.Clone(n.interfaces), nil
+ }
+
+ if err := n.updateInterfaces(); err != nil {
+ return nil, fmt.Errorf("update interfaces: %w", err)
+ }
+
+ return slices.Clone(n.interfaces), nil
}
// InterfaceByIndex returns the interface specified by index.
@@ -63,6 +94,8 @@ func (n *Net) Interfaces() ([]*transport.Interface, error) {
// sharing the logical data link; for more precision use
// InterfaceByName.
func (n *Net) InterfaceByIndex(index int) (*transport.Interface, error) {
+ n.mu.Lock()
+ defer n.mu.Unlock()
for _, ifc := range n.interfaces {
if ifc.Index == index {
return ifc, nil
@@ -74,6 +107,8 @@ func (n *Net) InterfaceByIndex(index int) (*transport.Interface, error) {
// InterfaceByName returns the interface specified by name.
func (n *Net) InterfaceByName(name string) (*transport.Interface, error) {
+ n.mu.Lock()
+ defer n.mu.Unlock()
for _, ifc := range n.interfaces {
if ifc.Name == name {
return ifc, nil
@@ -87,7 +122,7 @@ func (n *Net) filterInterfaces(interfaces []*transport.Interface) []*transport.I
if n.interfaceFilter == nil {
return interfaces
}
- result := []*transport.Interface{}
+ var result []*transport.Interface
for _, iface := range interfaces {
if n.interfaceFilter(iface.Name) {
result = append(result, iface)
diff --git a/client/proto/daemon.pb.go b/client/proto/daemon.pb.go
index 0d7d7146a..d04d7a9c0 100644
--- a/client/proto/daemon.pb.go
+++ b/client/proto/daemon.pb.go
@@ -146,6 +146,7 @@ const (
SystemEvent_DNS SystemEvent_Category = 1
SystemEvent_AUTHENTICATION SystemEvent_Category = 2
SystemEvent_CONNECTIVITY SystemEvent_Category = 3
+ SystemEvent_SYSTEM SystemEvent_Category = 4
)
// Enum value maps for SystemEvent_Category.
@@ -155,12 +156,14 @@ var (
1: "DNS",
2: "AUTHENTICATION",
3: "CONNECTIVITY",
+ 4: "SYSTEM",
}
SystemEvent_Category_value = map[string]int32{
"NETWORK": 0,
"DNS": 1,
"AUTHENTICATION": 2,
"CONNECTIVITY": 3,
+ "SYSTEM": 4,
}
)
@@ -4020,7 +4023,7 @@ var file_daemon_proto_rawDesc = []byte{
0x69, 0x6e, 0x61, 0x6c, 0x5f, 0x64, 0x69, 0x73, 0x70, 0x6f, 0x73, 0x69, 0x74, 0x69, 0x6f, 0x6e,
0x18, 0x02, 0x20, 0x01, 0x28, 0x08, 0x52, 0x10, 0x66, 0x69, 0x6e, 0x61, 0x6c, 0x44, 0x69, 0x73,
0x70, 0x6f, 0x73, 0x69, 0x74, 0x69, 0x6f, 0x6e, 0x22, 0x12, 0x0a, 0x10, 0x53, 0x75, 0x62, 0x73,
- 0x63, 0x72, 0x69, 0x62, 0x65, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x22, 0x87, 0x04, 0x0a,
+ 0x63, 0x72, 0x69, 0x62, 0x65, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x22, 0x93, 0x04, 0x0a,
0x0b, 0x53, 0x79, 0x73, 0x74, 0x65, 0x6d, 0x45, 0x76, 0x65, 0x6e, 0x74, 0x12, 0x0e, 0x0a, 0x02,
0x69, 0x64, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x02, 0x69, 0x64, 0x12, 0x38, 0x0a, 0x08,
0x73, 0x65, 0x76, 0x65, 0x72, 0x69, 0x74, 0x79, 0x18, 0x02, 0x20, 0x01, 0x28, 0x0e, 0x32, 0x1c,
@@ -4048,116 +4051,117 @@ var file_daemon_proto_rawDesc = []byte{
0x38, 0x01, 0x22, 0x3a, 0x0a, 0x08, 0x53, 0x65, 0x76, 0x65, 0x72, 0x69, 0x74, 0x79, 0x12, 0x08,
0x0a, 0x04, 0x49, 0x4e, 0x46, 0x4f, 0x10, 0x00, 0x12, 0x0b, 0x0a, 0x07, 0x57, 0x41, 0x52, 0x4e,
0x49, 0x4e, 0x47, 0x10, 0x01, 0x12, 0x09, 0x0a, 0x05, 0x45, 0x52, 0x52, 0x4f, 0x52, 0x10, 0x02,
- 0x12, 0x0c, 0x0a, 0x08, 0x43, 0x52, 0x49, 0x54, 0x49, 0x43, 0x41, 0x4c, 0x10, 0x03, 0x22, 0x46,
+ 0x12, 0x0c, 0x0a, 0x08, 0x43, 0x52, 0x49, 0x54, 0x49, 0x43, 0x41, 0x4c, 0x10, 0x03, 0x22, 0x52,
0x0a, 0x08, 0x43, 0x61, 0x74, 0x65, 0x67, 0x6f, 0x72, 0x79, 0x12, 0x0b, 0x0a, 0x07, 0x4e, 0x45,
0x54, 0x57, 0x4f, 0x52, 0x4b, 0x10, 0x00, 0x12, 0x07, 0x0a, 0x03, 0x44, 0x4e, 0x53, 0x10, 0x01,
0x12, 0x12, 0x0a, 0x0e, 0x41, 0x55, 0x54, 0x48, 0x45, 0x4e, 0x54, 0x49, 0x43, 0x41, 0x54, 0x49,
0x4f, 0x4e, 0x10, 0x02, 0x12, 0x10, 0x0a, 0x0c, 0x43, 0x4f, 0x4e, 0x4e, 0x45, 0x43, 0x54, 0x49,
- 0x56, 0x49, 0x54, 0x59, 0x10, 0x03, 0x22, 0x12, 0x0a, 0x10, 0x47, 0x65, 0x74, 0x45, 0x76, 0x65,
- 0x6e, 0x74, 0x73, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x22, 0x40, 0x0a, 0x11, 0x47, 0x65,
- 0x74, 0x45, 0x76, 0x65, 0x6e, 0x74, 0x73, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x12,
- 0x2b, 0x0a, 0x06, 0x65, 0x76, 0x65, 0x6e, 0x74, 0x73, 0x18, 0x01, 0x20, 0x03, 0x28, 0x0b, 0x32,
- 0x13, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x53, 0x79, 0x73, 0x74, 0x65, 0x6d, 0x45,
- 0x76, 0x65, 0x6e, 0x74, 0x52, 0x06, 0x65, 0x76, 0x65, 0x6e, 0x74, 0x73, 0x2a, 0x62, 0x0a, 0x08,
- 0x4c, 0x6f, 0x67, 0x4c, 0x65, 0x76, 0x65, 0x6c, 0x12, 0x0b, 0x0a, 0x07, 0x55, 0x4e, 0x4b, 0x4e,
- 0x4f, 0x57, 0x4e, 0x10, 0x00, 0x12, 0x09, 0x0a, 0x05, 0x50, 0x41, 0x4e, 0x49, 0x43, 0x10, 0x01,
- 0x12, 0x09, 0x0a, 0x05, 0x46, 0x41, 0x54, 0x41, 0x4c, 0x10, 0x02, 0x12, 0x09, 0x0a, 0x05, 0x45,
- 0x52, 0x52, 0x4f, 0x52, 0x10, 0x03, 0x12, 0x08, 0x0a, 0x04, 0x57, 0x41, 0x52, 0x4e, 0x10, 0x04,
- 0x12, 0x08, 0x0a, 0x04, 0x49, 0x4e, 0x46, 0x4f, 0x10, 0x05, 0x12, 0x09, 0x0a, 0x05, 0x44, 0x45,
- 0x42, 0x55, 0x47, 0x10, 0x06, 0x12, 0x09, 0x0a, 0x05, 0x54, 0x52, 0x41, 0x43, 0x45, 0x10, 0x07,
- 0x32, 0xb3, 0x0b, 0x0a, 0x0d, 0x44, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x53, 0x65, 0x72, 0x76, 0x69,
- 0x63, 0x65, 0x12, 0x36, 0x0a, 0x05, 0x4c, 0x6f, 0x67, 0x69, 0x6e, 0x12, 0x14, 0x2e, 0x64, 0x61,
- 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x4c, 0x6f, 0x67, 0x69, 0x6e, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73,
- 0x74, 0x1a, 0x15, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x4c, 0x6f, 0x67, 0x69, 0x6e,
- 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x22, 0x00, 0x12, 0x4b, 0x0a, 0x0c, 0x57, 0x61,
- 0x69, 0x74, 0x53, 0x53, 0x4f, 0x4c, 0x6f, 0x67, 0x69, 0x6e, 0x12, 0x1b, 0x2e, 0x64, 0x61, 0x65,
- 0x6d, 0x6f, 0x6e, 0x2e, 0x57, 0x61, 0x69, 0x74, 0x53, 0x53, 0x4f, 0x4c, 0x6f, 0x67, 0x69, 0x6e,
- 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x1a, 0x1c, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e,
- 0x2e, 0x57, 0x61, 0x69, 0x74, 0x53, 0x53, 0x4f, 0x4c, 0x6f, 0x67, 0x69, 0x6e, 0x52, 0x65, 0x73,
- 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x22, 0x00, 0x12, 0x2d, 0x0a, 0x02, 0x55, 0x70, 0x12, 0x11, 0x2e,
- 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x55, 0x70, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74,
- 0x1a, 0x12, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x55, 0x70, 0x52, 0x65, 0x73, 0x70,
- 0x6f, 0x6e, 0x73, 0x65, 0x22, 0x00, 0x12, 0x39, 0x0a, 0x06, 0x53, 0x74, 0x61, 0x74, 0x75, 0x73,
- 0x12, 0x15, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x53, 0x74, 0x61, 0x74, 0x75, 0x73,
- 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x1a, 0x16, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e,
- 0x2e, 0x53, 0x74, 0x61, 0x74, 0x75, 0x73, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x22,
- 0x00, 0x12, 0x33, 0x0a, 0x04, 0x44, 0x6f, 0x77, 0x6e, 0x12, 0x13, 0x2e, 0x64, 0x61, 0x65, 0x6d,
- 0x6f, 0x6e, 0x2e, 0x44, 0x6f, 0x77, 0x6e, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x1a, 0x14,
- 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x44, 0x6f, 0x77, 0x6e, 0x52, 0x65, 0x73, 0x70,
- 0x6f, 0x6e, 0x73, 0x65, 0x22, 0x00, 0x12, 0x42, 0x0a, 0x09, 0x47, 0x65, 0x74, 0x43, 0x6f, 0x6e,
- 0x66, 0x69, 0x67, 0x12, 0x18, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x47, 0x65, 0x74,
- 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x1a, 0x19, 0x2e,
- 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x47, 0x65, 0x74, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67,
- 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x22, 0x00, 0x12, 0x4b, 0x0a, 0x0c, 0x4c, 0x69,
- 0x73, 0x74, 0x4e, 0x65, 0x74, 0x77, 0x6f, 0x72, 0x6b, 0x73, 0x12, 0x1b, 0x2e, 0x64, 0x61, 0x65,
- 0x6d, 0x6f, 0x6e, 0x2e, 0x4c, 0x69, 0x73, 0x74, 0x4e, 0x65, 0x74, 0x77, 0x6f, 0x72, 0x6b, 0x73,
- 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x1a, 0x1c, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e,
- 0x2e, 0x4c, 0x69, 0x73, 0x74, 0x4e, 0x65, 0x74, 0x77, 0x6f, 0x72, 0x6b, 0x73, 0x52, 0x65, 0x73,
- 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x22, 0x00, 0x12, 0x51, 0x0a, 0x0e, 0x53, 0x65, 0x6c, 0x65, 0x63,
- 0x74, 0x4e, 0x65, 0x74, 0x77, 0x6f, 0x72, 0x6b, 0x73, 0x12, 0x1d, 0x2e, 0x64, 0x61, 0x65, 0x6d,
+ 0x56, 0x49, 0x54, 0x59, 0x10, 0x03, 0x12, 0x0a, 0x0a, 0x06, 0x53, 0x59, 0x53, 0x54, 0x45, 0x4d,
+ 0x10, 0x04, 0x22, 0x12, 0x0a, 0x10, 0x47, 0x65, 0x74, 0x45, 0x76, 0x65, 0x6e, 0x74, 0x73, 0x52,
+ 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x22, 0x40, 0x0a, 0x11, 0x47, 0x65, 0x74, 0x45, 0x76, 0x65,
+ 0x6e, 0x74, 0x73, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x12, 0x2b, 0x0a, 0x06, 0x65,
+ 0x76, 0x65, 0x6e, 0x74, 0x73, 0x18, 0x01, 0x20, 0x03, 0x28, 0x0b, 0x32, 0x13, 0x2e, 0x64, 0x61,
+ 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x53, 0x79, 0x73, 0x74, 0x65, 0x6d, 0x45, 0x76, 0x65, 0x6e, 0x74,
+ 0x52, 0x06, 0x65, 0x76, 0x65, 0x6e, 0x74, 0x73, 0x2a, 0x62, 0x0a, 0x08, 0x4c, 0x6f, 0x67, 0x4c,
+ 0x65, 0x76, 0x65, 0x6c, 0x12, 0x0b, 0x0a, 0x07, 0x55, 0x4e, 0x4b, 0x4e, 0x4f, 0x57, 0x4e, 0x10,
+ 0x00, 0x12, 0x09, 0x0a, 0x05, 0x50, 0x41, 0x4e, 0x49, 0x43, 0x10, 0x01, 0x12, 0x09, 0x0a, 0x05,
+ 0x46, 0x41, 0x54, 0x41, 0x4c, 0x10, 0x02, 0x12, 0x09, 0x0a, 0x05, 0x45, 0x52, 0x52, 0x4f, 0x52,
+ 0x10, 0x03, 0x12, 0x08, 0x0a, 0x04, 0x57, 0x41, 0x52, 0x4e, 0x10, 0x04, 0x12, 0x08, 0x0a, 0x04,
+ 0x49, 0x4e, 0x46, 0x4f, 0x10, 0x05, 0x12, 0x09, 0x0a, 0x05, 0x44, 0x45, 0x42, 0x55, 0x47, 0x10,
+ 0x06, 0x12, 0x09, 0x0a, 0x05, 0x54, 0x52, 0x41, 0x43, 0x45, 0x10, 0x07, 0x32, 0xb3, 0x0b, 0x0a,
+ 0x0d, 0x44, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x53, 0x65, 0x72, 0x76, 0x69, 0x63, 0x65, 0x12, 0x36,
+ 0x0a, 0x05, 0x4c, 0x6f, 0x67, 0x69, 0x6e, 0x12, 0x14, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e,
+ 0x2e, 0x4c, 0x6f, 0x67, 0x69, 0x6e, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x1a, 0x15, 0x2e,
+ 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x4c, 0x6f, 0x67, 0x69, 0x6e, 0x52, 0x65, 0x73, 0x70,
+ 0x6f, 0x6e, 0x73, 0x65, 0x22, 0x00, 0x12, 0x4b, 0x0a, 0x0c, 0x57, 0x61, 0x69, 0x74, 0x53, 0x53,
+ 0x4f, 0x4c, 0x6f, 0x67, 0x69, 0x6e, 0x12, 0x1b, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e,
+ 0x57, 0x61, 0x69, 0x74, 0x53, 0x53, 0x4f, 0x4c, 0x6f, 0x67, 0x69, 0x6e, 0x52, 0x65, 0x71, 0x75,
+ 0x65, 0x73, 0x74, 0x1a, 0x1c, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x57, 0x61, 0x69,
+ 0x74, 0x53, 0x53, 0x4f, 0x4c, 0x6f, 0x67, 0x69, 0x6e, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73,
+ 0x65, 0x22, 0x00, 0x12, 0x2d, 0x0a, 0x02, 0x55, 0x70, 0x12, 0x11, 0x2e, 0x64, 0x61, 0x65, 0x6d,
+ 0x6f, 0x6e, 0x2e, 0x55, 0x70, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x1a, 0x12, 0x2e, 0x64,
+ 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x55, 0x70, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65,
+ 0x22, 0x00, 0x12, 0x39, 0x0a, 0x06, 0x53, 0x74, 0x61, 0x74, 0x75, 0x73, 0x12, 0x15, 0x2e, 0x64,
+ 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x53, 0x74, 0x61, 0x74, 0x75, 0x73, 0x52, 0x65, 0x71, 0x75,
+ 0x65, 0x73, 0x74, 0x1a, 0x16, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x53, 0x74, 0x61,
+ 0x74, 0x75, 0x73, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x22, 0x00, 0x12, 0x33, 0x0a,
+ 0x04, 0x44, 0x6f, 0x77, 0x6e, 0x12, 0x13, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x44,
+ 0x6f, 0x77, 0x6e, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x1a, 0x14, 0x2e, 0x64, 0x61, 0x65,
+ 0x6d, 0x6f, 0x6e, 0x2e, 0x44, 0x6f, 0x77, 0x6e, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65,
+ 0x22, 0x00, 0x12, 0x42, 0x0a, 0x09, 0x47, 0x65, 0x74, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x12,
+ 0x18, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x47, 0x65, 0x74, 0x43, 0x6f, 0x6e, 0x66,
+ 0x69, 0x67, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x1a, 0x19, 0x2e, 0x64, 0x61, 0x65, 0x6d,
+ 0x6f, 0x6e, 0x2e, 0x47, 0x65, 0x74, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x52, 0x65, 0x73, 0x70,
+ 0x6f, 0x6e, 0x73, 0x65, 0x22, 0x00, 0x12, 0x4b, 0x0a, 0x0c, 0x4c, 0x69, 0x73, 0x74, 0x4e, 0x65,
+ 0x74, 0x77, 0x6f, 0x72, 0x6b, 0x73, 0x12, 0x1b, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e,
+ 0x4c, 0x69, 0x73, 0x74, 0x4e, 0x65, 0x74, 0x77, 0x6f, 0x72, 0x6b, 0x73, 0x52, 0x65, 0x71, 0x75,
+ 0x65, 0x73, 0x74, 0x1a, 0x1c, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x4c, 0x69, 0x73,
+ 0x74, 0x4e, 0x65, 0x74, 0x77, 0x6f, 0x72, 0x6b, 0x73, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73,
+ 0x65, 0x22, 0x00, 0x12, 0x51, 0x0a, 0x0e, 0x53, 0x65, 0x6c, 0x65, 0x63, 0x74, 0x4e, 0x65, 0x74,
+ 0x77, 0x6f, 0x72, 0x6b, 0x73, 0x12, 0x1d, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x53,
+ 0x65, 0x6c, 0x65, 0x63, 0x74, 0x4e, 0x65, 0x74, 0x77, 0x6f, 0x72, 0x6b, 0x73, 0x52, 0x65, 0x71,
+ 0x75, 0x65, 0x73, 0x74, 0x1a, 0x1e, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x53, 0x65,
+ 0x6c, 0x65, 0x63, 0x74, 0x4e, 0x65, 0x74, 0x77, 0x6f, 0x72, 0x6b, 0x73, 0x52, 0x65, 0x73, 0x70,
+ 0x6f, 0x6e, 0x73, 0x65, 0x22, 0x00, 0x12, 0x53, 0x0a, 0x10, 0x44, 0x65, 0x73, 0x65, 0x6c, 0x65,
+ 0x63, 0x74, 0x4e, 0x65, 0x74, 0x77, 0x6f, 0x72, 0x6b, 0x73, 0x12, 0x1d, 0x2e, 0x64, 0x61, 0x65,
+ 0x6d, 0x6f, 0x6e, 0x2e, 0x53, 0x65, 0x6c, 0x65, 0x63, 0x74, 0x4e, 0x65, 0x74, 0x77, 0x6f, 0x72,
+ 0x6b, 0x73, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x1a, 0x1e, 0x2e, 0x64, 0x61, 0x65, 0x6d,
0x6f, 0x6e, 0x2e, 0x53, 0x65, 0x6c, 0x65, 0x63, 0x74, 0x4e, 0x65, 0x74, 0x77, 0x6f, 0x72, 0x6b,
- 0x73, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x1a, 0x1e, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f,
- 0x6e, 0x2e, 0x53, 0x65, 0x6c, 0x65, 0x63, 0x74, 0x4e, 0x65, 0x74, 0x77, 0x6f, 0x72, 0x6b, 0x73,
- 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x22, 0x00, 0x12, 0x53, 0x0a, 0x10, 0x44, 0x65,
- 0x73, 0x65, 0x6c, 0x65, 0x63, 0x74, 0x4e, 0x65, 0x74, 0x77, 0x6f, 0x72, 0x6b, 0x73, 0x12, 0x1d,
- 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x53, 0x65, 0x6c, 0x65, 0x63, 0x74, 0x4e, 0x65,
- 0x74, 0x77, 0x6f, 0x72, 0x6b, 0x73, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x1a, 0x1e, 0x2e,
- 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x53, 0x65, 0x6c, 0x65, 0x63, 0x74, 0x4e, 0x65, 0x74,
- 0x77, 0x6f, 0x72, 0x6b, 0x73, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x22, 0x00, 0x12,
- 0x4a, 0x0a, 0x0f, 0x46, 0x6f, 0x72, 0x77, 0x61, 0x72, 0x64, 0x69, 0x6e, 0x67, 0x52, 0x75, 0x6c,
- 0x65, 0x73, 0x12, 0x14, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x45, 0x6d, 0x70, 0x74,
- 0x79, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x1a, 0x1f, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f,
- 0x6e, 0x2e, 0x46, 0x6f, 0x72, 0x77, 0x61, 0x72, 0x64, 0x69, 0x6e, 0x67, 0x52, 0x75, 0x6c, 0x65,
- 0x73, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x22, 0x00, 0x12, 0x48, 0x0a, 0x0b, 0x44,
- 0x65, 0x62, 0x75, 0x67, 0x42, 0x75, 0x6e, 0x64, 0x6c, 0x65, 0x12, 0x1a, 0x2e, 0x64, 0x61, 0x65,
- 0x6d, 0x6f, 0x6e, 0x2e, 0x44, 0x65, 0x62, 0x75, 0x67, 0x42, 0x75, 0x6e, 0x64, 0x6c, 0x65, 0x52,
- 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x1a, 0x1b, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e,
- 0x44, 0x65, 0x62, 0x75, 0x67, 0x42, 0x75, 0x6e, 0x64, 0x6c, 0x65, 0x52, 0x65, 0x73, 0x70, 0x6f,
- 0x6e, 0x73, 0x65, 0x22, 0x00, 0x12, 0x48, 0x0a, 0x0b, 0x47, 0x65, 0x74, 0x4c, 0x6f, 0x67, 0x4c,
- 0x65, 0x76, 0x65, 0x6c, 0x12, 0x1a, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x47, 0x65,
- 0x74, 0x4c, 0x6f, 0x67, 0x4c, 0x65, 0x76, 0x65, 0x6c, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74,
- 0x1a, 0x1b, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x47, 0x65, 0x74, 0x4c, 0x6f, 0x67,
- 0x4c, 0x65, 0x76, 0x65, 0x6c, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x22, 0x00, 0x12,
- 0x48, 0x0a, 0x0b, 0x53, 0x65, 0x74, 0x4c, 0x6f, 0x67, 0x4c, 0x65, 0x76, 0x65, 0x6c, 0x12, 0x1a,
- 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x53, 0x65, 0x74, 0x4c, 0x6f, 0x67, 0x4c, 0x65,
- 0x76, 0x65, 0x6c, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x1a, 0x1b, 0x2e, 0x64, 0x61, 0x65,
+ 0x73, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x22, 0x00, 0x12, 0x4a, 0x0a, 0x0f, 0x46,
+ 0x6f, 0x72, 0x77, 0x61, 0x72, 0x64, 0x69, 0x6e, 0x67, 0x52, 0x75, 0x6c, 0x65, 0x73, 0x12, 0x14,
+ 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x45, 0x6d, 0x70, 0x74, 0x79, 0x52, 0x65, 0x71,
+ 0x75, 0x65, 0x73, 0x74, 0x1a, 0x1f, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x46, 0x6f,
+ 0x72, 0x77, 0x61, 0x72, 0x64, 0x69, 0x6e, 0x67, 0x52, 0x75, 0x6c, 0x65, 0x73, 0x52, 0x65, 0x73,
+ 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x22, 0x00, 0x12, 0x48, 0x0a, 0x0b, 0x44, 0x65, 0x62, 0x75, 0x67,
+ 0x42, 0x75, 0x6e, 0x64, 0x6c, 0x65, 0x12, 0x1a, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e,
+ 0x44, 0x65, 0x62, 0x75, 0x67, 0x42, 0x75, 0x6e, 0x64, 0x6c, 0x65, 0x52, 0x65, 0x71, 0x75, 0x65,
+ 0x73, 0x74, 0x1a, 0x1b, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x44, 0x65, 0x62, 0x75,
+ 0x67, 0x42, 0x75, 0x6e, 0x64, 0x6c, 0x65, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x22,
+ 0x00, 0x12, 0x48, 0x0a, 0x0b, 0x47, 0x65, 0x74, 0x4c, 0x6f, 0x67, 0x4c, 0x65, 0x76, 0x65, 0x6c,
+ 0x12, 0x1a, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x47, 0x65, 0x74, 0x4c, 0x6f, 0x67,
+ 0x4c, 0x65, 0x76, 0x65, 0x6c, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x1a, 0x1b, 0x2e, 0x64,
+ 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x47, 0x65, 0x74, 0x4c, 0x6f, 0x67, 0x4c, 0x65, 0x76, 0x65,
+ 0x6c, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x22, 0x00, 0x12, 0x48, 0x0a, 0x0b, 0x53,
+ 0x65, 0x74, 0x4c, 0x6f, 0x67, 0x4c, 0x65, 0x76, 0x65, 0x6c, 0x12, 0x1a, 0x2e, 0x64, 0x61, 0x65,
0x6d, 0x6f, 0x6e, 0x2e, 0x53, 0x65, 0x74, 0x4c, 0x6f, 0x67, 0x4c, 0x65, 0x76, 0x65, 0x6c, 0x52,
- 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x22, 0x00, 0x12, 0x45, 0x0a, 0x0a, 0x4c, 0x69, 0x73,
- 0x74, 0x53, 0x74, 0x61, 0x74, 0x65, 0x73, 0x12, 0x19, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e,
- 0x2e, 0x4c, 0x69, 0x73, 0x74, 0x53, 0x74, 0x61, 0x74, 0x65, 0x73, 0x52, 0x65, 0x71, 0x75, 0x65,
- 0x73, 0x74, 0x1a, 0x1a, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x4c, 0x69, 0x73, 0x74,
- 0x53, 0x74, 0x61, 0x74, 0x65, 0x73, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x22, 0x00,
- 0x12, 0x45, 0x0a, 0x0a, 0x43, 0x6c, 0x65, 0x61, 0x6e, 0x53, 0x74, 0x61, 0x74, 0x65, 0x12, 0x19,
- 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x43, 0x6c, 0x65, 0x61, 0x6e, 0x53, 0x74, 0x61,
- 0x74, 0x65, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x1a, 0x1a, 0x2e, 0x64, 0x61, 0x65, 0x6d,
- 0x6f, 0x6e, 0x2e, 0x43, 0x6c, 0x65, 0x61, 0x6e, 0x53, 0x74, 0x61, 0x74, 0x65, 0x52, 0x65, 0x73,
- 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x22, 0x00, 0x12, 0x48, 0x0a, 0x0b, 0x44, 0x65, 0x6c, 0x65, 0x74,
- 0x65, 0x53, 0x74, 0x61, 0x74, 0x65, 0x12, 0x1a, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e,
- 0x44, 0x65, 0x6c, 0x65, 0x74, 0x65, 0x53, 0x74, 0x61, 0x74, 0x65, 0x52, 0x65, 0x71, 0x75, 0x65,
- 0x73, 0x74, 0x1a, 0x1b, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x44, 0x65, 0x6c, 0x65,
- 0x74, 0x65, 0x53, 0x74, 0x61, 0x74, 0x65, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x22,
- 0x00, 0x12, 0x6f, 0x0a, 0x18, 0x53, 0x65, 0x74, 0x4e, 0x65, 0x74, 0x77, 0x6f, 0x72, 0x6b, 0x4d,
- 0x61, 0x70, 0x50, 0x65, 0x72, 0x73, 0x69, 0x73, 0x74, 0x65, 0x6e, 0x63, 0x65, 0x12, 0x27, 0x2e,
- 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x53, 0x65, 0x74, 0x4e, 0x65, 0x74, 0x77, 0x6f, 0x72,
- 0x6b, 0x4d, 0x61, 0x70, 0x50, 0x65, 0x72, 0x73, 0x69, 0x73, 0x74, 0x65, 0x6e, 0x63, 0x65, 0x52,
- 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x1a, 0x28, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e,
- 0x53, 0x65, 0x74, 0x4e, 0x65, 0x74, 0x77, 0x6f, 0x72, 0x6b, 0x4d, 0x61, 0x70, 0x50, 0x65, 0x72,
- 0x73, 0x69, 0x73, 0x74, 0x65, 0x6e, 0x63, 0x65, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65,
- 0x22, 0x00, 0x12, 0x48, 0x0a, 0x0b, 0x54, 0x72, 0x61, 0x63, 0x65, 0x50, 0x61, 0x63, 0x6b, 0x65,
- 0x74, 0x12, 0x1a, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x54, 0x72, 0x61, 0x63, 0x65,
- 0x50, 0x61, 0x63, 0x6b, 0x65, 0x74, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x1a, 0x1b, 0x2e,
+ 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x1a, 0x1b, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e,
+ 0x53, 0x65, 0x74, 0x4c, 0x6f, 0x67, 0x4c, 0x65, 0x76, 0x65, 0x6c, 0x52, 0x65, 0x73, 0x70, 0x6f,
+ 0x6e, 0x73, 0x65, 0x22, 0x00, 0x12, 0x45, 0x0a, 0x0a, 0x4c, 0x69, 0x73, 0x74, 0x53, 0x74, 0x61,
+ 0x74, 0x65, 0x73, 0x12, 0x19, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x4c, 0x69, 0x73,
+ 0x74, 0x53, 0x74, 0x61, 0x74, 0x65, 0x73, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x1a, 0x1a,
+ 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x4c, 0x69, 0x73, 0x74, 0x53, 0x74, 0x61, 0x74,
+ 0x65, 0x73, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x22, 0x00, 0x12, 0x45, 0x0a, 0x0a,
+ 0x43, 0x6c, 0x65, 0x61, 0x6e, 0x53, 0x74, 0x61, 0x74, 0x65, 0x12, 0x19, 0x2e, 0x64, 0x61, 0x65,
+ 0x6d, 0x6f, 0x6e, 0x2e, 0x43, 0x6c, 0x65, 0x61, 0x6e, 0x53, 0x74, 0x61, 0x74, 0x65, 0x52, 0x65,
+ 0x71, 0x75, 0x65, 0x73, 0x74, 0x1a, 0x1a, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x43,
+ 0x6c, 0x65, 0x61, 0x6e, 0x53, 0x74, 0x61, 0x74, 0x65, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73,
+ 0x65, 0x22, 0x00, 0x12, 0x48, 0x0a, 0x0b, 0x44, 0x65, 0x6c, 0x65, 0x74, 0x65, 0x53, 0x74, 0x61,
+ 0x74, 0x65, 0x12, 0x1a, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x44, 0x65, 0x6c, 0x65,
+ 0x74, 0x65, 0x53, 0x74, 0x61, 0x74, 0x65, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x1a, 0x1b,
+ 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x44, 0x65, 0x6c, 0x65, 0x74, 0x65, 0x53, 0x74,
+ 0x61, 0x74, 0x65, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x22, 0x00, 0x12, 0x6f, 0x0a,
+ 0x18, 0x53, 0x65, 0x74, 0x4e, 0x65, 0x74, 0x77, 0x6f, 0x72, 0x6b, 0x4d, 0x61, 0x70, 0x50, 0x65,
+ 0x72, 0x73, 0x69, 0x73, 0x74, 0x65, 0x6e, 0x63, 0x65, 0x12, 0x27, 0x2e, 0x64, 0x61, 0x65, 0x6d,
+ 0x6f, 0x6e, 0x2e, 0x53, 0x65, 0x74, 0x4e, 0x65, 0x74, 0x77, 0x6f, 0x72, 0x6b, 0x4d, 0x61, 0x70,
+ 0x50, 0x65, 0x72, 0x73, 0x69, 0x73, 0x74, 0x65, 0x6e, 0x63, 0x65, 0x52, 0x65, 0x71, 0x75, 0x65,
+ 0x73, 0x74, 0x1a, 0x28, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x53, 0x65, 0x74, 0x4e,
+ 0x65, 0x74, 0x77, 0x6f, 0x72, 0x6b, 0x4d, 0x61, 0x70, 0x50, 0x65, 0x72, 0x73, 0x69, 0x73, 0x74,
+ 0x65, 0x6e, 0x63, 0x65, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x22, 0x00, 0x12, 0x48,
+ 0x0a, 0x0b, 0x54, 0x72, 0x61, 0x63, 0x65, 0x50, 0x61, 0x63, 0x6b, 0x65, 0x74, 0x12, 0x1a, 0x2e,
0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x54, 0x72, 0x61, 0x63, 0x65, 0x50, 0x61, 0x63, 0x6b,
- 0x65, 0x74, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x22, 0x00, 0x12, 0x44, 0x0a, 0x0f,
- 0x53, 0x75, 0x62, 0x73, 0x63, 0x72, 0x69, 0x62, 0x65, 0x45, 0x76, 0x65, 0x6e, 0x74, 0x73, 0x12,
- 0x18, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x53, 0x75, 0x62, 0x73, 0x63, 0x72, 0x69,
- 0x62, 0x65, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x1a, 0x13, 0x2e, 0x64, 0x61, 0x65, 0x6d,
- 0x6f, 0x6e, 0x2e, 0x53, 0x79, 0x73, 0x74, 0x65, 0x6d, 0x45, 0x76, 0x65, 0x6e, 0x74, 0x22, 0x00,
- 0x30, 0x01, 0x12, 0x42, 0x0a, 0x09, 0x47, 0x65, 0x74, 0x45, 0x76, 0x65, 0x6e, 0x74, 0x73, 0x12,
- 0x18, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x47, 0x65, 0x74, 0x45, 0x76, 0x65, 0x6e,
- 0x74, 0x73, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x1a, 0x19, 0x2e, 0x64, 0x61, 0x65, 0x6d,
- 0x6f, 0x6e, 0x2e, 0x47, 0x65, 0x74, 0x45, 0x76, 0x65, 0x6e, 0x74, 0x73, 0x52, 0x65, 0x73, 0x70,
- 0x6f, 0x6e, 0x73, 0x65, 0x22, 0x00, 0x42, 0x08, 0x5a, 0x06, 0x2f, 0x70, 0x72, 0x6f, 0x74, 0x6f,
- 0x62, 0x06, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x33,
+ 0x65, 0x74, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x1a, 0x1b, 0x2e, 0x64, 0x61, 0x65, 0x6d,
+ 0x6f, 0x6e, 0x2e, 0x54, 0x72, 0x61, 0x63, 0x65, 0x50, 0x61, 0x63, 0x6b, 0x65, 0x74, 0x52, 0x65,
+ 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x22, 0x00, 0x12, 0x44, 0x0a, 0x0f, 0x53, 0x75, 0x62, 0x73,
+ 0x63, 0x72, 0x69, 0x62, 0x65, 0x45, 0x76, 0x65, 0x6e, 0x74, 0x73, 0x12, 0x18, 0x2e, 0x64, 0x61,
+ 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x53, 0x75, 0x62, 0x73, 0x63, 0x72, 0x69, 0x62, 0x65, 0x52, 0x65,
+ 0x71, 0x75, 0x65, 0x73, 0x74, 0x1a, 0x13, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x53,
+ 0x79, 0x73, 0x74, 0x65, 0x6d, 0x45, 0x76, 0x65, 0x6e, 0x74, 0x22, 0x00, 0x30, 0x01, 0x12, 0x42,
+ 0x0a, 0x09, 0x47, 0x65, 0x74, 0x45, 0x76, 0x65, 0x6e, 0x74, 0x73, 0x12, 0x18, 0x2e, 0x64, 0x61,
+ 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x47, 0x65, 0x74, 0x45, 0x76, 0x65, 0x6e, 0x74, 0x73, 0x52, 0x65,
+ 0x71, 0x75, 0x65, 0x73, 0x74, 0x1a, 0x19, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x47,
+ 0x65, 0x74, 0x45, 0x76, 0x65, 0x6e, 0x74, 0x73, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65,
+ 0x22, 0x00, 0x42, 0x08, 0x5a, 0x06, 0x2f, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x62, 0x06, 0x70, 0x72,
+ 0x6f, 0x74, 0x6f, 0x33,
}
var (
diff --git a/client/proto/daemon.proto b/client/proto/daemon.proto
index 4f2032d00..49e577853 100644
--- a/client/proto/daemon.proto
+++ b/client/proto/daemon.proto
@@ -457,6 +457,7 @@ message SystemEvent {
DNS = 1;
AUTHENTICATION = 2;
CONNECTIVITY = 3;
+ SYSTEM = 4;
}
string id = 1;
diff --git a/client/server/network.go b/client/server/network.go
index aaf361524..d310f4da1 100644
--- a/client/server/network.go
+++ b/client/server/network.go
@@ -6,6 +6,7 @@ import (
"net/netip"
"slices"
"sort"
+ "strings"
"golang.org/x/exp/maps"
@@ -134,6 +135,18 @@ func (s *Server) SelectNetworks(_ context.Context, req *proto.SelectNetworksRequ
}
routeManager.TriggerSelection(routeManager.GetClientRoutes())
+ s.statusRecorder.PublishEvent(
+ proto.SystemEvent_INFO,
+ proto.SystemEvent_SYSTEM,
+ "Network selection changed",
+ "",
+ map[string]string{
+ "networks": strings.Join(req.GetNetworkIDs(), ", "),
+ "append": fmt.Sprint(req.GetAppend()),
+ "all": fmt.Sprint(req.GetAll()),
+ },
+ )
+
return &proto.SelectNetworksResponse{}, nil
}
@@ -164,6 +177,18 @@ func (s *Server) DeselectNetworks(_ context.Context, req *proto.SelectNetworksRe
}
routeManager.TriggerSelection(routeManager.GetClientRoutes())
+ s.statusRecorder.PublishEvent(
+ proto.SystemEvent_INFO,
+ proto.SystemEvent_SYSTEM,
+ "Network deselection changed",
+ "",
+ map[string]string{
+ "networks": strings.Join(req.GetNetworkIDs(), ", "),
+ "append": fmt.Sprint(req.GetAppend()),
+ "all": fmt.Sprint(req.GetAll()),
+ },
+ )
+
return &proto.SelectNetworksResponse{}, nil
}
diff --git a/client/server/server.go b/client/server/server.go
index dcc6e5651..8907f541f 100644
--- a/client/server/server.go
+++ b/client/server/server.go
@@ -751,6 +751,11 @@ func (s *Server) GetConfig(_ context.Context, _ *proto.GetConfigRequest) (*proto
}
+ disableNotifications := true
+ if s.config.DisableNotifications != nil {
+ disableNotifications = *s.config.DisableNotifications
+ }
+
return &proto.GetConfigResponse{
ManagementUrl: managementURL,
ConfigFile: s.latestConfigInput.ConfigPath,
@@ -763,13 +768,14 @@ func (s *Server) GetConfig(_ context.Context, _ *proto.GetConfigRequest) (*proto
ServerSSHAllowed: *s.config.ServerSSHAllowed,
RosenpassEnabled: s.config.RosenpassEnabled,
RosenpassPermissive: s.config.RosenpassPermissive,
- DisableNotifications: s.config.DisableNotifications,
+ DisableNotifications: disableNotifications,
}, nil
}
+
func (s *Server) onSessionExpire() {
if runtime.GOOS != "windows" {
isUIActive := internal.CheckUIApp()
- if !isUIActive {
+ if !isUIActive && s.config.DisableNotifications != nil && !*s.config.DisableNotifications {
if err := sendTerminalNotification(); err != nil {
log.Errorf("send session expire terminal notification: %v", err)
}
diff --git a/client/server/server_test.go b/client/server/server_test.go
index 278ec246c..0c0f32fec 100644
--- a/client/server/server_test.go
+++ b/client/server/server_test.go
@@ -135,7 +135,7 @@ func startManagement(t *testing.T, signalAddr string, counter *int) (*grpc.Serve
}
secretsManager := server.NewTimeBasedAuthSecretsManager(peersUpdateManager, config.TURNConfig, config.Relay)
- mgmtServer, err := server.NewServer(context.Background(), config, accountManager, settings.NewManager(store), peersUpdateManager, secretsManager, nil, nil)
+ mgmtServer, err := server.NewServer(context.Background(), config, accountManager, settings.NewManager(store), peersUpdateManager, secretsManager, nil, nil, nil)
if err != nil {
return nil, "", err
}
diff --git a/client/cmd/status_event.go b/client/status/event.go
similarity index 86%
rename from client/cmd/status_event.go
rename to client/status/event.go
index 9331570e6..2b65c9fa3 100644
--- a/client/cmd/status_event.go
+++ b/client/status/event.go
@@ -1,4 +1,4 @@
-package cmd
+package status
import (
"fmt"
@@ -9,7 +9,7 @@ import (
"github.com/netbirdio/netbird/client/proto"
)
-type systemEventOutput struct {
+type SystemEventOutput struct {
ID string `json:"id" yaml:"id"`
Severity string `json:"severity" yaml:"severity"`
Category string `json:"category" yaml:"category"`
@@ -19,10 +19,10 @@ type systemEventOutput struct {
Metadata map[string]string `json:"metadata" yaml:"metadata"`
}
-func mapEvents(protoEvents []*proto.SystemEvent) []systemEventOutput {
- events := make([]systemEventOutput, len(protoEvents))
+func mapEvents(protoEvents []*proto.SystemEvent) []SystemEventOutput {
+ events := make([]SystemEventOutput, len(protoEvents))
for i, event := range protoEvents {
- events[i] = systemEventOutput{
+ events[i] = SystemEventOutput{
ID: event.GetId(),
Severity: event.GetSeverity().String(),
Category: event.GetCategory().String(),
@@ -35,7 +35,7 @@ func mapEvents(protoEvents []*proto.SystemEvent) []systemEventOutput {
return events
}
-func parseEvents(events []systemEventOutput) string {
+func parseEvents(events []SystemEventOutput) string {
if len(events) == 0 {
return " No events recorded"
}
diff --git a/client/status/status.go b/client/status/status.go
new file mode 100644
index 000000000..43acc9197
--- /dev/null
+++ b/client/status/status.go
@@ -0,0 +1,729 @@
+package status
+
+import (
+ "encoding/json"
+ "fmt"
+ "net"
+ "net/netip"
+ "os"
+ "runtime"
+ "sort"
+ "strings"
+ "time"
+
+ "gopkg.in/yaml.v3"
+
+ "github.com/netbirdio/netbird/client/anonymize"
+ "github.com/netbirdio/netbird/client/internal/peer"
+ "github.com/netbirdio/netbird/client/proto"
+ "github.com/netbirdio/netbird/version"
+)
+
+type PeerStateDetailOutput struct {
+ FQDN string `json:"fqdn" yaml:"fqdn"`
+ IP string `json:"netbirdIp" yaml:"netbirdIp"`
+ PubKey string `json:"publicKey" yaml:"publicKey"`
+ Status string `json:"status" yaml:"status"`
+ LastStatusUpdate time.Time `json:"lastStatusUpdate" yaml:"lastStatusUpdate"`
+ ConnType string `json:"connectionType" yaml:"connectionType"`
+ IceCandidateType IceCandidateType `json:"iceCandidateType" yaml:"iceCandidateType"`
+ IceCandidateEndpoint IceCandidateType `json:"iceCandidateEndpoint" yaml:"iceCandidateEndpoint"`
+ RelayAddress string `json:"relayAddress" yaml:"relayAddress"`
+ LastWireguardHandshake time.Time `json:"lastWireguardHandshake" yaml:"lastWireguardHandshake"`
+ TransferReceived int64 `json:"transferReceived" yaml:"transferReceived"`
+ TransferSent int64 `json:"transferSent" yaml:"transferSent"`
+ Latency time.Duration `json:"latency" yaml:"latency"`
+ RosenpassEnabled bool `json:"quantumResistance" yaml:"quantumResistance"`
+ Networks []string `json:"networks" yaml:"networks"`
+}
+
+type PeersStateOutput struct {
+ Total int `json:"total" yaml:"total"`
+ Connected int `json:"connected" yaml:"connected"`
+ Details []PeerStateDetailOutput `json:"details" yaml:"details"`
+}
+
+type SignalStateOutput struct {
+ URL string `json:"url" yaml:"url"`
+ Connected bool `json:"connected" yaml:"connected"`
+ Error string `json:"error" yaml:"error"`
+}
+
+type ManagementStateOutput struct {
+ URL string `json:"url" yaml:"url"`
+ Connected bool `json:"connected" yaml:"connected"`
+ Error string `json:"error" yaml:"error"`
+}
+
+type RelayStateOutputDetail struct {
+ URI string `json:"uri" yaml:"uri"`
+ Available bool `json:"available" yaml:"available"`
+ Error string `json:"error" yaml:"error"`
+}
+
+type RelayStateOutput struct {
+ Total int `json:"total" yaml:"total"`
+ Available int `json:"available" yaml:"available"`
+ Details []RelayStateOutputDetail `json:"details" yaml:"details"`
+}
+
+type IceCandidateType struct {
+ Local string `json:"local" yaml:"local"`
+ Remote string `json:"remote" yaml:"remote"`
+}
+
+type NsServerGroupStateOutput struct {
+ Servers []string `json:"servers" yaml:"servers"`
+ Domains []string `json:"domains" yaml:"domains"`
+ Enabled bool `json:"enabled" yaml:"enabled"`
+ Error string `json:"error" yaml:"error"`
+}
+
+type OutputOverview struct {
+ Peers PeersStateOutput `json:"peers" yaml:"peers"`
+ CliVersion string `json:"cliVersion" yaml:"cliVersion"`
+ DaemonVersion string `json:"daemonVersion" yaml:"daemonVersion"`
+ ManagementState ManagementStateOutput `json:"management" yaml:"management"`
+ SignalState SignalStateOutput `json:"signal" yaml:"signal"`
+ Relays RelayStateOutput `json:"relays" yaml:"relays"`
+ IP string `json:"netbirdIp" yaml:"netbirdIp"`
+ PubKey string `json:"publicKey" yaml:"publicKey"`
+ KernelInterface bool `json:"usesKernelInterface" yaml:"usesKernelInterface"`
+ FQDN string `json:"fqdn" yaml:"fqdn"`
+ RosenpassEnabled bool `json:"quantumResistance" yaml:"quantumResistance"`
+ RosenpassPermissive bool `json:"quantumResistancePermissive" yaml:"quantumResistancePermissive"`
+ Networks []string `json:"networks" yaml:"networks"`
+ NumberOfForwardingRules int `json:"forwardingRules" yaml:"forwardingRules"`
+ NSServerGroups []NsServerGroupStateOutput `json:"dnsServers" yaml:"dnsServers"`
+ Events []SystemEventOutput `json:"events" yaml:"events"`
+}
+
+func ConvertToStatusOutputOverview(resp *proto.StatusResponse, anon bool, statusFilter string, prefixNamesFilter []string, prefixNamesFilterMap map[string]struct{}, ipsFilter map[string]struct{}) OutputOverview {
+ pbFullStatus := resp.GetFullStatus()
+
+ managementState := pbFullStatus.GetManagementState()
+ managementOverview := ManagementStateOutput{
+ URL: managementState.GetURL(),
+ Connected: managementState.GetConnected(),
+ Error: managementState.Error,
+ }
+
+ signalState := pbFullStatus.GetSignalState()
+ signalOverview := SignalStateOutput{
+ URL: signalState.GetURL(),
+ Connected: signalState.GetConnected(),
+ Error: signalState.Error,
+ }
+
+ relayOverview := mapRelays(pbFullStatus.GetRelays())
+ peersOverview := mapPeers(resp.GetFullStatus().GetPeers(), statusFilter, prefixNamesFilter, prefixNamesFilterMap, ipsFilter)
+
+ overview := OutputOverview{
+ Peers: peersOverview,
+ CliVersion: version.NetbirdVersion(),
+ DaemonVersion: resp.GetDaemonVersion(),
+ ManagementState: managementOverview,
+ SignalState: signalOverview,
+ Relays: relayOverview,
+ IP: pbFullStatus.GetLocalPeerState().GetIP(),
+ PubKey: pbFullStatus.GetLocalPeerState().GetPubKey(),
+ KernelInterface: pbFullStatus.GetLocalPeerState().GetKernelInterface(),
+ FQDN: pbFullStatus.GetLocalPeerState().GetFqdn(),
+ RosenpassEnabled: pbFullStatus.GetLocalPeerState().GetRosenpassEnabled(),
+ RosenpassPermissive: pbFullStatus.GetLocalPeerState().GetRosenpassPermissive(),
+ Networks: pbFullStatus.GetLocalPeerState().GetNetworks(),
+ NumberOfForwardingRules: int(pbFullStatus.GetNumberOfForwardingRules()),
+ NSServerGroups: mapNSGroups(pbFullStatus.GetDnsServers()),
+ Events: mapEvents(pbFullStatus.GetEvents()),
+ }
+
+ if anon {
+ anonymizer := anonymize.NewAnonymizer(anonymize.DefaultAddresses())
+ anonymizeOverview(anonymizer, &overview)
+ }
+
+ return overview
+}
+
+func mapRelays(relays []*proto.RelayState) RelayStateOutput {
+ var relayStateDetail []RelayStateOutputDetail
+
+ var relaysAvailable int
+ for _, relay := range relays {
+ available := relay.GetAvailable()
+ relayStateDetail = append(relayStateDetail,
+ RelayStateOutputDetail{
+ URI: relay.URI,
+ Available: available,
+ Error: relay.GetError(),
+ },
+ )
+
+ if available {
+ relaysAvailable++
+ }
+ }
+
+ return RelayStateOutput{
+ Total: len(relays),
+ Available: relaysAvailable,
+ Details: relayStateDetail,
+ }
+}
+
+func mapNSGroups(servers []*proto.NSGroupState) []NsServerGroupStateOutput {
+ mappedNSGroups := make([]NsServerGroupStateOutput, 0, len(servers))
+ for _, pbNsGroupServer := range servers {
+ mappedNSGroups = append(mappedNSGroups, NsServerGroupStateOutput{
+ Servers: pbNsGroupServer.GetServers(),
+ Domains: pbNsGroupServer.GetDomains(),
+ Enabled: pbNsGroupServer.GetEnabled(),
+ Error: pbNsGroupServer.GetError(),
+ })
+ }
+ return mappedNSGroups
+}
+
+func mapPeers(
+ peers []*proto.PeerState,
+ statusFilter string,
+ prefixNamesFilter []string,
+ prefixNamesFilterMap map[string]struct{},
+ ipsFilter map[string]struct{},
+) PeersStateOutput {
+ var peersStateDetail []PeerStateDetailOutput
+ peersConnected := 0
+ for _, pbPeerState := range peers {
+ localICE := ""
+ remoteICE := ""
+ localICEEndpoint := ""
+ remoteICEEndpoint := ""
+ relayServerAddress := ""
+ connType := ""
+ lastHandshake := time.Time{}
+ transferReceived := int64(0)
+ transferSent := int64(0)
+
+ isPeerConnected := pbPeerState.ConnStatus == peer.StatusConnected.String()
+ if skipDetailByFilters(pbPeerState, isPeerConnected, statusFilter, prefixNamesFilter, prefixNamesFilterMap, ipsFilter) {
+ continue
+ }
+ if isPeerConnected {
+ peersConnected++
+
+ localICE = pbPeerState.GetLocalIceCandidateType()
+ remoteICE = pbPeerState.GetRemoteIceCandidateType()
+ localICEEndpoint = pbPeerState.GetLocalIceCandidateEndpoint()
+ remoteICEEndpoint = pbPeerState.GetRemoteIceCandidateEndpoint()
+ connType = "P2P"
+ if pbPeerState.Relayed {
+ connType = "Relayed"
+ }
+ relayServerAddress = pbPeerState.GetRelayAddress()
+ lastHandshake = pbPeerState.GetLastWireguardHandshake().AsTime().Local()
+ transferReceived = pbPeerState.GetBytesRx()
+ transferSent = pbPeerState.GetBytesTx()
+ }
+
+ timeLocal := pbPeerState.GetConnStatusUpdate().AsTime().Local()
+ peerState := PeerStateDetailOutput{
+ IP: pbPeerState.GetIP(),
+ PubKey: pbPeerState.GetPubKey(),
+ Status: pbPeerState.GetConnStatus(),
+ LastStatusUpdate: timeLocal,
+ ConnType: connType,
+ IceCandidateType: IceCandidateType{
+ Local: localICE,
+ Remote: remoteICE,
+ },
+ IceCandidateEndpoint: IceCandidateType{
+ Local: localICEEndpoint,
+ Remote: remoteICEEndpoint,
+ },
+ RelayAddress: relayServerAddress,
+ FQDN: pbPeerState.GetFqdn(),
+ LastWireguardHandshake: lastHandshake,
+ TransferReceived: transferReceived,
+ TransferSent: transferSent,
+ Latency: pbPeerState.GetLatency().AsDuration(),
+ RosenpassEnabled: pbPeerState.GetRosenpassEnabled(),
+ Networks: pbPeerState.GetNetworks(),
+ }
+
+ peersStateDetail = append(peersStateDetail, peerState)
+ }
+
+ sortPeersByIP(peersStateDetail)
+
+ peersOverview := PeersStateOutput{
+ Total: len(peersStateDetail),
+ Connected: peersConnected,
+ Details: peersStateDetail,
+ }
+ return peersOverview
+}
+
+func sortPeersByIP(peersStateDetail []PeerStateDetailOutput) {
+ if len(peersStateDetail) > 0 {
+ sort.SliceStable(peersStateDetail, func(i, j int) bool {
+ iAddr, _ := netip.ParseAddr(peersStateDetail[i].IP)
+ jAddr, _ := netip.ParseAddr(peersStateDetail[j].IP)
+ return iAddr.Compare(jAddr) == -1
+ })
+ }
+}
+
+func ParseToJSON(overview OutputOverview) (string, error) {
+ jsonBytes, err := json.Marshal(overview)
+ if err != nil {
+ return "", fmt.Errorf("json marshal failed")
+ }
+ return string(jsonBytes), err
+}
+
+func ParseToYAML(overview OutputOverview) (string, error) {
+ yamlBytes, err := yaml.Marshal(overview)
+ if err != nil {
+ return "", fmt.Errorf("yaml marshal failed")
+ }
+ return string(yamlBytes), nil
+}
+
+func ParseGeneralSummary(overview OutputOverview, showURL bool, showRelays bool, showNameServers bool) string {
+ var managementConnString string
+ if overview.ManagementState.Connected {
+ managementConnString = "Connected"
+ if showURL {
+ managementConnString = fmt.Sprintf("%s to %s", managementConnString, overview.ManagementState.URL)
+ }
+ } else {
+ managementConnString = "Disconnected"
+ if overview.ManagementState.Error != "" {
+ managementConnString = fmt.Sprintf("%s, reason: %s", managementConnString, overview.ManagementState.Error)
+ }
+ }
+
+ var signalConnString string
+ if overview.SignalState.Connected {
+ signalConnString = "Connected"
+ if showURL {
+ signalConnString = fmt.Sprintf("%s to %s", signalConnString, overview.SignalState.URL)
+ }
+ } else {
+ signalConnString = "Disconnected"
+ if overview.SignalState.Error != "" {
+ signalConnString = fmt.Sprintf("%s, reason: %s", signalConnString, overview.SignalState.Error)
+ }
+ }
+
+ interfaceTypeString := "Userspace"
+ interfaceIP := overview.IP
+ if overview.KernelInterface {
+ interfaceTypeString = "Kernel"
+ } else if overview.IP == "" {
+ interfaceTypeString = "N/A"
+ interfaceIP = "N/A"
+ }
+
+ var relaysString string
+ if showRelays {
+ for _, relay := range overview.Relays.Details {
+ available := "Available"
+ reason := ""
+ if !relay.Available {
+ available = "Unavailable"
+ reason = fmt.Sprintf(", reason: %s", relay.Error)
+ }
+ relaysString += fmt.Sprintf("\n [%s] is %s%s", relay.URI, available, reason)
+ }
+ } else {
+ relaysString = fmt.Sprintf("%d/%d Available", overview.Relays.Available, overview.Relays.Total)
+ }
+
+ networks := "-"
+ if len(overview.Networks) > 0 {
+ sort.Strings(overview.Networks)
+ networks = strings.Join(overview.Networks, ", ")
+ }
+
+ var dnsServersString string
+ if showNameServers {
+ for _, nsServerGroup := range overview.NSServerGroups {
+ enabled := "Available"
+ if !nsServerGroup.Enabled {
+ enabled = "Unavailable"
+ }
+ errorString := ""
+ if nsServerGroup.Error != "" {
+ errorString = fmt.Sprintf(", reason: %s", nsServerGroup.Error)
+ errorString = strings.TrimSpace(errorString)
+ }
+
+ domainsString := strings.Join(nsServerGroup.Domains, ", ")
+ if domainsString == "" {
+ domainsString = "." // Show "." for the default zone
+ }
+ dnsServersString += fmt.Sprintf(
+ "\n [%s] for [%s] is %s%s",
+ strings.Join(nsServerGroup.Servers, ", "),
+ domainsString,
+ enabled,
+ errorString,
+ )
+ }
+ } else {
+ dnsServersString = fmt.Sprintf("%d/%d Available", countEnabled(overview.NSServerGroups), len(overview.NSServerGroups))
+ }
+
+ rosenpassEnabledStatus := "false"
+ if overview.RosenpassEnabled {
+ rosenpassEnabledStatus = "true"
+ if overview.RosenpassPermissive {
+ rosenpassEnabledStatus = "true (permissive)" //nolint:gosec
+ }
+ }
+
+ peersCountString := fmt.Sprintf("%d/%d Connected", overview.Peers.Connected, overview.Peers.Total)
+
+ goos := runtime.GOOS
+ goarch := runtime.GOARCH
+ goarm := ""
+ if goarch == "arm" {
+ goarm = fmt.Sprintf(" (ARMv%s)", os.Getenv("GOARM"))
+ }
+
+ summary := fmt.Sprintf(
+ "OS: %s\n"+
+ "Daemon version: %s\n"+
+ "CLI version: %s\n"+
+ "Management: %s\n"+
+ "Signal: %s\n"+
+ "Relays: %s\n"+
+ "Nameservers: %s\n"+
+ "FQDN: %s\n"+
+ "NetBird IP: %s\n"+
+ "Interface type: %s\n"+
+ "Quantum resistance: %s\n"+
+ "Networks: %s\n"+
+ "Forwarding rules: %d\n"+
+ "Peers count: %s\n",
+ fmt.Sprintf("%s/%s%s", goos, goarch, goarm),
+ overview.DaemonVersion,
+ version.NetbirdVersion(),
+ managementConnString,
+ signalConnString,
+ relaysString,
+ dnsServersString,
+ overview.FQDN,
+ interfaceIP,
+ interfaceTypeString,
+ rosenpassEnabledStatus,
+ networks,
+ overview.NumberOfForwardingRules,
+ peersCountString,
+ )
+ return summary
+}
+
+func ParseToFullDetailSummary(overview OutputOverview) string {
+ parsedPeersString := parsePeers(overview.Peers, overview.RosenpassEnabled, overview.RosenpassPermissive)
+ parsedEventsString := parseEvents(overview.Events)
+ summary := ParseGeneralSummary(overview, true, true, true)
+
+ return fmt.Sprintf(
+ "Peers detail:"+
+ "%s\n"+
+ "Events:"+
+ "%s\n"+
+ "%s",
+ parsedPeersString,
+ parsedEventsString,
+ summary,
+ )
+}
+
+func parsePeers(peers PeersStateOutput, rosenpassEnabled, rosenpassPermissive bool) string {
+ var (
+ peersString = ""
+ )
+
+ for _, peerState := range peers.Details {
+
+ localICE := "-"
+ if peerState.IceCandidateType.Local != "" {
+ localICE = peerState.IceCandidateType.Local
+ }
+
+ remoteICE := "-"
+ if peerState.IceCandidateType.Remote != "" {
+ remoteICE = peerState.IceCandidateType.Remote
+ }
+
+ localICEEndpoint := "-"
+ if peerState.IceCandidateEndpoint.Local != "" {
+ localICEEndpoint = peerState.IceCandidateEndpoint.Local
+ }
+
+ remoteICEEndpoint := "-"
+ if peerState.IceCandidateEndpoint.Remote != "" {
+ remoteICEEndpoint = peerState.IceCandidateEndpoint.Remote
+ }
+
+ rosenpassEnabledStatus := "false"
+ if rosenpassEnabled {
+ if peerState.RosenpassEnabled {
+ rosenpassEnabledStatus = "true"
+ } else {
+ if rosenpassPermissive {
+ rosenpassEnabledStatus = "false (remote didn't enable quantum resistance)"
+ } else {
+ rosenpassEnabledStatus = "false (connection won't work without a permissive mode)"
+ }
+ }
+ } else {
+ if peerState.RosenpassEnabled {
+ rosenpassEnabledStatus = "false (connection might not work without a remote permissive mode)"
+ }
+ }
+
+ networks := "-"
+ if len(peerState.Networks) > 0 {
+ sort.Strings(peerState.Networks)
+ networks = strings.Join(peerState.Networks, ", ")
+ }
+
+ peerString := fmt.Sprintf(
+ "\n %s:\n"+
+ " NetBird IP: %s\n"+
+ " Public key: %s\n"+
+ " Status: %s\n"+
+ " -- detail --\n"+
+ " Connection type: %s\n"+
+ " ICE candidate (Local/Remote): %s/%s\n"+
+ " ICE candidate endpoints (Local/Remote): %s/%s\n"+
+ " Relay server address: %s\n"+
+ " Last connection update: %s\n"+
+ " Last WireGuard handshake: %s\n"+
+ " Transfer status (received/sent) %s/%s\n"+
+ " Quantum resistance: %s\n"+
+ " Networks: %s\n"+
+ " Latency: %s\n",
+ peerState.FQDN,
+ peerState.IP,
+ peerState.PubKey,
+ peerState.Status,
+ peerState.ConnType,
+ localICE,
+ remoteICE,
+ localICEEndpoint,
+ remoteICEEndpoint,
+ peerState.RelayAddress,
+ timeAgo(peerState.LastStatusUpdate),
+ timeAgo(peerState.LastWireguardHandshake),
+ toIEC(peerState.TransferReceived),
+ toIEC(peerState.TransferSent),
+ rosenpassEnabledStatus,
+ networks,
+ peerState.Latency.String(),
+ )
+
+ peersString += peerString
+ }
+ return peersString
+}
+
+func skipDetailByFilters(
+ peerState *proto.PeerState,
+ isConnected bool,
+ statusFilter string,
+ prefixNamesFilter []string,
+ prefixNamesFilterMap map[string]struct{},
+ ipsFilter map[string]struct{},
+) bool {
+ statusEval := false
+ ipEval := false
+ nameEval := true
+
+ if statusFilter != "" {
+ lowerStatusFilter := strings.ToLower(statusFilter)
+ if lowerStatusFilter == "disconnected" && isConnected {
+ statusEval = true
+ } else if lowerStatusFilter == "connected" && !isConnected {
+ statusEval = true
+ }
+ }
+
+ if len(ipsFilter) > 0 {
+ _, ok := ipsFilter[peerState.IP]
+ if !ok {
+ ipEval = true
+ }
+ }
+
+ if len(prefixNamesFilter) > 0 {
+ for prefixNameFilter := range prefixNamesFilterMap {
+ if strings.HasPrefix(peerState.Fqdn, prefixNameFilter) {
+ nameEval = false
+ break
+ }
+ }
+ } else {
+ nameEval = false
+ }
+
+ return statusEval || ipEval || nameEval
+}
+
+func toIEC(b int64) string {
+ const unit = 1024
+ if b < unit {
+ return fmt.Sprintf("%d B", b)
+ }
+ div, exp := int64(unit), 0
+ for n := b / unit; n >= unit; n /= unit {
+ div *= unit
+ exp++
+ }
+ return fmt.Sprintf("%.1f %ciB",
+ float64(b)/float64(div), "KMGTPE"[exp])
+}
+
+func countEnabled(dnsServers []NsServerGroupStateOutput) int {
+ count := 0
+ for _, server := range dnsServers {
+ if server.Enabled {
+ count++
+ }
+ }
+ return count
+}
+
+// timeAgo returns a string representing the duration since the provided time in a human-readable format.
+func timeAgo(t time.Time) string {
+ if t.IsZero() || t.Equal(time.Unix(0, 0)) {
+ return "-"
+ }
+ duration := time.Since(t)
+ switch {
+ case duration < time.Second:
+ return "Now"
+ case duration < time.Minute:
+ seconds := int(duration.Seconds())
+ if seconds == 1 {
+ return "1 second ago"
+ }
+ return fmt.Sprintf("%d seconds ago", seconds)
+ case duration < time.Hour:
+ minutes := int(duration.Minutes())
+ seconds := int(duration.Seconds()) % 60
+ if minutes == 1 {
+ if seconds == 1 {
+ return "1 minute, 1 second ago"
+ } else if seconds > 0 {
+ return fmt.Sprintf("1 minute, %d seconds ago", seconds)
+ }
+ return "1 minute ago"
+ }
+ if seconds > 0 {
+ return fmt.Sprintf("%d minutes, %d seconds ago", minutes, seconds)
+ }
+ return fmt.Sprintf("%d minutes ago", minutes)
+ case duration < 24*time.Hour:
+ hours := int(duration.Hours())
+ minutes := int(duration.Minutes()) % 60
+ if hours == 1 {
+ if minutes == 1 {
+ return "1 hour, 1 minute ago"
+ } else if minutes > 0 {
+ return fmt.Sprintf("1 hour, %d minutes ago", minutes)
+ }
+ return "1 hour ago"
+ }
+ if minutes > 0 {
+ return fmt.Sprintf("%d hours, %d minutes ago", hours, minutes)
+ }
+ return fmt.Sprintf("%d hours ago", hours)
+ }
+
+ days := int(duration.Hours()) / 24
+ hours := int(duration.Hours()) % 24
+ if days == 1 {
+ if hours == 1 {
+ return "1 day, 1 hour ago"
+ } else if hours > 0 {
+ return fmt.Sprintf("1 day, %d hours ago", hours)
+ }
+ return "1 day ago"
+ }
+ if hours > 0 {
+ return fmt.Sprintf("%d days, %d hours ago", days, hours)
+ }
+ return fmt.Sprintf("%d days ago", days)
+}
+
+func anonymizePeerDetail(a *anonymize.Anonymizer, peer *PeerStateDetailOutput) {
+ peer.FQDN = a.AnonymizeDomain(peer.FQDN)
+ if localIP, port, err := net.SplitHostPort(peer.IceCandidateEndpoint.Local); err == nil {
+ peer.IceCandidateEndpoint.Local = fmt.Sprintf("%s:%s", a.AnonymizeIPString(localIP), port)
+ }
+ if remoteIP, port, err := net.SplitHostPort(peer.IceCandidateEndpoint.Remote); err == nil {
+ peer.IceCandidateEndpoint.Remote = fmt.Sprintf("%s:%s", a.AnonymizeIPString(remoteIP), port)
+ }
+
+ peer.RelayAddress = a.AnonymizeURI(peer.RelayAddress)
+
+ for i, route := range peer.Networks {
+ peer.Networks[i] = a.AnonymizeIPString(route)
+ }
+
+ for i, route := range peer.Networks {
+ peer.Networks[i] = a.AnonymizeRoute(route)
+ }
+}
+
+func anonymizeOverview(a *anonymize.Anonymizer, overview *OutputOverview) {
+ for i, peer := range overview.Peers.Details {
+ peer := peer
+ anonymizePeerDetail(a, &peer)
+ overview.Peers.Details[i] = peer
+ }
+
+ overview.ManagementState.URL = a.AnonymizeURI(overview.ManagementState.URL)
+ overview.ManagementState.Error = a.AnonymizeString(overview.ManagementState.Error)
+ overview.SignalState.URL = a.AnonymizeURI(overview.SignalState.URL)
+ overview.SignalState.Error = a.AnonymizeString(overview.SignalState.Error)
+
+ overview.IP = a.AnonymizeIPString(overview.IP)
+ for i, detail := range overview.Relays.Details {
+ detail.URI = a.AnonymizeURI(detail.URI)
+ detail.Error = a.AnonymizeString(detail.Error)
+ overview.Relays.Details[i] = detail
+ }
+
+ for i, nsGroup := range overview.NSServerGroups {
+ for j, domain := range nsGroup.Domains {
+ overview.NSServerGroups[i].Domains[j] = a.AnonymizeDomain(domain)
+ }
+ for j, ns := range nsGroup.Servers {
+ host, port, err := net.SplitHostPort(ns)
+ if err == nil {
+ overview.NSServerGroups[i].Servers[j] = fmt.Sprintf("%s:%s", a.AnonymizeIPString(host), port)
+ }
+ }
+ }
+
+ for i, route := range overview.Networks {
+ overview.Networks[i] = a.AnonymizeRoute(route)
+ }
+
+ overview.FQDN = a.AnonymizeDomain(overview.FQDN)
+
+ for i, event := range overview.Events {
+ overview.Events[i].Message = a.AnonymizeString(event.Message)
+ overview.Events[i].UserMessage = a.AnonymizeString(event.UserMessage)
+
+ for k, v := range event.Metadata {
+ event.Metadata[k] = a.AnonymizeString(v)
+ }
+ }
+}
diff --git a/client/status/status_test.go b/client/status/status_test.go
new file mode 100644
index 000000000..e48b441f5
--- /dev/null
+++ b/client/status/status_test.go
@@ -0,0 +1,607 @@
+package status
+
+import (
+ "bytes"
+ "encoding/json"
+ "fmt"
+ "runtime"
+ "testing"
+ "time"
+
+ "github.com/stretchr/testify/assert"
+ "github.com/stretchr/testify/require"
+ "google.golang.org/protobuf/types/known/durationpb"
+ "google.golang.org/protobuf/types/known/timestamppb"
+
+ "github.com/netbirdio/netbird/client/proto"
+ "github.com/netbirdio/netbird/version"
+)
+
+func init() {
+ loc, err := time.LoadLocation("UTC")
+ if err != nil {
+ panic(err)
+ }
+
+ time.Local = loc
+}
+
+var resp = &proto.StatusResponse{
+ Status: "Connected",
+ FullStatus: &proto.FullStatus{
+ Peers: []*proto.PeerState{
+ {
+ IP: "192.168.178.101",
+ PubKey: "Pubkey1",
+ Fqdn: "peer-1.awesome-domain.com",
+ ConnStatus: "Connected",
+ ConnStatusUpdate: timestamppb.New(time.Date(2001, time.Month(1), 1, 1, 1, 1, 0, time.UTC)),
+ Relayed: false,
+ LocalIceCandidateType: "",
+ RemoteIceCandidateType: "",
+ LocalIceCandidateEndpoint: "",
+ RemoteIceCandidateEndpoint: "",
+ LastWireguardHandshake: timestamppb.New(time.Date(2001, time.Month(1), 1, 1, 1, 2, 0, time.UTC)),
+ BytesRx: 200,
+ BytesTx: 100,
+ Networks: []string{
+ "10.1.0.0/24",
+ },
+ Latency: durationpb.New(time.Duration(10000000)),
+ },
+ {
+ IP: "192.168.178.102",
+ PubKey: "Pubkey2",
+ Fqdn: "peer-2.awesome-domain.com",
+ ConnStatus: "Connected",
+ ConnStatusUpdate: timestamppb.New(time.Date(2002, time.Month(2), 2, 2, 2, 2, 0, time.UTC)),
+ Relayed: true,
+ LocalIceCandidateType: "relay",
+ RemoteIceCandidateType: "prflx",
+ LocalIceCandidateEndpoint: "10.0.0.1:10001",
+ RemoteIceCandidateEndpoint: "10.0.10.1:10002",
+ LastWireguardHandshake: timestamppb.New(time.Date(2002, time.Month(2), 2, 2, 2, 3, 0, time.UTC)),
+ BytesRx: 2000,
+ BytesTx: 1000,
+ Latency: durationpb.New(time.Duration(10000000)),
+ },
+ },
+ ManagementState: &proto.ManagementState{
+ URL: "my-awesome-management.com:443",
+ Connected: true,
+ Error: "",
+ },
+ SignalState: &proto.SignalState{
+ URL: "my-awesome-signal.com:443",
+ Connected: true,
+ Error: "",
+ },
+ Relays: []*proto.RelayState{
+ {
+ URI: "stun:my-awesome-stun.com:3478",
+ Available: true,
+ Error: "",
+ },
+ {
+ URI: "turns:my-awesome-turn.com:443?transport=tcp",
+ Available: false,
+ Error: "context: deadline exceeded",
+ },
+ },
+ LocalPeerState: &proto.LocalPeerState{
+ IP: "192.168.178.100/16",
+ PubKey: "Some-Pub-Key",
+ KernelInterface: true,
+ Fqdn: "some-localhost.awesome-domain.com",
+ Networks: []string{
+ "10.10.0.0/24",
+ },
+ },
+ DnsServers: []*proto.NSGroupState{
+ {
+ Servers: []string{
+ "8.8.8.8:53",
+ },
+ Domains: nil,
+ Enabled: true,
+ Error: "",
+ },
+ {
+ Servers: []string{
+ "1.1.1.1:53",
+ "2.2.2.2:53",
+ },
+ Domains: []string{
+ "example.com",
+ "example.net",
+ },
+ Enabled: false,
+ Error: "timeout",
+ },
+ },
+ },
+ DaemonVersion: "0.14.1",
+}
+
+var overview = OutputOverview{
+ Peers: PeersStateOutput{
+ Total: 2,
+ Connected: 2,
+ Details: []PeerStateDetailOutput{
+ {
+ IP: "192.168.178.101",
+ PubKey: "Pubkey1",
+ FQDN: "peer-1.awesome-domain.com",
+ Status: "Connected",
+ LastStatusUpdate: time.Date(2001, 1, 1, 1, 1, 1, 0, time.UTC),
+ ConnType: "P2P",
+ IceCandidateType: IceCandidateType{
+ Local: "",
+ Remote: "",
+ },
+ IceCandidateEndpoint: IceCandidateType{
+ Local: "",
+ Remote: "",
+ },
+ LastWireguardHandshake: time.Date(2001, 1, 1, 1, 1, 2, 0, time.UTC),
+ TransferReceived: 200,
+ TransferSent: 100,
+ Networks: []string{
+ "10.1.0.0/24",
+ },
+ Latency: time.Duration(10000000),
+ },
+ {
+ IP: "192.168.178.102",
+ PubKey: "Pubkey2",
+ FQDN: "peer-2.awesome-domain.com",
+ Status: "Connected",
+ LastStatusUpdate: time.Date(2002, 2, 2, 2, 2, 2, 0, time.UTC),
+ ConnType: "Relayed",
+ IceCandidateType: IceCandidateType{
+ Local: "relay",
+ Remote: "prflx",
+ },
+ IceCandidateEndpoint: IceCandidateType{
+ Local: "10.0.0.1:10001",
+ Remote: "10.0.10.1:10002",
+ },
+ LastWireguardHandshake: time.Date(2002, 2, 2, 2, 2, 3, 0, time.UTC),
+ TransferReceived: 2000,
+ TransferSent: 1000,
+ Latency: time.Duration(10000000),
+ },
+ },
+ },
+ Events: []SystemEventOutput{},
+ CliVersion: version.NetbirdVersion(),
+ DaemonVersion: "0.14.1",
+ ManagementState: ManagementStateOutput{
+ URL: "my-awesome-management.com:443",
+ Connected: true,
+ Error: "",
+ },
+ SignalState: SignalStateOutput{
+ URL: "my-awesome-signal.com:443",
+ Connected: true,
+ Error: "",
+ },
+ Relays: RelayStateOutput{
+ Total: 2,
+ Available: 1,
+ Details: []RelayStateOutputDetail{
+ {
+ URI: "stun:my-awesome-stun.com:3478",
+ Available: true,
+ Error: "",
+ },
+ {
+ URI: "turns:my-awesome-turn.com:443?transport=tcp",
+ Available: false,
+ Error: "context: deadline exceeded",
+ },
+ },
+ },
+ IP: "192.168.178.100/16",
+ PubKey: "Some-Pub-Key",
+ KernelInterface: true,
+ FQDN: "some-localhost.awesome-domain.com",
+ NSServerGroups: []NsServerGroupStateOutput{
+ {
+ Servers: []string{
+ "8.8.8.8:53",
+ },
+ Domains: nil,
+ Enabled: true,
+ Error: "",
+ },
+ {
+ Servers: []string{
+ "1.1.1.1:53",
+ "2.2.2.2:53",
+ },
+ Domains: []string{
+ "example.com",
+ "example.net",
+ },
+ Enabled: false,
+ Error: "timeout",
+ },
+ },
+ Networks: []string{
+ "10.10.0.0/24",
+ },
+}
+
+func TestConversionFromFullStatusToOutputOverview(t *testing.T) {
+ convertedResult := ConvertToStatusOutputOverview(resp, false, "", nil, nil, nil)
+
+ assert.Equal(t, overview, convertedResult)
+}
+
+func TestSortingOfPeers(t *testing.T) {
+ peers := []PeerStateDetailOutput{
+ {
+ IP: "192.168.178.104",
+ },
+ {
+ IP: "192.168.178.102",
+ },
+ {
+ IP: "192.168.178.101",
+ },
+ {
+ IP: "192.168.178.105",
+ },
+ {
+ IP: "192.168.178.103",
+ },
+ }
+
+ sortPeersByIP(peers)
+
+ assert.Equal(t, peers[3].IP, "192.168.178.104")
+}
+
+func TestParsingToJSON(t *testing.T) {
+ jsonString, _ := ParseToJSON(overview)
+
+ //@formatter:off
+ expectedJSONString := `
+ {
+ "peers": {
+ "total": 2,
+ "connected": 2,
+ "details": [
+ {
+ "fqdn": "peer-1.awesome-domain.com",
+ "netbirdIp": "192.168.178.101",
+ "publicKey": "Pubkey1",
+ "status": "Connected",
+ "lastStatusUpdate": "2001-01-01T01:01:01Z",
+ "connectionType": "P2P",
+ "iceCandidateType": {
+ "local": "",
+ "remote": ""
+ },
+ "iceCandidateEndpoint": {
+ "local": "",
+ "remote": ""
+ },
+ "relayAddress": "",
+ "lastWireguardHandshake": "2001-01-01T01:01:02Z",
+ "transferReceived": 200,
+ "transferSent": 100,
+ "latency": 10000000,
+ "quantumResistance": false,
+ "networks": [
+ "10.1.0.0/24"
+ ]
+ },
+ {
+ "fqdn": "peer-2.awesome-domain.com",
+ "netbirdIp": "192.168.178.102",
+ "publicKey": "Pubkey2",
+ "status": "Connected",
+ "lastStatusUpdate": "2002-02-02T02:02:02Z",
+ "connectionType": "Relayed",
+ "iceCandidateType": {
+ "local": "relay",
+ "remote": "prflx"
+ },
+ "iceCandidateEndpoint": {
+ "local": "10.0.0.1:10001",
+ "remote": "10.0.10.1:10002"
+ },
+ "relayAddress": "",
+ "lastWireguardHandshake": "2002-02-02T02:02:03Z",
+ "transferReceived": 2000,
+ "transferSent": 1000,
+ "latency": 10000000,
+ "quantumResistance": false,
+ "networks": null
+ }
+ ]
+ },
+ "cliVersion": "development",
+ "daemonVersion": "0.14.1",
+ "management": {
+ "url": "my-awesome-management.com:443",
+ "connected": true,
+ "error": ""
+ },
+ "signal": {
+ "url": "my-awesome-signal.com:443",
+ "connected": true,
+ "error": ""
+ },
+ "relays": {
+ "total": 2,
+ "available": 1,
+ "details": [
+ {
+ "uri": "stun:my-awesome-stun.com:3478",
+ "available": true,
+ "error": ""
+ },
+ {
+ "uri": "turns:my-awesome-turn.com:443?transport=tcp",
+ "available": false,
+ "error": "context: deadline exceeded"
+ }
+ ]
+ },
+ "netbirdIp": "192.168.178.100/16",
+ "publicKey": "Some-Pub-Key",
+ "usesKernelInterface": true,
+ "fqdn": "some-localhost.awesome-domain.com",
+ "quantumResistance": false,
+ "quantumResistancePermissive": false,
+ "networks": [
+ "10.10.0.0/24"
+ ],
+ "forwardingRules": 0,
+ "dnsServers": [
+ {
+ "servers": [
+ "8.8.8.8:53"
+ ],
+ "domains": null,
+ "enabled": true,
+ "error": ""
+ },
+ {
+ "servers": [
+ "1.1.1.1:53",
+ "2.2.2.2:53"
+ ],
+ "domains": [
+ "example.com",
+ "example.net"
+ ],
+ "enabled": false,
+ "error": "timeout"
+ }
+ ],
+ "events": []
+ }`
+ // @formatter:on
+
+ var expectedJSON bytes.Buffer
+ require.NoError(t, json.Compact(&expectedJSON, []byte(expectedJSONString)))
+
+ assert.Equal(t, expectedJSON.String(), jsonString)
+}
+
+func TestParsingToYAML(t *testing.T) {
+ yaml, _ := ParseToYAML(overview)
+
+ expectedYAML :=
+ `peers:
+ total: 2
+ connected: 2
+ details:
+ - fqdn: peer-1.awesome-domain.com
+ netbirdIp: 192.168.178.101
+ publicKey: Pubkey1
+ status: Connected
+ lastStatusUpdate: 2001-01-01T01:01:01Z
+ connectionType: P2P
+ iceCandidateType:
+ local: ""
+ remote: ""
+ iceCandidateEndpoint:
+ local: ""
+ remote: ""
+ relayAddress: ""
+ lastWireguardHandshake: 2001-01-01T01:01:02Z
+ transferReceived: 200
+ transferSent: 100
+ latency: 10ms
+ quantumResistance: false
+ networks:
+ - 10.1.0.0/24
+ - fqdn: peer-2.awesome-domain.com
+ netbirdIp: 192.168.178.102
+ publicKey: Pubkey2
+ status: Connected
+ lastStatusUpdate: 2002-02-02T02:02:02Z
+ connectionType: Relayed
+ iceCandidateType:
+ local: relay
+ remote: prflx
+ iceCandidateEndpoint:
+ local: 10.0.0.1:10001
+ remote: 10.0.10.1:10002
+ relayAddress: ""
+ lastWireguardHandshake: 2002-02-02T02:02:03Z
+ transferReceived: 2000
+ transferSent: 1000
+ latency: 10ms
+ quantumResistance: false
+ networks: []
+cliVersion: development
+daemonVersion: 0.14.1
+management:
+ url: my-awesome-management.com:443
+ connected: true
+ error: ""
+signal:
+ url: my-awesome-signal.com:443
+ connected: true
+ error: ""
+relays:
+ total: 2
+ available: 1
+ details:
+ - uri: stun:my-awesome-stun.com:3478
+ available: true
+ error: ""
+ - uri: turns:my-awesome-turn.com:443?transport=tcp
+ available: false
+ error: 'context: deadline exceeded'
+netbirdIp: 192.168.178.100/16
+publicKey: Some-Pub-Key
+usesKernelInterface: true
+fqdn: some-localhost.awesome-domain.com
+quantumResistance: false
+quantumResistancePermissive: false
+networks:
+ - 10.10.0.0/24
+forwardingRules: 0
+dnsServers:
+ - servers:
+ - 8.8.8.8:53
+ domains: []
+ enabled: true
+ error: ""
+ - servers:
+ - 1.1.1.1:53
+ - 2.2.2.2:53
+ domains:
+ - example.com
+ - example.net
+ enabled: false
+ error: timeout
+events: []
+`
+
+ assert.Equal(t, expectedYAML, yaml)
+}
+
+func TestParsingToDetail(t *testing.T) {
+ // Calculate time ago based on the fixture dates
+ lastConnectionUpdate1 := timeAgo(overview.Peers.Details[0].LastStatusUpdate)
+ lastHandshake1 := timeAgo(overview.Peers.Details[0].LastWireguardHandshake)
+ lastConnectionUpdate2 := timeAgo(overview.Peers.Details[1].LastStatusUpdate)
+ lastHandshake2 := timeAgo(overview.Peers.Details[1].LastWireguardHandshake)
+
+ detail := ParseToFullDetailSummary(overview)
+
+ expectedDetail := fmt.Sprintf(
+ `Peers detail:
+ peer-1.awesome-domain.com:
+ NetBird IP: 192.168.178.101
+ Public key: Pubkey1
+ Status: Connected
+ -- detail --
+ Connection type: P2P
+ ICE candidate (Local/Remote): -/-
+ ICE candidate endpoints (Local/Remote): -/-
+ Relay server address:
+ Last connection update: %s
+ Last WireGuard handshake: %s
+ Transfer status (received/sent) 200 B/100 B
+ Quantum resistance: false
+ Networks: 10.1.0.0/24
+ Latency: 10ms
+
+ peer-2.awesome-domain.com:
+ NetBird IP: 192.168.178.102
+ Public key: Pubkey2
+ Status: Connected
+ -- detail --
+ Connection type: Relayed
+ ICE candidate (Local/Remote): relay/prflx
+ ICE candidate endpoints (Local/Remote): 10.0.0.1:10001/10.0.10.1:10002
+ Relay server address:
+ Last connection update: %s
+ Last WireGuard handshake: %s
+ Transfer status (received/sent) 2.0 KiB/1000 B
+ Quantum resistance: false
+ Networks: -
+ Latency: 10ms
+
+Events: No events recorded
+OS: %s/%s
+Daemon version: 0.14.1
+CLI version: %s
+Management: Connected to my-awesome-management.com:443
+Signal: Connected to my-awesome-signal.com:443
+Relays:
+ [stun:my-awesome-stun.com:3478] is Available
+ [turns:my-awesome-turn.com:443?transport=tcp] is Unavailable, reason: context: deadline exceeded
+Nameservers:
+ [8.8.8.8:53] for [.] is Available
+ [1.1.1.1:53, 2.2.2.2:53] for [example.com, example.net] is Unavailable, reason: timeout
+FQDN: some-localhost.awesome-domain.com
+NetBird IP: 192.168.178.100/16
+Interface type: Kernel
+Quantum resistance: false
+Networks: 10.10.0.0/24
+Forwarding rules: 0
+Peers count: 2/2 Connected
+`, lastConnectionUpdate1, lastHandshake1, lastConnectionUpdate2, lastHandshake2, runtime.GOOS, runtime.GOARCH, overview.CliVersion)
+
+ assert.Equal(t, expectedDetail, detail)
+}
+
+func TestParsingToShortVersion(t *testing.T) {
+ shortVersion := ParseGeneralSummary(overview, false, false, false)
+
+ expectedString := fmt.Sprintf("OS: %s/%s", runtime.GOOS, runtime.GOARCH) + `
+Daemon version: 0.14.1
+CLI version: development
+Management: Connected
+Signal: Connected
+Relays: 1/2 Available
+Nameservers: 1/2 Available
+FQDN: some-localhost.awesome-domain.com
+NetBird IP: 192.168.178.100/16
+Interface type: Kernel
+Quantum resistance: false
+Networks: 10.10.0.0/24
+Forwarding rules: 0
+Peers count: 2/2 Connected
+`
+
+ assert.Equal(t, expectedString, shortVersion)
+}
+
+func TestTimeAgo(t *testing.T) {
+ now := time.Now()
+
+ cases := []struct {
+ name string
+ input time.Time
+ expected string
+ }{
+ {"Now", now, "Now"},
+ {"Seconds ago", now.Add(-10 * time.Second), "10 seconds ago"},
+ {"One minute ago", now.Add(-1 * time.Minute), "1 minute ago"},
+ {"Minutes and seconds ago", now.Add(-(1*time.Minute + 30*time.Second)), "1 minute, 30 seconds ago"},
+ {"One hour ago", now.Add(-1 * time.Hour), "1 hour ago"},
+ {"Hours and minutes ago", now.Add(-(2*time.Hour + 15*time.Minute)), "2 hours, 15 minutes ago"},
+ {"One day ago", now.Add(-24 * time.Hour), "1 day ago"},
+ {"Multiple days ago", now.Add(-(72*time.Hour + 20*time.Minute)), "3 days ago"},
+ {"Zero time", time.Time{}, "-"},
+ {"Unix zero time", time.Unix(0, 0), "-"},
+ }
+
+ for _, tc := range cases {
+ t.Run(tc.name, func(t *testing.T) {
+ result := timeAgo(tc.input)
+ assert.Equal(t, tc.expected, result, "Failed %s", tc.name)
+ })
+ }
+}
diff --git a/client/system/info.go b/client/system/info.go
index d83e9509a..2a0343ca6 100644
--- a/client/system/info.go
+++ b/client/system/info.go
@@ -9,7 +9,6 @@ import (
"google.golang.org/grpc/metadata"
"github.com/netbirdio/netbird/management/proto"
- "github.com/netbirdio/netbird/version"
)
// DeviceNameCtxKey context key for device name
@@ -119,11 +118,6 @@ func extractDeviceName(ctx context.Context, defaultName string) string {
return v
}
-// GetDesktopUIUserAgent returns the Desktop ui user agent
-func GetDesktopUIUserAgent() string {
- return "netbird-desktop-ui/" + version.NetbirdVersion()
-}
-
func networkAddresses() ([]NetworkAddress, error) {
interfaces, err := net.Interfaces()
if err != nil {
diff --git a/client/ui/bundled.go b/client/ui/bundled.go
deleted file mode 100644
index e2c138b14..000000000
--- a/client/ui/bundled.go
+++ /dev/null
@@ -1,12 +0,0 @@
-// auto-generated
-// Code generated by '$ fyne bundle'. DO NOT EDIT.
-
-package main
-
-import "fyne.io/fyne/v2"
-
-var resourceNetbirdSystemtrayConnectedPng = &fyne.StaticResource{
- StaticName: "netbird-systemtray-connected.png",
- StaticContent: []byte(
- "\x89PNG\r\n\x1a\n\x00\x00\x00\rIHDR\x00\x00\x01\x00\x00\x00\x01\x00\b\x06\x00\x00\x00\\r\xa8f\x00\x00\x00\xc3zTXtRaw profile type exif\x00\x00x\xdamP\xdb\r\xc3 \f\xfc\xf7\x14\x1d\xc1\xaf\x80\x19\x874\xa9\xd4\r:~\r8Q\x88r\x92χ\x9d\x1cư\xff\xbe\x1fx50)\xe8\x92-\x95\x94СE\vW\x17\x86\x03\xb53\xa1v>@\xc1S\x1dNɞų\x8c\x86\xa5\xf8\xeb\xa8\xd3d\x83T]-\x17#{Gc\x9d\x1bEGf\xbb\x19\xc5E\xd2&b\x17[\x18\x950\x12\x1e\r\n\x83:\x9e\x85\xa9X\xbe>a\xddq\x86\x8d\x80F\x92\xbb\xf7ir?k\xf6\xedm\x8b\x17\x85y\x17\x12t\x16\xd11\x80\xb4P\x90\xdaE\xf5\xf0\xa1\xfc#u-\x92:[L\xe2\vy\xda\xd3\x01\xf8\x03\xda\xd4Y\x17ݮ\xb7\xee\x00\x00\x01\x84iCCPICC profile\x00\x00x\x9c}\x91=H\xc3@\x1c\xc5_S\xa5\"-\x0e\x16\x14\x11\xccP\x9d\xec\xa2\"\xe2T\xabP\x84\n\xa5Vh\xd5\xc1\xe4\xd2/hҐ\xa4\xb88\n\xae\x05\a?\x16\xab\x0e.κ:\xb8\n\x82\xe0\a\x88\xb3\x83\x93\xa2\x8b\x94\xf8\xbf\xa4\xd0\"ƃ\xe3~\xbc\xbb\xf7\xb8{\a\b\x8d\nSͮ\x18\xa0j\x96\x91N\xc4\xc5lnU\f\xbcB\xc0\x00B\x18\xc1\xac\xc4L}.\x95J\xc2s|\xdd\xc3\xc7\u05fb(\xcf\xf2>\xf7\xe7\b)y\x93\x01>\x918\xc6t\xc3\"\xde \x9e\u07b4t\xce\xfb\xc4aV\x92\x14\xe2s\xe2q\x83.H\xfc\xc8u\xd9\xe57\xceE\x87\x05\x9e\x1962\xe9y\xe20\xb1X\xec`\xb9\x83Y\xc9P\x89\xa7\x88#\x8a\xaaQ\xbe\x90uY\xe1\xbc\xc5Y\xad\xd4X\xeb\x9e\xfc\x85\xc1\xbc\xb6\xb2\xccu\x9a\xc3H`\x11KHA\x84\x8c\x1aʨ\xc0B\x94V\x8d\x14\x13iڏ{\xf8\x87\x1c\x7f\x8a\\2\xb9\xca`\xe4X@\x15*$\xc7\x0f\xfe\a\xbf\xbb5\v\x93\x13nR0\x0et\xbf\xd8\xf6\xc7(\x10\xd8\x05\x9au\xdb\xfe>\xb6\xed\xe6\t\xe0\x7f\x06\xae\xb4\xb6\xbf\xda\x00f>I\xaf\xb7\xb5\xc8\x11з\r\\\\\xb75y\x0f\xb8\xdc\x01\x06\x9ftɐ\x1c\xc9OS(\x14\x80\xf73\xfa\xa6\x1c\xd0\x7f\v\xf4\xae\xb9\xbd\xb5\xf6q\xfa\x00d\xa8\xab\xe4\rpp\b\x8c\x15){\xdd\xe3\xdd=\x9d\xbd\xfd{\xa6\xd5\xdf\x0fںr\xd0VwQ\xba\x00\x00\rxiTXtXML:com.adobe.xmp\x00\x00\x00\x00\x00\n
\n \n \n \n \n \n \n \n \n \n\n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n\xf0C\xff\xd9\x00\x00\x00\x06bKGD\x00\xff\x00\xff\x00\xff\xa0\xbd\xa7\x93\x00\x00\x00\tpHYs\x00\x00\v\x13\x00\x00\v\x13\x01\x00\x9a\x9c\x18\x00\x00\x00\atIME\a\xe8\x02\x17\r$'\xdd\xf7ȗ\x00\x00\x13;IDATx\xda\xed\x9d]o\x14W\x9a\xc7\xff\xa7\xaamh\xbf\xc46,I`\x99\xa1\xc3\ni\xb5{1\x95O0\xe4\x1b\xc0'X\xf2\t`.W`hp\xa2\xb9\fH{O\xa3\xcc\xc5\xecJ3q\xa4\x1d\xed\xcdJx>Aj/\"EBJګL \xb1\x00g\xf1\v\xb6\xbb\xeb\xec\x85mb\f\xb6\xfb\xa5^Ω\xfa\xfd\xee\x928v\xf7\xa9z\xfe\xcfs\x9e\xa7ο$\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\u0603a\t\xc0g\xd6\x7f\x1f5\x92\x8e\"k\xd4\b\xa4s\xb2jH\x9afez\n\xfe\xdb\b\x00x\x81mF\xd3/CE]\xa3(\x94~c\xa5\x8b;\xc1\x0e\x83E\x7f{\xecF\xfcA\x8d\x95\x00g\xb3\xfb\\tQ\xd2o\xadtq]\xba(I\x81\x95,K345\xa3˒\x84\x00\x80SY~5Х0\xd0o\x13\xabK\x96R>\x9b\xe4o\xd4\x1a\xbd\x1e\xc7\b\x008\x93\xe9\xadtkM\x8a\x02i\xdaZ\x9aS\x99\x12\xea\xf6\xabJ\x80Հ\x02\xf7\xf4W\x13\xe9\xdan\xa6'\xe8sXw\xe9\xf6ؿ\xc6\xed_Z\x01\x00\x05d{\xed\xec\xe9!\xcf\xda\x7f\xbb\xf1\xf7Z/\x80U\x81<\x03\xdf\x12\xf8\xc5\xc5\x7f\xf2K\xe9O\x05\x00d\xfcje\xffx\xecF\xfc\xe1\xfe\x7fM\x05\x00\xd9\x04~3j$5ݲVWX\r\a\xe2?\xdc\x1e\xfb!\x00\x909ks\xd1\xd5Dj\x1a\xcb\x18ω\xe07j\xd5\xf74\xfe\x10\x00Ȅ\x95O\xa3(Ht_R\xc4\xdeҙҿ\xbdw췟\x80\x15\x82\xb4\xb2~\x90\xe8+I\x11\xab\xe1\x0e\xd6\xea\xc1A\xd9\x7f[\x1f\x00\x86\xdc\xeb\xdbP\xf7E\x93\xcf\xc9\xec\xbf\x7f\xec\xc7\x16\x00\xd2\v\xfe\xb9\xe8b\"}axd\xd7\xcd\xf8O\x0e.\xfd\xd9\x02\xc0\xb0\xc1\x7f\xcbJ\x0f\t~G\t4_\xbf\x19\xb7\x8e\xfa1*\x00\xe8\x9b\xd5O\xa2\xfb\x8c\xf7\x1c\xcf\xfe\x81~\xd7\xcb\xcf!\x00\xd03\xb6\x19M\xaf\x87\xfaB\x96\xfd\xbe\xd3\xc1\x7f\xc8؏-\x00\fV\xf27\xa3\xc6z\xa8\x87\xa2\xd9\xe7x\xf4\x1f>\xf6\xa3\x02\x80\x81\x82\xdf\xd6\xf4\x10\a\x1e\x0f\xe2?\xd1\xed\xfa\x8d\u07b2\xff\xb6^\x00\x10\xfc\xa5\xc9\xfeG\x8d\xfdJW\x01\xd8f4\xfd\xf2ؾNt\xe7\xe8\x9b5\f\xb4\xdc\r\xb4\xbc\xfb\xcf\xc77\xb4l\x9a\xf12wѾ=?\xc1\xef\xcf\xf5\xb2\xbd5\xfe\x9c\xac\x00\xd6\x7f\x1f5\xc2D\xd3[\x89\x1a\xd6jZ\x81\xa6\x8d\xd5t`tn\xe7\xcb5$M\xcb\xec\x04{\x867\xa5\x95\x96\x8dѲ\xacvK\xa9ec\xb4\x9cX-Z\xa3ec\xd5\x0e\xa4e\xd5\xd4\xee\xb5\xd9\xe2#ks\x11O\xf6\xf9\x92\xfc\x8dZ\xf5\x1b\xf1\xc7N\n\x80mF\xd3[#jlv\x15)\xd0\xf4\x1e\xfb憌\xa6\xbd\xcf0F\xed\x1d\xb1X\xb6\xd2\xffX\xabvh\xd4>\xdeU\xeckU\xb1v'\xfaLF\xd7\b-On\xc1\x9a>\x18$\x19\x99,\x82<\b\xf4\x1b\xb3\xed\xed\x16Y\xa9Q\xe5\x87E\xac\xb4l\xa4XFq\"-\x86V\xb1\xeb°\xf3\x90O\x93\xb0\xf2\xe6\x1e\xbb=>\x1b\x0ft\xbd\xcc\xc0\x81nuq'\x93Gv\xfb\xf4\x17O\x84\xf5G,\xa9\x9d\x18\xfd5\xb4\x8a\xeb\xb3\xf1\x82\v\x1fju.\xbad\xa4/\xb8<\xfeT\x9f\xf5\x8e>\x1c4\xa1\x98\xa3\x82}-\xd4Ek\xd4\xe0e\f\xb9\xb0 \xa3\xd8Z\xfdu\xac\xab\x85\xbc\xab\x04:\xfe\x1eƿ\xd5ǽ<\xf2ۓ\x00l~\x1aE\x9bV\x17\tv\x87\xaa\x04\xa3\x05c\xf5e\x1e\x15\xc2\xda'\xd1w\\s\xbf\xb2\x7f\xbfc\xbf7~\xc5\xda\\tU\xd2%\xcax/z\t\v\x89\u0557\xe1\x88\x16Ҟ>\xb0\xef\xf70\xfe\al\xfc\xbd\xf6;V>\x8d\"\x93p\xaa\xcb\xc7\xedBb\xf5 \r1\xd89\xd3\xff\x1dK\xeaQ\xf0\x0f8\xf6{\xeb\x16`ǹ\xf5!\xcbZM1X\x9d\x8b\xbe2\xcc\xfb\xbd\xaa\x06\x83\x9a>L\xa3\n|\xd5\x03X\xbf\x13]\xb1F\xf7Y^\uf677҃\xf1\xd9x\xbe\x97\x1f^\xb9\x13]\t\xb8\xee\xbe\t\xc0\xc0c\xbf\x03\x05\x00\x11([\x8d\xa8\xb6\x12͛\x11\xdd;,S\xd0\xf8\xf3\xef\xba\x0e\xdb\xf8\xdb\xcbkǁ\xeb7\xe3\x96U\xefG\t\xc1\xe94ѐ\xd15\xdb\xd1w\xab\x9fD\xf7w^\xb5\xfd\xfa\xde\x7f.\xbaE\xf0{\x16\xffI\xba\xf1i\x0e\xd8\x136\xcd\xf6\xdb\\\xa0d\xd9#It{\xe2fܢ\xf1\xe7\xe5\xf5[\x18\xbb\x11\x7f\x94\xb9\x00 \x02\x15\x10\x82\xe7\x13\xed`z\xe5\"\x8b\xe1\xd1eKa\xec׳\x00 \x02%\xde\x1dlִ\xf5\xf5\xafeF;\nO?Wp\xe2\x05\x8b\xe2z\xf0\xa74\xf6;\xb4\a\xb0\x9f\xf1ٸ\x99H\x0fX\xfer\xd1yt\xe6\x95\x10t\x16Oi\xeb\xeb_+y6\xc9\xc28\\\xb1\xf5c\xf3\x95\x9a\x00H\xd2\xc4l|\x05\x11(\x0fɳI\xd9\xcd\xda\x1b\x15\xc1\xae\x10ؕ:\x8b\xe4\x1e\xf7\xb2\xf2\x9d\xe8\xf94\xe0\xfa\\\xf4\x90w\xbb{^\xfaw\x03u\xbe9\xfb\x86\x00\xbc\x91\x15N\xac(<\xfd\\ft\x8bEs \xfb\xa79\xf6\xeb\xbb\x02\xd8\xe5eW\x97\xed\xf6\x11V\xf05\xfb\xff4ud\xf0oW\t\x13\xda\xfa\xfaW\xea>\x9eaъ\x8e\xff$۱|_~\x00ϛ\xd1\xf4h\xa8\x87<6\xeaa\xf6\xdfi\xfc\xf5}\x83\x8cvT\xbb\xf0\x98j\xa0\x88\xe0Ϩ\xf17P\x05 I3\xcdx9\xe8게\xda\\\x1e\xbf\x184\x9bo\vǯ\xd4\xfd\xfe\xa4l\x97\xd7H\xe4J\x98\xfdCy}_\xd1z3n\x9b\x8e>B\x04<\xca\xfe+\xf5\xa1\xbb\xfcݥ\xa9\x9d\xfe\xc1\b\v\x9a\xc75\x93n\xe7a8;\x90\xa4#\x02~\xd1Y<\x95\xe26\x82\xde@\xf6\xb5\xbf\xdaAM\xad<\xfe\xd4\xc05\x1d\"\xe0\ao\x1b\xfb\r\xbd\x9dx2\xa3-\xaa\x81\xec\xe2?\xc9'\xfbok͐`(\xe2p\x19\xb9YS\xe7љ\xd4\x05\xe0\xd5\xcd3\xdaQx\xf6\xa9\x82\xa9U\x16;\xc5\xec\x9f\xe5\xd8/\xb5\n`\x97\x89\xebql\x03}d%ު\xe3\x18\xdd\xc73\x99\x05\xff+\x81\xf9\xf6=\xb6\x04)R3\xba\x9c\xe7\xdfK\xa5\xad;q=\x8e%}\xcc\xe5s+\xfb\xe7\xf5xo\xf7Ɍ:\x8b\xef2%\x186\xf9\x1b\xb5F\xb7c\xc9/\x01\x90\xa4\xf1\xd9x\xdeXD\xc0\xa5\xec\x9fo\xafa\x82)\xc1\xb0\x84\xf9{q\xa4*\xd9\xf5\x9bqK\xa6\xff\x17\x14B\xda\xc18Y\xc8\xe1\x9e\xed\x9e\xc3iD`\x90\xb5S~\x8d\xbf\xd7[\x0e\x19\xc01\xe2b\xd9\xfa\xfaי\xee\xfd\x8f\xced\x89F.<\x96\xa9op1z\x8b\xc2\\\x1b\x7f\x99U\x00{\xb6\x03M\xac\xc5\n*\xfd{|\xde?\xdb\x0f\x11h\xeb\xd1i%?\x8fsAz\x89\xff\xa4\xb8X\xc9\xf4\xed\xc0T\x02E\x94\xe0g\x8a\x17\x80\xbd\xc5\xc0\xb9%\x85\x18\x8e\x1c\x16\x81\xf1؍\xf8â\xfe|\xa6m\xdb\x1dC\x91{\\\xe5\x9c\x12o\xc6c\xbf\x81>\xd3\xe2)u1\x1b98\xfe\xc3|\xc7~\xb9\n\x80$M\xcc\xc6\xd70\x14\xc9'\xfb\xbb\xea\xea\x83\b\x1c\x10\xfcF\xad\"\x1a\x7f\xb9\n\xc0\x8e\b\xe0*\x941\x9do\xdfw\xbb:A\x04\xf6\x97\xfe\xed\"\xc6~\x85\b\x80$muu\rC\x91lH\x9eMʮ\x8f\xba\xbfE\xf9\xfe\xa4\xec\xfa1.\x98$k\xf5\xa0\xe8쿭C9\x82\xa1HF\xe2Z\xf4د\x1f\xc2D#\xff\xf8\xb7j\x1b\x8c\x148\xf6+\xac\x02\x90\xb6\rE6\xbb\x9c L5\xab:\xd8\xf8;\xfc\x03\a\x95\x7fX\xa8ȱ_\xa1\x02\xb0+\x02\x1c#N\xa9\x8cܬ\xa9\xfbd\xc6\xcb\xcf\xdd\xf9\xf6\xbdJ\x9e\x1d0F\xad\xfa\u0378UY\x01\x90\xf0\x12H3\xfb{+^\xeb\xa3\xea~\xffwջh\xa1[\x0f\xc8\x15&\xc1\x88\xc0\xb0\x01t\xcc\xfb\x97y$\xcf&*u\x94\u0605\xb1\x9f3\x02\xb0W\x04\xf0\x12\xe8\x9fη\uf563\x8ay2S\x8d\x97\x9182\xf6sJ\x00vE\x00C\x91~3\xe7\xa4_\x8d\xbf#\xd8\xfa\xf6\xbd\xd27\x05\xf3\xb4\xf9\xeaO\x97\x1c\x01k\xb1\x1eK\x7f\a\x9f\xf7O\xe5F\x9cx\xa9\x91\v?\x946\xfb\xbb2\xf6s\xae\x02\xd8e\xe2z\x1c\a\x16/\x81#\xb3\xff\xd3\xc9\xd2\x05\xbf$ٕ\xe3\xea.M\x953\xfe\x1d6\xcaqj\x0eS\xbf\x19\xb7p\x15:<\xfb\xfb8\xf6\xeb\xb9\x1f\xf0\xfd\xc9\xd2\xf5\x03\x8cQ\xab>\x1b/ \x00}\x88\x00\xaeB\a\x04H\x05:\xe6\x9d\xc5S\xe5z> t\xdb\x17\xc3ɕ\x1e\xbb\x11\xdf\xc5Pd_\xf6\xffy\xdc\xfb\xb1_\xafUNR\x12\xa1+\xca\xe6\xcb{\x01\x90p\x15z#3~\x7f\xb2:\x95\xceҔ\xff[\x01\xa3\xf6XWw]\xff\x98N\xd7Z\x88\xc06e\x1b\xfbUA\xf0L\xa2ۦ\x19;?\xda6>,\xe6\xca\\t7\x90\xaeV\xb2\xf4/\xe9د\xa7\xed\xf3٧\nO\xfd\xecg\xf6wt\xec\xe7U\x05\xb0K\x95]\x85\xbc;\xed\x97\xf6w\xf7\xb0!hB}\xe4\xcbg\xf5fu\xab\xe8*\xe4\xb2\xcdW>\n\x10xw`\xc8\xc5\xe7\xfdK!\x00R\xf5\\\x85*yZn\x1fɳ\to\x1a\x82VZv}\xec\xe7\xb5\x00\xec\x1a\x8aTA\x04\x92g\x93J~\x1e\x13H\x1d\x7fƂ\xf7|\xca\xfe\x92'M\xc0\xfd\xac7\xa3\x86\xad顬\x1ae\xbd齲\xf9ʁ\x91\v\x8fe&\xd6]\x8e$o\x1a\x7f\xdeV\x00\xbb\x94\xddK\xa0ʍ?_\xab\x00\x97l\xbeJ/\x00e\x16\x01\xbbYSR\xd2C1C\xad\xcb\xcaqw{\x01\x81\xe6]\xb2\xf9\xaa\x84\x00\x94U\x04|\x1d}U\xb9\n0\x81\xbfgW\xbc\xbf\xd3\xca\xe4*T\xf9\xb1\x9f\x87U\x80oc\xbf\xd2\t\xc0\xae\b\x94\xc1U\xa8\xf3\xe8\fQ\xeeS\x15\xe0\xa8\xcdW\xe5\x04@\xda1\x14Q\xb1/Z\x1c\x86*>\xef\xef{\x15\xe0\xaa\xcdW%\x05@\x92\xea\xb3\U000423c6\"\xb6\x1bT\xca\x1d\xb7\x14U\x80Q\xdb\xd7\xc6_i\x05@\xf2\xd3U(\xf9i\x8a\xec\xdfo\x15P\xb0\x89\xa8-\x89}])\xdb\xcd\xf5\x9bq˗c\xc4e\xb7\xf9\xcan\xcb4Q\\\xf27j\x8d\xcf\xc6\xf3\b\x80\xc3\xf8\xe2%@\xe9?\xe0\xba\xfd4Uܸ4,\x8fGE\xa9\aή\x8b\x80]\xa93\xf6\x1bX\x01\x82B\xd6\xce\a\x9b/\x04\xc0\x13\x11\xe8,\x9e\"\x90\x87\xd9\x06,\x8f\xe7\\\xfb\xab\x1d\xd4\xd4*\xd3\x1aV⑳\xf1ٸ隗\x00c\xbf4*\xa8\xe3\xb2\xeb\xc7\xf2\x8b\xff\xa4\\ٿ2\x02 \xb9e(b7k\xec\xfd\xd3\x12Ҽ\x8eL\x97d\xecWY\x01\xd8\x15\x01#-\x14~\xd32\xf6K\xaf\x15\xf0S>\a\xa7j\xc6߇\xcc\x10\x80=\xbc\xec\xear\x91\x86\"v\xb3V\xdaW`\x15\xa3\x00A\xe6O\x06\x1a\xa3\xd6\xe8\xf58F\x00J@ѮB\x94\xfe\x19\xaci\xd6ۀ\xb0\xbc\xd6\xf4\x95
@X\xbd\xec\x8f\x00d \x02\x8c\xfd\x1c\xe8\x03\xf4)\x02I\xc5\x1a\x7f\b\xc0\x80\"Ћ\xa1\bo\xf7q\xa1\x0f\xd0\xc76\xc0\xa8\x1d\xd6t\xb7\xaak\x85\x00\xf4\xc1Q\xaeB\xd8|\xb9\xb2\r\xe8\xdd&\xac\x8c6_\xfd`\xb8]\xfagu.\xfa\xcaH\xd1k7]7P盳\b\x80\v7\xf5hG#\xff\xfc\xbf=e\xff\xb1\x1b\xf1\aU^+*\x80\x01x\x9b\xa1\b6_\x0eU\x00\x9b\xb5\x9e\x9a\xb0>\xbeF\x0e\x01p\x80]W\xa1\xdd\x13\x84\xbc\xdd\xc7A\x8e\xd8\x06\x18\xa3V}6^@\x00``\x11\xd8=F\xcc\x13\x7f\xee\x91\x1c5\t\xa8\xe8\xd8\x0f\x01H\x91z3n'\x1b\xb5{\xae\xbc\xae\x1a\xf6l\x03\x0e\xa9\x00\xcan\xf3\x85\x00乀\xc7:Wk\x17~\x90\u0084\xc5pI\x00\xd6\x0e\xa8\x00\x8c\xdac\xdd\xea\x8e\xfd\x10\x80\x14Y\x9b\x8b\xaeʪaF;\x1a\xb9\xf0\x18\x11pI\x00\x0ehȚD\xb7M3^f\x85\x10\x80\xa1XoF\r\x19]{uc\xd574r\xfeG\x16\xc6\x15\xba\xc1\x9b\x0eA%}\xbb\x0f\x02P\x00IM\xb7d\xd5x\xed\xfe\x9aXWxn\x89\xc5q\xa6\x0f\xf0\xfa6\xc0\x84\xfa\x88UA\x00R\xc9\xfe\xc6\xea\xca\xdb\xfe[x\xe2\x05\"\xe0\xe06\xa0J6_\b@\xd6\xd9?\xd4\x17\x87\xfd\xf7\xf0\xc4\v\x85\xa7\x9f\xb3P\x8e\b\x80\x95\x96\x19\xfb!\x00\xa9\xb0r'\xba\xb2\xff1්\xc0\xfb\xcf\x11\x01w*\x80{d\x7f\x04 \x9d\x05\vt\xabןE\x04\nf\xedX%\xde\xee\x83\x00\xe4\xb5\xf7\x9f\x8b\xdeh\xfc!\x02\x0eW\x00ݠ\x926_\xfd\xc0i\xc0^\x83\xbf\x195l\xa8\xef\x06\xfd\xff;\x8b\xa70\n\xc9\xfb\xe6~g\xbd=\xd5\xfa\xaf\x0fX\t*\x80\xa1Ij\xbd\x97\xfeo\xa3vnI\xc1\x89\x17,d\x8e\x84'V\xc8\xfeT\x00\xc5g\xff\xbdl=:û\x02\xf2\xc8lc\x9b\xf3\xef\xfc\xe1?/\xb3\x12T\x00\xc3\xef%kz\x98\xd6\xef\x1a9\xffD\xa6\xbeɢf\x99\xd5F;\xeav&~\xc7J \x00C\xb3r'\xba\xd2o\xe3\xef\xf0\xba4\xd1ȅ\x1f\x10\x81,K\xff\xa9\xf5\xd6\xcc\x1f\xff\xd8f%\x10\x80\xa1K\xff~\xc6~\xfd\x88@\xed\xfc\x13\x99\xd1\x0e\x8b\x9c>\xed\xb0\xb1\xc4\xde\x1f\x01H#P\xf5/\xa9f\xff}ej\xed\xc2\x0f\x88@\xda7\xf4\xf4\x1ag\xfd\xfb\xb9\x0fY\x82\x83\xb3\x7fZ\x8d\xbfC\xfb\v\x9b5u\x1e\x9d\xc1O0\x8d\x9b\x99\xb1\x1f\x15@Z\f;\xf6\xa3\x12(\xe0f\x9e\xd8\xfa\x98U@\x00\x86fu.\xbat\xd0i\xbf\xccD\xe0\xfc\x8f\x18\x8a\fs#\x8fo\xb4&\xff\xed\xbf\x17X\t\x04`\xf8\x804\xfa,\xf7\xbfY\xdf\xc0Uh\b\x01\x1d9\xb9J\xe3\x0f\x01\x18\x9e\xd4\xc7~}\x8a@\r/\x81\xfe\x19\xedܮ\xdf]h\xb3\x10\b\xc0Pd6\xf6\xeb\xe7\x82L\xadb(\xd2\x1f\xedw\xce\xff\x80\xc9'\x020\xff\v\x8d?*\x80\xe1qe\xec\xd7\x0f\xb5\xb3O+\xed%\x10\x9e\xfa?J\x7f*\x80\x94\xb2\xbfcc\xbf\x9e\xe9\x06\xdb\xd6b\xeb\xa3\xd5\xcaV#\xc9\xfc;\xff>\x8f\xcd\x17\x15@\n\xd9\xff\x88\xb7\xfb\xb8\x9d\x06w\\\x85*t\x82Ќv\xd45DZ\xf9B\x00\x86\xa7\u05f7\xfb\xb8.\x02U:F\x8c\xcd\x17\x02\x90ޗv|\xec\xd7OV\xac\x88\b`\xf3\x85\x00\xa4\xb4\xf7\xf7d\xec\x87\b\xfc\x82\x95\xb0\xf9\xca\xea\xfe\xa9T\xf0\xfb\xdc\xf8;*H6k\xda\xfa\xe6\xac\xd4-\x97\xa6\xf3\xbc?\x15@j\xe4e\xf3UT%PFC\x11l\xbe\xa8\x00Ra\xe5\xd3(\n\x12}U\xf6\xefi\u05cfi\xeb\xd1\xe9RT\x02\xc1\xf8F\xeb\x9d\xcf\xff\x82\x00P\x01\xa4\xf0E\xad\xc7c\xbf~\x14\xbd\xbeQ\n/\x013\xdaQwk\x92\xc6\x1f\x02\x90B\xf6/\xd0\xe6\xab\b\xc2\x13/\xfcw\x15\x1a\xed\xdcf\xec\x87\x00\f\x8d\v6_\x85\x89\x80\xbf^\x02\xd8|!\x00)\xed\x89\x03]\xadR\xf6\x7fM\x04<5\x14\t\xde_\xc6\xe6+\xaf\xadVٳ\x7fY\xc7~\xfd\xe0\x93\xb5\x18c?*\x80\xd4(\xf3د\xac\x95\x006_\b@*\xac܉\xae\xe4\xf9v\x1f/D\xc0qW\xa1`|\x83\xe7\xfd\x11\x80\x94\xbeX@\xf6\x7fC\x04\xdcv\x15\xc2\xe6\v\x01H\x87\xb5\xb9\xa8\xb2\x8d\xbf\xa3\xa8\x9d[\x92\x99x\xe9\xde\xde\x1f\x9b\xafbֽl_\xc8G\x9b\xaf\xdcq\xccP\xc4Ԓ\xf6\xd4\x7f\xcc\xd3\xf8\xa3\x02\x18\x1e\x1fm\xbe\xf2\xdf\v\xec\x18\x8a8b-\x16\x9e}J\xe9O\x05\x90R\xf6g\xec\xd73v\xb3\xa6Σ3\xb2\x9bŽ\x1e\x02\x9b/*\x80\x14S\x89\xeesI\xfbP\xff\x82\xbd\x04\xb0\xf9B\x00Rc\xe5Nt\xc5J\x17\xb9\xa4\xfe\x88\x80\x19\xe92\xf6C\x00R\xfa\"\x8c\xfd|\x13\x01\xc6~\b@J{\xff\x92\xd9|\x15&\x02\xe7\x7f\xcc\xcdP\x04\x9b/G\xae\xbb\xf7\xc1ߌ\x1aI\xa8\xaf\x8c4\xcd\xe5L!0s0\x14\xe1\xed>T\x00\xa9\x91\xd4t\x8b\xe0O18\xeb\x1b\x1a9\xffc\xb67\xdd\xd4\x06.?T\x00\xe9d\x7f\xc6~\xd9\xd0}6\xa9\xee\xe2\xa9\xf4o8cZS\x7f\xfa\x13\x02@\x05\x90B\xb9Z\xd3C.a6d\xe1*dF;JFFh\xfc!\x00\xc3S5\x9b\xaf\xc2D \xcdc\xc4\xd8|!\x00\xa9d\xfef4\xcd\xd8/'\x11H\xcfK\x00\x9b/\x04 \x1d^\x86ⴟg\"\x10\x1c\xdf\xc2\xe6\xcbA\xbck\x02\xd2\xf8+\x8eA\xadŰ\xf9\xa2\x02H\rl\xbe\x8a\xad\x04\x061\x14\xc1\xe6\v\x01H\x85չ\xe8\x126_\xc5R;\xb7ԗ\b`\xf3\x85\x00\xa4\xb7_1\xfa\x8cK\xe6\x86\b\xf4\xe4%\x10&\xcb#'W\x19\xfb!\x00\xc3\xc3\xd8\xcf-z1\x141\xf5\xcd{\xf5\xbb\vd\x7f\x97\x93\xaa\x0f\x1f\x12\x9b/G\xe9\x06\xda\xfa\xe6\xecA\x86\"\xed\xe9?\xff\x99\xc6\x1f\x15\xc0\xf0`\xf3\xe5(ar\xe01\xe2Zc\x89ҟ\n \xa5\xec\xcf\xd8\xcfi\xf6[\x8ba\xf3E\x05\x90\xde\xcd\x15\xd2\xf8s>\x8b\xec1\x141a\x82͗G\xd4\\\xfep\xabs\xd1%\x19E\x92\xda\\*\xc7E\xe0XG#\xff\xf0X\x1b\x8b\xa7\xbe\x9c\xf9\xc3<\xd7\v\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00|\xe4\xff\x01\xf6P(\xf3)+S\x1f\x00\x00\x00\x00IEND\xaeB`\x82"),
-}
diff --git a/client/ui/client_ui.go b/client/ui/client_ui.go
index 9ed40b0be..51eec59a5 100644
--- a/client/ui/client_ui.go
+++ b/client/ui/client_ui.go
@@ -33,7 +33,7 @@ import (
"github.com/netbirdio/netbird/client/internal"
"github.com/netbirdio/netbird/client/proto"
- "github.com/netbirdio/netbird/client/system"
+ "github.com/netbirdio/netbird/client/ui/desktop"
"github.com/netbirdio/netbird/client/ui/event"
"github.com/netbirdio/netbird/util"
"github.com/netbirdio/netbird/version"
@@ -148,22 +148,24 @@ type serviceClient struct {
icError []byte
// systray menu items
- mStatus *systray.MenuItem
- mUp *systray.MenuItem
- mDown *systray.MenuItem
- mAdminPanel *systray.MenuItem
- mSettings *systray.MenuItem
- mAbout *systray.MenuItem
- mVersionUI *systray.MenuItem
- mVersionDaemon *systray.MenuItem
- mUpdate *systray.MenuItem
- mQuit *systray.MenuItem
- mRoutes *systray.MenuItem
- mAllowSSH *systray.MenuItem
- mAutoConnect *systray.MenuItem
- mEnableRosenpass *systray.MenuItem
- mNotifications *systray.MenuItem
- mAdvancedSettings *systray.MenuItem
+ mStatus *systray.MenuItem
+ mUp *systray.MenuItem
+ mDown *systray.MenuItem
+ mAdminPanel *systray.MenuItem
+ mSettings *systray.MenuItem
+ mAbout *systray.MenuItem
+ mVersionUI *systray.MenuItem
+ mVersionDaemon *systray.MenuItem
+ mUpdate *systray.MenuItem
+ mQuit *systray.MenuItem
+ mNetworks *systray.MenuItem
+ mAllowSSH *systray.MenuItem
+ mAutoConnect *systray.MenuItem
+ mEnableRosenpass *systray.MenuItem
+ mNotifications *systray.MenuItem
+ mAdvancedSettings *systray.MenuItem
+ mCreateDebugBundle *systray.MenuItem
+ mExitNode *systray.MenuItem
// application with main windows.
app fyne.App
@@ -200,6 +202,14 @@ type serviceClient struct {
wRoutes fyne.Window
eventManager *event.Manager
+
+ exitNodeMu sync.Mutex
+ mExitNodeItems []menuHandler
+}
+
+type menuHandler struct {
+ *systray.MenuItem
+ cancel context.CancelFunc
}
// newServiceClient instance constructor
@@ -473,6 +483,9 @@ func (s *serviceClient) updateStatus() error {
status, err := conn.Status(s.ctx, &proto.StatusRequest{})
if err != nil {
log.Errorf("get service status: %v", err)
+ if s.connected {
+ s.app.SendNotification(fyne.NewNotification("Error", "Connection to service lost"))
+ }
s.setDisconnectedStatus()
return err
}
@@ -498,7 +511,8 @@ func (s *serviceClient) updateStatus() error {
s.mStatus.SetTitle("Connected")
s.mUp.Disable()
s.mDown.Enable()
- s.mRoutes.Enable()
+ s.mNetworks.Enable()
+ go s.updateExitNodes()
systrayIconState = true
} else if status.Status != string(internal.StatusConnected) && s.mUp.Disabled() {
s.setDisconnectedStatus()
@@ -554,7 +568,9 @@ func (s *serviceClient) setDisconnectedStatus() {
s.mStatus.SetTitle("Disconnected")
s.mDown.Disable()
s.mUp.Enable()
- s.mRoutes.Disable()
+ s.mNetworks.Disable()
+ s.mExitNode.Disable()
+ go s.updateExitNodes()
}
func (s *serviceClient) onTrayReady() {
@@ -575,12 +591,18 @@ func (s *serviceClient) onTrayReady() {
s.mAllowSSH = s.mSettings.AddSubMenuItemCheckbox("Allow SSH", "Allow SSH connections", false)
s.mAutoConnect = s.mSettings.AddSubMenuItemCheckbox("Connect on Startup", "Connect automatically when the service starts", false)
s.mEnableRosenpass = s.mSettings.AddSubMenuItemCheckbox("Enable Quantum-Resistance", "Enable post-quantum security via Rosenpass", false)
- s.mNotifications = s.mSettings.AddSubMenuItemCheckbox("Notifications", "Enable notifications", true)
+ s.mNotifications = s.mSettings.AddSubMenuItemCheckbox("Notifications", "Enable notifications", false)
s.mAdvancedSettings = s.mSettings.AddSubMenuItem("Advanced Settings", "Advanced settings of the application")
+ s.mCreateDebugBundle = s.mSettings.AddSubMenuItem("Create Debug Bundle", "Create and open debug information bundle")
s.loadSettings()
- s.mRoutes = systray.AddMenuItem("Networks", "Open the networks management window")
- s.mRoutes.Disable()
+ s.exitNodeMu.Lock()
+ s.mExitNode = systray.AddMenuItem("Exit Node", "Select exit node for routing traffic")
+ s.mExitNode.Disable()
+ s.exitNodeMu.Unlock()
+
+ s.mNetworks = systray.AddMenuItem("Networks", "Open the networks management window")
+ s.mNetworks.Disable()
systray.AddSeparator()
s.mAbout = systray.AddMenuItem("About", "About")
@@ -599,6 +621,9 @@ func (s *serviceClient) onTrayReady() {
systray.AddSeparator()
s.mQuit = systray.AddMenuItem("Quit", "Quit the client app")
+ // update exit node menu in case service is already connected
+ go s.updateExitNodes()
+
s.update.SetOnUpdateListener(s.onUpdateAvailable)
go func() {
s.getSrvConfig()
@@ -614,6 +639,12 @@ func (s *serviceClient) onTrayReady() {
s.eventManager = event.NewManager(s.app, s.addr)
s.eventManager.SetNotificationsEnabled(s.mNotifications.Checked())
+ s.eventManager.AddHandler(func(event *proto.SystemEvent) {
+ if event.Category == proto.SystemEvent_SYSTEM {
+ s.updateExitNodes()
+ }
+ })
+
go s.eventManager.Start(s.ctx)
go func() {
@@ -628,7 +659,7 @@ func (s *serviceClient) onTrayReady() {
defer s.mUp.Enable()
err := s.menuUpClick()
if err != nil {
- s.runSelfCommand("error-msg", err.Error())
+ s.app.SendNotification(fyne.NewNotification("Error", "Failed to connect to NetBird service"))
return
}
}()
@@ -638,7 +669,7 @@ func (s *serviceClient) onTrayReady() {
defer s.mDown.Enable()
err := s.menuDownClick()
if err != nil {
- s.runSelfCommand("error-msg", err.Error())
+ s.app.SendNotification(fyne.NewNotification("Error", "Failed to connect to NetBird service"))
return
}
}()
@@ -676,6 +707,13 @@ func (s *serviceClient) onTrayReady() {
defer s.getSrvConfig()
s.runSelfCommand("settings", "true")
}()
+ case <-s.mCreateDebugBundle.ClickedCh:
+ go func() {
+ if err := s.createAndOpenDebugBundle(); err != nil {
+ log.Errorf("Failed to create debug bundle: %v", err)
+ s.app.SendNotification(fyne.NewNotification("Error", "Failed to create debug bundle"))
+ }
+ }()
case <-s.mQuit.ClickedCh:
systray.Quit()
return
@@ -684,10 +722,10 @@ func (s *serviceClient) onTrayReady() {
if err != nil {
log.Errorf("%s", err)
}
- case <-s.mRoutes.ClickedCh:
- s.mRoutes.Disable()
+ case <-s.mNetworks.ClickedCh:
+ s.mNetworks.Disable()
go func() {
- defer s.mRoutes.Enable()
+ defer s.mNetworks.Enable()
s.runSelfCommand("networks", "true")
}()
case <-s.mNotifications.ClickedCh:
@@ -718,7 +756,11 @@ func (s *serviceClient) runSelfCommand(command, arg string) {
return
}
- cmd := exec.Command(proc, fmt.Sprintf("--%s=%s", command, arg))
+ cmd := exec.Command(proc,
+ fmt.Sprintf("--%s=%s", command, arg),
+ fmt.Sprintf("--daemon-addr=%s", s.addr),
+ )
+
out, err := cmd.CombinedOutput()
if exitErr, ok := err.(*exec.ExitError); ok && exitErr.ExitCode() == 1 {
log.Errorf("start %s UI: %v, %s", command, err, string(out))
@@ -737,7 +779,12 @@ func normalizedVersion(version string) string {
return versionString
}
-func (s *serviceClient) onTrayExit() {}
+// onTrayExit is called when the tray icon is closed.
+func (s *serviceClient) onTrayExit() {
+ for _, item := range s.mExitNodeItems {
+ item.cancel()
+ }
+}
// getSrvClient connection to the service.
func (s *serviceClient) getSrvClient(timeout time.Duration) (proto.DaemonServiceClient, error) {
@@ -753,7 +800,7 @@ func (s *serviceClient) getSrvClient(timeout time.Duration) (proto.DaemonService
strings.TrimPrefix(s.addr, "tcp://"),
grpc.WithTransportCredentials(insecure.NewCredentials()),
grpc.WithBlock(),
- grpc.WithUserAgent(system.GetDesktopUIUserAgent()),
+ grpc.WithUserAgent(desktop.GetUIUserAgent()),
)
if err != nil {
return nil, fmt.Errorf("dial service: %w", err)
diff --git a/client/ui/config/config.go b/client/ui/config/config.go
deleted file mode 100644
index fc3361b61..000000000
--- a/client/ui/config/config.go
+++ /dev/null
@@ -1,46 +0,0 @@
-package config
-
-import (
- "os"
- "runtime"
-)
-
-// ClientConfig basic settings for the UI application.
-type ClientConfig struct {
- configPath string
- logFile string
- daemonAddr string
-}
-
-// Config object with default settings.
-//
-// We are creating this package to extract utility functions from the cmd package
-// reading and parsing the configurations for the client should be done here
-func Config() *ClientConfig {
- defaultConfigPath := "/etc/wiretrustee/config.json"
- defaultLogFile := "/var/log/wiretrustee/client.log"
- if runtime.GOOS == "windows" {
- defaultConfigPath = os.Getenv("PROGRAMDATA") + "\\Wiretrustee\\" + "config.json"
- defaultLogFile = os.Getenv("PROGRAMDATA") + "\\Wiretrustee\\" + "client.log"
- }
-
- defaultDaemonAddr := "unix:///var/run/wiretrustee.sock"
- if runtime.GOOS == "windows" {
- defaultDaemonAddr = "tcp://127.0.0.1:41731"
- }
- return &ClientConfig{
- configPath: defaultConfigPath,
- logFile: defaultLogFile,
- daemonAddr: defaultDaemonAddr,
- }
-}
-
-// DaemonAddr of the gRPC API.
-func (c *ClientConfig) DaemonAddr() string {
- return c.daemonAddr
-}
-
-// LogFile path.
-func (c *ClientConfig) LogFile() string {
- return c.logFile
-}
diff --git a/client/ui/debug.go b/client/ui/debug.go
new file mode 100644
index 000000000..845ea284c
--- /dev/null
+++ b/client/ui/debug.go
@@ -0,0 +1,50 @@
+//go:build !(linux && 386)
+
+package main
+
+import (
+ "fmt"
+ "path/filepath"
+
+ "fyne.io/fyne/v2"
+ "github.com/skratchdot/open-golang/open"
+
+ "github.com/netbirdio/netbird/client/proto"
+ nbstatus "github.com/netbirdio/netbird/client/status"
+)
+
+func (s *serviceClient) createAndOpenDebugBundle() error {
+ conn, err := s.getSrvClient(failFastTimeout)
+ if err != nil {
+ return fmt.Errorf("get client: %v", err)
+ }
+
+ statusResp, err := conn.Status(s.ctx, &proto.StatusRequest{GetFullPeerStatus: true})
+ if err != nil {
+ return fmt.Errorf("failed to get status: %v", err)
+ }
+
+ overview := nbstatus.ConvertToStatusOutputOverview(statusResp, true, "", nil, nil, nil)
+ statusOutput := nbstatus.ParseToFullDetailSummary(overview)
+
+ resp, err := conn.DebugBundle(s.ctx, &proto.DebugBundleRequest{
+ Anonymize: true,
+ Status: statusOutput,
+ SystemInfo: true,
+ })
+ if err != nil {
+ return fmt.Errorf("failed to create debug bundle: %v", err)
+ }
+
+ bundleDir := filepath.Dir(resp.GetPath())
+ if err := open.Start(bundleDir); err != nil {
+ return fmt.Errorf("failed to open debug bundle directory: %v", err)
+ }
+
+ s.app.SendNotification(fyne.NewNotification(
+ "Debug Bundle",
+ fmt.Sprintf("Debug bundle created at %s. Administrator privileges are required to access it.", resp.GetPath()),
+ ))
+
+ return nil
+}
diff --git a/client/ui/desktop/desktop.go b/client/ui/desktop/desktop.go
new file mode 100644
index 000000000..0c99e2f38
--- /dev/null
+++ b/client/ui/desktop/desktop.go
@@ -0,0 +1,8 @@
+package desktop
+
+import "github.com/netbirdio/netbird/version"
+
+// GetUIUserAgent returns the Desktop ui user agent
+func GetUIUserAgent() string {
+ return "netbird-desktop-ui/" + version.NetbirdVersion()
+}
diff --git a/client/ui/event/event.go b/client/ui/event/event.go
index 7925ee4d3..4d949416d 100644
--- a/client/ui/event/event.go
+++ b/client/ui/event/event.go
@@ -3,6 +3,7 @@ package event
import (
"context"
"fmt"
+ "slices"
"strings"
"sync"
"time"
@@ -14,17 +15,20 @@ import (
"google.golang.org/grpc/credentials/insecure"
"github.com/netbirdio/netbird/client/proto"
- "github.com/netbirdio/netbird/client/system"
+ "github.com/netbirdio/netbird/client/ui/desktop"
)
+type Handler func(*proto.SystemEvent)
+
type Manager struct {
app fyne.App
addr string
- mu sync.Mutex
- ctx context.Context
- cancel context.CancelFunc
- enabled bool
+ mu sync.Mutex
+ ctx context.Context
+ cancel context.CancelFunc
+ enabled bool
+ handlers []Handler
}
func NewManager(app fyne.App, addr string) *Manager {
@@ -100,20 +104,41 @@ func (e *Manager) SetNotificationsEnabled(enabled bool) {
func (e *Manager) handleEvent(event *proto.SystemEvent) {
e.mu.Lock()
enabled := e.enabled
+ handlers := slices.Clone(e.handlers)
e.mu.Unlock()
- if !enabled {
+ // critical events are always shown
+ if !enabled && event.Severity != proto.SystemEvent_CRITICAL {
return
}
- title := e.getEventTitle(event)
- e.app.SendNotification(fyne.NewNotification(title, event.UserMessage))
+ if event.UserMessage != "" {
+ title := e.getEventTitle(event)
+ body := event.UserMessage
+ id := event.Metadata["id"]
+ if id != "" {
+ body += fmt.Sprintf(" ID: %s", id)
+ }
+ e.app.SendNotification(fyne.NewNotification(title, body))
+ }
+
+ for _, handler := range handlers {
+ go handler(event)
+ }
+}
+
+func (e *Manager) AddHandler(handler Handler) {
+ e.mu.Lock()
+ defer e.mu.Unlock()
+ e.handlers = append(e.handlers, handler)
}
func (e *Manager) getEventTitle(event *proto.SystemEvent) string {
var prefix string
switch event.Severity {
- case proto.SystemEvent_ERROR, proto.SystemEvent_CRITICAL:
+ case proto.SystemEvent_CRITICAL:
+ prefix = "Critical"
+ case proto.SystemEvent_ERROR:
prefix = "Error"
case proto.SystemEvent_WARNING:
prefix = "Warning"
@@ -142,7 +167,7 @@ func getClient(addr string) (proto.DaemonServiceClient, error) {
conn, err := grpc.NewClient(
strings.TrimPrefix(addr, "tcp://"),
grpc.WithTransportCredentials(insecure.NewCredentials()),
- grpc.WithUserAgent(system.GetDesktopUIUserAgent()),
+ grpc.WithUserAgent(desktop.GetUIUserAgent()),
)
if err != nil {
return nil, err
diff --git a/client/ui/network.go b/client/ui/network.go
index 852c4765b..750788cf3 100644
--- a/client/ui/network.go
+++ b/client/ui/network.go
@@ -3,7 +3,9 @@
package main
import (
+ "context"
"fmt"
+ "runtime"
"sort"
"strings"
"time"
@@ -13,6 +15,7 @@ import (
"fyne.io/fyne/v2/dialog"
"fyne.io/fyne/v2/layout"
"fyne.io/fyne/v2/widget"
+ "fyne.io/systray"
log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/client/proto"
@@ -237,14 +240,14 @@ func (s *serviceClient) selectNetwork(id string, checked bool) {
s.showError(fmt.Errorf("failed to select network: %v", err))
return
}
- log.Infof("Route %s selected", id)
+ log.Infof("Network '%s' selected", id)
} else {
if _, err := conn.DeselectNetworks(s.ctx, req); err != nil {
log.Errorf("failed to deselect network: %v", err)
s.showError(fmt.Errorf("failed to deselect network: %v", err))
return
}
- log.Infof("Network %s deselected", id)
+ log.Infof("Network '%s' deselected", id)
}
}
@@ -324,6 +327,201 @@ func (s *serviceClient) updateNetworksBasedOnDisplayTab(tabs *container.AppTabs,
s.updateNetworks(grid, f)
}
+func (s *serviceClient) updateExitNodes() {
+ conn, err := s.getSrvClient(defaultFailTimeout)
+ if err != nil {
+ log.Errorf("get client: %v", err)
+ return
+ }
+
+ exitNodes, err := s.getExitNodes(conn)
+ if err != nil {
+ log.Errorf("get exit nodes: %v", err)
+ return
+ }
+
+ s.exitNodeMu.Lock()
+ defer s.exitNodeMu.Unlock()
+
+ s.recreateExitNodeMenu(exitNodes)
+
+ if len(s.mExitNodeItems) > 0 {
+ s.mExitNode.Enable()
+ } else {
+ s.mExitNode.Disable()
+ }
+
+ log.Debugf("Exit nodes updated: %d", len(s.mExitNodeItems))
+}
+
+func (s *serviceClient) recreateExitNodeMenu(exitNodes []*proto.Network) {
+ for _, node := range s.mExitNodeItems {
+ node.cancel()
+ node.Remove()
+ }
+ s.mExitNodeItems = nil
+
+ if runtime.GOOS == "linux" || runtime.GOOS == "freebsd" {
+ s.mExitNode.Remove()
+ s.mExitNode = systray.AddMenuItem("Exit Node", "Select exit node for routing traffic")
+ }
+
+ for _, node := range exitNodes {
+ menuItem := s.mExitNode.AddSubMenuItemCheckbox(
+ node.ID,
+ fmt.Sprintf("Use exit node %s", node.ID),
+ node.Selected,
+ )
+
+ ctx, cancel := context.WithCancel(context.Background())
+ s.mExitNodeItems = append(s.mExitNodeItems, menuHandler{
+ MenuItem: menuItem,
+ cancel: cancel,
+ })
+ go s.handleChecked(ctx, node.ID, menuItem)
+ }
+
+}
+
+func (s *serviceClient) getExitNodes(conn proto.DaemonServiceClient) ([]*proto.Network, error) {
+ ctx, cancel := context.WithTimeout(s.ctx, defaultFailTimeout)
+ defer cancel()
+
+ resp, err := conn.ListNetworks(ctx, &proto.ListNetworksRequest{})
+ if err != nil {
+ return nil, fmt.Errorf("list networks: %v", err)
+ }
+
+ var exitNodes []*proto.Network
+ for _, network := range resp.Routes {
+ if network.Range == "0.0.0.0/0" {
+ exitNodes = append(exitNodes, network)
+ }
+ }
+ return exitNodes, nil
+}
+
+func (s *serviceClient) handleChecked(ctx context.Context, id string, item *systray.MenuItem) {
+ for {
+ select {
+ case <-ctx.Done():
+ return
+ case _, ok := <-item.ClickedCh:
+ if !ok {
+ return
+ }
+ if err := s.toggleExitNode(id, item); err != nil {
+ log.Errorf("failed to toggle exit node: %v", err)
+ continue
+ }
+ }
+ }
+}
+
+// Add function to toggle exit node selection
+func (s *serviceClient) toggleExitNode(nodeID string, item *systray.MenuItem) error {
+ conn, err := s.getSrvClient(defaultFailTimeout)
+ if err != nil {
+ return fmt.Errorf("get client: %v", err)
+ }
+
+ log.Infof("Toggling exit node '%s'", nodeID)
+
+ s.exitNodeMu.Lock()
+ defer s.exitNodeMu.Unlock()
+
+ exitNodes, err := s.getExitNodes(conn)
+ if err != nil {
+ return fmt.Errorf("get exit nodes: %v", err)
+ }
+
+ var exitNode *proto.Network
+ // find other selected nodes and ours
+ ids := make([]string, 0, len(exitNodes))
+ for _, node := range exitNodes {
+ if node.ID == nodeID {
+ // preserve original state
+ cp := *node //nolint:govet
+ exitNode = &cp
+
+ // set desired state for recreation
+ node.Selected = true
+ continue
+ }
+ if node.Selected {
+ ids = append(ids, node.ID)
+
+ // set desired state for recreation
+ node.Selected = false
+ }
+ }
+
+ if item.Checked() && len(ids) == 0 {
+ // exit node is the only selected node, deselect it
+ ids = append(ids, nodeID)
+ exitNode = nil
+ }
+
+ // deselect all other selected exit nodes
+ if err := s.deselectOtherExitNodes(conn, ids, item); err != nil {
+ return err
+ }
+
+ if err := s.selectNewExitNode(conn, exitNode, nodeID, item); err != nil {
+ return err
+ }
+
+ // linux/bsd doesn't handle Check/Uncheck well, so we recreate the menu
+ if runtime.GOOS == "linux" || runtime.GOOS == "freebsd" {
+ s.recreateExitNodeMenu(exitNodes)
+ }
+
+ return nil
+}
+
+func (s *serviceClient) deselectOtherExitNodes(conn proto.DaemonServiceClient, ids []string, currentItem *systray.MenuItem) error {
+ // deselect all other selected exit nodes
+ if len(ids) > 0 {
+ deselectReq := &proto.SelectNetworksRequest{
+ NetworkIDs: ids,
+ }
+ if _, err := conn.DeselectNetworks(s.ctx, deselectReq); err != nil {
+ return fmt.Errorf("deselect networks: %v", err)
+ }
+
+ log.Infof("Deselected exit nodes: %v", ids)
+ }
+
+ // uncheck all other exit node menu items
+ for _, i := range s.mExitNodeItems {
+ if i.MenuItem == currentItem {
+ continue
+ }
+ i.Uncheck()
+ log.Infof("Unchecked exit node %v", i)
+ }
+
+ return nil
+}
+
+func (s *serviceClient) selectNewExitNode(conn proto.DaemonServiceClient, exitNode *proto.Network, nodeID string, item *systray.MenuItem) error {
+ if exitNode != nil && !exitNode.Selected {
+ selectReq := &proto.SelectNetworksRequest{
+ NetworkIDs: []string{exitNode.ID},
+ Append: true,
+ }
+ if _, err := conn.SelectNetworks(s.ctx, selectReq); err != nil {
+ return fmt.Errorf("select network: %v", err)
+ }
+
+ log.Infof("Selected exit node '%s'", nodeID)
+ }
+
+ item.Check()
+
+ return nil
+}
+
func getGridAndFilterFromTab(tabs *container.AppTabs, allGrid, overlappingGrid, exitNodesGrid *fyne.Container) (*fyne.Container, filter) {
switch tabs.Selected().Text {
case overlappingNetworksText:
diff --git a/go.sum b/go.sum
index c0685caa9..0ccf91b5d 100644
--- a/go.sum
+++ b/go.sum
@@ -18,10 +18,10 @@ cloud.google.com/go v0.74.0/go.mod h1:VV1xSbzvo+9QJOxLDaJfTjx5e+MePCpCWwvftOeQmW
cloud.google.com/go v0.78.0/go.mod h1:QjdrLG0uq+YwhjoVOLsS1t7TW8fs36kLs4XO5R5ECHg=
cloud.google.com/go v0.79.0/go.mod h1:3bzgcEeQlzbuEAYu4mrWhKqWjmpprinYgKJLgKHnbb8=
cloud.google.com/go v0.81.0/go.mod h1:mk/AM35KwGk/Nm2YSeZbxXdrNK3KZOYHmLkOqC2V6E0=
-cloud.google.com/go/auth v0.3.0 h1:PRyzEpGfx/Z9e8+lHsbkoUVXD0gnu4MNmm7Gp8TQNIs=
-cloud.google.com/go/auth v0.3.0/go.mod h1:lBv6NKTWp8E3LPzmO1TbiiRKc4drLOfHsgmlH9ogv5w=
-cloud.google.com/go/auth/oauth2adapt v0.2.2 h1:+TTV8aXpjeChS9M+aTtN/TjdQnzJvmzKFt//oWu7HX4=
-cloud.google.com/go/auth/oauth2adapt v0.2.2/go.mod h1:wcYjgpZI9+Yu7LyYBg4pqSiaRkfEK3GQcpb7C/uyF1Q=
+cloud.google.com/go/auth v0.14.1 h1:AwoJbzUdxA/whv1qj3TLKwh3XX5sikny2fc40wUl+h0=
+cloud.google.com/go/auth v0.14.1/go.mod h1:4JHUxlGXisL0AW8kXPtUF6ztuOksyfUQNFjfsOCXkPM=
+cloud.google.com/go/auth/oauth2adapt v0.2.7 h1:/Lc7xODdqcEw8IrZ9SvwnlLX6j9FHQM74z6cBk9Rw6M=
+cloud.google.com/go/auth/oauth2adapt v0.2.7/go.mod h1:NTbTTzfvPl1Y3V1nPpOgl2w6d/FjO7NNUQaWSox6ZMc=
cloud.google.com/go/bigquery v1.0.1/go.mod h1:i/xbL2UlR5RvWAURpBYZTtm/cXjCha9lbfbpx4poX+o=
cloud.google.com/go/bigquery v1.3.0/go.mod h1:PjpwJnslEMmckchkHFfq+HTD2DmtT67aNFKH1/VBDHE=
cloud.google.com/go/bigquery v1.4.0/go.mod h1:S8dzgnTigyfTmLBfrtrhyYhwRxG72rYxvftPBK2Dvzc=
@@ -29,8 +29,8 @@ cloud.google.com/go/bigquery v1.5.0/go.mod h1:snEHRnqQbz117VIFhE8bmtwIDY80NLUZUM
cloud.google.com/go/bigquery v1.7.0/go.mod h1://okPTzCYNXSlb24MZs83e2Do+h+VXtc4gLoIoXIAPc=
cloud.google.com/go/bigquery v1.8.0/go.mod h1:J5hqkt3O0uAFnINi6JXValWIb1v0goeZM77hZzJN/fQ=
cloud.google.com/go/compute/metadata v0.2.0/go.mod h1:zFmK7XCadkQkj6TtorcaGlCW1hT1fIilQDwofLpJ20k=
-cloud.google.com/go/compute/metadata v0.3.0 h1:Tz+eQXMEqDIKRsmY3cHTL6FVaynIjX2QxYC4trgAKZc=
-cloud.google.com/go/compute/metadata v0.3.0/go.mod h1:zFmK7XCadkQkj6TtorcaGlCW1hT1fIilQDwofLpJ20k=
+cloud.google.com/go/compute/metadata v0.6.0 h1:A6hENjEsCDtC1k8byVsgwvVcioamEHvZ4j01OwKxG9I=
+cloud.google.com/go/compute/metadata v0.6.0/go.mod h1:FjyFAW1MW0C203CEOMDTu3Dk1FlqW3Rga40jzHL4hfg=
cloud.google.com/go/datastore v1.0.0/go.mod h1:LXYbyblFSglQ5pkeyhO+Qmw7ukd3C+pD7TKLgZqpHYE=
cloud.google.com/go/datastore v1.1.0/go.mod h1:umbIZjpQpHh4hmRpGhH4tLFup+FVzqBi1b3c64qFpCk=
cloud.google.com/go/firestore v1.1.0/go.mod h1:ulACoGHTpvq5r8rxGJ4ddJZBZqakUQqClKRT5SZwBmk=
@@ -225,8 +225,8 @@ github.com/go-gl/glfw/v3.3/glfw v0.0.0-20240506104042-037f3cc74f2a h1:vxnBhFDDT+
github.com/go-gl/glfw/v3.3/glfw v0.0.0-20240506104042-037f3cc74f2a/go.mod h1:tQ2UAYgL5IevRw8kRxooKSPJfGvJ9fJQFa0TUsXzTg8=
github.com/go-logr/logr v0.1.0/go.mod h1:ixOQHD9gLJUVQQ2ZOR7zLEifBX6tGkNJF4QyIY7sIas=
github.com/go-logr/logr v1.2.2/go.mod h1:jdQByPbusPIv2/zmleS9BjJVeZ6kBagPoEUsqbVz/1A=
-github.com/go-logr/logr v1.4.1 h1:pKouT5E8xu9zeFC39JXRDukb6JFQPXM5p5I91188VAQ=
-github.com/go-logr/logr v1.4.1/go.mod h1:9T104GzyrTigFIr8wt5mBrctHMim0Nb2HLGrmQ40KvY=
+github.com/go-logr/logr v1.4.2 h1:6pFjapn8bFcIbiKo3XT4j/BhANplGihG6tvd+8rYgrY=
+github.com/go-logr/logr v1.4.2/go.mod h1:9T104GzyrTigFIr8wt5mBrctHMim0Nb2HLGrmQ40KvY=
github.com/go-logr/stdr v1.2.2 h1:hSWxHoqTgW2S2qGc0LTAI563KZ5YKYRhT3MFKZMbjag=
github.com/go-logr/stdr v1.2.2/go.mod h1:mMo/vtBO5dYbehREoey6XUKy/eSumjCCveDpRre4VKE=
github.com/go-ole/go-ole v1.2.6/go.mod h1:pprOEPIfldk/42T2oK7lQ4v4JSDwmV0As9GaiUsvbm0=
@@ -263,14 +263,12 @@ github.com/gogo/protobuf v1.3.2/go.mod h1:P1XiOD3dCwIKUDQYPy72D8LYyHL2YPYrpS2s69
github.com/golang-jwt/jwt v3.2.2+incompatible h1:IfV12K8xAKAnZqdXVzCZ+TOjboZ2keLg81eXfW3O+oY=
github.com/golang-jwt/jwt v3.2.2+incompatible/go.mod h1:8pz2t5EyA70fFQQSrl6XZXzqecmYZeUEB8OUGHkxJ+I=
github.com/golang/glog v0.0.0-20160126235308-23def4e6c14b/go.mod h1:SBH7ygxi8pfUlaOkMMuAQtPIUF8ecWP5IEl/CR7VP2Q=
-github.com/golang/glog v1.2.0 h1:uCdmnmatrKCgMBlM4rMuJZWOkPDqdbZPnrMXDY4gI68=
-github.com/golang/glog v1.2.0/go.mod h1:6AhwSGph0fcJtXVM/PEHPqZlFeoLxhs7/t5UDAwmO+w=
+github.com/golang/glog v1.2.3 h1:oDTdz9f5VGVVNGu/Q7UXKWYsD0873HXLHdJUNBsSEKM=
+github.com/golang/glog v1.2.3/go.mod h1:6AhwSGph0fcJtXVM/PEHPqZlFeoLxhs7/t5UDAwmO+w=
github.com/golang/groupcache v0.0.0-20160516000752-02826c3e7903/go.mod h1:cIg4eruTrX1D+g88fzRXU5OdNfaM+9IcxsU14FzY7Hc=
github.com/golang/groupcache v0.0.0-20190702054246-869f871628b6/go.mod h1:cIg4eruTrX1D+g88fzRXU5OdNfaM+9IcxsU14FzY7Hc=
github.com/golang/groupcache v0.0.0-20191227052852-215e87163ea7/go.mod h1:cIg4eruTrX1D+g88fzRXU5OdNfaM+9IcxsU14FzY7Hc=
github.com/golang/groupcache v0.0.0-20200121045136-8c9f03a8e57e/go.mod h1:cIg4eruTrX1D+g88fzRXU5OdNfaM+9IcxsU14FzY7Hc=
-github.com/golang/groupcache v0.0.0-20210331224755-41bb18bfe9da h1:oI5xCqsCo564l8iNU+DwB5epxmsaqB+rhGL0m5jtYqE=
-github.com/golang/groupcache v0.0.0-20210331224755-41bb18bfe9da/go.mod h1:cIg4eruTrX1D+g88fzRXU5OdNfaM+9IcxsU14FzY7Hc=
github.com/golang/mock v1.1.1/go.mod h1:oTYuIxOrZwtPieC+H1uAHpcLFnEyAGVDL/k47Jfbm0A=
github.com/golang/mock v1.2.0/go.mod h1:oTYuIxOrZwtPieC+H1uAHpcLFnEyAGVDL/k47Jfbm0A=
github.com/golang/mock v1.3.1/go.mod h1:sBzyDLLjw3U8JLTeZvSv8jJB+tU5PVekmnlKIyFUx0Y=
@@ -345,18 +343,18 @@ github.com/google/pprof v0.0.0-20210226084205-cbba55b83ad5/go.mod h1:kpwsk12EmLe
github.com/google/pprof v0.0.0-20211214055906-6f57359322fd h1:1FjCyPC+syAzJ5/2S8fqdZK1R22vvA0J7JZKcuOIQ7Y=
github.com/google/pprof v0.0.0-20211214055906-6f57359322fd/go.mod h1:KgnwoLYCZ8IQu3XUZ8Nc/bM9CCZFOyjUNOSygVozoDg=
github.com/google/renameio v0.1.0/go.mod h1:KWCgfxg9yswjAJkECMjeO8J8rahYeXnNhOm40UhjYkI=
-github.com/google/s2a-go v0.1.7 h1:60BLSyTrOV4/haCDW4zb1guZItoSq8foHCXrAnjBo/o=
-github.com/google/s2a-go v0.1.7/go.mod h1:50CgR4k1jNlWBu4UfS4AcfhVe1r6pdZPygJ3R8F0Qdw=
+github.com/google/s2a-go v0.1.9 h1:LGD7gtMgezd8a/Xak7mEWL0PjoTQFvpRudN895yqKW0=
+github.com/google/s2a-go v0.1.9/go.mod h1:YA0Ei2ZQL3acow2O62kdp9UlnvMmU7kA6Eutn0dXayM=
github.com/google/uuid v1.1.1/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
github.com/google/uuid v1.1.2/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0=
github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
-github.com/googleapis/enterprise-certificate-proxy v0.3.2 h1:Vie5ybvEvT75RniqhfFxPRy3Bf7vr3h0cechB90XaQs=
-github.com/googleapis/enterprise-certificate-proxy v0.3.2/go.mod h1:VLSiSSBs/ksPL8kq3OBOQ6WRI2QnaFynd1DCjZ62+V0=
+github.com/googleapis/enterprise-certificate-proxy v0.3.4 h1:XYIDZApgAnrN1c855gTgghdIA6Stxb52D5RnLI1SLyw=
+github.com/googleapis/enterprise-certificate-proxy v0.3.4/go.mod h1:YKe7cfqYXjKGpGvmSg28/fFvhNzinZQm8DGnaburhGA=
github.com/googleapis/gax-go/v2 v2.0.4/go.mod h1:0Wqv26UfaUD9n4G6kQubkQ+KchISgw+vpHVxEJEs9eg=
github.com/googleapis/gax-go/v2 v2.0.5/go.mod h1:DWXyrwAJ9X0FpwwEdw+IPEYBICEFu5mhpdKc/us6bOk=
-github.com/googleapis/gax-go/v2 v2.12.3 h1:5/zPPDvw8Q1SuXjrqrZslrqT7dL/uJT2CQii/cLCKqA=
-github.com/googleapis/gax-go/v2 v2.12.3/go.mod h1:AKloxT6GtNbaLm8QTNSidHUVsHYcBHwWRvkNFJUQcS4=
+github.com/googleapis/gax-go/v2 v2.14.1 h1:hb0FFeiPaQskmvakKu5EbCbpntQn48jyHuvrkurSS/Q=
+github.com/googleapis/gax-go/v2 v2.14.1/go.mod h1:Hb/NubMaVM88SrNkvl8X/o8XWwDJEPqouaLeN2IUxoA=
github.com/googleapis/gnostic v0.0.0-20170729233727-0c5108395e2d/go.mod h1:sJBsCZ4ayReDTBIg8b9dl28c5xFWyhBTVRp3pOg5EKY=
github.com/gopacket/gopacket v1.1.1 h1:zbx9F9d6A7sWNkFKrvMBZTfGgxFoY4NgUudFVVHMfcw=
github.com/gopacket/gopacket v1.1.1/go.mod h1:HavMeONEl7W9036of9LbSWoonqhH7HA1+ZRO+rMIvFs=
@@ -617,8 +615,8 @@ github.com/quic-go/quic-go v0.48.2 h1:wsKXZPeGWpMpCGSWqOcqpW2wZYic/8T3aqiOID0/KW
github.com/quic-go/quic-go v0.48.2/go.mod h1:yBgs3rWBOADpga7F+jJsb6Ybg1LSYiQvwWlLX+/6HMs=
github.com/rogpeppe/fastuuid v1.2.0/go.mod h1:jVj6XXZzXRy/MSR5jhDC/2q6DgLz+nrA6LYCDYWNEvQ=
github.com/rogpeppe/go-internal v1.3.0/go.mod h1:M8bDsm7K2OlrFYOpmOWEs/qY81heoFRclV5y23lUDJ4=
-github.com/rogpeppe/go-internal v1.11.0 h1:cWPaGQEPrBb5/AsnsZesgZZ9yb1OQ+GOISoDNXVBh4M=
-github.com/rogpeppe/go-internal v1.11.0/go.mod h1:ddIwULY96R17DhadqLgMfk9H9tvdUzkipdSkR5nkCZA=
+github.com/rogpeppe/go-internal v1.13.1 h1:KvO1DLK/DRN07sQ1LQKScxyZJuNnedQ5/wKSR38lUII=
+github.com/rogpeppe/go-internal v1.13.1/go.mod h1:uMEvuHeurkdAXX61udpOXGD/AzZDWNMNyH2VO9fmH0o=
github.com/rs/cors v1.8.0 h1:P2KMzcFwrPoSjkF1WLRPsp3UMLyql8L4v9hQpVeK5so=
github.com/rs/cors v1.8.0/go.mod h1:EBwu+T5AvHOcXwvZIkQFjUN6s8Czyqw12GL/Y0tUyRM=
github.com/rs/xid v1.3.0 h1:6NjYksEUlhurdVehpc7S7dk6DAmcKv8V9gG0FsVN2U4=
@@ -683,11 +681,11 @@ github.com/stretchr/testify v1.6.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/
github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg=
github.com/stretchr/testify v1.7.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg=
github.com/stretchr/testify v1.8.0/go.mod h1:yNjHg4UonilssWZ8iaSj1OCr/vHnekPRkoO+kdMU+MU=
-github.com/stretchr/testify v1.8.1/go.mod h1:w2LPCIKwWwSfY2zedu0+kehJoqGctiVI29o6fzry7u4=
github.com/stretchr/testify v1.8.3/go.mod h1:sz/lmYIOXD/1dqDmKjjqLyZ2RngseejIcXlSw2iwfAo=
github.com/stretchr/testify v1.8.4/go.mod h1:sz/lmYIOXD/1dqDmKjjqLyZ2RngseejIcXlSw2iwfAo=
-github.com/stretchr/testify v1.9.0 h1:HtqpIVDClZ4nwg75+f6Lvsy/wHu+3BoSGCbBAcpTsTg=
github.com/stretchr/testify v1.9.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY=
+github.com/stretchr/testify v1.10.0 h1:Xv5erBjTwe/5IxqUQTdXv5kgmIvbHo3QQyRwhJsOfJA=
+github.com/stretchr/testify v1.10.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY=
github.com/subosito/gotenv v1.2.0/go.mod h1:N0PQaV/YGNqwC0u51sEeR/aUtSLEXKX9iv69rRypqCw=
github.com/testcontainers/testcontainers-go v0.31.0 h1:W0VwIhcEVhRflwL9as3dhY6jXjVCA27AkmbnZ+UTh3U=
github.com/testcontainers/testcontainers-go v0.31.0/go.mod h1:D2lAoA0zUFiSY+eAflqK5mcUx/A5hrrORaEQrd0SefI=
@@ -739,28 +737,28 @@ go.opencensus.io v0.22.3/go.mod h1:yxeiOL68Rb0Xd1ddK5vPZ/oVn4vY4Ynel7k9FzqtOIw=
go.opencensus.io v0.22.4/go.mod h1:yxeiOL68Rb0Xd1ddK5vPZ/oVn4vY4Ynel7k9FzqtOIw=
go.opencensus.io v0.22.5/go.mod h1:5pWMHQbX5EPX2/62yrJeAkowc+lfs/XD7Uxpq3pI6kk=
go.opencensus.io v0.23.0/go.mod h1:XItmlyltB5F7CS4xOC1DcqMoFqwtC6OG2xF7mCv7P7E=
-go.opencensus.io v0.24.0 h1:y73uSU6J157QMP2kn2r30vwW1A2W2WFwSCGnAVxeaD0=
-go.opencensus.io v0.24.0/go.mod h1:vNK8G9p7aAivkbmorf4v+7Hgx+Zs0yY+0fOtgBfjQKo=
-go.opentelemetry.io/contrib/instrumentation/google.golang.org/grpc/otelgrpc v0.49.0 h1:4Pp6oUg3+e/6M4C0A/3kJ2VYa++dsWVTtGgLVj5xtHg=
-go.opentelemetry.io/contrib/instrumentation/google.golang.org/grpc/otelgrpc v0.49.0/go.mod h1:Mjt1i1INqiaoZOMGR1RIUJN+i3ChKoFRqzrRQhlkbs0=
-go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.51.0 h1:Xs2Ncz0gNihqu9iosIZ5SkBbWo5T8JhhLJFMQL1qmLI=
-go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.51.0/go.mod h1:vy+2G/6NvVMpwGX/NyLqcC41fxepnuKHk16E6IZUcJc=
-go.opentelemetry.io/otel v1.26.0 h1:LQwgL5s/1W7YiiRwxf03QGnWLb2HW4pLiAhaA5cZXBs=
-go.opentelemetry.io/otel v1.26.0/go.mod h1:UmLkJHUAidDval2EICqBMbnAd0/m2vmpf/dAM+fvFs4=
+go.opentelemetry.io/auto/sdk v1.1.0 h1:cH53jehLUN6UFLY71z+NDOiNJqDdPRaXzTel0sJySYA=
+go.opentelemetry.io/auto/sdk v1.1.0/go.mod h1:3wSPjt5PWp2RhlCcmmOial7AvC4DQqZb7a7wCow3W8A=
+go.opentelemetry.io/contrib/instrumentation/google.golang.org/grpc/otelgrpc v0.58.0 h1:PS8wXpbyaDJQ2VDHHncMe9Vct0Zn1fEjpsjrLxGJoSc=
+go.opentelemetry.io/contrib/instrumentation/google.golang.org/grpc/otelgrpc v0.58.0/go.mod h1:HDBUsEjOuRC0EzKZ1bSaRGZWUBAzo+MhAcUUORSr4D0=
+go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.58.0 h1:yd02MEjBdJkG3uabWP9apV+OuWRIXGDuJEUJbOHmCFU=
+go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.58.0/go.mod h1:umTcuxiv1n/s/S6/c2AT/g2CQ7u5C59sHDNmfSwgz7Q=
+go.opentelemetry.io/otel v1.34.0 h1:zRLXxLCgL1WyKsPVrgbSdMN4c0FMkDAskSTQP+0hdUY=
+go.opentelemetry.io/otel v1.34.0/go.mod h1:OWFPOQ+h4G8xpyjgqo4SxJYdDQ/qmRH+wivy7zzx9oI=
go.opentelemetry.io/otel/exporters/otlp/otlptrace v1.19.0 h1:Mne5On7VWdx7omSrSSZvM4Kw7cS7NQkOOmLcgscI51U=
go.opentelemetry.io/otel/exporters/otlp/otlptrace v1.19.0/go.mod h1:IPtUMKL4O3tH5y+iXVyAXqpAwMuzC1IrxVS81rummfE=
go.opentelemetry.io/otel/exporters/otlp/otlptrace/otlptracehttp v1.19.0 h1:IeMeyr1aBvBiPVYihXIaeIZba6b8E1bYp7lbdxK8CQg=
go.opentelemetry.io/otel/exporters/otlp/otlptrace/otlptracehttp v1.19.0/go.mod h1:oVdCUtjq9MK9BlS7TtucsQwUcXcymNiEDjgDD2jMtZU=
go.opentelemetry.io/otel/exporters/prometheus v0.48.0 h1:sBQe3VNGUjY9IKWQC6z2lNqa5iGbDSxhs60ABwK4y0s=
go.opentelemetry.io/otel/exporters/prometheus v0.48.0/go.mod h1:DtrbMzoZWwQHyrQmCfLam5DZbnmorsGbOtTbYHycU5o=
-go.opentelemetry.io/otel/metric v1.26.0 h1:7S39CLuY5Jgg9CrnA9HHiEjGMF/X2VHvoXGgSllRz30=
-go.opentelemetry.io/otel/metric v1.26.0/go.mod h1:SY+rHOI4cEawI9a7N1A4nIg/nTQXe1ccCNWYOJUrpX4=
-go.opentelemetry.io/otel/sdk v1.26.0 h1:Y7bumHf5tAiDlRYFmGqetNcLaVUZmh4iYfmGxtmz7F8=
-go.opentelemetry.io/otel/sdk v1.26.0/go.mod h1:0p8MXpqLeJ0pzcszQQN4F0S5FVjBLgypeGSngLsmirs=
-go.opentelemetry.io/otel/sdk/metric v1.26.0 h1:cWSks5tfriHPdWFnl+qpX3P681aAYqlZHcAyHw5aU9Y=
-go.opentelemetry.io/otel/sdk/metric v1.26.0/go.mod h1:ClMFFknnThJCksebJwz7KIyEDHO+nTB6gK8obLy8RyE=
-go.opentelemetry.io/otel/trace v1.26.0 h1:1ieeAUb4y0TE26jUFrCIXKpTuVK7uJGN9/Z/2LP5sQA=
-go.opentelemetry.io/otel/trace v1.26.0/go.mod h1:4iDxvGDQuUkHve82hJJ8UqrwswHYsZuWCBllGV2U2y0=
+go.opentelemetry.io/otel/metric v1.34.0 h1:+eTR3U0MyfWjRDhmFMxe2SsW64QrZ84AOhvqS7Y+PoQ=
+go.opentelemetry.io/otel/metric v1.34.0/go.mod h1:CEDrp0fy2D0MvkXE+dPV7cMi8tWZwX3dmaIhwPOaqHE=
+go.opentelemetry.io/otel/sdk v1.34.0 h1:95zS4k/2GOy069d321O8jWgYsW3MzVV+KuSPKp7Wr1A=
+go.opentelemetry.io/otel/sdk v1.34.0/go.mod h1:0e/pNiaMAqaykJGKbi+tSjWfNNHMTxoC9qANsCzbyxU=
+go.opentelemetry.io/otel/sdk/metric v1.32.0 h1:rZvFnvmvawYb0alrYkjraqJq0Z4ZUJAiyYCU9snn1CU=
+go.opentelemetry.io/otel/sdk/metric v1.32.0/go.mod h1:PWeZlq0zt9YkYAp3gjKZ0eicRYvOh1Gd+X99x6GHpCQ=
+go.opentelemetry.io/otel/trace v1.34.0 h1:+ouXS2V8Rd4hp4580a8q23bg0azF2nI8cqLYnC8mh/k=
+go.opentelemetry.io/otel/trace v1.34.0/go.mod h1:Svm7lSjQD7kG7KJ/MUHPVXSDGz2OX4h0M2jHBhmSfRE=
go.opentelemetry.io/proto/otlp v1.0.0 h1:T0TX0tmXU8a3CbNXzEKGeU5mIVOdf0oykP+u2lIVU/I=
go.opentelemetry.io/proto/otlp v1.0.0/go.mod h1:Sy6pihPLfYHkr3NkUbEhGHFhINUSI/v80hjKIs5JXpM=
go.uber.org/atomic v1.7.0/go.mod h1:fEN4uk6kAWBTFdckzkM89CLk9XfWZrxpCo0nPH17wJc=
@@ -885,8 +883,8 @@ golang.org/x/net v0.9.0/go.mod h1:d48xBJpPfHeWQsugry2m+kC02ZBRGRgulfHnEXEuWns=
golang.org/x/net v0.10.0/go.mod h1:0qNGK6F8kojg2nk9dLZ2mShWaEBan6FAoqfSigmmuDg=
golang.org/x/net v0.14.0/go.mod h1:PpSgVXXLK0OxS0F31C1/tv6XNguvCrnXIDrFMspZIUI=
golang.org/x/net v0.20.0/go.mod h1:z8BVo6PvndSri0LbOE3hAn0apkU+1YvI6E70E9jsnvY=
-golang.org/x/net v0.33.0 h1:74SYHlV8BIgHIFC/LrYkOGIwL19eTYXQ5wc6TBuO36I=
-golang.org/x/net v0.33.0/go.mod h1:HXLR5J+9DxmrqMwG9qjGCxZ+zKXxBru04zlTvWlWuN4=
+golang.org/x/net v0.34.0 h1:Mb7Mrk043xzHgnRM88suvJFwzVrRfHEHJEl5/71CKw0=
+golang.org/x/net v0.34.0/go.mod h1:di0qlW3YNM5oh6GqDGQr92MyTozJPmybPK4Ev/Gm31k=
golang.org/x/oauth2 v0.0.0-20180821212333-d2e6202438be/go.mod h1:N/0e6XlmueqKjAGxoOufVs8QHGRruUQn6yWY3a++T0U=
golang.org/x/oauth2 v0.0.0-20190226205417-e64efc72b421/go.mod h1:gOpvHmFTYa4IltrdGE7lF6nIHvwfUNPOp7c8zoXwtLw=
golang.org/x/oauth2 v0.0.0-20190604053449-0f29369cfe45/go.mod h1:gOpvHmFTYa4IltrdGE7lF6nIHvwfUNPOp7c8zoXwtLw=
@@ -900,8 +898,8 @@ golang.org/x/oauth2 v0.0.0-20210220000619-9bb904979d93/go.mod h1:KelEdhl1UZF7XfJ
golang.org/x/oauth2 v0.0.0-20210313182246-cd4f82c27b84/go.mod h1:KelEdhl1UZF7XfJ4dDtk6s++YSgaE7mD/BuKKDLBl4A=
golang.org/x/oauth2 v0.0.0-20210402161424-2e8d93401602/go.mod h1:KelEdhl1UZF7XfJ4dDtk6s++YSgaE7mD/BuKKDLBl4A=
golang.org/x/oauth2 v0.8.0/go.mod h1:yr7u4HXZRm1R1kBWqr/xKNqewf0plRYoB7sla+BCIXE=
-golang.org/x/oauth2 v0.19.0 h1:9+E/EZBCbTLNrbN35fHv/a/d/mOBatymz1zbtQrXpIg=
-golang.org/x/oauth2 v0.19.0/go.mod h1:vYi7skDa1x015PmRRYZ7+s1cWyPgrPiSYRe4rnsexc8=
+golang.org/x/oauth2 v0.26.0 h1:afQXWNNaeC4nvZ0Ed9XvCCzXM6UHJG7iCg0W4fPqSBE=
+golang.org/x/oauth2 v0.26.0/go.mod h1:XYTD2NtWslqkgxebSiOHnXEap4TF09sJSc7H1sXbhtI=
golang.org/x/sync v0.0.0-20180314180146-1d60e4601c6f/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
golang.org/x/sync v0.0.0-20181108010431-42b317875d0f/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
golang.org/x/sync v0.0.0-20181221193216-37e7f081c4d4/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
@@ -1019,8 +1017,8 @@ golang.org/x/text v0.21.0/go.mod h1:4IBbMaMmOPCJ8SecivzSH54+73PCFmPWxNTLm+vZkEQ=
golang.org/x/time v0.0.0-20181108054448-85acf8d2951c/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ=
golang.org/x/time v0.0.0-20190308202827-9d24e82272b4/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ=
golang.org/x/time v0.0.0-20191024005414-555d28b269f0/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ=
-golang.org/x/time v0.5.0 h1:o7cqy6amK/52YcAKIPlM3a+Fpj35zvRj2TP+e1xFSfk=
-golang.org/x/time v0.5.0/go.mod h1:3BpzKBy/shNhVucY/MWOyx10tF3SFh9QdLuxbVysPQM=
+golang.org/x/time v0.10.0 h1:3usCWA8tQn0L8+hFJQNgzpWbd89begxN66o1Ojdn5L4=
+golang.org/x/time v0.10.0/go.mod h1:3BpzKBy/shNhVucY/MWOyx10tF3SFh9QdLuxbVysPQM=
golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ=
golang.org/x/tools v0.0.0-20181011042414-1f849cf54d09/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ=
golang.org/x/tools v0.0.0-20181030221726-6c7e314b6563/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ=
@@ -1114,8 +1112,8 @@ google.golang.org/api v0.40.0/go.mod h1:fYKFpnQN0DsDSKRVRcQSDQNtqWPfM9i+zNPxepjR
google.golang.org/api v0.41.0/go.mod h1:RkxM5lITDfTzmyKFPt+wGrCJbVfniCr2ool8kTBzRTU=
google.golang.org/api v0.43.0/go.mod h1:nQsDGjRXMo4lvh5hP0TKqF244gqhGcr/YSIykhUk/94=
google.golang.org/api v0.44.0/go.mod h1:EBOGZqzyhtvMDoxwS97ctnh0zUmYY6CxqXsc1AvkYD8=
-google.golang.org/api v0.177.0 h1:8a0p/BbPa65GlqGWtUKxot4p0TV8OGOfyTjtmkXNXmk=
-google.golang.org/api v0.177.0/go.mod h1:srbhue4MLjkjbkux5p3dw/ocYOSZTaIEvf7bCOnFQDw=
+google.golang.org/api v0.220.0 h1:3oMI4gdBgB72WFVwE1nerDD8W3HUOS4kypK6rRLbGns=
+google.golang.org/api v0.220.0/go.mod h1:26ZAlY6aN/8WgpCzjPNy18QpYaz7Zgg1h0qe1GkZEmY=
google.golang.org/appengine v1.1.0/go.mod h1:EbEs0AVv82hx2wNQdGPgUI5lhzA/G0D9YwlJXL52JkM=
google.golang.org/appengine v1.4.0/go.mod h1:xpcJRLb0r/rnEns0DIKYYv+WjYCduHsrkT7/EB5XEv4=
google.golang.org/appengine v1.5.0/go.mod h1:xpcJRLb0r/rnEns0DIKYYv+WjYCduHsrkT7/EB5XEv4=
@@ -1164,10 +1162,11 @@ google.golang.org/genproto v0.0.0-20210310155132-4ce2db91004e/go.mod h1:FWY/as6D
google.golang.org/genproto v0.0.0-20210319143718-93e7006c17a6/go.mod h1:FWY/as6DDZQgahTzZj3fqbO1CbirC29ZNUFHwi0/+no=
google.golang.org/genproto v0.0.0-20210402141018-6c239bbf2bb1/go.mod h1:9lPAdzaEmUacj36I+k7YKbEc5CXzPIeORRgDAUOu28A=
google.golang.org/genproto v0.0.0-20210602131652-f16073e35f0c/go.mod h1:UODoCrxHCcBojKKwX1terBiRUaqAsFqJiF615XL43r0=
-google.golang.org/genproto/googleapis/api v0.0.0-20240509183442-62759503f434 h1:OpXbo8JnN8+jZGPrL4SSfaDjSCjupr8lXyBAbexEm/U=
-google.golang.org/genproto/googleapis/api v0.0.0-20240509183442-62759503f434/go.mod h1:FfiGhwUm6CJviekPrc0oJ+7h29e+DmWU6UtjX0ZvI7Y=
-google.golang.org/genproto/googleapis/rpc v0.0.0-20240903143218-8af14fe29dc1 h1:pPJltXNxVzT4pK9yD8vR9X75DaWYYmLGMsEvBfFQZzQ=
-google.golang.org/genproto/googleapis/rpc v0.0.0-20240903143218-8af14fe29dc1/go.mod h1:UqMtugtsSgubUsoxbuAoiCXvqvErP7Gf0so0mK9tHxU=
+google.golang.org/genproto v0.0.0-20240123012728-ef4313101c80 h1:KAeGQVN3M9nD0/bQXnr/ClcEMJ968gUXJQ9pwfSynuQ=
+google.golang.org/genproto/googleapis/api v0.0.0-20241209162323-e6fa225c2576 h1:CkkIfIt50+lT6NHAVoRYEyAvQGFM7xEwXUUywFvEb3Q=
+google.golang.org/genproto/googleapis/api v0.0.0-20241209162323-e6fa225c2576/go.mod h1:1R3kvZ1dtP3+4p4d3G8uJ8rFk/fWlScl38vanWACI08=
+google.golang.org/genproto/googleapis/rpc v0.0.0-20250127172529-29210b9bc287 h1:J1H9f+LEdWAfHcez/4cvaVBox7cOYT+IU6rgqj5x++8=
+google.golang.org/genproto/googleapis/rpc v0.0.0-20250127172529-29210b9bc287/go.mod h1:8BS3B93F/U1juMFq9+EDk+qOT5CO1R9IzXxG3PTqiRk=
google.golang.org/grpc v1.19.0/go.mod h1:mqu4LbDTu4XGKhr4mRzUsmM4RtVoemTSY81AxZiDr8c=
google.golang.org/grpc v1.20.1/go.mod h1:10oTOabMzJvdu6/UiuZezV6QK5dSlG84ov/aaiqXj38=
google.golang.org/grpc v1.21.1/go.mod h1:oYelfM1adQP15Ek0mdvEgi9Df8B9CZIaU1084ijfRaM=
@@ -1188,8 +1187,8 @@ google.golang.org/grpc v1.35.0/go.mod h1:qjiiYl8FncCW8feJPdyg3v6XW24KsRHe+dy9BAG
google.golang.org/grpc v1.36.0/go.mod h1:qjiiYl8FncCW8feJPdyg3v6XW24KsRHe+dy9BAGRRjU=
google.golang.org/grpc v1.36.1/go.mod h1:qjiiYl8FncCW8feJPdyg3v6XW24KsRHe+dy9BAGRRjU=
google.golang.org/grpc v1.38.0/go.mod h1:NREThFqKR1f3iQ6oBuvc5LadQuXVGo9rkm5ZGrQdJfM=
-google.golang.org/grpc v1.64.1 h1:LKtvyfbX3UGVPFcGqJ9ItpVWW6oN/2XqTxfAnwRRXiA=
-google.golang.org/grpc v1.64.1/go.mod h1:hiQF4LFZelK2WKaP6W0L92zGHtiQdZxk8CrSdvyjeP0=
+google.golang.org/grpc v1.70.0 h1:pWFv03aZoHzlRKHWicjsZytKAiYCtNS0dHbXnIdq7jQ=
+google.golang.org/grpc v1.70.0/go.mod h1:ofIJqVKDXx/JiXrwr2IG4/zwdH9txy3IlF40RmcJSQw=
google.golang.org/protobuf v0.0.0-20200109180630-ec00e32a8dfd/go.mod h1:DFci5gLYBciE7Vtevhsrf46CRTquxDuWsQurQQe4oz8=
google.golang.org/protobuf v0.0.0-20200221191635-4d8936d0db64/go.mod h1:kwYJMbMJ01Woi6D6+Kah6886xMZcty6N08ah7+eCXa0=
google.golang.org/protobuf v0.0.0-20200228230310-ab0ca4ff8a60/go.mod h1:cfTl7dwQJ+fmap5saPgwCLgHXTUD7jkjRqWcaiX5VyM=
@@ -1204,8 +1203,8 @@ google.golang.org/protobuf v1.26.0-rc.1/go.mod h1:jlhhOSvTdKEhbULTjvd4ARK9grFBp0
google.golang.org/protobuf v1.26.0/go.mod h1:9q0QmTI4eRPtz6boOQmLYwt+qCgq0jsYwAQnmE0givc=
google.golang.org/protobuf v1.28.0/go.mod h1:HV8QOd/L58Z+nl8r43ehVNZIU/HEI6OcFqwMG9pJV4I=
google.golang.org/protobuf v1.30.0/go.mod h1:HV8QOd/L58Z+nl8r43ehVNZIU/HEI6OcFqwMG9pJV4I=
-google.golang.org/protobuf v1.34.2 h1:6xV6lTsCfpGD21XK49h7MhtcApnLqkfYgPcdHftf6hg=
-google.golang.org/protobuf v1.34.2/go.mod h1:qYOHts0dSfpeUzUFpOMr/WGzszTmLH+DiWniOlNbLDw=
+google.golang.org/protobuf v1.36.4 h1:6A3ZDJHn/eNqc1i+IdefRzy/9PokBTPvcqMySR7NNIM=
+google.golang.org/protobuf v1.36.4/go.mod h1:9fA7Ob0pmnwhb644+1+CVWFRbNajQ6iRojtC/QF5bRE=
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
gopkg.in/check.v1 v1.0.0-20180628173108-788fd7840127/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
gopkg.in/check.v1 v1.0.0-20190902080502-41f04d3bba15/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
diff --git a/management/client/client.go b/management/client/client.go
index e9eeaccc1..950f6137e 100644
--- a/management/client/client.go
+++ b/management/client/client.go
@@ -15,7 +15,7 @@ type Client interface {
io.Closer
Sync(ctx context.Context, sysInfo *system.Info, msgHandler func(msg *proto.SyncResponse) error) error
GetServerPublicKey() (*wgtypes.Key, error)
- Register(serverKey wgtypes.Key, setupKey string, jwtToken string, sysInfo *system.Info, sshKey []byte) (*proto.LoginResponse, error)
+ Register(serverKey wgtypes.Key, setupKey string, jwtToken string, sysInfo *system.Info, sshKey []byte, dnsLabels domain.List) (*proto.LoginResponse, error)
Login(serverKey wgtypes.Key, sysInfo *system.Info, sshKey []byte, dnsLabels domain.List) (*proto.LoginResponse, error)
GetDeviceAuthorizationFlow(serverKey wgtypes.Key) (*proto.DeviceAuthorizationFlow, error)
GetPKCEAuthorizationFlow(serverKey wgtypes.Key) (*proto.PKCEAuthorizationFlow, error)
diff --git a/management/client/client_test.go b/management/client/client_test.go
index 2e9ce8b8b..73427b38a 100644
--- a/management/client/client_test.go
+++ b/management/client/client_test.go
@@ -79,7 +79,7 @@ func startManagement(t *testing.T) (*grpc.Server, net.Listener) {
}
secretsManager := mgmt.NewTimeBasedAuthSecretsManager(peersUpdateManager, config.TURNConfig, config.Relay)
- mgmtServer, err := mgmt.NewServer(context.Background(), config, accountManager, settings.NewManager(store), peersUpdateManager, secretsManager, nil, nil)
+ mgmtServer, err := mgmt.NewServer(context.Background(), config, accountManager, settings.NewManager(store), peersUpdateManager, secretsManager, nil, nil, nil)
if err != nil {
t.Fatal(err)
}
@@ -206,7 +206,7 @@ func TestClient_LoginRegistered(t *testing.T) {
t.Error(err)
}
info := system.GetInfo(context.TODO())
- resp, err := client.Register(*key, ValidKey, "", info, nil)
+ resp, err := client.Register(*key, ValidKey, "", info, nil, nil)
if err != nil {
t.Error(err)
}
@@ -236,7 +236,7 @@ func TestClient_Sync(t *testing.T) {
}
info := system.GetInfo(context.TODO())
- _, err = client.Register(*serverKey, ValidKey, "", info, nil)
+ _, err = client.Register(*serverKey, ValidKey, "", info, nil, nil)
if err != nil {
t.Error(err)
}
@@ -252,7 +252,7 @@ func TestClient_Sync(t *testing.T) {
}
info = system.GetInfo(context.TODO())
- _, err = remoteClient.Register(*serverKey, ValidKey, "", info, nil)
+ _, err = remoteClient.Register(*serverKey, ValidKey, "", info, nil, nil)
if err != nil {
t.Fatal(err)
}
@@ -353,7 +353,7 @@ func Test_SystemMetaDataFromClient(t *testing.T) {
}
info := system.GetInfo(context.TODO())
- _, err = testClient.Register(*key, ValidKey, "", info, nil)
+ _, err = testClient.Register(*key, ValidKey, "", info, nil, nil)
if err != nil {
t.Errorf("error while trying to register client: %v", err)
}
diff --git a/management/client/grpc.go b/management/client/grpc.go
index d02509c27..d3aaffec0 100644
--- a/management/client/grpc.go
+++ b/management/client/grpc.go
@@ -365,12 +365,12 @@ func (c *GrpcClient) login(serverKey wgtypes.Key, req *proto.LoginRequest) (*pro
// Register registers peer on Management Server. It actually calls a Login endpoint with a provided setup key
// Takes care of encrypting and decrypting messages.
// This method will also collect system info and send it with the request (e.g. hostname, os, etc)
-func (c *GrpcClient) Register(serverKey wgtypes.Key, setupKey string, jwtToken string, sysInfo *system.Info, pubSSHKey []byte) (*proto.LoginResponse, error) {
+func (c *GrpcClient) Register(serverKey wgtypes.Key, setupKey string, jwtToken string, sysInfo *system.Info, pubSSHKey []byte, dnsLabels domain.List) (*proto.LoginResponse, error) {
keys := &proto.PeerKeys{
SshPubKey: pubSSHKey,
WgPubKey: []byte(c.key.PublicKey().String()),
}
- return c.login(serverKey, &proto.LoginRequest{SetupKey: setupKey, Meta: infoToMetaData(sysInfo), JwtToken: jwtToken, PeerKeys: keys})
+ return c.login(serverKey, &proto.LoginRequest{SetupKey: setupKey, Meta: infoToMetaData(sysInfo), JwtToken: jwtToken, PeerKeys: keys, DnsLabels: dnsLabels.ToPunycodeList()})
}
// Login attempts login to Management Server. Takes care of encrypting and decrypting messages.
diff --git a/management/client/mock.go b/management/client/mock.go
index 11564093a..9e1786f82 100644
--- a/management/client/mock.go
+++ b/management/client/mock.go
@@ -14,7 +14,7 @@ type MockClient struct {
CloseFunc func() error
SyncFunc func(ctx context.Context, sysInfo *system.Info, msgHandler func(msg *proto.SyncResponse) error) error
GetServerPublicKeyFunc func() (*wgtypes.Key, error)
- RegisterFunc func(serverKey wgtypes.Key, setupKey string, jwtToken string, info *system.Info, sshKey []byte) (*proto.LoginResponse, error)
+ RegisterFunc func(serverKey wgtypes.Key, setupKey string, jwtToken string, info *system.Info, sshKey []byte, dnsLabels domain.List) (*proto.LoginResponse, error)
LoginFunc func(serverKey wgtypes.Key, info *system.Info, sshKey []byte, dnsLabels domain.List) (*proto.LoginResponse, error)
GetDeviceAuthorizationFlowFunc func(serverKey wgtypes.Key) (*proto.DeviceAuthorizationFlow, error)
GetPKCEAuthorizationFlowFunc func(serverKey wgtypes.Key) (*proto.PKCEAuthorizationFlow, error)
@@ -46,11 +46,11 @@ func (m *MockClient) GetServerPublicKey() (*wgtypes.Key, error) {
return m.GetServerPublicKeyFunc()
}
-func (m *MockClient) Register(serverKey wgtypes.Key, setupKey string, jwtToken string, info *system.Info, sshKey []byte) (*proto.LoginResponse, error) {
+func (m *MockClient) Register(serverKey wgtypes.Key, setupKey string, jwtToken string, info *system.Info, sshKey []byte, dnsLabels domain.List) (*proto.LoginResponse, error) {
if m.RegisterFunc == nil {
return nil, nil
}
- return m.RegisterFunc(serverKey, setupKey, jwtToken, info, sshKey)
+ return m.RegisterFunc(serverKey, setupKey, jwtToken, info, sshKey, dnsLabels)
}
func (m *MockClient) Login(serverKey wgtypes.Key, info *system.Info, sshKey []byte, dnsLabels domain.List) (*proto.LoginResponse, error) {
diff --git a/management/cmd/management.go b/management/cmd/management.go
index 469be4d67..64ca958b5 100644
--- a/management/cmd/management.go
+++ b/management/cmd/management.go
@@ -41,13 +41,12 @@ import (
"github.com/netbirdio/netbird/formatter"
mgmtProto "github.com/netbirdio/netbird/management/proto"
"github.com/netbirdio/netbird/management/server"
+ "github.com/netbirdio/netbird/management/server/auth"
nbContext "github.com/netbirdio/netbird/management/server/context"
"github.com/netbirdio/netbird/management/server/geolocation"
"github.com/netbirdio/netbird/management/server/groups"
nbhttp "github.com/netbirdio/netbird/management/server/http"
- "github.com/netbirdio/netbird/management/server/http/configs"
"github.com/netbirdio/netbird/management/server/idp"
- "github.com/netbirdio/netbird/management/server/jwtclaims"
"github.com/netbirdio/netbird/management/server/metrics"
"github.com/netbirdio/netbird/management/server/networks"
"github.com/netbirdio/netbird/management/server/networks/resources"
@@ -264,11 +263,11 @@ var (
tlsEnabled = true
}
- jwtValidator, err := jwtclaims.NewJWTValidator(
- ctx,
+ authManager := auth.NewManager(store,
config.HttpConfig.AuthIssuer,
- config.GetAuthAudiences(),
+ config.HttpConfig.AuthAudience,
config.HttpConfig.AuthKeysLocation,
+<<<<<<< HEAD
config.HttpConfig.IdpSignKeyRefreshEnabled,
)
if err != nil {
@@ -282,12 +281,24 @@ var (
KeysLocation: config.HttpConfig.AuthKeysLocation,
}
+=======
+ config.HttpConfig.AuthUserIDClaim,
+ config.GetAuthAudiences(),
+ config.HttpConfig.IdpSignKeyRefreshEnabled)
+ userManager := users.NewManager(store)
+ settingsManager := settings.NewManager(store)
+ permissionsManager := permissions.NewManager(userManager, settingsManager)
+>>>>>>> main
groupsManager := groups.NewManager(store, permissionsManager, accountManager)
resourcesManager := resources.NewManager(store, permissionsManager, groupsManager, accountManager)
routersManager := routers.NewManager(store, permissionsManager, accountManager)
networksManager := networks.NewManager(store, permissionsManager, resourcesManager, routersManager, accountManager)
+<<<<<<< HEAD
httpAPIHandler, err := nbhttp.NewAPIHandler(ctx, accountManager, networksManager, resourcesManager, routersManager, groupsManager, geo, jwtValidator, appMetrics, httpAPIAuthCfg, integratedPeerValidator, proxyController, permissionsManager, peersManager)
+=======
+ httpAPIHandler, err := nbhttp.NewAPIHandler(ctx, accountManager, networksManager, resourcesManager, routersManager, groupsManager, geo, authManager, appMetrics, config, integratedPeerValidator)
+>>>>>>> main
if err != nil {
return fmt.Errorf("failed creating HTTP API handler: %v", err)
}
@@ -296,7 +307,7 @@ var (
ephemeralManager.LoadInitialPeers(ctx)
gRPCAPIHandler := grpc.NewServer(gRPCOpts...)
- srv, err := server.NewServer(ctx, config, accountManager, settingsManager, peersUpdateManager, secretsManager, appMetrics, ephemeralManager)
+ srv, err := server.NewServer(ctx, config, accountManager, settingsManager, peersUpdateManager, secretsManager, appMetrics, ephemeralManager, authManager)
if err != nil {
return fmt.Errorf("failed creating gRPC API handler: %v", err)
}
diff --git a/management/server/account.go b/management/server/account.go
index c62acf6df..8cb5bec07 100644
--- a/management/server/account.go
+++ b/management/server/account.go
@@ -2,11 +2,8 @@ package server
import (
"context"
- "crypto/sha256"
- b64 "encoding/base64"
"errors"
"fmt"
- "hash/crc32"
"math/rand"
"net"
"net/netip"
@@ -24,15 +21,19 @@ import (
log "github.com/sirupsen/logrus"
"golang.org/x/exp/maps"
- "github.com/netbirdio/netbird/base62"
nbdns "github.com/netbirdio/netbird/dns"
"github.com/netbirdio/netbird/management/domain"
"github.com/netbirdio/netbird/management/server/activity"
+ nbcontext "github.com/netbirdio/netbird/management/server/context"
"github.com/netbirdio/netbird/management/server/geolocation"
"github.com/netbirdio/netbird/management/server/idp"
+<<<<<<< HEAD
"github.com/netbirdio/netbird/management/server/integrations/integrated_validator"
"github.com/netbirdio/netbird/management/server/integrations/port_forwarding"
"github.com/netbirdio/netbird/management/server/jwtclaims"
+=======
+ "github.com/netbirdio/netbird/management/server/integrated_validator"
+>>>>>>> main
nbpeer "github.com/netbirdio/netbird/management/server/peer"
"github.com/netbirdio/netbird/management/server/posture"
"github.com/netbirdio/netbird/management/server/status"
@@ -78,13 +79,10 @@ type AccountManager interface {
GetAccountByID(ctx context.Context, accountID string, userID string) (*types.Account, error)
AccountExists(ctx context.Context, accountID string) (bool, error)
GetAccountIDByUserID(ctx context.Context, userID, domain string) (string, error)
- GetAccountIDFromToken(ctx context.Context, claims jwtclaims.AuthorizationClaims) (string, string, error)
- CheckUserAccessByJWTGroups(ctx context.Context, claims jwtclaims.AuthorizationClaims) error
- GetPATInfo(ctx context.Context, token string) (*types.User, *types.PersonalAccessToken, string, string, error)
+ GetAccountIDFromUserAuth(ctx context.Context, userAuth nbcontext.UserAuth) (string, string, error)
DeleteAccount(ctx context.Context, accountID, userID string) error
- MarkPATUsed(ctx context.Context, tokenID string) error
GetUserByID(ctx context.Context, id string) (*types.User, error)
- GetUser(ctx context.Context, claims jwtclaims.AuthorizationClaims) (*types.User, error)
+ GetUserFromUserAuth(ctx context.Context, userAuth nbcontext.UserAuth) (*types.User, error)
ListUsers(ctx context.Context, accountID string) ([]*types.User, error)
GetPeers(ctx context.Context, accountID, userID, nameFilter, ipFilter string) ([]*nbpeer.Peer, error)
MarkPeerConnected(ctx context.Context, peerKey string, connected bool, realIP net.IP, accountID string) error
@@ -151,6 +149,7 @@ type AccountManager interface {
DeleteSetupKey(ctx context.Context, accountID, userID, keyID string) error
UpdateAccountPeers(ctx context.Context, accountID string)
BuildUserInfosForAccount(ctx context.Context, accountID, initiatorUserID string, accountUsers []*types.User) (map[string]*types.UserInfo, error)
+ SyncUserJWTGroups(ctx context.Context, userAuth nbcontext.UserAuth) error
}
type DefaultAccountManager struct {
@@ -256,6 +255,11 @@ func BuildManager(
metrics telemetry.AppMetrics,
proxyController port_forwarding.Controller,
) (*DefaultAccountManager, error) {
+ start := time.Now()
+ defer func() {
+ log.WithContext(ctx).Debugf("took %v to instantiate account manager", time.Since(start))
+ }()
+
am := &DefaultAccountManager{
Store: store,
geo: geo,
@@ -274,39 +278,21 @@ func BuildManager(
requestBuffer: NewAccountRequestBuffer(ctx, store),
proxyController: proxyController,
}
- allAccounts := store.GetAllAccounts(ctx)
+ accountsCounter, err := store.GetAccountsCounter(ctx)
+ if err != nil {
+ log.WithContext(ctx).Error(err)
+ }
+
// enable single account mode only if configured by user and number of existing accounts is not grater than 1
- am.singleAccountMode = singleAccountModeDomain != "" && len(allAccounts) <= 1
+ am.singleAccountMode = singleAccountModeDomain != "" && accountsCounter <= 1
if am.singleAccountMode {
if !isDomainValid(singleAccountModeDomain) {
return nil, status.Errorf(status.InvalidArgument, "invalid domain \"%s\" provided for a single account mode. Please review your input for --single-account-mode-domain", singleAccountModeDomain)
}
am.singleAccountModeDomain = singleAccountModeDomain
- log.WithContext(ctx).Infof("single account mode enabled, accounts number %d", len(allAccounts))
+ log.WithContext(ctx).Infof("single account mode enabled, accounts number %d", accountsCounter)
} else {
- log.WithContext(ctx).Infof("single account mode disabled, accounts number %d", len(allAccounts))
- }
-
- // if account doesn't have a default group
- // we create 'all' group and add all peers into it
- // also we create default rule with source as destination
- for _, account := range allAccounts {
- shouldSave := false
-
- _, err := account.GetGroupAll()
- if err != nil {
- if err := addAllGroup(account); err != nil {
- return nil, err
- }
- shouldSave = true
- }
-
- if shouldSave {
- err = store.SaveAccount(ctx, account)
- if err != nil {
- return nil, err
- }
- }
+ log.WithContext(ctx).Infof("single account mode disabled, accounts number %d", accountsCounter)
}
goCacheClient := gocache.New(CacheExpirationMax, 30*time.Minute)
@@ -959,11 +945,11 @@ func (am *DefaultAccountManager) removeUserFromCache(ctx context.Context, accoun
}
// updateAccountDomainAttributesIfNotUpToDate updates the account domain attributes if they are not up to date and then, saves the account changes
-func (am *DefaultAccountManager) updateAccountDomainAttributesIfNotUpToDate(ctx context.Context, accountID string, claims jwtclaims.AuthorizationClaims,
+func (am *DefaultAccountManager) updateAccountDomainAttributesIfNotUpToDate(ctx context.Context, accountID string, userAuth nbcontext.UserAuth,
primaryDomain bool,
) error {
- if claims.Domain == "" {
- log.WithContext(ctx).Errorf("claims don't contain a valid domain, skipping domain attributes update. Received claims: %v", claims)
+ if userAuth.Domain == "" {
+ log.WithContext(ctx).Errorf("claims don't contain a valid domain, skipping domain attributes update. Received claims: %v", userAuth)
return nil
}
@@ -976,11 +962,11 @@ func (am *DefaultAccountManager) updateAccountDomainAttributesIfNotUpToDate(ctx
return err
}
- if domainIsUpToDate(accountDomain, domainCategory, claims) {
+ if domainIsUpToDate(accountDomain, domainCategory, userAuth) {
return nil
}
- user, err := am.Store.GetUserByUserID(ctx, store.LockingStrengthShare, claims.UserId)
+ user, err := am.Store.GetUserByUserID(ctx, store.LockingStrengthShare, userAuth.UserId)
if err != nil {
log.WithContext(ctx).Errorf("error getting user: %v", err)
return err
@@ -989,13 +975,13 @@ func (am *DefaultAccountManager) updateAccountDomainAttributesIfNotUpToDate(ctx
newDomain := accountDomain
newCategoty := domainCategory
- lowerDomain := strings.ToLower(claims.Domain)
+ lowerDomain := strings.ToLower(userAuth.Domain)
if accountDomain != lowerDomain && user.HasAdminPower() {
newDomain = lowerDomain
}
if accountDomain == lowerDomain {
- newCategoty = claims.DomainCategory
+ newCategoty = userAuth.DomainCategory
}
return am.Store.UpdateAccountDomainAttributes(ctx, accountID, newDomain, newCategoty, primaryDomain)
@@ -1011,16 +997,16 @@ func (am *DefaultAccountManager) handleExistingUserAccount(
ctx context.Context,
userAccountID string,
domainAccountID string,
- claims jwtclaims.AuthorizationClaims,
+ userAuth nbcontext.UserAuth,
) error {
primaryDomain := domainAccountID == "" || userAccountID == domainAccountID
- err := am.updateAccountDomainAttributesIfNotUpToDate(ctx, userAccountID, claims, primaryDomain)
+ err := am.updateAccountDomainAttributesIfNotUpToDate(ctx, userAccountID, userAuth, primaryDomain)
if err != nil {
return err
}
// we should register the account ID to this user's metadata in our IDP manager
- err = am.addAccountIDToIDPAppMeta(ctx, claims.UserId, userAccountID)
+ err = am.addAccountIDToIDPAppMeta(ctx, userAuth.UserId, userAccountID)
if err != nil {
return err
}
@@ -1030,20 +1016,20 @@ func (am *DefaultAccountManager) handleExistingUserAccount(
// addNewPrivateAccount validates if there is an existing primary account for the domain, if so it adds the new user to that account,
// otherwise it will create a new account and make it primary account for the domain.
-func (am *DefaultAccountManager) addNewPrivateAccount(ctx context.Context, domainAccountID string, claims jwtclaims.AuthorizationClaims) (string, error) {
- if claims.UserId == "" {
+func (am *DefaultAccountManager) addNewPrivateAccount(ctx context.Context, domainAccountID string, userAuth nbcontext.UserAuth) (string, error) {
+ if userAuth.UserId == "" {
return "", fmt.Errorf("user ID is empty")
}
- lowerDomain := strings.ToLower(claims.Domain)
+ lowerDomain := strings.ToLower(userAuth.Domain)
- newAccount, err := am.newAccount(ctx, claims.UserId, lowerDomain)
+ newAccount, err := am.newAccount(ctx, userAuth.UserId, lowerDomain)
if err != nil {
return "", err
}
newAccount.Domain = lowerDomain
- newAccount.DomainCategory = claims.DomainCategory
+ newAccount.DomainCategory = userAuth.DomainCategory
newAccount.IsDomainPrimaryAccount = true
err = am.Store.SaveAccount(ctx, newAccount)
@@ -1051,33 +1037,33 @@ func (am *DefaultAccountManager) addNewPrivateAccount(ctx context.Context, domai
return "", err
}
- err = am.addAccountIDToIDPAppMeta(ctx, claims.UserId, newAccount.Id)
+ err = am.addAccountIDToIDPAppMeta(ctx, userAuth.UserId, newAccount.Id)
if err != nil {
return "", err
}
- am.StoreEvent(ctx, claims.UserId, claims.UserId, newAccount.Id, activity.UserJoined, nil)
+ am.StoreEvent(ctx, userAuth.UserId, userAuth.UserId, newAccount.Id, activity.UserJoined, nil)
return newAccount.Id, nil
}
-func (am *DefaultAccountManager) addNewUserToDomainAccount(ctx context.Context, domainAccountID string, claims jwtclaims.AuthorizationClaims) (string, error) {
+func (am *DefaultAccountManager) addNewUserToDomainAccount(ctx context.Context, domainAccountID string, userAuth nbcontext.UserAuth) (string, error) {
unlockAccount := am.Store.AcquireWriteLockByUID(ctx, domainAccountID)
defer unlockAccount()
- newUser := types.NewRegularUser(claims.UserId)
+ newUser := types.NewRegularUser(userAuth.UserId)
newUser.AccountID = domainAccountID
err := am.Store.SaveUser(ctx, store.LockingStrengthUpdate, newUser)
if err != nil {
return "", err
}
- err = am.addAccountIDToIDPAppMeta(ctx, claims.UserId, domainAccountID)
+ err = am.addAccountIDToIDPAppMeta(ctx, userAuth.UserId, domainAccountID)
if err != nil {
return "", err
}
- am.StoreEvent(ctx, claims.UserId, claims.UserId, domainAccountID, activity.UserJoined, nil)
+ am.StoreEvent(ctx, userAuth.UserId, userAuth.UserId, domainAccountID, activity.UserJoined, nil)
return domainAccountID, nil
}
@@ -1117,76 +1103,11 @@ func (am *DefaultAccountManager) redeemInvite(ctx context.Context, accountID str
return nil
}
-// MarkPATUsed marks a personal access token as used
-func (am *DefaultAccountManager) MarkPATUsed(ctx context.Context, tokenID string) error {
- return am.Store.MarkPATUsed(ctx, store.LockingStrengthUpdate, tokenID)
-}
-
// GetAccount returns an account associated with this account ID.
func (am *DefaultAccountManager) GetAccount(ctx context.Context, accountID string) (*types.Account, error) {
return am.Store.GetAccount(ctx, accountID)
}
-// GetPATInfo retrieves user, personal access token, domain, and category details from a personal access token.
-func (am *DefaultAccountManager) GetPATInfo(ctx context.Context, token string) (user *types.User, pat *types.PersonalAccessToken, domain string, category string, err error) {
- user, pat, err = am.extractPATFromToken(ctx, token)
- if err != nil {
- return nil, nil, "", "", err
- }
-
- domain, category, err = am.Store.GetAccountDomainAndCategory(ctx, store.LockingStrengthShare, user.AccountID)
- if err != nil {
- return nil, nil, "", "", err
- }
-
- return user, pat, domain, category, nil
-}
-
-// extractPATFromToken validates the token structure and retrieves associated User and PAT.
-func (am *DefaultAccountManager) extractPATFromToken(ctx context.Context, token string) (*types.User, *types.PersonalAccessToken, error) {
- if len(token) != types.PATLength {
- return nil, nil, fmt.Errorf("token has incorrect length")
- }
-
- prefix := token[:len(types.PATPrefix)]
- if prefix != types.PATPrefix {
- return nil, nil, fmt.Errorf("token has wrong prefix")
- }
- secret := token[len(types.PATPrefix) : len(types.PATPrefix)+types.PATSecretLength]
- encodedChecksum := token[len(types.PATPrefix)+types.PATSecretLength : len(types.PATPrefix)+types.PATSecretLength+types.PATChecksumLength]
-
- verificationChecksum, err := base62.Decode(encodedChecksum)
- if err != nil {
- return nil, nil, fmt.Errorf("token checksum decoding failed: %w", err)
- }
-
- secretChecksum := crc32.ChecksumIEEE([]byte(secret))
- if secretChecksum != verificationChecksum {
- return nil, nil, fmt.Errorf("token checksum does not match")
- }
-
- hashedToken := sha256.Sum256([]byte(token))
- encodedHashedToken := b64.StdEncoding.EncodeToString(hashedToken[:])
-
- var user *types.User
- var pat *types.PersonalAccessToken
-
- err = am.Store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
- pat, err = transaction.GetPATByHashedToken(ctx, store.LockingStrengthShare, encodedHashedToken)
- if err != nil {
- return err
- }
-
- user, err = transaction.GetUserByPATID(ctx, store.LockingStrengthShare, pat.ID)
- return err
- })
- if err != nil {
- return nil, nil, err
- }
-
- return user, pat, nil
-}
-
// GetAccountByID returns an account associated with this account ID.
func (am *DefaultAccountManager) GetAccountByID(ctx context.Context, accountID string, userID string) (*types.Account, error) {
user, err := am.Store.GetUserByUserID(ctx, store.LockingStrengthShare, userID)
@@ -1201,58 +1122,56 @@ func (am *DefaultAccountManager) GetAccountByID(ctx context.Context, accountID s
return am.Store.GetAccount(ctx, accountID)
}
-// GetAccountIDFromToken returns an account ID associated with this token.
-func (am *DefaultAccountManager) GetAccountIDFromToken(ctx context.Context, claims jwtclaims.AuthorizationClaims) (string, string, error) {
- if claims.UserId == "" {
+func (am *DefaultAccountManager) GetAccountIDFromUserAuth(ctx context.Context, userAuth nbcontext.UserAuth) (string, string, error) {
+ if userAuth.UserId == "" {
return "", "", errors.New(emptyUserID)
}
if am.singleAccountMode && am.singleAccountModeDomain != "" {
// This section is mostly related to self-hosted installations.
// We override incoming domain claims to group users under a single account.
- claims.Domain = am.singleAccountModeDomain
- claims.DomainCategory = types.PrivateCategory
+ userAuth.Domain = am.singleAccountModeDomain
+ userAuth.DomainCategory = types.PrivateCategory
log.WithContext(ctx).Debugf("overriding JWT Domain and DomainCategory claims since single account mode is enabled")
}
- accountID, err := am.getAccountIDWithAuthorizationClaims(ctx, claims)
+ accountID, err := am.getAccountIDWithAuthorizationClaims(ctx, userAuth)
if err != nil {
return "", "", err
}
- user, err := am.Store.GetUserByUserID(ctx, store.LockingStrengthShare, claims.UserId)
+ user, err := am.Store.GetUserByUserID(ctx, store.LockingStrengthShare, userAuth.UserId)
if err != nil {
// this is not really possible because we got an account by user ID
- return "", "", status.Errorf(status.NotFound, "user %s not found", claims.UserId)
+ return "", "", status.Errorf(status.NotFound, "user %s not found", userAuth.UserId)
+ }
+
+ if userAuth.IsChild {
+ return accountID, user.Id, nil
}
if user.AccountID != accountID {
- return "", "", status.Errorf(status.PermissionDenied, "user %s is not part of the account %s", claims.UserId, accountID)
+ return "", "", status.Errorf(status.PermissionDenied, "user %s is not part of the account %s", userAuth.UserId, accountID)
}
- if !user.IsServiceUser && claims.Invited {
+ if !user.IsServiceUser && userAuth.Invited {
err = am.redeemInvite(ctx, accountID, user.Id)
if err != nil {
return "", "", err
}
}
- if err = am.syncJWTGroups(ctx, accountID, claims); err != nil {
- return "", "", err
- }
-
return accountID, user.Id, nil
}
// syncJWTGroups processes the JWT groups for a user, updates the account based on the groups,
// and propagates changes to peers if group propagation is enabled.
-func (am *DefaultAccountManager) syncJWTGroups(ctx context.Context, accountID string, claims jwtclaims.AuthorizationClaims) error {
- if claim, exists := claims.Raw[jwtclaims.IsToken]; exists {
- if isToken, ok := claim.(bool); ok && isToken {
- return nil
- }
+// requires userAuth to have been ValidateAndParseToken and EnsureUserAccessByJWTGroups by the AuthManager
+func (am *DefaultAccountManager) SyncUserJWTGroups(ctx context.Context, userAuth nbcontext.UserAuth) error {
+ if userAuth.IsChild || userAuth.IsPAT {
+ return nil
}
- settings, err := am.Store.GetAccountSettings(ctx, store.LockingStrengthShare, accountID)
+ settings, err := am.Store.GetAccountSettings(ctx, store.LockingStrengthShare, userAuth.AccountId)
if err != nil {
return err
}
@@ -1266,9 +1185,7 @@ func (am *DefaultAccountManager) syncJWTGroups(ctx context.Context, accountID st
return nil
}
- jwtGroupsNames := extractJWTGroups(ctx, settings.JWTGroupsClaimName, claims)
-
- unlockAccount := am.Store.AcquireWriteLockByUID(ctx, accountID)
+ unlockAccount := am.Store.AcquireWriteLockByUID(ctx, userAuth.AccountId)
defer func() {
if unlockAccount != nil {
unlockAccount()
@@ -1280,17 +1197,17 @@ func (am *DefaultAccountManager) syncJWTGroups(ctx context.Context, accountID st
var hasChanges bool
var user *types.User
err = am.Store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
- user, err = transaction.GetUserByUserID(ctx, store.LockingStrengthShare, claims.UserId)
+ user, err = transaction.GetUserByUserID(ctx, store.LockingStrengthShare, userAuth.UserId)
if err != nil {
return fmt.Errorf("error getting user: %w", err)
}
- groups, err := transaction.GetAccountGroups(ctx, store.LockingStrengthShare, accountID)
+ groups, err := transaction.GetAccountGroups(ctx, store.LockingStrengthShare, userAuth.AccountId)
if err != nil {
return fmt.Errorf("error getting account groups: %w", err)
}
- changed, updatedAutoGroups, newGroupsToCreate, err := am.getJWTGroupsChanges(user, groups, jwtGroupsNames)
+ changed, updatedAutoGroups, newGroupsToCreate, err := am.getJWTGroupsChanges(user, groups, userAuth.Groups)
if err != nil {
return fmt.Errorf("error getting JWT groups changes: %w", err)
}
@@ -1315,7 +1232,7 @@ func (am *DefaultAccountManager) syncJWTGroups(ctx context.Context, accountID st
// Propagate changes to peers if group propagation is enabled
if settings.GroupsPropagationEnabled {
- groups, err = transaction.GetAccountGroups(ctx, store.LockingStrengthShare, accountID)
+ groups, err = transaction.GetAccountGroups(ctx, store.LockingStrengthShare, userAuth.AccountId)
if err != nil {
return fmt.Errorf("error getting account groups: %w", err)
}
@@ -1325,7 +1242,7 @@ func (am *DefaultAccountManager) syncJWTGroups(ctx context.Context, accountID st
groupsMap[group.ID] = group
}
- peers, err := transaction.GetUserPeers(ctx, store.LockingStrengthShare, accountID, claims.UserId)
+ peers, err := transaction.GetUserPeers(ctx, store.LockingStrengthShare, userAuth.AccountId, userAuth.UserId)
if err != nil {
return fmt.Errorf("error getting user peers: %w", err)
}
@@ -1339,7 +1256,7 @@ func (am *DefaultAccountManager) syncJWTGroups(ctx context.Context, accountID st
return fmt.Errorf("error saving groups: %w", err)
}
- if err = transaction.IncrementNetworkSerial(ctx, store.LockingStrengthUpdate, accountID); err != nil {
+ if err = transaction.IncrementNetworkSerial(ctx, store.LockingStrengthUpdate, userAuth.AccountId); err != nil {
return fmt.Errorf("error incrementing network serial: %w", err)
}
}
@@ -1357,45 +1274,45 @@ func (am *DefaultAccountManager) syncJWTGroups(ctx context.Context, accountID st
}
for _, g := range addNewGroups {
- group, err := am.Store.GetGroupByID(ctx, store.LockingStrengthShare, accountID, g)
+ group, err := am.Store.GetGroupByID(ctx, store.LockingStrengthShare, userAuth.AccountId, g)
if err != nil {
- log.WithContext(ctx).Debugf("group %s not found while saving user activity event of account %s", g, accountID)
+ log.WithContext(ctx).Debugf("group %s not found while saving user activity event of account %s", g, userAuth.AccountId)
} else {
meta := map[string]any{
"group": group.Name, "group_id": group.ID,
"is_service_user": user.IsServiceUser, "user_name": user.ServiceUserName,
}
- am.StoreEvent(ctx, user.Id, user.Id, accountID, activity.GroupAddedToUser, meta)
+ am.StoreEvent(ctx, user.Id, user.Id, userAuth.AccountId, activity.GroupAddedToUser, meta)
}
}
for _, g := range removeOldGroups {
- group, err := am.Store.GetGroupByID(ctx, store.LockingStrengthShare, accountID, g)
+ group, err := am.Store.GetGroupByID(ctx, store.LockingStrengthShare, userAuth.AccountId, g)
if err != nil {
- log.WithContext(ctx).Debugf("group %s not found while saving user activity event of account %s", g, accountID)
+ log.WithContext(ctx).Debugf("group %s not found while saving user activity event of account %s", g, userAuth.AccountId)
} else {
meta := map[string]any{
"group": group.Name, "group_id": group.ID,
"is_service_user": user.IsServiceUser, "user_name": user.ServiceUserName,
}
- am.StoreEvent(ctx, user.Id, user.Id, accountID, activity.GroupRemovedFromUser, meta)
+ am.StoreEvent(ctx, user.Id, user.Id, userAuth.AccountId, activity.GroupRemovedFromUser, meta)
}
}
if settings.GroupsPropagationEnabled {
- removedGroupAffectsPeers, err := areGroupChangesAffectPeers(ctx, am.Store, accountID, removeOldGroups)
+ removedGroupAffectsPeers, err := areGroupChangesAffectPeers(ctx, am.Store, userAuth.AccountId, removeOldGroups)
if err != nil {
return err
}
- newGroupsAffectsPeers, err := areGroupChangesAffectPeers(ctx, am.Store, accountID, addNewGroups)
+ newGroupsAffectsPeers, err := areGroupChangesAffectPeers(ctx, am.Store, userAuth.AccountId, addNewGroups)
if err != nil {
return err
}
if removedGroupAffectsPeers || newGroupsAffectsPeers {
- log.WithContext(ctx).Tracef("user %s: JWT group membership changed, updating account peers", claims.UserId)
- am.UpdateAccountPeers(ctx, accountID)
+ log.WithContext(ctx).Tracef("user %s: JWT group membership changed, updating account peers", userAuth.UserId)
+ am.UpdateAccountPeers(ctx, userAuth.AccountId)
}
}
@@ -1420,24 +1337,34 @@ func (am *DefaultAccountManager) syncJWTGroups(ctx context.Context, accountID st
// Existing user + Existing account + Existing Indexed Domain -> Nothing changes
//
// Existing user + Existing account + Existing domain reclassified Domain as private -> Nothing changes (index domain)
-func (am *DefaultAccountManager) getAccountIDWithAuthorizationClaims(ctx context.Context, claims jwtclaims.AuthorizationClaims) (string, error) {
+//
+// UserAuth IsChild -> checks that account exists
+func (am *DefaultAccountManager) getAccountIDWithAuthorizationClaims(ctx context.Context, userAuth nbcontext.UserAuth) (string, error) {
log.WithContext(ctx).Tracef("getting account with authorization claims. User ID: \"%s\", Account ID: \"%s\", Domain: \"%s\", Domain Category: \"%s\"",
- claims.UserId, claims.AccountId, claims.Domain, claims.DomainCategory)
+ userAuth.UserId, userAuth.AccountId, userAuth.Domain, userAuth.DomainCategory)
- if claims.UserId == "" {
+ if userAuth.UserId == "" {
return "", errors.New(emptyUserID)
}
- if claims.DomainCategory != types.PrivateCategory || !isDomainValid(claims.Domain) {
- return am.GetAccountIDByUserID(ctx, claims.UserId, claims.Domain)
+ if userAuth.IsChild {
+ exists, err := am.Store.AccountExists(ctx, store.LockingStrengthShare, userAuth.AccountId)
+ if err != nil || !exists {
+ return "", err
+ }
+ return userAuth.AccountId, nil
}
- if claims.AccountId != "" {
- return am.handlePrivateAccountWithIDFromClaim(ctx, claims)
+ if userAuth.DomainCategory != types.PrivateCategory || !isDomainValid(userAuth.Domain) {
+ return am.GetAccountIDByUserID(ctx, userAuth.UserId, userAuth.Domain)
+ }
+
+ if userAuth.AccountId != "" {
+ return am.handlePrivateAccountWithIDFromClaim(ctx, userAuth)
}
// We checked if the domain has a primary account already
- domainAccountID, cancel, err := am.getPrivateDomainWithGlobalLock(ctx, claims.Domain)
+ domainAccountID, cancel, err := am.getPrivateDomainWithGlobalLock(ctx, userAuth.Domain)
if cancel != nil {
defer cancel()
}
@@ -1445,14 +1372,14 @@ func (am *DefaultAccountManager) getAccountIDWithAuthorizationClaims(ctx context
return "", err
}
- userAccountID, err := am.Store.GetAccountIDByUserID(ctx, store.LockingStrengthShare, claims.UserId)
+ userAccountID, err := am.Store.GetAccountIDByUserID(ctx, store.LockingStrengthShare, userAuth.UserId)
if handleNotFound(err) != nil {
log.WithContext(ctx).Errorf("error getting account ID by user ID: %v", err)
return "", err
}
if userAccountID != "" {
- if err = am.handleExistingUserAccount(ctx, userAccountID, domainAccountID, claims); err != nil {
+ if err = am.handleExistingUserAccount(ctx, userAccountID, domainAccountID, userAuth); err != nil {
return "", err
}
@@ -1460,10 +1387,10 @@ func (am *DefaultAccountManager) getAccountIDWithAuthorizationClaims(ctx context
}
if domainAccountID != "" {
- return am.addNewUserToDomainAccount(ctx, domainAccountID, claims)
+ return am.addNewUserToDomainAccount(ctx, domainAccountID, userAuth)
}
- return am.addNewPrivateAccount(ctx, domainAccountID, claims)
+ return am.addNewPrivateAccount(ctx, domainAccountID, userAuth)
}
func (am *DefaultAccountManager) getPrivateDomainWithGlobalLock(ctx context.Context, domain string) (string, context.CancelFunc, error) {
domainAccountID, err := am.Store.GetAccountIDByPrivateDomain(ctx, store.LockingStrengthShare, domain)
@@ -1491,40 +1418,40 @@ func (am *DefaultAccountManager) getPrivateDomainWithGlobalLock(ctx context.Cont
return domainAccountID, cancel, nil
}
-func (am *DefaultAccountManager) handlePrivateAccountWithIDFromClaim(ctx context.Context, claims jwtclaims.AuthorizationClaims) (string, error) {
- userAccountID, err := am.Store.GetAccountIDByUserID(ctx, store.LockingStrengthShare, claims.UserId)
+func (am *DefaultAccountManager) handlePrivateAccountWithIDFromClaim(ctx context.Context, userAuth nbcontext.UserAuth) (string, error) {
+ userAccountID, err := am.Store.GetAccountIDByUserID(ctx, store.LockingStrengthShare, userAuth.UserId)
if err != nil {
log.WithContext(ctx).Errorf("error getting account ID by user ID: %v", err)
return "", err
}
- if userAccountID != claims.AccountId {
- return "", fmt.Errorf("user %s is not part of the account id %s", claims.UserId, claims.AccountId)
+ if userAccountID != userAuth.AccountId {
+ return "", fmt.Errorf("user %s is not part of the account id %s", userAuth.UserId, userAuth.AccountId)
}
- accountDomain, domainCategory, err := am.Store.GetAccountDomainAndCategory(ctx, store.LockingStrengthShare, claims.AccountId)
+ accountDomain, domainCategory, err := am.Store.GetAccountDomainAndCategory(ctx, store.LockingStrengthShare, userAuth.AccountId)
if handleNotFound(err) != nil {
log.WithContext(ctx).Errorf("error getting account domain and category: %v", err)
return "", err
}
- if domainIsUpToDate(accountDomain, domainCategory, claims) {
- return claims.AccountId, nil
+ if domainIsUpToDate(accountDomain, domainCategory, userAuth) {
+ return userAuth.AccountId, nil
}
// We checked if the domain has a primary account already
- domainAccountID, err := am.Store.GetAccountIDByPrivateDomain(ctx, store.LockingStrengthShare, claims.Domain)
+ domainAccountID, err := am.Store.GetAccountIDByPrivateDomain(ctx, store.LockingStrengthShare, userAuth.Domain)
if handleNotFound(err) != nil {
log.WithContext(ctx).Errorf(errorGettingDomainAccIDFmt, err)
return "", err
}
- err = am.handleExistingUserAccount(ctx, claims.AccountId, domainAccountID, claims)
+ err = am.handleExistingUserAccount(ctx, userAuth.AccountId, domainAccountID, userAuth)
if err != nil {
return "", err
}
- return claims.AccountId, nil
+ return userAuth.AccountId, nil
}
func handleNotFound(err error) error {
@@ -1539,8 +1466,8 @@ func handleNotFound(err error) error {
return nil
}
-func domainIsUpToDate(domain string, domainCategory string, claims jwtclaims.AuthorizationClaims) bool {
- return domainCategory == types.PrivateCategory || claims.DomainCategory != types.PrivateCategory || domain != claims.Domain
+func domainIsUpToDate(domain string, domainCategory string, userAuth nbcontext.UserAuth) bool {
+ return domainCategory == types.PrivateCategory || userAuth.DomainCategory != types.PrivateCategory || domain != userAuth.Domain
}
func (am *DefaultAccountManager) SyncAndMarkPeer(ctx context.Context, accountID string, peerPubKey string, meta nbpeer.PeerSystemMeta, realIP net.IP) (*nbpeer.Peer, *types.NetworkMap, []*posture.Checks, error) {
@@ -1622,34 +1549,6 @@ func (am *DefaultAccountManager) GetDNSDomain() string {
return am.dnsDomain
}
-// CheckUserAccessByJWTGroups checks if the user has access, particularly in cases where the admin enabled JWT
-// group propagation and set the list of groups with access permissions.
-func (am *DefaultAccountManager) CheckUserAccessByJWTGroups(ctx context.Context, claims jwtclaims.AuthorizationClaims) error {
- accountID, _, err := am.GetAccountIDFromToken(ctx, claims)
- if err != nil {
- return err
- }
-
- settings, err := am.Store.GetAccountSettings(ctx, store.LockingStrengthShare, accountID)
- if err != nil {
- return err
- }
-
- // Ensures JWT group synchronization to the management is enabled before,
- // filtering access based on the allowed groups.
- if settings != nil && settings.JWTGroupsEnabled {
- if allowedGroups := settings.JWTAllowGroups; len(allowedGroups) > 0 {
- userJWTGroups := extractJWTGroups(ctx, settings.JWTGroupsClaimName, claims)
-
- if !userHasAllowedGroup(allowedGroups, userJWTGroups) {
- return fmt.Errorf("user does not belong to any of the allowed JWT groups")
- }
- }
- }
-
- return nil
-}
-
func (am *DefaultAccountManager) onPeersInvalidated(ctx context.Context, accountID string) {
log.WithContext(ctx).Debugf("validated peers has been invalidated for account %s", accountID)
am.UpdateAccountPeers(ctx, accountID)
@@ -1717,46 +1616,6 @@ func (am *DefaultAccountManager) GetAccountSettings(ctx context.Context, account
return am.Store.GetAccountSettings(ctx, store.LockingStrengthShare, accountID)
}
-// addAllGroup to account object if it doesn't exist
-func addAllGroup(account *types.Account) error {
- if len(account.Groups) == 0 {
- allGroup := &types.Group{
- ID: xid.New().String(),
- Name: "All",
- Issued: types.GroupIssuedAPI,
- }
- for _, peer := range account.Peers {
- allGroup.Peers = append(allGroup.Peers, peer.ID)
- }
- account.Groups = map[string]*types.Group{allGroup.ID: allGroup}
-
- id := xid.New().String()
-
- defaultPolicy := &types.Policy{
- ID: id,
- Name: types.DefaultRuleName,
- Description: types.DefaultRuleDescription,
- Enabled: true,
- Rules: []*types.PolicyRule{
- {
- ID: id,
- Name: types.DefaultRuleName,
- Description: types.DefaultRuleDescription,
- Enabled: true,
- Sources: []string{allGroup.ID},
- Destinations: []string{allGroup.ID},
- Bidirectional: true,
- Protocol: types.PolicyRuleProtocolALL,
- Action: types.PolicyTrafficActionAccept,
- },
- },
- }
-
- account.Policies = []*types.Policy{defaultPolicy}
- }
- return nil
-}
-
// newAccountWithId creates a new Account with a default SetupKey (doesn't store in a Store) and provided id
func newAccountWithId(ctx context.Context, accountID, userID, domain string) *types.Account {
log.WithContext(ctx).Debugf("creating new account")
@@ -1801,45 +1660,12 @@ func newAccountWithId(ctx context.Context, accountID, userID, domain string) *ty
},
}
- if err := addAllGroup(acc); err != nil {
+ if err := acc.AddAllGroup(); err != nil {
log.WithContext(ctx).Errorf("error adding all group to account %s: %v", acc.Id, err)
}
return acc
}
-// extractJWTGroups extracts the group names from a JWT token's claims.
-func extractJWTGroups(ctx context.Context, claimName string, claims jwtclaims.AuthorizationClaims) []string {
- userJWTGroups := make([]string, 0)
-
- if claim, ok := claims.Raw[claimName]; ok {
- if claimGroups, ok := claim.([]interface{}); ok {
- for _, g := range claimGroups {
- if group, ok := g.(string); ok {
- userJWTGroups = append(userJWTGroups, group)
- } else {
- log.WithContext(ctx).Debugf("JWT claim %q contains a non-string group (type: %T): %v", claimName, g, g)
- }
- }
- }
- } else {
- log.WithContext(ctx).Debugf("JWT claim %q is not a string array", claimName)
- }
-
- return userJWTGroups
-}
-
-// userHasAllowedGroup checks if a user belongs to any of the allowed groups.
-func userHasAllowedGroup(allowedGroups []string, userGroups []string) bool {
- for _, userGroup := range userGroups {
- for _, allowedGroup := range allowedGroups {
- if userGroup == allowedGroup {
- return true
- }
- }
- }
- return false
-}
-
// separateGroups separates user's auto groups into non-JWT and JWT groups.
// Returns the list of standard auto groups and a map of JWT auto groups,
// where the keys are the group names and the values are the group IDs.
diff --git a/management/server/account_test.go b/management/server/account_test.go
index 802cc8bd2..8a042f4c3 100644
--- a/management/server/account_test.go
+++ b/management/server/account_test.go
@@ -2,8 +2,6 @@ package server
import (
"context"
- "crypto/sha256"
- b64 "encoding/base64"
"encoding/json"
"fmt"
"io"
@@ -15,9 +13,12 @@ import (
"testing"
"time"
+<<<<<<< HEAD
"github.com/golang-jwt/jwt"
"github.com/netbirdio/netbird/management/server/integrations/port_forwarding"
+=======
+>>>>>>> main
"github.com/netbirdio/netbird/management/server/util"
resourceTypes "github.com/netbirdio/netbird/management/server/networks/resources/types"
@@ -31,7 +32,7 @@ import (
nbdns "github.com/netbirdio/netbird/dns"
"github.com/netbirdio/netbird/management/server/activity"
- "github.com/netbirdio/netbird/management/server/jwtclaims"
+ nbcontext "github.com/netbirdio/netbird/management/server/context"
nbpeer "github.com/netbirdio/netbird/management/server/peer"
"github.com/netbirdio/netbird/management/server/posture"
"github.com/netbirdio/netbird/management/server/store"
@@ -438,7 +439,7 @@ func TestAccountManager_GetOrCreateAccountByUser(t *testing.T) {
}
func TestDefaultAccountManager_GetAccountIDFromToken(t *testing.T) {
- type initUserParams jwtclaims.AuthorizationClaims
+ type initUserParams nbcontext.UserAuth
var (
publicDomain = "public.com"
@@ -461,7 +462,7 @@ func TestDefaultAccountManager_GetAccountIDFromToken(t *testing.T) {
testCases := []struct {
name string
- inputClaims jwtclaims.AuthorizationClaims
+ inputClaims nbcontext.UserAuth
inputInitUserParams initUserParams
inputUpdateAttrs bool
inputUpdateClaimAccount bool
@@ -476,7 +477,7 @@ func TestDefaultAccountManager_GetAccountIDFromToken(t *testing.T) {
}{
{
name: "New User With Public Domain",
- inputClaims: jwtclaims.AuthorizationClaims{
+ inputClaims: nbcontext.UserAuth{
Domain: publicDomain,
UserId: "pub-domain-user",
DomainCategory: types.PublicCategory,
@@ -493,7 +494,7 @@ func TestDefaultAccountManager_GetAccountIDFromToken(t *testing.T) {
},
{
name: "New User With Unknown Domain",
- inputClaims: jwtclaims.AuthorizationClaims{
+ inputClaims: nbcontext.UserAuth{
Domain: unknownDomain,
UserId: "unknown-domain-user",
DomainCategory: types.UnknownCategory,
@@ -510,7 +511,7 @@ func TestDefaultAccountManager_GetAccountIDFromToken(t *testing.T) {
},
{
name: "New User With Private Domain",
- inputClaims: jwtclaims.AuthorizationClaims{
+ inputClaims: nbcontext.UserAuth{
Domain: privateDomain,
UserId: "pvt-domain-user",
DomainCategory: types.PrivateCategory,
@@ -527,7 +528,7 @@ func TestDefaultAccountManager_GetAccountIDFromToken(t *testing.T) {
},
{
name: "New Regular User With Existing Private Domain",
- inputClaims: jwtclaims.AuthorizationClaims{
+ inputClaims: nbcontext.UserAuth{
Domain: privateDomain,
UserId: "new-pvt-domain-user",
DomainCategory: types.PrivateCategory,
@@ -545,7 +546,7 @@ func TestDefaultAccountManager_GetAccountIDFromToken(t *testing.T) {
},
{
name: "Existing User With Existing Reclassified Private Domain",
- inputClaims: jwtclaims.AuthorizationClaims{
+ inputClaims: nbcontext.UserAuth{
Domain: defaultInitAccount.Domain,
UserId: defaultInitAccount.UserId,
DomainCategory: types.PrivateCategory,
@@ -562,7 +563,7 @@ func TestDefaultAccountManager_GetAccountIDFromToken(t *testing.T) {
},
{
name: "Existing Account Id With Existing Reclassified Private Domain",
- inputClaims: jwtclaims.AuthorizationClaims{
+ inputClaims: nbcontext.UserAuth{
Domain: defaultInitAccount.Domain,
UserId: defaultInitAccount.UserId,
DomainCategory: types.PrivateCategory,
@@ -580,7 +581,7 @@ func TestDefaultAccountManager_GetAccountIDFromToken(t *testing.T) {
},
{
name: "User With Private Category And Empty Domain",
- inputClaims: jwtclaims.AuthorizationClaims{
+ inputClaims: nbcontext.UserAuth{
Domain: "",
UserId: "pvt-domain-user",
DomainCategory: types.PrivateCategory,
@@ -609,7 +610,7 @@ func TestDefaultAccountManager_GetAccountIDFromToken(t *testing.T) {
require.NoError(t, err, "get init account failed")
if testCase.inputUpdateAttrs {
- err = manager.updateAccountDomainAttributesIfNotUpToDate(context.Background(), initAccount.Id, jwtclaims.AuthorizationClaims{UserId: testCase.inputInitUserParams.UserId, Domain: testCase.inputInitUserParams.Domain, DomainCategory: testCase.inputInitUserParams.DomainCategory}, true)
+ err = manager.updateAccountDomainAttributesIfNotUpToDate(context.Background(), initAccount.Id, nbcontext.UserAuth{UserId: testCase.inputInitUserParams.UserId, Domain: testCase.inputInitUserParams.Domain, DomainCategory: testCase.inputInitUserParams.DomainCategory}, true)
require.NoError(t, err, "update init user failed")
}
@@ -617,7 +618,7 @@ func TestDefaultAccountManager_GetAccountIDFromToken(t *testing.T) {
testCase.inputClaims.AccountId = initAccount.Id
}
- accountID, _, err = manager.GetAccountIDFromToken(context.Background(), testCase.inputClaims)
+ accountID, _, err = manager.GetAccountIDFromUserAuth(context.Background(), testCase.inputClaims)
require.NoError(t, err, "support function failed")
account, err := manager.Store.GetAccount(context.Background(), accountID)
@@ -636,14 +637,12 @@ func TestDefaultAccountManager_GetAccountIDFromToken(t *testing.T) {
}
}
-func TestDefaultAccountManager_GetGroupsFromTheToken(t *testing.T) {
+func TestDefaultAccountManager_SyncUserJWTGroups(t *testing.T) {
userId := "user-id"
domain := "test.domain"
-
_ = newAccountWithId(context.Background(), "", userId, domain)
manager, err := createManager(t)
require.NoError(t, err, "unable to create account manager")
-
accountID, err := manager.GetAccountIDByUserID(context.Background(), userId, domain)
require.NoError(t, err, "create init user failed")
// as initAccount was created without account id we have to take the id after account initialization
@@ -651,65 +650,50 @@ func TestDefaultAccountManager_GetGroupsFromTheToken(t *testing.T) {
// it is important to set the id as it help to avoid creating additional account with empty Id and re-pointing indices to it
initAccount, err := manager.Store.GetAccount(context.Background(), accountID)
require.NoError(t, err, "get init account failed")
-
- claims := jwtclaims.AuthorizationClaims{
+ claims := nbcontext.UserAuth{
AccountId: accountID, // is empty as it is based on accountID right after initialization of initAccount
Domain: domain,
UserId: userId,
DomainCategory: "test-category",
- Raw: jwt.MapClaims{"idp-groups": []interface{}{"group1", "group2"}},
+ Groups: []string{"group1", "group2"},
}
-
t.Run("JWT groups disabled", func(t *testing.T) {
- accountID, _, err := manager.GetAccountIDFromToken(context.Background(), claims)
- require.NoError(t, err, "get account by token failed")
-
+ err := manager.SyncUserJWTGroups(context.Background(), claims)
+ require.NoError(t, err, "synt user jwt groups failed")
account, err := manager.Store.GetAccount(context.Background(), accountID)
require.NoError(t, err, "get account failed")
-
require.Len(t, account.Groups, 1, "only ALL group should exists")
})
-
t.Run("JWT groups enabled without claim name", func(t *testing.T) {
initAccount.Settings.JWTGroupsEnabled = true
err := manager.Store.SaveAccount(context.Background(), initAccount)
require.NoError(t, err, "save account failed")
require.Len(t, manager.Store.GetAllAccounts(context.Background()), 1, "only one account should exist")
-
- accountID, _, err := manager.GetAccountIDFromToken(context.Background(), claims)
- require.NoError(t, err, "get account by token failed")
-
+ err = manager.SyncUserJWTGroups(context.Background(), claims)
+ require.NoError(t, err, "synt user jwt groups failed")
account, err := manager.Store.GetAccount(context.Background(), accountID)
require.NoError(t, err, "get account failed")
-
require.Len(t, account.Groups, 1, "if group claim is not set no group added from JWT")
})
-
t.Run("JWT groups enabled", func(t *testing.T) {
initAccount.Settings.JWTGroupsEnabled = true
initAccount.Settings.JWTGroupsClaimName = "idp-groups"
err := manager.Store.SaveAccount(context.Background(), initAccount)
require.NoError(t, err, "save account failed")
require.Len(t, manager.Store.GetAllAccounts(context.Background()), 1, "only one account should exist")
-
- accountID, _, err := manager.GetAccountIDFromToken(context.Background(), claims)
- require.NoError(t, err, "get account by token failed")
-
+ err = manager.SyncUserJWTGroups(context.Background(), claims)
+ require.NoError(t, err, "synt user jwt groups failed")
account, err := manager.Store.GetAccount(context.Background(), accountID)
require.NoError(t, err, "get account failed")
-
require.Len(t, account.Groups, 3, "groups should be added to the account")
-
groupsByNames := map[string]*types.Group{}
for _, g := range account.Groups {
groupsByNames[g.Name] = g
}
-
g1, ok := groupsByNames["group1"]
require.True(t, ok, "group1 should be added to the account")
require.Equal(t, g1.Name, "group1", "group1 name should match")
require.Equal(t, g1.Issued, types.GroupIssuedJWT, "group1 issued should match")
-
g2, ok := groupsByNames["group2"]
require.True(t, ok, "group2 should be added to the account")
require.Equal(t, g2.Name, "group2", "group2 name should match")
@@ -717,88 +701,6 @@ func TestDefaultAccountManager_GetGroupsFromTheToken(t *testing.T) {
})
}
-func TestAccountManager_GetAccountFromPAT(t *testing.T) {
- store, cleanup, err := store.NewTestStoreFromSQL(context.Background(), "", t.TempDir())
- if err != nil {
- t.Fatalf("Error when creating store: %s", err)
- }
- t.Cleanup(cleanup)
- account := newAccountWithId(context.Background(), "account_id", "testuser", "")
-
- token := "nbp_9999EUDNdkeusjentDLSJEn1902u84390W6W"
- hashedToken := sha256.Sum256([]byte(token))
- encodedHashedToken := b64.StdEncoding.EncodeToString(hashedToken[:])
- account.Users["someUser"] = &types.User{
- Id: "someUser",
- PATs: map[string]*types.PersonalAccessToken{
- "tokenId": {
- ID: "tokenId",
- UserID: "someUser",
- HashedToken: encodedHashedToken,
- },
- },
- }
- err = store.SaveAccount(context.Background(), account)
- if err != nil {
- t.Fatalf("Error when saving account: %s", err)
- }
-
- am := DefaultAccountManager{
- Store: store,
- }
-
- user, pat, _, _, err := am.GetPATInfo(context.Background(), token)
- if err != nil {
- t.Fatalf("Error when getting Account from PAT: %s", err)
- }
-
- assert.Equal(t, "account_id", user.AccountID)
- assert.Equal(t, "someUser", user.Id)
- assert.Equal(t, account.Users["someUser"].PATs["tokenId"].ID, pat.ID)
-}
-
-func TestDefaultAccountManager_MarkPATUsed(t *testing.T) {
- store, cleanup, err := store.NewTestStoreFromSQL(context.Background(), "", t.TempDir())
- if err != nil {
- t.Fatalf("Error when creating store: %s", err)
- }
- t.Cleanup(cleanup)
-
- account := newAccountWithId(context.Background(), "account_id", "testuser", "")
-
- token := "nbp_9999EUDNdkeusjentDLSJEn1902u84390W6W"
- hashedToken := sha256.Sum256([]byte(token))
- encodedHashedToken := b64.StdEncoding.EncodeToString(hashedToken[:])
- account.Users["someUser"] = &types.User{
- Id: "someUser",
- PATs: map[string]*types.PersonalAccessToken{
- "tokenId": {
- ID: "tokenId",
- HashedToken: encodedHashedToken,
- },
- },
- }
- err = store.SaveAccount(context.Background(), account)
- if err != nil {
- t.Fatalf("Error when saving account: %s", err)
- }
-
- am := DefaultAccountManager{
- Store: store,
- }
-
- err = am.MarkPATUsed(context.Background(), "tokenId")
- if err != nil {
- t.Fatalf("Error when marking PAT used: %s", err)
- }
-
- account, err = am.Store.GetAccount(context.Background(), "account_id")
- if err != nil {
- t.Fatalf("Error when getting account: %s", err)
- }
- assert.True(t, !account.Users["someUser"].PATs["tokenId"].GetLastUsed().IsZero())
-}
-
func TestAccountManager_PrivateAccount(t *testing.T) {
manager, err := createManager(t)
if err != nil {
@@ -963,13 +865,13 @@ func TestAccountManager_DeleteAccount(t *testing.T) {
}
func BenchmarkTest_GetAccountWithclaims(b *testing.B) {
- claims := jwtclaims.AuthorizationClaims{
+ claims := nbcontext.UserAuth{
Domain: "example.com",
UserId: "pvt-domain-user",
DomainCategory: types.PrivateCategory,
}
- publicClaims := jwtclaims.AuthorizationClaims{
+ publicClaims := nbcontext.UserAuth{
Domain: "test.com",
UserId: "public-domain-user",
DomainCategory: types.PublicCategory,
@@ -2684,11 +2586,13 @@ func TestAccount_SetJWTGroups(t *testing.T) {
assert.NoError(t, manager.Store.SaveAccount(context.Background(), account), "unable to save account")
t.Run("skip sync for token auth type", func(t *testing.T) {
- claims := jwtclaims.AuthorizationClaims{
- UserId: "user1",
- Raw: jwt.MapClaims{"groups": []interface{}{"group3"}, "is_token": true},
+ claims := nbcontext.UserAuth{
+ UserId: "user1",
+ AccountId: "accountID",
+ Groups: []string{"group3"},
+ IsPAT: true,
}
- err = manager.syncJWTGroups(context.Background(), "accountID", claims)
+ err = manager.SyncUserJWTGroups(context.Background(), claims)
assert.NoError(t, err, "unable to sync jwt groups")
user, err := manager.Store.GetUserByUserID(context.Background(), store.LockingStrengthShare, "user1")
@@ -2697,11 +2601,12 @@ func TestAccount_SetJWTGroups(t *testing.T) {
})
t.Run("empty jwt groups", func(t *testing.T) {
- claims := jwtclaims.AuthorizationClaims{
- UserId: "user1",
- Raw: jwt.MapClaims{"groups": []interface{}{}},
+ claims := nbcontext.UserAuth{
+ UserId: "user1",
+ AccountId: "accountID",
+ Groups: []string{},
}
- err := manager.syncJWTGroups(context.Background(), "accountID", claims)
+ err := manager.SyncUserJWTGroups(context.Background(), claims)
assert.NoError(t, err, "unable to sync jwt groups")
user, err := manager.Store.GetUserByUserID(context.Background(), store.LockingStrengthShare, "user1")
@@ -2710,11 +2615,12 @@ func TestAccount_SetJWTGroups(t *testing.T) {
})
t.Run("jwt match existing api group", func(t *testing.T) {
- claims := jwtclaims.AuthorizationClaims{
- UserId: "user1",
- Raw: jwt.MapClaims{"groups": []interface{}{"group1"}},
+ claims := nbcontext.UserAuth{
+ UserId: "user1",
+ AccountId: "accountID",
+ Groups: []string{"group1"},
}
- err := manager.syncJWTGroups(context.Background(), "accountID", claims)
+ err := manager.SyncUserJWTGroups(context.Background(), claims)
assert.NoError(t, err, "unable to sync jwt groups")
user, err := manager.Store.GetUserByUserID(context.Background(), store.LockingStrengthShare, "user1")
@@ -2730,11 +2636,12 @@ func TestAccount_SetJWTGroups(t *testing.T) {
account.Users["user1"].AutoGroups = []string{"group1"}
assert.NoError(t, manager.Store.SaveUser(context.Background(), store.LockingStrengthUpdate, account.Users["user1"]))
- claims := jwtclaims.AuthorizationClaims{
- UserId: "user1",
- Raw: jwt.MapClaims{"groups": []interface{}{"group1"}},
+ claims := nbcontext.UserAuth{
+ UserId: "user1",
+ AccountId: "accountID",
+ Groups: []string{"group1"},
}
- err = manager.syncJWTGroups(context.Background(), "accountID", claims)
+ err = manager.SyncUserJWTGroups(context.Background(), claims)
assert.NoError(t, err, "unable to sync jwt groups")
user, err := manager.Store.GetUserByUserID(context.Background(), store.LockingStrengthShare, "user1")
@@ -2747,11 +2654,12 @@ func TestAccount_SetJWTGroups(t *testing.T) {
})
t.Run("add jwt group", func(t *testing.T) {
- claims := jwtclaims.AuthorizationClaims{
- UserId: "user1",
- Raw: jwt.MapClaims{"groups": []interface{}{"group1", "group2"}},
+ claims := nbcontext.UserAuth{
+ UserId: "user1",
+ AccountId: "accountID",
+ Groups: []string{"group1", "group2"},
}
- err = manager.syncJWTGroups(context.Background(), "accountID", claims)
+ err = manager.SyncUserJWTGroups(context.Background(), claims)
assert.NoError(t, err, "unable to sync jwt groups")
user, err := manager.Store.GetUserByUserID(context.Background(), store.LockingStrengthShare, "user1")
@@ -2760,11 +2668,12 @@ func TestAccount_SetJWTGroups(t *testing.T) {
})
t.Run("existed group not update", func(t *testing.T) {
- claims := jwtclaims.AuthorizationClaims{
- UserId: "user1",
- Raw: jwt.MapClaims{"groups": []interface{}{"group2"}},
+ claims := nbcontext.UserAuth{
+ UserId: "user1",
+ AccountId: "accountID",
+ Groups: []string{"group2"},
}
- err = manager.syncJWTGroups(context.Background(), "accountID", claims)
+ err = manager.SyncUserJWTGroups(context.Background(), claims)
assert.NoError(t, err, "unable to sync jwt groups")
user, err := manager.Store.GetUserByUserID(context.Background(), store.LockingStrengthShare, "user1")
@@ -2773,11 +2682,12 @@ func TestAccount_SetJWTGroups(t *testing.T) {
})
t.Run("add new group", func(t *testing.T) {
- claims := jwtclaims.AuthorizationClaims{
- UserId: "user2",
- Raw: jwt.MapClaims{"groups": []interface{}{"group1", "group3"}},
+ claims := nbcontext.UserAuth{
+ UserId: "user2",
+ AccountId: "accountID",
+ Groups: []string{"group1", "group3"},
}
- err = manager.syncJWTGroups(context.Background(), "accountID", claims)
+ err = manager.SyncUserJWTGroups(context.Background(), claims)
assert.NoError(t, err, "unable to sync jwt groups")
groups, err := manager.Store.GetAccountGroups(context.Background(), store.LockingStrengthShare, "accountID")
@@ -2790,11 +2700,12 @@ func TestAccount_SetJWTGroups(t *testing.T) {
})
t.Run("remove all JWT groups when list is empty", func(t *testing.T) {
- claims := jwtclaims.AuthorizationClaims{
- UserId: "user1",
- Raw: jwt.MapClaims{"groups": []interface{}{}},
+ claims := nbcontext.UserAuth{
+ UserId: "user1",
+ AccountId: "accountID",
+ Groups: []string{},
}
- err = manager.syncJWTGroups(context.Background(), "accountID", claims)
+ err = manager.SyncUserJWTGroups(context.Background(), claims)
assert.NoError(t, err, "unable to sync jwt groups")
user, err := manager.Store.GetUserByUserID(context.Background(), store.LockingStrengthShare, "user1")
@@ -2804,11 +2715,12 @@ func TestAccount_SetJWTGroups(t *testing.T) {
})
t.Run("remove all JWT groups when claim does not exist", func(t *testing.T) {
- claims := jwtclaims.AuthorizationClaims{
- UserId: "user2",
- Raw: jwt.MapClaims{},
+ claims := nbcontext.UserAuth{
+ UserId: "user2",
+ AccountId: "accountID",
+ Groups: []string{},
}
- err = manager.syncJWTGroups(context.Background(), "accountID", claims)
+ err = manager.SyncUserJWTGroups(context.Background(), claims)
assert.NoError(t, err, "unable to sync jwt groups")
user, err := manager.Store.GetUserByUserID(context.Background(), store.LockingStrengthShare, "user2")
diff --git a/management/server/jwtclaims/extractor.go b/management/server/auth/jwt/extractor.go
similarity index 51%
rename from management/server/jwtclaims/extractor.go
rename to management/server/auth/jwt/extractor.go
index 18214b434..fab429125 100644
--- a/management/server/jwtclaims/extractor.go
+++ b/management/server/auth/jwt/extractor.go
@@ -1,15 +1,17 @@
-package jwtclaims
+package jwt
import (
- "net/http"
+ "errors"
+ "net/url"
"time"
"github.com/golang-jwt/jwt"
+ log "github.com/sirupsen/logrus"
+
+ nbcontext "github.com/netbirdio/netbird/management/server/context"
)
const (
- // TokenUserProperty key for the user property in the request context
- TokenUserProperty = "user"
// AccountIDSuffix suffix for the account id claim
AccountIDSuffix = "wt_account_id"
// DomainIDSuffix suffix for the domain id claim
@@ -22,19 +24,16 @@ const (
LastLoginSuffix = "nb_last_login"
// Invited claim indicates that an incoming JWT is from a user that just accepted an invitation
Invited = "nb_invited"
- // IsToken claim indicates that auth type from the user is a token
- IsToken = "is_token"
)
-// ExtractClaims Extract function type
-type ExtractClaims func(r *http.Request) AuthorizationClaims
+var (
+ errUserIDClaimEmpty = errors.New("user ID claim token value is empty")
+)
// ClaimsExtractor struct that holds the extract function
type ClaimsExtractor struct {
authAudience string
userIDClaim string
-
- FromRequestContext ExtractClaims
}
// ClaimsExtractorOption is a function that configures the ClaimsExtractor
@@ -54,13 +53,6 @@ func WithUserIDClaim(userIDClaim string) ClaimsExtractorOption {
}
}
-// WithFromRequestContext sets the function that extracts claims from the request context
-func WithFromRequestContext(ec ExtractClaims) ClaimsExtractorOption {
- return func(c *ClaimsExtractor) {
- c.FromRequestContext = ec
- }
-}
-
// NewClaimsExtractor returns an extractor, and if provided with a function with ExtractClaims signature,
// then it will use that logic. Uses ExtractClaimsFromRequestContext by default
func NewClaimsExtractor(options ...ClaimsExtractorOption) *ClaimsExtractor {
@@ -68,49 +60,13 @@ func NewClaimsExtractor(options ...ClaimsExtractorOption) *ClaimsExtractor {
for _, option := range options {
option(ce)
}
- if ce.FromRequestContext == nil {
- ce.FromRequestContext = ce.fromRequestContext
- }
+
if ce.userIDClaim == "" {
ce.userIDClaim = UserIDClaim
}
return ce
}
-// FromToken extracts claims from the token (after auth)
-func (c *ClaimsExtractor) FromToken(token *jwt.Token) AuthorizationClaims {
- claims := token.Claims.(jwt.MapClaims)
- jwtClaims := AuthorizationClaims{
- Raw: claims,
- }
- userID, ok := claims[c.userIDClaim].(string)
- if !ok {
- return jwtClaims
- }
- jwtClaims.UserId = userID
- accountIDClaim, ok := claims[c.authAudience+AccountIDSuffix]
- if ok {
- jwtClaims.AccountId = accountIDClaim.(string)
- }
- domainClaim, ok := claims[c.authAudience+DomainIDSuffix]
- if ok {
- jwtClaims.Domain = domainClaim.(string)
- }
- domainCategoryClaim, ok := claims[c.authAudience+DomainCategorySuffix]
- if ok {
- jwtClaims.DomainCategory = domainCategoryClaim.(string)
- }
- LastLoginClaimString, ok := claims[c.authAudience+LastLoginSuffix]
- if ok {
- jwtClaims.LastLogin = parseTime(LastLoginClaimString.(string))
- }
- invitedBool, ok := claims[c.authAudience+Invited]
- if ok {
- jwtClaims.Invited = invitedBool.(bool)
- }
- return jwtClaims
-}
-
func parseTime(timeString string) time.Time {
if timeString == "" {
return time.Time{}
@@ -122,11 +78,67 @@ func parseTime(timeString string) time.Time {
return parsedTime
}
-// fromRequestContext extracts claims from the request context previously filled by the JWT token (after auth)
-func (c *ClaimsExtractor) fromRequestContext(r *http.Request) AuthorizationClaims {
- if r.Context().Value(TokenUserProperty) == nil {
- return AuthorizationClaims{}
+func (c ClaimsExtractor) audienceClaim(claimName string) string {
+ url, err := url.JoinPath(c.authAudience, claimName)
+ if err != nil {
+ return c.authAudience + claimName // as it was previously
}
- token := r.Context().Value(TokenUserProperty).(*jwt.Token)
- return c.FromToken(token)
+
+ return url
+}
+
+func (c *ClaimsExtractor) ToUserAuth(token *jwt.Token) (nbcontext.UserAuth, error) {
+ claims := token.Claims.(jwt.MapClaims)
+ userAuth := nbcontext.UserAuth{}
+
+ userID, ok := claims[c.userIDClaim].(string)
+ if !ok {
+ return userAuth, errUserIDClaimEmpty
+ }
+ userAuth.UserId = userID
+
+ if accountIDClaim, ok := claims[c.audienceClaim(AccountIDSuffix)]; ok {
+ userAuth.AccountId = accountIDClaim.(string)
+ }
+
+ if domainClaim, ok := claims[c.audienceClaim(DomainIDSuffix)]; ok {
+ userAuth.Domain = domainClaim.(string)
+ }
+
+ if domainCategoryClaim, ok := claims[c.audienceClaim(DomainCategorySuffix)]; ok {
+ userAuth.DomainCategory = domainCategoryClaim.(string)
+ }
+
+ if lastLoginClaimString, ok := claims[c.audienceClaim(LastLoginSuffix)]; ok {
+ userAuth.LastLogin = parseTime(lastLoginClaimString.(string))
+ }
+
+ if invitedBool, ok := claims[c.audienceClaim(Invited)]; ok {
+ if value, ok := invitedBool.(bool); ok {
+ userAuth.Invited = value
+ }
+ }
+
+ return userAuth, nil
+}
+
+func (c *ClaimsExtractor) ToGroups(token *jwt.Token, claimName string) []string {
+ claims := token.Claims.(jwt.MapClaims)
+ userJWTGroups := make([]string, 0)
+
+ if claim, ok := claims[claimName]; ok {
+ if claimGroups, ok := claim.([]interface{}); ok {
+ for _, g := range claimGroups {
+ if group, ok := g.(string); ok {
+ userJWTGroups = append(userJWTGroups, group)
+ } else {
+ log.Debugf("JWT claim %q contains a non-string group (type: %T): %v", claimName, g, g)
+ }
+ }
+ }
+ } else {
+ log.Debugf("JWT claim %q is not a string array", claimName)
+ }
+
+ return userJWTGroups
}
diff --git a/management/server/auth/jwt/validator.go b/management/server/auth/jwt/validator.go
new file mode 100644
index 000000000..5b38ca786
--- /dev/null
+++ b/management/server/auth/jwt/validator.go
@@ -0,0 +1,302 @@
+package jwt
+
+import (
+ "context"
+ "crypto/ecdsa"
+ "crypto/elliptic"
+ "crypto/rsa"
+ "encoding/base64"
+ "encoding/json"
+ "errors"
+ "fmt"
+ "math/big"
+ "net/http"
+ "net/url"
+ "strconv"
+ "strings"
+ "sync"
+ "time"
+
+ "github.com/golang-jwt/jwt"
+
+ log "github.com/sirupsen/logrus"
+)
+
+// Jwks is a collection of JSONWebKey obtained from Config.HttpServerConfig.AuthKeysLocation
+type Jwks struct {
+ Keys []JSONWebKey `json:"keys"`
+ expiresInTime time.Time
+}
+
+// The supported elliptic curves types
+const (
+ // p256 represents a cryptographic elliptical curve type.
+ p256 = "P-256"
+
+ // p384 represents a cryptographic elliptical curve type.
+ p384 = "P-384"
+
+ // p521 represents a cryptographic elliptical curve type.
+ p521 = "P-521"
+)
+
+// JSONWebKey is a representation of a Jason Web Key
+type JSONWebKey struct {
+ Kty string `json:"kty"`
+ Kid string `json:"kid"`
+ Use string `json:"use"`
+ N string `json:"n"`
+ E string `json:"e"`
+ Crv string `json:"crv"`
+ X string `json:"x"`
+ Y string `json:"y"`
+ X5c []string `json:"x5c"`
+}
+
+type Validator struct {
+ lock sync.Mutex
+ issuer string
+ audienceList []string
+ keysLocation string
+ idpSignkeyRefreshEnabled bool
+ keys *Jwks
+}
+
+var (
+ errKeyNotFound = errors.New("unable to find appropriate key")
+ errInvalidAudience = errors.New("invalid audience")
+ errInvalidIssuer = errors.New("invalid issuer")
+ errTokenEmpty = errors.New("required authorization token not found")
+ errTokenInvalid = errors.New("token is invalid")
+ errTokenParsing = errors.New("token could not be parsed")
+)
+
+func NewValidator(issuer string, audienceList []string, keysLocation string, idpSignkeyRefreshEnabled bool) *Validator {
+ keys, err := getPemKeys(keysLocation)
+ if err != nil {
+ log.WithField("keysLocation", keysLocation).Errorf("could not get keys from location: %s", err)
+ }
+
+ return &Validator{
+ keys: keys,
+ issuer: issuer,
+ audienceList: audienceList,
+ keysLocation: keysLocation,
+ idpSignkeyRefreshEnabled: idpSignkeyRefreshEnabled,
+ }
+}
+
+func (v *Validator) getKeyFunc(ctx context.Context) jwt.Keyfunc {
+ return func(token *jwt.Token) (interface{}, error) {
+ // Verify 'aud' claim
+ var checkAud bool
+ for _, audience := range v.audienceList {
+ checkAud = token.Claims.(jwt.MapClaims).VerifyAudience(audience, false)
+ if checkAud {
+ break
+ }
+ }
+ if !checkAud {
+ return token, errInvalidAudience
+ }
+
+ // Verify 'issuer' claim
+ checkIss := token.Claims.(jwt.MapClaims).VerifyIssuer(v.issuer, false)
+ if !checkIss {
+ return token, errInvalidIssuer
+ }
+
+ // If keys are rotated, verify the keys prior to token validation
+ if v.idpSignkeyRefreshEnabled {
+ // If the keys are invalid, retrieve new ones
+ // @todo propose a separate go routine to regularly check these to prevent blocking when actually
+ // validating the token
+ if !v.keys.stillValid() {
+ v.lock.Lock()
+ defer v.lock.Unlock()
+
+ refreshedKeys, err := getPemKeys(v.keysLocation)
+ if err != nil {
+ log.WithContext(ctx).Debugf("cannot get JSONWebKey: %v, falling back to old keys", err)
+ refreshedKeys = v.keys
+ }
+
+ log.WithContext(ctx).Debugf("keys refreshed, new UTC expiration time: %s", refreshedKeys.expiresInTime.UTC())
+
+ v.keys = refreshedKeys
+ }
+ }
+
+ publicKey, err := getPublicKey(token, v.keys)
+ if err == nil {
+ return publicKey, nil
+ }
+
+ msg := fmt.Sprintf("getPublicKey error: %s", err)
+ if errors.Is(err, errKeyNotFound) && !v.idpSignkeyRefreshEnabled {
+ msg = fmt.Sprintf("getPublicKey error: %s. You can enable key refresh by setting HttpServerConfig.IdpSignKeyRefreshEnabled to true in your management.json file and restart the service", err)
+ }
+
+ log.WithContext(ctx).Error(msg)
+
+ return nil, err
+ }
+}
+
+// ValidateAndParse validates the token and returns the parsed token
+func (m *Validator) ValidateAndParse(ctx context.Context, token string) (*jwt.Token, error) {
+ // If the token is empty...
+ if token == "" {
+ // If we get here, the required token is missing
+ log.WithContext(ctx).Debugf(" Error: No credentials found (CredentialsOptional=false)")
+ return nil, errTokenEmpty
+ }
+
+ // Now parse the token
+ parsedToken, err := jwt.Parse(token, m.getKeyFunc(ctx))
+
+ // Check if there was an error in parsing...
+ if err != nil {
+ err = fmt.Errorf("%w: %s", errTokenParsing, err)
+ log.WithContext(ctx).Error(err.Error())
+ return nil, err
+ }
+
+ // Check if the parsed token is valid...
+ if !parsedToken.Valid {
+ log.WithContext(ctx).Debug(errTokenInvalid.Error())
+ return nil, errTokenInvalid
+ }
+
+ return parsedToken, nil
+}
+
+// stillValid returns true if the JSONWebKey still valid and have enough time to be used
+func (jwks *Jwks) stillValid() bool {
+ return !jwks.expiresInTime.IsZero() && time.Now().Add(5*time.Second).Before(jwks.expiresInTime)
+}
+
+func getPemKeys(keysLocation string) (*Jwks, error) {
+ jwks := &Jwks{}
+
+ url, err := url.ParseRequestURI(keysLocation)
+ if err != nil {
+ return jwks, err
+ }
+
+ resp, err := http.Get(url.String())
+ if err != nil {
+ return jwks, err
+ }
+ defer resp.Body.Close()
+
+ err = json.NewDecoder(resp.Body).Decode(jwks)
+ if err != nil {
+ return jwks, err
+ }
+
+ cacheControlHeader := resp.Header.Get("Cache-Control")
+ expiresIn := getMaxAgeFromCacheHeader(cacheControlHeader)
+ jwks.expiresInTime = time.Now().Add(time.Duration(expiresIn) * time.Second)
+
+ return jwks, nil
+}
+
+func getPublicKey(token *jwt.Token, jwks *Jwks) (interface{}, error) {
+ // todo as we load the jkws when the server is starting, we should build a JKS map with the pem cert at the boot time
+ for k := range jwks.Keys {
+ if token.Header["kid"] != jwks.Keys[k].Kid {
+ continue
+ }
+
+ if len(jwks.Keys[k].X5c) != 0 {
+ cert := "-----BEGIN CERTIFICATE-----\n" + jwks.Keys[k].X5c[0] + "\n-----END CERTIFICATE-----"
+ return jwt.ParseRSAPublicKeyFromPEM([]byte(cert))
+ }
+
+ if jwks.Keys[k].Kty == "RSA" {
+ return getPublicKeyFromRSA(jwks.Keys[k])
+ }
+ if jwks.Keys[k].Kty == "EC" {
+ return getPublicKeyFromECDSA(jwks.Keys[k])
+ }
+ }
+
+ return nil, errKeyNotFound
+}
+
+func getPublicKeyFromECDSA(jwk JSONWebKey) (publicKey *ecdsa.PublicKey, err error) {
+ if jwk.X == "" || jwk.Y == "" || jwk.Crv == "" {
+ return nil, fmt.Errorf("ecdsa key incomplete")
+ }
+
+ var xCoordinate []byte
+ if xCoordinate, err = base64.RawURLEncoding.DecodeString(jwk.X); err != nil {
+ return nil, err
+ }
+
+ var yCoordinate []byte
+ if yCoordinate, err = base64.RawURLEncoding.DecodeString(jwk.Y); err != nil {
+ return nil, err
+ }
+
+ publicKey = &ecdsa.PublicKey{}
+
+ var curve elliptic.Curve
+ switch jwk.Crv {
+ case p256:
+ curve = elliptic.P256()
+ case p384:
+ curve = elliptic.P384()
+ case p521:
+ curve = elliptic.P521()
+ }
+
+ publicKey.Curve = curve
+ publicKey.X = big.NewInt(0).SetBytes(xCoordinate)
+ publicKey.Y = big.NewInt(0).SetBytes(yCoordinate)
+
+ return publicKey, nil
+}
+
+func getPublicKeyFromRSA(jwk JSONWebKey) (*rsa.PublicKey, error) {
+ decodedE, err := base64.RawURLEncoding.DecodeString(jwk.E)
+ if err != nil {
+ return nil, err
+ }
+ decodedN, err := base64.RawURLEncoding.DecodeString(jwk.N)
+ if err != nil {
+ return nil, err
+ }
+
+ var n, e big.Int
+ e.SetBytes(decodedE)
+ n.SetBytes(decodedN)
+
+ return &rsa.PublicKey{
+ E: int(e.Int64()),
+ N: &n,
+ }, nil
+}
+
+// getMaxAgeFromCacheHeader extracts max-age directive from the Cache-Control header
+func getMaxAgeFromCacheHeader(cacheControl string) int {
+ // Split into individual directives
+ directives := strings.Split(cacheControl, ",")
+
+ for _, directive := range directives {
+ directive = strings.TrimSpace(directive)
+ if strings.HasPrefix(directive, "max-age=") {
+ // Extract the max-age value
+ maxAgeStr := strings.TrimPrefix(directive, "max-age=")
+ maxAge, err := strconv.Atoi(maxAgeStr)
+ if err != nil {
+ return 0
+ }
+
+ return maxAge
+ }
+ }
+
+ return 0
+}
diff --git a/management/server/auth/manager.go b/management/server/auth/manager.go
new file mode 100644
index 000000000..6835a3ced
--- /dev/null
+++ b/management/server/auth/manager.go
@@ -0,0 +1,170 @@
+package auth
+
+import (
+ "context"
+ "crypto/sha256"
+ "encoding/base64"
+ "fmt"
+ "hash/crc32"
+
+ "github.com/golang-jwt/jwt"
+
+ "github.com/netbirdio/netbird/base62"
+ nbjwt "github.com/netbirdio/netbird/management/server/auth/jwt"
+ nbcontext "github.com/netbirdio/netbird/management/server/context"
+ "github.com/netbirdio/netbird/management/server/store"
+ "github.com/netbirdio/netbird/management/server/types"
+)
+
+var _ Manager = (*manager)(nil)
+
+type Manager interface {
+ ValidateAndParseToken(ctx context.Context, value string) (nbcontext.UserAuth, *jwt.Token, error)
+ EnsureUserAccessByJWTGroups(ctx context.Context, userAuth nbcontext.UserAuth, token *jwt.Token) (nbcontext.UserAuth, error)
+ MarkPATUsed(ctx context.Context, tokenID string) error
+ GetPATInfo(ctx context.Context, token string) (user *types.User, pat *types.PersonalAccessToken, domain string, category string, err error)
+}
+
+type manager struct {
+ store store.Store
+
+ validator *nbjwt.Validator
+ extractor *nbjwt.ClaimsExtractor
+}
+
+func NewManager(store store.Store, issuer, audience, keysLocation, userIdClaim string, allAudiences []string, idpRefreshKeys bool) Manager {
+ // @note if invalid/missing parameters are sent the validator will instantiate
+ // but it will fail when validating and parsing the token
+ jwtValidator := nbjwt.NewValidator(
+ issuer,
+ allAudiences,
+ keysLocation,
+ idpRefreshKeys,
+ )
+
+ claimsExtractor := nbjwt.NewClaimsExtractor(
+ nbjwt.WithAudience(audience),
+ nbjwt.WithUserIDClaim(userIdClaim),
+ )
+
+ return &manager{
+ store: store,
+
+ validator: jwtValidator,
+ extractor: claimsExtractor,
+ }
+}
+
+func (m *manager) ValidateAndParseToken(ctx context.Context, value string) (nbcontext.UserAuth, *jwt.Token, error) {
+ token, err := m.validator.ValidateAndParse(ctx, value)
+ if err != nil {
+ return nbcontext.UserAuth{}, nil, err
+ }
+
+ userAuth, err := m.extractor.ToUserAuth(token)
+ if err != nil {
+ return nbcontext.UserAuth{}, nil, err
+ }
+ return userAuth, token, err
+}
+
+func (m *manager) EnsureUserAccessByJWTGroups(ctx context.Context, userAuth nbcontext.UserAuth, token *jwt.Token) (nbcontext.UserAuth, error) {
+ if userAuth.IsChild || userAuth.IsPAT {
+ return userAuth, nil
+ }
+
+ settings, err := m.store.GetAccountSettings(ctx, store.LockingStrengthShare, userAuth.AccountId)
+ if err != nil {
+ return userAuth, err
+ }
+
+ // Ensures JWT group synchronization to the management is enabled before,
+ // filtering access based on the allowed groups.
+ if settings != nil && settings.JWTGroupsEnabled {
+ userAuth.Groups = m.extractor.ToGroups(token, settings.JWTGroupsClaimName)
+ if allowedGroups := settings.JWTAllowGroups; len(allowedGroups) > 0 {
+ if !userHasAllowedGroup(allowedGroups, userAuth.Groups) {
+ return userAuth, fmt.Errorf("user does not belong to any of the allowed JWT groups")
+ }
+ }
+ }
+
+ return userAuth, nil
+}
+
+// MarkPATUsed marks a personal access token as used
+func (am *manager) MarkPATUsed(ctx context.Context, tokenID string) error {
+ return am.store.MarkPATUsed(ctx, store.LockingStrengthUpdate, tokenID)
+}
+
+// GetPATInfo retrieves user, personal access token, domain, and category details from a personal access token.
+func (am *manager) GetPATInfo(ctx context.Context, token string) (user *types.User, pat *types.PersonalAccessToken, domain string, category string, err error) {
+ user, pat, err = am.extractPATFromToken(ctx, token)
+ if err != nil {
+ return nil, nil, "", "", err
+ }
+
+ domain, category, err = am.store.GetAccountDomainAndCategory(ctx, store.LockingStrengthShare, user.AccountID)
+ if err != nil {
+ return nil, nil, "", "", err
+ }
+
+ return user, pat, domain, category, nil
+}
+
+// extractPATFromToken validates the token structure and retrieves associated User and PAT.
+func (am *manager) extractPATFromToken(ctx context.Context, token string) (*types.User, *types.PersonalAccessToken, error) {
+ if len(token) != types.PATLength {
+ return nil, nil, fmt.Errorf("PAT has incorrect length")
+ }
+
+ prefix := token[:len(types.PATPrefix)]
+ if prefix != types.PATPrefix {
+ return nil, nil, fmt.Errorf("PAT has wrong prefix")
+ }
+ secret := token[len(types.PATPrefix) : len(types.PATPrefix)+types.PATSecretLength]
+ encodedChecksum := token[len(types.PATPrefix)+types.PATSecretLength : len(types.PATPrefix)+types.PATSecretLength+types.PATChecksumLength]
+
+ verificationChecksum, err := base62.Decode(encodedChecksum)
+ if err != nil {
+ return nil, nil, fmt.Errorf("PAT checksum decoding failed: %w", err)
+ }
+
+ secretChecksum := crc32.ChecksumIEEE([]byte(secret))
+ if secretChecksum != verificationChecksum {
+ return nil, nil, fmt.Errorf("PAT checksum does not match")
+ }
+
+ hashedToken := sha256.Sum256([]byte(token))
+ encodedHashedToken := base64.StdEncoding.EncodeToString(hashedToken[:])
+
+ var user *types.User
+ var pat *types.PersonalAccessToken
+
+ err = am.store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
+ pat, err = transaction.GetPATByHashedToken(ctx, store.LockingStrengthShare, encodedHashedToken)
+ if err != nil {
+ return err
+ }
+
+ user, err = transaction.GetUserByPATID(ctx, store.LockingStrengthShare, pat.ID)
+ return err
+ })
+ if err != nil {
+ return nil, nil, err
+ }
+
+ return user, pat, nil
+}
+
+// userHasAllowedGroup checks if a user belongs to any of the allowed groups.
+func userHasAllowedGroup(allowedGroups []string, userGroups []string) bool {
+ for _, userGroup := range userGroups {
+ for _, allowedGroup := range allowedGroups {
+ if userGroup == allowedGroup {
+ return true
+ }
+ }
+ }
+ return false
+}
diff --git a/management/server/auth/manager_mock.go b/management/server/auth/manager_mock.go
new file mode 100644
index 000000000..bc7066548
--- /dev/null
+++ b/management/server/auth/manager_mock.go
@@ -0,0 +1,54 @@
+package auth
+
+import (
+ "context"
+
+ "github.com/golang-jwt/jwt"
+
+ nbcontext "github.com/netbirdio/netbird/management/server/context"
+ "github.com/netbirdio/netbird/management/server/types"
+)
+
+var (
+ _ Manager = (*MockManager)(nil)
+)
+
+// @note really dislike this mocking approach but rather than have to do additional test refactoring.
+type MockManager struct {
+ ValidateAndParseTokenFunc func(ctx context.Context, value string) (nbcontext.UserAuth, *jwt.Token, error)
+ EnsureUserAccessByJWTGroupsFunc func(ctx context.Context, userAuth nbcontext.UserAuth, token *jwt.Token) (nbcontext.UserAuth, error)
+ MarkPATUsedFunc func(ctx context.Context, tokenID string) error
+ GetPATInfoFunc func(ctx context.Context, token string) (user *types.User, pat *types.PersonalAccessToken, domain string, category string, err error)
+}
+
+// EnsureUserAccessByJWTGroups implements Manager.
+func (m *MockManager) EnsureUserAccessByJWTGroups(ctx context.Context, userAuth nbcontext.UserAuth, token *jwt.Token) (nbcontext.UserAuth, error) {
+ if m.EnsureUserAccessByJWTGroupsFunc != nil {
+ return m.EnsureUserAccessByJWTGroupsFunc(ctx, userAuth, token)
+ }
+ return nbcontext.UserAuth{}, nil
+}
+
+// GetPATInfo implements Manager.
+func (m *MockManager) GetPATInfo(ctx context.Context, token string) (user *types.User, pat *types.PersonalAccessToken, domain string, category string, err error) {
+ if m.GetPATInfoFunc != nil {
+ return m.GetPATInfoFunc(ctx, token)
+ }
+ return &types.User{}, &types.PersonalAccessToken{}, "", "", nil
+}
+
+// MarkPATUsed implements Manager.
+func (m *MockManager) MarkPATUsed(ctx context.Context, tokenID string) error {
+ if m.MarkPATUsedFunc != nil {
+ return m.MarkPATUsedFunc(ctx, tokenID)
+ }
+ return nil
+}
+
+// ValidateAndParseToken implements Manager.
+func (m *MockManager) ValidateAndParseToken(ctx context.Context, value string) (nbcontext.UserAuth, *jwt.Token, error) {
+ if m.ValidateAndParseTokenFunc != nil {
+ return m.ValidateAndParseTokenFunc(ctx, value)
+ }
+ return nbcontext.UserAuth{}, &jwt.Token{}, nil
+}
diff --git a/management/server/auth/manager_test.go b/management/server/auth/manager_test.go
new file mode 100644
index 000000000..55fb1e31a
--- /dev/null
+++ b/management/server/auth/manager_test.go
@@ -0,0 +1,407 @@
+package auth_test
+
+import (
+ "context"
+ "crypto/sha256"
+ "encoding/base64"
+ "fmt"
+ "net/http"
+ "net/http/httptest"
+ "os"
+ "strings"
+ "testing"
+ "time"
+
+ "github.com/golang-jwt/jwt"
+ "github.com/stretchr/testify/assert"
+ "github.com/stretchr/testify/require"
+
+ "github.com/netbirdio/netbird/management/server/auth"
+ nbjwt "github.com/netbirdio/netbird/management/server/auth/jwt"
+ nbcontext "github.com/netbirdio/netbird/management/server/context"
+ "github.com/netbirdio/netbird/management/server/store"
+ "github.com/netbirdio/netbird/management/server/types"
+)
+
+func TestAuthManager_GetAccountInfoFromPAT(t *testing.T) {
+ store, cleanup, err := store.NewTestStoreFromSQL(context.Background(), "", t.TempDir())
+ if err != nil {
+ t.Fatalf("Error when creating store: %s", err)
+ }
+ t.Cleanup(cleanup)
+
+ token := "nbp_9999EUDNdkeusjentDLSJEn1902u84390W6W"
+ hashedToken := sha256.Sum256([]byte(token))
+ encodedHashedToken := base64.StdEncoding.EncodeToString(hashedToken[:])
+ account := &types.Account{
+ Id: "account_id",
+ Users: map[string]*types.User{"someUser": {
+ Id: "someUser",
+ PATs: map[string]*types.PersonalAccessToken{
+ "tokenId": {
+ ID: "tokenId",
+ UserID: "someUser",
+ HashedToken: encodedHashedToken,
+ },
+ },
+ }},
+ }
+
+ err = store.SaveAccount(context.Background(), account)
+ if err != nil {
+ t.Fatalf("Error when saving account: %s", err)
+ }
+
+ manager := auth.NewManager(store, "", "", "", "", []string{}, false)
+
+ user, pat, _, _, err := manager.GetPATInfo(context.Background(), token)
+ if err != nil {
+ t.Fatalf("Error when getting Account from PAT: %s", err)
+ }
+
+ assert.Equal(t, "account_id", user.AccountID)
+ assert.Equal(t, "someUser", user.Id)
+ assert.Equal(t, account.Users["someUser"].PATs["tokenId"].ID, pat.ID)
+}
+
+func TestAuthManager_MarkPATUsed(t *testing.T) {
+ store, cleanup, err := store.NewTestStoreFromSQL(context.Background(), "", t.TempDir())
+ if err != nil {
+ t.Fatalf("Error when creating store: %s", err)
+ }
+ t.Cleanup(cleanup)
+
+ token := "nbp_9999EUDNdkeusjentDLSJEn1902u84390W6W"
+ hashedToken := sha256.Sum256([]byte(token))
+ encodedHashedToken := base64.StdEncoding.EncodeToString(hashedToken[:])
+ account := &types.Account{
+ Id: "account_id",
+ Users: map[string]*types.User{"someUser": {
+ Id: "someUser",
+ PATs: map[string]*types.PersonalAccessToken{
+ "tokenId": {
+ ID: "tokenId",
+ HashedToken: encodedHashedToken,
+ },
+ },
+ }},
+ }
+
+ err = store.SaveAccount(context.Background(), account)
+ if err != nil {
+ t.Fatalf("Error when saving account: %s", err)
+ }
+
+ manager := auth.NewManager(store, "", "", "", "", []string{}, false)
+
+ err = manager.MarkPATUsed(context.Background(), "tokenId")
+ if err != nil {
+ t.Fatalf("Error when marking PAT used: %s", err)
+ }
+
+ account, err = store.GetAccount(context.Background(), "account_id")
+ if err != nil {
+ t.Fatalf("Error when getting account: %s", err)
+ }
+ assert.True(t, !account.Users["someUser"].PATs["tokenId"].GetLastUsed().IsZero())
+}
+
+func TestAuthManager_EnsureUserAccessByJWTGroups(t *testing.T) {
+ store, cleanup, err := store.NewTestStoreFromSQL(context.Background(), "", t.TempDir())
+ if err != nil {
+ t.Fatalf("Error when creating store: %s", err)
+ }
+ t.Cleanup(cleanup)
+
+ userId := "user-id"
+ domain := "test.domain"
+
+ account := &types.Account{
+ Id: "account_id",
+ Domain: domain,
+ Users: map[string]*types.User{"someUser": {
+ Id: "someUser",
+ }},
+ Settings: &types.Settings{},
+ }
+
+ err = store.SaveAccount(context.Background(), account)
+ if err != nil {
+ t.Fatalf("Error when saving account: %s", err)
+ }
+
+ // this has been validated and parsed by ValidateAndParseToken
+ userAuth := nbcontext.UserAuth{
+ AccountId: account.Id,
+ Domain: domain,
+ UserId: userId,
+ DomainCategory: "test-category",
+ // Groups: []string{"group1", "group2"},
+ }
+
+ // these tests only assert groups are parsed from token as per account settings
+ token := jwt.NewWithClaims(jwt.SigningMethodHS256, jwt.MapClaims{"idp-groups": []interface{}{"group1", "group2"}})
+
+ manager := auth.NewManager(store, "", "", "", "", []string{}, false)
+
+ t.Run("JWT groups disabled", func(t *testing.T) {
+ userAuth, err := manager.EnsureUserAccessByJWTGroups(context.Background(), userAuth, token)
+ require.NoError(t, err, "ensure user access by JWT groups failed")
+ require.Len(t, userAuth.Groups, 0, "account not enabled to ensure access by groups")
+ })
+
+ t.Run("User impersonated", func(t *testing.T) {
+ userAuth, err := manager.EnsureUserAccessByJWTGroups(context.Background(), userAuth, token)
+ require.NoError(t, err, "ensure user access by JWT groups failed")
+ require.Len(t, userAuth.Groups, 0, "account not enabled to ensure access by groups")
+ })
+
+ t.Run("User PAT", func(t *testing.T) {
+ userAuth, err := manager.EnsureUserAccessByJWTGroups(context.Background(), userAuth, token)
+ require.NoError(t, err, "ensure user access by JWT groups failed")
+ require.Len(t, userAuth.Groups, 0, "account not enabled to ensure access by groups")
+ })
+
+ t.Run("JWT groups enabled without claim name", func(t *testing.T) {
+ account.Settings.JWTGroupsEnabled = true
+ err := store.SaveAccount(context.Background(), account)
+ require.NoError(t, err, "save account failed")
+
+ userAuth, err := manager.EnsureUserAccessByJWTGroups(context.Background(), userAuth, token)
+ require.NoError(t, err, "ensure user access by JWT groups failed")
+ require.Len(t, userAuth.Groups, 0, "account missing groups claim name")
+ })
+
+ t.Run("JWT groups enabled without allowed groups", func(t *testing.T) {
+ account.Settings.JWTGroupsEnabled = true
+ account.Settings.JWTGroupsClaimName = "idp-groups"
+ err := store.SaveAccount(context.Background(), account)
+ require.NoError(t, err, "save account failed")
+
+ userAuth, err := manager.EnsureUserAccessByJWTGroups(context.Background(), userAuth, token)
+ require.NoError(t, err, "ensure user access by JWT groups failed")
+ require.Equal(t, []string{"group1", "group2"}, userAuth.Groups, "group parsed do not match")
+ })
+
+ t.Run("User in allowed JWT groups", func(t *testing.T) {
+ account.Settings.JWTGroupsEnabled = true
+ account.Settings.JWTGroupsClaimName = "idp-groups"
+ account.Settings.JWTAllowGroups = []string{"group1"}
+ err := store.SaveAccount(context.Background(), account)
+ require.NoError(t, err, "save account failed")
+
+ userAuth, err := manager.EnsureUserAccessByJWTGroups(context.Background(), userAuth, token)
+ require.NoError(t, err, "ensure user access by JWT groups failed")
+
+ require.Equal(t, []string{"group1", "group2"}, userAuth.Groups, "group parsed do not match")
+ })
+
+ t.Run("User not in allowed JWT groups", func(t *testing.T) {
+ account.Settings.JWTGroupsEnabled = true
+ account.Settings.JWTGroupsClaimName = "idp-groups"
+ account.Settings.JWTAllowGroups = []string{"not-a-group"}
+ err := store.SaveAccount(context.Background(), account)
+ require.NoError(t, err, "save account failed")
+
+ _, err = manager.EnsureUserAccessByJWTGroups(context.Background(), userAuth, token)
+ require.Error(t, err, "ensure user access is not in allowed groups")
+ })
+}
+
+func TestAuthManager_ValidateAndParseToken(t *testing.T) {
+ server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ w.Header().Add("Cache-Control", "max-age=30") // set a 30s expiry to these keys
+ http.ServeFile(w, r, "test_data/jwks.json")
+ }))
+ defer server.Close()
+
+ issuer := "http://issuer.local"
+ audience := "http://audience.local"
+ userIdClaim := "" // defaults to "sub"
+
+ // we're only testing with RSA256
+ keyData, _ := os.ReadFile("test_data/sample_key")
+ key, _ := jwt.ParseRSAPrivateKeyFromPEM(keyData)
+ keyId := "test-key"
+
+ // note, we can use a nil store because ValidateAndParseToken does not use it in it's flow
+ manager := auth.NewManager(nil, issuer, audience, server.URL, userIdClaim, []string{audience}, false)
+
+ customClaim := func(name string) string {
+ return fmt.Sprintf("%s/%s", audience, name)
+ }
+
+ lastLogin := time.Date(2025, 2, 12, 14, 25, 26, 0, time.UTC) //"2025-02-12T14:25:26.186Z"
+
+ tests := []struct {
+ name string
+ tokenFunc func() string
+ expected *nbcontext.UserAuth // nil indicates expected error
+ }{
+ {
+ name: "Valid with custom claims",
+ tokenFunc: func() string {
+ token := jwt.New(jwt.SigningMethodRS256)
+ token.Header["kid"] = keyId
+ token.Claims = jwt.MapClaims{
+ "iss": issuer,
+ "aud": []string{audience},
+ "iat": time.Now().Unix(),
+ "exp": time.Now().Add(time.Hour * 1).Unix(),
+ "sub": "user-id|123",
+ customClaim(nbjwt.AccountIDSuffix): "account-id|567",
+ customClaim(nbjwt.DomainIDSuffix): "http://localhost",
+ customClaim(nbjwt.DomainCategorySuffix): "private",
+ customClaim(nbjwt.LastLoginSuffix): lastLogin.Format(time.RFC3339),
+ customClaim(nbjwt.Invited): false,
+ }
+ tokenString, _ := token.SignedString(key)
+ return tokenString
+ },
+ expected: &nbcontext.UserAuth{
+ UserId: "user-id|123",
+ AccountId: "account-id|567",
+ Domain: "http://localhost",
+ DomainCategory: "private",
+ LastLogin: lastLogin,
+ Invited: false,
+ },
+ },
+ {
+ name: "Valid without custom claims",
+ tokenFunc: func() string {
+ token := jwt.New(jwt.SigningMethodRS256)
+ token.Header["kid"] = keyId
+ token.Claims = jwt.MapClaims{
+ "iss": issuer,
+ "aud": []string{audience},
+ "iat": time.Now().Unix(),
+ "exp": time.Now().Add(time.Hour).Unix(),
+ "sub": "user-id|123",
+ }
+ tokenString, _ := token.SignedString(key)
+ return tokenString
+ },
+ expected: &nbcontext.UserAuth{
+ UserId: "user-id|123",
+ },
+ },
+ {
+ name: "Expired token",
+ tokenFunc: func() string {
+ token := jwt.New(jwt.SigningMethodRS256)
+ token.Header["kid"] = keyId
+ token.Claims = jwt.MapClaims{
+ "iss": issuer,
+ "aud": []string{audience},
+ "iat": time.Now().Add(time.Hour * -2).Unix(),
+ "exp": time.Now().Add(time.Hour * -1).Unix(),
+ "sub": "user-id|123",
+ }
+ tokenString, _ := token.SignedString(key)
+ return tokenString
+ },
+ },
+ {
+ name: "Not yet valid",
+ tokenFunc: func() string {
+ token := jwt.New(jwt.SigningMethodRS256)
+ token.Header["kid"] = keyId
+ token.Claims = jwt.MapClaims{
+ "iss": issuer,
+ "aud": []string{audience},
+ "iat": time.Now().Add(time.Hour).Unix(),
+ "exp": time.Now().Add(time.Hour * 2).Unix(),
+ "sub": "user-id|123",
+ }
+ tokenString, _ := token.SignedString(key)
+ return tokenString
+ },
+ },
+ {
+ name: "Invalid signature",
+ tokenFunc: func() string {
+ token := jwt.New(jwt.SigningMethodRS256)
+ token.Header["kid"] = keyId
+ token.Claims = jwt.MapClaims{
+ "iss": issuer,
+ "aud": []string{audience},
+ "iat": time.Now().Unix(),
+ "exp": time.Now().Add(time.Hour).Unix(),
+ "sub": "user-id|123",
+ }
+ tokenString, _ := token.SignedString(key)
+ parts := strings.Split(tokenString, ".")
+ parts[2] = "invalid-signature"
+ return strings.Join(parts, ".")
+ },
+ },
+ {
+ name: "Invalid issuer",
+ tokenFunc: func() string {
+ token := jwt.New(jwt.SigningMethodRS256)
+ token.Header["kid"] = keyId
+ token.Claims = jwt.MapClaims{
+ "iss": "not-the-issuer",
+ "aud": []string{audience},
+ "iat": time.Now().Unix(),
+ "exp": time.Now().Add(time.Hour).Unix(),
+ "sub": "user-id|123",
+ }
+ tokenString, _ := token.SignedString(key)
+ return tokenString
+ },
+ },
+ {
+ name: "Invalid audience",
+ tokenFunc: func() string {
+ token := jwt.New(jwt.SigningMethodRS256)
+ token.Header["kid"] = keyId
+ token.Claims = jwt.MapClaims{
+ "iss": issuer,
+ "aud": []string{"not-the-audience"},
+ "iat": time.Now().Unix(),
+ "exp": time.Now().Add(time.Hour).Unix(),
+ "sub": "user-id|123",
+ }
+ tokenString, _ := token.SignedString(key)
+ return tokenString
+ },
+ },
+ {
+ name: "Invalid user claim",
+ tokenFunc: func() string {
+ token := jwt.New(jwt.SigningMethodRS256)
+ token.Header["kid"] = keyId
+ token.Claims = jwt.MapClaims{
+ "iss": issuer,
+ "aud": []string{audience},
+ "iat": time.Now().Unix(),
+ "exp": time.Now().Add(time.Hour).Unix(),
+ "not-sub": "user-id|123",
+ }
+ tokenString, _ := token.SignedString(key)
+ return tokenString
+ },
+ },
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ tokenString := tt.tokenFunc()
+
+ userAuth, token, err := manager.ValidateAndParseToken(context.Background(), tokenString)
+
+ if tt.expected != nil {
+ assert.NoError(t, err)
+ assert.True(t, token.Valid)
+ assert.Equal(t, *tt.expected, userAuth)
+ } else {
+ assert.Error(t, err)
+ assert.Nil(t, token)
+ assert.Empty(t, userAuth)
+ }
+ })
+ }
+
+}
diff --git a/management/server/auth/test_data/jwks.json b/management/server/auth/test_data/jwks.json
new file mode 100644
index 000000000..8080f5599
--- /dev/null
+++ b/management/server/auth/test_data/jwks.json
@@ -0,0 +1,11 @@
+{
+ "keys": [
+ {
+ "kty": "RSA",
+ "kid": "test-key",
+ "use": "sig",
+ "n": "4f5wg5l2hKsTeNem_V41fGnJm6gOdrj8ym3rFkEU_wT8RDtnSgFEZOQpHEgQ7JL38xUfU0Y3g6aYw9QT0hJ7mCpz9Er5qLaMXJwZxzHzAahlfA0icqabvJOMvQtzD6uQv6wPEyZtDTWiQi9AXwBpHssPnpYGIn20ZZuNlX2BrClciHhCPUIIZOQn_MmqTD31jSyjoQoV7MhhMTATKJx2XrHhR-1DcKJzQBSTAGnpYVaqpsARap-nwRipr3nUTuxyGohBTSmjJ2usSeQXHI3bODIRe1AuTyHceAbewn8b462yEWKARdpd9AjQW5SIVPfdsz5B6GlYQ5LdYKtznTuy7w",
+ "e": "AQAB"
+ }
+ ]
+}
\ No newline at end of file
diff --git a/management/server/auth/test_data/sample_key b/management/server/auth/test_data/sample_key
new file mode 100644
index 000000000..e69284a3f
--- /dev/null
+++ b/management/server/auth/test_data/sample_key
@@ -0,0 +1,27 @@
+-----BEGIN RSA PRIVATE KEY-----
+MIIEowIBAAKCAQEA4f5wg5l2hKsTeNem/V41fGnJm6gOdrj8ym3rFkEU/wT8RDtn
+SgFEZOQpHEgQ7JL38xUfU0Y3g6aYw9QT0hJ7mCpz9Er5qLaMXJwZxzHzAahlfA0i
+cqabvJOMvQtzD6uQv6wPEyZtDTWiQi9AXwBpHssPnpYGIn20ZZuNlX2BrClciHhC
+PUIIZOQn/MmqTD31jSyjoQoV7MhhMTATKJx2XrHhR+1DcKJzQBSTAGnpYVaqpsAR
+ap+nwRipr3nUTuxyGohBTSmjJ2usSeQXHI3bODIRe1AuTyHceAbewn8b462yEWKA
+Rdpd9AjQW5SIVPfdsz5B6GlYQ5LdYKtznTuy7wIDAQABAoIBAQCwia1k7+2oZ2d3
+n6agCAbqIE1QXfCmh41ZqJHbOY3oRQG3X1wpcGH4Gk+O+zDVTV2JszdcOt7E5dAy
+MaomETAhRxB7hlIOnEN7WKm+dGNrKRvV0wDU5ReFMRHg31/Lnu8c+5BvGjZX+ky9
+POIhFFYJqwCRlopGSUIxmVj5rSgtzk3iWOQXr+ah1bjEXvlxDOWkHN6YfpV5ThdE
+KdBIPGEVqa63r9n2h+qazKrtiRqJqGnOrHzOECYbRFYhexsNFz7YT02xdfSHn7gM
+IvabDDP/Qp0PjE1jdouiMaFHYnLBbgvlnZW9yuVf/rpXTUq/njxIXMmvmEyyvSDn
+FcFikB8pAoGBAPF77hK4m3/rdGT7X8a/gwvZ2R121aBcdPwEaUhvj/36dx596zvY
+mEOjrWfZhF083/nYWE2kVquj2wjs+otCLfifEEgXcVPTnEOPO9Zg3uNSL0nNQghj
+FuD3iGLTUBCtM66oTe0jLSslHe8gLGEQqyMzHOzYxNqibxcOZIe8Qt0NAoGBAO+U
+I5+XWjWEgDmvyC3TrOSf/KCGjtu0TSv30ipv27bDLMrpvPmD/5lpptTFwcxvVhCs
+2b+chCjlghFSWFbBULBrfci2FtliClOVMYrlNBdUSJhf3aYSG2Doe6Bgt1n2CpNn
+/iu37Y3NfemZBJA7hNl4dYe+f+uzM87cdQ214+jrAoGAXA0XxX8ll2+ToOLJsaNT
+OvNB9h9Uc5qK5X5w+7G7O998BN2PC/MWp8H+2fVqpXgNENpNXttkRm1hk1dych86
+EunfdPuqsX+as44oCyJGFHVBnWpm33eWQw9YqANRI+pCJzP08I5WK3osnPiwshd+
+hR54yjgfYhBFNI7B95PmEQkCgYBzFSz7h1+s34Ycr8SvxsOBWxymG5zaCsUbPsL0
+4aCgLScCHb9J+E86aVbbVFdglYa5Id7DPTL61ixhl7WZjujspeXZGSbmq0Kcnckb
+mDgqkLECiOJW2NHP/j0McAkDLL4tysF8TLDO8gvuvzNC+WQ6drO2ThrypLVZQ+ry
+eBIPmwKBgEZxhqa0gVvHQG/7Od69KWj4eJP28kq13RhKay8JOoN0vPmspXJo1HY3
+CKuHRG+AP579dncdUnOMvfXOtkdM4vk0+hWASBQzM9xzVcztCa+koAugjVaLS9A+
+9uQoqEeVNTckxx0S2bYevRy7hGQmUJTyQm3j1zEUR5jpdbL83Fbq
+-----END RSA PRIVATE KEY-----
\ No newline at end of file
diff --git a/management/server/auth/test_data/sample_key.pub b/management/server/auth/test_data/sample_key.pub
new file mode 100644
index 000000000..d5b7f7102
--- /dev/null
+++ b/management/server/auth/test_data/sample_key.pub
@@ -0,0 +1,9 @@
+-----BEGIN PUBLIC KEY-----
+MIIBIjANBgkqhkiG9w0BAQEFAAOCAQ8AMIIBCgKCAQEA4f5wg5l2hKsTeNem/V41
+fGnJm6gOdrj8ym3rFkEU/wT8RDtnSgFEZOQpHEgQ7JL38xUfU0Y3g6aYw9QT0hJ7
+mCpz9Er5qLaMXJwZxzHzAahlfA0icqabvJOMvQtzD6uQv6wPEyZtDTWiQi9AXwBp
+HssPnpYGIn20ZZuNlX2BrClciHhCPUIIZOQn/MmqTD31jSyjoQoV7MhhMTATKJx2
+XrHhR+1DcKJzQBSTAGnpYVaqpsARap+nwRipr3nUTuxyGohBTSmjJ2usSeQXHI3b
+ODIRe1AuTyHceAbewn8b462yEWKARdpd9AjQW5SIVPfdsz5B6GlYQ5LdYKtznTuy
+7wIDAQAB
+-----END PUBLIC KEY-----
\ No newline at end of file
diff --git a/management/server/config.go b/management/server/config.go
index 397b5f0e6..ce2ff4d16 100644
--- a/management/server/config.go
+++ b/management/server/config.go
@@ -2,7 +2,6 @@ package server
import (
"net/netip"
- "net/url"
"github.com/netbirdio/netbird/management/server/idp"
"github.com/netbirdio/netbird/management/server/store"
@@ -180,9 +179,3 @@ type ReverseProxy struct {
// trusted IP prefixes.
TrustedPeers []netip.Prefix
}
-
-// validateURL validates input http url
-func validateURL(httpURL string) bool {
- _, err := url.ParseRequestURI(httpURL)
- return err == nil
-}
diff --git a/management/server/context/auth.go b/management/server/context/auth.go
new file mode 100644
index 000000000..5cb28ddb7
--- /dev/null
+++ b/management/server/context/auth.go
@@ -0,0 +1,60 @@
+package context
+
+import (
+ "context"
+ "fmt"
+ "net/http"
+ "time"
+)
+
+type key int
+
+const (
+ UserAuthContextKey key = iota
+)
+
+type UserAuth struct {
+ // The account id the user is accessing
+ AccountId string
+ // The account domain
+ Domain string
+ // The account domain category, TBC values
+ DomainCategory string
+ // Indicates whether this user was invited, TBC logic
+ Invited bool
+ // Indicates whether this is a child account
+ IsChild bool
+
+ // The user id
+ UserId string
+ // Last login time for this user
+ LastLogin time.Time
+ // The Groups the user belongs to on this account
+ Groups []string
+
+ // Indicates whether this user has authenticated with a Personal Access Token
+ IsPAT bool
+}
+
+func GetUserAuthFromRequest(r *http.Request) (UserAuth, error) {
+ return GetUserAuthFromContext(r.Context())
+}
+
+func SetUserAuthInRequest(r *http.Request, userAuth UserAuth) *http.Request {
+ return r.WithContext(SetUserAuthInContext(r.Context(), userAuth))
+}
+
+func GetUserAuthFromContext(ctx context.Context) (UserAuth, error) {
+ if userAuth, ok := ctx.Value(UserAuthContextKey).(UserAuth); ok {
+ return userAuth, nil
+ }
+ return UserAuth{}, fmt.Errorf("user auth not in context")
+}
+
+func SetUserAuthInContext(ctx context.Context, userAuth UserAuth) context.Context {
+ //nolint
+ ctx = context.WithValue(ctx, UserIDKey, userAuth.UserId)
+ //nolint
+ ctx = context.WithValue(ctx, AccountIDKey, userAuth.AccountId)
+ return context.WithValue(ctx, UserAuthContextKey, userAuth)
+}
diff --git a/management/server/grpcserver.go b/management/server/grpcserver.go
index c3a36153a..9f77fd242 100644
--- a/management/server/grpcserver.go
+++ b/management/server/grpcserver.go
@@ -20,8 +20,8 @@ import (
"github.com/netbirdio/netbird/encryption"
"github.com/netbirdio/netbird/management/proto"
+ "github.com/netbirdio/netbird/management/server/auth"
nbContext "github.com/netbirdio/netbird/management/server/context"
- "github.com/netbirdio/netbird/management/server/jwtclaims"
nbpeer "github.com/netbirdio/netbird/management/server/peer"
"github.com/netbirdio/netbird/management/server/posture"
"github.com/netbirdio/netbird/management/server/settings"
@@ -39,11 +39,10 @@ type GRPCServer struct {
peersUpdateManager *PeersUpdateManager
config *Config
secretsManager SecretsManager
- jwtValidator jwtclaims.JWTValidator
- jwtClaimsExtractor *jwtclaims.ClaimsExtractor
appMetrics telemetry.AppMetrics
ephemeralManager *EphemeralManager
peerLocks sync.Map
+ authManager auth.Manager
}
// NewServer creates a new Management server
@@ -56,29 +55,13 @@ func NewServer(
secretsManager SecretsManager,
appMetrics telemetry.AppMetrics,
ephemeralManager *EphemeralManager,
+ authManager auth.Manager,
) (*GRPCServer, error) {
key, err := wgtypes.GeneratePrivateKey()
if err != nil {
return nil, err
}
- var jwtValidator jwtclaims.JWTValidator
-
- if config.HttpConfig != nil && config.HttpConfig.AuthIssuer != "" && config.HttpConfig.AuthAudience != "" && validateURL(config.HttpConfig.AuthKeysLocation) {
- jwtValidator, err = jwtclaims.NewJWTValidator(
- ctx,
- config.HttpConfig.AuthIssuer,
- config.GetAuthAudiences(),
- config.HttpConfig.AuthKeysLocation,
- config.HttpConfig.IdpSignKeyRefreshEnabled,
- )
- if err != nil {
- return nil, status.Errorf(codes.Internal, "unable to create new jwt middleware, err: %v", err)
- }
- } else {
- log.WithContext(ctx).Debug("unable to use http config to create new jwt middleware")
- }
-
if appMetrics != nil {
// update gauge based on number of connected peers which is equal to open gRPC streams
err = appMetrics.GRPCMetrics().RegisterConnectedStreams(func() int64 {
@@ -89,16 +72,6 @@ func NewServer(
}
}
- var audience, userIDClaim string
- if config.HttpConfig != nil {
- audience = config.HttpConfig.AuthAudience
- userIDClaim = config.HttpConfig.AuthUserIDClaim
- }
- jwtClaimsExtractor := jwtclaims.NewClaimsExtractor(
- jwtclaims.WithAudience(audience),
- jwtclaims.WithUserIDClaim(userIDClaim),
- )
-
return &GRPCServer{
wgKey: key,
// peerKey -> event channel
@@ -107,8 +80,7 @@ func NewServer(
settingsManager: settingsManager,
config: config,
secretsManager: secretsManager,
- jwtValidator: jwtValidator,
- jwtClaimsExtractor: jwtClaimsExtractor,
+ authManager: authManager,
appMetrics: appMetrics,
ephemeralManager: ephemeralManager,
}, nil
@@ -294,26 +266,37 @@ func (s *GRPCServer) cancelPeerRoutines(ctx context.Context, accountID string, p
}
func (s *GRPCServer) validateToken(ctx context.Context, jwtToken string) (string, error) {
- if s.jwtValidator == nil {
- return "", status.Error(codes.Internal, "no jwt validator set")
+ if s.authManager == nil {
+ return "", status.Errorf(codes.Internal, "missing auth manager")
}
- token, err := s.jwtValidator.ValidateAndParse(ctx, jwtToken)
+ userAuth, token, err := s.authManager.ValidateAndParseToken(ctx, jwtToken)
if err != nil {
return "", status.Errorf(codes.InvalidArgument, "invalid jwt token, err: %v", err)
}
- claims := s.jwtClaimsExtractor.FromToken(token)
+
// we need to call this method because if user is new, we will automatically add it to existing or create a new account
- _, _, err = s.accountManager.GetAccountIDFromToken(ctx, claims)
+ accountId, _, err := s.accountManager.GetAccountIDFromUserAuth(ctx, userAuth)
if err != nil {
return "", status.Errorf(codes.Internal, "unable to fetch account with claims, err: %v", err)
}
- if err := s.accountManager.CheckUserAccessByJWTGroups(ctx, claims); err != nil {
+ if userAuth.AccountId != accountId {
+ log.WithContext(ctx).Debugf("gRPC server sets accountId from ensure, before %s, now %s", userAuth.AccountId, accountId)
+ userAuth.AccountId = accountId
+ }
+
+ userAuth, err = s.authManager.EnsureUserAccessByJWTGroups(ctx, userAuth, token)
+ if err != nil {
return "", status.Error(codes.PermissionDenied, err.Error())
}
- return claims.UserId, nil
+ err = s.accountManager.SyncUserJWTGroups(ctx, userAuth)
+ if err != nil {
+ log.WithContext(ctx).Errorf("gRPC server failed to sync user JWT groups: %s", err)
+ }
+
+ return userAuth.UserId, nil
}
func (s *GRPCServer) acquirePeerLockByUID(ctx context.Context, uniqueID string) (unlock func()) {
diff --git a/management/server/http/handler.go b/management/server/http/handler.go
index eb1cfb5dd..7dd277daa 100644
--- a/management/server/http/handler.go
+++ b/management/server/http/handler.go
@@ -13,9 +13,9 @@ import (
"github.com/netbirdio/netbird/management/server/permissions"
s "github.com/netbirdio/netbird/management/server"
+ "github.com/netbirdio/netbird/management/server/auth"
"github.com/netbirdio/netbird/management/server/geolocation"
nbgroups "github.com/netbirdio/netbird/management/server/groups"
- "github.com/netbirdio/netbird/management/server/http/configs"
"github.com/netbirdio/netbird/management/server/http/handlers/accounts"
"github.com/netbirdio/netbird/management/server/http/handlers/dns"
"github.com/netbirdio/netbird/management/server/http/handlers/events"
@@ -27,8 +27,12 @@ import (
"github.com/netbirdio/netbird/management/server/http/handlers/setup_keys"
"github.com/netbirdio/netbird/management/server/http/handlers/users"
"github.com/netbirdio/netbird/management/server/http/middleware"
+<<<<<<< HEAD
"github.com/netbirdio/netbird/management/server/integrations/integrated_validator"
"github.com/netbirdio/netbird/management/server/jwtclaims"
+=======
+ "github.com/netbirdio/netbird/management/server/integrated_validator"
+>>>>>>> main
nbnetworks "github.com/netbirdio/netbird/management/server/networks"
"github.com/netbirdio/netbird/management/server/networks/resources"
"github.com/netbirdio/netbird/management/server/networks/routers"
@@ -39,55 +43,63 @@ import (
const apiPrefix = "/api"
// NewAPIHandler creates the Management service HTTP API handler registering all the available endpoints.
+<<<<<<< HEAD
func NewAPIHandler(ctx context.Context, accountManager s.AccountManager, networksManager nbnetworks.Manager, resourceManager resources.Manager, routerManager routers.Manager, groupsManager nbgroups.Manager, LocationManager geolocation.Geolocation, jwtValidator jwtclaims.JWTValidator, appMetrics telemetry.AppMetrics, authCfg configs.AuthCfg, integratedValidator integrated_validator.IntegratedValidator, proxyController port_forwarding.Controller, permissionsManager permissions.Manager, peersManager nbpeers.Manager) (http.Handler, error) {
claimsExtractor := jwtclaims.NewClaimsExtractor(
jwtclaims.WithAudience(authCfg.Audience),
jwtclaims.WithUserIDClaim(authCfg.UserIDClaim),
)
+=======
+func NewAPIHandler(
+ ctx context.Context,
+ accountManager s.AccountManager,
+ networksManager nbnetworks.Manager,
+ resourceManager resources.Manager,
+ routerManager routers.Manager,
+ groupsManager nbgroups.Manager,
+ LocationManager geolocation.Geolocation,
+ authManager auth.Manager,
+ appMetrics telemetry.AppMetrics,
+ config *s.Config,
+ integratedValidator integrated_validator.IntegratedValidator) (http.Handler, error) {
+>>>>>>> main
authMiddleware := middleware.NewAuthMiddleware(
- accountManager.GetPATInfo,
- jwtValidator.ValidateAndParse,
- accountManager.MarkPATUsed,
- accountManager.CheckUserAccessByJWTGroups,
- claimsExtractor,
- authCfg.Audience,
- authCfg.UserIDClaim,
+ authManager,
+ accountManager.GetAccountIDFromUserAuth,
+ accountManager.SyncUserJWTGroups,
)
corsMiddleware := cors.AllowAll()
- claimsExtractor = jwtclaims.NewClaimsExtractor(
- jwtclaims.WithAudience(authCfg.Audience),
- jwtclaims.WithUserIDClaim(authCfg.UserIDClaim),
- )
-
- acMiddleware := middleware.NewAccessControl(
- authCfg.Audience,
- authCfg.UserIDClaim,
- accountManager.GetUser)
+ acMiddleware := middleware.NewAccessControl(accountManager.GetUserFromUserAuth)
rootRouter := mux.NewRouter()
metricsMiddleware := appMetrics.HTTPMiddleware()
prefix := apiPrefix
router := rootRouter.PathPrefix(prefix).Subrouter()
+
router.Use(metricsMiddleware.Handler, corsMiddleware.Handler, authMiddleware.Handler, acMiddleware.Handler)
+<<<<<<< HEAD
if _, err := integrations.RegisterHandlers(ctx, prefix, router, accountManager, claimsExtractor, integratedValidator, appMetrics.GetMeter(), permissionsManager, peersManager, proxyController); err != nil {
+=======
+ if _, err := integrations.RegisterHandlers(ctx, prefix, router, accountManager, integratedValidator, appMetrics.GetMeter()); err != nil {
+>>>>>>> main
return nil, fmt.Errorf("register integrations endpoints: %w", err)
}
- accounts.AddEndpoints(accountManager, authCfg, router)
- peers.AddEndpoints(accountManager, authCfg, router)
- users.AddEndpoints(accountManager, authCfg, router)
- setup_keys.AddEndpoints(accountManager, authCfg, router)
- policies.AddEndpoints(accountManager, LocationManager, authCfg, router)
- groups.AddEndpoints(accountManager, authCfg, router)
- routes.AddEndpoints(accountManager, authCfg, router)
- dns.AddEndpoints(accountManager, authCfg, router)
- events.AddEndpoints(accountManager, authCfg, router)
- networks.AddEndpoints(networksManager, resourceManager, routerManager, groupsManager, accountManager, accountManager.GetAccountIDFromToken, authCfg, router)
+ accounts.AddEndpoints(accountManager, router)
+ peers.AddEndpoints(accountManager, router)
+ users.AddEndpoints(accountManager, router)
+ setup_keys.AddEndpoints(accountManager, router)
+ policies.AddEndpoints(accountManager, LocationManager, router)
+ groups.AddEndpoints(accountManager, router)
+ routes.AddEndpoints(accountManager, router)
+ dns.AddEndpoints(accountManager, router)
+ events.AddEndpoints(accountManager, router)
+ networks.AddEndpoints(networksManager, resourceManager, routerManager, groupsManager, accountManager, router)
return rootRouter, nil
}
diff --git a/management/server/http/handlers/accounts/accounts_handler.go b/management/server/http/handlers/accounts/accounts_handler.go
index a23628cdc..bc0054a7f 100644
--- a/management/server/http/handlers/accounts/accounts_handler.go
+++ b/management/server/http/handlers/accounts/accounts_handler.go
@@ -9,47 +9,42 @@ import (
"github.com/netbirdio/netbird/management/server"
"github.com/netbirdio/netbird/management/server/account"
+ nbcontext "github.com/netbirdio/netbird/management/server/context"
"github.com/netbirdio/netbird/management/server/http/api"
- "github.com/netbirdio/netbird/management/server/http/configs"
"github.com/netbirdio/netbird/management/server/http/util"
- "github.com/netbirdio/netbird/management/server/jwtclaims"
"github.com/netbirdio/netbird/management/server/status"
"github.com/netbirdio/netbird/management/server/types"
)
// handler is a handler that handles the server.Account HTTP endpoints
type handler struct {
- accountManager server.AccountManager
- claimsExtractor *jwtclaims.ClaimsExtractor
+ accountManager server.AccountManager
}
-func AddEndpoints(accountManager server.AccountManager, authCfg configs.AuthCfg, router *mux.Router) {
- accountsHandler := newHandler(accountManager, authCfg)
+func AddEndpoints(accountManager server.AccountManager, router *mux.Router) {
+ accountsHandler := newHandler(accountManager)
router.HandleFunc("/accounts/{accountId}", accountsHandler.updateAccount).Methods("PUT", "OPTIONS")
router.HandleFunc("/accounts/{accountId}", accountsHandler.deleteAccount).Methods("DELETE", "OPTIONS")
router.HandleFunc("/accounts", accountsHandler.getAllAccounts).Methods("GET", "OPTIONS")
}
// newHandler creates a new handler HTTP handler
-func newHandler(accountManager server.AccountManager, authCfg configs.AuthCfg) *handler {
+func newHandler(accountManager server.AccountManager) *handler {
return &handler{
accountManager: accountManager,
- claimsExtractor: jwtclaims.NewClaimsExtractor(
- jwtclaims.WithAudience(authCfg.Audience),
- jwtclaims.WithUserIDClaim(authCfg.UserIDClaim),
- ),
}
}
// getAllAccounts is HTTP GET handler that returns a list of accounts. Effectively returns just a single account.
func (h *handler) getAllAccounts(w http.ResponseWriter, r *http.Request) {
- claims := h.claimsExtractor.FromRequestContext(r)
- accountID, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims)
+ userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
+ accountID, userID := userAuth.AccountId, userAuth.UserId
+
settings, err := h.accountManager.GetAccountSettings(r.Context(), accountID, userID)
if err != nil {
util.WriteError(r.Context(), err, w)
@@ -62,13 +57,14 @@ func (h *handler) getAllAccounts(w http.ResponseWriter, r *http.Request) {
// updateAccount is HTTP PUT handler that updates the provided account. Updates only account settings (server.Settings)
func (h *handler) updateAccount(w http.ResponseWriter, r *http.Request) {
- claims := h.claimsExtractor.FromRequestContext(r)
- _, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims)
+ userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
+ _, userID := userAuth.AccountId, userAuth.UserId
+
vars := mux.Vars(r)
accountID := vars["accountId"]
if len(accountID) == 0 {
@@ -125,7 +121,12 @@ func (h *handler) updateAccount(w http.ResponseWriter, r *http.Request) {
// deleteAccount is a HTTP DELETE handler to delete an account
func (h *handler) deleteAccount(w http.ResponseWriter, r *http.Request) {
- claims := h.claimsExtractor.FromRequestContext(r)
+ userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
+ if err != nil {
+ util.WriteError(r.Context(), err, w)
+ return
+ }
+
vars := mux.Vars(r)
targetAccountID := vars["accountId"]
if len(targetAccountID) == 0 {
@@ -133,7 +134,7 @@ func (h *handler) deleteAccount(w http.ResponseWriter, r *http.Request) {
return
}
- err := h.accountManager.DeleteAccount(r.Context(), targetAccountID, claims.UserId)
+ err = h.accountManager.DeleteAccount(r.Context(), targetAccountID, userAuth.UserId)
if err != nil {
util.WriteError(r.Context(), err, w)
return
diff --git a/management/server/http/handlers/accounts/accounts_handler_test.go b/management/server/http/handlers/accounts/accounts_handler_test.go
index e8a599863..a8d57a13f 100644
--- a/management/server/http/handlers/accounts/accounts_handler_test.go
+++ b/management/server/http/handlers/accounts/accounts_handler_test.go
@@ -13,19 +13,16 @@ import (
"github.com/gorilla/mux"
"github.com/stretchr/testify/assert"
+ nbcontext "github.com/netbirdio/netbird/management/server/context"
"github.com/netbirdio/netbird/management/server/http/api"
- "github.com/netbirdio/netbird/management/server/jwtclaims"
"github.com/netbirdio/netbird/management/server/mock_server"
"github.com/netbirdio/netbird/management/server/status"
"github.com/netbirdio/netbird/management/server/types"
)
-func initAccountsTestData(account *types.Account, admin *types.User) *handler {
+func initAccountsTestData(account *types.Account) *handler {
return &handler{
accountManager: &mock_server.MockAccountManager{
- GetAccountIDFromTokenFunc: func(ctx context.Context, claims jwtclaims.AuthorizationClaims) (string, string, error) {
- return account.Id, admin.Id, nil
- },
GetAccountSettingsFunc: func(ctx context.Context, accountID string, userID string) (*types.Settings, error) {
return account.Settings, nil
},
@@ -44,15 +41,6 @@ func initAccountsTestData(account *types.Account, admin *types.User) *handler {
return accCopy, nil
},
},
- claimsExtractor: jwtclaims.NewClaimsExtractor(
- jwtclaims.WithFromRequestContext(func(r *http.Request) jwtclaims.AuthorizationClaims {
- return jwtclaims.AuthorizationClaims{
- UserId: "test_user",
- Domain: "hotmail.com",
- AccountId: "test_account",
- }
- }),
- ),
}
}
@@ -75,7 +63,7 @@ func TestAccounts_AccountsHandler(t *testing.T) {
PeerLoginExpiration: time.Hour,
RegularUsersViewBlocked: true,
},
- }, adminUser)
+ })
tt := []struct {
name string
@@ -191,6 +179,11 @@ func TestAccounts_AccountsHandler(t *testing.T) {
t.Run(tc.name, func(t *testing.T) {
recorder := httptest.NewRecorder()
req := httptest.NewRequest(tc.requestType, tc.requestPath, tc.requestBody)
+ req = nbcontext.SetUserAuthInRequest(req, nbcontext.UserAuth{
+ UserId: adminUser.Id,
+ AccountId: accountID,
+ Domain: "hotmail.com",
+ })
router := mux.NewRouter()
router.HandleFunc("/api/accounts", handler.getAllAccounts).Methods("GET")
diff --git a/management/server/http/handlers/dns/dns_settings_handler.go b/management/server/http/handlers/dns/dns_settings_handler.go
index 112eee179..6ff938369 100644
--- a/management/server/http/handlers/dns/dns_settings_handler.go
+++ b/management/server/http/handlers/dns/dns_settings_handler.go
@@ -8,51 +8,44 @@ import (
log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/management/server"
+ nbcontext "github.com/netbirdio/netbird/management/server/context"
"github.com/netbirdio/netbird/management/server/http/api"
- "github.com/netbirdio/netbird/management/server/http/configs"
"github.com/netbirdio/netbird/management/server/http/util"
- "github.com/netbirdio/netbird/management/server/jwtclaims"
"github.com/netbirdio/netbird/management/server/types"
)
// dnsSettingsHandler is a handler that returns the DNS settings of the account
type dnsSettingsHandler struct {
- accountManager server.AccountManager
- claimsExtractor *jwtclaims.ClaimsExtractor
+ accountManager server.AccountManager
}
-func AddEndpoints(accountManager server.AccountManager, authCfg configs.AuthCfg, router *mux.Router) {
- addDNSSettingEndpoint(accountManager, authCfg, router)
- addDNSNameserversEndpoint(accountManager, authCfg, router)
+func AddEndpoints(accountManager server.AccountManager, router *mux.Router) {
+ addDNSSettingEndpoint(accountManager, router)
+ addDNSNameserversEndpoint(accountManager, router)
}
-func addDNSSettingEndpoint(accountManager server.AccountManager, authCfg configs.AuthCfg, router *mux.Router) {
- dnsSettingsHandler := newDNSSettingsHandler(accountManager, authCfg)
+func addDNSSettingEndpoint(accountManager server.AccountManager, router *mux.Router) {
+ dnsSettingsHandler := newDNSSettingsHandler(accountManager)
router.HandleFunc("/dns/settings", dnsSettingsHandler.getDNSSettings).Methods("GET", "OPTIONS")
router.HandleFunc("/dns/settings", dnsSettingsHandler.updateDNSSettings).Methods("PUT", "OPTIONS")
}
// newDNSSettingsHandler returns a new instance of dnsSettingsHandler handler
-func newDNSSettingsHandler(accountManager server.AccountManager, authCfg configs.AuthCfg) *dnsSettingsHandler {
- return &dnsSettingsHandler{
- accountManager: accountManager,
- claimsExtractor: jwtclaims.NewClaimsExtractor(
- jwtclaims.WithAudience(authCfg.Audience),
- jwtclaims.WithUserIDClaim(authCfg.UserIDClaim),
- ),
- }
+func newDNSSettingsHandler(accountManager server.AccountManager) *dnsSettingsHandler {
+ return &dnsSettingsHandler{accountManager: accountManager}
}
// getDNSSettings returns the DNS settings for the account
func (h *dnsSettingsHandler) getDNSSettings(w http.ResponseWriter, r *http.Request) {
- claims := h.claimsExtractor.FromRequestContext(r)
- accountID, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims)
+ userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
if err != nil {
log.WithContext(r.Context()).Error(err)
http.Redirect(w, r, "/", http.StatusInternalServerError)
return
}
+ accountID, userID := userAuth.AccountId, userAuth.UserId
+
dnsSettings, err := h.accountManager.GetDNSSettings(r.Context(), accountID, userID)
if err != nil {
util.WriteError(r.Context(), err, w)
@@ -68,13 +61,14 @@ func (h *dnsSettingsHandler) getDNSSettings(w http.ResponseWriter, r *http.Reque
// updateDNSSettings handles update to DNS settings of an account
func (h *dnsSettingsHandler) updateDNSSettings(w http.ResponseWriter, r *http.Request) {
- claims := h.claimsExtractor.FromRequestContext(r)
- accountID, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims)
+ userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
+ accountID, userID := userAuth.AccountId, userAuth.UserId
+
var req api.PutApiDnsSettingsJSONRequestBody
err = json.NewDecoder(r.Body).Decode(&req)
if err != nil {
diff --git a/management/server/http/handlers/dns/dns_settings_handler_test.go b/management/server/http/handlers/dns/dns_settings_handler_test.go
index 9ca1dc032..ca81adf43 100644
--- a/management/server/http/handlers/dns/dns_settings_handler_test.go
+++ b/management/server/http/handlers/dns/dns_settings_handler_test.go
@@ -17,7 +17,8 @@ import (
"github.com/gorilla/mux"
- "github.com/netbirdio/netbird/management/server/jwtclaims"
+ nbcontext "github.com/netbirdio/netbird/management/server/context"
+
"github.com/netbirdio/netbird/management/server/mock_server"
)
@@ -52,19 +53,7 @@ func initDNSSettingsTestData() *dnsSettingsHandler {
}
return status.Errorf(status.InvalidArgument, "the dns settings provided are nil")
},
- GetAccountIDFromTokenFunc: func(ctx context.Context, _ jwtclaims.AuthorizationClaims) (string, string, error) {
- return testingDNSSettingsAccount.Id, testingDNSSettingsAccount.Users[testDNSSettingsUserID].Id, nil
- },
},
- claimsExtractor: jwtclaims.NewClaimsExtractor(
- jwtclaims.WithFromRequestContext(func(r *http.Request) jwtclaims.AuthorizationClaims {
- return jwtclaims.AuthorizationClaims{
- UserId: "test_user",
- Domain: "hotmail.com",
- AccountId: testDNSSettingsAccountID,
- }
- }),
- ),
}
}
@@ -118,6 +107,11 @@ func TestDNSSettingsHandlers(t *testing.T) {
t.Run(tc.name, func(t *testing.T) {
recorder := httptest.NewRecorder()
req := httptest.NewRequest(tc.requestType, tc.requestPath, tc.requestBody)
+ req = nbcontext.SetUserAuthInRequest(req, nbcontext.UserAuth{
+ UserId: testingDNSSettingsAccount.Users[testDNSSettingsUserID].Id,
+ AccountId: testingDNSSettingsAccount.Id,
+ Domain: testingDNSSettingsAccount.Domain,
+ })
router := mux.NewRouter()
router.HandleFunc("/api/dns/settings", p.getDNSSettings).Methods("GET")
diff --git a/management/server/http/handlers/dns/nameservers_handler.go b/management/server/http/handlers/dns/nameservers_handler.go
index 09047e231..33d070477 100644
--- a/management/server/http/handlers/dns/nameservers_handler.go
+++ b/management/server/http/handlers/dns/nameservers_handler.go
@@ -10,21 +10,19 @@ import (
nbdns "github.com/netbirdio/netbird/dns"
"github.com/netbirdio/netbird/management/server"
+ nbcontext "github.com/netbirdio/netbird/management/server/context"
"github.com/netbirdio/netbird/management/server/http/api"
- "github.com/netbirdio/netbird/management/server/http/configs"
"github.com/netbirdio/netbird/management/server/http/util"
- "github.com/netbirdio/netbird/management/server/jwtclaims"
"github.com/netbirdio/netbird/management/server/status"
)
// nameserversHandler is the nameserver group handler of the account
type nameserversHandler struct {
- accountManager server.AccountManager
- claimsExtractor *jwtclaims.ClaimsExtractor
+ accountManager server.AccountManager
}
-func addDNSNameserversEndpoint(accountManager server.AccountManager, authCfg configs.AuthCfg, router *mux.Router) {
- nameserversHandler := newNameserversHandler(accountManager, authCfg)
+func addDNSNameserversEndpoint(accountManager server.AccountManager, router *mux.Router) {
+ nameserversHandler := newNameserversHandler(accountManager)
router.HandleFunc("/dns/nameservers", nameserversHandler.getAllNameservers).Methods("GET", "OPTIONS")
router.HandleFunc("/dns/nameservers", nameserversHandler.createNameserverGroup).Methods("POST", "OPTIONS")
router.HandleFunc("/dns/nameservers/{nsgroupId}", nameserversHandler.updateNameserverGroup).Methods("PUT", "OPTIONS")
@@ -33,26 +31,21 @@ func addDNSNameserversEndpoint(accountManager server.AccountManager, authCfg con
}
// newNameserversHandler returns a new instance of nameserversHandler handler
-func newNameserversHandler(accountManager server.AccountManager, authCfg configs.AuthCfg) *nameserversHandler {
- return &nameserversHandler{
- accountManager: accountManager,
- claimsExtractor: jwtclaims.NewClaimsExtractor(
- jwtclaims.WithAudience(authCfg.Audience),
- jwtclaims.WithUserIDClaim(authCfg.UserIDClaim),
- ),
- }
+func newNameserversHandler(accountManager server.AccountManager) *nameserversHandler {
+ return &nameserversHandler{accountManager: accountManager}
}
// getAllNameservers returns the list of nameserver groups for the account
func (h *nameserversHandler) getAllNameservers(w http.ResponseWriter, r *http.Request) {
- claims := h.claimsExtractor.FromRequestContext(r)
- accountID, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims)
+ userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
if err != nil {
log.WithContext(r.Context()).Error(err)
http.Redirect(w, r, "/", http.StatusInternalServerError)
return
}
+ accountID, userID := userAuth.AccountId, userAuth.UserId
+
nsGroups, err := h.accountManager.ListNameServerGroups(r.Context(), accountID, userID)
if err != nil {
util.WriteError(r.Context(), err, w)
@@ -69,13 +62,14 @@ func (h *nameserversHandler) getAllNameservers(w http.ResponseWriter, r *http.Re
// createNameserverGroup handles nameserver group creation request
func (h *nameserversHandler) createNameserverGroup(w http.ResponseWriter, r *http.Request) {
- claims := h.claimsExtractor.FromRequestContext(r)
- accountID, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims)
+ userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
+ accountID, userID := userAuth.AccountId, userAuth.UserId
+
var req api.PostApiDnsNameserversJSONRequestBody
err = json.NewDecoder(r.Body).Decode(&req)
if err != nil {
@@ -102,13 +96,14 @@ func (h *nameserversHandler) createNameserverGroup(w http.ResponseWriter, r *htt
// updateNameserverGroup handles update to a nameserver group identified by a given ID
func (h *nameserversHandler) updateNameserverGroup(w http.ResponseWriter, r *http.Request) {
- claims := h.claimsExtractor.FromRequestContext(r)
- accountID, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims)
+ userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
+ accountID, userID := userAuth.AccountId, userAuth.UserId
+
nsGroupID := mux.Vars(r)["nsgroupId"]
if len(nsGroupID) == 0 {
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "invalid nameserver group ID"), w)
@@ -153,13 +148,14 @@ func (h *nameserversHandler) updateNameserverGroup(w http.ResponseWriter, r *htt
// deleteNameserverGroup handles nameserver group deletion request
func (h *nameserversHandler) deleteNameserverGroup(w http.ResponseWriter, r *http.Request) {
- claims := h.claimsExtractor.FromRequestContext(r)
- accountID, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims)
+ userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
+ accountID, userID := userAuth.AccountId, userAuth.UserId
+
nsGroupID := mux.Vars(r)["nsgroupId"]
if len(nsGroupID) == 0 {
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "invalid nameserver group ID"), w)
@@ -177,14 +173,14 @@ func (h *nameserversHandler) deleteNameserverGroup(w http.ResponseWriter, r *htt
// getNameserverGroup handles a nameserver group Get request identified by ID
func (h *nameserversHandler) getNameserverGroup(w http.ResponseWriter, r *http.Request) {
- claims := h.claimsExtractor.FromRequestContext(r)
- accountID, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims)
+ userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
if err != nil {
- log.WithContext(r.Context()).Error(err)
- http.Redirect(w, r, "/", http.StatusInternalServerError)
+ util.WriteError(r.Context(), err, w)
return
}
+ accountID, userID := userAuth.AccountId, userAuth.UserId
+
nsGroupID := mux.Vars(r)["nsgroupId"]
if len(nsGroupID) == 0 {
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "invalid nameserver group ID"), w)
diff --git a/management/server/http/handlers/dns/nameservers_handler_test.go b/management/server/http/handlers/dns/nameservers_handler_test.go
index c6561e4d8..45283bc37 100644
--- a/management/server/http/handlers/dns/nameservers_handler_test.go
+++ b/management/server/http/handlers/dns/nameservers_handler_test.go
@@ -18,7 +18,8 @@ import (
"github.com/gorilla/mux"
- "github.com/netbirdio/netbird/management/server/jwtclaims"
+ nbcontext "github.com/netbirdio/netbird/management/server/context"
+
"github.com/netbirdio/netbird/management/server/mock_server"
)
@@ -81,19 +82,7 @@ func initNameserversTestData() *nameserversHandler {
}
return status.Errorf(status.NotFound, "nameserver group with ID %s was not found", nsGroupToSave.ID)
},
- GetAccountIDFromTokenFunc: func(_ context.Context, claims jwtclaims.AuthorizationClaims) (string, string, error) {
- return claims.AccountId, claims.UserId, nil
- },
},
- claimsExtractor: jwtclaims.NewClaimsExtractor(
- jwtclaims.WithFromRequestContext(func(r *http.Request) jwtclaims.AuthorizationClaims {
- return jwtclaims.AuthorizationClaims{
- UserId: "test_user",
- Domain: "hotmail.com",
- AccountId: testNSGroupAccountID,
- }
- }),
- ),
}
}
@@ -204,6 +193,11 @@ func TestNameserversHandlers(t *testing.T) {
t.Run(tc.name, func(t *testing.T) {
recorder := httptest.NewRecorder()
req := httptest.NewRequest(tc.requestType, tc.requestPath, tc.requestBody)
+ req = nbcontext.SetUserAuthInRequest(req, nbcontext.UserAuth{
+ UserId: "test_user",
+ AccountId: testNSGroupAccountID,
+ Domain: "hotmail.com",
+ })
router := mux.NewRouter()
router.HandleFunc("/api/dns/nameservers/{nsgroupId}", p.getNameserverGroup).Methods("GET")
diff --git a/management/server/http/handlers/events/events_handler.go b/management/server/http/handlers/events/events_handler.go
index 62da59535..0fb2295a8 100644
--- a/management/server/http/handlers/events/events_handler.go
+++ b/management/server/http/handlers/events/events_handler.go
@@ -10,44 +10,37 @@ import (
"github.com/netbirdio/netbird/management/server"
"github.com/netbirdio/netbird/management/server/activity"
+ nbcontext "github.com/netbirdio/netbird/management/server/context"
"github.com/netbirdio/netbird/management/server/http/api"
- "github.com/netbirdio/netbird/management/server/http/configs"
"github.com/netbirdio/netbird/management/server/http/util"
- "github.com/netbirdio/netbird/management/server/jwtclaims"
)
// handler HTTP handler
type handler struct {
- accountManager server.AccountManager
- claimsExtractor *jwtclaims.ClaimsExtractor
+ accountManager server.AccountManager
}
-func AddEndpoints(accountManager server.AccountManager, authCfg configs.AuthCfg, router *mux.Router) {
- eventsHandler := newHandler(accountManager, authCfg)
+func AddEndpoints(accountManager server.AccountManager, router *mux.Router) {
+ eventsHandler := newHandler(accountManager)
router.HandleFunc("/events", eventsHandler.getAllEvents).Methods("GET", "OPTIONS")
}
// newHandler creates a new events handler
-func newHandler(accountManager server.AccountManager, authCfg configs.AuthCfg) *handler {
- return &handler{
- accountManager: accountManager,
- claimsExtractor: jwtclaims.NewClaimsExtractor(
- jwtclaims.WithAudience(authCfg.Audience),
- jwtclaims.WithUserIDClaim(authCfg.UserIDClaim),
- ),
- }
+func newHandler(accountManager server.AccountManager) *handler {
+ return &handler{accountManager: accountManager}
}
// getAllEvents list of the given account
func (h *handler) getAllEvents(w http.ResponseWriter, r *http.Request) {
- claims := h.claimsExtractor.FromRequestContext(r)
- accountID, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims)
+ userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
if err != nil {
log.WithContext(r.Context()).Error(err)
http.Redirect(w, r, "/", http.StatusInternalServerError)
return
}
+ accountID, userID := userAuth.AccountId, userAuth.UserId
+
accountEvents, err := h.accountManager.GetEvents(r.Context(), accountID, userID)
if err != nil {
util.WriteError(r.Context(), err, w)
diff --git a/management/server/http/handlers/events/events_handler_test.go b/management/server/http/handlers/events/events_handler_test.go
index fd603f289..3a643fe90 100644
--- a/management/server/http/handlers/events/events_handler_test.go
+++ b/management/server/http/handlers/events/events_handler_test.go
@@ -13,9 +13,10 @@ import (
"github.com/gorilla/mux"
"github.com/stretchr/testify/assert"
+ nbcontext "github.com/netbirdio/netbird/management/server/context"
+
"github.com/netbirdio/netbird/management/server/activity"
"github.com/netbirdio/netbird/management/server/http/api"
- "github.com/netbirdio/netbird/management/server/jwtclaims"
"github.com/netbirdio/netbird/management/server/mock_server"
"github.com/netbirdio/netbird/management/server/types"
)
@@ -29,22 +30,10 @@ func initEventsTestData(account string, events ...*activity.Event) *handler {
}
return []*activity.Event{}, nil
},
- GetAccountIDFromTokenFunc: func(_ context.Context, claims jwtclaims.AuthorizationClaims) (string, string, error) {
- return claims.AccountId, claims.UserId, nil
- },
GetUsersFromAccountFunc: func(_ context.Context, accountID, userID string) (map[string]*types.UserInfo, error) {
return make(map[string]*types.UserInfo), nil
},
},
- claimsExtractor: jwtclaims.NewClaimsExtractor(
- jwtclaims.WithFromRequestContext(func(r *http.Request) jwtclaims.AuthorizationClaims {
- return jwtclaims.AuthorizationClaims{
- UserId: "test_user",
- Domain: "hotmail.com",
- AccountId: "test_account",
- }
- }),
- ),
}
}
@@ -199,6 +188,11 @@ func TestEvents_GetEvents(t *testing.T) {
t.Run(tc.name, func(t *testing.T) {
recorder := httptest.NewRecorder()
req := httptest.NewRequest(tc.requestType, tc.requestPath, tc.requestBody)
+ req = nbcontext.SetUserAuthInRequest(req, nbcontext.UserAuth{
+ UserId: "test_user",
+ Domain: "hotmail.com",
+ AccountId: "test_account",
+ })
router := mux.NewRouter()
router.HandleFunc("/api/events/", handler.getAllEvents).Methods("GET")
diff --git a/management/server/http/handlers/groups/groups_handler.go b/management/server/http/handlers/groups/groups_handler.go
index 383a33dea..2d0b8bdbd 100644
--- a/management/server/http/handlers/groups/groups_handler.go
+++ b/management/server/http/handlers/groups/groups_handler.go
@@ -7,24 +7,23 @@ import (
"github.com/gorilla/mux"
log "github.com/sirupsen/logrus"
+ nbcontext "github.com/netbirdio/netbird/management/server/context"
+ nbpeer "github.com/netbirdio/netbird/management/server/peer"
+
"github.com/netbirdio/netbird/management/server"
"github.com/netbirdio/netbird/management/server/http/api"
- "github.com/netbirdio/netbird/management/server/http/configs"
"github.com/netbirdio/netbird/management/server/http/util"
- "github.com/netbirdio/netbird/management/server/jwtclaims"
- nbpeer "github.com/netbirdio/netbird/management/server/peer"
"github.com/netbirdio/netbird/management/server/status"
"github.com/netbirdio/netbird/management/server/types"
)
// handler is a handler that returns groups of the account
type handler struct {
- accountManager server.AccountManager
- claimsExtractor *jwtclaims.ClaimsExtractor
+ accountManager server.AccountManager
}
-func AddEndpoints(accountManager server.AccountManager, authCfg configs.AuthCfg, router *mux.Router) {
- groupsHandler := newHandler(accountManager, authCfg)
+func AddEndpoints(accountManager server.AccountManager, router *mux.Router) {
+ groupsHandler := newHandler(accountManager)
router.HandleFunc("/groups", groupsHandler.getAllGroups).Methods("GET", "OPTIONS")
router.HandleFunc("/groups", groupsHandler.createGroup).Methods("POST", "OPTIONS")
router.HandleFunc("/groups/{groupId}", groupsHandler.updateGroup).Methods("PUT", "OPTIONS")
@@ -33,25 +32,21 @@ func AddEndpoints(accountManager server.AccountManager, authCfg configs.AuthCfg,
}
// newHandler creates a new groups handler
-func newHandler(accountManager server.AccountManager, authCfg configs.AuthCfg) *handler {
+func newHandler(accountManager server.AccountManager) *handler {
return &handler{
accountManager: accountManager,
- claimsExtractor: jwtclaims.NewClaimsExtractor(
- jwtclaims.WithAudience(authCfg.Audience),
- jwtclaims.WithUserIDClaim(authCfg.UserIDClaim),
- ),
}
}
// getAllGroups list for the account
func (h *handler) getAllGroups(w http.ResponseWriter, r *http.Request) {
- claims := h.claimsExtractor.FromRequestContext(r)
- accountID, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims)
+ userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
if err != nil {
log.WithContext(r.Context()).Error(err)
http.Redirect(w, r, "/", http.StatusInternalServerError)
return
}
+ accountID, userID := userAuth.AccountId, userAuth.UserId
groups, err := h.accountManager.GetAllGroups(r.Context(), accountID, userID)
if err != nil {
@@ -75,13 +70,14 @@ func (h *handler) getAllGroups(w http.ResponseWriter, r *http.Request) {
// updateGroup handles update to a group identified by a given ID
func (h *handler) updateGroup(w http.ResponseWriter, r *http.Request) {
- claims := h.claimsExtractor.FromRequestContext(r)
- accountID, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims)
+ userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
+ accountID, userID := userAuth.AccountId, userAuth.UserId
+
vars := mux.Vars(r)
groupID, ok := vars["groupId"]
if !ok {
@@ -164,13 +160,14 @@ func (h *handler) updateGroup(w http.ResponseWriter, r *http.Request) {
// createGroup handles group creation request
func (h *handler) createGroup(w http.ResponseWriter, r *http.Request) {
- claims := h.claimsExtractor.FromRequestContext(r)
- accountID, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims)
+ userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
+ accountID, userID := userAuth.AccountId, userAuth.UserId
+
var req api.PostApiGroupsJSONRequestBody
err = json.NewDecoder(r.Body).Decode(&req)
if err != nil {
@@ -223,13 +220,14 @@ func (h *handler) createGroup(w http.ResponseWriter, r *http.Request) {
// deleteGroup handles group deletion request
func (h *handler) deleteGroup(w http.ResponseWriter, r *http.Request) {
- claims := h.claimsExtractor.FromRequestContext(r)
- accountID, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims)
+ userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
+ accountID, userID := userAuth.AccountId, userAuth.UserId
+
groupID := mux.Vars(r)["groupId"]
if len(groupID) == 0 {
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "invalid group ID"), w)
@@ -253,12 +251,13 @@ func (h *handler) deleteGroup(w http.ResponseWriter, r *http.Request) {
// getGroup returns a group
func (h *handler) getGroup(w http.ResponseWriter, r *http.Request) {
- claims := h.claimsExtractor.FromRequestContext(r)
- accountID, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims)
+ userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
+
+ accountID, userID := userAuth.AccountId, userAuth.UserId
groupID := mux.Vars(r)["groupId"]
if len(groupID) == 0 {
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "invalid group ID"), w)
diff --git a/management/server/http/handlers/groups/groups_handler_test.go b/management/server/http/handlers/groups/groups_handler_test.go
index c82d33240..f4ac34e53 100644
--- a/management/server/http/handlers/groups/groups_handler_test.go
+++ b/management/server/http/handlers/groups/groups_handler_test.go
@@ -18,9 +18,9 @@ import (
"golang.org/x/exp/maps"
"github.com/netbirdio/netbird/management/server"
+ nbcontext "github.com/netbirdio/netbird/management/server/context"
"github.com/netbirdio/netbird/management/server/http/api"
"github.com/netbirdio/netbird/management/server/http/util"
- "github.com/netbirdio/netbird/management/server/jwtclaims"
"github.com/netbirdio/netbird/management/server/mock_server"
nbpeer "github.com/netbirdio/netbird/management/server/peer"
"github.com/netbirdio/netbird/management/server/status"
@@ -59,9 +59,6 @@ func initGroupTestData(initGroups ...*types.Group) *handler {
return group, nil
},
- GetAccountIDFromTokenFunc: func(_ context.Context, claims jwtclaims.AuthorizationClaims) (string, string, error) {
- return claims.AccountId, claims.UserId, nil
- },
GetGroupByNameFunc: func(ctx context.Context, groupName, _ string) (*types.Group, error) {
if groupName == "All" {
return &types.Group{ID: "id-all", Name: "All", Issued: types.GroupIssuedAPI}, nil
@@ -87,15 +84,6 @@ func initGroupTestData(initGroups ...*types.Group) *handler {
return nil
},
},
- claimsExtractor: jwtclaims.NewClaimsExtractor(
- jwtclaims.WithFromRequestContext(func(r *http.Request) jwtclaims.AuthorizationClaims {
- return jwtclaims.AuthorizationClaims{
- UserId: "test_user",
- Domain: "hotmail.com",
- AccountId: "test_id",
- }
- }),
- ),
}
}
@@ -134,6 +122,11 @@ func TestGetGroup(t *testing.T) {
t.Run(tc.name, func(t *testing.T) {
recorder := httptest.NewRecorder()
req := httptest.NewRequest(tc.requestType, tc.requestPath, tc.requestBody)
+ req = nbcontext.SetUserAuthInRequest(req, nbcontext.UserAuth{
+ UserId: "test_user",
+ Domain: "hotmail.com",
+ AccountId: "test_id",
+ })
router := mux.NewRouter()
router.HandleFunc("/api/groups/{groupId}", p.getGroup).Methods("GET")
@@ -255,6 +248,11 @@ func TestWriteGroup(t *testing.T) {
t.Run(tc.name, func(t *testing.T) {
recorder := httptest.NewRecorder()
req := httptest.NewRequest(tc.requestType, tc.requestPath, tc.requestBody)
+ req = nbcontext.SetUserAuthInRequest(req, nbcontext.UserAuth{
+ UserId: "test_user",
+ Domain: "hotmail.com",
+ AccountId: "test_id",
+ })
router := mux.NewRouter()
router.HandleFunc("/api/groups", p.createGroup).Methods("POST")
@@ -332,7 +330,11 @@ func TestDeleteGroup(t *testing.T) {
t.Run(tc.name, func(t *testing.T) {
recorder := httptest.NewRecorder()
req := httptest.NewRequest(tc.requestType, tc.requestPath, nil)
-
+ req = nbcontext.SetUserAuthInRequest(req, nbcontext.UserAuth{
+ UserId: "test_user",
+ Domain: "hotmail.com",
+ AccountId: "test_id",
+ })
router := mux.NewRouter()
router.HandleFunc("/api/groups/{groupId}", p.deleteGroup).Methods("DELETE")
router.ServeHTTP(recorder, req)
diff --git a/management/server/http/handlers/networks/handler.go b/management/server/http/handlers/networks/handler.go
index f716348d6..bb6b97267 100644
--- a/management/server/http/handlers/networks/handler.go
+++ b/management/server/http/handlers/networks/handler.go
@@ -10,11 +10,10 @@ import (
log "github.com/sirupsen/logrus"
s "github.com/netbirdio/netbird/management/server"
+ nbcontext "github.com/netbirdio/netbird/management/server/context"
"github.com/netbirdio/netbird/management/server/groups"
"github.com/netbirdio/netbird/management/server/http/api"
- "github.com/netbirdio/netbird/management/server/http/configs"
"github.com/netbirdio/netbird/management/server/http/util"
- "github.com/netbirdio/netbird/management/server/jwtclaims"
"github.com/netbirdio/netbird/management/server/networks"
"github.com/netbirdio/netbird/management/server/networks/resources"
"github.com/netbirdio/netbird/management/server/networks/routers"
@@ -31,16 +30,14 @@ type handler struct {
routerManager routers.Manager
accountManager s.AccountManager
- groupsManager groups.Manager
- extractFromToken func(ctx context.Context, claims jwtclaims.AuthorizationClaims) (string, string, error)
- claimsExtractor *jwtclaims.ClaimsExtractor
+ groupsManager groups.Manager
}
-func AddEndpoints(networksManager networks.Manager, resourceManager resources.Manager, routerManager routers.Manager, groupsManager groups.Manager, accountManager s.AccountManager, extractFromToken func(ctx context.Context, claims jwtclaims.AuthorizationClaims) (string, string, error), authCfg configs.AuthCfg, router *mux.Router) {
- addRouterEndpoints(routerManager, extractFromToken, authCfg, router)
- addResourceEndpoints(resourceManager, groupsManager, extractFromToken, authCfg, router)
+func AddEndpoints(networksManager networks.Manager, resourceManager resources.Manager, routerManager routers.Manager, groupsManager groups.Manager, accountManager s.AccountManager, router *mux.Router) {
+ addRouterEndpoints(routerManager, router)
+ addResourceEndpoints(resourceManager, groupsManager, router)
- networksHandler := newHandler(networksManager, resourceManager, routerManager, groupsManager, accountManager, extractFromToken, authCfg)
+ networksHandler := newHandler(networksManager, resourceManager, routerManager, groupsManager, accountManager)
router.HandleFunc("/networks", networksHandler.getAllNetworks).Methods("GET", "OPTIONS")
router.HandleFunc("/networks", networksHandler.createNetwork).Methods("POST", "OPTIONS")
router.HandleFunc("/networks/{networkId}", networksHandler.getNetwork).Methods("GET", "OPTIONS")
@@ -48,29 +45,25 @@ func AddEndpoints(networksManager networks.Manager, resourceManager resources.Ma
router.HandleFunc("/networks/{networkId}", networksHandler.deleteNetwork).Methods("DELETE", "OPTIONS")
}
-func newHandler(networksManager networks.Manager, resourceManager resources.Manager, routerManager routers.Manager, groupsManager groups.Manager, accountManager s.AccountManager, extractFromToken func(ctx context.Context, claims jwtclaims.AuthorizationClaims) (string, string, error), authCfg configs.AuthCfg) *handler {
+func newHandler(networksManager networks.Manager, resourceManager resources.Manager, routerManager routers.Manager, groupsManager groups.Manager, accountManager s.AccountManager) *handler {
return &handler{
- networksManager: networksManager,
- resourceManager: resourceManager,
- routerManager: routerManager,
- groupsManager: groupsManager,
- accountManager: accountManager,
- extractFromToken: extractFromToken,
- claimsExtractor: jwtclaims.NewClaimsExtractor(
- jwtclaims.WithAudience(authCfg.Audience),
- jwtclaims.WithUserIDClaim(authCfg.UserIDClaim),
- ),
+ networksManager: networksManager,
+ resourceManager: resourceManager,
+ routerManager: routerManager,
+ groupsManager: groupsManager,
+ accountManager: accountManager,
}
}
func (h *handler) getAllNetworks(w http.ResponseWriter, r *http.Request) {
- claims := h.claimsExtractor.FromRequestContext(r)
- accountID, userID, err := h.extractFromToken(r.Context(), claims)
+ userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
+ accountID, userID := userAuth.AccountId, userAuth.UserId
+
networks, err := h.networksManager.GetAllNetworks(r.Context(), accountID, userID)
if err != nil {
util.WriteError(r.Context(), err, w)
@@ -105,12 +98,12 @@ func (h *handler) getAllNetworks(w http.ResponseWriter, r *http.Request) {
}
func (h *handler) createNetwork(w http.ResponseWriter, r *http.Request) {
- claims := h.claimsExtractor.FromRequestContext(r)
- accountID, userID, err := h.extractFromToken(r.Context(), claims)
+ userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
+ accountID, userID := userAuth.AccountId, userAuth.UserId
var req api.NetworkRequest
err = json.NewDecoder(r.Body).Decode(&req)
@@ -141,12 +134,12 @@ func (h *handler) createNetwork(w http.ResponseWriter, r *http.Request) {
}
func (h *handler) getNetwork(w http.ResponseWriter, r *http.Request) {
- claims := h.claimsExtractor.FromRequestContext(r)
- accountID, userID, err := h.extractFromToken(r.Context(), claims)
+ userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
+ accountID, userID := userAuth.AccountId, userAuth.UserId
vars := mux.Vars(r)
networkID := vars["networkId"]
@@ -179,13 +172,13 @@ func (h *handler) getNetwork(w http.ResponseWriter, r *http.Request) {
}
func (h *handler) updateNetwork(w http.ResponseWriter, r *http.Request) {
- claims := h.claimsExtractor.FromRequestContext(r)
- accountID, userID, err := h.extractFromToken(r.Context(), claims)
+ userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
+ accountID, userID := userAuth.AccountId, userAuth.UserId
vars := mux.Vars(r)
networkID := vars["networkId"]
if len(networkID) == 0 {
@@ -229,13 +222,13 @@ func (h *handler) updateNetwork(w http.ResponseWriter, r *http.Request) {
}
func (h *handler) deleteNetwork(w http.ResponseWriter, r *http.Request) {
- claims := h.claimsExtractor.FromRequestContext(r)
- accountID, userID, err := h.extractFromToken(r.Context(), claims)
+ userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
+ accountID, userID := userAuth.AccountId, userAuth.UserId
vars := mux.Vars(r)
networkID := vars["networkId"]
if len(networkID) == 0 {
diff --git a/management/server/http/handlers/networks/resources_handler.go b/management/server/http/handlers/networks/resources_handler.go
index f2dc8e3b8..fba7026e8 100644
--- a/management/server/http/handlers/networks/resources_handler.go
+++ b/management/server/http/handlers/networks/resources_handler.go
@@ -1,30 +1,26 @@
package networks
import (
- "context"
"encoding/json"
"net/http"
"github.com/gorilla/mux"
+ nbcontext "github.com/netbirdio/netbird/management/server/context"
"github.com/netbirdio/netbird/management/server/groups"
"github.com/netbirdio/netbird/management/server/http/api"
- "github.com/netbirdio/netbird/management/server/http/configs"
"github.com/netbirdio/netbird/management/server/http/util"
- "github.com/netbirdio/netbird/management/server/jwtclaims"
"github.com/netbirdio/netbird/management/server/networks/resources"
"github.com/netbirdio/netbird/management/server/networks/resources/types"
)
type resourceHandler struct {
- resourceManager resources.Manager
- groupsManager groups.Manager
- extractFromToken func(ctx context.Context, claims jwtclaims.AuthorizationClaims) (string, string, error)
- claimsExtractor *jwtclaims.ClaimsExtractor
+ resourceManager resources.Manager
+ groupsManager groups.Manager
}
-func addResourceEndpoints(resourcesManager resources.Manager, groupsManager groups.Manager, extractFromToken func(ctx context.Context, claims jwtclaims.AuthorizationClaims) (string, string, error), authCfg configs.AuthCfg, router *mux.Router) {
- resourceHandler := newResourceHandler(resourcesManager, groupsManager, extractFromToken, authCfg)
+func addResourceEndpoints(resourcesManager resources.Manager, groupsManager groups.Manager, router *mux.Router) {
+ resourceHandler := newResourceHandler(resourcesManager, groupsManager)
router.HandleFunc("/networks/resources", resourceHandler.getAllResourcesInAccount).Methods("GET", "OPTIONS")
router.HandleFunc("/networks/{networkId}/resources", resourceHandler.getAllResourcesInNetwork).Methods("GET", "OPTIONS")
router.HandleFunc("/networks/{networkId}/resources", resourceHandler.createResource).Methods("POST", "OPTIONS")
@@ -33,26 +29,21 @@ func addResourceEndpoints(resourcesManager resources.Manager, groupsManager grou
router.HandleFunc("/networks/{networkId}/resources/{resourceId}", resourceHandler.deleteResource).Methods("DELETE", "OPTIONS")
}
-func newResourceHandler(resourceManager resources.Manager, groupsManager groups.Manager, extractFromToken func(ctx context.Context, claims jwtclaims.AuthorizationClaims) (string, string, error), authCfg configs.AuthCfg) *resourceHandler {
+func newResourceHandler(resourceManager resources.Manager, groupsManager groups.Manager) *resourceHandler {
return &resourceHandler{
- resourceManager: resourceManager,
- groupsManager: groupsManager,
- extractFromToken: extractFromToken,
- claimsExtractor: jwtclaims.NewClaimsExtractor(
- jwtclaims.WithAudience(authCfg.Audience),
- jwtclaims.WithUserIDClaim(authCfg.UserIDClaim),
- ),
+ resourceManager: resourceManager,
+ groupsManager: groupsManager,
}
}
func (h *resourceHandler) getAllResourcesInNetwork(w http.ResponseWriter, r *http.Request) {
- claims := h.claimsExtractor.FromRequestContext(r)
- accountID, userID, err := h.extractFromToken(r.Context(), claims)
+ userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
+ accountID, userID := userAuth.AccountId, userAuth.UserId
networkID := mux.Vars(r)["networkId"]
resources, err := h.resourceManager.GetAllResourcesInNetwork(r.Context(), accountID, userID, networkID)
if err != nil {
@@ -76,13 +67,14 @@ func (h *resourceHandler) getAllResourcesInNetwork(w http.ResponseWriter, r *htt
util.WriteJSONObject(r.Context(), w, resourcesResponse)
}
func (h *resourceHandler) getAllResourcesInAccount(w http.ResponseWriter, r *http.Request) {
- claims := h.claimsExtractor.FromRequestContext(r)
- accountID, userID, err := h.extractFromToken(r.Context(), claims)
+ userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
+ accountID, userID := userAuth.AccountId, userAuth.UserId
+
resources, err := h.resourceManager.GetAllResourcesInAccount(r.Context(), accountID, userID)
if err != nil {
util.WriteError(r.Context(), err, w)
@@ -106,13 +98,14 @@ func (h *resourceHandler) getAllResourcesInAccount(w http.ResponseWriter, r *htt
}
func (h *resourceHandler) createResource(w http.ResponseWriter, r *http.Request) {
- claims := h.claimsExtractor.FromRequestContext(r)
- accountID, userID, err := h.extractFromToken(r.Context(), claims)
+ userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
+ accountID, userID := userAuth.AccountId, userAuth.UserId
+
var req api.NetworkResourceRequest
err = json.NewDecoder(r.Body).Decode(&req)
if err != nil {
@@ -144,13 +137,13 @@ func (h *resourceHandler) createResource(w http.ResponseWriter, r *http.Request)
}
func (h *resourceHandler) getResource(w http.ResponseWriter, r *http.Request) {
- claims := h.claimsExtractor.FromRequestContext(r)
- accountID, userID, err := h.extractFromToken(r.Context(), claims)
+ userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
+ accountID, userID := userAuth.AccountId, userAuth.UserId
networkID := mux.Vars(r)["networkId"]
resourceID := mux.Vars(r)["resourceId"]
resource, err := h.resourceManager.GetResource(r.Context(), accountID, userID, networkID, resourceID)
@@ -171,13 +164,13 @@ func (h *resourceHandler) getResource(w http.ResponseWriter, r *http.Request) {
}
func (h *resourceHandler) updateResource(w http.ResponseWriter, r *http.Request) {
- claims := h.claimsExtractor.FromRequestContext(r)
- accountID, userID, err := h.extractFromToken(r.Context(), claims)
+ userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
+ accountID, userID := userAuth.AccountId, userAuth.UserId
var req api.NetworkResourceRequest
err = json.NewDecoder(r.Body).Decode(&req)
if err != nil {
@@ -209,12 +202,12 @@ func (h *resourceHandler) updateResource(w http.ResponseWriter, r *http.Request)
}
func (h *resourceHandler) deleteResource(w http.ResponseWriter, r *http.Request) {
- claims := h.claimsExtractor.FromRequestContext(r)
- accountID, userID, err := h.extractFromToken(r.Context(), claims)
+ userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
+ accountID, userID := userAuth.AccountId, userAuth.UserId
networkID := mux.Vars(r)["networkId"]
resourceID := mux.Vars(r)["resourceId"]
diff --git a/management/server/http/handlers/networks/routers_handler.go b/management/server/http/handlers/networks/routers_handler.go
index 7ca95d902..f98da4966 100644
--- a/management/server/http/handlers/networks/routers_handler.go
+++ b/management/server/http/handlers/networks/routers_handler.go
@@ -1,28 +1,24 @@
package networks
import (
- "context"
"encoding/json"
"net/http"
"github.com/gorilla/mux"
+ nbcontext "github.com/netbirdio/netbird/management/server/context"
"github.com/netbirdio/netbird/management/server/http/api"
- "github.com/netbirdio/netbird/management/server/http/configs"
"github.com/netbirdio/netbird/management/server/http/util"
- "github.com/netbirdio/netbird/management/server/jwtclaims"
"github.com/netbirdio/netbird/management/server/networks/routers"
"github.com/netbirdio/netbird/management/server/networks/routers/types"
)
type routersHandler struct {
- routersManager routers.Manager
- extractFromToken func(ctx context.Context, claims jwtclaims.AuthorizationClaims) (string, string, error)
- claimsExtractor *jwtclaims.ClaimsExtractor
+ routersManager routers.Manager
}
-func addRouterEndpoints(routersManager routers.Manager, extractFromToken func(ctx context.Context, claims jwtclaims.AuthorizationClaims) (string, string, error), authCfg configs.AuthCfg, router *mux.Router) {
- routersHandler := newRoutersHandler(routersManager, extractFromToken, authCfg)
+func addRouterEndpoints(routersManager routers.Manager, router *mux.Router) {
+ routersHandler := newRoutersHandler(routersManager)
router.HandleFunc("/networks/{networkId}/routers", routersHandler.getAllRouters).Methods("GET", "OPTIONS")
router.HandleFunc("/networks/{networkId}/routers", routersHandler.createRouter).Methods("POST", "OPTIONS")
router.HandleFunc("/networks/{networkId}/routers/{routerId}", routersHandler.getRouter).Methods("GET", "OPTIONS")
@@ -30,25 +26,21 @@ func addRouterEndpoints(routersManager routers.Manager, extractFromToken func(ct
router.HandleFunc("/networks/{networkId}/routers/{routerId}", routersHandler.deleteRouter).Methods("DELETE", "OPTIONS")
}
-func newRoutersHandler(routersManager routers.Manager, extractFromToken func(ctx context.Context, claims jwtclaims.AuthorizationClaims) (string, string, error), authCfg configs.AuthCfg) *routersHandler {
+func newRoutersHandler(routersManager routers.Manager) *routersHandler {
return &routersHandler{
- routersManager: routersManager,
- extractFromToken: extractFromToken,
- claimsExtractor: jwtclaims.NewClaimsExtractor(
- jwtclaims.WithAudience(authCfg.Audience),
- jwtclaims.WithUserIDClaim(authCfg.UserIDClaim),
- ),
+ routersManager: routersManager,
}
}
func (h *routersHandler) getAllRouters(w http.ResponseWriter, r *http.Request) {
- claims := h.claimsExtractor.FromRequestContext(r)
- accountID, userID, err := h.extractFromToken(r.Context(), claims)
+ userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
+ accountID, userID := userAuth.AccountId, userAuth.UserId
+
networkID := mux.Vars(r)["networkId"]
routers, err := h.routersManager.GetAllRoutersInNetwork(r.Context(), accountID, userID, networkID)
if err != nil {
@@ -65,13 +57,14 @@ func (h *routersHandler) getAllRouters(w http.ResponseWriter, r *http.Request) {
}
func (h *routersHandler) createRouter(w http.ResponseWriter, r *http.Request) {
- claims := h.claimsExtractor.FromRequestContext(r)
- accountID, userID, err := h.extractFromToken(r.Context(), claims)
+ userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
+ accountID, userID := userAuth.AccountId, userAuth.UserId
+
networkID := mux.Vars(r)["networkId"]
var req api.NetworkRouterRequest
err = json.NewDecoder(r.Body).Decode(&req)
@@ -96,13 +89,14 @@ func (h *routersHandler) createRouter(w http.ResponseWriter, r *http.Request) {
}
func (h *routersHandler) getRouter(w http.ResponseWriter, r *http.Request) {
- claims := h.claimsExtractor.FromRequestContext(r)
- accountID, userID, err := h.extractFromToken(r.Context(), claims)
+ userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
+ accountID, userID := userAuth.AccountId, userAuth.UserId
+
routerID := mux.Vars(r)["routerId"]
networkID := mux.Vars(r)["networkId"]
router, err := h.routersManager.GetRouter(r.Context(), accountID, userID, networkID, routerID)
@@ -115,13 +109,14 @@ func (h *routersHandler) getRouter(w http.ResponseWriter, r *http.Request) {
}
func (h *routersHandler) updateRouter(w http.ResponseWriter, r *http.Request) {
- claims := h.claimsExtractor.FromRequestContext(r)
- accountID, userID, err := h.extractFromToken(r.Context(), claims)
+ userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
+ accountID, userID := userAuth.AccountId, userAuth.UserId
+
var req api.NetworkRouterRequest
err = json.NewDecoder(r.Body).Decode(&req)
if err != nil {
@@ -146,13 +141,13 @@ func (h *routersHandler) updateRouter(w http.ResponseWriter, r *http.Request) {
}
func (h *routersHandler) deleteRouter(w http.ResponseWriter, r *http.Request) {
- claims := h.claimsExtractor.FromRequestContext(r)
- accountID, userID, err := h.extractFromToken(r.Context(), claims)
+ userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
+ accountID, userID := userAuth.AccountId, userAuth.UserId
routerID := mux.Vars(r)["routerId"]
networkID := mux.Vars(r)["networkId"]
err = h.routersManager.DeleteRouter(r.Context(), accountID, userID, networkID, routerID)
diff --git a/management/server/http/handlers/peers/peers_handler.go b/management/server/http/handlers/peers/peers_handler.go
index 723ab4f67..a907a0f24 100644
--- a/management/server/http/handlers/peers/peers_handler.go
+++ b/management/server/http/handlers/peers/peers_handler.go
@@ -10,11 +10,10 @@ import (
log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/management/server"
+ nbcontext "github.com/netbirdio/netbird/management/server/context"
"github.com/netbirdio/netbird/management/server/groups"
"github.com/netbirdio/netbird/management/server/http/api"
- "github.com/netbirdio/netbird/management/server/http/configs"
"github.com/netbirdio/netbird/management/server/http/util"
- "github.com/netbirdio/netbird/management/server/jwtclaims"
nbpeer "github.com/netbirdio/netbird/management/server/peer"
"github.com/netbirdio/netbird/management/server/status"
"github.com/netbirdio/netbird/management/server/types"
@@ -22,12 +21,11 @@ import (
// Handler is a handler that returns peers of the account
type Handler struct {
- accountManager server.AccountManager
- claimsExtractor *jwtclaims.ClaimsExtractor
+ accountManager server.AccountManager
}
-func AddEndpoints(accountManager server.AccountManager, authCfg configs.AuthCfg, router *mux.Router) {
- peersHandler := NewHandler(accountManager, authCfg)
+func AddEndpoints(accountManager server.AccountManager, router *mux.Router) {
+ peersHandler := NewHandler(accountManager)
router.HandleFunc("/peers", peersHandler.GetAllPeers).Methods("GET", "OPTIONS")
router.HandleFunc("/peers/{peerId}", peersHandler.HandlePeer).
Methods("GET", "PUT", "DELETE", "OPTIONS")
@@ -35,13 +33,9 @@ func AddEndpoints(accountManager server.AccountManager, authCfg configs.AuthCfg,
}
// NewHandler creates a new peers Handler
-func NewHandler(accountManager server.AccountManager, authCfg configs.AuthCfg) *Handler {
+func NewHandler(accountManager server.AccountManager) *Handler {
return &Handler{
accountManager: accountManager,
- claimsExtractor: jwtclaims.NewClaimsExtractor(
- jwtclaims.WithAudience(authCfg.Audience),
- jwtclaims.WithUserIDClaim(authCfg.UserIDClaim),
- ),
}
}
@@ -149,12 +143,13 @@ func (h *Handler) deletePeer(ctx context.Context, accountID, userID string, peer
// HandlePeer handles all peer requests for GET, PUT and DELETE operations
func (h *Handler) HandlePeer(w http.ResponseWriter, r *http.Request) {
- claims := h.claimsExtractor.FromRequestContext(r)
- accountID, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims)
+ userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
+
+ accountID, userID := userAuth.AccountId, userAuth.UserId
vars := mux.Vars(r)
peerID := vars["peerId"]
if len(peerID) == 0 {
@@ -179,17 +174,22 @@ func (h *Handler) HandlePeer(w http.ResponseWriter, r *http.Request) {
// GetAllPeers returns a list of all peers associated with a provided account
func (h *Handler) GetAllPeers(w http.ResponseWriter, r *http.Request) {
- claims := h.claimsExtractor.FromRequestContext(r)
- accountID, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims)
+ userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
+<<<<<<< HEAD
nameFilter := r.URL.Query().Get("name")
ipFilter := r.URL.Query().Get("ip")
peers, err := h.accountManager.GetPeers(r.Context(), accountID, userID, nameFilter, ipFilter)
+=======
+ accountID, userID := userAuth.AccountId, userAuth.UserId
+
+ peers, err := h.accountManager.GetPeers(r.Context(), accountID, userID)
+>>>>>>> main
if err != nil {
util.WriteError(r.Context(), err, w)
return
@@ -233,13 +233,14 @@ func (h *Handler) setApprovalRequiredFlag(respBody []*api.PeerBatch, approvedPee
// GetAccessiblePeers returns a list of all peers that the specified peer can connect to within the network.
func (h *Handler) GetAccessiblePeers(w http.ResponseWriter, r *http.Request) {
- claims := h.claimsExtractor.FromRequestContext(r)
- accountID, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims)
+ userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
+ accountID, userID := userAuth.AccountId, userAuth.UserId
+
vars := mux.Vars(r)
peerID := vars["peerId"]
if len(peerID) == 0 {
diff --git a/management/server/http/handlers/peers/peers_handler_test.go b/management/server/http/handlers/peers/peers_handler_test.go
index 248611bd2..cb60ae4f1 100644
--- a/management/server/http/handlers/peers/peers_handler_test.go
+++ b/management/server/http/handlers/peers/peers_handler_test.go
@@ -15,8 +15,8 @@ import (
"github.com/gorilla/mux"
"golang.org/x/exp/maps"
+ nbcontext "github.com/netbirdio/netbird/management/server/context"
"github.com/netbirdio/netbird/management/server/http/api"
- "github.com/netbirdio/netbird/management/server/jwtclaims"
nbpeer "github.com/netbirdio/netbird/management/server/peer"
"github.com/netbirdio/netbird/management/server/types"
@@ -25,16 +25,13 @@ import (
"github.com/netbirdio/netbird/management/server/mock_server"
)
-type ctxKey string
-
const (
testPeerID = "test_peer"
noUpdateChannelTestPeerID = "no-update-channel"
- adminUser = "admin_user"
- regularUser = "regular_user"
- serviceUser = "service_user"
- userIDKey ctxKey = "user_id"
+ adminUser = "admin_user"
+ regularUser = "regular_user"
+ serviceUser = "service_user"
)
func initTestMetaData(peers ...*nbpeer.Peer) *Handler {
@@ -146,9 +143,6 @@ func initTestMetaData(peers ...*nbpeer.Peer) *Handler {
GetDNSDomainFunc: func() string {
return "netbird.selfhosted"
},
- GetAccountIDFromTokenFunc: func(_ context.Context, claims jwtclaims.AuthorizationClaims) (string, string, error) {
- return claims.AccountId, claims.UserId, nil
- },
GetAccountFunc: func(ctx context.Context, accountID string) (*types.Account, error) {
return account, nil
},
@@ -167,16 +161,6 @@ func initTestMetaData(peers ...*nbpeer.Peer) *Handler {
return ok
},
},
- claimsExtractor: jwtclaims.NewClaimsExtractor(
- jwtclaims.WithFromRequestContext(func(r *http.Request) jwtclaims.AuthorizationClaims {
- userID := r.Context().Value(userIDKey).(string)
- return jwtclaims.AuthorizationClaims{
- UserId: userID,
- Domain: "hotmail.com",
- AccountId: "test_id",
- }
- }),
- ),
}
}
@@ -267,8 +251,11 @@ func TestGetPeers(t *testing.T) {
recorder := httptest.NewRecorder()
req := httptest.NewRequest(tc.requestType, tc.requestPath, tc.requestBody)
- ctx := context.WithValue(context.Background(), userIDKey, "admin_user")
- req = req.WithContext(ctx)
+ req = nbcontext.SetUserAuthInRequest(req, nbcontext.UserAuth{
+ UserId: "admin_user",
+ Domain: "hotmail.com",
+ AccountId: "test_id",
+ })
router := mux.NewRouter()
router.HandleFunc("/api/peers/", p.GetAllPeers).Methods("GET")
@@ -412,8 +399,11 @@ func TestGetAccessiblePeers(t *testing.T) {
recorder := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodGet, fmt.Sprintf("/api/peers/%s/accessible-peers", tc.peerID), nil)
- ctx := context.WithValue(context.Background(), userIDKey, tc.callerUserID)
- req = req.WithContext(ctx)
+ req = nbcontext.SetUserAuthInRequest(req, nbcontext.UserAuth{
+ UserId: tc.callerUserID,
+ Domain: "hotmail.com",
+ AccountId: "test_id",
+ })
router := mux.NewRouter()
router.HandleFunc("/api/peers/{peerId}/accessible-peers", p.GetAccessiblePeers).Methods("GET")
diff --git a/management/server/http/handlers/policies/geolocation_handler_test.go b/management/server/http/handlers/policies/geolocation_handler_test.go
index fc5839baa..fbdc324d6 100644
--- a/management/server/http/handlers/policies/geolocation_handler_test.go
+++ b/management/server/http/handlers/policies/geolocation_handler_test.go
@@ -13,9 +13,9 @@ import (
"github.com/gorilla/mux"
"github.com/stretchr/testify/assert"
+ nbcontext "github.com/netbirdio/netbird/management/server/context"
"github.com/netbirdio/netbird/management/server/geolocation"
"github.com/netbirdio/netbird/management/server/http/api"
- "github.com/netbirdio/netbird/management/server/jwtclaims"
"github.com/netbirdio/netbird/management/server/mock_server"
"github.com/netbirdio/netbird/management/server/types"
"github.com/netbirdio/netbird/util"
@@ -43,23 +43,11 @@ func initGeolocationTestData(t *testing.T) *geolocationsHandler {
return &geolocationsHandler{
accountManager: &mock_server.MockAccountManager{
- GetAccountIDFromTokenFunc: func(_ context.Context, claims jwtclaims.AuthorizationClaims) (string, string, error) {
- return claims.AccountId, claims.UserId, nil
- },
GetUserByIDFunc: func(ctx context.Context, id string) (*types.User, error) {
return types.NewAdminUser(id), nil
},
},
geolocationManager: geo,
- claimsExtractor: jwtclaims.NewClaimsExtractor(
- jwtclaims.WithFromRequestContext(func(r *http.Request) jwtclaims.AuthorizationClaims {
- return jwtclaims.AuthorizationClaims{
- UserId: "test_user",
- Domain: "hotmail.com",
- AccountId: "test_id",
- }
- }),
- ),
}
}
@@ -112,6 +100,11 @@ func TestGetCitiesByCountry(t *testing.T) {
t.Run(tc.name, func(t *testing.T) {
recorder := httptest.NewRecorder()
req := httptest.NewRequest(tc.requestType, tc.requestPath, nil)
+ req = nbcontext.SetUserAuthInRequest(req, nbcontext.UserAuth{
+ UserId: "test_user",
+ Domain: "hotmail.com",
+ AccountId: "test_id",
+ })
router := mux.NewRouter()
router.HandleFunc("/api/locations/countries/{country}/cities", geolocationHandler.getCitiesByCountry).Methods("GET")
@@ -200,6 +193,11 @@ func TestGetAllCountries(t *testing.T) {
t.Run(tc.name, func(t *testing.T) {
recorder := httptest.NewRecorder()
req := httptest.NewRequest(tc.requestType, tc.requestPath, nil)
+ req = nbcontext.SetUserAuthInRequest(req, nbcontext.UserAuth{
+ UserId: "test_user",
+ Domain: "hotmail.com",
+ AccountId: "test_id",
+ })
router := mux.NewRouter()
router.HandleFunc("/api/locations/countries", geolocationHandler.getAllCountries).Methods("GET")
diff --git a/management/server/http/handlers/policies/geolocations_handler.go b/management/server/http/handlers/policies/geolocations_handler.go
index 161d97402..c4868f879 100644
--- a/management/server/http/handlers/policies/geolocations_handler.go
+++ b/management/server/http/handlers/policies/geolocations_handler.go
@@ -7,11 +7,10 @@ import (
"github.com/gorilla/mux"
"github.com/netbirdio/netbird/management/server"
+ nbcontext "github.com/netbirdio/netbird/management/server/context"
"github.com/netbirdio/netbird/management/server/geolocation"
"github.com/netbirdio/netbird/management/server/http/api"
- "github.com/netbirdio/netbird/management/server/http/configs"
"github.com/netbirdio/netbird/management/server/http/util"
- "github.com/netbirdio/netbird/management/server/jwtclaims"
"github.com/netbirdio/netbird/management/server/status"
)
@@ -23,24 +22,19 @@ var (
type geolocationsHandler struct {
accountManager server.AccountManager
geolocationManager geolocation.Geolocation
- claimsExtractor *jwtclaims.ClaimsExtractor
}
-func addLocationsEndpoint(accountManager server.AccountManager, locationManager geolocation.Geolocation, authCfg configs.AuthCfg, router *mux.Router) {
- locationHandler := newGeolocationsHandlerHandler(accountManager, locationManager, authCfg)
+func addLocationsEndpoint(accountManager server.AccountManager, locationManager geolocation.Geolocation, router *mux.Router) {
+ locationHandler := newGeolocationsHandlerHandler(accountManager, locationManager)
router.HandleFunc("/locations/countries", locationHandler.getAllCountries).Methods("GET", "OPTIONS")
router.HandleFunc("/locations/countries/{country}/cities", locationHandler.getCitiesByCountry).Methods("GET", "OPTIONS")
}
// newGeolocationsHandlerHandler creates a new Geolocations handler
-func newGeolocationsHandlerHandler(accountManager server.AccountManager, geolocationManager geolocation.Geolocation, authCfg configs.AuthCfg) *geolocationsHandler {
+func newGeolocationsHandlerHandler(accountManager server.AccountManager, geolocationManager geolocation.Geolocation) *geolocationsHandler {
return &geolocationsHandler{
accountManager: accountManager,
geolocationManager: geolocationManager,
- claimsExtractor: jwtclaims.NewClaimsExtractor(
- jwtclaims.WithAudience(authCfg.Audience),
- jwtclaims.WithUserIDClaim(authCfg.UserIDClaim),
- ),
}
}
@@ -104,12 +98,13 @@ func (l *geolocationsHandler) getCitiesByCountry(w http.ResponseWriter, r *http.
}
func (l *geolocationsHandler) authenticateUser(r *http.Request) error {
- claims := l.claimsExtractor.FromRequestContext(r)
- _, userID, err := l.accountManager.GetAccountIDFromToken(r.Context(), claims)
+ userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
if err != nil {
return err
}
+ _, userID := userAuth.AccountId, userAuth.UserId
+
user, err := l.accountManager.GetUserByID(r.Context(), userID)
if err != nil {
return err
diff --git a/management/server/http/handlers/policies/policies_handler.go b/management/server/http/handlers/policies/policies_handler.go
index a748e73b8..63fc8a03b 100644
--- a/management/server/http/handlers/policies/policies_handler.go
+++ b/management/server/http/handlers/policies/policies_handler.go
@@ -8,51 +8,46 @@ import (
"github.com/gorilla/mux"
"github.com/netbirdio/netbird/management/server"
+ nbcontext "github.com/netbirdio/netbird/management/server/context"
"github.com/netbirdio/netbird/management/server/geolocation"
"github.com/netbirdio/netbird/management/server/http/api"
- "github.com/netbirdio/netbird/management/server/http/configs"
"github.com/netbirdio/netbird/management/server/http/util"
- "github.com/netbirdio/netbird/management/server/jwtclaims"
"github.com/netbirdio/netbird/management/server/status"
"github.com/netbirdio/netbird/management/server/types"
)
// handler is a handler that returns policy of the account
type handler struct {
- accountManager server.AccountManager
- claimsExtractor *jwtclaims.ClaimsExtractor
+ accountManager server.AccountManager
}
-func AddEndpoints(accountManager server.AccountManager, locationManager geolocation.Geolocation, authCfg configs.AuthCfg, router *mux.Router) {
- policiesHandler := newHandler(accountManager, authCfg)
+func AddEndpoints(accountManager server.AccountManager, locationManager geolocation.Geolocation, router *mux.Router) {
+ policiesHandler := newHandler(accountManager)
router.HandleFunc("/policies", policiesHandler.getAllPolicies).Methods("GET", "OPTIONS")
router.HandleFunc("/policies", policiesHandler.createPolicy).Methods("POST", "OPTIONS")
router.HandleFunc("/policies/{policyId}", policiesHandler.updatePolicy).Methods("PUT", "OPTIONS")
router.HandleFunc("/policies/{policyId}", policiesHandler.getPolicy).Methods("GET", "OPTIONS")
router.HandleFunc("/policies/{policyId}", policiesHandler.deletePolicy).Methods("DELETE", "OPTIONS")
- addPostureCheckEndpoint(accountManager, locationManager, authCfg, router)
+ addPostureCheckEndpoint(accountManager, locationManager, router)
}
// newHandler creates a new policies handler
-func newHandler(accountManager server.AccountManager, authCfg configs.AuthCfg) *handler {
+func newHandler(accountManager server.AccountManager) *handler {
return &handler{
accountManager: accountManager,
- claimsExtractor: jwtclaims.NewClaimsExtractor(
- jwtclaims.WithAudience(authCfg.Audience),
- jwtclaims.WithUserIDClaim(authCfg.UserIDClaim),
- ),
}
}
// getAllPolicies list for the account
func (h *handler) getAllPolicies(w http.ResponseWriter, r *http.Request) {
- claims := h.claimsExtractor.FromRequestContext(r)
- accountID, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims)
+ userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
+ accountID, userID := userAuth.AccountId, userAuth.UserId
+
listPolicies, err := h.accountManager.ListPolicies(r.Context(), accountID, userID)
if err != nil {
util.WriteError(r.Context(), err, w)
@@ -80,13 +75,14 @@ func (h *handler) getAllPolicies(w http.ResponseWriter, r *http.Request) {
// updatePolicy handles update to a policy identified by a given ID
func (h *handler) updatePolicy(w http.ResponseWriter, r *http.Request) {
- claims := h.claimsExtractor.FromRequestContext(r)
- accountID, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims)
+ userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
+ accountID, userID := userAuth.AccountId, userAuth.UserId
+
vars := mux.Vars(r)
policyID := vars["policyId"]
if len(policyID) == 0 {
@@ -105,13 +101,14 @@ func (h *handler) updatePolicy(w http.ResponseWriter, r *http.Request) {
// createPolicy handles policy creation request
func (h *handler) createPolicy(w http.ResponseWriter, r *http.Request) {
- claims := h.claimsExtractor.FromRequestContext(r)
- accountID, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims)
+ userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
+ accountID, userID := userAuth.AccountId, userAuth.UserId
+
h.savePolicy(w, r, accountID, userID, "")
}
@@ -306,13 +303,13 @@ func (h *handler) savePolicy(w http.ResponseWriter, r *http.Request, accountID s
// deletePolicy handles policy deletion request
func (h *handler) deletePolicy(w http.ResponseWriter, r *http.Request) {
- claims := h.claimsExtractor.FromRequestContext(r)
- accountID, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims)
+ userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
+ accountID, userID := userAuth.AccountId, userAuth.UserId
vars := mux.Vars(r)
policyID := vars["policyId"]
if len(policyID) == 0 {
@@ -330,13 +327,14 @@ func (h *handler) deletePolicy(w http.ResponseWriter, r *http.Request) {
// getPolicy handles a group Get request identified by ID
func (h *handler) getPolicy(w http.ResponseWriter, r *http.Request) {
- claims := h.claimsExtractor.FromRequestContext(r)
- accountID, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims)
+ userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
+ accountID, userID := userAuth.AccountId, userAuth.UserId
+
vars := mux.Vars(r)
policyID := vars["policyId"]
if len(policyID) == 0 {
diff --git a/management/server/http/handlers/policies/policies_handler_test.go b/management/server/http/handlers/policies/policies_handler_test.go
index 8fbf84d4b..6450295eb 100644
--- a/management/server/http/handlers/policies/policies_handler_test.go
+++ b/management/server/http/handlers/policies/policies_handler_test.go
@@ -13,8 +13,8 @@ import (
"github.com/gorilla/mux"
"github.com/stretchr/testify/assert"
+ nbcontext "github.com/netbirdio/netbird/management/server/context"
"github.com/netbirdio/netbird/management/server/http/api"
- "github.com/netbirdio/netbird/management/server/jwtclaims"
"github.com/netbirdio/netbird/management/server/mock_server"
"github.com/netbirdio/netbird/management/server/status"
"github.com/netbirdio/netbird/management/server/types"
@@ -44,9 +44,6 @@ func initPoliciesTestData(policies ...*types.Policy) *handler {
GetAllGroupsFunc: func(ctx context.Context, accountID, userID string) ([]*types.Group, error) {
return []*types.Group{{ID: "F"}, {ID: "G"}}, nil
},
- GetAccountIDFromTokenFunc: func(_ context.Context, claims jwtclaims.AuthorizationClaims) (string, string, error) {
- return claims.AccountId, claims.UserId, nil
- },
GetAccountByIDFunc: func(ctx context.Context, accountID string, userID string) (*types.Account, error) {
user := types.NewAdminUser(userID)
return &types.Account{
@@ -65,15 +62,6 @@ func initPoliciesTestData(policies ...*types.Policy) *handler {
}, nil
},
},
- claimsExtractor: jwtclaims.NewClaimsExtractor(
- jwtclaims.WithFromRequestContext(func(r *http.Request) jwtclaims.AuthorizationClaims {
- return jwtclaims.AuthorizationClaims{
- UserId: "test_user",
- Domain: "hotmail.com",
- AccountId: "test_id",
- }
- }),
- ),
}
}
@@ -115,6 +103,11 @@ func TestPoliciesGetPolicy(t *testing.T) {
t.Run(tc.name, func(t *testing.T) {
recorder := httptest.NewRecorder()
req := httptest.NewRequest(tc.requestType, tc.requestPath, tc.requestBody)
+ req = nbcontext.SetUserAuthInRequest(req, nbcontext.UserAuth{
+ UserId: "test_user",
+ Domain: "hotmail.com",
+ AccountId: "test_id",
+ })
router := mux.NewRouter()
router.HandleFunc("/api/policies/{policyId}", p.getPolicy).Methods("GET")
@@ -274,6 +267,11 @@ func TestPoliciesWritePolicy(t *testing.T) {
t.Run(tc.name, func(t *testing.T) {
recorder := httptest.NewRecorder()
req := httptest.NewRequest(tc.requestType, tc.requestPath, tc.requestBody)
+ req = nbcontext.SetUserAuthInRequest(req, nbcontext.UserAuth{
+ UserId: "test_user",
+ Domain: "hotmail.com",
+ AccountId: "test_id",
+ })
router := mux.NewRouter()
router.HandleFunc("/api/policies", p.createPolicy).Methods("POST")
diff --git a/management/server/http/handlers/policies/posture_checks_handler.go b/management/server/http/handlers/policies/posture_checks_handler.go
index ce0d4878c..e6e58da58 100644
--- a/management/server/http/handlers/policies/posture_checks_handler.go
+++ b/management/server/http/handlers/policies/posture_checks_handler.go
@@ -7,11 +7,10 @@ import (
"github.com/gorilla/mux"
"github.com/netbirdio/netbird/management/server"
+ nbcontext "github.com/netbirdio/netbird/management/server/context"
"github.com/netbirdio/netbird/management/server/geolocation"
"github.com/netbirdio/netbird/management/server/http/api"
- "github.com/netbirdio/netbird/management/server/http/configs"
"github.com/netbirdio/netbird/management/server/http/util"
- "github.com/netbirdio/netbird/management/server/jwtclaims"
"github.com/netbirdio/netbird/management/server/posture"
"github.com/netbirdio/netbird/management/server/status"
)
@@ -20,40 +19,35 @@ import (
type postureChecksHandler struct {
accountManager server.AccountManager
geolocationManager geolocation.Geolocation
- claimsExtractor *jwtclaims.ClaimsExtractor
}
-func addPostureCheckEndpoint(accountManager server.AccountManager, locationManager geolocation.Geolocation, authCfg configs.AuthCfg, router *mux.Router) {
- postureCheckHandler := newPostureChecksHandler(accountManager, locationManager, authCfg)
+func addPostureCheckEndpoint(accountManager server.AccountManager, locationManager geolocation.Geolocation, router *mux.Router) {
+ postureCheckHandler := newPostureChecksHandler(accountManager, locationManager)
router.HandleFunc("/posture-checks", postureCheckHandler.getAllPostureChecks).Methods("GET", "OPTIONS")
router.HandleFunc("/posture-checks", postureCheckHandler.createPostureCheck).Methods("POST", "OPTIONS")
router.HandleFunc("/posture-checks/{postureCheckId}", postureCheckHandler.updatePostureCheck).Methods("PUT", "OPTIONS")
router.HandleFunc("/posture-checks/{postureCheckId}", postureCheckHandler.getPostureCheck).Methods("GET", "OPTIONS")
router.HandleFunc("/posture-checks/{postureCheckId}", postureCheckHandler.deletePostureCheck).Methods("DELETE", "OPTIONS")
- addLocationsEndpoint(accountManager, locationManager, authCfg, router)
+ addLocationsEndpoint(accountManager, locationManager, router)
}
// newPostureChecksHandler creates a new PostureChecks handler
-func newPostureChecksHandler(accountManager server.AccountManager, geolocationManager geolocation.Geolocation, authCfg configs.AuthCfg) *postureChecksHandler {
+func newPostureChecksHandler(accountManager server.AccountManager, geolocationManager geolocation.Geolocation) *postureChecksHandler {
return &postureChecksHandler{
accountManager: accountManager,
geolocationManager: geolocationManager,
- claimsExtractor: jwtclaims.NewClaimsExtractor(
- jwtclaims.WithAudience(authCfg.Audience),
- jwtclaims.WithUserIDClaim(authCfg.UserIDClaim),
- ),
}
}
// getAllPostureChecks list for the account
func (p *postureChecksHandler) getAllPostureChecks(w http.ResponseWriter, r *http.Request) {
- claims := p.claimsExtractor.FromRequestContext(r)
- accountID, userID, err := p.accountManager.GetAccountIDFromToken(r.Context(), claims)
+ userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
+ accountID, userID := userAuth.AccountId, userAuth.UserId
listPostureChecks, err := p.accountManager.ListPostureChecks(r.Context(), accountID, userID)
if err != nil {
util.WriteError(r.Context(), err, w)
@@ -70,13 +64,14 @@ func (p *postureChecksHandler) getAllPostureChecks(w http.ResponseWriter, r *htt
// updatePostureCheck handles update to a posture check identified by a given ID
func (p *postureChecksHandler) updatePostureCheck(w http.ResponseWriter, r *http.Request) {
- claims := p.claimsExtractor.FromRequestContext(r)
- accountID, userID, err := p.accountManager.GetAccountIDFromToken(r.Context(), claims)
+ userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
+ accountID, userID := userAuth.AccountId, userAuth.UserId
+
vars := mux.Vars(r)
postureChecksID := vars["postureCheckId"]
if len(postureChecksID) == 0 {
@@ -95,25 +90,26 @@ func (p *postureChecksHandler) updatePostureCheck(w http.ResponseWriter, r *http
// createPostureCheck handles posture check creation request
func (p *postureChecksHandler) createPostureCheck(w http.ResponseWriter, r *http.Request) {
- claims := p.claimsExtractor.FromRequestContext(r)
- accountID, userID, err := p.accountManager.GetAccountIDFromToken(r.Context(), claims)
+ userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
+ accountID, userID := userAuth.AccountId, userAuth.UserId
+
p.savePostureChecks(w, r, accountID, userID, "")
}
// getPostureCheck handles a posture check Get request identified by ID
func (p *postureChecksHandler) getPostureCheck(w http.ResponseWriter, r *http.Request) {
- claims := p.claimsExtractor.FromRequestContext(r)
- accountID, userID, err := p.accountManager.GetAccountIDFromToken(r.Context(), claims)
+ userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
+ accountID, userID := userAuth.AccountId, userAuth.UserId
vars := mux.Vars(r)
postureChecksID := vars["postureCheckId"]
if len(postureChecksID) == 0 {
@@ -132,13 +128,13 @@ func (p *postureChecksHandler) getPostureCheck(w http.ResponseWriter, r *http.Re
// deletePostureCheck handles posture check deletion request
func (p *postureChecksHandler) deletePostureCheck(w http.ResponseWriter, r *http.Request) {
- claims := p.claimsExtractor.FromRequestContext(r)
- accountID, userID, err := p.accountManager.GetAccountIDFromToken(r.Context(), claims)
+ userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
+ accountID, userID := userAuth.AccountId, userAuth.UserId
vars := mux.Vars(r)
postureChecksID := vars["postureCheckId"]
if len(postureChecksID) == 0 {
diff --git a/management/server/http/handlers/policies/posture_checks_handler_test.go b/management/server/http/handlers/policies/posture_checks_handler_test.go
index 237687fd4..e3844caa2 100644
--- a/management/server/http/handlers/policies/posture_checks_handler_test.go
+++ b/management/server/http/handlers/policies/posture_checks_handler_test.go
@@ -14,9 +14,9 @@ import (
"github.com/gorilla/mux"
"github.com/stretchr/testify/assert"
+ nbcontext "github.com/netbirdio/netbird/management/server/context"
"github.com/netbirdio/netbird/management/server/geolocation"
"github.com/netbirdio/netbird/management/server/http/api"
- "github.com/netbirdio/netbird/management/server/jwtclaims"
"github.com/netbirdio/netbird/management/server/mock_server"
"github.com/netbirdio/netbird/management/server/posture"
"github.com/netbirdio/netbird/management/server/status"
@@ -66,20 +66,8 @@ func initPostureChecksTestData(postureChecks ...*posture.Checks) *postureChecksH
}
return accountPostureChecks, nil
},
- GetAccountIDFromTokenFunc: func(_ context.Context, claims jwtclaims.AuthorizationClaims) (string, string, error) {
- return claims.AccountId, claims.UserId, nil
- },
},
geolocationManager: &geolocation.Mock{},
- claimsExtractor: jwtclaims.NewClaimsExtractor(
- jwtclaims.WithFromRequestContext(func(r *http.Request) jwtclaims.AuthorizationClaims {
- return jwtclaims.AuthorizationClaims{
- UserId: "test_user",
- Domain: "hotmail.com",
- AccountId: "test_id",
- }
- }),
- ),
}
}
@@ -187,6 +175,11 @@ func TestGetPostureCheck(t *testing.T) {
t.Run(tc.name, func(t *testing.T) {
recorder := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodGet, "/api/posture-checks/"+tc.id, tc.requestBody)
+ req = nbcontext.SetUserAuthInRequest(req, nbcontext.UserAuth{
+ UserId: "test_user",
+ Domain: "hotmail.com",
+ AccountId: "test_id",
+ })
router := mux.NewRouter()
router.HandleFunc("/api/posture-checks/{postureCheckId}", p.getPostureCheck).Methods("GET")
@@ -835,6 +828,11 @@ func TestPostureCheckUpdate(t *testing.T) {
t.Run(tc.name, func(t *testing.T) {
recorder := httptest.NewRecorder()
req := httptest.NewRequest(tc.requestType, tc.requestPath, tc.requestBody)
+ req = nbcontext.SetUserAuthInRequest(req, nbcontext.UserAuth{
+ UserId: "test_user",
+ Domain: "hotmail.com",
+ AccountId: "test_id",
+ })
defaultHandler := *p
if tc.setupHandlerFunc != nil {
diff --git a/management/server/http/handlers/routes/routes_handler.go b/management/server/http/handlers/routes/routes_handler.go
index 6b6c37910..0f0d24780 100644
--- a/management/server/http/handlers/routes/routes_handler.go
+++ b/management/server/http/handlers/routes/routes_handler.go
@@ -10,10 +10,9 @@ import (
"github.com/netbirdio/netbird/management/domain"
"github.com/netbirdio/netbird/management/server"
+ nbcontext "github.com/netbirdio/netbird/management/server/context"
"github.com/netbirdio/netbird/management/server/http/api"
- "github.com/netbirdio/netbird/management/server/http/configs"
"github.com/netbirdio/netbird/management/server/http/util"
- "github.com/netbirdio/netbird/management/server/jwtclaims"
"github.com/netbirdio/netbird/management/server/status"
"github.com/netbirdio/netbird/route"
)
@@ -22,12 +21,11 @@ const failedToConvertRoute = "failed to convert route to response: %v"
// handler is the routes handler of the account
type handler struct {
- accountManager server.AccountManager
- claimsExtractor *jwtclaims.ClaimsExtractor
+ accountManager server.AccountManager
}
-func AddEndpoints(accountManager server.AccountManager, authCfg configs.AuthCfg, router *mux.Router) {
- routesHandler := newHandler(accountManager, authCfg)
+func AddEndpoints(accountManager server.AccountManager, router *mux.Router) {
+ routesHandler := newHandler(accountManager)
router.HandleFunc("/routes", routesHandler.getAllRoutes).Methods("GET", "OPTIONS")
router.HandleFunc("/routes", routesHandler.createRoute).Methods("POST", "OPTIONS")
router.HandleFunc("/routes/{routeId}", routesHandler.updateRoute).Methods("PUT", "OPTIONS")
@@ -36,25 +34,22 @@ func AddEndpoints(accountManager server.AccountManager, authCfg configs.AuthCfg,
}
// newHandler returns a new instance of routes handler
-func newHandler(accountManager server.AccountManager, authCfg configs.AuthCfg) *handler {
+func newHandler(accountManager server.AccountManager) *handler {
return &handler{
accountManager: accountManager,
- claimsExtractor: jwtclaims.NewClaimsExtractor(
- jwtclaims.WithAudience(authCfg.Audience),
- jwtclaims.WithUserIDClaim(authCfg.UserIDClaim),
- ),
}
}
// getAllRoutes returns the list of routes for the account
func (h *handler) getAllRoutes(w http.ResponseWriter, r *http.Request) {
- claims := h.claimsExtractor.FromRequestContext(r)
- accountID, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims)
+ userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
+ accountID, userID := userAuth.AccountId, userAuth.UserId
+
routes, err := h.accountManager.ListRoutes(r.Context(), accountID, userID)
if err != nil {
util.WriteError(r.Context(), err, w)
@@ -75,13 +70,14 @@ func (h *handler) getAllRoutes(w http.ResponseWriter, r *http.Request) {
// createRoute handles route creation request
func (h *handler) createRoute(w http.ResponseWriter, r *http.Request) {
- claims := h.claimsExtractor.FromRequestContext(r)
- accountID, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims)
+ userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
+ accountID, userID := userAuth.AccountId, userAuth.UserId
+
var req api.PostApiRoutesJSONRequestBody
err = json.NewDecoder(r.Body).Decode(&req)
if err != nil {
@@ -172,13 +168,13 @@ func (h *handler) validateRoute(req api.PostApiRoutesJSONRequestBody) error {
// updateRoute handles update to a route identified by a given ID
func (h *handler) updateRoute(w http.ResponseWriter, r *http.Request) {
- claims := h.claimsExtractor.FromRequestContext(r)
- accountID, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims)
+ userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
+ accountID, userID := userAuth.AccountId, userAuth.UserId
vars := mux.Vars(r)
routeID := vars["routeId"]
if len(routeID) == 0 {
@@ -265,13 +261,13 @@ func (h *handler) updateRoute(w http.ResponseWriter, r *http.Request) {
// deleteRoute handles route deletion request
func (h *handler) deleteRoute(w http.ResponseWriter, r *http.Request) {
- claims := h.claimsExtractor.FromRequestContext(r)
- accountID, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims)
+ userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
+ accountID, userID := userAuth.AccountId, userAuth.UserId
routeID := mux.Vars(r)["routeId"]
if len(routeID) == 0 {
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "invalid route ID"), w)
@@ -289,13 +285,14 @@ func (h *handler) deleteRoute(w http.ResponseWriter, r *http.Request) {
// getRoute handles a route Get request identified by ID
func (h *handler) getRoute(w http.ResponseWriter, r *http.Request) {
- claims := h.claimsExtractor.FromRequestContext(r)
- accountID, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims)
+ userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
+ accountID, userID := userAuth.AccountId, userAuth.UserId
+
routeID := mux.Vars(r)["routeId"]
if len(routeID) == 0 {
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "invalid route ID"), w)
diff --git a/management/server/http/handlers/routes/routes_handler_test.go b/management/server/http/handlers/routes/routes_handler_test.go
index f3bd79ee4..ad1f8912d 100644
--- a/management/server/http/handlers/routes/routes_handler_test.go
+++ b/management/server/http/handlers/routes/routes_handler_test.go
@@ -16,12 +16,10 @@ import (
"github.com/stretchr/testify/require"
"github.com/netbirdio/netbird/management/domain"
+ nbcontext "github.com/netbirdio/netbird/management/server/context"
"github.com/netbirdio/netbird/management/server/http/api"
- "github.com/netbirdio/netbird/management/server/jwtclaims"
"github.com/netbirdio/netbird/management/server/mock_server"
- nbpeer "github.com/netbirdio/netbird/management/server/peer"
"github.com/netbirdio/netbird/management/server/status"
- "github.com/netbirdio/netbird/management/server/types"
"github.com/netbirdio/netbird/management/server/util"
"github.com/netbirdio/netbird/route"
)
@@ -60,32 +58,6 @@ var baseExistingRoute = &route.Route{
Groups: []string{existingGroupID},
}
-var testingAccount = &types.Account{
- Id: testAccountID,
- Domain: "hotmail.com",
- Peers: map[string]*nbpeer.Peer{
- existingPeerID: {
- Key: existingPeerKey,
- IP: netip.MustParseAddr(existingPeerIP1).AsSlice(),
- ID: existingPeerID,
- Meta: nbpeer.PeerSystemMeta{
- GoOS: "linux",
- },
- },
- nonLinuxExistingPeerID: {
- Key: nonLinuxExistingPeerID,
- IP: netip.MustParseAddr(existingPeerIP2).AsSlice(),
- ID: nonLinuxExistingPeerID,
- Meta: nbpeer.PeerSystemMeta{
- GoOS: "darwin",
- },
- },
- },
- Users: map[string]*types.User{
- "test_user": types.NewAdminUser("test_user"),
- },
-}
-
func initRoutesTestData() *handler {
return &handler{
accountManager: &mock_server.MockAccountManager{
@@ -150,20 +122,7 @@ func initRoutesTestData() *handler {
}
return nil
},
- GetAccountIDFromTokenFunc: func(_ context.Context, _ jwtclaims.AuthorizationClaims) (string, string, error) {
- // return testingAccount, testingAccount.Users["test_user"], nil
- return testingAccount.Id, testingAccount.Users["test_user"].Id, nil
- },
},
- claimsExtractor: jwtclaims.NewClaimsExtractor(
- jwtclaims.WithFromRequestContext(func(r *http.Request) jwtclaims.AuthorizationClaims {
- return jwtclaims.AuthorizationClaims{
- UserId: "test_user",
- Domain: "hotmail.com",
- AccountId: testAccountID,
- }
- }),
- ),
}
}
@@ -526,6 +485,11 @@ func TestRoutesHandlers(t *testing.T) {
t.Run(tc.name, func(t *testing.T) {
recorder := httptest.NewRecorder()
req := httptest.NewRequest(tc.requestType, tc.requestPath, tc.requestBody)
+ req = nbcontext.SetUserAuthInRequest(req, nbcontext.UserAuth{
+ UserId: "test_user",
+ Domain: "hotmail.com",
+ AccountId: testAccountID,
+ })
router := mux.NewRouter()
router.HandleFunc("/api/routes/{routeId}", p.getRoute).Methods("GET")
diff --git a/management/server/http/handlers/setup_keys/setupkeys_handler.go b/management/server/http/handlers/setup_keys/setupkeys_handler.go
index 3bd3ef589..8095f43b0 100644
--- a/management/server/http/handlers/setup_keys/setupkeys_handler.go
+++ b/management/server/http/handlers/setup_keys/setupkeys_handler.go
@@ -10,22 +10,20 @@ import (
"github.com/gorilla/mux"
"github.com/netbirdio/netbird/management/server"
+ nbcontext "github.com/netbirdio/netbird/management/server/context"
"github.com/netbirdio/netbird/management/server/http/api"
- "github.com/netbirdio/netbird/management/server/http/configs"
"github.com/netbirdio/netbird/management/server/http/util"
- "github.com/netbirdio/netbird/management/server/jwtclaims"
"github.com/netbirdio/netbird/management/server/status"
"github.com/netbirdio/netbird/management/server/types"
)
// handler is a handler that returns a list of setup keys of the account
type handler struct {
- accountManager server.AccountManager
- claimsExtractor *jwtclaims.ClaimsExtractor
+ accountManager server.AccountManager
}
-func AddEndpoints(accountManager server.AccountManager, authCfg configs.AuthCfg, router *mux.Router) {
- keysHandler := newHandler(accountManager, authCfg)
+func AddEndpoints(accountManager server.AccountManager, router *mux.Router) {
+ keysHandler := newHandler(accountManager)
router.HandleFunc("/setup-keys", keysHandler.getAllSetupKeys).Methods("GET", "OPTIONS")
router.HandleFunc("/setup-keys", keysHandler.createSetupKey).Methods("POST", "OPTIONS")
router.HandleFunc("/setup-keys/{keyId}", keysHandler.getSetupKey).Methods("GET", "OPTIONS")
@@ -34,25 +32,21 @@ func AddEndpoints(accountManager server.AccountManager, authCfg configs.AuthCfg,
}
// newHandler creates a new setup key handler
-func newHandler(accountManager server.AccountManager, authCfg configs.AuthCfg) *handler {
+func newHandler(accountManager server.AccountManager) *handler {
return &handler{
accountManager: accountManager,
- claimsExtractor: jwtclaims.NewClaimsExtractor(
- jwtclaims.WithAudience(authCfg.Audience),
- jwtclaims.WithUserIDClaim(authCfg.UserIDClaim),
- ),
}
}
// createSetupKey is a POST requests that creates a new SetupKey
func (h *handler) createSetupKey(w http.ResponseWriter, r *http.Request) {
- claims := h.claimsExtractor.FromRequestContext(r)
- accountID, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims)
+ userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
+ accountID, userID := userAuth.AccountId, userAuth.UserId
req := &api.PostApiSetupKeysJSONRequestBody{}
err = json.NewDecoder(r.Body).Decode(&req)
if err != nil {
@@ -108,12 +102,12 @@ func (h *handler) createSetupKey(w http.ResponseWriter, r *http.Request) {
// getSetupKey is a GET request to get a SetupKey by ID
func (h *handler) getSetupKey(w http.ResponseWriter, r *http.Request) {
- claims := h.claimsExtractor.FromRequestContext(r)
- accountID, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims)
+ userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
+ accountID, userID := userAuth.AccountId, userAuth.UserId
vars := mux.Vars(r)
keyID := vars["keyId"]
@@ -133,13 +127,13 @@ func (h *handler) getSetupKey(w http.ResponseWriter, r *http.Request) {
// updateSetupKey is a PUT request to update server.SetupKey
func (h *handler) updateSetupKey(w http.ResponseWriter, r *http.Request) {
- claims := h.claimsExtractor.FromRequestContext(r)
- accountID, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims)
+ userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
+ accountID, userID := userAuth.AccountId, userAuth.UserId
vars := mux.Vars(r)
keyID := vars["keyId"]
if len(keyID) == 0 {
@@ -174,13 +168,13 @@ func (h *handler) updateSetupKey(w http.ResponseWriter, r *http.Request) {
// getAllSetupKeys is a GET request that returns a list of SetupKey
func (h *handler) getAllSetupKeys(w http.ResponseWriter, r *http.Request) {
- claims := h.claimsExtractor.FromRequestContext(r)
- accountID, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims)
+ userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
+ accountID, userID := userAuth.AccountId, userAuth.UserId
setupKeys, err := h.accountManager.ListSetupKeys(r.Context(), accountID, userID)
if err != nil {
util.WriteError(r.Context(), err, w)
@@ -196,13 +190,13 @@ func (h *handler) getAllSetupKeys(w http.ResponseWriter, r *http.Request) {
}
func (h *handler) deleteSetupKey(w http.ResponseWriter, r *http.Request) {
- claims := h.claimsExtractor.FromRequestContext(r)
- accountID, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims)
+ userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
+ accountID, userID := userAuth.AccountId, userAuth.UserId
vars := mux.Vars(r)
keyID := vars["keyId"]
if len(keyID) == 0 {
diff --git a/management/server/http/handlers/setup_keys/setupkeys_handler_test.go b/management/server/http/handlers/setup_keys/setupkeys_handler_test.go
index 4912f9639..e9135469f 100644
--- a/management/server/http/handlers/setup_keys/setupkeys_handler_test.go
+++ b/management/server/http/handlers/setup_keys/setupkeys_handler_test.go
@@ -14,8 +14,8 @@ import (
"github.com/gorilla/mux"
"github.com/stretchr/testify/assert"
+ nbcontext "github.com/netbirdio/netbird/management/server/context"
"github.com/netbirdio/netbird/management/server/http/api"
- "github.com/netbirdio/netbird/management/server/jwtclaims"
"github.com/netbirdio/netbird/management/server/mock_server"
"github.com/netbirdio/netbird/management/server/status"
"github.com/netbirdio/netbird/management/server/types"
@@ -28,14 +28,9 @@ const (
notFoundSetupKeyID = "notFoundSetupKeyID"
)
-func initSetupKeysTestMetaData(defaultKey *types.SetupKey, newKey *types.SetupKey, updatedSetupKey *types.SetupKey,
- user *types.User,
-) *handler {
+func initSetupKeysTestMetaData(defaultKey *types.SetupKey, newKey *types.SetupKey, updatedSetupKey *types.SetupKey) *handler {
return &handler{
accountManager: &mock_server.MockAccountManager{
- GetAccountIDFromTokenFunc: func(_ context.Context, claims jwtclaims.AuthorizationClaims) (string, string, error) {
- return claims.AccountId, claims.UserId, nil
- },
CreateSetupKeyFunc: func(_ context.Context, _ string, keyName string, typ types.SetupKeyType, _ time.Duration, _ []string,
_ int, _ string, ephemeral bool, allowExtraDNSLabels bool,
) (*types.SetupKey, error) {
@@ -76,15 +71,6 @@ func initSetupKeysTestMetaData(defaultKey *types.SetupKey, newKey *types.SetupKe
return status.Errorf(status.NotFound, "key %s not found", keyID)
},
},
- claimsExtractor: jwtclaims.NewClaimsExtractor(
- jwtclaims.WithFromRequestContext(func(r *http.Request) jwtclaims.AuthorizationClaims {
- return jwtclaims.AuthorizationClaims{
- UserId: user.Id,
- Domain: "hotmail.com",
- AccountId: "testAccountId",
- }
- }),
- ),
}
}
@@ -171,12 +157,17 @@ func TestSetupKeysHandlers(t *testing.T) {
},
}
- handler := initSetupKeysTestMetaData(defaultSetupKey, newSetupKey, updatedDefaultSetupKey, adminUser)
+ handler := initSetupKeysTestMetaData(defaultSetupKey, newSetupKey, updatedDefaultSetupKey)
for _, tc := range tt {
t.Run(tc.name, func(t *testing.T) {
recorder := httptest.NewRecorder()
req := httptest.NewRequest(tc.requestType, tc.requestPath, tc.requestBody)
+ req = nbcontext.SetUserAuthInRequest(req, nbcontext.UserAuth{
+ UserId: adminUser.Id,
+ Domain: "hotmail.com",
+ AccountId: "testAccountId",
+ })
router := mux.NewRouter()
router.HandleFunc("/api/setup-keys", handler.getAllSetupKeys).Methods("GET", "OPTIONS")
diff --git a/management/server/http/handlers/users/pat_handler.go b/management/server/http/handlers/users/pat_handler.go
index 7b93d2ae1..84fbef93e 100644
--- a/management/server/http/handlers/users/pat_handler.go
+++ b/management/server/http/handlers/users/pat_handler.go
@@ -7,22 +7,20 @@ import (
"github.com/gorilla/mux"
"github.com/netbirdio/netbird/management/server"
+ nbcontext "github.com/netbirdio/netbird/management/server/context"
"github.com/netbirdio/netbird/management/server/http/api"
- "github.com/netbirdio/netbird/management/server/http/configs"
"github.com/netbirdio/netbird/management/server/http/util"
- "github.com/netbirdio/netbird/management/server/jwtclaims"
"github.com/netbirdio/netbird/management/server/status"
"github.com/netbirdio/netbird/management/server/types"
)
// patHandler is the nameserver group handler of the account
type patHandler struct {
- accountManager server.AccountManager
- claimsExtractor *jwtclaims.ClaimsExtractor
+ accountManager server.AccountManager
}
-func addUsersTokensEndpoint(accountManager server.AccountManager, authCfg configs.AuthCfg, router *mux.Router) {
- tokenHandler := newPATsHandler(accountManager, authCfg)
+func addUsersTokensEndpoint(accountManager server.AccountManager, router *mux.Router) {
+ tokenHandler := newPATsHandler(accountManager)
router.HandleFunc("/users/{userId}/tokens", tokenHandler.getAllTokens).Methods("GET", "OPTIONS")
router.HandleFunc("/users/{userId}/tokens", tokenHandler.createToken).Methods("POST", "OPTIONS")
router.HandleFunc("/users/{userId}/tokens/{tokenId}", tokenHandler.getToken).Methods("GET", "OPTIONS")
@@ -30,25 +28,21 @@ func addUsersTokensEndpoint(accountManager server.AccountManager, authCfg config
}
// newPATsHandler creates a new patHandler HTTP handler
-func newPATsHandler(accountManager server.AccountManager, authCfg configs.AuthCfg) *patHandler {
+func newPATsHandler(accountManager server.AccountManager) *patHandler {
return &patHandler{
accountManager: accountManager,
- claimsExtractor: jwtclaims.NewClaimsExtractor(
- jwtclaims.WithAudience(authCfg.Audience),
- jwtclaims.WithUserIDClaim(authCfg.UserIDClaim),
- ),
}
}
// getAllTokens is HTTP GET handler that returns a list of all personal access tokens for the given user
func (h *patHandler) getAllTokens(w http.ResponseWriter, r *http.Request) {
- claims := h.claimsExtractor.FromRequestContext(r)
- accountID, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims)
+ userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
+ accountID, userID := userAuth.AccountId, userAuth.UserId
vars := mux.Vars(r)
targetUserID := vars["userId"]
if len(userID) == 0 {
@@ -72,13 +66,13 @@ func (h *patHandler) getAllTokens(w http.ResponseWriter, r *http.Request) {
// getToken is HTTP GET handler that returns a personal access token for the given user
func (h *patHandler) getToken(w http.ResponseWriter, r *http.Request) {
- claims := h.claimsExtractor.FromRequestContext(r)
- accountID, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims)
+ userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
+ accountID, userID := userAuth.AccountId, userAuth.UserId
vars := mux.Vars(r)
targetUserID := vars["userId"]
if len(targetUserID) == 0 {
@@ -103,13 +97,13 @@ func (h *patHandler) getToken(w http.ResponseWriter, r *http.Request) {
// createToken is HTTP POST handler that creates a personal access token for the given user
func (h *patHandler) createToken(w http.ResponseWriter, r *http.Request) {
- claims := h.claimsExtractor.FromRequestContext(r)
- accountID, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims)
+ userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
+ accountID, userID := userAuth.AccountId, userAuth.UserId
vars := mux.Vars(r)
targetUserID := vars["userId"]
if len(targetUserID) == 0 {
@@ -135,13 +129,13 @@ func (h *patHandler) createToken(w http.ResponseWriter, r *http.Request) {
// deleteToken is HTTP DELETE handler that deletes a personal access token for the given user
func (h *patHandler) deleteToken(w http.ResponseWriter, r *http.Request) {
- claims := h.claimsExtractor.FromRequestContext(r)
- accountID, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims)
+ userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
+ accountID, userID := userAuth.AccountId, userAuth.UserId
vars := mux.Vars(r)
targetUserID := vars["userId"]
if len(targetUserID) == 0 {
diff --git a/management/server/http/handlers/users/pat_handler_test.go b/management/server/http/handlers/users/pat_handler_test.go
index 9388067a4..6593de64a 100644
--- a/management/server/http/handlers/users/pat_handler_test.go
+++ b/management/server/http/handlers/users/pat_handler_test.go
@@ -12,11 +12,12 @@ import (
"github.com/google/go-cmp/cmp"
"github.com/gorilla/mux"
- "github.com/netbirdio/netbird/management/server/util"
"github.com/stretchr/testify/assert"
+ "github.com/netbirdio/netbird/management/server/util"
+
+ nbcontext "github.com/netbirdio/netbird/management/server/context"
"github.com/netbirdio/netbird/management/server/http/api"
- "github.com/netbirdio/netbird/management/server/jwtclaims"
"github.com/netbirdio/netbird/management/server/mock_server"
"github.com/netbirdio/netbird/management/server/status"
"github.com/netbirdio/netbird/management/server/types"
@@ -77,10 +78,6 @@ func initPATTestData() *patHandler {
PersonalAccessToken: types.PersonalAccessToken{},
}, nil
},
-
- GetAccountIDFromTokenFunc: func(_ context.Context, claims jwtclaims.AuthorizationClaims) (string, string, error) {
- return claims.AccountId, claims.UserId, nil
- },
DeletePATFunc: func(_ context.Context, accountID string, initiatorUserID string, targetUserID string, tokenID string) error {
if accountID != existingAccountID {
return status.Errorf(status.NotFound, "account with ID %s not found", accountID)
@@ -115,15 +112,6 @@ func initPATTestData() *patHandler {
return []*types.PersonalAccessToken{testAccount.Users[existingUserID].PATs[existingTokenID], testAccount.Users[existingUserID].PATs["token2"]}, nil
},
},
- claimsExtractor: jwtclaims.NewClaimsExtractor(
- jwtclaims.WithFromRequestContext(func(r *http.Request) jwtclaims.AuthorizationClaims {
- return jwtclaims.AuthorizationClaims{
- UserId: existingUserID,
- Domain: testDomain,
- AccountId: existingAccountID,
- }
- }),
- ),
}
}
@@ -185,6 +173,11 @@ func TestTokenHandlers(t *testing.T) {
t.Run(tc.name, func(t *testing.T) {
recorder := httptest.NewRecorder()
req := httptest.NewRequest(tc.requestType, tc.requestPath, tc.requestBody)
+ req = nbcontext.SetUserAuthInRequest(req, nbcontext.UserAuth{
+ UserId: existingUserID,
+ Domain: testDomain,
+ AccountId: existingAccountID,
+ })
router := mux.NewRouter()
router.HandleFunc("/api/users/{userId}/tokens", p.getAllTokens).Methods("GET")
diff --git a/management/server/http/handlers/users/users_handler.go b/management/server/http/handlers/users/users_handler.go
index 7380dd97e..3869f21f0 100644
--- a/management/server/http/handlers/users/users_handler.go
+++ b/management/server/http/handlers/users/users_handler.go
@@ -9,39 +9,33 @@ import (
log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/management/server/http/api"
- "github.com/netbirdio/netbird/management/server/http/configs"
"github.com/netbirdio/netbird/management/server/http/util"
"github.com/netbirdio/netbird/management/server/status"
"github.com/netbirdio/netbird/management/server/types"
"github.com/netbirdio/netbird/management/server"
- "github.com/netbirdio/netbird/management/server/jwtclaims"
+ nbcontext "github.com/netbirdio/netbird/management/server/context"
)
// handler is a handler that returns users of the account
type handler struct {
- accountManager server.AccountManager
- claimsExtractor *jwtclaims.ClaimsExtractor
+ accountManager server.AccountManager
}
-func AddEndpoints(accountManager server.AccountManager, authCfg configs.AuthCfg, router *mux.Router) {
- userHandler := newHandler(accountManager, authCfg)
+func AddEndpoints(accountManager server.AccountManager, router *mux.Router) {
+ userHandler := newHandler(accountManager)
router.HandleFunc("/users", userHandler.getAllUsers).Methods("GET", "OPTIONS")
router.HandleFunc("/users/{userId}", userHandler.updateUser).Methods("PUT", "OPTIONS")
router.HandleFunc("/users/{userId}", userHandler.deleteUser).Methods("DELETE", "OPTIONS")
router.HandleFunc("/users", userHandler.createUser).Methods("POST", "OPTIONS")
router.HandleFunc("/users/{userId}/invite", userHandler.inviteUser).Methods("POST", "OPTIONS")
- addUsersTokensEndpoint(accountManager, authCfg, router)
+ addUsersTokensEndpoint(accountManager, router)
}
// newHandler creates a new UsersHandler HTTP handler
-func newHandler(accountManager server.AccountManager, authCfg configs.AuthCfg) *handler {
+func newHandler(accountManager server.AccountManager) *handler {
return &handler{
accountManager: accountManager,
- claimsExtractor: jwtclaims.NewClaimsExtractor(
- jwtclaims.WithAudience(authCfg.Audience),
- jwtclaims.WithUserIDClaim(authCfg.UserIDClaim),
- ),
}
}
@@ -52,13 +46,13 @@ func (h *handler) updateUser(w http.ResponseWriter, r *http.Request) {
return
}
- claims := h.claimsExtractor.FromRequestContext(r)
- accountID, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims)
+ userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
+ accountID, userID := userAuth.AccountId, userAuth.UserId
vars := mux.Vars(r)
targetUserID := vars["userId"]
if len(targetUserID) == 0 {
@@ -103,7 +97,7 @@ func (h *handler) updateUser(w http.ResponseWriter, r *http.Request) {
util.WriteError(r.Context(), err, w)
return
}
- util.WriteJSONObject(r.Context(), w, toUserResponse(newUser, claims.UserId))
+ util.WriteJSONObject(r.Context(), w, toUserResponse(newUser, userID))
}
// deleteUser is a DELETE request to delete a user
@@ -113,13 +107,13 @@ func (h *handler) deleteUser(w http.ResponseWriter, r *http.Request) {
return
}
- claims := h.claimsExtractor.FromRequestContext(r)
- accountID, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims)
+ userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
+ accountID, userID := userAuth.AccountId, userAuth.UserId
vars := mux.Vars(r)
targetUserID := vars["userId"]
if len(targetUserID) == 0 {
@@ -143,12 +137,12 @@ func (h *handler) createUser(w http.ResponseWriter, r *http.Request) {
return
}
- claims := h.claimsExtractor.FromRequestContext(r)
- accountID, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims)
+ userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
+ accountID, userID := userAuth.AccountId, userAuth.UserId
req := &api.PostApiUsersJSONRequestBody{}
err = json.NewDecoder(r.Body).Decode(&req)
@@ -184,7 +178,7 @@ func (h *handler) createUser(w http.ResponseWriter, r *http.Request) {
util.WriteError(r.Context(), err, w)
return
}
- util.WriteJSONObject(r.Context(), w, toUserResponse(newUser, claims.UserId))
+ util.WriteJSONObject(r.Context(), w, toUserResponse(newUser, userID))
}
// getAllUsers returns a list of users of the account this user belongs to.
@@ -195,13 +189,13 @@ func (h *handler) getAllUsers(w http.ResponseWriter, r *http.Request) {
return
}
- claims := h.claimsExtractor.FromRequestContext(r)
- accountID, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims)
+ userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
+ accountID, userID := userAuth.AccountId, userAuth.UserId
data, err := h.accountManager.GetUsersFromAccount(r.Context(), accountID, userID)
if err != nil {
util.WriteError(r.Context(), err, w)
@@ -216,7 +210,7 @@ func (h *handler) getAllUsers(w http.ResponseWriter, r *http.Request) {
continue
}
if serviceUser == "" {
- users = append(users, toUserResponse(d, claims.UserId))
+ users = append(users, toUserResponse(d, userID))
continue
}
@@ -227,7 +221,7 @@ func (h *handler) getAllUsers(w http.ResponseWriter, r *http.Request) {
return
}
if includeServiceUser == d.IsServiceUser {
- users = append(users, toUserResponse(d, claims.UserId))
+ users = append(users, toUserResponse(d, userID))
}
}
@@ -242,12 +236,12 @@ func (h *handler) inviteUser(w http.ResponseWriter, r *http.Request) {
return
}
- claims := h.claimsExtractor.FromRequestContext(r)
- accountID, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims)
+ userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
+ accountID, userID := userAuth.AccountId, userAuth.UserId
vars := mux.Vars(r)
targetUserID := vars["userId"]
diff --git a/management/server/http/handlers/users/users_handler_test.go b/management/server/http/handlers/users/users_handler_test.go
index ff77cedff..a6a904a4c 100644
--- a/management/server/http/handlers/users/users_handler_test.go
+++ b/management/server/http/handlers/users/users_handler_test.go
@@ -13,8 +13,8 @@ import (
"github.com/gorilla/mux"
"github.com/stretchr/testify/assert"
+ nbcontext "github.com/netbirdio/netbird/management/server/context"
"github.com/netbirdio/netbird/management/server/http/api"
- "github.com/netbirdio/netbird/management/server/jwtclaims"
"github.com/netbirdio/netbird/management/server/mock_server"
"github.com/netbirdio/netbird/management/server/status"
"github.com/netbirdio/netbird/management/server/types"
@@ -64,9 +64,6 @@ var usersTestAccount = &types.Account{
func initUsersTestData() *handler {
return &handler{
accountManager: &mock_server.MockAccountManager{
- GetAccountIDFromTokenFunc: func(_ context.Context, claims jwtclaims.AuthorizationClaims) (string, string, error) {
- return usersTestAccount.Id, claims.UserId, nil
- },
GetUserByIDFunc: func(ctx context.Context, id string) (*types.User, error) {
return usersTestAccount.Users[id], nil
},
@@ -127,15 +124,6 @@ func initUsersTestData() *handler {
return nil
},
},
- claimsExtractor: jwtclaims.NewClaimsExtractor(
- jwtclaims.WithFromRequestContext(func(r *http.Request) jwtclaims.AuthorizationClaims {
- return jwtclaims.AuthorizationClaims{
- UserId: existingUserID,
- Domain: testDomain,
- AccountId: existingAccountID,
- }
- }),
- ),
}
}
@@ -158,6 +146,11 @@ func TestGetUsers(t *testing.T) {
t.Run(tc.name, func(t *testing.T) {
recorder := httptest.NewRecorder()
req := httptest.NewRequest(tc.requestType, tc.requestPath, nil)
+ req = nbcontext.SetUserAuthInRequest(req, nbcontext.UserAuth{
+ UserId: existingUserID,
+ Domain: testDomain,
+ AccountId: existingAccountID,
+ })
userHandler.getAllUsers(recorder, req)
@@ -263,6 +256,11 @@ func TestUpdateUser(t *testing.T) {
t.Run(tc.name, func(t *testing.T) {
recorder := httptest.NewRecorder()
req := httptest.NewRequest(tc.requestType, tc.requestPath, tc.requestBody)
+ req = nbcontext.SetUserAuthInRequest(req, nbcontext.UserAuth{
+ UserId: existingUserID,
+ Domain: testDomain,
+ AccountId: existingAccountID,
+ })
router := mux.NewRouter()
router.HandleFunc("/api/users/{userId}", userHandler.updateUser).Methods("PUT")
@@ -355,6 +353,11 @@ func TestCreateUser(t *testing.T) {
t.Run(tc.name, func(t *testing.T) {
req := httptest.NewRequest(tc.requestType, tc.requestPath, tc.requestBody)
rr := httptest.NewRecorder()
+ req = nbcontext.SetUserAuthInRequest(req, nbcontext.UserAuth{
+ UserId: existingUserID,
+ Domain: testDomain,
+ AccountId: existingAccountID,
+ })
userHandler.createUser(rr, req)
@@ -399,6 +402,12 @@ func TestInviteUser(t *testing.T) {
t.Run(tc.name, func(t *testing.T) {
req := httptest.NewRequest(tc.requestType, tc.requestPath, nil)
req = mux.SetURLVars(req, tc.requestVars)
+ req = nbcontext.SetUserAuthInRequest(req, nbcontext.UserAuth{
+ UserId: existingUserID,
+ Domain: testDomain,
+ AccountId: existingAccountID,
+ })
+
rr := httptest.NewRecorder()
userHandler.inviteUser(rr, req)
@@ -452,6 +461,12 @@ func TestDeleteUser(t *testing.T) {
t.Run(tc.name, func(t *testing.T) {
req := httptest.NewRequest(tc.requestType, tc.requestPath, nil)
req = mux.SetURLVars(req, tc.requestVars)
+ req = nbcontext.SetUserAuthInRequest(req, nbcontext.UserAuth{
+ UserId: existingUserID,
+ Domain: testDomain,
+ AccountId: existingAccountID,
+ })
+
rr := httptest.NewRecorder()
userHandler.deleteUser(rr, req)
diff --git a/management/server/http/middleware/access_control.go b/management/server/http/middleware/access_control.go
index c5bdf5fe7..4ed90f47b 100644
--- a/management/server/http/middleware/access_control.go
+++ b/management/server/http/middleware/access_control.go
@@ -7,30 +7,24 @@ import (
log "github.com/sirupsen/logrus"
+ nbcontext "github.com/netbirdio/netbird/management/server/context"
"github.com/netbirdio/netbird/management/server/http/middleware/bypass"
"github.com/netbirdio/netbird/management/server/http/util"
"github.com/netbirdio/netbird/management/server/status"
"github.com/netbirdio/netbird/management/server/types"
-
- "github.com/netbirdio/netbird/management/server/jwtclaims"
)
// GetUser function defines a function to fetch user from Account by jwtclaims.AuthorizationClaims
-type GetUser func(ctx context.Context, claims jwtclaims.AuthorizationClaims) (*types.User, error)
+type GetUser func(ctx context.Context, userAuth nbcontext.UserAuth) (*types.User, error)
// AccessControl middleware to restrict to make POST/PUT/DELETE requests by admin only
type AccessControl struct {
- claimsExtract jwtclaims.ClaimsExtractor
- getUser GetUser
+ getUser GetUser
}
// NewAccessControl instance constructor
-func NewAccessControl(audience, userIDClaim string, getUser GetUser) *AccessControl {
+func NewAccessControl(getUser GetUser) *AccessControl {
return &AccessControl{
- claimsExtract: *jwtclaims.NewClaimsExtractor(
- jwtclaims.WithAudience(audience),
- jwtclaims.WithUserIDClaim(userIDClaim),
- ),
getUser: getUser,
}
}
@@ -45,12 +39,16 @@ func (a *AccessControl) Handler(h http.Handler) http.Handler {
return
}
- claims := a.claimsExtract.FromRequestContext(r)
-
- user, err := a.getUser(r.Context(), claims)
+ userAuth, err := nbcontext.GetUserAuthFromRequest(r)
if err != nil {
- log.WithContext(r.Context()).Errorf("failed to get user from claims: %s", err)
- util.WriteError(r.Context(), status.Errorf(status.Unauthorized, "invalid JWT"), w)
+ log.WithContext(r.Context()).Errorf("failed to get user auth from request: %s", err)
+ util.WriteError(r.Context(), status.Errorf(status.Unauthorized, "invalid user auth"), w)
+ }
+
+ user, err := a.getUser(r.Context(), userAuth)
+ if err != nil {
+ log.WithContext(r.Context()).Errorf("failed to get user: %s", err)
+ util.WriteError(r.Context(), status.Errorf(status.Unauthorized, "invalid user auth"), w)
return
}
diff --git a/management/server/http/middleware/auth_middleware.go b/management/server/http/middleware/auth_middleware.go
index dcf73259a..a8e6790a9 100644
--- a/management/server/http/middleware/auth_middleware.go
+++ b/management/server/http/middleware/auth_middleware.go
@@ -8,67 +8,41 @@ import (
"strings"
"time"
- "github.com/golang-jwt/jwt"
log "github.com/sirupsen/logrus"
- nbContext "github.com/netbirdio/netbird/management/server/context"
+ "github.com/netbirdio/netbird/management/server/auth"
+ nbcontext "github.com/netbirdio/netbird/management/server/context"
"github.com/netbirdio/netbird/management/server/http/middleware/bypass"
"github.com/netbirdio/netbird/management/server/http/util"
- "github.com/netbirdio/netbird/management/server/jwtclaims"
"github.com/netbirdio/netbird/management/server/status"
- "github.com/netbirdio/netbird/management/server/types"
)
-// GetAccountInfoFromPATFunc function
-type GetAccountInfoFromPATFunc func(ctx context.Context, token string) (user *types.User, pat *types.PersonalAccessToken, domain string, category string, err error)
-
-// ValidateAndParseTokenFunc function
-type ValidateAndParseTokenFunc func(ctx context.Context, token string) (*jwt.Token, error)
-
-// MarkPATUsedFunc function
-type MarkPATUsedFunc func(ctx context.Context, token string) error
-
-// CheckUserAccessByJWTGroupsFunc function
-type CheckUserAccessByJWTGroupsFunc func(ctx context.Context, claims jwtclaims.AuthorizationClaims) error
+type EnsureAccountFunc func(ctx context.Context, userAuth nbcontext.UserAuth) (string, string, error)
+type SyncUserJWTGroupsFunc func(ctx context.Context, userAuth nbcontext.UserAuth) error
// AuthMiddleware middleware to verify personal access tokens (PAT) and JWT tokens
type AuthMiddleware struct {
- getAccountInfoFromPAT GetAccountInfoFromPATFunc
- validateAndParseToken ValidateAndParseTokenFunc
- markPATUsed MarkPATUsedFunc
- checkUserAccessByJWTGroups CheckUserAccessByJWTGroupsFunc
- claimsExtractor *jwtclaims.ClaimsExtractor
- audience string
- userIDClaim string
+ authManager auth.Manager
+ ensureAccount EnsureAccountFunc
+ syncUserJWTGroups SyncUserJWTGroupsFunc
}
-const (
- userProperty = "user"
-)
-
// NewAuthMiddleware instance constructor
-func NewAuthMiddleware(getAccountInfoFromPAT GetAccountInfoFromPATFunc, validateAndParseToken ValidateAndParseTokenFunc,
- markPATUsed MarkPATUsedFunc, checkUserAccessByJWTGroups CheckUserAccessByJWTGroupsFunc, claimsExtractor *jwtclaims.ClaimsExtractor,
- audience string, userIdClaim string) *AuthMiddleware {
- if userIdClaim == "" {
- userIdClaim = jwtclaims.UserIDClaim
- }
-
+func NewAuthMiddleware(
+ authManager auth.Manager,
+ ensureAccount EnsureAccountFunc,
+ syncUserJWTGroups SyncUserJWTGroupsFunc,
+) *AuthMiddleware {
return &AuthMiddleware{
- getAccountInfoFromPAT: getAccountInfoFromPAT,
- validateAndParseToken: validateAndParseToken,
- markPATUsed: markPATUsed,
- checkUserAccessByJWTGroups: checkUserAccessByJWTGroups,
- claimsExtractor: claimsExtractor,
- audience: audience,
- userIDClaim: userIdClaim,
+ authManager: authManager,
+ ensureAccount: ensureAccount,
+ syncUserJWTGroups: syncUserJWTGroups,
}
}
// Handler method of the middleware which authenticates a user either by JWT claims or by PAT
func (m *AuthMiddleware) Handler(h http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
-
if bypass.ShouldBypass(r.URL.Path, h, w, r) {
return
}
@@ -84,108 +58,111 @@ func (m *AuthMiddleware) Handler(h http.Handler) http.Handler {
switch authType {
case "bearer":
- err := m.checkJWTFromRequest(w, r, auth)
+ request, err := m.checkJWTFromRequest(r, auth)
if err != nil {
- log.WithContext(r.Context()).Errorf("Error when validating JWT claims: %s", err.Error())
+ log.WithContext(r.Context()).Errorf("Error when validating JWT: %s", err.Error())
util.WriteError(r.Context(), status.Errorf(status.Unauthorized, "token invalid"), w)
return
}
+
+ h.ServeHTTP(w, request)
case "token":
- err := m.checkPATFromRequest(w, r, auth)
+ request, err := m.checkPATFromRequest(r, auth)
if err != nil {
- log.WithContext(r.Context()).Debugf("Error when validating PAT claims: %s", err.Error())
+ log.WithContext(r.Context()).Debugf("Error when validating PAT: %s", err.Error())
util.WriteError(r.Context(), status.Errorf(status.Unauthorized, "token invalid"), w)
return
}
+ h.ServeHTTP(w, request)
default:
util.WriteError(r.Context(), status.Errorf(status.Unauthorized, "no valid authentication provided"), w)
return
}
- claims := m.claimsExtractor.FromRequestContext(r)
- //nolint
- ctx := context.WithValue(r.Context(), nbContext.UserIDKey, claims.UserId)
- //nolint
- ctx = context.WithValue(ctx, nbContext.AccountIDKey, claims.AccountId)
- h.ServeHTTP(w, r.WithContext(ctx))
})
}
// CheckJWTFromRequest checks if the JWT is valid
-func (m *AuthMiddleware) checkJWTFromRequest(w http.ResponseWriter, r *http.Request, auth []string) error {
+func (m *AuthMiddleware) checkJWTFromRequest(r *http.Request, auth []string) (*http.Request, error) {
token, err := getTokenFromJWTRequest(auth)
// If an error occurs, call the error handler and return an error
if err != nil {
- return fmt.Errorf("Error extracting token: %w", err)
+ return r, fmt.Errorf("error extracting token: %w", err)
}
- validatedToken, err := m.validateAndParseToken(r.Context(), token)
+ ctx := r.Context()
+
+ userAuth, validatedToken, err := m.authManager.ValidateAndParseToken(ctx, token)
if err != nil {
- return err
+ return r, err
}
- if validatedToken == nil {
- return nil
+ if impersonate, ok := r.URL.Query()["account"]; ok && len(impersonate) == 1 {
+ userAuth.AccountId = impersonate[0]
+ userAuth.IsChild = ok
}
- if err := m.verifyUserAccess(r.Context(), validatedToken); err != nil {
- return err
+ // we need to call this method because if user is new, we will automatically add it to existing or create a new account
+ accountId, _, err := m.ensureAccount(ctx, userAuth)
+ if err != nil {
+ return r, err
}
- // If we get here, everything worked and we can set the
- // user property in context.
- newRequest := r.WithContext(context.WithValue(r.Context(), userProperty, validatedToken)) //nolint
- // Update the current request with the new context information.
- *r = *newRequest
- return nil
-}
+ if userAuth.AccountId != accountId {
+ log.WithContext(ctx).Debugf("Auth middleware sets accountId from ensure, before %s, now %s", userAuth.AccountId, accountId)
+ userAuth.AccountId = accountId
+ }
-// verifyUserAccess checks if a user, based on a validated JWT token,
-// is allowed access, particularly in cases where the admin enabled JWT
-// group propagation and designated certain groups with access permissions.
-func (m *AuthMiddleware) verifyUserAccess(ctx context.Context, validatedToken *jwt.Token) error {
- authClaims := m.claimsExtractor.FromToken(validatedToken)
- return m.checkUserAccessByJWTGroups(ctx, authClaims)
+ userAuth, err = m.authManager.EnsureUserAccessByJWTGroups(ctx, userAuth, validatedToken)
+ if err != nil {
+ return r, err
+ }
+
+ err = m.syncUserJWTGroups(ctx, userAuth)
+ if err != nil {
+ log.WithContext(ctx).Errorf("HTTP server failed to sync user JWT groups: %s", err)
+ }
+
+ return nbcontext.SetUserAuthInRequest(r, userAuth), nil
}
// CheckPATFromRequest checks if the PAT is valid
-func (m *AuthMiddleware) checkPATFromRequest(w http.ResponseWriter, r *http.Request, auth []string) error {
+func (m *AuthMiddleware) checkPATFromRequest(r *http.Request, auth []string) (*http.Request, error) {
token, err := getTokenFromPATRequest(auth)
if err != nil {
- return fmt.Errorf("error extracting token: %w", err)
+ return r, fmt.Errorf("error extracting token: %w", err)
}
- user, pat, accDomain, accCategory, err := m.getAccountInfoFromPAT(r.Context(), token)
+ ctx := r.Context()
+ user, pat, accDomain, accCategory, err := m.authManager.GetPATInfo(ctx, token)
if err != nil {
- return fmt.Errorf("invalid Token: %w", err)
+ return r, fmt.Errorf("invalid Token: %w", err)
}
if time.Now().After(pat.GetExpirationDate()) {
- return fmt.Errorf("token expired")
+ return r, fmt.Errorf("token expired")
}
- err = m.markPATUsed(r.Context(), pat.ID)
+ err = m.authManager.MarkPATUsed(ctx, pat.ID)
if err != nil {
- return err
+ return r, err
}
- claimMaps := jwt.MapClaims{}
- claimMaps[m.userIDClaim] = user.Id
- claimMaps[m.audience+jwtclaims.AccountIDSuffix] = user.AccountID
- claimMaps[m.audience+jwtclaims.DomainIDSuffix] = accDomain
- claimMaps[m.audience+jwtclaims.DomainCategorySuffix] = accCategory
- claimMaps[jwtclaims.IsToken] = true
- jwtToken := jwt.NewWithClaims(jwt.SigningMethodHS256, claimMaps)
- newRequest := r.WithContext(context.WithValue(r.Context(), jwtclaims.TokenUserProperty, jwtToken)) //nolint
- // Update the current request with the new context information.
- *r = *newRequest
- return nil
+ userAuth := nbcontext.UserAuth{
+ UserId: user.Id,
+ AccountId: user.AccountID,
+ Domain: accDomain,
+ DomainCategory: accCategory,
+ IsPAT: true,
+ }
+
+ return nbcontext.SetUserAuthInRequest(r, userAuth), nil
}
// getTokenFromJWTRequest is a "TokenExtractor" that takes auth header parts and extracts
// the JWT token from the Authorization header.
func getTokenFromJWTRequest(authHeaderParts []string) (string, error) {
if len(authHeaderParts) != 2 || strings.ToLower(authHeaderParts[0]) != "bearer" {
- return "", errors.New("Authorization header format must be Bearer {token}")
+ return "", errors.New("authorization header format must be Bearer {token}")
}
return authHeaderParts[1], nil
@@ -195,7 +172,7 @@ func getTokenFromJWTRequest(authHeaderParts []string) (string, error) {
// the PAT token from the Authorization header.
func getTokenFromPATRequest(authHeaderParts []string) (string, error) {
if len(authHeaderParts) != 2 || strings.ToLower(authHeaderParts[0]) != "token" {
- return "", errors.New("Authorization header format must be Token {token}")
+ return "", errors.New("authorization header format must be Token {token}")
}
return authHeaderParts[1], nil
diff --git a/management/server/http/middleware/auth_middleware_test.go b/management/server/http/middleware/auth_middleware_test.go
index c1686ed44..3dc7d51cb 100644
--- a/management/server/http/middleware/auth_middleware_test.go
+++ b/management/server/http/middleware/auth_middleware_test.go
@@ -9,10 +9,14 @@ import (
"time"
"github.com/golang-jwt/jwt"
+ "github.com/stretchr/testify/assert"
+
+ "github.com/netbirdio/netbird/management/server/auth"
+ nbjwt "github.com/netbirdio/netbird/management/server/auth/jwt"
+ nbcontext "github.com/netbirdio/netbird/management/server/context"
"github.com/netbirdio/netbird/management/server/util"
"github.com/netbirdio/netbird/management/server/http/middleware/bypass"
- "github.com/netbirdio/netbird/management/server/jwtclaims"
"github.com/netbirdio/netbird/management/server/types"
)
@@ -58,17 +62,23 @@ func mockGetAccountInfoFromPAT(_ context.Context, token string) (user *types.Use
return nil, nil, "", "", fmt.Errorf("PAT invalid")
}
-func mockValidateAndParseToken(_ context.Context, token string) (*jwt.Token, error) {
+func mockValidateAndParseToken(_ context.Context, token string) (nbcontext.UserAuth, *jwt.Token, error) {
if token == JWT {
- return &jwt.Token{
- Claims: jwt.MapClaims{
- userIDClaim: userID,
- audience + jwtclaims.AccountIDSuffix: accountID,
+ return nbcontext.UserAuth{
+ UserId: userID,
+ AccountId: accountID,
+ Domain: testAccount.Domain,
+ DomainCategory: testAccount.DomainCategory,
},
- Valid: true,
- }, nil
+ &jwt.Token{
+ Claims: jwt.MapClaims{
+ userIDClaim: userID,
+ audience + nbjwt.AccountIDSuffix: accountID,
+ },
+ Valid: true,
+ }, nil
}
- return nil, fmt.Errorf("JWT invalid")
+ return nbcontext.UserAuth{}, nil, fmt.Errorf("JWT invalid")
}
func mockMarkPATUsed(_ context.Context, token string) error {
@@ -78,16 +88,20 @@ func mockMarkPATUsed(_ context.Context, token string) error {
return fmt.Errorf("Should never get reached")
}
-func mockCheckUserAccessByJWTGroups(_ context.Context, claims jwtclaims.AuthorizationClaims) error {
- if testAccount.Id != claims.AccountId {
- return fmt.Errorf("account with id %s does not exist", claims.AccountId)
+func mockEnsureUserAccessByJWTGroups(_ context.Context, userAuth nbcontext.UserAuth, token *jwt.Token) (nbcontext.UserAuth, error) {
+ if userAuth.IsChild || userAuth.IsPAT {
+ return userAuth, nil
}
- if _, ok := testAccount.Users[claims.UserId]; !ok {
- return fmt.Errorf("user with id %s does not exist", claims.UserId)
+ if testAccount.Id != userAuth.AccountId {
+ return userAuth, fmt.Errorf("account with id %s does not exist", userAuth.AccountId)
}
- return nil
+ if _, ok := testAccount.Users[userAuth.UserId]; !ok {
+ return userAuth, fmt.Errorf("user with id %s does not exist", userAuth.UserId)
+ }
+
+ return userAuth, nil
}
func TestAuthMiddleware_Handler(t *testing.T) {
@@ -158,22 +172,24 @@ func TestAuthMiddleware_Handler(t *testing.T) {
}
nextHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
- // do nothing
+
})
- claimsExtractor := jwtclaims.NewClaimsExtractor(
- jwtclaims.WithAudience(audience),
- jwtclaims.WithUserIDClaim(userIDClaim),
- )
+ mockAuth := &auth.MockManager{
+ ValidateAndParseTokenFunc: mockValidateAndParseToken,
+ EnsureUserAccessByJWTGroupsFunc: mockEnsureUserAccessByJWTGroups,
+ MarkPATUsedFunc: mockMarkPATUsed,
+ GetPATInfoFunc: mockGetAccountInfoFromPAT,
+ }
authMiddleware := NewAuthMiddleware(
- mockGetAccountInfoFromPAT,
- mockValidateAndParseToken,
- mockMarkPATUsed,
- mockCheckUserAccessByJWTGroups,
- claimsExtractor,
- audience,
- userIDClaim,
+ mockAuth,
+ func(ctx context.Context, userAuth nbcontext.UserAuth) (string, string, error) {
+ return userAuth.AccountId, userAuth.UserId, nil
+ },
+ func(ctx context.Context, userAuth nbcontext.UserAuth) error {
+ return nil
+ },
)
handlerToTest := authMiddleware.Handler(nextHandler)
@@ -195,9 +211,115 @@ func TestAuthMiddleware_Handler(t *testing.T) {
result := rec.Result()
defer result.Body.Close()
+
if result.StatusCode != tc.expectedStatusCode {
t.Errorf("expected status code %d, got %d", tc.expectedStatusCode, result.StatusCode)
}
})
}
}
+
+func TestAuthMiddleware_Handler_Child(t *testing.T) {
+ tt := []struct {
+ name string
+ path string
+ authHeader string
+ expectedUserAuth *nbcontext.UserAuth // nil expects 401 response status
+ }{
+ {
+ name: "Valid PAT Token",
+ path: "/test",
+ authHeader: "Token " + PAT,
+ expectedUserAuth: &nbcontext.UserAuth{
+ AccountId: accountID,
+ UserId: userID,
+ Domain: testAccount.Domain,
+ DomainCategory: testAccount.DomainCategory,
+ IsPAT: true,
+ },
+ },
+ {
+ name: "Valid PAT Token ignores child",
+ path: "/test?account=xyz",
+ authHeader: "Token " + PAT,
+ expectedUserAuth: &nbcontext.UserAuth{
+ AccountId: accountID,
+ UserId: userID,
+ Domain: testAccount.Domain,
+ DomainCategory: testAccount.DomainCategory,
+ IsPAT: true,
+ },
+ },
+ {
+ name: "Valid JWT Token",
+ path: "/test",
+ authHeader: "Bearer " + JWT,
+ expectedUserAuth: &nbcontext.UserAuth{
+ AccountId: accountID,
+ UserId: userID,
+ Domain: testAccount.Domain,
+ DomainCategory: testAccount.DomainCategory,
+ },
+ },
+
+ {
+ name: "Valid JWT Token with child",
+ path: "/test?account=xyz",
+ authHeader: "Bearer " + JWT,
+ expectedUserAuth: &nbcontext.UserAuth{
+ AccountId: "xyz",
+ UserId: userID,
+ Domain: testAccount.Domain,
+ DomainCategory: testAccount.DomainCategory,
+ IsChild: true,
+ },
+ },
+ }
+
+ mockAuth := &auth.MockManager{
+ ValidateAndParseTokenFunc: mockValidateAndParseToken,
+ EnsureUserAccessByJWTGroupsFunc: mockEnsureUserAccessByJWTGroups,
+ MarkPATUsedFunc: mockMarkPATUsed,
+ GetPATInfoFunc: mockGetAccountInfoFromPAT,
+ }
+
+ authMiddleware := NewAuthMiddleware(
+ mockAuth,
+ func(ctx context.Context, userAuth nbcontext.UserAuth) (string, string, error) {
+ return userAuth.AccountId, userAuth.UserId, nil
+ },
+ func(ctx context.Context, userAuth nbcontext.UserAuth) error {
+ return nil
+ },
+ )
+
+ for _, tc := range tt {
+ t.Run(tc.name, func(t *testing.T) {
+ handlerToTest := authMiddleware.Handler(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ userAuth, err := nbcontext.GetUserAuthFromRequest(r)
+ if tc.expectedUserAuth != nil {
+ assert.NoError(t, err)
+ assert.Equal(t, *tc.expectedUserAuth, userAuth)
+ } else {
+ assert.Error(t, err)
+ assert.Empty(t, userAuth)
+ }
+ }))
+
+ req := httptest.NewRequest("GET", "http://testing"+tc.path, nil)
+ req.Header.Set("Authorization", tc.authHeader)
+ rec := httptest.NewRecorder()
+
+ handlerToTest.ServeHTTP(rec, req)
+
+ result := rec.Result()
+ defer result.Body.Close()
+
+ if tc.expectedUserAuth != nil {
+ assert.Equal(t, 200, result.StatusCode)
+ } else {
+ assert.Equal(t, 401, result.StatusCode)
+ }
+ })
+ }
+}
diff --git a/management/server/http/testing/benchmarks/peers_handler_benchmark_test.go b/management/server/http/testing/benchmarks/peers_handler_benchmark_test.go
index 7f8eee6e7..e2c2c1d85 100644
--- a/management/server/http/testing/benchmarks/peers_handler_benchmark_test.go
+++ b/management/server/http/testing/benchmarks/peers_handler_benchmark_test.go
@@ -77,13 +77,13 @@ func BenchmarkUpdatePeer(b *testing.B) {
func BenchmarkGetOnePeer(b *testing.B) {
var expectedMetrics = map[string]testing_tools.PerformanceMetrics{
- "Peers - XS": {MinMsPerOpLocal: 15, MaxMsPerOpLocal: 40, MinMsPerOpCICD: 30, MaxMsPerOpCICD: 70},
- "Peers - S": {MinMsPerOpLocal: 1, MaxMsPerOpLocal: 5, MinMsPerOpCICD: 2, MaxMsPerOpCICD: 30},
- "Peers - M": {MinMsPerOpLocal: 9, MaxMsPerOpLocal: 18, MinMsPerOpCICD: 15, MaxMsPerOpCICD: 50},
- "Peers - L": {MinMsPerOpLocal: 40, MaxMsPerOpLocal: 90, MinMsPerOpCICD: 50, MaxMsPerOpCICD: 130},
- "Groups - L": {MinMsPerOpLocal: 80, MaxMsPerOpLocal: 130, MinMsPerOpCICD: 30, MaxMsPerOpCICD: 200},
- "Users - L": {MinMsPerOpLocal: 40, MaxMsPerOpLocal: 90, MinMsPerOpCICD: 50, MaxMsPerOpCICD: 130},
- "Setup Keys - L": {MinMsPerOpLocal: 40, MaxMsPerOpLocal: 90, MinMsPerOpCICD: 50, MaxMsPerOpCICD: 130},
+ "Peers - XS": {MinMsPerOpLocal: 15, MaxMsPerOpLocal: 40, MinMsPerOpCICD: 2, MaxMsPerOpCICD: 70},
+ "Peers - S": {MinMsPerOpLocal: 1, MaxMsPerOpLocal: 5, MinMsPerOpCICD: 2, MaxMsPerOpCICD: 70},
+ "Peers - M": {MinMsPerOpLocal: 9, MaxMsPerOpLocal: 18, MinMsPerOpCICD: 2, MaxMsPerOpCICD: 70},
+ "Peers - L": {MinMsPerOpLocal: 40, MaxMsPerOpLocal: 90, MinMsPerOpCICD: 5, MaxMsPerOpCICD: 200},
+ "Groups - L": {MinMsPerOpLocal: 80, MaxMsPerOpLocal: 130, MinMsPerOpCICD: 5, MaxMsPerOpCICD: 200},
+ "Users - L": {MinMsPerOpLocal: 40, MaxMsPerOpLocal: 90, MinMsPerOpCICD: 5, MaxMsPerOpCICD: 200},
+ "Setup Keys - L": {MinMsPerOpLocal: 40, MaxMsPerOpLocal: 90, MinMsPerOpCICD: 5, MaxMsPerOpCICD: 200},
"Peers - XL": {MinMsPerOpLocal: 200, MaxMsPerOpLocal: 400, MinMsPerOpCICD: 200, MaxMsPerOpCICD: 750},
}
@@ -111,9 +111,9 @@ func BenchmarkGetOnePeer(b *testing.B) {
func BenchmarkGetAllPeers(b *testing.B) {
var expectedMetrics = map[string]testing_tools.PerformanceMetrics{
- "Peers - XS": {MinMsPerOpLocal: 40, MaxMsPerOpLocal: 70, MinMsPerOpCICD: 50, MaxMsPerOpCICD: 150},
- "Peers - S": {MinMsPerOpLocal: 2, MaxMsPerOpLocal: 10, MinMsPerOpCICD: 5, MaxMsPerOpCICD: 30},
- "Peers - M": {MinMsPerOpLocal: 20, MaxMsPerOpLocal: 50, MinMsPerOpCICD: 20, MaxMsPerOpCICD: 70},
+ "Peers - XS": {MinMsPerOpLocal: 40, MaxMsPerOpLocal: 70, MinMsPerOpCICD: 2, MaxMsPerOpCICD: 100},
+ "Peers - S": {MinMsPerOpLocal: 2, MaxMsPerOpLocal: 10, MinMsPerOpCICD: 2, MaxMsPerOpCICD: 100},
+ "Peers - M": {MinMsPerOpLocal: 20, MaxMsPerOpLocal: 50, MinMsPerOpCICD: 2, MaxMsPerOpCICD: 100},
"Peers - L": {MinMsPerOpLocal: 110, MaxMsPerOpLocal: 150, MinMsPerOpCICD: 100, MaxMsPerOpCICD: 300},
"Groups - L": {MinMsPerOpLocal: 150, MaxMsPerOpLocal: 200, MinMsPerOpCICD: 130, MaxMsPerOpCICD: 500},
"Users - L": {MinMsPerOpLocal: 100, MaxMsPerOpLocal: 170, MinMsPerOpCICD: 100, MaxMsPerOpCICD: 400},
diff --git a/management/server/http/testing/benchmarks/users_handler_benchmark_test.go b/management/server/http/testing/benchmarks/users_handler_benchmark_test.go
index 0baf76328..b7deab334 100644
--- a/management/server/http/testing/benchmarks/users_handler_benchmark_test.go
+++ b/management/server/http/testing/benchmarks/users_handler_benchmark_test.go
@@ -48,13 +48,12 @@ func BenchmarkUpdateUser(b *testing.B) {
log.SetOutput(io.Discard)
defer log.SetOutput(os.Stderr)
- recorder := httptest.NewRecorder()
-
for name, bc := range benchCasesUsers {
b.Run(name, func(b *testing.B) {
apiHandler, am, _ := testing_tools.BuildApiBlackBoxWithDBState(b, "../testdata/users.sql", nil, false)
testing_tools.PopulateTestData(b, am.(*server.DefaultAccountManager), bc.Peers, bc.Groups, bc.Users, bc.SetupKeys)
+ recorder := httptest.NewRecorder()
b.ResetTimer()
start := time.Now()
for i := 0; i < b.N; i++ {
@@ -97,13 +96,12 @@ func BenchmarkGetOneUser(b *testing.B) {
log.SetOutput(io.Discard)
defer log.SetOutput(os.Stderr)
- recorder := httptest.NewRecorder()
-
for name, bc := range benchCasesUsers {
b.Run(name, func(b *testing.B) {
apiHandler, am, _ := testing_tools.BuildApiBlackBoxWithDBState(b, "../testdata/users.sql", nil, false)
testing_tools.PopulateTestData(b, am.(*server.DefaultAccountManager), bc.Peers, bc.Groups, bc.Users, bc.SetupKeys)
+ recorder := httptest.NewRecorder()
b.ResetTimer()
start := time.Now()
for i := 0; i < b.N; i++ {
@@ -118,26 +116,25 @@ func BenchmarkGetOneUser(b *testing.B) {
func BenchmarkGetAllUsers(b *testing.B) {
var expectedMetrics = map[string]testing_tools.PerformanceMetrics{
- "Users - XS": {MinMsPerOpLocal: 0, MaxMsPerOpLocal: 2, MinMsPerOpCICD: 0, MaxMsPerOpCICD: 10},
- "Users - S": {MinMsPerOpLocal: 0, MaxMsPerOpLocal: 2, MinMsPerOpCICD: 0, MaxMsPerOpCICD: 10},
- "Users - M": {MinMsPerOpLocal: 3, MaxMsPerOpLocal: 10, MinMsPerOpCICD: 5, MaxMsPerOpCICD: 15},
- "Users - L": {MinMsPerOpLocal: 10, MaxMsPerOpLocal: 20, MinMsPerOpCICD: 20, MaxMsPerOpCICD: 50},
- "Peers - L": {MinMsPerOpLocal: 15, MaxMsPerOpLocal: 25, MinMsPerOpCICD: 20, MaxMsPerOpCICD: 55},
- "Groups - L": {MinMsPerOpLocal: 15, MaxMsPerOpLocal: 25, MinMsPerOpCICD: 25, MaxMsPerOpCICD: 55},
- "Setup Keys - L": {MinMsPerOpLocal: 15, MaxMsPerOpLocal: 25, MinMsPerOpCICD: 25, MaxMsPerOpCICD: 55},
- "Users - XL": {MinMsPerOpLocal: 80, MaxMsPerOpLocal: 120, MinMsPerOpCICD: 100, MaxMsPerOpCICD: 300},
+ "Users - XS": {MinMsPerOpLocal: 0, MaxMsPerOpLocal: 2, MinMsPerOpCICD: 0, MaxMsPerOpCICD: 75},
+ "Users - S": {MinMsPerOpLocal: 0, MaxMsPerOpLocal: 2, MinMsPerOpCICD: 0, MaxMsPerOpCICD: 75},
+ "Users - M": {MinMsPerOpLocal: 3, MaxMsPerOpLocal: 10, MinMsPerOpCICD: 0, MaxMsPerOpCICD: 75},
+ "Users - L": {MinMsPerOpLocal: 10, MaxMsPerOpLocal: 20, MinMsPerOpCICD: 10, MaxMsPerOpCICD: 100},
+ "Peers - L": {MinMsPerOpLocal: 15, MaxMsPerOpLocal: 25, MinMsPerOpCICD: 10, MaxMsPerOpCICD: 100},
+ "Groups - L": {MinMsPerOpLocal: 15, MaxMsPerOpLocal: 25, MinMsPerOpCICD: 10, MaxMsPerOpCICD: 100},
+ "Setup Keys - L": {MinMsPerOpLocal: 15, MaxMsPerOpLocal: 25, MinMsPerOpCICD: 10, MaxMsPerOpCICD: 100},
+ "Users - XL": {MinMsPerOpLocal: 80, MaxMsPerOpLocal: 120, MinMsPerOpCICD: 50, MaxMsPerOpCICD: 300},
}
log.SetOutput(io.Discard)
defer log.SetOutput(os.Stderr)
- recorder := httptest.NewRecorder()
-
for name, bc := range benchCasesUsers {
b.Run(name, func(b *testing.B) {
apiHandler, am, _ := testing_tools.BuildApiBlackBoxWithDBState(b, "../testdata/users.sql", nil, false)
testing_tools.PopulateTestData(b, am.(*server.DefaultAccountManager), bc.Peers, bc.Groups, bc.Users, bc.SetupKeys)
+ recorder := httptest.NewRecorder()
b.ResetTimer()
start := time.Now()
for i := 0; i < b.N; i++ {
@@ -152,26 +149,25 @@ func BenchmarkGetAllUsers(b *testing.B) {
func BenchmarkDeleteUsers(b *testing.B) {
var expectedMetrics = map[string]testing_tools.PerformanceMetrics{
- "Users - XS": {MinMsPerOpLocal: 0, MaxMsPerOpLocal: 5, MinMsPerOpCICD: 2, MaxMsPerOpCICD: 15},
- "Users - S": {MinMsPerOpLocal: 0, MaxMsPerOpLocal: 5, MinMsPerOpCICD: 2, MaxMsPerOpCICD: 15},
- "Users - M": {MinMsPerOpLocal: 0, MaxMsPerOpLocal: 5, MinMsPerOpCICD: 2, MaxMsPerOpCICD: 15},
- "Users - L": {MinMsPerOpLocal: 0, MaxMsPerOpLocal: 5, MinMsPerOpCICD: 2, MaxMsPerOpCICD: 15},
- "Peers - L": {MinMsPerOpLocal: 0, MaxMsPerOpLocal: 5, MinMsPerOpCICD: 2, MaxMsPerOpCICD: 15},
- "Groups - L": {MinMsPerOpLocal: 0, MaxMsPerOpLocal: 5, MinMsPerOpCICD: 2, MaxMsPerOpCICD: 15},
- "Setup Keys - L": {MinMsPerOpLocal: 0, MaxMsPerOpLocal: 5, MinMsPerOpCICD: 2, MaxMsPerOpCICD: 15},
- "Users - XL": {MinMsPerOpLocal: 0, MaxMsPerOpLocal: 5, MinMsPerOpCICD: 2, MaxMsPerOpCICD: 15},
+ "Users - XS": {MinMsPerOpLocal: 0, MaxMsPerOpLocal: 5, MinMsPerOpCICD: 1, MaxMsPerOpCICD: 50},
+ "Users - S": {MinMsPerOpLocal: 0, MaxMsPerOpLocal: 5, MinMsPerOpCICD: 1, MaxMsPerOpCICD: 50},
+ "Users - M": {MinMsPerOpLocal: 0, MaxMsPerOpLocal: 5, MinMsPerOpCICD: 1, MaxMsPerOpCICD: 50},
+ "Users - L": {MinMsPerOpLocal: 0, MaxMsPerOpLocal: 5, MinMsPerOpCICD: 1, MaxMsPerOpCICD: 50},
+ "Peers - L": {MinMsPerOpLocal: 0, MaxMsPerOpLocal: 5, MinMsPerOpCICD: 1, MaxMsPerOpCICD: 50},
+ "Groups - L": {MinMsPerOpLocal: 0, MaxMsPerOpLocal: 5, MinMsPerOpCICD: 1, MaxMsPerOpCICD: 50},
+ "Setup Keys - L": {MinMsPerOpLocal: 0, MaxMsPerOpLocal: 5, MinMsPerOpCICD: 1, MaxMsPerOpCICD: 50},
+ "Users - XL": {MinMsPerOpLocal: 0, MaxMsPerOpLocal: 5, MinMsPerOpCICD: 1, MaxMsPerOpCICD: 50},
}
log.SetOutput(io.Discard)
defer log.SetOutput(os.Stderr)
- recorder := httptest.NewRecorder()
-
for name, bc := range benchCasesUsers {
b.Run(name, func(b *testing.B) {
apiHandler, am, _ := testing_tools.BuildApiBlackBoxWithDBState(b, "../testdata/users.sql", nil, false)
testing_tools.PopulateTestData(b, am.(*server.DefaultAccountManager), bc.Peers, bc.Groups, 1000, bc.SetupKeys)
+ recorder := httptest.NewRecorder()
b.ResetTimer()
start := time.Now()
for i := 0; i < b.N; i++ {
diff --git a/management/server/http/testing/testing_tools/tools.go b/management/server/http/testing/testing_tools/tools.go
index a1e121fdf..1ed3355ef 100644
--- a/management/server/http/testing/testing_tools/tools.go
+++ b/management/server/http/testing/testing_tools/tools.go
@@ -3,6 +3,7 @@ package testing_tools
import (
"bytes"
"context"
+ "errors"
"fmt"
"io"
"net"
@@ -13,7 +14,11 @@ import (
"testing"
"time"
+<<<<<<< HEAD
"github.com/netbirdio/management-integrations/integrations"
+=======
+ "github.com/golang-jwt/jwt"
+>>>>>>> main
"github.com/stretchr/testify/assert"
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
@@ -23,11 +28,11 @@ import (
"github.com/netbirdio/netbird/management/server"
"github.com/netbirdio/netbird/management/server/activity"
+ "github.com/netbirdio/netbird/management/server/auth"
+ nbcontext "github.com/netbirdio/netbird/management/server/context"
"github.com/netbirdio/netbird/management/server/geolocation"
"github.com/netbirdio/netbird/management/server/groups"
nbhttp "github.com/netbirdio/netbird/management/server/http"
- "github.com/netbirdio/netbird/management/server/http/configs"
- "github.com/netbirdio/netbird/management/server/jwtclaims"
"github.com/netbirdio/netbird/management/server/networks"
"github.com/netbirdio/netbird/management/server/networks/resources"
"github.com/netbirdio/netbird/management/server/networks/routers"
@@ -36,6 +41,7 @@ import (
"github.com/netbirdio/netbird/management/server/store"
"github.com/netbirdio/netbird/management/server/telemetry"
"github.com/netbirdio/netbird/management/server/types"
+ "github.com/netbirdio/netbird/management/server/util"
)
const (
@@ -120,13 +126,26 @@ func BuildApiBlackBoxWithDBState(t TB, sqlFile string, expectedPeerUpdate *serve
t.Fatalf("Failed to create manager: %v", err)
}
+ // @note this is required so that PAT's validate from store, but JWT's are mocked
+ authManager := auth.NewManager(store, "", "", "", "", []string{}, false)
+ authManagerMock := &auth.MockManager{
+ ValidateAndParseTokenFunc: mockValidateAndParseToken,
+ EnsureUserAccessByJWTGroupsFunc: authManager.EnsureUserAccessByJWTGroups,
+ MarkPATUsedFunc: authManager.MarkPATUsed,
+ GetPATInfoFunc: authManager.GetPATInfo,
+ }
+
networksManagerMock := networks.NewManagerMock()
resourcesManagerMock := resources.NewManagerMock()
routersManagerMock := routers.NewManagerMock()
groupsManagerMock := groups.NewManagerMock()
+<<<<<<< HEAD
permissionsManagerMock := permissions.NewManagerMock()
peersManager := peers.NewManager(store, permissionsManagerMock)
apiHandler, err := nbhttp.NewAPIHandler(context.Background(), am, networksManagerMock, resourcesManagerMock, routersManagerMock, groupsManagerMock, geoMock, &jwtclaims.JwtValidatorMock{}, metrics, configs.AuthCfg{}, validatorMock, proxyController, permissionsManagerMock, peersManager)
+=======
+ apiHandler, err := nbhttp.NewAPIHandler(context.Background(), am, networksManagerMock, resourcesManagerMock, routersManagerMock, groupsManagerMock, geoMock, authManagerMock, metrics, &server.Config{}, validatorMock)
+>>>>>>> main
if err != nil {
t.Fatalf("Failed to create API handler: %v", err)
}
@@ -316,3 +335,25 @@ func EvaluateBenchmarkResults(b *testing.B, name string, duration time.Duration,
b.Fatalf("Benchmark %s failed: too slow (%.2f ms/op, maximum %.2f ms/op)", name, msPerOp, maxExpected)
}
}
+
+func mockValidateAndParseToken(_ context.Context, token string) (nbcontext.UserAuth, *jwt.Token, error) {
+ userAuth := nbcontext.UserAuth{}
+
+ switch token {
+ case "testUserId", "testAdminId", "testOwnerId", "testServiceUserId", "testServiceAdminId", "blockedUserId":
+ userAuth.UserId = token
+ userAuth.AccountId = "testAccountId"
+ userAuth.Domain = "test.com"
+ userAuth.DomainCategory = "private"
+ case "otherUserId":
+ userAuth.UserId = "otherUserId"
+ userAuth.AccountId = "otherAccountId"
+ userAuth.Domain = "other.com"
+ userAuth.DomainCategory = "private"
+ case "invalidToken":
+ return userAuth, nil, errors.New("invalid token")
+ }
+
+ jwtToken := jwt.New(jwt.SigningMethodHS256)
+ return userAuth, jwtToken, nil
+}
diff --git a/management/server/jwtclaims/claims.go b/management/server/jwtclaims/claims.go
deleted file mode 100644
index 2527acbe3..000000000
--- a/management/server/jwtclaims/claims.go
+++ /dev/null
@@ -1,19 +0,0 @@
-package jwtclaims
-
-import (
- "time"
-
- "github.com/golang-jwt/jwt"
-)
-
-// AuthorizationClaims stores authorization information from JWTs
-type AuthorizationClaims struct {
- UserId string
- AccountId string
- Domain string
- DomainCategory string
- LastLogin time.Time
- Invited bool
-
- Raw jwt.MapClaims
-}
diff --git a/management/server/jwtclaims/extractor_test.go b/management/server/jwtclaims/extractor_test.go
deleted file mode 100644
index eccd7c9e7..000000000
--- a/management/server/jwtclaims/extractor_test.go
+++ /dev/null
@@ -1,227 +0,0 @@
-package jwtclaims
-
-import (
- "context"
- "net/http"
- "testing"
- "time"
-
- "github.com/golang-jwt/jwt"
- "github.com/stretchr/testify/require"
-)
-
-func newTestRequestWithJWT(t *testing.T, claims AuthorizationClaims, audience string) *http.Request {
- t.Helper()
- const layout = "2006-01-02T15:04:05.999Z"
-
- claimMaps := jwt.MapClaims{}
- if claims.UserId != "" {
- claimMaps[UserIDClaim] = claims.UserId
- }
- if claims.AccountId != "" {
- claimMaps[audience+AccountIDSuffix] = claims.AccountId
- }
- if claims.Domain != "" {
- claimMaps[audience+DomainIDSuffix] = claims.Domain
- }
- if claims.DomainCategory != "" {
- claimMaps[audience+DomainCategorySuffix] = claims.DomainCategory
- }
- if claims.LastLogin != (time.Time{}) {
- claimMaps[audience+LastLoginSuffix] = claims.LastLogin.Format(layout)
- }
-
- if claims.Invited {
- claimMaps[audience+Invited] = true
- }
- token := jwt.NewWithClaims(jwt.SigningMethodHS256, claimMaps)
- r, err := http.NewRequest(http.MethodGet, "http://localhost", nil)
- require.NoError(t, err, "creating testing request failed")
- testRequest := r.WithContext(context.WithValue(r.Context(), TokenUserProperty, token)) // nolint
-
- return testRequest
-}
-
-func TestExtractClaimsFromRequestContext(t *testing.T) {
- type test struct {
- name string
- inputAuthorizationClaims AuthorizationClaims
- inputAudiance string
- testingFunc require.ComparisonAssertionFunc
- expectedMSG string
- }
-
- const layout = "2006-01-02T15:04:05.999Z"
- lastLogin, _ := time.Parse(layout, "2023-08-17T09:30:40.465Z")
-
- testCase1 := test{
- name: "All Claim Fields",
- inputAudiance: "https://login/",
- inputAuthorizationClaims: AuthorizationClaims{
- UserId: "test",
- Domain: "test.com",
- AccountId: "testAcc",
- LastLogin: lastLogin,
- DomainCategory: "public",
- Invited: true,
- Raw: jwt.MapClaims{
- "https://login/wt_account_domain": "test.com",
- "https://login/wt_account_domain_category": "public",
- "https://login/wt_account_id": "testAcc",
- "https://login/nb_last_login": lastLogin.Format(layout),
- "sub": "test",
- "https://login/" + Invited: true,
- },
- },
- testingFunc: require.EqualValues,
- expectedMSG: "extracted claims should match input claims",
- }
-
- testCase2 := test{
- name: "Domain Is Empty",
- inputAudiance: "https://login/",
- inputAuthorizationClaims: AuthorizationClaims{
- UserId: "test",
- AccountId: "testAcc",
- Raw: jwt.MapClaims{
- "https://login/wt_account_id": "testAcc",
- "sub": "test",
- },
- },
- testingFunc: require.EqualValues,
- expectedMSG: "extracted claims should match input claims",
- }
-
- testCase3 := test{
- name: "Account ID Is Empty",
- inputAudiance: "https://login/",
- inputAuthorizationClaims: AuthorizationClaims{
- UserId: "test",
- Domain: "test.com",
- Raw: jwt.MapClaims{
- "https://login/wt_account_domain": "test.com",
- "sub": "test",
- },
- },
- testingFunc: require.EqualValues,
- expectedMSG: "extracted claims should match input claims",
- }
-
- testCase4 := test{
- name: "Category Is Empty",
- inputAudiance: "https://login/",
- inputAuthorizationClaims: AuthorizationClaims{
- UserId: "test",
- Domain: "test.com",
- AccountId: "testAcc",
- Raw: jwt.MapClaims{
- "https://login/wt_account_domain": "test.com",
- "https://login/wt_account_id": "testAcc",
- "sub": "test",
- },
- },
- testingFunc: require.EqualValues,
- expectedMSG: "extracted claims should match input claims",
- }
-
- testCase5 := test{
- name: "Only User ID Is set",
- inputAudiance: "https://login/",
- inputAuthorizationClaims: AuthorizationClaims{
- UserId: "test",
- Raw: jwt.MapClaims{
- "sub": "test",
- },
- },
- testingFunc: require.EqualValues,
- expectedMSG: "extracted claims should match input claims",
- }
-
- for _, testCase := range []test{testCase1, testCase2, testCase3, testCase4, testCase5} {
- t.Run(testCase.name, func(t *testing.T) {
- request := newTestRequestWithJWT(t, testCase.inputAuthorizationClaims, testCase.inputAudiance)
-
- extractor := NewClaimsExtractor(WithAudience(testCase.inputAudiance))
- extractedClaims := extractor.FromRequestContext(request)
-
- testCase.testingFunc(t, testCase.inputAuthorizationClaims, extractedClaims, testCase.expectedMSG)
- })
- }
-}
-
-func TestExtractClaimsSetOptions(t *testing.T) {
- t.Helper()
- type test struct {
- name string
- extractor *ClaimsExtractor
- check func(t *testing.T, c test)
- }
-
- testCase1 := test{
- name: "No custom options",
- extractor: NewClaimsExtractor(),
- check: func(t *testing.T, c test) {
- t.Helper()
- if c.extractor.authAudience != "" {
- t.Error("audience should be empty")
- return
- }
- if c.extractor.userIDClaim != UserIDClaim {
- t.Errorf("user id claim should be default, expected %s, got %s", UserIDClaim, c.extractor.userIDClaim)
- return
- }
- if c.extractor.FromRequestContext == nil {
- t.Error("from request context should not be nil")
- return
- }
- },
- }
-
- testCase2 := test{
- name: "Custom audience",
- extractor: NewClaimsExtractor(WithAudience("https://login/")),
- check: func(t *testing.T, c test) {
- t.Helper()
- if c.extractor.authAudience != "https://login/" {
- t.Errorf("audience expected %s, got %s", "https://login/", c.extractor.authAudience)
- return
- }
- },
- }
-
- testCase3 := test{
- name: "Custom user id claim",
- extractor: NewClaimsExtractor(WithUserIDClaim("customUserId")),
- check: func(t *testing.T, c test) {
- t.Helper()
- if c.extractor.userIDClaim != "customUserId" {
- t.Errorf("user id claim expected %s, got %s", "customUserId", c.extractor.userIDClaim)
- return
- }
- },
- }
-
- testCase4 := test{
- name: "Custom extractor from request context",
- extractor: NewClaimsExtractor(
- WithFromRequestContext(func(r *http.Request) AuthorizationClaims {
- return AuthorizationClaims{
- UserId: "testCustomRequest",
- }
- })),
- check: func(t *testing.T, c test) {
- t.Helper()
- claims := c.extractor.FromRequestContext(&http.Request{})
- if claims.UserId != "testCustomRequest" {
- t.Errorf("user id claim expected %s, got %s", "testCustomRequest", claims.UserId)
- return
- }
- },
- }
-
- for _, testCase := range []test{testCase1, testCase2, testCase3, testCase4} {
- t.Run(testCase.name, func(t *testing.T) {
- testCase.check(t, testCase)
- })
- }
-}
diff --git a/management/server/jwtclaims/jwtValidator.go b/management/server/jwtclaims/jwtValidator.go
deleted file mode 100644
index 79e59e76f..000000000
--- a/management/server/jwtclaims/jwtValidator.go
+++ /dev/null
@@ -1,349 +0,0 @@
-package jwtclaims
-
-import (
- "context"
- "crypto/ecdsa"
- "crypto/elliptic"
- "crypto/rsa"
- "encoding/base64"
- "encoding/json"
- "errors"
- "fmt"
- "math/big"
- "net/http"
- "strconv"
- "strings"
- "sync"
- "time"
-
- "github.com/golang-jwt/jwt"
- log "github.com/sirupsen/logrus"
-)
-
-// Options is a struct for specifying configuration options for the middleware.
-type Options struct {
- // The function that will return the Key to validate the JWT.
- // It can be either a shared secret or a public key.
- // Default value: nil
- ValidationKeyGetter jwt.Keyfunc
- // The name of the property in the request where the user information
- // from the JWT will be stored.
- // Default value: "user"
- UserProperty string
- // The function that will be called when there's an error validating the token
- // Default value:
- CredentialsOptional bool
- // A function that extracts the token from the request
- // Default: FromAuthHeader (i.e., from Authorization header as bearer token)
- Debug bool
- // When set, all requests with the OPTIONS method will use authentication
- // Default: false
- EnableAuthOnOptions bool
-}
-
-// Jwks is a collection of JSONWebKey obtained from Config.HttpServerConfig.AuthKeysLocation
-type Jwks struct {
- Keys []JSONWebKey `json:"keys"`
- expiresInTime time.Time
-}
-
-// The supported elliptic curves types
-const (
- // p256 represents a cryptographic elliptical curve type.
- p256 = "P-256"
-
- // p384 represents a cryptographic elliptical curve type.
- p384 = "P-384"
-
- // p521 represents a cryptographic elliptical curve type.
- p521 = "P-521"
-)
-
-// JSONWebKey is a representation of a Jason Web Key
-type JSONWebKey struct {
- Kty string `json:"kty"`
- Kid string `json:"kid"`
- Use string `json:"use"`
- N string `json:"n"`
- E string `json:"e"`
- Crv string `json:"crv"`
- X string `json:"x"`
- Y string `json:"y"`
- X5c []string `json:"x5c"`
-}
-
-type JWTValidator interface {
- ValidateAndParse(ctx context.Context, token string) (*jwt.Token, error)
-}
-
-// jwtValidatorImpl struct to handle token validation and parsing
-type jwtValidatorImpl struct {
- options Options
-}
-
-var keyNotFound = errors.New("unable to find appropriate key")
-
-// NewJWTValidator constructor
-func NewJWTValidator(ctx context.Context, issuer string, audienceList []string, keysLocation string, idpSignkeyRefreshEnabled bool) (JWTValidator, error) {
- keys, err := getPemKeys(ctx, keysLocation)
- if err != nil {
- return nil, err
- }
-
- var lock sync.Mutex
- options := Options{
- ValidationKeyGetter: func(token *jwt.Token) (interface{}, error) {
- // Verify 'aud' claim
- var checkAud bool
- for _, audience := range audienceList {
- checkAud = token.Claims.(jwt.MapClaims).VerifyAudience(audience, false)
- if checkAud {
- break
- }
- }
- if !checkAud {
- return token, errors.New("invalid audience")
- }
- // Verify 'issuer' claim
- checkIss := token.Claims.(jwt.MapClaims).VerifyIssuer(issuer, false)
- if !checkIss {
- return token, errors.New("invalid issuer")
- }
-
- // If keys are rotated, verify the keys prior to token validation
- if idpSignkeyRefreshEnabled {
- // If the keys are invalid, retrieve new ones
- if !keys.stillValid() {
- lock.Lock()
- defer lock.Unlock()
-
- refreshedKeys, err := getPemKeys(ctx, keysLocation)
- if err != nil {
- log.WithContext(ctx).Debugf("cannot get JSONWebKey: %v, falling back to old keys", err)
- refreshedKeys = keys
- }
-
- log.WithContext(ctx).Debugf("keys refreshed, new UTC expiration time: %s", refreshedKeys.expiresInTime.UTC())
-
- keys = refreshedKeys
- }
- }
-
- publicKey, err := getPublicKey(ctx, token, keys)
- if err == nil {
- return publicKey, nil
- }
-
- msg := fmt.Sprintf("getPublicKey error: %s", err)
- if errors.Is(err, keyNotFound) && !idpSignkeyRefreshEnabled {
- msg = fmt.Sprintf("getPublicKey error: %s. You can enable key refresh by setting HttpServerConfig.IdpSignKeyRefreshEnabled to true in your management.json file and restart the service", err)
- }
-
- log.WithContext(ctx).Error(msg)
-
- return nil, err
- },
- EnableAuthOnOptions: false,
- }
-
- if options.UserProperty == "" {
- options.UserProperty = "user"
- }
-
- return &jwtValidatorImpl{
- options: options,
- }, nil
-}
-
-// ValidateAndParse validates the token and returns the parsed token
-func (m *jwtValidatorImpl) ValidateAndParse(ctx context.Context, token string) (*jwt.Token, error) {
- // If the token is empty...
- if token == "" {
- // Check if it was required
- if m.options.CredentialsOptional {
- log.WithContext(ctx).Debugf("no credentials found (CredentialsOptional=true)")
- // No error, just no token (and that is ok given that CredentialsOptional is true)
- return nil, nil //nolint:nilnil
- }
-
- // If we get here, the required token is missing
- errorMsg := "required authorization token not found"
- log.WithContext(ctx).Debugf(" Error: No credentials found (CredentialsOptional=false)")
- return nil, errors.New(errorMsg)
- }
-
- // Now parse the token
- parsedToken, err := jwt.Parse(token, m.options.ValidationKeyGetter)
-
- // Check if there was an error in parsing...
- if err != nil {
- log.WithContext(ctx).Errorf("error parsing token: %v", err)
- return nil, fmt.Errorf("error parsing token: %w", err)
- }
-
- // Check if the parsed token is valid...
- if !parsedToken.Valid {
- errorMsg := "token is invalid"
- log.WithContext(ctx).Debug(errorMsg)
- return nil, errors.New(errorMsg)
- }
-
- return parsedToken, nil
-}
-
-// stillValid returns true if the JSONWebKey still valid and have enough time to be used
-func (jwks *Jwks) stillValid() bool {
- return !jwks.expiresInTime.IsZero() && time.Now().Add(5*time.Second).Before(jwks.expiresInTime)
-}
-
-func getPemKeys(ctx context.Context, keysLocation string) (*Jwks, error) {
- resp, err := http.Get(keysLocation)
- if err != nil {
- return nil, err
- }
- defer resp.Body.Close()
-
- jwks := &Jwks{}
- err = json.NewDecoder(resp.Body).Decode(jwks)
- if err != nil {
- return jwks, err
- }
-
- cacheControlHeader := resp.Header.Get("Cache-Control")
- expiresIn := getMaxAgeFromCacheHeader(ctx, cacheControlHeader)
- jwks.expiresInTime = time.Now().Add(time.Duration(expiresIn) * time.Second)
-
- return jwks, err
-}
-
-func getPublicKey(ctx context.Context, token *jwt.Token, jwks *Jwks) (interface{}, error) {
- // todo as we load the jkws when the server is starting, we should build a JKS map with the pem cert at the boot time
-
- for k := range jwks.Keys {
- if token.Header["kid"] != jwks.Keys[k].Kid {
- continue
- }
-
- if len(jwks.Keys[k].X5c) != 0 {
- cert := "-----BEGIN CERTIFICATE-----\n" + jwks.Keys[k].X5c[0] + "\n-----END CERTIFICATE-----"
- return jwt.ParseRSAPublicKeyFromPEM([]byte(cert))
- }
-
- if jwks.Keys[k].Kty == "RSA" {
- log.WithContext(ctx).Debugf("generating PublicKey from RSA JWK")
- return getPublicKeyFromRSA(jwks.Keys[k])
- }
- if jwks.Keys[k].Kty == "EC" {
- log.WithContext(ctx).Debugf("generating PublicKey from ECDSA JWK")
- return getPublicKeyFromECDSA(jwks.Keys[k])
- }
-
- log.WithContext(ctx).Debugf("Key Type: %s not yet supported, please raise ticket!", jwks.Keys[k].Kty)
- }
-
- return nil, keyNotFound
-}
-
-func getPublicKeyFromECDSA(jwk JSONWebKey) (publicKey *ecdsa.PublicKey, err error) {
-
- if jwk.X == "" || jwk.Y == "" || jwk.Crv == "" {
- return nil, fmt.Errorf("ecdsa key incomplete")
- }
-
- var xCoordinate []byte
- if xCoordinate, err = base64.RawURLEncoding.DecodeString(jwk.X); err != nil {
- return nil, err
- }
-
- var yCoordinate []byte
- if yCoordinate, err = base64.RawURLEncoding.DecodeString(jwk.Y); err != nil {
- return nil, err
- }
-
- publicKey = &ecdsa.PublicKey{}
-
- var curve elliptic.Curve
- switch jwk.Crv {
- case p256:
- curve = elliptic.P256()
- case p384:
- curve = elliptic.P384()
- case p521:
- curve = elliptic.P521()
- }
-
- publicKey.Curve = curve
- publicKey.X = big.NewInt(0).SetBytes(xCoordinate)
- publicKey.Y = big.NewInt(0).SetBytes(yCoordinate)
-
- return publicKey, nil
-}
-
-func getPublicKeyFromRSA(jwk JSONWebKey) (*rsa.PublicKey, error) {
-
- decodedE, err := base64.RawURLEncoding.DecodeString(jwk.E)
- if err != nil {
- return nil, err
- }
- decodedN, err := base64.RawURLEncoding.DecodeString(jwk.N)
- if err != nil {
- return nil, err
- }
-
- var n, e big.Int
- e.SetBytes(decodedE)
- n.SetBytes(decodedN)
-
- return &rsa.PublicKey{
- E: int(e.Int64()),
- N: &n,
- }, nil
-}
-
-// getMaxAgeFromCacheHeader extracts max-age directive from the Cache-Control header
-func getMaxAgeFromCacheHeader(ctx context.Context, cacheControl string) int {
- // Split into individual directives
- directives := strings.Split(cacheControl, ",")
-
- for _, directive := range directives {
- directive = strings.TrimSpace(directive)
- if strings.HasPrefix(directive, "max-age=") {
- // Extract the max-age value
- maxAgeStr := strings.TrimPrefix(directive, "max-age=")
- maxAge, err := strconv.Atoi(maxAgeStr)
- if err != nil {
- log.WithContext(ctx).Debugf("error parsing max-age: %v", err)
- return 0
- }
-
- return maxAge
- }
- }
-
- return 0
-}
-
-type JwtValidatorMock struct{}
-
-func (j *JwtValidatorMock) ValidateAndParse(ctx context.Context, token string) (*jwt.Token, error) {
- claimMaps := jwt.MapClaims{}
-
- switch token {
- case "testUserId", "testAdminId", "testOwnerId", "testServiceUserId", "testServiceAdminId", "blockedUserId":
- claimMaps[UserIDClaim] = token
- claimMaps[AccountIDSuffix] = "testAccountId"
- claimMaps[DomainIDSuffix] = "test.com"
- claimMaps[DomainCategorySuffix] = "private"
- case "otherUserId":
- claimMaps[UserIDClaim] = "otherUserId"
- claimMaps[AccountIDSuffix] = "otherAccountId"
- claimMaps[DomainIDSuffix] = "other.com"
- claimMaps[DomainCategorySuffix] = "private"
- case "invalidToken":
- return nil, errors.New("invalid token")
- }
-
- jwtToken := jwt.NewWithClaims(jwt.SigningMethodHS256, claimMaps)
- return jwtToken, nil
-}
-
diff --git a/management/server/management_proto_test.go b/management/server/management_proto_test.go
index 1c9703c58..c838c4a27 100644
--- a/management/server/management_proto_test.go
+++ b/management/server/management_proto_test.go
@@ -441,7 +441,7 @@ func startManagementForTest(t *testing.T, testFile string, config *Config) (*grp
secretsManager := NewTimeBasedAuthSecretsManager(peersUpdateManager, config.TURNConfig, config.Relay)
ephemeralMgr := NewEphemeralManager(store, accountManager)
- mgmtServer, err := NewServer(context.Background(), config, accountManager, settings.NewManager(store), peersUpdateManager, secretsManager, nil, ephemeralMgr)
+ mgmtServer, err := NewServer(context.Background(), config, accountManager, settings.NewManager(store), peersUpdateManager, secretsManager, nil, ephemeralMgr, nil)
if err != nil {
return nil, nil, "", cleanup, err
}
diff --git a/management/server/management_test.go b/management/server/management_test.go
index 56c295815..838065e49 100644
--- a/management/server/management_test.go
+++ b/management/server/management_test.go
@@ -206,6 +206,7 @@ func startServer(
secretsManager,
nil,
nil,
+ nil,
)
if err != nil {
t.Fatalf("failed creating management server: %v", err)
diff --git a/management/server/mock_server/account_mock.go b/management/server/mock_server/account_mock.go
index d5bee97f3..5564aab01 100644
--- a/management/server/mock_server/account_mock.go
+++ b/management/server/mock_server/account_mock.go
@@ -13,14 +13,16 @@ import (
"github.com/netbirdio/netbird/management/domain"
"github.com/netbirdio/netbird/management/server"
"github.com/netbirdio/netbird/management/server/activity"
+ nbcontext "github.com/netbirdio/netbird/management/server/context"
"github.com/netbirdio/netbird/management/server/idp"
- "github.com/netbirdio/netbird/management/server/jwtclaims"
nbpeer "github.com/netbirdio/netbird/management/server/peer"
"github.com/netbirdio/netbird/management/server/posture"
"github.com/netbirdio/netbird/management/server/types"
"github.com/netbirdio/netbird/route"
)
+var _ server.AccountManager = (*MockAccountManager)(nil)
+
type MockAccountManager struct {
GetOrCreateAccountByUserFunc func(ctx context.Context, userId, domain string) (*types.Account, error)
GetAccountFunc func(ctx context.Context, accountID string) (*types.Account, error)
@@ -29,7 +31,7 @@ type MockAccountManager struct {
GetSetupKeyFunc func(ctx context.Context, accountID, userID, keyID string) (*types.SetupKey, error)
AccountExistsFunc func(ctx context.Context, accountID string) (bool, error)
GetAccountIDByUserIdFunc func(ctx context.Context, userId, domain string) (string, error)
- GetUserFunc func(ctx context.Context, claims jwtclaims.AuthorizationClaims) (*types.User, error)
+ GetUserFromUserAuthFunc func(ctx context.Context, userAuth nbcontext.UserAuth) (*types.User, error)
ListUsersFunc func(ctx context.Context, accountID string) ([]*types.User, error)
GetPeersFunc func(ctx context.Context, accountID, userID, nameFilter, ipFilter string) ([]*nbpeer.Peer, error)
MarkPeerConnectedFunc func(ctx context.Context, peerKey string, connected bool, realIP net.IP) error
@@ -54,8 +56,6 @@ type MockAccountManager struct {
DeletePolicyFunc func(ctx context.Context, accountID, policyID, userID string) error
ListPoliciesFunc func(ctx context.Context, accountID, userID string) ([]*types.Policy, error)
GetUsersFromAccountFunc func(ctx context.Context, accountID, userID string) (map[string]*types.UserInfo, error)
- GetPATInfoFunc func(ctx context.Context, token string) (*types.User, *types.PersonalAccessToken, string, string, error)
- MarkPATUsedFunc func(ctx context.Context, pat string) error
UpdatePeerMetaFunc func(ctx context.Context, peerID string, meta nbpeer.PeerSystemMeta) error
UpdatePeerFunc func(ctx context.Context, accountID, userID string, peer *nbpeer.Peer) (*nbpeer.Peer, error)
CreateRouteFunc func(ctx context.Context, accountID string, prefix netip.Prefix, networkType route.NetworkType, domains domain.List, peer string, peerGroups []string, description string, netID route.NetID, masquerade bool, metric int, groups, accessControlGroupIDs []string, enabled bool, userID string, keepRoute bool) (*route.Route, error)
@@ -80,8 +80,7 @@ type MockAccountManager struct {
DeleteNameServerGroupFunc func(ctx context.Context, accountID, nsGroupID, userID string) error
ListNameServerGroupsFunc func(ctx context.Context, accountID string, userID string) ([]*nbdns.NameServerGroup, error)
CreateUserFunc func(ctx context.Context, accountID, userID string, key *types.UserInfo) (*types.UserInfo, error)
- GetAccountIDFromTokenFunc func(ctx context.Context, claims jwtclaims.AuthorizationClaims) (string, string, error)
- CheckUserAccessByJWTGroupsFunc func(ctx context.Context, claims jwtclaims.AuthorizationClaims) error
+ GetAccountIDFromUserAuthFunc func(ctx context.Context, userAuth nbcontext.UserAuth) (string, string, error)
DeleteAccountFunc func(ctx context.Context, accountID, userID string) error
GetDNSDomainFunc func() string
StoreEventFunc func(ctx context.Context, initiatorID, targetID, accountID string, activityID activity.ActivityDescriber, meta map[string]any)
@@ -240,14 +239,6 @@ func (am *MockAccountManager) MarkPeerConnected(ctx context.Context, peerKey str
return status.Errorf(codes.Unimplemented, "method MarkPeerConnected is not implemented")
}
-// GetPATInfo mock implementation of GetPATInfo from server.AccountManager interface
-func (am *MockAccountManager) GetPATInfo(ctx context.Context, pat string) (*types.User, *types.PersonalAccessToken, string, string, error) {
- if am.GetPATInfoFunc != nil {
- return am.GetPATInfoFunc(ctx, pat)
- }
- return nil, nil, "", "", status.Errorf(codes.Unimplemented, "method GetPATInfo is not implemented")
-}
-
// DeleteAccount mock implementation of DeleteAccount from server.AccountManager interface
func (am *MockAccountManager) DeleteAccount(ctx context.Context, accountID, userID string) error {
if am.DeleteAccountFunc != nil {
@@ -256,14 +247,6 @@ func (am *MockAccountManager) DeleteAccount(ctx context.Context, accountID, user
return status.Errorf(codes.Unimplemented, "method DeleteAccount is not implemented")
}
-// MarkPATUsed mock implementation of MarkPATUsed from server.AccountManager interface
-func (am *MockAccountManager) MarkPATUsed(ctx context.Context, pat string) error {
- if am.MarkPATUsedFunc != nil {
- return am.MarkPATUsedFunc(ctx, pat)
- }
- return status.Errorf(codes.Unimplemented, "method MarkPATUsed is not implemented")
-}
-
// CreatePAT mock implementation of GetPAT from server.AccountManager interface
func (am *MockAccountManager) CreatePAT(ctx context.Context, accountID string, initiatorUserID string, targetUserID string, name string, expiresIn int) (*types.PersonalAccessTokenGenerated, error) {
if am.CreatePATFunc != nil {
@@ -430,11 +413,11 @@ func (am *MockAccountManager) UpdatePeerMeta(ctx context.Context, peerID string,
}
// GetUser mock implementation of GetUser from server.AccountManager interface
-func (am *MockAccountManager) GetUser(ctx context.Context, claims jwtclaims.AuthorizationClaims) (*types.User, error) {
- if am.GetUserFunc != nil {
- return am.GetUserFunc(ctx, claims)
+func (am *MockAccountManager) GetUserFromUserAuth(ctx context.Context, userAuth nbcontext.UserAuth) (*types.User, error) {
+ if am.GetUserFromUserAuthFunc != nil {
+ return am.GetUserFromUserAuthFunc(ctx, userAuth)
}
- return nil, status.Errorf(codes.Unimplemented, "method GetUser is not implemented")
+ return nil, status.Errorf(codes.Unimplemented, "method GetUserFromUserAuth is not implemented")
}
func (am *MockAccountManager) ListUsers(ctx context.Context, accountID string) ([]*types.User, error) {
@@ -614,19 +597,11 @@ func (am *MockAccountManager) CreateUser(ctx context.Context, accountID, userID
return nil, status.Errorf(codes.Unimplemented, "method CreateUser is not implemented")
}
-// GetAccountIDFromToken mocks GetAccountIDFromToken of the AccountManager interface
-func (am *MockAccountManager) GetAccountIDFromToken(ctx context.Context, claims jwtclaims.AuthorizationClaims) (string, string, error) {
- if am.GetAccountIDFromTokenFunc != nil {
- return am.GetAccountIDFromTokenFunc(ctx, claims)
+func (am *MockAccountManager) GetAccountIDFromUserAuth(ctx context.Context, userAuth nbcontext.UserAuth) (string, string, error) {
+ if am.GetAccountIDFromUserAuthFunc != nil {
+ return am.GetAccountIDFromUserAuthFunc(ctx, userAuth)
}
- return "", "", status.Errorf(codes.Unimplemented, "method GetAccountIDFromToken is not implemented")
-}
-
-func (am *MockAccountManager) CheckUserAccessByJWTGroups(ctx context.Context, claims jwtclaims.AuthorizationClaims) error {
- if am.CheckUserAccessByJWTGroupsFunc != nil {
- return am.CheckUserAccessByJWTGroupsFunc(ctx, claims)
- }
- return status.Errorf(codes.Unimplemented, "method CheckUserAccessByJWTGroups is not implemented")
+ return "", "", status.Errorf(codes.Unimplemented, "method GetAccountIDFromUserAuth is not implemented")
}
// GetPeers mocks GetPeers of the AccountManager interface
@@ -859,3 +834,7 @@ func (am *MockAccountManager) BuildUserInfosForAccount(ctx context.Context, acco
}
return nil, status.Errorf(codes.Unimplemented, "method BuildUserInfosForAccount is not implemented")
}
+
+func (am *MockAccountManager) SyncUserJWTGroups(ctx context.Context, userAuth nbcontext.UserAuth) error {
+ return status.Errorf(codes.Unimplemented, "method SyncUserJWTGroups is not implemented")
+}
diff --git a/management/server/store/sql_store.go b/management/server/store/sql_store.go
index 6d86557df..1dae3999b 100644
--- a/management/server/store/sql_store.go
+++ b/management/server/store/sql_store.go
@@ -15,7 +15,6 @@ import (
"sync"
"time"
- "github.com/netbirdio/netbird/management/server/util"
log "github.com/sirupsen/logrus"
"gorm.io/driver/mysql"
"gorm.io/driver/postgres"
@@ -24,6 +23,8 @@ import (
"gorm.io/gorm/clause"
"gorm.io/gorm/logger"
+ "github.com/netbirdio/netbird/management/server/util"
+
nbdns "github.com/netbirdio/netbird/dns"
"github.com/netbirdio/netbird/management/server/account"
resourceTypes "github.com/netbirdio/netbird/management/server/networks/resources/types"
@@ -615,6 +616,16 @@ func (s *SqlStore) GetResourceGroups(ctx context.Context, lockStrength LockingSt
return groups, nil
}
+func (s *SqlStore) GetAccountsCounter(ctx context.Context) (int64, error) {
+ var count int64
+ result := s.db.Model(&types.Account{}).Count(&count)
+ if result.Error != nil {
+ return 0, fmt.Errorf("failed to get all accounts counter: %w", result.Error)
+ }
+
+ return count, nil
+}
+
func (s *SqlStore) GetAllAccounts(ctx context.Context) (all []*types.Account) {
var accounts []types.Account
result := s.db.Find(&accounts)
@@ -1035,6 +1046,13 @@ func NewSqliteStoreFromFileStore(ctx context.Context, fileStore *FileStore, data
}
for _, account := range fileStore.GetAllAccounts(ctx) {
+ _, err = account.GetGroupAll()
+ if err != nil {
+ if err := account.AddAllGroup(); err != nil {
+ return nil, err
+ }
+ }
+
err := store.SaveAccount(ctx, account)
if err != nil {
return nil, err
diff --git a/management/server/store/sql_store_test.go b/management/server/store/sql_store_test.go
index bdb5905bd..54649c5c1 100644
--- a/management/server/store/sql_store_test.go
+++ b/management/server/store/sql_store_test.go
@@ -15,7 +15,6 @@ import (
"time"
"github.com/google/uuid"
- "github.com/rs/xid"
log "github.com/sirupsen/logrus"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
@@ -2045,52 +2044,12 @@ func newAccountWithId(ctx context.Context, accountID, userID, domain string) *ty
},
}
- if err := addAllGroup(acc); err != nil {
+ if err := acc.AddAllGroup(); err != nil {
log.WithContext(ctx).Errorf("error adding all group to account %s: %v", acc.Id, err)
}
return acc
}
-// addAllGroup to account object if it doesn't exist
-func addAllGroup(account *types.Account) error {
- if len(account.Groups) == 0 {
- allGroup := &types.Group{
- ID: xid.New().String(),
- Name: "All",
- Issued: types.GroupIssuedAPI,
- }
- for _, peer := range account.Peers {
- allGroup.Peers = append(allGroup.Peers, peer.ID)
- }
- account.Groups = map[string]*types.Group{allGroup.ID: allGroup}
-
- id := xid.New().String()
-
- defaultPolicy := &types.Policy{
- ID: id,
- Name: types.DefaultRuleName,
- Description: types.DefaultRuleDescription,
- Enabled: true,
- Rules: []*types.PolicyRule{
- {
- ID: id,
- Name: types.DefaultRuleName,
- Description: types.DefaultRuleDescription,
- Enabled: true,
- Sources: []string{allGroup.ID},
- Destinations: []string{allGroup.ID},
- Bidirectional: true,
- Protocol: types.PolicyRuleProtocolALL,
- Action: types.PolicyTrafficActionAccept,
- },
- },
- }
-
- account.Policies = []*types.Policy{defaultPolicy}
- }
- return nil
-}
-
func TestSqlStore_GetAccountNetworks(t *testing.T) {
store, cleanup, err := NewTestStoreFromSQL(context.Background(), "../testdata/store.sql", t.TempDir())
t.Cleanup(cleanup)
diff --git a/management/server/store/store.go b/management/server/store/store.go
index e94ba2f35..d84d699bb 100644
--- a/management/server/store/store.go
+++ b/management/server/store/store.go
@@ -48,6 +48,7 @@ const (
)
type Store interface {
+ GetAccountsCounter(ctx context.Context) (int64, error)
GetAllAccounts(ctx context.Context) []*types.Account
GetAccount(ctx context.Context, accountID string) (*types.Account, error)
AccountExists(ctx context.Context, lockStrength LockingStrength, id string) (bool, error)
@@ -352,7 +353,46 @@ func NewTestStoreFromSQL(ctx context.Context, filename string, dataDir string) (
return nil, nil, fmt.Errorf("failed to create test store: %v", err)
}
- return getSqlStoreEngine(ctx, store, kind)
+ err = addAllGroupToAccount(ctx, store)
+ if err != nil {
+ return nil, nil, fmt.Errorf("failed to add all group to account: %v", err)
+ }
+
+
+ maxRetries := 2
+ for i := 0; i < maxRetries; i++ {
+ sqlStore, cleanUp, err := getSqlStoreEngine(ctx, store, kind)
+ if err == nil {
+ return sqlStore, cleanUp, nil
+ }
+ if i < maxRetries-1 {
+ time.Sleep(100 * time.Millisecond)
+ }
+ }
+ return nil, nil, fmt.Errorf("failed to create test store after %d attempts: %v", maxRetries, err)
+}
+
+func addAllGroupToAccount(ctx context.Context, store Store) error {
+ allAccounts := store.GetAllAccounts(ctx)
+ for _, account := range allAccounts {
+ shouldSave := false
+
+ _, err := account.GetGroupAll()
+ if err != nil {
+ if err := account.AddAllGroup(); err != nil {
+ return err
+ }
+ shouldSave = true
+ }
+
+ if shouldSave {
+ err = store.SaveAccount(ctx, account)
+ if err != nil {
+ return err
+ }
+ }
+ }
+ return nil
}
func getSqlStoreEngine(ctx context.Context, store *SqlStore, kind Engine) (Store, func(), error) {
diff --git a/management/server/testdata/storev1.sql b/management/server/testdata/storev1.sql
index cda333d4f..8b09ec2be 100644
--- a/management/server/testdata/storev1.sql
+++ b/management/server/testdata/storev1.sql
@@ -36,4 +36,3 @@ INSERT INTO peers VALUES('xlx9/9D8+ibnRiIIB8nHGMxGOzxV17r8ShPHgi4aYSM=','auth0|6
INSERT INTO peers VALUES('6kjbmVq1hmucVzvBXo5OucY5OYv+jSsB1jUTLq291Dw=','google-oauth2|103201118415301331038','6kjbmVq1hmucVzvBXo5OucY5OYv+jSsB1jUTLq291Dw=','5AFB60DB-61F2-4251-8E11-494847EE88E9','"100.64.0.2"','braginini','linux','Linux','21.04','x86_64','Ubuntu','','','','',NULL,'','','','{"Cloud":"","Platform":""}',NULL,'braginini','braginini','2021-12-24 16:12:05.994305438+01:00',0,0,0,'','',0,0,NULL,'2024-10-02 17:00:54.228182+02:00',0,'""','','',0);
INSERT INTO peers VALUES('Ok+5QMdt/UjoktNOvicGYj+IX2g98p+0N2PJ3vJ45RI=','google-oauth2|103201118415301331038','Ok+5QMdt/UjoktNOvicGYj+IX2g98p+0N2PJ3vJ45RI=','A72E4DC2-00DE-4542-8A24-62945438104E','"100.64.0.1"','braginini','linux','Linux','21.04','x86_64','Ubuntu','','','','',NULL,'','','','{"Cloud":"","Platform":""}',NULL,'braginini','braginini-1','2021-12-24 16:11:27.015739803+01:00',0,0,0,'','',0,0,NULL,'2024-10-02 17:00:54.228182+02:00',1,'""','','',0);
INSERT INTO installations VALUES(1,'');
-
diff --git a/management/server/types/account.go b/management/server/types/account.go
index 4c68b9523..c890a7730 100644
--- a/management/server/types/account.go
+++ b/management/server/types/account.go
@@ -12,6 +12,7 @@ import (
"github.com/hashicorp/go-multierror"
"github.com/miekg/dns"
+ "github.com/rs/xid"
log "github.com/sirupsen/logrus"
nbdns "github.com/netbirdio/netbird/dns"
@@ -1525,3 +1526,43 @@ func getPoliciesSourcePeers(policies []*Policy, groups map[string]*Group) map[st
return sourcePeers
}
+
+// AddAllGroup to account object if it doesn't exist
+func (a *Account) AddAllGroup() error {
+ if len(a.Groups) == 0 {
+ allGroup := &Group{
+ ID: xid.New().String(),
+ Name: "All",
+ Issued: GroupIssuedAPI,
+ }
+ for _, peer := range a.Peers {
+ allGroup.Peers = append(allGroup.Peers, peer.ID)
+ }
+ a.Groups = map[string]*Group{allGroup.ID: allGroup}
+
+ id := xid.New().String()
+
+ defaultPolicy := &Policy{
+ ID: id,
+ Name: DefaultRuleName,
+ Description: DefaultRuleDescription,
+ Enabled: true,
+ Rules: []*PolicyRule{
+ {
+ ID: id,
+ Name: DefaultRuleName,
+ Description: DefaultRuleDescription,
+ Enabled: true,
+ Sources: []string{allGroup.ID},
+ Destinations: []string{allGroup.ID},
+ Bidirectional: true,
+ Protocol: PolicyRuleProtocolALL,
+ Action: PolicyTrafficActionAccept,
+ },
+ },
+ }
+
+ a.Policies = []*Policy{defaultPolicy}
+ }
+ return nil
+}
diff --git a/management/server/user.go b/management/server/user.go
index 6ba9b68d3..381879ae6 100644
--- a/management/server/user.go
+++ b/management/server/user.go
@@ -8,16 +8,16 @@ import (
"time"
"github.com/google/uuid"
+ log "github.com/sirupsen/logrus"
+
"github.com/netbirdio/netbird/management/server/activity"
nbContext "github.com/netbirdio/netbird/management/server/context"
"github.com/netbirdio/netbird/management/server/idp"
- "github.com/netbirdio/netbird/management/server/jwtclaims"
nbpeer "github.com/netbirdio/netbird/management/server/peer"
"github.com/netbirdio/netbird/management/server/status"
"github.com/netbirdio/netbird/management/server/store"
"github.com/netbirdio/netbird/management/server/types"
"github.com/netbirdio/netbird/management/server/util"
- log "github.com/sirupsen/logrus"
)
// createServiceUser creates a new service user under the given account.
@@ -174,31 +174,26 @@ func (am *DefaultAccountManager) GetUserByID(ctx context.Context, id string) (*t
return am.Store.GetUserByUserID(ctx, store.LockingStrengthShare, id)
}
-// GetUser looks up a user by provided authorization claims.
-// It will also create an account if didn't exist for this user before.
-func (am *DefaultAccountManager) GetUser(ctx context.Context, claims jwtclaims.AuthorizationClaims) (*types.User, error) {
- accountID, userID, err := am.GetAccountIDFromToken(ctx, claims)
- if err != nil {
- return nil, fmt.Errorf("failed to get account with token claims %v", err)
- }
-
- user, err := am.Store.GetUserByUserID(ctx, store.LockingStrengthShare, userID)
+// GetUser looks up a user by provided nbContext.UserAuths.
+// Expects account to have been created already.
+func (am *DefaultAccountManager) GetUserFromUserAuth(ctx context.Context, userAuth nbContext.UserAuth) (*types.User, error) {
+ user, err := am.Store.GetUserByUserID(ctx, store.LockingStrengthShare, userAuth.UserId)
if err != nil {
return nil, err
}
// this code should be outside of the am.GetAccountIDFromToken(claims) because this method is called also by the gRPC
// server when user authenticates a device. And we need to separate the Dashboard login event from the Device login event.
- newLogin := user.LastDashboardLoginChanged(claims.LastLogin)
+ newLogin := user.LastDashboardLoginChanged(userAuth.LastLogin)
- err = am.Store.SaveUserLastLogin(ctx, accountID, userID, claims.LastLogin)
+ err = am.Store.SaveUserLastLogin(ctx, userAuth.AccountId, userAuth.UserId, userAuth.LastLogin)
if err != nil {
log.WithContext(ctx).Errorf("failed saving user last login: %v", err)
}
if newLogin {
- meta := map[string]any{"timestamp": claims.LastLogin}
- am.StoreEvent(ctx, claims.UserId, claims.UserId, accountID, activity.DashboardLogin, meta)
+ meta := map[string]any{"timestamp": userAuth.LastLogin}
+ am.StoreEvent(ctx, userAuth.UserId, userAuth.UserId, userAuth.AccountId, activity.DashboardLogin, meta)
}
return user, nil
diff --git a/management/server/user_test.go b/management/server/user_test.go
index 4a532c8a6..a180a761a 100644
--- a/management/server/user_test.go
+++ b/management/server/user_test.go
@@ -10,6 +10,8 @@ import (
"github.com/eko/gocache/v3/cache"
cacheStore "github.com/eko/gocache/v3/store"
"github.com/google/go-cmp/cmp"
+
+ nbcontext "github.com/netbirdio/netbird/management/server/context"
"github.com/netbirdio/netbird/management/server/util"
"golang.org/x/exp/maps"
@@ -25,7 +27,6 @@ import (
"github.com/netbirdio/netbird/management/server/activity"
"github.com/netbirdio/netbird/management/server/idp"
"github.com/netbirdio/netbird/management/server/integration_reference"
- "github.com/netbirdio/netbird/management/server/jwtclaims"
)
const (
@@ -925,11 +926,12 @@ func TestDefaultAccountManager_GetUser(t *testing.T) {
eventStore: &activity.InMemoryEventStore{},
}
- claims := jwtclaims.AuthorizationClaims{
- UserId: mockUserID,
+ claims := nbcontext.UserAuth{
+ UserId: mockUserID,
+ AccountId: mockAccountID,
}
- user, err := am.GetUser(context.Background(), claims)
+ user, err := am.GetUserFromUserAuth(context.Background(), claims)
if err != nil {
t.Fatalf("Error when checking user role: %s", err)
}
diff --git a/release_files/ui-post-install.sh b/release_files/ui-post-install.sh
new file mode 100644
index 000000000..f6e8ddf92
--- /dev/null
+++ b/release_files/ui-post-install.sh
@@ -0,0 +1,10 @@
+#!/bin/sh
+
+# Check if netbird-ui is running
+if pgrep -x -f /usr/bin/netbird-ui >/dev/null 2>&1;
+then
+ runner=$(ps --no-headers -o '%U' -p $(pgrep -x -f /usr/bin/netbird-ui) | sed 's/^[ \t]*//;s/[ \t]*$//')
+ # Only re-run if it was already running
+ pkill -x -f /usr/bin/netbird-ui >/dev/null 2>&1
+ su -l - "$runner" -c 'nohup /usr/bin/netbird-ui > /dev/null 2>&1 &'
+fi